summaryrefslogtreecommitdiff
path: root/src/services/api/rest/routers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/services/api/rest/routers.py')
-rw-r--r--src/services/api/rest/routers.py209
1 files changed, 135 insertions, 74 deletions
diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py
index 38b10ef7d..da981d5bf 100644
--- a/src/services/api/rest/routers.py
+++ b/src/services/api/rest/routers.py
@@ -18,10 +18,12 @@
# pylint: disable=wildcard-import,unused-wildcard-import
# pylint: disable=broad-exception-caught
+import json
import copy
import logging
import traceback
from threading import Lock
+from typing import Union
from typing import Callable
from typing import TYPE_CHECKING
@@ -43,8 +45,30 @@ from vyos.configtree import ConfigTree
from vyos.configdiff import get_config_diff
from vyos.configsession import ConfigSessionError
-from .. session import SessionState
-from . models import *
+from ..session import SessionState
+from .models import success
+from .models import error
+from .models import responses
+from .models import ApiModel
+from .models import ConfigureModel
+from .models import ConfigureListModel
+from .models import ConfigSectionModel
+from .models import ConfigSectionListModel
+from .models import ConfigSectionTreeModel
+from .models import BaseConfigSectionTreeModel
+from .models import BaseConfigureModel
+from .models import BaseConfigSectionModel
+from .models import RetrieveModel
+from .models import ConfigFileModel
+from .models import ImageModel
+from .models import ContainerImageModel
+from .models import GenerateModel
+from .models import ShowModel
+from .models import RebootModel
+from .models import ResetModel
+from .models import ImportPkiModel
+from .models import PoweroffModel
+
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -69,7 +93,7 @@ def auth_required(data: ApiModel):
api_keys = session.keys
key_id = check_auth(api_keys, key)
if not key_id:
- raise HTTPException(status_code=401, detail="Valid API key is required")
+ raise HTTPException(status_code=401, detail='Valid API key is required')
session.id = key_id
@@ -79,10 +103,12 @@ def auth_required(data: ApiModel):
# validation by FastAPI/Pydantic, as is used for application/json requests
class MultipartRequest(Request):
"""Override Request class to convert form request to json"""
+
# pylint: disable=attribute-defined-outside-init
# pylint: disable=too-many-branches,too-many-statements
_form_err = ()
+
@property
def form_err(self):
return self._form_err
@@ -104,16 +130,16 @@ class MultipartRequest(Request):
return self._headers
async def _get_form(
- self, *, max_files: int | float = 1000, max_fields: int | float = 1000
- ) -> FormData:
+ self, *, max_files: int | float = 1000, max_fields: int | float = 1000
+ ) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
- ), "The `python-multipart` library must be installed to use form parsing."
- content_type_header = self.orig_headers.get("Content-Type")
+ ), 'The `python-multipart` library must be installed to use form parsing.'
+ content_type_header = self.orig_headers.get('Content-Type')
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
- if content_type == b"multipart/form-data":
+ if content_type == b'multipart/form-data':
try:
multipart_parser = MultiPartParser(
self.orig_headers,
@@ -123,10 +149,10 @@ class MultipartRequest(Request):
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
- if "app" in self.scope:
+ if 'app' in self.scope:
raise HTTPException(status_code=400, detail=exc.message)
raise exc
- elif content_type == b"application/x-www-form-urlencoded":
+ elif content_type == b'application/x-www-form-urlencoded':
form_parser = FormParser(self.orig_headers, self.stream())
self._form = await form_parser.parse()
else:
@@ -134,7 +160,7 @@ class MultipartRequest(Request):
return self._form
async def body(self) -> bytes:
- if not hasattr(self, "_body"):
+ if not hasattr(self, '_body'):
forms = {}
merge = {}
body = await super().body()
@@ -143,12 +169,12 @@ class MultipartRequest(Request):
form_data = await self.form()
if form_data:
endpoint = self.url.path
- LOG.debug("processing form data")
+ LOG.debug('processing form data')
for k, v in form_data.multi_items():
forms[k] = v
if 'data' not in forms:
- self.form_err = (422, "Non-empty data field is required")
+ self.form_err = (422, 'Non-empty data field is required')
return self._body
try:
tmp = json.loads(forms['data'])
@@ -168,37 +194,57 @@ class MultipartRequest(Request):
for c in cmds:
if not isinstance(c, dict):
- self.form_err = (400,
- f"Malformed command '{c}': any command must be JSON of dict")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': any command must be JSON of dict",
+ )
return self._body
if 'op' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'op' field")
- if endpoint not in ('/config-file', '/container-image',
- '/image', '/configure-section'):
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'op' field",
+ )
+ if endpoint not in (
+ '/config-file',
+ '/container-image',
+ '/image',
+ '/configure-section',
+ ):
if 'path' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'path' field")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'path' field",
+ )
elif not isinstance(c['path'], list):
- self.form_err = (400,
- f"Malformed command '{c}': 'path' field must be a list")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'path' field must be a list",
+ )
elif not all(isinstance(el, str) for el in c['path']):
- self.form_err = (400,
- f"Malformed command '{0}': 'path' field must be a list of strings")
+ self.form_err = (
+ 400,
+ f"Malformed command '{0}': 'path' field must be a list of strings",
+ )
if endpoint in ('/configure'):
if not c['path']:
- self.form_err = (400,
- f"Malformed command '{c}': 'path' list must be non-empty")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'path' list must be non-empty",
+ )
if 'value' in c and not isinstance(c['value'], str):
- self.form_err = (400,
- f"Malformed command '{c}': 'value' field must be a string")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'value' field must be a string",
+ )
if endpoint in ('/configure-section'):
if 'section' not in c and 'config' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'section' or 'config' field")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'section' or 'config' field",
+ )
if 'key' not in forms and 'key' not in merge:
- self.form_err = (401, "Valid API key is required")
+ self.form_err = (401, 'Valid API key is required')
if 'key' in forms and 'key' not in merge:
merge['key'] = forms['key']
@@ -232,12 +278,15 @@ class MultipartRoute(APIRoute):
return custom_route_handler
-router = APIRouter(route_class=MultipartRoute,
- responses={**responses},
- dependencies=[Depends(auth_required)])
+router = APIRouter(
+ route_class=MultipartRoute,
+ responses={**responses},
+ dependencies=[Depends(auth_required)],
+)
-self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background"
+self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background'
+
def call_commit(s: SessionState):
try:
@@ -245,15 +294,22 @@ def call_commit(s: SessionState):
except ConfigSessionError as e:
s.session.discard()
if s.debug:
- LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}")
+ LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}')
else:
- LOG.warning(f"ConfigSessionError: {e}")
-
-
-def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
- ConfigSectionModel, ConfigSectionListModel,
- ConfigSectionTreeModel],
- _request: Request, background_tasks: BackgroundTasks):
+ LOG.warning(f'ConfigSessionError: {e}')
+
+
+def _configure_op(
+ data: Union[
+ ConfigureModel,
+ ConfigureListModel,
+ ConfigSectionModel,
+ ConfigSectionListModel,
+ ConfigSectionTreeModel,
+ ],
+ _request: Request,
+ background_tasks: BackgroundTasks,
+):
# pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
# pylint: disable=consider-using-with
@@ -287,10 +343,10 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
if c.value:
value = c.value
else:
- value = ""
+ value = ''
# For vyos.configsession calls that have no separate value arguments,
# and for type checking too
- cfg_path = " ".join(path + [value]).strip()
+ cfg_path = ' '.join(path + [value]).strip()
elif isinstance(c, BaseConfigSectionModel):
section = c.section
@@ -304,7 +360,9 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
session.set(path, value=value)
elif op == 'delete':
if state.strict and not config.exists(cfg_path):
- raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist")
+ raise ConfigSessionError(
+ f'Cannot delete [{cfg_path}]: path/value does not exist'
+ )
session.delete(path, value=value)
elif op == 'comment':
session.comment(path, value=value)
@@ -343,7 +401,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
session.discard()
status = 400
if state.debug:
- LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}")
+ LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}')
error_msg = str(e)
except Exception:
session.discard()
@@ -351,7 +409,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
status = 500
# Don't give the details away to the outer world
- error_msg = "An internal error occured. Check the logs for details."
+ error_msg = 'An internal error occured. Check the logs for details.'
finally:
lock.release()
@@ -371,21 +429,24 @@ def create_path_import_pki_no_prompt(path):
@router.post('/configure')
-def configure_op(data: Union[ConfigureModel,
- ConfigureListModel],
- request: Request, background_tasks: BackgroundTasks):
+def configure_op(
+ data: Union[ConfigureModel, ConfigureListModel],
+ request: Request,
+ background_tasks: BackgroundTasks,
+):
return _configure_op(data, request, background_tasks)
@router.post('/configure-section')
-def configure_section_op(data: Union[ConfigSectionModel,
- ConfigSectionListModel,
- ConfigSectionTreeModel],
- request: Request, background_tasks: BackgroundTasks):
+def configure_section_op(
+ data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel],
+ request: Request,
+ background_tasks: BackgroundTasks,
+):
return _configure_op(data, request, background_tasks)
-@router.post("/retrieve")
+@router.post('/retrieve')
async def retrieve_op(data: RetrieveModel):
state = SessionState()
session = state.session
@@ -393,7 +454,7 @@ async def retrieve_op(data: RetrieveModel):
config = Config(session_env=env)
op = data.op
- path = " ".join(data.path)
+ path = ' '.join(data.path)
try:
if op == 'returnValue':
@@ -424,7 +485,7 @@ async def retrieve_op(data: RetrieveModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -448,7 +509,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):
if data.file:
path = data.file
else:
- return error(400, "Missing required field \"file\"")
+ return error(400, 'Missing required field "file"')
session.migrate_and_load_config(path)
@@ -466,7 +527,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(msg)
@@ -484,14 +545,14 @@ def image_op(data: ImageModel):
elif op == 'delete':
res = session.remove_image(data.name)
elif op == 'show':
- res = session.show(["system", "image"])
+ res = session.show(['system', 'image'])
elif op == 'set_default':
res = session.set_default_image(data.name)
except ConfigSessionError as e:
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -508,13 +569,13 @@ def container_image_op(data: ContainerImageModel):
if data.name:
name = data.name
else:
- return error(400, "Missing required field \"name\"")
+ return error(400, 'Missing required field "name"')
res = session.add_container_image(name)
elif op == 'delete':
if data.name:
name = data.name
else:
- return error(400, "Missing required field \"name\"")
+ return error(400, 'Missing required field "name"')
res = session.delete_container_image(name)
elif op == 'show':
res = session.show_container_image()
@@ -524,7 +585,7 @@ def container_image_op(data: ContainerImageModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -546,7 +607,7 @@ def generate_op(data: GenerateModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -568,7 +629,7 @@ def show_op(data: ShowModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -590,7 +651,7 @@ def reboot_op(data: RebootModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -612,7 +673,7 @@ def reset_op(data: ResetModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -652,7 +713,7 @@ def import_pki(data: ImportPkiModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
finally:
lock.release()
@@ -676,18 +737,18 @@ def poweroff_op(data: PoweroffModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
-def rest_init(app: "FastAPI"):
+def rest_init(app: 'FastAPI'):
if all(r in app.routes for r in router.routes):
return
app.include_router(router)
-def rest_clear(app: "FastAPI"):
+def rest_clear(app: 'FastAPI'):
for r in router.routes:
if r in app.routes:
app.routes.remove(r)