diff options
Diffstat (limited to 'src/services/api/rest/routers.py')
-rw-r--r-- | src/services/api/rest/routers.py | 209 |
1 files changed, 135 insertions, 74 deletions
diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py index 38b10ef7d..da981d5bf 100644 --- a/src/services/api/rest/routers.py +++ b/src/services/api/rest/routers.py @@ -18,10 +18,12 @@ # pylint: disable=wildcard-import,unused-wildcard-import # pylint: disable=broad-exception-caught +import json import copy import logging import traceback from threading import Lock +from typing import Union from typing import Callable from typing import TYPE_CHECKING @@ -43,8 +45,30 @@ from vyos.configtree import ConfigTree from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSessionError -from .. session import SessionState -from . models import * +from ..session import SessionState +from .models import success +from .models import error +from .models import responses +from .models import ApiModel +from .models import ConfigureModel +from .models import ConfigureListModel +from .models import ConfigSectionModel +from .models import ConfigSectionListModel +from .models import ConfigSectionTreeModel +from .models import BaseConfigSectionTreeModel +from .models import BaseConfigureModel +from .models import BaseConfigSectionModel +from .models import RetrieveModel +from .models import ConfigFileModel +from .models import ImageModel +from .models import ContainerImageModel +from .models import GenerateModel +from .models import ShowModel +from .models import RebootModel +from .models import ResetModel +from .models import ImportPkiModel +from .models import PoweroffModel + if TYPE_CHECKING: from fastapi import FastAPI @@ -69,7 +93,7 @@ def auth_required(data: ApiModel): api_keys = session.keys key_id = check_auth(api_keys, key) if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") + raise HTTPException(status_code=401, detail='Valid API key is required') session.id = key_id @@ -79,10 +103,12 @@ def auth_required(data: ApiModel): # validation by FastAPI/Pydantic, as is used for application/json requests class MultipartRequest(Request): """Override Request class to convert form request to json""" + # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-branches,too-many-statements _form_err = () + @property def form_err(self): return self._form_err @@ -104,16 +130,16 @@ class MultipartRequest(Request): return self._headers async def _get_form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000 - ) -> FormData: + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: if self._form is None: assert ( parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") + ), 'The `python-multipart` library must be installed to use form parsing.' + content_type_header = self.orig_headers.get('Content-Type') content_type: bytes content_type, _ = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": + if content_type == b'multipart/form-data': try: multipart_parser = MultiPartParser( self.orig_headers, @@ -123,10 +149,10 @@ class MultipartRequest(Request): ) self._form = await multipart_parser.parse() except MultiPartException as exc: - if "app" in self.scope: + if 'app' in self.scope: raise HTTPException(status_code=400, detail=exc.message) raise exc - elif content_type == b"application/x-www-form-urlencoded": + elif content_type == b'application/x-www-form-urlencoded': form_parser = FormParser(self.orig_headers, self.stream()) self._form = await form_parser.parse() else: @@ -134,7 +160,7 @@ class MultipartRequest(Request): return self._form async def body(self) -> bytes: - if not hasattr(self, "_body"): + if not hasattr(self, '_body'): forms = {} merge = {} body = await super().body() @@ -143,12 +169,12 @@ class MultipartRequest(Request): form_data = await self.form() if form_data: endpoint = self.url.path - LOG.debug("processing form data") + LOG.debug('processing form data') for k, v in form_data.multi_items(): forms[k] = v if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") + self.form_err = (422, 'Non-empty data field is required') return self._body try: tmp = json.loads(forms['data']) @@ -168,37 +194,57 @@ class MultipartRequest(Request): for c in cmds: if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") + self.form_err = ( + 400, + f"Malformed command '{c}': any command must be JSON of dict", + ) return self._body if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'op' field", + ) + if endpoint not in ( + '/config-file', + '/container-image', + '/image', + '/configure-section', + ): if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'path' field", + ) elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' field must be a list", + ) elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") + self.form_err = ( + 400, + f"Malformed command '{0}': 'path' field must be a list of strings", + ) if endpoint in ('/configure'): if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' list must be non-empty", + ) if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") + self.form_err = ( + 400, + f"Malformed command '{c}': 'value' field must be a string", + ) if endpoint in ('/configure-section'): if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'section' or 'config' field", + ) if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") + self.form_err = (401, 'Valid API key is required') if 'key' in forms and 'key' not in merge: merge['key'] = forms['key'] @@ -232,12 +278,15 @@ class MultipartRoute(APIRoute): return custom_route_handler -router = APIRouter(route_class=MultipartRoute, - responses={**responses}, - dependencies=[Depends(auth_required)]) +router = APIRouter( + route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)], +) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" +self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background' + def call_commit(s: SessionState): try: @@ -245,15 +294,22 @@ def call_commit(s: SessionState): except ConfigSessionError as e: s.session.discard() if s.debug: - LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}') else: - LOG.warning(f"ConfigSessionError: {e}") - - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - _request: Request, background_tasks: BackgroundTasks): + LOG.warning(f'ConfigSessionError: {e}') + + +def _configure_op( + data: Union[ + ConfigureModel, + ConfigureListModel, + ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel, + ], + _request: Request, + background_tasks: BackgroundTasks, +): # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=consider-using-with @@ -287,10 +343,10 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, if c.value: value = c.value else: - value = "" + value = '' # For vyos.configsession calls that have no separate value arguments, # and for type checking too - cfg_path = " ".join(path + [value]).strip() + cfg_path = ' '.join(path + [value]).strip() elif isinstance(c, BaseConfigSectionModel): section = c.section @@ -304,7 +360,9 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.set(path, value=value) elif op == 'delete': if state.strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") + raise ConfigSessionError( + f'Cannot delete [{cfg_path}]: path/value does not exist' + ) session.delete(path, value=value) elif op == 'comment': session.comment(path, value=value) @@ -343,7 +401,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.discard() status = 400 if state.debug: - LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}') error_msg = str(e) except Exception: session.discard() @@ -351,7 +409,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, status = 500 # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." + error_msg = 'An internal error occured. Check the logs for details.' finally: lock.release() @@ -371,21 +429,24 @@ def create_path_import_pki_no_prompt(path): @router.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): +def configure_op( + data: Union[ConfigureModel, ConfigureListModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) @router.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): +def configure_section_op( + data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) -@router.post("/retrieve") +@router.post('/retrieve') async def retrieve_op(data: RetrieveModel): state = SessionState() session = state.session @@ -393,7 +454,7 @@ async def retrieve_op(data: RetrieveModel): config = Config(session_env=env) op = data.op - path = " ".join(data.path) + path = ' '.join(data.path) try: if op == 'returnValue': @@ -424,7 +485,7 @@ async def retrieve_op(data: RetrieveModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -448,7 +509,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): if data.file: path = data.file else: - return error(400, "Missing required field \"file\"") + return error(400, 'Missing required field "file"') session.migrate_and_load_config(path) @@ -466,7 +527,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(msg) @@ -484,14 +545,14 @@ def image_op(data: ImageModel): elif op == 'delete': res = session.remove_image(data.name) elif op == 'show': - res = session.show(["system", "image"]) + res = session.show(['system', 'image']) elif op == 'set_default': res = session.set_default_image(data.name) except ConfigSessionError as e: return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -508,13 +569,13 @@ def container_image_op(data: ContainerImageModel): if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.add_container_image(name) elif op == 'delete': if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.delete_container_image(name) elif op == 'show': res = session.show_container_image() @@ -524,7 +585,7 @@ def container_image_op(data: ContainerImageModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -546,7 +607,7 @@ def generate_op(data: GenerateModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -568,7 +629,7 @@ def show_op(data: ShowModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -590,7 +651,7 @@ def reboot_op(data: RebootModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -612,7 +673,7 @@ def reset_op(data: ResetModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -652,7 +713,7 @@ def import_pki(data: ImportPkiModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') finally: lock.release() @@ -676,18 +737,18 @@ def poweroff_op(data: PoweroffModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) -def rest_init(app: "FastAPI"): +def rest_init(app: 'FastAPI'): if all(r in app.routes for r in router.routes): return app.include_router(router) -def rest_clear(app: "FastAPI"): +def rest_clear(app: 'FastAPI'): for r in router.routes: if r in app.routes: app.routes.remove(r) |