#!/usr/bin/env python3
#
# Copyright (C) 2021-2022 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import csv
import gzip
import os
import re

from pathlib import Path
from socket import AF_INET
from socket import AF_INET6
from socket import getaddrinfo
from time import strftime

from vyos.remote import download
from vyos.template import is_ipv4
from vyos.template import render
from vyos.util import call
from vyos.util import cmd
from vyos.util import dict_search_args
from vyos.util import dict_search_recursive
from vyos.util import run

# Domain Resolver

def fqdn_config_parse(firewall):
    firewall['ip_fqdn'] = {}
    firewall['ip6_fqdn'] = {}

    for domain, path in dict_search_recursive(firewall, 'fqdn'):
        fw_name = path[1] # name/ipv6-name
        rule = path[3] # rule id
        suffix = path[4][0] # source/destination (1 char)
        set_name = f'{fw_name}_{rule}_{suffix}'
            
        if path[0] == 'name':
            firewall['ip_fqdn'][set_name] = domain
        elif path[0] == 'ipv6_name':
            firewall['ip6_fqdn'][set_name] = domain

def fqdn_resolve(fqdn, ipv6=False):
    try:
        res = getaddrinfo(fqdn, None, AF_INET6 if ipv6 else AF_INET)
        return set(item[4][0] for item in res)
    except:
        return None

# End Domain Resolver

def find_nftables_rule(table, chain, rule_matches=[]):
    # Find rule in table/chain that matches all criteria and return the handle
    results = cmd(f'sudo nft -a list chain {table} {chain}').split("\n")
    for line in results:
        if all(rule_match in line for rule_match in rule_matches):
            handle_search = re.search('handle (\d+)', line)
            if handle_search:
                return handle_search[1]
    return None

def remove_nftables_rule(table, chain, handle):
    cmd(f'sudo nft delete rule {table} {chain} handle {handle}')

# Functions below used by template generation

def nft_action(vyos_action):
    if vyos_action == 'accept':
        return 'return'
    return vyos_action

