diff options
Diffstat (limited to 'python/vyos/firewall.py')
-rw-r--r-- | python/vyos/firewall.py | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 355ec44b0..a61d0a9f8 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -14,11 +14,22 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +import csv +import gzip +import os import re +from pathlib import Path +from time import strftime + +from vyos.remote import download +from vyos.template import is_ipv4 +from vyos.template import render from vyos.util import call from vyos.util import cmd from vyos.util import dict_search_args +from vyos.util import dict_search_recursive +from vyos.util import run # Functions for firewall group domain-groups @@ -139,6 +150,9 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): if suffix[0] == '!': suffix = f'!= {suffix[1:]}' output.append(f'{ip_name} {prefix}addr {suffix}') + + if dict_search_args(side_conf, 'geoip', 'country_code'): + output.append(f'{ip_name} {prefix}addr @GEOIP_CC_{fw_name}_{rule_id}') if 'mac_address' in side_conf: suffix = side_conf["mac_address"] @@ -338,3 +352,127 @@ def parse_policy_set(set_conf, def_suffix): mss = set_conf['tcp_mss'] out.append(f'tcp option maxseg size set {mss}') return " ".join(out) + +# GeoIP + +nftables_geoip_conf = '/run/nftables-geoip.conf' +geoip_database = '/usr/share/vyos-geoip/dbip-country-lite.csv.gz' +geoip_lock_file = '/run/vyos-geoip.lock' + +def geoip_load_data(codes=[]): + data = None + + if not os.path.exists(geoip_database): + return [] + + try: + with gzip.open(geoip_database, mode='rt') as csv_fh: + reader = csv.reader(csv_fh) + out = [] + for start, end, code in reader: + if code.lower() in codes: + out.append([start, end, code.lower()]) + return out + except: + print('Error: Failed to open GeoIP database') + return [] + +def geoip_download_data(): + url = 'https://download.db-ip.com/free/dbip-country-lite-{}.csv.gz'.format(strftime("%Y-%m")) + try: + dirname = os.path.dirname(geoip_database) + if not os.path.exists(dirname): + os.mkdir(dirname) + + download(geoip_database, url) + print("Downloaded GeoIP database") + return True + except: + print("Error: Failed to download GeoIP database") + return False + +class GeoIPLock(object): + def __init__(self, file): + self.file = file + + def __enter__(self): + if os.path.exists(self.file): + return False + + Path(self.file).touch() + return True + + def __exit__(self, exc_type, exc_value, tb): + os.unlink(self.file) + +def geoip_update(firewall, force=False): + with GeoIPLock(geoip_lock_file) as lock: + if not lock: + print("Script is already running") + return False + + if not firewall: + print("Firewall is not configured") + return True + + if not os.path.exists(geoip_database): + if not geoip_download_data(): + return False + elif force: + geoip_download_data() + + ipv4_codes = {} + ipv6_codes = {} + + ipv4_sets = {} + ipv6_sets = {} + + # Map country codes to set names + for codes, path in dict_search_recursive(firewall, 'country_code'): + if path[0] == 'name': + set_name = f'GEOIP_CC_{path[1]}_{path[3]}' + ipv4_sets[set_name] = [] + for code in codes: + if code not in ipv4_codes: + ipv4_codes[code] = [set_name] + else: + ipv4_codes[code].append(set_n) + elif path[0] == 'ipv6_name': + set_name = f'GEOIP_CC_{path[1]}_{path[3]}' + ipv6_sets[set_name] = [] + for code in codes: + if code not in ipv6_codes: + ipv6_codes[code] = [set_name] + else: + ipv6_codes[code].append(set_name) + + if not ipv4_codes and not ipv6_codes: + if force: + print("GeoIP not in use by firewall") + return True + + geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes]) + + # Iterate IP blocks to assign to sets + for start, end, code in geoip_data: + ipv4 = is_ipv4(start) + if code in ipv4_codes and ipv4: + ip_range = f'{start}-{end}' if start != end else start + for setname in ipv4_codes[code]: + ipv4_sets[setname].append(ip_range) + if code in ipv6_codes and not ipv4: + ip_range = f'{start}-{end}' if start != end else start + for setname in ipv6_codes[code]: + ipv6_sets[setname].append(ip_range) + + render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', { + 'ipv4_sets': ipv4_sets, + 'ipv6_sets': ipv6_sets + }) + + result = run(f'nft -f {nftables_geoip_conf}') + if result != 0: + print('Error: GeoIP failed to update firewall') + return False + + return True |