diff options
Diffstat (limited to 'src/services/vyos-domain-resolver')
-rwxr-xr-x | src/services/vyos-domain-resolver | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/src/services/vyos-domain-resolver b/src/services/vyos-domain-resolver new file mode 100755 index 000000000..fb18724af --- /dev/null +++ b/src/services/vyos-domain-resolver @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022-2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import json +import time +import logging +import os + +from vyos.configdict import dict_merge +from vyos.configquery import ConfigTreeQuery +from vyos.firewall import fqdn_config_parse +from vyos.firewall import fqdn_resolve +from vyos.ifconfig import WireGuardIf +from vyos.remote import download +from vyos.utils.commit import commit_in_progress +from vyos.utils.dict import dict_search_args +from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME +from vyos.utils.file import makedir, chmod_775, write_file, read_file +from vyos.utils.network import is_valid_ipv4_address_or_range, is_valid_ipv6_address_or_range +from vyos.utils.process import cmd +from vyos.utils.process import run +from vyos.xml_ref import get_defaults + +base = ['firewall'] +timeout = 300 +cache = False +base_firewall = ['firewall'] +base_nat = ['nat'] +base_interfaces = ['interfaces'] + +firewall_config_dir = "/config/firewall" + +domain_state = {} + +ipv4_tables = { + 'ip vyos_mangle', + 'ip vyos_filter', + 'ip vyos_nat', + 'ip raw' +} + +ipv6_tables = { + 'ip6 vyos_mangle', + 'ip6 vyos_filter', + 'ip6 raw' +} + +logger = logging.getLogger(__name__) +logs_handler = logging.StreamHandler() +logger.addHandler(logs_handler) +logger.setLevel(logging.INFO) + +def get_config(conf, node): + node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True, + no_tag_node_value_mangle=True) + + default_values = get_defaults(node, get_first_key=True) + + node_config = dict_merge(default_values, node_config) + + if node == base_firewall and 'global_options' in node_config: + global_config = node_config['global_options'] + global timeout, cache + + if 'resolver_interval' in global_config: + timeout = int(global_config['resolver_interval']) + + if 'resolver_cache' in global_config: + cache = True + + fqdn_config_parse(node_config, node[0]) + + return node_config + +def resolve(domains, ipv6=False): + global domain_state + + ip_list = set() + + for domain in domains: + resolved = fqdn_resolve(domain, ipv6=ipv6) + + cache_key = f'{domain}_ipv6' if ipv6 else domain + + if resolved and cache: + domain_state[cache_key] = resolved + elif not resolved: + if cache_key not in domain_state: + continue + resolved = domain_state[cache_key] + + ip_list = ip_list | resolved + return ip_list + +def nft_output(table, set_name, ip_list): + output = [f'flush set {table} {set_name}'] + if ip_list: + ip_str = ','.join(ip_list) + output.append(f'add element {table} {set_name} {{ {ip_str} }}') + return output + +def nft_valid_sets(): + try: + valid_sets = [] + sets_json = cmd('nft --json list sets') + sets_obj = json.loads(sets_json) + + for obj in sets_obj['nftables']: + if 'set' in obj: + family = obj['set']['family'] + table = obj['set']['table'] + name = obj['set']['name'] + valid_sets.append((f'{family} {table}', name)) + + return valid_sets + except: + return [] + +def update_remote_group(config): + conf_lines = [] + count = 0 + valid_sets = nft_valid_sets() + + remote_groups = dict_search_args(config, 'group', 'remote_group') + if remote_groups: + # Create directory for list files if necessary + if not os.path.isdir(firewall_config_dir): + makedir(firewall_config_dir, group='vyattacfg') + chmod_775(firewall_config_dir) + + for set_name, remote_config in remote_groups.items(): + if 'url' not in remote_config: + continue + nft_ip_set_name = f'R_{set_name}' + nft_ip6_set_name = f'R6_{set_name}' + + # Create list file if necessary + list_file = os.path.join(firewall_config_dir, f"{nft_ip_set_name}.txt") + if not os.path.exists(list_file): + write_file(list_file, '', user="root", group="vyattacfg", mode=0o644) + + # Attempt to download file, use cached version if download fails + try: + download(list_file, remote_config['url'], raise_error=True) + except: + logger.error(f'Failed to download list-file for {set_name} remote group') + logger.info(f'Using cached list-file for {set_name} remote group') + + # Read list file + ip_list = [] + ip6_list = [] + invalid_list = [] + for line in read_file(list_file).splitlines(): + line_first_word = line.strip().partition(' ')[0] + + if is_valid_ipv4_address_or_range(line_first_word): + ip_list.append(line_first_word) + elif is_valid_ipv6_address_or_range(line_first_word): + ip6_list.append(line_first_word) + else: + if line_first_word[0].isalnum(): + invalid_list.append(line_first_word) + + # Load ip tables + for table in ipv4_tables: + if (table, nft_ip_set_name) in valid_sets: + conf_lines += nft_output(table, nft_ip_set_name, ip_list) + + # Load ip6 tables + for table in ipv6_tables: + if (table, nft_ip6_set_name) in valid_sets: + conf_lines += nft_output(table, nft_ip6_set_name, ip6_list) + + invalid_str = ", ".join(invalid_list) + if invalid_str: + logger.info(f'Invalid address for set {set_name}: {invalid_str}') + + count += 1 + + nft_conf_str = "\n".join(conf_lines) + "\n" + code = run(f'nft --file -', input=nft_conf_str) + + logger.info(f'Updated {count} remote-groups in firewall - result: {code}') + + +def update_fqdn(config, node): + conf_lines = [] + count = 0 + valid_sets = nft_valid_sets() + + if node == 'firewall': + domain_groups = dict_search_args(config, 'group', 'domain_group') + if domain_groups: + for set_name, domain_config in domain_groups.items(): + if 'address' not in domain_config: + continue + nft_set_name = f'D_{set_name}' + domains = domain_config['address'] + + ip_list = resolve(domains, ipv6=False) + for table in ipv4_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + ip6_list = resolve(domains, ipv6=True) + for table in ipv6_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip6_list) + count += 1 + + for set_name, domain in config['ip_fqdn'].items(): + table = 'ip vyos_filter' + nft_set_name = f'FQDN_{set_name}' + ip_list = resolve([domain], ipv6=False) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + for set_name, domain in config['ip6_fqdn'].items(): + table = 'ip6 vyos_filter' + nft_set_name = f'FQDN_{set_name}' + ip_list = resolve([domain], ipv6=True) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + else: + # It's NAT + for set_name, domain in config['ip_fqdn'].items(): + table = 'ip vyos_nat' + nft_set_name = f'FQDN_nat_{set_name}' + ip_list = resolve([domain], ipv6=False) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + nft_conf_str = "\n".join(conf_lines) + "\n" + code = run(f'nft --file -', input=nft_conf_str) + + logger.info(f'Updated {count} sets in {node} - result: {code}') + +def update_interfaces(config, node): + if node == 'interfaces': + wg_interfaces = dict_search_args(config, 'wireguard') + if wg_interfaces: + + peer_public_keys = {} + # for each wireguard interfaces + for interface, wireguard in wg_interfaces.items(): + peer_public_keys[interface] = [] + for peer, peer_config in wireguard['peer'].items(): + # check peer if peer host-name or address is set + if 'host_name' in peer_config or 'address' in peer_config: + # check latest handshake + peer_public_keys[interface].append( + peer_config['public_key'] + ) + + now_time = time.time() + for (interface, check_peer_public_keys) in peer_public_keys.items(): + if len(check_peer_public_keys) == 0: + continue + + intf = WireGuardIf(interface, create=False, debug=False) + handshakes = intf.operational.get_latest_handshakes() + + # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME + # if data is being transmitted between the peers. If no data is + # transmitted, the handshake will not be initiated unless new + # data begins to flow. Each handshake generates a new session + # key, and the key is rotated at least every 120 seconds or + # upon data transmission after a prolonged silence. + for public_key, handshake_time in handshakes.items(): + if public_key in check_peer_public_keys and ( + handshake_time == 0 + or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME) + ): + intf.operational.reset_peer(public_key=public_key) + +if __name__ == '__main__': + logger.info('VyOS domain resolver') + + count = 1 + while commit_in_progress(): + if ( count % 60 == 0 ): + logger.info(f'Commit still in progress after {count}s - waiting') + count += 1 + time.sleep(1) + + conf = ConfigTreeQuery() + firewall = get_config(conf, base_firewall) + nat = get_config(conf, base_nat) + interfaces = get_config(conf, base_interfaces) + + logger.info(f'interval: {timeout}s - cache: {cache}') + + while True: + update_fqdn(firewall, 'firewall') + update_fqdn(nat, 'nat') + update_remote_group(firewall) + update_interfaces(interfaces, 'interfaces') + time.sleep(timeout) |