#!/usr/bin/env python3
#
# Copyright (C) 2020 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#

import os
import sys
import grp
import re
import json
import logging
import signal
import importlib.util
import zmq
from contextlib import redirect_stdout, redirect_stderr

from vyos.defaults import directories
from vyos.configsource import ConfigSourceString, ConfigSourceError
from vyos.config import Config
from vyos import ConfigError

CFG_GROUP = 'vyattacfg'

debug = True

logger = logging.getLogger(__name__)
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)

if debug:
    logger.setLevel(logging.DEBUG)
else:
    logger.setLevel(logging.INFO)

SOCKET_PATH = "ipc:///run/vyos-configd.sock"

# Response error codes
R_SUCCESS = 1
R_ERROR_COMMIT = 2
R_ERROR_DAEMON = 4
R_PASS = 8

vyos_conf_scripts_dir = directories['conf_mode']
configd_include_file = os.path.join(directories['data'], 'configd-include.json')
configd_env_set_file = os.path.join(directories['data'], 'vyos-configd-env-set')
configd_env_unset_file = os.path.join(directories['data'], 'vyos-configd-env-unset')
# sourced on entering config session
configd_env_file = '/etc/default/vyos-configd-env'

session_tty = None

def key_name_from_file_name(f):
    return os.path.splitext(f)[0]

def module_name_from_key(k):
    return k.replace('-', '_')

def path_from_file_name(f):
    return os.path.join(vyos_conf_scripts_dir, f)

# opt-in to be run by daemon
with open(configd_include_file) as f:
    try:
        include = json.load(f)
    except OSError as e:
        logger.critical(f"configd include file error: {e}")
        sys.exit(1)
    except json.JSONDecodeError as e:
        logger.critical(f"JSON load error: {e}")
        sys.exit(1)

# import conf_mode scripts
(_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir)))
filenames.sort()

load_filenames = [f for f in filenames if f in include]
imports = [key_name_from_file_name(f) for f in load_filenames]
module_names = [module_name_from_key(k) for k in imports]
paths = [path_from_file_name(f) for f in load_filenames]
to_load = list(zip(module_names, paths))

modules = []

for x in to_load:
    spec = importlib.util.spec_from_file_location(x[0], x[1])
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    modules.append(module)

conf_mode_scripts = dict(zip(imports, modules))

exclude_set = {key_name_from_file_name(f) for f in filenames if f not in include}
include_set = {key_name_from_file_name(f) for f in filenames if f in include}


def run_script(script, config) -> int:
    config.set_level([])
    try:
        with open(session_tty, 'w') as f, redirect_stdout(f):
            with redirect_stderr(f):
                c = script.get_config(config)
                script.verify(c)
                script.generate(c)
                script.apply(c)
    except ConfigError as e:
        logger.critical(e)
        with open(session_tty, 'w') as f, redirect_stdout(f):
            print(f"{e}\n")
        return R_ERROR_COMMIT
    except Exception as e:
        logger.critical(e)
        return R_ERROR_DAEMON

    return R_SUCCESS

def initialization(socket):
    global session_tty
    # Reset config strings:
    active_string = ''
    session_string = ''
    # check first for resent init msg, in case of client timeout
    while True:
        msg = socket.recv().decode("utf-8", "ignore")
        try:
            message = json.loads(msg)
            if message["type"] == "init":
                resp = "init"
                socket.send(resp.encode())
        except:
            break

    # zmq synchronous for ipc from single client:
    active_string = msg
    resp = "active"
    socket.send(resp.encode())
    session_string = socket.recv().decode("utf-8", "ignore")
    resp = "session"
    socket.send(resp.encode())
    pid_string = socket.recv().decode("utf-8", "ignore")
    resp = "pid"
    socket.send(resp.encode())

    logger.debug(f"config session pid is {pid_string}")
    try:
        session_tty = os.readlink(f"/proc/{pid_string}/fd/1")
    except FileNotFoundError:
        session_tty = None

    try:
        configsource = ConfigSourceString(running_config_text=active_string,
                                          session_config_text=session_string)
    except ConfigSourceError as e:
        logger.debug(e)
        return None

    config = Config(config_source=configsource)

    return config

def process_node_data(config, data) -> int:
    if not config:
        logger.critical(f"Empty config")
        return R_ERROR_DAEMON

    script_name = None

    res = re.match(r'^.+\/([^/].+).py(VYOS_TAGNODE_VALUE=.+)?', data)
    if res.group(1):
        script_name = res.group(1)
    if res.group(2):
        env = res.group(2).split('=')
        os.environ[env[0]] = env[1]

    if not script_name:
        logger.critical(f"Missing script_name")
        return R_ERROR_DAEMON

    if script_name in exclude_set:
        return R_PASS

    result = run_script(conf_mode_scripts[script_name], config)

    return result

def remove_if_file(f: str):
    try:
        os.remove(f)
    except FileNotFoundError:
        pass
    except OSError:
        raise

def shutdown():
    remove_if_file(configd_env_file)
    os.symlink(configd_env_unset_file, configd_env_file)
    sys.exit(0)

if __name__ == '__main__':
    context = zmq.Context()
    socket = context.socket(zmq.REP)

    # Set the right permissions on the socket, then change it back
    o_mask = os.umask(0)
    socket.bind(SOCKET_PATH)
    os.umask(o_mask)

    cfg_group = grp.getgrnam(CFG_GROUP)
    os.setgid(cfg_group.gr_gid)

    os.environ['SUDO_USER'] = 'vyos'
    os.environ['SUDO_GID'] = str(cfg_group.gr_gid)

    def sig_handler(signum, frame):
        shutdown()

    signal.signal(signal.SIGTERM, sig_handler)
    signal.signal(signal.SIGINT, sig_handler)

    # Define the vyshim environment variable
    remove_if_file(configd_env_file)
    os.symlink(configd_env_set_file, configd_env_file)

    config = None

    while True:
        #  Wait for next request from client
        msg = socket.recv().decode()
        logger.debug(f"Received message: {msg}")
        message = json.loads(msg)

        if message["type"] == "init":
            resp = "init"
            socket.send(resp.encode())
            config = initialization(socket)
        elif message["type"] == "node":
            res = process_node_data(config, message["data"])
            response = res.to_bytes(1, byteorder=sys.byteorder)
            logger.debug(f"Sending response {res}")
            socket.send(response)
        else:
            logger.critical(f"Unexpected message: {message}")