diff options
Diffstat (limited to 'src/conf_mode/nat_cgnat.py')
| -rw-r--r-- | src/conf_mode/nat_cgnat.py | 475 |
1 files changed, 475 insertions, 0 deletions
diff --git a/src/conf_mode/nat_cgnat.py b/src/conf_mode/nat_cgnat.py new file mode 100644 index 0000000..3484e58 --- /dev/null +++ b/src/conf_mode/nat_cgnat.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import ipaddress +import jmespath +import logging +import os + +from sys import exit +from logging.handlers import SysLogHandler + +from vyos.config import Config +from vyos.configdict import is_node_changed +from vyos.template import render +from vyos.utils.process import cmd +from vyos.utils.process import run +from vyos import ConfigError +from vyos import airbag + +airbag.enable() + + +nftables_cgnat_config = '/run/nftables-cgnat.nft' + +# Logging +logger = logging.getLogger('cgnat') +logger.setLevel(logging.DEBUG) + +syslog_handler = SysLogHandler(address="/dev/log") +syslog_handler.setLevel(logging.INFO) + +formatter = logging.Formatter('%(name)s: %(message)s') +syslog_handler.setFormatter(formatter) + +logger.addHandler(syslog_handler) + + +class IPOperations: + def __init__(self, ip_prefix: str): + self.ip_prefix = ip_prefix + self.ip_network = ipaddress.ip_network(ip_prefix) if '/' in ip_prefix else None + + def get_ips_count(self) -> int: + """Returns the number of IPs in a prefix or range. + + Example: + % ip = IPOperations('192.0.2.0/30') + % ip.get_ips_count() + 4 + % ip = IPOperations('192.0.2.0-192.0.2.2') + % ip.get_ips_count() + 3 + """ + if '-' in self.ip_prefix: + start_ip, end_ip = self.ip_prefix.split('-') + start_ip = ipaddress.ip_address(start_ip) + end_ip = ipaddress.ip_address(end_ip) + return int(end_ip) - int(start_ip) + 1 + elif '/31' in self.ip_prefix: + return 2 + elif '/32' in self.ip_prefix: + return 1 + else: + return sum( + 1 + for _ in [self.ip_network.network_address] + + list(self.ip_network.hosts()) + + [self.ip_network.broadcast_address] + ) + + def convert_prefix_to_list_ips(self) -> list: + """Converts a prefix or IP range to a list of IPs including the network and broadcast addresses. + + Example: + % ip = IPOperations('192.0.2.0/30') + % ip.convert_prefix_to_list_ips() + ['192.0.2.0', '192.0.2.1', '192.0.2.2', '192.0.2.3'] + % + % ip = IPOperations('192.0.0.1-192.0.2.5') + % ip.convert_prefix_to_list_ips() + ['192.0.2.1', '192.0.2.2', '192.0.2.3', '192.0.2.4', '192.0.2.5'] + """ + if '-' in self.ip_prefix: + start_ip, end_ip = self.ip_prefix.split('-') + start_ip = ipaddress.ip_address(start_ip) + end_ip = ipaddress.ip_address(end_ip) + return [ + str(ipaddress.ip_address(ip)) + for ip in range(int(start_ip), int(end_ip) + 1) + ] + elif '/31' in self.ip_prefix: + return [ + str(ip) + for ip in [ + self.ip_network.network_address, + self.ip_network.broadcast_address, + ] + ] + elif '/32' in self.ip_prefix: + return [str(self.ip_network.network_address)] + else: + return [ + str(ip) + for ip in [self.ip_network.network_address] + + list(self.ip_network.hosts()) + + [self.ip_network.broadcast_address] + ] + + def get_prefix_by_ip_range(self) -> list[ipaddress.IPv4Network]: + """Return the common prefix for the address range + + Example: + % ip = IPOperations('100.64.0.1-100.64.0.5') + % ip.get_prefix_by_ip_range() + [IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')] + """ + # We do not need to convert the IP range to network + # if it is already in network format + if self.ip_network: + return [self.ip_network] + + # Raise an error if the IP range is not in the correct format + if '-' not in self.ip_prefix: + raise ValueError( + 'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.' + ) + # Split the IP range and convert it to IP address objects + range_start, range_end = self.ip_prefix.split('-') + range_start = ipaddress.IPv4Address(range_start) + range_end = ipaddress.IPv4Address(range_end) + + # Return the summarized IP networks list + return list(ipaddress.summarize_address_range(range_start, range_end)) + + +def _delete_conntrack_entries(source_prefixes: list[ipaddress.IPv4Network]) -> None: + """Delete all conntrack entries for the list of prefixes""" + for source_prefix in source_prefixes: + run(f'conntrack -D -s {source_prefix}') + + +def generate_port_rules( + external_hosts: list, + internal_hosts: list, + port_count: int, + global_port_range: str = '1024-65535', +) -> list: + """Generates a list of nftables option rules for the batch file. + + Args: + external_hosts (list): A list of external host IPs. + internal_hosts (list): A list of internal host IPs. + port_count (int): The number of ports required per host. + global_port_range (str): The global port range to be used. Default is '1024-65535'. + + Returns: + list: A list containing two elements: + - proto_map_elements (list): A list of proto map elements. + - other_map_elements (list): A list of other map elements. + """ + rules = [] + proto_map_elements = [] + other_map_elements = [] + start_port, end_port = map(int, global_port_range.split('-')) + total_possible_ports = (end_port - start_port) + 1 + + # Calculate the required number of ports per host + required_ports_per_host = port_count + current_port = start_port + current_external_index = 0 + + for internal_host in internal_hosts: + external_host = external_hosts[current_external_index] + next_end_port = current_port + required_ports_per_host - 1 + + # If the port range exceeds the end_port, move to the next external host + while next_end_port > end_port: + current_external_index = (current_external_index + 1) % len(external_hosts) + external_host = external_hosts[current_external_index] + current_port = start_port + next_end_port = current_port + required_ports_per_host - 1 + + proto_map_elements.append( + f'{internal_host} : {external_host} . {current_port}-{next_end_port}' + ) + other_map_elements.append(f'{internal_host} : {external_host}') + + current_port = next_end_port + 1 + if current_port > end_port: + current_port = start_port + current_external_index += 1 # Move to the next external host + + return [proto_map_elements, other_map_elements] + + +def get_config(config=None): + if config: + conf = config + else: + conf = Config() + + base = ['nat', 'cgnat'] + config = conf.get_config_dict( + base, + get_first_key=True, + key_mangling=('-', '_'), + no_tag_node_value_mangle=True, + with_recursive_defaults=True, + ) + + effective_config = conf.get_config_dict( + base, + get_first_key=True, + key_mangling=('-', '_'), + no_tag_node_value_mangle=True, + effective=True, + ) + + # Check if the pool configuration has changed + if not conf.exists(base) or is_node_changed(conf, base + ['pool']): + config['delete_conntrack_entries'] = {} + + # add running config + if effective_config: + config['effective'] = effective_config + + if not conf.exists(base): + config['deleted'] = {} + + return config + + +def verify(config): + # bail out early - looks like removal from running config + if 'deleted' in config: + return None + + if 'pool' not in config: + raise ConfigError(f'Pool must be defined!') + if 'rule' not in config: + raise ConfigError(f'Rule must be defined!') + + for pool in ('external', 'internal'): + if pool not in config['pool']: + raise ConfigError(f'{pool} pool must be defined!') + for pool_name, pool_config in config['pool'][pool].items(): + if 'range' not in pool_config: + raise ConfigError( + f'Range for "{pool} pool {pool_name}" must be defined!' + ) + + external_pools_query = "keys(pool.external)" + external_pools: list = jmespath.search(external_pools_query, config) + internal_pools_query = "keys(pool.internal)" + internal_pools: list = jmespath.search(internal_pools_query, config) + + used_external_pools = {} + used_internal_pools = {} + for rule, rule_config in config['rule'].items(): + if 'source' not in rule_config: + raise ConfigError(f'Rule "{rule}" source pool must be defined!') + if 'pool' not in rule_config['source']: + raise ConfigError(f'Rule "{rule}" source pool must be defined!') + + if 'translation' not in rule_config: + raise ConfigError(f'Rule "{rule}" translation pool must be defined!') + + # Check if pool exists + internal_pool = rule_config['source']['pool'] + if internal_pool not in internal_pools: + raise ConfigError(f'Internal pool "{internal_pool}" does not exist!') + external_pool = rule_config['translation']['pool'] + if external_pool not in external_pools: + raise ConfigError(f'External pool "{external_pool}" does not exist!') + + # Check pool duplication in different rules + if external_pool in used_external_pools: + raise ConfigError( + f'External pool "{external_pool}" is already used in rule ' + f'{used_external_pools[external_pool]} and cannot be used in ' + f'rule {rule}!' + ) + + if internal_pool in used_internal_pools: + raise ConfigError( + f'Internal pool "{internal_pool}" is already used in rule ' + f'{used_internal_pools[internal_pool]} and cannot be used in ' + f'rule {rule}!' + ) + + used_external_pools[external_pool] = rule + used_internal_pools[internal_pool] = rule + + # Check calculation for allocation + external_port_range: str = config['pool']['external'][external_pool]['external_port_range'] + + external_ip_ranges: list = list( + config['pool']['external'][external_pool]['range'] + ) + internal_ip_ranges: list = config['pool']['internal'][internal_pool]['range'] + start_port, end_port = map(int, external_port_range.split('-')) + ports_per_range_count: int = (end_port - start_port) + 1 + + external_list_hosts_count = [] + external_list_hosts = [] + internal_list_hosts_count = [] + internal_list_hosts = [] + for ext_range in external_ip_ranges: + # External hosts count + e_count = IPOperations(ext_range).get_ips_count() + external_list_hosts_count.append(e_count) + # External hosts list + e_hosts = IPOperations(ext_range).convert_prefix_to_list_ips() + external_list_hosts.extend(e_hosts) + for int_range in internal_ip_ranges: + # Internal hosts count + i_count = IPOperations(int_range).get_ips_count() + internal_list_hosts_count.append(i_count) + # Internal hosts list + i_hosts = IPOperations(int_range).convert_prefix_to_list_ips() + internal_list_hosts.extend(i_hosts) + + external_host_count = sum(external_list_hosts_count) + internal_host_count = sum(internal_list_hosts_count) + ports_per_user: int = int( + config['pool']['external'][external_pool]['per_user_limit']['port'] + ) + users_per_extip = ports_per_range_count // ports_per_user + max_users = users_per_extip * external_host_count + + if internal_host_count > max_users: + raise ConfigError( + f'Rule "{rule}" does not have enough ports available for the ' + f'specified parameters' + ) + + +def generate(config): + if 'deleted' in config: + return None + + proto_maps = [] + other_maps = [] + + for rule, rule_config in config['rule'].items(): + ext_pool_name: str = rule_config['translation']['pool'] + int_pool_name: str = rule_config['source']['pool'] + + # Sort the external ranges by sequence + external_ranges: list = sorted( + config['pool']['external'][ext_pool_name]['range'], + key=lambda r: int(config['pool']['external'][ext_pool_name]['range'][r].get('seq', 999999)) + ) + internal_ranges: list = [range for range in config['pool']['internal'][int_pool_name]['range']] + external_list_hosts_count = [] + external_list_hosts = [] + internal_list_hosts_count = [] + internal_list_hosts = [] + + for ext_range in external_ranges: + # External hosts count + e_count = IPOperations(ext_range).get_ips_count() + external_list_hosts_count.append(e_count) + # External hosts list + e_hosts = IPOperations(ext_range).convert_prefix_to_list_ips() + external_list_hosts.extend(e_hosts) + + for int_range in internal_ranges: + # Internal hosts count + i_count = IPOperations(int_range).get_ips_count() + internal_list_hosts_count.append(i_count) + # Internal hosts list + i_hosts = IPOperations(int_range).convert_prefix_to_list_ips() + internal_list_hosts.extend(i_hosts) + + external_host_count = sum(external_list_hosts_count) + internal_host_count = sum(internal_list_hosts_count) + ports_per_user = int( + jmespath.search(f'pool.external."{ext_pool_name}".per_user_limit.port', config) + ) + external_port_range: str = jmespath.search( + f'pool.external."{ext_pool_name}".external_port_range', config + ) + + rule_proto_maps, rule_other_maps = generate_port_rules( + external_list_hosts, internal_list_hosts, ports_per_user, external_port_range + ) + + proto_maps.extend(rule_proto_maps) + other_maps.extend(rule_other_maps) + + config['proto_map_elements'] = ', '.join(proto_maps) + config['other_map_elements'] = ', '.join(other_maps) + + render(nftables_cgnat_config, 'firewall/nftables-cgnat.j2', config) + + # dry-run newly generated configuration + tmp = run(f'nft --check --file {nftables_cgnat_config}') + if tmp > 0: + raise ConfigError('Configuration file errors encountered!') + + +def apply(config): + if 'deleted' in config: + # Cleanup cgnat + cmd('nft delete table ip cgnat') + if os.path.isfile(nftables_cgnat_config): + os.unlink(nftables_cgnat_config) + else: + cmd(f'nft --file {nftables_cgnat_config}') + + # Delete conntrack entries + # if the pool configuration has changed + if 'delete_conntrack_entries' in config and 'effective' in config: + # Prepare the list of internal pool prefixes + internal_pool_prefix_list: list[ipaddress.IPv4Network] = [] + + # Get effective rules configurations + for rule_config in config['effective'].get('rule', {}).values(): + # Get effective internal pool configuration + internal_pool = rule_config['source']['pool'] + # Find the internal IP ranges for the internal pool + internal_ip_ranges: list[str] = config['effective']['pool']['internal'][ + internal_pool + ]['range'] + # Get the IP prefixes for the internal IP range + for internal_range in internal_ip_ranges: + ip_prefix: list[ipaddress.IPv4Network] = IPOperations( + internal_range + ).get_prefix_by_ip_range() + # Add the IP prefixes to the list of all internal pool prefixes + internal_pool_prefix_list += ip_prefix + + # Delete required sources for conntrack + _delete_conntrack_entries(internal_pool_prefix_list) + + # Logging allocations + if 'log_allocation' in config: + allocations = config['proto_map_elements'] + allocations = allocations.split(',') + for allocation in allocations: + try: + # Split based on the delimiters used in the nft data format + internal_host, rest = allocation.split(' : ') + external_host, port_range = rest.split(' . ') + # Log the parsed data + logger.info( + f'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}') + except ValueError as e: + # Log error message + logger.error(f"Error processing line '{allocation}': {e}") + + +if __name__ == '__main__': + try: + c = get_config() + verify(c) + generate(c) + apply(c) + except ConfigError as e: + print(e) + exit(1) |
