From b4b491d424fba6f3d417135adc1865e338a480a1 Mon Sep 17 00:00:00 2001
From: sarthurdev <965089+sarthurdev@users.noreply.github.com>
Date: Mon, 31 Oct 2022 21:08:42 +0100
Subject: nat: T1877: T970: Add firewall groups to NAT

---
 src/conf_mode/firewall.py           | 22 +++++++----
 src/conf_mode/nat.py                | 73 +++++++++++++++++++++++++++++++------
 src/helpers/vyos-domain-resolver.py |  1 +
 3 files changed, 78 insertions(+), 18 deletions(-)

(limited to 'src')

diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py
index 2bb765e65..783adec46 100755
--- a/src/conf_mode/firewall.py
+++ b/src/conf_mode/firewall.py
@@ -41,6 +41,7 @@ from vyos import ConfigError
 from vyos import airbag
 airbag.enable()
 
+nat_conf_script = '/usr/libexec/vyos/conf_mode/nat.py'
 policy_route_conf_script = '/usr/libexec/vyos/conf_mode/policy-route.py'
 
 nftables_conf = '/run/nftables.conf'
@@ -158,7 +159,7 @@ def get_config(config=None):
         for zone in firewall['zone']:
             firewall['zone'][zone] = dict_merge(default_values, firewall['zone'][zone])
 
-    firewall['policy_resync'] = bool('group' in firewall or node_changed(conf, base + ['group']))
+    firewall['group_resync'] = bool('group' in firewall or node_changed(conf, base + ['group']))
 
     if 'config_trap' in firewall and firewall['config_trap'] == 'enable':
         diff = get_config_diff(conf)
@@ -463,6 +464,12 @@ def post_apply_trap(firewall):
 
                 cmd(base_cmd + ' '.join(objects))
 
+def resync_nat():
+    # Update nat as firewall groups were updated
+    tmp, out = rc_cmd(nat_conf_script)
+    if tmp > 0:
+        Warning(f'Failed to re-apply nat configuration! {out}')
+
 def resync_policy_route():
     # Update policy route as firewall groups were updated
     tmp, out = rc_cmd(policy_route_conf_script)
@@ -474,19 +481,20 @@ def apply(firewall):
     if install_result == 1:
         raise ConfigError(f'Failed to apply firewall: {output}')
 
+    apply_sysfs(firewall)
+
+    if firewall['group_resync']:
+        resync_nat()
+        resync_policy_route()
+
     # T970 Enable a resolver (systemd daemon) that checks
-    # domain-group addresses and update entries for domains by timeout
+    # domain-group/fqdn addresses and update entries for domains by timeout
     # If router loaded without internet connection or for synchronization
     domain_action = 'stop'
     if dict_search_args(firewall, 'group', 'domain_group') or firewall['ip_fqdn'] or firewall['ip6_fqdn']:
         domain_action = 'restart'
     call(f'systemctl {domain_action} vyos-domain-resolver.service')
 
-    apply_sysfs(firewall)
-
-    if firewall['policy_resync']:
-        resync_policy_route()
-
     if firewall['geoip_updated']:
         # Call helper script to Update set contents
         if 'name' in firewall['geoip_updated'] or 'ipv6_name' in firewall['geoip_updated']:
diff --git a/src/conf_mode/nat.py b/src/conf_mode/nat.py
index 978c043e9..9f8221514 100755
--- a/src/conf_mode/nat.py
+++ b/src/conf_mode/nat.py
@@ -32,6 +32,7 @@ from vyos.util import cmd
 from vyos.util import run
 from vyos.util import check_kmod
 from vyos.util import dict_search
+from vyos.util import dict_search_args
 from vyos.validate import is_addr_assigned
 from vyos.xml import defaults
 from vyos import ConfigError
@@ -47,6 +48,13 @@ else:
 nftables_nat_config = '/run/nftables_nat.conf'
 nftables_static_nat_conf = '/run/nftables_static-nat-rules.nft'
 
+valid_groups = [
+    'address_group',
+    'domain_group',
+    'network_group',
+    'port_group'
+]
+
 def get_handler(json, chain, target):
     """ Get nftable rule handler number of given chain/target combination.
     Handler is required when adding NAT/Conntrack helper targets """
@@ -60,7 +68,7 @@ def get_handler(json, chain, target):
     return None
 
 
-def verify_rule(config, err_msg):
+def verify_rule(config, err_msg, groups_dict):
     """ Common verify steps used for both source and destination NAT """
 
     if (dict_search('translation.port', config) != None or
@@ -78,6 +86,45 @@ def verify_rule(config, err_msg):
                              'statically maps a whole network of addresses onto another\n' \
                              'network of addresses')
 
