diff options
-rw-r--r-- | Makefile | 4 | ||||
-rw-r--r-- | data/templates/https/vyos-http-api.service.j2 | 1 | ||||
-rw-r--r-- | data/templates/pmacct/override.conf.j2 | 4 | ||||
-rw-r--r-- | data/templates/pmacct/uacctd.conf.j2 | 2 | ||||
-rw-r--r-- | python/vyos/progressbar.py | 70 | ||||
-rw-r--r-- | python/vyos/remote.py | 44 | ||||
-rw-r--r-- | python/vyos/utils/io.py | 39 | ||||
-rwxr-xr-x | src/conf_mode/flow_accounting_conf.py | 34 | ||||
-rwxr-xr-x | src/conf_mode/http-api.py | 6 | ||||
-rw-r--r-- | src/etc/sysctl.d/30-vyos-router.conf | 3 | ||||
-rwxr-xr-x | src/services/vyos-http-api-server | 173 | ||||
-rwxr-xr-x | src/system/uacctd_stop.py | 67 |
12 files changed, 334 insertions, 113 deletions
@@ -26,10 +26,10 @@ interface_definitions: $(config_xml_obj) $(CURDIR)/scripts/override-default $(BUILD_DIR)/interface-definitions - $(CURDIR)/python/vyos/xml_ref/generate_cache.py --xml-dir $(BUILD_DIR)/interface-definitions || exit 1 - find $(BUILD_DIR)/interface-definitions -type f -name "*.xml" | xargs -I {} $(CURDIR)/scripts/build-command-templates {} $(CURDIR)/schema/interface_definition.rng $(TMPL_DIR) || exit 1 + $(CURDIR)/python/vyos/xml_ref/generate_cache.py --xml-dir $(BUILD_DIR)/interface-definitions || exit 1 + # XXX: delete top level node.def's that now live in other packages # IPSec VPN EAP-RADIUS does not support source-address rm -rf $(TMPL_DIR)/vpn/ipsec/remote-access/radius/source-address diff --git a/data/templates/https/vyos-http-api.service.j2 b/data/templates/https/vyos-http-api.service.j2 index fb424e06c..f620b3248 100644 --- a/data/templates/https/vyos-http-api.service.j2 +++ b/data/templates/https/vyos-http-api.service.j2 @@ -6,6 +6,7 @@ Requires=vyos-router.service [Service] ExecStart={{ vrf_command }}/usr/libexec/vyos/services/vyos-http-api-server +ExecReload=kill -HUP $MAINPID Type=idle SyslogIdentifier=vyos-http-api diff --git a/data/templates/pmacct/override.conf.j2 b/data/templates/pmacct/override.conf.j2 index 213569ddc..44a100bb6 100644 --- a/data/templates/pmacct/override.conf.j2 +++ b/data/templates/pmacct/override.conf.j2 @@ -9,9 +9,9 @@ ConditionPathExists=/run/pmacct/uacctd.conf EnvironmentFile= ExecStart= ExecStart={{ vrf_command }}/usr/sbin/uacctd -f /run/pmacct/uacctd.conf +ExecStop=/usr/libexec/vyos/system/uacctd_stop.py $MAINPID 60 WorkingDirectory= WorkingDirectory=/run/pmacct -PIDFile= -PIDFile=/run/pmacct/uacctd.pid Restart=always RestartSec=10 +KillMode=mixed diff --git a/data/templates/pmacct/uacctd.conf.j2 b/data/templates/pmacct/uacctd.conf.j2 index 1370f8121..aae0a0619 100644 --- a/data/templates/pmacct/uacctd.conf.j2 +++ b/data/templates/pmacct/uacctd.conf.j2 @@ -1,7 +1,7 @@ # Genereated from VyOS configuration daemonize: true promisc: false -pidfile: /run/pmacct/uacctd.pid +syslog: daemon uacctd_group: 2 uacctd_nl_size: 2097152 snaplen: {{ packet_length }} diff --git a/python/vyos/progressbar.py b/python/vyos/progressbar.py new file mode 100644 index 000000000..1793c445b --- /dev/null +++ b/python/vyos/progressbar.py @@ -0,0 +1,70 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +import math +import os +import signal +import subprocess +import sys + +from vyos.utils.io import print_error + +class Progressbar: + def __init__(self, step=None): + self.total = 0.0 + self.step = step + def __enter__(self): + # Recalculate terminal width with every window resize. + signal.signal(signal.SIGWINCH, lambda signum, frame: self._update_cols()) + # Disable line wrapping to prevent the staircase effect. + subprocess.run(['tput', 'rmam'], check=False) + self._update_cols() + # Print an empty progressbar with entry. + self.progress(0, 1) + return self + def __exit__(self, exc_type, kexc_val, exc_tb): + # Revert to the default SIGWINCH handler (ie nothing). + signal.signal(signal.SIGWINCH, signal.SIG_DFL) + # Reenable line wrapping. + subprocess.run(['tput', 'smam'], check=False) + def _update_cols(self): + # `os.get_terminal_size()' is fast enough for our purposes. + self.col = max(os.get_terminal_size().columns - 15, 20) + def increment(self): + """ + Stateful progressbar taking the step fraction at init and no input at + callback (for FTP) + """ + if self.step: + if self.total < 1.0: + self.total += self.step + if self.total >= 1.0: + self.total = 1.0 + # Ignore superfluous calls caused by fuzzy FTP size calculations. + self.step = None + self.progress(self.total, 1.0) + def progress(self, done, total): + """ + Stateless progressbar taking no input at init and current progress with + final size at callback (for SSH) + """ + if done <= total: + length = math.ceil(self.col * done / total) + percentage = str(math.ceil(100 * done / total)).rjust(3) + # Carriage return at the end will make sure the line will get overwritten. + print_error(f'[{length * "#"}{(self.col - length) * "_"}] {percentage}%', end='\r') + # Print a newline to make sure the full progressbar doesn't get overwritten by the next line. + if done == total: + print_error() diff --git a/python/vyos/remote.py b/python/vyos/remote.py index cf731c881..1ca8a9530 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -32,9 +32,8 @@ from requests import Session from requests.adapters import HTTPAdapter from requests.packages.urllib3 import PoolManager +from vyos.progressbar import Progressbar from vyos.utils.io import ask_yes_no -from vyos.utils.io import make_incremental_progressbar -from vyos.utils.io import make_progressbar from vyos.utils.io import print_error from vyos.utils.misc import begin from vyos.utils.process import cmd @@ -131,16 +130,16 @@ class FtpC: if self.secure: conn.prot_p() # Almost all FTP servers support the `SIZE' command. + size = conn.size(self.path) if self.check_space: - check_storage(path, conn.size(self.path)) + check_storage(path, size) # No progressbar if we can't determine the size or if the file is too small. if self.progressbar and size and size > CHUNK_SIZE: - progress = make_incremental_progressbar(CHUNK_SIZE / size) - next(progress) - callback = lambda block: begin(f.write(block), next(progress)) + with Progressbar(CHUNK_SIZE / size) as p: + callback = lambda block: begin(f.write(block), p.increment()) + conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) else: - callback = f.write - conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) + conn.retrbinary('RETR ' + self.path, f.write, CHUNK_SIZE) def upload(self, location: str): size = os.path.getsize(location) @@ -150,12 +149,10 @@ class FtpC: if self.secure: conn.prot_p() if self.progressbar and size and size > CHUNK_SIZE: - progress = make_incremental_progressbar(CHUNK_SIZE / size) - next(progress) - callback = lambda block: next(progress) + with Progressbar(CHUNK_SIZE / size) as p: + conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, lambda block: p.increment()) else: - callback = None - conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, callback) + conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE) class SshC: known_hosts = os.path.expanduser('~/.ssh/known_hosts') @@ -190,14 +187,16 @@ class SshC: return ssh def download(self, location: str): - callback = make_progressbar() if self.progressbar else None with self._establish() as ssh, ssh.open_sftp() as sftp: if self.check_space: check_storage(location, sftp.stat(self.path).st_size) - sftp.get(self.path, location, callback=callback) + if self.progressbar: + with Progressbar() as p: + sftp.get(self.path, location, callback=p.progress) + else: + sftp.get(self.path, location) def upload(self, location: str): - callback = make_progressbar() if self.progressbar else None with self._establish() as ssh, ssh.open_sftp() as sftp: try: # If the remote path is a directory, use the original filename. @@ -210,7 +209,11 @@ class SshC: except IOError: path = self.path finally: - sftp.put(location, path, callback=callback) + if self.progressbar: + with Progressbar() as p: + sftp.put(location, path, callback=p.progress) + else: + sftp.put(location, path) class HttpC: @@ -264,10 +267,9 @@ class HttpC: with s.get(final_urlstring, stream=True, timeout=self.timeout) as r, open(location, 'wb') as f: if self.progressbar and size: - progress = make_incremental_progressbar(CHUNK_SIZE / size) - next(progress) - for chunk in iter(lambda: begin(next(progress), r.raw.read(CHUNK_SIZE)), b''): - f.write(chunk) + with Progressbar(CHUNK_SIZE / size) as p: + for chunk in iter(lambda: begin(p.increment(), r.raw.read(CHUNK_SIZE)), b''): + f.write(chunk) else: # We'll try to stream the download directly with `copyfileobj()` so that large # files (like entire VyOS images) don't occupy much memory. diff --git a/python/vyos/utils/io.py b/python/vyos/utils/io.py index 843494855..5fffa62f8 100644 --- a/python/vyos/utils/io.py +++ b/python/vyos/utils/io.py @@ -24,45 +24,6 @@ def print_error(str='', end='\n'): sys.stderr.write(end) sys.stderr.flush() -def make_progressbar(): - """ - Make a procedure that takes two arguments `done` and `total` and prints a - progressbar based on the ratio thereof, whose length is determined by the - width of the terminal. - """ - import shutil, math - col, _ = shutil.get_terminal_size() - col = max(col - 15, 20) - def print_progressbar(done, total): - if done <= total: - increment = total / col - length = math.ceil(done / increment) - percentage = str(math.ceil(100 * done / total)).rjust(3) - print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') - # Print a newline so that the subsequent prints don't overwrite the full bar. - if done == total: - print_error() - return print_progressbar - -def make_incremental_progressbar(increment: float): - """ - Make a generator that displays a progressbar that grows monotonically with - every iteration. - First call displays it at 0% and every subsequent iteration displays it - at `increment` increments where 0.0 < `increment` < 1.0. - Intended for FTP and HTTP transfers with stateless callbacks. - """ - print_progressbar = make_progressbar() - total = 0.0 - while total < 1.0: - print_progressbar(total, 1.0) - yield - total += increment - print_progressbar(1, 1) - # Ignore further calls. - while True: - yield - def ask_input(question, default='', numeric_only=False, valid_responses=[]): question_out = question if default: diff --git a/src/conf_mode/flow_accounting_conf.py b/src/conf_mode/flow_accounting_conf.py index 71acd69fa..f29fc94fb 100755 --- a/src/conf_mode/flow_accounting_conf.py +++ b/src/conf_mode/flow_accounting_conf.py @@ -28,6 +28,7 @@ from vyos.ifconfig import Section from vyos.template import render from vyos.utils.process import call from vyos.utils.process import cmd +from vyos.utils.process import run from vyos.utils.network import is_addr_assigned from vyos import ConfigError from vyos import airbag @@ -116,6 +117,30 @@ def _nftables_config(configured_ifaces, direction, length=None): cmd(command, raising=ConfigError) +def _nftables_trigger_setup(operation: str) -> None: + """Add a dummy rule to unlock the main pmacct loop with a packet-trigger + + Args: + operation (str): 'add' or 'delete' a trigger + """ + # check if a chain exists + table_exists = False + if run('nft -snj list table ip pmacct') == 0: + table_exists = True + + if operation == 'delete' and table_exists: + nft_cmd: str = 'nft delete table ip pmacct' + cmd(nft_cmd, raising=ConfigError) + if operation == 'add' and not table_exists: + nft_cmds: list[str] = [ + 'nft add table ip pmacct', + 'nft add chain ip pmacct pmacct_out { type filter hook output priority raw - 50 \\; policy accept \\; }', + 'nft add rule ip pmacct pmacct_out oif lo ip daddr 127.0.254.0 counter log group 2 snaplen 1 queue-threshold 0 comment NFLOG_TRIGGER' + ] + for nft_cmd in nft_cmds: + cmd(nft_cmd, raising=ConfigError) + + def get_config(config=None): if config: conf = config @@ -252,7 +277,6 @@ def generate(flow_config): call('systemctl daemon-reload') def apply(flow_config): - action = 'restart' # Check if flow-accounting was removed and define command if not flow_config: _nftables_config([], 'ingress') @@ -262,6 +286,10 @@ def apply(flow_config): call(f'systemctl stop {systemd_service}') if os.path.exists(uacctd_conf_path): os.unlink(uacctd_conf_path) + + # must be done after systemctl + _nftables_trigger_setup('delete') + return # Start/reload flow-accounting daemon @@ -277,6 +305,10 @@ def apply(flow_config): else: _nftables_config([], 'egress') + # add a trigger for signal processing + _nftables_trigger_setup('add') + + if __name__ == '__main__': try: config = get_config() diff --git a/src/conf_mode/http-api.py b/src/conf_mode/http-api.py index 793a90d88..d8fe3b736 100755 --- a/src/conf_mode/http-api.py +++ b/src/conf_mode/http-api.py @@ -27,6 +27,7 @@ from vyos.config import Config from vyos.configdep import set_dependents, call_dependents from vyos.template import render from vyos.utils.process import call +from vyos.utils.process import is_systemd_service_running from vyos import ConfigError from vyos import airbag airbag.enable() @@ -130,7 +131,10 @@ def apply(http_api): service_name = 'vyos-http-api.service' if http_api is not None: - call(f'systemctl restart {service_name}') + if is_systemd_service_running(f'{service_name}'): + call(f'systemctl reload {service_name}') + else: + call(f'systemctl restart {service_name}') else: call(f'systemctl stop {service_name}') diff --git a/src/etc/sysctl.d/30-vyos-router.conf b/src/etc/sysctl.d/30-vyos-router.conf index fcdc1b21d..1c9b8999f 100644 --- a/src/etc/sysctl.d/30-vyos-router.conf +++ b/src/etc/sysctl.d/30-vyos-router.conf @@ -21,7 +21,6 @@ net.ipv4.conf.all.arp_filter=0 # https://vyos.dev/T300 net.ipv4.conf.all.arp_ignore=0 - net.ipv4.conf.all.arp_announce=2 # Enable packet forwarding for IPv4 @@ -103,6 +102,6 @@ net.ipv4.igmp_max_memberships = 512 net.core.rps_sock_flow_entries = 32768 # Congestion control -net.core.default_qdisc=fq +net.core.default_qdisc=fq_codel net.ipv4.tcp_congestion_control=bbr diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 66e80ced5..3a9efb73e 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -22,12 +22,14 @@ import grp import copy import json import logging +import signal import traceback import threading +from time import sleep from typing import List, Union, Callable, Dict -import uvicorn from fastapi import FastAPI, Depends, Request, Response, HTTPException +from fastapi import BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.exceptions import RequestValidationError from fastapi.routing import APIRoute @@ -36,10 +38,14 @@ from starlette.middleware.cors import CORSMiddleware from starlette.datastructures import FormData from starlette.formparsers import FormParser, MultiPartParser from multipart.multipart import parse_options_header +from uvicorn import Config as UvicornConfig +from uvicorn import Server as UvicornServer from ariadne.asgi import GraphQL -import vyos.config +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSession, ConfigSessionError import api.graphql.state @@ -410,12 +416,24 @@ app.router.route_class = MultipartRoute async def validation_exception_handler(request, exc): return error(400, str(exc.errors()[0])) +self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" + +def call_commit(s: ConfigSession): + try: + s.commit() + except ConfigSessionError as e: + s.discard() + if app.state.vyos_debug: + logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + else: + logger.warning(f"ConfigSessionError: {e}") + def _configure_op(data: Union[ConfigureModel, ConfigureListModel, ConfigSectionModel, ConfigSectionListModel], - request: Request): + request: Request, background_tasks: BackgroundTasks): session = app.state.vyos_session env = session.get_session_env() - config = vyos.config.Config(session_env=env) + config = Config(session_env=env) endpoint = request.url.path @@ -470,7 +488,15 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, else: raise ConfigSessionError(f"'{op}' is not a valid operation") # end for - session.commit() + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, session) + msg = self_ref_msg + else: + session.commit() + logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'") except ConfigSessionError as e: session.discard() @@ -495,21 +521,21 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, @app.post('/configure') def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request): - return _configure_op(data, request) + ConfigureListModel], + request: Request, background_tasks: BackgroundTasks): + return _configure_op(data, request, background_tasks) @app.post('/configure-section') def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel], - request: Request): - return _configure_op(data, request) + ConfigSectionListModel], + request: Request, background_tasks: BackgroundTasks): + return _configure_op(data, request, background_tasks) @app.post("/retrieve") async def retrieve_op(data: RetrieveModel): session = app.state.vyos_session env = session.get_session_env() - config = vyos.config.Config(session_env=env) + config = Config(session_env=env) op = data.op path = " ".join(data.path) @@ -528,10 +554,10 @@ async def retrieve_op(data: RetrieveModel): res = session.show_config(path=data.path) if config_format == 'json': - config_tree = vyos.configtree.ConfigTree(res) + config_tree = ConfigTree(res) res = json.loads(config_tree.to_json()) elif config_format == 'json_ast': - config_tree = vyos.configtree.ConfigTree(res) + config_tree = ConfigTree(res) res = json.loads(config_tree.to_json_ast()) elif config_format == 'raw': pass @@ -548,10 +574,11 @@ async def retrieve_op(data: RetrieveModel): return success(res) @app.post('/config-file') -def config_file_op(data: ConfigFileModel): +def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): session = app.state.vyos_session - + env = session.get_session_env() op = data.op + msg = None try: if op == 'save': @@ -559,14 +586,23 @@ def config_file_op(data: ConfigFileModel): path = data.file else: path = '/config/config.boot' - res = session.save_config(path) + msg = session.save_config(path) elif op == 'load': if data.file: path = data.file else: return error(400, "Missing required field \"file\"") - res = session.migrate_and_load_config(path) - res = session.commit() + + session.migrate_and_load_config(path) + + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, session) + msg = self_ref_msg + else: + session.commit() else: return error(400, f"'{op}' is not a valid operation") except ConfigSessionError as e: @@ -575,7 +611,7 @@ def config_file_op(data: ConfigFileModel): logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") - return success(res) + return success(msg) @app.post('/image') def image_op(data: ImageModel): @@ -607,7 +643,7 @@ def image_op(data: ImageModel): return success(res) @app.post('/container-image') -def image_op(data: ContainerImageModel): +def container_image_op(data: ContainerImageModel): session = app.state.vyos_session op = data.op @@ -702,7 +738,7 @@ def reset_op(data: ResetModel): # GraphQL integration ### -def graphql_init(fast_api_app): +def graphql_init(app: FastAPI = app): from api.graphql.libs.token_auth import get_user_context api.graphql.state.init() api.graphql.state.settings['app'] = app @@ -728,26 +764,45 @@ def graphql_init(fast_api_app): debug=True, introspection=in_spec)) ### +# Modify uvicorn to allow reloading server within the configsession +### -if __name__ == '__main__': - # systemd's user and group options don't work, do it by hand here, - # else no one else will be able to commit - cfg_group = grp.getgrnam(CFG_GROUP) - os.setgid(cfg_group.gr_gid) +server = None +shutdown = False - # Need to set file permissions to 775 too so that every vyattacfg group member - # has write access to the running config - os.umask(0o002) +class ApiServerConfig(UvicornConfig): + pass + +class ApiServer(UvicornServer): + def install_signal_handlers(self): + pass + +def reload_handler(signum, frame): + global server + logger.debug('Reload signal received...') + if server is not None: + server.handle_exit(signum, frame) + server = None + logger.info('Server stopping for reload...') + else: + logger.warning('Reload called for non-running server...') +def shutdown_handler(signum, frame): + global shutdown + logger.debug('Shutdown signal received...') + server.handle_exit(signum, frame) + logger.info('Server shutdown...') + shutdown = True + +def initialization(session: ConfigSession, app: FastAPI = app): + global server try: server_config = load_server_config() - except Exception as err: - logger.critical(f"Failed to load the HTTP API server config: {err}") + except Exception as e: + logger.critical(f'Failed to load the HTTP API server config: {e}') sys.exit(1) - config_session = ConfigSession(os.getpid()) - - app.state.vyos_session = config_session + app.state.vyos_session = session app.state.vyos_keys = server_config['api_keys'] app.state.vyos_debug = server_config['debug'] @@ -770,14 +825,44 @@ if __name__ == '__main__': if app.state.vyos_graphql: graphql_init(app) + if not server_config['socket']: + config = ApiServerConfig(app, + host=server_config["listen_address"], + port=int(server_config["port"]), + proxy_headers=True) + else: + config = ApiServerConfig(app, + uds="/run/api.sock", + proxy_headers=True) + server = ApiServer(config) + +def run_server(): try: - if not server_config['socket']: - uvicorn.run(app, host=server_config["listen_address"], - port=int(server_config["port"]), - proxy_headers=True) - else: - uvicorn.run(app, uds="/run/api.sock", - proxy_headers=True) - except OSError as err: - logger.critical(f"OSError {err}") + server.run() + except OSError as e: + logger.critical(e) sys.exit(1) + +if __name__ == '__main__': + # systemd's user and group options don't work, do it by hand here, + # else no one else will be able to commit + cfg_group = grp.getgrnam(CFG_GROUP) + os.setgid(cfg_group.gr_gid) + + # Need to set file permissions to 775 too so that every vyattacfg group member + # has write access to the running config + os.umask(0o002) + + signal.signal(signal.SIGHUP, reload_handler) + signal.signal(signal.SIGTERM, shutdown_handler) + + config_session = ConfigSession(os.getpid()) + + while True: + logger.debug('Enter main loop...') + if shutdown: + break + if server is None: + initialization(config_session) + server.run() + sleep(1) diff --git a/src/system/uacctd_stop.py b/src/system/uacctd_stop.py new file mode 100755 index 000000000..7fbac0566 --- /dev/null +++ b/src/system/uacctd_stop.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2023 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +# Control pmacct daemons in a tricky way. +# Pmacct has signal processing in a main loop, together with packet +# processing. Because of this, while it is waiting for packets, it cannot +# handle the control signal. We need to start the systemctl command and then +# send some packets to pmacct to wake it up + +from argparse import ArgumentParser +from socket import socket +from sys import exit +from time import sleep + +from psutil import Process + + +def stop_process(pid: int, timeout: int) -> None: + """Send a signal to uacctd + and then send packets to special address predefined in a firewall + to unlock main loop in uacctd and finish the process properly + + Args: + pid (int): uacctd PID + timeout (int): seconds to wait for a process end + """ + # find a process + uacctd = Process(pid) + uacctd.terminate() + + # create a socket + trigger = socket() + + first_cycle: bool = True + while uacctd.is_running() and timeout: + trigger.sendto(b'WAKEUP', ('127.0.254.0', 0)) + # do not sleep during first attempt + if not first_cycle: + sleep(1) + timeout -= 1 + first_cycle = False + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('process_id', + type=int, + help='PID file of uacctd core process') + parser.add_argument('timeout', + type=int, + help='time to wait for process end') + args = parser.parse_args() + stop_process(args.process_id, args.timeout) + exit() |