From 063de842144ac95565a46df3da86dbc7f56643ae Mon Sep 17 00:00:00 2001
From: Nicolas Fort <nicolasfort1988@gmail.com>
Date: Mon, 11 Sep 2023 15:56:07 +0000
Subject: T4072: Firewall op-mode command: add bridge capabilities

---
 src/op_mode/firewall.py | 114 ++++++++++++++++++++++++++----------------------
 1 file changed, 62 insertions(+), 52 deletions(-)

(limited to 'src')

diff --git a/src/op_mode/firewall.py b/src/op_mode/firewall.py
index 11cbd977d..3434707ec 100755
--- a/src/op_mode/firewall.py
+++ b/src/op_mode/firewall.py
@@ -24,27 +24,39 @@ from vyos.config import Config
 from vyos.utils.process import cmd
 from vyos.utils.dict import dict_search_args
 
-def get_config_firewall(conf, hook=None, priority=None, ipv6=False):
+def get_config_firewall(conf, family=None, hook=None, priority=None):
     config_path = ['firewall']
-    if hook:
-        config_path += ['ipv6' if ipv6 else 'ipv4', hook]
-        if priority:
-            config_path += [priority]
+    if family:
+        config_path += [family]
+        if hook:
+            config_path += [hook]
+            if priority:
+                config_path += [priority]
 
     firewall = conf.get_config_dict(config_path, key_mangling=('-', '_'),
                                 get_first_key=True, no_tag_node_value_mangle=True)
 
     return firewall
 
-def get_nftables_details(hook, priority, ipv6=False):
-    suffix = '6' if ipv6 else ''
-    aux = 'IPV6_' if ipv6 else ''
-    name_prefix = 'NAME6_' if ipv6 else 'NAME_'
+def get_nftables_details(family, hook, priority):
+    if family == 'ipv6':
+        suffix = 'ip6'
+        name_prefix = 'NAME6_'
+        aux='IPV6_'
+    elif family == 'ipv4':
+        suffix = 'ip'
+        name_prefix = 'NAME_'
+        aux=''
+    else:
+        suffix = 'bridge'
+        name_prefix = 'NAME_'
+        aux=''
+
     if hook == 'name' or hook == 'ipv6-name':
-        command = f'sudo nft list chain ip{suffix} vyos_filter {name_prefix}{priority}'
+        command = f'sudo nft list chain {suffix} vyos_filter {name_prefix}{priority}'
     else:
         up_hook = hook.upper()
-        command = f'sudo nft list chain ip{suffix} vyos_filter VYOS_{aux}{up_hook}_{priority}'
+        command = f'sudo nft list chain {suffix} vyos_filter VYOS_{aux}{up_hook}_{priority}'
 
     try:
         results = cmd(command)
@@ -68,11 +80,10 @@ def get_nftables_details(hook, priority, ipv6=False):
         out[rule_id] = rule
     return out
 
-def output_firewall_name(hook, priority, firewall_conf, ipv6=False, single_rule_id=None):
-    ip_str = 'IPv6' if ipv6 else 'IPv4'
-    print(f'\n---------------------------------\n{ip_str} Firewall "{hook} {priority}"\n')
+def output_firewall_name(family, hook, priority, firewall_conf, single_rule_id=None):
+    print(f'\n---------------------------------\n{family} Firewall "{hook} {priority}"\n')
 
-    details = get_nftables_details(hook, priority, ipv6)
+    details = get_nftables_details(family, hook, priority)
     rows = []
 
     if 'rule' in firewall_conf:
