diff options
Diffstat (limited to 'src')
-rwxr-xr-x | src/conf_mode/http-api.py | 6 | ||||
-rwxr-xr-x | src/conf_mode/policy-local-route.py | 79 | ||||
-rwxr-xr-x | src/helpers/config_dependency.py | 79 | ||||
-rwxr-xr-x | src/services/vyos-http-api-server | 173 |
4 files changed, 275 insertions, 62 deletions
diff --git a/src/conf_mode/http-api.py b/src/conf_mode/http-api.py index 793a90d88..d8fe3b736 100755 --- a/src/conf_mode/http-api.py +++ b/src/conf_mode/http-api.py @@ -27,6 +27,7 @@ from vyos.config import Config from vyos.configdep import set_dependents, call_dependents from vyos.template import render from vyos.utils.process import call +from vyos.utils.process import is_systemd_service_running from vyos import ConfigError from vyos import airbag airbag.enable() @@ -130,7 +131,10 @@ def apply(http_api): service_name = 'vyos-http-api.service' if http_api is not None: - call(f'systemctl restart {service_name}') + if is_systemd_service_running(f'{service_name}'): + call(f'systemctl reload {service_name}') + else: + call(f'systemctl restart {service_name}') else: call(f'systemctl stop {service_name}') diff --git a/src/conf_mode/policy-local-route.py b/src/conf_mode/policy-local-route.py index 2e8aabb80..91e4fce2c 100755 --- a/src/conf_mode/policy-local-route.py +++ b/src/conf_mode/policy-local-route.py @@ -52,19 +52,28 @@ def get_config(config=None): if tmp: for rule in (tmp or []): src = leaf_node_changed(conf, base_rule + [rule, 'source', 'address']) + src_port = leaf_node_changed(conf, base_rule + [rule, 'source', 'port']) fwmk = leaf_node_changed(conf, base_rule + [rule, 'fwmark']) iif = leaf_node_changed(conf, base_rule + [rule, 'inbound-interface']) dst = leaf_node_changed(conf, base_rule + [rule, 'destination', 'address']) + dst_port = leaf_node_changed(conf, base_rule + [rule, 'destination', 'port']) + table = leaf_node_changed(conf, base_rule + [rule, 'set', 'table']) proto = leaf_node_changed(conf, base_rule + [rule, 'protocol']) rule_def = {} if src: rule_def = dict_merge({'source': {'address': src}}, rule_def) + if src_port: + rule_def = dict_merge({'source': {'port': src_port}}, rule_def) if fwmk: rule_def = dict_merge({'fwmark' : fwmk}, rule_def) if iif: rule_def = dict_merge({'inbound_interface' : iif}, rule_def) if dst: rule_def = dict_merge({'destination': {'address': dst}}, rule_def) + if dst_port: + rule_def = dict_merge({'destination': {'port': dst_port}}, rule_def) + if table: + rule_def = dict_merge({'table' : table}, rule_def) if proto: rule_def = dict_merge({'protocol' : proto}, rule_def) dict = dict_merge({dict_id : {rule : rule_def}}, dict) @@ -79,9 +88,12 @@ def get_config(config=None): if 'rule' in pbr[route]: for rule, rule_config in pbr[route]['rule'].items(): src = leaf_node_changed(conf, base_rule + [rule, 'source', 'address']) + src_port = leaf_node_changed(conf, base_rule + [rule, 'source', 'port']) fwmk = leaf_node_changed(conf, base_rule + [rule, 'fwmark']) iif = leaf_node_changed(conf, base_rule + [rule, 'inbound-interface']) dst = leaf_node_changed(conf, base_rule + [rule, 'destination', 'address']) + dst_port = leaf_node_changed(conf, base_rule + [rule, 'destination', 'port']) + table = leaf_node_changed(conf, base_rule + [rule, 'set', 'table']) proto = leaf_node_changed(conf, base_rule + [rule, 'protocol']) # keep track of changes in configuration # otherwise we might remove an existing node although nothing else has changed @@ -105,14 +117,32 @@ def get_config(config=None): if len(src) > 0: rule_def = dict_merge({'source': {'address': src}}, rule_def) + # source port + if src_port is None: + if 'source' in rule_config: + if 'port' in rule_config['source']: + tmp = rule_config['source']['port'] + if isinstance(tmp, str): + tmp = [tmp] + rule_def = dict_merge({'source': {'port': tmp}}, rule_def) + else: + changed = True + if len(src_port) > 0: + rule_def = dict_merge({'source': {'port': src_port}}, rule_def) + + # fwmark if fwmk is None: if 'fwmark' in rule_config: - rule_def = dict_merge({'fwmark': rule_config['fwmark']}, rule_def) + tmp = rule_config['fwmark'] + if isinstance(tmp, str): + tmp = [tmp] + rule_def = dict_merge({'fwmark': tmp}, rule_def) else: changed = True if len(fwmk) > 0: rule_def = dict_merge({'fwmark' : fwmk}, rule_def) + # inbound-interface if iif is None: if 'inbound_interface' in rule_config: rule_def = dict_merge({'inbound_interface': rule_config['inbound_interface']}, rule_def) @@ -121,6 +151,7 @@ def get_config(config=None): if len(iif) > 0: rule_def = dict_merge({'inbound_interface' : iif}, rule_def) + # destination address if dst is None: if 'destination' in rule_config: if 'address' in rule_config['destination']: @@ -130,9 +161,35 @@ def get_config(config=None): if len(dst) > 0: rule_def = dict_merge({'destination': {'address': dst}}, rule_def) + # destination port + if dst_port is None: + if 'destination' in rule_config: + if 'port' in rule_config['destination']: + tmp = rule_config['destination']['port'] + if isinstance(tmp, str): + tmp = [tmp] + rule_def = dict_merge({'destination': {'port': tmp}}, rule_def) + else: + changed = True + if len(dst_port) > 0: + rule_def = dict_merge({'destination': {'port': dst_port}}, rule_def) + + # table + if table is None: + if 'set' in rule_config and 'table' in rule_config['set']: + rule_def = dict_merge({'table': [rule_config['set']['table']]}, rule_def) + else: + changed = True + if len(table) > 0: + rule_def = dict_merge({'table' : table}, rule_def) + + # protocol if proto is None: if 'protocol' in rule_config: - rule_def = dict_merge({'protocol': rule_config['protocol']}, rule_def) + tmp = rule_config['protocol'] + if isinstance(tmp, str): + tmp = [tmp] + rule_def = dict_merge({'protocol': tmp}, rule_def) else: changed = True if len(proto) > 0: @@ -192,19 +249,27 @@ def apply(pbr): for rule, rule_config in pbr[rule_rm].items(): source = rule_config.get('source', {}).get('address', ['']) + source_port = rule_config.get('source', {}).get('port', ['']) destination = rule_config.get('destination', {}).get('address', ['']) + destination_port = rule_config.get('destination', {}).get('port', ['']) fwmark = rule_config.get('fwmark', ['']) inbound_interface = rule_config.get('inbound_interface', ['']) protocol = rule_config.get('protocol', ['']) + table = rule_config.get('table', ['']) - for src, dst, fwmk, iif, proto in product(source, destination, fwmark, inbound_interface, protocol): + for src, dst, src_port, dst_port, fwmk, iif, proto, table in product( + source, destination, source_port, destination_port, + fwmark, inbound_interface, protocol, table): f_src = '' if src == '' else f' from {src} ' + f_src_port = '' if src_port == '' else f' sport {src_port} ' f_dst = '' if dst == '' else f' to {dst} ' + f_dst_port = '' if dst_port == '' else f' dport {dst_port} ' f_fwmk = '' if fwmk == '' else f' fwmark {fwmk} ' f_iif = '' if iif == '' else f' iif {iif} ' f_proto = '' if proto == '' else f' ipproto {proto} ' + f_table = '' if table == '' else f' lookup {table} ' - call(f'ip{v6} rule del prio {rule} {f_src}{f_dst}{f_fwmk}{f_iif}') + call(f'ip{v6} rule del prio {rule} {f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif}{f_table}') # Generate new config for route in ['local_route', 'local_route6']: @@ -218,7 +283,9 @@ def apply(pbr): for rule, rule_config in pbr_route['rule'].items(): table = rule_config['set'].get('table', '') source = rule_config.get('source', {}).get('address', ['all']) + source_port = rule_config.get('source', {}).get('port', '') destination = rule_config.get('destination', {}).get('address', ['all']) + destination_port = rule_config.get('destination', {}).get('port', '') fwmark = rule_config.get('fwmark', '') inbound_interface = rule_config.get('inbound_interface', '') protocol = rule_config.get('protocol', '') @@ -227,11 +294,13 @@ def apply(pbr): f_src = f' from {src} ' if src else '' for dst in destination: f_dst = f' to {dst} ' if dst else '' + f_src_port = f' sport {source_port} ' if source_port else '' + f_dst_port = f' dport {destination_port} ' if destination_port else '' f_fwmk = f' fwmark {fwmark} ' if fwmark else '' f_iif = f' iif {inbound_interface} ' if inbound_interface else '' f_proto = f' ipproto {protocol} ' if protocol else '' - call(f'ip{v6} rule add prio {rule}{f_src}{f_dst}{f_proto}{f_fwmk}{f_iif} lookup {table}') + call(f'ip{v6} rule add prio {rule}{f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif} lookup {table}') return None diff --git a/src/helpers/config_dependency.py b/src/helpers/config_dependency.py index 50c72956e..817bcc65a 100755 --- a/src/helpers/config_dependency.py +++ b/src/helpers/config_dependency.py @@ -18,22 +18,75 @@ import os import sys +import json from argparse import ArgumentParser from argparse import ArgumentTypeError - -try: - from vyos.configdep import check_dependency_graph - from vyos.defaults import directories -except ImportError: - # allow running during addon package build - _here = os.path.dirname(__file__) - sys.path.append(os.path.join(_here, '../../python/vyos')) - from configdep import check_dependency_graph - from defaults import directories +from graphlib import TopologicalSorter, CycleError # addon packages will need to specify the dependency directory -dependency_dir = os.path.join(directories['data'], - 'config-mode-dependencies') +data_dir = '/usr/share/vyos/' +dependency_dir = os.path.join(data_dir, 'config-mode-dependencies') + +def dict_merge(source, destination): + from copy import deepcopy + tmp = deepcopy(destination) + + for key, value in source.items(): + if key not in tmp: + tmp[key] = value + elif isinstance(source[key], dict): + tmp[key] = dict_merge(source[key], tmp[key]) + + return tmp + +def read_dependency_dict(dependency_dir: str = dependency_dir) -> dict: + res = {} + for dep_file in os.listdir(dependency_dir): + if not dep_file.endswith('.json'): + continue + path = os.path.join(dependency_dir, dep_file) + with open(path) as f: + d = json.load(f) + if dep_file == 'vyos-1x.json': + res = dict_merge(res, d) + else: + res = dict_merge(d, res) + + return res + +def graph_from_dependency_dict(d: dict) -> dict: + g = {} + for k in list(d): + g[k] = set() + # add the dependencies for every sub-case; should there be cases + # that are mutally exclusive in the future, the graphs will be + # distinguished + for el in list(d[k]): + g[k] |= set(d[k][el]) + + return g + +def is_acyclic(d: dict) -> bool: + g = graph_from_dependency_dict(d) + ts = TopologicalSorter(g) + try: + # get node iterator + order = ts.static_order() + # try iteration + _ = [*order] + except CycleError: + return False + + return True + +def check_dependency_graph(dependency_dir: str = dependency_dir, + supplement: str = None) -> bool: + d = read_dependency_dict(dependency_dir=dependency_dir) + if supplement is not None: + with open(supplement) as f: + d = dict_merge(json.load(f), d) + + return is_acyclic(d) def path_exists(s): if not os.path.exists(s): @@ -50,8 +103,10 @@ def main(): args = vars(parser.parse_args()) if not check_dependency_graph(**args): + print("dependency error: cycle exists") sys.exit(1) + print("dependency graph acyclic") sys.exit(0) if __name__ == '__main__': diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 66e80ced5..3a9efb73e 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -22,12 +22,14 @@ import grp import copy import json import logging +import signal import traceback import threading +from time import sleep from typing import List, Union, Callable, Dict -import uvicorn from fastapi import FastAPI, Depends, Request, Response, HTTPException +from fastapi import BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.exceptions import RequestValidationError from fastapi.routing import APIRoute @@ -36,10 +38,14 @@ from starlette.middleware.cors import CORSMiddleware from starlette.datastructures import FormData from starlette.formparsers import FormParser, MultiPartParser from multipart.multipart import parse_options_header +from uvicorn import Config as UvicornConfig +from uvicorn import Server as UvicornServer from ariadne.asgi import GraphQL -import vyos.config +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSession, ConfigSessionError import api.graphql.state @@ -410,12 +416,24 @@ app.router.route_class = MultipartRoute async def validation_exception_handler(request, exc): return error(400, str(exc.errors()[0])) +self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" + +def call_commit(s: ConfigSession): + try: + s.commit() + except ConfigSessionError as e: + s.discard() + if app.state.vyos_debug: + logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + else: + logger.warning(f"ConfigSessionError: {e}") + def _configure_op(data: Union[ConfigureModel, ConfigureListModel, ConfigSectionModel, ConfigSectionListModel], - request: Request): + request: Request, background_tasks: BackgroundTasks): session = app.state.vyos_session env = session.get_session_env() - config = vyos.config.Config(session_env=env) + config = Config(session_env=env) endpoint = request.url.path @@ -470,7 +488,15 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, else: raise ConfigSessionError(f"'{op}' is not a valid operation") # end for - session.commit() + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, session) + msg = self_ref_msg + else: + session.commit() + logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'") except ConfigSessionError as e: session.discard() @@ -495,21 +521,21 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, @app.post('/configure') def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request): - return _configure_op(data, request) + ConfigureListModel], + request: Request, background_tasks: BackgroundTasks): + return _configure_op(data, request, background_tasks) @app.post('/configure-section') def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel], - request: Request): - return _configure_op(data, request) + ConfigSectionListModel], + request: Request, background_tasks: BackgroundTasks): + return _configure_op(data, request, background_tasks) @app.post("/retrieve") async def retrieve_op(data: RetrieveModel): session = app.state.vyos_session env = session.get_session_env() - config = vyos.config.Config(session_env=env) + config = Config(session_env=env) op = data.op path = " ".join(data.path) @@ -528,10 +554,10 @@ async def retrieve_op(data: RetrieveModel): res = session.show_config(path=data.path) if config_format == 'json': - config_tree = vyos.configtree.ConfigTree(res) + config_tree = ConfigTree(res) res = json.loads(config_tree.to_json()) elif config_format == 'json_ast': - config_tree = vyos.configtree.ConfigTree(res) + config_tree = ConfigTree(res) res = json.loads(config_tree.to_json_ast()) elif config_format == 'raw': pass @@ -548,10 +574,11 @@ async def retrieve_op(data: RetrieveModel): return success(res) @app.post('/config-file') -def config_file_op(data: ConfigFileModel): +def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): session = app.state.vyos_session - + env = session.get_session_env() op = data.op + msg = None try: if op == 'save': @@ -559,14 +586,23 @@ def config_file_op(data: ConfigFileModel): path = data.file else: path = '/config/config.boot' - res = session.save_config(path) + msg = session.save_config(path) elif op == 'load': if data.file: path = data.file else: return error(400, "Missing required field \"file\"") - res = session.migrate_and_load_config(path) - res = session.commit() + + session.migrate_and_load_config(path) + + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, session) + msg = self_ref_msg + else: + session.commit() else: return error(400, f"'{op}' is not a valid operation") except ConfigSessionError as e: @@ -575,7 +611,7 @@ def config_file_op(data: ConfigFileModel): logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") - return success(res) + return success(msg) @app.post('/image') def image_op(data: ImageModel): @@ -607,7 +643,7 @@ def image_op(data: ImageModel): return success(res) @app.post('/container-image') -def image_op(data: ContainerImageModel): +def container_image_op(data: ContainerImageModel): session = app.state.vyos_session op = data.op @@ -702,7 +738,7 @@ def reset_op(data: ResetModel): # GraphQL integration ### -def graphql_init(fast_api_app): +def graphql_init(app: FastAPI = app): from api.graphql.libs.token_auth import get_user_context api.graphql.state.init() api.graphql.state.settings['app'] = app @@ -728,26 +764,45 @@ def graphql_init(fast_api_app): debug=True, introspection=in_spec)) ### +# Modify uvicorn to allow reloading server within the configsession +### -if __name__ == '__main__': - # systemd's user and group options don't work, do it by hand here, - # else no one else will be able to commit - cfg_group = grp.getgrnam(CFG_GROUP) - os.setgid(cfg_group.gr_gid) +server = None +shutdown = False - # Need to set file permissions to 775 too so that every vyattacfg group member - # has write access to the running config - os.umask(0o002) +class ApiServerConfig(UvicornConfig): + pass + +class ApiServer(UvicornServer): + def install_signal_handlers(self): + pass + +def reload_handler(signum, frame): + global server + logger.debug('Reload signal received...') + if server is not None: + server.handle_exit(signum, frame) + server = None + logger.info('Server stopping for reload...') + else: + logger.warning('Reload called for non-running server...') +def shutdown_handler(signum, frame): + global shutdown + logger.debug('Shutdown signal received...') + server.handle_exit(signum, frame) + logger.info('Server shutdown...') + shutdown = True + +def initialization(session: ConfigSession, app: FastAPI = app): + global server try: server_config = load_server_config() - except Exception as err: - logger.critical(f"Failed to load the HTTP API server config: {err}") + except Exception as e: + logger.critical(f'Failed to load the HTTP API server config: {e}') sys.exit(1) - config_session = ConfigSession(os.getpid()) - - app.state.vyos_session = config_session + app.state.vyos_session = session app.state.vyos_keys = server_config['api_keys'] app.state.vyos_debug = server_config['debug'] @@ -770,14 +825,44 @@ if __name__ == '__main__': if app.state.vyos_graphql: graphql_init(app) + if not server_config['socket']: + config = ApiServerConfig(app, + host=server_config["listen_address"], + port=int(server_config["port"]), + proxy_headers=True) + else: + config = ApiServerConfig(app, + uds="/run/api.sock", + proxy_headers=True) + server = ApiServer(config) + +def run_server(): try: - if not server_config['socket']: - uvicorn.run(app, host=server_config["listen_address"], - port=int(server_config["port"]), - proxy_headers=True) - else: - uvicorn.run(app, uds="/run/api.sock", - proxy_headers=True) - except OSError as err: - logger.critical(f"OSError {err}") + server.run() + except OSError as e: + logger.critical(e) sys.exit(1) + +if __name__ == '__main__': + # systemd's user and group options don't work, do it by hand here, + # else no one else will be able to commit + cfg_group = grp.getgrnam(CFG_GROUP) + os.setgid(cfg_group.gr_gid) + + # Need to set file permissions to 775 too so that every vyattacfg group member + # has write access to the running config + os.umask(0o002) + + signal.signal(signal.SIGHUP, reload_handler) + signal.signal(signal.SIGTERM, shutdown_handler) + + config_session = ConfigSession(os.getpid()) + + while True: + logger.debug('Enter main loop...') + if shutdown: + break + if server is None: + initialization(config_session) + server.run() + sleep(1) |