def parse_rule(rule_conf, fw_name, rule_id, ip_name):
    output = []
    def_suffix = '6' if ip_name == 'ip6' else ''

    if 'state' in rule_conf and rule_conf['state']:
        states = ",".join([s for s, v in rule_conf['state'].items() if v == 'enable'])

        if states:
            output.append(f'ct state {{{states}}}')

    if 'connection_status' in rule_conf and rule_conf['connection_status']:
        status = rule_conf['connection_status']
        if status['nat'] == 'destination':
            nat_status = '{dnat}'
            output.append(f'ct status {nat_status}')
        if status['nat'] == 'source':
            nat_status = '{snat}'
            output.append(f'ct status {nat_status}')

    if 'protocol' in rule_conf and rule_conf['protocol'] != 'all':
        proto = rule_conf['protocol']
        operator = ''
        if proto[0] == '!':
            operator = '!='
            proto = proto[1:]
        if proto == 'tcp_udp':
            proto = '{tcp, udp}'
        output.append(f'meta l4proto {operator} {proto}')

    for side in ['destination', 'source']:
        if side in rule_conf:
            prefix = side[0]
            side_conf = rule_conf[side]
            address_mask = side_conf.get('address_mask', None)

            if 'address' in side_conf:
                suffix = side_conf['address']
                operator = ''
                exclude = suffix[0] == '!'
                if exclude:
                    operator = '!= '
                    suffix = suffix[1:]
                if address_mask:
                    operator = '!=' if exclude else '=='
                    operator = f'& {address_mask} {operator} '
                output.append(f'{ip_name} {prefix}addr {operator}{suffix}')

            if 'fqdn' in side_conf:
                fqdn = side_conf['fqdn']
                operator = ''
                if fqdn[0] == '!':
                    operator = '!='
                output.append(f'{ip_name} {prefix}addr {operator} @FQDN_{fw_name}_{rule_id}_{prefix}')

            if dict_search_args(side_conf, 'geoip', 'country_code'):
                operator = ''
                if dict_search_args(side_conf, 'geoip', 'inverse_match') != None:
                    operator = '!='
                output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC_{fw_name}_{rule_id}')

            if 'mac_address' in side_conf:
                suffix = side_conf["mac_address"]
                if suffix[0] == '!':
                    suffix = f'!= {suffix[1:]}'
                output.append(f'ether {prefix}addr {suffix}')

            if 'port' in side_conf:
                proto = rule_conf['protocol']
                port = side_conf['port'].split(',')

                ports = []
                negated_ports = []

                for p in port:
                    if p[0] == '!':
                        negated_ports.append(p[1:])
                    else:
                        ports.append(p)

                if proto == 'tcp_udp':
                    proto = 'th'

                if ports:
                    ports_str = ','.join(ports)
                    output.append(f'{proto} {prefix}port {{{ports_str}}}')

                if negated_ports:
                    negated_ports_str = ','.join(negated_ports)
                    output.append(f'{proto} {prefix}port != {{{negated_ports_str}}}')

            if 'group' in side_conf:
                group = side_conf['group']
                if 'address_group' in group:
                    group_name = group['address_group']
                    operator = ''
                    exclude = group_name[0] == "!"
                    if exclude:
                        operator = '!='
                        group_name = group_name[1:]
                    if address_mask:
                        operator = '!=' if exclude else '=='
                        operator = f'& {address_mask} {operator}'
                    output.append(f'{ip_name} {prefix}addr {operator} @A{def_suffix}_{group_name}')
                # Generate firewall group domain-group
                elif 'domain_group' in group:
                    group_name = group['domain_group']
                    operator = ''
                    if group_name[0] == '!':
                        operator = '!='
                        group_name = group_name[1:]
                    output.append(f'{ip_name} {prefix}addr {operator} @D_{group_name}')
                elif 'network_group' in group:
                    group_name = group['network_group']
                    operator = ''
                    if group_name[0] == '!':
                        operator = '!='
                        group_name = group_name[1:]
                    output.append(f'{ip_name} {prefix}addr {operator} @N{def_suffix}_{group_name}')
                if 'mac_group' in group:
                    group_name = group['mac_group']
                    operator = ''
                    if group_name[0] == '!':
                        operator = '!='
                        group_name = group_name[1:]
                    output.append(f'ether {prefix}addr {operator} @M_{group_name}')
                if 'port_group' in group:
                    proto = rule_conf['protocol']
                    group_name = group['port_group']

                    if proto == 'tcp_udp':
                        proto = 'th'

                    operator = ''
                    if group_name[0] == '!':
                        operator = '!='
                        group_name = group_name[1:]

                    output.append(f'{proto} {prefix}port {operator} @P_{group_name}')

    if 'log' in rule_conf and rule_conf['log'] == 'enable':
        action = rule_conf['action'] if 'action' in rule_conf else 'accept'
        output.append(f'log prefix "[{fw_name[:19]}-{rule_id}-{action[:1].upper()}]"')

        if 'log_level' in rule_conf:
            log_level = rule_conf['log_level']
            output.append(f'level {log_level}')


    if 'hop_limit' in rule_conf:
        operators = {'eq': '==', 'gt': '>', 'lt': '<'}
        for op, operator in operators.items():
            if op in rule_conf['hop_limit']:
                value = rule_conf['hop_limit'][op]
                output.append(f'ip6 hoplimit {operator} {value}')

    if 'inbound_interface' in rule_conf:
        if 'interface_name' in rule_conf['inbound_interface']:
            iiface = rule_conf['inbound_interface']['interface_name']
            output.append(f'iifname {{{iiface}}}')
        else:
            iiface = rule_conf['inbound_interface']['interface_group']
            output.append(f'iifname @I_{iiface}')

    if 'outbound_interface' in rule_conf:
        if 'interface_name' in rule_conf['outbound_interface']:
            oiface = rule_conf['outbound_interface']['interface_name']
            output.append(f'oifname {{{oiface}}}')
        else:
            oiface = rule_conf['outbound_interface']['interface_group']
            output.append(f'oifname @I_{oiface}')

    if 'ttl' in rule_conf:
        operators = {'eq': '==', 'gt': '>', 'lt': '<'}
        for op, operator in operators.items():
            if op in rule_conf['ttl']:
                value = rule_conf['ttl'][op]
                output.append(f'ip ttl {operator} {value}')

    for icmp in ['icmp', 'icmpv6']:
        if icmp in rule_conf:
            if 'type_name' in rule_conf[icmp]:
                output.append(icmp + ' type ' + rule_conf[icmp]['type_name'])
            else:
                if 'code' in rule_conf[icmp]:
                    output.append(icmp + ' code ' + rule_conf[icmp]['code'])
                if 'type' in rule_conf[icmp]:
                    output.append(icmp + ' type ' + rule_conf[icmp]['type'])


    if 'packet_length' in rule_conf:
        lengths_str = ','.join(rule_conf['packet_length'])
        output.append(f'ip{def_suffix} length {{{lengths_str}}}')

    if 'packet_length_exclude' in rule_conf:
        negated_lengths_str = ','.join(rule_conf['packet_length_exclude'])
        output.append(f'ip{def_suffix} length != {{{negated_lengths_str}}}')

    if 'dscp' in rule_conf:
        dscp_str = ','.join(rule_conf['dscp'])
        output.append(f'ip{def_suffix} dscp {{{dscp_str}}}')

    if 'dscp_exclude' in rule_conf:
        negated_dscp_str = ','.join(rule_conf['dscp_exclude'])
        output.append(f'ip{def_suffix} dscp != {{{negated_dscp_str}}}')

    if 'ipsec' in rule_conf:
        if 'match_ipsec' in rule_conf['ipsec']:
            output.append('meta ipsec == 1')
        if 'match_non_ipsec' in rule_conf['ipsec']:
            output.append('meta ipsec == 0')

    if 'fragment' in rule_conf:
        # Checking for fragmentation after priority -400 is not possible,
        # so we use a priority -450 hook to set a mark
        if 'match_frag' in rule_conf['fragment']:
            output.append('meta mark 0xffff1')
        if 'match_non_frag' in rule_conf['fragment']:
            output.append('meta mark != 0xffff1')

    if 'limit' in rule_conf:
        if 'rate' in rule_conf['limit']:
            output.append(f'limit rate {rule_conf["limit"]["rate"]}')
            if 'burst' in rule_conf['limit']:
                output.append(f'burst {rule_conf["limit"]["burst"]} packets')

    if 'recent' in rule_conf:
        count = rule_conf['recent']['count']
        time = rule_conf['recent']['time']
        output.append(f'add @RECENT{def_suffix}_{fw_name}_{rule_id} {{ {ip_name} saddr limit rate over {count}/{time} burst {count} packets }}')

    if 'time' in rule_conf:
        output.append(parse_time(rule_conf['time']))

    tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags')
    if tcp_flags:
        output.append(parse_tcp_flags(tcp_flags))

    # TCP MSS
    tcp_mss = dict_search_args(rule_conf, 'tcp', 'mss')
    if tcp_mss:
        output.append(f'tcp option maxseg size {tcp_mss}')

    if 'connection_mark' in rule_conf:
        conn_mark_str = ','.join(rule_conf['connection_mark'])
        output.append(f'ct mark {{{conn_mark_str}}}')

    output.append('counter')

    if 'set' in rule_conf:
        output.append(parse_policy_set(rule_conf['set'], def_suffix))

    if 'action' in rule_conf:
        output.append(nft_action(rule_conf['action']))
        if 'jump' in rule_conf['action']:
            target = rule_conf['jump_target']
            output.append(f'NAME{def_suffix}_{target}')

    else:
        output.append('return')

    output.append(f'comment "{fw_name}-{rule_id}"')
    return " ".join(output)

