#!/usr/bin/env python3
#
# Copyright (C) 2024 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import grp
import logging
import multiprocessing
import os
import queue
import signal
import socket
import threading
from datetime import timedelta
from pathlib import Path
from time import sleep
from typing import Dict, AnyStr

from pyroute2 import conntrack
from pyroute2.netlink import nfnetlink
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_CTNETLINK
from pyroute2.netlink.nfnetlink.nfctsocket import nfct_msg, \
    IPCTNL_MSG_CT_DELETE, IPCTNL_MSG_CT_NEW, IPS_SEEN_REPLY, \
    IPS_OFFLOAD, IPS_ASSURED

from vyos.utils.file import read_json


shutdown_event = multiprocessing.Event()

logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)


class DebugFormatter(logging.Formatter):
    def format(self, record):
        self._style._fmt = '[%(asctime)s] %(levelname)s: %(message)s'
        return super().format(record)


def set_log_level(level: str) -> None:
    if level == 'debug':
        logger.setLevel(logging.DEBUG)
        logger.parent.handlers[0].setFormatter(DebugFormatter())
    else:
        logger.setLevel(logging.INFO)


EVENT_NAME_TO_GROUP = {
    'new': nfnetlink.NFNLGRP_CONNTRACK_NEW,
    'update': nfnetlink.NFNLGRP_CONNTRACK_UPDATE,
    'destroy': nfnetlink.NFNLGRP_CONNTRACK_DESTROY
}

#  https://github.com/torvalds/linux/blob/1dfe225e9af5bd3399a1dbc6a4df6a6041ff9c23/include/uapi/linux/netfilter/nf_conntrack_tcp.h#L9
TCP_CONNTRACK_SYN_SENT = 1
TCP_CONNTRACK_SYN_RECV = 2
TCP_CONNTRACK_ESTABLISHED = 3
TCP_CONNTRACK_FIN_WAIT = 4
TCP_CONNTRACK_CLOSE_WAIT = 5
TCP_CONNTRACK_LAST_ACK = 6
TCP_CONNTRACK_TIME_WAIT = 7
TCP_CONNTRACK_CLOSE = 8
TCP_CONNTRACK_LISTEN = 9
TCP_CONNTRACK_MAX = 10
TCP_CONNTRACK_IGNORE = 11
TCP_CONNTRACK_RETRANS = 12
TCP_CONNTRACK_UNACK = 13
TCP_CONNTRACK_TIMEOUT_MAX = 14

TCP_CONNTRACK_TO_NAME = {
    TCP_CONNTRACK_SYN_SENT: "SYN_SENT",
    TCP_CONNTRACK_SYN_RECV: "SYN_RECV",
    TCP_CONNTRACK_ESTABLISHED: "ESTABLISHED",
    TCP_CONNTRACK_FIN_WAIT: "FIN_WAIT",
    TCP_CONNTRACK_CLOSE_WAIT: "CLOSE_WAIT",
    TCP_CONNTRACK_LAST_ACK: "LAST_ACK",
    TCP_CONNTRACK_TIME_WAIT: "TIME_WAIT",
    TCP_CONNTRACK_CLOSE: "CLOSE",
    TCP_CONNTRACK_LISTEN: "LISTEN",
    TCP_CONNTRACK_MAX: "MAX",
    TCP_CONNTRACK_IGNORE: "IGNORE",
    TCP_CONNTRACK_RETRANS: "RETRANS",
    TCP_CONNTRACK_UNACK: "UNACK",
    TCP_CONNTRACK_TIMEOUT_MAX: "TIMEOUT_MAX",
}

# https://github.com/torvalds/linux/blob/1dfe225e9af5bd3399a1dbc6a4df6a6041ff9c23/include/uapi/linux/netfilter/nf_conntrack_sctp.h#L8
SCTP_CONNTRACK_CLOSED = 1
SCTP_CONNTRACK_COOKIE_WAIT = 2
SCTP_CONNTRACK_COOKIE_ECHOED = 3
SCTP_CONNTRACK_ESTABLISHED = 4
SCTP_CONNTRACK_SHUTDOWN_SENT = 5
SCTP_CONNTRACK_SHUTDOWN_RECD = 6
SCTP_CONNTRACK_SHUTDOWN_ACK_SENT = 7
SCTP_CONNTRACK_HEARTBEAT_SENT = 8
SCTP_CONNTRACK_HEARTBEAT_ACKED = 9  # no longer used
SCTP_CONNTRACK_MAX = 10

