summaryrefslogtreecommitdiff
path: root/src/services/vyos-domain-resolver
diff options
context:
space:
mode:
Diffstat (limited to 'src/services/vyos-domain-resolver')
-rwxr-xr-xsrc/services/vyos-domain-resolver313
1 files changed, 313 insertions, 0 deletions
diff --git a/src/services/vyos-domain-resolver b/src/services/vyos-domain-resolver
new file mode 100755
index 000000000..fb18724af
--- /dev/null
+++ b/src/services/vyos-domain-resolver
@@ -0,0 +1,313 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022-2024 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+import json
+import time
+import logging
+import os
+
+from vyos.configdict import dict_merge
+from vyos.configquery import ConfigTreeQuery
+from vyos.firewall import fqdn_config_parse
+from vyos.firewall import fqdn_resolve
+from vyos.ifconfig import WireGuardIf
+from vyos.remote import download
+from vyos.utils.commit import commit_in_progress
+from vyos.utils.dict import dict_search_args
+from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME
+from vyos.utils.file import makedir, chmod_775, write_file, read_file
+from vyos.utils.network import is_valid_ipv4_address_or_range, is_valid_ipv6_address_or_range
+from vyos.utils.process import cmd
+from vyos.utils.process import run
+from vyos.xml_ref import get_defaults
+
+base = ['firewall']
+timeout = 300
+cache = False
+base_firewall = ['firewall']
+base_nat = ['nat']
+base_interfaces = ['interfaces']
+
+firewall_config_dir = "/config/firewall"
+
+domain_state = {}
+
+ipv4_tables = {
+ 'ip vyos_mangle',
+ 'ip vyos_filter',
+ 'ip vyos_nat',
+ 'ip raw'
+}
+
+ipv6_tables = {
+ 'ip6 vyos_mangle',
+ 'ip6 vyos_filter',
+ 'ip6 raw'
+}
+
+logger = logging.getLogger(__name__)
+logs_handler = logging.StreamHandler()
+logger.addHandler(logs_handler)
+logger.setLevel(logging.INFO)
+
+def get_config(conf, node):
+ node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True,
+ no_tag_node_value_mangle=True)
+
+ default_values = get_defaults(node, get_first_key=True)
+
+ node_config = dict_merge(default_values, node_config)
+
+ if node == base_firewall and 'global_options' in node_config:
+ global_config = node_config['global_options']
+ global timeout, cache
+
+ if 'resolver_interval' in global_config:
+ timeout = int(global_config['resolver_interval'])
+
+ if 'resolver_cache' in global_config:
+ cache = True
+
+ fqdn_config_parse(node_config, node[0])
+
+ return node_config
+
+def resolve(domains, ipv6=False):
+ global domain_state
+
+ ip_list = set()
+
+ for domain in domains:
+ resolved = fqdn_resolve(domain, ipv6=ipv6)
+
+ cache_key = f'{domain}_ipv6' if ipv6 else domain
+
+ if resolved and cache:
+ domain_state[cache_key] = resolved
+ elif not resolved:
+ if cache_key not in domain_state:
+ continue
+ resolved = domain_state[cache_key]
+
+ ip_list = ip_list | resolved
+ return ip_list
+
+def nft_output(table, set_name, ip_list):
+ output = [f'flush set {table} {set_name}']
+ if ip_list:
+ ip_str = ','.join(ip_list)
+ output.append(f'add element {table} {set_name} {{ {ip_str} }}')
+ return output
+
+def nft_valid_sets():
+ try:
+ valid_sets = []
+ sets_json = cmd('nft --json list sets')
+ sets_obj = json.loads(sets_json)
+
+ for obj in sets_obj['nftables']:
+ if 'set' in obj:
+ family = obj['set']['family']
+ table = obj['set']['table']
+ name = obj['set']['name']
+ valid_sets.append((f'{family} {table}', name))
+
+ return valid_sets
+ except:
+ return []
+
+def update_remote_group(config):
+ conf_lines = []
+ count = 0
+ valid_sets = nft_valid_sets()
+
+ remote_groups = dict_search_args(config, 'group', 'remote_group')
+ if remote_groups:
+ # Create directory for list files if necessary
+ if not os.path.isdir(firewall_config_dir):
+ makedir(firewall_config_dir, group='vyattacfg')
+ chmod_775(firewall_config_dir)
+
+ for set_name, remote_config in remote_groups.items():
+ if 'url' not in remote_config:
+ continue
+ nft_ip_set_name = f'R_{set_name}'
+ nft_ip6_set_name = f'R6_{set_name}'
+
+ # Create list file if necessary
+ list_file = os.path.join(firewall_config_dir, f"{nft_ip_set_name}.txt")
+ if not os.path.exists(list_file):
+ write_file(list_file, '', user="root", group="vyattacfg", mode=0o644)
+
+ # Attempt to download file, use cached version if download fails
+ try:
+ download(list_file, remote_config['url'], raise_error=True)
+ except:
+ logger.error(f'Failed to download list-file for {set_name} remote group')
+ logger.info(f'Using cached list-file for {set_name} remote group')
+
+ # Read list file
+ ip_list = []
+ ip6_list = []
+ invalid_list = []
+ for line in read_file(list_file).splitlines():
+ line_first_word = line.strip().partition(' ')[0]
+
+ if is_valid_ipv4_address_or_range(line_first_word):
+ ip_list.append(line_first_word)
+ elif is_valid_ipv6_address_or_range(line_first_word):
+ ip6_list.append(line_first_word)
+ else:
+ if line_first_word[0].isalnum():
+ invalid_list.append(line_first_word)
+
+ # Load ip tables
+ for table in ipv4_tables:
+ if (table, nft_ip_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_ip_set_name, ip_list)
+
+ # Load ip6 tables
+ for table in ipv6_tables:
+ if (table, nft_ip6_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_ip6_set_name, ip6_list)
+
+ invalid_str = ", ".join(invalid_list)
+ if invalid_str:
+ logger.info(f'Invalid address for set {set_name}: {invalid_str}')
+
+ count += 1
+
+ nft_conf_str = "\n".join(conf_lines) + "\n"
+ code = run(f'nft --file -', input=nft_conf_str)
+
+ logger.info(f'Updated {count} remote-groups in firewall - result: {code}')
+
+
+def update_fqdn(config, node):
+ conf_lines = []
+ count = 0
+ valid_sets = nft_valid_sets()
+
+ if node == 'firewall':
+ domain_groups = dict_search_args(config, 'group', 'domain_group')
+ if domain_groups:
+ for set_name, domain_config in domain_groups.items():
+ if 'address' not in domain_config:
+ continue
+ nft_set_name = f'D_{set_name}'
+ domains = domain_config['address']
+
+ ip_list = resolve(domains, ipv6=False)
+ for table in ipv4_tables:
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ ip6_list = resolve(domains, ipv6=True)
+ for table in ipv6_tables:
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip6_list)
+ count += 1
+
+ for set_name, domain in config['ip_fqdn'].items():
+ table = 'ip vyos_filter'
+ nft_set_name = f'FQDN_{set_name}'
+ ip_list = resolve([domain], ipv6=False)
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ count += 1
+
+ for set_name, domain in config['ip6_fqdn'].items():
+ table = 'ip6 vyos_filter'
+ nft_set_name = f'FQDN_{set_name}'
+ ip_list = resolve([domain], ipv6=True)
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ count += 1
+
+ else:
+ # It's NAT
+ for set_name, domain in config['ip_fqdn'].items():
+ table = 'ip vyos_nat'
+ nft_set_name = f'FQDN_nat_{set_name}'
+ ip_list = resolve([domain], ipv6=False)
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ count += 1
+
+ nft_conf_str = "\n".join(conf_lines) + "\n"
+ code = run(f'nft --file -', input=nft_conf_str)
+
+ logger.info(f'Updated {count} sets in {node} - result: {code}')
+
+def update_interfaces(config, node):
+ if node == 'interfaces':
+ wg_interfaces = dict_search_args(config, 'wireguard')
+ if wg_interfaces:
+
+ peer_public_keys = {}
+ # for each wireguard interfaces
+ for interface, wireguard in wg_interfaces.items():
+ peer_public_keys[interface] = []
+ for peer, peer_config in wireguard['peer'].items():
+ # check peer if peer host-name or address is set
+ if 'host_name' in peer_config or 'address' in peer_config:
+ # check latest handshake
+ peer_public_keys[interface].append(
+ peer_config['public_key']
+ )
+
+ now_time = time.time()
+ for (interface, check_peer_public_keys) in peer_public_keys.items():
+ if len(check_peer_public_keys) == 0:
+ continue
+
+ intf = WireGuardIf(interface, create=False, debug=False)
+ handshakes = intf.operational.get_latest_handshakes()
+
+ # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME
+ # if data is being transmitted between the peers. If no data is
+ # transmitted, the handshake will not be initiated unless new
+ # data begins to flow. Each handshake generates a new session
+ # key, and the key is rotated at least every 120 seconds or
+ # upon data transmission after a prolonged silence.
+ for public_key, handshake_time in handshakes.items():
+ if public_key in check_peer_public_keys and (
+ handshake_time == 0
+ or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME)
+ ):
+ intf.operational.reset_peer(public_key=public_key)
+
+if __name__ == '__main__':
+ logger.info('VyOS domain resolver')
+
+ count = 1
+ while commit_in_progress():
+ if ( count % 60 == 0 ):
+ logger.info(f'Commit still in progress after {count}s - waiting')
+ count += 1
+ time.sleep(1)
+
+ conf = ConfigTreeQuery()
+ firewall = get_config(conf, base_firewall)
+ nat = get_config(conf, base_nat)
+ interfaces = get_config(conf, base_interfaces)
+
+ logger.info(f'interval: {timeout}s - cache: {cache}')
+
+ while True:
+ update_fqdn(firewall, 'firewall')
+ update_fqdn(nat, 'nat')
+ update_remote_group(firewall)
+ update_interfaces(interfaces, 'interfaces')
+ time.sleep(timeout)