From cd5316c266656e1532c0a2166ba892bd1c807743 Mon Sep 17 00:00:00 2001
From: Nicolas Fort <nicolasfort1988@gmail.com>
Date: Thu, 26 Oct 2023 10:11:14 +0000
Subject: T5513: T5564: update op-mode command show firewall. Counter available
 for default actions and extend references for firewall groups

---
 op-mode-definitions/nat66.xml.in |   8 +--
 src/op_mode/firewall.py          | 147 ++++++++++++++++++++++++++++++++-------
 2 files changed, 124 insertions(+), 31 deletions(-)

diff --git a/op-mode-definitions/nat66.xml.in b/op-mode-definitions/nat66.xml.in
index 6a8a39000..4df20d847 100644
--- a/op-mode-definitions/nat66.xml.in
+++ b/op-mode-definitions/nat66.xml.in
@@ -16,7 +16,7 @@
                 <properties>
                   <help>Show configured source NAT66 rules</help>
                 </properties>
-                <command>${vyos_op_scripts_dir}/nat.py show_rules --direction source --family inet6</command>
+                <command>sudo ${vyos_op_scripts_dir}/nat.py show_rules --direction source --family inet6</command>
               </node>
               <node name="statistics">
                 <properties>
@@ -39,7 +39,7 @@
                     <command>sudo ${vyos_op_scripts_dir}/nat.py show_translations --direction source --family inet6 --address "$6"</command>
                   </tagNode>
                 </children>
-                <command>${vyos_op_scripts_dir}/nat.py show_translations --direction source --family inet6</command>
+                <command>sudo ${vyos_op_scripts_dir}/nat.py show_translations --direction source --family inet6</command>
               </node>
             </children>
           </node>
@@ -52,7 +52,7 @@
                 <properties>
                   <help>Show configured destination NAT66 rules</help>
                 </properties>
-                <command>${vyos_op_scripts_dir}/nat.py show_rules --direction destination --family inet6</command>
+                <command>sudo ${vyos_op_scripts_dir}/nat.py show_rules --direction destination --family inet6</command>
               </node>
               <node name="statistics">
                 <properties>
@@ -75,7 +75,7 @@
                     <command>sudo ${vyos_op_scripts_dir}/nat.py show_translations --direction destination --family inet6 --address "$6"</command>
                   </tagNode>
                 </children>
-                <command>${vyos_op_scripts_dir}/nat.py show_translations --direction destination --family inet6</command>
+                <command>sudo ${vyos_op_scripts_dir}/nat.py show_translations --direction destination --family inet6</command>
               </node>
             </children>
           </node>
diff --git a/src/op_mode/firewall.py b/src/op_mode/firewall.py
index 3434707ec..20f54b9ba 100755
--- a/src/op_mode/firewall.py
+++ b/src/op_mode/firewall.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 #
-# Copyright (C) 2021 VyOS maintainers and contributors
+# Copyright (C) 2023 VyOS maintainers and contributors
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License version 2 or later as
@@ -24,19 +24,28 @@ from vyos.config import Config
 from vyos.utils.process import cmd
 from vyos.utils.dict import dict_search_args
 
-def get_config_firewall(conf, family=None, hook=None, priority=None):
-    config_path = ['firewall']
-    if family:
-        config_path += [family]
-        if hook:
-            config_path += [hook]
-            if priority:
-                config_path += [priority]
+def get_config_node(conf, node=None, family=None, hook=None, priority=None):
+    if node == 'nat':
+        if family == 'ipv6':
+            config_path = ['nat66']
+        else:
+            config_path = ['nat']
 
