From 8dcb042bb2352717395ba3c17bc5437534c83af5 Mon Sep 17 00:00:00 2001
From: Nicolas Fort <nicolasfort1988@gmail.com>
Date: Fri, 30 Aug 2024 17:54:17 +0000
Subject: T6687: add fqdn support to nat rules.

(cherry picked from commit 4c3d037f036e84c77333a400b35bb1a628a1a118)
---
 data/templates/firewall/nftables-nat.j2      |  13 ++++
 interface-definitions/include/nat-rule.xml.i |   2 +
 python/vyos/firewall.py                      |  47 ++++++------
 python/vyos/nat.py                           |   7 ++
 smoketest/scripts/cli/test_nat.py            |  26 +++++++
 src/conf_mode/firewall.py                    |  21 ++++--
 src/conf_mode/nat.py                         |  20 +++++
 src/helpers/vyos-domain-resolver.py          | 107 +++++++++++++++------------
 src/systemd/vyos-domain-resolver.service     |   1 +
 9 files changed, 167 insertions(+), 77 deletions(-)

diff --git a/data/templates/firewall/nftables-nat.j2 b/data/templates/firewall/nftables-nat.j2
index 4254f6a0e..8c8dd3a8b 100644
--- a/data/templates/firewall/nftables-nat.j2
+++ b/data/templates/firewall/nftables-nat.j2
@@ -19,6 +19,12 @@ table ip vyos_nat {
 {%         endfor %}
 {%     endif %}
     }
+{%     for set_name in ip_fqdn %}
+    set FQDN_nat_{{ set_name }} {
+        type ipv4_addr
+        flags interval
+    }
+{%     endfor %}
 
     #
     # Source NAT rules build up here
@@ -31,7 +37,14 @@ table ip vyos_nat {
         {{ config | nat_rule(rule, 'source') }}
 {%         endfor %}
 {%     endif %}
+
+    }
+{%     for set_name in ip_fqdn %}
+    set FQDN_nat_{{ set_name }} {
+        type ipv4_addr
+        flags interval
     }
