summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rwxr-xr-xsrc/conf_mode/conntrack.py7
-rwxr-xr-xsrc/conf_mode/nat64.py209
-rwxr-xr-xsrc/migration-scripts/conntrack/4-to-559
-rwxr-xr-xsrc/migration-scripts/firewall/10-to-1118
-rwxr-xr-xsrc/migration-scripts/firewall/12-to-139
-rwxr-xr-xsrc/migration-scripts/interfaces/29-to-3011
-rwxr-xr-xsrc/migration-scripts/pppoe-server/6-to-736
-rwxr-xr-xsrc/op_mode/dhcp.py109
-rwxr-xr-xsrc/op_mode/image_installer.py87
9 files changed, 463 insertions, 82 deletions
diff --git a/src/conf_mode/conntrack.py b/src/conf_mode/conntrack.py
index 4cece6921..7f6c71440 100755
--- a/src/conf_mode/conntrack.py
+++ b/src/conf_mode/conntrack.py
@@ -159,6 +159,13 @@ def verify(conntrack):
if not group_obj:
Warning(f'{error_group} "{group_name}" has no members!')
+ if dict_search_args(conntrack, 'timeout', 'custom', inet, 'rule') != None:
+ for rule, rule_config in conntrack['timeout']['custom'][inet]['rule'].items():
+ if 'protocol' not in rule_config:
+ raise ConfigError(f'Conntrack custom timeout rule {rule} requires protocol tcp or udp')
+ else:
+ if 'tcp' in rule_config['protocol'] and 'udp' in rule_config['protocol']:
+ raise ConfigError(f'conntrack custom timeout rule {rule} - Cant use both tcp and udp protocol')
return None
def generate(conntrack):
diff --git a/src/conf_mode/nat64.py b/src/conf_mode/nat64.py
new file mode 100755
index 000000000..a8b90fb11
--- /dev/null
+++ b/src/conf_mode/nat64.py
@@ -0,0 +1,209 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2023 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+# pylint: disable=empty-docstring,missing-module-docstring
+
+import csv
+import os
+import re
+
+from ipaddress import IPv6Network
+from json import dumps as json_write
+
+from vyos import ConfigError
+from vyos import airbag
+from vyos.config import Config
+from vyos.configdict import dict_merge
+from vyos.configdict import is_node_changed
+from vyos.utils.dict import dict_search
+from vyos.utils.file import write_file
+from vyos.utils.kernel import check_kmod
+from vyos.utils.process import cmd
+from vyos.utils.process import run
+
+airbag.enable()
+
+INSTANCE_REGEX = re.compile(r"instance-(\d+)")
+JOOL_CONFIG_DIR = "/run/jool"
+
+
+def get_config(config: Config | None = None) -> None:
+ if config is None:
+ config = Config()
+
+ base = ["nat64"]
+ nat64 = config.get_config_dict(base, key_mangling=("-", "_"), get_first_key=True)
+
+ base_src = base + ["source", "rule"]
+
+ # Load in existing instances so we can destroy any unknown
+ lines = cmd("jool instance display --csv").splitlines()
+ for _, instance, _ in csv.reader(lines):
+ match = INSTANCE_REGEX.fullmatch(instance)
+ if not match:
+ # FIXME: Instances that don't match should be ignored but WARN'ed to the user
+ continue
+ num = match.group(1)
+
+ rules = nat64.setdefault("source", {}).setdefault("rule", {})
+ # Mark it for deletion
+ if num not in rules:
+ rules[num] = {"deleted": True}
+ continue
+
+ # If the user changes the mode, recreate the instance else Jool fails with:
+ # Jool error: Sorry; you can't change an instance's framework for now.
+ if is_node_changed(config, base_src + [f"instance-{num}", "mode"]):
+ rules[num]["recreate"] = True
+
+ # If the user changes the pool6, recreate the instance else Jool fails with:
+ # Jool error: Sorry; you can't change a NAT64 instance's pool6 for now.
+ if dict_search("source.prefix", rules[num]) and is_node_changed(
+ config,
+ base_src + [num, "source", "prefix"],
+ ):
+ rules[num]["recreate"] = True
+
+ return nat64
+
+
+def verify(nat64) -> None:
+ if not nat64:
+ # no need to verify the CLI as nat64 is going to be deactivated
+ return
+
+ if dict_search("source.rule", nat64):
+ # Ensure only 1 netfilter instance per namespace
+ nf_rules = filter(
+ lambda i: "deleted" not in i and i.get('mode') == "netfilter",
+ nat64["source"]["rule"].values(),
+ )
+ next(nf_rules, None) # Discard the first element
+ if next(nf_rules, None) is not None:
+ raise ConfigError(
+ "Jool permits only 1 NAT64 netfilter instance (per network namespace)"
+ )
+
+ for rule, instance in nat64["source"]["rule"].items():
+ if "deleted" in instance:
+ continue
+
+ # Verify that source.prefix is set and is a /96
+ if not dict_search("source.prefix", instance):
+ raise ConfigError(f"Source NAT64 rule {rule} missing source prefix")
+ if IPv6Network(instance["source"]["prefix"]).prefixlen != 96:
+ raise ConfigError(f"Source NAT64 rule {rule} source prefix must be /96")
+
+ pools = dict_search("translation.pool", instance)
+ if pools:
+ for num, pool in pools.items():
+ if "address" not in pool:
+ raise ConfigError(
+ f"Source NAT64 rule {rule} translation pool "
+ f"{num} missing address/prefix"
+ )
+ if "port" not in pool:
+ raise ConfigError(
+ f"Source NAT64 rule {rule} translation pool "
+ f"{num} missing port(-range)"
+ )
+
+
+def generate(nat64) -> None:
+ os.makedirs(JOOL_CONFIG_DIR, exist_ok=True)
+
+ if dict_search("source.rule", nat64):
+ for rule, instance in nat64["source"]["rule"].items():
+ if "deleted" in instance:
+ # Delete the unused instance file
+ os.unlink(os.path.join(JOOL_CONFIG_DIR, f"instance-{rule}.json"))
+ continue
+
+ name = f"instance-{rule}"
+ config = {
+ "instance": name,
+ "framework": "netfilter",
+ "global": {
+ "pool6": instance["source"]["prefix"],
+ "manually-enabled": "disable" not in instance,
+ },
+ # "bib": [],
+ }
+
+ if "description" in instance:
+ config["comment"] = instance["description"]
+
+ if dict_search("translation.pool", instance):
+ pool4 = []
+ for pool in instance["translation"]["pool"].values():
+ if "disable" in pool:
+ continue
+
+ protos = pool.get("protocol", {}).keys() or ("tcp", "udp", "icmp")
+ for proto in protos:
+ obj = {
+ "protocol": proto.upper(),
+ "prefix": pool["address"],
+ "port range": pool["port"],
+ }
+ if "description" in pool:
+ obj["comment"] = pool["description"]
+
+ pool4.append(obj)
+
+ if pool4:
+ config["pool4"] = pool4
+
+ write_file(f'{JOOL_CONFIG_DIR}/{name}.json', json_write(config, indent=2))
+
+
+def apply(nat64) -> None:
+ if not nat64:
+ return
+
+ if dict_search("source.rule", nat64):
+ # Deletions first to avoid conflicts
+ for rule, instance in nat64["source"]["rule"].items():
+ if not any(k in instance for k in ("deleted", "recreate")):
+ continue
+
+ ret = run(f"jool instance remove instance-{rule}")
+ if ret != 0:
+ raise ConfigError(
+ f"Failed to remove nat64 source rule {rule} (jool instance instance-{rule})"
+ )
+
+ # Now creations
+ for rule, instance in nat64["source"]["rule"].items():
+ if "deleted" in instance:
+ continue
+
+ name = f"instance-{rule}"
+ ret = run(f"jool -i {name} file handle {JOOL_CONFIG_DIR}/{name}.json")
+ if ret != 0:
+ raise ConfigError(f"Failed to set jool instance {name}")
+
+
+if __name__ == "__main__":
+ try:
+ check_kmod(["jool"])
+ c = get_config()
+ verify(c)
+ generate(c)
+ apply(c)
+ except ConfigError as e:
+ print(e)
+ exit(1)
diff --git a/src/migration-scripts/conntrack/4-to-5 b/src/migration-scripts/conntrack/4-to-5
new file mode 100755
index 000000000..d2e5fc5fa
--- /dev/null
+++ b/src/migration-scripts/conntrack/4-to-5
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2023 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+# T5779: system conntrack timeout custom
+# Before:
+# Protocols tcp, udp and icmp allowed. When using udp it did not work
+# Only ipv4 custom timeout rules
+# Now:
+# Valid protocols are only tcp or udp.
+# Extend functionality to ipv6 and move ipv4 custom rules to new node:
+# set system conntrack timeout custom [ipv4 | ipv6] rule <rule> ...
+
+from sys import argv
+from sys import exit
+
+from vyos.configtree import ConfigTree
+
+if len(argv) < 2:
+ print("Must specify file name!")
+ exit(1)
+
+file_name = argv[1]
+
+with open(file_name, 'r') as f:
+ config_file = f.read()
+
+base = ['system', 'conntrack']
+config = ConfigTree(config_file)
+
+if not config.exists(base):
+ # Nothing to do
+ exit(0)
+
+if config.exists(base + ['timeout', 'custom', 'rule']):
+ for rule in config.list_nodes(base + ['timeout', 'custom', 'rule']):
+ if config.exists(base + ['timeout', 'custom', 'rule', rule, 'protocol', 'tcp']):
+ config.set(base + ['timeout', 'custom', 'ipv4', 'rule'])
+ config.copy(base + ['timeout', 'custom', 'rule', rule], base + ['timeout', 'custom', 'ipv4', 'rule', rule])
+ config.delete(base + ['timeout', 'custom', 'rule'])
+
+try:
+ with open(file_name, 'w') as f:
+ f.write(config.to_string())
+except OSError as e:
+ print("Failed to save the modified config: {}".format(e))
+ exit(1)
diff --git a/src/migration-scripts/firewall/10-to-11 b/src/migration-scripts/firewall/10-to-11
index b739fb139..e14ea0e51 100755
--- a/src/migration-scripts/firewall/10-to-11
+++ b/src/migration-scripts/firewall/10-to-11
@@ -63,19 +63,11 @@ if not config.exists(base):
### Migration of state policies
if config.exists(base + ['state-policy']):
- for family in ['ipv4', 'ipv6']:
- for hook in ['forward', 'input', 'output']:
- for priority in ['filter']:
- # Add default-action== accept for compatibility reasons:
- config.set(base + [family, hook, priority, 'default-action'], value='accept')
- position = 1
- for state in config.list_nodes(base + ['state-policy']):
- action = config.return_value(base + ['state-policy', state, 'action'])
- config.set(base + [family, hook, priority, 'rule'])
- config.set_tag(base + [family, hook, priority, 'rule'])
- config.set(base + [family, hook, priority, 'rule', position, 'state', state], value='enable')
- config.set(base + [family, hook, priority, 'rule', position, 'action'], value=action)
- position = position + 1
+ for state in config.list_nodes(base + ['state-policy']):
+ action = config.return_value(base + ['state-policy', state, 'action'])
+ config.set(base + ['global-options', 'state-policy', state, 'action'], value=action)
+ if config.exists(base + ['state-policy', state, 'log']):
+ config.set(base + ['global-options', 'state-policy', state, 'log'], value='enable')
config.delete(base + ['state-policy'])
## migration of global options:
diff --git a/src/migration-scripts/firewall/12-to-13 b/src/migration-scripts/firewall/12-to-13
index 4eaae779b..8396dd9d1 100755
--- a/src/migration-scripts/firewall/12-to-13
+++ b/src/migration-scripts/firewall/12-to-13
@@ -49,6 +49,15 @@ if not config.exists(base):
# Nothing to do
exit(0)
+# State Policy logs:
+if config.exists(base + ['global-options', 'state-policy']):
+ for state in config.list_nodes(base + ['global-options', 'state-policy']):
+ if config.exists(base + ['global-options', 'state-policy', state, 'log']):
+ log_value = config.return_value(base + ['global-options', 'state-policy', state, 'log'])
+ config.delete(base + ['global-options', 'state-policy', state, 'log'])
+ if log_value == 'enable':
+ config.set(base + ['global-options', 'state-policy', state, 'log'])
+
for family in ['ipv4', 'ipv6', 'bridge']:
if config.exists(base + [family]):
for hook in ['forward', 'input', 'output', 'name']:
diff --git a/src/migration-scripts/interfaces/29-to-30 b/src/migration-scripts/interfaces/29-to-30
index 97e1b329c..04e023e77 100755
--- a/src/migration-scripts/interfaces/29-to-30
+++ b/src/migration-scripts/interfaces/29-to-30
@@ -35,16 +35,19 @@ if __name__ == '__main__':
# Nothing to do
sys.exit(0)
for interface in config.list_nodes(base):
+ if not config.exists(base + [interface, 'private-key']):
+ continue
private_key = config.return_value(base + [interface, 'private-key'])
interface_base = base + [interface]
if config.exists(interface_base + ['peer']):
for peer in config.list_nodes(interface_base + ['peer']):
peer_base = interface_base + ['peer', peer]
+ if not config.exists(peer_base + ['public-key']):
+ continue
peer_public_key = config.return_value(peer_base + ['public-key'])
- if config.exists(peer_base + ['public-key']):
- if not config.exists(peer_base + ['disable']) \
- and is_wireguard_key_pair(private_key, peer_public_key):
- config.set(peer_base + ['disable'])
+ if not config.exists(peer_base + ['disable']) \
+ and is_wireguard_key_pair(private_key, peer_public_key):
+ config.set(peer_base + ['disable'])
try:
with open(file_name, 'w') as f:
diff --git a/src/migration-scripts/pppoe-server/6-to-7 b/src/migration-scripts/pppoe-server/6-to-7
index 8b5482705..34996d8fe 100755
--- a/src/migration-scripts/pppoe-server/6-to-7
+++ b/src/migration-scripts/pppoe-server/6-to-7
@@ -76,23 +76,25 @@ if config.exists(base + ['gateway-address']):
#named pool migration
namedpools_base = pool_base + ['name']
-if config.return_value(base + ['authentication', 'mode']) == 'local':
- if config.list_nodes(namedpools_base):
- default_pool = config.list_nodes(namedpools_base)[0]
-
-for pool_name in config.list_nodes(namedpools_base):
- pool_path = namedpools_base + [pool_name]
- if config.exists(pool_path + ['subnet']):
- subnet = config.return_value(pool_path + ['subnet'])
- config.set(pool_base + [pool_name, 'range'], value=subnet)
- if config.exists(pool_path + ['next-pool']):
- next_pool = config.return_value(pool_path + ['next-pool'])
- config.set(pool_base + [pool_name, 'next-pool'], value=next_pool)
- if not gateway:
- if config.exists(pool_path + ['gateway-address']):
- gateway = config.return_value(pool_path + ['gateway-address'])
-
-config.delete(namedpools_base)
+if config.exists(namedpools_base):
+ if config.exists(base + ['authentication', 'mode']):
+ if config.return_value(base + ['authentication', 'mode']) == 'local':
+ if config.list_nodes(namedpools_base):
+ default_pool = config.list_nodes(namedpools_base)[0]
+
+ for pool_name in config.list_nodes(namedpools_base):
+ pool_path = namedpools_base + [pool_name]
+ if config.exists(pool_path + ['subnet']):
+ subnet = config.return_value(pool_path + ['subnet'])
+ config.set(pool_base + [pool_name, 'range'], value=subnet)
+ if config.exists(pool_path + ['next-pool']):
+ next_pool = config.return_value(pool_path + ['next-pool'])
+ config.set(pool_base + [pool_name, 'next-pool'], value=next_pool)
+ if not gateway:
+ if config.exists(pool_path + ['gateway-address']):
+ gateway = config.return_value(pool_path + ['gateway-address'])
+
+ config.delete(namedpools_base)
if gateway:
config.set(base + ['gateway-address'], value=gateway)
diff --git a/src/op_mode/dhcp.py b/src/op_mode/dhcp.py
index 77f38992b..d6b8aa0b8 100755
--- a/src/op_mode/dhcp.py
+++ b/src/op_mode/dhcp.py
@@ -43,6 +43,7 @@ sort_valid_inet6 = ['end', 'iaid_duid', 'ip', 'last_communication', 'pool', 'rem
ArgFamily = typing.Literal['inet', 'inet6']
ArgState = typing.Literal['all', 'active', 'free', 'expired', 'released', 'abandoned', 'reset', 'backup']
+ArgOrigin = typing.Literal['local', 'remote']
def _utc_to_local(utc_dt):
return datetime.fromtimestamp((datetime.fromtimestamp(utc_dt) - datetime(1970, 1, 1)).total_seconds())
@@ -71,7 +72,7 @@ def _find_list_of_dict_index(lst, key='ip', value='') -> int:
return idx
-def _get_raw_server_leases(family='inet', pool=None, sorted=None, state=[]) -> list:
+def _get_raw_server_leases(family='inet', pool=None, sorted=None, state=[], origin=None) -> list:
"""
Get DHCP server leases
:return list
@@ -82,51 +83,61 @@ def _get_raw_server_leases(family='inet', pool=None, sorted=None, state=[]) -> l
if pool is None:
pool = _get_dhcp_pools(family=family)
+ aux = False
else:
pool = [pool]
-
- for lease in leases:
- data_lease = {}
- data_lease['ip'] = lease.ip
- data_lease['state'] = lease.binding_state
- data_lease['pool'] = lease.sets.get('shared-networkname', '')
- data_lease['end'] = lease.end.timestamp() if lease.end else None
-
- if family == 'inet':
- data_lease['mac'] = lease.ethernet
- data_lease['start'] = lease.start.timestamp()
- data_lease['hostname'] = lease.hostname
-
- if family == 'inet6':
- data_lease['last_communication'] = lease.last_communication.timestamp()
- data_lease['iaid_duid'] = _format_hex_string(lease.host_identifier_string)
- lease_types_long = {'na': 'non-temporary', 'ta': 'temporary', 'pd': 'prefix delegation'}
- data_lease['type'] = lease_types_long[lease.type]
-
- data_lease['remaining'] = '-'
-
- if lease.end:
- data_lease['remaining'] = lease.end - datetime.utcnow()
-
- if data_lease['remaining'].days >= 0:
- # substraction gives us a timedelta object which can't be formatted with strftime
- # so we use str(), split gets rid of the microseconds
- data_lease['remaining'] = str(data_lease["remaining"]).split('.')[0]
-
- # Do not add old leases
- if data_lease['remaining'] != '' and data_lease['pool'] in pool and data_lease['state'] != 'free':
- if not state or data_lease['state'] in state:
- data.append(data_lease)
-
- # deduplicate
- checked = []
- for entry in data:
- addr = entry.get('ip')
- if addr not in checked:
- checked.append(addr)
- else:
- idx = _find_list_of_dict_index(data, key='ip', value=addr)
- data.pop(idx)
+ aux = True
+
+ ## Search leases for every pool
+ for pool_name in pool:
+ for lease in leases:
+ if lease.sets.get('shared-networkname', '') == pool_name or lease.sets.get('shared-networkname', '') == '':
+ #if lease.sets.get('shared-networkname', '') == pool_name:
+ data_lease = {}
+ data_lease['ip'] = lease.ip
+ data_lease['state'] = lease.binding_state
+ #data_lease['pool'] = pool_name if lease.sets.get('shared-networkname', '') != '' else 'Fail-Over Server'
+ data_lease['pool'] = lease.sets.get('shared-networkname', '')
+ data_lease['end'] = lease.end.timestamp() if lease.end else None
+ data_lease['origin'] = 'local' if data_lease['pool'] != '' else 'remote'
+
+ if family == 'inet':
+ data_lease['mac'] = lease.ethernet
+ data_lease['start'] = lease.start.timestamp()
+ data_lease['hostname'] = lease.hostname
+
+ if family == 'inet6':
+ data_lease['last_communication'] = lease.last_communication.timestamp()
+ data_lease['iaid_duid'] = _format_hex_string(lease.host_identifier_string)
+ lease_types_long = {'na': 'non-temporary', 'ta': 'temporary', 'pd': 'prefix delegation'}
+ data_lease['type'] = lease_types_long[lease.type]
+
+ data_lease['remaining'] = '-'
+
+ if lease.end:
+ data_lease['remaining'] = lease.end - datetime.utcnow()
+
+ if data_lease['remaining'].days >= 0:
+ # substraction gives us a timedelta object which can't be formatted with strftime
+ # so we use str(), split gets rid of the microseconds
+ data_lease['remaining'] = str(data_lease["remaining"]).split('.')[0]
+
+ # Do not add old leases
+ if data_lease['remaining'] != '' and data_lease['state'] != 'free':
+ if not state or data_lease['state'] in state or state == 'all':
+ if not origin or data_lease['origin'] in origin:
+ if not aux or (aux and data_lease['pool'] == pool_name):
+ data.append(data_lease)
+
+ # deduplicate
+ checked = []
+ for entry in data:
+ addr = entry.get('ip')
+ if addr not in checked:
+ checked.append(addr)
+ else:
+ idx = _find_list_of_dict_index(data, key='ip', value=addr)
+ data.pop(idx)
if sorted:
if sorted == 'ip':
@@ -150,10 +161,11 @@ def _get_formatted_server_leases(raw_data, family='inet'):
remain = lease.get('remaining')
pool = lease.get('pool')
hostname = lease.get('hostname')
- data_entries.append([ipaddr, hw_addr, state, start, end, remain, pool, hostname])
+ origin = lease.get('origin')
+ data_entries.append([ipaddr, hw_addr, state, start, end, remain, pool, hostname, origin])
headers = ['IP Address', 'MAC address', 'State', 'Lease start', 'Lease expiration', 'Remaining', 'Pool',
- 'Hostname']
+ 'Hostname', 'Origin']
if family == 'inet6':
for lease in raw_data:
@@ -267,7 +279,8 @@ def show_pool_statistics(raw: bool, family: ArgFamily, pool: typing.Optional[str
@_verify
def show_server_leases(raw: bool, family: ArgFamily, pool: typing.Optional[str],
- sorted: typing.Optional[str], state: typing.Optional[ArgState]):
+ sorted: typing.Optional[str], state: typing.Optional[ArgState],
+ origin: typing.Optional[ArgOrigin] ):
# if dhcp server is down, inactive leases may still be shown as active, so warn the user.
v = '6' if family == 'inet6' else ''
service_name = 'DHCPv6' if family == 'inet6' else 'DHCP'
@@ -285,7 +298,7 @@ def show_server_leases(raw: bool, family: ArgFamily, pool: typing.Optional[str],
if sorted and sorted not in sort_valid:
raise vyos.opmode.IncorrectValue(f'DHCP{v} sort "{sorted}" is invalid!')
- lease_data = _get_raw_server_leases(family=family, pool=pool, sorted=sorted, state=state)
+ lease_data = _get_raw_server_leases(family=family, pool=pool, sorted=sorted, state=state, origin=origin)
if raw:
return lease_data
else:
diff --git a/src/op_mode/image_installer.py b/src/op_mode/image_installer.py
index cdb84a152..b3e6e518c 100755
--- a/src/op_mode/image_installer.py
+++ b/src/op_mode/image_installer.py
@@ -60,6 +60,8 @@ MSG_INPUT_PASSWORD: str = 'Please enter a password for the "vyos" user'
MSG_INPUT_ROOT_SIZE_ALL: str = 'Would you like to use all the free space on the drive?'
MSG_INPUT_ROOT_SIZE_SET: str = 'Please specify the size (in GB) of the root partition (min is 1.5 GB)?'
MSG_INPUT_CONSOLE_TYPE: str = 'What console should be used by default? (K: KVM, S: Serial, U: USB-Serial)?'
+MSG_INPUT_COPY_DATA: str = 'Would you like to copy data to the new image?'
+MSG_INPUT_CHOOSE_COPY_DATA: str = 'From which image would you like to save config information?'
MSG_WARN_ISO_SIGN_INVALID: str = 'Signature is not valid. Do you want to continue with installation?'
MSG_WARN_ISO_SIGN_UNAVAL: str = 'Signature is not available. Do you want to continue with installation?'
MSG_WARN_ROOT_SIZE_TOOBIG: str = 'The size is too big. Try again.'
@@ -184,6 +186,83 @@ def create_partitions(target_disk: str, target_size: int,
return disk_details
+def search_format_selection(image: tuple[str, str]) -> str:
+ """Format a string for selection of image
+
+ Args:
+ image (tuple[str, str]): a tuple of image name and drive
+
+ Returns:
+ str: formatted string
+ """
+ return f'{image[0]} on {image[1]}'
+
+
+def search_previous_installation(disks: list[str]) -> None:
+ """Search disks for previous installation config and SSH keys
+
+ Args:
+ disks (list[str]): a list of available disks
+ """
+ mnt_config = '/mnt/config'
+ mnt_ssh = '/mnt/ssh'
+ mnt_tmp = '/mnt/tmp'
+ rmtree(Path(mnt_config), ignore_errors=True)
+ rmtree(Path(mnt_ssh), ignore_errors=True)
+ Path(mnt_tmp).mkdir(exist_ok=True)
+
+ print('Searching for data from previous installations')
+ image_data = []
+ for disk_name in disks:
+ for partition in disk.partition_list(disk_name):
+ if disk.partition_mount(partition, mnt_tmp):
+ if Path(mnt_tmp + '/boot').exists():
+ for path in Path(mnt_tmp + '/boot').iterdir():
+ if path.joinpath('rw/config/.vyatta_config').exists():
+ image_data.append((path.name, partition))
+
+ disk.partition_umount(partition)
+
+ if len(image_data) == 1:
+ image_name, image_drive = image_data[0]
+ print('Found data from previous installation:')
+ print(f'\t{image_name} on {image_drive}')
+ if not ask_yes_no(MSG_INPUT_COPY_DATA, default=True):
+ return
+
+ elif len(image_data) > 1:
+ print('Found data from previous installations')
+ if not ask_yes_no(MSG_INPUT_COPY_DATA, default=True):
+ return
+
+ image_name, image_drive = select_entry(image_data,
+ 'Available versions:',
+ MSG_INPUT_CHOOSE_COPY_DATA,
+ search_format_selection)
+ else:
+ print('No previous installation found')
+ return
+
+ disk.partition_mount(image_drive, mnt_tmp)
+
+ copytree(f'{mnt_tmp}/boot/{image_name}/rw/config', mnt_config)
+ Path(mnt_ssh).mkdir()
+ host_keys: list[str] = glob(f'{mnt_tmp}/boot/{image_name}/rw/etc/ssh/ssh_host*')
+ for host_key in host_keys:
+ copy(host_key, mnt_ssh)
+
+ disk.partition_umount(image_drive)
+
+
+def copy_previous_installation_data(target_dir: str) -> None:
+ if Path('/mnt/config').exists():
+ copytree('/mnt/config', f'{target_dir}/opt/vyatta/etc/config',
+ dirs_exist_ok=True)
+ if Path('/mnt/ssh').exists():
+ copytree('/mnt/ssh', f'{target_dir}/etc/ssh',
+ dirs_exist_ok=True)
+
+
def ask_single_disk(disks_available: dict[str, int]) -> str:
"""Ask user to select a disk for installation
@@ -204,6 +283,8 @@ def ask_single_disk(disks_available: dict[str, int]) -> str:
print(MSG_INFO_INSTALL_EXIT)
exit()
+ search_previous_installation(list(disks_available))
+
disk_details: disk.DiskDetails = create_partitions(disk_selected,
disks_available[disk_selected])
@@ -260,6 +341,8 @@ def check_raid_install(disks_available: dict[str, int]) -> Union[str, None]:
print(MSG_INFO_INSTALL_EXIT)
exit()
+ search_previous_installation(list(disks_available))
+
disks: list[disk.DiskDetails] = []
for disk_selected in list(disks_selected):
print(f'Creating partitions on {disk_selected}')
@@ -581,6 +664,10 @@ def install_image() -> None:
copy(FILE_ROOTFS_SRC,
f'{DIR_DST_ROOT}/boot/{image_name}/{image_name}.squashfs')
+ # copy saved config data and SSH keys
+ # owner restored on copy of config data by chmod_2775, above
+ copy_previous_installation_data(f'{DIR_DST_ROOT}/boot/{image_name}/rw')
+
if is_raid_install(install_target):
write_dir: str = f'{DIR_DST_ROOT}/boot/{image_name}/rw'
raid.update_default(write_dir)