From fc9885f859617bab36c971f4eaa56240741f52c4 Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Tue, 24 Sep 2024 22:48:25 -0500 Subject: http-api: T6736: separate REST API and GraphQL API activation The GraphQL API was implemented as an addition to the existing REST API. As there is no necessary dependency, separate the initialization of the respective endpoints. Factor out the REST Pydantic models and FastAPI routes for symmetry and clarity. --- src/services/api/graphql/bindings.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'src/services/api/graphql/bindings.py') diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py index ef4966466..93dd0fbfb 100644 --- a/src/services/api/graphql/bindings.py +++ b/src/services/api/graphql/bindings.py @@ -1,4 +1,4 @@ -# Copyright 2021 VyOS maintainers and contributors +# Copyright 2021-2024 VyOS maintainers and contributors # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -13,24 +13,35 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . + import vyos.defaults + +from ariadne import make_executable_schema +from ariadne import load_schema_from_path +from ariadne import snake_case_fallback_resolvers + from . graphql.queries import query from . graphql.mutations import mutation from . graphql.directives import directives_dict from . graphql.errors import op_mode_error from . graphql.auth_token_mutation import auth_token_mutation from . libs.token_auth import init_secret -from . import state -from ariadne import make_executable_schema, load_schema_from_path, snake_case_fallback_resolvers + +from .. session import SessionState + def generate_schema(): + state = SessionState() api_schema_dir = vyos.defaults.directories['api_schema'] - if state.settings['app'].state.vyos_auth_type == 'token': + if state.auth_type == 'token': init_secret() type_defs = load_schema_from_path(api_schema_dir) - schema = make_executable_schema(type_defs, query, op_mode_error, mutation, auth_token_mutation, snake_case_fallback_resolvers, directives=directives_dict) + schema = make_executable_schema(type_defs, query, op_mode_error, + mutation, auth_token_mutation, + snake_case_fallback_resolvers, + directives=directives_dict) return schema -- cgit v1.2.3 From 7e23fd9da028b3c623b69fda8a6bcfd887f1c18c Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Mon, 30 Sep 2024 20:40:02 -0500 Subject: http-api: T6736: normalize formatting --- smoketest/scripts/cli/test_service_https.py | 75 ++++++-- src/services/api/graphql/bindings.py | 29 +-- .../api/graphql/graphql/auth_token_mutation.py | 31 ++- src/services/api/graphql/graphql/mutations.py | 71 ++++--- src/services/api/graphql/graphql/queries.py | 71 ++++--- src/services/api/graphql/libs/key_auth.py | 4 +- src/services/api/graphql/libs/token_auth.py | 10 +- src/services/api/graphql/routers.py | 41 ++-- src/services/api/graphql/session/session.py | 33 ++-- src/services/api/rest/models.py | 143 ++++++++------ src/services/api/rest/routers.py | 209 +++++++++++++-------- src/services/api/session.py | 1 + 12 files changed, 427 insertions(+), 291 deletions(-) (limited to 'src/services/api/graphql/bindings.py') diff --git a/smoketest/scripts/cli/test_service_https.py b/smoketest/scripts/cli/test_service_https.py index 2f4fcd6ab..04c4a2e51 100755 --- a/smoketest/scripts/cli/test_service_https.py +++ b/smoketest/scripts/cli/test_service_https.py @@ -89,6 +89,7 @@ server { PROCESS_NAME = 'nginx' + class TestHTTPSService(VyOSUnitTestSHIM.TestCase): @classmethod def setUpClass(cls): @@ -120,19 +121,29 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): # verify() - certificates do not exist (yet) with self.assertRaises(ConfigSessionError): self.cli_commit() - self.cli_set(pki_base + ['certificate', cert_name, 'certificate', cert_data.replace('\n','')]) - self.cli_set(pki_base + ['certificate', cert_name, 'private', 'key', key_data.replace('\n','')]) + self.cli_set( + pki_base + + ['certificate', cert_name, 'certificate', cert_data.replace('\n', '')] + ) + self.cli_set( + pki_base + + ['certificate', cert_name, 'private', 'key', key_data.replace('\n', '')] + ) self.cli_set(base_path + ['certificates', 'dh-params', dh_name]) # verify() - dh-params do not exist (yet) with self.assertRaises(ConfigSessionError): self.cli_commit() - self.cli_set(pki_base + ['dh', dh_name, 'parameters', dh_1024.replace('\n','')]) + self.cli_set( + pki_base + ['dh', dh_name, 'parameters', dh_1024.replace('\n', '')] + ) # verify() - dh-param minimum length is 2048 bit with self.assertRaises(ConfigSessionError): self.cli_commit() - self.cli_set(pki_base + ['dh', dh_name, 'parameters', dh_2048.replace('\n','')]) + self.cli_set( + pki_base + ['dh', dh_name, 'parameters', dh_2048.replace('\n', '')] + ) self.cli_commit() self.assertTrue(process_named_running(PROCESS_NAME)) @@ -162,7 +173,7 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): nginx_config = read_file('/etc/nginx/sites-enabled/default') self.assertIn(f'listen {address}:{port} ssl;', nginx_config) - self.assertIn(f'ssl_protocols TLSv1.2 TLSv1.3;', nginx_config) # default + self.assertIn('ssl_protocols TLSv1.2 TLSv1.3;', nginx_config) # default url = f'https://{address}/retrieve' payload = {'data': '{"op": "showConfig", "path": []}', 'key': f'{key}'} @@ -182,11 +193,16 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): self.assertEqual(r.status_code, 401) # Check path config - payload = {'data': '{"op": "showConfig", "path": ["system", "login"]}', 'key': f'{key}'} + payload = { + 'data': '{"op": "showConfig", "path": ["system", "login"]}', + 'key': f'{key}', + } r = request('POST', url, verify=False, headers=headers, data=payload) response = r.json() vyos_user_exists = 'vyos' in response.get('data', {}).get('user', {}) - self.assertTrue(vyos_user_exists, "The 'vyos' user does not exist in the response.") + self.assertTrue( + vyos_user_exists, "The 'vyos' user does not exist in the response." + ) # GraphQL auth test: a missing key will return status code 400, as # 'key' is a non-nullable field in the schema; an incorrect key is @@ -210,7 +226,13 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): }} """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query_valid_key}) + r = request( + 'POST', + graphql_url, + verify=False, + headers=headers, + json={'query': query_valid_key}, + ) success = r.json()['data']['SystemStatus']['success'] self.assertTrue(success) @@ -226,7 +248,13 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query_invalid_key}) + r = request( + 'POST', + graphql_url, + verify=False, + headers=headers, + json={'query': query_invalid_key}, + ) success = r.json()['data']['SystemStatus']['success'] self.assertFalse(success) @@ -242,7 +270,13 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query_no_key}) + r = request( + 'POST', + graphql_url, + verify=False, + headers=headers, + json={'query': query_no_key}, + ) success = r.json()['data']['SystemStatus']['success'] self.assertFalse(success) @@ -263,7 +297,9 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): } } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': mutation}) + r = request( + 'POST', graphql_url, verify=False, headers=headers, json={'query': mutation} + ) token = r.json()['data']['AuthToken']['data']['result']['token'] @@ -286,7 +322,9 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query}) + r = request( + 'POST', graphql_url, verify=False, headers=headers, json={'query': query} + ) success = r.json()['data']['ShowVersion']['success'] self.assertTrue(success) @@ -371,14 +409,14 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): self.cli_commit() payload_path = [ - "interfaces", - "dummy", - f"{conf_interface}", - "address", - f"{conf_address}", + 'interfaces', + 'dummy', + f'{conf_interface}', + 'address', + f'{conf_address}', ] - payload = {'data': json.dumps({"op": "set", "path": payload_path}), 'key': key} + payload = {'data': json.dumps({'op': 'set', 'path': payload_path}), 'key': key} r = request('POST', url, verify=False, headers=headers, data=payload) self.assertEqual(r.status_code, 200) @@ -508,5 +546,6 @@ class TestHTTPSService(VyOSUnitTestSHIM.TestCase): call(f'sudo rm -f {nginx_tmp_site}') call('sudo systemctl reload nginx') + if __name__ == '__main__': unittest.main(verbosity=5) diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py index 93dd0fbfb..ebf745f32 100644 --- a/src/services/api/graphql/bindings.py +++ b/src/services/api/graphql/bindings.py @@ -20,18 +20,18 @@ from ariadne import make_executable_schema from ariadne import load_schema_from_path from ariadne import snake_case_fallback_resolvers -from . graphql.queries import query -from . graphql.mutations import mutation -from . graphql.directives import directives_dict -from . graphql.errors import op_mode_error -from . graphql.auth_token_mutation import auth_token_mutation -from . libs.token_auth import init_secret +from .graphql.queries import query +from .graphql.mutations import mutation +from .graphql.directives import directives_dict +from .graphql.errors import op_mode_error +from .graphql.auth_token_mutation import auth_token_mutation +from .libs.token_auth import init_secret -from .. session import SessionState +from ..session import SessionState def generate_schema(): - state = SessionState() + state = SessionState() api_schema_dir = vyos.defaults.directories['api_schema'] if state.auth_type == 'token': @@ -39,9 +39,14 @@ def generate_schema(): type_defs = load_schema_from_path(api_schema_dir) - schema = make_executable_schema(type_defs, query, op_mode_error, - mutation, auth_token_mutation, - snake_case_fallback_resolvers, - directives=directives_dict) + schema = make_executable_schema( + type_defs, + query, + op_mode_error, + mutation, + auth_token_mutation, + snake_case_fallback_resolvers, + directives=directives_dict, + ) return schema diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py index 164960217..c74364603 100644 --- a/src/services/api/graphql/graphql/auth_token_mutation.py +++ b/src/services/api/graphql/graphql/auth_token_mutation.py @@ -19,11 +19,12 @@ from typing import Dict from ariadne import ObjectType from graphql import GraphQLResolveInfo -from .. libs.token_auth import generate_token -from .. session.session import get_user_info -from ... session import SessionState +from ..libs.token_auth import generate_token +from ..session.session import get_user_info +from ...session import SessionState + +auth_token_mutation = ObjectType('Mutation') -auth_token_mutation = ObjectType("Mutation") @auth_token_mutation.field('AuthToken') def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): @@ -35,8 +36,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): secret = getattr(state, 'secret', '') exp_interval = int(state.token_exp) - expiration = (datetime.datetime.now(tz=datetime.timezone.utc) + - datetime.timedelta(seconds=exp_interval)) + expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=exp_interval + ) res = generate_token(user, passwd, secret, expiration) try: @@ -46,18 +48,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): pass if 'token' in res: data['result'] = res - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} if 'errors' in res: - return { - "success": False, - "errors": res['errors'] - } - - return { - "success": False, - "errors": ['token generation failed'] - } + return {'success': False, 'errors': res['errors']} + + return {'success': False, 'errors': ['token generation failed']} diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py index 62031ada3..0b391c070 100644 --- a/src/services/api/graphql/graphql/mutations.py +++ b/src/services/api/graphql/graphql/mutations.py @@ -14,20 +14,23 @@ # along with this library. If not, see . from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from ... session import SessionState -from .. libs import key_auth -from .. session.session import Session -from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -mutation = ObjectType("Mutation") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +mutation = ObjectType('Mutation') + def make_mutation_resolver(mutation_name, class_name, session_func): """Dynamically generate a resolver for the mutation named in the @@ -59,10 +62,7 @@ def make_mutation_resolver(mutation_name, class_name, session_func): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -77,14 +77,8 @@ def make_mutation_resolver(mutation_name, class_name, session_func): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no # auth_type; validation and defaultValue ensure it is @@ -106,35 +100,36 @@ def make_mutation_resolver(mutation_name, class_name, session_func): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errore": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errore': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) + def make_gen_op_mutation_resolver(mutation_name): return make_mutation_resolver(mutation_name, mutation_name, 'gen_op_mutation') + def make_composite_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py index 1e9036574..9303fe909 100644 --- a/src/services/api/graphql/graphql/queries.py +++ b/src/services/api/graphql/graphql/queries.py @@ -14,20 +14,23 @@ # along with this library. If not, see . from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from ... session import SessionState -from .. libs import key_auth -from .. session.session import Session -from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -query = ObjectType("Query") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +query = ObjectType('Query') + def make_query_resolver(query_name, class_name, session_func): """Dynamically generate a resolver for the query named in the @@ -59,10 +62,7 @@ def make_query_resolver(query_name, class_name, session_func): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -77,14 +77,8 @@ def make_query_resolver(query_name, class_name, session_func): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no # auth_type; validation and defaultValue ensure it is @@ -106,35 +100,36 @@ def make_query_resolver(query_name, class_name, session_func): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errors": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errors': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) + def make_gen_op_query_resolver(query_name): return make_query_resolver(query_name, query_name, 'gen_op_query') + def make_composite_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py index 9e49a1203..ffd7f32b2 100644 --- a/src/services/api/graphql/libs/key_auth.py +++ b/src/services/api/graphql/libs/key_auth.py @@ -14,7 +14,8 @@ # along with this library. If not, see . -from ... session import SessionState +from ...session import SessionState + def check_auth(key_list, key): if not key_list: @@ -25,6 +26,7 @@ def check_auth(key_list, key): key_id = k['id'] return key_id + def auth_required(key): state = SessionState() api_keys = None diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py index 2d772e035..4f743a096 100644 --- a/src/services/api/graphql/libs/token_auth.py +++ b/src/services/api/graphql/libs/token_auth.py @@ -19,7 +19,7 @@ import uuid import pam from secrets import token_hex -from ... session import SessionState +from ...session import SessionState def _check_passwd_pam(username: str, passwd: str) -> bool: @@ -48,13 +48,13 @@ def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict: payload_data = {'iss': user, 'sub': user_id, 'exp': exp} secret = getattr(state, 'secret', None) if secret is None: - return {"errors": ['missing secret']} - token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256") + return {'errors': ['missing secret']} + token = jwt.encode(payload=payload_data, key=secret, algorithm='HS256') users |= {user_id: user} return {'token': token} else: - return {"errors": ['failed pam authentication']} + return {'errors': ['failed pam authentication']} def get_user_context(request): @@ -70,7 +70,7 @@ def get_user_context(request): try: secret = getattr(state, 'secret', None) - payload = jwt.decode(token, secret, algorithms=["HS256"]) + payload = jwt.decode(token, secret, algorithms=['HS256']) user_id: str = payload.get('sub') if user_id is None: return context diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py index f02380cdc..ed3ee1e8c 100644 --- a/src/services/api/graphql/routers.py +++ b/src/services/api/graphql/routers.py @@ -26,14 +26,15 @@ if typing.TYPE_CHECKING: from fastapi import FastAPI -def graphql_init(app: "FastAPI"): - from .. session import SessionState +def graphql_init(app: 'FastAPI'): + from ..session import SessionState from .libs.token_auth import get_user_context state = SessionState() # import after initializaion of state from .bindings import generate_schema + schema = generate_schema() in_spec = state.introspection @@ -44,21 +45,33 @@ def graphql_init(app: "FastAPI"): if state.origins: origins = state.origins - app.add_route('/graphql', CORSMiddleware(GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec), - allow_origins=origins, - allow_methods=("GET", "POST", "OPTIONS"), - allow_headers=("Authorization",))) + app.add_route( + '/graphql', + CORSMiddleware( + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + allow_origins=origins, + allow_methods=('GET', 'POST', 'OPTIONS'), + allow_headers=('Authorization',), + ), + ) else: - app.add_route('/graphql', GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec)) + app.add_route( + '/graphql', + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + ) -def graphql_clear(app: "FastAPI"): +def graphql_clear(app: 'FastAPI'): for r in app.routes: if r.path == '/graphql': app.routes.remove(r) diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py index 6e2875f3c..619534f43 100644 --- a/src/services/api/graphql/session/session.py +++ b/src/services/api/graphql/session/session.py @@ -28,34 +28,45 @@ from api.graphql.libs.op_mode import normalize_output op_mode_include_file = os.path.join(directories['data'], 'op-mode-standardized.json') -def get_config_dict(path=[], effective=False, key_mangling=None, - get_first_key=False, no_multi_convert=False, - no_tag_node_value_mangle=False): + +def get_config_dict( + path=[], + effective=False, + key_mangling=None, + get_first_key=False, + no_multi_convert=False, + no_tag_node_value_mangle=False, +): config = Config() - return config.get_config_dict(path=path, effective=effective, - key_mangling=key_mangling, - get_first_key=get_first_key, - no_multi_convert=no_multi_convert, - no_tag_node_value_mangle=no_tag_node_value_mangle) + return config.get_config_dict( + path=path, + effective=effective, + key_mangling=key_mangling, + get_first_key=get_first_key, + no_multi_convert=no_multi_convert, + no_tag_node_value_mangle=no_tag_node_value_mangle, + ) + def get_user_info(user): user_info = {} - info = get_config_dict(['system', 'login', 'user', user], - get_first_key=True) + info = get_config_dict(['system', 'login', 'user', user], get_first_key=True) if not info: - raise ValueError("No such user") + raise ValueError('No such user') user_info['user'] = user user_info['full_name'] = info.get('full-name', '') return user_info + class Session: """ Wrapper for calling configsession functions based on GraphQL requests. Non-nullable fields in the respective schema allow avoiding a key check in 'data'. """ + def __init__(self, session, data): self._session = session self._data = data diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py index 034e3fcdb..d65d6e1ec 100644 --- a/src/services/api/rest/models.py +++ b/src/services/api/rest/models.py @@ -31,75 +31,88 @@ from fastapi.responses import HTMLResponse def error(code, msg): - resp = {"success": False, "error": msg, "data": None} + resp = {'success': False, 'error': msg, 'data': None} resp = json.dumps(resp) return HTMLResponse(resp, status_code=code) + def success(data): - resp = {"success": True, "data": data, "error": None} + resp = {'success': True, 'data': data, 'error': None} resp = json.dumps(resp) return HTMLResponse(resp) + # Pydantic models for validation # Pydantic will cast when possible, so use StrictStr validators added as # needed for additional constraints # json_schema_extra adds anotations to OpenAPI to add examples + class ApiModel(BaseModel): key: StrictStr + class BasePathModel(BaseModel): op: StrictStr path: List[StrictStr] - @field_validator("path") + @field_validator('path') @classmethod def check_non_empty(cls, path: str) -> str: if not len(path) > 0: raise ValueError('path must be non-empty') return path + class BaseConfigureModel(BasePathModel): value: StrictStr = None + class ConfigureModel(ApiModel, BaseConfigureModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "set | delete | comment", - "path": ["config", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'set | delete | comment', + 'path': ['config', 'mode', 'path'], } } + class ConfigureListModel(ApiModel): commands: List[BaseConfigureModel] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "commands": "list of commands", + 'example': { + 'key': 'id_key', + 'commands': 'list of commands', } } + class BaseConfigSectionModel(BasePathModel): section: Dict + class ConfigSectionModel(ApiModel, BaseConfigSectionModel): pass + class ConfigSectionListModel(ApiModel): commands: List[BaseConfigSectionModel] + class BaseConfigSectionTreeModel(BaseModel): op: StrictStr mask: Dict config: Dict + class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): pass + class RetrieveModel(ApiModel): op: StrictStr path: List[StrictStr] @@ -107,34 +120,34 @@ class RetrieveModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "returnValue | returnValues | exists | showConfig", - "path": ["config", "mode", "path"], - "configFormat": "json (default) | json_ast | raw", - + 'example': { + 'key': 'id_key', + 'op': 'returnValue | returnValues | exists | showConfig', + 'path': ['config', 'mode', 'path'], + 'configFormat': 'json (default) | json_ast | raw', } } + class ConfigFileModel(ApiModel): op: StrictStr file: StrictStr = None class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "save | load", - "file": "filename", + 'example': { + 'key': 'id_key', + 'op': 'save | load', + 'file': 'filename', } } class ImageOp(str, Enum): - add = "add" - delete = "delete" - show = "show" - set_default = "set_default" + add = 'add' + delete = 'delete' + show = 'show' + set_default = 'set_default' class ImageModel(ApiModel): @@ -146,23 +159,24 @@ class ImageModel(ApiModel): def check_data(self) -> Self: if self.op == 'add': if not self.url: - raise ValueError("Missing required field \"url\"") + raise ValueError('Missing required field "url"') elif self.op in ['delete', 'set_default']: if not self.name: - raise ValueError("Missing required field \"name\"") + raise ValueError('Missing required field "name"') return self class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show | set_default", - "url": "imagelocation", - "name": "imagename", + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show | set_default', + 'url': 'imagelocation', + 'name': 'imagename', } } + class ImportPkiModel(ApiModel): op: StrictStr path: List[StrictStr] @@ -170,11 +184,11 @@ class ImportPkiModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "import_pki", - "path": ["op", "mode", "path"], - "passphrase": "passphrase", + 'example': { + 'key': 'id_key', + 'op': 'import_pki', + 'path': ['op', 'mode', 'path'], + 'passphrase': 'passphrase', } } @@ -185,75 +199,80 @@ class ContainerImageModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show", - "name": "imagename", + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show', + 'name': 'imagename', } } + class GenerateModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "generate", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'generate', + 'path': ['op', 'mode', 'path'], } } + class ShowModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "show", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'show', + 'path': ['op', 'mode', 'path'], } } + class RebootModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "reboot", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'reboot', + 'path': ['op', 'mode', 'path'], } } + class ResetModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "reset", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'reset', + 'path': ['op', 'mode', 'path'], } } + class PoweroffModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "poweroff", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'poweroff', + 'path': ['op', 'mode', 'path'], } } @@ -263,14 +282,16 @@ class Success(BaseModel): data: Union[str, bool, Dict] error: str + class Error(BaseModel): success: bool = False data: Union[str, bool, Dict] error: str + responses = { 200: {'model': Success}, 400: {'model': Error}, 422: {'model': Error, 'description': 'Validation Error'}, - 500: {'model': Error} + 500: {'model': Error}, } diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py index 38b10ef7d..da981d5bf 100644 --- a/src/services/api/rest/routers.py +++ b/src/services/api/rest/routers.py @@ -18,10 +18,12 @@ # pylint: disable=wildcard-import,unused-wildcard-import # pylint: disable=broad-exception-caught +import json import copy import logging import traceback from threading import Lock +from typing import Union from typing import Callable from typing import TYPE_CHECKING @@ -43,8 +45,30 @@ from vyos.configtree import ConfigTree from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSessionError -from .. session import SessionState -from . models import * +from ..session import SessionState +from .models import success +from .models import error +from .models import responses +from .models import ApiModel +from .models import ConfigureModel +from .models import ConfigureListModel +from .models import ConfigSectionModel +from .models import ConfigSectionListModel +from .models import ConfigSectionTreeModel +from .models import BaseConfigSectionTreeModel +from .models import BaseConfigureModel +from .models import BaseConfigSectionModel +from .models import RetrieveModel +from .models import ConfigFileModel +from .models import ImageModel +from .models import ContainerImageModel +from .models import GenerateModel +from .models import ShowModel +from .models import RebootModel +from .models import ResetModel +from .models import ImportPkiModel +from .models import PoweroffModel + if TYPE_CHECKING: from fastapi import FastAPI @@ -69,7 +93,7 @@ def auth_required(data: ApiModel): api_keys = session.keys key_id = check_auth(api_keys, key) if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") + raise HTTPException(status_code=401, detail='Valid API key is required') session.id = key_id @@ -79,10 +103,12 @@ def auth_required(data: ApiModel): # validation by FastAPI/Pydantic, as is used for application/json requests class MultipartRequest(Request): """Override Request class to convert form request to json""" + # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-branches,too-many-statements _form_err = () + @property def form_err(self): return self._form_err @@ -104,16 +130,16 @@ class MultipartRequest(Request): return self._headers async def _get_form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000 - ) -> FormData: + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: if self._form is None: assert ( parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") + ), 'The `python-multipart` library must be installed to use form parsing.' + content_type_header = self.orig_headers.get('Content-Type') content_type: bytes content_type, _ = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": + if content_type == b'multipart/form-data': try: multipart_parser = MultiPartParser( self.orig_headers, @@ -123,10 +149,10 @@ class MultipartRequest(Request): ) self._form = await multipart_parser.parse() except MultiPartException as exc: - if "app" in self.scope: + if 'app' in self.scope: raise HTTPException(status_code=400, detail=exc.message) raise exc - elif content_type == b"application/x-www-form-urlencoded": + elif content_type == b'application/x-www-form-urlencoded': form_parser = FormParser(self.orig_headers, self.stream()) self._form = await form_parser.parse() else: @@ -134,7 +160,7 @@ class MultipartRequest(Request): return self._form async def body(self) -> bytes: - if not hasattr(self, "_body"): + if not hasattr(self, '_body'): forms = {} merge = {} body = await super().body() @@ -143,12 +169,12 @@ class MultipartRequest(Request): form_data = await self.form() if form_data: endpoint = self.url.path - LOG.debug("processing form data") + LOG.debug('processing form data') for k, v in form_data.multi_items(): forms[k] = v if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") + self.form_err = (422, 'Non-empty data field is required') return self._body try: tmp = json.loads(forms['data']) @@ -168,37 +194,57 @@ class MultipartRequest(Request): for c in cmds: if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") + self.form_err = ( + 400, + f"Malformed command '{c}': any command must be JSON of dict", + ) return self._body if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'op' field", + ) + if endpoint not in ( + '/config-file', + '/container-image', + '/image', + '/configure-section', + ): if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'path' field", + ) elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' field must be a list", + ) elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") + self.form_err = ( + 400, + f"Malformed command '{0}': 'path' field must be a list of strings", + ) if endpoint in ('/configure'): if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' list must be non-empty", + ) if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") + self.form_err = ( + 400, + f"Malformed command '{c}': 'value' field must be a string", + ) if endpoint in ('/configure-section'): if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'section' or 'config' field", + ) if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") + self.form_err = (401, 'Valid API key is required') if 'key' in forms and 'key' not in merge: merge['key'] = forms['key'] @@ -232,12 +278,15 @@ class MultipartRoute(APIRoute): return custom_route_handler -router = APIRouter(route_class=MultipartRoute, - responses={**responses}, - dependencies=[Depends(auth_required)]) +router = APIRouter( + route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)], +) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" +self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background' + def call_commit(s: SessionState): try: @@ -245,15 +294,22 @@ def call_commit(s: SessionState): except ConfigSessionError as e: s.session.discard() if s.debug: - LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}') else: - LOG.warning(f"ConfigSessionError: {e}") - - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - _request: Request, background_tasks: BackgroundTasks): + LOG.warning(f'ConfigSessionError: {e}') + + +def _configure_op( + data: Union[ + ConfigureModel, + ConfigureListModel, + ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel, + ], + _request: Request, + background_tasks: BackgroundTasks, +): # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=consider-using-with @@ -287,10 +343,10 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, if c.value: value = c.value else: - value = "" + value = '' # For vyos.configsession calls that have no separate value arguments, # and for type checking too - cfg_path = " ".join(path + [value]).strip() + cfg_path = ' '.join(path + [value]).strip() elif isinstance(c, BaseConfigSectionModel): section = c.section @@ -304,7 +360,9 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.set(path, value=value) elif op == 'delete': if state.strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") + raise ConfigSessionError( + f'Cannot delete [{cfg_path}]: path/value does not exist' + ) session.delete(path, value=value) elif op == 'comment': session.comment(path, value=value) @@ -343,7 +401,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.discard() status = 400 if state.debug: - LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}') error_msg = str(e) except Exception: session.discard() @@ -351,7 +409,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, status = 500 # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." + error_msg = 'An internal error occured. Check the logs for details.' finally: lock.release() @@ -371,21 +429,24 @@ def create_path_import_pki_no_prompt(path): @router.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): +def configure_op( + data: Union[ConfigureModel, ConfigureListModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) @router.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): +def configure_section_op( + data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) -@router.post("/retrieve") +@router.post('/retrieve') async def retrieve_op(data: RetrieveModel): state = SessionState() session = state.session @@ -393,7 +454,7 @@ async def retrieve_op(data: RetrieveModel): config = Config(session_env=env) op = data.op - path = " ".join(data.path) + path = ' '.join(data.path) try: if op == 'returnValue': @@ -424,7 +485,7 @@ async def retrieve_op(data: RetrieveModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -448,7 +509,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): if data.file: path = data.file else: - return error(400, "Missing required field \"file\"") + return error(400, 'Missing required field "file"') session.migrate_and_load_config(path) @@ -466,7 +527,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(msg) @@ -484,14 +545,14 @@ def image_op(data: ImageModel): elif op == 'delete': res = session.remove_image(data.name) elif op == 'show': - res = session.show(["system", "image"]) + res = session.show(['system', 'image']) elif op == 'set_default': res = session.set_default_image(data.name) except ConfigSessionError as e: return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -508,13 +569,13 @@ def container_image_op(data: ContainerImageModel): if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.add_container_image(name) elif op == 'delete': if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.delete_container_image(name) elif op == 'show': res = session.show_container_image() @@ -524,7 +585,7 @@ def container_image_op(data: ContainerImageModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -546,7 +607,7 @@ def generate_op(data: GenerateModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -568,7 +629,7 @@ def show_op(data: ShowModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -590,7 +651,7 @@ def reboot_op(data: RebootModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -612,7 +673,7 @@ def reset_op(data: ResetModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -652,7 +713,7 @@ def import_pki(data: ImportPkiModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') finally: lock.release() @@ -676,18 +737,18 @@ def poweroff_op(data: PoweroffModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) -def rest_init(app: "FastAPI"): +def rest_init(app: 'FastAPI'): if all(r in app.routes for r in router.routes): return app.include_router(router) -def rest_clear(app: "FastAPI"): +def rest_clear(app: 'FastAPI'): for r in router.routes: if r in app.routes: app.routes.remove(r) diff --git a/src/services/api/session.py b/src/services/api/session.py index dcdc7246c..ad3ef660c 100644 --- a/src/services/api/session.py +++ b/src/services/api/session.py @@ -13,6 +13,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . + class SessionState: # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-instance-attributes,too-few-public-methods -- cgit v1.2.3