#!/usr/bin/env python3
#
# Copyright (C) 2020-2024 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

# pylint: disable=redefined-outer-name

import os
import sys
import grp
import re
import json
import typing
import logging
import signal
import traceback
import importlib.util
import io
from contextlib import redirect_stdout

import zmq

from vyos.defaults import directories
from vyos.utils.boot import boot_configuration_complete
from vyos.configsource import ConfigSourceString
from vyos.configsource import ConfigSourceError
from vyos.configdiff import get_commit_scripts
from vyos.config import Config
from vyos.frrender import FRRender
from vyos.frrender import get_frrender_dict
from vyos import ConfigError

CFG_GROUP = 'vyattacfg'

script_stdout_log = '/tmp/vyos-configd-script-stdout'

debug = True

logger = logging.getLogger(__name__)
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)

if debug:
    logger.setLevel(logging.DEBUG)
else:
    logger.setLevel(logging.INFO)

SOCKET_PATH = 'ipc:///run/vyos-configd.sock'
MAX_MSG_SIZE = 65535
PAD_MSG_SIZE = 6

# Response error codes
R_SUCCESS = 1
R_ERROR_COMMIT = 2
R_ERROR_DAEMON = 4
R_PASS = 8

vyos_conf_scripts_dir = directories['conf_mode']
configd_include_file = os.path.join(directories['data'], 'configd-include.json')
configd_env_set_file = os.path.join(directories['data'], 'vyos-configd-env-set')
configd_env_unset_file = os.path.join(directories['data'], 'vyos-configd-env-unset')
# sourced on entering config session
configd_env_file = '/etc/default/vyos-configd-env'

def key_name_from_file_name(f):
    return os.path.splitext(f)[0]

def module_name_from_key(k):
    return k.replace('-', '_')

def path_from_file_name(f):
    return os.path.join(vyos_conf_scripts_dir, f)


# opt-in to be run by daemon
with open(configd_include_file) as f:
    try:
        include = json.load(f)
    except OSError as e:
        logger.critical(f'configd include file error: {e}')
        sys.exit(1)
    except json.JSONDecodeError as e:
        logger.critical(f'JSON load error: {e}')
        sys.exit(1)


# import conf_mode scripts
(_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir)))
filenames.sort()

load_filenames = [f for f in filenames if f in include]
imports = [key_name_from_file_name(f) for f in load_filenames]
module_names = [module_name_from_key(k) for k in imports]
paths = [path_from_file_name(f) for f in load_filenames]
to_load = list(zip(module_names, paths))

modules = []

for x in to_load:
    spec = importlib.util.spec_from_file_location(x[0], x[1])
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    modules.append(module)

conf_mode_scripts = dict(zip(imports, modules))

exclude_set = {key_name_from_file_name(f) for f in filenames if f not in include}
include_set = {key_name_from_file_name(f) for f in filenames if f in include}


def write_stdout_log(file_name, msg):
    if boot_configuration_complete():
        return
    with open(file_name, 'a') as f:
        f.write(msg)


def run_script(script_name, config, args) -> tuple[int, str]:
    # pylint: disable=broad-exception-caught

    script = conf_mode_scripts[script_name]
    script.argv = args
    config.set_level([])
    try:
        c = script.get_config(config)
        script.verify(c)
        script.generate(c)
        script.apply(c)
    except ConfigError as e:
        logger.error(e)
        return R_ERROR_COMMIT, str(e)
    except Exception:
        tb = traceback.format_exc()
        logger.error(tb)
        return R_ERROR_COMMIT, tb

    return R_SUCCESS, ''


def initialization(socket):
    # pylint: disable=broad-exception-caught,too-many-locals

    # Reset config strings:
    active_string = ''
    session_string = ''
    # check first for resent init msg, in case of client timeout
    while True:
        msg = socket.recv().decode('utf-8', 'ignore')
        try:
            message = json.loads(msg)
            if message['type'] == 'init':
                resp = 'init'
                socket.send(resp.encode())
        except Exception:
            break

    # zmq synchronous for ipc from single client:
    active_string = msg
    resp = 'active'
    socket.send(resp.encode())
    session_string = socket.recv().decode('utf-8', 'ignore')
    resp = 'session'
    socket.send(resp.encode())
    pid_string = socket.recv().decode('utf-8', 'ignore')
    resp = 'pid'
    socket.send(resp.encode())
    sudo_user_string = socket.recv().decode('utf-8', 'ignore')
    resp = 'sudo_user'
    socket.send(resp.encode())
    temp_config_dir_string = socket.recv().decode('utf-8', 'ignore')
    resp = 'temp_config_dir'
    socket.send(resp.encode())
    changes_only_dir_string = socket.recv().decode('utf-8', 'ignore')
    resp = 'changes_only_dir'
    socket.send(resp.encode())

    logger.debug(f'config session pid is {pid_string}')
    logger.debug(f'config session sudo_user is {sudo_user_string}')

    os.environ['SUDO_USER'] = sudo_user_string
    if temp_config_dir_string:
        os.environ['VYATTA_TEMP_CONFIG_DIR'] = temp_config_dir_string
    if changes_only_dir_string:
        os.environ['VYATTA_CHANGES_ONLY_DIR'] = changes_only_dir_string

    try:
        configsource = ConfigSourceString(running_config_text=active_string,
                                          session_config_text=session_string)
    except ConfigSourceError as e:
        logger.debug(e)
        return None

    config = Config(config_source=configsource)
    dependent_func: dict[str, list[typing.Callable]] = {}
    setattr(config, 'dependent_func', dependent_func)

    commit_scripts = get_commit_scripts(config)
    logger.debug(f'commit_scripts: {commit_scripts}')

    scripts_called = []
    setattr(config, 'scripts_called', scripts_called)

    if not hasattr(config, 'frrender_cls'):
        setattr(config, 'frrender_cls', FRRender())

    return config


