diff options
Diffstat (limited to 'src')
65 files changed, 3587 insertions, 1705 deletions
diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py index 5638a9668..ffbd915a2 100755 --- a/src/conf_mode/firewall.py +++ b/src/conf_mode/firewall.py @@ -36,11 +36,15 @@ from vyos.utils.process import cmd from vyos.utils.process import rc_cmd from vyos import ConfigError from vyos import airbag +from pathlib import Path from subprocess import run as subp_run airbag.enable() nftables_conf = '/run/nftables.conf' +domain_resolver_usage = '/run/use-vyos-domain-resolver-firewall' +domain_resolver_usage_nat = '/run/use-vyos-domain-resolver-nat' + sysctl_file = r'/run/sysctl/10-vyos-firewall.conf' valid_groups = [ @@ -128,7 +132,7 @@ def get_config(config=None): firewall['geoip_updated'] = geoip_updated(conf, firewall) - fqdn_config_parse(firewall) + fqdn_config_parse(firewall, 'firewall') set_dependents('conntrack', conf) @@ -570,12 +574,15 @@ def apply(firewall): call_dependents() - # T970 Enable a resolver (systemd daemon) that checks - # domain-group/fqdn addresses and update entries for domains by timeout - # If router loaded without internet connection or for synchronization - domain_action = 'stop' - if dict_search_args(firewall, 'group', 'domain_group') or firewall['ip_fqdn'] or firewall['ip6_fqdn']: - domain_action = 'restart' + ## DOMAIN RESOLVER + domain_action = 'restart' + if dict_search_args(firewall, 'group', 'domain_group') or firewall['ip_fqdn'].items() or firewall['ip6_fqdn'].items(): + text = f'# Automatically generated by firewall.py\nThis file indicates that vyos-domain-resolver service is used by the firewall.\n' + Path(domain_resolver_usage).write_text(text) + else: + Path(domain_resolver_usage).unlink(missing_ok=True) + if not Path('/run').glob('use-vyos-domain-resolver*'): + domain_action = 'stop' call(f'systemctl {domain_action} vyos-domain-resolver.service') if firewall['geoip_updated']: diff --git a/src/conf_mode/interfaces_bridge.py b/src/conf_mode/interfaces_bridge.py index 7b2c1ee0b..637db442a 100755 --- a/src/conf_mode/interfaces_bridge.py +++ b/src/conf_mode/interfaces_bridge.py @@ -53,20 +53,22 @@ def get_config(config=None): tmp = node_changed(conf, base + [ifname, 'member', 'interface']) if tmp: if 'member' in bridge: - bridge['member'].update({'interface_remove' : tmp }) + bridge['member'].update({'interface_remove': {t: {} for t in tmp}}) else: - bridge.update({'member' : {'interface_remove' : tmp }}) - for interface in tmp: - # When using VXLAN member interfaces that are configured for Single - # VXLAN Device (SVD) we need to call the VXLAN conf-mode script to - # re-create VLAN to VNI mappings if required, but only if the interface - # is already live on the system - this must not be done on first commit - if interface.startswith('vxlan') and interface_exists(interface): - set_dependents('vxlan', conf, interface) - # When using Wireless member interfaces we need to inform hostapd - # to properly set-up the bridge - elif interface.startswith('wlan') and interface_exists(interface): - set_dependents('wlan', conf, interface) + bridge.update({'member': {'interface_remove': {t: {} for t in tmp}}}) + for interface in tmp: + # When using VXLAN member interfaces that are configured for Single + # VXLAN Device (SVD) we need to call the VXLAN conf-mode script to + # re-create VLAN to VNI mappings if required, but only if the interface + # is already live on the system - this must not be done on first commit + if interface.startswith('vxlan') and interface_exists(interface): + set_dependents('vxlan', conf, interface) + _, vxlan = get_interface_dict(conf, ['interfaces', 'vxlan'], ifname=interface) + bridge['member']['interface_remove'].update({interface: vxlan}) + # When using Wireless member interfaces we need to inform hostapd + # to properly set-up the bridge + elif interface.startswith('wlan') and interface_exists(interface): + set_dependents('wlan', conf, interface) if dict_search('member.interface', bridge) is not None: for interface in list(bridge['member']['interface']): @@ -118,6 +120,16 @@ def get_config(config=None): return bridge def verify(bridge): + # to delete interface or remove a member interface VXLAN first need to check if + # VXLAN does not require to be a member of a bridge interface + if dict_search('member.interface_remove', bridge): + for iface, iface_config in bridge['member']['interface_remove'].items(): + if iface.startswith('vxlan') and dict_search('parameters.neighbor_suppress', iface_config) != None: + raise ConfigError( + f'To detach interface {iface} from bridge you must first ' + f'disable "neighbor-suppress" parameter in the VXLAN interface {iface}' + ) + if 'deleted' in bridge: return None @@ -192,7 +204,7 @@ def apply(bridge): try: call_dependents() except ConfigError: - raise ConfigError('Error updating member interface configuration after changing bridge!') + raise ConfigError(f'Error updating member interface {interface} configuration after changing bridge!') return None diff --git a/src/conf_mode/load-balancing_reverse-proxy.py b/src/conf_mode/load-balancing_haproxy.py index 17226efe9..45042dd52 100755..100644 --- a/src/conf_mode/load-balancing_reverse-proxy.py +++ b/src/conf_mode/load-balancing_haproxy.py @@ -48,7 +48,7 @@ def get_config(config=None): else: conf = Config() - base = ['load-balancing', 'reverse-proxy'] + base = ['load-balancing', 'haproxy'] if not conf.exists(base): return None lb = conf.get_config_dict(base, diff --git a/src/conf_mode/nat.py b/src/conf_mode/nat.py index 39803fa02..98b2f3f29 100755 --- a/src/conf_mode/nat.py +++ b/src/conf_mode/nat.py @@ -26,10 +26,13 @@ from vyos.template import is_ip_network from vyos.utils.kernel import check_kmod from vyos.utils.dict import dict_search from vyos.utils.dict import dict_search_args +from vyos.utils.file import write_file from vyos.utils.process import cmd from vyos.utils.process import run +from vyos.utils.process import call from vyos.utils.network import is_addr_assigned from vyos.utils.network import interface_exists +from vyos.firewall import fqdn_config_parse from vyos import ConfigError from vyos import airbag @@ -39,6 +42,8 @@ k_mod = ['nft_nat', 'nft_chain_nat'] nftables_nat_config = '/run/nftables_nat.conf' nftables_static_nat_conf = '/run/nftables_static-nat-rules.nft' +domain_resolver_usage = '/run/use-vyos-domain-resolver-nat' +domain_resolver_usage_firewall = '/run/use-vyos-domain-resolver-firewall' valid_groups = [ 'address_group', @@ -71,6 +76,8 @@ def get_config(config=None): if 'dynamic_group' in nat['firewall_group']: del nat['firewall_group']['dynamic_group'] + fqdn_config_parse(nat, 'nat') + return nat def verify_rule(config, err_msg, groups_dict): @@ -251,6 +258,19 @@ def apply(nat): call_dependents() + # DOMAIN RESOLVER + if nat and 'deleted' not in nat: + domain_action = 'restart' + if nat['ip_fqdn'].items(): + text = f'# Automatically generated by nat.py\nThis file indicates that vyos-domain-resolver service is used by nat.\n' + write_file(domain_resolver_usage, text) + elif os.path.exists(domain_resolver_usage): + os.unlink(domain_resolver_usage) + if not os.path.exists(domain_resolver_usage_firewall): + # Firewall not using domain resolver + domain_action = 'stop' + call(f'systemctl {domain_action} vyos-domain-resolver.service') + return None if __name__ == '__main__': diff --git a/src/conf_mode/pki.py b/src/conf_mode/pki.py index 215b22b37..45e0129a3 100755 --- a/src/conf_mode/pki.py +++ b/src/conf_mode/pki.py @@ -27,6 +27,7 @@ from vyos.configdict import node_changed from vyos.configdiff import Diff from vyos.configdiff import get_config_diff from vyos.defaults import directories +from vyos.pki import encode_certificate from vyos.pki import is_ca_certificate from vyos.pki import load_certificate from vyos.pki import load_public_key @@ -36,9 +37,11 @@ from vyos.pki import load_private_key from vyos.pki import load_crl from vyos.pki import load_dh_parameters from vyos.utils.boot import boot_configuration_complete +from vyos.utils.configfs import add_cli_node from vyos.utils.dict import dict_search from vyos.utils.dict import dict_search_args from vyos.utils.dict import dict_search_recursive +from vyos.utils.file import read_file from vyos.utils.process import call from vyos.utils.process import cmd from vyos.utils.process import is_systemd_service_active @@ -68,7 +71,7 @@ sync_search = [ }, { 'keys': ['certificate', 'ca_certificate'], - 'path': ['load_balancing', 'reverse_proxy'], + 'path': ['load_balancing', 'haproxy'], }, { 'keys': ['key'], @@ -446,9 +449,37 @@ def generate(pki): # Get foldernames under vyos_certbot_dir which each represent a certbot cert if os.path.exists(f'{vyos_certbot_dir}/live'): for cert in certbot_list_on_disk: + # ACME certificate is no longer in use by CLI remove it if cert not in certbot_list: - # certificate is no longer active on the CLI - remove it certbot_delete(cert) + continue + # ACME not enabled for individual certificate - bail out early + if 'acme' not in pki['certificate'][cert]: + continue + + # Read in ACME certificate chain information + tmp = read_file(f'{vyos_certbot_dir}/live/{cert}/chain.pem') + tmp = load_certificate(tmp, wrap_tags=False) + cert_chain_base64 = "".join(encode_certificate(tmp).strip().split("\n")[1:-1]) + + # Check if CA chain certificate is already present on CLI to avoid adding + # a duplicate. This only checks for manual added CA certificates and not + # auto added ones with the AUTOCHAIN_ prefix + autochain_prefix = 'AUTOCHAIN_' + ca_cert_present = False + if 'ca' in pki: + for ca_base64, cli_path in dict_search_recursive(pki['ca'], 'certificate'): + # Ignore automatic added CA certificates + if any(item.startswith(autochain_prefix) for item in cli_path): + continue + if cert_chain_base64 == ca_base64: + ca_cert_present = True + + if not ca_cert_present: + tmp = dict_search_args(pki, 'ca', f'{autochain_prefix}{cert}', 'certificate') + if not bool(tmp) or tmp != cert_chain_base64: + print(f'Adding/replacing automatically imported CA certificate for "{cert}" ...') + add_cli_node(['pki', 'ca', f'{autochain_prefix}{cert}', 'certificate'], value=cert_chain_base64) return None diff --git a/src/conf_mode/policy_local-route.py b/src/conf_mode/policy_local-route.py index 331fd972d..9be2bc227 100755 --- a/src/conf_mode/policy_local-route.py +++ b/src/conf_mode/policy_local-route.py @@ -54,6 +54,7 @@ def get_config(config=None): dst = leaf_node_changed(conf, base_rule + [rule, 'destination', 'address']) dst_port = leaf_node_changed(conf, base_rule + [rule, 'destination', 'port']) table = leaf_node_changed(conf, base_rule + [rule, 'set', 'table']) + vrf = leaf_node_changed(conf, base_rule + [rule, 'set', 'vrf']) proto = leaf_node_changed(conf, base_rule + [rule, 'protocol']) rule_def = {} if src: @@ -70,6 +71,8 @@ def get_config(config=None): rule_def = dict_merge({'destination': {'port': dst_port}}, rule_def) if table: rule_def = dict_merge({'table' : table}, rule_def) + if vrf: + rule_def = dict_merge({'vrf' : vrf}, rule_def) if proto: rule_def = dict_merge({'protocol' : proto}, rule_def) dict = dict_merge({dict_id : {rule : rule_def}}, dict) @@ -90,6 +93,7 @@ def get_config(config=None): dst = leaf_node_changed(conf, base_rule + [rule, 'destination', 'address']) dst_port = leaf_node_changed(conf, base_rule + [rule, 'destination', 'port']) table = leaf_node_changed(conf, base_rule + [rule, 'set', 'table']) + vrf = leaf_node_changed(conf, base_rule + [rule, 'set', 'vrf']) proto = leaf_node_changed(conf, base_rule + [rule, 'protocol']) # keep track of changes in configuration # otherwise we might remove an existing node although nothing else has changed @@ -179,6 +183,15 @@ def get_config(config=None): if len(table) > 0: rule_def = dict_merge({'table' : table}, rule_def) + # vrf + if vrf is None: + if 'set' in rule_config and 'vrf' in rule_config['set']: + rule_def = dict_merge({'vrf': [rule_config['set']['vrf']]}, rule_def) + else: + changed = True + if len(vrf) > 0: + rule_def = dict_merge({'vrf' : vrf}, rule_def) + # protocol if proto is None: if 'protocol' in rule_config: @@ -218,8 +231,15 @@ def verify(pbr): ): raise ConfigError('Source or destination address or fwmark or inbound-interface or protocol is required!') - if 'set' not in pbr_route['rule'][rule] or 'table' not in pbr_route['rule'][rule]['set']: - raise ConfigError('Table set is required!') + if 'set' not in pbr_route['rule'][rule]: + raise ConfigError('Either set table or set vrf is required!') + + set_tgts = pbr_route['rule'][rule]['set'] + if 'table' not in set_tgts and 'vrf' not in set_tgts: + raise ConfigError('Either set table or set vrf is required!') + + if 'table' in set_tgts and 'vrf' in set_tgts: + raise ConfigError('set table and set vrf cannot both be set!') if 'inbound_interface' in pbr_route['rule'][rule]: interface = pbr_route['rule'][rule]['inbound_interface'] @@ -250,11 +270,14 @@ def apply(pbr): fwmark = rule_config.get('fwmark', ['']) inbound_interface = rule_config.get('inbound_interface', ['']) protocol = rule_config.get('protocol', ['']) - table = rule_config.get('table', ['']) + # VRF 'default' is actually table 'main' for RIB rules + vrf = [ 'main' if x == 'default' else x for x in rule_config.get('vrf', ['']) ] + # See generate section below for table/vrf overlap explanation + table_or_vrf = rule_config.get('table', vrf) - for src, dst, src_port, dst_port, fwmk, iif, proto, table in product( + for src, dst, src_port, dst_port, fwmk, iif, proto, table_or_vrf in product( source, destination, source_port, destination_port, - fwmark, inbound_interface, protocol, table): + fwmark, inbound_interface, protocol, table_or_vrf): f_src = '' if src == '' else f' from {src} ' f_src_port = '' if src_port == '' else f' sport {src_port} ' f_dst = '' if dst == '' else f' to {dst} ' @@ -262,7 +285,7 @@ def apply(pbr): f_fwmk = '' if fwmk == '' else f' fwmark {fwmk} ' f_iif = '' if iif == '' else f' iif {iif} ' f_proto = '' if proto == '' else f' ipproto {proto} ' - f_table = '' if table == '' else f' lookup {table} ' + f_table = '' if table_or_vrf == '' else f' lookup {table_or_vrf} ' call(f'ip{v6} rule del prio {rule} {f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif}{f_table}') @@ -276,7 +299,13 @@ def apply(pbr): if 'rule' in pbr_route: for rule, rule_config in pbr_route['rule'].items(): - table = rule_config['set'].get('table', '') + # VRFs get configred as route table alias names for iproute2 and only + # one 'set' can get past validation. Either can be fed to lookup. + vrf = rule_config['set'].get('vrf', '') + if vrf == 'default': + table_or_vrf = 'main' + else: + table_or_vrf = rule_config['set'].get('table', vrf) source = rule_config.get('source', {}).get('address', ['all']) source_port = rule_config.get('source', {}).get('port', '') destination = rule_config.get('destination', {}).get('address', ['all']) @@ -295,7 +324,7 @@ def apply(pbr): f_iif = f' iif {inbound_interface} ' if inbound_interface else '' f_proto = f' ipproto {protocol} ' if protocol else '' - call(f'ip{v6} rule add prio {rule}{f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif} lookup {table}') + call(f'ip{v6} rule add prio {rule}{f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif} lookup {table_or_vrf}') return None diff --git a/src/conf_mode/protocols_static.py b/src/conf_mode/protocols_static.py index a2373218a..430cc69d4 100755 --- a/src/conf_mode/protocols_static.py +++ b/src/conf_mode/protocols_static.py @@ -88,7 +88,7 @@ def verify(static): if {'blackhole', 'reject'} <= set(prefix_options): raise ConfigError(f'Can not use both blackhole and reject for '\ - 'prefix "{prefix}"!') + f'prefix "{prefix}"!') return None diff --git a/src/conf_mode/service_monitoring_frr-exporter.py b/src/conf_mode/service_monitoring_frr-exporter.py new file mode 100755 index 000000000..01527d579 --- /dev/null +++ b/src/conf_mode/service_monitoring_frr-exporter.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import os + +from sys import exit + +from vyos.config import Config +from vyos.configdict import is_node_changed +from vyos.configverify import verify_vrf +from vyos.template import render +from vyos.utils.process import call +from vyos import ConfigError +from vyos import airbag + + +airbag.enable() + +service_file = '/etc/systemd/system/frr_exporter.service' +systemd_service = 'frr_exporter.service' + + +def get_config(config=None): + if config: + conf = config + else: + conf = Config() + base = ['service', 'monitoring', 'frr-exporter'] + if not conf.exists(base): + return None + + config_data = conf.get_config_dict( + base, key_mangling=('-', '_'), get_first_key=True + ) + config_data = conf.merge_defaults(config_data, recursive=True) + + tmp = is_node_changed(conf, base + ['vrf']) + if tmp: + config_data.update({'restart_required': {}}) + + return config_data + + +def verify(config_data): + # bail out early - looks like removal from running config + if not config_data: + return None + + verify_vrf(config_data) + return None + + +def generate(config_data): + if not config_data: + # Delete systemd files + if os.path.isfile(service_file): + os.unlink(service_file) + return None + + # Render frr_exporter service_file + render(service_file, 'frr_exporter/frr_exporter.service.j2', config_data) + return None + + +def apply(config_data): + # Reload systemd manager configuration + call('systemctl daemon-reload') + if not config_data: + call(f'systemctl stop {systemd_service}') + return + + # we need to restart the service if e.g. the VRF name changed + systemd_action = 'reload-or-restart' + if 'restart_required' in config_data: + systemd_action = 'restart' + + call(f'systemctl {systemd_action} {systemd_service}') + + +if __name__ == '__main__': + try: + c = get_config() + verify(c) + generate(c) + apply(c) + except ConfigError as e: + print(e) + exit(1) diff --git a/src/conf_mode/service_monitoring_node-exporter.py b/src/conf_mode/service_monitoring_node-exporter.py new file mode 100755 index 000000000..db34bb5d0 --- /dev/null +++ b/src/conf_mode/service_monitoring_node-exporter.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import os + +from sys import exit + +from vyos.config import Config +from vyos.configdict import is_node_changed +from vyos.configverify import verify_vrf +from vyos.template import render +from vyos.utils.process import call +from vyos import ConfigError +from vyos import airbag + + +airbag.enable() + +service_file = '/etc/systemd/system/node_exporter.service' +systemd_service = 'node_exporter.service' + + +def get_config(config=None): + if config: + conf = config + else: + conf = Config() + base = ['service', 'monitoring', 'node-exporter'] + if not conf.exists(base): + return None + + config_data = conf.get_config_dict( + base, key_mangling=('-', '_'), get_first_key=True + ) + config_data = conf.merge_defaults(config_data, recursive=True) + + tmp = is_node_changed(conf, base + ['vrf']) + if tmp: + config_data.update({'restart_required': {}}) + + return config_data + + +def verify(config_data): + # bail out early - looks like removal from running config + if not config_data: + return None + + verify_vrf(config_data) + return None + + +def generate(config_data): + if not config_data: + # Delete systemd files + if os.path.isfile(service_file): + os.unlink(service_file) + return None + + # Render node_exporter service_file + render(service_file, 'node_exporter/node_exporter.service.j2', config_data) + return None + + +def apply(config_data): + # Reload systemd manager configuration + call('systemctl daemon-reload') + if not config_data: + call(f'systemctl stop {systemd_service}') + return + + # we need to restart the service if e.g. the VRF name changed + systemd_action = 'reload-or-restart' + if 'restart_required' in config_data: + systemd_action = 'restart' + + call(f'systemctl {systemd_action} {systemd_service}') + + +if __name__ == '__main__': + try: + c = get_config() + verify(c) + generate(c) + apply(c) + except ConfigError as e: + print(e) + exit(1) diff --git a/src/conf_mode/service_ntp.py b/src/conf_mode/service_ntp.py index 83880fd72..32563aa0e 100755 --- a/src/conf_mode/service_ntp.py +++ b/src/conf_mode/service_ntp.py @@ -17,6 +17,7 @@ import os from vyos.config import Config +from vyos.config import config_dict_merge from vyos.configdict import is_node_changed from vyos.configverify import verify_vrf from vyos.configverify import verify_interface_exists @@ -42,13 +43,21 @@ def get_config(config=None): if not conf.exists(base): return None - ntp = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, with_defaults=True) + ntp = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True) ntp['config_file'] = config_file ntp['user'] = user_group tmp = is_node_changed(conf, base + ['vrf']) if tmp: ntp.update({'restart_required': {}}) + # We have gathered the dict representation of the CLI, but there are default + # options which we need to update into the dictionary retrived. + default_values = conf.get_config_defaults(**ntp.kwargs, recursive=True) + # Only defined PTP default port, if PTP feature is in use + if 'ptp' not in ntp: + del default_values['ptp'] + + ntp = config_dict_merge(default_values, ntp) return ntp def verify(ntp): @@ -87,6 +96,15 @@ def verify(ntp): if ipv6_addresses > 1: raise ConfigError(f'NTP Only admits one ipv6 value for listen-address parameter ') + if 'server' in ntp: + for host, server in ntp['server'].items(): + if 'ptp' in server: + if 'ptp' not in ntp: + raise ConfigError('PTP must be enabled for the NTP service '\ + f'before it can be used for server "{host}"') + else: + break + return None def generate(ntp): diff --git a/src/conf_mode/system_config-management.py b/src/conf_mode/system_config-management.py index c681a8405..8de4e5342 100755 --- a/src/conf_mode/system_config-management.py +++ b/src/conf_mode/system_config-management.py @@ -22,6 +22,7 @@ from vyos.config import Config from vyos.config_mgmt import ConfigMgmt from vyos.config_mgmt import commit_post_hook_dir, commit_hooks + def get_config(config=None): if config: conf = config @@ -36,22 +37,29 @@ def get_config(config=None): return mgmt -def verify(_mgmt): + +def verify(mgmt): + d = mgmt.config_dict + confirm = d.get('commit_confirm', {}) + if confirm.get('action', '') == 'reload' and 'commit_revisions' not in d: + raise ConfigError('commit-confirm reload requires non-zero commit-revisions') + return + def generate(mgmt): if mgmt is None: return mgmt.initialize_revision() + def apply(mgmt): if mgmt is None: return locations = mgmt.locations - archive_target = os.path.join(commit_post_hook_dir, - commit_hooks['commit_archive']) + archive_target = os.path.join(commit_post_hook_dir, commit_hooks['commit_archive']) if locations: try: os.symlink('/usr/bin/config-mgmt', archive_target) @@ -68,8 +76,9 @@ def apply(mgmt): raise ConfigError from exc revisions = mgmt.max_revisions - revision_target = os.path.join(commit_post_hook_dir, - commit_hooks['commit_revision']) + revision_target = os.path.join( + commit_post_hook_dir, commit_hooks['commit_revision'] + ) if revisions > 0: try: os.symlink('/usr/bin/config-mgmt', revision_target) @@ -85,6 +94,7 @@ def apply(mgmt): except OSError as exc: raise ConfigError from exc + if __name__ == '__main__': try: c = get_config() diff --git a/src/conf_mode/system_login_banner.py b/src/conf_mode/system_login_banner.py index 923e1bf57..5826d8042 100755 --- a/src/conf_mode/system_login_banner.py +++ b/src/conf_mode/system_login_banner.py @@ -28,6 +28,7 @@ airbag.enable() PRELOGIN_FILE = r'/etc/issue' PRELOGIN_NET_FILE = r'/etc/issue.net' POSTLOGIN_FILE = r'/etc/motd' +POSTLOGIN_VYOS_FILE = r'/run/motd.d/01-vyos-nonproduction' default_config_data = { 'issue': 'Welcome to VyOS - \\n \\l\n\n', @@ -94,6 +95,9 @@ def apply(banner): render(POSTLOGIN_FILE, 'login/default_motd.j2', banner, permission=0o644, user='root', group='root') + render(POSTLOGIN_VYOS_FILE, 'login/motd_vyos_nonproduction.j2', banner, + permission=0o644, user='root', group='root') + return None if __name__ == '__main__': diff --git a/src/conf_mode/system_option.py b/src/conf_mode/system_option.py index a84572f83..e2832cde6 100755 --- a/src/conf_mode/system_option.py +++ b/src/conf_mode/system_option.py @@ -46,6 +46,13 @@ systemd_action_file = '/lib/systemd/system/ctrl-alt-del.target' usb_autosuspend = r'/etc/udev/rules.d/40-usb-autosuspend.rules' kernel_dynamic_debug = r'/sys/kernel/debug/dynamic_debug/control' time_format_to_locale = {'12-hour': 'en_US.UTF-8', '24-hour': 'en_GB.UTF-8'} +tuned_profiles = { + 'power-save': 'powersave', + 'network-latency': 'network-latency', + 'network-throughput': 'network-throughput', + 'virtual-guest': 'virtual-guest', + 'virtual-host': 'virtual-host', +} def get_config(config=None): @@ -171,7 +178,10 @@ def apply(options): # wait until daemon has started before sending configuration while not is_systemd_service_running('tuned.service'): sleep(0.250) - cmd('tuned-adm profile network-{performance}'.format(**options)) + performance = ' '.join( + list(tuned_profiles[profile] for profile in options['performance']) + ) + cmd(f'tuned-adm profile {performance}') else: cmd('systemctl stop tuned.service') diff --git a/src/conf_mode/system_syslog.py b/src/conf_mode/system_syslog.py index 2497c5bb6..eb2f02eb3 100755 --- a/src/conf_mode/system_syslog.py +++ b/src/conf_mode/system_syslog.py @@ -53,6 +53,17 @@ def get_config(config=None): if syslog.from_defaults(['global']): del syslog['global'] + if ( + 'global' in syslog + and 'preserve_fqdn' in syslog['global'] + and conf.exists(['system', 'host-name']) + and conf.exists(['system', 'domain-name']) + ): + hostname = conf.return_value(['system', 'host-name']) + domain = conf.return_value(['system', 'domain-name']) + fqdn = f'{hostname}.{domain}' + syslog['global']['local_host_name'] = fqdn + return syslog def verify(syslog): diff --git a/src/conf_mode/vpn_ipsec.py b/src/conf_mode/vpn_ipsec.py index ca0c3657f..e22b7550c 100755 --- a/src/conf_mode/vpn_ipsec.py +++ b/src/conf_mode/vpn_ipsec.py @@ -214,6 +214,19 @@ def verify(ipsec): else: verify_interface_exists(ipsec, interface) + # need to use a pseudo-random function (PRF) with an authenticated encryption algorithm. + # If a hash algorithm is defined then it will be mapped to an equivalent PRF + if 'ike_group' in ipsec: + for _, ike_config in ipsec['ike_group'].items(): + for proposal, proposal_config in ike_config.get('proposal', {}).items(): + if 'encryption' in proposal_config and 'prf' not in proposal_config: + # list of hash algorithms that cannot be mapped to an equivalent PRF + algs = ['aes128gmac', 'aes192gmac', 'aes256gmac', 'sha256_96'] + if 'hash' in proposal_config and proposal_config['hash'] in algs: + raise ConfigError( + f"A PRF algorithm is mandatory in IKE proposal {proposal}" + ) + if 'l2tp' in ipsec: if 'esp_group' in ipsec['l2tp']: if 'esp_group' not in ipsec or ipsec['l2tp']['esp_group'] not in ipsec['esp_group']: diff --git a/src/helpers/commit-confirm-notify.py b/src/helpers/commit-confirm-notify.py index 8d7626c78..af6167651 100755 --- a/src/helpers/commit-confirm-notify.py +++ b/src/helpers/commit-confirm-notify.py @@ -2,30 +2,56 @@ import os import sys import time +from argparse import ArgumentParser # Minutes before reboot to trigger notification. intervals = [1, 5, 15, 60] -def notify(interval): - s = "" if interval == 1 else "s" +parser = ArgumentParser() +parser.add_argument( + 'minutes', type=int, help='minutes before rollback to trigger notification' +) +parser.add_argument( + '--reboot', action='store_true', help="use 'soft' rollback instead of reboot" +) + + +def notify(interval, reboot=False): + s = '' if interval == 1 else 's' time.sleep((minutes - interval) * 60) - message = ('"[commit-confirm] System is going to reboot in ' - f'{interval} minute{s} to rollback the last commit.\n' - 'Confirm your changes to cancel the reboot."') - os.system("wall -n " + message) + if reboot: + message = ( + '"[commit-confirm] System will reboot in ' + f'{interval} minute{s}\nto rollback the last commit.\n' + 'Confirm your changes to cancel the reboot."' + ) + os.system('wall -n ' + message) + else: + message = ( + '"[commit-confirm] System will reload previous config in ' + f'{interval} minute{s}\nto rollback the last commit.\n' + 'Confirm your changes to cancel the reload."' + ) + os.system('wall -n ' + message) + -if __name__ == "__main__": +if __name__ == '__main__': # Must be run as root to call wall(1) without a banner. - if len(sys.argv) != 2 or os.getuid() != 0: + if os.getuid() != 0: print('This script requires superuser privileges.', file=sys.stderr) exit(1) - minutes = int(sys.argv[1]) + + args = parser.parse_args() + + minutes = args.minutes + reboot = args.reboot + # Drop the argument from the list so that the notification # doesn't kick in immediately. if minutes in intervals: intervals.remove(minutes) for interval in sorted(intervals, reverse=True): if minutes >= interval: - notify(interval) - minutes -= (minutes - interval) + notify(interval, reboot=reboot) + minutes -= minutes - interval exit(0) diff --git a/src/helpers/vyos-domain-resolver.py b/src/helpers/vyos-domain-resolver.py index 57cfcabd7..f5a1d9297 100755 --- a/src/helpers/vyos-domain-resolver.py +++ b/src/helpers/vyos-domain-resolver.py @@ -30,6 +30,8 @@ from vyos.xml_ref import get_defaults base = ['firewall'] timeout = 300 cache = False +base_firewall = ['firewall'] +base_nat = ['nat'] domain_state = {} @@ -46,25 +48,25 @@ ipv6_tables = { 'ip6 raw' } -def get_config(conf): - firewall = conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, +def get_config(conf, node): + node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True) - default_values = get_defaults(base, get_first_key=True) + default_values = get_defaults(node, get_first_key=True) - firewall = dict_merge(default_values, firewall) + node_config = dict_merge(default_values, node_config) global timeout, cache - if 'resolver_interval' in firewall: - timeout = int(firewall['resolver_interval']) + if 'resolver_interval' in node_config: + timeout = int(node_config['resolver_interval']) - if 'resolver_cache' in firewall: + if 'resolver_cache' in node_config: cache = True - fqdn_config_parse(firewall) + fqdn_config_parse(node_config, node[0]) - return firewall + return node_config def resolve(domains, ipv6=False): global domain_state @@ -108,55 +110,60 @@ def nft_valid_sets(): except: return [] -def update(firewall): +def update_fqdn(config, node): conf_lines = [] count = 0 - valid_sets = nft_valid_sets() - domain_groups = dict_search_args(firewall, 'group', 'domain_group') - if domain_groups: - for set_name, domain_config in domain_groups.items(): - if 'address' not in domain_config: - continue - - nft_set_name = f'D_{set_name}' - domains = domain_config['address'] - - ip_list = resolve(domains, ipv6=False) - for table in ipv4_tables: - if (table, nft_set_name) in valid_sets: - conf_lines += nft_output(table, nft_set_name, ip_list) - - ip6_list = resolve(domains, ipv6=True) - for table in ipv6_tables: - if (table, nft_set_name) in valid_sets: - conf_lines += nft_output(table, nft_set_name, ip6_list) + if node == 'firewall': + domain_groups = dict_search_args(config, 'group', 'domain_group') + if domain_groups: + for set_name, domain_config in domain_groups.items(): + if 'address' not in domain_config: + continue + nft_set_name = f'D_{set_name}' + domains = domain_config['address'] + + ip_list = resolve(domains, ipv6=False) + for table in ipv4_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + ip6_list = resolve(domains, ipv6=True) + for table in ipv6_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip6_list) + count += 1 + + for set_name, domain in config['ip_fqdn'].items(): + table = 'ip vyos_filter' + nft_set_name = f'FQDN_{set_name}' + ip_list = resolve([domain], ipv6=False) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) count += 1 - for set_name, domain in firewall['ip_fqdn'].items(): - table = 'ip vyos_filter' - nft_set_name = f'FQDN_{set_name}' - - ip_list = resolve([domain], ipv6=False) - - if (table, nft_set_name) in valid_sets: - conf_lines += nft_output(table, nft_set_name, ip_list) - count += 1 - - for set_name, domain in firewall['ip6_fqdn'].items(): - table = 'ip6 vyos_filter' - nft_set_name = f'FQDN_{set_name}' + for set_name, domain in config['ip6_fqdn'].items(): + table = 'ip6 vyos_filter' + nft_set_name = f'FQDN_{set_name}' + ip_list = resolve([domain], ipv6=True) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 - ip_list = resolve([domain], ipv6=True) - if (table, nft_set_name) in valid_sets: - conf_lines += nft_output(table, nft_set_name, ip_list) - count += 1 + else: + # It's NAT + for set_name, domain in config['ip_fqdn'].items(): + table = 'ip vyos_nat' + nft_set_name = f'FQDN_nat_{set_name}' + ip_list = resolve([domain], ipv6=False) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 nft_conf_str = "\n".join(conf_lines) + "\n" code = run(f'nft --file -', input=nft_conf_str) - print(f'Updated {count} sets - result: {code}') + print(f'Updated {count} sets in {node} - result: {code}') if __name__ == '__main__': print(f'VyOS domain resolver') @@ -169,10 +176,12 @@ if __name__ == '__main__': time.sleep(1) conf = ConfigTreeQuery() - firewall = get_config(conf) + firewall = get_config(conf, base_firewall) + nat = get_config(conf, base_nat) print(f'interval: {timeout}s - cache: {cache}') while True: - update(firewall) + update_fqdn(firewall, 'firewall') + update_fqdn(nat, 'nat') time.sleep(timeout) diff --git a/src/migration-scripts/https/6-to-7 b/src/migration-scripts/https/6-to-7 new file mode 100644 index 000000000..571f3b6ae --- /dev/null +++ b/src/migration-scripts/https/6-to-7 @@ -0,0 +1,43 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +# T6736: move REST API to distinct node + + +from vyos.configtree import ConfigTree + + +base = ['service', 'https', 'api'] + +def migrate(config: ConfigTree) -> None: + if not config.exists(base): + # Nothing to do + return + + # Move REST API configuration to new node + # REST API was previously enabled if base path exists + config.set(['service', 'https', 'api', 'rest']) + for entry in ('debug', 'strict'): + if config.exists(base + [entry]): + config.set(base + ['rest', entry]) + config.delete(base + [entry]) + + # Move CORS settings under GraphQL + # CORS is not implemented for REST API + if config.exists(base + ['cors']): + old_base = base + ['cors'] + new_base = base + ['graphql', 'cors'] + config.copy(old_base, new_base) + config.delete(old_base) diff --git a/src/migration-scripts/reverse-proxy/1-to-2 b/src/migration-scripts/reverse-proxy/1-to-2 new file mode 100755 index 000000000..61612bc36 --- /dev/null +++ b/src/migration-scripts/reverse-proxy/1-to-2 @@ -0,0 +1,27 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +# T6745: Rename base node to haproxy + +from vyos.configtree import ConfigTree + +base = ['load-balancing', 'reverse-proxy'] + +def migrate(config: ConfigTree) -> None: + if not config.exists(base): + # Nothing to do + return + + config.rename(base, 'haproxy') diff --git a/src/migration-scripts/system/27-to-28 b/src/migration-scripts/system/27-to-28 new file mode 100644 index 000000000..0a5be48ab --- /dev/null +++ b/src/migration-scripts/system/27-to-28 @@ -0,0 +1,33 @@ +# Copyright 2023-2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +# rename 'system option performance' leaf nodes to new names + +from vyos.configtree import ConfigTree + +base = ['system', 'option', 'performance'] + +def migrate(config: ConfigTree) -> None: + if not config.exists(base): + return + + replace = { + 'throughput' : 'network-throughput', + 'latency' : 'network-latency' + } + + for old_name, new_name in replace.items(): + if config.return_value(base) == old_name: + config.set(base, new_name) diff --git a/src/op_mode/interfaces_wireguard.py b/src/op_mode/interfaces_wireguard.py new file mode 100644 index 000000000..627af0579 --- /dev/null +++ b/src/op_mode/interfaces_wireguard.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import sys +import vyos.opmode + +from vyos.ifconfig import WireGuardIf +from vyos.configquery import ConfigTreeQuery + + +def _verify(func): + """Decorator checks if WireGuard interface config exists""" + from functools import wraps + + @wraps(func) + def _wrapper(*args, **kwargs): + config = ConfigTreeQuery() + interface = kwargs.get('intf_name') + if not config.exists(['interfaces', 'wireguard', interface]): + unconf_message = f'WireGuard interface {interface} is not configured' + raise vyos.opmode.UnconfiguredSubsystem(unconf_message) + return func(*args, **kwargs) + + return _wrapper + + +@_verify +def show_summary(raw: bool, intf_name: str): + intf = WireGuardIf(intf_name, create=False, debug=False) + return intf.operational.show_interface() + + +if __name__ == '__main__': + try: + res = vyos.opmode.run(sys.modules[__name__]) + if res: + print(res) + except (ValueError, vyos.opmode.Error) as e: + print(e) + sys.exit(1) diff --git a/src/op_mode/reverseproxy.py b/src/op_mode/load-balancing_haproxy.py index 19704182a..ae6734e16 100755 --- a/src/op_mode/reverseproxy.py +++ b/src/op_mode/load-balancing_haproxy.py @@ -217,8 +217,8 @@ def _get_formatted_output(data): def show(raw: bool): config = ConfigTreeQuery() - if not config.exists('load-balancing reverse-proxy'): - raise vyos.opmode.UnconfiguredSubsystem('Reverse-proxy is not configured') + if not config.exists('load-balancing haproxy'): + raise vyos.opmode.UnconfiguredSubsystem('Haproxy is not configured') data = _get_raw_data() if raw: diff --git a/src/op_mode/mtr.py b/src/op_mode/mtr.py index de139f2fa..522cbe008 100644 --- a/src/op_mode/mtr.py +++ b/src/op_mode/mtr.py @@ -23,161 +23,162 @@ from vyos.utils.network import vrf_list from vyos.utils.process import call options = { - 'report': { + 'report-mode': { 'mtr': '{command} --report', 'type': 'noarg', - 'help': 'This option puts mtr into report mode. When in this mode, mtr will run for the number of cycles specified by the -c option, and then print statistics and exit.' + 'help': 'This option puts mtr into report mode. When in this mode, mtr will run for the number of cycles specified by the -c option, and then print statistics and exit.', }, 'report-wide': { 'mtr': '{command} --report-wide', 'type': 'noarg', - 'help': 'This option puts mtr into wide report mode. When in this mode, mtr will not cut hostnames in the report.' + 'help': 'This option puts mtr into wide report mode. When in this mode, mtr will not cut hostnames in the report.', }, 'raw': { 'mtr': '{command} --raw', 'type': 'noarg', - 'help': 'Use the raw output format. This format is better suited for archival of the measurement results.' + 'help': 'Use the raw output format. This format is better suited for archival of the measurement results.', }, 'json': { 'mtr': '{command} --json', 'type': 'noarg', - 'help': 'Use this option to tell mtr to use the JSON output format.' + 'help': 'Use this option to tell mtr to use the JSON output format.', }, 'split': { 'mtr': '{command} --split', 'type': 'noarg', - 'help': 'Use this option to set mtr to spit out a format that is suitable for a split-user interface.' + 'help': 'Use this option to set mtr to spit out a format that is suitable for a split-user interface.', }, 'no-dns': { 'mtr': '{command} --no-dns', 'type': 'noarg', - 'help': 'Use this option to force mtr to display numeric IP numbers and not try to resolve the host names.' + 'help': 'Use this option to force mtr to display numeric IP numbers and not try to resolve the host names.', }, 'show-ips': { 'mtr': '{command} --show-ips {value}', 'type': '<num>', - 'help': 'Use this option to tell mtr to display both the host names and numeric IP numbers.' + 'help': 'Use this option to tell mtr to display both the host names and numeric IP numbers.', }, 'ipinfo': { 'mtr': '{command} --ipinfo {value}', 'type': '<num>', - 'help': 'Displays information about each IP hop.' + 'help': 'Displays information about each IP hop.', }, 'aslookup': { 'mtr': '{command} --aslookup', 'type': 'noarg', - 'help': 'Displays the Autonomous System (AS) number alongside each hop. Equivalent to --ipinfo 0.' + 'help': 'Displays the Autonomous System (AS) number alongside each hop. Equivalent to --ipinfo 0.', }, 'interval': { 'mtr': '{command} --interval {value}', 'type': '<num>', - 'help': 'Use this option to specify the positive number of seconds between ICMP ECHO requests. The default value for this parameter is one second. The root user may choose values between zero and one.' + 'help': 'Use this option to specify the positive number of seconds between ICMP ECHO requests. The default value for this parameter is one second. The root user may choose values between zero and one.', }, 'report-cycles': { 'mtr': '{command} --report-cycles {value}', 'type': '<num>', - 'help': 'Use this option to set the number of pings sent to determine both the machines on the network and the reliability of those machines. Each cycle lasts one second.' + 'help': 'Use this option to set the number of pings sent to determine both the machines on the network and the reliability of those machines. Each cycle lasts one second.', }, 'psize': { 'mtr': '{command} --psize {value}', 'type': '<num>', - 'help': 'This option sets the packet size used for probing. It is in bytes, inclusive IP and ICMP headers. If set to a negative number, every iteration will use a different, random packet size up to that number.' + 'help': 'This option sets the packet size used for probing. It is in bytes, inclusive IP and ICMP headers. If set to a negative number, every iteration will use a different, random packet size up to that number.', }, 'bitpattern': { 'mtr': '{command} --bitpattern {value}', 'type': '<num>', - 'help': 'Specifies bit pattern to use in payload. Should be within range 0 - 255. If NUM is greater than 255, a random pattern is used.' + 'help': 'Specifies bit pattern to use in payload. Should be within range 0 - 255. If NUM is greater than 255, a random pattern is used.', }, 'gracetime': { 'mtr': '{command} --gracetime {value}', 'type': '<num>', - 'help': 'Use this option to specify the positive number of seconds to wait for responses after the final request. The default value is five seconds.' + 'help': 'Use this option to specify the positive number of seconds to wait for responses after the final request. The default value is five seconds.', }, 'tos': { 'mtr': '{command} --tos {value}', 'type': '<tos>', - 'help': 'Specifies value for type of service field in IP header. Should be within range 0 - 255.' + 'help': 'Specifies value for type of service field in IP header. Should be within range 0 - 255.', }, 'mpls': { 'mtr': '{command} --mpls {value}', 'type': 'noarg', - 'help': 'Use this option to tell mtr to display information from ICMP extensions for MPLS (RFC 4950) that are encoded in the response packets.' + 'help': 'Use this option to tell mtr to display information from ICMP extensions for MPLS (RFC 4950) that are encoded in the response packets.', }, 'interface': { 'mtr': '{command} --interface {value}', 'type': '<interface>', 'helpfunction': interface_list, - 'help': 'Use the network interface with a specific name for sending network probes. This can be useful when you have multiple network interfaces with routes to your destination, for example both wired Ethernet and WiFi, and wish to test a particular interface.' + 'help': 'Use the network interface with a specific name for sending network probes. This can be useful when you have multiple network interfaces with routes to your destination, for example both wired Ethernet and WiFi, and wish to test a particular interface.', }, 'address': { 'mtr': '{command} --address {value}', 'type': '<x.x.x.x> <h:h:h:h:h:h:h:h>', - 'help': 'Use this option to bind the outgoing socket to ADDRESS, so that all packets will be sent with ADDRESS as source address.' + 'help': 'Use this option to bind the outgoing socket to ADDRESS, so that all packets will be sent with ADDRESS as source address.', }, 'first-ttl': { 'mtr': '{command} --first-ttl {value}', 'type': '<num>', - 'help': 'Specifies with what TTL to start. Defaults to 1.' + 'help': 'Specifies with what TTL to start. Defaults to 1.', }, 'max-ttl': { 'mtr': '{command} --max-ttl {value}', 'type': '<num>', - 'help': 'Specifies the maximum number of hops or max time-to-live value mtr will probe. Default is 30.' + 'help': 'Specifies the maximum number of hops or max time-to-live value mtr will probe. Default is 30.', }, 'max-unknown': { 'mtr': '{command} --max-unknown {value}', 'type': '<num>', - 'help': 'Specifies the maximum unknown host. Default is 5.' + 'help': 'Specifies the maximum unknown host. Default is 5.', }, 'udp': { 'mtr': '{command} --udp', 'type': 'noarg', - 'help': 'Use UDP datagrams instead of ICMP ECHO.' + 'help': 'Use UDP datagrams instead of ICMP ECHO.', }, 'tcp': { 'mtr': '{command} --tcp', 'type': 'noarg', - 'help': ' Use TCP SYN packets instead of ICMP ECHO. PACKETSIZE is ignored, since SYN packets can not contain data.' + 'help': ' Use TCP SYN packets instead of ICMP ECHO. PACKETSIZE is ignored, since SYN packets can not contain data.', }, 'sctp': { 'mtr': '{command} --sctp', 'type': 'noarg', - 'help': 'Use Stream Control Transmission Protocol packets instead of ICMP ECHO.' + 'help': 'Use Stream Control Transmission Protocol packets instead of ICMP ECHO.', }, 'port': { 'mtr': '{command} --port {value}', 'type': '<port>', - 'help': 'The target port number for TCP/SCTP/UDP traces.' + 'help': 'The target port number for TCP/SCTP/UDP traces.', }, 'localport': { 'mtr': '{command} --localport {value}', 'type': '<port>', - 'help': 'The source port number for UDP traces.' + 'help': 'The source port number for UDP traces.', }, 'timeout': { 'mtr': '{command} --timeout {value}', 'type': '<num>', - 'help': ' The number of seconds to keep probe sockets open before giving up on the connection.' + 'help': ' The number of seconds to keep probe sockets open before giving up on the connection.', }, 'mark': { 'mtr': '{command} --mark {value}', 'type': '<num>', - 'help': ' Set the mark for each packet sent through this socket similar to the netfilter MARK target but socket-based. MARK is 32 unsigned integer.' + 'help': ' Set the mark for each packet sent through this socket similar to the netfilter MARK target but socket-based. MARK is 32 unsigned integer.', }, 'vrf': { 'mtr': 'sudo ip vrf exec {value} {command}', 'type': '<vrf>', 'help': 'Use specified VRF table', 'helpfunction': vrf_list, - 'dflt': 'default' - } - } + 'dflt': 'default', + }, +} mtr = { 4: '/bin/mtr -4', 6: '/bin/mtr -6', } + class List(list): def first(self): return self.pop(0) if self else '' @@ -203,8 +204,8 @@ def completion_failure(option: str) -> None: def expension_failure(option, completions): reason = 'Ambiguous' if completions else 'Invalid' sys.stderr.write( - '\n\n {} command: {} [{}]\n\n'.format(reason, ' '.join(sys.argv), - option)) + '\n\n {} command: {} [{}]\n\n'.format(reason, ' '.join(sys.argv), option) + ) if completions: sys.stderr.write(' Possible completions:\n ') sys.stderr.write('\n '.join(completions)) @@ -218,21 +219,24 @@ def complete(prefix): def convert(command, args): + to_json = False while args: shortname = args.first() longnames = complete(shortname) if len(longnames) != 1: expension_failure(shortname, longnames) longname = longnames[0] + if longname == 'json': + to_json = True if options[longname]['type'] == 'noarg': - command = options[longname]['mtr'].format( - command=command, value='') + command = options[longname]['mtr'].format(command=command, value='') elif not args: sys.exit(f'mtr: missing argument for {longname} option') else: command = options[longname]['mtr'].format( - command=command, value=args.first()) - return command + command=command, value=args.first() + ) + return command, to_json if __name__ == '__main__': @@ -240,8 +244,7 @@ if __name__ == '__main__': host = args.first() if not host: - sys.exit("mtr: Missing host") - + sys.exit('mtr: Missing host') if host == '--get-options' or host == '--get-options-nested': if host == '--get-options-nested': @@ -302,5 +305,8 @@ if __name__ == '__main__': except ValueError: sys.exit(f'mtr: Unknown host: {host}') - command = convert(mtr[version], args) - call(f'{command} --curses --displaymode 0 {host}') + command, to_json = convert(mtr[version], args) + if to_json: + call(f'{command} {host}') + else: + call(f'{command} --curses --displaymode 0 {host}') diff --git a/src/op_mode/mtr_execute.py b/src/op_mode/mtr_execute.py new file mode 100644 index 000000000..2585a7ee4 --- /dev/null +++ b/src/op_mode/mtr_execute.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import ipaddress +import socket +import sys +import typing + +from json import loads + +from vyos.utils.network import interface_list +from vyos.utils.network import vrf_list +from vyos.utils.process import cmd +from vyos.utils.process import call + +import vyos.opmode + +ArgProtocol = typing.Literal['tcp', 'udp', 'sctp'] +noargs_list = [ + 'report_mode', + 'json', + 'report_wide', + 'split', + 'raw', + 'no_dns', + 'aslookup', +] + + +def vrf_list_default(): + return vrf_list() + ['default'] + + +options = { + 'report_mode': { + 'mtr': '{command} --report', + }, + 'protocol': { + 'mtr': '{command} --{value}', + }, + 'json': { + 'mtr': '{command} --json', + }, + 'report_wide': { + 'mtr': '{command} --report-wide', + }, + 'raw': { + 'mtr': '{command} --raw', + }, + 'split': { + 'mtr': '{command} --split', + }, + 'no_dns': { + 'mtr': '{command} --no-dns', + }, + 'show_ips': { + 'mtr': '{command} --show-ips {value}', + }, + 'ipinfo': { + 'mtr': '{command} --ipinfo {value}', + }, + 'aslookup': { + 'mtr': '{command} --aslookup', + }, + 'interval': { + 'mtr': '{command} --interval {value}', + }, + 'report_cycles': { + 'mtr': '{command} --report-cycles {value}', + }, + 'psize': { + 'mtr': '{command} --psize {value}', + }, + 'bitpattern': { + 'mtr': '{command} --bitpattern {value}', + }, + 'gracetime': { + 'mtr': '{command} --gracetime {value}', + }, + 'tos': { + 'mtr': '{command} --tos {value}', + }, + 'mpls': { + 'mtr': '{command} --mpls {value}', + }, + 'interface': { + 'mtr': '{command} --interface {value}', + 'helpfunction': interface_list, + }, + 'address': { + 'mtr': '{command} --address {value}', + }, + 'first_ttl': { + 'mtr': '{command} --first-ttl {value}', + }, + 'max_ttl': { + 'mtr': '{command} --max-ttl {value}', + }, + 'max_unknown': { + 'mtr': '{command} --max-unknown {value}', + }, + 'port': { + 'mtr': '{command} --port {value}', + }, + 'localport': { + 'mtr': '{command} --localport {value}', + }, + 'timeout': { + 'mtr': '{command} --timeout {value}', + }, + 'mark': { + 'mtr': '{command} --mark {value}', + }, + 'vrf': { + 'mtr': 'sudo ip vrf exec {value} {command}', + 'helpfunction': vrf_list_default, + 'dflt': 'default', + }, +} + +mtr_command = { + 4: '/bin/mtr -4', + 6: '/bin/mtr -6', +} + + +def mtr( + host: str, + for_api: typing.Optional[bool], + report_mode: typing.Optional[bool], + protocol: typing.Optional[ArgProtocol], + report_wide: typing.Optional[bool], + raw: typing.Optional[bool], + json: typing.Optional[bool], + split: typing.Optional[bool], + no_dns: typing.Optional[bool], + show_ips: typing.Optional[str], + ipinfo: typing.Optional[str], + aslookup: typing.Optional[bool], + interval: typing.Optional[str], + report_cycles: typing.Optional[str], + psize: typing.Optional[str], + bitpattern: typing.Optional[str], + gracetime: typing.Optional[str], + tos: typing.Optional[str], + mpl: typing.Optional[bool], + interface: typing.Optional[str], + address: typing.Optional[str], + first_ttl: typing.Optional[str], + max_ttl: typing.Optional[str], + max_unknown: typing.Optional[str], + port: typing.Optional[str], + localport: typing.Optional[str], + timeout: typing.Optional[str], + mark: typing.Optional[str], + vrf: typing.Optional[str], +): + args = locals() + for name, option in options.items(): + if 'dflt' in option and not args[name]: + args[name] = option['dflt'] + + try: + ip = socket.gethostbyname(host) + except UnicodeError: + raise vyos.opmode.InternalError(f'Unknown host: {host}') + except socket.gaierror: + ip = host + + try: + version = ipaddress.ip_address(ip).version + except ValueError: + raise vyos.opmode.InternalError(f'Unknown host: {host}') + + command = mtr_command[version] + + for key, val in args.items(): + if key in options and val: + if 'helpfunction' in options[key]: + allowed_values = options[key]['helpfunction']() + if val not in allowed_values: + raise vyos.opmode.InternalError( + f'Invalid argument for option {key} - {val}' + ) + value = '' if key in noargs_list else val + command = options[key]['mtr'].format(command=command, value=val) + + if json: + output = cmd(f'{command} {host}') + if for_api: + output = loads(output) + print(output) + else: + call(f'{command} --curses --displaymode 0 {host}') + + +if __name__ == '__main__': + try: + res = vyos.opmode.run(sys.modules[__name__]) + if res: + print(res) + except (ValueError, vyos.opmode.Error) as e: + print(e) + sys.exit(1) diff --git a/src/op_mode/pki.py b/src/op_mode/pki.py index ab613e5c4..49a461e9e 100755 --- a/src/op_mode/pki.py +++ b/src/op_mode/pki.py @@ -14,25 +14,36 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. -import argparse import ipaddress import os import re import sys import tabulate +import typing from cryptography import x509 from cryptography.x509.oid import ExtendedKeyUsageOID +import vyos.opmode + from vyos.config import Config from vyos.config import config_dict_mangle_acme -from vyos.pki import encode_certificate, encode_public_key, encode_private_key, encode_dh_parameters +from vyos.pki import encode_certificate +from vyos.pki import encode_public_key +from vyos.pki import encode_private_key +from vyos.pki import encode_dh_parameters from vyos.pki import get_certificate_fingerprint -from vyos.pki import create_certificate, create_certificate_request, create_certificate_revocation_list +from vyos.pki import create_certificate +from vyos.pki import create_certificate_request +from vyos.pki import create_certificate_revocation_list from vyos.pki import create_private_key from vyos.pki import create_dh_parameters -from vyos.pki import load_certificate, load_certificate_request, load_private_key -from vyos.pki import load_crl, load_dh_parameters, load_public_key +from vyos.pki import load_certificate +from vyos.pki import load_certificate_request +from vyos.pki import load_private_key +from vyos.pki import load_crl +from vyos.pki import load_dh_parameters +from vyos.pki import load_public_key from vyos.pki import verify_certificate from vyos.utils.io import ask_input from vyos.utils.io import ask_yes_no @@ -42,18 +53,50 @@ from vyos.utils.process import cmd CERT_REQ_END = '-----END CERTIFICATE REQUEST-----' auth_dir = '/config/auth' +ArgsPkiType = typing.Literal['ca', 'certificate', 'dh', 'key-pair', 'openvpn', 'crl'] +ArgsPkiTypeGen = typing.Literal[ArgsPkiType, typing.Literal['ssh', 'wireguard']] +ArgsFingerprint = typing.Literal['sha256', 'sha384', 'sha512'] + # Helper Functions conf = Config() + + +def _verify(target): + """Decorator checks if config for PKI exists""" + from functools import wraps + + if target not in ['ca', 'certificate']: + raise ValueError('Invalid PKI') + + def _verify_target(func): + @wraps(func) + def _wrapper(*args, **kwargs): + name = kwargs.get('name') + unconf_message = f'PKI {target} "{name}" does not exist!' + if name: + if not conf.exists(['pki', target, name]): + raise vyos.opmode.UnconfiguredSubsystem(unconf_message) + return func(*args, **kwargs) + + return _wrapper + + return _verify_target + + def get_default_values(): # Fetch default x509 values base = ['pki', 'x509', 'default'] - x509_defaults = conf.get_config_dict(base, key_mangling=('-', '_'), - no_tag_node_value_mangle=True, - get_first_key=True, - with_recursive_defaults=True) + x509_defaults = conf.get_config_dict( + base, + key_mangling=('-', '_'), + no_tag_node_value_mangle=True, + get_first_key=True, + with_recursive_defaults=True, + ) return x509_defaults + def get_config_ca_certificate(name=None): # Fetch ca certificates from config base = ['pki', 'ca'] @@ -62,12 +105,15 @@ def get_config_ca_certificate(name=None): if name: base = base + [name] - if not conf.exists(base + ['private', 'key']) or not conf.exists(base + ['certificate']): + if not conf.exists(base + ['private', 'key']) or not conf.exists( + base + ['certificate'] + ): return False - return conf.get_config_dict(base, key_mangling=('-', '_'), - get_first_key=True, - no_tag_node_value_mangle=True) + return conf.get_config_dict( + base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True + ) + def get_config_certificate(name=None): # Get certificates from config @@ -77,18 +123,21 @@ def get_config_certificate(name=None): if name: base = base + [name] - if not conf.exists(base + ['private', 'key']) or not conf.exists(base + ['certificate']): + if not conf.exists(base + ['private', 'key']) or not conf.exists( + base + ['certificate'] + ): return False - pki = conf.get_config_dict(base, key_mangling=('-', '_'), - get_first_key=True, - no_tag_node_value_mangle=True) + pki = conf.get_config_dict( + base, key_mangling=('-', '_'), get_first_key=True, no_tag_node_value_mangle=True + ) if pki: for certificate in pki: pki[certificate] = config_dict_mangle_acme(certificate, pki[certificate]) return pki + def get_certificate_ca(cert, ca_certs): # Find CA certificate for given certificate if not ca_certs: @@ -107,6 +156,7 @@ def get_certificate_ca(cert, ca_certs): return ca_name return None + def get_config_revoked_certificates(): # Fetch revoked certificates from config ca_base = ['pki', 'ca'] @@ -115,19 +165,26 @@ def get_config_revoked_certificates(): certs = [] if conf.exists(ca_base): - ca_certificates = conf.get_config_dict(ca_base, key_mangling=('-', '_'), - get_first_key=True, - no_tag_node_value_mangle=True) + ca_certificates = conf.get_config_dict( + ca_base, + key_mangling=('-', '_'), + get_first_key=True, + no_tag_node_value_mangle=True, + ) certs.extend(ca_certificates.values()) if conf.exists(cert_base): - certificates = conf.get_config_dict(cert_base, key_mangling=('-', '_'), - get_first_key=True, - no_tag_node_value_mangle=True) + certificates = conf.get_config_dict( + cert_base, + key_mangling=('-', '_'), + get_first_key=True, + no_tag_node_value_mangle=True, + ) certs.extend(certificates.values()) return [cert_dict for cert_dict in certs if 'revoke' in cert_dict] + def get_revoked_by_serial_numbers(serial_numbers=[]): # Return serial numbers of revoked certificates certs_out = [] @@ -151,113 +208,153 @@ def get_revoked_by_serial_numbers(serial_numbers=[]): certs_out.append(cert_name) return certs_out -def install_certificate(name, cert='', private_key=None, key_type=None, key_passphrase=None, is_ca=False): + +def install_certificate( + name, cert='', private_key=None, key_type=None, key_passphrase=None, is_ca=False +): # Show/install conf commands for certificate prefix = 'ca' if is_ca else 'certificate' - base = f"pki {prefix} {name}" + base = f'pki {prefix} {name}' config_paths = [] if cert: - cert_pem = "".join(encode_certificate(cert).strip().split("\n")[1:-1]) + cert_pem = ''.join(encode_certificate(cert).strip().split('\n')[1:-1]) config_paths.append(f"{base} certificate '{cert_pem}'") if private_key: - key_pem = "".join(encode_private_key(private_key, passphrase=key_passphrase).strip().split("\n")[1:-1]) + key_pem = ''.join( + encode_private_key(private_key, passphrase=key_passphrase) + .strip() + .split('\n')[1:-1] + ) config_paths.append(f"{base} private key '{key_pem}'") if key_passphrase: - config_paths.append(f"{base} private password-protected") + config_paths.append(f'{base} private password-protected') install_into_config(conf, config_paths) + def install_crl(ca_name, crl): # Show/install conf commands for crl - crl_pem = "".join(encode_certificate(crl).strip().split("\n")[1:-1]) + crl_pem = ''.join(encode_certificate(crl).strip().split('\n')[1:-1]) install_into_config(conf, [f"pki ca {ca_name} crl '{crl_pem}'"]) + def install_dh_parameters(name, params): # Show/install conf commands for dh params - dh_pem = "".join(encode_dh_parameters(params).strip().split("\n")[1:-1]) + dh_pem = ''.join(encode_dh_parameters(params).strip().split('\n')[1:-1]) install_into_config(conf, [f"pki dh {name} parameters '{dh_pem}'"]) + def install_ssh_key(name, public_key, private_key, passphrase=None): # Show/install conf commands for ssh key - key_openssh = encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH') + key_openssh = encode_public_key( + public_key, encoding='OpenSSH', key_format='OpenSSH' + ) username = os.getlogin() - type_key_split = key_openssh.split(" ") - - base = f"system login user {username} authentication public-keys {name}" - install_into_config(conf, [ - f"{base} key '{type_key_split[1]}'", - f"{base} type '{type_key_split[0]}'" - ]) - print(encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase)) - -def install_keypair(name, key_type, private_key=None, public_key=None, passphrase=None, prompt=True): + type_key_split = key_openssh.split(' ') + + base = f'system login user {username} authentication public-keys {name}' + install_into_config( + conf, + [f"{base} key '{type_key_split[1]}'", f"{base} type '{type_key_split[0]}'"], + ) + print( + encode_private_key( + private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase + ) + ) + + +def install_keypair( + name, key_type, private_key=None, public_key=None, passphrase=None, prompt=True +): # Show/install conf commands for key-pair config_paths = [] if public_key: - install_public_key = not prompt or ask_yes_no('Do you want to install the public key?', default=True) + install_public_key = not prompt or ask_yes_no( + 'Do you want to install the public key?', default=True + ) public_key_pem = encode_public_key(public_key) if install_public_key: - install_public_pem = "".join(public_key_pem.strip().split("\n")[1:-1]) - config_paths.append(f"pki key-pair {name} public key '{install_public_pem}'") + install_public_pem = ''.join(public_key_pem.strip().split('\n')[1:-1]) + config_paths.append( + f"pki key-pair {name} public key '{install_public_pem}'" + ) else: - print("Public key:") + print('Public key:') print(public_key_pem) if private_key: - install_private_key = not prompt or ask_yes_no('Do you want to install the private key?', default=True) + install_private_key = not prompt or ask_yes_no( + 'Do you want to install the private key?', default=True + ) private_key_pem = encode_private_key(private_key, passphrase=passphrase) if install_private_key: - install_private_pem = "".join(private_key_pem.strip().split("\n")[1:-1]) - config_paths.append(f"pki key-pair {name} private key '{install_private_pem}'") + install_private_pem = ''.join(private_key_pem.strip().split('\n')[1:-1]) + config_paths.append( + f"pki key-pair {name} private key '{install_private_pem}'" + ) if passphrase: - config_paths.append(f"pki key-pair {name} private password-protected") + config_paths.append(f'pki key-pair {name} private password-protected') else: - print("Private key:") + print('Private key:') print(private_key_pem) install_into_config(conf, config_paths) + def install_openvpn_key(name, key_data, key_version='1'): config_paths = [ f"pki openvpn shared-secret {name} key '{key_data}'", - f"pki openvpn shared-secret {name} version '{key_version}'" + f"pki openvpn shared-secret {name} version '{key_version}'", ] install_into_config(conf, config_paths) + def install_wireguard_key(interface, private_key, public_key): # Show conf commands for installing wireguard key pairs from vyos.ifconfig import Section + if Section.section(interface) != 'wireguard': print(f'"{interface}" is not a WireGuard interface name!') exit(1) # Check if we are running in a config session - if yes, we can directly write to the CLI - install_into_config(conf, [f"interfaces wireguard {interface} private-key '{private_key}'"]) + install_into_config( + conf, [f"interfaces wireguard {interface} private-key '{private_key}'"] + ) print(f"Corresponding public-key to use on peer system is: '{public_key}'") + def install_wireguard_psk(interface, peer, psk): from vyos.ifconfig import Section + if Section.section(interface) != 'wireguard': print(f'"{interface}" is not a WireGuard interface name!') exit(1) # Check if we are running in a config session - if yes, we can directly write to the CLI - install_into_config(conf, [f"interfaces wireguard {interface} peer {peer} preshared-key '{psk}'"]) + install_into_config( + conf, [f"interfaces wireguard {interface} peer {peer} preshared-key '{psk}'"] + ) + def ask_passphrase(): passphrase = None - print("Note: If you plan to use the generated key on this router, do not encrypt the private key.") + print( + 'Note: If you plan to use the generated key on this router, do not encrypt the private key.' + ) if ask_yes_no('Do you want to encrypt the private key with a passphrase?'): passphrase = ask_input('Enter passphrase:') return passphrase + def write_file(filename, contents): full_path = os.path.join(auth_dir, filename) directory = os.path.dirname(full_path) @@ -266,7 +363,9 @@ def write_file(filename, contents): print('Failed to write file: directory does not exist') return False - if os.path.exists(full_path) and not ask_yes_no('Do you want to overwrite the existing file?'): + if os.path.exists(full_path) and not ask_yes_no( + 'Do you want to overwrite the existing file?' + ): return False with open(full_path, 'w') as f: @@ -274,10 +373,14 @@ def write_file(filename, contents): print(f'File written to {full_path}') -# Generation functions +# Generation functions def generate_private_key(): - key_type = ask_input('Enter private key type: [rsa, dsa, ec]', default='rsa', valid_responses=['rsa', 'dsa', 'ec']) + key_type = ask_input( + 'Enter private key type: [rsa, dsa, ec]', + default='rsa', + valid_responses=['rsa', 'dsa', 'ec'], + ) size_valid = [] size_default = 0 @@ -289,28 +392,43 @@ def generate_private_key(): size_default = 256 size_valid = [224, 256, 384, 521] - size = ask_input('Enter private key bits:', default=size_default, numeric_only=True, valid_responses=size_valid) + size = ask_input( + 'Enter private key bits:', + default=size_default, + numeric_only=True, + valid_responses=size_valid, + ) return create_private_key(key_type, size), key_type + def parse_san_string(san_string): if not san_string: return None output = [] - san_split = san_string.strip().split(",") + san_split = san_string.strip().split(',') for pair_str in san_split: - tag, value = pair_str.strip().split(":", 1) + tag, value = pair_str.strip().split(':', 1) if tag == 'ipv4': output.append(ipaddress.IPv4Address(value)) elif tag == 'ipv6': output.append(ipaddress.IPv6Address(value)) elif tag == 'dns' or tag == 'rfc822': output.append(value) - return output - -def generate_certificate_request(private_key=None, key_type=None, return_request=False, name=None, install=False, file=False, ask_san=True): + return + + +def generate_certificate_request( + private_key=None, + key_type=None, + return_request=False, + name=None, + install=False, + file=False, + ask_san=True, +): if not private_key: private_key, key_type = generate_private_key() @@ -319,18 +437,24 @@ def generate_certificate_request(private_key=None, key_type=None, return_request while True: country = ask_input('Enter country code:', default=default_values['country']) if len(country) != 2: - print("Country name must be a 2 character country code") + print('Country name must be a 2 character country code') continue subject['country'] = country break subject['state'] = ask_input('Enter state:', default=default_values['state']) - subject['locality'] = ask_input('Enter locality:', default=default_values['locality']) - subject['organization'] = ask_input('Enter organization name:', default=default_values['organization']) + subject['locality'] = ask_input( + 'Enter locality:', default=default_values['locality'] + ) + subject['organization'] = ask_input( + 'Enter organization name:', default=default_values['organization'] + ) subject['common_name'] = ask_input('Enter common name:', default='vyos.io') subject_alt_names = None if ask_san and ask_yes_no('Do you want to configure Subject Alternative Names?'): - print("Enter alternative names in a comma separate list, example: ipv4:1.1.1.1,ipv6:fe80::1,dns:vyos.net,rfc822:user@vyos.net") + print( + 'Enter alternative names in a comma separate list, example: ipv4:1.1.1.1,ipv6:fe80::1,dns:vyos.net,rfc822:user@vyos.net' + ) san_string = ask_input('Enter Subject Alternative Names:') subject_alt_names = parse_san_string(san_string) @@ -347,24 +471,48 @@ def generate_certificate_request(private_key=None, key_type=None, return_request return None if install: - print("Certificate request:") - print(encode_certificate(cert_req) + "\n") - install_certificate(name, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False) + print('Certificate request:') + print(encode_certificate(cert_req) + '\n') + install_certificate( + name, + private_key=private_key, + key_type=key_type, + key_passphrase=passphrase, + is_ca=False, + ) if file: write_file(f'{name}.csr', encode_certificate(cert_req)) - write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) - -def generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False, is_sub_ca=False): - valid_days = ask_input('Enter how many days certificate will be valid:', default='365' if not is_ca else '1825', numeric_only=True) + write_file( + f'{name}.key', encode_private_key(private_key, passphrase=passphrase) + ) + + +def generate_certificate( + cert_req, ca_cert, ca_private_key, is_ca=False, is_sub_ca=False +): + valid_days = ask_input( + 'Enter how many days certificate will be valid:', + default='365' if not is_ca else '1825', + numeric_only=True, + ) cert_type = None if not is_ca: - cert_type = ask_input('Enter certificate type: (client, server)', default='server', valid_responses=['client', 'server']) - return create_certificate(cert_req, ca_cert, ca_private_key, valid_days, cert_type, is_ca, is_sub_ca) + cert_type = ask_input( + 'Enter certificate type: (client, server)', + default='server', + valid_responses=['client', 'server'], + ) + return create_certificate( + cert_req, ca_cert, ca_private_key, valid_days, cert_type, is_ca, is_sub_ca + ) + def generate_ca_certificate(name, install=False, file=False): private_key, key_type = generate_private_key() - cert_req = generate_certificate_request(private_key, key_type, return_request=True, ask_san=False) + cert_req = generate_certificate_request( + private_key, key_type, return_request=True, ask_san=False + ) cert = generate_certificate(cert_req, cert_req, private_key, is_ca=True) passphrase = ask_passphrase() @@ -374,11 +522,16 @@ def generate_ca_certificate(name, install=False, file=False): return None if install: - install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True) + install_certificate( + name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True + ) if file: write_file(f'{name}.pem', encode_certificate(cert)) - write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) + write_file( + f'{name}.key', encode_private_key(private_key, passphrase=passphrase) + ) + def generate_ca_certificate_sign(name, ca_name, install=False, file=False): ca_dict = get_config_ca_certificate(ca_name) @@ -390,17 +543,19 @@ def generate_ca_certificate_sign(name, ca_name, install=False, file=False): ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: - print("Failed to load signing CA certificate, aborting") + print('Failed to load signing CA certificate, aborting') return None ca_private = ca_dict['private'] ca_private_passphrase = None if 'password_protected' in ca_private: ca_private_passphrase = ask_input('Enter signing CA private key passphrase:') - ca_private_key = load_private_key(ca_private['key'], passphrase=ca_private_passphrase) + ca_private_key = load_private_key( + ca_private['key'], passphrase=ca_private_passphrase + ) if not ca_private_key: - print("Failed to load signing CA private key, aborting") + print('Failed to load signing CA private key, aborting') return None private_key = None @@ -409,9 +564,11 @@ def generate_ca_certificate_sign(name, ca_name, install=False, file=False): cert_req = None if not ask_yes_no('Do you already have a certificate request?'): private_key, key_type = generate_private_key() - cert_req = generate_certificate_request(private_key, key_type, return_request=True, ask_san=False) + cert_req = generate_certificate_request( + private_key, key_type, return_request=True, ask_san=False + ) else: - print("Paste certificate request and press enter:") + print('Paste certificate request and press enter:') lines = [] curr_line = '' while True: @@ -421,17 +578,21 @@ def generate_ca_certificate_sign(name, ca_name, install=False, file=False): lines.append(curr_line) if not lines: - print("Aborted") + print('Aborted') return None - wrap = lines[0].find('-----') < 0 # Only base64 pasted, add the CSR tags for parsing - cert_req = load_certificate_request("\n".join(lines), wrap) + wrap = ( + lines[0].find('-----') < 0 + ) # Only base64 pasted, add the CSR tags for parsing + cert_req = load_certificate_request('\n'.join(lines), wrap) if not cert_req: - print("Invalid certificate request") + print('Invalid certificate request') return None - cert = generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=True, is_sub_ca=True) + cert = generate_certificate( + cert_req, ca_cert, ca_private_key, is_ca=True, is_sub_ca=True + ) passphrase = None if private_key is not None: @@ -444,12 +605,17 @@ def generate_ca_certificate_sign(name, ca_name, install=False, file=False): return None if install: - install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True) + install_certificate( + name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=True + ) if file: write_file(f'{name}.pem', encode_certificate(cert)) if private_key is not None: - write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) + write_file( + f'{name}.key', encode_private_key(private_key, passphrase=passphrase) + ) + def generate_certificate_sign(name, ca_name, install=False, file=False): ca_dict = get_config_ca_certificate(ca_name) @@ -461,17 +627,19 @@ def generate_certificate_sign(name, ca_name, install=False, file=False): ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: - print("Failed to load CA certificate, aborting") + print('Failed to load CA certificate, aborting') return None ca_private = ca_dict['private'] ca_private_passphrase = None if 'password_protected' in ca_private: ca_private_passphrase = ask_input('Enter CA private key passphrase:') - ca_private_key = load_private_key(ca_private['key'], passphrase=ca_private_passphrase) + ca_private_key = load_private_key( + ca_private['key'], passphrase=ca_private_passphrase + ) if not ca_private_key: - print("Failed to load CA private key, aborting") + print('Failed to load CA private key, aborting') return None private_key = None @@ -480,9 +648,11 @@ def generate_certificate_sign(name, ca_name, install=False, file=False): cert_req = None if not ask_yes_no('Do you already have a certificate request?'): private_key, key_type = generate_private_key() - cert_req = generate_certificate_request(private_key, key_type, return_request=True) + cert_req = generate_certificate_request( + private_key, key_type, return_request=True + ) else: - print("Paste certificate request and press enter:") + print('Paste certificate request and press enter:') lines = [] curr_line = '' while True: @@ -492,18 +662,20 @@ def generate_certificate_sign(name, ca_name, install=False, file=False): lines.append(curr_line) if not lines: - print("Aborted") + print('Aborted') return None - wrap = lines[0].find('-----') < 0 # Only base64 pasted, add the CSR tags for parsing - cert_req = load_certificate_request("\n".join(lines), wrap) + wrap = ( + lines[0].find('-----') < 0 + ) # Only base64 pasted, add the CSR tags for parsing + cert_req = load_certificate_request('\n'.join(lines), wrap) if not cert_req: - print("Invalid certificate request") + print('Invalid certificate request') return None cert = generate_certificate(cert_req, ca_cert, ca_private_key, is_ca=False) - + passphrase = None if private_key is not None: passphrase = ask_passphrase() @@ -515,12 +687,17 @@ def generate_certificate_sign(name, ca_name, install=False, file=False): return None if install: - install_certificate(name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=False) + install_certificate( + name, cert, private_key, key_type, key_passphrase=passphrase, is_ca=False + ) if file: write_file(f'{name}.pem', encode_certificate(cert)) if private_key is not None: - write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) + write_file( + f'{name}.key', encode_private_key(private_key, passphrase=passphrase) + ) + def generate_certificate_selfsign(name, install=False, file=False): private_key, key_type = generate_private_key() @@ -534,11 +711,21 @@ def generate_certificate_selfsign(name, install=False, file=False): return None if install: - install_certificate(name, cert, private_key=private_key, key_type=key_type, key_passphrase=passphrase, is_ca=False) + install_certificate( + name, + cert, + private_key=private_key, + key_type=key_type, + key_passphrase=passphrase, + is_ca=False, + ) if file: write_file(f'{name}.pem', encode_certificate(cert)) - write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) + write_file( + f'{name}.key', encode_private_key(private_key, passphrase=passphrase) + ) + def generate_certificate_revocation_list(ca_name, install=False, file=False): ca_dict = get_config_ca_certificate(ca_name) @@ -550,17 +737,19 @@ def generate_certificate_revocation_list(ca_name, install=False, file=False): ca_cert = load_certificate(ca_dict['certificate']) if not ca_cert: - print("Failed to load CA certificate, aborting") + print('Failed to load CA certificate, aborting') return None ca_private = ca_dict['private'] ca_private_passphrase = None if 'password_protected' in ca_private: ca_private_passphrase = ask_input('Enter CA private key passphrase:') - ca_private_key = load_private_key(ca_private['key'], passphrase=ca_private_passphrase) + ca_private_key = load_private_key( + ca_private['key'], passphrase=ca_private_passphrase + ) if not ca_private_key: - print("Failed to load CA private key, aborting") + print('Failed to load CA private key, aborting') return None revoked_certs = get_config_revoked_certificates() @@ -581,13 +770,13 @@ def generate_certificate_revocation_list(ca_name, install=False, file=False): continue if not to_revoke: - print("No revoked certificates to add to the CRL") + print('No revoked certificates to add to the CRL') return None crl = create_certificate_revocation_list(ca_cert, ca_private_key, to_revoke) if not crl: - print("Failed to create CRL") + print('Failed to create CRL') return None if not install and not file: @@ -598,7 +787,8 @@ def generate_certificate_revocation_list(ca_name, install=False, file=False): install_crl(ca_name, crl) if file: - write_file(f'{name}.crl', encode_certificate(crl)) + write_file(f'{ca_name}.crl', encode_certificate(crl)) + def generate_ssh_keypair(name, install=False, file=False): private_key, key_type = generate_private_key() @@ -607,29 +797,42 @@ def generate_ssh_keypair(name, install=False, file=False): if not install and not file: print(encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH')) - print("") - print(encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase)) + print('') + print( + encode_private_key( + private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase + ) + ) return None if install: install_ssh_key(name, public_key, private_key, passphrase) if file: - write_file(f'{name}.pem', encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH')) - write_file(f'{name}.key', encode_private_key(private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase)) + write_file( + f'{name}.pem', + encode_public_key(public_key, encoding='OpenSSH', key_format='OpenSSH'), + ) + write_file( + f'{name}.key', + encode_private_key( + private_key, encoding='PEM', key_format='OpenSSH', passphrase=passphrase + ), + ) + def generate_dh_parameters(name, install=False, file=False): bits = ask_input('Enter DH parameters key size:', default=2048, numeric_only=True) - print("Generating parameters...") + print('Generating parameters...') dh_params = create_dh_parameters(bits) if not dh_params: - print("Failed to create DH parameters") + print('Failed to create DH parameters') return None if not install and not file: - print("DH Parameters:") + print('DH Parameters:') print(encode_dh_parameters(dh_params)) if install: @@ -638,6 +841,7 @@ def generate_dh_parameters(name, install=False, file=False): if file: write_file(f'{name}.pem', encode_dh_parameters(dh_params)) + def generate_keypair(name, install=False, file=False): private_key, key_type = generate_private_key() public_key = private_key.public_key() @@ -645,7 +849,7 @@ def generate_keypair(name, install=False, file=False): if not install and not file: print(encode_public_key(public_key)) - print("") + print('') print(encode_private_key(private_key, passphrase=passphrase)) return None @@ -654,13 +858,16 @@ def generate_keypair(name, install=False, file=False): if file: write_file(f'{name}.pem', encode_public_key(public_key)) - write_file(f'{name}.key', encode_private_key(private_key, passphrase=passphrase)) + write_file( + f'{name}.key', encode_private_key(private_key, passphrase=passphrase) + ) + def generate_openvpn_key(name, install=False, file=False): result = cmd('openvpn --genkey secret /dev/stdout | grep -o "^[^#]*"') if not result: - print("Failed to generate OpenVPN key") + print('Failed to generate OpenVPN key') return None if not install and not file: @@ -668,11 +875,13 @@ def generate_openvpn_key(name, install=False, file=False): return None if install: - key_lines = result.split("\n") - key_data = "".join(key_lines[1:-1]) # Remove wrapper tags and line endings + key_lines = result.split('\n') + key_data = ''.join(key_lines[1:-1]) # Remove wrapper tags and line endings key_version = '1' - version_search = re.search(r'BEGIN OpenVPN Static key V(\d+)', result) # Future-proofing (hopefully) + version_search = re.search( + r'BEGIN OpenVPN Static key V(\d+)', result + ) # Future-proofing (hopefully) if version_search: key_version = version_search[1] @@ -681,6 +890,7 @@ def generate_openvpn_key(name, install=False, file=False): if file: write_file(f'{name}.key', result) + def generate_wireguard_key(interface=None, install=False): private_key = cmd('wg genkey') public_key = cmd('wg pubkey', input=private_key) @@ -691,6 +901,7 @@ def generate_wireguard_key(interface=None, install=False): print(f'Private key: {private_key}') print(f'Public key: {public_key}', end='\n\n') + def generate_wireguard_psk(interface=None, peer=None, install=False): psk = cmd('wg genpsk') if interface and peer and install: @@ -698,8 +909,11 @@ def generate_wireguard_psk(interface=None, peer=None, install=False): else: print(f'Pre-shared key: {psk}') + # Import functions -def import_ca_certificate(name, path=None, key_path=None, no_prompt=False, passphrase=None): +def import_ca_certificate( + name, path=None, key_path=None, no_prompt=False, passphrase=None +): if path: if not os.path.exists(path): print(f'File not found: {path}') @@ -736,7 +950,10 @@ def import_ca_certificate(name, path=None, key_path=None, no_prompt=False, passp install_certificate(name, private_key=key, is_ca=True) -def import_certificate(name, path=None, key_path=None, no_prompt=False, passphrase=None): + +def import_certificate( + name, path=None, key_path=None, no_prompt=False, passphrase=None +): if path: if not os.path.exists(path): print(f'File not found: {path}') @@ -773,6 +990,7 @@ def import_certificate(name, path=None, key_path=None, no_prompt=False, passphra install_certificate(name, private_key=key, is_ca=False) + def import_crl(name, path): if not os.path.exists(path): print(f'File not found: {path}') @@ -790,6 +1008,7 @@ def import_crl(name, path): install_crl(name, crl) + def import_dh_parameters(name, path): if not os.path.exists(path): print(f'File not found: {path}') @@ -807,6 +1026,7 @@ def import_dh_parameters(name, path): install_dh_parameters(name, dh) + def import_keypair(name, path=None, key_path=None, no_prompt=False, passphrase=None): if path: if not os.path.exists(path): @@ -844,6 +1064,7 @@ def import_keypair(name, path=None, key_path=None, no_prompt=False, passphrase=N install_keypair(name, None, private_key=key, prompt=False) + def import_openvpn_secret(name, path): if not os.path.exists(path): print(f'File not found: {path}') @@ -853,19 +1074,134 @@ def import_openvpn_secret(name, path): key_version = '1' with open(path) as f: - key_lines = f.read().strip().split("\n") - key_lines = list(filter(lambda line: not line.strip().startswith('#'), key_lines)) # Remove commented lines - key_data = "".join(key_lines[1:-1]) # Remove wrapper tags and line endings - - version_search = re.search(r'BEGIN OpenVPN Static key V(\d+)', key_lines[0]) # Future-proofing (hopefully) + key_lines = f.read().strip().split('\n') + key_lines = list( + filter(lambda line: not line.strip().startswith('#'), key_lines) + ) # Remove commented lines + key_data = ''.join(key_lines[1:-1]) # Remove wrapper tags and line endings + + version_search = re.search( + r'BEGIN OpenVPN Static key V(\d+)', key_lines[0] + ) # Future-proofing (hopefully) if version_search: key_version = version_search[1] install_openvpn_key(name, key_data, key_version) -# Show functions -def show_certificate_authority(name=None, pem=False): - headers = ['Name', 'Subject', 'Issuer CN', 'Issued', 'Expiry', 'Private Key', 'Parent'] + +def generate_pki( + raw: bool, + pki_type: ArgsPkiTypeGen, + name: typing.Optional[str], + file: typing.Optional[bool], + install: typing.Optional[bool], + sign: typing.Optional[str], + self_sign: typing.Optional[bool], + key: typing.Optional[bool], + psk: typing.Optional[bool], + interface: typing.Optional[str], + peer: typing.Optional[str], +): + try: + if pki_type == 'ca': + if sign: + generate_ca_certificate_sign(name, sign, install=install, file=file) + else: + generate_ca_certificate(name, install=install, file=file) + elif pki_type == 'certificate': + if sign: + generate_certificate_sign(name, sign, install=install, file=file) + elif self_sign: + generate_certificate_selfsign(name, install=install, file=file) + else: + generate_certificate_request(name=name, install=install, file=file) + + elif pki_type == 'crl': + generate_certificate_revocation_list(name, install=install, file=file) + + elif pki_type == 'ssh': + generate_ssh_keypair(name, install=install, file=file) + + elif pki_type == 'dh': + generate_dh_parameters(name, install=install, file=file) + + elif pki_type == 'key-pair': + generate_keypair(name, install=install, file=file) + + elif pki_type == 'openvpn': + generate_openvpn_key(name, install=install, file=file) + + elif pki_type == 'wireguard': + # WireGuard supports writing key directly into the CLI, but this + # requires the vyos_libexec_dir environment variable to be set + os.environ['vyos_libexec_dir'] = '/usr/libexec/vyos' + + if key: + generate_wireguard_key(interface, install=install) + if psk: + generate_wireguard_psk(interface, peer=peer, install=install) + except KeyboardInterrupt: + print('Aborted') + sys.exit(0) + + +def import_pki( + name: str, + pki_type: ArgsPkiType, + filename: typing.Optional[str], + key_filename: typing.Optional[str], + no_prompt: typing.Optional[bool], + passphrase: typing.Optional[str], +): + try: + if pki_type == 'ca': + import_ca_certificate( + name, + path=filename, + key_path=key_filename, + no_prompt=no_prompt, + passphrase=passphrase, + ) + elif pki_type == 'certificate': + import_certificate( + name, + path=filename, + key_path=key_filename, + no_prompt=no_prompt, + passphrase=passphrase, + ) + elif pki_type == 'crl': + import_crl(name, filename) + elif pki_type == 'dh': + import_dh_parameters(name, filename) + elif pki_type == 'key-pair': + import_keypair( + name, + path=filename, + key_path=key_filename, + no_prompt=no_prompt, + passphrase=passphrase, + ) + elif pki_type == 'openvpn': + import_openvpn_secret(name, filename) + except KeyboardInterrupt: + print('Aborted') + sys.exit(0) + + +@_verify('ca') +def show_certificate_authority( + raw: bool, name: typing.Optional[str] = None, pem: typing.Optional[bool] = False +): + headers = [ + 'Name', + 'Subject', + 'Issuer CN', + 'Issued', + 'Expiry', + 'Private Key', + 'Parent', + ] data = [] certs = get_config_ca_certificate() if certs: @@ -882,7 +1218,7 @@ def show_certificate_authority(name=None, pem=False): return parent_ca_name = get_certificate_ca(cert, certs) - cert_issuer_cn = cert.issuer.rfc4514_string().split(",")[0] + cert_issuer_cn = cert.issuer.rfc4514_string().split(',')[0] if not parent_ca_name or parent_ca_name == cert_name: parent_ca_name = 'N/A' @@ -890,14 +1226,45 @@ def show_certificate_authority(name=None, pem=False): if not cert: continue - have_private = 'Yes' if 'private' in cert_dict and 'key' in cert_dict['private'] else 'No' - data.append([cert_name, cert.subject.rfc4514_string(), cert_issuer_cn, cert.not_valid_before, cert.not_valid_after, have_private, parent_ca_name]) - - print("Certificate Authorities:") + have_private = ( + 'Yes' + if 'private' in cert_dict and 'key' in cert_dict['private'] + else 'No' + ) + data.append( + [ + cert_name, + cert.subject.rfc4514_string(), + cert_issuer_cn, + cert.not_valid_before, + cert.not_valid_after, + have_private, + parent_ca_name, + ] + ) + + print('Certificate Authorities:') print(tabulate.tabulate(data, headers)) -def show_certificate(name=None, pem=False, fingerprint_hash=None): - headers = ['Name', 'Type', 'Subject CN', 'Issuer CN', 'Issued', 'Expiry', 'Revoked', 'Private Key', 'CA Present'] + +@_verify('certificate') +def show_certificate( + raw: bool, + name: typing.Optional[str] = None, + pem: typing.Optional[bool] = False, + fingerprint: typing.Optional[ArgsFingerprint] = None, +): + headers = [ + 'Name', + 'Type', + 'Subject CN', + 'Issuer CN', + 'Issued', + 'Expiry', + 'Revoked', + 'Private Key', + 'CA Present', + ] data = [] certs = get_config_certificate() if certs: @@ -917,13 +1284,13 @@ def show_certificate(name=None, pem=False, fingerprint_hash=None): if name and pem: print(encode_certificate(cert)) return - elif name and fingerprint_hash: - print(get_certificate_fingerprint(cert, fingerprint_hash)) + elif name and fingerprint: + print(get_certificate_fingerprint(cert, fingerprint)) return ca_name = get_certificate_ca(cert, ca_certs) - cert_subject_cn = cert.subject.rfc4514_string().split(",")[0] - cert_issuer_cn = cert.issuer.rfc4514_string().split(",")[0] + cert_subject_cn = cert.subject.rfc4514_string().split(',')[0] + cert_issuer_cn = cert.issuer.rfc4514_string().split(',')[0] cert_type = 'Unknown' try: @@ -932,21 +1299,37 @@ def show_certificate(name=None, pem=False, fingerprint_hash=None): cert_type = 'Server' elif ext and ExtendedKeyUsageOID.CLIENT_AUTH in ext.value: cert_type = 'Client' - except: + except Exception: pass revoked = 'Yes' if 'revoke' in cert_dict else 'No' - have_private = 'Yes' if 'private' in cert_dict and 'key' in cert_dict['private'] else 'No' + have_private = ( + 'Yes' + if 'private' in cert_dict and 'key' in cert_dict['private'] + else 'No' + ) have_ca = f'Yes ({ca_name})' if ca_name else 'No' - data.append([ - cert_name, cert_type, cert_subject_cn, cert_issuer_cn, - cert.not_valid_before, cert.not_valid_after, - revoked, have_private, have_ca]) - - print("Certificates:") + data.append( + [ + cert_name, + cert_type, + cert_subject_cn, + cert_issuer_cn, + cert.not_valid_before, + cert.not_valid_after, + revoked, + have_private, + have_ca, + ] + ) + + print('Certificates:') print(tabulate.tabulate(data, headers)) -def show_crl(name=None, pem=False): + +def show_crl( + raw: bool, name: typing.Optional[str] = None, pem: typing.Optional[bool] = False +): headers = ['CA Name', 'Updated', 'Revokes'] data = [] certs = get_config_ca_certificate() @@ -971,141 +1354,31 @@ def show_crl(name=None, pem=False): print(encode_certificate(crl)) continue - certs = get_revoked_by_serial_numbers([revoked.serial_number for revoked in crl]) - data.append([cert_name, crl.last_update, ", ".join(certs)]) + certs = get_revoked_by_serial_numbers( + [revoked.serial_number for revoked in crl] + ) + data.append([cert_name, crl.last_update, ', '.join(certs)]) if name and pem: return - print("Certificate Revocation Lists:") + print('Certificate Revocation Lists:') print(tabulate.tabulate(data, headers)) -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--action', help='PKI action', required=True) - - # X509 - parser.add_argument('--ca', help='Certificate Authority', required=False) - parser.add_argument('--certificate', help='Certificate', required=False) - parser.add_argument('--crl', help='Certificate Revocation List', required=False) - parser.add_argument('--sign', help='Sign certificate with specified CA', required=False) - parser.add_argument('--self-sign', help='Self-sign the certificate', action='store_true') - parser.add_argument('--pem', help='Output using PEM encoding', action='store_true') - parser.add_argument('--fingerprint', help='Show fingerprint and exit', action='store') - # SSH - parser.add_argument('--ssh', help='SSH Key', required=False) +def show_all(raw: bool): + show_certificate_authority(raw) + print('\n') + show_certificate(raw) + print('\n') + show_crl(raw) - # DH - parser.add_argument('--dh', help='DH Parameters', required=False) - - # Key pair - parser.add_argument('--keypair', help='Key pair', required=False) - - # OpenVPN - parser.add_argument('--openvpn', help='OpenVPN TLS key', required=False) - - # WireGuard - parser.add_argument('--wireguard', help='Wireguard', action='store_true') - group = parser.add_mutually_exclusive_group() - group.add_argument('--key', help='Wireguard key pair', action='store_true', required=False) - group.add_argument('--psk', help='Wireguard pre shared key', action='store_true', required=False) - parser.add_argument('--interface', help='Install generated keys into running-config for named interface', action='store') - parser.add_argument('--peer', help='Install generated keys into running-config for peer', action='store') - - # Global - parser.add_argument('--file', help='Write generated keys into specified filename', action='store_true') - parser.add_argument('--install', help='Install generated keys into running-config', action='store_true') - - parser.add_argument('--filename', help='Write certificate into specified filename', action='store') - parser.add_argument('--key-filename', help='Write key into specified filename', action='store') - - parser.add_argument('--no-prompt', action='store_true', help='Perform action non-interactively') - parser.add_argument('--passphrase', help='A passphrase to decrypt the private key') - - args = parser.parse_args() +if __name__ == '__main__': try: - if args.action == 'generate': - if args.ca: - if args.sign: - generate_ca_certificate_sign(args.ca, args.sign, install=args.install, file=args.file) - else: - generate_ca_certificate(args.ca, install=args.install, file=args.file) - elif args.certificate: - if args.sign: - generate_certificate_sign(args.certificate, args.sign, install=args.install, file=args.file) - elif args.self_sign: - generate_certificate_selfsign(args.certificate, install=args.install, file=args.file) - else: - generate_certificate_request(name=args.certificate, install=args.install, file=args.file) - - elif args.crl: - generate_certificate_revocation_list(args.crl, install=args.install, file=args.file) - - elif args.ssh: - generate_ssh_keypair(args.ssh, install=args.install, file=args.file) - - elif args.dh: - generate_dh_parameters(args.dh, install=args.install, file=args.file) - - elif args.keypair: - generate_keypair(args.keypair, install=args.install, file=args.file) - - elif args.openvpn: - generate_openvpn_key(args.openvpn, install=args.install, file=args.file) - - elif args.wireguard: - # WireGuard supports writing key directly into the CLI, but this - # requires the vyos_libexec_dir environment variable to be set - os.environ["vyos_libexec_dir"] = "/usr/libexec/vyos" - - if args.key: - generate_wireguard_key(args.interface, install=args.install) - if args.psk: - generate_wireguard_psk(args.interface, peer=args.peer, install=args.install) - elif args.action == 'import': - if args.ca: - import_ca_certificate(args.ca, path=args.filename, key_path=args.key_filename, - no_prompt=args.no_prompt, passphrase=args.passphrase) - elif args.certificate: - import_certificate(args.certificate, path=args.filename, key_path=args.key_filename, - no_prompt=args.no_prompt, passphrase=args.passphrase) - elif args.crl: - import_crl(args.crl, args.filename) - elif args.dh: - import_dh_parameters(args.dh, args.filename) - elif args.keypair: - import_keypair(args.keypair, path=args.filename, key_path=args.key_filename, - no_prompt=args.no_prompt, passphrase=args.passphrase) - elif args.openvpn: - import_openvpn_secret(args.openvpn, args.filename) - elif args.action == 'show': - if args.ca: - ca_name = None if args.ca == 'all' else args.ca - if ca_name: - if not conf.exists(['pki', 'ca', ca_name]): - print(f'CA "{ca_name}" does not exist!') - exit(1) - show_certificate_authority(ca_name, args.pem) - elif args.certificate: - cert_name = None if args.certificate == 'all' else args.certificate - if cert_name: - if not conf.exists(['pki', 'certificate', cert_name]): - print(f'Certificate "{cert_name}" does not exist!') - exit(1) - if args.fingerprint is None: - show_certificate(None if args.certificate == 'all' else args.certificate, args.pem) - else: - show_certificate(args.certificate, fingerprint_hash=args.fingerprint) - elif args.crl: - show_crl(None if args.crl == 'all' else args.crl, args.pem) - else: - show_certificate_authority() - print('\n') - show_certificate() - print('\n') - show_crl() - except KeyboardInterrupt: - print("Aborted") - sys.exit(0) + res = vyos.opmode.run(sys.modules[__name__]) + if res: + print(res) + except (ValueError, vyos.opmode.Error) as e: + print(e) + sys.exit(1) diff --git a/src/op_mode/restart.py b/src/op_mode/restart.py index a83c8b9d8..3b0031f34 100755 --- a/src/op_mode/restart.py +++ b/src/op_mode/restart.py @@ -41,6 +41,10 @@ service_map = { 'systemd_service': 'pdns-recursor', 'path': ['service', 'dns', 'forwarding'], }, + 'haproxy': { + 'systemd_service': 'haproxy', + 'path': ['load-balancing', 'haproxy'], + }, 'igmp_proxy': { 'systemd_service': 'igmpproxy', 'path': ['protocols', 'igmp-proxy'], @@ -53,10 +57,6 @@ service_map = { 'systemd_service': 'avahi-daemon', 'path': ['service', 'mdns', 'repeater'], }, - 'reverse_proxy': { - 'systemd_service': 'haproxy', - 'path': ['load-balancing', 'reverse-proxy'], - }, 'router_advert': { 'systemd_service': 'radvd', 'path': ['service', 'router-advert'], @@ -83,10 +83,10 @@ services = typing.Literal[ 'dhcpv6', 'dns_dynamic', 'dns_forwarding', + 'haproxy', 'igmp_proxy', 'ipsec', 'mdns_repeater', - 'reverse_proxy', 'router_advert', 'snmp', 'ssh', diff --git a/src/op_mode/vrrp.py b/src/op_mode/vrrp.py index 60be86065..ef1338e23 100755 --- a/src/op_mode/vrrp.py +++ b/src/op_mode/vrrp.py @@ -13,47 +13,324 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - +import json import sys -import argparse +import typing + +from jinja2 import Template -from vyos.configquery import ConfigTreeQuery -from vyos.ifconfig.vrrp import VRRP +import vyos.opmode +from vyos.ifconfig import VRRP from vyos.ifconfig.vrrp import VRRPNoData -parser = argparse.ArgumentParser() -group = parser.add_mutually_exclusive_group() -group.add_argument("-s", "--summary", action="store_true", help="Print VRRP summary") -group.add_argument("-t", "--statistics", action="store_true", help="Print VRRP statistics") -group.add_argument("-d", "--data", action="store_true", help="Print detailed VRRP data") - -args = parser.parse_args() - -def is_configured(): - """ Check if VRRP is configured """ - config = ConfigTreeQuery() - if not config.exists(['high-availability', 'vrrp', 'group']): - return False - return True - -# Exit early if VRRP is dead or not configured -if is_configured() == False: - print('VRRP not configured!') - exit(0) -if not VRRP.is_running(): - print('VRRP is not running') - sys.exit(0) - -try: - if args.summary: - print(VRRP.format(VRRP.collect('json'))) - elif args.statistics: - print(VRRP.collect('stats')) - elif args.data: - print(VRRP.collect('state')) - else: - parser.print_help() + +stat_template = Template(""" +{% for rec in instances %} +VRRP Instance: {{rec.instance}} + Advertisements: + Received: {{rec.advert_rcvd}} + Sent: {{rec.advert_sent}} + Became master: {{rec.become_master}} + Released master: {{rec.release_master}} + Packet Errors: + Length: {{rec.packet_len_err}} + TTL: {{rec.ip_ttl_err}} + Invalid Type: {{rec.invalid_type_rcvd}} + Advertisement Interval: {{rec.advert_interval_err}} + Address List: {{rec.addr_list_err}} + Authentication Errors: + Invalid Type: {{rec.invalid_authtype}} + Type Mismatch: {{rec.authtype_mismatch}} + Failure: {{rec.auth_failure}} + Priority Zero: + Received: {{rec.pri_zero_rcvd}} + Sent: {{rec.pri_zero_sent}} +{% endfor %} +""") + +detail_template = Template(""" +{%- for rec in instances %} + VRRP Instance: {{rec.iname}} + VRRP Version: {{rec.version}} + State: {{rec.state}} + {% if rec.state == 'BACKUP' -%} + Master priority: {{ rec.master_priority }} + {% if rec.version == 3 -%} + Master advert interval: {{ rec.master_adver_int }} + {% endif -%} + {% endif -%} + Wantstate: {{rec.wantstate}} + Last transition: {{rec.last_transition}} + Interface: {{rec.ifp_ifname}} + {% if rec.dont_track_primary > 0 -%} + VRRP interface tracking disabled + {% endif -%} + {% if rec.skip_check_adv_addr > 0 -%} + Skip checking advert IP addresses + {% endif -%} + {% if rec.strict_mode > 0 -%} + Enforcing strict VRRP compliance + {% endif -%} + Gratuitous ARP delay: {{rec.garp_delay}} + Gratuitous ARP repeat: {{rec.garp_rep}} + Gratuitous ARP refresh: {{rec.garp_refresh}} + Gratuitous ARP refresh repeat: {{rec.garp_refresh_rep}} + Gratuitous ARP lower priority delay: {{rec.garp_lower_prio_delay}} + Gratuitous ARP lower priority repeat: {{rec.garp_lower_prio_rep}} + Send advert after receive lower priority advert: {{rec.lower_prio_no_advert}} + Send advert after receive higher priority advert: {{rec.higher_prio_send_advert}} + Virtual Router ID: {{rec.vrid}} + Priority: {{rec.base_priority}} + Effective priority: {{rec.effective_priority}} + Advert interval: {{rec.adver_int}} sec + Accept: {{rec.accept}} + Preempt: {{rec.nopreempt}} + {% if rec.preempt_delay -%} + Preempt delay: {{rec.preempt_delay}} + {% endif -%} + Promote secondaries: {{rec.promote_secondaries}} + Authentication type: {{rec.auth_type}} + {% if rec.vips %} + Virtual IP ({{ rec.vips | length }}): + {% for ip in rec.vips -%} + {{ip}} + {% endfor -%} + {% endif -%} + {% if rec.evips %} + Virtual IP Excluded: + {% for ip in rec.evips -%} + {{ip}} + {% endfor -%} + {% endif -%} + {% if rec.vroutes %} + Virtual Routes: + {% for route in rec.vroutes -%} + {{route}} + {% endfor -%} + {% endif -%} + {% if rec.vrules %} + Virtual Rules: + {% for rule in rec.vrules -%} + {{rule}} + {% endfor -%} + {% endif -%} + {% if rec.track_ifp %} + Tracked interfaces: + {% for ifp in rec.track_ifp -%} + {{ifp}} + {% endfor -%} + {% endif -%} + {% if rec.track_script %} + Tracked scripts: + {% for script in rec.track_script -%} + {{script}} + {% endfor -%} + {% endif %} + Using smtp notification: {{rec.smtp_alert}} + Notify deleted: {{rec.notify_deleted}} +{% endfor %} +""") + +# https://github.com/acassen/keepalived/blob/59c39afe7410f927c9894a1bafb87e398c6f02be/keepalived/include/vrrp.h#L126 +VRRP_AUTH_NONE = 0 +VRRP_AUTH_PASS = 1 +VRRP_AUTH_AH = 2 + +# https://github.com/acassen/keepalived/blob/59c39afe7410f927c9894a1bafb87e398c6f02be/keepalived/include/vrrp.h#L417 +VRRP_STATE_INIT = 0 +VRRP_STATE_BACK = 1 +VRRP_STATE_MAST = 2 +VRRP_STATE_FAULT = 3 + +VRRP_AUTH_TO_NAME = { + VRRP_AUTH_NONE: 'NONE', + VRRP_AUTH_PASS: 'SIMPLE_PASSWORD', + VRRP_AUTH_AH: 'IPSEC_AH', +} + +VRRP_STATE_TO_NAME = { + VRRP_STATE_INIT: 'INIT', + VRRP_STATE_BACK: 'BACKUP', + VRRP_STATE_MAST: 'MASTER', + VRRP_STATE_FAULT: 'FAULT', +} + + +def _get_raw_data(group_name: str = None) -> list: + """ + Retrieve raw JSON data for all VRRP groups. + + Args: + group_name (str, optional): If provided, filters the data to only + include the specified vrrp group. + + Returns: + list: A list of raw JSON data for VRRP groups, filtered by group_name + if specified. + """ + try: + output = VRRP.collect('json') + except VRRPNoData as e: + raise vyos.opmode.DataUnavailable(f'{e}') + + data = json.loads(output) + + if not data: + return [] + + if group_name is not None: + for rec in data: + if rec['data'].get('iname') == group_name: + return [rec] + return [] + return data + + +def _get_formatted_statistics_output(data: list) -> str: + """ + Prepare formatted statistics output from the given data. + + Args: + data (list): A list of dictionaries containing vrrp grop information + and statistics. + + Returns: + str: Rendered statistics output based on the provided data. + """ + instances = list() + for instance in data: + instances.append( + {'instance': instance['data'].get('iname'), **instance['stats']} + ) + + return stat_template.render(instances=instances) + + +def _process_field(data: dict, field: str, true_value: str, false_value: str): + """ + Updates the given field in the data dictionary with a specified value based + on its truthiness. + + Args: + data (dict): The dictionary containing the field to be processed. + field (str): The key representing the field in the dictionary. + true_value (str): The value to set if the field's value is truthy. + false_value (str): The value to set if the field's value is falsy. + + Returns: + None: The function modifies the dictionary in place. + """ + data[field] = true_value if data.get(field) else false_value + + +def _get_formatted_detail_output(data: list) -> str: + """ + Prepare formatted detail information output from the given data. + + Args: + data (list): A list of dictionaries containing vrrp grop information + and statistics. + + Returns: + str: Rendered detail info output based on the provided data. + """ + instances = list() + for instance in data: + instance['data']['state'] = VRRP_STATE_TO_NAME.get( + instance['data'].get('state'), 'unknown' + ) + instance['data']['wantstate'] = VRRP_STATE_TO_NAME.get( + instance['data'].get('wantstate'), 'unknown' + ) + instance['data']['auth_type'] = VRRP_AUTH_TO_NAME.get( + instance['data'].get('auth_type'), 'unknown' + ) + _process_field(instance['data'], 'lower_prio_no_advert', 'false', 'true') + _process_field(instance['data'], 'higher_prio_send_advert', 'true', 'false') + _process_field(instance['data'], 'accept', 'Enabled', 'Disabled') + _process_field(instance['data'], 'notify_deleted', 'Deleted', 'Fault') + _process_field(instance['data'], 'smtp_alert', 'yes', 'no') + _process_field(instance['data'], 'nopreempt', 'Disabled', 'Enabled') + _process_field(instance['data'], 'promote_secondaries', 'Enabled', 'Disabled') + instance['data']['vips'] = instance['data'].get('vips', False) + instance['data']['evips'] = instance['data'].get('evips', False) + instance['data']['vroutes'] = instance['data'].get('vroutes', False) + instance['data']['vrules'] = instance['data'].get('vrules', False) + + instances.append(instance['data']) + + return detail_template.render(instances=instances) + + +def show_detail( + raw: bool, group_name: typing.Optional[str] = None +) -> typing.Union[list, str]: + """ + Display detailed information about the VRRP group. + + Args: + raw (bool): If True, return raw data instead of formatted output. + group_name (str, optional): Filter the data by a specific group name, + if provided. + + Returns: + list or str: Raw data if `raw` is True, otherwise a formatted detail + output. + """ + data = _get_raw_data(group_name) + + if raw: + return data + + return _get_formatted_detail_output(data) + + +def show_statistics( + raw: bool, group_name: typing.Optional[str] = None +) -> typing.Union[list, str]: + """ + Display VRRP group statistics. + + Args: + raw (bool): If True, return raw data instead of formatted output. + group_name (str, optional): Filter the data by a specific group name, + if provided. + + Returns: + list or str: Raw data if `raw` is True, otherwise a formatted statistic + output. + """ + data = _get_raw_data(group_name) + + if raw: + return data + + return _get_formatted_statistics_output(data) + + +def show_summary(raw: bool) -> typing.Union[list, str]: + """ + Display a summary of VRRP group. + + Args: + raw (bool): If True, return raw data instead of formatted output. + + Returns: + list or str: Raw data if `raw` is True, otherwise a formatted summary output. + """ + data = _get_raw_data() + + if raw: + return data + + return VRRP.format(data) + + +if __name__ == '__main__': + try: + res = vyos.opmode.run(sys.modules[__name__]) + if res: + print(res) + except (ValueError, vyos.opmode.Error) as e: + print(e) sys.exit(1) -except VRRPNoData as e: - print(e) - sys.exit(1) diff --git a/src/services/api/__init__.py b/src/services/api/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/services/api/__init__.py diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py index ef4966466..ebf745f32 100644 --- a/src/services/api/graphql/bindings.py +++ b/src/services/api/graphql/bindings.py @@ -1,4 +1,4 @@ -# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -13,24 +13,40 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see <http://www.gnu.org/licenses/>. + import vyos.defaults -from . graphql.queries import query -from . graphql.mutations import mutation -from . graphql.directives import directives_dict -from . graphql.errors import op_mode_error -from . graphql.auth_token_mutation import auth_token_mutation -from . libs.token_auth import init_secret -from . import state -from ariadne import make_executable_schema, load_schema_from_path, snake_case_fallback_resolvers + +from ariadne import make_executable_schema +from ariadne import load_schema_from_path +from ariadne import snake_case_fallback_resolvers + +from .graphql.queries import query +from .graphql.mutations import mutation +from .graphql.directives import directives_dict +from .graphql.errors import op_mode_error +from .graphql.auth_token_mutation import auth_token_mutation +from .libs.token_auth import init_secret + +from ..session import SessionState + def generate_schema(): + state = SessionState() api_schema_dir = vyos.defaults.directories['api_schema'] - if state.settings['app'].state.vyos_auth_type == 'token': + if state.auth_type == 'token': init_secret() type_defs = load_schema_from_path(api_schema_dir) - schema = make_executable_schema(type_defs, query, op_mode_error, mutation, auth_token_mutation, snake_case_fallback_resolvers, directives=directives_dict) + schema = make_executable_schema( + type_defs, + query, + op_mode_error, + mutation, + auth_token_mutation, + snake_case_fallback_resolvers, + directives=directives_dict, + ) return schema diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py index a53fa4d60..c74364603 100644 --- a/src/services/api/graphql/graphql/auth_token_mutation.py +++ b/src/services/api/graphql/graphql/auth_token_mutation.py @@ -19,11 +19,12 @@ from typing import Dict from ariadne import ObjectType from graphql import GraphQLResolveInfo -from .. libs.token_auth import generate_token -from .. session.session import get_user_info -from .. import state +from ..libs.token_auth import generate_token +from ..session.session import get_user_info +from ...session import SessionState + +auth_token_mutation = ObjectType('Mutation') -auth_token_mutation = ObjectType("Mutation") @auth_token_mutation.field('AuthToken') def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): @@ -31,10 +32,13 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): user = data['username'] passwd = data['password'] - secret = state.settings['secret'] - exp_interval = int(state.settings['app'].state.vyos_token_exp) - expiration = (datetime.datetime.now(tz=datetime.timezone.utc) + - datetime.timedelta(seconds=exp_interval)) + state = SessionState() + + secret = getattr(state, 'secret', '') + exp_interval = int(state.token_exp) + expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=exp_interval + ) res = generate_token(user, passwd, secret, expiration) try: @@ -44,18 +48,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): pass if 'token' in res: data['result'] = res - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} if 'errors' in res: - return { - "success": False, - "errors": res['errors'] - } - - return { - "success": False, - "errors": ['token generation failed'] - } + return {'success': False, 'errors': res['errors']} + + return {'success': False, 'errors': ['token generation failed']} diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py index d115a8e94..0b391c070 100644 --- a/src/services/api/graphql/graphql/mutations.py +++ b/src/services/api/graphql/graphql/mutations.py @@ -14,20 +14,23 @@ # along with this library. If not, see <http://www.gnu.org/licenses/>. from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from .. import state -from .. libs import key_auth -from api.graphql.session.session import Session -from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -mutation = ObjectType("Mutation") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +mutation = ObjectType('Mutation') + def make_mutation_resolver(mutation_name, class_name, session_func): """Dynamically generate a resolver for the mutation named in the @@ -45,12 +48,13 @@ def make_mutation_resolver(mutation_name, class_name, session_func): func_base_name = convert_camel_case_to_snake(class_name) resolver_name = f'resolve_{func_base_name}' func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)' + state = SessionState() @mutation.field(mutation_name) @with_signature(func_sig, func_name=resolver_name) async def func_impl(*args, **kwargs): try: - auth_type = state.settings['app'].state.vyos_auth_type + auth_type = state.auth_type if auth_type == 'key': data = kwargs['data'] @@ -58,10 +62,7 @@ def make_mutation_resolver(mutation_name, class_name, session_func): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -76,21 +77,15 @@ def make_mutation_resolver(mutation_name, class_name, session_func): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no - # vyos_auth_type; validation and defaultValue ensure it is + # auth_type; validation and defaultValue ensure it is # one of the previous cases, so this is never reached. pass - session = state.settings['app'].state.vyos_session + session = state.session # one may override the session functions with a local subclass try: @@ -105,35 +100,36 @@ def make_mutation_resolver(mutation_name, class_name, session_func): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errore": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errore': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) + def make_gen_op_mutation_resolver(mutation_name): return make_mutation_resolver(mutation_name, mutation_name, 'gen_op_mutation') + def make_composite_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py index 717098259..9303fe909 100644 --- a/src/services/api/graphql/graphql/queries.py +++ b/src/services/api/graphql/graphql/queries.py @@ -14,20 +14,23 @@ # along with this library. If not, see <http://www.gnu.org/licenses/>. from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from .. import state -from .. libs import key_auth -from api.graphql.session.session import Session -from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -query = ObjectType("Query") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +query = ObjectType('Query') + def make_query_resolver(query_name, class_name, session_func): """Dynamically generate a resolver for the query named in the @@ -45,12 +48,13 @@ def make_query_resolver(query_name, class_name, session_func): func_base_name = convert_camel_case_to_snake(class_name) resolver_name = f'resolve_{func_base_name}' func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)' + state = SessionState() @query.field(query_name) @with_signature(func_sig, func_name=resolver_name) async def func_impl(*args, **kwargs): try: - auth_type = state.settings['app'].state.vyos_auth_type + auth_type = state.auth_type if auth_type == 'key': data = kwargs['data'] @@ -58,10 +62,7 @@ def make_query_resolver(query_name, class_name, session_func): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -76,21 +77,15 @@ def make_query_resolver(query_name, class_name, session_func): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no - # vyos_auth_type; validation and defaultValue ensure it is + # auth_type; validation and defaultValue ensure it is # one of the previous cases, so this is never reached. pass - session = state.settings['app'].state.vyos_session + session = state.session # one may override the session functions with a local subclass try: @@ -105,35 +100,36 @@ def make_query_resolver(query_name, class_name, session_func): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errors": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errors': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) + def make_gen_op_query_resolver(query_name): return make_query_resolver(query_name, query_name, 'gen_op_query') + def make_composite_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) diff --git a/src/services/api/graphql/libs/__init__.py b/src/services/api/graphql/libs/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/services/api/graphql/libs/__init__.py diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py index 2db0f7d48..ffd7f32b2 100644 --- a/src/services/api/graphql/libs/key_auth.py +++ b/src/services/api/graphql/libs/key_auth.py @@ -1,5 +1,21 @@ +# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +from ...session import SessionState -from .. import state def check_auth(key_list, key): if not key_list: @@ -10,9 +26,11 @@ def check_auth(key_list, key): key_id = k['id'] return key_id + def auth_required(key): + state = SessionState() api_keys = None - api_keys = state.settings['app'].state.vyos_keys + api_keys = state.keys key_id = check_auth(api_keys, key) - state.settings['app'].state.vyos_id = key_id + state.id = key_id return key_id diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py index 8585485c9..4f743a096 100644 --- a/src/services/api/graphql/libs/token_auth.py +++ b/src/services/api/graphql/libs/token_auth.py @@ -1,46 +1,67 @@ +# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + import jwt import uuid import pam from secrets import token_hex -from .. import state +from ...session import SessionState + def _check_passwd_pam(username: str, passwd: str) -> bool: if pam.authenticate(username, passwd): return True return False + def init_secret(): - length = int(state.settings['app'].state.vyos_secret_len) + state = SessionState() + length = int(state.secret_len) secret = token_hex(length) - state.settings['secret'] = secret + state.secret = secret + def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict: if user is None or passwd is None: return {} + state = SessionState() if _check_passwd_pam(user, passwd): - app = state.settings['app'] try: - users = app.state.vyos_token_users + users = state.token_users except AttributeError: - app.state.vyos_token_users = {} - users = app.state.vyos_token_users + users = state.token_users = {} user_id = uuid.uuid1().hex payload_data = {'iss': user, 'sub': user_id, 'exp': exp} - secret = state.settings.get('secret') + secret = getattr(state, 'secret', None) if secret is None: - return {"errors": ['missing secret']} - token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256") + return {'errors': ['missing secret']} + token = jwt.encode(payload=payload_data, key=secret, algorithm='HS256') users |= {user_id: user} return {'token': token} else: - return {"errors": ['failed pam authentication']} + return {'errors': ['failed pam authentication']} + def get_user_context(request): context = {} context['request'] = request context['user'] = None + state = SessionState() if 'Authorization' in request.headers: auth = request.headers['Authorization'] scheme, token = auth.split() @@ -48,8 +69,8 @@ def get_user_context(request): return context try: - secret = state.settings.get('secret') - payload = jwt.decode(token, secret, algorithms=["HS256"]) + secret = getattr(state, 'secret', None) + payload = jwt.decode(token, secret, algorithms=['HS256']) user_id: str = payload.get('sub') if user_id is None: return context @@ -59,7 +80,7 @@ def get_user_context(request): except jwt.PyJWTError: return context try: - users = state.settings['app'].state.vyos_token_users + users = state.token_users except AttributeError: return context diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py new file mode 100644 index 000000000..ed3ee1e8c --- /dev/null +++ b/src/services/api/graphql/routers.py @@ -0,0 +1,77 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +# pylint: disable=import-outside-toplevel + + +import typing + +from ariadne.asgi import GraphQL +from starlette.middleware.cors import CORSMiddleware + + +if typing.TYPE_CHECKING: + from fastapi import FastAPI + + +def graphql_init(app: 'FastAPI'): + from ..session import SessionState + from .libs.token_auth import get_user_context + + state = SessionState() + + # import after initializaion of state + from .bindings import generate_schema + + schema = generate_schema() + + in_spec = state.introspection + + # remove route and reinstall below, for any changes; alternatively, test + # for config_diff before proceeding + graphql_clear(app) + + if state.origins: + origins = state.origins + app.add_route( + '/graphql', + CORSMiddleware( + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + allow_origins=origins, + allow_methods=('GET', 'POST', 'OPTIONS'), + allow_headers=('Authorization',), + ), + ) + else: + app.add_route( + '/graphql', + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + ) + + +def graphql_clear(app: 'FastAPI'): + for r in app.routes: + if r.path == '/graphql': + app.routes.remove(r) diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py index 6ae44b9bf..619534f43 100644 --- a/src/services/api/graphql/session/session.py +++ b/src/services/api/graphql/session/session.py @@ -28,34 +28,45 @@ from api.graphql.libs.op_mode import normalize_output op_mode_include_file = os.path.join(directories['data'], 'op-mode-standardized.json') -def get_config_dict(path=[], effective=False, key_mangling=None, - get_first_key=False, no_multi_convert=False, - no_tag_node_value_mangle=False): + +def get_config_dict( + path=[], + effective=False, + key_mangling=None, + get_first_key=False, + no_multi_convert=False, + no_tag_node_value_mangle=False, +): config = Config() - return config.get_config_dict(path=path, effective=effective, - key_mangling=key_mangling, - get_first_key=get_first_key, - no_multi_convert=no_multi_convert, - no_tag_node_value_mangle=no_tag_node_value_mangle) + return config.get_config_dict( + path=path, + effective=effective, + key_mangling=key_mangling, + get_first_key=get_first_key, + no_multi_convert=no_multi_convert, + no_tag_node_value_mangle=no_tag_node_value_mangle, + ) + def get_user_info(user): user_info = {} - info = get_config_dict(['system', 'login', 'user', user], - get_first_key=True) + info = get_config_dict(['system', 'login', 'user', user], get_first_key=True) if not info: - raise ValueError("No such user") + raise ValueError('No such user') user_info['user'] = user user_info['full_name'] = info.get('full-name', '') return user_info + class Session: """ Wrapper for calling configsession functions based on GraphQL requests. Non-nullable fields in the respective schema allow avoiding a key check in 'data'. """ + def __init__(self, session, data): self._session = session self._data = data @@ -138,7 +149,6 @@ class Session: return res def show_user_info(self): - session = self._session data = self._data user_info = {} @@ -151,10 +161,9 @@ class Session: return user_info def system_status(self): - import api.graphql.session.composite.system_status as system_status + from api.graphql.session.composite import system_status session = self._session - data = self._data status = {} status['host_name'] = session.show(['host', 'name']).strip() @@ -165,7 +174,6 @@ class Session: return status def gen_op_query(self): - session = self._session data = self._data name = self._name op_mode_list = self._op_mode_list @@ -189,7 +197,6 @@ class Session: return res def gen_op_mutation(self): - session = self._session data = self._data name = self._name op_mode_list = self._op_mode_list diff --git a/src/services/api/graphql/state.py b/src/services/api/graphql/state.py deleted file mode 100644 index 63db9f4ef..000000000 --- a/src/services/api/graphql/state.py +++ /dev/null @@ -1,4 +0,0 @@ - -def init(): - global settings - settings = {} diff --git a/src/services/api/rest/__init__.py b/src/services/api/rest/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/services/api/rest/__init__.py diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py new file mode 100644 index 000000000..27d9fb5ee --- /dev/null +++ b/src/services/api/rest/models.py @@ -0,0 +1,313 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +# pylint: disable=too-few-public-methods + +import json +from html import escape +from enum import Enum +from typing import List +from typing import Union +from typing import Dict +from typing import Self + +from pydantic import BaseModel +from pydantic import StrictStr +from pydantic import field_validator +from pydantic import model_validator +from fastapi.responses import HTMLResponse + + +def error(code, msg): + msg = escape(msg, quote=False) + resp = {'success': False, 'error': msg, 'data': None} + resp = json.dumps(resp) + return HTMLResponse(resp, status_code=code) + + +def success(data): + resp = {'success': True, 'data': data, 'error': None} + resp = json.dumps(resp) + return HTMLResponse(resp) + + +# Pydantic models for validation +# Pydantic will cast when possible, so use StrictStr validators added as +# needed for additional constraints +# json_schema_extra adds anotations to OpenAPI to add examples + + +class ApiModel(BaseModel): + key: StrictStr + + +class BasePathModel(BaseModel): + op: StrictStr + path: List[StrictStr] + + @field_validator('path') + @classmethod + def check_non_empty(cls, path: str) -> str: + if not len(path) > 0: + raise ValueError('path must be non-empty') + return path + + +class BaseConfigureModel(BasePathModel): + value: StrictStr = None + + +class ConfigureModel(ApiModel, BaseConfigureModel): + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'set | delete | comment', + 'path': ['config', 'mode', 'path'], + } + } + + +class ConfigureListModel(ApiModel): + commands: List[BaseConfigureModel] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'commands': 'list of commands', + } + } + + +class BaseConfigSectionModel(BasePathModel): + section: Dict + + +class ConfigSectionModel(ApiModel, BaseConfigSectionModel): + pass + + +class ConfigSectionListModel(ApiModel): + commands: List[BaseConfigSectionModel] + + +class BaseConfigSectionTreeModel(BaseModel): + op: StrictStr + mask: Dict + config: Dict + + +class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): + pass + + +class RetrieveModel(ApiModel): + op: StrictStr + path: List[StrictStr] + configFormat: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'returnValue | returnValues | exists | showConfig', + 'path': ['config', 'mode', 'path'], + 'configFormat': 'json (default) | json_ast | raw', + } + } + + +class ConfigFileModel(ApiModel): + op: StrictStr + file: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'save | load', + 'file': 'filename', + } + } + + +class ImageOp(str, Enum): + add = 'add' + delete = 'delete' + show = 'show' + set_default = 'set_default' + + +class ImageModel(ApiModel): + op: ImageOp + url: StrictStr = None + name: StrictStr = None + + @model_validator(mode='after') + def check_data(self) -> Self: + if self.op == 'add': + if not self.url: + raise ValueError('Missing required field "url"') + elif self.op in ['delete', 'set_default']: + if not self.name: + raise ValueError('Missing required field "name"') + + return self + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show | set_default', + 'url': 'imagelocation', + 'name': 'imagename', + } + } + + +class ImportPkiModel(ApiModel): + op: StrictStr + path: List[StrictStr] + passphrase: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'import_pki', + 'path': ['op', 'mode', 'path'], + 'passphrase': 'passphrase', + } + } + + +class ContainerImageModel(ApiModel): + op: StrictStr + name: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show', + 'name': 'imagename', + } + } + + +class GenerateModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'generate', + 'path': ['op', 'mode', 'path'], + } + } + + +class ShowModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'show', + 'path': ['op', 'mode', 'path'], + } + } + + +class RebootModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'reboot', + 'path': ['op', 'mode', 'path'], + } + } + + +class ResetModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'reset', + 'path': ['op', 'mode', 'path'], + } + } + + +class PoweroffModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'poweroff', + 'path': ['op', 'mode', 'path'], + } + } + + +class TracerouteModel(ApiModel): + op: StrictStr + host: StrictStr + + class Config: + schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'traceroute', + 'host': 'host', + } + } + + +class Success(BaseModel): + success: bool + data: Union[str, bool, Dict] + error: str + + +class Error(BaseModel): + success: bool = False + data: Union[str, bool, Dict] + error: str + + +responses = { + 200: {'model': Success}, + 400: {'model': Error}, + 422: {'model': Error, 'description': 'Validation Error'}, + 500: {'model': Error}, +} diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py new file mode 100644 index 000000000..e52c77fda --- /dev/null +++ b/src/services/api/rest/routers.py @@ -0,0 +1,778 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +# pylint: disable=line-too-long,raise-missing-from,invalid-name +# pylint: disable=wildcard-import,unused-wildcard-import +# pylint: disable=broad-exception-caught + +import json +import copy +import logging +import traceback +from threading import Lock +from typing import Union +from typing import Callable +from typing import TYPE_CHECKING + +from fastapi import Depends +from fastapi import Request +from fastapi import Response +from fastapi import HTTPException +from fastapi import APIRouter +from fastapi import BackgroundTasks +from fastapi.routing import APIRoute +from starlette.datastructures import FormData +from starlette.formparsers import FormParser +from starlette.formparsers import MultiPartParser +from starlette.formparsers import MultiPartException +from multipart.multipart import parse_options_header + +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configdiff import get_config_diff +from vyos.configsession import ConfigSessionError + +from ..session import SessionState +from .models import success +from .models import error +from .models import responses +from .models import ApiModel +from .models import ConfigureModel +from .models import ConfigureListModel +from .models import ConfigSectionModel +from .models import ConfigSectionListModel +from .models import ConfigSectionTreeModel +from .models import BaseConfigSectionTreeModel +from .models import BaseConfigureModel +from .models import BaseConfigSectionModel +from .models import RetrieveModel +from .models import ConfigFileModel +from .models import ImageModel +from .models import ContainerImageModel +from .models import GenerateModel +from .models import ShowModel +from .models import RebootModel +from .models import ResetModel +from .models import ImportPkiModel +from .models import PoweroffModel +from .models import TracerouteModel + + +if TYPE_CHECKING: + from fastapi import FastAPI + + +LOG = logging.getLogger('http_api.routers') + +lock = Lock() + + +def check_auth(key_list, key): + key_id = None + for k in key_list: + if k['key'] == key: + key_id = k['id'] + return key_id + + +def auth_required(data: ApiModel): + session = SessionState() + key = data.key + api_keys = session.keys + key_id = check_auth(api_keys, key) + if not key_id: + raise HTTPException(status_code=401, detail='Valid API key is required') + session.id = key_id + + +# override Request and APIRoute classes in order to convert form request to json; +# do all explicit validation here, for backwards compatability of error messages; +# the explicit validation may be dropped, if desired, in favor of native +# validation by FastAPI/Pydantic, as is used for application/json requests +class MultipartRequest(Request): + """Override Request class to convert form request to json""" + + # pylint: disable=attribute-defined-outside-init + # pylint: disable=too-many-branches,too-many-statements + + _form_err = () + + @property + def form_err(self): + return self._form_err + + @form_err.setter + def form_err(self, val): + if not self._form_err: + self._form_err = val + + @property + def orig_headers(self): + self._orig_headers = super().headers + return self._orig_headers + + @property + def headers(self): + self._headers = super().headers.mutablecopy() + self._headers['content-type'] = 'application/json' + return self._headers + + async def _get_form( + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: + if self._form is None: + assert ( + parse_options_header is not None + ), 'The `python-multipart` library must be installed to use form parsing.' + content_type_header = self.orig_headers.get('Content-Type') + content_type: bytes + content_type, _ = parse_options_header(content_type_header) + if content_type == b'multipart/form-data': + try: + multipart_parser = MultiPartParser( + self.orig_headers, + self.stream(), + max_files=max_files, + max_fields=max_fields, + ) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if 'app' in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc + elif content_type == b'application/x-www-form-urlencoded': + form_parser = FormParser(self.orig_headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() + return self._form + + async def body(self) -> bytes: + if not hasattr(self, '_body'): + forms = {} + merge = {} + body = await super().body() + self._body = body + + form_data = await self.form() + if form_data: + endpoint = self.url.path + LOG.debug('processing form data') + for k, v in form_data.multi_items(): + forms[k] = v + + if 'data' not in forms: + self.form_err = (422, 'Non-empty data field is required') + return self._body + try: + tmp = json.loads(forms['data']) + except json.JSONDecodeError as e: + self.form_err = (400, f'Failed to parse JSON: {e}') + return self._body + if isinstance(tmp, list): + merge['commands'] = tmp + else: + merge = tmp + + if 'commands' in merge: + cmds = merge['commands'] + else: + cmds = copy.deepcopy(merge) + cmds = [cmds] + + for c in cmds: + if not isinstance(c, dict): + self.form_err = ( + 400, + f"Malformed command '{c}': any command must be JSON of dict", + ) + return self._body + if 'op' not in c: + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'op' field", + ) + if endpoint not in ( + '/config-file', + '/container-image', + '/image', + '/configure-section', + '/traceroute', + ): + if 'path' not in c: + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'path' field", + ) + elif not isinstance(c['path'], list): + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' field must be a list", + ) + elif not all(isinstance(el, str) for el in c['path']): + self.form_err = ( + 400, + f"Malformed command '{0}': 'path' field must be a list of strings", + ) + if endpoint in ('/configure'): + if not c['path']: + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' list must be non-empty", + ) + if 'value' in c and not isinstance(c['value'], str): + self.form_err = ( + 400, + f"Malformed command '{c}': 'value' field must be a string", + ) + if endpoint in ('/configure-section'): + if 'section' not in c and 'config' not in c: + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'section' or 'config' field", + ) + + if 'key' not in forms and 'key' not in merge: + self.form_err = (401, 'Valid API key is required') + if 'key' in forms and 'key' not in merge: + merge['key'] = forms['key'] + + new_body = json.dumps(merge) + new_body = new_body.encode() + self._body = new_body + + return self._body + + +class MultipartRoute(APIRoute): + """Override APIRoute class to convert form request to json""" + + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + request = MultipartRequest(request.scope, request.receive) + try: + response: Response = await original_route_handler(request) + except HTTPException as e: + return error(e.status_code, e.detail) + except Exception as e: + form_err = request.form_err + if form_err: + return error(*form_err) + raise e + + return response + + return custom_route_handler + + +router = APIRouter( + route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)], +) + + +self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background' + + +def call_commit(s: SessionState): + try: + s.session.commit() + except ConfigSessionError as e: + s.session.discard() + if s.debug: + LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}') + else: + LOG.warning(f'ConfigSessionError: {e}') + + +def _configure_op( + data: Union[ + ConfigureModel, + ConfigureListModel, + ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel, + ], + _request: Request, + background_tasks: BackgroundTasks, +): + # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements + # pylint: disable=consider-using-with + + state = SessionState() + session = state.session + env = session.get_session_env() + + # Allow users to pass just one command + if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)): + data = [data] + else: + data = data.commands + + # We don't want multiple people/apps to be able to commit at once, + # or modify the shared session while someone else is doing the same, + # so the lock is really global + lock.acquire() + + config = Config(session_env=env) + + status = 200 + msg = None + error_msg = None + try: + for c in data: + op = c.op + if not isinstance(c, BaseConfigSectionTreeModel): + path = c.path + + if isinstance(c, BaseConfigureModel): + if c.value: + value = c.value + else: + value = '' + # For vyos.configsession calls that have no separate value arguments, + # and for type checking too + cfg_path = ' '.join(path + [value]).strip() + + elif isinstance(c, BaseConfigSectionModel): + section = c.section + + elif isinstance(c, BaseConfigSectionTreeModel): + mask = c.mask + config = c.config + + if isinstance(c, BaseConfigureModel): + if op == 'set': + session.set(path, value=value) + elif op == 'delete': + if state.strict and not config.exists(cfg_path): + raise ConfigSessionError( + f'Cannot delete [{cfg_path}]: path/value does not exist' + ) + session.delete(path, value=value) + elif op == 'comment': + session.comment(path, value=value) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + + elif isinstance(c, BaseConfigSectionModel): + if op == 'set': + session.set_section(path, section) + elif op == 'load': + session.load_section(path, section) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + + elif isinstance(c, BaseConfigSectionTreeModel): + if op == 'set': + session.set_section_tree(config) + elif op == 'load': + session.load_section_tree(mask, config) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + # end for + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, state) + msg = self_ref_msg + else: + # capture non-fatal warnings + out = session.commit() + msg = out if out else msg + + LOG.info(f"Configuration modified via HTTP API using key '{state.id}'") + except ConfigSessionError as e: + session.discard() + status = 400 + if state.debug: + LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}') + error_msg = str(e) + except Exception: + session.discard() + LOG.critical(traceback.format_exc()) + status = 500 + + # Don't give the details away to the outer world + error_msg = 'An internal error occured. Check the logs for details.' + finally: + lock.release() + + if status != 200: + return error(status, error_msg) + + return success(msg) + + +def create_path_import_pki_no_prompt(path): + correct_paths = ['ca', 'certificate', 'key-pair'] + if path[1] not in correct_paths: + return False + path[3] = '--key-filename' + path.insert(2, '--name') + return ['--pki-type'] + path[1:] + + +@router.post('/configure') +def configure_op( + data: Union[ConfigureModel, ConfigureListModel], + request: Request, + background_tasks: BackgroundTasks, +): + return _configure_op(data, request, background_tasks) + + +@router.post('/configure-section') +def configure_section_op( + data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel], + request: Request, + background_tasks: BackgroundTasks, +): + return _configure_op(data, request, background_tasks) + + +@router.post('/retrieve') +async def retrieve_op(data: RetrieveModel): + state = SessionState() + session = state.session + env = session.get_session_env() + config = Config(session_env=env) + + op = data.op + path = ' '.join(data.path) + + try: + if op == 'returnValue': + res = config.return_value(path) + elif op == 'returnValues': + res = config.return_values(path) + elif op == 'exists': + res = config.exists(path) + elif op == 'showConfig': + config_format = 'json' + if data.configFormat: + config_format = data.configFormat + + res = session.show_config(path=data.path) + if config_format == 'json': + config_tree = ConfigTree(res) + res = json.loads(config_tree.to_json()) + elif config_format == 'json_ast': + config_tree = ConfigTree(res) + res = json.loads(config_tree.to_json_ast()) + elif config_format == 'raw': + pass + else: + return error(400, f"'{config_format}' is not a valid config format") + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/config-file') +def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): + state = SessionState() + session = state.session + env = session.get_session_env() + op = data.op + msg = None + + try: + if op == 'save': + if data.file: + path = data.file + else: + path = '/config/config.boot' + msg = session.save_config(path) + elif op == 'load': + if data.file: + path = data.file + else: + return error(400, 'Missing required field "file"') + + session.migrate_and_load_config(path) + + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, state) + msg = self_ref_msg + else: + session.commit() + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(msg) + + +@router.post('/image') +def image_op(data: ImageModel): + state = SessionState() + session = state.session + + op = data.op + + try: + if op == 'add': + res = session.install_image(data.url) + elif op == 'delete': + res = session.remove_image(data.name) + elif op == 'show': + res = session.show(['system', 'image']) + elif op == 'set_default': + res = session.set_default_image(data.name) + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/container-image') +def container_image_op(data: ContainerImageModel): + state = SessionState() + session = state.session + + op = data.op + + try: + if op == 'add': + if data.name: + name = data.name + else: + return error(400, 'Missing required field "name"') + res = session.add_container_image(name) + elif op == 'delete': + if data.name: + name = data.name + else: + return error(400, 'Missing required field "name"') + res = session.delete_container_image(name) + elif op == 'show': + res = session.show_container_image() + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/generate') +def generate_op(data: GenerateModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'generate': + res = session.generate(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/show') +def show_op(data: ShowModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'show': + res = session.show(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/reboot') +def reboot_op(data: RebootModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'reboot': + res = session.reboot(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/reset') +def reset_op(data: ResetModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'reset': + res = session.reset(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/import-pki') +def import_pki(data: ImportPkiModel): + # pylint: disable=consider-using-with + + state = SessionState() + session = state.session + + op = data.op + path = data.path + + lock.acquire() + + try: + if op == 'import-pki': + # need to get rid or interactive mode for private key + if len(path) == 5 and path[3] in ['key-file', 'private-key']: + path_no_prompt = create_path_import_pki_no_prompt(path) + if not path_no_prompt: + return error(400, f"Invalid command: {' '.join(path)}") + if data.passphrase: + path_no_prompt += ['--passphrase', data.passphrase] + res = session.import_pki_no_prompt(path_no_prompt) + else: + res = session.import_pki(path) + if not res[0].isdigit(): + return error(400, res) + # commit changes + session.commit() + res = res.split('. ')[0] + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + finally: + lock.release() + + return success(res) + + +@router.post('/poweroff') +def poweroff_op(data: PoweroffModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'poweroff': + res = session.poweroff(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/traceroute') +def traceroute_op(data: TracerouteModel): + state = SessionState() + session = state.session + + op = data.op + host = data.host + + try: + if op == 'traceroute': + res = session.traceroute(host) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occurred. Check the logs for details.') + + return success(res) + + +def rest_init(app: 'FastAPI'): + if all(r in app.routes for r in router.routes): + return + app.include_router(router) + + +def rest_clear(app: 'FastAPI'): + for r in router.routes: + if r in app.routes: + app.routes.remove(r) diff --git a/src/services/api/session.py b/src/services/api/session.py new file mode 100644 index 000000000..ad3ef660c --- /dev/null +++ b/src/services/api/session.py @@ -0,0 +1,41 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +class SessionState: + # pylint: disable=attribute-defined-outside-init + # pylint: disable=too-many-instance-attributes,too-few-public-methods + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(SessionState, cls).__new__(cls) + cls._instance._initialize() + return cls._instance + + def _initialize(self): + self.session = None + self.keys = [] + self.id = None + self.rest = False + self.debug = False + self.strict = False + self.graphql = False + self.origins = [] + self.introspection = False + self.auth_type = None + self.token_exp = None + self.secret_len = None diff --git a/src/services/vyos-configd b/src/services/vyos-configd index 3674d9627..cb23642dc 100755 --- a/src/services/vyos-configd +++ b/src/services/vyos-configd @@ -14,6 +14,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +# pylint: disable=redefined-outer-name + import os import sys import grp @@ -22,9 +24,12 @@ import json import typing import logging import signal +import traceback import importlib.util +import io +from contextlib import redirect_stdout + import zmq -from contextlib import contextmanager from vyos.defaults import directories from vyos.utils.boot import boot_configuration_complete @@ -49,7 +54,8 @@ if debug: else: logger.setLevel(logging.INFO) -SOCKET_PATH = "ipc:///run/vyos-configd.sock" +SOCKET_PATH = 'ipc:///run/vyos-configd.sock' +MAX_MSG_SIZE = 65535 # Response error codes R_SUCCESS = 1 @@ -64,9 +70,6 @@ configd_env_unset_file = os.path.join(directories['data'], 'vyos-configd-env-uns # sourced on entering config session configd_env_file = '/etc/default/vyos-configd-env' -session_out = None -session_mode = None - def key_name_from_file_name(f): return os.path.splitext(f)[0] @@ -76,17 +79,19 @@ def module_name_from_key(k): def path_from_file_name(f): return os.path.join(vyos_conf_scripts_dir, f) + # opt-in to be run by daemon with open(configd_include_file) as f: try: include = json.load(f) except OSError as e: - logger.critical(f"configd include file error: {e}") + logger.critical(f'configd include file error: {e}') sys.exit(1) except json.JSONDecodeError as e: - logger.critical(f"JSON load error: {e}") + logger.critical(f'JSON load error: {e}') sys.exit(1) + # import conf_mode scripts (_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir))) filenames.sort() @@ -110,31 +115,17 @@ conf_mode_scripts = dict(zip(imports, modules)) exclude_set = {key_name_from_file_name(f) for f in filenames if f not in include} include_set = {key_name_from_file_name(f) for f in filenames if f in include} -@contextmanager -def stdout_redirected(filename, mode): - saved_stdout_fd = None - destination_file = None - try: - sys.stdout.flush() - saved_stdout_fd = os.dup(sys.stdout.fileno()) - destination_file = open(filename, mode) - os.dup2(destination_file.fileno(), sys.stdout.fileno()) - yield - finally: - if saved_stdout_fd is not None: - os.dup2(saved_stdout_fd, sys.stdout.fileno()) - os.close(saved_stdout_fd) - if destination_file is not None: - destination_file.close() - -def explicit_print(path, mode, msg): - try: - with open(path, mode) as f: - f.write(f"\n{msg}\n\n") - except OSError: - logger.critical("error explicit_print") -def run_script(script_name, config, args) -> int: +def write_stdout_log(file_name, msg): + if boot_configuration_complete(): + return + with open(file_name, 'a') as f: + f.write(msg) + + +def run_script(script_name, config, args) -> tuple[int, str]: + # pylint: disable=broad-exception-caught + script = conf_mode_scripts[script_name] script.argv = args config.set_level([]) @@ -145,64 +136,54 @@ def run_script(script_name, config, args) -> int: script.apply(c) except ConfigError as e: logger.error(e) - explicit_print(session_out, session_mode, str(e)) - return R_ERROR_COMMIT - except Exception as e: - logger.critical(e) - return R_ERROR_DAEMON + return R_ERROR_COMMIT, str(e) + except Exception: + tb = traceback.format_exc() + logger.error(tb) + return R_ERROR_COMMIT, tb + + return R_SUCCESS, '' - return R_SUCCESS def initialization(socket): - global session_out - global session_mode + # pylint: disable=broad-exception-caught,too-many-locals + # Reset config strings: active_string = '' session_string = '' # check first for resent init msg, in case of client timeout while True: - msg = socket.recv().decode("utf-8", "ignore") + msg = socket.recv().decode('utf-8', 'ignore') try: message = json.loads(msg) - if message["type"] == "init": - resp = "init" + if message['type'] == 'init': + resp = 'init' socket.send(resp.encode()) - except: + except Exception: break # zmq synchronous for ipc from single client: active_string = msg - resp = "active" + resp = 'active' socket.send(resp.encode()) - session_string = socket.recv().decode("utf-8", "ignore") - resp = "session" + session_string = socket.recv().decode('utf-8', 'ignore') + resp = 'session' socket.send(resp.encode()) - pid_string = socket.recv().decode("utf-8", "ignore") - resp = "pid" + pid_string = socket.recv().decode('utf-8', 'ignore') + resp = 'pid' socket.send(resp.encode()) - sudo_user_string = socket.recv().decode("utf-8", "ignore") - resp = "sudo_user" + sudo_user_string = socket.recv().decode('utf-8', 'ignore') + resp = 'sudo_user' socket.send(resp.encode()) - temp_config_dir_string = socket.recv().decode("utf-8", "ignore") - resp = "temp_config_dir" + temp_config_dir_string = socket.recv().decode('utf-8', 'ignore') + resp = 'temp_config_dir' socket.send(resp.encode()) - changes_only_dir_string = socket.recv().decode("utf-8", "ignore") - resp = "changes_only_dir" + changes_only_dir_string = socket.recv().decode('utf-8', 'ignore') + resp = 'changes_only_dir' socket.send(resp.encode()) - logger.debug(f"config session pid is {pid_string}") - logger.debug(f"config session sudo_user is {sudo_user_string}") - - try: - session_out = os.readlink(f"/proc/{pid_string}/fd/1") - session_mode = 'w' - except FileNotFoundError: - session_out = None - - # if not a 'live' session, for example on boot, write to file - if not session_out or not boot_configuration_complete(): - session_out = script_stdout_log - session_mode = 'a' + logger.debug(f'config session pid is {pid_string}') + logger.debug(f'config session sudo_user is {sudo_user_string}') os.environ['SUDO_USER'] = sudo_user_string if temp_config_dir_string: @@ -229,10 +210,12 @@ def initialization(socket): return config -def process_node_data(config, data, last: bool = False) -> int: + +def process_node_data(config, data, _last: bool = False) -> tuple[int, str]: if not config: - logger.critical(f"Empty config") - return R_ERROR_DAEMON + out = 'Empty config' + logger.critical(out) + return R_ERROR_DAEMON, out script_name = None os.environ['VYOS_TAGNODE_VALUE'] = '' @@ -246,8 +229,9 @@ def process_node_data(config, data, last: bool = False) -> int: if res.group(2): script_name = res.group(2) if not script_name: - logger.critical(f"Missing script_name") - return R_ERROR_DAEMON + out = 'Missing script_name' + logger.critical(out) + return R_ERROR_DAEMON, out if res.group(3): args = res.group(3).split() args.insert(0, f'{script_name}.py') @@ -259,26 +243,55 @@ def process_node_data(config, data, last: bool = False) -> int: scripts_called.append(script_record) if script_name not in include_set: - return R_PASS + return R_PASS, '' + + with redirect_stdout(io.StringIO()) as o: + result, err_out = run_script(script_name, config, args) + amb_out = o.getvalue() + o.close() + + out = amb_out + err_out + + return result, out + - with stdout_redirected(session_out, session_mode): - result = run_script(script_name, config, args) +def send_result(sock, err, msg): + msg_size = min(MAX_MSG_SIZE, len(msg)) if msg else 0 + + err_rep = err.to_bytes(1, byteorder=sys.byteorder) + logger.debug(f'Sending reply: {err}') + sock.send(err_rep) + + # size req from vyshim client + size_req = sock.recv().decode() + logger.debug(f'Received request: {size_req}') + msg_size_rep = hex(msg_size).encode() + sock.send(msg_size_rep) + logger.debug(f'Sending reply: {msg_size}') + + if msg_size > 0: + # send req is sent from vyshim client only if msg_size > 0 + send_req = sock.recv().decode() + logger.debug(f'Received request: {send_req}') + sock.send(msg.encode()) + logger.debug('Sending reply with output') + + write_stdout_log(script_stdout_log, msg) - return result def remove_if_file(f: str): try: os.remove(f) except FileNotFoundError: pass - except OSError: - raise + def shutdown(): remove_if_file(configd_env_file) os.symlink(configd_env_unset_file, configd_env_file) sys.exit(0) + if __name__ == '__main__': context = zmq.Context() socket = context.socket(zmq.REP) @@ -294,6 +307,7 @@ if __name__ == '__main__': os.environ['VYOS_CONFIGD'] = 't' def sig_handler(signum, frame): + # pylint: disable=unused-argument shutdown() signal.signal(signal.SIGTERM, sig_handler) @@ -308,20 +322,19 @@ if __name__ == '__main__': while True: # Wait for next request from client msg = socket.recv().decode() - logger.debug(f"Received message: {msg}") + logger.debug(f'Received message: {msg}') message = json.loads(msg) - if message["type"] == "init": - resp = "init" + if message['type'] == 'init': + resp = 'init' socket.send(resp.encode()) config = initialization(socket) - elif message["type"] == "node": - res = process_node_data(config, message["data"], message["last"]) - response = res.to_bytes(1, byteorder=sys.byteorder) - logger.debug(f"Sending response {res}") - socket.send(response) - if message["last"] and config: + elif message['type'] == 'node': + res, out = process_node_data(config, message['data'], message['last']) + send_result(socket, res, out) + + if message['last'] and config: scripts_called = getattr(config, 'scripts_called', []) logger.debug(f'scripts_called: {scripts_called}') else: - logger.critical(f"Unexpected message: {message}") + logger.critical(f'Unexpected message: {message}') diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 97633577d..558561182 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -17,915 +17,51 @@ import os import sys import grp -import copy import json import logging import signal -import traceback -import threading -from enum import Enum - from time import sleep -from typing import List, Union, Callable, Dict, Self -from fastapi import FastAPI, Depends, Request, Response, HTTPException -from fastapi import BackgroundTasks -from fastapi.responses import HTMLResponse +from fastapi import FastAPI from fastapi.exceptions import RequestValidationError -from fastapi.routing import APIRoute -from pydantic import BaseModel, StrictStr, validator, model_validator -from starlette.middleware.cors import CORSMiddleware -from starlette.datastructures import FormData -from starlette.formparsers import FormParser, MultiPartParser -from multipart.multipart import parse_options_header from uvicorn import Config as UvicornConfig from uvicorn import Server as UvicornServer -from ariadne.asgi import GraphQL - -from vyos.config import Config -from vyos.configtree import ConfigTree -from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSession -from vyos.configsession import ConfigSessionError from vyos.defaults import api_config_state -import api.graphql.state +from api.session import SessionState +from api.rest.models import error CFG_GROUP = 'vyattacfg' debug = True -logger = logging.getLogger(__name__) +LOG = logging.getLogger('http_api') logs_handler = logging.StreamHandler() -logger.addHandler(logs_handler) +LOG.addHandler(logs_handler) if debug: - logger.setLevel(logging.DEBUG) + LOG.setLevel(logging.DEBUG) else: - logger.setLevel(logging.INFO) + LOG.setLevel(logging.INFO) -# Giant lock! -lock = threading.Lock() def load_server_config(): with open(api_config_state) as f: config = json.load(f) return config -def check_auth(key_list, key): - key_id = None - for k in key_list: - if k['key'] == key: - key_id = k['id'] - return key_id - -def error(code, msg): - resp = {"success": False, "error": msg, "data": None} - resp = json.dumps(resp) - return HTMLResponse(resp, status_code=code) - -def success(data): - resp = {"success": True, "data": data, "error": None} - resp = json.dumps(resp) - return HTMLResponse(resp) - -# Pydantic models for validation -# Pydantic will cast when possible, so use StrictStr -# validators added as needed for additional constraints -# schema_extra adds anotations to OpenAPI, to add examples - -class ApiModel(BaseModel): - key: StrictStr - -class BasePathModel(BaseModel): - op: StrictStr - path: List[StrictStr] - - @validator("path") - def check_non_empty(cls, path): - if not len(path) > 0: - raise ValueError('path must be non-empty') - return path - -class BaseConfigureModel(BasePathModel): - value: StrictStr = None - -class ConfigureModel(ApiModel, BaseConfigureModel): - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "set | delete | comment", - "path": ['config', 'mode', 'path'], - } - } - -class ConfigureListModel(ApiModel): - commands: List[BaseConfigureModel] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "commands": "list of commands", - } - } - -class BaseConfigSectionModel(BasePathModel): - section: Dict - -class ConfigSectionModel(ApiModel, BaseConfigSectionModel): - pass - -class ConfigSectionListModel(ApiModel): - commands: List[BaseConfigSectionModel] - -class BaseConfigSectionTreeModel(BaseModel): - op: StrictStr - mask: Dict - config: Dict - -class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): - pass - -class RetrieveModel(ApiModel): - op: StrictStr - path: List[StrictStr] - configFormat: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "returnValue | returnValues | exists | showConfig", - "path": ['config', 'mode', 'path'], - "configFormat": "json (default) | json_ast | raw", - - } - } - -class ConfigFileModel(ApiModel): - op: StrictStr - file: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "save | load", - "file": "filename", - } - } - - -class ImageOp(str, Enum): - add = "add" - delete = "delete" - show = "show" - set_default = "set_default" - - -class ImageModel(ApiModel): - op: ImageOp - url: StrictStr = None - name: StrictStr = None - - @model_validator(mode='after') - def check_data(self) -> Self: - if self.op == 'add': - if not self.url: - raise ValueError("Missing required field \"url\"") - elif self.op in ['delete', 'set_default']: - if not self.name: - raise ValueError("Missing required field \"name\"") - - return self - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show | set_default", - "url": "imagelocation", - "name": "imagename", - } - } - -class ImportPkiModel(ApiModel): - op: StrictStr - path: List[StrictStr] - passphrase: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "import_pki", - "path": ["op", "mode", "path"], - "passphrase": "passphrase", - } - } - - -class ContainerImageModel(ApiModel): - op: StrictStr - name: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show", - "name": "imagename", - } - } - -class GenerateModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "generate", - "path": ["op", "mode", "path"], - } - } - -class ShowModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "show", - "path": ["op", "mode", "path"], - } - } - -class RebootModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "reboot", - "path": ["op", "mode", "path"], - } - } - -class ResetModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "reset", - "path": ["op", "mode", "path"], - } - } - -class PoweroffModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "poweroff", - "path": ["op", "mode", "path"], - } - } - - -class Success(BaseModel): - success: bool - data: Union[str, bool, Dict] - error: str - -class Error(BaseModel): - success: bool = False - data: Union[str, bool, Dict] - error: str - -responses = { - 200: {'model': Success}, - 400: {'model': Error}, - 422: {'model': Error, 'description': 'Validation Error'}, - 500: {'model': Error} -} - -def auth_required(data: ApiModel): - key = data.key - api_keys = app.state.vyos_keys - key_id = check_auth(api_keys, key) - if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") - app.state.vyos_id = key_id - -# override Request and APIRoute classes in order to convert form request to json; -# do all explicit validation here, for backwards compatability of error messages; -# the explicit validation may be dropped, if desired, in favor of native -# validation by FastAPI/Pydantic, as is used for application/json requests -class MultipartRequest(Request): - _form_err = () - @property - def form_err(self): - return self._form_err - - @form_err.setter - def form_err(self, val): - if not self._form_err: - self._form_err = val - - @property - def orig_headers(self): - self._orig_headers = super().headers - return self._orig_headers - - @property - def headers(self): - self._headers = super().headers.mutablecopy() - self._headers['content-type'] = 'application/json' - return self._headers - - async def form(self) -> FormData: - if self._form is None: - assert ( - parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.orig_headers, self.stream()) - self._form = await multipart_parser.parse() - elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.orig_headers, self.stream()) - self._form = await form_parser.parse() - else: - self._form = FormData() - return self._form - - async def body(self) -> bytes: - if not hasattr(self, "_body"): - forms = {} - merge = {} - body = await super().body() - self._body = body - - form_data = await self.form() - if form_data: - endpoint = self.url.path - logger.debug("processing form data") - for k, v in form_data.multi_items(): - forms[k] = v - - if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") - return self._body - else: - try: - tmp = json.loads(forms['data']) - except json.JSONDecodeError as e: - self.form_err = (400, f'Failed to parse JSON: {e}') - return self._body - if isinstance(tmp, list): - merge['commands'] = tmp - else: - merge = tmp - - if 'commands' in merge: - cmds = merge['commands'] - else: - cmds = copy.deepcopy(merge) - cmds = [cmds] - - for c in cmds: - if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") - return self._body - if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): - if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") - elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") - elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") - if endpoint in ('/configure'): - if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") - if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") - if endpoint in ('/configure-section'): - if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") - - if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") - if 'key' in forms and 'key' not in merge: - merge['key'] = forms['key'] - - new_body = json.dumps(merge) - new_body = new_body.encode() - self._body = new_body - - return self._body - -class MultipartRoute(APIRoute): - def get_route_handler(self) -> Callable: - original_route_handler = super().get_route_handler() - - async def custom_route_handler(request: Request) -> Response: - request = MultipartRequest(request.scope, request.receive) - try: - response: Response = await original_route_handler(request) - except HTTPException as e: - return error(e.status_code, e.detail) - except Exception as e: - form_err = request.form_err - if form_err: - return error(*form_err) - raise e - - return response - - return custom_route_handler app = FastAPI(debug=True, title="VyOS API", - version="0.1.0", - responses={**responses}, - dependencies=[Depends(auth_required)]) - -app.router.route_class = MultipartRoute + version="0.1.0") @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): +async def validation_exception_handler(_request, exc): return error(400, str(exc.errors()[0])) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" - -def call_commit(s: ConfigSession): - try: - s.commit() - except ConfigSessionError as e: - s.discard() - if app.state.vyos_debug: - logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}") - else: - logger.warning(f"ConfigSessionError: {e}") - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): - session = app.state.vyos_session - env = session.get_session_env() - - endpoint = request.url.path - - # Allow users to pass just one command - if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)): - data = [data] - else: - data = data.commands - - # We don't want multiple people/apps to be able to commit at once, - # or modify the shared session while someone else is doing the same, - # so the lock is really global - lock.acquire() - - config = Config(session_env=env) - - status = 200 - msg = None - error_msg = None - try: - for c in data: - op = c.op - if not isinstance(c, BaseConfigSectionTreeModel): - path = c.path - - if isinstance(c, BaseConfigureModel): - if c.value: - value = c.value - else: - value = "" - # For vyos.configsession calls that have no separate value arguments, - # and for type checking too - cfg_path = " ".join(path + [value]).strip() - - elif isinstance(c, BaseConfigSectionModel): - section = c.section - - elif isinstance(c, BaseConfigSectionTreeModel): - mask = c.mask - config = c.config - - if isinstance(c, BaseConfigureModel): - if op == 'set': - session.set(path, value=value) - elif op == 'delete': - if app.state.vyos_strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") - session.delete(path, value=value) - elif op == 'comment': - session.comment(path, value=value) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - - elif isinstance(c, BaseConfigSectionModel): - if op == 'set': - session.set_section(path, section) - elif op == 'load': - session.load_section(path, section) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - - elif isinstance(c, BaseConfigSectionTreeModel): - if op == 'set': - session.set_section_tree(config) - elif op == 'load': - session.load_section_tree(mask, config) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - # end for - config = Config(session_env=env) - d = get_config_diff(config) - - if d.is_node_changed(['service', 'https']): - background_tasks.add_task(call_commit, session) - msg = self_ref_msg - else: - session.commit() - - logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'") - except ConfigSessionError as e: - session.discard() - status = 400 - if app.state.vyos_debug: - logger.critical(f"ConfigSessionError:\n {traceback.format_exc()}") - error_msg = str(e) - except Exception as e: - session.discard() - logger.critical(traceback.format_exc()) - status = 500 - - # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." - finally: - lock.release() - - if status != 200: - return error(status, error_msg) - - return success(msg) - -def create_path_import_pki_no_prompt(path): - correct_paths = ['ca', 'certificate', 'key-pair'] - if path[1] not in correct_paths: - return False - path[1] = '--' + path[1].replace('-', '') - path[3] = '--key-filename' - return path[1:] - -@app.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): - return _configure_op(data, request, background_tasks) - -@app.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): - return _configure_op(data, request, background_tasks) - -@app.post("/retrieve") -async def retrieve_op(data: RetrieveModel): - session = app.state.vyos_session - env = session.get_session_env() - config = Config(session_env=env) - - op = data.op - path = " ".join(data.path) - - try: - if op == 'returnValue': - res = config.return_value(path) - elif op == 'returnValues': - res = config.return_values(path) - elif op == 'exists': - res = config.exists(path) - elif op == 'showConfig': - config_format = 'json' - if data.configFormat: - config_format = data.configFormat - - res = session.show_config(path=data.path) - if config_format == 'json': - config_tree = ConfigTree(res) - res = json.loads(config_tree.to_json()) - elif config_format == 'json_ast': - config_tree = ConfigTree(res) - res = json.loads(config_tree.to_json_ast()) - elif config_format == 'raw': - pass - else: - return error(400, f"'{config_format}' is not a valid config format") - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/config-file') -def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): - session = app.state.vyos_session - env = session.get_session_env() - op = data.op - msg = None - - try: - if op == 'save': - if data.file: - path = data.file - else: - path = '/config/config.boot' - msg = session.save_config(path) - elif op == 'load': - if data.file: - path = data.file - else: - return error(400, "Missing required field \"file\"") - - session.migrate_and_load_config(path) - - config = Config(session_env=env) - d = get_config_diff(config) - - if d.is_node_changed(['service', 'https']): - background_tasks.add_task(call_commit, session) - msg = self_ref_msg - else: - session.commit() - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(msg) - -@app.post('/image') -def image_op(data: ImageModel): - session = app.state.vyos_session - - op = data.op - - try: - if op == 'add': - res = session.install_image(data.url) - elif op == 'delete': - res = session.remove_image(data.name) - elif op == 'show': - res = session.show(["system", "image"]) - elif op == 'set_default': - res = session.set_default_image(data.name) - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/container-image') -def container_image_op(data: ContainerImageModel): - session = app.state.vyos_session - - op = data.op - - try: - if op == 'add': - if data.name: - name = data.name - else: - return error(400, "Missing required field \"name\"") - res = session.add_container_image(name) - elif op == 'delete': - if data.name: - name = data.name - else: - return error(400, "Missing required field \"name\"") - res = session.delete_container_image(name) - elif op == 'show': - res = session.show_container_image() - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/generate') -def generate_op(data: GenerateModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'generate': - res = session.generate(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/show') -def show_op(data: ShowModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'show': - res = session.show(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/reboot') -def reboot_op(data: RebootModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'reboot': - res = session.reboot(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/reset') -def reset_op(data: ResetModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'reset': - res = session.reset(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/import-pki') -def import_pki(data: ImportPkiModel): - session = app.state.vyos_session - - op = data.op - path = data.path - lock.acquire() - - try: - if op == 'import-pki': - # need to get rid or interactive mode for private key - if len(path) == 5 and path[3] in ['key-file', 'private-key']: - path_no_prompt = create_path_import_pki_no_prompt(path) - if not path_no_prompt: - return error(400, f"Invalid command: {' '.join(path)}") - if data.passphrase: - path_no_prompt += ['--passphrase', data.passphrase] - res = session.import_pki_no_prompt(path_no_prompt) - else: - res = session.import_pki(path) - if not res[0].isdigit(): - return error(400, res) - # commit changes - session.commit() - res = res.split('. ')[0] - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - finally: - lock.release() - - return success(res) - -@app.post('/poweroff') -def poweroff_op(data: PoweroffModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'poweroff': - res = session.poweroff(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - - -### -# GraphQL integration -### - -def graphql_init(app: FastAPI = app): - from api.graphql.libs.token_auth import get_user_context - api.graphql.state.init() - api.graphql.state.settings['app'] = app - - # import after initializaion of state - from api.graphql.bindings import generate_schema - schema = generate_schema() - - in_spec = app.state.vyos_introspection - - if app.state.vyos_origins: - origins = app.state.vyos_origins - app.add_route('/graphql', CORSMiddleware(GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec), - allow_origins=origins, - allow_methods=("GET", "POST", "OPTIONS"), - allow_headers=("Authorization",))) - else: - app.add_route('/graphql', GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec)) ### # Modify uvicorn to allow reloading server within the configsession ### @@ -933,30 +69,41 @@ def graphql_init(app: FastAPI = app): server = None shutdown = False + class ApiServerConfig(UvicornConfig): pass + class ApiServer(UvicornServer): def install_signal_handlers(self): pass + def reload_handler(signum, frame): + # pylint: disable=global-statement + global server - logger.debug('Reload signal received...') + LOG.debug('Reload signal received...') if server is not None: server.handle_exit(signum, frame) server = None - logger.info('Server stopping for reload...') + LOG.info('Server stopping for reload...') else: - logger.warning('Reload called for non-running server...') + LOG.warning('Reload called for non-running server...') + def shutdown_handler(signum, frame): + # pylint: disable=global-statement + global shutdown - logger.debug('Shutdown signal received...') + LOG.debug('Shutdown signal received...') server.handle_exit(signum, frame) - logger.info('Server shutdown...') + LOG.info('Server shutdown...') shutdown = True +# end modify uvicorn + + def flatten_keys(d: dict) -> list[dict]: keys_list = [] for el in list(d['keys'].get('id', {})): @@ -965,49 +112,87 @@ def flatten_keys(d: dict) -> list[dict]: keys_list.append({'id': el, 'key': key}) return keys_list -def initialization(session: ConfigSession, app: FastAPI = app): + +def regenerate_docs(app: FastAPI) -> None: + docs = ('/openapi.json', '/docs', '/docs/oauth2-redirect', '/redoc') + remove = [] + for r in app.routes: + if r.path in docs: + remove.append(r) + for r in remove: + app.routes.remove(r) + + app.openapi_schema = None + app.setup() + + +def initialization(session: SessionState, app: FastAPI = app): + # pylint: disable=global-statement,broad-exception-caught,import-outside-toplevel + global server try: server_config = load_server_config() except Exception as e: - logger.critical(f'Failed to load the HTTP API server config: {e}') + LOG.critical(f'Failed to load the HTTP API server config: {e}') sys.exit(1) - app.state.vyos_session = session - app.state.vyos_keys = [] - if 'keys' in server_config: - app.state.vyos_keys = flatten_keys(server_config) + session.keys = flatten_keys(server_config) + + rest_config = server_config.get('rest', {}) + session.debug = bool('debug' in rest_config) + session.strict = bool('strict' in rest_config) + + graphql_config = server_config.get('graphql', {}) + session.origins = graphql_config.get('cors', {}).get('allow_origin', []) + + if 'rest' in server_config: + session.rest = True + else: + session.rest = False - app.state.vyos_debug = bool('debug' in server_config) - app.state.vyos_strict = bool('strict' in server_config) - app.state.vyos_origins = server_config.get('cors', {}).get('allow_origin', []) if 'graphql' in server_config: - app.state.vyos_graphql = True + session.graphql = True if isinstance(server_config['graphql'], dict): if 'introspection' in server_config['graphql']: - app.state.vyos_introspection = True + session.introspection = True else: - app.state.vyos_introspection = False + session.introspection = False # default values if not set explicitly - app.state.vyos_auth_type = server_config['graphql']['authentication']['type'] - app.state.vyos_token_exp = server_config['graphql']['authentication']['expiration'] - app.state.vyos_secret_len = server_config['graphql']['authentication']['secret_length'] + session.auth_type = server_config['graphql']['authentication']['type'] + session.token_exp = server_config['graphql']['authentication']['expiration'] + session.secret_len = server_config['graphql']['authentication']['secret_length'] + else: + session.graphql = False + + # pass session state + app.state = session + + # add REST routes + if session.rest: + from api.rest.routers import rest_init + rest_init(app) else: - app.state.vyos_graphql = False + from api.rest.routers import rest_clear + rest_clear(app) - if app.state.vyos_graphql: + # add GraphQL route + if session.graphql: + from api.graphql.routers import graphql_init graphql_init(app) + else: + from api.graphql.routers import graphql_clear + graphql_clear(app) + + regenerate_docs(app) + + LOG.debug('Active routes are:') + for r in app.routes: + LOG.debug(f'{r.path}') config = ApiServerConfig(app, uds="/run/api.sock", proxy_headers=True) server = ApiServer(config) -def run_server(): - try: - server.run() - except OSError as e: - logger.critical(e) - sys.exit(1) if __name__ == '__main__': # systemd's user and group options don't work, do it by hand here, @@ -1022,13 +207,14 @@ if __name__ == '__main__': signal.signal(signal.SIGHUP, reload_handler) signal.signal(signal.SIGTERM, shutdown_handler) - config_session = ConfigSession(os.getpid()) + session_state = SessionState() + session_state.session = ConfigSession(os.getpid()) while True: - logger.debug('Enter main loop...') + LOG.debug('Enter main loop...') if shutdown: break if server is None: - initialization(config_session) + initialization(session_state) server.run() sleep(1) diff --git a/src/shim/vyshim.c b/src/shim/vyshim.c index a78f62a7b..68e6c4015 100644 --- a/src/shim/vyshim.c +++ b/src/shim/vyshim.c @@ -67,6 +67,8 @@ void timer_handler(int); double get_posix_clock_time(void); +static char * s_recv_string (void *, int); + int main(int argc, char* argv[]) { // string for node data: conf_mode script and tagnode, if applicable @@ -119,31 +121,44 @@ int main(int argc, char* argv[]) zmq_recv(requester, error_code, 1, 0); debug_print("Received node data receipt\n"); - int err = (int)error_code[0]; + char msg_size_str[7]; + zmq_send(requester, "msg_size", 8, 0); + zmq_recv(requester, msg_size_str, 6, 0); + msg_size_str[6] = '\0'; + int msg_size = (int)strtol(msg_size_str, NULL, 16); + debug_print("msg_size: %d\n", msg_size); + + if (msg_size > 0) { + zmq_send(requester, "send", 4, 0); + char *msg = s_recv_string(requester, msg_size); + printf("%s", msg); + free(msg); + } free(string_node_data_msg); - zmq_close(requester); - zmq_ctx_destroy(context); + int err = (int)error_code[0]; + int ret = 0; if (err & PASS) { debug_print("Received PASS\n"); - int ret = pass_through(argv, ex_index); - return ret; + ret = pass_through(argv, ex_index); } if (err & ERROR_DAEMON) { debug_print("Received ERROR_DAEMON\n"); - int ret = pass_through(argv, ex_index); - return ret; + ret = pass_through(argv, ex_index); } if (err & ERROR_COMMIT) { debug_print("Received ERROR_COMMIT\n"); - return -1; + ret = -1; } - return 0; + zmq_close(requester); + zmq_ctx_destroy(context); + + return ret; } int initialization(void* Requester) @@ -342,3 +357,15 @@ double get_posix_clock_time(void) double get_posix_clock_time(void) {return (double)0;} #endif + +// Receive string from socket and convert into C string +static char * s_recv_string (void *socket, int bufsize) { + char * buffer = (char *)malloc(bufsize+1); + int size = zmq_recv(socket, buffer, bufsize, 0); + if (size == -1) + return NULL; + if (size > bufsize) + size = bufsize; + buffer[size] = '\0'; + return buffer; +} diff --git a/src/systemd/vyos-domain-resolver.service b/src/systemd/vyos-domain-resolver.service index c56b51f0c..e63ae5e34 100644 --- a/src/systemd/vyos-domain-resolver.service +++ b/src/systemd/vyos-domain-resolver.service @@ -1,6 +1,7 @@ [Unit] Description=VyOS firewall domain resolver After=vyos-router.service +ConditionPathExistsGlob=/run/use-vyos-domain-resolver* [Service] Type=simple diff --git a/src/utils/vyos-commands-to-config b/src/utils/vyos-commands-to-config new file mode 100755 index 000000000..927d9bd70 --- /dev/null +++ b/src/utils/vyos-commands-to-config @@ -0,0 +1,53 @@ +#! /usr/bin/python3 +# +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +import sys +import json + +from vyos.configtree import ConfigTree +from vyos.utils.config import parse_commands +from vyos.utils.config import set_tags + +def commands_to_config(cmds): + ct = ConfigTree('') + cmds = parse_commands(cmds) + + for c in cmds: + if c["op"] == "set": + if c["is_leaf"]: + replace = False if c["is_multi"] else True + ct.set(c["path"], value=c["value"], replace=replace) + set_tags(ct, c["path"]) + else: + ct.create_node(c["path"]) + set_tags(ct, c["path"]) + else: + raise ValueError( + f"\"{c['op']}\" is not a supported config operation") + + return ct + + +if __name__ == '__main__': + try: + cmds = sys.stdin.read() + ct = commands_to_config(cmds) + out = ConfigTree(ct.to_string()) + print(str(out)) + except Exception as e: + print(e) + sys.exit(1) diff --git a/src/utils/vyos-show-config b/src/utils/vyos-show-config new file mode 100755 index 000000000..152322fc1 --- /dev/null +++ b/src/utils/vyos-show-config @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (C) 2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +import os +import sys +import argparse + +from signal import signal, SIGPIPE, SIG_DFL + +def get_config(path): + from vyos.utils.process import rc_cmd + res, out = rc_cmd(f"cli-shell-api showCfg {path}") + if res > 0: + print("Error: failed to retrieve the config", file=sys.stderr) + sys.exit(1) + else: + return out + +def strip_config(config): + from vyos.utils.strip_config import strip_config_source + return strip_config_source(config) + +if __name__ == '__main__': + signal(SIGPIPE,SIG_DFL) + + parser = argparse.ArgumentParser() + parser.add_argument("--strip-private", + help="Strip private information from the config", + action="store_true") + + args, path_args = parser.parse_known_args() + + config = get_config(" ".join(path_args)) + + if args.strip_private: + edit_level = os.getenv("VYATTA_EDIT_LEVEL") + if (edit_level != "/") or (len(path_args) > 0): + print("Error: show --strip-private only works at the top level", + file=sys.stderr) + sys.exit(1) + else: + print(strip_config(config)) + else: + print(config) diff --git a/src/validators/interface-address b/src/validators/interface-address index 4c203956b..2a2583fc3 100755 --- a/src/validators/interface-address +++ b/src/validators/interface-address @@ -1,3 +1,3 @@ #!/bin/sh -ipaddrcheck --is-ipv4-host $1 || ipaddrcheck --is-ipv6-host $1 +ipaddrcheck --is-any-host "$1" diff --git a/src/validators/ip-address b/src/validators/ip-address index 11d6df09e..351f728a6 100755 --- a/src/validators/ip-address +++ b/src/validators/ip-address @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-any-single $1 +ipaddrcheck --is-any-single "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IP address" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ip-cidr b/src/validators/ip-cidr index 60d2ac295..8a01e7ad9 100755 --- a/src/validators/ip-cidr +++ b/src/validators/ip-cidr @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-any-cidr $1 +ipaddrcheck --is-any-cidr "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IP CIDR" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ip-host b/src/validators/ip-host index 77c578fa2..7c5ad2612 100755 --- a/src/validators/ip-host +++ b/src/validators/ip-host @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-any-host $1 +ipaddrcheck --is-any-host "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IP host" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ip-prefix b/src/validators/ip-prefix index e5a64fea8..25204ace5 100755 --- a/src/validators/ip-prefix +++ b/src/validators/ip-prefix @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-any-net $1 +ipaddrcheck --is-any-net "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IP prefix" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv4 b/src/validators/ipv4 index 8676d5800..11f854cf1 100755 --- a/src/validators/ipv4 +++ b/src/validators/ipv4 @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv4 $1 +ipaddrcheck --is-ipv4 "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not IPv4" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv4-address b/src/validators/ipv4-address index 058db088b..1cfd961ba 100755 --- a/src/validators/ipv4-address +++ b/src/validators/ipv4-address @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv4-single $1 +ipaddrcheck --is-ipv4-single "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv4 address" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv4-host b/src/validators/ipv4-host index 74b8c36a7..eb8faaa2a 100755 --- a/src/validators/ipv4-host +++ b/src/validators/ipv4-host @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv4-host $1 +ipaddrcheck --is-ipv4-host "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv4 host" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv4-multicast b/src/validators/ipv4-multicast index 3f28c51db..cf871bd59 100755 --- a/src/validators/ipv4-multicast +++ b/src/validators/ipv4-multicast @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv4-multicast $1 && ipaddrcheck --is-ipv4-single $1 +ipaddrcheck --is-ipv4-multicast "$1" && ipaddrcheck --is-ipv4-single "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv4 multicast address" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv4-prefix b/src/validators/ipv4-prefix index 7e1e0e8dd..f8d46c69c 100755 --- a/src/validators/ipv4-prefix +++ b/src/validators/ipv4-prefix @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv4-net $1 +ipaddrcheck --is-ipv4-net "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv4 prefix" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv4-range b/src/validators/ipv4-range index 6492bfc52..7bf271bbb 100755 --- a/src/validators/ipv4-range +++ b/src/validators/ipv4-range @@ -1,40 +1,10 @@ -#!/bin/bash +#!/bin/sh -# snippet from https://stackoverflow.com/questions/10768160/ip-address-converter -ip2dec () { - local a b c d ip=$@ - IFS=. read -r a b c d <<< "$ip" - printf '%d\n' "$((a * 256 ** 3 + b * 256 ** 2 + c * 256 + d))" -} +ipaddrcheck --verbose --is-ipv4-range "$1" -error_exit() { - echo "Error: $1 is not a valid IPv4 address range" - exit 1 -} - -# Only run this if there is a hypen present in $1 -if [[ "$1" =~ "-" ]]; then - # This only works with real bash (<<<) - split IP addresses into array with - # hyphen as delimiter - readarray -d - -t strarr <<< $1 - - ipaddrcheck --is-ipv4-single ${strarr[0]} - if [ $? -gt 0 ]; then - error_exit $1 - fi - - ipaddrcheck --is-ipv4-single ${strarr[1]} - if [ $? -gt 0 ]; then - error_exit $1 - fi - - start=$(ip2dec ${strarr[0]}) - stop=$(ip2dec ${strarr[1]}) - if [ $start -ge $stop ]; then - error_exit $1 - fi - - exit 0 +if [ $? -gt 0 ]; then + echo "Error: $1 is not a valid IPv4 address range" + exit 1 fi -error_exit $1 +exit 0 diff --git a/src/validators/ipv6 b/src/validators/ipv6 index 4ae130eb5..57696add7 100755 --- a/src/validators/ipv6 +++ b/src/validators/ipv6 @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv6 $1 +ipaddrcheck --is-ipv6 "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not IPv6" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv6-address b/src/validators/ipv6-address index 1fca77668..460639090 100755 --- a/src/validators/ipv6-address +++ b/src/validators/ipv6-address @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv6-single $1 +ipaddrcheck --is-ipv6-single "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv6 address" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv6-host b/src/validators/ipv6-host index 7085809a9..1eb4d8e35 100755 --- a/src/validators/ipv6-host +++ b/src/validators/ipv6-host @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv6-host $1 +ipaddrcheck --is-ipv6-host "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv6 host" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv6-multicast b/src/validators/ipv6-multicast index 5aa7d734a..746ff7edf 100755 --- a/src/validators/ipv6-multicast +++ b/src/validators/ipv6-multicast @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv6-multicast $1 && ipaddrcheck --is-ipv6-single $1 +ipaddrcheck --is-ipv6-multicast "$1" && ipaddrcheck --is-ipv6-single "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv6 multicast address" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv6-prefix b/src/validators/ipv6-prefix index 890dda723..1bb9b42fe 100755 --- a/src/validators/ipv6-prefix +++ b/src/validators/ipv6-prefix @@ -1,10 +1,10 @@ #!/bin/sh -ipaddrcheck --is-ipv6-net $1 +ipaddrcheck --is-ipv6-net "$1" if [ $? -gt 0 ]; then echo "Error: $1 is not a valid IPv6 prefix" exit 1 fi -exit 0
\ No newline at end of file +exit 0 diff --git a/src/validators/ipv6-range b/src/validators/ipv6-range index 7080860c4..0d2eb6384 100755 --- a/src/validators/ipv6-range +++ b/src/validators/ipv6-range @@ -1,20 +1,10 @@ -#!/usr/bin/env python3 +#!/bin/sh -from ipaddress import IPv6Address -from sys import argv, exit +ipaddrcheck --verbose --is-ipv6-range "$1" -if __name__ == '__main__': - if len(argv) > 1: - # try to pass validation and raise an error if failed - try: - ipv6_range = argv[1] - range_left = ipv6_range.split('-')[0] - range_right = ipv6_range.split('-')[1] - if not IPv6Address(range_left) < IPv6Address(range_right): - raise ValueError(f'left element {range_left} must be less than right element {range_right}') - except Exception as err: - print(f'Error: {ipv6_range} is not a valid IPv6 range: {err}') - exit(1) - else: - print('Error: an IPv6 range argument must be provided') - exit(1) +if [ $? -gt 0 ]; then + echo "Error: $1 is not a valid IPv6 address range" + exit 1 +fi + +exit 0 |