-    firewall = conf.get_config_dict(config_path, key_mangling=('-', '_'),
+    elif node == 'policy':
+        config_path = ['policy']
+    else:
+        config_path = ['firewall']
+        if family:
+            config_path += [family]
+            if hook:
+                config_path += [hook]
+                if priority:
+                    config_path += [priority]
+
+    node_config = conf.get_config_dict(config_path, key_mangling=('-', '_'),
                                 get_first_key=True, no_tag_node_value_mangle=True)
 
-    return firewall
+    return node_config
 
 def get_nftables_details(family, hook, priority):
     if family == 'ipv6':
@@ -102,7 +111,15 @@ def output_firewall_name(family, hook, priority, firewall_conf, single_rule_id=N
                 row.append(rule_details['conditions'])
             rows.append(row)
 
-    if 'default_action' in firewall_conf and not single_rule_id:
+    if hook in ['input', 'forward', 'output']:
+        def_action = firewall_conf['default_action'] if 'default_action' in firewall_conf else 'accept'
+        row = ['default', def_action, 'all']
+        rule_details = details['default-action']
+        row.append(rule_details.get('packets', 0))
+        row.append(rule_details.get('bytes', 0))
+        rows.append(row)
+
+    elif 'default_action' in firewall_conf and not single_rule_id:
         row = ['default', firewall_conf['default_action'], 'all']
         if 'default-action' in details:
             rule_details = details['default-action']
@@ -167,16 +184,16 @@ def output_firewall_name_statistics(family, hook, prior, prior_conf, single_rule
                                     dest_addr = 'any'
 
             # Get inbound interface
-            iiface = dict_search_args(rule_conf, 'inbound_interface', 'interface_name')
+            iiface = dict_search_args(rule_conf, 'inbound_interface', 'name')
             if not iiface:
-                iiface = dict_search_args(rule_conf, 'inbound_interface', 'interface_group')
+                iiface = dict_search_args(rule_conf, 'inbound_interface', 'group')
                 if not iiface:
                     iiface = 'any'
 
             # Get outbound interface
-            oiface = dict_search_args(rule_conf, 'outbound_interface', 'interface_name')
+            oiface = dict_search_args(rule_conf, 'outbound_interface', 'name')
             if not oiface:
-                oiface = dict_search_args(rule_conf, 'outbound_interface', 'interface_group')
+                oiface = dict_search_args(rule_conf, 'outbound_interface', 'group')
                 if not oiface:
                     oiface = 'any'
 
@@ -198,8 +215,9 @@ def output_firewall_name_statistics(family, hook, prior, prior_conf, single_rule
 
     if hook in ['input', 'forward', 'output']:
         row = ['default']
-        row.append('N/A')
-        row.append('N/A')
+        rule_details = details['default-action']
+        row.append(rule_details.get('packets', 0))
+        row.append(rule_details.get('bytes', 0))
         if 'default_action' in prior_conf:
             row.append(prior_conf['default_action'])
         else:
@@ -234,7 +252,7 @@ def show_firewall():
     print('Rulesets Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf)
+    firewall = get_config_node(conf)
 
     if not firewall:
         return
@@ -249,7 +267,7 @@ def show_firewall_family(family):
     print(f'Rulesets {family} Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf)
+    firewall = get_config_node(conf)
 
     if not firewall or family not in firewall:
         return
@@ -262,7 +280,7 @@ def show_firewall_name(family, hook, priority):
     print('Ruleset Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf, family, hook, priority)
+    firewall = get_config_node(conf, 'firewall', family, hook, priority)
     if firewall:
         output_firewall_name(family, hook, priority, firewall)
 
@@ -270,17 +288,20 @@ def show_firewall_rule(family, hook, priority, rule_id):
     print('Rule Information')
 
     conf = Config()
-    firewall = get_config_firewall(conf, family, hook, priority)
+    firewall = get_config_node(conf, 'firewall', family, hook, priority)
     if firewall:
         output_firewall_name(family, hook, priority, firewall, rule_id)
 
 def show_firewall_group(name=None):
     conf = Config()
-    firewall = get_config_firewall(conf)
+    firewall = get_config_node(conf, node='firewall')
 
     if 'group' not in firewall:
         return
 
+    nat = get_config_node(conf, node='nat')
+    policy = get_config_node(conf, node='policy')
+
     def find_references(group_type, group_name):
         out = []
         family = []
@@ -296,6 +317,7 @@ def show_firewall_group(name=None):
             family = ['ipv4', 'ipv6']
 
         for item in family:
+            # Look references in firewall
             for name_type in ['name', 'ipv6_name', 'forward', 'input', 'output']:
                 if item in firewall:
                     if name_type not in firewall[item]:
@@ -308,8 +330,8 @@ def show_firewall_group(name=None):
                         for rule_id, rule_conf in priority_conf['rule'].items():
                             source_group = dict_search_args(rule_conf, 'source', 'group', group_type)
                             dest_group = dict_search_args(rule_conf, 'destination', 'group', group_type)
-                            in_interface = dict_search_args(rule_conf, 'inbound_interface', 'interface_group')
-                            out_interface = dict_search_args(rule_conf, 'outbound_interface', 'interface_group')
+                            in_interface = dict_search_args(rule_conf, 'inbound_interface', 'group')
+                            out_interface = dict_search_args(rule_conf, 'outbound_interface', 'group')
                             if source_group:
                                 if source_group[0] == "!":
                                     source_group = source_group[1:]
@@ -330,6 +352,76 @@ def show_firewall_group(name=None):
                                     out_interface = out_interface[1:]
                                 if group_name == out_interface:
                                     out.append(f'{item}-{name_type}-{priority}-{rule_id}')
+
+            # Look references in route | route6
+            for name_type in ['route', 'route6']:
+                if name_type not in policy:
+                    continue
+                if name_type == 'route' and item == 'ipv6':
+                    continue
+                elif name_type == 'route6' and item == 'ipv4':
+                    continue
+                else:
+                    for policy_name, policy_conf in policy[name_type].items():
+                        if 'rule' not in policy_conf:
+                            continue
+                        for rule_id, rule_conf in policy_conf['rule'].items():
+                            source_group = dict_search_args(rule_conf, 'source', 'group', group_type)
+                            dest_group = dict_search_args(rule_conf, 'destination', 'group', group_type)
+                            in_interface = dict_search_args(rule_conf, 'inbound_interface', 'group')
+                            out_interface = dict_search_args(rule_conf, 'outbound_interface', 'group')
+                            if source_group:
+                                if source_group[0] == "!":
+                                    source_group = source_group[1:]
+                                if group_name == source_group:
+                                    out.append(f'{name_type}-{policy_name}-{rule_id}')
+                            if dest_group:
+                                if dest_group[0] == "!":
+                                    dest_group = dest_group[1:]
+                                if group_name == dest_group:
+                                    out.append(f'{name_type}-{policy_name}-{rule_id}')
+                            if in_interface:
+                                if in_interface[0] == "!":
+                                    in_interface = in_interface[1:]
+                                if group_name == in_interface:
+                                    out.append(f'{name_type}-{policy_name}-{rule_id}')
+                            if out_interface:
+                                if out_interface[0] == "!":
+                                    out_interface = out_interface[1:]
+                                if group_name == out_interface:
+                                    out.append(f'{name_type}-{policy_name}-{rule_id}')
+
+        ## Look references in nat table
+        for direction in ['source', 'destination']:
+            if direction in nat:
+                if 'rule' not in nat[direction]:
+                    continue
+                for rule_id, rule_conf in nat[direction]['rule'].items():
+                    source_group = dict_search_args(rule_conf, 'source', 'group', group_type)
+                    dest_group = dict_search_args(rule_conf, 'destination', 'group', group_type)
+                    in_interface = dict_search_args(rule_conf, 'inbound_interface', 'group')
+                    out_interface = dict_search_args(rule_conf, 'outbound_interface', 'group')
+                    if source_group:
+                        if source_group[0] == "!":
+                            source_group = source_group[1:]
+                        if group_name == source_group:
+                            out.append(f'nat-{direction}-{rule_id}')
+                    if dest_group:
+                        if dest_group[0] == "!":
+                            dest_group = dest_group[1:]
+                        if group_name == dest_group:
+                            out.append(f'nat-{direction}-{rule_id}')
+                    if in_interface:
+                        if in_interface[0] == "!":
+                            in_interface = in_interface[1:]
+                        if group_name == in_interface:
+                            out.append(f'nat-{direction}-{rule_id}')
+                    if out_interface:
+                        if out_interface[0] == "!":
+                            out_interface = out_interface[1:]
+                        if group_name == out_interface:
+                            out.append(f'nat-{direction}-{rule_id}')
+
         return out
 
     header = ['Name', 'Type', 'References', 'Members']
@@ -356,6 +448,7 @@ def show_firewall_group(name=None):
                 row.append('N/D')
             rows.append(row)
 
+
     if rows:
         print('Firewall Groups\n')
         print(tabulate.tabulate(rows, header))
@@ -364,7 +457,7 @@ def show_summary():
     print('Ruleset Summary')
 
     conf = Config()
-    firewall = get_config_firewall(conf)
+    firewall = get_config_node(conf)
 
     if not firewall:
         return
@@ -410,7 +503,7 @@ def show_statistics():
     print('Rulesets Statistics')
 
     conf = Config()
-    firewall = get_config_firewall(conf)
+    firewall = get_config_node(conf)
 
     if not firewall:
         return
-- 
cgit v1.2.3