From 1e72e1c68a708518526b8e090f5d6482671cbd57 Mon Sep 17 00:00:00 2001
From: John Estabrook <jestabro@vyos.io>
Date: Fri, 3 Mar 2023 11:30:27 -0600
Subject: op-mode: T5051: use typing.Literal in op-mode scripts

---
 src/op_mode/bgp.py       |  3 ++-
 src/op_mode/conntrack.py |  4 +++-
 src/op_mode/dhcp.py      |  9 ++++++---
 src/op_mode/nat.py       | 10 ++++++----
 src/op_mode/neighbor.py  |  8 ++++++--
 src/op_mode/openvpn.py   |  7 +++++--
 src/op_mode/route.py     |  6 ++++--
 7 files changed, 32 insertions(+), 15 deletions(-)

(limited to 'src/op_mode')

diff --git a/src/op_mode/bgp.py b/src/op_mode/bgp.py
index 23001a9d7..3f6d45dd7 100755
--- a/src/op_mode/bgp.py
+++ b/src/op_mode/bgp.py
@@ -30,6 +30,7 @@ from vyos.configquery import ConfigTreeQuery
 
 import vyos.opmode
 
+ArgFamily = typing.Literal['inet', 'inet6']
 
 frr_command_template = Template("""
 {% if family %}
@@ -75,7 +76,7 @@ def _verify(func):
 
 @_verify
 def show_neighbors(raw: bool,
-                   family: str,
+                   family: ArgFamily,
                    peer: typing.Optional[str],
                    vrf: typing.Optional[str]):
     kwargs = dict(locals())
diff --git a/src/op_mode/conntrack.py b/src/op_mode/conntrack.py
index df213cc5a..ea7c4c208 100755
--- a/src/op_mode/conntrack.py
+++ b/src/op_mode/conntrack.py
@@ -15,6 +15,7 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 import sys
+import typing
 import xmltodict
 
 from tabulate import tabulate
@@ -23,6 +24,7 @@ from vyos.util import run
 
 import vyos.opmode
 
+ArgFamily = typing.Literal['inet', 'inet6']
 
 def _get_xml_data(family):
     """
@@ -126,7 +128,7 @@ def get_formatted_output(dict_data):
     return output
 
 
-def show(raw: bool, family: str):
+def show(raw: bool, family: ArgFamily):
     family = 'ipv6' if family == 'inet6' else 'ipv4'
     conntrack_data = _get_raw_data(family)
     if raw:
diff --git a/src/op_mode/dhcp.py b/src/op_mode/dhcp.py
index b9e6e7bc9..587df4abb 100755
--- a/src/op_mode/dhcp.py
+++ b/src/op_mode/dhcp.py
@@ -36,6 +36,9 @@ lease_valid_states = ['all', 'active', 'free', 'expired', 'released', 'abandoned
 sort_valid_inet = ['end', 'mac', 'hostname', 'ip', 'pool', 'remaining', 'start', 'state']
 sort_valid_inet6 = ['end', 'iaid_duid', 'ip', 'last_communication', 'pool', 'remaining', 'state', 'type']
 
+ArgFamily = typing.Literal['inet', 'inet6']
+ArgState = typing.Literal['all', 'active', 'free', 'expired', 'released', 'abandoned', 'reset', 'backup']
+
 def _utc_to_local(utc_dt):
     return datetime.fromtimestamp((datetime.fromtimestamp(utc_dt) - datetime(1970, 1, 1)).total_seconds())
 
@@ -248,7 +251,7 @@ def _verify(func):
 
 
 @_verify
-def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]):
+def show_pool_statistics(raw: bool, family: ArgFamily, pool: typing.Optional[str]):
     pool_data = _get_raw_pool_statistics(family=family, pool=pool)
     if raw:
         return pool_data
@@ -257,8 +260,8 @@ def show_pool_statistics(raw: bool, family: str, pool: typing.Optional[str]):
 
 
 @_verify
-def show_server_leases(raw: bool, family: str, pool: typing.Optional[str],
-                       sorted: typing.Optional[str], state: typing.Optional[str]):
+def show_server_leases(raw: bool, family: ArgFamily, pool: typing.Optional[str],
+                       sorted: typing.Optional[str], state: typing.Optional[ArgState]):
     # if dhcp server is down, inactive leases may still be shown as active, so warn the user.
     if not is_systemd_service_running('isc-dhcp-server.service'):
         Warning('DHCP server is configured but not started. Data may be stale.')
diff --git a/src/op_mode/nat.py b/src/op_mode/nat.py
index cf06de0e9..c92795745 100755
--- a/src/op_mode/nat.py
+++ b/src/op_mode/nat.py
@@ -31,6 +31,8 @@ from vyos.util import dict_search
 base = 'nat'
 unconf_message = 'NAT is not configured'
 
