summaryrefslogtreecommitdiff
path: root/src/conf_mode/nat_cgnat.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/conf_mode/nat_cgnat.py')
-rw-r--r--src/conf_mode/nat_cgnat.py475
1 files changed, 475 insertions, 0 deletions
diff --git a/src/conf_mode/nat_cgnat.py b/src/conf_mode/nat_cgnat.py
new file mode 100644
index 0000000..3484e58
--- /dev/null
+++ b/src/conf_mode/nat_cgnat.py
@@ -0,0 +1,475 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2024 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+import ipaddress
+import jmespath
+import logging
+import os
+
+from sys import exit
+from logging.handlers import SysLogHandler
+
+from vyos.config import Config
+from vyos.configdict import is_node_changed
+from vyos.template import render
+from vyos.utils.process import cmd
+from vyos.utils.process import run
+from vyos import ConfigError
+from vyos import airbag
+
+airbag.enable()
+
+
+nftables_cgnat_config = '/run/nftables-cgnat.nft'
+
+# Logging
+logger = logging.getLogger('cgnat')
+logger.setLevel(logging.DEBUG)
+
+syslog_handler = SysLogHandler(address="/dev/log")
+syslog_handler.setLevel(logging.INFO)
+
+formatter = logging.Formatter('%(name)s: %(message)s')
+syslog_handler.setFormatter(formatter)
+
+logger.addHandler(syslog_handler)
+
+
+class IPOperations:
+ def __init__(self, ip_prefix: str):
+ self.ip_prefix = ip_prefix
+ self.ip_network = ipaddress.ip_network(ip_prefix) if '/' in ip_prefix else None
+
+ def get_ips_count(self) -> int:
+ """Returns the number of IPs in a prefix or range.
+
+ Example:
+ % ip = IPOperations('192.0.2.0/30')
+ % ip.get_ips_count()
+ 4
+ % ip = IPOperations('192.0.2.0-192.0.2.2')
+ % ip.get_ips_count()
+ 3
+ """
+ if '-' in self.ip_prefix:
+ start_ip, end_ip = self.ip_prefix.split('-')
+ start_ip = ipaddress.ip_address(start_ip)
+ end_ip = ipaddress.ip_address(end_ip)
+ return int(end_ip) - int(start_ip) + 1
+ elif '/31' in self.ip_prefix:
+ return 2
+ elif '/32' in self.ip_prefix:
+ return 1
+ else:
+ return sum(
+ 1
+ for _ in [self.ip_network.network_address]
+ + list(self.ip_network.hosts())
+ + [self.ip_network.broadcast_address]
+ )
+
+ def convert_prefix_to_list_ips(self) -> list:
+ """Converts a prefix or IP range to a list of IPs including the network and broadcast addresses.
+
+ Example:
+ % ip = IPOperations('192.0.2.0/30')
+ % ip.convert_prefix_to_list_ips()
+ ['192.0.2.0', '192.0.2.1', '192.0.2.2', '192.0.2.3']
+ %
+ % ip = IPOperations('192.0.0.1-192.0.2.5')
+ % ip.convert_prefix_to_list_ips()
+ ['192.0.2.1', '192.0.2.2', '192.0.2.3', '192.0.2.4', '192.0.2.5']
+ """
+ if '-' in self.ip_prefix:
+ start_ip, end_ip = self.ip_prefix.split('-')
+ start_ip = ipaddress.ip_address(start_ip)
+ end_ip = ipaddress.ip_address(end_ip)
+ return [
+ str(ipaddress.ip_address(ip))
+ for ip in range(int(start_ip), int(end_ip) + 1)
+ ]
+ elif '/31' in self.ip_prefix:
+ return [
+ str(ip)
+ for ip in [
+ self.ip_network.network_address,
+ self.ip_network.broadcast_address,
+ ]
+ ]
+ elif '/32' in self.ip_prefix:
+ return [str(self.ip_network.network_address)]
+ else:
+ return [
+ str(ip)
+ for ip in [self.ip_network.network_address]
+ + list(self.ip_network.hosts())
+ + [self.ip_network.broadcast_address]
+ ]
+
+ def get_prefix_by_ip_range(self) -> list[ipaddress.IPv4Network]:
+ """Return the common prefix for the address range
+
+ Example:
+ % ip = IPOperations('100.64.0.1-100.64.0.5')
+ % ip.get_prefix_by_ip_range()
+ [IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')]
+ """
+ # We do not need to convert the IP range to network
+ # if it is already in network format
+ if self.ip_network:
+ return [self.ip_network]
+
+ # Raise an error if the IP range is not in the correct format
+ if '-' not in self.ip_prefix:
+ raise ValueError(
+ 'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.'
+ )
+ # Split the IP range and convert it to IP address objects
+ range_start, range_end = self.ip_prefix.split('-')
+ range_start = ipaddress.IPv4Address(range_start)
+ range_end = ipaddress.IPv4Address(range_end)
+
+ # Return the summarized IP networks list
+ return list(ipaddress.summarize_address_range(range_start, range_end))
+
+
+def _delete_conntrack_entries(source_prefixes: list[ipaddress.IPv4Network]) -> None:
+ """Delete all conntrack entries for the list of prefixes"""
+ for source_prefix in source_prefixes:
+ run(f'conntrack -D -s {source_prefix}')
+
+
+def generate_port_rules(
+ external_hosts: list,
+ internal_hosts: list,
+ port_count: int,
+ global_port_range: str = '1024-65535',
+) -> list:
+ """Generates a list of nftables option rules for the batch file.
+
+ Args:
+ external_hosts (list): A list of external host IPs.
+ internal_hosts (list): A list of internal host IPs.
+ port_count (int): The number of ports required per host.
+ global_port_range (str): The global port range to be used. Default is '1024-65535'.
+
+ Returns:
+ list: A list containing two elements:
+ - proto_map_elements (list): A list of proto map elements.
+ - other_map_elements (list): A list of other map elements.
+ """
+ rules = []
+ proto_map_elements = []
+ other_map_elements = []
+ start_port, end_port = map(int, global_port_range.split('-'))
+ total_possible_ports = (end_port - start_port) + 1
+
+ # Calculate the required number of ports per host
+ required_ports_per_host = port_count
+ current_port = start_port
+ current_external_index = 0
+
+ for internal_host in internal_hosts:
+ external_host = external_hosts[current_external_index]
+ next_end_port = current_port + required_ports_per_host - 1
+
+ # If the port range exceeds the end_port, move to the next external host
+ while next_end_port > end_port:
+ current_external_index = (current_external_index + 1) % len(external_hosts)
+ external_host = external_hosts[current_external_index]
+ current_port = start_port
+ next_end_port = current_port + required_ports_per_host - 1
+
+ proto_map_elements.append(
+ f'{internal_host} : {external_host} . {current_port}-{next_end_port}'
+ )
+ other_map_elements.append(f'{internal_host} : {external_host}')
+
+ current_port = next_end_port + 1
+ if current_port > end_port:
+ current_port = start_port
+ current_external_index += 1 # Move to the next external host
+
+ return [proto_map_elements, other_map_elements]
+
+
+def get_config(config=None):
+ if config:
+ conf = config
+ else:
+ conf = Config()
+
+ base = ['nat', 'cgnat']
+ config = conf.get_config_dict(
+ base,
+ get_first_key=True,
+ key_mangling=('-', '_'),
+ no_tag_node_value_mangle=True,
+ with_recursive_defaults=True,
+ )
+
+ effective_config = conf.get_config_dict(
+ base,
+ get_first_key=True,
+ key_mangling=('-', '_'),
+ no_tag_node_value_mangle=True,
+ effective=True,
+ )
+
+ # Check if the pool configuration has changed
+ if not conf.exists(base) or is_node_changed(conf, base + ['pool']):
+ config['delete_conntrack_entries'] = {}
+
+ # add running config
+ if effective_config:
+ config['effective'] = effective_config
+
+ if not conf.exists(base):
+ config['deleted'] = {}
+
+ return config
+
+
+def verify(config):
+ # bail out early - looks like removal from running config
+ if 'deleted' in config:
+ return None
+
+ if 'pool' not in config:
+ raise ConfigError(f'Pool must be defined!')
+ if 'rule' not in config:
+ raise ConfigError(f'Rule must be defined!')
+
+ for pool in ('external', 'internal'):
+ if pool not in config['pool']:
+ raise ConfigError(f'{pool} pool must be defined!')
+ for pool_name, pool_config in config['pool'][pool].items():
+ if 'range' not in pool_config:
+ raise ConfigError(
+ f'Range for "{pool} pool {pool_name}" must be defined!'
+ )
+
+ external_pools_query = "keys(pool.external)"
+ external_pools: list = jmespath.search(external_pools_query, config)
+ internal_pools_query = "keys(pool.internal)"
+ internal_pools: list = jmespath.search(internal_pools_query, config)
+
+ used_external_pools = {}
+ used_internal_pools = {}
+ for rule, rule_config in config['rule'].items():
+ if 'source' not in rule_config:
+ raise ConfigError(f'Rule "{rule}" source pool must be defined!')
+ if 'pool' not in rule_config['source']:
+ raise ConfigError(f'Rule "{rule}" source pool must be defined!')
+
+ if 'translation' not in rule_config:
+ raise ConfigError(f'Rule "{rule}" translation pool must be defined!')
+
+ # Check if pool exists
+ internal_pool = rule_config['source']['pool']
+ if internal_pool not in internal_pools:
+ raise ConfigError(f'Internal pool "{internal_pool}" does not exist!')
+ external_pool = rule_config['translation']['pool']
+ if external_pool not in external_pools:
+ raise ConfigError(f'External pool "{external_pool}" does not exist!')
+
+ # Check pool duplication in different rules
+ if external_pool in used_external_pools:
+ raise ConfigError(
+ f'External pool "{external_pool}" is already used in rule '
+ f'{used_external_pools[external_pool]} and cannot be used in '
+ f'rule {rule}!'
+ )
+
+ if internal_pool in used_internal_pools:
+ raise ConfigError(
+ f'Internal pool "{internal_pool}" is already used in rule '
+ f'{used_internal_pools[internal_pool]} and cannot be used in '
+ f'rule {rule}!'
+ )
+
+ used_external_pools[external_pool] = rule
+ used_internal_pools[internal_pool] = rule
+
+ # Check calculation for allocation
+ external_port_range: str = config['pool']['external'][external_pool]['external_port_range']
+
+ external_ip_ranges: list = list(
+ config['pool']['external'][external_pool]['range']
+ )
+ internal_ip_ranges: list = config['pool']['internal'][internal_pool]['range']
+ start_port, end_port = map(int, external_port_range.split('-'))
+ ports_per_range_count: int = (end_port - start_port) + 1
+
+ external_list_hosts_count = []
+ external_list_hosts = []
+ internal_list_hosts_count = []
+ internal_list_hosts = []
+ for ext_range in external_ip_ranges:
+ # External hosts count
+ e_count = IPOperations(ext_range).get_ips_count()
+ external_list_hosts_count.append(e_count)
+ # External hosts list
+ e_hosts = IPOperations(ext_range).convert_prefix_to_list_ips()
+ external_list_hosts.extend(e_hosts)
+ for int_range in internal_ip_ranges:
+ # Internal hosts count
+ i_count = IPOperations(int_range).get_ips_count()
+ internal_list_hosts_count.append(i_count)
+ # Internal hosts list
+ i_hosts = IPOperations(int_range).convert_prefix_to_list_ips()
+ internal_list_hosts.extend(i_hosts)
+
+ external_host_count = sum(external_list_hosts_count)
+ internal_host_count = sum(internal_list_hosts_count)
+ ports_per_user: int = int(
+ config['pool']['external'][external_pool]['per_user_limit']['port']
+ )
+ users_per_extip = ports_per_range_count // ports_per_user
+ max_users = users_per_extip * external_host_count
+
+ if internal_host_count > max_users:
+ raise ConfigError(
+ f'Rule "{rule}" does not have enough ports available for the '
+ f'specified parameters'
+ )
+
+
+def generate(config):
+ if 'deleted' in config:
+ return None
+
+ proto_maps = []
+ other_maps = []
+
+ for rule, rule_config in config['rule'].items():
+ ext_pool_name: str = rule_config['translation']['pool']
+ int_pool_name: str = rule_config['source']['pool']
+
+ # Sort the external ranges by sequence
+ external_ranges: list = sorted(
+ config['pool']['external'][ext_pool_name]['range'],
+ key=lambda r: int(config['pool']['external'][ext_pool_name]['range'][r].get('seq', 999999))
+ )
+ internal_ranges: list = [range for range in config['pool']['internal'][int_pool_name]['range']]
+ external_list_hosts_count = []
+ external_list_hosts = []
+ internal_list_hosts_count = []
+ internal_list_hosts = []
+
+ for ext_range in external_ranges:
+ # External hosts count
+ e_count = IPOperations(ext_range).get_ips_count()
+ external_list_hosts_count.append(e_count)
+ # External hosts list
+ e_hosts = IPOperations(ext_range).convert_prefix_to_list_ips()
+ external_list_hosts.extend(e_hosts)
+
+ for int_range in internal_ranges:
+ # Internal hosts count
+ i_count = IPOperations(int_range).get_ips_count()
+ internal_list_hosts_count.append(i_count)
+ # Internal hosts list
+ i_hosts = IPOperations(int_range).convert_prefix_to_list_ips()
+ internal_list_hosts.extend(i_hosts)
+
+ external_host_count = sum(external_list_hosts_count)
+ internal_host_count = sum(internal_list_hosts_count)
+ ports_per_user = int(
+ jmespath.search(f'pool.external."{ext_pool_name}".per_user_limit.port', config)
+ )
+ external_port_range: str = jmespath.search(
+ f'pool.external."{ext_pool_name}".external_port_range', config
+ )
+
+ rule_proto_maps, rule_other_maps = generate_port_rules(
+ external_list_hosts, internal_list_hosts, ports_per_user, external_port_range
+ )
+
+ proto_maps.extend(rule_proto_maps)
+ other_maps.extend(rule_other_maps)
+
+ config['proto_map_elements'] = ', '.join(proto_maps)
+ config['other_map_elements'] = ', '.join(other_maps)
+
+ render(nftables_cgnat_config, 'firewall/nftables-cgnat.j2', config)
+
+ # dry-run newly generated configuration
+ tmp = run(f'nft --check --file {nftables_cgnat_config}')
+ if tmp > 0:
+ raise ConfigError('Configuration file errors encountered!')
+
+
+def apply(config):
+ if 'deleted' in config:
+ # Cleanup cgnat
+ cmd('nft delete table ip cgnat')
+ if os.path.isfile(nftables_cgnat_config):
+ os.unlink(nftables_cgnat_config)
+ else:
+ cmd(f'nft --file {nftables_cgnat_config}')
+
+ # Delete conntrack entries
+ # if the pool configuration has changed
+ if 'delete_conntrack_entries' in config and 'effective' in config:
+ # Prepare the list of internal pool prefixes
+ internal_pool_prefix_list: list[ipaddress.IPv4Network] = []
+
+ # Get effective rules configurations
+ for rule_config in config['effective'].get('rule', {}).values():
+ # Get effective internal pool configuration
+ internal_pool = rule_config['source']['pool']
+ # Find the internal IP ranges for the internal pool
+ internal_ip_ranges: list[str] = config['effective']['pool']['internal'][
+ internal_pool
+ ]['range']
+ # Get the IP prefixes for the internal IP range
+ for internal_range in internal_ip_ranges:
+ ip_prefix: list[ipaddress.IPv4Network] = IPOperations(
+ internal_range
+ ).get_prefix_by_ip_range()
+ # Add the IP prefixes to the list of all internal pool prefixes
+ internal_pool_prefix_list += ip_prefix
+
+ # Delete required sources for conntrack
+ _delete_conntrack_entries(internal_pool_prefix_list)
+
+ # Logging allocations
+ if 'log_allocation' in config:
+ allocations = config['proto_map_elements']
+ allocations = allocations.split(',')
+ for allocation in allocations:
+ try:
+ # Split based on the delimiters used in the nft data format
+ internal_host, rest = allocation.split(' : ')
+ external_host, port_range = rest.split(' . ')
+ # Log the parsed data
+ logger.info(
+ f'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}')
+ except ValueError as e:
+ # Log error message
+ logger.error(f"Error processing line '{allocation}': {e}")
+
+
+if __name__ == '__main__':
+ try:
+ c = get_config()
+ verify(c)
+ generate(c)
+ apply(c)
+ except ConfigError as e:
+ print(e)
+ exit(1)