#!/usr/bin/env python3 # # Copyright (C) 2022-2024 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import json import time import logging from vyos.configdict import dict_merge from vyos.configquery import ConfigTreeQuery from vyos.firewall import fqdn_config_parse from vyos.firewall import fqdn_resolve from vyos.ifconfig import WireGuardIf from vyos.utils.commit import commit_in_progress from vyos.utils.dict import dict_search_args from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME from vyos.utils.process import cmd from vyos.utils.process import run from vyos.xml_ref import get_defaults base = ['firewall'] timeout = 300 cache = False base_firewall = ['firewall'] base_nat = ['nat'] base_interfaces = ['interfaces'] domain_state = {} ipv4_tables = { 'ip vyos_mangle', 'ip vyos_filter', 'ip vyos_nat', 'ip raw' } ipv6_tables = { 'ip6 vyos_mangle', 'ip6 vyos_filter', 'ip6 raw' } logger = logging.getLogger(__name__) logs_handler = logging.StreamHandler() logger.addHandler(logs_handler) logger.setLevel(logging.INFO) def get_config(conf, node): node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) default_values = get_defaults(node, get_first_key=True) node_config = dict_merge(default_values, node_config) if node == base_firewall and 'global_options' in node_config: global_config = node_config['global_options'] global timeout, cache if 'resolver_interval' in global_config: timeout = int(global_config['resolver_interval']) if 'resolver_cache' in global_config: cache = True fqdn_config_parse(node_config, node[0]) return node_config def resolve(domains, ipv6=False): global domain_state ip_list = set() for domain in domains: resolved = fqdn_resolve(domain, ipv6=ipv6) if resolved and cache: domain_state[domain] = resolved elif not resolved: if domain not in domain_state: continue resolved = domain_state[domain] ip_list = ip_list | resolved return ip_list def nft_output(table, set_name, ip_list): output = [f'flush set {table} {set_name}'] if ip_list: ip_str = ','.join(ip_list) output.append(f'add element {table} {set_name} {{ {ip_str} }}') return output def nft_valid_sets(): try: valid_sets = [] sets_json = cmd('nft --json list sets') sets_obj = json.loads(sets_json) for obj in sets_obj['nftables']: if 'set' in obj: family = obj['set']['family'] table = obj['set']['table'] name = obj['set']['name'] valid_sets.append((f'{family} {table}', name)) return valid_sets except: return [] def update_fqdn(config, node): conf_lines = [] count = 0 valid_sets = nft_valid_sets() if node == 'firewall': domain_groups = dict_search_args(config, 'group', 'domain_group') if domain_groups: for set_name, domain_config in domain_groups.items(): if 'address' not in domain_config: continue nft_set_name = f'D_{set_name}' domains = domain_config['address'] ip_list = resolve(domains, ipv6=False) for table in ipv4_tables: if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) ip6_list = resolve(domains, ipv6=True) for table in ipv6_tables: if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip6_list) count += 1 for set_name, domain in config['ip_fqdn'].items(): table = 'ip vyos_filter' nft_set_name = f'FQDN_{set_name}' ip_list = resolve([domain], ipv6=False) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 for set_name, domain in config['ip6_fqdn'].items(): table = 'ip6 vyos_filter' nft_set_name = f'FQDN_{set_name}' ip_list = resolve([domain], ipv6=True) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 else: # It's NAT for set_name, domain in config['ip_fqdn'].items(): table = 'ip vyos_nat' nft_set_name = f'FQDN_nat_{set_name}' ip_list = resolve([domain], ipv6=False) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 nft_conf_str = "\n".join(conf_lines) + "\n" code = run(f'nft --file -', input=nft_conf_str) logger.info(f'Updated {count} sets in {node} - result: {code}') def update_interfaces(config, node): if node == 'interfaces': wg_interfaces = dict_search_args(config, 'wireguard') if wg_interfaces: peer_public_keys = {} # for each wireguard interfaces for interface, wireguard in wg_interfaces.items(): peer_public_keys[interface] = [] for peer, peer_config in wireguard['peer'].items(): # check peer if peer host-name or address is set if 'host_name' in peer_config or 'address' in peer_config: # check latest handshake peer_public_keys[interface].append( peer_config['public_key'] ) now_time = time.time() for (interface, check_peer_public_keys) in peer_public_keys.items(): if len(check_peer_public_keys) == 0: continue intf = WireGuardIf(interface, create=False, debug=False) handshakes = intf.operational.get_latest_handshakes() # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME # if data is being transmitted between the peers. If no data is # transmitted, the handshake will not be initiated unless new # data begins to flow. Each handshake generates a new session # key, and the key is rotated at least every 120 seconds or # upon data transmission after a prolonged silence. for public_key, handshake_time in handshakes.items(): if public_key in check_peer_public_keys and ( handshake_time == 0 or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME) ): intf.operational.reset_peer(public_key=public_key) if __name__ == '__main__': logger.info('VyOS domain resolver') count = 1 while commit_in_progress(): if ( count % 60 == 0 ): logger.info(f'Commit still in progress after {count}s - waiting') count += 1 time.sleep(1) conf = ConfigTreeQuery() firewall = get_config(conf, base_firewall) nat = get_config(conf, base_nat) interfaces = get_config(conf, base_interfaces) logger.info(f'interval: {timeout}s - cache: {cache}') while True: update_fqdn(firewall, 'firewall') update_fqdn(nat, 'nat') update_interfaces(interfaces, 'interfaces') time.sleep(timeout)