def parse_tcp_flags(flags):
    include = [flag for flag in flags if flag != 'not']
    exclude = list(flags['not']) if 'not' in flags else []
    return f'tcp flags & ({"|".join(include + exclude)}) == {"|".join(include) if include else "0x0"}'

def parse_time(time):
    out = []
    if 'startdate' in time:
        start = time['startdate']
        if 'T' not in start and 'starttime' in time:
            start += f' {time["starttime"]}'
        out.append(f'time >= "{start}"')
    if 'starttime' in time and 'startdate' not in time:
        out.append(f'hour >= "{time["starttime"]}"')
    if 'stopdate' in time:
        stop = time['stopdate']
        if 'T' not in stop and 'stoptime' in time:
            stop += f' {time["stoptime"]}'
        out.append(f'time < "{stop}"')
    if 'stoptime' in time and 'stopdate' not in time:
        out.append(f'hour < "{time["stoptime"]}"')
    if 'weekdays' in time:
        days = time['weekdays'].split(",")
        out_days = [f'"{day}"' for day in days if day[0] != '!']
        out.append(f'day {{{",".join(out_days)}}}')
    return " ".join(out)

def parse_policy_set(set_conf, def_suffix):
    out = []
    if 'connection_mark' in set_conf:
        conn_mark = set_conf['connection_mark']
        out.append(f'ct mark set {conn_mark}')
    if 'dscp' in set_conf:
        dscp = set_conf['dscp']
        out.append(f'ip{def_suffix} dscp set {dscp}')
    if 'mark' in set_conf:
        mark = set_conf['mark']
        out.append(f'meta mark set {mark}')
    if 'table' in set_conf:
        table = set_conf['table']
        if table == 'main':
            table = '254'
        mark = 0x7FFFFFFF - int(table)
        out.append(f'meta mark set {mark}')
    if 'tcp_mss' in set_conf:
        mss = set_conf['tcp_mss']
        out.append(f'tcp option maxseg size set {mss}')
    return " ".join(out)

