#!/usr/bin/env python3 # # Copyright (C) 2022 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import json import os import time from vyos.configdict import dict_merge from vyos.configquery import ConfigTreeQuery from vyos.firewall import fqdn_config_parse from vyos.firewall import fqdn_resolve from vyos.util import cmd from vyos.util import commit_in_progress from vyos.util import dict_search_args from vyos.util import run from vyos.xml import defaults base = ['firewall'] timeout = 300 cache = False domain_state = {} ipv4_tables = { 'ip vyos_mangle', 'ip vyos_filter', 'ip vyos_nat' } ipv6_tables = { 'ip6 vyos_mangle', 'ip6 vyos_filter' } def get_config(conf): firewall = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) default_values = defaults(base) for tmp in ['name', 'ipv6_name']: if tmp in default_values: del default_values[tmp] if 'zone' in default_values: del default_values['zone'] firewall = dict_merge(default_values, firewall) global timeout, cache if 'resolver_interval' in firewall: timeout = int(firewall['resolver_interval']) if 'resolver_cache' in firewall: cache = True fqdn_config_parse(firewall) return firewall def resolve(domains, ipv6=False): global domain_state ip_list = set() for domain in domains: resolved = fqdn_resolve(domain, ipv6=ipv6) if resolved and cache: domain_state[domain] = resolved elif not resolved: if domain not in domain_state: continue resolved = domain_state[domain] ip_list = ip_list | resolved return ip_list def nft_output(table, set_name, ip_list): output = [f'flush set {table} {set_name}'] if ip_list: ip_str = ','.join(ip_list) output.append(f'add element {table} {set_name} {{ {ip_str} }}') return output def nft_valid_sets(): try: valid_sets = [] sets_json = cmd('nft -j list sets') sets_obj = json.loads(sets_json) for obj in sets_obj['nftables']: if 'set' in obj: family = obj['set']['family'] table = obj['set']['table'] name = obj['set']['name'] valid_sets.append((f'{family} {table}', name)) return valid_sets except: return [] def update(firewall): conf_lines = [] count = 0 valid_sets = nft_valid_sets() domain_groups = dict_search_args(firewall, 'group', 'domain_group') if domain_groups: for set_name, domain_config in domain_groups.items(): if 'address' not in domain_config: continue nft_set_name = f'D_{set_name}' domains = domain_config['address'] ip_list = resolve(domains, ipv6=False) for table in ipv4_tables: if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) ip6_list = resolve(domains, ipv6=True) for table in ipv6_tables: if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip6_list) count += 1 for set_name, domain in firewall['ip_fqdn'].items(): table = 'ip vyos_filter' nft_set_name = f'FQDN_{set_name}' ip_list = resolve([domain], ipv6=False) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 for set_name, domain in firewall['ip6_fqdn'].items(): table = 'ip6 vyos_filter' nft_set_name = f'FQDN_{set_name}' ip_list = resolve([domain], ipv6=True) if (table, nft_set_name) in valid_sets: conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 nft_conf_str = "\n".join(conf_lines) + "\n" code = run(f'nft -f -', input=nft_conf_str) print(f'Updated {count} sets - result: {code}') if __name__ == '__main__': print(f'VyOS domain resolver') count = 1 while commit_in_progress(): if ( count % 60 == 0 ): print(f'Commit still in progress after {count}s - waiting') count += 1 time.sleep(1) conf = ConfigTreeQuery() firewall = get_config(conf) print(f'interval: {timeout}s - cache: {cache}') while True: update(firewall) time.sleep(timeout)