summaryrefslogtreecommitdiff
path: root/src/services/vyos-domain-resolver
diff options
context:
space:
mode:
Diffstat (limited to 'src/services/vyos-domain-resolver')
-rwxr-xr-xsrc/services/vyos-domain-resolver160
1 files changed, 119 insertions, 41 deletions
diff --git a/src/services/vyos-domain-resolver b/src/services/vyos-domain-resolver
index fe0f40a07..fb18724af 100755
--- a/src/services/vyos-domain-resolver
+++ b/src/services/vyos-domain-resolver
@@ -13,19 +13,22 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-
import json
import time
import logging
+import os
from vyos.configdict import dict_merge
from vyos.configquery import ConfigTreeQuery
from vyos.firewall import fqdn_config_parse
from vyos.firewall import fqdn_resolve
from vyos.ifconfig import WireGuardIf
+from vyos.remote import download
from vyos.utils.commit import commit_in_progress
from vyos.utils.dict import dict_search_args
from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME
+from vyos.utils.file import makedir, chmod_775, write_file, read_file
+from vyos.utils.network import is_valid_ipv4_address_or_range, is_valid_ipv6_address_or_range
from vyos.utils.process import cmd
from vyos.utils.process import run
from vyos.xml_ref import get_defaults
@@ -37,6 +40,8 @@ base_firewall = ['firewall']
base_nat = ['nat']
base_interfaces = ['interfaces']
+firewall_config_dir = "/config/firewall"
+
domain_state = {}
ipv4_tables = {
@@ -65,13 +70,15 @@ def get_config(conf, node):
node_config = dict_merge(default_values, node_config)
- global timeout, cache
+ if node == base_firewall and 'global_options' in node_config:
+ global_config = node_config['global_options']
+ global timeout, cache
- if 'resolver_interval' in node_config:
- timeout = int(node_config['resolver_interval'])
+ if 'resolver_interval' in global_config:
+ timeout = int(global_config['resolver_interval'])
- if 'resolver_cache' in node_config:
- cache = True
+ if 'resolver_cache' in global_config:
+ cache = True
fqdn_config_parse(node_config, node[0])
@@ -85,12 +92,14 @@ def resolve(domains, ipv6=False):
for domain in domains:
resolved = fqdn_resolve(domain, ipv6=ipv6)
+ cache_key = f'{domain}_ipv6' if ipv6 else domain
+
if resolved and cache:
- domain_state[domain] = resolved
+ domain_state[cache_key] = resolved
elif not resolved:
- if domain not in domain_state:
+ if cache_key not in domain_state:
continue
- resolved = domain_state[domain]
+ resolved = domain_state[cache_key]
ip_list = ip_list | resolved
return ip_list
@@ -119,6 +128,73 @@ def nft_valid_sets():
except:
return []
+def update_remote_group(config):
+ conf_lines = []
+ count = 0
+ valid_sets = nft_valid_sets()
+
+ remote_groups = dict_search_args(config, 'group', 'remote_group')
+ if remote_groups:
+ # Create directory for list files if necessary
+ if not os.path.isdir(firewall_config_dir):
+ makedir(firewall_config_dir, group='vyattacfg')
+ chmod_775(firewall_config_dir)
+
+ for set_name, remote_config in remote_groups.items():
+ if 'url' not in remote_config:
+ continue
+ nft_ip_set_name = f'R_{set_name}'
+ nft_ip6_set_name = f'R6_{set_name}'
+
+ # Create list file if necessary
+ list_file = os.path.join(firewall_config_dir, f"{nft_ip_set_name}.txt")
+ if not os.path.exists(list_file):
+ write_file(list_file, '', user="root", group="vyattacfg", mode=0o644)
+
+ # Attempt to download file, use cached version if download fails
+ try:
+ download(list_file, remote_config['url'], raise_error=True)
+ except:
+ logger.error(f'Failed to download list-file for {set_name} remote group')
+ logger.info(f'Using cached list-file for {set_name} remote group')
+
+ # Read list file
+ ip_list = []
+ ip6_list = []
+ invalid_list = []
+ for line in read_file(list_file).splitlines():
+ line_first_word = line.strip().partition(' ')[0]
+
+ if is_valid_ipv4_address_or_range(line_first_word):
+ ip_list.append(line_first_word)
+ elif is_valid_ipv6_address_or_range(line_first_word):
+ ip6_list.append(line_first_word)
+ else:
+ if line_first_word[0].isalnum():
+ invalid_list.append(line_first_word)
+
+ # Load ip tables
+ for table in ipv4_tables:
+ if (table, nft_ip_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_ip_set_name, ip_list)
+
+ # Load ip6 tables
+ for table in ipv6_tables:
+ if (table, nft_ip6_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_ip6_set_name, ip6_list)
+
+ invalid_str = ", ".join(invalid_list)
+ if invalid_str:
+ logger.info(f'Invalid address for set {set_name}: {invalid_str}')
+
+ count += 1
+
+ nft_conf_str = "\n".join(conf_lines) + "\n"
+ code = run(f'nft --file -', input=nft_conf_str)
+
+ logger.info(f'Updated {count} remote-groups in firewall - result: {code}')
+
+
def update_fqdn(config, node):
conf_lines = []
count = 0
@@ -177,39 +253,40 @@ def update_fqdn(config, node):
def update_interfaces(config, node):
if node == 'interfaces':
wg_interfaces = dict_search_args(config, 'wireguard')
+ if wg_interfaces:
+
+ peer_public_keys = {}
+ # for each wireguard interfaces
+ for interface, wireguard in wg_interfaces.items():
+ peer_public_keys[interface] = []
+ for peer, peer_config in wireguard['peer'].items():
+ # check peer if peer host-name or address is set
+ if 'host_name' in peer_config or 'address' in peer_config:
+ # check latest handshake
+ peer_public_keys[interface].append(
+ peer_config['public_key']
+ )
+
+ now_time = time.time()
+ for (interface, check_peer_public_keys) in peer_public_keys.items():
+ if len(check_peer_public_keys) == 0:
+ continue
- peer_public_keys = {}
- # for each wireguard interfaces
- for interface, wireguard in wg_interfaces.items():
- peer_public_keys[interface] = []
- for peer, peer_config in wireguard['peer'].items():
- # check peer if peer host-name or address is set
- if 'host_name' in peer_config or 'address' in peer_config:
- # check latest handshake
- peer_public_keys[interface].append(
- peer_config['public_key']
- )
-
- now_time = time.time()
- for (interface, check_peer_public_keys) in peer_public_keys.items():
- if len(check_peer_public_keys) == 0:
- continue
-
- intf = WireGuardIf(interface, create=False, debug=False)
- handshakes = intf.operational.get_latest_handshakes()
-
- # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME
- # if data is being transmitted between the peers. If no data is
- # transmitted, the handshake will not be initiated unless new
- # data begins to flow. Each handshake generates a new session
- # key, and the key is rotated at least every 120 seconds or
- # upon data transmission after a prolonged silence.
- for public_key, handshake_time in handshakes.items():
- if public_key in check_peer_public_keys and (
- handshake_time == 0
- or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME)
- ):
- intf.operational.reset_peer(public_key=public_key)
+ intf = WireGuardIf(interface, create=False, debug=False)
+ handshakes = intf.operational.get_latest_handshakes()
+
+ # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME
+ # if data is being transmitted between the peers. If no data is
+ # transmitted, the handshake will not be initiated unless new
+ # data begins to flow. Each handshake generates a new session
+ # key, and the key is rotated at least every 120 seconds or
+ # upon data transmission after a prolonged silence.
+ for public_key, handshake_time in handshakes.items():
+ if public_key in check_peer_public_keys and (
+ handshake_time == 0
+ or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME)
+ ):
+ intf.operational.reset_peer(public_key=public_key)
if __name__ == '__main__':
logger.info('VyOS domain resolver')
@@ -231,5 +308,6 @@ if __name__ == '__main__':
while True:
update_fqdn(firewall, 'firewall')
update_fqdn(nat, 'nat')
+ update_remote_group(firewall)
update_interfaces(interfaces, 'interfaces')
time.sleep(timeout)