+    for side in ['destination', 'source']:
+        if side in config:
+            side_conf = config[side]
+
+            if len({'address', 'fqdn'} & set(side_conf)) > 1:
+                raise ConfigError('Only one of address, fqdn or geoip can be specified')
+
+            if 'group' in side_conf:
+                if len({'address_group', 'network_group', 'domain_group'} & set(side_conf['group'])) > 1:
+                    raise ConfigError('Only one address-group, network-group or domain-group can be specified')
+
+                for group in valid_groups:
+                    if group in side_conf['group']:
+                        group_name = side_conf['group'][group]
+                        error_group = group.replace("_", "-")
+
+                        if group in ['address_group', 'network_group', 'domain_group']:
+                            types = [t for t in ['address', 'fqdn'] if t in side_conf]
+                            if types:
+                                raise ConfigError(f'{error_group} and {types[0]} cannot both be defined')
+
+                        if group_name and group_name[0] == '!':
+                            group_name = group_name[1:]
+
+                        group_obj = dict_search_args(groups_dict, group, group_name)
+
+                        if group_obj is None:
+                            raise ConfigError(f'Invalid {error_group} "{group_name}" on firewall rule')
+
+                        if not group_obj:
+                            Warning(f'{error_group} "{group_name}" has no members!')
+
+            if dict_search_args(side_conf, 'group', 'port_group'):
+                if 'protocol' not in config:
+                    raise ConfigError('Protocol must be defined if specifying a port-group')
+
+                if config['protocol'] not in ['tcp', 'udp', 'tcp_udp']:
+                    raise ConfigError('Protocol must be tcp, udp, or tcp_udp when specifying a port-group')
+
 def get_config(config=None):
     if config:
         conf = config
@@ -105,16 +152,20 @@ def get_config(config=None):
     condensed_json = jmespath.search(pattern, nftable_json)
 
     if not conf.exists(base):
-        nat['helper_functions'] = 'remove'
-
-        # Retrieve current table handler positions
-        nat['pre_ct_ignore'] = get_handler(condensed_json, 'PREROUTING', 'VYOS_CT_HELPER')
-        nat['pre_ct_conntrack'] = get_handler(condensed_json, 'PREROUTING', 'NAT_CONNTRACK')
-        nat['out_ct_ignore'] = get_handler(condensed_json, 'OUTPUT', 'VYOS_CT_HELPER')
-        nat['out_ct_conntrack'] = get_handler(condensed_json, 'OUTPUT', 'NAT_CONNTRACK')
+        if get_handler(condensed_json, 'PREROUTING', 'VYOS_CT_HELPER'):
+            nat['helper_functions'] = 'remove'
+
+            # Retrieve current table handler positions
+            nat['pre_ct_ignore'] = get_handler(condensed_json, 'PREROUTING', 'VYOS_CT_HELPER')
+            nat['pre_ct_conntrack'] = get_handler(condensed_json, 'PREROUTING', 'NAT_CONNTRACK')
+            nat['out_ct_ignore'] = get_handler(condensed_json, 'OUTPUT', 'VYOS_CT_HELPER')
+            nat['out_ct_conntrack'] = get_handler(condensed_json, 'OUTPUT', 'NAT_CONNTRACK')
         nat['deleted'] = ''
         return nat
 
+    nat['firewall_group'] = conf.get_config_dict(['firewall', 'group'], key_mangling=('-', '_'), get_first_key=True,
+                                    no_tag_node_value_mangle=True)
+
     # check if NAT connection tracking helpers need to be set up - this has to
     # be done only once
     if not get_handler(condensed_json, 'PREROUTING', 'NAT_CONNTRACK'):
@@ -157,7 +208,7 @@ def verify(nat):
                         Warning(f'IP address {ip} does not exist on the system!')
 
             # common rule verification
-            verify_rule(config, err_msg)
+            verify_rule(config, err_msg, nat['firewall_group'])
 
 
     if dict_search('destination.rule', nat):
@@ -175,7 +226,7 @@ def verify(nat):
                     raise ConfigError(f'{err_msg} translation requires address and/or port')
 
             # common rule verification
-            verify_rule(config, err_msg)
+            verify_rule(config, err_msg, nat['firewall_group'])
 
     if dict_search('static.rule', nat):
         for rule, config in dict_search('static.rule', nat).items():
@@ -186,7 +237,7 @@ def verify(nat):
                                   'inbound-interface not specified')
 
             # common rule verification
-            verify_rule(config, err_msg)
+            verify_rule(config, err_msg, nat['firewall_group'])
 
     return None
 
diff --git a/src/helpers/vyos-domain-resolver.py b/src/helpers/vyos-domain-resolver.py
index 2f71f15db..035c208b2 100755
--- a/src/helpers/vyos-domain-resolver.py
+++ b/src/helpers/vyos-domain-resolver.py
@@ -37,6 +37,7 @@ domain_state = {}
 ipv4_tables = {
     'ip mangle',
     'ip vyos_filter',
+    'ip vyos_nat'
 }
 
 ipv6_tables = {
-- 
cgit v1.2.3