def process_node_data(config, data, _last: bool = False) -> tuple[int, str]:
    if not config:
        out = 'Empty config'
        logger.critical(out)
        return R_ERROR_DAEMON, out

    script_name = None
    os.environ['VYOS_TAGNODE_VALUE'] = ''
    args = []
    config.dependency_list.clear()

    res = re.match(r'^(VYOS_TAGNODE_VALUE=[^/]+)?.*\/([^/]+).py(.*)', data)
    if res.group(1):
        env = res.group(1).split('=')
        os.environ[env[0]] = env[1]
    if res.group(2):
        script_name = res.group(2)
    if not script_name:
        out = 'Missing script_name'
        logger.critical(out)
        return R_ERROR_DAEMON, out
    if res.group(3):
        args = res.group(3).split()
    args.insert(0, f'{script_name}.py')

    tag_value = os.getenv('VYOS_TAGNODE_VALUE', '')
    tag_ext = f'_{tag_value}' if tag_value else ''
    script_record = f'{script_name}{tag_ext}'
    scripts_called = getattr(config, 'scripts_called', [])
    scripts_called.append(script_record)

    if script_name not in include_set:
        return R_PASS, ''

    with redirect_stdout(io.StringIO()) as o:
        result, err_out = run_script(script_name, config, args)
    amb_out = o.getvalue()
    o.close()

    out = amb_out + err_out

    return result, out


def send_result(sock, err, msg):
    msg = msg if msg else ''
    msg_size = min(MAX_MSG_SIZE, len(msg))

    err_rep = err.to_bytes(1)
    msg_size_rep = f'{msg_size:#0{PAD_MSG_SIZE}x}'

    logger.debug(f'Sending reply: error_code {err} with output')
    sock.send_multipart([err_rep, msg_size_rep.encode(), msg.encode()])

    write_stdout_log(script_stdout_log, msg)


def remove_if_file(f: str):
    try:
        os.remove(f)
    except FileNotFoundError:
        pass


def shutdown():
    remove_if_file(configd_env_file)
    os.symlink(configd_env_unset_file, configd_env_file)
    sys.exit(0)


if __name__ == '__main__':
    context = zmq.Context()
    socket = context.socket(zmq.REP)

    # Set the right permissions on the socket, then change it back
    o_mask = os.umask(0)
    socket.bind(SOCKET_PATH)
    os.umask(o_mask)

    cfg_group = grp.getgrnam(CFG_GROUP)
    os.setgid(cfg_group.gr_gid)

    os.environ['VYOS_CONFIGD'] = 't'

    def sig_handler(signum, frame):
        # pylint: disable=unused-argument
        shutdown()

    signal.signal(signal.SIGTERM, sig_handler)
    signal.signal(signal.SIGINT, sig_handler)

    # Define the vyshim environment variable
    remove_if_file(configd_env_file)
    os.symlink(configd_env_set_file, configd_env_file)

    config = None

    while True:
        #  Wait for next request from client
        msg = socket.recv().decode()
        logger.debug(f'Received message: {msg}')
        message = json.loads(msg)

        if message['type'] == 'init':
            resp = 'init'
            socket.send(resp.encode())
            config = initialization(socket)
        elif message['type'] == 'node':
            res, out = process_node_data(config, message['data'], message['last'])
            send_result(socket, res, out)

            if message['last'] and config:
                scripts_called = getattr(config, 'scripts_called', [])
                logger.debug(f'scripts_called: {scripts_called}')

                if hasattr(config, 'frrender_cls') and res == R_SUCCESS:
                    frrender_cls = getattr(config, 'frrender_cls')
                    tmp = get_frrender_dict(config)
                    frrender_cls.generate(tmp)
                    frrender_cls.apply()
        else:
            logger.critical(f'Unexpected message: {message}')