summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaniil Baturin <daniil@vyos.io>2023-03-07 19:51:18 +0000
committerGitHub <noreply@github.com>2023-03-07 19:51:18 +0000
commit06e810ffc398366f7d4cc00241e0692e6fe81620 (patch)
treea1e2d444b83eadbb8f0e7f43b6493a03c6f90c8a /src
parent9f0857c2e7821ea22d5132d6352ee8cbcb64b738 (diff)
parent1e72e1c68a708518526b8e090f5d6482671cbd57 (diff)
downloadvyos-1x-06e810ffc398366f7d4cc00241e0692e6fe81620.tar.gz
vyos-1x-06e810ffc398366f7d4cc00241e0692e6fe81620.zip
Merge pull request #1868 from jestabro/literal
op-mode: T5051: use Literal types to provide op-mode CLI choices and API enums
Diffstat (limited to 'src')
-rwxr-xr-xsrc/op_mode/bgp.py3
-rwxr-xr-xsrc/op_mode/conntrack.py4
-rwxr-xr-xsrc/op_mode/dhcp.py9
-rwxr-xr-xsrc/op_mode/nat.py10
-rwxr-xr-xsrc/op_mode/neighbor.py8
-rwxr-xr-xsrc/op_mode/openvpn.py7
-rwxr-xr-xsrc/op_mode/route.py6
-rwxr-xr-xsrc/services/api/graphql/generate/schema_from_op_mode.py39
-rw-r--r--src/services/api/graphql/libs/op_mode.py17
9 files changed, 78 insertions, 25 deletions
diff --git a/src/op_mode/bgp.py b/src/op_mode/bgp.py
index 23001a9d7..3f6d45dd7 100755
--- a/src/op_mode/bgp.py
+++ b/src/op_mode/bgp.py
@@ -30,6 +30,7 @@ from vyos.configquery import ConfigTreeQuery
import vyos.opmode
+ArgFamily = typing.Literal['inet', 'inet6']
frr_command_template = Template("""
{% if family %}
@@ -75,7 +76,7 @@ def _verify(func):
@_verify
def show_neighbors(raw: bool,
- family: str,
+ family: ArgFamily,
peer: typing.Optional[str],
vrf: typing.Optional[str]):
kwargs = dict(locals())
diff --git a/src/op_mode/conntrack.py b/src/op_mode/conntrack.py
index df213cc5a..ea7c4c208 100755
--- a/src/op_mode/conntrack.py
+++ b/src/op_mode/conntrack.py
@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys
+import typing
import xmltodict
from tabulate import tabulate
@@ -23,6 +24,7 @@ from vyos.util import run
import vyos.opmode
+ArgFamily = typing.Literal['inet', 'inet6']
def _get_xml_data(family):
"""
@@ -126,7 +128,7 @@ def get_formatted_output(dict_data):
return output
-def show(raw: bool, family: str):
+def show(raw: bool, family: ArgFamily):
family = 'ipv6' if family == 'inet6' else 'ipv4'
conntrack_data = _get_raw_data(family)
if raw:
diff --git a/src/op_mode/dhcp.py b/src/op_mode/dhcp.py
index b9e6e7bc9..587df4abb 100755
--- a/src/op_mode/dhcp.py
+++ b/src/op_mode/dhcp.py
@@ -36,6 +36,9 @@ lease_valid_states = ['all', 'active', 'free', 'expired', 'released', 'abandoned
sort_valid_inet = ['end', 'mac', 'hostname', 'ip', 'pool', 'remaining', 'start', 'state']
sort_valid_inet6 = ['end', 'iaid_duid', 'ip', 'last_communication', 'pool', 'remaining', 'state', 'type']
+ArgFamily = typing.Literal['inet', 'inet6']
+ArgState = typing.Literal['all', 'active', 'free', 'expired', 'released', 'abandoned', 'reset', 'backup']
+
def _utc_to_local(utc_dt):
return datetime.fromtimestamp((datetime.fromtimestamp(utc_dt) - datetime(1970, 1, 1)).total_seconds())
@@ -248,7 +251,7 @@ def _verify(func):
@_verify
-def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]):
+def show_pool_statistics(raw: bool, family: ArgFamily, pool: typing.Optional[str]):
pool_data = _get_raw_pool_statistics(family=family, pool=pool)
if raw:
return pool_data
@@ -257,8 +260,8 @@ def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]):
@_verify
-def show_server_leases(raw: bool, family: str, pool: typing.Optional[str],
- sorted: typing.Optional[str], state: typing.Optional[str]):
+def show_server_leases(raw: bool, family: ArgFamily, pool: typing.Optional[str],
+ sorted: typing.Optional[str], state: typing.Optional[ArgState]):
# if dhcp server is down, inactive leases may still be shown as active, so warn the user.
if not is_systemd_service_running('isc-dhcp-server.service'):
Warning('DHCP server is configured but not started. Data may be stale.')
diff --git a/src/op_mode/nat.py b/src/op_mode/nat.py
index cf06de0e9..c92795745 100755
--- a/src/op_mode/nat.py
+++ b/src/op_mode/nat.py
@@ -31,6 +31,8 @@ from vyos.util import dict_search
base = 'nat'
unconf_message = 'NAT is not configured'
+ArgDirection = typing.Literal['source', 'destination']
+ArgFamily = typing.Literal['inet', 'inet6']
def _get_xml_translation(direction, family, address=None):
"""
@@ -298,7 +300,7 @@ def _verify(func):
@_verify
-def show_rules(raw: bool, direction: str, family: str):
+def show_rules(raw: bool, direction: ArgDirection, family: ArgFamily):
nat_rules = _get_raw_data_rules(direction, family)
if raw:
return nat_rules
@@ -307,7 +309,7 @@ def show_rules(raw: bool, direction: str, family: str):
@_verify
-def show_statistics(raw: bool, direction: str, family: str):
+def show_statistics(raw: bool, direction: ArgDirection, family: ArgFamily):
nat_statistics = _get_raw_data_rules(direction, family)
if raw:
return nat_statistics
@@ -316,8 +318,8 @@ def show_statistics(raw: bool, direction: str, family: str):
@_verify
-def show_translations(raw: bool, direction:
- str, family: str,
+def show_translations(raw: bool, direction: ArgDirection,
+ family: ArgFamily,
address: typing.Optional[str],
verbose: typing.Optional[bool]):
family = 'ipv6' if family == 'inet6' else 'ipv4'
diff --git a/src/op_mode/neighbor.py b/src/op_mode/neighbor.py
index 264dbdc72..b329ea280 100755
--- a/src/op_mode/neighbor.py
+++ b/src/op_mode/neighbor.py
@@ -32,6 +32,9 @@ import typing
import vyos.opmode
+ArgFamily = typing.Literal['inet', 'inet6']
+ArgState = typing.Literal['reachable', 'stale', 'failed', 'permanent']
+
def interface_exists(interface):
import os
return os.path.exists(f'/sys/class/net/{interface}')
@@ -88,7 +91,8 @@ def format_neighbors(neighs, interface=None):
headers = ["Address", "Interface", "Link layer address", "State"]
return tabulate(neighs, headers)
-def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.Optional[str]):
+def show(raw: bool, family: ArgFamily, interface: typing.Optional[str],
+ state: typing.Optional[ArgState]):
""" Display neighbor table contents """
data = get_raw_data(family, interface, state=state)
@@ -97,7 +101,7 @@ def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.
else:
return format_neighbors(data, interface)
-def reset(family: str, interface: typing.Optional[str], address: typing.Optional[str]):
+def reset(family: ArgFamily, interface: typing.Optional[str], address: typing.Optional[str]):
from vyos.util import run
if address and interface:
diff --git a/src/op_mode/openvpn.py b/src/op_mode/openvpn.py
index 79130c7c0..8f88ab422 100755
--- a/src/op_mode/openvpn.py
+++ b/src/op_mode/openvpn.py
@@ -18,6 +18,7 @@
import os
import sys
+import typing
from tabulate import tabulate
import vyos.opmode
@@ -26,6 +27,8 @@ from vyos.util import commit_in_progress
from vyos.util import call
from vyos.config import Config
+ArgMode = typing.Literal['client', 'server', 'site_to_site']
+
def _get_tunnel_address(peer_host, peer_port, status_file):
peer = peer_host + ':' + peer_port
lst = []
@@ -155,7 +158,7 @@ def _get_raw_data(mode: str) -> dict:
d['local_port'] = conf_dict[intf].get('local-port', '')
if conf.exists(f'interfaces openvpn {intf} server client'):
d['configured_clients'] = conf.list_nodes(f'interfaces openvpn {intf} server client')
- if mode in ['client', 'site-to-site']:
+ if mode in ['client', 'site_to_site']:
for client in d['clients']:
if 'shared-secret-key-file' in list(conf_dict[intf]):
client['name'] = 'None (PSK)'
@@ -198,7 +201,7 @@ def _format_openvpn(data: dict) -> str:
return out
-def show(raw: bool, mode: str) -> str:
+def show(raw: bool, mode: ArgMode) -> str:
openvpn_data = _get_raw_data(mode)
if raw:
diff --git a/src/op_mode/route.py b/src/op_mode/route.py
index 7f0f9cbac..d6d6b7d6f 100755
--- a/src/op_mode/route.py
+++ b/src/op_mode/route.py
@@ -54,7 +54,9 @@ frr_command_template = Template("""
{% endif %}
""")
-def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typing.Optional[str]):
+ArgFamily = typing.Literal['inet', 'inet6']
+
+def show_summary(raw: bool, family: ArgFamily, table: typing.Optional[int], vrf: typing.Optional[str]):
from vyos.util import cmd
if family == 'inet':
@@ -94,7 +96,7 @@ def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typin
return output
def show(raw: bool,
- family: str,
+ family: ArgFamily,
net: typing.Optional[str],
table: typing.Optional[int],
protocol: typing.Optional[str],
diff --git a/src/services/api/graphql/generate/schema_from_op_mode.py b/src/services/api/graphql/generate/schema_from_op_mode.py
index 98b2ad7b7..5e00e66bc 100755
--- a/src/services/api/graphql/generate/schema_from_op_mode.py
+++ b/src/services/api/graphql/generate/schema_from_op_mode.py
@@ -26,6 +26,7 @@ from jinja2 import Template
from vyos.defaults import directories
from vyos.opmode import _is_op_mode_function_name as is_op_mode_function_name
+from vyos.opmode import _get_literal_values as get_literal_values
from vyos.util import load_as_module
if __package__ is None or __package__ == '':
sys.path.append(os.path.join(directories['services'], 'api'))
@@ -94,6 +95,14 @@ extend type Mutation {
}
"""
+enum_template = """
+enum {{ enum_name }} {
+ {%- for field_entry in enum_fields %}
+ {{ field_entry }}
+ {%- endfor %}
+}
+"""
+
error_template = """
interface OpModeError {
name: String!
@@ -109,12 +118,18 @@ type {{ name }} implements OpModeError {
{%- endfor %}
"""
-def create_schema(func_name: str, base_name: str, func: callable) -> str:
+def create_schema(func_name: str, base_name: str, func: callable,
+ enums: dict) -> str:
sig = signature(func)
+ for k in sig.parameters:
+ t = get_literal_values(sig.parameters[k].annotation)
+ if t:
+ enums[t] = snake_to_pascal_case(sig.parameters[k].name + '_' + base_name)
+
field_dict = {}
for k in sig.parameters:
- field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation)
+ field_dict[sig.parameters[k].name] = map_type_name(sig.parameters[k].annotation, enums)
# It is assumed that if one is generating a schema for a 'show_*'
# function, that 'get_raw_data' is present and 'raw' is desired.
@@ -137,6 +152,20 @@ def create_schema(func_name: str, base_name: str, func: callable) -> str:
return res
+def create_enums(enums: dict) -> str:
+ enum_data = []
+ for k, v in enums.items():
+ enum = {'enum_name': v, 'enum_fields': list(k)}
+ enum_data.append(enum)
+
+ out = ''
+ j2_template = Template(enum_template)
+ for el in enum_data:
+ out += j2_template.render(el)
+ out += '\n'
+
+ return out
+
def create_error_schema():
from vyos import opmode
@@ -176,11 +205,13 @@ def generate_op_mode_definitions():
funcs_dict[name] = thunk
results = []
+ enums = {} # gather enums from function Literal type args
for name,func in funcs_dict.items():
- res = create_schema(name, basename, func)
+ res = create_schema(name, basename, func, enums)
results.append(res)
- out = '\n'.join(results)
+ out = create_enums(enums)
+ out += '\n'.join(results)
with open(f'{SCHEMA_PATH}/{basename}.graphql', 'w') as f:
f.write(out)
diff --git a/src/services/api/graphql/libs/op_mode.py b/src/services/api/graphql/libs/op_mode.py
index c553bbd67..e91d8bd0f 100644
--- a/src/services/api/graphql/libs/op_mode.py
+++ b/src/services/api/graphql/libs/op_mode.py
@@ -16,13 +16,13 @@
import os
import re
import typing
-import importlib.util
-from typing import Union
+from typing import Union, Tuple, Optional
from humps import decamelize
from vyos.defaults import directories
from vyos.util import load_as_module
from vyos.opmode import _normalize_field_names
+from vyos.opmode import _is_literal_type, _get_literal_values
def load_op_mode_as_module(name: str):
path = os.path.join(directories['op_mode'], name)
@@ -73,7 +73,7 @@ def snake_to_pascal_case(name: str) -> str:
res = ''.join(map(str.title, name.split('_')))
return res
-def map_type_name(type_name: type, optional: bool = False) -> str:
+def map_type_name(type_name: type, enums: Optional[dict] = None, optional: bool = False) -> str:
if type_name == str:
return 'String!' if not optional else 'String = null'
if type_name == int:
@@ -82,12 +82,17 @@ def map_type_name(type_name: type, optional: bool = False) -> str:
return 'Boolean = false'
if typing.get_origin(type_name) == list:
if not optional:
- return f'[{map_type_name(typing.get_args(type_name)[0])}]!'
- return f'[{map_type_name(typing.get_args(type_name)[0])}]'
+ return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]!'
+ return f'[{map_type_name(typing.get_args(type_name)[0], enums=enums)}]'
+ if _is_literal_type(type_name):
+ mapped = enums.get(_get_literal_values(type_name), '')
+ if not mapped:
+ raise ValueError(typing.get_args(type_name))
+ return f'{mapped}!' if not optional else mapped
# typing.Optional is typing.Union[_, NoneType]
if (typing.get_origin(type_name) is typing.Union and
typing.get_args(type_name)[1] == type(None)):
- return f'{map_type_name(typing.get_args(type_name)[0], optional=True)}'
+ return f'{map_type_name(typing.get_args(type_name)[0], enums=enums, optional=True)}'
# scalar 'Generic' is defined in schema.graphql
return 'Generic'