SCTP_CONNTRACK_TO_NAME = {
    SCTP_CONNTRACK_CLOSED: 'CLOSED',
    SCTP_CONNTRACK_COOKIE_WAIT: 'COOKIE_WAIT',
    SCTP_CONNTRACK_COOKIE_ECHOED: 'COOKIE_ECHOED',
    SCTP_CONNTRACK_ESTABLISHED: 'ESTABLISHED',
    SCTP_CONNTRACK_SHUTDOWN_SENT: 'SHUTDOWN_SENT',
    SCTP_CONNTRACK_SHUTDOWN_RECD: 'SHUTDOWN_RECD',
    SCTP_CONNTRACK_SHUTDOWN_ACK_SENT: 'SHUTDOWN_ACK_SENT',
    SCTP_CONNTRACK_HEARTBEAT_SENT: 'HEARTBEAT_SENT',
    SCTP_CONNTRACK_HEARTBEAT_ACKED: 'HEARTBEAT_ACKED',
    SCTP_CONNTRACK_MAX: 'MAX',
}

PROTO_CONNTRACK_TO_NAME = {
    'TCP': TCP_CONNTRACK_TO_NAME,
    'SCTP': SCTP_CONNTRACK_TO_NAME
}

SUPPORTED_PROTO_TO_NAME = {
    socket.IPPROTO_ICMP: 'icmp',
    socket.IPPROTO_TCP: 'tcp',
    socket.IPPROTO_UDP: 'udp',
}

PROTO_TO_NAME = {
    socket.IPPROTO_ICMPV6: 'icmpv6',
    socket.IPPROTO_SCTP: 'sctp',
    socket.IPPROTO_GRE: 'gre',
}

PROTO_TO_NAME.update(SUPPORTED_PROTO_TO_NAME)


def sig_handler(signum, frame):
    process_name = multiprocessing.current_process().name
    logger.debug(f'[{process_name}]: {"Shutdown" if signum == signal.SIGTERM else "Reload"} signal received...')
    shutdown_event.set()


def format_flow_data(data: Dict) -> AnyStr:
    """
    Formats the flow event data into a string suitable for logging.
    """
    key_format = {
        'SRC_PORT': 'sport',
        'DST_PORT': 'dport'
    }
    message = f"src={data['ADDR'].get('SRC')} dst={data['ADDR'].get('DST')}"

    for key in ['SRC_PORT', 'DST_PORT', 'TYPE', 'CODE', 'ID']:
        tmp = data['PROTO'].get(key)
        if tmp is not None:
            key = key_format.get(key, key)
            message += f" {key.lower()}={tmp}"

    if 'COUNTERS' in data:
        for key in ['PACKETS', 'BYTES']:
            tmp = data['COUNTERS'].get(key)
            if tmp is not None:
                message += f" {key.lower()}={tmp}"

    return message


def format_event_message(event: Dict) -> AnyStr:
    """
    Formats the internal parsed event data into a string suitable for logging.
    """
    event_type = f"[{event['COMMON']['EVENT_TYPE'].upper()}]"
    message = f"{event_type:<{9}} {event['COMMON']['ID']} " \
              f"{event['ORIG']['PROTO'].get('NAME'):<{8}} " \
              f"{event['ORIG']['PROTO'].get('NUMBER')} "

    tmp = event['COMMON']['TIME_OUT']
    if tmp is not None: message += f"{tmp} "

    if proto_info := event['COMMON'].get('PROTO_INFO'):
        message += f"{proto_info.get('STATE_NAME')} "

    for key in ['ORIG', 'REPLY']:
        message += f"{format_flow_data(event[key])} "
        if key == 'ORIG' and not (event['COMMON']['STATUS'] & IPS_SEEN_REPLY):
            message += f"[UNREPLIED] "

    tmp = event['COMMON']['MARK']
    if tmp is not None: message += f"mark={tmp} "

    if event['COMMON']['STATUS'] & IPS_OFFLOAD: message += f" [OFFLOAD] "
    elif event['COMMON']['STATUS'] & IPS_ASSURED: message += f" [ASSURED] "

    if tmp := event['COMMON']['PORTID']: message += f"portid={tmp} "
    if tstamp := event['COMMON'].get('TIMESTAMP'):
        message += f"start={tstamp['START']} stop={tstamp['STOP']} "
        delta_ns = tstamp['STOP'] - tstamp['START']
        delta_s = delta_ns // 1e9
        remaining_ns = delta_ns % 1e9
        delta = timedelta(seconds=delta_s, microseconds=remaining_ns / 1000)
        message += f"delta={delta.total_seconds()} "

    return message


