diff options
Diffstat (limited to 'python/vyos/firewall.py')
-rw-r--r-- | python/vyos/firewall.py | 252 |
1 files changed, 243 insertions, 9 deletions
diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 04fd44173..48263eef5 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (C) 2021 VyOS maintainers and contributors +# Copyright (C) 2021-2022 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as @@ -14,10 +14,51 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +import csv +import gzip +import os import re +from pathlib import Path +from socket import AF_INET +from socket import AF_INET6 +from socket import getaddrinfo +from time import strftime + +from vyos.remote import download +from vyos.template import is_ipv4 +from vyos.template import render +from vyos.util import call from vyos.util import cmd from vyos.util import dict_search_args +from vyos.util import dict_search_recursive +from vyos.util import run + +# Domain Resolver + +def fqdn_config_parse(firewall): + firewall['ip_fqdn'] = {} + firewall['ip6_fqdn'] = {} + + for domain, path in dict_search_recursive(firewall, 'fqdn'): + fw_name = path[1] # name/ipv6-name + rule = path[3] # rule id + suffix = path[4][0] # source/destination (1 char) + set_name = f'{fw_name}_{rule}_{suffix}' + + if path[0] == 'name': + firewall['ip_fqdn'][set_name] = domain + elif path[0] == 'ipv6_name': + firewall['ip6_fqdn'][set_name] = domain + +def fqdn_resolve(fqdn, ipv6=False): + try: + res = getaddrinfo(fqdn, None, AF_INET6 if ipv6 else AF_INET) + return set(item[4][0] for item in res) + except: + return None + +# End Domain Resolver def find_nftables_rule(table, chain, rule_matches=[]): # Find rule in table/chain that matches all criteria and return the handle @@ -72,12 +113,32 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if side in rule_conf: prefix = side[0] side_conf = rule_conf[side] + address_mask = side_conf.get('address_mask', None) if 'address' in side_conf: suffix = side_conf['address'] - if suffix[0] == '!': - suffix = f'!= {suffix[1:]}' - output.append(f'{ip_name} {prefix}addr {suffix}') + operator = '' + exclude = suffix[0] == '!' + if exclude: + operator = '!= ' + suffix = suffix[1:] + if address_mask: + operator = '!=' if exclude else '==' + operator = f'& {address_mask} {operator} ' + output.append(f'{ip_name} {prefix}addr {operator}{suffix}') + + if 'fqdn' in side_conf: + fqdn = side_conf['fqdn'] + operator = '' + if fqdn[0] == '!': + operator = '!=' + output.append(f'{ip_name} {prefix}addr {operator} @FQDN_{fw_name}_{rule_id}_{prefix}') + + if dict_search_args(side_conf, 'geoip', 'country_code'): + operator = '' + if dict_search_args(side_conf, 'geoip', 'inverse_match') != None: + operator = '!=' + output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC_{fw_name}_{rule_id}') if 'mac_address' in side_conf: suffix = side_conf["mac_address"] @@ -114,24 +175,36 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if 'address_group' in group: group_name = group['address_group'] operator = '' + exclude = group_name[0] == "!" + if exclude: + operator = '!=' + group_name = group_name[1:] + if address_mask: + operator = '!=' if exclude else '==' + operator = f'& {address_mask} {operator}' + output.append(f'{ip_name} {prefix}addr {operator} @A{def_suffix}_{group_name}') + # Generate firewall group domain-group + elif 'domain_group' in group: + group_name = group['domain_group'] + operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] - output.append(f'{ip_name} {prefix}addr {operator} $A{def_suffix}_{group_name}') + output.append(f'{ip_name} {prefix}addr {operator} @D_{group_name}') elif 'network_group' in group: group_name = group['network_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] - output.append(f'{ip_name} {prefix}addr {operator} $N{def_suffix}_{group_name}') + output.append(f'{ip_name} {prefix}addr {operator} @N{def_suffix}_{group_name}') if 'mac_group' in group: group_name = group['mac_group'] operator = '' if group_name[0] == '!': operator = '!=' group_name = group_name[1:] - output.append(f'ether {prefix}addr {operator} $M_{group_name}') + output.append(f'ether {prefix}addr {operator} @M_{group_name}') if 'port_group' in group: proto = rule_conf['protocol'] group_name = group['port_group'] @@ -144,11 +217,16 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): operator = '!=' group_name = group_name[1:] - output.append(f'{proto} {prefix}port {operator} $P_{group_name}') + output.append(f'{proto} {prefix}port {operator} @P_{group_name}') if 'log' in rule_conf and rule_conf['log'] == 'enable': action = rule_conf['action'] if 'action' in rule_conf else 'accept' - output.append(f'log prefix "[{fw_name[:19]}-{rule_id}-{action[:1].upper()}] "') + output.append(f'log prefix "[{fw_name[:19]}-{rule_id}-{action[:1].upper()}]"') + + if 'log_level' in rule_conf: + log_level = rule_conf['log_level'] + output.append(f'level {log_level}') + if 'hop_limit' in rule_conf: operators = {'eq': '==', 'gt': '>', 'lt': '<'} @@ -157,6 +235,21 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): value = rule_conf['hop_limit'][op] output.append(f'ip6 hoplimit {operator} {value}') + if 'inbound_interface' in rule_conf: + iiface = rule_conf['inbound_interface'] + output.append(f'iifname {iiface}') + + if 'outbound_interface' in rule_conf: + oiface = rule_conf['outbound_interface'] + output.append(f'oifname {oiface}') + + if 'ttl' in rule_conf: + operators = {'eq': '==', 'gt': '>', 'lt': '<'} + for op, operator in operators.items(): + if op in rule_conf['ttl']: + value = rule_conf['ttl'][op] + output.append(f'ip ttl {operator} {value}') + for icmp in ['icmp', 'icmpv6']: if icmp in rule_conf: if 'type_name' in rule_conf[icmp]: @@ -167,6 +260,23 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if 'type' in rule_conf[icmp]: output.append(icmp + ' type ' + rule_conf[icmp]['type']) + + if 'packet_length' in rule_conf: + lengths_str = ','.join(rule_conf['packet_length']) + output.append(f'ip{def_suffix} length {{{lengths_str}}}') + + if 'packet_length_exclude' in rule_conf: + negated_lengths_str = ','.join(rule_conf['packet_length_exclude']) + output.append(f'ip{def_suffix} length != {{{negated_lengths_str}}}') + + if 'dscp' in rule_conf: + dscp_str = ','.join(rule_conf['dscp']) + output.append(f'ip{def_suffix} dscp {{{dscp_str}}}') + + if 'dscp_exclude' in rule_conf: + negated_dscp_str = ','.join(rule_conf['dscp_exclude']) + output.append(f'ip{def_suffix} dscp != {{{negated_dscp_str}}}') + if 'ipsec' in rule_conf: if 'match_ipsec' in rule_conf['ipsec']: output.append('meta ipsec == 1') @@ -199,6 +309,11 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if tcp_flags: output.append(parse_tcp_flags(tcp_flags)) + # TCP MSS + tcp_mss = dict_search_args(rule_conf, 'tcp', 'mss') + if tcp_mss: + output.append(f'tcp option maxseg size {tcp_mss}') + output.append('counter') if 'set' in rule_conf: @@ -206,6 +321,10 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if 'action' in rule_conf: output.append(nft_action(rule_conf['action'])) + if 'jump' in rule_conf['action']: + target = rule_conf['jump_target'] + output.append(f'NAME{def_suffix}_{target}') + else: output.append('return') @@ -257,3 +376,118 @@ def parse_policy_set(set_conf, def_suffix): mss = set_conf['tcp_mss'] out.append(f'tcp option maxseg size set {mss}') return " ".join(out) + +# GeoIP + +nftables_geoip_conf = '/run/nftables-geoip.conf' +geoip_database = '/usr/share/vyos-geoip/dbip-country-lite.csv.gz' +geoip_lock_file = '/run/vyos-geoip.lock' + +def geoip_load_data(codes=[]): + data = None + + if not os.path.exists(geoip_database): + return [] + + try: + with gzip.open(geoip_database, mode='rt') as csv_fh: + reader = csv.reader(csv_fh) + out = [] + for start, end, code in reader: + if code.lower() in codes: + out.append([start, end, code.lower()]) + return out + except: + print('Error: Failed to open GeoIP database') + return [] + +def geoip_download_data(): + url = 'https://download.db-ip.com/free/dbip-country-lite-{}.csv.gz'.format(strftime("%Y-%m")) + try: + dirname = os.path.dirname(geoip_database) + if not os.path.exists(dirname): + os.mkdir(dirname) + + download(geoip_database, url) + print("Downloaded GeoIP database") + return True + except: + print("Error: Failed to download GeoIP database") + return False + +class GeoIPLock(object): + def __init__(self, file): + self.file = file + + def __enter__(self): + if os.path.exists(self.file): + return False + + Path(self.file).touch() + return True + + def __exit__(self, exc_type, exc_value, tb): + os.unlink(self.file) + +def geoip_update(firewall, force=False): + with GeoIPLock(geoip_lock_file) as lock: + if not lock: + print("Script is already running") + return False + + if not firewall: + print("Firewall is not configured") + return True + + if not os.path.exists(geoip_database): + if not geoip_download_data(): + return False + elif force: + geoip_download_data() + + ipv4_codes = {} + ipv6_codes = {} + + ipv4_sets = {} + ipv6_sets = {} + + # Map country codes to set names + for codes, path in dict_search_recursive(firewall, 'country_code'): + set_name = f'GEOIP_CC_{path[1]}_{path[3]}' + if path[0] == 'name': + for code in codes: + ipv4_codes.setdefault(code, []).append(set_name) + elif path[0] == 'ipv6_name': + for code in codes: + ipv6_codes.setdefault(code, []).append(set_name) + + if not ipv4_codes and not ipv6_codes: + if force: + print("GeoIP not in use by firewall") + return True + + geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes]) + + # Iterate IP blocks to assign to sets + for start, end, code in geoip_data: + ipv4 = is_ipv4(start) + if code in ipv4_codes and ipv4: + ip_range = f'{start}-{end}' if start != end else start + for setname in ipv4_codes[code]: + ipv4_sets.setdefault(setname, []).append(ip_range) + if code in ipv6_codes and not ipv4: + ip_range = f'{start}-{end}' if start != end else start + for setname in ipv6_codes[code]: + ipv6_sets.setdefault(setname, []).append(ip_range) + + render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', { + 'ipv4_sets': ipv4_sets, + 'ipv6_sets': ipv6_sets + }) + + result = run(f'nft -f {nftables_geoip_conf}') + if result != 0: + print('Error: GeoIP failed to update firewall') + return False + + return True |