#!/usr/bin/env python3
# Copyright (C) 2021 VyOS maintainers and contributors
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import re

from json import loads
from sys import exit

from vyos.config import Config
from vyos.template import render
from vyos.util import cmd
from vyos.util import dict_search_args
from vyos.util import run
from vyos import ConfigError
from vyos import airbag

mark_offset = 0x7FFFFFFF
nftables_conf = '/run/nftables_policy.conf'

preserve_chains = [

valid_groups = [

def get_policy_interfaces(conf):
    out = {}
    interfaces = conf.get_config_dict(['interfaces'], key_mangling=('-', '_'), get_first_key=True,
    def find_interfaces(iftype_conf, output={}, prefix=''):
        for ifname, if_conf in iftype_conf.items():
            if 'policy' in if_conf:
                output[prefix + ifname] = if_conf['policy']
            for vif in ['vif', 'vif_s', 'vif_c']:
                if vif in if_conf:
                    output.update(find_interfaces(if_conf[vif], output, f'{prefix}{ifname}.'))
        return output
    for iftype, iftype_conf in interfaces.items():
    return out

def get_config(config=None):
    if config:
        conf = config
        conf = Config()
    base = ['policy']

    policy = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True,

    policy['firewall_group'] = conf.get_config_dict(['firewall', 'group'], key_mangling=('-', '_'), get_first_key=True,
    policy['interfaces'] = get_policy_interfaces(conf)

    return policy

def verify_rule(policy, name, rule_conf, ipv6):
    icmp = 'icmp' if not ipv6 else 'icmpv6'
    if icmp in rule_conf:
        icmp_defined = False
        if 'type_name' in rule_conf[icmp]:
            icmp_defined = True
            if 'code' in rule_conf[icmp] or 'type' in rule_conf[icmp]:
                raise ConfigError(f'{name} rule {rule_id}: Cannot use ICMP type/code with ICMP type-name')
        if 'code' in rule_conf[icmp]:
            icmp_defined = True
            if 'type' not in rule_conf[icmp]:
                raise ConfigError(f'{name} rule {rule_id}: ICMP code can only be defined if ICMP type is defined')
        if 'type' in rule_conf[icmp]:
            icmp_defined = True

        if icmp_defined and 'protocol' not in rule_conf or rule_conf['protocol'] != icmp:
            raise ConfigError(f'{name} rule {rule_id}: ICMP type/code or type-name can only be defined if protocol is ICMP')

    if 'set' in rule_conf:
        if 'tcp_mss' in rule_conf['set']:
            tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags')
            if not tcp_flags or 'syn' not in tcp_flags:
                raise ConfigError(f'{name} rule {rule_id}: TCP SYN flag must be set to modify TCP-MSS')

    tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags')
    if tcp_flags:
        if dict_search_args(rule_conf, 'protocol') != 'tcp':
            raise ConfigError('Protocol must be tcp when specifying tcp flags')

        not_flags = dict_search_args(rule_conf, 'tcp', 'flags', 'not')
        if not_flags:
            duplicates = [flag for flag in tcp_flags if flag in not_flags]
            if duplicates:
                raise ConfigError(f'Cannot match a tcp flag as set and not set')

    for side in ['destination', 'source']:
        if side in rule_conf:
            side_conf = rule_conf[side]

            if 'group' in side_conf:
                if {'address_group', 'network_group'} <= set(side_conf['group']):
                    raise ConfigError('Only one address-group or network-group can be specified')

                for group in valid_groups:
                    if group in side_conf['group']:
                        group_name = side_conf['group'][group]
                        fw_group = f'ipv6_{group}' if ipv6 and group in ['address_group', 'network_group'] else group
                        error_group = fw_group.replace("_", "-")
                        group_obj = dict_search_args(policy['firewall_group'], fw_group, group_name)

                        if group_obj is None:
                            raise ConfigError(f'Invalid {error_group} "{group_name}" on policy route rule')

                        if not group_obj:
                            print(f'WARNING: {error_group} "{group_name}" has no members')

            if 'port' in side_conf or dict_search_args(side_conf, 'group', 'port_group'):
                if 'protocol' not in rule_conf:
                    raise ConfigError('Protocol must be defined if specifying a port or port-group')

                if rule_conf['protocol'] not in ['tcp', 'udp', 'tcp_udp']:
                    raise ConfigError('Protocol must be tcp, udp, or tcp_udp when specifying a port or port-group')

def verify(policy):
    for route in ['route', 'route6']:
        ipv6 = route == 'route6'
        if route in policy:
            for name, pol_conf in policy[route].items():
                if 'rule' in pol_conf:
                    for rule_id, rule_conf in pol_conf['rule'].items():
                        verify_rule(policy, name, rule_conf, ipv6)

    for ifname, if_policy in policy['interfaces'].items():
        name = dict_search_args(if_policy, 'route')
        ipv6_name = dict_search_args(if_policy, 'route6')

        if name and not dict_search_args(policy, 'route', name):
            raise ConfigError(f'Policy route "{name}" is still referenced on interface {ifname}')

        if ipv6_name and not dict_search_args(policy, 'route6', ipv6_name):
            raise ConfigError(f'Policy route6 "{ipv6_name}" is still referenced on interface {ifname}')

    return None

def cleanup_rule(table, jump_chain):
    commands = []
    results = cmd(f'nft -a list table {table}').split("\n")
    for line in results:
        if f'jump {jump_chain}' in line:
            handle_search = re.search('handle (\d+)', line)
            if handle_search:
                commands.append(f'delete rule {table} {chain} handle {handle_search[1]}')
    return commands

def cleanup_commands(policy):
    commands = []
    for table in ['ip mangle', 'ip6 mangle']:
        json_str = cmd(f'nft -j list table {table}')
        obj = loads(json_str)
        if 'nftables' not in obj:
        for item in obj['nftables']:
            if 'chain' in item:
                chain = item['chain']['name']
                if not chain.startswith("VYOS_PBR"):
                if chain not in preserve_chains:
                    if table == 'ip mangle' and dict_search_args(policy, 'route', chain.replace("VYOS_PBR_", "", 1)):
                        commands.append(f'flush chain {table} {chain}')
                    elif table == 'ip6 mangle' and dict_search_args(policy, 'route6', chain.replace("VYOS_PBR6_", "", 1)):
                        commands.append(f'flush chain {table} {chain}')
                        commands += cleanup_rule(table, chain)
                        commands.append(f'delete chain {table} {chain}')
    return commands

def generate(policy):
    if not os.path.exists(nftables_conf):
        policy['first_install'] = True
        policy['cleanup_commands'] = cleanup_commands(policy)

    render(nftables_conf, 'firewall/nftables-policy.tmpl', policy)
    return None

def apply_table_marks(policy):
    for route in ['route', 'route6']:
        if route in policy:
            cmd_str = 'ip' if route == 'route' else 'ip -6'
            for name, pol_conf in policy[route].items():
                if 'rule' in pol_conf:
                    for rule_id, rule_conf in pol_conf['rule'].items():
                        set_table = dict_search_args(rule_conf, 'set', 'table')
                        if set_table:
                            if set_table == 'main':
                                set_table = '254'
                            table_mark = mark_offset - int(set_table)
                            cmd(f'{cmd_str} rule add pref {set_table} fwmark {table_mark} table {set_table}')

def cleanup_table_marks():
    for cmd_str in ['ip', 'ip -6']:
        json_rules = cmd(f'{cmd_str} -j -N rule list')
        rules = loads(json_rules)
        for rule in rules:
            if 'fwmark' not in rule or 'table' not in rule:
            fwmark = rule['fwmark']
            table = int(rule['table'])
            if fwmark[:2] == '0x':
                fwmark = int(fwmark, 16)
            if (int(fwmark) == (mark_offset - table)):
                cmd(f'{cmd_str} rule del fwmark {fwmark} table {table}')

def apply(policy):
    install_result = run(f'nft -f {nftables_conf}')
    if install_result == 1:
        raise ConfigError('Failed to apply policy based routing')

    if 'first_install' not in policy:


    return None

if __name__ == '__main__':
        c = get_config()
    except ConfigError as e: