#!/usr/bin/env python3 # # Copyright (C) 2024 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . import argparse import grp import logging import multiprocessing import os import queue import signal import socket import threading from datetime import timedelta from pathlib import Path from time import sleep from typing import Dict, AnyStr from pyroute2 import conntrack from pyroute2.netlink import nfnetlink from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_CTNETLINK from pyroute2.netlink.nfnetlink.nfctsocket import nfct_msg, \ IPCTNL_MSG_CT_DELETE, IPCTNL_MSG_CT_NEW, IPS_SEEN_REPLY, \ IPS_OFFLOAD, IPS_ASSURED from vyos.utils.file import read_json shutdown_event = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(__name__) class DebugFormatter(logging.Formatter): def format(self, record): self._style._fmt = '[%(asctime)s] %(levelname)s: %(message)s' return super().format(record) def set_log_level(level: str) -> None: if level == 'debug': logger.setLevel(logging.DEBUG) logger.parent.handlers[0].setFormatter(DebugFormatter()) else: logger.setLevel(logging.INFO) EVENT_NAME_TO_GROUP = { 'new': nfnetlink.NFNLGRP_CONNTRACK_NEW, 'update': nfnetlink.NFNLGRP_CONNTRACK_UPDATE, 'destroy': nfnetlink.NFNLGRP_CONNTRACK_DESTROY } # https://github.com/torvalds/linux/blob/1dfe225e9af5bd3399a1dbc6a4df6a6041ff9c23/include/uapi/linux/netfilter/nf_conntrack_tcp.h#L9 TCP_CONNTRACK_SYN_SENT = 1 TCP_CONNTRACK_SYN_RECV = 2 TCP_CONNTRACK_ESTABLISHED = 3 TCP_CONNTRACK_FIN_WAIT = 4 TCP_CONNTRACK_CLOSE_WAIT = 5 TCP_CONNTRACK_LAST_ACK = 6 TCP_CONNTRACK_TIME_WAIT = 7 TCP_CONNTRACK_CLOSE = 8 TCP_CONNTRACK_LISTEN = 9 TCP_CONNTRACK_MAX = 10 TCP_CONNTRACK_IGNORE = 11 TCP_CONNTRACK_RETRANS = 12 TCP_CONNTRACK_UNACK = 13 TCP_CONNTRACK_TIMEOUT_MAX = 14 TCP_CONNTRACK_TO_NAME = { TCP_CONNTRACK_SYN_SENT: "SYN_SENT", TCP_CONNTRACK_SYN_RECV: "SYN_RECV", TCP_CONNTRACK_ESTABLISHED: "ESTABLISHED", TCP_CONNTRACK_FIN_WAIT: "FIN_WAIT", TCP_CONNTRACK_CLOSE_WAIT: "CLOSE_WAIT", TCP_CONNTRACK_LAST_ACK: "LAST_ACK", TCP_CONNTRACK_TIME_WAIT: "TIME_WAIT", TCP_CONNTRACK_CLOSE: "CLOSE", TCP_CONNTRACK_LISTEN: "LISTEN", TCP_CONNTRACK_MAX: "MAX", TCP_CONNTRACK_IGNORE: "IGNORE", TCP_CONNTRACK_RETRANS: "RETRANS", TCP_CONNTRACK_UNACK: "UNACK", TCP_CONNTRACK_TIMEOUT_MAX: "TIMEOUT_MAX", } # https://github.com/torvalds/linux/blob/1dfe225e9af5bd3399a1dbc6a4df6a6041ff9c23/include/uapi/linux/netfilter/nf_conntrack_sctp.h#L8 SCTP_CONNTRACK_CLOSED = 1 SCTP_CONNTRACK_COOKIE_WAIT = 2 SCTP_CONNTRACK_COOKIE_ECHOED = 3 SCTP_CONNTRACK_ESTABLISHED = 4 SCTP_CONNTRACK_SHUTDOWN_SENT = 5 SCTP_CONNTRACK_SHUTDOWN_RECD = 6 SCTP_CONNTRACK_SHUTDOWN_ACK_SENT = 7 SCTP_CONNTRACK_HEARTBEAT_SENT = 8 SCTP_CONNTRACK_HEARTBEAT_ACKED = 9 # no longer used SCTP_CONNTRACK_MAX = 10 SCTP_CONNTRACK_TO_NAME = { SCTP_CONNTRACK_CLOSED: 'CLOSED', SCTP_CONNTRACK_COOKIE_WAIT: 'COOKIE_WAIT', SCTP_CONNTRACK_COOKIE_ECHOED: 'COOKIE_ECHOED', SCTP_CONNTRACK_ESTABLISHED: 'ESTABLISHED', SCTP_CONNTRACK_SHUTDOWN_SENT: 'SHUTDOWN_SENT', SCTP_CONNTRACK_SHUTDOWN_RECD: 'SHUTDOWN_RECD', SCTP_CONNTRACK_SHUTDOWN_ACK_SENT: 'SHUTDOWN_ACK_SENT', SCTP_CONNTRACK_HEARTBEAT_SENT: 'HEARTBEAT_SENT', SCTP_CONNTRACK_HEARTBEAT_ACKED: 'HEARTBEAT_ACKED', SCTP_CONNTRACK_MAX: 'MAX', } PROTO_CONNTRACK_TO_NAME = { 'TCP': TCP_CONNTRACK_TO_NAME, 'SCTP': SCTP_CONNTRACK_TO_NAME } SUPPORTED_PROTO_TO_NAME = { socket.IPPROTO_ICMP: 'icmp', socket.IPPROTO_TCP: 'tcp', socket.IPPROTO_UDP: 'udp', } PROTO_TO_NAME = { socket.IPPROTO_ICMPV6: 'icmpv6', socket.IPPROTO_SCTP: 'sctp', socket.IPPROTO_GRE: 'gre', } PROTO_TO_NAME.update(SUPPORTED_PROTO_TO_NAME) def sig_handler(signum, frame): process_name = multiprocessing.current_process().name logger.debug(f'[{process_name}]: {"Shutdown" if signum == signal.SIGTERM else "Reload"} signal received...') shutdown_event.set() def format_flow_data(data: Dict) -> AnyStr: """ Formats the flow event data into a string suitable for logging. """ key_format = { 'SRC_PORT': 'sport', 'DST_PORT': 'dport' } message = f"src={data['ADDR'].get('SRC')} dst={data['ADDR'].get('DST')}" for key in ['SRC_PORT', 'DST_PORT', 'TYPE', 'CODE', 'ID']: tmp = data['PROTO'].get(key) if tmp is not None: key = key_format.get(key, key) message += f" {key.lower()}={tmp}" if 'COUNTERS' in data: for key in ['PACKETS', 'BYTES']: tmp = data['COUNTERS'].get(key) if tmp is not None: message += f" {key.lower()}={tmp}" return message def format_event_message(event: Dict) -> AnyStr: """ Formats the internal parsed event data into a string suitable for logging. """ event_type = f"[{event['COMMON']['EVENT_TYPE'].upper()}]" message = f"{event_type:<{9}} {event['COMMON']['ID']} " \ f"{event['ORIG']['PROTO'].get('NAME'):<{8}} " \ f"{event['ORIG']['PROTO'].get('NUMBER')} " tmp = event['COMMON']['TIME_OUT'] if tmp is not None: message += f"{tmp} " if proto_info := event['COMMON'].get('PROTO_INFO'): message += f"{proto_info.get('STATE_NAME')} " for key in ['ORIG', 'REPLY']: message += f"{format_flow_data(event[key])} " if key == 'ORIG' and not (event['COMMON']['STATUS'] & IPS_SEEN_REPLY): message += f"[UNREPLIED] " tmp = event['COMMON']['MARK'] if tmp is not None: message += f"mark={tmp} " if event['COMMON']['STATUS'] & IPS_OFFLOAD: message += f" [OFFLOAD] " elif event['COMMON']['STATUS'] & IPS_ASSURED: message += f" [ASSURED] " if tmp := event['COMMON']['PORTID']: message += f"portid={tmp} " if tstamp := event['COMMON'].get('TIMESTAMP'): message += f"start={tstamp['START']} stop={tstamp['STOP']} " delta_ns = tstamp['STOP'] - tstamp['START'] delta_s = delta_ns // 1e9 remaining_ns = delta_ns % 1e9 delta = timedelta(seconds=delta_s, microseconds=remaining_ns / 1000) message += f"delta={delta.total_seconds()} " return message def parse_event_type(header: Dict) -> AnyStr: """ Extract event type from nfct_msg. new, update, destroy """ event_type = 'unknown' if header['type'] == IPCTNL_MSG_CT_DELETE | (NFNL_SUBSYS_CTNETLINK << 8): event_type = 'destroy' elif header['type'] == IPCTNL_MSG_CT_NEW | (NFNL_SUBSYS_CTNETLINK << 8): event_type = 'update' if header['flags']: event_type = 'new' return event_type def parse_proto(cta: nfct_msg.cta_tuple) -> Dict: """ Extract proto info from nfct_msg. src/dst port, code, type, id """ data = dict() cta_proto = cta.get_attr('CTA_TUPLE_PROTO') proto_num = cta_proto.get_attr('CTA_PROTO_NUM') data['NUMBER'] = proto_num data['NAME'] = PROTO_TO_NAME.get(proto_num, 'unknown') if proto_num in (socket.IPPROTO_ICMP, socket.IPPROTO_ICMPV6): pref = 'CTA_PROTO_ICMP' if proto_num == socket.IPPROTO_ICMPV6: pref += 'V6' keys = ['TYPE', 'CODE', 'ID'] else: pref = 'CTA_PROTO' keys = ['SRC_PORT', 'DST_PORT'] for key in keys: data[key] = cta_proto.get_attr(f'{pref}_{key}') return data def parse_proto_info(cta: nfct_msg.cta_protoinfo) -> Dict: """ Extract proto state and state name from nfct_msg """ data = dict() if not cta: return data for proto in ['TCP', 'SCTP']: if proto_info := cta.get_attr(f'CTA_PROTOINFO_{proto}'): data['STATE'] = proto_info.get_attr(f'CTA_PROTOINFO_{proto}_STATE') data['STATE_NAME'] = PROTO_CONNTRACK_TO_NAME.get(proto, {}).get(data['STATE'], 'unknown') return data def parse_timestamp(cta: nfct_msg.cta_timestamp) -> Dict: """ Extract timestamp from nfct_msg """ data = dict() if not cta: return data data['START'] = cta.get_attr('CTA_TIMESTAMP_START') data['STOP'] = cta.get_attr('CTA_TIMESTAMP_STOP') return data def parse_ip_addr(family: int, cta: nfct_msg.cta_tuple) -> Dict: """ Extract ip adr from nfct_msg """ data = dict() cta_ip = cta.get_attr('CTA_TUPLE_IP') if family == socket.AF_INET: pref = 'CTA_IP_V4' elif family == socket.AF_INET6: pref = 'CTA_IP_V6' else: logger.error(f'Undefined INET: {family}') raise NotImplementedError(family) for direct in ['SRC', 'DST']: data[direct] = cta_ip.get_attr(f'{pref}_{direct}') return data def parse_counters(cta: nfct_msg.cta_counters) -> Dict: """ Extract counters from nfct_msg """ data = dict() if not cta: return data for key in ['PACKETS', 'BYTES']: tmp = cta.get_attr(f'CTA_COUNTERS_{key}') if tmp is None: tmp = cta.get_attr(f'CTA_COUNTERS32_{key}') data['key'] = tmp return data def is_need_to_log(event_type: AnyStr, proto_num: int, conf_event: Dict): """ Filter message by event type and protocols """ conf = conf_event.get(event_type) if conf == {} or conf.get(SUPPORTED_PROTO_TO_NAME.get(proto_num, 'other')) is not None: return True return False def parse_conntrack_event(msg: nfct_msg, conf_event: Dict) -> Dict: """ Convert nfct_msg to internal data dict. """ data = dict() event_type = parse_event_type(msg['header']) proto_num = msg.get_nested('CTA_TUPLE_ORIG', 'CTA_TUPLE_PROTO', 'CTA_PROTO_NUM') if not is_need_to_log(event_type, proto_num, conf_event): return data data = { 'COMMON': { 'ID': msg.get_attr('CTA_ID'), 'EVENT_TYPE': event_type, 'TIME_OUT': msg.get_attr('CTA_TIMEOUT'), 'MARK': msg.get_attr('CTA_MARK'), 'PORTID': msg['header'].get('pid'), 'PROTO_INFO': parse_proto_info(msg.get_attr('CTA_PROTOINFO')), 'STATUS': msg.get_attr('CTA_STATUS'), 'TIMESTAMP': parse_timestamp(msg.get_attr('CTA_TIMESTAMP')) }, 'ORIG': {}, 'REPLY': {}, } for direct in ['ORIG', 'REPLY']: data[direct]['ADDR'] = parse_ip_addr(msg['nfgen_family'], msg.get_attr(f'CTA_TUPLE_{direct}')) data[direct]['PROTO'] = parse_proto(msg.get_attr(f'CTA_TUPLE_{direct}')) data[direct]['COUNTERS'] = parse_counters(msg.get_attr(f'CTA_COUNTERS_{direct}')) return data def worker(ct: conntrack.Conntrack, shutdown_event: multiprocessing.Event, conf_event: Dict): """ Main function of parser worker process """ process_name = multiprocessing.current_process().name logger.debug(f'[{process_name}] started') timeout = 0.1 while not shutdown_event.is_set(): if not ct.buffer_queue.empty(): try: for msg in ct.get(): parsed_event = parse_conntrack_event(msg, conf_event) if parsed_event: message = format_event_message(parsed_event) if logger.level == logging.DEBUG: logger.debug(f"[{process_name}]: {message} raw: {msg}") else: logger.info(message) except queue.Full: logger.error("Conntrack message queue if full.") except Exception as e: logger.error(f"Error in queue: {e.__class__} {e}") else: sleep(timeout) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', action='store', help='Path to vyos-conntrack-logger configuration', required=True, type=Path) args = parser.parse_args() try: config = read_json(args.config) except Exception as err: logger.error(f'Configuration file "{args.config}" does not exist or malformed: {err}') exit(1) set_log_level(config.get('log_level', 'info')) signal.signal(signal.SIGHUP, sig_handler) signal.signal(signal.SIGTERM, sig_handler) if 'event' in config: event_groups = list(config.get('event').keys()) else: logger.error(f'Configuration is wrong. Event filter is empty.') exit(1) conf_event = config['event'] qsize = config.get('queue_size') ct = conntrack.Conntrack(async_qsize=int(qsize) if qsize else None) ct.buffer_queue = multiprocessing.Queue(ct.async_qsize) ct.bind(async_cache=True) for name in event_groups: if group := EVENT_NAME_TO_GROUP.get(name): ct.add_membership(group) else: logger.error(f'Unexpected event group {name}') processes = list() try: for _ in range(multiprocessing.cpu_count()): p = multiprocessing.Process(target=worker, args=(ct, shutdown_event, conf_event)) processes.append(p) p.start() logger.info('Conntrack socket bound and listening for messages.') while not shutdown_event.is_set(): if not ct.pthread.is_alive(): if ct.buffer_queue.qsize()/ct.async_qsize < 0.9: if not shutdown_event.is_set(): logger.debug('Restart listener thread') # restart listener thread after queue overloaded when queue size low than 90% ct.pthread = threading.Thread( name="Netlink async cache", target=ct.async_recv ) ct.pthread.daemon = True ct.pthread.start() else: sleep(0.1) finally: for p in processes: p.join() if not p.is_alive(): logger.debug(f"[{p.name}]: finished") ct.close() logging.info("Conntrack socket closed.") exit()