#!/usr/bin/env python3
#
# Copyright (C) 2020-2023 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import sys
import grp
import re
import json
import logging
import signal
import importlib.util
import zmq
from contextlib import contextmanager

from vyos.defaults import directories
from vyos.utils.boot import boot_configuration_complete
from vyos.configsource import ConfigSourceString
from vyos.configsource import ConfigSourceError
from vyos.config import Config
from vyos import ConfigError

CFG_GROUP = 'vyattacfg'

script_stdout_log = '/tmp/vyos-configd-script-stdout'

debug = True

logger = logging.getLogger(__name__)
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)

if debug:
    logger.setLevel(logging.DEBUG)
else:
    logger.setLevel(logging.INFO)

SOCKET_PATH = "ipc:///run/vyos-configd.sock"

# Response error codes
R_SUCCESS = 1
R_ERROR_COMMIT = 2
R_ERROR_DAEMON = 4
R_PASS = 8

vyos_conf_scripts_dir = directories['conf_mode']
configd_include_file = os.path.join(directories['data'], 'configd-include.json')
configd_env_set_file = os.path.join(directories['data'], 'vyos-configd-env-set')
configd_env_unset_file = os.path.join(directories['data'], 'vyos-configd-env-unset')
# sourced on entering config session
configd_env_file = '/etc/default/vyos-configd-env'

session_out = None
session_mode = None

def key_name_from_file_name(f):
    return os.path.splitext(f)[0]

def module_name_from_key(k):
    return k.replace('-', '_')

def path_from_file_name(f):
    return os.path.join(vyos_conf_scripts_dir, f)

# opt-in to be run by daemon
with open(configd_include_file) as f:
    try:
        include = json.load(f)
    except OSError as e:
        logger.critical(f"configd include file error: {e}")
        sys.exit(1)
    except json.JSONDecodeError as e:
        logger.critical(f"JSON load error: {e}")
        sys.exit(1)

# import conf_mode scripts
(_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir)))
filenames.sort()

load_filenames = [f for f in filenames if f in include]
imports = [key_name_from_file_name(f) for f in load_filenames]
module_names = [module_name_from_key(k) for k in imports]
paths = [path_from_file_name(f) for f in load_filenames]
to_load = list(zip(module_names, paths))

modules = []

for x in to_load:
    spec = importlib.util.spec_from_file_location(x[0], x[1])
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    modules.append(module)

conf_mode_scripts = dict(zip(imports, modules))

exclude_set = {key_name_from_file_name(f) for f in filenames if f not in include}
include_set = {key_name_from_file_name(f) for f in filenames if f in include}

@contextmanager
def stdout_redirected(filename, mode):
    saved_stdout_fd = None
    destination_file = None
    try:
        sys.stdout.flush()
        saved_stdout_fd = os.dup(sys.stdout.fileno())
        destination_file = open(filename, mode)
        os.dup2(destination_file.fileno(), sys.stdout.fileno())
        yield
    finally:
        if saved_stdout_fd is not None:
            os.dup2(saved_stdout_fd, sys.stdout.fileno())
            os.close(saved_stdout_fd)
        if destination_file is not None:
            destination_file.close()

def explicit_print(path, mode, msg):
    try:
        with open(path, mode) as f:
            f.write(f"\n{msg}\n\n")
    except OSError:
        logger.critical("error explicit_print")

def run_script(script, config, args) -> int:
    script.argv = args
    config.set_level([])
    try:
        c = script.get_config(config)
        script.verify(c)
        script.generate(c)
        script.apply(c)
    except ConfigError as e:
        logger.critical(e)
        explicit_print(session_out, session_mode, str(e))
        return R_ERROR_COMMIT
    except Exception as e:
        logger.critical(e)
        return R_ERROR_DAEMON

    return R_SUCCESS

def initialization(socket):
    global session_out
    global session_mode
    # Reset config strings:
    active_string = ''
    session_string = ''
    # check first for resent init msg, in case of client timeout
    while True:
        msg = socket.recv().decode("utf-8", "ignore")
        try:
            message = json.loads(msg)
            if message["type"] == "init":
                resp = "init"
                socket.send(resp.encode())
        except:
            break

    # zmq synchronous for ipc from single client:
    active_string = msg
    resp = "active"
    socket.send(resp.encode())
    session_string = socket.recv().decode("utf-8", "ignore")
    resp = "session"
    socket.send(resp.encode())
    pid_string = socket.recv().decode("utf-8", "ignore")
    resp = "pid"
    socket.send(resp.encode())

    logger.debug(f"config session pid is {pid_string}")
    try:
        session_out = os.readlink(f"/proc/{pid_string}/fd/1")
        session_mode = 'w'
    except FileNotFoundError:
        session_out = None

    # if not a 'live' session, for example on boot, write to file
    if not session_out or not boot_configuration_complete():
        session_out = script_stdout_log
        session_mode = 'a'

    try:
        configsource = ConfigSourceString(running_config_text=active_string,
                                          session_config_text=session_string)
    except ConfigSourceError as e:
        logger.debug(e)
        return None

    config = Config(config_source=configsource)

    return config

def process_node_data(config, data) -> int:
    if not config:
        logger.critical(f"Empty config")
        return R_ERROR_DAEMON

    script_name = None
    args = []

    res = re.match(r'^(VYOS_TAGNODE_VALUE=[^/]+)?.*\/([^/]+).py(.*)', data)
    if res.group(1):
        env = res.group(1).split('=')
        os.environ[env[0]] = env[1]
    if res.group(2):
        script_name = res.group(2)
    if not script_name:
        logger.critical(f"Missing script_name")
        return R_ERROR_DAEMON
    if res.group(3):
        args = res.group(3).split()
    args.insert(0, f'{script_name}.py')

    if script_name not in include_set:
        return R_PASS

    with stdout_redirected(session_out, session_mode):
        result = run_script(conf_mode_scripts[script_name], config, args)

    return result

def remove_if_file(f: str):
    try:
        os.remove(f)
    except FileNotFoundError:
        pass
    except OSError:
        raise

def shutdown():
    remove_if_file(configd_env_file)
    os.symlink(configd_env_unset_file, configd_env_file)
    sys.exit(0)

if __name__ == '__main__':
    context = zmq.Context()
    socket = context.socket(zmq.REP)

    # Set the right permissions on the socket, then change it back
    o_mask = os.umask(0)
    socket.bind(SOCKET_PATH)
    os.umask(o_mask)

    cfg_group = grp.getgrnam(CFG_GROUP)
    os.setgid(cfg_group.gr_gid)

    os.environ['SUDO_USER'] = 'vyos'
    os.environ['SUDO_GID'] = str(cfg_group.gr_gid)

    def sig_handler(signum, frame):
        shutdown()

    signal.signal(signal.SIGTERM, sig_handler)
    signal.signal(signal.SIGINT, sig_handler)

    # Define the vyshim environment variable
    remove_if_file(configd_env_file)
    os.symlink(configd_env_set_file, configd_env_file)

    config = None

    while True:
        #  Wait for next request from client
        msg = socket.recv().decode()
        logger.debug(f"Received message: {msg}")
        message = json.loads(msg)

        if message["type"] == "init":
            resp = "init"
            socket.send(resp.encode())
            config = initialization(socket)
        elif message["type"] == "node":
            res = process_node_data(config, message["data"])
            response = res.to_bytes(1, byteorder=sys.byteorder)
            logger.debug(f"Sending response {res}")
            socket.send(response)
        else:
            logger.critical(f"Unexpected message: {message}")