From 4c3d037f036e84c77333a400b35bb1a628a1a118 Mon Sep 17 00:00:00 2001
From: Nicolas Fort <nicolasfort1988@gmail.com>
Date: Fri, 30 Aug 2024 17:54:17 +0000
Subject: T6687: add fqdn support to nat rules.

---
 src/helpers/vyos-domain-resolver.py | 107 +++++++++++++++++++-----------------
 1 file changed, 58 insertions(+), 49 deletions(-)

(limited to 'src/helpers')

diff --git a/src/helpers/vyos-domain-resolver.py b/src/helpers/vyos-domain-resolver.py
index 57cfcabd7..f5a1d9297 100755
--- a/src/helpers/vyos-domain-resolver.py
+++ b/src/helpers/vyos-domain-resolver.py
@@ -30,6 +30,8 @@ from vyos.xml_ref import get_defaults
 base = ['firewall']
 timeout = 300
 cache = False
+base_firewall = ['firewall']
+base_nat = ['nat']
 
 domain_state = {}
 
@@ -46,25 +48,25 @@ ipv6_tables = {
     'ip6 raw'
 }
 
-def get_config(conf):
-    firewall = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True,
+def get_config(conf, node):
+    node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True,
                                     no_tag_node_value_mangle=True)
 
-    default_values = get_defaults(base, get_first_key=True)
+    default_values = get_defaults(node, get_first_key=True)
 
-    firewall = dict_merge(default_values, firewall)
+    node_config = dict_merge(default_values, node_config)
 
     global timeout, cache
 
-    if 'resolver_interval' in firewall:
-        timeout = int(firewall['resolver_interval'])
+    if 'resolver_interval' in node_config:
+        timeout = int(node_config['resolver_interval'])
 
-    if 'resolver_cache' in firewall:
+    if 'resolver_cache' in node_config:
         cache = True
 
-    fqdn_config_parse(firewall)
+    fqdn_config_parse(node_config, node[0])
 
-    return firewall
+    return node_config
 
 def resolve(domains, ipv6=False):
     global domain_state
@@ -108,55 +110,60 @@ def nft_valid_sets():
     except:
         return []
 
-def update(firewall):
+def update_fqdn(config, node):
     conf_lines = []
     count = 0
-
     valid_sets = nft_valid_sets()
 
-    domain_groups = dict_search_args(firewall, 'group', 'domain_group')
-    if domain_groups:
-        for set_name, domain_config in domain_groups.items():
-            if 'address' not in domain_config:
-                continue
-
-            nft_set_name = f'D_{set_name}'
-            domains = domain_config['address']
-
-            ip_list = resolve(domains, ipv6=False)
-            for table in ipv4_tables:
-                if (table, nft_set_name) in valid_sets:
-                    conf_lines += nft_output(table, nft_set_name, ip_list)
-
-            ip6_list = resolve(domains, ipv6=True)
-            for table in ipv6_tables:
-                if (table, nft_set_name) in valid_sets:
-                    conf_lines += nft_output(table, nft_set_name, ip6_list)
+    if node == 'firewall':
+        domain_groups = dict_search_args(config, 'group', 'domain_group')
+        if domain_groups:
+            for set_name, domain_config in domain_groups.items():
+                if 'address' not in domain_config:
+                    continue
+                nft_set_name = f'D_{set_name}'
+                domains = domain_config['address']
+
+                ip_list = resolve(domains, ipv6=False)
+                for table in ipv4_tables:
+                    if (table, nft_set_name) in valid_sets:
+                        conf_lines += nft_output(table, nft_set_name, ip_list)
+                ip6_list = resolve(domains, ipv6=True)
+                for table in ipv6_tables:
+                    if (table, nft_set_name) in valid_sets:
+                        conf_lines += nft_output(table, nft_set_name, ip6_list)
+                count += 1
+
+        for set_name, domain in config['ip_fqdn'].items():
+            table = 'ip vyos_filter'
+            nft_set_name = f'FQDN_{set_name}'
+            ip_list = resolve([domain], ipv6=False)
+            if (table, nft_set_name) in valid_sets:
+                conf_lines += nft_output(table, nft_set_name, ip_list)
             count += 1
 
-    for set_name, domain in firewall['ip_fqdn'].items():
-        table = 'ip vyos_filter'
-        nft_set_name = f'FQDN_{set_name}'
-
-        ip_list = resolve([domain], ipv6=False)
-
-        if (table, nft_set_name) in valid_sets:
-            conf_lines += nft_output(table, nft_set_name, ip_list)
-        count += 1
-
-    for set_name, domain in firewall['ip6_fqdn'].items():
-        table = 'ip6 vyos_filter'
-        nft_set_name = f'FQDN_{set_name}'
+        for set_name, domain in config['ip6_fqdn'].items():
+            table = 'ip6 vyos_filter'
+            nft_set_name = f'FQDN_{set_name}'
+            ip_list = resolve([domain], ipv6=True)
+            if (table, nft_set_name) in valid_sets:
+                conf_lines += nft_output(table, nft_set_name, ip_list)
+            count += 1
 
-        ip_list = resolve([domain], ipv6=True)
-        if (table, nft_set_name) in valid_sets:
-            conf_lines += nft_output(table, nft_set_name, ip_list)
-        count += 1
+    else:
+        # It's NAT
+        for set_name, domain in config['ip_fqdn'].items():
+            table = 'ip vyos_nat'
+            nft_set_name = f'FQDN_nat_{set_name}'
+            ip_list = resolve([domain], ipv6=False)
+            if (table, nft_set_name) in valid_sets:
+                conf_lines += nft_output(table, nft_set_name, ip_list)
+            count += 1
 
     nft_conf_str = "\n".join(conf_lines) + "\n"
     code = run(f'nft --file -', input=nft_conf_str)
 
-    print(f'Updated {count} sets - result: {code}')
+    print(f'Updated {count} sets in {node} - result: {code}')
 
 if __name__ == '__main__':
     print(f'VyOS domain resolver')
@@ -169,10 +176,12 @@ if __name__ == '__main__':
         time.sleep(1)
 
     conf = ConfigTreeQuery()
-    firewall = get_config(conf)
+    firewall = get_config(conf, base_firewall)
+    nat = get_config(conf, base_nat)
 
     print(f'interval: {timeout}s - cache: {cache}')
 
     while True:
-        update(firewall)
+        update_fqdn(firewall, 'firewall')
+        update_fqdn(nat, 'nat')
         time.sleep(timeout)
-- 
cgit v1.2.3