#!/usr/bin/env python3
#
# Copyright (C) 2021-2024 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# T2199: Remove unavailable nodes due to XML/Python implementation using nftables
#        monthdays: nftables does not have a monthdays equivalent
#        utc: nftables userspace uses localtime and calculates the UTC offset automatically
#        icmp/v6: migrate previously available `type-name` to valid type/code
# T4178: Update tcp flags to use multi value node
# T6071: CLI description limit of 256 characters

import re

from sys import argv
from sys import exit

from vyos.configtree import ConfigTree

if len(argv) < 2:
    print("Must specify file name!")
    exit(1)

file_name = argv[1]

with open(file_name, 'r') as f:
    config_file = f.read()

max_len_description = 255

base = ['firewall']
config = ConfigTree(config_file)

if not config.exists(base):
    # Nothing to do
    exit(0)

icmp_remove = ['any']
icmp_translations = {
    'ping': 'echo-request',
    'pong': 'echo-reply',
    'ttl-exceeded': 'time-exceeded',
    # Network Unreachable
    'network-unreachable': [3, 0],
    'host-unreachable': [3, 1],
    'protocol-unreachable': [3, 2],
    'port-unreachable': [3, 3],
    'fragmentation-needed': [3, 4],
    'source-route-failed': [3, 5],
    'network-unknown': [3, 6],
    'host-unknown': [3, 7],
    'network-prohibited': [3, 9],
    'host-prohibited': [3, 10],
    'TOS-network-unreachable': [3, 11],
    'TOS-host-unreachable': [3, 12],
    'communication-prohibited': [3, 13],
    'host-precedence-violation': [3, 14],
    'precedence-cutoff': [3, 15],
    # Redirect
    'network-redirect': [5, 0],
    'host-redirect': [5, 1],
    'TOS-network-redirect': [5, 2],
    'TOS host-redirect': [5, 3],
    #  Time Exceeded
    'ttl-zero-during-transit': [11, 0],
    'ttl-zero-during-reassembly': [11, 1],
    'ttl-exceeded': 'time-exceeded',
    # Parameter Problem
    'ip-header-bad': [12, 0],
    'required-option-missing': [12, 1]
}

icmpv6_remove = []
icmpv6_translations = {
    'ping': 'echo-request',
    'pong': 'echo-reply',
    # Destination Unreachable
    'no-route': [1, 0],
    'communication-prohibited': [1, 1],
    'address-unreachble': [1, 3],
    'port-unreachable': [1, 4],
    # nd
    'redirect': 'nd-redirect',
    'router-solicitation': 'nd-router-solicit',
    'router-advertisement': 'nd-router-advert',
    'neighbour-solicitation': 'nd-neighbor-solicit',
    'neighbor-solicitation': 'nd-neighbor-solicit',
    'neighbour-advertisement': 'nd-neighbor-advert',
    'neighbor-advertisement': 'nd-neighbor-advert',
    #  Time Exceeded
    'ttl-zero-during-transit': [3, 0],
    'ttl-zero-during-reassembly': [3, 1],
    # Parameter Problem
    'bad-header': [4, 0],
    'unknown-header-type': [4, 1],
    'unknown-option': [4, 2]
}

if config.exists(base + ['group']):
    for group_type in config.list_nodes(base + ['group']):
        for group_name in config.list_nodes(base + ['group', group_type]):
            name_description = base + ['group', group_type, group_name, 'description']
            if config.exists(name_description):
                tmp = config.return_value(name_description)
                config.set(name_description, value=tmp[:max_len_description])

if config.exists(base + ['name']):
    for name in config.list_nodes(base + ['name']):
        name_description = base + ['name', name, 'description']
        if config.exists(name_description):
            tmp = config.return_value(name_description)
            config.set(name_description, value=tmp[:max_len_description])

        if not config.exists(base + ['name', name, 'rule']):
            continue

        for rule in config.list_nodes(base + ['name', name, 'rule']):
            rule_description = base + ['name', name, 'rule', rule, 'description']
            if config.exists(rule_description):
                tmp = config.return_value(rule_description)
                config.set(rule_description, value=tmp[:max_len_description])

            rule_recent = base + ['name', name, 'rule', rule, 'recent']
            rule_time = base + ['name', name, 'rule', rule, 'time']
            rule_tcp_flags = base + ['name', name, 'rule', rule, 'tcp', 'flags']
            rule_icmp = base + ['name', name, 'rule', rule, 'icmp']

            if config.exists(rule_time + ['monthdays']):
                config.delete(rule_time + ['monthdays'])

            if config.exists(rule_time + ['utc']):
                config.delete(rule_time + ['utc'])

            if config.exists(rule_recent + ['time']):
                tmp = int(config.return_value(rule_recent + ['time']))
                unit = 'minute'
                if tmp > 600:
                    unit = 'hour'
                elif tmp < 10:
                    unit = 'second'
                config.set(rule_recent + ['time'], value=unit)

            if config.exists(rule_tcp_flags):
                tmp = config.return_value(rule_tcp_flags)
                config.delete(rule_tcp_flags)
                for flag in tmp.split(","):
                    if flag[0] == '!':
                        config.set(rule_tcp_flags + ['not', flag[1:].lower()])
                    else:
                        config.set(rule_tcp_flags + [flag.lower()])

            if config.exists(rule_icmp + ['type-name']):
                tmp = config.return_value(rule_icmp + ['type-name'])
                if tmp in icmp_remove:
                    config.delete(rule_icmp + ['type-name'])
                elif tmp in icmp_translations:
                    translate = icmp_translations[tmp]
                    if isinstance(translate, str):
                        config.set(rule_icmp + ['type-name'], value=translate)
                    elif isinstance(translate, list):
                        config.delete(rule_icmp + ['type-name'])
                        config.set(rule_icmp + ['type'], value=translate[0])
                        config.set(rule_icmp + ['code'], value=translate[1])

            for src_dst in ['destination', 'source']:
                pg_base = base + ['name', name, 'rule', rule, src_dst, 'group', 'port-group']
                proto_base = base + ['name', name, 'rule', rule, 'protocol']
                if config.exists(pg_base) and not config.exists(proto_base):
                    config.set(proto_base, value='tcp_udp')

