diff options
| -rw-r--r-- | op-mode-definitions/openvpn.xml.in | 2 | ||||
| -rw-r--r-- | python/vyos/opmode.py | 35 | ||||
| -rwxr-xr-x | src/op_mode/bgp.py | 3 | ||||
| -rwxr-xr-x | src/op_mode/conntrack.py | 4 | ||||
| -rwxr-xr-x | src/op_mode/dhcp.py | 9 | ||||
| -rwxr-xr-x | src/op_mode/nat.py | 10 | ||||
| -rwxr-xr-x | src/op_mode/neighbor.py | 8 | ||||
| -rwxr-xr-x | src/op_mode/openvpn.py | 7 | ||||
| -rwxr-xr-x | src/op_mode/route.py | 6 | ||||
| -rwxr-xr-x | src/services/api/graphql/generate/schema_from_op_mode.py | 39 | ||||
| -rw-r--r-- | src/services/api/graphql/libs/op_mode.py | 17 | 
11 files changed, 112 insertions, 28 deletions
| diff --git a/op-mode-definitions/openvpn.xml.in b/op-mode-definitions/openvpn.xml.in index 0aa9c3209..94647af02 100644 --- a/op-mode-definitions/openvpn.xml.in +++ b/op-mode-definitions/openvpn.xml.in @@ -122,7 +122,7 @@              <properties>                <help>Show tunnel status for OpenVPN site-to-site interfaces</help>              </properties> -            <command>sudo ${vyos_op_scripts_dir}/openvpn.py show --mode site-to-site</command> +            <command>sudo ${vyos_op_scripts_dir}/openvpn.py show --mode site_to_site</command>            </leafNode>          </children>        </node> diff --git a/python/vyos/opmode.py b/python/vyos/opmode.py index d02ad4de6..d7172a0b5 100644 --- a/python/vyos/opmode.py +++ b/python/vyos/opmode.py @@ -128,6 +128,25 @@ def _get_arg_type(t):      else:          return t +def _is_literal_type(t): +    if _is_optional_type(t): +        t = _get_arg_type(t) + +    if typing.get_origin(t) == typing.Literal: +        return True + +    return False + +def _get_literal_values(t): +    """ Returns the tuple of allowed values for a Literal type +    """ +    if not _is_literal_type(t): +        return tuple() +    if _is_optional_type(t): +        t = _get_arg_type(t) + +    return typing.get_args(t) +  def _normalize_field_name(name):      # Convert the name to string if it is not      # (in some cases they may be numbers) @@ -194,9 +213,21 @@ def run(module):                  subparser.add_argument(f"--{opt}", action='store_true')              else:                  if _is_optional_type(th): -                    subparser.add_argument(f"--{opt}", type=_get_arg_type(th), default=None) +                    if _is_literal_type(th): +                        subparser.add_argument(f"--{opt}", +                                               choices=list(_get_literal_values(th)), +                                               default=None) +                    else: +                        subparser.add_argument(f"--{opt}", +                                               type=_get_arg_type(th), default=None)                  else: -                    subparser.add_argument(f"--{opt}", type=_get_arg_type(th), required=True) +                    if _is_literal_type(th): +                        subparser.add_argument(f"--{opt}", +                                               choices=list(_get_literal_values(th)), +                                               required=True) +                    else: +                        subparser.add_argument(f"--{opt}", +                                               type=_get_arg_type(th), required=True)      # Get options as a dict rather than a namespace,      # so that we can modify it and pack for passing to functions diff --git a/src/op_mode/bgp.py b/src/op_mode/bgp.py index 23001a9d7..3f6d45dd7 100755 --- a/src/op_mode/bgp.py +++ b/src/op_mode/bgp.py @@ -30,6 +30,7 @@ from vyos.configquery import ConfigTreeQuery  import vyos.opmode +ArgFamily = typing.Literal['inet', 'inet6']  frr_command_template = Template("""  {% if family %} @@ -75,7 +76,7 @@ def _verify(func):  @_verify  def show_neighbors(raw: bool, -                   family: str, +                   family: ArgFamily,                     peer: typing.Optional[str],                     vrf: typing.Optional[str]):      kwargs = dict(locals()) diff --git a/src/op_mode/conntrack.py b/src/op_mode/conntrack.py index df213cc5a..ea7c4c208 100755 --- a/src/op_mode/conntrack.py +++ b/src/op_mode/conntrack.py @@ -15,6 +15,7 @@  # along with this program.  If not, see <http://www.gnu.org/licenses/>.  import sys +import typing  import xmltodict  from tabulate import tabulate @@ -23,6 +24,7 @@ from vyos.util import run  import vyos.opmode +ArgFamily = typing.Literal['inet', 'inet6']  def _get_xml_data(family):      """ @@ -126,7 +128,7 @@ def get_formatted_output(dict_data):      return output -def show(raw: bool, family: str): +def show(raw: bool, family: ArgFamily):      family = 'ipv6' if family == 'inet6' else 'ipv4'      conntrack_data = _get_raw_data(family)      if raw: diff --git a/src/op_mode/dhcp.py b/src/op_mode/dhcp.py index b9e6e7bc9..587df4abb 100755 --- a/src/op_mode/dhcp.py +++ b/src/op_mode/dhcp.py @@ -36,6 +36,9 @@ lease_valid_states = ['all', 'active', 'free', 'expired', 'released', 'abandoned  sort_valid_inet = ['end', 'mac', 'hostname', 'ip', 'pool', 'remaining', 'start', 'state']  sort_valid_inet6 = ['end', 'iaid_duid', 'ip', 'last_communication', 'pool', 'remaining', 'state', 'type'] +ArgFamily = typing.Literal['inet', 'inet6'] +ArgState = typing.Literal['all', 'active', 'free', 'expired', 'released', 'abandoned', 'reset', 'backup'] +  def _utc_to_local(utc_dt):      return datetime.fromtimestamp((datetime.fromtimestamp(utc_dt) - datetime(1970, 1, 1)).total_seconds()) @@ -248,7 +251,7 @@ def _verify(func):  @_verify -def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]): +def show_pool_statistics(raw: bool, family: ArgFamily, pool: typing.Optional[str]):      pool_data = _get_raw_pool_statistics(family=family, pool=pool)      if raw:          return pool_data @@ -257,8 +260,8 @@ def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]):  @_verify -def show_server_leases(raw: bool, family: str, pool: typing.Optional[str], -                       sorted: typing.Optional[str], state: typing.Optional[str]): +def show_server_leases(raw: bool, family: ArgFamily, pool: typing.Optional[str], +                       sorted: typing.Optional[str], state: typing.Optional[ArgState]):      # if dhcp server is down, inactive leases may still be shown as active, so warn the user.      if not is_systemd_service_running('isc-dhcp-server.service'):          Warning('DHCP server is configured but not started. Data may be stale.') diff --git a/src/op_mode/nat.py b/src/op_mode/nat.py index cf06de0e9..c92795745 100755 --- a/src/op_mode/nat.py +++ b/src/op_mode/nat.py @@ -31,6 +31,8 @@ from vyos.util import dict_search  base = 'nat'  unconf_message = 'NAT is not configured' +ArgDirection = typing.Literal['source', 'destination'] +ArgFamily = typing.Literal['inet', 'inet6']  def _get_xml_translation(direction, family, address=None):      """ @@ -298,7 +300,7 @@ def _verify(func):  @_verify -def show_rules(raw: bool, direction: str, family: str): +def show_rules(raw: bool, direction: ArgDirection, family: ArgFamily):      nat_rules = _get_raw_data_rules(direction, family)      if raw:          return nat_rules @@ -307,7 +309,7 @@ def show_rules(raw: bool, direction: str, family: str):  @_verify -def show_statistics(raw: bool, direction: str, family: str): +def show_statistics(raw: bool, direction: ArgDirection, family: ArgFamily):      nat_statistics = _get_raw_data_rules(direction, family)      if raw:          return nat_statistics @@ -316,8 +318,8 @@ def show_statistics(raw: bool, direction: str, family: str):  @_verify -def show_translations(raw: bool, direction: -                      str, family: str, +def show_translations(raw: bool, direction: ArgDirection, +                      family: ArgFamily,                        address: typing.Optional[str],                        verbose: typing.Optional[bool]):      family = 'ipv6' if family == 'inet6' else 'ipv4' diff --git a/src/op_mode/neighbor.py b/src/op_mode/neighbor.py index 264dbdc72..b329ea280 100755 --- a/src/op_mode/neighbor.py +++ b/src/op_mode/neighbor.py @@ -32,6 +32,9 @@ import typing  import vyos.opmode +ArgFamily = typing.Literal['inet', 'inet6'] +ArgState = typing.Literal['reachable', 'stale', 'failed', 'permanent'] +  def interface_exists(interface):      import os      return os.path.exists(f'/sys/class/net/{interface}') @@ -88,7 +91,8 @@ def format_neighbors(neighs, interface=None):      headers = ["Address", "Interface", "Link layer address",  "State"]      return tabulate(neighs, headers) -def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.Optional[str]): +def show(raw: bool, family: ArgFamily, interface: typing.Optional[str], +         state: typing.Optional[ArgState]):      """ Display neighbor table contents """      data = get_raw_data(family, interface, state=state) @@ -97,7 +101,7 @@ def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.      else:          return format_neighbors(data, interface) -def reset(family: str, interface: typing.Optional[str], address: typing.Optional[str]): +def reset(family: ArgFamily, interface: typing.Optional[str], address: typing.Optional[str]):      from vyos.util import run      if address and interface: diff --git a/src/op_mode/openvpn.py b/src/op_mode/openvpn.py index 79130c7c0..8f88ab422 100755 --- a/src/op_mode/openvpn.py +++ b/src/op_mode/openvpn.py @@ -18,6 +18,7 @@  import os  import sys +import typing  from tabulate import tabulate  import vyos.opmode @@ -26,6 +27,8 @@ from vyos.util import commit_in_progress  from vyos.util import call  from vyos.config import Config +ArgMode = typing.Literal['client', 'server', 'site_to_site'] +  def _get_tunnel_address(peer_host, peer_port, status_file):      peer = peer_host + ':' + peer_port      lst = [] @@ -155,7 +158,7 @@ def _get_raw_data(mode: str) -> dict:          d['local_port'] = conf_dict[intf].get('local-port', '')          if conf.exists(f'interfaces openvpn {intf} server client'):              d['configured_clients'] = conf.list_nodes(f'interfaces openvpn {intf} server client') -        if mode in ['client', 'site-to-site']: +        if mode in ['client', 'site_to_site']:              for client in d['clients']:                  if 'shared-secret-key-file' in list(conf_dict[intf]):                      client['name'] = 'None (PSK)' @@ -198,7 +201,7 @@ def _format_openvpn(data: dict) -> str:      return out -def show(raw: bool, mode: str) -> str: +def show(raw: bool, mode: ArgMode) -> str:      openvpn_data = _get_raw_data(mode)      if raw: diff --git a/src/op_mode/route.py b/src/op_mode/route.py index 7f0f9cbac..d6d6b7d6f 100755 --- a/src/op_mode/route.py +++ b/src/op_mode/route.py @@ -54,7 +54,9 @@ frr_command_template = Template("""  {% endif %}  """) -def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typing.Optional[str]): +ArgFamily = typing.Literal['inet', 'inet6'] + +def show_summary(raw: bool, family: ArgFamily, table: typing.Optional[int], vrf: typing.Optional[str]):      from vyos.util import cmd      if family == 'inet': @@ -94,7 +96,7 @@ def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typin          return output  def show(raw: bool, -         family: str, +         family: ArgFamily,           net: typing.Optional[str],           table: typing.Optional[int],           protocol: typing.Optional[str], diff --git a/src/services/api/graphql/generate/schema_from_op_mode.py b/src/services/api/graphql/generate/schema_from_op_mode.py index 98b2ad7b7..5e00e66bc 100755 --- a/src/services/api/graphql/generate/schema_from_op_mode.py +++ b/src/services/api/graphql/generate/schema_from_op_mode.py @@ -26,6 +26,7 @@ from jinja2 import Template  from vyos.defaults import directories  from vyos.opmode import _is_op_mode_function_name as is_op_mode_function_name +from vyos.opmode import _get_literal_values as get_literal_values  from vyos.util import load_as_module  if __package__ is None or __package__ == '':      sys.path.append(os.path.join(directories['services'], 'api')) @@ -94,6 +95,14 @@ extend type Mutation {  }  """ +enum_template = """ +enum {{ enum_name }} { +    {%- for field_entry in enum_fields %} +    {{ field_entry }} +    {%- endfor %} +} +""" +  error_template = """  interface OpModeError {      name: String! @@ -109,12 +118,18 @@ type {{ name }} implements OpModeError {  {%- endfor %}  """ -def create_schema(func_name: str, base_name: str, func: callable) -> str: +def create_schema(func_name: str, base_name: str, func: callable, +                  enums: dict) -> str:      sig = signature(func) +    for k in sig.parameters: +        t = get_literal_values(sig.parameters[k].annotation) +        if t: +            enums[t] = snake_to_pascal_case(sig.parameters[k].name + '_' + base_name) +      field_dict = {}      for k in sig.parameters: -        field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation) +        field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation, enums)      # It is assumed that if one is generating a schema for a 'show_*'      # function, that 'get_raw_data' is present and 'raw' is desired. @@ -137,6 +152,20 @@ def create_schema(func_name: str, base_name: str, func: callable) -> str:      return res +def create_enums(enums: dict) -> str: +    enum_data = [] +    for k, v in enums.items(): +        enum = {'enum_name': v, 'enum_fields': list(k)} +        enum_data.append(enum) + +    out = '' +    j2_template = Template(enum_template) +    for el in enum_data: +        out += j2_template.render(el) +        out += '\n' + +    return out +  def create_error_schema():      from vyos import opmode @@ -176,11 +205,13 @@ def generate_op_mode_definitions():              funcs_dict[name] = thunk          results = [] +        enums = {} # gather enums from function Literal type args          for name,func in funcs_dict.items(): -            res = create_schema(name, basename, func) +            res = create_schema(name, basename, func, enums)              results.append(res) -        out = '\n'.join(results) +        out = create_enums(enums) +        out += '\n'.join(results)          with open(f'{SCHEMA_PATH}/{basename}.graphql', 'w') as f:              f.write(out) diff --git a/src/services/api/graphql/libs/op_mode.py b/src/services/api/graphql/libs/op_mode.py index c553bbd67..e91d8bd0f 100644 --- a/src/services/api/graphql/libs/op_mode.py +++ b/src/services/api/graphql/libs/op_mode.py @@ -16,13 +16,13 @@  import os  import re  import typing -import importlib.util -from typing import Union +from typing import Union, Tuple, Optional  from humps import decamelize  from vyos.defaults import directories  from vyos.util import load_as_module  from vyos.opmode import _normalize_field_names +from vyos.opmode import _is_literal_type, _get_literal_values  def load_op_mode_as_module(name: str):      path = os.path.join(directories['op_mode'], name) @@ -73,7 +73,7 @@ def snake_to_pascal_case(name: str) -> str:      res = ''.join(map(str.title, name.split('_')))      return res -def map_type_name(type_name: type, optional: bool = False) -> str: +def map_type_name(type_name: type, enums: Optional[dict] = None, optional: bool = False) -> str:      if type_name == str:          return 'String!' if not optional else 'String = null'      if type_name == int: @@ -82,12 +82,17 @@ def map_type_name(type_name: type, optional: bool = False) -> str:          return 'Boolean = false'      if typing.get_origin(type_name) == list:          if not optional: -            return f'[{map_type_name(typing.get_args(type_name)[0])}]!' -        return f'[{map_type_name(typing.get_args(type_name)[0])}]' +            return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]!' +        return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]' +    if _is_literal_type(type_name): +        mapped = enums.get(_get_literal_values(type_name), '') +        if not mapped: +            raise ValueError(typing.get_args(type_name)) +        return f'{mapped}!' if not optional else mapped      # typing.Optional is typing.Union[_, NoneType]      if (typing.get_origin(type_name) is typing.Union and              typing.get_args(type_name)[1] == type(None)): -        return f'{map_type_name(typing.get_args(type_name)[0], optional=True)}' +        return f'{map_type_name(typing.get_args(type_name)[0], enums=enums, optional=True)}'      # scalar 'Generic' is defined in schema.graphql      return 'Generic' | 
