From 6f7d1e15665655e37e8ca830e28d9650445c1217 Mon Sep 17 00:00:00 2001
From: sarthurdev <965089+sarthurdev@users.noreply.github.com>
Date: Tue, 27 Feb 2024 21:38:24 +0100
Subject: vrf: conntrack: T6073: Populate VRF zoning chains only while
 conntrack is required

---
 data/config-mode-dependencies/vyos-1x.json |  3 ++-
 data/vyos-firewall-init.conf               |  2 --
 python/vyos/firewall.py                    | 18 ++++++++++++++++++
 smoketest/scripts/cli/test_vrf.py          | 23 +++++++++++++++++++++++
 src/conf_mode/system_conntrack.py          |  4 ++++
 src/conf_mode/vrf.py                       | 18 ++++++++++++++++++
 6 files changed, 65 insertions(+), 3 deletions(-)

diff --git a/data/config-mode-dependencies/vyos-1x.json b/data/config-mode-dependencies/vyos-1x.json
index b0586e0bb..6ab36005b 100644
--- a/data/config-mode-dependencies/vyos-1x.json
+++ b/data/config-mode-dependencies/vyos-1x.json
@@ -1,6 +1,7 @@
 {
     "system_conntrack": {
-        "conntrack_sync": ["service_conntrack-sync"]
+        "conntrack_sync": ["service_conntrack-sync"],
+        "vrf": ["vrf"]
     },
     "firewall": {
         "conntrack": ["system_conntrack"],
diff --git a/data/vyos-firewall-init.conf b/data/vyos-firewall-init.conf
index 5a4e03015..3929edf0b 100644
--- a/data/vyos-firewall-init.conf
+++ b/data/vyos-firewall-init.conf
@@ -65,11 +65,9 @@ table inet vrf_zones {
     # Chain for inbound traffic
     chain vrf_zones_ct_in {
         type filter hook prerouting priority raw; policy accept;
-        counter ct original zone set iifname map @ct_iface_map
     }
     # Chain for locally-generated traffic
     chain vrf_zones_ct_out {
         type filter hook output priority raw; policy accept;
-        counter ct original zone set oifname map @ct_iface_map
     }
 }
diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py
index eee11bd2d..49e095946 100644
--- a/python/vyos/firewall.py
+++ b/python/vyos/firewall.py
@@ -34,6 +34,24 @@ from vyos.utils.process import call
 from vyos.utils.process import cmd
 from vyos.utils.process import run
 
+# Conntrack
+
+def conntrack_required(conf):
+    required_nodes = ['nat', 'nat66', 'load-balancing wan']
+
+    for path in required_nodes:
+        if conf.exists(path):
+            return True
+
+    firewall = conf.get_config_dict(['firewall'], key_mangling=('-', '_'),
+                                    no_tag_node_value_mangle=True, get_first_key=True)
+
+    for rules, path in dict_search_recursive(firewall, 'rule'):
+        if any(('state' in rule_conf or 'connection_status' in rule_conf or 'offload_target' in rule_conf) for rule_conf in rules.values()):
+            return True
+
+    return False
+
 # Domain Resolver
 
 def fqdn_config_parse(firewall):
diff --git a/smoketest/scripts/cli/test_vrf.py b/smoketest/scripts/cli/test_vrf.py
index 438387f2d..c96b8e374 100755
--- a/smoketest/scripts/cli/test_vrf.py
+++ b/smoketest/scripts/cli/test_vrf.py
@@ -529,5 +529,28 @@ class VRFTest(VyOSUnitTestSHIM.TestCase):
             self.assertNotIn(f' no ip nht resolve-via-default', frrconfig)
             self.assertNotIn(f' no ipv6 nht resolve-via-default', frrconfig)
 
+    def test_vrf_conntrack(self):
+        table = '1000'
+        nftables_rules = {
+            'vrf_zones_ct_in': ['ct original zone set iifname map @ct_iface_map'],
+            'vrf_zones_ct_out': ['ct original zone set oifname map @ct_iface_map']
+        }
+
+        self.cli_set(base_path + ['name', 'blue', 'table', table])
+        self.cli_commit()
+
+        # Conntrack rules should not be present
+        for chain, rule in nftables_rules.items():
+            self.verify_nftables_chain(rule, 'inet vrf_zones', chain, inverse=True)
+
+        self.cli_set(['nat'])
+        self.cli_commit()
+
+        # Conntrack rules should now be present
+        for chain, rule in nftables_rules.items():
+            self.verify_nftables_chain(rule, 'inet vrf_zones', chain, inverse=False)
+
+        self.cli_delete(['nat'])
+
 if __name__ == '__main__':
     unittest.main(verbosity=2)
diff --git a/src/conf_mode/system_conntrack.py b/src/conf_mode/system_conntrack.py
index 7f6c71440..e075bc928 100755
--- a/src/conf_mode/system_conntrack.py
+++ b/src/conf_mode/system_conntrack.py
@@ -104,6 +104,10 @@ def get_config(config=None):
     if conf.exists(['service', 'conntrack-sync']):
         set_dependents('conntrack_sync', conf)
 
+    # If conntrack status changes, VRF zone rules need updating
+    if conf.exists(['vrf']):
+        set_dependents('vrf', conf)
+
     return conntrack
 
 def verify(conntrack):
diff --git a/src/conf_mode/vrf.py b/src/conf_mode/vrf.py
index a2f4956be..16908100f 100755
--- a/src/conf_mode/vrf.py
+++ b/src/conf_mode/vrf.py
@@ -23,6 +23,7 @@ from vyos.config import Config
 from vyos.configdict import dict_merge
 from vyos.configdict import node_changed
 from vyos.configverify import verify_route_map
+from vyos.firewall import conntrack_required
 from vyos.ifconfig import Interface
 from vyos.template import render
 from vyos.template import render_to_string
@@ -41,6 +42,12 @@ airbag.enable()
 config_file = '/etc/iproute2/rt_tables.d/vyos-vrf.conf'
 k_mod = ['vrf']
 
+nftables_table = 'inet vrf_zones'
+nftables_rules = {
+    'vrf_zones_ct_in': 'counter ct original zone set iifname map @ct_iface_map',
+    'vrf_zones_ct_out': 'counter ct original zone set oifname map @ct_iface_map'
+}
+
 def has_rule(af : str, priority : int, table : str=None):
     """
     Check if a given ip rule exists
@@ -114,6 +121,9 @@ def get_config(config=None):
         routes = vrf_routing(conf, name)
         if routes: vrf['vrf_remove'][name]['route'] = routes
 
+    if 'name' in vrf:
+        vrf['conntrack'] = conntrack_required(conf)
+
     # We also need the route-map information from the config
     #
     # XXX: one MUST always call this without the key_mangling() option! See
@@ -294,6 +304,14 @@ def apply(vrf):
             nft_add_element = f'add element inet vrf_zones ct_iface_map {{ "{name}" : {table} }}'
             cmd(f'nft {nft_add_element}')
 
+        if vrf['conntrack']:
+            for chain, rule in nftables_rules.items():
+                cmd(f'nft add rule inet vrf_zones {chain} {rule}')
+    
+    if 'name' not in vrf or not vrf['conntrack']:
+        for chain, rule in nftables_rules.items():
+            cmd(f'nft flush chain inet vrf_zones {chain}')
+
     # Apply FRR filters
     zebra_daemon = 'zebra'
     # Save original configuration prior to starting any commit actions
-- 
cgit v1.2.3