From 342db936a02a02ba04867f932137638485ef0a6f Mon Sep 17 00:00:00 2001
From: Nicolas Fort <nicolasfort1988@gmail.com>
Date: Mon, 29 May 2023 18:48:12 +0000
Subject: T5160: firewall refactor. Update op-mode commands to new syntax.

---
 src/op_mode/firewall.py | 186 ++++++++++++++++++++++--------------------------
 1 file changed, 85 insertions(+), 101 deletions(-)

(limited to 'src')

diff --git a/src/op_mode/firewall.py b/src/op_mode/firewall.py
index 8260bbb77..8eb883f81 100755
--- a/src/op_mode/firewall.py
+++ b/src/op_mode/firewall.py
@@ -24,62 +24,27 @@ from vyos.config import Config
 from vyos.utils.process import cmd
 from vyos.utils.dict import dict_search_args
 
-def get_firewall_interfaces(firewall, name=None, ipv6=False):
-    directions = ['in', 'out', 'local']
-
-    if 'interface' in firewall:
-        for ifname, if_conf in firewall['interface'].items():
-            for direction in directions:
-                if direction not in if_conf:
-                    continue
-
-                fw_conf = if_conf[direction]
-                name_str = f'({ifname},{direction})'
-
-                if 'name' in fw_conf:
-                    fw_name = fw_conf['name']
-
-                    if not name:
-                        firewall['name'][fw_name]['interface'].append(name_str)
-                    elif not ipv6 and name == fw_name:
-                        firewall['interface'].append(name_str)
-
-                if 'ipv6_name' in fw_conf:
-                    fw_name = fw_conf['ipv6_name']
-
-                    if not name:
-                        firewall['ipv6_name'][fw_name]['interface'].append(name_str)
-                    elif ipv6 and name == fw_name:
-                        firewall['interface'].append(name_str)
-
-    return firewall
-
-def get_config_firewall(conf, name=None, ipv6=False, interfaces=True):
+def get_config_firewall(conf, hook=None, priority=None, ipv6=False, interfaces=True):
     config_path = ['firewall']
-    if name:
-        config_path += ['ipv6-name' if ipv6 else 'name', name]
+    if hook:
+        config_path += ['ipv6' if ipv6 else 'ip', hook]
+        if priority:
+            config_path += [priority]
 
     firewall = conf.get_config_dict(config_path, key_mangling=('-', '_'),
                                 get_first_key=True, no_tag_node_value_mangle=True)
-    if firewall and interfaces:
-        if name:
-            firewall['interface'] = {}
-        else:
-            if 'name' in firewall:
-                for fw_name, name_conf in firewall['name'].items():
-                    name_conf['interface'] = []
-
-            if 'ipv6_name' in firewall:
-                for fw_name, name_conf in firewall['ipv6_name'].items():
-                    name_conf['interface'] = []
 
-        get_firewall_interfaces(firewall, name, ipv6)
     return firewall
 
-def get_nftables_details(name, ipv6=False):
+def get_nftables_details(hook, priority, ipv6=False):
     suffix = '6' if ipv6 else ''
     name_prefix = 'NAME6_' if ipv6 else 'NAME_'
-    command = f'sudo nft list chain ip{suffix} vyos_filter {name_prefix}{name}'
+    if hook == 'name' or hook == 'ipv6-name':
+        command = f'sudo nft list chain ip{suffix} vyos_filter {name_prefix}{priority}'
+    else:
+        up_hook = hook.upper()
+        command = f'sudo nft list chain ip{suffix} vyos_filter VYOS_{up_hook}_{priority}'
+
     try:
         results = cmd(command)
     except:
@@ -87,7 +52,7 @@ def get_nftables_details(name, ipv6=False):
 
     out = {}
     for line in results.split('\n'):
-        comment_search = re.search(rf'{name}[\- ](\d+|default-action)', line)
+        comment_search = re.search(rf'{priority}[\- ](\d+|default-action)', line)
         if not comment_search:
             continue
 
@@ -102,18 +67,15 @@ def get_nftables_details(name, ipv6=False):
         out[rule_id] = rule
     return out
 
-def output_firewall_name(name, name_conf, ipv6=False, single_rule_id=None):
+def output_firewall_name(hook, priority, firewall_conf, ipv6=False, single_rule_id=None):
     ip_str = 'IPv6' if ipv6 else 'IPv4'
-    print(f'\n---------------------------------\n{ip_str} Firewall "{name}"\n')
-
-    if name_conf['interface']:
-        print('Active on: {0}\n'.format(" ".join(name_conf['interface'])))
+    print(f'\n---------------------------------\n{ip_str} Firewall "{hook} {priority}"\n')
 
