diff options
Diffstat (limited to 'src/conf_mode')
-rwxr-xr-x | src/conf_mode/interfaces_macsec.py | 10 | ||||
-rwxr-xr-x | src/conf_mode/nat_cgnat.py | 51 | ||||
-rwxr-xr-x | src/conf_mode/system_option.py | 5 | ||||
-rwxr-xr-x | src/conf_mode/vpn_openconnect.py | 43 |
4 files changed, 90 insertions, 19 deletions
diff --git a/src/conf_mode/interfaces_macsec.py b/src/conf_mode/interfaces_macsec.py index eb0ca9a8b..3ede4377a 100755 --- a/src/conf_mode/interfaces_macsec.py +++ b/src/conf_mode/interfaces_macsec.py @@ -103,9 +103,9 @@ def verify(macsec): # Logic to check static configuration if dict_search('security.static', macsec) != None: - # tx-key must be defined + # key must be defined if dict_search('security.static.key', macsec) == None: - raise ConfigError('Static MACsec tx-key must be defined.') + raise ConfigError('Static MACsec key must be defined.') tx_len = len(dict_search('security.static.key', macsec)) @@ -119,12 +119,12 @@ def verify(macsec): if 'peer' not in macsec['security']['static']: raise ConfigError('Must have at least one peer defined for static MACsec') - # For every enabled peer, make sure a MAC and rx-key is defined + # For every enabled peer, make sure a MAC and key is defined for peer, peer_config in macsec['security']['static']['peer'].items(): if 'disable' not in peer_config and ('mac' not in peer_config or 'key' not in peer_config): - raise ConfigError('Every enabled MACsec static peer must have a MAC address and rx-key defined.') + raise ConfigError('Every enabled MACsec static peer must have a MAC address and key defined!') - # check rx-key length against cipher suite + # check key length against cipher suite rx_len = len(peer_config['key']) if dict_search('security.cipher', macsec) == 'gcm-aes-128' and rx_len != GCM_AES_128_LEN: diff --git a/src/conf_mode/nat_cgnat.py b/src/conf_mode/nat_cgnat.py index cb336a35c..34ec64fce 100755 --- a/src/conf_mode/nat_cgnat.py +++ b/src/conf_mode/nat_cgnat.py @@ -23,6 +23,7 @@ from sys import exit from logging.handlers import SysLogHandler from vyos.config import Config +from vyos.configdict import is_node_changed from vyos.template import render from vyos.utils.process import cmd from vyos.utils.process import run @@ -118,6 +119,41 @@ class IPOperations: + [self.ip_network.broadcast_address] ] + def get_prefix_by_ip_range(self): + """Return the common prefix for the address range + + Example: + % ip = IPOperations('100.64.0.1-100.64.0.5') + % ip.get_prefix_by_ip_range() + 100.64.0.0/29 + """ + if '-' in self.ip_prefix: + ip_start, ip_end = self.ip_prefix.split('-') + start_ip = ipaddress.IPv4Address(ip_start.strip()) + end_ip = ipaddress.IPv4Address(ip_end.strip()) + + start_int = int(start_ip) + end_int = int(end_ip) + + # XOR to find differing bits + xor = start_int ^ end_int + + # Count the number of leading zeros in the XOR result to find the prefix length + prefix_length = 32 - xor.bit_length() + + # Calculate the network address + network_int = start_int & (0xFFFFFFFF << (32 - prefix_length)) + network_address = ipaddress.IPv4Address(network_int) + + return f"{network_address}/{prefix_length}" + return self.ip_prefix + + +def _delete_conntrack_entries(source_prefixes: list) -> None: + """Delete all conntrack entries for the list of prefixes""" + for source_prefix in source_prefixes: + run(f'conntrack -D -s {source_prefix}') + def generate_port_rules( external_hosts: list, @@ -188,6 +224,9 @@ def get_config(config=None): with_recursive_defaults=True, ) + if conf.exists(base) and is_node_changed(conf, base + ['pool']): + config.update({'delete_conntrack_entries': {}}) + return config @@ -386,6 +425,18 @@ def apply(config): # Log error message logger.error(f"Error processing line '{allocation}': {e}") + # Delete conntrack entries + if 'delete_conntrack_entries' in config: + internal_pool_prefix_list = [] + for rule, rule_config in config['rule'].items(): + internal_pool = rule_config['source']['pool'] + internal_ip_ranges: list = config['pool']['internal'][internal_pool]['range'] + for internal_range in internal_ip_ranges: + ip_prefix = IPOperations(internal_range).get_prefix_by_ip_range() + internal_pool_prefix_list.append(ip_prefix) + # Deleta required sources for conntrack + _delete_conntrack_entries(internal_pool_prefix_list) + if __name__ == '__main__': try: diff --git a/src/conf_mode/system_option.py b/src/conf_mode/system_option.py index a2e5db575..2c31703e9 100755 --- a/src/conf_mode/system_option.py +++ b/src/conf_mode/system_option.py @@ -35,6 +35,7 @@ airbag.enable() curlrc_config = r'/etc/curlrc' ssh_config = r'/etc/ssh/ssh_config.d/91-vyos-ssh-client-options.conf' systemd_action_file = '/lib/systemd/system/ctrl-alt-del.target' +usb_autosuspend = r'/etc/udev/rules.d/40-usb-autosuspend.rules' time_format_to_locale = { '12-hour': 'en_US.UTF-8', '24-hour': 'en_GB.UTF-8' @@ -85,6 +86,7 @@ def verify(options): def generate(options): render(curlrc_config, 'system/curlrc.j2', options) render(ssh_config, 'system/ssh_config.j2', options) + render(usb_autosuspend, 'system/40_usb_autosuspend.j2', options) cmdline_options = [] if 'kernel' in options: @@ -155,6 +157,9 @@ def apply(options): time_format = time_format_to_locale.get(options['time_format']) cmd(f'localectl set-locale LC_TIME={time_format}') + cmd('udevadm control --reload-rules') + + if __name__ == '__main__': try: c = get_config() diff --git a/src/conf_mode/vpn_openconnect.py b/src/conf_mode/vpn_openconnect.py index 8159fedea..42785134f 100755 --- a/src/conf_mode/vpn_openconnect.py +++ b/src/conf_mode/vpn_openconnect.py @@ -21,14 +21,17 @@ from vyos.base import Warning from vyos.config import Config from vyos.configverify import verify_pki_certificate from vyos.configverify import verify_pki_ca_certificate -from vyos.pki import wrap_certificate +from vyos.pki import find_chain +from vyos.pki import encode_certificate +from vyos.pki import load_certificate from vyos.pki import wrap_private_key from vyos.template import render -from vyos.utils.process import call +from vyos.utils.dict import dict_search +from vyos.utils.file import write_file from vyos.utils.network import check_port_availability -from vyos.utils.process import is_systemd_service_running from vyos.utils.network import is_listen_port_bind_service -from vyos.utils.dict import dict_search +from vyos.utils.process import call +from vyos.utils.process import is_systemd_service_running from vyos import ConfigError from passlib.hash import sha512_crypt from time import sleep @@ -142,7 +145,8 @@ def verify(ocserv): verify_pki_certificate(ocserv, ocserv['ssl']['certificate']) if 'ca_certificate' in ocserv['ssl']: - verify_pki_ca_certificate(ocserv, ocserv['ssl']['ca_certificate']) + for ca_cert in ocserv['ssl']['ca_certificate']: + verify_pki_ca_certificate(ocserv, ca_cert) # Check network settings if "network_settings" in ocserv: @@ -219,25 +223,36 @@ def generate(ocserv): if "ssl" in ocserv: cert_file_path = os.path.join(cfg_dir, 'cert.pem') cert_key_path = os.path.join(cfg_dir, 'cert.key') - ca_cert_file_path = os.path.join(cfg_dir, 'ca.pem') + if 'certificate' in ocserv['ssl']: cert_name = ocserv['ssl']['certificate'] pki_cert = ocserv['pki']['certificate'][cert_name] - with open(cert_file_path, 'w') as f: - f.write(wrap_certificate(pki_cert['certificate'])) + loaded_pki_cert = load_certificate(pki_cert['certificate']) + loaded_ca_certs = {load_certificate(c['certificate']) + for c in ocserv['pki']['ca'].values()} if 'ca' in ocserv['pki'] else {} + + cert_full_chain = find_chain(loaded_pki_cert, loaded_ca_certs) + + write_file(cert_file_path, + '\n'.join(encode_certificate(c) for c in cert_full_chain)) if 'private' in pki_cert and 'key' in pki_cert['private']: - with open(cert_key_path, 'w') as f: - f.write(wrap_private_key(pki_cert['private']['key'])) + write_file(cert_key_path, wrap_private_key(pki_cert['private']['key'])) if 'ca_certificate' in ocserv['ssl']: - ca_name = ocserv['ssl']['ca_certificate'] - pki_ca_cert = ocserv['pki']['ca'][ca_name] + ca_cert_file_path = os.path.join(cfg_dir, 'ca.pem') + ca_chains = [] + + for ca_name in ocserv['ssl']['ca_certificate']: + pki_ca_cert = ocserv['pki']['ca'][ca_name] + loaded_ca_cert = load_certificate(pki_ca_cert['certificate']) + ca_full_chain = find_chain(loaded_ca_cert, loaded_ca_certs) + ca_chains.append( + '\n'.join(encode_certificate(c) for c in ca_full_chain)) - with open(ca_cert_file_path, 'w') as f: - f.write(wrap_certificate(pki_ca_cert['certificate'])) + write_file(ca_cert_file_path, '\n'.join(ca_chains)) # Render config render(ocserv_conf, 'ocserv/ocserv_config.j2', ocserv) |