diff options
Diffstat (limited to 'src/helpers/vyos-domain-resolver.py')
-rwxr-xr-x | src/helpers/vyos-domain-resolver.py | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/src/helpers/vyos-domain-resolver.py b/src/helpers/vyos-domain-resolver.py new file mode 100755 index 000000000..035c208b2 --- /dev/null +++ b/src/helpers/vyos-domain-resolver.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import json +import os +import time + +from vyos.configdict import dict_merge +from vyos.configquery import ConfigTreeQuery +from vyos.firewall import fqdn_config_parse +from vyos.firewall import fqdn_resolve +from vyos.util import cmd +from vyos.util import commit_in_progress +from vyos.util import dict_search_args +from vyos.util import run +from vyos.xml import defaults + +base = ['firewall'] +timeout = 300 +cache = False + +domain_state = {} + +ipv4_tables = { + 'ip mangle', + 'ip vyos_filter', + 'ip vyos_nat' +} + +ipv6_tables = { + 'ip6 mangle', + 'ip6 vyos_filter' +} + +def get_config(conf): + firewall = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, + no_tag_node_value_mangle=True) + + default_values = defaults(base) + for tmp in ['name', 'ipv6_name']: + if tmp in default_values: + del default_values[tmp] + + if 'zone' in default_values: + del default_values['zone'] + + firewall = dict_merge(default_values, firewall) + + global timeout, cache + + if 'resolver_interval' in firewall: + timeout = int(firewall['resolver_interval']) + + if 'resolver_cache' in firewall: + cache = True + + fqdn_config_parse(firewall) + + return firewall + +def resolve(domains, ipv6=False): + global domain_state + + ip_list = set() + + for domain in domains: + resolved = fqdn_resolve(domain, ipv6=ipv6) + + if resolved and cache: + domain_state[domain] = resolved + elif not resolved: + if domain not in domain_state: + continue + resolved = domain_state[domain] + + ip_list = ip_list | resolved + return ip_list + +def nft_output(table, set_name, ip_list): + output = [f'flush set {table} {set_name}'] + if ip_list: + ip_str = ','.join(ip_list) + output.append(f'add element {table} {set_name} {{ {ip_str} }}') + return output + +def nft_valid_sets(): + try: + valid_sets = [] + sets_json = cmd('nft -j list sets') + sets_obj = json.loads(sets_json) + + for obj in sets_obj['nftables']: + if 'set' in obj: + family = obj['set']['family'] + table = obj['set']['table'] + name = obj['set']['name'] + valid_sets.append((f'{family} {table}', name)) + + return valid_sets + except: + return [] + +def update(firewall): + conf_lines = [] + count = 0 + + valid_sets = nft_valid_sets() + + domain_groups = dict_search_args(firewall, 'group', 'domain_group') + if domain_groups: + for set_name, domain_config in domain_groups.items(): + if 'address' not in domain_config: + continue + + nft_set_name = f'D_{set_name}' + domains = domain_config['address'] + + ip_list = resolve(domains, ipv6=False) + for table in ipv4_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + + ip6_list = resolve(domains, ipv6=True) + for table in ipv6_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip6_list) + count += 1 + + for set_name, domain in firewall['ip_fqdn'].items(): + table = 'ip vyos_filter' + nft_set_name = f'FQDN_{set_name}' + + ip_list = resolve([domain], ipv6=False) + + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + for set_name, domain in firewall['ip6_fqdn'].items(): + table = 'ip6 vyos_filter' + nft_set_name = f'FQDN_{set_name}' + + ip_list = resolve([domain], ipv6=True) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + nft_conf_str = "\n".join(conf_lines) + "\n" + code = run(f'nft -f -', input=nft_conf_str) + + print(f'Updated {count} sets - result: {code}') + +if __name__ == '__main__': + print(f'VyOS domain resolver') + + count = 1 + while commit_in_progress(): + if ( count % 60 == 0 ): + print(f'Commit still in progress after {count}s - waiting') + count += 1 + time.sleep(1) + + conf = ConfigTreeQuery() + firewall = get_config(conf) + + print(f'interval: {timeout}s - cache: {cache}') + + while True: + update(firewall) + time.sleep(timeout) |