-    details = get_nftables_details(name, ipv6)
+    details = get_nftables_details(hook, priority, ipv6)
     rows = []
 
-    if 'rule' in name_conf:
-        for rule_id, rule_conf in name_conf['rule'].items():
+    if 'rule' in firewall_conf:
+        for rule_id, rule_conf in firewall_conf['rule'].items():
             if single_rule_id and rule_id != single_rule_id:
                 continue
 
@@ -128,8 +90,8 @@ def output_firewall_name(name, name_conf, ipv6=False, single_rule_id=None):
                 row.append(rule_details['conditions'])
             rows.append(row)
 
-    if 'default_action' in name_conf and not single_rule_id:
-        row = ['default', name_conf['default_action'], 'all']
+    if 'default_action' in firewall_conf and not single_rule_id:
+        row = ['default', firewall_conf['default_action'], 'all']
         if 'default-action' in details:
             rule_details = details['default-action']
             row.append(rule_details.get('packets', 0))
@@ -140,18 +102,15 @@ def output_firewall_name(name, name_conf, ipv6=False, single_rule_id=None):
         header = ['Rule', 'Action', 'Protocol', 'Packets', 'Bytes', 'Conditions']
         print(tabulate.tabulate(rows, header) + '\n')
 
-def output_firewall_name_statistics(name, name_conf, ipv6=False, single_rule_id=None):
+def output_firewall_name_statistics(hook, prior, prior_conf, ipv6=False, single_rule_id=None):
     ip_str = 'IPv6' if ipv6 else 'IPv4'
-    print(f'\n---------------------------------\n{ip_str} Firewall "{name}"\n')
-
-    if name_conf['interface']:
-        print('Active on: {0}\n'.format(" ".join(name_conf['interface'])))
+    print(f'\n---------------------------------\n{ip_str} Firewall "{hook} {prior}"\n')
 
-    details = get_nftables_details(name, ipv6)
+    details = get_nftables_details(prior, ipv6)
     rows = []
 
-    if 'rule' in name_conf:
-        for rule_id, rule_conf in name_conf['rule'].items():
+    if 'rule' in prior_conf:
+        for rule_id, rule_conf in prior_conf['rule'].items():
             if single_rule_id and rule_id != single_rule_id:
                 continue
 
