#!/usr/bin/env python3
#
# Copyright (C) 2021 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#

import os
import re
import time
import logging
import tempfile
import threading
from sys import argv

from vyos.configtree import ConfigTree
from vyos.defaults import directories
from vyos.util import cmd, boot_configuration_complete
from vyos.migrator import VirtualMigrator

vyos_udev_dir = directories['vyos_udev_dir']
vyos_log_dir = '/run/udev/log'
vyos_log_file = os.path.join(vyos_log_dir, 'vyos-net-name')

config_path = '/opt/vyatta/etc/config/config.boot'

lock = threading.Lock()

try:
    os.mkdir(vyos_log_dir)
except FileExistsError:
    pass

logging.basicConfig(filename=vyos_log_file, level=logging.DEBUG)

def is_available(intfs: dict, intf_name: str) -> bool:
    """ Check if interface name is already assigned
    """
    if intf_name in list(intfs.values()):
        return False
    return True

def find_available(intfs: dict, prefix: str) -> str:
    """ Find lowest indexed iterface name that is not assigned
    """
    index_list = [int(x.replace(prefix, '')) for x in list(intfs.values()) if prefix in x]
    index_list.sort()
    # find 'holes' in list, if any
    missing = sorted(set(range(index_list[0], index_list[-1])) - set(index_list))
    if missing:
        return f'{prefix}{missing[0]}'

    return f'{prefix}{len(index_list)}'

def mod_ifname(ifname: str) -> str:
    """ Check interface with names eX and return ifname on the next format eth{ifindex} - 2
    """
    if re.match("^e[0-9]+$", ifname):
        intf = ifname.split("e")
        if intf[1]:
            if int(intf[1]) >= 2:
                return "eth" + str(int(intf[1]) - 2)
            else:
               return "eth" + str(intf[1])

    return ifname

def get_biosdevname(ifname: str) -> str:
    """ Use legacy vyatta-biosdevname to query for name

    This is carried over for compatability only, and will likely be dropped
    going forward.
    XXX: This throws an error, and likely has for a long time, unnoticed
    since vyatta_net_name redirected stderr to /dev/null.
    """
    intf = mod_ifname(ifname)

    if 'eth' not in intf:
        return intf
    if os.path.isdir('/proc/xen'):
        return intf

    time.sleep(1)

    try:
        biosname = cmd(f'/sbin/biosdevname --policy all_ethN -i {ifname}')
    except Exception as e:
        logging.error(f'biosdevname error: {e}')
        biosname = ''

    return intf if biosname == '' else biosname

def leave_rescan_hint(intf_name: str, hwid: str):
    """Write interface information reported by udev

    This script is called while the root mount is still read-only. Leave
    information in /run/udev: file name, the interface; contents, the
    hardware id.
    """
    try:
        os.mkdir(vyos_udev_dir)
    except FileExistsError:
        pass
    except Exception as e:
        logging.critical(f"Error creating rescan hint directory: {e}")
        exit(1)

    try:
        with open(os.path.join(vyos_udev_dir, intf_name), 'w') as f:
            f.write(hwid)
    except OSError as e:
        logging.critical(f"OSError {e}")

def get_configfile_interfaces() -> dict:
    """Read existing interfaces from config file
    """
    interfaces: dict = {}

    if not os.path.isfile(config_path):
        # If the case, then we are running off of livecd; return empty
        return interfaces

    try:
        with open(config_path) as f:
            config_file = f.read()
    except OSError as e:
        logging.critical(f"OSError {e}")
        exit(1)

    try:
        config = ConfigTree(config_file)
    except Exception:
        try:
            logging.debug(f"updating component version string syntax")
            # this will update the component version string syntax,
            # required for updates 1.2 --> 1.3/1.4
            with tempfile.NamedTemporaryFile() as fp:
                with open(fp.name, 'w') as fd:
                    fd.write(config_file)
                virtual_migration = VirtualMigrator(fp.name)
                virtual_migration.run()
                with open(fp.name) as fd:
                    config_file = fd.read()

            config = ConfigTree(config_file)

        except Exception as e:
            logging.critical(f"ConfigTree error: {e}")

    base = ['interfaces', 'ethernet']
    if config.exists(base):
        eth_intfs = config.list_nodes(base)
        for intf in eth_intfs:
            path = base + [intf, 'hw-id']
            if not config.exists(path):
                logging.warning(f"no 'hw-id' entry for {intf}")
                continue
            hwid = config.return_value(path)
            if hwid in list(interfaces):
                logging.warning(f"multiple entries for {hwid}: {interfaces[hwid]}, {intf}")
                continue
            interfaces[hwid] = intf

    base = ['interfaces', 'wireless']
    if config.exists(base):
        wlan_intfs = config.list_nodes(base)
        for intf in wlan_intfs:
            path = base + [intf, 'hw-id']
            if not config.exists(path):
                logging.warning(f"no 'hw-id' entry for {intf}")
                continue
            hwid = config.return_value(path)
            if hwid in list(interfaces):
                logging.warning(f"multiple entries for {hwid}: {interfaces[hwid]}, {intf}")
                continue
            interfaces[hwid] = intf

    logging.debug(f"config file entries: {interfaces}")

    return interfaces

def add_assigned_interfaces(intfs: dict):
    """Add interfaces found by previous invocation of udev rule
    """
    if not os.path.isdir(vyos_udev_dir):
        return

    for intf in os.listdir(vyos_udev_dir):
        path = os.path.join(vyos_udev_dir, intf)
        try:
            with open(path) as f:
                hwid = f.read().rstrip()
        except OSError as e:
            logging.error(f"OSError {e}")
            continue
        intfs[hwid] = intf

def on_boot_event(intf_name: str, hwid: str, predefined: str = '') -> str:
    """Called on boot by vyos-router: 'coldplug' in vyatta_net_name
    """
    logging.info(f"lookup {intf_name}, {hwid}")
    interfaces = get_configfile_interfaces()
    logging.debug(f"config file interfaces are {interfaces}")

    if hwid in list(interfaces):
        logging.info(f"use mapping from config file: '{hwid}' -> '{interfaces[hwid]}'")
        return interfaces[hwid]

    add_assigned_interfaces(interfaces)
    logging.debug(f"adding assigned interfaces: {interfaces}")

    if predefined:
        newname = predefined
        logging.info(f"predefined interface name for '{intf_name}' is '{newname}'")
    else:
        newname = get_biosdevname(intf_name)
        logging.info(f"biosdevname returned '{newname}' for '{intf_name}'")

    if not is_available(interfaces, newname):
        prefix = re.sub(r'\d+$', '', newname)
        newname = find_available(interfaces, prefix)

    logging.info(f"new name for '{intf_name}' is '{newname}'")

    leave_rescan_hint(newname, hwid)

    return newname

def hotplug_event():
    # Not yet implemented, since interface-rescan will only be run on boot.
    pass

if len(argv) > 3:
    predef_name = argv[3]
else:
    predef_name = ''

lock.acquire()
if not boot_configuration_complete():
    res = on_boot_event(argv[1], argv[2], predefined=predef_name)
    logging.debug(f"on boot, returned name is {res}")
    print(res)
else:
    logging.debug("boot configuration complete")
lock.release()