+{%     endfor %}
 
     chain VYOS_PRE_DNAT_HOOK {
         return
diff --git a/interface-definitions/include/nat-rule.xml.i b/interface-definitions/include/nat-rule.xml.i
index deb13529d..0a7179ff1 100644
--- a/interface-definitions/include/nat-rule.xml.i
+++ b/interface-definitions/include/nat-rule.xml.i
@@ -18,6 +18,7 @@
         <help>NAT destination parameters</help>
       </properties>
       <children>
+        #include <include/firewall/fqdn.xml.i>
         #include <include/nat-address.xml.i>
         #include <include/nat-port.xml.i>
         #include <include/firewall/source-destination-group.xml.i>
@@ -315,6 +316,7 @@
         <help>NAT source parameters</help>
       </properties>
       <children>
+        #include <include/firewall/fqdn.xml.i>
         #include <include/nat-address.xml.i>
         #include <include/nat-port.xml.i>
         #include <include/firewall/source-destination-group.xml.i>
diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py
index 8913ba152..fe4326807 100755
--- a/python/vyos/firewall.py
+++ b/python/vyos/firewall.py
@@ -50,25 +50,32 @@ def conntrack_required(conf):
 
 # Domain Resolver
 
-def fqdn_config_parse(firewall):
-    firewall['ip_fqdn'] = {}
-    firewall['ip6_fqdn'] = {}
-
-    for domain, path in dict_search_recursive(firewall, 'fqdn'):
-        hook_name = path[1]
-        priority = path[2]
-
-        fw_name = path[2]
-        rule = path[4]
-        suffix = path[5][0]
-        set_name = f'{hook_name}_{priority}_{rule}_{suffix}'
-
-        if (path[0] == 'ipv4') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'):
-            firewall['ip_fqdn'][set_name] = domain
-        elif (path[0] == 'ipv6') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'):
-            if path[1] == 'name':
-                set_name = f'name6_{priority}_{rule}_{suffix}'
-            firewall['ip6_fqdn'][set_name] = domain
+def fqdn_config_parse(config, node):
+    config['ip_fqdn'] = {}
+    config['ip6_fqdn'] = {}
+
+    for domain, path in dict_search_recursive(config, 'fqdn'):
+        if node != 'nat':
+            hook_name = path[1]
+            priority = path[2]
+
+            rule = path[4]
+            suffix = path[5][0]
+            set_name = f'{hook_name}_{priority}_{rule}_{suffix}'
+
+            if (path[0] == 'ipv4') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'):
+                config['ip_fqdn'][set_name] = domain
+            elif (path[0] == 'ipv6') and (path[1] == 'forward' or path[1] == 'input' or path[1] == 'output' or path[1] == 'name'):
+                if path[1] == 'name':
+                    set_name = f'name6_{priority}_{rule}_{suffix}'
+                config['ip6_fqdn'][set_name] = domain
+        else:
+            # Parse FQDN for NAT
+            nat_direction = path[0]
+            nat_rule = path[2]
+            suffix = path[3][0]
+            set_name = f'{nat_direction}_{nat_rule}_{suffix}'
+            config['ip_fqdn'][set_name] = domain
 
 def fqdn_resolve(fqdn, ipv6=False):
     try:
@@ -77,8 +84,6 @@ def fqdn_resolve(fqdn, ipv6=False):
     except:
         return None
 
-# End Domain Resolver
-
 def find_nftables_rule(table, chain, rule_matches=[]):
     # Find rule in table/chain that matches all criteria and return the handle
     results = cmd(f'sudo nft --handle list chain {table} {chain}').split("\n")
diff --git a/python/vyos/nat.py b/python/vyos/nat.py
index e54548788..4fe21ef13 100644
--- a/python/vyos/nat.py
+++ b/python/vyos/nat.py
@@ -236,6 +236,13 @@ def parse_nat_rule(rule_conf, rule_id, nat_type, ipv6=False):
 
                 output.append(f'{proto} {prefix}port {operator} @P_{group_name}')
 
+        if 'fqdn' in side_conf:
+            fqdn = side_conf['fqdn']
+            operator = ''
+            if fqdn[0] == '!':
+                operator = '!='
+            output.append(f' ip {prefix}addr {operator} @FQDN_nat_{nat_type}_{rule_id}_{prefix}')
+
     output.append('counter')
 
     if 'log' in rule_conf:
diff --git a/smoketest/scripts/cli/test_nat.py b/smoketest/scripts/cli/test_nat.py
index 5161e47fd..0beafcc6c 100755
--- a/smoketest/scripts/cli/test_nat.py
+++ b/smoketest/scripts/cli/test_nat.py
@@ -304,5 +304,31 @@ class TestNAT(VyOSUnitTestSHIM.TestCase):
 
         self.verify_nftables(nftables_search, 'ip vyos_nat')
 
+    def test_nat_fqdn(self):
+        source_domain = 'vyos.dev'
+        destination_domain = 'vyos.io'
+
+        self.cli_set(src_path + ['rule', '1', 'outbound-interface', 'name', 'eth0'])
+        self.cli_set(src_path + ['rule', '1', 'source', 'fqdn', source_domain])
+        self.cli_set(src_path + ['rule', '1', 'translation', 'address', 'masquerade'])
+
+        self.cli_set(dst_path + ['rule', '1', 'destination', 'fqdn', destination_domain])
+        self.cli_set(dst_path + ['rule', '1', 'source', 'fqdn', source_domain])
+        self.cli_set(dst_path + ['rule', '1', 'destination', 'port', '5122'])
+        self.cli_set(dst_path + ['rule', '1', 'protocol', 'tcp'])
+        self.cli_set(dst_path + ['rule', '1', 'translation', 'address', '198.51.100.1'])
+        self.cli_set(dst_path + ['rule', '1', 'translation', 'port', '22'])
+
+
+        self.cli_commit()
+
+        nftables_search = [
+            ['set FQDN_nat_destination_1_d'],
+            ['set FQDN_nat_source_1_s'],
+            ['oifname "eth0"', 'ip saddr @FQDN_nat_source_1_s', 'masquerade', 'comment "SRC-NAT-1"'],
+            ['tcp dport 5122', 'ip saddr @FQDN_nat_destination_1_s', 'ip daddr @FQDN_nat_destination_1_d', 'dnat to 198.51.100.1:22', 'comment "DST-NAT-1"']
+        ]
+
+        self.verify_nftables(nftables_search, 'ip vyos_nat')
 if __name__ == '__main__':
     unittest.main(verbosity=2)
diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py
index 9974a1466..f575843f3 100755
--- a/src/conf_mode/firewall.py
+++ b/src/conf_mode/firewall.py
@@ -36,10 +36,14 @@ from vyos.utils.process import cmd
 from vyos.utils.process import rc_cmd
 from vyos import ConfigError
 from vyos import airbag
+from pathlib import Path
 
 airbag.enable()
 
 nftables_conf = '/run/nftables.conf'
+domain_resolver_usage = '/run/use-vyos-domain-resolver-firewall'
+domain_resolver_usage_nat = '/run/use-vyos-domain-resolver-nat'
+
 sysctl_file = r'/run/sysctl/10-vyos-firewall.conf'
 
 valid_groups = [
@@ -122,7 +126,7 @@ def get_config(config=None):
 
     firewall['geoip_updated'] = geoip_updated(conf, firewall)
 
-    fqdn_config_parse(firewall)
+    fqdn_config_parse(firewall, 'firewall')
 
     set_dependents('conntrack', conf)
 
@@ -467,12 +471,15 @@ def apply(firewall):
 
     call_dependents()
 
-    # T970 Enable a resolver (systemd daemon) that checks
-    # domain-group/fqdn addresses and update entries for domains by timeout
-    # If router loaded without internet connection or for synchronization
-    domain_action = 'stop'
-    if dict_search_args(firewall, 'group', 'domain_group') or firewall['ip_fqdn'] or firewall['ip6_fqdn']:
-        domain_action = 'restart'
+    ## DOMAIN RESOLVER
+    domain_action = 'restart'
+    if dict_search_args(firewall, 'group', 'domain_group') or firewall['ip_fqdn'].items() or firewall['ip6_fqdn'].items():
+        text = f'# Automatically generated by firewall.py\nThis file indicates that vyos-domain-resolver service is used by the firewall.\n'
+        Path(domain_resolver_usage).write_text(text)
+    else:
+        Path(domain_resolver_usage).unlink(missing_ok=True)
+        if not Path('/run').glob('use-vyos-domain-resolver*'):
+            domain_action = 'stop'
     call(f'systemctl {domain_action} vyos-domain-resolver.service')
 
     if firewall['geoip_updated']:
diff --git a/src/conf_mode/nat.py b/src/conf_mode/nat.py
index 39803fa02..98b2f3f29 100755
--- a/src/conf_mode/nat.py
+++ b/src/conf_mode/nat.py
@@ -26,10 +26,13 @@ from vyos.template import is_ip_network
 from vyos.utils.kernel import check_kmod
 from vyos.utils.dict import dict_search
 from vyos.utils.dict import dict_search_args
+from vyos.utils.file import write_file
 from vyos.utils.process import cmd
 from vyos.utils.process import run
+from vyos.utils.process import call
 from vyos.utils.network import is_addr_assigned
 from vyos.utils.network import interface_exists
+from vyos.firewall import fqdn_config_parse
 from vyos import ConfigError
 
 from vyos import airbag
@@ -39,6 +42,8 @@ k_mod = ['nft_nat', 'nft_chain_nat']
 
 nftables_nat_config = '/run/nftables_nat.conf'
 nftables_static_nat_conf = '/run/nftables_static-nat-rules.nft'
+domain_resolver_usage = '/run/use-vyos-domain-resolver-nat'
+domain_resolver_usage_firewall = '/run/use-vyos-domain-resolver-firewall'
 
 valid_groups = [
     'address_group',
@@ -71,6 +76,8 @@ def get_config(config=None):
     if 'dynamic_group' in nat['firewall_group']:
         del nat['firewall_group']['dynamic_group']
 
+    fqdn_config_parse(nat, 'nat')
+
     return nat
 
 def verify_rule(config, err_msg, groups_dict):
@@ -251,6 +258,19 @@ def apply(nat):
 
     call_dependents()
 
+    # DOMAIN RESOLVER
+    if nat and 'deleted' not in nat:
+        domain_action = 'restart'
+        if nat['ip_fqdn'].items():
+            text = f'# Automatically generated by nat.py\nThis file indicates that vyos-domain-resolver service is used by nat.\n'
+            write_file(domain_resolver_usage, text)
+        elif os.path.exists(domain_resolver_usage):
+            os.unlink(domain_resolver_usage)
+            if not os.path.exists(domain_resolver_usage_firewall):
+                # Firewall not using domain resolver
+                domain_action = 'stop'
+        call(f'systemctl {domain_action} vyos-domain-resolver.service')
+
     return None
 
 if __name__ == '__main__':
diff --git a/src/helpers/vyos-domain-resolver.py b/src/helpers/vyos-domain-resolver.py
index 57cfcabd7..f5a1d9297 100755
--- a/src/helpers/vyos-domain-resolver.py
+++ b/src/helpers/vyos-domain-resolver.py
@@ -30,6 +30,8 @@ from vyos.xml_ref import get_defaults
 base = ['firewall']
 timeout = 300
 cache = False
+base_firewall = ['firewall']
+base_nat = ['nat']
 
 domain_state = {}
 
@@ -46,25 +48,25 @@ ipv6_tables = {
     'ip6 raw'
 }
 
-def get_config(conf):
-    firewall = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True,
+def get_config(conf, node):
+    node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True,
                                     no_tag_node_value_mangle=True)
 
-    default_values = get_defaults(base, get_first_key=True)
+    default_values = get_defaults(node, get_first_key=True)
 
-    firewall = dict_merge(default_values, firewall)
+    node_config = dict_merge(default_values, node_config)
 
     global timeout, cache
 
-    if 'resolver_interval' in firewall:
-        timeout = int(firewall['resolver_interval'])
+    if 'resolver_interval' in node_config:
+        timeout = int(node_config['resolver_interval'])
 
-    if 'resolver_cache' in firewall:
+    if 'resolver_cache' in node_config:
         cache = True
 
-    fqdn_config_parse(firewall)
+    fqdn_config_parse(node_config, node[0])
 
-    return firewall
+    return node_config
 
 def resolve(domains, ipv6=False):
     global domain_state
@@ -108,55 +110,60 @@ def nft_valid_sets():
     except:
         return []
 
-def update(firewall):
+def update_fqdn(config, node):
     conf_lines = []
     count = 0
-
     valid_sets = nft_valid_sets()
 
-    domain_groups = dict_search_args(firewall, 'group', 'domain_group')
-    if domain_groups:
-        for set_name, domain_config in domain_groups.items():
-            if 'address' not in domain_config:
-                continue
-
-            nft_set_name = f'D_{set_name}'
-            domains = domain_config['address']
-
-            ip_list = resolve(domains, ipv6=False)
-            for table in ipv4_tables:
-                if (table, nft_set_name) in valid_sets:
-                    conf_lines += nft_output(table, nft_set_name, ip_list)
-
-            ip6_list = resolve(domains, ipv6=True)
-            for table in ipv6_tables:
-                if (table, nft_set_name) in valid_sets:
-                    conf_lines += nft_output(table, nft_set_name, ip6_list)
+    if node == 'firewall':
+        domain_groups = dict_search_args(config, 'group', 'domain_group')
+        if domain_groups:
+            for set_name, domain_config in domain_groups.items():
+                if 'address' not in domain_config:
+                    continue
+                nft_set_name = f'D_{set_name}'
+                domains = domain_config['address']
+
+                ip_list = resolve(domains, ipv6=False)
+                for table in ipv4_tables:
+                    if (table, nft_set_name) in valid_sets:
+                        conf_lines += nft_output(table, nft_set_name, ip_list)
+                ip6_list = resolve(domains, ipv6=True)
+                for table in ipv6_tables:
+                    if (table, nft_set_name) in valid_sets:
+                        conf_lines += nft_output(table, nft_set_name, ip6_list)
+                count += 1
+
+        for set_name, domain in config['ip_fqdn'].items():
+            table = 'ip vyos_filter'
+            nft_set_name = f'FQDN_{set_name}'
+            ip_list = resolve([domain], ipv6=False)
+            if (table, nft_set_name) in valid_sets:
+                conf_lines += nft_output(table, nft_set_name, ip_list)
             count += 1
 
-    for set_name, domain in firewall['ip_fqdn'].items():
-        table = 'ip vyos_filter'
-        nft_set_name = f'FQDN_{set_name}'
-
-        ip_list = resolve([domain], ipv6=False)
-
-        if (table, nft_set_name) in valid_sets:
-            conf_lines += nft_output(table, nft_set_name, ip_list)
-        count += 1
-
-    for set_name, domain in firewall['ip6_fqdn'].items():
-        table = 'ip6 vyos_filter'
-        nft_set_name = f'FQDN_{set_name}'
+        for set_name, domain in config['ip6_fqdn'].items():
+            table = 'ip6 vyos_filter'
+            nft_set_name = f'FQDN_{set_name}'
+            ip_list = resolve([domain], ipv6=True)
+            if (table, nft_set_name) in valid_sets:
+                conf_lines += nft_output(table, nft_set_name, ip_list)
+            count += 1
 
-        ip_list = resolve([domain], ipv6=True)
-        if (table, nft_set_name) in valid_sets:
-            conf_lines += nft_output(table, nft_set_name, ip_list)
-        count += 1
+    else:
+        # It's NAT
+        for set_name, domain in config['ip_fqdn'].items():
+            table = 'ip vyos_nat'
+            nft_set_name = f'FQDN_nat_{set_name}'
+            ip_list = resolve([domain], ipv6=False)
+            if (table, nft_set_name) in valid_sets:
+                conf_lines += nft_output(table, nft_set_name, ip_list)
+            count += 1
 
     nft_conf_str = "\n".join(conf_lines) + "\n"
     code = run(f'nft --file -', input=nft_conf_str)
 
-    print(f'Updated {count} sets - result: {code}')
+    print(f'Updated {count} sets in {node} - result: {code}')
 
 if __name__ == '__main__':
     print(f'VyOS domain resolver')
@@ -169,10 +176,12 @@ if __name__ == '__main__':
         time.sleep(1)
 
     conf = ConfigTreeQuery()
-    firewall = get_config(conf)
+    firewall = get_config(conf, base_firewall)
+    nat = get_config(conf, base_nat)
 
     print(f'interval: {timeout}s - cache: {cache}')
 
     while True:
-        update(firewall)
+        update_fqdn(firewall, 'firewall')
+        update_fqdn(nat, 'nat')
         time.sleep(timeout)
diff --git a/src/systemd/vyos-domain-resolver.service b/src/systemd/vyos-domain-resolver.service
index c56b51f0c..e63ae5e34 100644
--- a/src/systemd/vyos-domain-resolver.service
+++ b/src/systemd/vyos-domain-resolver.service
@@ -1,6 +1,7 @@
 [Unit]
 Description=VyOS firewall domain resolver
 After=vyos-router.service
+ConditionPathExistsGlob=/run/use-vyos-domain-resolver*
 
 [Service]
 Type=simple
-- 
cgit v1.2.3