From 84a83ecc4c78bf2e0954658ea539e42b4c015fa2 Mon Sep 17 00:00:00 2001
From: sarthurdev <965089+sarthurdev@users.noreply.github.com>
Date: Mon, 3 Jan 2022 22:17:08 +0100
Subject: firewall: T4130: Fix firewall state-policy errors

Also fixes:
* Issue with multiple state-policy rules being created on firewall updates
* Prevents interface rules being inserted before state-policy
---
 src/conf_mode/firewall-interface.py | 27 +++++++++++++++++++++++++--
 src/conf_mode/firewall.py           | 28 +++++++++++++++++++++++++---
 2 files changed, 50 insertions(+), 5 deletions(-)

(limited to 'src')

diff --git a/src/conf_mode/firewall-interface.py b/src/conf_mode/firewall-interface.py
index 3a17dc5a4..516fa6c48 100755
--- a/src/conf_mode/firewall-interface.py
+++ b/src/conf_mode/firewall-interface.py
@@ -107,6 +107,15 @@ def cleanup_rule(table, chain, ifname, new_name=None):
                 run(f'nft delete rule {table} {chain} handle {handle_search[1]}')
     return retval
 
+def state_policy_handle(table, chain):
+    results = cmd(f'nft -a list chain {table} {chain}').split("\n")
+    for line in results:
+        if 'jump VYOS_STATE_POLICY' in line:
+            handle_search = re.search('handle (\d+)', line)
+            if handle_search:
+                return handle_search[1]
+    return None
+
 def apply(if_firewall):
     ifname = if_firewall['ifname']
 
@@ -118,18 +127,32 @@ def apply(if_firewall):
         name = dict_search_args(if_firewall, direction, 'name')
         if name:
             rule_exists = cleanup_rule('ip filter', chain, ifname, name)
+            rule_action = 'insert'
+            rule_prefix = ''
 
             if not rule_exists:
-                run(f'nft insert rule ip filter {chain} {if_prefix}ifname {ifname} counter jump {name}')
+                handle = state_policy_handle('ip filter', chain)
+                if handle:
+                    rule_action = 'add'
+                    rule_prefix = f'position {handle}'
+
+                run(f'nft {rule_action} rule ip filter {chain} {rule_prefix} {if_prefix}ifname {ifname} counter jump {name}')
         else:
             cleanup_rule('ip filter', chain, ifname)
 
         ipv6_name = dict_search_args(if_firewall, direction, 'ipv6_name')
         if ipv6_name:
             rule_exists = cleanup_rule('ip6 filter', ipv6_chain, ifname, ipv6_name)
+            rule_action = 'insert'
+            rule_prefix = ''
 
             if not rule_exists:
-                run(f'nft insert rule ip6 filter {ipv6_chain} {if_prefix}ifname {ifname} counter jump {ipv6_name}')
+                handle = state_policy_handle('ip filter', chain)
+                if handle:
+                    rule_action = 'add'
+                    rule_prefix = f'position {handle}'
+
+                run(f'nft {rule_action} rule ip6 filter {ipv6_chain} {rule_prefix} {if_prefix}ifname {ifname} counter jump {ipv6_name}')
         else:
             cleanup_rule('ip6 filter', ipv6_chain, ifname)
 
diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py
index 5ac48c9ba..8e037c679 100755
--- a/src/conf_mode/firewall.py
+++ b/src/conf_mode/firewall.py
@@ -205,13 +205,20 @@ def verify(firewall):
 def cleanup_commands(firewall):
     commands = []
     for table in ['ip filter', 'ip6 filter']:
+        state_chain = 'VYOS_STATE_POLICY' if table == 'ip filter' else 'VYOS_STATE_POLICY6'
         json_str = cmd(f'nft -j list table {table}')
         obj = loads(json_str)
         if 'nftables' not in obj:
             continue
         for item in obj['nftables']:
             if 'chain' in item:
-                if item['chain']['name'] not in preserve_chains:
+                if item['chain']['name'] in ['VYOS_STATE_POLICY', 'VYOS_STATE_POLICY6']:
+                    chain = item['chain']['name']
+                    if 'state_policy' not in firewall:
+                        commands.append(f'delete chain {table} {chain}')
+                    else:
+                        commands.append(f'flush chain {table} {chain}')
+                elif item['chain']['name'] not in preserve_chains:
                     chain = item['chain']['name']
                     if table == 'ip filter' and dict_search_args(firewall, 'name', chain):
                         commands.append(f'flush chain {table} {chain}')
@@ -219,6 +226,14 @@ def cleanup_commands(firewall):
                         commands.append(f'flush chain {table} {chain}')
                     else:
                         commands.append(f'delete chain {table} {chain}')
+            elif 'rule' in item:
+                rule = item['rule']
+                if rule['chain'] in ['VYOS_FW_IN', 'VYOS_FW_OUTPUT', 'VYOS_FW_LOCAL', 'VYOS_FW6_IN', 'VYOS_FW6_OUTPUT', 'VYOS_FW6_LOCAL']:
+                    if 'expr' in rule and any([True for expr in rule['expr'] if dict_search_args(expr, 'jump', 'target') == state_chain]):
+                        if 'state_policy' not in firewall:
+                            chain = rule['chain']
+                            handle = rule['handle']
+                            commands.append(f'delete rule {table} {chain} handle {handle}')
     return commands
 
 def generate(firewall):
@@ -286,6 +301,11 @@ def post_apply_trap(firewall):
 
                 cmd(base_cmd + ' '.join(objects))
 
+def state_policy_rule_exists():
+    # Determine if state policy rules already exist in nft
+    search_str = cmd(f'nft list chain ip filter VYOS_FW_IN')
+    return 'VYOS_STATE_POLICY' in search_str
+
 def apply(firewall):
     if 'first_install' in firewall:
         run('nfct helper add rpc inet tcp')
@@ -296,9 +316,11 @@ def apply(firewall):
     if install_result == 1:
         raise ConfigError('Failed to apply firewall')
 
-    if 'state_policy' in firewall:
-        for chain in ['INPUT', 'OUTPUT', 'FORWARD']:
+    if 'state_policy' in firewall and not state_policy_rule_exists():
+        for chain in ['VYOS_FW_IN', 'VYOS_FW_OUTPUT', 'VYOS_FW_LOCAL']:
             cmd(f'nft insert rule ip filter {chain} jump VYOS_STATE_POLICY')
+
+        for chain in ['VYOS_FW6_IN', 'VYOS_FW6_OUTPUT', 'VYOS_FW6_LOCAL']:
             cmd(f'nft insert rule ip6 filter {chain} jump VYOS_STATE_POLICY6')
 
     apply_sysfs(firewall)
-- 
cgit v1.2.3