+ArgDirection = typing.Literal['source', 'destination']
+ArgFamily = typing.Literal['inet', 'inet6']
 
 def _get_xml_translation(direction, family, address=None):
     """
@@ -298,7 +300,7 @@ def _verify(func):
 
 
 @_verify
-def show_rules(raw: bool, direction: str, family: str):
+def show_rules(raw: bool, direction: ArgDirection, family: ArgFamily):
     nat_rules = _get_raw_data_rules(direction, family)
     if raw:
         return nat_rules
@@ -307,7 +309,7 @@ def show_rules(raw: bool, direction: str, family: str):
 
 
 @_verify
-def show_statistics(raw: bool, direction: str, family: str):
+def show_statistics(raw: bool, direction: ArgDirection, family: ArgFamily):
     nat_statistics = _get_raw_data_rules(direction, family)
     if raw:
         return nat_statistics
@@ -316,8 +318,8 @@ def show_statistics(raw: bool, direction: str, family: str):
 
 
 @_verify
-def show_translations(raw: bool, direction:
-                      str, family: str,
+def show_translations(raw: bool, direction: ArgDirection,
+                      family: ArgFamily,
                       address: typing.Optional[str],
                       verbose: typing.Optional[bool]):
     family = 'ipv6' if family == 'inet6' else 'ipv4'
diff --git a/src/op_mode/neighbor.py b/src/op_mode/neighbor.py
index 264dbdc72..b329ea280 100755
--- a/src/op_mode/neighbor.py
+++ b/src/op_mode/neighbor.py
@@ -32,6 +32,9 @@ import typing
 
 import vyos.opmode
 
+ArgFamily = typing.Literal['inet', 'inet6']
+ArgState = typing.Literal['reachable', 'stale', 'failed', 'permanent']
+
 def interface_exists(interface):
     import os
     return os.path.exists(f'/sys/class/net/{interface}')
@@ -88,7 +91,8 @@ def format_neighbors(neighs, interface=None):
     headers = ["Address", "Interface", "Link layer address",  "State"]
     return tabulate(neighs, headers)
 
-def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.Optional[str]):
+def show(raw: bool, family: ArgFamily, interface: typing.Optional[str],
+         state: typing.Optional[ArgState]):
     """ Display neighbor table contents """
     data = get_raw_data(family, interface, state=state)
 
@@ -97,7 +101,7 @@ def show(raw: bool, family: str, interface: typing.Optional[str], state: typing.
     else:
         return format_neighbors(data, interface)
 
-def reset(family: str, interface: typing.Optional[str], address: typing.Optional[str]):
+def reset(family: ArgFamily, interface: typing.Optional[str], address: typing.Optional[str]):
     from vyos.util import run
 
     if address and interface:
diff --git a/src/op_mode/openvpn.py b/src/op_mode/openvpn.py
index 79130c7c0..8f88ab422 100755
--- a/src/op_mode/openvpn.py
+++ b/src/op_mode/openvpn.py
@@ -18,6 +18,7 @@
 
 import os
 import sys
+import typing
 from tabulate import tabulate
 
 import vyos.opmode
@@ -26,6 +27,8 @@ from vyos.util import commit_in_progress
 from vyos.util import call
 from vyos.config import Config
 
+ArgMode = typing.Literal['client', 'server', 'site_to_site']
+
 def _get_tunnel_address(peer_host, peer_port, status_file):
     peer = peer_host + ':' + peer_port
     lst = []
@@ -155,7 +158,7 @@ def _get_raw_data(mode: str) -> dict:
         d['local_port'] = conf_dict[intf].get('local-port', '')
         if conf.exists(f'interfaces openvpn {intf} server client'):
             d['configured_clients'] = conf.list_nodes(f'interfaces openvpn {intf} server client')
-        if mode in ['client', 'site-to-site']:
+        if mode in ['client', 'site_to_site']:
             for client in d['clients']:
                 if 'shared-secret-key-file' in list(conf_dict[intf]):
                     client['name'] = 'None (PSK)'
@@ -198,7 +201,7 @@ def _format_openvpn(data: dict) -> str:
 
     return out
 
-def show(raw: bool, mode: str) -> str:
+def show(raw: bool, mode: ArgMode) -> str:
     openvpn_data = _get_raw_data(mode)
 
     if raw:
diff --git a/src/op_mode/route.py b/src/op_mode/route.py
index 7f0f9cbac..d6d6b7d6f 100755
--- a/src/op_mode/route.py
+++ b/src/op_mode/route.py
@@ -54,7 +54,9 @@ frr_command_template = Template("""
 {% endif %}
 """)
 
-def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typing.Optional[str]):
+ArgFamily = typing.Literal['inet', 'inet6']
+
+def show_summary(raw: bool, family: ArgFamily, table: typing.Optional[int], vrf: typing.Optional[str]):
     from vyos.util import cmd
 
     if family == 'inet':
@@ -94,7 +96,7 @@ def show_summary(raw: bool, family: str, table: typing.Optional[int], vrf: typin
         return output
 
 def show(raw: bool,
-         family: str,
+         family: ArgFamily,
          net: typing.Optional[str],
          table: typing.Optional[int],
          protocol: typing.Optional[str],
-- 
cgit v1.2.3