diff options
Diffstat (limited to 'src')
-rwxr-xr-x | src/completion/list_interfaces.py | 54 | ||||
-rwxr-xr-x | src/conf_mode/container.py | 37 | ||||
-rwxr-xr-x | src/conf_mode/interfaces-tunnel.py | 2 | ||||
-rwxr-xr-x | src/conf_mode/protocols_ospfv3.py | 12 | ||||
-rwxr-xr-x | src/conf_mode/qos.py | 16 | ||||
-rwxr-xr-x | src/conf_mode/service_ipoe-server.py | 7 | ||||
-rw-r--r-- | src/etc/systemd/system/frr.service.d/override.conf | 2 | ||||
-rw-r--r-- | src/etc/systemd/system/keepalived.service.d/override.conf (renamed from src/systemd/keepalived.service) | 7 | ||||
-rwxr-xr-x | src/op_mode/bgp.py | 3 | ||||
-rwxr-xr-x | src/op_mode/conntrack.py | 4 | ||||
-rwxr-xr-x | src/op_mode/dhcp.py | 32 | ||||
-rwxr-xr-x | src/op_mode/nat.py | 10 | ||||
-rwxr-xr-x | src/op_mode/neighbor.py | 8 | ||||
-rwxr-xr-x | src/op_mode/openvpn.py | 7 | ||||
-rwxr-xr-x | src/op_mode/route.py | 6 | ||||
-rwxr-xr-x | src/services/api/graphql/generate/schema_from_op_mode.py | 126 | ||||
-rw-r--r-- | src/services/api/graphql/libs/op_mode.py | 17 |
17 files changed, 250 insertions, 100 deletions
diff --git a/src/completion/list_interfaces.py b/src/completion/list_interfaces.py deleted file mode 100755 index b19b90156..000000000 --- a/src/completion/list_interfaces.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright (C) 2019-2020 VyOS maintainers and contributors -# -# This program is free software; you can redistribute it and/or modify -# it under the terms of the GNU General Public License version 2 or later as -# published by the Free Software Foundation. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. - -import sys -import argparse -from vyos.ifconfig import Section - -def matching(feature): - for section in Section.feature(feature): - for intf in Section.interfaces(section): - yield intf - -parser = argparse.ArgumentParser() -group = parser.add_mutually_exclusive_group() -group.add_argument("-t", "--type", type=str, help="List interfaces of specific type") -group.add_argument("-b", "--broadcast", action="store_true", help="List all broadcast interfaces") -group.add_argument("-br", "--bridgeable", action="store_true", help="List all bridgeable interfaces") -group.add_argument("-bo", "--bondable", action="store_true", help="List all bondable interfaces") - -args = parser.parse_args() - -if args.type: - try: - interfaces = Section.interfaces(args.type) - print(" ".join(interfaces)) - except ValueError as e: - print(e, file=sys.stderr) - print("") - -elif args.broadcast: - print(" ".join(matching("broadcast"))) - -elif args.bridgeable: - print(" ".join(matching("bridgeable"))) - -elif args.bondable: - # we need to filter out VLAN interfaces identified by a dot (.) in their name - print(" ".join([intf for intf in matching("bondable") if '.' not in intf])) - -else: - print(" ".join(Section.interfaces())) diff --git a/src/conf_mode/container.py b/src/conf_mode/container.py index 10e9e9213..68070ea5b 100755 --- a/src/conf_mode/container.py +++ b/src/conf_mode/container.py @@ -18,7 +18,6 @@ import os from ipaddress import ip_address from ipaddress import ip_network -from time import sleep from json import dumps as json_write from vyos.base import Warning @@ -28,6 +27,7 @@ from vyos.configdict import node_changed from vyos.util import call from vyos.util import cmd from vyos.util import run +from vyos.util import rc_cmd from vyos.util import write_file from vyos.template import inc_ip from vyos.template import is_ipv4 @@ -68,6 +68,9 @@ def get_config(config=None): # container base default values can not be merged here - remove and add them later if 'name' in default_values: del default_values['name'] + # registry will be handled below + if 'registry' in default_values: + del default_values['registry'] container = dict_merge(default_values, container) # Merge per-container default values @@ -95,6 +98,15 @@ def get_config(config=None): container['name'][name]['volume'][volume] = dict_merge( default_values_volume, container['name'][name]['volume'][volume]) + # registry is a tagNode with default values - merge the list from + # default_values['registry'] into the tagNode variables + if 'registry' not in container: + container.update({'registry' : {}}) + default_values = defaults(base) + for registry in default_values['registry'].split(): + tmp = {registry : {}} + container['registry'] = dict_merge(tmp, container['registry']) + # Delete container network, delete containers tmp = node_changed(conf, base + ['network']) if tmp: container.update({'network_remove' : tmp}) @@ -226,6 +238,11 @@ def verify(container): if 'network' in container_config and network in container_config['network']: raise ConfigError(f'Can not remove network "{network}", used by container "{container}"!') + if 'registry' in container and 'authentication' in container['registry']: + for registry, registry_config in container['registry']['authentication'].items(): + if not {'username', 'password'} <= set(registry_config): + raise ConfigError('If registry username or or password is defined, so must be the other!') + return None def generate_run_arguments(name, container_config): @@ -355,6 +372,24 @@ def generate(container): write_file(f'/etc/cni/net.d/{network}.conflist', json_write(tmp, indent=2)) + if 'registry' in container: + cmd = f'podman logout --all' + rc, out = rc_cmd(cmd) + if rc != 0: + raise ConfigError(out) + + for registry, registry_config in container['registry'].items(): + if 'disable' in registry_config: + continue + if 'authentication' in registry_config: + if {'username', 'password'} <= set(registry_config['authentication']): + username = registry_config['authentication']['username'] + password = registry_config['authentication']['password'] + cmd = f'podman login --username {username} --password {password} {registry}' + rc, out = rc_cmd(cmd) + if rc != 0: + raise ConfigError(out) + render(config_containers_registry, 'container/registries.conf.j2', container) render(config_containers_storage, 'container/storage.conf.j2', container) diff --git a/src/conf_mode/interfaces-tunnel.py b/src/conf_mode/interfaces-tunnel.py index e2701d9d3..0a3726e94 100755 --- a/src/conf_mode/interfaces-tunnel.py +++ b/src/conf_mode/interfaces-tunnel.py @@ -136,7 +136,7 @@ def verify(tunnel): if our_key != None: if their_address == our_address and their_key == our_key: raise ConfigError(f'Key "{our_key}" for source-address "{our_address}" ' \ - f'is already used for tunnel "{tunnel_if}"!') + f'is already used for tunnel "{o_tunnel}"!') else: our_source_if = dict_search('source_interface', tunnel) their_source_if = dict_search('source_interface', o_tunnel_conf) diff --git a/src/conf_mode/protocols_ospfv3.py b/src/conf_mode/protocols_ospfv3.py index ed0a8fba2..1e2c02d03 100755 --- a/src/conf_mode/protocols_ospfv3.py +++ b/src/conf_mode/protocols_ospfv3.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (C) 2021 VyOS maintainers and contributors +# Copyright (C) 2021-2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as @@ -146,15 +146,25 @@ def generate(ospfv3): if not ospfv3 or 'deleted' in ospfv3: return None + ospfv3['protocol'] = 'ospf6' # required for frr/vrf.route-map.v6.frr.j2 + ospfv3['frr_zebra_config'] = render_to_string('frr/vrf.route-map.v6.frr.j2', ospfv3) ospfv3['new_frr_config'] = render_to_string('frr/ospf6d.frr.j2', ospfv3) return None def apply(ospfv3): ospf6_daemon = 'ospf6d' + zebra_daemon = 'zebra' # Save original configuration prior to starting any commit actions frr_cfg = frr.FRRConfig() + # The route-map used for the FIB (zebra) is part of the zebra daemon + frr_cfg.load_configuration(zebra_daemon) + frr_cfg.modify_section('(\s+)?ipv6 protocol ospf6 route-map [-a-zA-Z0-9.]+', stop_pattern='(\s|!)') + if 'frr_zebra_config' in ospfv3: + frr_cfg.add_before(frr.default_add_before, ospfv3['frr_zebra_config']) + frr_cfg.commit_configuration(zebra_daemon) + # Generate empty helper string which can be ammended to FRR commands, it # will be either empty (default VRF) or contain the "vrf <name" statement vrf = '' diff --git a/src/conf_mode/qos.py b/src/conf_mode/qos.py index dca713283..1be2c283f 100755 --- a/src/conf_mode/qos.py +++ b/src/conf_mode/qos.py @@ -21,7 +21,9 @@ from netifaces import interfaces from vyos.base import Warning from vyos.config import Config +from vyos.configdep import set_dependents, call_dependents from vyos.configdict import dict_merge +from vyos.ifconfig import Section from vyos.qos import CAKE from vyos.qos import DropTail from vyos.qos import FairQueue @@ -83,6 +85,18 @@ def get_config(config=None): get_first_key=True, no_tag_node_value_mangle=True) + if 'interface' in qos: + for ifname, if_conf in qos['interface'].items(): + if_node = Section.get_config_path(ifname) + + if not if_node: + continue + + path = f'interfaces {if_node}' + if conf.exists(f'{path} mirror') or conf.exists(f'{path} redirect'): + type_node = path.split(" ")[1] # return only interface type node + set_dependents(type_node, conf, ifname) + if 'policy' in qos: for policy in qos['policy']: # when calling defaults() we need to use the real CLI node, thus we @@ -245,6 +259,8 @@ def apply(qos): tmp = shaper_type(interface) tmp.update(shaper_config, direction) + call_dependents() + return None if __name__ == '__main__': diff --git a/src/conf_mode/service_ipoe-server.py b/src/conf_mode/service_ipoe-server.py index e9afd6a55..9cdfa08ef 100755 --- a/src/conf_mode/service_ipoe-server.py +++ b/src/conf_mode/service_ipoe-server.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (C) 2018-2022 VyOS maintainers and contributors +# Copyright (C) 2018-2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as @@ -53,8 +53,11 @@ def verify(ipoe): if 'interface' not in ipoe: raise ConfigError('No IPoE interface configured') - for interface in ipoe['interface']: + for interface, iface_config in ipoe['interface'].items(): verify_interface_exists(interface) + if 'client_subnet' in iface_config and 'vlan' in iface_config: + raise ConfigError('Option "client-subnet" incompatible with "vlan"!' + 'Use "ipoe client-ip-pool" instead.') #verify_accel_ppp_base_service(ipoe, local_users=False) diff --git a/src/etc/systemd/system/frr.service.d/override.conf b/src/etc/systemd/system/frr.service.d/override.conf index 69eb1a86a..2e2f67f70 100644 --- a/src/etc/systemd/system/frr.service.d/override.conf +++ b/src/etc/systemd/system/frr.service.d/override.conf @@ -3,6 +3,8 @@ Before= Before=vyos-router.service [Service] +LimitNOFILE=4096 +LimitNOFILESoft=4096 ExecStartPre=/bin/bash -c 'mkdir -p /run/frr/config; \ echo "log syslog" > /run/frr/config/frr.conf; \ echo "log facility local7" >> /run/frr/config/frr.conf; \ diff --git a/src/systemd/keepalived.service b/src/etc/systemd/system/keepalived.service.d/override.conf index a462d8614..d91a824b9 100644 --- a/src/systemd/keepalived.service +++ b/src/etc/systemd/system/keepalived.service.d/override.conf @@ -1,13 +1,14 @@ [Unit] -Description=Keepalive Daemon (LVS and VRRP) After=vyos-router.service -# Only start if there is a configuration file +# Only start if there is our configuration file - remove Debian default +# config file from the condition list +ConditionFileNotEmpty= ConditionFileNotEmpty=/run/keepalived/keepalived.conf [Service] KillMode=process Type=simple # Read configuration variable file if it is present +ExecStart= ExecStart=/usr/sbin/keepalived --use-file /run/keepalived/keepalived.conf --pid /run/keepalived/keepalived.pid --dont-fork --snmp -ExecReload=/bin/kill -HUP $MAINPID PIDFile=/run/keepalived/keepalived.pid diff --git a/src/op_mode/bgp.py b/src/op_mode/bgp.py index 23001a9d7..3f6d45dd7 100755 --- a/src/op_mode/bgp.py +++ b/src/op_mode/bgp.py @@ -30,6 +30,7 @@ from vyos.configquery import ConfigTreeQuery import vyos.opmode +ArgFamily = typing.Literal['inet', 'inet6'] frr_command_template = Template(""" {% if family %} @@ -75,7 +76,7 @@ def _verify(func): @_verify def show_neighbors(raw: bool, - family: str, + family: ArgFamily, peer: typing.Optional[str], vrf: typing.Optional[str]): kwargs = dict(locals()) diff --git a/src/op_mode/conntrack.py b/src/op_mode/conntrack.py index df213cc5a..ea7c4c208 100755 --- a/src/op_mode/conntrack.py +++ b/src/op_mode/conntrack.py @@ -15,6 +15,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import sys +import typing import xmltodict from tabulate import tabulate @@ -23,6 +24,7 @@ from vyos.util import run import vyos.opmode +ArgFamily = typing.Literal['inet', 'inet6'] def _get_xml_data(family): """ @@ -126,7 +128,7 @@ def get_formatted_output(dict_data): return output -def show(raw: bool, family: str): +def show(raw: bool, family: ArgFamily): family = 'ipv6' if family == 'inet6' else 'ipv4' conntrack_data = _get_raw_data(family) if raw: diff --git a/src/op_mode/dhcp.py b/src/op_mode/dhcp.py index b9e6e7bc9..41da14065 100755 --- a/src/op_mode/dhcp.py +++ b/src/op_mode/dhcp.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright (C) 2022 VyOS maintainers and contributors +# Copyright (C) 2022-2023 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as @@ -36,6 +36,9 @@ lease_valid_states = ['all', 'active', 'free', 'expired', 'released', 'abandoned sort_valid_inet = ['end', 'mac', 'hostname', 'ip', 'pool', 'remaining', 'start', 'state'] sort_valid_inet6 = ['end', 'iaid_duid', 'ip', 'last_communication', 'pool', 'remaining', 'state', 'type'] +ArgFamily = typing.Literal['inet', 'inet6'] +ArgState = typing.Literal['all', 'active', 'free', 'expired', 'released', 'abandoned', 'reset', 'backup'] + def _utc_to_local(utc_dt): return datetime.fromtimestamp((datetime.fromtimestamp(utc_dt) - datetime(1970, 1, 1)).total_seconds()) @@ -82,7 +85,7 @@ def _get_raw_server_leases(family='inet', pool=None, sorted=None, state=[]) -> l data_lease['ip'] = lease.ip data_lease['state'] = lease.binding_state data_lease['pool'] = lease.sets.get('shared-networkname', '') - data_lease['end'] = lease.end.timestamp() + data_lease['end'] = lease.end.timestamp() if lease.end else None if family == 'inet': data_lease['mac'] = lease.ethernet @@ -95,17 +98,18 @@ def _get_raw_server_leases(family='inet', pool=None, sorted=None, state=[]) -> l lease_types_long = {'na': 'non-temporary', 'ta': 'temporary', 'pd': 'prefix delegation'} data_lease['type'] = lease_types_long[lease.type] - data_lease['remaining'] = lease.end - datetime.utcnow() + data_lease['remaining'] = '-' - if data_lease['remaining'].days >= 0: - # substraction gives us a timedelta object which can't be formatted with strftime - # so we use str(), split gets rid of the microseconds - data_lease['remaining'] = str(data_lease["remaining"]).split('.')[0] - else: - data_lease['remaining'] = '' + if lease.end: + data_lease['remaining'] = lease.end - datetime.utcnow() + + if data_lease['remaining'].days >= 0: + # substraction gives us a timedelta object which can't be formatted with strftime + # so we use str(), split gets rid of the microseconds + data_lease['remaining'] = str(data_lease["remaining"]).split('.')[0] # Do not add old leases - if data_lease['remaining'] != '' and data_lease['pool'] in pool: + if data_lease['remaining'] != '' and data_lease['pool'] in pool and data_lease['state'] != 'free': if not state or data_lease['state'] in state: data.append(data_lease) @@ -137,7 +141,7 @@ def _get_formatted_server_leases(raw_data, family='inet'): start = lease.get('start') start = _utc_to_local(start).strftime('%Y/%m/%d %H:%M:%S') end = lease.get('end') - end = _utc_to_local(end).strftime('%Y/%m/%d %H:%M:%S') + end = _utc_to_local(end).strftime('%Y/%m/%d %H:%M:%S') if end else '-' remain = lease.get('remaining') pool = lease.get('pool') hostname = lease.get('hostname') @@ -248,7 +252,7 @@ def _verify(func): @_verify -def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]): +def show_pool_statistics(raw: bool, family: ArgFamily, pool: typing.Optional[str]): pool_data = _get_raw_pool_statistics(family=family, pool=pool) if raw: return pool_data @@ -257,8 +261,8 @@ def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]): @_verify -def show_server_leases(raw: bool, family: str, pool: typing.Optional[str], - sorted: typing.Optional[str], state: typing.Optional[str]): +def show_server_leases(raw: bool, family: ArgFamily, pool: typing.Optional[str], + sorted: typing.Optional[str], state: typing.Optional[ArgState]): # if dhcp server is down, inactive leases may still be shown as active, so warn the user. if not is_systemd_service_running('isc-dhcp-server.service'): Warning('DHCP server is configured but not started. Data may be stale.') diff --git a/src/op_mode/nat.py b/src/op_mode/nat.py index cf06de0e9..c92795745 100755 --- a/src/op_mode/nat.py +++ b/src/op_mode/nat.py @@ -31,6 +31,8 @@ from vyos.util import dict_search base = 'nat' unconf_message = 'NAT is not configured' +ArgDirection = typing.Literal['source', 'destination'] +ArgFamily = typing.Literal['inet', 'inet6'] def _get_xml_translation(direction, family, address=None): """ @@ -298,7 +300,7 @@ def _verify(func): @_verify -def show_rules(raw: bool, direction: str, family: str): +def show_rules(raw: bool, direction: ArgDirection, family: ArgFamily): nat_rules = _get_raw_data_rules(direction, family) if raw: return nat_rules @@ -307,7 +309,7 @@ def show_rules(raw: bool, direction: str, family: str): @_verify -def show_statistics(raw: bool, direction: str, family: str): +def show_statistics(raw: bool, direction: ArgDirection, family: ArgFamily): nat_statistics = _get_raw_data_rules(direction, family) if raw: return nat_statistics @@ -316,8 +318,8 @@ def show_statistics(raw: bool, direction: str, family: str): @_verify -def show_translations(raw: bool, direction: - str, family: str, +def show_translations(raw: bool, direction: ArgDirection, + family: ArgFamily, address: typing.Optional[str], verbose: typing.Optional[bool]): family = 'ipv6' if family == 'inet6' else 'ipv4' diff --git a/src/op_mode/neighbor.py b/src/op_mode/neighbor.py index 264dbdc72..b329ea280 100755 --- a/src/op_mode/neighbor.py +++ b/src/op_mode/neighbor.py @@ -32,6 +32,9 @@ import typing import vyos.opmode +ArgFamily = typing.Literal['inet', 'inet6'] +ArgState = typing.Literal['reachable', 'stale', 'failed', 'permanent'] + def interface_exists(interface): import os return os.path.exists(f'/sys/class/net/{interface}') @@ -88,7 +91,8 @@ def format_neighbors(neighs, interface=None): headers = ["Address", "Interface", "Link layer address", "State"] return tabulate(neighs, headers) -def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.Optional[str]): +def show(raw: bool, family: ArgFamily, interface: typing.Optional[str], + state: typing.Optional[ArgState]): """ Display neighbor table contents """ data = get_raw_data(family, interface, state=state) @@ -97,7 +101,7 @@ def show(raw: bool, family: str, interface: typing.Optional[str], state: typing. else: return format_neighbors(data, interface) -def reset(family: str, interface: typing.Optional[str], address: typing.Optional[str]): +def reset(family: ArgFamily, interface: typing.Optional[str], address: typing.Optional[str]): from vyos.util import run if address and interface: diff --git a/src/op_mode/openvpn.py b/src/op_mode/openvpn.py index 79130c7c0..8f88ab422 100755 --- a/src/op_mode/openvpn.py +++ b/src/op_mode/openvpn.py @@ -18,6 +18,7 @@ import os import sys +import typing from tabulate import tabulate import vyos.opmode @@ -26,6 +27,8 @@ from vyos.util import commit_in_progress from vyos.util import call from vyos.config import Config +ArgMode = typing.Literal['client', 'server', 'site_to_site'] + def _get_tunnel_address(peer_host, peer_port, status_file): peer = peer_host + ':' + peer_port lst = [] @@ -155,7 +158,7 @@ def _get_raw_data(mode: str) -> dict: d['local_port'] = conf_dict[intf].get('local-port', '') if conf.exists(f'interfaces openvpn {intf} server client'): d['configured_clients'] = conf.list_nodes(f'interfaces openvpn {intf} server client') - if mode in ['client', 'site-to-site']: + if mode in ['client', 'site_to_site']: for client in d['clients']: if 'shared-secret-key-file' in list(conf_dict[intf]): client['name'] = 'None (PSK)' @@ -198,7 +201,7 @@ def _format_openvpn(data: dict) -> str: return out -def show(raw: bool, mode: str) -> str: +def show(raw: bool, mode: ArgMode) -> str: openvpn_data = _get_raw_data(mode) if raw: diff --git a/src/op_mode/route.py b/src/op_mode/route.py index 7f0f9cbac..d6d6b7d6f 100755 --- a/src/op_mode/route.py +++ b/src/op_mode/route.py @@ -54,7 +54,9 @@ frr_command_template = Template(""" {% endif %} """) -def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typing.Optional[str]): +ArgFamily = typing.Literal['inet', 'inet6'] + +def show_summary(raw: bool, family: ArgFamily, table: typing.Optional[int], vrf: typing.Optional[str]): from vyos.util import cmd if family == 'inet': @@ -94,7 +96,7 @@ def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typin return output def show(raw: bool, - family: str, + family: ArgFamily, net: typing.Optional[str], table: typing.Optional[int], protocol: typing.Optional[str], diff --git a/src/services/api/graphql/generate/schema_from_op_mode.py b/src/services/api/graphql/generate/schema_from_op_mode.py index 98b2ad7b7..cb7b0fd21 100755 --- a/src/services/api/graphql/generate/schema_from_op_mode.py +++ b/src/services/api/graphql/generate/schema_from_op_mode.py @@ -26,6 +26,7 @@ from jinja2 import Template from vyos.defaults import directories from vyos.opmode import _is_op_mode_function_name as is_op_mode_function_name +from vyos.opmode import _get_literal_values as get_literal_values from vyos.util import load_as_module if __package__ is None or __package__ == '': sys.path.append(os.path.join(directories['services'], 'api')) @@ -38,8 +39,10 @@ else: OP_MODE_PATH = directories['op_mode'] SCHEMA_PATH = directories['api_schema'] +CLIENT_OP_PATH = directories['api_client_op'] DATA_DIR = directories['data'] + op_mode_include_file = os.path.join(DATA_DIR, 'op-mode-standardized.json') op_mode_error_schema = 'op_mode_error.graphql' @@ -94,6 +97,14 @@ extend type Mutation { } """ +enum_template = """ +enum {{ enum_name }} { + {%- for field_entry in enum_fields %} + {{ field_entry }} + {%- endfor %} +} +""" + error_template = """ interface OpModeError { name: String! @@ -109,12 +120,52 @@ type {{ name }} implements OpModeError { {%- endfor %} """ -def create_schema(func_name: str, base_name: str, func: callable) -> str: +op_query_template = """ +query {{ op_name }} ({{ op_sig }}) { + {{ op_name }} (data: { {{ op_arg }} }) { + success + errors + op_mode_error { + name + message + vyos_code + } + data { + result + } + } +} +""" + +op_mutation_template = """ +mutation {{ op_name }} ({{ op_sig }}) { + {{ op_name }} (data: { {{ op_arg }} }) { + success + errors + op_mode_error { + name + message + vyos_code + } + data { + result + } + } +} +""" + +def create_schema(func_name: str, base_name: str, func: callable, + enums: dict) -> str: sig = signature(func) + for k in sig.parameters: + t = get_literal_values(sig.parameters[k].annotation) + if t: + enums[t] = snake_to_pascal_case(sig.parameters[k].name + '_' + base_name) + field_dict = {} for k in sig.parameters: - field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation) + field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation, enums) # It is assumed that if one is generating a schema for a 'show_*' # function, that 'get_raw_data' is present and 'raw' is desired. @@ -137,6 +188,58 @@ def create_schema(func_name: str, base_name: str, func: callable) -> str: return res +def create_client_op(func_name: str, base_name: str, func: callable, + enums: dict) -> str: + sig = signature(func) + + for k in sig.parameters: + t = get_literal_values(sig.parameters[k].annotation) + if t: + enums[t] = snake_to_pascal_case(sig.parameters[k].name + '_' + base_name) + + field_dict = {} + for k in sig.parameters: + field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation, enums) + + # It is assumed that if one is generating a schema for a 'show_*' + # function, that 'get_raw_data' is present and 'raw' is desired. + if 'raw' in list(field_dict): + del field_dict['raw'] + + op_sig = ['$key: String'] + op_arg = ['key: $key'] + for k,v in field_dict.items(): + op_sig.append('$'+k+': '+v) + op_arg.append(k+': $'+k) + + op_data = {} + op_data['op_name'] = snake_to_pascal_case(func_name + '_' + base_name) + op_data['op_sig'] = ', '.join(op_sig) + op_data['op_arg'] = ', '.join(op_arg) + + if is_show_function_name(func_name): + j2_template = Template(op_query_template) + else: + j2_template = Template(op_mutation_template) + + res = j2_template.render(op_data) + + return res + +def create_enums(enums: dict) -> str: + enum_data = [] + for k, v in enums.items(): + enum = {'enum_name': v, 'enum_fields': list(k)} + enum_data.append(enum) + + out = '' + j2_template = Template(enum_template) + for el in enum_data: + out += j2_template.render(el) + out += '\n' + + return out + def create_error_schema(): from vyos import opmode @@ -157,6 +260,8 @@ def create_error_schema(): return res def generate_op_mode_definitions(): + os.makedirs(CLIENT_OP_PATH, exist_ok=True) + out = create_error_schema() with open(f'{SCHEMA_PATH}/{op_mode_error_schema}', 'w') as f: f.write(out) @@ -175,14 +280,23 @@ def generate_op_mode_definitions(): for (name, thunk) in funcs: funcs_dict[name] = thunk - results = [] + schema = [] + client_op = [] + enums = {} # gather enums from function Literal type args for name,func in funcs_dict.items(): - res = create_schema(name, basename, func) - results.append(res) + res = create_schema(name, basename, func, enums) + schema.append(res) + res = create_client_op(name, basename, func, enums) + client_op.append(res) - out = '\n'.join(results) + out = create_enums(enums) + out += '\n'.join(schema) with open(f'{SCHEMA_PATH}/{basename}.graphql', 'w') as f: f.write(out) + out = '\n'.join(client_op) + with open(f'{CLIENT_OP_PATH}/{basename}.graphql', 'w') as f: + f.write(out) + if __name__ == '__main__': generate_op_mode_definitions() diff --git a/src/services/api/graphql/libs/op_mode.py b/src/services/api/graphql/libs/op_mode.py index c553bbd67..e91d8bd0f 100644 --- a/src/services/api/graphql/libs/op_mode.py +++ b/src/services/api/graphql/libs/op_mode.py @@ -16,13 +16,13 @@ import os import re import typing -import importlib.util -from typing import Union +from typing import Union, Tuple, Optional from humps import decamelize from vyos.defaults import directories from vyos.util import load_as_module from vyos.opmode import _normalize_field_names +from vyos.opmode import _is_literal_type, _get_literal_values def load_op_mode_as_module(name: str): path = os.path.join(directories['op_mode'], name) @@ -73,7 +73,7 @@ def snake_to_pascal_case(name: str) -> str: res = ''.join(map(str.title, name.split('_'))) return res -def map_type_name(type_name: type, optional: bool = False) -> str: +def map_type_name(type_name: type, enums: Optional[dict] = None, optional: bool = False) -> str: if type_name == str: return 'String!' if not optional else 'String = null' if type_name == int: @@ -82,12 +82,17 @@ def map_type_name(type_name: type, optional: bool = False) -> str: return 'Boolean = false' if typing.get_origin(type_name) == list: if not optional: - return f'[{map_type_name(typing.get_args(type_name)[0])}]!' - return f'[{map_type_name(typing.get_args(type_name)[0])}]' + return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]!' + return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]' + if _is_literal_type(type_name): + mapped = enums.get(_get_literal_values(type_name), '') + if not mapped: + raise ValueError(typing.get_args(type_name)) + return f'{mapped}!' if not optional else mapped # typing.Optional is typing.Union[_, NoneType] if (typing.get_origin(type_name) is typing.Union and typing.get_args(type_name)[1] == type(None)): - return f'{map_type_name(typing.get_args(type_name)[0], optional=True)}' + return f'{map_type_name(typing.get_args(type_name)[0], enums=enums, optional=True)}' # scalar 'Generic' is defined in schema.graphql return 'Generic' |