summaryrefslogtreecommitdiff
path: root/src/conf_mode/nat_cgnat.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/conf_mode/nat_cgnat.py')
-rwxr-xr-xsrc/conf_mode/nat_cgnat.py117
1 files changed, 112 insertions, 5 deletions
diff --git a/src/conf_mode/nat_cgnat.py b/src/conf_mode/nat_cgnat.py
index d429f6e21..3484e5873 100755
--- a/src/conf_mode/nat_cgnat.py
+++ b/src/conf_mode/nat_cgnat.py
@@ -16,11 +16,14 @@
import ipaddress
import jmespath
+import logging
import os
from sys import exit
+from logging.handlers import SysLogHandler
from vyos.config import Config
+from vyos.configdict import is_node_changed
from vyos.template import render
from vyos.utils.process import cmd
from vyos.utils.process import run
@@ -32,6 +35,18 @@ airbag.enable()
nftables_cgnat_config = '/run/nftables-cgnat.nft'
+# Logging
+logger = logging.getLogger('cgnat')
+logger.setLevel(logging.DEBUG)
+
+syslog_handler = SysLogHandler(address="/dev/log")
+syslog_handler.setLevel(logging.INFO)
+
+formatter = logging.Formatter('%(name)s: %(message)s')
+syslog_handler.setFormatter(formatter)
+
+logger.addHandler(syslog_handler)
+
class IPOperations:
def __init__(self, ip_prefix: str):
@@ -104,6 +119,38 @@ class IPOperations:
+ [self.ip_network.broadcast_address]
]
+ def get_prefix_by_ip_range(self) -> list[ipaddress.IPv4Network]:
+ """Return the common prefix for the address range
+
+ Example:
+ % ip = IPOperations('100.64.0.1-100.64.0.5')
+ % ip.get_prefix_by_ip_range()
+ [IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')]
+ """
+ # We do not need to convert the IP range to network
+ # if it is already in network format
+ if self.ip_network:
+ return [self.ip_network]
+
+ # Raise an error if the IP range is not in the correct format
+ if '-' not in self.ip_prefix:
+ raise ValueError(
+ 'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.'
+ )
+ # Split the IP range and convert it to IP address objects
+ range_start, range_end = self.ip_prefix.split('-')
+ range_start = ipaddress.IPv4Address(range_start)
+ range_end = ipaddress.IPv4Address(range_end)
+
+ # Return the summarized IP networks list
+ return list(ipaddress.summarize_address_range(range_start, range_end))
+
+
+def _delete_conntrack_entries(source_prefixes: list[ipaddress.IPv4Network]) -> None:
+ """Delete all conntrack entries for the list of prefixes"""
+ for source_prefix in source_prefixes:
+ run(f'conntrack -D -s {source_prefix}')
+
def generate_port_rules(
external_hosts: list,
@@ -174,12 +221,31 @@ def get_config(config=None):
with_recursive_defaults=True,
)
+ effective_config = conf.get_config_dict(
+ base,
+ get_first_key=True,
+ key_mangling=('-', '_'),
+ no_tag_node_value_mangle=True,
+ effective=True,
+ )
+
+ # Check if the pool configuration has changed
+ if not conf.exists(base) or is_node_changed(conf, base + ['pool']):
+ config['delete_conntrack_entries'] = {}
+
+ # add running config
+ if effective_config:
+ config['effective'] = effective_config
+
+ if not conf.exists(base):
+ config['deleted'] = {}
+
return config
def verify(config):
# bail out early - looks like removal from running config
- if not config:
+ if 'deleted' in config:
return None
if 'pool' not in config:
@@ -283,7 +349,7 @@ def verify(config):
def generate(config):
- if not config:
+ if 'deleted' in config:
return None
proto_maps = []
@@ -348,13 +414,54 @@ def generate(config):
def apply(config):
- if not config:
+ if 'deleted' in config:
# Cleanup cgnat
cmd('nft delete table ip cgnat')
if os.path.isfile(nftables_cgnat_config):
os.unlink(nftables_cgnat_config)
- return None
- cmd(f'nft --file {nftables_cgnat_config}')
+ else:
+ cmd(f'nft --file {nftables_cgnat_config}')
+
+ # Delete conntrack entries
+ # if the pool configuration has changed
+ if 'delete_conntrack_entries' in config and 'effective' in config:
+ # Prepare the list of internal pool prefixes
+ internal_pool_prefix_list: list[ipaddress.IPv4Network] = []
+
+ # Get effective rules configurations
+ for rule_config in config['effective'].get('rule', {}).values():
+ # Get effective internal pool configuration
+ internal_pool = rule_config['source']['pool']
+ # Find the internal IP ranges for the internal pool
+ internal_ip_ranges: list[str] = config['effective']['pool']['internal'][
+ internal_pool
+ ]['range']
+ # Get the IP prefixes for the internal IP range
+ for internal_range in internal_ip_ranges:
+ ip_prefix: list[ipaddress.IPv4Network] = IPOperations(
+ internal_range
+ ).get_prefix_by_ip_range()
+ # Add the IP prefixes to the list of all internal pool prefixes
+ internal_pool_prefix_list += ip_prefix
+
+ # Delete required sources for conntrack
+ _delete_conntrack_entries(internal_pool_prefix_list)
+
+ # Logging allocations
+ if 'log_allocation' in config:
+ allocations = config['proto_map_elements']
+ allocations = allocations.split(',')
+ for allocation in allocations:
+ try:
+ # Split based on the delimiters used in the nft data format
+ internal_host, rest = allocation.split(' : ')
+ external_host, port_range = rest.split(' . ')
+ # Log the parsed data
+ logger.info(
+ f'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}')
+ except ValueError as e:
+ # Log error message
+ logger.error(f"Error processing line '{allocation}': {e}")
if __name__ == '__main__':