diff options
Diffstat (limited to 'src/services/api/rest')
-rw-r--r-- | src/services/api/rest/models.py | 143 | ||||
-rw-r--r-- | src/services/api/rest/routers.py | 209 |
2 files changed, 217 insertions, 135 deletions
diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py index 034e3fcdb..d65d6e1ec 100644 --- a/src/services/api/rest/models.py +++ b/src/services/api/rest/models.py @@ -31,75 +31,88 @@ from fastapi.responses import HTMLResponse def error(code, msg): - resp = {"success": False, "error": msg, "data": None} + resp = {'success': False, 'error': msg, 'data': None} resp = json.dumps(resp) return HTMLResponse(resp, status_code=code) + def success(data): - resp = {"success": True, "data": data, "error": None} + resp = {'success': True, 'data': data, 'error': None} resp = json.dumps(resp) return HTMLResponse(resp) + # Pydantic models for validation # Pydantic will cast when possible, so use StrictStr validators added as # needed for additional constraints # json_schema_extra adds anotations to OpenAPI to add examples + class ApiModel(BaseModel): key: StrictStr + class BasePathModel(BaseModel): op: StrictStr path: List[StrictStr] - @field_validator("path") + @field_validator('path') @classmethod def check_non_empty(cls, path: str) -> str: if not len(path) > 0: raise ValueError('path must be non-empty') return path + class BaseConfigureModel(BasePathModel): value: StrictStr = None + class ConfigureModel(ApiModel, BaseConfigureModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "set | delete | comment", - "path": ["config", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'set | delete | comment', + 'path': ['config', 'mode', 'path'], } } + class ConfigureListModel(ApiModel): commands: List[BaseConfigureModel] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "commands": "list of commands", + 'example': { + 'key': 'id_key', + 'commands': 'list of commands', } } + class BaseConfigSectionModel(BasePathModel): section: Dict + class ConfigSectionModel(ApiModel, BaseConfigSectionModel): pass + class ConfigSectionListModel(ApiModel): commands: List[BaseConfigSectionModel] + class BaseConfigSectionTreeModel(BaseModel): op: StrictStr mask: Dict config: Dict + class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): pass + class RetrieveModel(ApiModel): op: StrictStr path: List[StrictStr] @@ -107,34 +120,34 @@ class RetrieveModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "returnValue | returnValues | exists | showConfig", - "path": ["config", "mode", "path"], - "configFormat": "json (default) | json_ast | raw", - + 'example': { + 'key': 'id_key', + 'op': 'returnValue | returnValues | exists | showConfig', + 'path': ['config', 'mode', 'path'], + 'configFormat': 'json (default) | json_ast | raw', } } + class ConfigFileModel(ApiModel): op: StrictStr file: StrictStr = None class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "save | load", - "file": "filename", + 'example': { + 'key': 'id_key', + 'op': 'save | load', + 'file': 'filename', } } class ImageOp(str, Enum): - add = "add" - delete = "delete" - show = "show" - set_default = "set_default" + add = 'add' + delete = 'delete' + show = 'show' + set_default = 'set_default' class ImageModel(ApiModel): @@ -146,23 +159,24 @@ class ImageModel(ApiModel): def check_data(self) -> Self: if self.op == 'add': if not self.url: - raise ValueError("Missing required field \"url\"") + raise ValueError('Missing required field "url"') elif self.op in ['delete', 'set_default']: if not self.name: - raise ValueError("Missing required field \"name\"") + raise ValueError('Missing required field "name"') return self class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show | set_default", - "url": "imagelocation", - "name": "imagename", + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show | set_default', + 'url': 'imagelocation', + 'name': 'imagename', } } + class ImportPkiModel(ApiModel): op: StrictStr path: List[StrictStr] @@ -170,11 +184,11 @@ class ImportPkiModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "import_pki", - "path": ["op", "mode", "path"], - "passphrase": "passphrase", + 'example': { + 'key': 'id_key', + 'op': 'import_pki', + 'path': ['op', 'mode', 'path'], + 'passphrase': 'passphrase', } } @@ -185,75 +199,80 @@ class ContainerImageModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show", - "name": "imagename", + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show', + 'name': 'imagename', } } + class GenerateModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "generate", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'generate', + 'path': ['op', 'mode', 'path'], } } + class ShowModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "show", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'show', + 'path': ['op', 'mode', 'path'], } } + class RebootModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "reboot", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'reboot', + 'path': ['op', 'mode', 'path'], } } + class ResetModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "reset", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'reset', + 'path': ['op', 'mode', 'path'], } } + class PoweroffModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "poweroff", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'poweroff', + 'path': ['op', 'mode', 'path'], } } @@ -263,14 +282,16 @@ class Success(BaseModel): data: Union[str, bool, Dict] error: str + class Error(BaseModel): success: bool = False data: Union[str, bool, Dict] error: str + responses = { 200: {'model': Success}, 400: {'model': Error}, 422: {'model': Error, 'description': 'Validation Error'}, - 500: {'model': Error} + 500: {'model': Error}, } diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py index 38b10ef7d..da981d5bf 100644 --- a/src/services/api/rest/routers.py +++ b/src/services/api/rest/routers.py @@ -18,10 +18,12 @@ # pylint: disable=wildcard-import,unused-wildcard-import # pylint: disable=broad-exception-caught +import json import copy import logging import traceback from threading import Lock +from typing import Union from typing import Callable from typing import TYPE_CHECKING @@ -43,8 +45,30 @@ from vyos.configtree import ConfigTree from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSessionError -from .. session import SessionState -from . models import * +from ..session import SessionState +from .models import success +from .models import error +from .models import responses +from .models import ApiModel +from .models import ConfigureModel +from .models import ConfigureListModel +from .models import ConfigSectionModel +from .models import ConfigSectionListModel +from .models import ConfigSectionTreeModel +from .models import BaseConfigSectionTreeModel +from .models import BaseConfigureModel +from .models import BaseConfigSectionModel +from .models import RetrieveModel +from .models import ConfigFileModel +from .models import ImageModel +from .models import ContainerImageModel +from .models import GenerateModel +from .models import ShowModel +from .models import RebootModel +from .models import ResetModel +from .models import ImportPkiModel +from .models import PoweroffModel + if TYPE_CHECKING: from fastapi import FastAPI @@ -69,7 +93,7 @@ def auth_required(data: ApiModel): api_keys = session.keys key_id = check_auth(api_keys, key) if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") + raise HTTPException(status_code=401, detail='Valid API key is required') session.id = key_id @@ -79,10 +103,12 @@ def auth_required(data: ApiModel): # validation by FastAPI/Pydantic, as is used for application/json requests class MultipartRequest(Request): """Override Request class to convert form request to json""" + # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-branches,too-many-statements _form_err = () + @property def form_err(self): return self._form_err @@ -104,16 +130,16 @@ class MultipartRequest(Request): return self._headers async def _get_form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000 - ) -> FormData: + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: if self._form is None: assert ( parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") + ), 'The `python-multipart` library must be installed to use form parsing.' + content_type_header = self.orig_headers.get('Content-Type') content_type: bytes content_type, _ = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": + if content_type == b'multipart/form-data': try: multipart_parser = MultiPartParser( self.orig_headers, @@ -123,10 +149,10 @@ class MultipartRequest(Request): ) self._form = await multipart_parser.parse() except MultiPartException as exc: - if "app" in self.scope: + if 'app' in self.scope: raise HTTPException(status_code=400, detail=exc.message) raise exc - elif content_type == b"application/x-www-form-urlencoded": + elif content_type == b'application/x-www-form-urlencoded': form_parser = FormParser(self.orig_headers, self.stream()) self._form = await form_parser.parse() else: @@ -134,7 +160,7 @@ class MultipartRequest(Request): return self._form async def body(self) -> bytes: - if not hasattr(self, "_body"): + if not hasattr(self, '_body'): forms = {} merge = {} body = await super().body() @@ -143,12 +169,12 @@ class MultipartRequest(Request): form_data = await self.form() if form_data: endpoint = self.url.path - LOG.debug("processing form data") + LOG.debug('processing form data') for k, v in form_data.multi_items(): forms[k] = v if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") + self.form_err = (422, 'Non-empty data field is required') return self._body try: tmp = json.loads(forms['data']) @@ -168,37 +194,57 @@ class MultipartRequest(Request): for c in cmds: if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") + self.form_err = ( + 400, + f"Malformed command '{c}': any command must be JSON of dict", + ) return self._body if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'op' field", + ) + if endpoint not in ( + '/config-file', + '/container-image', + '/image', + '/configure-section', + ): if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'path' field", + ) elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' field must be a list", + ) elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") + self.form_err = ( + 400, + f"Malformed command '{0}': 'path' field must be a list of strings", + ) if endpoint in ('/configure'): if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' list must be non-empty", + ) if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") + self.form_err = ( + 400, + f"Malformed command '{c}': 'value' field must be a string", + ) if endpoint in ('/configure-section'): if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'section' or 'config' field", + ) if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") + self.form_err = (401, 'Valid API key is required') if 'key' in forms and 'key' not in merge: merge['key'] = forms['key'] @@ -232,12 +278,15 @@ class MultipartRoute(APIRoute): return custom_route_handler -router = APIRouter(route_class=MultipartRoute, - responses={**responses}, - dependencies=[Depends(auth_required)]) +router = APIRouter( + route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)], +) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" +self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background' + def call_commit(s: SessionState): try: @@ -245,15 +294,22 @@ def call_commit(s: SessionState): except ConfigSessionError as e: s.session.discard() if s.debug: - LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}') else: - LOG.warning(f"ConfigSessionError: {e}") - - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - _request: Request, background_tasks: BackgroundTasks): + LOG.warning(f'ConfigSessionError: {e}') + + +def _configure_op( + data: Union[ + ConfigureModel, + ConfigureListModel, + ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel, + ], + _request: Request, + background_tasks: BackgroundTasks, +): # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=consider-using-with @@ -287,10 +343,10 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, if c.value: value = c.value else: - value = "" + value = '' # For vyos.configsession calls that have no separate value arguments, # and for type checking too - cfg_path = " ".join(path + [value]).strip() + cfg_path = ' '.join(path + [value]).strip() elif isinstance(c, BaseConfigSectionModel): section = c.section @@ -304,7 +360,9 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.set(path, value=value) elif op == 'delete': if state.strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") + raise ConfigSessionError( + f'Cannot delete [{cfg_path}]: path/value does not exist' + ) session.delete(path, value=value) elif op == 'comment': session.comment(path, value=value) @@ -343,7 +401,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.discard() status = 400 if state.debug: - LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}') error_msg = str(e) except Exception: session.discard() @@ -351,7 +409,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, status = 500 # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." + error_msg = 'An internal error occured. Check the logs for details.' finally: lock.release() @@ -371,21 +429,24 @@ def create_path_import_pki_no_prompt(path): @router.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): +def configure_op( + data: Union[ConfigureModel, ConfigureListModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) @router.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): +def configure_section_op( + data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) -@router.post("/retrieve") +@router.post('/retrieve') async def retrieve_op(data: RetrieveModel): state = SessionState() session = state.session @@ -393,7 +454,7 @@ async def retrieve_op(data: RetrieveModel): config = Config(session_env=env) op = data.op - path = " ".join(data.path) + path = ' '.join(data.path) try: if op == 'returnValue': @@ -424,7 +485,7 @@ async def retrieve_op(data: RetrieveModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -448,7 +509,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): if data.file: path = data.file else: - return error(400, "Missing required field \"file\"") + return error(400, 'Missing required field "file"') session.migrate_and_load_config(path) @@ -466,7 +527,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(msg) @@ -484,14 +545,14 @@ def image_op(data: ImageModel): elif op == 'delete': res = session.remove_image(data.name) elif op == 'show': - res = session.show(["system", "image"]) + res = session.show(['system', 'image']) elif op == 'set_default': res = session.set_default_image(data.name) except ConfigSessionError as e: return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -508,13 +569,13 @@ def container_image_op(data: ContainerImageModel): if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.add_container_image(name) elif op == 'delete': if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.delete_container_image(name) elif op == 'show': res = session.show_container_image() @@ -524,7 +585,7 @@ def container_image_op(data: ContainerImageModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -546,7 +607,7 @@ def generate_op(data: GenerateModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -568,7 +629,7 @@ def show_op(data: ShowModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -590,7 +651,7 @@ def reboot_op(data: RebootModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -612,7 +673,7 @@ def reset_op(data: ResetModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -652,7 +713,7 @@ def import_pki(data: ImportPkiModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') finally: lock.release() @@ -676,18 +737,18 @@ def poweroff_op(data: PoweroffModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) -def rest_init(app: "FastAPI"): +def rest_init(app: 'FastAPI'): if all(r in app.routes for r in router.routes): return app.include_router(router) -def rest_clear(app: "FastAPI"): +def rest_clear(app: 'FastAPI'): for r in router.routes: if r in app.routes: app.routes.remove(r) |