From 804efa2ef6bfee84d13f633d863f6f22f9eec273 Mon Sep 17 00:00:00 2001
From: Viacheslav Hletenko <v.gletenko@vyos.io>
Date: Thu, 27 Jun 2024 22:09:19 +0300
Subject: T6497: CGNAT refactoring delete conntrack entries (#3699)

---
 src/conf_mode/nat_cgnat.py | 110 ++++++++++++++++++++++++++++-----------------
 1 file changed, 68 insertions(+), 42 deletions(-)

(limited to 'src')

diff --git a/src/conf_mode/nat_cgnat.py b/src/conf_mode/nat_cgnat.py
index 34ec64fce..3484e5873 100755
--- a/src/conf_mode/nat_cgnat.py
+++ b/src/conf_mode/nat_cgnat.py
@@ -119,37 +119,34 @@ class IPOperations:
                 + [self.ip_network.broadcast_address]
             ]
 
-    def get_prefix_by_ip_range(self):
+    def get_prefix_by_ip_range(self) -> list[ipaddress.IPv4Network]:
         """Return the common prefix for the address range
 
         Example:
             % ip = IPOperations('100.64.0.1-100.64.0.5')
             % ip.get_prefix_by_ip_range()
-            100.64.0.0/29
+            [IPv4Network('100.64.0.1/32'), IPv4Network('100.64.0.2/31'), IPv4Network('100.64.0.4/31')]
         """
-        if '-' in self.ip_prefix:
-            ip_start, ip_end = self.ip_prefix.split('-')
-            start_ip = ipaddress.IPv4Address(ip_start.strip())
-            end_ip = ipaddress.IPv4Address(ip_end.strip())
-
-            start_int = int(start_ip)
-            end_int = int(end_ip)
-
-            # XOR to find differing bits
-            xor = start_int ^ end_int
-
-            # Count the number of leading zeros in the XOR result to find the prefix length
-            prefix_length = 32 - xor.bit_length()
-
-            # Calculate the network address
-            network_int = start_int & (0xFFFFFFFF << (32 - prefix_length))
-            network_address = ipaddress.IPv4Address(network_int)
+        # We do not need to convert the IP range to network
+        # if it is already in network format
+        if self.ip_network:
+            return [self.ip_network]
+
+        # Raise an error if the IP range is not in the correct format
+        if '-' not in self.ip_prefix:
+            raise ValueError(
+                'Invalid IP range format. Please provide the IP range in CIDR format or with "-" separator.'
+            )
+        # Split the IP range and convert it to IP address objects
+        range_start, range_end = self.ip_prefix.split('-')
+        range_start = ipaddress.IPv4Address(range_start)
+        range_end = ipaddress.IPv4Address(range_end)
 
-            return f"{network_address}/{prefix_length}"
-        return self.ip_prefix
+        # Return the summarized IP networks list
+        return list(ipaddress.summarize_address_range(range_start, range_end))
 
 
-def _delete_conntrack_entries(source_prefixes: list) -> None:
+def _delete_conntrack_entries(source_prefixes: list[ipaddress.IPv4Network]) -> None:
     """Delete all conntrack entries for the list of prefixes"""
     for source_prefix in source_prefixes:
         run(f'conntrack -D -s {source_prefix}')
@@ -224,15 +221,31 @@ def get_config(config=None):
         with_recursive_defaults=True,
     )
 
-    if conf.exists(base) and is_node_changed(conf, base + ['pool']):
-        config.update({'delete_conntrack_entries': {}})
+    effective_config = conf.get_config_dict(
+        base,
+        get_first_key=True,
+        key_mangling=('-', '_'),
+        no_tag_node_value_mangle=True,
+        effective=True,
+    )
+
+    # Check if the pool configuration has changed
+    if not conf.exists(base) or is_node_changed(conf, base + ['pool']):
+        config['delete_conntrack_entries'] = {}
+
+    # add running config
+    if effective_config:
+        config['effective'] = effective_config
+
+    if not conf.exists(base):
+        config['deleted'] = {}
 
     return config
 
 
 def verify(config):
     # bail out early - looks like removal from running config
-    if not config:
+    if 'deleted' in config:
         return None
 
     if 'pool' not in config:
@@ -336,7 +349,7 @@ def verify(config):
 
 
 def generate(config):
-    if not config:
+    if 'deleted' in config:
         return None
 
     proto_maps = []
@@ -401,13 +414,38 @@ def generate(config):
 
 
 def apply(config):
-    if not config:
+    if 'deleted' in config:
         # Cleanup cgnat
         cmd('nft delete table ip cgnat')
         if os.path.isfile(nftables_cgnat_config):
             os.unlink(nftables_cgnat_config)
-        return None
-    cmd(f'nft --file {nftables_cgnat_config}')
+    else:
+        cmd(f'nft --file {nftables_cgnat_config}')
+
+    # Delete conntrack entries
+    # if the pool configuration has changed
+    if 'delete_conntrack_entries' in config and 'effective' in config:
+        # Prepare the list of internal pool prefixes
+        internal_pool_prefix_list: list[ipaddress.IPv4Network] = []
+
+        # Get effective rules configurations
+        for rule_config in config['effective'].get('rule', {}).values():
+            # Get effective internal pool configuration
+            internal_pool = rule_config['source']['pool']
+            # Find the internal IP ranges for the internal pool
+            internal_ip_ranges: list[str] = config['effective']['pool']['internal'][
+                internal_pool
+            ]['range']
+            # Get the IP prefixes for the internal IP range
+            for internal_range in internal_ip_ranges:
+                ip_prefix: list[ipaddress.IPv4Network] = IPOperations(
+                    internal_range
+                ).get_prefix_by_ip_range()
+                # Add the IP prefixes to the list of all internal pool prefixes
+                internal_pool_prefix_list += ip_prefix
+
+        # Delete required sources for conntrack
+        _delete_conntrack_entries(internal_pool_prefix_list)
 
     # Logging allocations
     if 'log_allocation' in config:
@@ -420,23 +458,11 @@ def apply(config):
                 external_host, port_range = rest.split(' . ')
                 # Log the parsed data
                 logger.info(
-                    f"Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}")
+                    f'Internal host: {internal_host.lstrip()}, external host: {external_host}, Port range: {port_range}')
             except ValueError as e:
                 # Log error message
                 logger.error(f"Error processing line '{allocation}': {e}")
 
-    # Delete conntrack entries
-    if 'delete_conntrack_entries' in config:
-        internal_pool_prefix_list = []
-        for rule, rule_config in config['rule'].items():
-            internal_pool = rule_config['source']['pool']
-            internal_ip_ranges: list = config['pool']['internal'][internal_pool]['range']
-            for internal_range in internal_ip_ranges:
-                ip_prefix = IPOperations(internal_range).get_prefix_by_ip_range()
-                internal_pool_prefix_list.append(ip_prefix)
-        # Deleta required sources for conntrack
-        _delete_conntrack_entries(internal_pool_prefix_list)
-
 
 if __name__ == '__main__':
     try:
-- 
cgit v1.2.3