summaryrefslogtreecommitdiff
path: root/src/services
diff options
context:
space:
mode:
Diffstat (limited to 'src/services')
-rw-r--r--src/services/api/graphql/bindings.py29
-rw-r--r--src/services/api/graphql/graphql/auth_token_mutation.py31
-rw-r--r--src/services/api/graphql/graphql/mutations.py71
-rw-r--r--src/services/api/graphql/graphql/queries.py71
-rw-r--r--src/services/api/graphql/libs/key_auth.py4
-rw-r--r--src/services/api/graphql/libs/token_auth.py10
-rw-r--r--src/services/api/graphql/routers.py41
-rw-r--r--src/services/api/graphql/session/session.py33
-rw-r--r--src/services/api/rest/models.py143
-rw-r--r--src/services/api/rest/routers.py209
-rw-r--r--src/services/api/session.py1
11 files changed, 370 insertions, 273 deletions
diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py
index 93dd0fbfb..ebf745f32 100644
--- a/src/services/api/graphql/bindings.py
+++ b/src/services/api/graphql/bindings.py
@@ -20,18 +20,18 @@ from ariadne import make_executable_schema
from ariadne import load_schema_from_path
from ariadne import snake_case_fallback_resolvers
-from . graphql.queries import query
-from . graphql.mutations import mutation
-from . graphql.directives import directives_dict
-from . graphql.errors import op_mode_error
-from . graphql.auth_token_mutation import auth_token_mutation
-from . libs.token_auth import init_secret
+from .graphql.queries import query
+from .graphql.mutations import mutation
+from .graphql.directives import directives_dict
+from .graphql.errors import op_mode_error
+from .graphql.auth_token_mutation import auth_token_mutation
+from .libs.token_auth import init_secret
-from .. session import SessionState
+from ..session import SessionState
def generate_schema():
- state = SessionState()
+ state = SessionState()
api_schema_dir = vyos.defaults.directories['api_schema']
if state.auth_type == 'token':
@@ -39,9 +39,14 @@ def generate_schema():
type_defs = load_schema_from_path(api_schema_dir)
- schema = make_executable_schema(type_defs, query, op_mode_error,
- mutation, auth_token_mutation,
- snake_case_fallback_resolvers,
- directives=directives_dict)
+ schema = make_executable_schema(
+ type_defs,
+ query,
+ op_mode_error,
+ mutation,
+ auth_token_mutation,
+ snake_case_fallback_resolvers,
+ directives=directives_dict,
+ )
return schema
diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py
index 164960217..c74364603 100644
--- a/src/services/api/graphql/graphql/auth_token_mutation.py
+++ b/src/services/api/graphql/graphql/auth_token_mutation.py
@@ -19,11 +19,12 @@ from typing import Dict
from ariadne import ObjectType
from graphql import GraphQLResolveInfo
-from .. libs.token_auth import generate_token
-from .. session.session import get_user_info
-from ... session import SessionState
+from ..libs.token_auth import generate_token
+from ..session.session import get_user_info
+from ...session import SessionState
+
+auth_token_mutation = ObjectType('Mutation')
-auth_token_mutation = ObjectType("Mutation")
@auth_token_mutation.field('AuthToken')
def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict):
@@ -35,8 +36,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict):
secret = getattr(state, 'secret', '')
exp_interval = int(state.token_exp)
- expiration = (datetime.datetime.now(tz=datetime.timezone.utc) +
- datetime.timedelta(seconds=exp_interval))
+ expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(
+ seconds=exp_interval
+ )
res = generate_token(user, passwd, secret, expiration)
try:
@@ -46,18 +48,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict):
pass
if 'token' in res:
data['result'] = res
- return {
- "success": True,
- "data": data
- }
+ return {'success': True, 'data': data}
if 'errors' in res:
- return {
- "success": False,
- "errors": res['errors']
- }
-
- return {
- "success": False,
- "errors": ['token generation failed']
- }
+ return {'success': False, 'errors': res['errors']}
+
+ return {'success': False, 'errors': ['token generation failed']}
diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py
index 62031ada3..0b391c070 100644
--- a/src/services/api/graphql/graphql/mutations.py
+++ b/src/services/api/graphql/graphql/mutations.py
@@ -14,20 +14,23 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module
-from ariadne import ObjectType, convert_camel_case_to_snake
-from makefun import with_signature
# used below by func_sig
-from typing import Any, Dict, Optional # pylint: disable=W0611
-from graphql import GraphQLResolveInfo # pylint: disable=W0611
+from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401
+from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401
+
+from ariadne import ObjectType, convert_camel_case_to_snake
+from makefun import with_signature
-from ... session import SessionState
-from .. libs import key_auth
-from .. session.session import Session
-from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
from vyos.opmode import Error as OpModeError
-mutation = ObjectType("Mutation")
+from ...session import SessionState
+from ..libs import key_auth
+from ..session.session import Session
+from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
+
+mutation = ObjectType('Mutation')
+
def make_mutation_resolver(mutation_name, class_name, session_func):
"""Dynamically generate a resolver for the mutation named in the
@@ -59,10 +62,7 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
auth = key_auth.auth_required(key)
if auth is None:
- return {
- "success": False,
- "errors": ['invalid API key']
- }
+ return {'success': False, 'errors': ['invalid API key']}
# We are finished with the 'key' entry, and may remove so as to
# pass the rest of data (if any) to function.
@@ -77,14 +77,8 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
if user is None:
error = info.context.get('error')
if error is not None:
- return {
- "success": False,
- "errors": [error]
- }
- return {
- "success": False,
- "errors": ['not authenticated']
- }
+ return {'success': False, 'errors': [error]}
+ return {'success': False, 'errors': ['not authenticated']}
else:
# AtrributeError will have already been raised if no
# auth_type; validation and defaultValue ensure it is
@@ -106,35 +100,36 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
result = method()
data['result'] = result
- return {
- "success": True,
- "data": data
- }
+ return {'success': True, 'data': data}
except OpModeError as e:
typename = type(e).__name__
msg = str(e)
return {
- "success": False,
- "errore": ['op_mode_error'],
- "op_mode_error": {"name": f"{typename}",
- "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"),
- "vyos_code": op_mode_err_code.get(typename, 9999)}
+ 'success': False,
+ 'errore': ['op_mode_error'],
+ 'op_mode_error': {
+ 'name': f'{typename}',
+ 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'),
+ 'vyos_code': op_mode_err_code.get(typename, 9999),
+ },
}
except Exception as error:
- return {
- "success": False,
- "errors": [repr(error)]
- }
+ return {'success': False, 'errors': [repr(error)]}
return func_impl
+
def make_config_session_mutation_resolver(mutation_name):
- return make_mutation_resolver(mutation_name, mutation_name,
- convert_camel_case_to_snake(mutation_name))
+ return make_mutation_resolver(
+ mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name)
+ )
+
def make_gen_op_mutation_resolver(mutation_name):
return make_mutation_resolver(mutation_name, mutation_name, 'gen_op_mutation')
+
def make_composite_mutation_resolver(mutation_name):
- return make_mutation_resolver(mutation_name, mutation_name,
- convert_camel_case_to_snake(mutation_name))
+ return make_mutation_resolver(
+ mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name)
+ )
diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py
index 1e9036574..9303fe909 100644
--- a/src/services/api/graphql/graphql/queries.py
+++ b/src/services/api/graphql/graphql/queries.py
@@ -14,20 +14,23 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module
-from ariadne import ObjectType, convert_camel_case_to_snake
-from makefun import with_signature
# used below by func_sig
-from typing import Any, Dict, Optional # pylint: disable=W0611
-from graphql import GraphQLResolveInfo # pylint: disable=W0611
+from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401
+from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401
+
+from ariadne import ObjectType, convert_camel_case_to_snake
+from makefun import with_signature
-from ... session import SessionState
-from .. libs import key_auth
-from .. session.session import Session
-from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
from vyos.opmode import Error as OpModeError
-query = ObjectType("Query")
+from ...session import SessionState
+from ..libs import key_auth
+from ..session.session import Session
+from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
+
+query = ObjectType('Query')
+
def make_query_resolver(query_name, class_name, session_func):
"""Dynamically generate a resolver for the query named in the
@@ -59,10 +62,7 @@ def make_query_resolver(query_name, class_name, session_func):
auth = key_auth.auth_required(key)
if auth is None:
- return {
- "success": False,
- "errors": ['invalid API key']
- }
+ return {'success': False, 'errors': ['invalid API key']}
# We are finished with the 'key' entry, and may remove so as to
# pass the rest of data (if any) to function.
@@ -77,14 +77,8 @@ def make_query_resolver(query_name, class_name, session_func):
if user is None:
error = info.context.get('error')
if error is not None:
- return {
- "success": False,
- "errors": [error]
- }
- return {
- "success": False,
- "errors": ['not authenticated']
- }
+ return {'success': False, 'errors': [error]}
+ return {'success': False, 'errors': ['not authenticated']}
else:
# AtrributeError will have already been raised if no
# auth_type; validation and defaultValue ensure it is
@@ -106,35 +100,36 @@ def make_query_resolver(query_name, class_name, session_func):
result = method()
data['result'] = result
- return {
- "success": True,
- "data": data
- }
+ return {'success': True, 'data': data}
except OpModeError as e:
typename = type(e).__name__
msg = str(e)
return {
- "success": False,
- "errors": ['op_mode_error'],
- "op_mode_error": {"name": f"{typename}",
- "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"),
- "vyos_code": op_mode_err_code.get(typename, 9999)}
+ 'success': False,
+ 'errors': ['op_mode_error'],
+ 'op_mode_error': {
+ 'name': f'{typename}',
+ 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'),
+ 'vyos_code': op_mode_err_code.get(typename, 9999),
+ },
}
except Exception as error:
- return {
- "success": False,
- "errors": [repr(error)]
- }
+ return {'success': False, 'errors': [repr(error)]}
return func_impl
+
def make_config_session_query_resolver(query_name):
- return make_query_resolver(query_name, query_name,
- convert_camel_case_to_snake(query_name))
+ return make_query_resolver(
+ query_name, query_name, convert_camel_case_to_snake(query_name)
+ )
+
def make_gen_op_query_resolver(query_name):
return make_query_resolver(query_name, query_name, 'gen_op_query')
+
def make_composite_query_resolver(query_name):
- return make_query_resolver(query_name, query_name,
- convert_camel_case_to_snake(query_name))
+ return make_query_resolver(
+ query_name, query_name, convert_camel_case_to_snake(query_name)
+ )
diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py
index 9e49a1203..ffd7f32b2 100644
--- a/src/services/api/graphql/libs/key_auth.py
+++ b/src/services/api/graphql/libs/key_auth.py
@@ -14,7 +14,8 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>.
-from ... session import SessionState
+from ...session import SessionState
+
def check_auth(key_list, key):
if not key_list:
@@ -25,6 +26,7 @@ def check_auth(key_list, key):
key_id = k['id']
return key_id
+
def auth_required(key):
state = SessionState()
api_keys = None
diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py
index 2d772e035..4f743a096 100644
--- a/src/services/api/graphql/libs/token_auth.py
+++ b/src/services/api/graphql/libs/token_auth.py
@@ -19,7 +19,7 @@ import uuid
import pam
from secrets import token_hex
-from ... session import SessionState
+from ...session import SessionState
def _check_passwd_pam(username: str, passwd: str) -> bool:
@@ -48,13 +48,13 @@ def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict:
payload_data = {'iss': user, 'sub': user_id, 'exp': exp}
secret = getattr(state, 'secret', None)
if secret is None:
- return {"errors": ['missing secret']}
- token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256")
+ return {'errors': ['missing secret']}
+ token = jwt.encode(payload=payload_data, key=secret, algorithm='HS256')
users |= {user_id: user}
return {'token': token}
else:
- return {"errors": ['failed pam authentication']}
+ return {'errors': ['failed pam authentication']}
def get_user_context(request):
@@ -70,7 +70,7 @@ def get_user_context(request):
try:
secret = getattr(state, 'secret', None)
- payload = jwt.decode(token, secret, algorithms=["HS256"])
+ payload = jwt.decode(token, secret, algorithms=['HS256'])
user_id: str = payload.get('sub')
if user_id is None:
return context
diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py
index f02380cdc..ed3ee1e8c 100644
--- a/src/services/api/graphql/routers.py
+++ b/src/services/api/graphql/routers.py
@@ -26,14 +26,15 @@ if typing.TYPE_CHECKING:
from fastapi import FastAPI
-def graphql_init(app: "FastAPI"):
- from .. session import SessionState
+def graphql_init(app: 'FastAPI'):
+ from ..session import SessionState
from .libs.token_auth import get_user_context
state = SessionState()
# import after initializaion of state
from .bindings import generate_schema
+
schema = generate_schema()
in_spec = state.introspection
@@ -44,21 +45,33 @@ def graphql_init(app: "FastAPI"):
if state.origins:
origins = state.origins
- app.add_route('/graphql', CORSMiddleware(GraphQL(schema,
- context_value=get_user_context,
- debug=True,
- introspection=in_spec),
- allow_origins=origins,
- allow_methods=("GET", "POST", "OPTIONS"),
- allow_headers=("Authorization",)))
+ app.add_route(
+ '/graphql',
+ CORSMiddleware(
+ GraphQL(
+ schema,
+ context_value=get_user_context,
+ debug=True,
+ introspection=in_spec,
+ ),
+ allow_origins=origins,
+ allow_methods=('GET', 'POST', 'OPTIONS'),
+ allow_headers=('Authorization',),
+ ),
+ )
else:
- app.add_route('/graphql', GraphQL(schema,
- context_value=get_user_context,
- debug=True,
- introspection=in_spec))
+ app.add_route(
+ '/graphql',
+ GraphQL(
+ schema,
+ context_value=get_user_context,
+ debug=True,
+ introspection=in_spec,
+ ),
+ )
-def graphql_clear(app: "FastAPI"):
+def graphql_clear(app: 'FastAPI'):
for r in app.routes:
if r.path == '/graphql':
app.routes.remove(r)
diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py
index 6e2875f3c..619534f43 100644
--- a/src/services/api/graphql/session/session.py
+++ b/src/services/api/graphql/session/session.py
@@ -28,34 +28,45 @@ from api.graphql.libs.op_mode import normalize_output
op_mode_include_file = os.path.join(directories['data'], 'op-mode-standardized.json')
-def get_config_dict(path=[], effective=False, key_mangling=None,
- get_first_key=False, no_multi_convert=False,
- no_tag_node_value_mangle=False):
+
+def get_config_dict(
+ path=[],
+ effective=False,
+ key_mangling=None,
+ get_first_key=False,
+ no_multi_convert=False,
+ no_tag_node_value_mangle=False,
+):
config = Config()
- return config.get_config_dict(path=path, effective=effective,
- key_mangling=key_mangling,
- get_first_key=get_first_key,
- no_multi_convert=no_multi_convert,
- no_tag_node_value_mangle=no_tag_node_value_mangle)
+ return config.get_config_dict(
+ path=path,
+ effective=effective,
+ key_mangling=key_mangling,
+ get_first_key=get_first_key,
+ no_multi_convert=no_multi_convert,
+ no_tag_node_value_mangle=no_tag_node_value_mangle,
+ )
+
def get_user_info(user):
user_info = {}
- info = get_config_dict(['system', 'login', 'user', user],
- get_first_key=True)
+ info = get_config_dict(['system', 'login', 'user', user], get_first_key=True)
if not info:
- raise ValueError("No such user")
+ raise ValueError('No such user')
user_info['user'] = user
user_info['full_name'] = info.get('full-name', '')
return user_info
+
class Session:
"""
Wrapper for calling configsession functions based on GraphQL requests.
Non-nullable fields in the respective schema allow avoiding a key check
in 'data'.
"""
+
def __init__(self, session, data):
self._session = session
self._data = data
diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py
index 034e3fcdb..d65d6e1ec 100644
--- a/src/services/api/rest/models.py
+++ b/src/services/api/rest/models.py
@@ -31,75 +31,88 @@ from fastapi.responses import HTMLResponse
def error(code, msg):
- resp = {"success": False, "error": msg, "data": None}
+ resp = {'success': False, 'error': msg, 'data': None}
resp = json.dumps(resp)
return HTMLResponse(resp, status_code=code)
+
def success(data):
- resp = {"success": True, "data": data, "error": None}
+ resp = {'success': True, 'data': data, 'error': None}
resp = json.dumps(resp)
return HTMLResponse(resp)
+
# Pydantic models for validation
# Pydantic will cast when possible, so use StrictStr validators added as
# needed for additional constraints
# json_schema_extra adds anotations to OpenAPI to add examples
+
class ApiModel(BaseModel):
key: StrictStr
+
class BasePathModel(BaseModel):
op: StrictStr
path: List[StrictStr]
- @field_validator("path")
+ @field_validator('path')
@classmethod
def check_non_empty(cls, path: str) -> str:
if not len(path) > 0:
raise ValueError('path must be non-empty')
return path
+
class BaseConfigureModel(BasePathModel):
value: StrictStr = None
+
class ConfigureModel(ApiModel, BaseConfigureModel):
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "set | delete | comment",
- "path": ["config", "mode", "path"],
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'set | delete | comment',
+ 'path': ['config', 'mode', 'path'],
}
}
+
class ConfigureListModel(ApiModel):
commands: List[BaseConfigureModel]
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "commands": "list of commands",
+ 'example': {
+ 'key': 'id_key',
+ 'commands': 'list of commands',
}
}
+
class BaseConfigSectionModel(BasePathModel):
section: Dict
+
class ConfigSectionModel(ApiModel, BaseConfigSectionModel):
pass
+
class ConfigSectionListModel(ApiModel):
commands: List[BaseConfigSectionModel]
+
class BaseConfigSectionTreeModel(BaseModel):
op: StrictStr
mask: Dict
config: Dict
+
class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel):
pass
+
class RetrieveModel(ApiModel):
op: StrictStr
path: List[StrictStr]
@@ -107,34 +120,34 @@ class RetrieveModel(ApiModel):
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "returnValue | returnValues | exists | showConfig",
- "path": ["config", "mode", "path"],
- "configFormat": "json (default) | json_ast | raw",
-
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'returnValue | returnValues | exists | showConfig',
+ 'path': ['config', 'mode', 'path'],
+ 'configFormat': 'json (default) | json_ast | raw',
}
}
+
class ConfigFileModel(ApiModel):
op: StrictStr
file: StrictStr = None
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "save | load",
- "file": "filename",
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'save | load',
+ 'file': 'filename',
}
}
class ImageOp(str, Enum):
- add = "add"
- delete = "delete"
- show = "show"
- set_default = "set_default"
+ add = 'add'
+ delete = 'delete'
+ show = 'show'
+ set_default = 'set_default'
class ImageModel(ApiModel):
@@ -146,23 +159,24 @@ class ImageModel(ApiModel):
def check_data(self) -> Self:
if self.op == 'add':
if not self.url:
- raise ValueError("Missing required field \"url\"")
+ raise ValueError('Missing required field "url"')
elif self.op in ['delete', 'set_default']:
if not self.name:
- raise ValueError("Missing required field \"name\"")
+ raise ValueError('Missing required field "name"')
return self
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "add | delete | show | set_default",
- "url": "imagelocation",
- "name": "imagename",
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'add | delete | show | set_default',
+ 'url': 'imagelocation',
+ 'name': 'imagename',
}
}
+
class ImportPkiModel(ApiModel):
op: StrictStr
path: List[StrictStr]
@@ -170,11 +184,11 @@ class ImportPkiModel(ApiModel):
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "import_pki",
- "path": ["op", "mode", "path"],
- "passphrase": "passphrase",
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'import_pki',
+ 'path': ['op', 'mode', 'path'],
+ 'passphrase': 'passphrase',
}
}
@@ -185,75 +199,80 @@ class ContainerImageModel(ApiModel):
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "add | delete | show",
- "name": "imagename",
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'add | delete | show',
+ 'name': 'imagename',
}
}
+
class GenerateModel(ApiModel):
op: StrictStr
path: List[StrictStr]
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "generate",
- "path": ["op", "mode", "path"],
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'generate',
+ 'path': ['op', 'mode', 'path'],
}
}
+
class ShowModel(ApiModel):
op: StrictStr
path: List[StrictStr]
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "show",
- "path": ["op", "mode", "path"],
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'show',
+ 'path': ['op', 'mode', 'path'],
}
}
+
class RebootModel(ApiModel):
op: StrictStr
path: List[StrictStr]
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "reboot",
- "path": ["op", "mode", "path"],
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'reboot',
+ 'path': ['op', 'mode', 'path'],
}
}
+
class ResetModel(ApiModel):
op: StrictStr
path: List[StrictStr]
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "reset",
- "path": ["op", "mode", "path"],
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'reset',
+ 'path': ['op', 'mode', 'path'],
}
}
+
class PoweroffModel(ApiModel):
op: StrictStr
path: List[StrictStr]
class Config:
json_schema_extra = {
- "example": {
- "key": "id_key",
- "op": "poweroff",
- "path": ["op", "mode", "path"],
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'poweroff',
+ 'path': ['op', 'mode', 'path'],
}
}
@@ -263,14 +282,16 @@ class Success(BaseModel):
data: Union[str, bool, Dict]
error: str
+
class Error(BaseModel):
success: bool = False
data: Union[str, bool, Dict]
error: str
+
responses = {
200: {'model': Success},
400: {'model': Error},
422: {'model': Error, 'description': 'Validation Error'},
- 500: {'model': Error}
+ 500: {'model': Error},
}
diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py
index 38b10ef7d..da981d5bf 100644
--- a/src/services/api/rest/routers.py
+++ b/src/services/api/rest/routers.py
@@ -18,10 +18,12 @@
# pylint: disable=wildcard-import,unused-wildcard-import
# pylint: disable=broad-exception-caught
+import json
import copy
import logging
import traceback
from threading import Lock
+from typing import Union
from typing import Callable
from typing import TYPE_CHECKING
@@ -43,8 +45,30 @@ from vyos.configtree import ConfigTree
from vyos.configdiff import get_config_diff
from vyos.configsession import ConfigSessionError
-from .. session import SessionState
-from . models import *
+from ..session import SessionState
+from .models import success
+from .models import error
+from .models import responses
+from .models import ApiModel
+from .models import ConfigureModel
+from .models import ConfigureListModel
+from .models import ConfigSectionModel
+from .models import ConfigSectionListModel
+from .models import ConfigSectionTreeModel
+from .models import BaseConfigSectionTreeModel
+from .models import BaseConfigureModel
+from .models import BaseConfigSectionModel
+from .models import RetrieveModel
+from .models import ConfigFileModel
+from .models import ImageModel
+from .models import ContainerImageModel
+from .models import GenerateModel
+from .models import ShowModel
+from .models import RebootModel
+from .models import ResetModel
+from .models import ImportPkiModel
+from .models import PoweroffModel
+
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -69,7 +93,7 @@ def auth_required(data: ApiModel):
api_keys = session.keys
key_id = check_auth(api_keys, key)
if not key_id:
- raise HTTPException(status_code=401, detail="Valid API key is required")
+ raise HTTPException(status_code=401, detail='Valid API key is required')
session.id = key_id
@@ -79,10 +103,12 @@ def auth_required(data: ApiModel):
# validation by FastAPI/Pydantic, as is used for application/json requests
class MultipartRequest(Request):
"""Override Request class to convert form request to json"""
+
# pylint: disable=attribute-defined-outside-init
# pylint: disable=too-many-branches,too-many-statements
_form_err = ()
+
@property
def form_err(self):
return self._form_err
@@ -104,16 +130,16 @@ class MultipartRequest(Request):
return self._headers
async def _get_form(
- self, *, max_files: int | float = 1000, max_fields: int | float = 1000
- ) -> FormData:
+ self, *, max_files: int | float = 1000, max_fields: int | float = 1000
+ ) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
- ), "The `python-multipart` library must be installed to use form parsing."
- content_type_header = self.orig_headers.get("Content-Type")
+ ), 'The `python-multipart` library must be installed to use form parsing.'
+ content_type_header = self.orig_headers.get('Content-Type')
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
- if content_type == b"multipart/form-data":
+ if content_type == b'multipart/form-data':
try:
multipart_parser = MultiPartParser(
self.orig_headers,
@@ -123,10 +149,10 @@ class MultipartRequest(Request):
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
- if "app" in self.scope:
+ if 'app' in self.scope:
raise HTTPException(status_code=400, detail=exc.message)
raise exc
- elif content_type == b"application/x-www-form-urlencoded":
+ elif content_type == b'application/x-www-form-urlencoded':
form_parser = FormParser(self.orig_headers, self.stream())
self._form = await form_parser.parse()
else:
@@ -134,7 +160,7 @@ class MultipartRequest(Request):
return self._form
async def body(self) -> bytes:
- if not hasattr(self, "_body"):
+ if not hasattr(self, '_body'):
forms = {}
merge = {}
body = await super().body()
@@ -143,12 +169,12 @@ class MultipartRequest(Request):
form_data = await self.form()
if form_data:
endpoint = self.url.path
- LOG.debug("processing form data")
+ LOG.debug('processing form data')
for k, v in form_data.multi_items():
forms[k] = v
if 'data' not in forms:
- self.form_err = (422, "Non-empty data field is required")
+ self.form_err = (422, 'Non-empty data field is required')
return self._body
try:
tmp = json.loads(forms['data'])
@@ -168,37 +194,57 @@ class MultipartRequest(Request):
for c in cmds:
if not isinstance(c, dict):
- self.form_err = (400,
- f"Malformed command '{c}': any command must be JSON of dict")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': any command must be JSON of dict",
+ )
return self._body
if 'op' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'op' field")
- if endpoint not in ('/config-file', '/container-image',
- '/image', '/configure-section'):
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'op' field",
+ )
+ if endpoint not in (
+ '/config-file',
+ '/container-image',
+ '/image',
+ '/configure-section',
+ ):
if 'path' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'path' field")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'path' field",
+ )
elif not isinstance(c['path'], list):
- self.form_err = (400,
- f"Malformed command '{c}': 'path' field must be a list")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'path' field must be a list",
+ )
elif not all(isinstance(el, str) for el in c['path']):
- self.form_err = (400,
- f"Malformed command '{0}': 'path' field must be a list of strings")
+ self.form_err = (
+ 400,
+ f"Malformed command '{0}': 'path' field must be a list of strings",
+ )
if endpoint in ('/configure'):
if not c['path']:
- self.form_err = (400,
- f"Malformed command '{c}': 'path' list must be non-empty")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'path' list must be non-empty",
+ )
if 'value' in c and not isinstance(c['value'], str):
- self.form_err = (400,
- f"Malformed command '{c}': 'value' field must be a string")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'value' field must be a string",
+ )
if endpoint in ('/configure-section'):
if 'section' not in c and 'config' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'section' or 'config' field")
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'section' or 'config' field",
+ )
if 'key' not in forms and 'key' not in merge:
- self.form_err = (401, "Valid API key is required")
+ self.form_err = (401, 'Valid API key is required')
if 'key' in forms and 'key' not in merge:
merge['key'] = forms['key']
@@ -232,12 +278,15 @@ class MultipartRoute(APIRoute):
return custom_route_handler
-router = APIRouter(route_class=MultipartRoute,
- responses={**responses},
- dependencies=[Depends(auth_required)])
+router = APIRouter(
+ route_class=MultipartRoute,
+ responses={**responses},
+ dependencies=[Depends(auth_required)],
+)
-self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background"
+self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background'
+
def call_commit(s: SessionState):
try:
@@ -245,15 +294,22 @@ def call_commit(s: SessionState):
except ConfigSessionError as e:
s.session.discard()
if s.debug:
- LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}")
+ LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}')
else:
- LOG.warning(f"ConfigSessionError: {e}")
-
-
-def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
- ConfigSectionModel, ConfigSectionListModel,
- ConfigSectionTreeModel],
- _request: Request, background_tasks: BackgroundTasks):
+ LOG.warning(f'ConfigSessionError: {e}')
+
+
+def _configure_op(
+ data: Union[
+ ConfigureModel,
+ ConfigureListModel,
+ ConfigSectionModel,
+ ConfigSectionListModel,
+ ConfigSectionTreeModel,
+ ],
+ _request: Request,
+ background_tasks: BackgroundTasks,
+):
# pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
# pylint: disable=consider-using-with
@@ -287,10 +343,10 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
if c.value:
value = c.value
else:
- value = ""
+ value = ''
# For vyos.configsession calls that have no separate value arguments,
# and for type checking too
- cfg_path = " ".join(path + [value]).strip()
+ cfg_path = ' '.join(path + [value]).strip()
elif isinstance(c, BaseConfigSectionModel):
section = c.section
@@ -304,7 +360,9 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
session.set(path, value=value)
elif op == 'delete':
if state.strict and not config.exists(cfg_path):
- raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist")
+ raise ConfigSessionError(
+ f'Cannot delete [{cfg_path}]: path/value does not exist'
+ )
session.delete(path, value=value)
elif op == 'comment':
session.comment(path, value=value)
@@ -343,7 +401,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
session.discard()
status = 400
if state.debug:
- LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}")
+ LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}')
error_msg = str(e)
except Exception:
session.discard()
@@ -351,7 +409,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
status = 500
# Don't give the details away to the outer world
- error_msg = "An internal error occured. Check the logs for details."
+ error_msg = 'An internal error occured. Check the logs for details.'
finally:
lock.release()
@@ -371,21 +429,24 @@ def create_path_import_pki_no_prompt(path):
@router.post('/configure')
-def configure_op(data: Union[ConfigureModel,
- ConfigureListModel],
- request: Request, background_tasks: BackgroundTasks):
+def configure_op(
+ data: Union[ConfigureModel, ConfigureListModel],
+ request: Request,
+ background_tasks: BackgroundTasks,
+):
return _configure_op(data, request, background_tasks)
@router.post('/configure-section')
-def configure_section_op(data: Union[ConfigSectionModel,
- ConfigSectionListModel,
- ConfigSectionTreeModel],
- request: Request, background_tasks: BackgroundTasks):
+def configure_section_op(
+ data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel],
+ request: Request,
+ background_tasks: BackgroundTasks,
+):
return _configure_op(data, request, background_tasks)
-@router.post("/retrieve")
+@router.post('/retrieve')
async def retrieve_op(data: RetrieveModel):
state = SessionState()
session = state.session
@@ -393,7 +454,7 @@ async def retrieve_op(data: RetrieveModel):
config = Config(session_env=env)
op = data.op
- path = " ".join(data.path)
+ path = ' '.join(data.path)
try:
if op == 'returnValue':
@@ -424,7 +485,7 @@ async def retrieve_op(data: RetrieveModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -448,7 +509,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):
if data.file:
path = data.file
else:
- return error(400, "Missing required field \"file\"")
+ return error(400, 'Missing required field "file"')
session.migrate_and_load_config(path)
@@ -466,7 +527,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(msg)
@@ -484,14 +545,14 @@ def image_op(data: ImageModel):
elif op == 'delete':
res = session.remove_image(data.name)
elif op == 'show':
- res = session.show(["system", "image"])
+ res = session.show(['system', 'image'])
elif op == 'set_default':
res = session.set_default_image(data.name)
except ConfigSessionError as e:
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -508,13 +569,13 @@ def container_image_op(data: ContainerImageModel):
if data.name:
name = data.name
else:
- return error(400, "Missing required field \"name\"")
+ return error(400, 'Missing required field "name"')
res = session.add_container_image(name)
elif op == 'delete':
if data.name:
name = data.name
else:
- return error(400, "Missing required field \"name\"")
+ return error(400, 'Missing required field "name"')
res = session.delete_container_image(name)
elif op == 'show':
res = session.show_container_image()
@@ -524,7 +585,7 @@ def container_image_op(data: ContainerImageModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -546,7 +607,7 @@ def generate_op(data: GenerateModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -568,7 +629,7 @@ def show_op(data: ShowModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -590,7 +651,7 @@ def reboot_op(data: RebootModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -612,7 +673,7 @@ def reset_op(data: ResetModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
@@ -652,7 +713,7 @@ def import_pki(data: ImportPkiModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
finally:
lock.release()
@@ -676,18 +737,18 @@ def poweroff_op(data: PoweroffModel):
return error(400, str(e))
except Exception:
LOG.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
-def rest_init(app: "FastAPI"):
+def rest_init(app: 'FastAPI'):
if all(r in app.routes for r in router.routes):
return
app.include_router(router)
-def rest_clear(app: "FastAPI"):
+def rest_clear(app: 'FastAPI'):
for r in router.routes:
if r in app.routes:
app.routes.remove(r)
diff --git a/src/services/api/session.py b/src/services/api/session.py
index dcdc7246c..ad3ef660c 100644
--- a/src/services/api/session.py
+++ b/src/services/api/session.py
@@ -13,6 +13,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
class SessionState:
# pylint: disable=attribute-defined-outside-init
# pylint: disable=too-many-instance-attributes,too-few-public-methods