@@ -103,11 +114,10 @@ def output_firewall_name(hook, priority, firewall_conf, ipv6=False, single_rule_
         header = ['Rule', 'Action', 'Protocol', 'Packets', 'Bytes', 'Conditions']
         print(tabulate.tabulate(rows, header) + '\n')
 
-def output_firewall_name_statistics(hook, prior, prior_conf, ipv6=False, single_rule_id=None):
-    ip_str = 'IPv6' if ipv6 else 'IPv4'
-    print(f'\n---------------------------------\n{ip_str} Firewall "{hook} {prior}"\n')
+def output_firewall_name_statistics(family, hook, prior, prior_conf, single_rule_id=None):
+    print(f'\n---------------------------------\n{family} Firewall "{hook} {prior}"\n')
 
-    details = get_nftables_details(hook, prior, ipv6)
+    details = get_nftables_details(family, hook, prior)
     rows = []
 
     if 'rule' in prior_conf:
@@ -210,8 +220,8 @@ def output_firewall_name_statistics(hook, prior, prior_conf, ipv6=False, single_
             row.append('0')
             row.append('0')
         row.append(prior_conf['default_action'])
-        row.append('any') # Source
-        row.append('any') # Dest
+        row.append('any')   # Source
+        row.append('any')   # Dest
         row.append('any')   # inbound-interface
         row.append('any')   # outbound-interface
         rows.append(row)
@@ -229,15 +239,11 @@ def show_firewall():
     if not firewall:
         return
 
-    if 'ipv4' in firewall:
-        for hook, hook_conf in firewall['ipv4'].items():
-            for prior, prior_conf in firewall['ipv4'][hook].items():
-                output_firewall_name(hook, prior, prior_conf, ipv6=False)
-
-    if 'ipv6' in firewall:
-        for hook, hook_conf in firewall['ipv6'].items():
-            for prior, prior_conf in firewall['ipv6'][hook].items():
-                output_firewall_name(hook, prior, prior_conf, ipv6=True)
+    for family in ['ipv4', 'ipv6', 'bridge']:
+        if family in firewall:
+            for hook, hook_conf in firewall[family].items():
+                for prior, prior_conf in firewall[family][hook].items():
+                    output_firewall_name(family, hook, prior, prior_conf)
 
 def show_firewall_family(family):
     print(f'Rulesets {family} Information')
@@ -245,31 +251,28 @@ def show_firewall_family(family):
     conf = Config()
     firewall = get_config_firewall(conf)
 
-    if not firewall:
+    if not firewall or family not in firewall:
         return
 
     for hook, hook_conf in firewall[family].items():
         for prior, prior_conf in firewall[family][hook].items():
-            if family == 'ipv6':
-                output_firewall_name(hook, prior, prior_conf, ipv6=True)
-            else:
-                output_firewall_name(hook, prior, prior_conf, ipv6=False)
+            output_firewall_name(family, hook, prior, prior_conf)
 
-def show_firewall_name(hook, priority, ipv6=False):
+def show_firewall_name(family, hook, priority):
     print('Ruleset Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf, hook, priority, ipv6)
+    firewall = get_config_firewall(conf, family, hook, priority)
     if firewall:
-        output_firewall_name(hook, priority, firewall, ipv6)
+        output_firewall_name(family, hook, priority, firewall)
 
-def show_firewall_rule(hook, priority, rule_id, ipv6=False):
+def show_firewall_rule(family, hook, priority, rule_id):
     print('Rule Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf, hook, priority, ipv6)
+    firewall = get_config_firewall(conf, family, hook, priority)
     if firewall:
-        output_firewall_name(hook, priority, firewall, ipv6, rule_id)
+        output_firewall_name(family, hook, priority, firewall, rule_id)
 
 def show_firewall_group(name=None):
     conf = Config()
@@ -369,6 +372,7 @@ def show_summary():
     header = ['Ruleset Hook', 'Ruleset Priority', 'Description', 'References']
     v4_out = []
     v6_out = []
+    br_out = []
 
     if 'ipv4' in firewall:
         for hook, hook_conf in firewall['ipv4'].items():
@@ -382,6 +386,12 @@ def show_summary():
                 description = prior_conf.get('description', '')
                 v6_out.append([hook, prior, description])
 
+    if 'bridge' in firewall:
+        for hook, hook_conf in firewall['bridge'].items():
+            for prior, prior_conf in firewall['bridge'][hook].items():
+                description = prior_conf.get('description', '')
+                br_out.append([hook, prior, description])
+
     if v6_out:
         print('\nIPv6 Ruleset:\n')
         print(tabulate.tabulate(v6_out, header) + '\n')
@@ -390,6 +400,10 @@ def show_summary():
         print('\nIPv4 Ruleset:\n')
         print(tabulate.tabulate(v4_out, header) + '\n')
 
+    if br_out:
+        print('\nBridge Ruleset:\n')
+        print(tabulate.tabulate(br_out, header) + '\n')
+
     show_firewall_group()
 
 def show_statistics():
@@ -401,15 +415,11 @@ def show_statistics():
     if not firewall:
         return
 
-    if 'ipv4' in firewall:
-        for hook, hook_conf in firewall['ipv4'].items():
-            for prior, prior_conf in firewall['ipv4'][hook].items():
-                output_firewall_name_statistics(hook,prior, prior_conf, ipv6=False)
-
-    if 'ipv6' in firewall:
-        for hook, hook_conf in firewall['ipv6'].items():
-            for prior, prior_conf in firewall['ipv6'][hook].items():
-                output_firewall_name_statistics(hook,prior, prior_conf, ipv6=True)
+    for family in ['ipv4', 'ipv6', 'bridge']:
+        if family in firewall:
+            for hook, hook_conf in firewall[family].items():
+                for prior, prior_conf in firewall[family][hook].items():
+                    output_firewall_name_statistics(family, hook,prior, prior_conf)
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
@@ -425,9 +435,9 @@ if __name__ == '__main__':
 
     if args.action == 'show':
         if not args.rule:
-            show_firewall_name(args.hook, args.priority, args.ipv6)
+            show_firewall_name(args.family, args.hook, args.priority)
         else:
-            show_firewall_rule(args.hook, args.priority, args.rule, args.ipv6)
+            show_firewall_rule(args.family, args.hook, args.priority, args.rule)
     elif args.action == 'show_all':
         show_firewall()
     elif args.action == 'show_family':
-- 
cgit v1.2.3