#!/usr/bin/env python3
#
# Copyright (C) 2025 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#
import os
import sys
import grp
import json
import signal
import socket
import typing
import logging
import traceback
import importlib.util
import io
from contextlib import redirect_stdout
from dataclasses import dataclass
from dataclasses import fields
from dataclasses import field
from dataclasses import asdict
from pathlib import Path

import tomli

from google.protobuf.json_format import MessageToDict
from google.protobuf.json_format import ParseDict

from vyos.defaults import directories
from vyos.utils.boot import boot_configuration_complete
from vyos.configsource import ConfigSourceCache
from vyos.configsource import ConfigSourceError
from vyos.config import Config
from vyos.frrender import FRRender
from vyos.frrender import get_frrender_dict
from vyos import ConfigError

from vyos.proto import vycall_pb2


@dataclass
class Status:
    success: bool = False
    out: str = ''


@dataclass
class Call:
    script_name: str = ''
    tag_value: str = None
    arg_value: str = None
    reply: Status = None

    def set_reply(self, success: bool, out: str):
        self.reply = Status(success=success, out=out)


@dataclass
class Session:
    # pylint: disable=too-many-instance-attributes

    session_id: str = ''
    named_active: str = None
    named_proposed: str = None
    dry_run: bool = False
    atomic: bool = False
    background: bool = False
    config: Config = None
    init: Status = None
    calls: list[Call] = field(default_factory=list)

    def set_init(self, success: bool, out: str):
        self.init = Status(success=success, out=out)


@dataclass
class ServerConf:
    commitd_socket: str = ''
    session_dir: str = ''
    running_cache: str = ''
    session_cache: str = ''


server_conf = None
SOCKET_PATH = None
conf_mode_scripts = None
frr = None

CFG_GROUP = 'vyattacfg'

script_stdout_log = '/tmp/vyos-commitd-script-stdout'

debug = True

logger = logging.getLogger(__name__)
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)

if debug:
    logger.setLevel(logging.DEBUG)
else:
    logger.setLevel(logging.INFO)


vyos_conf_scripts_dir = directories['conf_mode']
commitd_include_file = os.path.join(directories['data'], 'configd-include.json')


def key_name_from_file_name(f):
    return os.path.splitext(f)[0]


def module_name_from_key(k):
    return k.replace('-', '_')


def path_from_file_name(f):
    return os.path.join(vyos_conf_scripts_dir, f)


def load_conf_mode_scripts():
    with open(commitd_include_file) as f:
        try:
            include = json.load(f)
        except OSError as e:
            logger.critical(f'configd include file error: {e}')
            sys.exit(1)
        except json.JSONDecodeError as e:
            logger.critical(f'JSON load error: {e}')
            sys.exit(1)

    # import conf_mode scripts
    (_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir)))
    filenames.sort()

    # this is redundant, as all scripts are currently in the include file;
    # leave it as an inexpensive check for future changes
    load_filenames = [f for f in filenames if f in include]
    imports = [key_name_from_file_name(f) for f in load_filenames]
    module_names = [module_name_from_key(k) for k in imports]
    paths = [path_from_file_name(f) for f in load_filenames]
    to_load = list(zip(module_names, paths))

    modules = []

    for x in to_load:
        spec = importlib.util.spec_from_file_location(x[0], x[1])
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        modules.append(module)

    scripts = dict(zip(imports, modules))

    return scripts


def get_session_out(session: Session) -> str:
    out = ''
    if session.init and session.init.out:
        out = f'{out} + init: {session.init.out} + \n'
    for call in session.calls:
        reply = call.reply
        if reply and reply.out:
            out = f'{out} + {call.script_name}: {reply.out} + \n'
    return out


def write_stdout_log(file_name, session):
    if boot_configuration_complete():
        return
    with open(file_name, 'a') as f:
        f.write(get_session_out(session))


def msg_to_commit_data(msg: vycall_pb2.Commit) -> Session:
    # pylint: disable=no-member

    d = MessageToDict(msg, preserving_proto_field_name=True)

    # wrap in dataclasses
    session = Session(**d)
    session.init = Status(**session.init) if session.init else None
    session.calls = list(map(lambda x: Call(**x), session.calls))
    for call in session.calls:
        call.reply = Status(**call.reply) if call.reply else None

    return session


def commit_data_to_msg(obj: Session) -> vycall_pb2.Commit:
    # pylint: disable=no-member

    # avoid asdict attempt of deepcopy on Config obj
    obj.config = None

    msg = vycall_pb2.Commit()
    msg = ParseDict(asdict(obj), msg, ignore_unknown_fields=True)

    return msg


def initialization(session: Session) -> Session:
    running_cache = os.path.join(server_conf.session_dir, server_conf.running_cache)
    session_cache = os.path.join(server_conf.session_dir, server_conf.session_cache)
    try:
        configsource = ConfigSourceCache(
            running_config_cache=running_cache,
            session_config_cache=session_cache,
        )
    except ConfigSourceError as e:
        fail_msg = f'Failed to read config caches: {e}'
        logger.critical(fail_msg)
        session.set_init(False, fail_msg)
        return session

    session.set_init(True, '')

    config = Config(config_source=configsource)

    dependent_func: dict[str, list[typing.Callable]] = {}
    setattr(config, 'dependent_func', dependent_func)

    scripts_called = []
    setattr(config, 'scripts_called', scripts_called)

    dry_run = False
    setattr(config, 'dry_run', dry_run)

    session.config = config

    return session


