diff options
Diffstat (limited to 'python')
| -rw-r--r-- | python/vyos/firewall.py | 138 | 
1 files changed, 138 insertions, 0 deletions
| diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 355ec44b0..a61d0a9f8 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -14,11 +14,22 @@  # You should have received a copy of the GNU General Public License  # along with this program.  If not, see <http://www.gnu.org/licenses/>. +import csv +import gzip +import os  import re +from pathlib import Path +from time import strftime + +from vyos.remote import download +from vyos.template import is_ipv4 +from vyos.template import render  from vyos.util import call  from vyos.util import cmd  from vyos.util import dict_search_args +from vyos.util import dict_search_recursive +from vyos.util import run  # Functions for firewall group domain-groups @@ -139,6 +150,9 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name):                  if suffix[0] == '!':                      suffix = f'!= {suffix[1:]}'                  output.append(f'{ip_name} {prefix}addr {suffix}') +             +            if dict_search_args(side_conf, 'geoip', 'country_code'): +                output.append(f'{ip_name} {prefix}addr @GEOIP_CC_{fw_name}_{rule_id}')              if 'mac_address' in side_conf:                  suffix = side_conf["mac_address"] @@ -338,3 +352,127 @@ def parse_policy_set(set_conf, def_suffix):          mss = set_conf['tcp_mss']          out.append(f'tcp option maxseg size set {mss}')      return " ".join(out) + +# GeoIP + +nftables_geoip_conf = '/run/nftables-geoip.conf' +geoip_database = '/usr/share/vyos-geoip/dbip-country-lite.csv.gz' +geoip_lock_file = '/run/vyos-geoip.lock' + +def geoip_load_data(codes=[]): +    data = None + +    if not os.path.exists(geoip_database): +        return [] + +    try: +        with gzip.open(geoip_database, mode='rt') as csv_fh: +            reader = csv.reader(csv_fh) +            out = [] +            for start, end, code in reader: +                if code.lower() in codes: +                    out.append([start, end, code.lower()]) +            return out +    except: +        print('Error: Failed to open GeoIP database') +    return [] + +def geoip_download_data(): +    url = 'https://download.db-ip.com/free/dbip-country-lite-{}.csv.gz'.format(strftime("%Y-%m")) +    try: +        dirname = os.path.dirname(geoip_database) +        if not os.path.exists(dirname): +            os.mkdir(dirname) + +        download(geoip_database, url) +        print("Downloaded GeoIP database") +        return True +    except: +        print("Error: Failed to download GeoIP database") +    return False + +class GeoIPLock(object): +    def __init__(self, file): +        self.file = file + +    def __enter__(self): +        if os.path.exists(self.file): +            return False + +        Path(self.file).touch() +        return True + +    def __exit__(self, exc_type, exc_value, tb): +        os.unlink(self.file) + +def geoip_update(firewall, force=False): +    with GeoIPLock(geoip_lock_file) as lock: +        if not lock: +            print("Script is already running") +            return False + +        if not firewall: +            print("Firewall is not configured") +            return True + +        if not os.path.exists(geoip_database): +            if not geoip_download_data(): +                return False +        elif force: +            geoip_download_data() + +        ipv4_codes = {} +        ipv6_codes = {} + +        ipv4_sets = {} +        ipv6_sets = {} + +        # Map country codes to set names +        for codes, path in dict_search_recursive(firewall, 'country_code'): +            if path[0] == 'name': +                set_name = f'GEOIP_CC_{path[1]}_{path[3]}' +                ipv4_sets[set_name] = [] +                for code in codes: +                    if code not in ipv4_codes: +                        ipv4_codes[code] = [set_name] +                    else: +                        ipv4_codes[code].append(set_n) +            elif path[0] == 'ipv6_name': +                set_name = f'GEOIP_CC_{path[1]}_{path[3]}' +                ipv6_sets[set_name] = [] +                for code in codes: +                    if code not in ipv6_codes: +                        ipv6_codes[code] = [set_name] +                    else: +                        ipv6_codes[code].append(set_name) + +        if not ipv4_codes and not ipv6_codes: +            if force: +                print("GeoIP not in use by firewall") +            return True + +        geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes]) + +        # Iterate IP blocks to assign to sets +        for start, end, code in geoip_data: +            ipv4 = is_ipv4(start) +            if code in ipv4_codes and ipv4: +                ip_range = f'{start}-{end}' if start != end else start +                for setname in ipv4_codes[code]: +                    ipv4_sets[setname].append(ip_range) +            if code in ipv6_codes and not ipv4: +                ip_range = f'{start}-{end}' if start != end else start +                for setname in ipv6_codes[code]: +                    ipv6_sets[setname].append(ip_range) + +        render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', { +            'ipv4_sets': ipv4_sets, +            'ipv6_sets': ipv6_sets +        }) + +        result = run(f'nft -f {nftables_geoip_conf}') +        if result != 0: +            print('Error: GeoIP failed to update firewall') +            return False + +        return True | 
