From 9a5dfb4b7ec9e065a73511a38e1713aec03eee0e Mon Sep 17 00:00:00 2001
From: Nicolas Fort <nicolasfort1988@gmail.com>
Date: Fri, 28 Oct 2022 18:19:47 +0000
Subject: T4780: Firewall: add firewall groups in firewall. Extend matching
 criteria so this new group can be used in inbound and outbound matcher

---
 src/conf_mode/firewall.py             |   3 +-
 src/conf_mode/policy-route.py         | 109 ++--------------------------------
 src/migration-scripts/firewall/8-to-9 |  91 ++++++++++++++++++++++++++++
 3 files changed, 98 insertions(+), 105 deletions(-)
 create mode 100755 src/migration-scripts/firewall/8-to-9

(limited to 'src')

diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py
index cbd9cbe90..dcdbf8fab 100755
--- a/src/conf_mode/firewall.py
+++ b/src/conf_mode/firewall.py
@@ -67,7 +67,8 @@ valid_groups = [
     'address_group',
     'domain_group',
     'network_group',
-    'port_group'
+    'port_group',
+    'interface_group'
 ]
 
 nested_group_types = [
diff --git a/src/conf_mode/policy-route.py b/src/conf_mode/policy-route.py
index 00539b9c7..40a32efb3 100755
--- a/src/conf_mode/policy-route.py
+++ b/src/conf_mode/policy-route.py
@@ -15,7 +15,6 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 import os
-import re
 
 from json import loads
 from sys import exit
@@ -25,7 +24,6 @@ from vyos.config import Config
 from vyos.template import render
 from vyos.util import cmd
 from vyos.util import dict_search_args
-from vyos.util import dict_search_recursive
 from vyos.util import run
 from vyos import ConfigError
 from vyos import airbag
@@ -34,48 +32,14 @@ airbag.enable()
 mark_offset = 0x7FFFFFFF
 nftables_conf = '/run/nftables_policy.conf'
 
-ROUTE_PREFIX = 'VYOS_PBR_'
-ROUTE6_PREFIX = 'VYOS_PBR6_'
-
-preserve_chains = [
-    'VYOS_PBR_PREROUTING',
-    'VYOS_PBR_POSTROUTING',
-    'VYOS_PBR6_PREROUTING',
-    'VYOS_PBR6_POSTROUTING'
-]
-
 valid_groups = [
     'address_group',
+    'domain_group',
     'network_group',
-    'port_group'
+    'port_group',
+    'interface_group'
 ]
 
-group_set_prefix = {
-    'A_': 'address_group',
-    'A6_': 'ipv6_address_group',
-#    'D_': 'domain_group',
-    'M_': 'mac_group',
-    'N_': 'network_group',
-    'N6_': 'ipv6_network_group',
-    'P_': 'port_group'
-}
-
-def get_policy_interfaces(conf):
-    out = {}
-    interfaces = conf.get_config_dict(['interfaces'], key_mangling=('-', '_'), get_first_key=True,
-                                    no_tag_node_value_mangle=True)
-    def find_interfaces(iftype_conf, output={}, prefix=''):
-        for ifname, if_conf in iftype_conf.items():
-            if 'policy' in if_conf:
-                output[prefix + ifname] = if_conf['policy']
-            for vif in ['vif', 'vif_s', 'vif_c']:
-                if vif in if_conf:
-                    output.update(find_interfaces(if_conf[vif], output, f'{prefix}{ifname}.'))
-        return output
-    for iftype, iftype_conf in interfaces.items():
-        out.update(find_interfaces(iftype_conf))
-    return out
-
 def get_config(config=None):
     if config:
         conf = config
@@ -88,7 +52,6 @@ def get_config(config=None):
 
     policy['firewall_group'] = conf.get_config_dict(['firewall', 'group'], key_mangling=('-', '_'), get_first_key=True,
                                     no_tag_node_value_mangle=True)
-    policy['interfaces'] = get_policy_interfaces(conf)
 
     return policy
 
@@ -132,8 +95,8 @@ def verify_rule(policy, name, rule_conf, ipv6, rule_id):
             side_conf = rule_conf[side]
 
             if 'group' in side_conf:
-                if {'address_group', 'network_group'} <= set(side_conf['group']):
-                    raise ConfigError('Only one address-group or network-group can be specified')
+                if len({'address_group', 'domain_group', 'network_group'} & set(side_conf['group'])) > 1:
+                    raise ConfigError('Only one address-group, domain-group or network-group can be specified')
 
                 for group in valid_groups:
                     if group in side_conf['group']:
@@ -168,73 +131,11 @@ def verify(policy):
                     for rule_id, rule_conf in pol_conf['rule'].items():
                         verify_rule(policy, name, rule_conf, ipv6, rule_id)
 
-    for ifname, if_policy in policy['interfaces'].items():
-        name = dict_search_args(if_policy, 'route')
-        ipv6_name = dict_search_args(if_policy, 'route6')
-
-        if name and not dict_search_args(policy, 'route', name):
-            raise ConfigError(f'Policy route "{name}" is still referenced on interface {ifname}')
-
-        if ipv6_name and not dict_search_args(policy, 'route6', ipv6_name):
-            raise ConfigError(f'Policy route6 "{ipv6_name}" is still referenced on interface {ifname}')
-
     return None
 
-def cleanup_commands(policy):
-    commands = []
-    commands_chains = []
-    commands_sets = []
-    for table in ['ip mangle', 'ip6 mangle']:
-        route_node = 'route' if table == 'ip mangle' else 'route6'
-        chain_prefix = ROUTE_PREFIX if table == 'ip mangle' else ROUTE6_PREFIX
-
-        json_str = cmd(f'nft -t -j list table {table}')
-        obj = loads(json_str)
-        if 'nftables' not in obj:
-            continue
-        for item in obj['nftables']:
-            if 'chain' in item:
-                chain = item['chain']['name']
-                if chain in preserve_chains or not chain.startswith("VYOS_PBR"):
-                    continue
-
-                if dict_search_args(policy, route_node, chain.replace(chain_prefix, "", 1)) != None:
-                    commands.append(f'flush chain {table} {chain}')
-                else:
-                    commands_chains.append(f'delete chain {table} {chain}')
-
-            if 'rule' in item:
-                rule = item['rule']
-                chain = rule['chain']
-                handle = rule['handle']
-
-                if chain not in preserve_chains:
-                    continue
-
-                target, _ = next(dict_search_recursive(rule['expr'], 'target'))
-
-                if target.startswith(chain_prefix):
-                    if dict_search_args(policy, route_node, target.replace(chain_prefix, "", 1)) == None:
-                        commands.append(f'delete rule {table} {chain} handle {handle}')
-
-            if 'set' in item:
-                set_name = item['set']['name']
-
-                for prefix, group_type in group_set_prefix.items():
-                    if set_name.startswith(prefix):
-                        group_name = set_name.replace(prefix, "", 1)
-                        if dict_search_args(policy, 'firewall_group', group_type, group_name) != None:
-                            commands_sets.append(f'flush set {table} {set_name}')
-                        else:
-                            commands_sets.append(f'delete set {table} {set_name}')
-
-    return commands + commands_chains + commands_sets
-
 def generate(policy):
     if not os.path.exists(nftables_conf):
         policy['first_install'] = True
-    else:
-        policy['cleanup_commands'] = cleanup_commands(policy)
 
     render(nftables_conf, 'firewall/nftables-policy.j2', policy)
     return None
diff --git a/src/migration-scripts/firewall/8-to-9 b/src/migration-scripts/firewall/8-to-9
new file mode 100755
index 000000000..f7c1bb90d
--- /dev/null
+++ b/src/migration-scripts/firewall/8-to-9
@@ -0,0 +1,91 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+# T4780: Add firewall interface group
+#  cli changes from:       
+#  set firewall [name | ipv6-name] <name> rule <number> [inbound-interface | outbound-interface] <interface_name>
+#  To
+#  set firewall [name | ipv6-name] <name> rule <number> [inbound-interface | outbound-interface]  [interface-name | interface-group] <interface_name | interface_group>
+
+import re
+
+from sys import argv
+from sys import exit
+
+from vyos.configtree import ConfigTree
+from vyos.ifconfig import Section
+
+if (len(argv) < 1):
+    print("Must specify file name!")
+    exit(1)
+
+file_name = argv[1]
+
+with open(file_name, 'r') as f:
+    config_file = f.read()
+
+base = ['firewall']
+config = ConfigTree(config_file)
+
+if not config.exists(base):
+    # Nothing to do
+    exit(0)
+
+if config.exists(base + ['name']):
+    for name in config.list_nodes(base + ['name']):
+        if not config.exists(base + ['name', name, 'rule']):
+            continue
+
+        for rule in config.list_nodes(base + ['name', name, 'rule']):
+            rule_iiface = base + ['name', name, 'rule', rule, 'inbound-interface']
+            rule_oiface = base + ['name', name, 'rule', rule, 'outbound-interface']
+
+            if config.exists(rule_iiface):
+                tmp = config.return_value(rule_iiface)
+                config.delete(rule_iiface)
+                config.set(rule_iiface + ['interface-name'], value=tmp)
+
+            if config.exists(rule_oiface):
+                tmp = config.return_value(rule_oiface)
+                config.delete(rule_oiface)
+                config.set(rule_oiface + ['interface-name'], value=tmp)
+
+
+if config.exists(base + ['ipv6-name']):
+    for name in config.list_nodes(base + ['ipv6-name']):
+        if not config.exists(base + ['ipv6-name', name, 'rule']):
+            continue
+
+        for rule in config.list_nodes(base + ['ipv6-name', name, 'rule']):
+            rule_iiface = base + ['ipv6-name', name, 'rule', rule, 'inbound-interface']
+            rule_oiface = base + ['ipv6-name', name, 'rule', rule, 'outbound-interface']
+
+            if config.exists(rule_iiface):
+                tmp = config.return_value(rule_iiface)
+                config.delete(rule_iiface)
+                config.set(rule_iiface + ['interface-name'], value=tmp)
+
+            if config.exists(rule_oiface):
+                tmp = config.return_value(rule_oiface)
+                config.delete(rule_oiface)
+                config.set(rule_oiface + ['interface-name'], value=tmp)
+
+try:
+    with open(file_name, 'w') as f:
+        f.write(config.to_string())
+except OSError as e:
+    print("Failed to save the modified config: {}".format(e))
+    exit(1)
\ No newline at end of file
-- 
cgit v1.2.3