def parse_event_type(header: Dict) -> AnyStr:
    """
    Extract event type from nfct_msg. new, update, destroy
    """
    event_type = 'unknown'
    if header['type'] == IPCTNL_MSG_CT_DELETE | (NFNL_SUBSYS_CTNETLINK << 8):
        event_type = 'destroy'
    elif header['type'] == IPCTNL_MSG_CT_NEW | (NFNL_SUBSYS_CTNETLINK << 8):
        event_type = 'update'
        if header['flags']:
            event_type = 'new'
    return event_type


def parse_proto(cta: nfct_msg.cta_tuple) -> Dict:
    """
    Extract proto info from nfct_msg. src/dst port, code, type, id
    """
    data = dict()

    cta_proto = cta.get_attr('CTA_TUPLE_PROTO')
    proto_num = cta_proto.get_attr('CTA_PROTO_NUM')

    data['NUMBER'] = proto_num
    data['NAME'] = PROTO_TO_NAME.get(proto_num, 'unknown')

    if proto_num in (socket.IPPROTO_ICMP, socket.IPPROTO_ICMPV6):
        pref = 'CTA_PROTO_ICMP'
        if proto_num == socket.IPPROTO_ICMPV6: pref += 'V6'
        keys = ['TYPE', 'CODE', 'ID']
    else:
        pref = 'CTA_PROTO'
        keys = ['SRC_PORT', 'DST_PORT']

    for key in keys:
        data[key] = cta_proto.get_attr(f'{pref}_{key}')

    return data


def parse_proto_info(cta: nfct_msg.cta_protoinfo) -> Dict:
    """
    Extract proto state and state name from nfct_msg
    """
    data = dict()
    if not cta:
        return data

    for proto in ['TCP', 'SCTP']:
        if proto_info := cta.get_attr(f'CTA_PROTOINFO_{proto}'):
            data['STATE'] = proto_info.get_attr(f'CTA_PROTOINFO_{proto}_STATE')
            data['STATE_NAME'] = PROTO_CONNTRACK_TO_NAME.get(proto, {}).get(data['STATE'], 'unknown')
    return data


def parse_timestamp(cta: nfct_msg.cta_timestamp) -> Dict:
    """
    Extract timestamp from nfct_msg
    """
    data = dict()
    if not cta:
        return data
    data['START'] = cta.get_attr('CTA_TIMESTAMP_START')
    data['STOP'] = cta.get_attr('CTA_TIMESTAMP_STOP')

    return data


def parse_ip_addr(family: int, cta: nfct_msg.cta_tuple) -> Dict:
    """
    Extract ip adr from nfct_msg
    """
    data = dict()
    cta_ip = cta.get_attr('CTA_TUPLE_IP')

    if family == socket.AF_INET:
        pref = 'CTA_IP_V4'
    elif family == socket.AF_INET6:
        pref = 'CTA_IP_V6'
    else:
        logger.error(f'Undefined INET: {family}')
        raise NotImplementedError(family)

    for direct in ['SRC', 'DST']:
        data[direct] = cta_ip.get_attr(f'{pref}_{direct}')

    return data


def parse_counters(cta: nfct_msg.cta_counters) -> Dict:
    """
    Extract counters from nfct_msg
    """
    data = dict()
    if not cta:
        return data

    for key in ['PACKETS', 'BYTES']:
        tmp = cta.get_attr(f'CTA_COUNTERS_{key}')
        if tmp is None:
            tmp = cta.get_attr(f'CTA_COUNTERS32_{key}')
        data['key'] = tmp

    return data


def is_need_to_log(event_type: AnyStr, proto_num: int, conf_event: Dict):
    """
    Filter message by event type and protocols
    """
    conf = conf_event.get(event_type)
    if conf == {} or conf.get(SUPPORTED_PROTO_TO_NAME.get(proto_num, 'other')) is not None:
        return True
    return False