if config.exists(base + ['ipv6-name']):
    for name in config.list_nodes(base + ['ipv6-name']):
        name_description = base + ['ipv6-name', name, 'description']
        if config.exists(name_description):
            tmp = config.return_value(name_description)
            config.set(name_description, value=tmp[:max_len_description])

        if not config.exists(base + ['ipv6-name', name, 'rule']):
            continue

        for rule in config.list_nodes(base + ['ipv6-name', name, 'rule']):
            rule_description = base + ['ipv6-name', name, 'rule', rule, 'description']
            if config.exists(rule_description):
                tmp = config.return_value(rule_description)
                config.set(rule_description, value=tmp[:max_len_description])

            rule_recent = base + ['ipv6-name', name, 'rule', rule, 'recent']
            rule_time = base + ['ipv6-name', name, 'rule', rule, 'time']
            rule_tcp_flags = base + ['ipv6-name', name, 'rule', rule, 'tcp', 'flags']
            rule_icmp = base + ['ipv6-name', name, 'rule', rule, 'icmpv6']

            if config.exists(rule_time + ['monthdays']):
                config.delete(rule_time + ['monthdays'])

            if config.exists(rule_time + ['utc']):
                config.delete(rule_time + ['utc'])

            if config.exists(rule_recent + ['time']):
                tmp = int(config.return_value(rule_recent + ['time']))
                unit = 'minute'
                if tmp > 600:
                    unit = 'hour'
                elif tmp < 10:
                    unit = 'second'
                config.set(rule_recent + ['time'], value=unit)

            if config.exists(rule_tcp_flags):
                tmp = config.return_value(rule_tcp_flags)
                config.delete(rule_tcp_flags)
                for flag in tmp.split(","):
                    if flag[0] == '!':
                        config.set(rule_tcp_flags + ['not', flag[1:].lower()])
                    else:
                        config.set(rule_tcp_flags + [flag.lower()])

            if config.exists(base + ['ipv6-name', name, 'rule', rule, 'protocol']):
                tmp = config.return_value(base + ['ipv6-name', name, 'rule', rule, 'protocol'])
                if tmp == 'icmpv6':
                    config.set(base + ['ipv6-name', name, 'rule', rule, 'protocol'], value='ipv6-icmp')

            if config.exists(rule_icmp + ['type']):
                tmp = config.return_value(rule_icmp + ['type'])
                type_code_match = re.match(r'^(\d+)(?:/(\d+))?$', tmp)

                if type_code_match:
                    config.set(rule_icmp + ['type'], value=type_code_match[1])
                    if type_code_match[2]:
                        config.set(rule_icmp + ['code'], value=type_code_match[2])
                elif tmp in icmpv6_remove:
                    config.delete(rule_icmp + ['type'])
                elif tmp in icmpv6_translations:
                    translate = icmpv6_translations[tmp]
                    if isinstance(translate, str):
                        config.delete(rule_icmp + ['type'])
                        config.set(rule_icmp + ['type-name'], value=translate)
                    elif isinstance(translate, list):
                        config.set(rule_icmp + ['type'], value=translate[0])
                        config.set(rule_icmp + ['code'], value=translate[1])
                else:
                    config.rename(rule_icmp + ['type'], 'type-name')

            for src_dst in ['destination', 'source']:
                pg_base = base + ['ipv6-name', name, 'rule', rule, src_dst, 'group', 'port-group']
                proto_base = base + ['ipv6-name', name, 'rule', rule, 'protocol']
                if config.exists(pg_base) and not config.exists(proto_base):
                    config.set(proto_base, value='tcp_udp')

try:
    with open(file_name, 'w') as f:
        f.write(config.to_string())
except OSError as e:
    print("Failed to save the modified config: {}".format(e))
    exit(1)