@@ -174,7 +133,7 @@ def output_firewall_name_statistics(name, name_conf, ipv6=False, single_rule_id=
             row.append(dest_addr)
             rows.append(row)
 
-    if 'default_action' in name_conf and not single_rule_id:
+    if 'default_action' in prior_conf and not single_rule_id:
         row = ['default']
         if 'default-action' in details:
             rule_details = details['default-action']
@@ -183,7 +142,7 @@ def output_firewall_name_statistics(name, name_conf, ipv6=False, single_rule_id=
         else:
             row.append('0')
             row.append('0')
-        row.append(name_conf['default_action'])
+        row.append(prior_conf['default_action'])
         row.append('0.0.0.0/0') # Source
         row.append('0.0.0.0/0') # Dest
         rows.append(row)
@@ -201,29 +160,47 @@ def show_firewall():
     if not firewall:
         return
 
-    if 'name' in firewall:
-        for name, name_conf in firewall['name'].items():
-            output_firewall_name(name, name_conf, ipv6=False)
+    if 'ip' in firewall:
+        for hook, hook_conf in firewall['ip'].items():
+            for prior, prior_conf in firewall['ip'][hook].items():
+                output_firewall_name(hook, prior, prior_conf, ipv6=False)
+
+    if 'ipv6' in firewall:
+        for hook, hook_conf in firewall['ipv6'].items():
+            for prior, prior_conf in firewall['ipv6'][hook].items():
+                output_firewall_name(hook, prior, prior_conf, ipv6=True)
+
+def show_firewall_family(family):
+    print(f'Rulesets {family} Information')
+
+    conf = Config()
+    firewall = get_config_firewall(conf)
+
+    if not firewall:
+        return
 
-    if 'ipv6_name' in firewall:
-        for name, name_conf in firewall['ipv6_name'].items():
-            output_firewall_name(name, name_conf, ipv6=True)
+    for hook, hook_conf in firewall[family].items():
+        for prior, prior_conf in firewall[family][hook].items():
+            if family == 'ipv6':
+                output_firewall_name(hook, prior, prior_conf, ipv6=True)
+            else:
+                output_firewall_name(hook, prior, prior_conf, ipv6=False)
 
-def show_firewall_name(name, ipv6=False):
+def show_firewall_name(hook, priority, ipv6=False):
     print('Ruleset Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf, name, ipv6)
+    firewall = get_config_firewall(conf, hook, priority, ipv6)
     if firewall:
-        output_firewall_name(name, firewall, ipv6)
+        output_firewall_name(hook, priority, firewall, ipv6)
 
-def show_firewall_rule(name, rule_id, ipv6=False):
+def show_firewall_rule(hook, priority, rule_id, ipv6=False):
     print('Rule Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf, name, ipv6)
+    firewall = get_config_firewall(conf, hook, priority, ipv6)
     if firewall:
-        output_firewall_name(name, firewall, ipv6, rule_id)
+        output_firewall_name(hook, priority, firewall, ipv6, rule_id)
 
 def show_firewall_group(name=None):
     conf = Config()
@@ -284,28 +261,28 @@ def show_summary():
     if not firewall:
         return
 
-    header = ['Ruleset Name', 'Description', 'References']
+    header = ['Ruleset Hook', 'Ruleset Priority', 'Description', 'References']
     v4_out = []
     v6_out = []
 
-    if 'name' in firewall:
-        for name, name_conf in firewall['name'].items():
-            description = name_conf.get('description', '')
-            interfaces = ", ".join(name_conf['interface'])
-            v4_out.append([name, description, interfaces])
+    if 'ip' in firewall:
+        for hook, hook_conf in firewall['ip'].items():
+            for prior, prior_conf in firewall['ip'][hook].items():
+                description = prior_conf.get('description', '')
+                v4_out.append([hook, prior, description])
 
-    if 'ipv6_name' in firewall:
-        for name, name_conf in firewall['ipv6_name'].items():
-            description = name_conf.get('description', '')
-            interfaces = ", ".join(name_conf['interface'])
-            v6_out.append([name, description, interfaces or 'N/A'])
+    if 'ipv6' in firewall:
+        for hook, hook_conf in firewall['ipv6'].items():
+            for prior, prior_conf in firewall['ipv6'][hook].items():
+                description = prior_conf.get('description', '')
+                v6_out.append([hook, prior, description])
 
     if v6_out:
-        print('\nIPv6 name:\n')
+        print('\nIPv6 Ruleset:\n')
         print(tabulate.tabulate(v6_out, header) + '\n')
 
     if v4_out:
-        print('\nIPv4 name:\n')
+        print('\nIPv4 Ruleset:\n')
         print(tabulate.tabulate(v4_out, header) + '\n')
 
     show_firewall_group()
@@ -319,18 +296,23 @@ def show_statistics():
     if not firewall:
         return
 
-    if 'name' in firewall:
-        for name, name_conf in firewall['name'].items():
-            output_firewall_name_statistics(name, name_conf, ipv6=False)
+    if 'ip' in firewall:
+        for hook, hook_conf in firewall['ip'].items():
+            for prior, prior_conf in firewall['ip'][hook].items():
+                output_firewall_name_statistics(hook,prior, prior_conf, ipv6=False)
 
-    if 'ipv6_name' in firewall:
-        for name, name_conf in firewall['ipv6_name'].items():
-            output_firewall_name_statistics(name, name_conf, ipv6=True)
+    if 'ipv6' in firewall:
+        for hook, hook_conf in firewall['ipv6'].items():
+            for prior, prior_conf in firewall['ipv6'][hook].items():
+                output_firewall_name_statistics(hook,prior, prior_conf, ipv6=True)
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--action', help='Action', required=False)
     parser.add_argument('--name', help='Firewall name', required=False, action='store', nargs='?', default='')
+    parser.add_argument('--family', help='IP family', required=False, action='store', nargs='?', default='')
+    parser.add_argument('--hook', help='Firewall hook', required=False, action='store', nargs='?', default='')
+    parser.add_argument('--priority', help='Firewall priority', required=False, action='store', nargs='?', default='')
     parser.add_argument('--rule', help='Firewall Rule ID', required=False)
     parser.add_argument('--ipv6', help='IPv6 toggle', action='store_true')
 
@@ -338,11 +320,13 @@ if __name__ == '__main__':
 
     if args.action == 'show':
         if not args.rule:
-            show_firewall_name(args.name, args.ipv6)
+            show_firewall_name(args.hook, args.priority, args.ipv6)
         else:
-            show_firewall_rule(args.name, args.rule, args.ipv6)
+            show_firewall_rule(args.hook, args.priority, args.rule, args.ipv6)
     elif args.action == 'show_all':
         show_firewall()
+    elif args.action == 'show_family':
+        show_firewall_family(args.family)
     elif args.action == 'show_group':
         show_firewall_group(args.name)
     elif args.action == 'show_statistics':
-- 
cgit v1.2.3