def parse_conntrack_event(msg: nfct_msg, conf_event: Dict) -> Dict:
    """
    Convert nfct_msg to internal data dict.
    """
    data = dict()
    event_type = parse_event_type(msg['header'])
    proto_num = msg.get_nested('CTA_TUPLE_ORIG', 'CTA_TUPLE_PROTO', 'CTA_PROTO_NUM')

    if not is_need_to_log(event_type, proto_num, conf_event):
        return data

    data = {
        'COMMON': {
            'ID': msg.get_attr('CTA_ID'),
            'EVENT_TYPE': event_type,
            'TIME_OUT': msg.get_attr('CTA_TIMEOUT'),
            'MARK': msg.get_attr('CTA_MARK'),
            'PORTID': msg['header'].get('pid'),
            'PROTO_INFO': parse_proto_info(msg.get_attr('CTA_PROTOINFO')),
            'STATUS': msg.get_attr('CTA_STATUS'),
            'TIMESTAMP': parse_timestamp(msg.get_attr('CTA_TIMESTAMP'))
        },
        'ORIG': {},
        'REPLY': {},
    }

    for direct in ['ORIG', 'REPLY']:
        data[direct]['ADDR'] = parse_ip_addr(msg['nfgen_family'], msg.get_attr(f'CTA_TUPLE_{direct}'))
        data[direct]['PROTO'] = parse_proto(msg.get_attr(f'CTA_TUPLE_{direct}'))
        data[direct]['COUNTERS'] = parse_counters(msg.get_attr(f'CTA_COUNTERS_{direct}'))

    return data


def worker(ct: conntrack.Conntrack, shutdown_event: multiprocessing.Event, conf_event: Dict):
    """
    Main function of parser worker process
    """
    process_name = multiprocessing.current_process().name
    logger.debug(f'[{process_name}] started')
    timeout = 0.1
    while not shutdown_event.is_set():
        if not ct.buffer_queue.empty():
            try:
                for msg in ct.get():
                    parsed_event = parse_conntrack_event(msg, conf_event)
                    if parsed_event:
                        message = format_event_message(parsed_event)
                        if logger.level == logging.DEBUG:
                            logger.debug(f"[{process_name}]: {message} raw: {msg}")
                        else:
                            logger.info(message)
            except queue.Full:
                logger.error("Conntrack message queue if full.")
            except Exception as e:
                logger.error(f"Error in queue: {e.__class__} {e}")
        else:
            sleep(timeout)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--config',
                        action='store',
                        help='Path to vyos-conntrack-logger configuration',
                        required=True,
                        type=Path)

    args = parser.parse_args()
    try:
        config = read_json(args.config)
    except Exception as err:
        logger.error(f'Configuration file "{args.config}" does not exist or malformed: {err}')
        exit(1)

    set_log_level(config.get('log_level', 'info'))

    signal.signal(signal.SIGHUP, sig_handler)
    signal.signal(signal.SIGTERM, sig_handler)

    if 'event' in config:
        event_groups = list(config.get('event').keys())
    else:
        logger.error(f'Configuration is wrong. Event filter is empty.')
        exit(1)

    conf_event = config['event']
    qsize = config.get('queue_size')
    ct = conntrack.Conntrack(async_qsize=int(qsize) if qsize else None)
    ct.buffer_queue = multiprocessing.Queue(ct.async_qsize)
    ct.bind(async_cache=True)

    for name in event_groups:
        if group := EVENT_NAME_TO_GROUP.get(name):
            ct.add_membership(group)
        else:
            logger.error(f'Unexpected event group {name}')
    processes = list()
    try:
        for _ in range(multiprocessing.cpu_count()):
            p = multiprocessing.Process(target=worker, args=(ct,
                                                             shutdown_event,
                                                             conf_event))
            processes.append(p)
            p.start()
        logger.info('Conntrack socket bound and listening for messages.')

        while not shutdown_event.is_set():
            if not ct.pthread.is_alive():
                if ct.buffer_queue.qsize()/ct.async_qsize < 0.9:
                    if not shutdown_event.is_set():
                        logger.debug('Restart listener thread')
                        # restart listener thread after queue overloaded when queue size low than 90%
                        ct.pthread = threading.Thread(
                            name="Netlink async cache", target=ct.async_recv
                        )
                        ct.pthread.daemon = True
                        ct.pthread.start()
            else:
                sleep(0.1)
    finally:
        for p in processes:
            p.join()
            if not p.is_alive():
                logger.debug(f"[{p.name}]: finished")
        ct.close()
        logging.info("Conntrack socket closed.")
    exit()