# GeoIP

nftables_geoip_conf = '/run/nftables-geoip.conf'
geoip_database = '/usr/share/vyos-geoip/dbip-country-lite.csv.gz'
geoip_lock_file = '/run/vyos-geoip.lock'

def geoip_load_data(codes=[]):
    data = None

    if not os.path.exists(geoip_database):
        return []

    try:
        with gzip.open(geoip_database, mode='rt') as csv_fh:
            reader = csv.reader(csv_fh)
            out = []
            for start, end, code in reader:
                if code.lower() in codes:
                    out.append([start, end, code.lower()])
            return out
    except:
        print('Error: Failed to open GeoIP database')
    return []

def geoip_download_data():
    url = 'https://download.db-ip.com/free/dbip-country-lite-{}.csv.gz'.format(strftime("%Y-%m"))
    try:
        dirname = os.path.dirname(geoip_database)
        if not os.path.exists(dirname):
            os.mkdir(dirname)

        download(geoip_database, url)
        print("Downloaded GeoIP database")
        return True
    except:
        print("Error: Failed to download GeoIP database")
    return False

class GeoIPLock(object):
    def __init__(self, file):
        self.file = file

    def __enter__(self):
        if os.path.exists(self.file):
            return False

        Path(self.file).touch()
        return True

    def __exit__(self, exc_type, exc_value, tb):
        os.unlink(self.file)

def geoip_update(firewall, force=False):
    with GeoIPLock(geoip_lock_file) as lock:
        if not lock:
            print("Script is already running")
            return False

        if not firewall:
            print("Firewall is not configured")
            return True

        if not os.path.exists(geoip_database):
            if not geoip_download_data():
                return False
        elif force:
            geoip_download_data()

        ipv4_codes = {}
        ipv6_codes = {}

        ipv4_sets = {}
        ipv6_sets = {}

        # Map country codes to set names
        for codes, path in dict_search_recursive(firewall, 'country_code'):
            set_name = f'GEOIP_CC_{path[1]}_{path[3]}'
            if path[0] == 'name':
                for code in codes:
                    ipv4_codes.setdefault(code, []).append(set_name)
            elif path[0] == 'ipv6_name':
                for code in codes:
                    ipv6_codes.setdefault(code, []).append(set_name)

        if not ipv4_codes and not ipv6_codes:
            if force:
                print("GeoIP not in use by firewall")
            return True

        geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes])

        # Iterate IP blocks to assign to sets
        for start, end, code in geoip_data:
            ipv4 = is_ipv4(start)
            if code in ipv4_codes and ipv4:
                ip_range = f'{start}-{end}' if start != end else start
                for setname in ipv4_codes[code]:
                    ipv4_sets.setdefault(setname, []).append(ip_range)
            if code in ipv6_codes and not ipv4:
                ip_range = f'{start}-{end}' if start != end else start
                for setname in ipv6_codes[code]:
                    ipv6_sets.setdefault(setname, []).append(ip_range)

        render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', {
            'ipv4_sets': ipv4_sets,
            'ipv6_sets': ipv6_sets
        })

        result = run(f'nft -f {nftables_geoip_conf}')
        if result != 0:
            print('Error: GeoIP failed to update firewall')
            return False

        return True