From 34db435e7a74ee8509777802e03927de2dd57627 Mon Sep 17 00:00:00 2001
From: sarthurdev <965089+sarthurdev@users.noreply.github.com>
Date: Mon, 13 Jun 2022 01:45:06 +0200
Subject: firewall: T4147: Use named sets for firewall groups

* Refactor nftables clean-up code
* Adds policy route test for using firewall groups
---
 python/vyos/firewall.py |  8 ++++----
 python/vyos/template.py | 39 +++++++++++++--------------------------
 2 files changed, 17 insertions(+), 30 deletions(-)

(limited to 'python')

diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py
index a61d0a9f8..f8f913944 100644
--- a/python/vyos/firewall.py
+++ b/python/vyos/firewall.py
@@ -192,7 +192,7 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name):
                     if group_name[0] == '!':
                         operator = '!='
                         group_name = group_name[1:]
-                    output.append(f'{ip_name} {prefix}addr {operator} $A{def_suffix}_{group_name}')
+                    output.append(f'{ip_name} {prefix}addr {operator} @A{def_suffix}_{group_name}')
                 # Generate firewall group domain-group
                 elif 'domain_group' in group:
                     group_name = group['domain_group']
@@ -207,14 +207,14 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name):
                     if group_name[0] == '!':
                         operator = '!='
                         group_name = group_name[1:]
-                    output.append(f'{ip_name} {prefix}addr {operator} $N{def_suffix}_{group_name}')
+                    output.append(f'{ip_name} {prefix}addr {operator} @N{def_suffix}_{group_name}')
                 if 'mac_group' in group:
                     group_name = group['mac_group']
                     operator = ''
                     if group_name[0] == '!':
                         operator = '!='
                         group_name = group_name[1:]
-                    output.append(f'ether {prefix}addr {operator} $M_{group_name}')
+                    output.append(f'ether {prefix}addr {operator} @M_{group_name}')
                 if 'port_group' in group:
                     proto = rule_conf['protocol']
                     group_name = group['port_group']
@@ -227,7 +227,7 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name):
                         operator = '!='
                         group_name = group_name[1:]
 
-                    output.append(f'{proto} {prefix}port {operator} $P_{group_name}')
+                    output.append(f'{proto} {prefix}port {operator} @P_{group_name}')
 
     if 'log' in rule_conf and rule_conf['log'] == 'enable':
         action = rule_conf['action'] if 'action' in rule_conf else 'accept'
diff --git a/python/vyos/template.py b/python/vyos/template.py
index 3feda47c8..eb7f06480 100644
--- a/python/vyos/template.py
+++ b/python/vyos/template.py
@@ -592,37 +592,24 @@ def nft_intra_zone_action(zone_conf, ipv6=False):
     return 'return'
 
 @register_filter('nft_nested_group')
-def nft_nested_group(out_list, includes, prefix):
+def nft_nested_group(out_list, includes, groups, key):
     if not vyos_defined(out_list):
         out_list = []
-    for name in includes:
-        out_list.append(f'${prefix}{name}')
-    return out_list
-
-@register_filter('sort_nested_groups')
-def sort_nested_groups(groups):
-    seen = []
-    out = {}
-
-    def include_iterate(group_name):
-        group = groups[group_name]
-        if 'include' not in group:
-            if group_name not in out:
-                out[group_name] = groups[group_name]
-            return
 
-        for inc_group_name in group['include']:
-            if inc_group_name not in seen:
-                seen.append(inc_group_name)
-                include_iterate(inc_group_name)
+    def add_includes(name):
+        if key in groups[name]:
+            for item in groups[name][key]:
+                if item in out_list:
+                    continue
+                out_list.append(item)
 
-        if group_name not in out:
-            out[group_name] = groups[group_name]
+        if 'include' in groups[name]:
+            for name_inc in groups[name]['include']:
+                add_includes(name_inc)
 
-    for group_name in groups:
-        include_iterate(group_name)
-
-    return out.items()
+    for name in includes:
+        add_includes(name)
+    return out_list
 
 @register_test('vyos_defined')
 def vyos_defined(value, test_value=None, var_type=None):
-- 
cgit v1.2.3