From c509d0e6caae55106a2fbde3059652a493ed3903 Mon Sep 17 00:00:00 2001
From: khramshinr <khramshinr@gmail.com>
Date: Mon, 8 Jul 2024 16:38:22 +0600
Subject: T6362: Create conntrack logger daemon

---
 src/services/vyos-conntrack-logger | 458 +++++++++++++++++++++++++++++++++++++
 1 file changed, 458 insertions(+)
 create mode 100755 src/services/vyos-conntrack-logger

(limited to 'src/services')

diff --git a/src/services/vyos-conntrack-logger b/src/services/vyos-conntrack-logger
new file mode 100755
index 000000000..9c31b465f
--- /dev/null
+++ b/src/services/vyos-conntrack-logger
@@ -0,0 +1,458 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2024 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import argparse
+import grp
+import logging
+import multiprocessing
+import os
+import queue
+import signal
+import socket
+import threading
+from datetime import timedelta
+from pathlib import Path
+from time import sleep
+from typing import Dict, AnyStr
+
+from pyroute2 import conntrack
+from pyroute2.netlink import nfnetlink
+from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_CTNETLINK
+from pyroute2.netlink.nfnetlink.nfctsocket import nfct_msg, \
+    IPCTNL_MSG_CT_DELETE, IPCTNL_MSG_CT_NEW, IPS_SEEN_REPLY, \
+    IPS_OFFLOAD, IPS_ASSURED
+
+from vyos.utils.file import read_json
+
+
+shutdown_event = multiprocessing.Event()
+
+logging.basicConfig(level=logging.INFO, format='%(message)s')
+logger = logging.getLogger(__name__)
+
+
+class DebugFormatter(logging.Formatter):
+    def format(self, record):
+        self._style._fmt = '[%(asctime)s] %(levelname)s: %(message)s'
+        return super().format(record)
+
+
+def set_log_level(level: str) -> None:
+    if level == 'debug':
+        logger.setLevel(logging.DEBUG)
+        logger.parent.handlers[0].setFormatter(DebugFormatter())
+    else:
+        logger.setLevel(logging.INFO)
+
+
+EVENT_NAME_TO_GROUP = {
+    'new': nfnetlink.NFNLGRP_CONNTRACK_NEW,
+    'update': nfnetlink.NFNLGRP_CONNTRACK_UPDATE,
+    'destroy': nfnetlink.NFNLGRP_CONNTRACK_DESTROY
+}
+
+#  https://github.com/torvalds/linux/blob/1dfe225e9af5bd3399a1dbc6a4df6a6041ff9c23/include/uapi/linux/netfilter/nf_conntrack_tcp.h#L9
+TCP_CONNTRACK_SYN_SENT = 1
+TCP_CONNTRACK_SYN_RECV = 2
+TCP_CONNTRACK_ESTABLISHED = 3
+TCP_CONNTRACK_FIN_WAIT = 4
+TCP_CONNTRACK_CLOSE_WAIT = 5
+TCP_CONNTRACK_LAST_ACK = 6
+TCP_CONNTRACK_TIME_WAIT = 7
+TCP_CONNTRACK_CLOSE = 8
+TCP_CONNTRACK_LISTEN = 9
+TCP_CONNTRACK_MAX = 10
+TCP_CONNTRACK_IGNORE = 11
+TCP_CONNTRACK_RETRANS = 12
+TCP_CONNTRACK_UNACK = 13
+TCP_CONNTRACK_TIMEOUT_MAX = 14
+
+TCP_CONNTRACK_TO_NAME = {
+    TCP_CONNTRACK_SYN_SENT: "SYN_SENT",
+    TCP_CONNTRACK_SYN_RECV: "SYN_RECV",
+    TCP_CONNTRACK_ESTABLISHED: "ESTABLISHED",
+    TCP_CONNTRACK_FIN_WAIT: "FIN_WAIT",
+    TCP_CONNTRACK_CLOSE_WAIT: "CLOSE_WAIT",
+    TCP_CONNTRACK_LAST_ACK: "LAST_ACK",
+    TCP_CONNTRACK_TIME_WAIT: "TIME_WAIT",
+    TCP_CONNTRACK_CLOSE: "CLOSE",
+    TCP_CONNTRACK_LISTEN: "LISTEN",
+    TCP_CONNTRACK_MAX: "MAX",
+    TCP_CONNTRACK_IGNORE: "IGNORE",
+    TCP_CONNTRACK_RETRANS: "RETRANS",
+    TCP_CONNTRACK_UNACK: "UNACK",
+    TCP_CONNTRACK_TIMEOUT_MAX: "TIMEOUT_MAX",
+}
+
+# https://github.com/torvalds/linux/blob/1dfe225e9af5bd3399a1dbc6a4df6a6041ff9c23/include/uapi/linux/netfilter/nf_conntrack_sctp.h#L8
+SCTP_CONNTRACK_CLOSED = 1
+SCTP_CONNTRACK_COOKIE_WAIT = 2
+SCTP_CONNTRACK_COOKIE_ECHOED = 3
+SCTP_CONNTRACK_ESTABLISHED = 4
+SCTP_CONNTRACK_SHUTDOWN_SENT = 5
+SCTP_CONNTRACK_SHUTDOWN_RECD = 6
+SCTP_CONNTRACK_SHUTDOWN_ACK_SENT = 7
+SCTP_CONNTRACK_HEARTBEAT_SENT = 8
+SCTP_CONNTRACK_HEARTBEAT_ACKED = 9  # no longer used
+SCTP_CONNTRACK_MAX = 10
+
+SCTP_CONNTRACK_TO_NAME = {
+    SCTP_CONNTRACK_CLOSED: 'CLOSED',
+    SCTP_CONNTRACK_COOKIE_WAIT: 'COOKIE_WAIT',
+    SCTP_CONNTRACK_COOKIE_ECHOED: 'COOKIE_ECHOED',
+    SCTP_CONNTRACK_ESTABLISHED: 'ESTABLISHED',
+    SCTP_CONNTRACK_SHUTDOWN_SENT: 'SHUTDOWN_SENT',
+    SCTP_CONNTRACK_SHUTDOWN_RECD: 'SHUTDOWN_RECD',
+    SCTP_CONNTRACK_SHUTDOWN_ACK_SENT: 'SHUTDOWN_ACK_SENT',
+    SCTP_CONNTRACK_HEARTBEAT_SENT: 'HEARTBEAT_SENT',
+    SCTP_CONNTRACK_HEARTBEAT_ACKED: 'HEARTBEAT_ACKED',
+    SCTP_CONNTRACK_MAX: 'MAX',
+}
+
+PROTO_CONNTRACK_TO_NAME = {
+    'TCP': TCP_CONNTRACK_TO_NAME,
+    'SCTP': SCTP_CONNTRACK_TO_NAME
+}
+
+SUPPORTED_PROTO_TO_NAME = {
+    socket.IPPROTO_ICMP: 'icmp',
+    socket.IPPROTO_TCP: 'tcp',
+    socket.IPPROTO_UDP: 'udp',
+}
+
+PROTO_TO_NAME = {
+    socket.IPPROTO_ICMPV6: 'icmpv6',
+    socket.IPPROTO_SCTP: 'sctp',
+    socket.IPPROTO_GRE: 'gre',
+}
+
+PROTO_TO_NAME.update(SUPPORTED_PROTO_TO_NAME)
+
+
+def sig_handler(signum, frame):
+    process_name = multiprocessing.current_process().name
+    logger.debug(f'[{process_name}]: {"Shutdown" if signum == signal.SIGTERM else "Reload"} signal received...')
+    shutdown_event.set()
+
+
+def format_flow_data(data: Dict) -> AnyStr:
+    """
+    Formats the flow event data into a string suitable for logging.
+    """
+    key_format = {
+        'SRC_PORT': 'sport',
+        'DST_PORT': 'dport'
+    }
+    message = f"src={data['ADDR'].get('SRC')} dst={data['ADDR'].get('DST')}"
+
+    for key in ['SRC_PORT', 'DST_PORT', 'TYPE', 'CODE', 'ID']:
+        tmp = data['PROTO'].get(key)
+        if tmp is not None:
+            key = key_format.get(key, key)
+            message += f" {key.lower()}={tmp}"
+
+    if 'COUNTERS' in data:
+        for key in ['PACKETS', 'BYTES']:
+            tmp = data['COUNTERS'].get(key)
+            if tmp is not None:
+                message += f" {key.lower()}={tmp}"
+
+    return message
+
+
+def format_event_message(event: Dict) -> AnyStr:
+    """
+    Formats the internal parsed event data into a string suitable for logging.
+    """
+    event_type = f"[{event['COMMON']['EVENT_TYPE'].upper()}]"
+    message = f"{event_type:<{9}} {event['COMMON']['ID']} " \
+              f"{event['ORIG']['PROTO'].get('NAME'):<{8}} " \
+              f"{event['ORIG']['PROTO'].get('NUMBER')} "
+
+    tmp = event['COMMON']['TIME_OUT']
+    if tmp is not None: message += f"{tmp} "
+
+    if proto_info := event['COMMON'].get('PROTO_INFO'):
+        message += f"{proto_info.get('STATE_NAME')} "
+
+    for key in ['ORIG', 'REPLY']:
+        message += f"{format_flow_data(event[key])} "
+        if key == 'ORIG' and not (event['COMMON']['STATUS'] & IPS_SEEN_REPLY):
+            message += f"[UNREPLIED] "
+
+    tmp = event['COMMON']['MARK']
+    if tmp is not None: message += f"mark={tmp} "
+
+    if event['COMMON']['STATUS'] & IPS_OFFLOAD: message += f" [OFFLOAD] "
+    elif event['COMMON']['STATUS'] & IPS_ASSURED: message += f" [ASSURED] "
+
+    if tmp := event['COMMON']['PORTID']: message += f"portid={tmp} "
+    if tstamp := event['COMMON'].get('TIMESTAMP'):
+        message += f"start={tstamp['START']} stop={tstamp['STOP']} "
+        delta_ns = tstamp['STOP'] - tstamp['START']
+        delta_s = delta_ns // 1e9
+        remaining_ns = delta_ns % 1e9
+        delta = timedelta(seconds=delta_s, microseconds=remaining_ns / 1000)
+        message += f"delta={delta.total_seconds()} "
+
+    return message
+
+
+def parse_event_type(header: Dict) -> AnyStr:
+    """
+    Extract event type from nfct_msg. new, update, destroy
+    """
+    event_type = 'unknown'
+    if header['type'] == IPCTNL_MSG_CT_DELETE | (NFNL_SUBSYS_CTNETLINK << 8):
+        event_type = 'destroy'
+    elif header['type'] == IPCTNL_MSG_CT_NEW | (NFNL_SUBSYS_CTNETLINK << 8):
+        event_type = 'update'
+        if header['flags']:
+            event_type = 'new'
+    return event_type
+
+
+def parse_proto(cta: nfct_msg.cta_tuple) -> Dict:
+    """
+    Extract proto info from nfct_msg. src/dst port, code, type, id
+    """
+    data = dict()
+
+    cta_proto = cta.get_attr('CTA_TUPLE_PROTO')
+    proto_num = cta_proto.get_attr('CTA_PROTO_NUM')
+
+    data['NUMBER'] = proto_num
+    data['NAME'] = PROTO_TO_NAME.get(proto_num, 'unknown')
+
+    if proto_num in (socket.IPPROTO_ICMP, socket.IPPROTO_ICMPV6):
+        pref = 'CTA_PROTO_ICMP'
+        if proto_num == socket.IPPROTO_ICMPV6: pref += 'V6'
+        keys = ['TYPE', 'CODE', 'ID']
+    else:
+        pref = 'CTA_PROTO'
+        keys = ['SRC_PORT', 'DST_PORT']
+
+    for key in keys:
+        data[key] = cta_proto.get_attr(f'{pref}_{key}')
+
+    return data
+
+
+def parse_proto_info(cta: nfct_msg.cta_protoinfo) -> Dict:
+    """
+    Extract proto state and state name from nfct_msg
+    """
+    data = dict()
+    if not cta:
+        return data
+
+    for proto in ['TCP', 'SCTP']:
+        if proto_info := cta.get_attr(f'CTA_PROTOINFO_{proto}'):
+            data['STATE'] = proto_info.get_attr(f'CTA_PROTOINFO_{proto}_STATE')
+            data['STATE_NAME'] = PROTO_CONNTRACK_TO_NAME.get(proto, {}).get(data['STATE'], 'unknown')
+    return data
+
+
+def parse_timestamp(cta: nfct_msg.cta_timestamp) -> Dict:
+    """
+    Extract timestamp from nfct_msg
+    """
+    data = dict()
+    if not cta:
+        return data
+    data['START'] = cta.get_attr('CTA_TIMESTAMP_START')
+    data['STOP'] = cta.get_attr('CTA_TIMESTAMP_STOP')
+
+    return data
+
+
+def parse_ip_addr(family: int, cta: nfct_msg.cta_tuple) -> Dict:
+    """
+    Extract ip adr from nfct_msg
+    """
+    data = dict()
+    cta_ip = cta.get_attr('CTA_TUPLE_IP')
+
+    if family == socket.AF_INET:
+        pref = 'CTA_IP_V4'
+    elif family == socket.AF_INET6:
+        pref = 'CTA_IP_V6'
+    else:
+        logger.error(f'Undefined INET: {family}')
+        raise NotImplementedError(family)
+
+    for direct in ['SRC', 'DST']:
+        data[direct] = cta_ip.get_attr(f'{pref}_{direct}')
+
+    return data
+
+
+def parse_counters(cta: nfct_msg.cta_counters) -> Dict:
+    """
+    Extract counters from nfct_msg
+    """
+    data = dict()
+    if not cta:
+        return data
+
+    for key in ['PACKETS', 'BYTES']:
+        tmp = cta.get_attr(f'CTA_COUNTERS_{key}')
+        if tmp is None:
+            tmp = cta.get_attr(f'CTA_COUNTERS32_{key}')
+        data['key'] = tmp
+
+    return data
+
+
+def is_need_to_log(event_type: AnyStr, proto_num: int, conf_event: Dict):
+    """
+    Filter message by event type and protocols
+    """
+    conf = conf_event.get(event_type)
+    if conf == {} or conf.get(SUPPORTED_PROTO_TO_NAME.get(proto_num, 'other')) is not None:
+        return True
+    return False
+
+
+def parse_conntrack_event(msg: nfct_msg, conf_event: Dict) -> Dict:
+    """
+    Convert nfct_msg to internal data dict.
+    """
+    data = dict()
+    event_type = parse_event_type(msg['header'])
+    proto_num = msg.get_nested('CTA_TUPLE_ORIG', 'CTA_TUPLE_PROTO', 'CTA_PROTO_NUM')
+
+    if not is_need_to_log(event_type, proto_num, conf_event):
+        return data
+
+    data = {
+        'COMMON': {
+            'ID': msg.get_attr('CTA_ID'),
+            'EVENT_TYPE': event_type,
+            'TIME_OUT': msg.get_attr('CTA_TIMEOUT'),
+            'MARK': msg.get_attr('CTA_MARK'),
+            'PORTID': msg['header'].get('pid'),
+            'PROTO_INFO': parse_proto_info(msg.get_attr('CTA_PROTOINFO')),
+            'STATUS': msg.get_attr('CTA_STATUS'),
+            'TIMESTAMP': parse_timestamp(msg.get_attr('CTA_TIMESTAMP'))
+        },
+        'ORIG': {},
+        'REPLY': {},
+    }
+
+    for direct in ['ORIG', 'REPLY']:
+        data[direct]['ADDR'] = parse_ip_addr(msg['nfgen_family'], msg.get_attr(f'CTA_TUPLE_{direct}'))
+        data[direct]['PROTO'] = parse_proto(msg.get_attr(f'CTA_TUPLE_{direct}'))
+        data[direct]['COUNTERS'] = parse_counters(msg.get_attr(f'CTA_COUNTERS_{direct}'))
+
+    return data
+
+
+def worker(ct: conntrack.Conntrack, shutdown_event: multiprocessing.Event, conf_event: Dict):
+    """
+    Main function of parser worker process
+    """
+    process_name = multiprocessing.current_process().name
+    logger.debug(f'[{process_name}] started')
+    timeout = 0.1
+    while not shutdown_event.is_set():
+        if not ct.buffer_queue.empty():
+            try:
+                for msg in ct.get():
+                    parsed_event = parse_conntrack_event(msg, conf_event)
+                    if parsed_event:
+                        message = format_event_message(parsed_event)
+                        if logger.level == logging.DEBUG:
+                            logger.debug(f"[{process_name}]: {message} raw: {msg}")
+                        else:
+                            logger.info(message)
+            except queue.Full:
+                logger.error("Conntrack message queue if full.")
+            except Exception as e:
+                logger.error(f"Error in queue: {e.__class__} {e}")
+        else:
+            sleep(timeout)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-c',
+                        '--config',
+                        action='store',
+                        help='Path to vyos-conntrack-logger configuration',
+                        required=True,
+                        type=Path)
+
+    args = parser.parse_args()
+    try:
+        config = read_json(args.config)
+    except Exception as err:
+        logger.error(f'Configuration file "{args.config}" does not exist or malformed: {err}')
+        exit(1)
+
+    set_log_level(config.get('log_level', 'info'))
+
+    signal.signal(signal.SIGHUP, sig_handler)
+    signal.signal(signal.SIGTERM, sig_handler)
+
+    if 'event' in config:
+        event_groups = list(config.get('event').keys())
+    else:
+        logger.error(f'Configuration is wrong. Event filter is empty.')
+        exit(1)
+
+    conf_event = config['event']
+    qsize = config.get('queue_size')
+    ct = conntrack.Conntrack(async_qsize=int(qsize) if qsize else None)
+    ct.buffer_queue = multiprocessing.Queue(ct.async_qsize)
+    ct.bind(async_cache=True)
+
+    for name in event_groups:
+        if group := EVENT_NAME_TO_GROUP.get(name):
+            ct.add_membership(group)
+        else:
+            logger.error(f'Unexpected event group {name}')
+    processes = list()
+    try:
+        for _ in range(multiprocessing.cpu_count()):
+            p = multiprocessing.Process(target=worker, args=(ct,
+                                                             shutdown_event,
+                                                             conf_event))
+            processes.append(p)
+            p.start()
+        logger.info('Conntrack socket bound and listening for messages.')
+
+        while not shutdown_event.is_set():
+            if not ct.pthread.is_alive():
+                if ct.buffer_queue.qsize()/ct.async_qsize < 0.9:
+                    if not shutdown_event.is_set():
+                        logger.debug('Restart listener thread')
+                        # restart listener thread after queue overloaded when queue size low than 90%
+                        ct.pthread = threading.Thread(
+                            name="Netlink async cache", target=ct.async_recv
+                        )
+                        ct.pthread.daemon = True
+                        ct.pthread.start()
+            else:
+                sleep(0.1)
+    finally:
+        for p in processes:
+            p.join()
+            if not p.is_alive():
+                logger.debug(f"[{p.name}]: finished")
+        ct.close()
+        logging.info("Conntrack socket closed.")
+    exit()
-- 
cgit v1.2.3