def run_script(script_name: str, config: Config, args: list) -> tuple[bool, str]:
    # pylint: disable=broad-exception-caught

    script = conf_mode_scripts[script_name]
    script.argv = args
    config.set_level([])
    try:
        c = script.get_config(config)
        script.verify(c)
        script.generate(c)
        script.apply(c)
    except ConfigError as e:
        logger.error(e)
        return False, str(e)
    except Exception:
        tb = traceback.format_exc()
        logger.error(tb)
        return False, tb

    return True, ''


def process_call_data(call: Call, config: Config, last: bool = False) -> None:
    # pylint: disable=too-many-locals

    script_name = key_name_from_file_name(call.script_name)

    if script_name not in conf_mode_scripts:
        fail_msg = f'No such script: {call.script_name}'
        logger.critical(fail_msg)
        call.set_reply(False, fail_msg)
        return

    config.dependency_list.clear()

    tag_value = call.tag_value if call.tag_value is not None else ''
    os.environ['VYOS_TAGNODE_VALUE'] = tag_value

    args = call.arg_value.split() if call.arg_value else []
    args.insert(0, f'{script_name}.py')

    tag_ext = f'_{tag_value}' if tag_value else ''
    script_record = f'{script_name}{tag_ext}'
    scripts_called = getattr(config, 'scripts_called', [])
    scripts_called.append(script_record)

    with redirect_stdout(io.StringIO()) as o:
        success, err_out = run_script(script_name, config, args)
    amb_out = o.getvalue()
    o.close()

    out = amb_out + err_out

    call.set_reply(success, out)

    logger.info(f'[{script_name}] {out}')

    if last:
        scripts_called = getattr(config, 'scripts_called', [])
        logger.debug(f'scripts_called: {scripts_called}')

    if last and success:
        tmp = get_frrender_dict(config)
        if frr.generate(tmp):
            # only apply a new FRR configuration if anything changed
            # in comparison to the previous applied configuration
            frr.apply()


def process_session_data(session: Session) -> Session:
    if session.init is None or not session.init.success:
        return session

    config = session.config
    len_calls = len(session.calls)
    for index, call in enumerate(session.calls):
        process_call_data(call, config, last=len_calls == index + 1)

    return session


def read_message(msg: bytes) -> Session:
    """Read message into Session instance"""

    message = vycall_pb2.Commit()  # pylint: disable=no-member
    message.ParseFromString(msg)
    session = msg_to_commit_data(message)

    session = initialization(session)
    session = process_session_data(session)

    write_stdout_log(script_stdout_log, session)

    return session


def write_reply(session: Session) -> bytearray:
    """Serialize modified object to bytearray, prepending data length
    header"""

    reply = commit_data_to_msg(session)
    encoded_data = reply.SerializeToString()
    byte_size = reply.ByteSize()
    length_bytes = byte_size.to_bytes(4)
    arr = bytearray(length_bytes)
    arr.extend(encoded_data)

    return arr


def load_server_conf() -> ServerConf:
    # pylint: disable=import-outside-toplevel
    # pylint: disable=broad-exception-caught
    from vyos.defaults import vyconfd_conf

    try:
        with open(vyconfd_conf, 'rb') as f:
            vyconfd_conf_d = tomli.load(f)

    except Exception as e:
        logger.critical(f'Failed to open the vyconfd.conf file {vyconfd_conf}: {e}')
        sys.exit(1)

    app = vyconfd_conf_d.get('appliance', {})

    conf_data = {
        k: v for k, v in app.items() if k in [_.name for _ in fields(ServerConf)]
    }

    conf = ServerConf(**conf_data)

    return conf


def remove_if_exists(f: str):
    try:
        os.unlink(f)
    except FileNotFoundError:
        pass


def sig_handler(_signum, _frame):
    logger.info('stopping server')
    raise KeyboardInterrupt


def run_server():
    # pylint: disable=global-statement

    global server_conf
    global SOCKET_PATH
    global conf_mode_scripts
    global frr

    signal.signal(signal.SIGTERM, sig_handler)
    signal.signal(signal.SIGINT, sig_handler)

    logger.info('starting server')

    server_conf = load_server_conf()
    SOCKET_PATH = server_conf.commitd_socket
    conf_mode_scripts = load_conf_mode_scripts()

    cfg_group = grp.getgrnam(CFG_GROUP)
    os.setgid(cfg_group.gr_gid)

    server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)

    remove_if_exists(SOCKET_PATH)
    server_socket.bind(SOCKET_PATH)
    Path(SOCKET_PATH).chmod(0o775)

    # We only need one long-lived instance of FRRender
    frr = FRRender()

    server_socket.listen(2)
    while True:
        try:
            conn, _ = server_socket.accept()
            logger.debug('connection accepted')
            while True:
                # receive size of data
                data_length = conn.recv(4)
                if not data_length:
                    logger.debug('no data')
                    # if no data break
                    break

                length = int.from_bytes(data_length)
                # receive data
                data = conn.recv(length)

                session = read_message(data)
                reply = write_reply(session)
                conn.sendall(reply)

            conn.close()
            logger.debug('connection closed')

        except KeyboardInterrupt:
            break

    server_socket.close()
    sys.exit(0)


if __name__ == '__main__':
    run_server()