diff options
-rwxr-xr-x | src/conf_mode/nat_cgnat.py | 89 |
1 files changed, 83 insertions, 6 deletions
diff --git a/src/conf_mode/nat_cgnat.py b/src/conf_mode/nat_cgnat.py index cb336a35c..3484e5873 100755 --- a/src/conf_mode/nat_cgnat.py +++ b/src/conf_mode/nat_cgnat.py @@ -23,6 +23,7 @@ from sys import exit from logging.handlers import SysLogHandler from vyos.config import Config +from vyos.configdict import is_node_changed from vyos.template import render from vyos.utils.process import cmd from vyos.utils.process import run @@ -118,6 +119,38 @@ class IPOperations: + [self.ip_network.broadcast_address] ] + def get_prefix_by_ip_range(self) -> list[ipaddress.IPv4Network]: + """Return the common prefix for the address range + + Example: + % ip = IPOperations('100.64.0.1-100.64.0.5') + % ip.get_prefix_by_ip_range() + [IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')] + """ + # We do not need to convert the IP range to network + # if it is already in network format + if self.ip_network: + return [self.ip_network] + + # Raise an error if the IP range is not in the correct format + if '-' not in self.ip_prefix: + raise ValueError( + 'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.' + ) + # Split the IP range and convert it to IP address objects + range_start, range_end = self.ip_prefix.split('-') + range_start = ipaddress.IPv4Address(range_start) + range_end = ipaddress.IPv4Address(range_end) + + # Return the summarized IP networks list + return list(ipaddress.summarize_address_range(range_start, range_end)) + + +def _delete_conntrack_entries(source_prefixes: list[ipaddress.IPv4Network]) -> None: + """Delete all conntrack entries for the list of prefixes""" + for source_prefix in source_prefixes: + run(f'conntrack -D -s {source_prefix}') + def generate_port_rules( external_hosts: list, @@ -188,12 +221,31 @@ def get_config(config=None): with_recursive_defaults=True, ) + effective_config = conf.get_config_dict( + base, + get_first_key=True, + key_mangling=('-', '_'), + no_tag_node_value_mangle=True, + effective=True, + ) + + # Check if the pool configuration has changed + if not conf.exists(base) or is_node_changed(conf, base + ['pool']): + config['delete_conntrack_entries'] = {} + + # add running config + if effective_config: + config['effective'] = effective_config + + if not conf.exists(base): + config['deleted'] = {} + return config def verify(config): # bail out early - looks like removal from running config - if not config: + if 'deleted' in config: return None if 'pool' not in config: @@ -297,7 +349,7 @@ def verify(config): def generate(config): - if not config: + if 'deleted' in config: return None proto_maps = [] @@ -362,13 +414,38 @@ def generate(config): def apply(config): - if not config: + if 'deleted' in config: # Cleanup cgnat cmd('nft delete table ip cgnat') if os.path.isfile(nftables_cgnat_config): os.unlink(nftables_cgnat_config) - return None - cmd(f'nft --file {nftables_cgnat_config}') + else: + cmd(f'nft --file {nftables_cgnat_config}') + + # Delete conntrack entries + # if the pool configuration has changed + if 'delete_conntrack_entries' in config and 'effective' in config: + # Prepare the list of internal pool prefixes + internal_pool_prefix_list: list[ipaddress.IPv4Network] = [] + + # Get effective rules configurations + for rule_config in config['effective'].get('rule', {}).values(): + # Get effective internal pool configuration + internal_pool = rule_config['source']['pool'] + # Find the internal IP ranges for the internal pool + internal_ip_ranges: list[str] = config['effective']['pool']['internal'][ + internal_pool + ]['range'] + # Get the IP prefixes for the internal IP range + for internal_range in internal_ip_ranges: + ip_prefix: list[ipaddress.IPv4Network] = IPOperations( + internal_range + ).get_prefix_by_ip_range() + # Add the IP prefixes to the list of all internal pool prefixes + internal_pool_prefix_list += ip_prefix + + # Delete required sources for conntrack + _delete_conntrack_entries(internal_pool_prefix_list) # Logging allocations if 'log_allocation' in config: @@ -381,7 +458,7 @@ def apply(config): external_host, port_range = rest.split(' . ') # Log the parsed data logger.info( - f"Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}") + f'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}') except ValueError as e: # Log error message logger.error(f"Error processing line '{allocation}': {e}") |