diff options
| -rwxr-xr-x | src/services/api/graphql/generate/schema_from_op_mode.py | 39 | ||||
| -rw-r--r-- | src/services/api/graphql/libs/op_mode.py | 17 | 
2 files changed, 46 insertions, 10 deletions
| diff --git a/src/services/api/graphql/generate/schema_from_op_mode.py b/src/services/api/graphql/generate/schema_from_op_mode.py index 98b2ad7b7..5e00e66bc 100755 --- a/src/services/api/graphql/generate/schema_from_op_mode.py +++ b/src/services/api/graphql/generate/schema_from_op_mode.py @@ -26,6 +26,7 @@ from jinja2 import Template  from vyos.defaults import directories  from vyos.opmode import _is_op_mode_function_name as is_op_mode_function_name +from vyos.opmode import _get_literal_values as get_literal_values  from vyos.util import load_as_module  if __package__ is None or __package__ == '':      sys.path.append(os.path.join(directories['services'], 'api')) @@ -94,6 +95,14 @@ extend type Mutation {  }  """ +enum_template = """ +enum {{ enum_name }} { +    {%- for field_entry in enum_fields %} +    {{ field_entry }} +    {%- endfor %} +} +""" +  error_template = """  interface OpModeError {      name: String! @@ -109,12 +118,18 @@ type {{ name }} implements OpModeError {  {%- endfor %}  """ -def create_schema(func_name: str, base_name: str, func: callable) -> str: +def create_schema(func_name: str, base_name: str, func: callable, +                  enums: dict) -> str:      sig = signature(func) +    for k in sig.parameters: +        t = get_literal_values(sig.parameters[k].annotation) +        if t: +            enums[t] = snake_to_pascal_case(sig.parameters[k].name + '_' + base_name) +      field_dict = {}      for k in sig.parameters: -        field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation) +        field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation, enums)      # It is assumed that if one is generating a schema for a 'show_*'      # function, that 'get_raw_data' is present and 'raw' is desired. @@ -137,6 +152,20 @@ def create_schema(func_name: str, base_name: str, func: callable) -> str:      return res +def create_enums(enums: dict) -> str: +    enum_data = [] +    for k, v in enums.items(): +        enum = {'enum_name': v, 'enum_fields': list(k)} +        enum_data.append(enum) + +    out = '' +    j2_template = Template(enum_template) +    for el in enum_data: +        out += j2_template.render(el) +        out += '\n' + +    return out +  def create_error_schema():      from vyos import opmode @@ -176,11 +205,13 @@ def generate_op_mode_definitions():              funcs_dict[name] = thunk          results = [] +        enums = {} # gather enums from function Literal type args          for name,func in funcs_dict.items(): -            res = create_schema(name, basename, func) +            res = create_schema(name, basename, func, enums)              results.append(res) -        out = '\n'.join(results) +        out = create_enums(enums) +        out += '\n'.join(results)          with open(f'{SCHEMA_PATH}/{basename}.graphql', 'w') as f:              f.write(out) diff --git a/src/services/api/graphql/libs/op_mode.py b/src/services/api/graphql/libs/op_mode.py index c553bbd67..e91d8bd0f 100644 --- a/src/services/api/graphql/libs/op_mode.py +++ b/src/services/api/graphql/libs/op_mode.py @@ -16,13 +16,13 @@  import os  import re  import typing -import importlib.util -from typing import Union +from typing import Union, Tuple, Optional  from humps import decamelize  from vyos.defaults import directories  from vyos.util import load_as_module  from vyos.opmode import _normalize_field_names +from vyos.opmode import _is_literal_type, _get_literal_values  def load_op_mode_as_module(name: str):      path = os.path.join(directories['op_mode'], name) @@ -73,7 +73,7 @@ def snake_to_pascal_case(name: str) -> str:      res = ''.join(map(str.title, name.split('_')))      return res -def map_type_name(type_name: type, optional: bool = False) -> str: +def map_type_name(type_name: type, enums: Optional[dict] = None, optional: bool = False) -> str:      if type_name == str:          return 'String!' if not optional else 'String = null'      if type_name == int: @@ -82,12 +82,17 @@ def map_type_name(type_name: type, optional: bool = False) -> str:          return 'Boolean = false'      if typing.get_origin(type_name) == list:          if not optional: -            return f'[{map_type_name(typing.get_args(type_name)[0])}]!' -        return f'[{map_type_name(typing.get_args(type_name)[0])}]' +            return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]!' +        return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]' +    if _is_literal_type(type_name): +        mapped = enums.get(_get_literal_values(type_name), '') +        if not mapped: +            raise ValueError(typing.get_args(type_name)) +        return f'{mapped}!' if not optional else mapped      # typing.Optional is typing.Union[_, NoneType]      if (typing.get_origin(type_name) is typing.Union and              typing.get_args(type_name)[1] == type(None)): -        return f'{map_type_name(typing.get_args(type_name)[0], optional=True)}' +        return f'{map_type_name(typing.get_args(type_name)[0], enums=enums, optional=True)}'      # scalar 'Generic' is defined in schema.graphql      return 'Generic' | 
