diff options
Diffstat (limited to 'src/services')
22 files changed, 3638 insertions, 1143 deletions
diff --git a/src/services/api/__init__.py b/src/services/api/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/services/api/__init__.py diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py index ef4966466..ebf745f32 100644 --- a/src/services/api/graphql/bindings.py +++ b/src/services/api/graphql/bindings.py @@ -1,4 +1,4 @@ -# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -13,24 +13,40 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see <http://www.gnu.org/licenses/>. + import vyos.defaults -from . graphql.queries import query -from . graphql.mutations import mutation -from . graphql.directives import directives_dict -from . graphql.errors import op_mode_error -from . graphql.auth_token_mutation import auth_token_mutation -from . libs.token_auth import init_secret -from . import state -from ariadne import make_executable_schema, load_schema_from_path, snake_case_fallback_resolvers + +from ariadne import make_executable_schema +from ariadne import load_schema_from_path +from ariadne import snake_case_fallback_resolvers + +from .graphql.queries import query +from .graphql.mutations import mutation +from .graphql.directives import directives_dict +from .graphql.errors import op_mode_error +from .graphql.auth_token_mutation import auth_token_mutation +from .libs.token_auth import init_secret + +from ..session import SessionState + def generate_schema(): + state = SessionState() api_schema_dir = vyos.defaults.directories['api_schema'] - if state.settings['app'].state.vyos_auth_type == 'token': + if state.auth_type == 'token': init_secret() type_defs = load_schema_from_path(api_schema_dir) - schema = make_executable_schema(type_defs, query, op_mode_error, mutation, auth_token_mutation, snake_case_fallback_resolvers, directives=directives_dict) + schema = make_executable_schema( + type_defs, + query, + op_mode_error, + mutation, + auth_token_mutation, + snake_case_fallback_resolvers, + directives=directives_dict, + ) return schema diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py index a53fa4d60..c74364603 100644 --- a/src/services/api/graphql/graphql/auth_token_mutation.py +++ b/src/services/api/graphql/graphql/auth_token_mutation.py @@ -19,11 +19,12 @@ from typing import Dict from ariadne import ObjectType from graphql import GraphQLResolveInfo -from .. libs.token_auth import generate_token -from .. session.session import get_user_info -from .. import state +from ..libs.token_auth import generate_token +from ..session.session import get_user_info +from ...session import SessionState + +auth_token_mutation = ObjectType('Mutation') -auth_token_mutation = ObjectType("Mutation") @auth_token_mutation.field('AuthToken') def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): @@ -31,10 +32,13 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): user = data['username'] passwd = data['password'] - secret = state.settings['secret'] - exp_interval = int(state.settings['app'].state.vyos_token_exp) - expiration = (datetime.datetime.now(tz=datetime.timezone.utc) + - datetime.timedelta(seconds=exp_interval)) + state = SessionState() + + secret = getattr(state, 'secret', '') + exp_interval = int(state.token_exp) + expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=exp_interval + ) res = generate_token(user, passwd, secret, expiration) try: @@ -44,18 +48,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): pass if 'token' in res: data['result'] = res - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} if 'errors' in res: - return { - "success": False, - "errors": res['errors'] - } - - return { - "success": False, - "errors": ['token generation failed'] - } + return {'success': False, 'errors': res['errors']} + + return {'success': False, 'errors': ['token generation failed']} diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py index d115a8e94..0b391c070 100644 --- a/src/services/api/graphql/graphql/mutations.py +++ b/src/services/api/graphql/graphql/mutations.py @@ -14,20 +14,23 @@ # along with this library. If not, see <http://www.gnu.org/licenses/>. from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from .. import state -from .. libs import key_auth -from api.graphql.session.session import Session -from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -mutation = ObjectType("Mutation") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +mutation = ObjectType('Mutation') + def make_mutation_resolver(mutation_name, class_name, session_func): """Dynamically generate a resolver for the mutation named in the @@ -45,12 +48,13 @@ def make_mutation_resolver(mutation_name, class_name, session_func): func_base_name = convert_camel_case_to_snake(class_name) resolver_name = f'resolve_{func_base_name}' func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)' + state = SessionState() @mutation.field(mutation_name) @with_signature(func_sig, func_name=resolver_name) async def func_impl(*args, **kwargs): try: - auth_type = state.settings['app'].state.vyos_auth_type + auth_type = state.auth_type if auth_type == 'key': data = kwargs['data'] @@ -58,10 +62,7 @@ def make_mutation_resolver(mutation_name, class_name, session_func): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -76,21 +77,15 @@ def make_mutation_resolver(mutation_name, class_name, session_func): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no - # vyos_auth_type; validation and defaultValue ensure it is + # auth_type; validation and defaultValue ensure it is # one of the previous cases, so this is never reached. pass - session = state.settings['app'].state.vyos_session + session = state.session # one may override the session functions with a local subclass try: @@ -105,35 +100,36 @@ def make_mutation_resolver(mutation_name, class_name, session_func): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errore": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errore': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) + def make_gen_op_mutation_resolver(mutation_name): return make_mutation_resolver(mutation_name, mutation_name, 'gen_op_mutation') + def make_composite_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py index 717098259..9303fe909 100644 --- a/src/services/api/graphql/graphql/queries.py +++ b/src/services/api/graphql/graphql/queries.py @@ -14,20 +14,23 @@ # along with this library. If not, see <http://www.gnu.org/licenses/>. from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from .. import state -from .. libs import key_auth -from api.graphql.session.session import Session -from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -query = ObjectType("Query") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +query = ObjectType('Query') + def make_query_resolver(query_name, class_name, session_func): """Dynamically generate a resolver for the query named in the @@ -45,12 +48,13 @@ def make_query_resolver(query_name, class_name, session_func): func_base_name = convert_camel_case_to_snake(class_name) resolver_name = f'resolve_{func_base_name}' func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)' + state = SessionState() @query.field(query_name) @with_signature(func_sig, func_name=resolver_name) async def func_impl(*args, **kwargs): try: - auth_type = state.settings['app'].state.vyos_auth_type + auth_type = state.auth_type if auth_type == 'key': data = kwargs['data'] @@ -58,10 +62,7 @@ def make_query_resolver(query_name, class_name, session_func): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -76,21 +77,15 @@ def make_query_resolver(query_name, class_name, session_func): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no - # vyos_auth_type; validation and defaultValue ensure it is + # auth_type; validation and defaultValue ensure it is # one of the previous cases, so this is never reached. pass - session = state.settings['app'].state.vyos_session + session = state.session # one may override the session functions with a local subclass try: @@ -105,35 +100,36 @@ def make_query_resolver(query_name, class_name, session_func): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errors": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errors': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) + def make_gen_op_query_resolver(query_name): return make_query_resolver(query_name, query_name, 'gen_op_query') + def make_composite_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) diff --git a/src/services/api/graphql/libs/__init__.py b/src/services/api/graphql/libs/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/services/api/graphql/libs/__init__.py diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py index 2db0f7d48..ffd7f32b2 100644 --- a/src/services/api/graphql/libs/key_auth.py +++ b/src/services/api/graphql/libs/key_auth.py @@ -1,5 +1,21 @@ +# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +from ...session import SessionState -from .. import state def check_auth(key_list, key): if not key_list: @@ -10,9 +26,11 @@ def check_auth(key_list, key): key_id = k['id'] return key_id + def auth_required(key): + state = SessionState() api_keys = None - api_keys = state.settings['app'].state.vyos_keys + api_keys = state.keys key_id = check_auth(api_keys, key) - state.settings['app'].state.vyos_id = key_id + state.id = key_id return key_id diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py index 8585485c9..4f743a096 100644 --- a/src/services/api/graphql/libs/token_auth.py +++ b/src/services/api/graphql/libs/token_auth.py @@ -1,46 +1,67 @@ +# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + import jwt import uuid import pam from secrets import token_hex -from .. import state +from ...session import SessionState + def _check_passwd_pam(username: str, passwd: str) -> bool: if pam.authenticate(username, passwd): return True return False + def init_secret(): - length = int(state.settings['app'].state.vyos_secret_len) + state = SessionState() + length = int(state.secret_len) secret = token_hex(length) - state.settings['secret'] = secret + state.secret = secret + def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict: if user is None or passwd is None: return {} + state = SessionState() if _check_passwd_pam(user, passwd): - app = state.settings['app'] try: - users = app.state.vyos_token_users + users = state.token_users except AttributeError: - app.state.vyos_token_users = {} - users = app.state.vyos_token_users + users = state.token_users = {} user_id = uuid.uuid1().hex payload_data = {'iss': user, 'sub': user_id, 'exp': exp} - secret = state.settings.get('secret') + secret = getattr(state, 'secret', None) if secret is None: - return {"errors": ['missing secret']} - token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256") + return {'errors': ['missing secret']} + token = jwt.encode(payload=payload_data, key=secret, algorithm='HS256') users |= {user_id: user} return {'token': token} else: - return {"errors": ['failed pam authentication']} + return {'errors': ['failed pam authentication']} + def get_user_context(request): context = {} context['request'] = request context['user'] = None + state = SessionState() if 'Authorization' in request.headers: auth = request.headers['Authorization'] scheme, token = auth.split() @@ -48,8 +69,8 @@ def get_user_context(request): return context try: - secret = state.settings.get('secret') - payload = jwt.decode(token, secret, algorithms=["HS256"]) + secret = getattr(state, 'secret', None) + payload = jwt.decode(token, secret, algorithms=['HS256']) user_id: str = payload.get('sub') if user_id is None: return context @@ -59,7 +80,7 @@ def get_user_context(request): except jwt.PyJWTError: return context try: - users = state.settings['app'].state.vyos_token_users + users = state.token_users except AttributeError: return context diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py new file mode 100644 index 000000000..ed3ee1e8c --- /dev/null +++ b/src/services/api/graphql/routers.py @@ -0,0 +1,77 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +# pylint: disable=import-outside-toplevel + + +import typing + +from ariadne.asgi import GraphQL +from starlette.middleware.cors import CORSMiddleware + + +if typing.TYPE_CHECKING: + from fastapi import FastAPI + + +def graphql_init(app: 'FastAPI'): + from ..session import SessionState + from .libs.token_auth import get_user_context + + state = SessionState() + + # import after initializaion of state + from .bindings import generate_schema + + schema = generate_schema() + + in_spec = state.introspection + + # remove route and reinstall below, for any changes; alternatively, test + # for config_diff before proceeding + graphql_clear(app) + + if state.origins: + origins = state.origins + app.add_route( + '/graphql', + CORSMiddleware( + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + allow_origins=origins, + allow_methods=('GET', 'POST', 'OPTIONS'), + allow_headers=('Authorization',), + ), + ) + else: + app.add_route( + '/graphql', + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + ) + + +def graphql_clear(app: 'FastAPI'): + for r in app.routes: + if r.path == '/graphql': + app.routes.remove(r) diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py index 6ae44b9bf..619534f43 100644 --- a/src/services/api/graphql/session/session.py +++ b/src/services/api/graphql/session/session.py @@ -28,34 +28,45 @@ from api.graphql.libs.op_mode import normalize_output op_mode_include_file = os.path.join(directories['data'], 'op-mode-standardized.json') -def get_config_dict(path=[], effective=False, key_mangling=None, - get_first_key=False, no_multi_convert=False, - no_tag_node_value_mangle=False): + +def get_config_dict( + path=[], + effective=False, + key_mangling=None, + get_first_key=False, + no_multi_convert=False, + no_tag_node_value_mangle=False, +): config = Config() - return config.get_config_dict(path=path, effective=effective, - key_mangling=key_mangling, - get_first_key=get_first_key, - no_multi_convert=no_multi_convert, - no_tag_node_value_mangle=no_tag_node_value_mangle) + return config.get_config_dict( + path=path, + effective=effective, + key_mangling=key_mangling, + get_first_key=get_first_key, + no_multi_convert=no_multi_convert, + no_tag_node_value_mangle=no_tag_node_value_mangle, + ) + def get_user_info(user): user_info = {} - info = get_config_dict(['system', 'login', 'user', user], - get_first_key=True) + info = get_config_dict(['system', 'login', 'user', user], get_first_key=True) if not info: - raise ValueError("No such user") + raise ValueError('No such user') user_info['user'] = user user_info['full_name'] = info.get('full-name', '') return user_info + class Session: """ Wrapper for calling configsession functions based on GraphQL requests. Non-nullable fields in the respective schema allow avoiding a key check in 'data'. """ + def __init__(self, session, data): self._session = session self._data = data @@ -138,7 +149,6 @@ class Session: return res def show_user_info(self): - session = self._session data = self._data user_info = {} @@ -151,10 +161,9 @@ class Session: return user_info def system_status(self): - import api.graphql.session.composite.system_status as system_status + from api.graphql.session.composite import system_status session = self._session - data = self._data status = {} status['host_name'] = session.show(['host', 'name']).strip() @@ -165,7 +174,6 @@ class Session: return status def gen_op_query(self): - session = self._session data = self._data name = self._name op_mode_list = self._op_mode_list @@ -189,7 +197,6 @@ class Session: return res def gen_op_mutation(self): - session = self._session data = self._data name = self._name op_mode_list = self._op_mode_list diff --git a/src/services/api/graphql/state.py b/src/services/api/graphql/state.py deleted file mode 100644 index 63db9f4ef..000000000 --- a/src/services/api/graphql/state.py +++ /dev/null @@ -1,4 +0,0 @@ - -def init(): - global settings - settings = {} diff --git a/src/services/api/rest/__init__.py b/src/services/api/rest/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/services/api/rest/__init__.py diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py new file mode 100644 index 000000000..dda50010f --- /dev/null +++ b/src/services/api/rest/models.py @@ -0,0 +1,320 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +# pylint: disable=too-few-public-methods + +import json +from html import escape +from enum import Enum +from typing import List +from typing import Union +from typing import Dict +from typing import Self + +from pydantic import BaseModel +from pydantic import StrictStr +from pydantic import field_validator +from pydantic import model_validator +from fastapi.responses import HTMLResponse + + +def error(code, msg): + msg = escape(msg, quote=False) + resp = {'success': False, 'error': msg, 'data': None} + resp = json.dumps(resp) + return HTMLResponse(resp, status_code=code) + + +def success(data): + resp = {'success': True, 'data': data, 'error': None} + resp = json.dumps(resp) + return HTMLResponse(resp) + + +# Pydantic models for validation +# Pydantic will cast when possible, so use StrictStr validators added as +# needed for additional constraints +# json_schema_extra adds anotations to OpenAPI to add examples + + +class ApiModel(BaseModel): + key: StrictStr + + +class BasePathModel(BaseModel): + op: StrictStr + path: List[StrictStr] + + @field_validator('path') + @classmethod + def check_non_empty(cls, path: str) -> str: + if not len(path) > 0: + raise ValueError('path must be non-empty') + return path + + +class BaseConfigureModel(BasePathModel): + value: StrictStr = None + + +class ConfigureModel(ApiModel, BaseConfigureModel): + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'set | delete | comment', + 'path': ['config', 'mode', 'path'], + } + } + + +class ConfigureListModel(ApiModel): + commands: List[BaseConfigureModel] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'commands': 'list of commands', + } + } + + +class BaseConfigSectionModel(BasePathModel): + section: Dict + + +class ConfigSectionModel(ApiModel, BaseConfigSectionModel): + pass + + +class ConfigSectionListModel(ApiModel): + commands: List[BaseConfigSectionModel] + + +class BaseConfigSectionTreeModel(BaseModel): + op: StrictStr + mask: Dict + config: Dict + + +class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): + pass + + +class RetrieveModel(ApiModel): + op: StrictStr + path: List[StrictStr] + configFormat: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'returnValue | returnValues | exists | showConfig', + 'path': ['config', 'mode', 'path'], + 'configFormat': 'json (default) | json_ast | raw', + } + } + + +class ConfigFileModel(ApiModel): + op: StrictStr + file: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'save | load', + 'file': 'filename', + } + } + + +class ImageOp(str, Enum): + add = 'add' + delete = 'delete' + show = 'show' + set_default = 'set_default' + + +class ImageModel(ApiModel): + op: ImageOp + url: StrictStr = None + name: StrictStr = None + + @model_validator(mode='after') + def check_data(self) -> Self: + if self.op == 'add': + if not self.url: + raise ValueError('Missing required field "url"') + elif self.op in ['delete', 'set_default']: + if not self.name: + raise ValueError('Missing required field "name"') + + return self + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show | set_default', + 'url': 'imagelocation', + 'name': 'imagename', + } + } + + +class ImportPkiModel(ApiModel): + op: StrictStr + path: List[StrictStr] + passphrase: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'import_pki', + 'path': ['op', 'mode', 'path'], + 'passphrase': 'passphrase', + } + } + + +class ContainerImageModel(ApiModel): + op: StrictStr + name: StrictStr = None + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show', + 'name': 'imagename', + } + } + + +class GenerateModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'generate', + 'path': ['op', 'mode', 'path'], + } + } + + +class ShowModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'show', + 'path': ['op', 'mode', 'path'], + } + } + + +class RebootModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'reboot', + 'path': ['op', 'mode', 'path'], + } + } + + +class ResetModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'reset', + 'path': ['op', 'mode', 'path'], + } + } + + +class PoweroffModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'poweroff', + 'path': ['op', 'mode', 'path'], + } + } + + +class TracerouteModel(ApiModel): + op: StrictStr + host: StrictStr + + class Config: + schema_extra = { + 'example': { + 'key': 'id_key', + 'op': 'traceroute', + 'host': 'host', + } + } + + +class InfoQueryParams(BaseModel): + model_config = {"extra": "forbid"} + + version: bool = True + hostname: bool = True + + +class Success(BaseModel): + success: bool + data: Union[str, bool, Dict] + error: str + + +class Error(BaseModel): + success: bool = False + data: Union[str, bool, Dict] + error: str + + +responses = { + 200: {'model': Success}, + 400: {'model': Error}, + 422: {'model': Error, 'description': 'Validation Error'}, + 500: {'model': Error}, +} diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py new file mode 100644 index 000000000..e52c77fda --- /dev/null +++ b/src/services/api/rest/routers.py @@ -0,0 +1,778 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +# pylint: disable=line-too-long,raise-missing-from,invalid-name +# pylint: disable=wildcard-import,unused-wildcard-import +# pylint: disable=broad-exception-caught + +import json +import copy +import logging +import traceback +from threading import Lock +from typing import Union +from typing import Callable +from typing import TYPE_CHECKING + +from fastapi import Depends +from fastapi import Request +from fastapi import Response +from fastapi import HTTPException +from fastapi import APIRouter +from fastapi import BackgroundTasks +from fastapi.routing import APIRoute +from starlette.datastructures import FormData +from starlette.formparsers import FormParser +from starlette.formparsers import MultiPartParser +from starlette.formparsers import MultiPartException +from multipart.multipart import parse_options_header + +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configdiff import get_config_diff +from vyos.configsession import ConfigSessionError + +from ..session import SessionState +from .models import success +from .models import error +from .models import responses +from .models import ApiModel +from .models import ConfigureModel +from .models import ConfigureListModel +from .models import ConfigSectionModel +from .models import ConfigSectionListModel +from .models import ConfigSectionTreeModel +from .models import BaseConfigSectionTreeModel +from .models import BaseConfigureModel +from .models import BaseConfigSectionModel +from .models import RetrieveModel +from .models import ConfigFileModel +from .models import ImageModel +from .models import ContainerImageModel +from .models import GenerateModel +from .models import ShowModel +from .models import RebootModel +from .models import ResetModel +from .models import ImportPkiModel +from .models import PoweroffModel +from .models import TracerouteModel + + +if TYPE_CHECKING: + from fastapi import FastAPI + + +LOG = logging.getLogger('http_api.routers') + +lock = Lock() + + +def check_auth(key_list, key): + key_id = None + for k in key_list: + if k['key'] == key: + key_id = k['id'] + return key_id + + +def auth_required(data: ApiModel): + session = SessionState() + key = data.key + api_keys = session.keys + key_id = check_auth(api_keys, key) + if not key_id: + raise HTTPException(status_code=401, detail='Valid API key is required') + session.id = key_id + + +# override Request and APIRoute classes in order to convert form request to json; +# do all explicit validation here, for backwards compatability of error messages; +# the explicit validation may be dropped, if desired, in favor of native +# validation by FastAPI/Pydantic, as is used for application/json requests +class MultipartRequest(Request): + """Override Request class to convert form request to json""" + + # pylint: disable=attribute-defined-outside-init + # pylint: disable=too-many-branches,too-many-statements + + _form_err = () + + @property + def form_err(self): + return self._form_err + + @form_err.setter + def form_err(self, val): + if not self._form_err: + self._form_err = val + + @property + def orig_headers(self): + self._orig_headers = super().headers + return self._orig_headers + + @property + def headers(self): + self._headers = super().headers.mutablecopy() + self._headers['content-type'] = 'application/json' + return self._headers + + async def _get_form( + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: + if self._form is None: + assert ( + parse_options_header is not None + ), 'The `python-multipart` library must be installed to use form parsing.' + content_type_header = self.orig_headers.get('Content-Type') + content_type: bytes + content_type, _ = parse_options_header(content_type_header) + if content_type == b'multipart/form-data': + try: + multipart_parser = MultiPartParser( + self.orig_headers, + self.stream(), + max_files=max_files, + max_fields=max_fields, + ) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if 'app' in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc + elif content_type == b'application/x-www-form-urlencoded': + form_parser = FormParser(self.orig_headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() + return self._form + + async def body(self) -> bytes: + if not hasattr(self, '_body'): + forms = {} + merge = {} + body = await super().body() + self._body = body + + form_data = await self.form() + if form_data: + endpoint = self.url.path + LOG.debug('processing form data') + for k, v in form_data.multi_items(): + forms[k] = v + + if 'data' not in forms: + self.form_err = (422, 'Non-empty data field is required') + return self._body + try: + tmp = json.loads(forms['data']) + except json.JSONDecodeError as e: + self.form_err = (400, f'Failed to parse JSON: {e}') + return self._body + if isinstance(tmp, list): + merge['commands'] = tmp + else: + merge = tmp + + if 'commands' in merge: + cmds = merge['commands'] + else: + cmds = copy.deepcopy(merge) + cmds = [cmds] + + for c in cmds: + if not isinstance(c, dict): + self.form_err = ( + 400, + f"Malformed command '{c}': any command must be JSON of dict", + ) + return self._body + if 'op' not in c: + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'op' field", + ) + if endpoint not in ( + '/config-file', + '/container-image', + '/image', + '/configure-section', + '/traceroute', + ): + if 'path' not in c: + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'path' field", + ) + elif not isinstance(c['path'], list): + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' field must be a list", + ) + elif not all(isinstance(el, str) for el in c['path']): + self.form_err = ( + 400, + f"Malformed command '{0}': 'path' field must be a list of strings", + ) + if endpoint in ('/configure'): + if not c['path']: + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' list must be non-empty", + ) + if 'value' in c and not isinstance(c['value'], str): + self.form_err = ( + 400, + f"Malformed command '{c}': 'value' field must be a string", + ) + if endpoint in ('/configure-section'): + if 'section' not in c and 'config' not in c: + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'section' or 'config' field", + ) + + if 'key' not in forms and 'key' not in merge: + self.form_err = (401, 'Valid API key is required') + if 'key' in forms and 'key' not in merge: + merge['key'] = forms['key'] + + new_body = json.dumps(merge) + new_body = new_body.encode() + self._body = new_body + + return self._body + + +class MultipartRoute(APIRoute): + """Override APIRoute class to convert form request to json""" + + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + request = MultipartRequest(request.scope, request.receive) + try: + response: Response = await original_route_handler(request) + except HTTPException as e: + return error(e.status_code, e.detail) + except Exception as e: + form_err = request.form_err + if form_err: + return error(*form_err) + raise e + + return response + + return custom_route_handler + + +router = APIRouter( + route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)], +) + + +self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background' + + +def call_commit(s: SessionState): + try: + s.session.commit() + except ConfigSessionError as e: + s.session.discard() + if s.debug: + LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}') + else: + LOG.warning(f'ConfigSessionError: {e}') + + +def _configure_op( + data: Union[ + ConfigureModel, + ConfigureListModel, + ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel, + ], + _request: Request, + background_tasks: BackgroundTasks, +): + # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements + # pylint: disable=consider-using-with + + state = SessionState() + session = state.session + env = session.get_session_env() + + # Allow users to pass just one command + if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)): + data = [data] + else: + data = data.commands + + # We don't want multiple people/apps to be able to commit at once, + # or modify the shared session while someone else is doing the same, + # so the lock is really global + lock.acquire() + + config = Config(session_env=env) + + status = 200 + msg = None + error_msg = None + try: + for c in data: + op = c.op + if not isinstance(c, BaseConfigSectionTreeModel): + path = c.path + + if isinstance(c, BaseConfigureModel): + if c.value: + value = c.value + else: + value = '' + # For vyos.configsession calls that have no separate value arguments, + # and for type checking too + cfg_path = ' '.join(path + [value]).strip() + + elif isinstance(c, BaseConfigSectionModel): + section = c.section + + elif isinstance(c, BaseConfigSectionTreeModel): + mask = c.mask + config = c.config + + if isinstance(c, BaseConfigureModel): + if op == 'set': + session.set(path, value=value) + elif op == 'delete': + if state.strict and not config.exists(cfg_path): + raise ConfigSessionError( + f'Cannot delete [{cfg_path}]: path/value does not exist' + ) + session.delete(path, value=value) + elif op == 'comment': + session.comment(path, value=value) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + + elif isinstance(c, BaseConfigSectionModel): + if op == 'set': + session.set_section(path, section) + elif op == 'load': + session.load_section(path, section) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + + elif isinstance(c, BaseConfigSectionTreeModel): + if op == 'set': + session.set_section_tree(config) + elif op == 'load': + session.load_section_tree(mask, config) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + # end for + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, state) + msg = self_ref_msg + else: + # capture non-fatal warnings + out = session.commit() + msg = out if out else msg + + LOG.info(f"Configuration modified via HTTP API using key '{state.id}'") + except ConfigSessionError as e: + session.discard() + status = 400 + if state.debug: + LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}') + error_msg = str(e) + except Exception: + session.discard() + LOG.critical(traceback.format_exc()) + status = 500 + + # Don't give the details away to the outer world + error_msg = 'An internal error occured. Check the logs for details.' + finally: + lock.release() + + if status != 200: + return error(status, error_msg) + + return success(msg) + + +def create_path_import_pki_no_prompt(path): + correct_paths = ['ca', 'certificate', 'key-pair'] + if path[1] not in correct_paths: + return False + path[3] = '--key-filename' + path.insert(2, '--name') + return ['--pki-type'] + path[1:] + + +@router.post('/configure') +def configure_op( + data: Union[ConfigureModel, ConfigureListModel], + request: Request, + background_tasks: BackgroundTasks, +): + return _configure_op(data, request, background_tasks) + + +@router.post('/configure-section') +def configure_section_op( + data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel], + request: Request, + background_tasks: BackgroundTasks, +): + return _configure_op(data, request, background_tasks) + + +@router.post('/retrieve') +async def retrieve_op(data: RetrieveModel): + state = SessionState() + session = state.session + env = session.get_session_env() + config = Config(session_env=env) + + op = data.op + path = ' '.join(data.path) + + try: + if op == 'returnValue': + res = config.return_value(path) + elif op == 'returnValues': + res = config.return_values(path) + elif op == 'exists': + res = config.exists(path) + elif op == 'showConfig': + config_format = 'json' + if data.configFormat: + config_format = data.configFormat + + res = session.show_config(path=data.path) + if config_format == 'json': + config_tree = ConfigTree(res) + res = json.loads(config_tree.to_json()) + elif config_format == 'json_ast': + config_tree = ConfigTree(res) + res = json.loads(config_tree.to_json_ast()) + elif config_format == 'raw': + pass + else: + return error(400, f"'{config_format}' is not a valid config format") + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/config-file') +def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): + state = SessionState() + session = state.session + env = session.get_session_env() + op = data.op + msg = None + + try: + if op == 'save': + if data.file: + path = data.file + else: + path = '/config/config.boot' + msg = session.save_config(path) + elif op == 'load': + if data.file: + path = data.file + else: + return error(400, 'Missing required field "file"') + + session.migrate_and_load_config(path) + + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, state) + msg = self_ref_msg + else: + session.commit() + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(msg) + + +@router.post('/image') +def image_op(data: ImageModel): + state = SessionState() + session = state.session + + op = data.op + + try: + if op == 'add': + res = session.install_image(data.url) + elif op == 'delete': + res = session.remove_image(data.name) + elif op == 'show': + res = session.show(['system', 'image']) + elif op == 'set_default': + res = session.set_default_image(data.name) + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/container-image') +def container_image_op(data: ContainerImageModel): + state = SessionState() + session = state.session + + op = data.op + + try: + if op == 'add': + if data.name: + name = data.name + else: + return error(400, 'Missing required field "name"') + res = session.add_container_image(name) + elif op == 'delete': + if data.name: + name = data.name + else: + return error(400, 'Missing required field "name"') + res = session.delete_container_image(name) + elif op == 'show': + res = session.show_container_image() + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/generate') +def generate_op(data: GenerateModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'generate': + res = session.generate(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/show') +def show_op(data: ShowModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'show': + res = session.show(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/reboot') +def reboot_op(data: RebootModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'reboot': + res = session.reboot(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/reset') +def reset_op(data: ResetModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'reset': + res = session.reset(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/import-pki') +def import_pki(data: ImportPkiModel): + # pylint: disable=consider-using-with + + state = SessionState() + session = state.session + + op = data.op + path = data.path + + lock.acquire() + + try: + if op == 'import-pki': + # need to get rid or interactive mode for private key + if len(path) == 5 and path[3] in ['key-file', 'private-key']: + path_no_prompt = create_path_import_pki_no_prompt(path) + if not path_no_prompt: + return error(400, f"Invalid command: {' '.join(path)}") + if data.passphrase: + path_no_prompt += ['--passphrase', data.passphrase] + res = session.import_pki_no_prompt(path_no_prompt) + else: + res = session.import_pki(path) + if not res[0].isdigit(): + return error(400, res) + # commit changes + session.commit() + res = res.split('. ')[0] + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + finally: + lock.release() + + return success(res) + + +@router.post('/poweroff') +def poweroff_op(data: PoweroffModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'poweroff': + res = session.poweroff(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') + + return success(res) + + +@router.post('/traceroute') +def traceroute_op(data: TracerouteModel): + state = SessionState() + session = state.session + + op = data.op + host = data.host + + try: + if op == 'traceroute': + res = session.traceroute(host) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occurred. Check the logs for details.') + + return success(res) + + +def rest_init(app: 'FastAPI'): + if all(r in app.routes for r in router.routes): + return + app.include_router(router) + + +def rest_clear(app: 'FastAPI'): + for r in router.routes: + if r in app.routes: + app.routes.remove(r) diff --git a/src/services/api/session.py b/src/services/api/session.py new file mode 100644 index 000000000..ad3ef660c --- /dev/null +++ b/src/services/api/session.py @@ -0,0 +1,41 @@ +# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + + +class SessionState: + # pylint: disable=attribute-defined-outside-init + # pylint: disable=too-many-instance-attributes,too-few-public-methods + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(SessionState, cls).__new__(cls) + cls._instance._initialize() + return cls._instance + + def _initialize(self): + self.session = None + self.keys = [] + self.id = None + self.rest = False + self.debug = False + self.strict = False + self.graphql = False + self.origins = [] + self.introspection = False + self.auth_type = None + self.token_exp = None + self.secret_len = None diff --git a/src/services/vyos-commitd b/src/services/vyos-commitd new file mode 100755 index 000000000..e7f2d82c7 --- /dev/null +++ b/src/services/vyos-commitd @@ -0,0 +1,457 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2025 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# +import os +import sys +import grp +import json +import signal +import socket +import typing +import logging +import traceback +import importlib.util +import io +from contextlib import redirect_stdout +from dataclasses import dataclass +from dataclasses import fields +from dataclasses import field +from dataclasses import asdict +from pathlib import Path + +import tomli + +from google.protobuf.json_format import MessageToDict +from google.protobuf.json_format import ParseDict + +from vyos.defaults import directories +from vyos.utils.boot import boot_configuration_complete +from vyos.configsource import ConfigSourceCache +from vyos.configsource import ConfigSourceError +from vyos.config import Config +from vyos.frrender import FRRender +from vyos.frrender import get_frrender_dict +from vyos import ConfigError + +from vyos.proto import vycall_pb2 + + +@dataclass +class Status: + success: bool = False + out: str = '' + + +@dataclass +class Call: + script_name: str = '' + tag_value: str = None + arg_value: str = None + reply: Status = None + + def set_reply(self, success: bool, out: str): + self.reply = Status(success=success, out=out) + + +@dataclass +class Session: + # pylint: disable=too-many-instance-attributes + + session_id: str = '' + dry_run: bool = False + atomic: bool = False + background: bool = False + config: Config = None + init: Status = None + calls: list[Call] = field(default_factory=list) + + def set_init(self, success: bool, out: str): + self.init = Status(success=success, out=out) + + +@dataclass +class ServerConf: + commitd_socket: str = '' + session_dir: str = '' + running_cache: str = '' + session_cache: str = '' + + +server_conf = None +SOCKET_PATH = None +conf_mode_scripts = None +frr = None + +CFG_GROUP = 'vyattacfg' + +script_stdout_log = '/tmp/vyos-commitd-script-stdout' + +debug = True + +logger = logging.getLogger(__name__) +logs_handler = logging.StreamHandler() +logger.addHandler(logs_handler) + +if debug: + logger.setLevel(logging.DEBUG) +else: + logger.setLevel(logging.INFO) + + +vyos_conf_scripts_dir = directories['conf_mode'] +commitd_include_file = os.path.join(directories['data'], 'configd-include.json') + + +def key_name_from_file_name(f): + return os.path.splitext(f)[0] + + +def module_name_from_key(k): + return k.replace('-', '_') + + +def path_from_file_name(f): + return os.path.join(vyos_conf_scripts_dir, f) + + +def load_conf_mode_scripts(): + with open(commitd_include_file) as f: + try: + include = json.load(f) + except OSError as e: + logger.critical(f'configd include file error: {e}') + sys.exit(1) + except json.JSONDecodeError as e: + logger.critical(f'JSON load error: {e}') + sys.exit(1) + + # import conf_mode scripts + (_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir))) + filenames.sort() + + # this is redundant, as all scripts are currently in the include file; + # leave it as an inexpensive check for future changes + load_filenames = [f for f in filenames if f in include] + imports = [key_name_from_file_name(f) for f in load_filenames] + module_names = [module_name_from_key(k) for k in imports] + paths = [path_from_file_name(f) for f in load_filenames] + to_load = list(zip(module_names, paths)) + + modules = [] + + for x in to_load: + spec = importlib.util.spec_from_file_location(x[0], x[1]) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + modules.append(module) + + scripts = dict(zip(imports, modules)) + + return scripts + + +def get_session_out(session: Session) -> str: + out = '' + if session.init and session.init.out: + out = f'{out} + init: {session.init.out} + \n' + for call in session.calls: + reply = call.reply + if reply and reply.out: + out = f'{out} + {call.script_name}: {reply.out} + \n' + return out + + +def write_stdout_log(file_name, session): + if boot_configuration_complete(): + return + with open(file_name, 'a') as f: + f.write(get_session_out(session)) + + +def msg_to_commit_data(msg: vycall_pb2.Commit) -> Session: + # pylint: disable=no-member + + d = MessageToDict(msg, preserving_proto_field_name=True) + + # wrap in dataclasses + session = Session(**d) + session.init = Status(**session.init) if session.init else None + session.calls = list(map(lambda x: Call(**x), session.calls)) + for call in session.calls: + call.reply = Status(**call.reply) if call.reply else None + + return session + + +def commit_data_to_msg(obj: Session) -> vycall_pb2.Commit: + # pylint: disable=no-member + + # avoid asdict attempt of deepcopy on Config obj + obj.config = None + + msg = vycall_pb2.Commit() + msg = ParseDict(asdict(obj), msg, ignore_unknown_fields=True) + + return msg + + +def initialization(session: Session) -> Session: + running_cache = os.path.join(server_conf.session_dir, server_conf.running_cache) + session_cache = os.path.join(server_conf.session_dir, server_conf.session_cache) + try: + configsource = ConfigSourceCache( + running_config_cache=running_cache, + session_config_cache=session_cache, + ) + except ConfigSourceError as e: + fail_msg = f'Failed to read config caches: {e}' + logger.critical(fail_msg) + session.set_init(False, fail_msg) + return session + + session.set_init(True, '') + + config = Config(config_source=configsource) + + dependent_func: dict[str, list[typing.Callable]] = {} + setattr(config, 'dependent_func', dependent_func) + + scripts_called = [] + setattr(config, 'scripts_called', scripts_called) + + dry_run = session.dry_run + config.set_bool_attr('dry_run', dry_run) + logger.debug(f'commit dry_run is {dry_run}') + + session.config = config + + return session + + +def run_script(script_name: str, config: Config, args: list) -> tuple[bool, str]: + # pylint: disable=broad-exception-caught + + script = conf_mode_scripts[script_name] + script.argv = args + config.set_level([]) + dry_run = config.get_bool_attr('dry_run') + try: + c = script.get_config(config) + script.verify(c) + if not dry_run: + script.generate(c) + script.apply(c) + else: + if hasattr(script, 'call_dependents'): + script.call_dependents() + except ConfigError as e: + logger.error(e) + return False, str(e) + except Exception: + tb = traceback.format_exc() + logger.error(tb) + return False, tb + + return True, '' + + +def process_call_data(call: Call, config: Config, last: bool = False) -> None: + # pylint: disable=too-many-locals + + script_name = key_name_from_file_name(call.script_name) + + if script_name not in conf_mode_scripts: + fail_msg = f'No such script: {call.script_name}' + logger.critical(fail_msg) + call.set_reply(False, fail_msg) + return + + config.dependency_list.clear() + + tag_value = call.tag_value if call.tag_value is not None else '' + os.environ['VYOS_TAGNODE_VALUE'] = tag_value + + args = call.arg_value.split() if call.arg_value else [] + args.insert(0, f'{script_name}.py') + + tag_ext = f'_{tag_value}' if tag_value else '' + script_record = f'{script_name}{tag_ext}' + scripts_called = getattr(config, 'scripts_called', []) + scripts_called.append(script_record) + + with redirect_stdout(io.StringIO()) as o: + success, err_out = run_script(script_name, config, args) + amb_out = o.getvalue() + o.close() + + out = amb_out + err_out + + call.set_reply(success, out) + + logger.info(f'[{script_name}] {out}') + + if last: + scripts_called = getattr(config, 'scripts_called', []) + logger.debug(f'scripts_called: {scripts_called}') + + if last and success: + tmp = get_frrender_dict(config) + if frr.generate(tmp): + # only apply a new FRR configuration if anything changed + # in comparison to the previous applied configuration + frr.apply() + + +def process_session_data(session: Session) -> Session: + if session.init is None or not session.init.success: + return session + + config = session.config + len_calls = len(session.calls) + for index, call in enumerate(session.calls): + process_call_data(call, config, last=len_calls == index + 1) + + return session + + +def read_message(msg: bytes) -> Session: + """Read message into Session instance""" + + message = vycall_pb2.Commit() # pylint: disable=no-member + message.ParseFromString(msg) + session = msg_to_commit_data(message) + + session = initialization(session) + session = process_session_data(session) + + write_stdout_log(script_stdout_log, session) + + return session + + +def write_reply(session: Session) -> bytearray: + """Serialize modified object to bytearray, prepending data length + header""" + + reply = commit_data_to_msg(session) + encoded_data = reply.SerializeToString() + byte_size = reply.ByteSize() + length_bytes = byte_size.to_bytes(4) + arr = bytearray(length_bytes) + arr.extend(encoded_data) + + return arr + + +def load_server_conf() -> ServerConf: + # pylint: disable=import-outside-toplevel + # pylint: disable=broad-exception-caught + from vyos.defaults import vyconfd_conf + + try: + with open(vyconfd_conf, 'rb') as f: + vyconfd_conf_d = tomli.load(f) + + except Exception as e: + logger.critical(f'Failed to open the vyconfd.conf file {vyconfd_conf}: {e}') + sys.exit(1) + + app = vyconfd_conf_d.get('appliance', {}) + + conf_data = { + k: v for k, v in app.items() if k in [_.name for _ in fields(ServerConf)] + } + + conf = ServerConf(**conf_data) + + return conf + + +def remove_if_exists(f: str): + try: + os.unlink(f) + except FileNotFoundError: + pass + + +def sig_handler(_signum, _frame): + logger.info('stopping server') + raise KeyboardInterrupt + + +def run_server(): + # pylint: disable=global-statement + + global server_conf + global SOCKET_PATH + global conf_mode_scripts + global frr + + signal.signal(signal.SIGTERM, sig_handler) + signal.signal(signal.SIGINT, sig_handler) + + logger.info('starting server') + + server_conf = load_server_conf() + SOCKET_PATH = server_conf.commitd_socket + conf_mode_scripts = load_conf_mode_scripts() + + cfg_group = grp.getgrnam(CFG_GROUP) + os.setgid(cfg_group.gr_gid) + + server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + remove_if_exists(SOCKET_PATH) + server_socket.bind(SOCKET_PATH) + Path(SOCKET_PATH).chmod(0o775) + + # We only need one long-lived instance of FRRender + frr = FRRender() + + server_socket.listen(2) + while True: + try: + conn, _ = server_socket.accept() + logger.debug('connection accepted') + while True: + # receive size of data + data_length = conn.recv(4) + if not data_length: + logger.debug('no data') + # if no data break + break + + length = int.from_bytes(data_length) + # receive data + data = conn.recv(length) + + session = read_message(data) + reply = write_reply(session) + conn.sendall(reply) + + conn.close() + logger.debug('connection closed') + + except KeyboardInterrupt: + break + + server_socket.close() + sys.exit(0) + + +if __name__ == '__main__': + run_server() diff --git a/src/services/vyos-configd b/src/services/vyos-configd index 3674d9627..28acccd2c 100755 --- a/src/services/vyos-configd +++ b/src/services/vyos-configd @@ -14,6 +14,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +# pylint: disable=redefined-outer-name + import os import sys import grp @@ -22,9 +24,13 @@ import json import typing import logging import signal +import traceback import importlib.util +import io +from contextlib import redirect_stdout +from enum import Enum + import zmq -from contextlib import contextmanager from vyos.defaults import directories from vyos.utils.boot import boot_configuration_complete @@ -32,6 +38,8 @@ from vyos.configsource import ConfigSourceString from vyos.configsource import ConfigSourceError from vyos.configdiff import get_commit_scripts from vyos.config import Config +from vyos.frrender import FRRender +from vyos.frrender import get_frrender_dict from vyos import ConfigError CFG_GROUP = 'vyattacfg' @@ -49,13 +57,18 @@ if debug: else: logger.setLevel(logging.INFO) -SOCKET_PATH = "ipc:///run/vyos-configd.sock" +SOCKET_PATH = 'ipc:///run/vyos-configd.sock' +MAX_MSG_SIZE = 65535 +PAD_MSG_SIZE = 6 + # Response error codes -R_SUCCESS = 1 -R_ERROR_COMMIT = 2 -R_ERROR_DAEMON = 4 -R_PASS = 8 +class Response(Enum): + SUCCESS = 1 + ERROR_COMMIT = 2 + ERROR_DAEMON = 4 + PASS = 8 + vyos_conf_scripts_dir = directories['conf_mode'] configd_include_file = os.path.join(directories['data'], 'configd-include.json') @@ -64,29 +77,31 @@ configd_env_unset_file = os.path.join(directories['data'], 'vyos-configd-env-uns # sourced on entering config session configd_env_file = '/etc/default/vyos-configd-env' -session_out = None -session_mode = None def key_name_from_file_name(f): return os.path.splitext(f)[0] + def module_name_from_key(k): return k.replace('-', '_') + def path_from_file_name(f): return os.path.join(vyos_conf_scripts_dir, f) + # opt-in to be run by daemon with open(configd_include_file) as f: try: include = json.load(f) except OSError as e: - logger.critical(f"configd include file error: {e}") + logger.critical(f'configd include file error: {e}') sys.exit(1) except json.JSONDecodeError as e: - logger.critical(f"JSON load error: {e}") + logger.critical(f'JSON load error: {e}') sys.exit(1) + # import conf_mode scripts (_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir))) filenames.sort() @@ -110,31 +125,17 @@ conf_mode_scripts = dict(zip(imports, modules)) exclude_set = {key_name_from_file_name(f) for f in filenames if f not in include} include_set = {key_name_from_file_name(f) for f in filenames if f in include} -@contextmanager -def stdout_redirected(filename, mode): - saved_stdout_fd = None - destination_file = None - try: - sys.stdout.flush() - saved_stdout_fd = os.dup(sys.stdout.fileno()) - destination_file = open(filename, mode) - os.dup2(destination_file.fileno(), sys.stdout.fileno()) - yield - finally: - if saved_stdout_fd is not None: - os.dup2(saved_stdout_fd, sys.stdout.fileno()) - os.close(saved_stdout_fd) - if destination_file is not None: - destination_file.close() - -def explicit_print(path, mode, msg): - try: - with open(path, mode) as f: - f.write(f"\n{msg}\n\n") - except OSError: - logger.critical("error explicit_print") -def run_script(script_name, config, args) -> int: +def write_stdout_log(file_name, msg): + if boot_configuration_complete(): + return + with open(file_name, 'a') as f: + f.write(msg) + + +def run_script(script_name, config, args) -> tuple[Response, str]: + # pylint: disable=broad-exception-caught + script = conf_mode_scripts[script_name] script.argv = args config.set_level([]) @@ -145,64 +146,54 @@ def run_script(script_name, config, args) -> int: script.apply(c) except ConfigError as e: logger.error(e) - explicit_print(session_out, session_mode, str(e)) - return R_ERROR_COMMIT - except Exception as e: - logger.critical(e) - return R_ERROR_DAEMON + return Response.ERROR_COMMIT, str(e) + except Exception: + tb = traceback.format_exc() + logger.error(tb) + return Response.ERROR_COMMIT, tb + + return Response.SUCCESS, '' - return R_SUCCESS def initialization(socket): - global session_out - global session_mode + # pylint: disable=broad-exception-caught,too-many-locals + # Reset config strings: active_string = '' session_string = '' # check first for resent init msg, in case of client timeout while True: - msg = socket.recv().decode("utf-8", "ignore") + msg = socket.recv().decode('utf-8', 'ignore') try: message = json.loads(msg) - if message["type"] == "init": - resp = "init" + if message['type'] == 'init': + resp = 'init' socket.send(resp.encode()) - except: + except Exception: break # zmq synchronous for ipc from single client: active_string = msg - resp = "active" + resp = 'active' socket.send(resp.encode()) - session_string = socket.recv().decode("utf-8", "ignore") - resp = "session" + session_string = socket.recv().decode('utf-8', 'ignore') + resp = 'session' socket.send(resp.encode()) - pid_string = socket.recv().decode("utf-8", "ignore") - resp = "pid" + pid_string = socket.recv().decode('utf-8', 'ignore') + resp = 'pid' socket.send(resp.encode()) - sudo_user_string = socket.recv().decode("utf-8", "ignore") - resp = "sudo_user" + sudo_user_string = socket.recv().decode('utf-8', 'ignore') + resp = 'sudo_user' socket.send(resp.encode()) - temp_config_dir_string = socket.recv().decode("utf-8", "ignore") - resp = "temp_config_dir" + temp_config_dir_string = socket.recv().decode('utf-8', 'ignore') + resp = 'temp_config_dir' socket.send(resp.encode()) - changes_only_dir_string = socket.recv().decode("utf-8", "ignore") - resp = "changes_only_dir" + changes_only_dir_string = socket.recv().decode('utf-8', 'ignore') + resp = 'changes_only_dir' socket.send(resp.encode()) - logger.debug(f"config session pid is {pid_string}") - logger.debug(f"config session sudo_user is {sudo_user_string}") - - try: - session_out = os.readlink(f"/proc/{pid_string}/fd/1") - session_mode = 'w' - except FileNotFoundError: - session_out = None - - # if not a 'live' session, for example on boot, write to file - if not session_out or not boot_configuration_complete(): - session_out = script_stdout_log - session_mode = 'a' + logger.debug(f'config session pid is {pid_string}') + logger.debug(f'config session sudo_user is {sudo_user_string}') os.environ['SUDO_USER'] = sudo_user_string if temp_config_dir_string: @@ -211,8 +202,9 @@ def initialization(socket): os.environ['VYATTA_CHANGES_ONLY_DIR'] = changes_only_dir_string try: - configsource = ConfigSourceString(running_config_text=active_string, - session_config_text=session_string) + configsource = ConfigSourceString( + running_config_text=active_string, session_config_text=session_string + ) except ConfigSourceError as e: logger.debug(e) return None @@ -229,10 +221,12 @@ def initialization(socket): return config -def process_node_data(config, data, last: bool = False) -> int: + +def process_node_data(config, data, _last: bool = False) -> tuple[Response, str]: if not config: - logger.critical(f"Empty config") - return R_ERROR_DAEMON + out = 'Empty config' + logger.critical(out) + return Response.ERROR_DAEMON, out script_name = None os.environ['VYOS_TAGNODE_VALUE'] = '' @@ -246,8 +240,9 @@ def process_node_data(config, data, last: bool = False) -> int: if res.group(2): script_name = res.group(2) if not script_name: - logger.critical(f"Missing script_name") - return R_ERROR_DAEMON + out = 'Missing script_name' + logger.critical(out) + return Response.ERROR_DAEMON, out if res.group(3): args = res.group(3).split() args.insert(0, f'{script_name}.py') @@ -259,26 +254,46 @@ def process_node_data(config, data, last: bool = False) -> int: scripts_called.append(script_record) if script_name not in include_set: - return R_PASS + return Response.PASS, '' + + with redirect_stdout(io.StringIO()) as o: + result, err_out = run_script(script_name, config, args) + amb_out = o.getvalue() + o.close() + + out = amb_out + err_out + + return result, out + + +def send_result(sock, err, msg): + err_no = err.value + err_name = err.name + msg = msg if msg else '' + msg_size = min(MAX_MSG_SIZE, len(msg)) - with stdout_redirected(session_out, session_mode): - result = run_script(script_name, config, args) + err_rep = err_no.to_bytes(1) + msg_size_rep = f'{msg_size:#0{PAD_MSG_SIZE}x}' + + logger.debug(f'Sending reply: {err_name} with output') + sock.send_multipart([err_rep, msg_size_rep.encode(), msg.encode()]) + + write_stdout_log(script_stdout_log, msg) - return result def remove_if_file(f: str): try: os.remove(f) except FileNotFoundError: pass - except OSError: - raise + def shutdown(): remove_if_file(configd_env_file) os.symlink(configd_env_unset_file, configd_env_file) sys.exit(0) + if __name__ == '__main__': context = zmq.Context() socket = context.socket(zmq.REP) @@ -294,6 +309,7 @@ if __name__ == '__main__': os.environ['VYOS_CONFIGD'] = 't' def sig_handler(signum, frame): + # pylint: disable=unused-argument shutdown() signal.signal(signal.SIGTERM, sig_handler) @@ -303,25 +319,33 @@ if __name__ == '__main__': remove_if_file(configd_env_file) os.symlink(configd_env_set_file, configd_env_file) - config = None + # We only need one long-lived instance of FRRender + frr = FRRender() + config = None while True: # Wait for next request from client msg = socket.recv().decode() - logger.debug(f"Received message: {msg}") + logger.debug(f'Received message: {msg}') message = json.loads(msg) - if message["type"] == "init": - resp = "init" + if message['type'] == 'init': + resp = 'init' socket.send(resp.encode()) config = initialization(socket) - elif message["type"] == "node": - res = process_node_data(config, message["data"], message["last"]) - response = res.to_bytes(1, byteorder=sys.byteorder) - logger.debug(f"Sending response {res}") - socket.send(response) - if message["last"] and config: + elif message['type'] == 'node': + res, out = process_node_data(config, message['data'], message['last']) + send_result(socket, res, out) + + if message['last'] and config: scripts_called = getattr(config, 'scripts_called', []) logger.debug(f'scripts_called: {scripts_called}') + + if res == Response.SUCCESS: + tmp = get_frrender_dict(config) + if frr.generate(tmp): + # only apply a new FRR configuration if anything changed + # in comparison to the previous applied configuration + frr.apply() else: - logger.critical(f"Unexpected message: {message}") + logger.critical(f'Unexpected message: {message}') diff --git a/src/services/vyos-conntrack-logger b/src/services/vyos-conntrack-logger index 9c31b465f..ec0e1f717 100755 --- a/src/services/vyos-conntrack-logger +++ b/src/services/vyos-conntrack-logger @@ -15,10 +15,8 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import argparse -import grp import logging import multiprocessing -import os import queue import signal import socket diff --git a/src/services/vyos-domain-resolver b/src/services/vyos-domain-resolver new file mode 100755 index 000000000..fb18724af --- /dev/null +++ b/src/services/vyos-domain-resolver @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2022-2024 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +import json +import time +import logging +import os + +from vyos.configdict import dict_merge +from vyos.configquery import ConfigTreeQuery +from vyos.firewall import fqdn_config_parse +from vyos.firewall import fqdn_resolve +from vyos.ifconfig import WireGuardIf +from vyos.remote import download +from vyos.utils.commit import commit_in_progress +from vyos.utils.dict import dict_search_args +from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME +from vyos.utils.file import makedir, chmod_775, write_file, read_file +from vyos.utils.network import is_valid_ipv4_address_or_range, is_valid_ipv6_address_or_range +from vyos.utils.process import cmd +from vyos.utils.process import run +from vyos.xml_ref import get_defaults + +base = ['firewall'] +timeout = 300 +cache = False +base_firewall = ['firewall'] +base_nat = ['nat'] +base_interfaces = ['interfaces'] + +firewall_config_dir = "/config/firewall" + +domain_state = {} + +ipv4_tables = { + 'ip vyos_mangle', + 'ip vyos_filter', + 'ip vyos_nat', + 'ip raw' +} + +ipv6_tables = { + 'ip6 vyos_mangle', + 'ip6 vyos_filter', + 'ip6 raw' +} + +logger = logging.getLogger(__name__) +logs_handler = logging.StreamHandler() +logger.addHandler(logs_handler) +logger.setLevel(logging.INFO) + +def get_config(conf, node): + node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True, + no_tag_node_value_mangle=True) + + default_values = get_defaults(node, get_first_key=True) + + node_config = dict_merge(default_values, node_config) + + if node == base_firewall and 'global_options' in node_config: + global_config = node_config['global_options'] + global timeout, cache + + if 'resolver_interval' in global_config: + timeout = int(global_config['resolver_interval']) + + if 'resolver_cache' in global_config: + cache = True + + fqdn_config_parse(node_config, node[0]) + + return node_config + +def resolve(domains, ipv6=False): + global domain_state + + ip_list = set() + + for domain in domains: + resolved = fqdn_resolve(domain, ipv6=ipv6) + + cache_key = f'{domain}_ipv6' if ipv6 else domain + + if resolved and cache: + domain_state[cache_key] = resolved + elif not resolved: + if cache_key not in domain_state: + continue + resolved = domain_state[cache_key] + + ip_list = ip_list | resolved + return ip_list + +def nft_output(table, set_name, ip_list): + output = [f'flush set {table} {set_name}'] + if ip_list: + ip_str = ','.join(ip_list) + output.append(f'add element {table} {set_name} {{ {ip_str} }}') + return output + +def nft_valid_sets(): + try: + valid_sets = [] + sets_json = cmd('nft --json list sets') + sets_obj = json.loads(sets_json) + + for obj in sets_obj['nftables']: + if 'set' in obj: + family = obj['set']['family'] + table = obj['set']['table'] + name = obj['set']['name'] + valid_sets.append((f'{family} {table}', name)) + + return valid_sets + except: + return [] + +def update_remote_group(config): + conf_lines = [] + count = 0 + valid_sets = nft_valid_sets() + + remote_groups = dict_search_args(config, 'group', 'remote_group') + if remote_groups: + # Create directory for list files if necessary + if not os.path.isdir(firewall_config_dir): + makedir(firewall_config_dir, group='vyattacfg') + chmod_775(firewall_config_dir) + + for set_name, remote_config in remote_groups.items(): + if 'url' not in remote_config: + continue + nft_ip_set_name = f'R_{set_name}' + nft_ip6_set_name = f'R6_{set_name}' + + # Create list file if necessary + list_file = os.path.join(firewall_config_dir, f"{nft_ip_set_name}.txt") + if not os.path.exists(list_file): + write_file(list_file, '', user="root", group="vyattacfg", mode=0o644) + + # Attempt to download file, use cached version if download fails + try: + download(list_file, remote_config['url'], raise_error=True) + except: + logger.error(f'Failed to download list-file for {set_name} remote group') + logger.info(f'Using cached list-file for {set_name} remote group') + + # Read list file + ip_list = [] + ip6_list = [] + invalid_list = [] + for line in read_file(list_file).splitlines(): + line_first_word = line.strip().partition(' ')[0] + + if is_valid_ipv4_address_or_range(line_first_word): + ip_list.append(line_first_word) + elif is_valid_ipv6_address_or_range(line_first_word): + ip6_list.append(line_first_word) + else: + if line_first_word[0].isalnum(): + invalid_list.append(line_first_word) + + # Load ip tables + for table in ipv4_tables: + if (table, nft_ip_set_name) in valid_sets: + conf_lines += nft_output(table, nft_ip_set_name, ip_list) + + # Load ip6 tables + for table in ipv6_tables: + if (table, nft_ip6_set_name) in valid_sets: + conf_lines += nft_output(table, nft_ip6_set_name, ip6_list) + + invalid_str = ", ".join(invalid_list) + if invalid_str: + logger.info(f'Invalid address for set {set_name}: {invalid_str}') + + count += 1 + + nft_conf_str = "\n".join(conf_lines) + "\n" + code = run(f'nft --file -', input=nft_conf_str) + + logger.info(f'Updated {count} remote-groups in firewall - result: {code}') + + +def update_fqdn(config, node): + conf_lines = [] + count = 0 + valid_sets = nft_valid_sets() + + if node == 'firewall': + domain_groups = dict_search_args(config, 'group', 'domain_group') + if domain_groups: + for set_name, domain_config in domain_groups.items(): + if 'address' not in domain_config: + continue + nft_set_name = f'D_{set_name}' + domains = domain_config['address'] + + ip_list = resolve(domains, ipv6=False) + for table in ipv4_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + ip6_list = resolve(domains, ipv6=True) + for table in ipv6_tables: + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip6_list) + count += 1 + + for set_name, domain in config['ip_fqdn'].items(): + table = 'ip vyos_filter' + nft_set_name = f'FQDN_{set_name}' + ip_list = resolve([domain], ipv6=False) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + for set_name, domain in config['ip6_fqdn'].items(): + table = 'ip6 vyos_filter' + nft_set_name = f'FQDN_{set_name}' + ip_list = resolve([domain], ipv6=True) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + else: + # It's NAT + for set_name, domain in config['ip_fqdn'].items(): + table = 'ip vyos_nat' + nft_set_name = f'FQDN_nat_{set_name}' + ip_list = resolve([domain], ipv6=False) + if (table, nft_set_name) in valid_sets: + conf_lines += nft_output(table, nft_set_name, ip_list) + count += 1 + + nft_conf_str = "\n".join(conf_lines) + "\n" + code = run(f'nft --file -', input=nft_conf_str) + + logger.info(f'Updated {count} sets in {node} - result: {code}') + +def update_interfaces(config, node): + if node == 'interfaces': + wg_interfaces = dict_search_args(config, 'wireguard') + if wg_interfaces: + + peer_public_keys = {} + # for each wireguard interfaces + for interface, wireguard in wg_interfaces.items(): + peer_public_keys[interface] = [] + for peer, peer_config in wireguard['peer'].items(): + # check peer if peer host-name or address is set + if 'host_name' in peer_config or 'address' in peer_config: + # check latest handshake + peer_public_keys[interface].append( + peer_config['public_key'] + ) + + now_time = time.time() + for (interface, check_peer_public_keys) in peer_public_keys.items(): + if len(check_peer_public_keys) == 0: + continue + + intf = WireGuardIf(interface, create=False, debug=False) + handshakes = intf.operational.get_latest_handshakes() + + # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME + # if data is being transmitted between the peers. If no data is + # transmitted, the handshake will not be initiated unless new + # data begins to flow. Each handshake generates a new session + # key, and the key is rotated at least every 120 seconds or + # upon data transmission after a prolonged silence. + for public_key, handshake_time in handshakes.items(): + if public_key in check_peer_public_keys and ( + handshake_time == 0 + or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME) + ): + intf.operational.reset_peer(public_key=public_key) + +if __name__ == '__main__': + logger.info('VyOS domain resolver') + + count = 1 + while commit_in_progress(): + if ( count % 60 == 0 ): + logger.info(f'Commit still in progress after {count}s - waiting') + count += 1 + time.sleep(1) + + conf = ConfigTreeQuery() + firewall = get_config(conf, base_firewall) + nat = get_config(conf, base_nat) + interfaces = get_config(conf, base_interfaces) + + logger.info(f'interval: {timeout}s - cache: {cache}') + + while True: + update_fqdn(firewall, 'firewall') + update_fqdn(nat, 'nat') + update_remote_group(firewall) + update_interfaces(interfaces, 'interfaces') + time.sleep(timeout) diff --git a/src/services/vyos-hostsd b/src/services/vyos-hostsd index 1ba90471e..44f03586c 100755 --- a/src/services/vyos-hostsd +++ b/src/services/vyos-hostsd @@ -233,10 +233,7 @@ # } import os -import sys -import time import json -import signal import traceback import re import logging @@ -245,7 +242,6 @@ import zmq from voluptuous import Schema, MultipleInvalid, Required, Any from collections import OrderedDict from vyos.utils.file import makedir -from vyos.utils.permission import chown from vyos.utils.permission import chmod_755 from vyos.utils.process import popen from vyos.utils.process import process_named_running diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 97633577d..be3dd5051 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -17,946 +17,135 @@ import os import sys import grp -import copy import json import logging import signal import traceback -import threading -from enum import Enum - from time import sleep -from typing import List, Union, Callable, Dict, Self +from typing import Annotated -from fastapi import FastAPI, Depends, Request, Response, HTTPException -from fastapi import BackgroundTasks -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, Query from fastapi.exceptions import RequestValidationError -from fastapi.routing import APIRoute -from pydantic import BaseModel, StrictStr, validator, model_validator -from starlette.middleware.cors import CORSMiddleware -from starlette.datastructures import FormData -from starlette.formparsers import FormParser, MultiPartParser -from multipart.multipart import parse_options_header from uvicorn import Config as UvicornConfig from uvicorn import Server as UvicornServer -from ariadne.asgi import GraphQL - -from vyos.config import Config -from vyos.configtree import ConfigTree -from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSession -from vyos.configsession import ConfigSessionError from vyos.defaults import api_config_state +from vyos.utils.file import read_file +from vyos.version import get_version -import api.graphql.state +from api.session import SessionState +from api.rest.models import error, InfoQueryParams, success CFG_GROUP = 'vyattacfg' debug = True -logger = logging.getLogger(__name__) +LOG = logging.getLogger('http_api') logs_handler = logging.StreamHandler() -logger.addHandler(logs_handler) +LOG.addHandler(logs_handler) if debug: - logger.setLevel(logging.DEBUG) + LOG.setLevel(logging.DEBUG) else: - logger.setLevel(logging.INFO) + LOG.setLevel(logging.INFO) -# Giant lock! -lock = threading.Lock() def load_server_config(): with open(api_config_state) as f: config = json.load(f) return config -def check_auth(key_list, key): - key_id = None - for k in key_list: - if k['key'] == key: - key_id = k['id'] - return key_id - -def error(code, msg): - resp = {"success": False, "error": msg, "data": None} - resp = json.dumps(resp) - return HTMLResponse(resp, status_code=code) - -def success(data): - resp = {"success": True, "data": data, "error": None} - resp = json.dumps(resp) - return HTMLResponse(resp) - -# Pydantic models for validation -# Pydantic will cast when possible, so use StrictStr -# validators added as needed for additional constraints -# schema_extra adds anotations to OpenAPI, to add examples - -class ApiModel(BaseModel): - key: StrictStr - -class BasePathModel(BaseModel): - op: StrictStr - path: List[StrictStr] - - @validator("path") - def check_non_empty(cls, path): - if not len(path) > 0: - raise ValueError('path must be non-empty') - return path - -class BaseConfigureModel(BasePathModel): - value: StrictStr = None - -class ConfigureModel(ApiModel, BaseConfigureModel): - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "set | delete | comment", - "path": ['config', 'mode', 'path'], - } - } - -class ConfigureListModel(ApiModel): - commands: List[BaseConfigureModel] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "commands": "list of commands", - } - } - -class BaseConfigSectionModel(BasePathModel): - section: Dict - -class ConfigSectionModel(ApiModel, BaseConfigSectionModel): - pass - -class ConfigSectionListModel(ApiModel): - commands: List[BaseConfigSectionModel] - -class BaseConfigSectionTreeModel(BaseModel): - op: StrictStr - mask: Dict - config: Dict - -class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): - pass - -class RetrieveModel(ApiModel): - op: StrictStr - path: List[StrictStr] - configFormat: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "returnValue | returnValues | exists | showConfig", - "path": ['config', 'mode', 'path'], - "configFormat": "json (default) | json_ast | raw", - - } - } - -class ConfigFileModel(ApiModel): - op: StrictStr - file: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "save | load", - "file": "filename", - } - } - - -class ImageOp(str, Enum): - add = "add" - delete = "delete" - show = "show" - set_default = "set_default" - - -class ImageModel(ApiModel): - op: ImageOp - url: StrictStr = None - name: StrictStr = None - - @model_validator(mode='after') - def check_data(self) -> Self: - if self.op == 'add': - if not self.url: - raise ValueError("Missing required field \"url\"") - elif self.op in ['delete', 'set_default']: - if not self.name: - raise ValueError("Missing required field \"name\"") - - return self - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show | set_default", - "url": "imagelocation", - "name": "imagename", - } - } - -class ImportPkiModel(ApiModel): - op: StrictStr - path: List[StrictStr] - passphrase: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "import_pki", - "path": ["op", "mode", "path"], - "passphrase": "passphrase", - } - } - - -class ContainerImageModel(ApiModel): - op: StrictStr - name: StrictStr = None - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show", - "name": "imagename", - } - } - -class GenerateModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "generate", - "path": ["op", "mode", "path"], - } - } - -class ShowModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "show", - "path": ["op", "mode", "path"], - } - } - -class RebootModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "reboot", - "path": ["op", "mode", "path"], - } - } - -class ResetModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "reset", - "path": ["op", "mode", "path"], - } - } - -class PoweroffModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - schema_extra = { - "example": { - "key": "id_key", - "op": "poweroff", - "path": ["op", "mode", "path"], - } - } - - -class Success(BaseModel): - success: bool - data: Union[str, bool, Dict] - error: str - -class Error(BaseModel): - success: bool = False - data: Union[str, bool, Dict] - error: str - -responses = { - 200: {'model': Success}, - 400: {'model': Error}, - 422: {'model': Error, 'description': 'Validation Error'}, - 500: {'model': Error} -} - -def auth_required(data: ApiModel): - key = data.key - api_keys = app.state.vyos_keys - key_id = check_auth(api_keys, key) - if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") - app.state.vyos_id = key_id - -# override Request and APIRoute classes in order to convert form request to json; -# do all explicit validation here, for backwards compatability of error messages; -# the explicit validation may be dropped, if desired, in favor of native -# validation by FastAPI/Pydantic, as is used for application/json requests -class MultipartRequest(Request): - _form_err = () - @property - def form_err(self): - return self._form_err - - @form_err.setter - def form_err(self, val): - if not self._form_err: - self._form_err = val - - @property - def orig_headers(self): - self._orig_headers = super().headers - return self._orig_headers - - @property - def headers(self): - self._headers = super().headers.mutablecopy() - self._headers['content-type'] = 'application/json' - return self._headers - - async def form(self) -> FormData: - if self._form is None: - assert ( - parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.orig_headers, self.stream()) - self._form = await multipart_parser.parse() - elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.orig_headers, self.stream()) - self._form = await form_parser.parse() - else: - self._form = FormData() - return self._form - - async def body(self) -> bytes: - if not hasattr(self, "_body"): - forms = {} - merge = {} - body = await super().body() - self._body = body - - form_data = await self.form() - if form_data: - endpoint = self.url.path - logger.debug("processing form data") - for k, v in form_data.multi_items(): - forms[k] = v - - if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") - return self._body - else: - try: - tmp = json.loads(forms['data']) - except json.JSONDecodeError as e: - self.form_err = (400, f'Failed to parse JSON: {e}') - return self._body - if isinstance(tmp, list): - merge['commands'] = tmp - else: - merge = tmp - - if 'commands' in merge: - cmds = merge['commands'] - else: - cmds = copy.deepcopy(merge) - cmds = [cmds] - - for c in cmds: - if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") - return self._body - if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): - if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") - elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") - elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") - if endpoint in ('/configure'): - if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") - if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") - if endpoint in ('/configure-section'): - if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") - - if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") - if 'key' in forms and 'key' not in merge: - merge['key'] = forms['key'] - - new_body = json.dumps(merge) - new_body = new_body.encode() - self._body = new_body - - return self._body - -class MultipartRoute(APIRoute): - def get_route_handler(self) -> Callable: - original_route_handler = super().get_route_handler() - - async def custom_route_handler(request: Request) -> Response: - request = MultipartRequest(request.scope, request.receive) - try: - response: Response = await original_route_handler(request) - except HTTPException as e: - return error(e.status_code, e.detail) - except Exception as e: - form_err = request.form_err - if form_err: - return error(*form_err) - raise e - - return response - - return custom_route_handler app = FastAPI(debug=True, title="VyOS API", - version="0.1.0", - responses={**responses}, - dependencies=[Depends(auth_required)]) + version="0.1.0") -app.router.route_class = MultipartRoute @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): +async def validation_exception_handler(_request, exc): return error(400, str(exc.errors()[0])) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" - -def call_commit(s: ConfigSession): - try: - s.commit() - except ConfigSessionError as e: - s.discard() - if app.state.vyos_debug: - logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}") - else: - logger.warning(f"ConfigSessionError: {e}") - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): - session = app.state.vyos_session - env = session.get_session_env() - - endpoint = request.url.path - - # Allow users to pass just one command - if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)): - data = [data] - else: - data = data.commands - - # We don't want multiple people/apps to be able to commit at once, - # or modify the shared session while someone else is doing the same, - # so the lock is really global - lock.acquire() - - config = Config(session_env=env) - - status = 200 - msg = None - error_msg = None - try: - for c in data: - op = c.op - if not isinstance(c, BaseConfigSectionTreeModel): - path = c.path - - if isinstance(c, BaseConfigureModel): - if c.value: - value = c.value - else: - value = "" - # For vyos.configsession calls that have no separate value arguments, - # and for type checking too - cfg_path = " ".join(path + [value]).strip() - - elif isinstance(c, BaseConfigSectionModel): - section = c.section - - elif isinstance(c, BaseConfigSectionTreeModel): - mask = c.mask - config = c.config - - if isinstance(c, BaseConfigureModel): - if op == 'set': - session.set(path, value=value) - elif op == 'delete': - if app.state.vyos_strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") - session.delete(path, value=value) - elif op == 'comment': - session.comment(path, value=value) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - - elif isinstance(c, BaseConfigSectionModel): - if op == 'set': - session.set_section(path, section) - elif op == 'load': - session.load_section(path, section) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - - elif isinstance(c, BaseConfigSectionTreeModel): - if op == 'set': - session.set_section_tree(config) - elif op == 'load': - session.load_section_tree(mask, config) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - # end for - config = Config(session_env=env) - d = get_config_diff(config) - - if d.is_node_changed(['service', 'https']): - background_tasks.add_task(call_commit, session) - msg = self_ref_msg - else: - session.commit() - - logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'") - except ConfigSessionError as e: - session.discard() - status = 400 - if app.state.vyos_debug: - logger.critical(f"ConfigSessionError:\n {traceback.format_exc()}") - error_msg = str(e) - except Exception as e: - session.discard() - logger.critical(traceback.format_exc()) - status = 500 - - # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." - finally: - lock.release() - - if status != 200: - return error(status, error_msg) - - return success(msg) - -def create_path_import_pki_no_prompt(path): - correct_paths = ['ca', 'certificate', 'key-pair'] - if path[1] not in correct_paths: - return False - path[1] = '--' + path[1].replace('-', '') - path[3] = '--key-filename' - return path[1:] - -@app.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): - return _configure_op(data, request, background_tasks) - -@app.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): - return _configure_op(data, request, background_tasks) -@app.post("/retrieve") -async def retrieve_op(data: RetrieveModel): - session = app.state.vyos_session - env = session.get_session_env() - config = Config(session_env=env) +@app.get('/info') +def info(q: Annotated[InfoQueryParams, Query()]): + show_version = q.version + show_hostname = q.hostname - op = data.op - path = " ".join(data.path) + prelogin_file = r'/etc/issue' + hostname_file = r'/etc/hostname' + default = 'Welcome to VyOS' try: - if op == 'returnValue': - res = config.return_value(path) - elif op == 'returnValues': - res = config.return_values(path) - elif op == 'exists': - res = config.exists(path) - elif op == 'showConfig': - config_format = 'json' - if data.configFormat: - config_format = data.configFormat - - res = session.show_config(path=data.path) - if config_format == 'json': - config_tree = ConfigTree(res) - res = json.loads(config_tree.to_json()) - elif config_format == 'json_ast': - config_tree = ConfigTree(res) - res = json.loads(config_tree.to_json_ast()) - elif config_format == 'raw': - pass - else: - return error(400, f"'{config_format}' is not a valid config format") - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/config-file') -def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): - session = app.state.vyos_session - env = session.get_session_env() - op = data.op - msg = None - - try: - if op == 'save': - if data.file: - path = data.file - else: - path = '/config/config.boot' - msg = session.save_config(path) - elif op == 'load': - if data.file: - path = data.file - else: - return error(400, "Missing required field \"file\"") - - session.migrate_and_load_config(path) - - config = Config(session_env=env) - d = get_config_diff(config) - - if d.is_node_changed(['service', 'https']): - background_tasks.add_task(call_commit, session) - msg = self_ref_msg - else: - session.commit() - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(msg) - -@app.post('/image') -def image_op(data: ImageModel): - session = app.state.vyos_session - - op = data.op - - try: - if op == 'add': - res = session.install_image(data.url) - elif op == 'delete': - res = session.remove_image(data.name) - elif op == 'show': - res = session.show(["system", "image"]) - elif op == 'set_default': - res = session.set_default_image(data.name) - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/container-image') -def container_image_op(data: ContainerImageModel): - session = app.state.vyos_session - - op = data.op - - try: - if op == 'add': - if data.name: - name = data.name - else: - return error(400, "Missing required field \"name\"") - res = session.add_container_image(name) - elif op == 'delete': - if data.name: - name = data.name - else: - return error(400, "Missing required field \"name\"") - res = session.delete_container_image(name) - elif op == 'show': - res = session.show_container_image() - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/generate') -def generate_op(data: GenerateModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'generate': - res = session.generate(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/show') -def show_op(data: ShowModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'show': - res = session.show(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/reboot') -def reboot_op(data: RebootModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'reboot': - res = session.reboot(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/reset') -def reset_op(data: ResetModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'reset': - res = session.reset(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/import-pki') -def import_pki(data: ImportPkiModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - lock.acquire() - - try: - if op == 'import-pki': - # need to get rid or interactive mode for private key - if len(path) == 5 and path[3] in ['key-file', 'private-key']: - path_no_prompt = create_path_import_pki_no_prompt(path) - if not path_no_prompt: - return error(400, f"Invalid command: {' '.join(path)}") - if data.passphrase: - path_no_prompt += ['--passphrase', data.passphrase] - res = session.import_pki_no_prompt(path_no_prompt) - else: - res = session.import_pki(path) - if not res[0].isdigit(): - return error(400, res) - # commit changes - session.commit() - res = res.split('. ')[0] - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - finally: - lock.release() + res = { + 'banner': '', + 'hostname': '', + 'version': '' + } + if show_version: + res.update(version=get_version()) - return success(res) + if show_hostname: + try: + hostname = read_file(hostname_file) + except Exception: + hostname = 'vyos' + res.update(hostname=hostname) -@app.post('/poweroff') -def poweroff_op(data: PoweroffModel): - session = app.state.vyos_session + banner = read_file(prelogin_file, defaultonfailure=default) + if banner == f'{default} - \\n \\l': + banner = banner.partition(default)[1] - op = data.op - path = data.path - - try: - if op == 'poweroff': - res = session.poweroff(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + res.update(banner=banner) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, 'An internal error occured. Check the logs for details.') return success(res) ### -# GraphQL integration -### - -def graphql_init(app: FastAPI = app): - from api.graphql.libs.token_auth import get_user_context - api.graphql.state.init() - api.graphql.state.settings['app'] = app - - # import after initializaion of state - from api.graphql.bindings import generate_schema - schema = generate_schema() - - in_spec = app.state.vyos_introspection - - if app.state.vyos_origins: - origins = app.state.vyos_origins - app.add_route('/graphql', CORSMiddleware(GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec), - allow_origins=origins, - allow_methods=("GET", "POST", "OPTIONS"), - allow_headers=("Authorization",))) - else: - app.add_route('/graphql', GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec)) -### # Modify uvicorn to allow reloading server within the configsession ### server = None shutdown = False + class ApiServerConfig(UvicornConfig): pass + class ApiServer(UvicornServer): def install_signal_handlers(self): pass + def reload_handler(signum, frame): + # pylint: disable=global-statement + global server - logger.debug('Reload signal received...') + LOG.debug('Reload signal received...') if server is not None: server.handle_exit(signum, frame) server = None - logger.info('Server stopping for reload...') + LOG.info('Server stopping for reload...') else: - logger.warning('Reload called for non-running server...') + LOG.warning('Reload called for non-running server...') + def shutdown_handler(signum, frame): + # pylint: disable=global-statement + global shutdown - logger.debug('Shutdown signal received...') + LOG.debug('Shutdown signal received...') server.handle_exit(signum, frame) - logger.info('Server shutdown...') + LOG.info('Server shutdown...') shutdown = True +# end modify uvicorn + + def flatten_keys(d: dict) -> list[dict]: keys_list = [] for el in list(d['keys'].get('id', {})): @@ -965,49 +154,87 @@ def flatten_keys(d: dict) -> list[dict]: keys_list.append({'id': el, 'key': key}) return keys_list -def initialization(session: ConfigSession, app: FastAPI = app): + +def regenerate_docs(app: FastAPI) -> None: + docs = ('/openapi.json', '/docs', '/docs/oauth2-redirect', '/redoc') + remove = [] + for r in app.routes: + if r.path in docs: + remove.append(r) + for r in remove: + app.routes.remove(r) + + app.openapi_schema = None + app.setup() + + +def initialization(session: SessionState, app: FastAPI = app): + # pylint: disable=global-statement,broad-exception-caught,import-outside-toplevel + global server try: server_config = load_server_config() except Exception as e: - logger.critical(f'Failed to load the HTTP API server config: {e}') + LOG.critical(f'Failed to load the HTTP API server config: {e}') sys.exit(1) - app.state.vyos_session = session - app.state.vyos_keys = [] - if 'keys' in server_config: - app.state.vyos_keys = flatten_keys(server_config) + session.keys = flatten_keys(server_config) + + rest_config = server_config.get('rest', {}) + session.debug = bool('debug' in rest_config) + session.strict = bool('strict' in rest_config) + + graphql_config = server_config.get('graphql', {}) + session.origins = graphql_config.get('cors', {}).get('allow_origin', []) + + if 'rest' in server_config: + session.rest = True + else: + session.rest = False - app.state.vyos_debug = bool('debug' in server_config) - app.state.vyos_strict = bool('strict' in server_config) - app.state.vyos_origins = server_config.get('cors', {}).get('allow_origin', []) if 'graphql' in server_config: - app.state.vyos_graphql = True + session.graphql = True if isinstance(server_config['graphql'], dict): if 'introspection' in server_config['graphql']: - app.state.vyos_introspection = True + session.introspection = True else: - app.state.vyos_introspection = False + session.introspection = False # default values if not set explicitly - app.state.vyos_auth_type = server_config['graphql']['authentication']['type'] - app.state.vyos_token_exp = server_config['graphql']['authentication']['expiration'] - app.state.vyos_secret_len = server_config['graphql']['authentication']['secret_length'] + session.auth_type = server_config['graphql']['authentication']['type'] + session.token_exp = server_config['graphql']['authentication']['expiration'] + session.secret_len = server_config['graphql']['authentication']['secret_length'] else: - app.state.vyos_graphql = False + session.graphql = False + + # pass session state + app.state = session - if app.state.vyos_graphql: + # add REST routes + if session.rest: + from api.rest.routers import rest_init + rest_init(app) + else: + from api.rest.routers import rest_clear + rest_clear(app) + + # add GraphQL route + if session.graphql: + from api.graphql.routers import graphql_init graphql_init(app) + else: + from api.graphql.routers import graphql_clear + graphql_clear(app) + + regenerate_docs(app) + + LOG.debug('Active routes are:') + for r in app.routes: + LOG.debug(f'{r.path}') config = ApiServerConfig(app, uds="/run/api.sock", proxy_headers=True) server = ApiServer(config) -def run_server(): - try: - server.run() - except OSError as e: - logger.critical(e) - sys.exit(1) if __name__ == '__main__': # systemd's user and group options don't work, do it by hand here, @@ -1022,13 +249,14 @@ if __name__ == '__main__': signal.signal(signal.SIGHUP, reload_handler) signal.signal(signal.SIGTERM, shutdown_handler) - config_session = ConfigSession(os.getpid()) + session_state = SessionState() + session_state.session = ConfigSession(os.getpid()) while True: - logger.debug('Enter main loop...') + LOG.debug('Enter main loop...') if shutdown: break if server is None: - initialization(config_session) + initialization(session_state) server.run() sleep(1) diff --git a/src/services/vyos-network-event-logger b/src/services/vyos-network-event-logger new file mode 100644 index 000000000..840ff3cda --- /dev/null +++ b/src/services/vyos-network-event-logger @@ -0,0 +1,1218 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2025 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import argparse +import logging +import multiprocessing +import queue +import signal +import socket +import threading +from pathlib import Path +from time import sleep +from typing import Dict, AnyStr, List, Union + +from pyroute2.common import AF_MPLS +from pyroute2.iproute import IPRoute +from pyroute2.netlink import rtnl, nlmsg +from pyroute2.netlink.nfnetlink.nfctsocket import nfct_msg +from pyroute2.netlink.rtnl import (rt_proto as RT_PROTO, rt_type as RT_TYPES, + rtypes as RTYPES + ) +from pyroute2.netlink.rtnl.fibmsg import FR_ACT_GOTO, FR_ACT_NOP, FR_ACT_TO_TBL, \ + fibmsg +from pyroute2.netlink.rtnl import ifaddrmsg +from pyroute2.netlink.rtnl import ifinfmsg +from pyroute2.netlink.rtnl import ndmsg +from pyroute2.netlink.rtnl import rtmsg +from pyroute2.netlink.rtnl.rtmsg import nh, rtmsg_base + +from vyos.include.uapi.linux.fib_rules import * +from vyos.include.uapi.linux.icmpv6 import * +from vyos.include.uapi.linux.if_arp import * +from vyos.include.uapi.linux.lwtunnel import * +from vyos.include.uapi.linux.neighbour import * +from vyos.include.uapi.linux.rtnetlink import * + +from vyos.utils.file import read_json + + +manager = multiprocessing.Manager() +cache = manager.dict() + + +class UnsupportedMessageType(Exception): + pass + +shutdown_event = multiprocessing.Event() + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger(__name__) + + +class DebugFormatter(logging.Formatter): + def format(self, record): + self._style._fmt = '[%(asctime)s] %(levelname)s: %(message)s' + return super().format(record) + + +def set_log_level(level: str) -> None: + if level == 'debug': + logger.setLevel(logging.DEBUG) + logger.parent.handlers[0].setFormatter(DebugFormatter()) + else: + logger.setLevel(logging.INFO) + +IFF_FLAGS = { + 'RUNNING': ifinfmsg.IFF_RUNNING, + 'LOOPBACK': ifinfmsg.IFF_LOOPBACK, + 'BROADCAST': ifinfmsg.IFF_BROADCAST, + 'POINTOPOINT': ifinfmsg.IFF_POINTOPOINT, + 'MULTICAST': ifinfmsg.IFF_MULTICAST, + 'NOARP': ifinfmsg.IFF_NOARP, + 'ALLMULTI': ifinfmsg.IFF_ALLMULTI, + 'PROMISC': ifinfmsg.IFF_PROMISC, + 'MASTER': ifinfmsg.IFF_MASTER, + 'SLAVE': ifinfmsg.IFF_SLAVE, + 'DEBUG': ifinfmsg.IFF_DEBUG, + 'DYNAMIC': ifinfmsg.IFF_DYNAMIC, + 'AUTOMEDIA': ifinfmsg.IFF_AUTOMEDIA, + 'PORTSEL': ifinfmsg.IFF_PORTSEL, + 'NOTRAILERS': ifinfmsg.IFF_NOTRAILERS, + 'UP': ifinfmsg.IFF_UP, + 'LOWER_UP': ifinfmsg.IFF_LOWER_UP, + 'DORMANT': ifinfmsg.IFF_DORMANT, + 'ECHO': ifinfmsg.IFF_ECHO, +} + +NEIGH_STATE_FLAGS = { + 'INCOMPLETE': ndmsg.NUD_INCOMPLETE, + 'REACHABLE': ndmsg.NUD_REACHABLE, + 'STALE': ndmsg.NUD_STALE, + 'DELAY': ndmsg.NUD_DELAY, + 'PROBE': ndmsg.NUD_PROBE, + 'FAILED': ndmsg.NUD_FAILED, + 'NOARP': ndmsg.NUD_NOARP, + 'PERMANENT': ndmsg.NUD_PERMANENT, +} + +IFA_FLAGS = { + 'secondary': ifaddrmsg.IFA_F_SECONDARY, + 'temporary': ifaddrmsg.IFA_F_SECONDARY, + 'nodad': ifaddrmsg.IFA_F_NODAD, + 'optimistic': ifaddrmsg.IFA_F_OPTIMISTIC, + 'dadfailed': ifaddrmsg.IFA_F_DADFAILED, + 'home': ifaddrmsg.IFA_F_HOMEADDRESS, + 'deprecated': ifaddrmsg.IFA_F_DEPRECATED, + 'tentative': ifaddrmsg.IFA_F_TENTATIVE, + 'permanent': ifaddrmsg.IFA_F_PERMANENT, + 'mngtmpaddr': ifaddrmsg.IFA_F_MANAGETEMPADDR, + 'noprefixroute': ifaddrmsg.IFA_F_NOPREFIXROUTE, + 'autojoin': ifaddrmsg.IFA_F_MCAUTOJOIN, + 'stable-privacy': ifaddrmsg.IFA_F_STABLE_PRIVACY, +} + +RT_SCOPE_TO_NAME = { + rtmsg.RT_SCOPE_UNIVERSE: 'global', + rtmsg.RT_SCOPE_SITE: 'site', + rtmsg.RT_SCOPE_LINK: 'link', + rtmsg.RT_SCOPE_HOST: 'host', + rtmsg.RT_SCOPE_NOWHERE: 'nowhere', +} + +FAMILY_TO_NAME = { + socket.AF_INET: 'inet', + socket.AF_INET6: 'inet6', + socket.AF_PACKET: 'link', + AF_MPLS: 'mpls', + socket.AF_BRIDGE: 'bridge', +} + +_INFINITY = 4294967295 + + +def _get_iif_name(idx: int) -> str: + """ + Retrieves the interface name associated with a given index. + """ + try: + if_info = IPRoute().link("get", index=idx) + if if_info: + return if_info[0].get_attr('IFLA_IFNAME') + except Exception as e: + pass + + return '' + + +def remember_if_index(idx: int, event_type: int) -> None: + """ + Manages the caching of network interface names based on their index and event type. + + - For RTM_DELLINK event, the interface name is removed from the cache if exists. + - For RTM_NEWLINK event, the interface name is retrieved and updated in the cache. + """ + name = cache.get(idx) + if name: + if event_type == rtnl.RTM_DELLINK: + del cache[idx] + else: + name = _get_iif_name(idx) + if name: + cache[idx] = name + else: + cache[idx] = _get_iif_name(idx) + + +class BaseFormatter: + """ + A base class providing utility methods for formatting network message data. + """ + def _get_if_name_by_index(self, idx: int) -> str: + """ + Retrieves the name of a network interface based on its index. + + Uses a cached lookup for efficiency. If the name is not found in the cache, + it queries the system and updates the cache. + """ + if_name = cache.get(idx) + if not if_name: + if_name = _get_iif_name(idx) + cache[idx] = if_name + + return if_name + + def _format_rttable(self, idx: int) -> str: + """ + Formats a route table identifier into a readable name. + """ + return f'{RT_TABLE_TO_NAME.get(idx, idx)}' + + def _parse_flag(self, data: int, flags: dict) -> list: + """ + Extracts and returns flag names equal the bits set in a numeric value. + """ + result = list() + if data: + for key, val in flags.items(): + if data & val: + result.append(key) + data &= ~val + + if data: + result.append(f"{data:#x}") + + return result + + def af_bit_len(self, af: int) -> int: + """ + Gets the bit length of a given address family. + Supports common address families like IPv4, IPv6, and MPLS. + """ + _map = { + socket.AF_INET6: 128, + socket.AF_INET: 32, + AF_MPLS: 20, + } + + return _map.get(af) + + def _format_simple_field(self, data: str, prefix: str='') -> str: + """ + Formats a simple field with an optional prefix. + + A simple field represents a value that does not require additional + parsing and is used as is. + """ + return self._output(f'{prefix} {data}') if data is not None else '' + + def _output(self, data: str) -> str: + """ + Standardizes the output format. + + Ensures that the output is enclosed with single spaces and has no leading + or trailing whitespace. + """ + return f' {data.strip()} ' if data else '' + + +class BaseMSGFormatter(BaseFormatter): + """ + A base formatter class for network messages. + This class provides common methods for formatting network-related messages, + """ + + def _prepare_start_message(self, event: str) -> str: + """ + Prepares a starting message string based on the event type. + """ + if event in ['RTM_DELROUTE', 'RTM_DELLINK', 'RTM_DELNEIGH', + 'RTM_DELADDR', 'RTM_DELADDRLABEL', 'RTM_DELRULE', + 'RTM_DELNETCONF']: + return 'Deleted ' + if event == 'RTM_GETNEIGH': + return 'Miss ' + return '' + + def _format_flow_field(self, data: int) -> str: + """ + Formats a flow field to represent traffic realms. + """ + to = data & 0xFFFF + from_ = data >> 16 + result = f"realm{'s' if from_ else ''} " + if from_: + result += f'{from_}/' + result += f'{to}' + + return self._output(result) + + def format(self, msg: nlmsg) -> str: + """ + Abstract method to format a complete message. + + This method must be implemented by subclasses to provide specific formatting + logic for different types of messages. + """ + raise NotImplementedError(f'{msg.get("event")}: {msg}') + + +class LinkFormatter(BaseMSGFormatter): + """ + A formatter class for handling link-related network messages + `RTM_NEWLINK` and `RTM_DELLINK`. + """ + def _format_iff_flags(self, flags: int) -> str: + """ + Formats interface flags into a human-readable string. + """ + result = list() + if flags: + if flags & IFF_FLAGS['UP'] and not flags & IFF_FLAGS['RUNNING']: + result.append('NO-CARRIER') + + flags &= ~IFF_FLAGS['RUNNING'] + + result.extend(self._parse_flag(flags, IFF_FLAGS)) + + return self._output(f'<{(",").join(result)}>') + + def _format_if_props(self, data: ifinfmsg.ifinfbase.proplist) -> str: + """ + Formats interface alternative name properties. + """ + result = '' + for rec in data.altnames(): + result += f'[altname {rec}] ' + return self._output(result) + + def _format_link(self, msg: ifinfmsg.ifinfmsg) -> str: + """ + Formats the link attribute of a network interface message. + """ + if msg.get_attr("IFLA_LINK") is not None: + iflink = msg.get_attr("IFLA_LINK") + if iflink: + if msg.get_attr("IFLA_LINK_NETNSID"): + return f'if{iflink}' + else: + return self._get_if_name_by_index(iflink) + return 'NONE' + + def _format_link_info(self, msg: ifinfmsg.ifinfmsg) -> str: + """ + Formats detailed information about the link, including type, address, + broadcast address, and permanent address. + """ + result = f'link/{ARPHRD_TO_NAME.get(msg.get("ifi_type"), msg.get("ifi_type"))}' + result += self._format_simple_field(msg.get_attr('IFLA_ADDRESS')) + + if msg.get_attr("IFLA_BROADCAST"): + if msg.get('flags') & ifinfmsg.IFF_POINTOPOINT: + result += f' peer' + else: + result += f' brd' + result += f' {msg.get_attr("IFLA_BROADCAST")}' + + if msg.get_attr("IFLA_PERM_ADDRESS"): + if not msg.get_attr("IFLA_ADDRESS") or \ + msg.get_attr("IFLA_ADDRESS") != msg.get_attr("IFLA_PERM_ADDRESS"): + result += f' permaddr {msg.get_attr("IFLA_PERM_ADDRESS")}' + + return self._output(result) + + def format(self, msg: ifinfmsg.ifinfmsg): + """ + Formats a network link message into a structured output string. + """ + if msg.get("family") not in [socket.AF_UNSPEC, socket.AF_BRIDGE]: + return None + + message = self._prepare_start_message(msg.get('event')) + + link = self._format_link(msg) + + message += f'{msg.get("index")}: {msg.get_attr("IFLA_IFNAME")}' + message += f'@{link}' if link else '' + message += f': {self._format_iff_flags(msg.get("flags"))}' + + message += self._format_simple_field(msg.get_attr('IFLA_MTU'), prefix='mtu') + message += self._format_simple_field(msg.get_attr('IFLA_QDISC'), prefix='qdisc') + message += self._format_simple_field(msg.get_attr('IFLA_OPERSTATE'), prefix='state') + message += self._format_simple_field(msg.get_attr('IFLA_GROUP'), prefix='group') + message += self._format_simple_field(msg.get_attr('IFLA_MASTER'), prefix='master') + + message += self._format_link_info(msg) + + if msg.get_attr('IFLA_PROP_LIST'): + message += self._format_if_props(msg.get_attr('IFLA_PROP_LIST')) + + return self._output(message) + + +class EncapFormatter(BaseFormatter): + """ + A formatter class for handling encapsulation attributes in routing messages. + """ + # TODO: implement other lwtunnel decoder in pyroute2 + # https://github.com/svinota/pyroute2/blob/78cfe838bec8d96324811a3962bda15fb028e0ce/pyroute2/netlink/rtnl/rtmsg.py#L657 + def __init__(self): + """ + Initializes the EncapFormatter with supported encapsulation types. + """ + self.formatters = { + rtmsg.LWTUNNEL_ENCAP_MPLS: self.mpls_format, + rtmsg.LWTUNNEL_ENCAP_SEG6: self.seg6_format, + rtmsg.LWTUNNEL_ENCAP_BPF: self.bpf_format, + rtmsg.LWTUNNEL_ENCAP_SEG6_LOCAL: self.seg6local_format, + } + + def _format_srh(self, data: rtmsg_base.seg6_encap_info.ipv6_sr_hdr): + """ + Formats Segment Routing Header (SRH) attributes. + """ + result = '' + # pyroute2 decode mode only as inline or encap (encap, l2encap, encap.red, l2encap.red") + # https://github.com/svinota/pyroute2/blob/78cfe838bec8d96324811a3962bda15fb028e0ce/pyroute2/netlink/rtnl/rtmsg.py#L220 + for key in ['mode', 'segs']: + + val = data.get(key) + + if val: + if key == 'segs': + result += f'{key} {len(val)} {val} ' + else: + result += f'{key} {val} ' + + return self._output(result) + + def _format_bpf_object(self, data: rtmsg_base.bpf_encap_info, attr_name: str, attr_key: str): + """ + Formats eBPF program attributes. + """ + attr = data.get_attr(attr_name) + if not attr: + return '' + result = '' + if attr.get_attr("LWT_BPF_PROG_NAME"): + result += f'{attr.get_attr("LWT_BPF_PROG_NAME")} ' + if attr.get_attr("LWT_BPF_PROG_FD"): + result += f'{attr.get_attr("LWT_BPF_PROG_FD")} ' + + return self._output(f'{attr_key} {result.strip()}') + + def mpls_format(self, data: rtmsg_base.mpls_encap_info): + """ + Formats MPLS encapsulation attributes. + """ + result = '' + if data.get_attr("MPLS_IPTUNNEL_DST"): + for rec in data.get_attr("MPLS_IPTUNNEL_DST"): + for key, val in rec.items(): + if val: + result += f'{key} {val} ' + + if data.get_attr("MPLS_IPTUNNEL_TTL"): + result += f' ttl {data.get_attr("MPLS_IPTUNNEL_TTL")}' + + return self._output(result) + + def bpf_format(self, data: rtmsg_base.bpf_encap_info): + """ + Formats eBPF encapsulation attributes. + """ + result = '' + result += self._format_bpf_object(data, 'LWT_BPF_IN', 'in') + result += self._format_bpf_object(data, 'LWT_BPF_OUT', 'out') + result += self._format_bpf_object(data, 'LWT_BPF_XMIT', 'xmit') + + if data.get_attr('LWT_BPF_XMIT_HEADROOM'): + result += f'headroom {data.get_attr("LWT_BPF_XMIT_HEADROOM")} ' + + return self._output(result) + + def seg6_format(self, data: rtmsg_base.seg6_encap_info): + """ + Formats Segment Routing (SEG6) encapsulation attributes. + """ + result = '' + if data.get_attr("SEG6_IPTUNNEL_SRH"): + result += self._format_srh(data.get_attr("SEG6_IPTUNNEL_SRH")) + + return self._output(result) + + def seg6local_format(self, data: rtmsg_base.seg6local_encap_info): + """ + Formats SEG6 local encapsulation attributes. + """ + result = '' + formatters = { + 'SEG6_LOCAL_ACTION': lambda val: f' action {next((k for k, v in data.action.actions.items() if v == val), "unknown")}', + 'SEG6_LOCAL_SRH': lambda val: f' {self._format_srh(val)}', + 'SEG6_LOCAL_TABLE': lambda val: f' table {self._format_rttable(val)}', + 'SEG6_LOCAL_NH4': lambda val: f' nh4 {val}', + 'SEG6_LOCAL_NH6': lambda val: f' nh6 {val}', + 'SEG6_LOCAL_IIF': lambda val: f' iif {self._get_if_name_by_index(val)}', + 'SEG6_LOCAL_OIF': lambda val: f' oif {self._get_if_name_by_index(val)}', + 'SEG6_LOCAL_BPF': lambda val: f' endpoint {val.get("LWT_BPF_PROG_NAME")}', + 'SEG6_LOCAL_VRFTABLE': lambda val: f' vrftable {self._format_rttable(val)}', + } + + for rec in data.get('attrs'): + if rec[0] in formatters: + result += formatters[rec[0]](rec[1]) + + return self._output(result) + + def format(self, type: int, data: Union[rtmsg_base.mpls_encap_info, + rtmsg_base.bpf_encap_info, + rtmsg_base.seg6_encap_info, + rtmsg_base.seg6local_encap_info]): + """ + Formats encapsulation attributes based on their type. + """ + result = '' + formatter = self.formatters.get(type) + + result += f'encap {ENCAP_TO_NAME.get(type, "unknown")}' + + if formatter: + result += f' {formatter(data)}' + + return self._output(result) + + +class RouteFormatter(BaseMSGFormatter): + """ + A formatter class for handling network routing messages + `RTM_NEWROUTE` and `RTM_DELROUTE`. + """ + + def _format_rt_flags(self, flags: int) -> str: + """ + Formats route flags into a comma-separated string. + """ + result = list() + result.extend(self._parse_flag(flags, RT_FlAGS)) + + return self._output(",".join(result)) + + def _format_rta_encap(self, type: int, data: Union[rtmsg_base.mpls_encap_info, + rtmsg_base.bpf_encap_info, + rtmsg_base.seg6_encap_info, + rtmsg_base.seg6local_encap_info]) -> str: + """ + Formats encapsulation attributes. + """ + return EncapFormatter().format(type, data) + + def _format_rta_newdest(self, data: str) -> str: + """ + Formats a new destination attribute. + """ + return self._output(f'as to {data}') + + def _format_rta_gateway(self, data: str) -> str: + """ + Formats a gateway attribute. + """ + return self._output(f'via {data}') + + def _format_rta_via(self, data: str) -> str: + """ + Formats a 'via' route attribute. + """ + return self._output(f'{data}') + + def _format_rta_metrics(self, data: rtmsg_base.metrics): + """ + Formats routing metrics. + """ + result = '' + + def __format_metric_time(_val: int) -> str: + """Formats metric time values into seconds or milliseconds.""" + return f"{_val / 1000}s" if _val >= 1000 else f"{_val}ms" + + def __format_reatures(_val: int) -> str: + """Parse and formats routing feature flags.""" + result = self._parse_flag(_val, {'ecn': RTAX_FEATURE_ECN, + 'tcp_usec_ts': RTAX_FEATURE_TCP_USEC_TS}) + return ",".join(result) + + formatters = { + 'RTAX_MTU': lambda val: f' mtu {val}', + 'RTAX_WINDOW': lambda val: f' window {val}', + 'RTAX_RTT': lambda val: f' rtt {__format_metric_time(val / 8)}', + 'RTAX_RTTVAR': lambda val: f' rttvar {__format_metric_time(val / 4)}', + 'RTAX_SSTHRESH': lambda val: f' ssthresh {val}', + 'RTAX_CWND': lambda val: f' cwnd {val}', + 'RTAX_ADVMSS': lambda val: f' advmss {val}', + 'RTAX_REORDERING': lambda val: f' reordering {val}', + 'RTAX_HOPLIMIT': lambda val: f' hoplimit {val}', + 'RTAX_INITCWND': lambda val: f' initcwnd {val}', + 'RTAX_FEATURES': lambda val: f' features {__format_reatures(val)}', + 'RTAX_RTO_MIN': lambda val: f' rto_min {__format_metric_time(val)}', + 'RTAX_INITRWND': lambda val: f' initrwnd {val}', + 'RTAX_QUICKACK': lambda val: f' quickack {val}', + } + + for rec in data.get('attrs'): + if rec[0] in formatters: + result += formatters[rec[0]](rec[1]) + + return self._output(result) + + def _format_rta_pref(self, data: int) -> str: + """ + Formats a pref attribute. + """ + pref = { + ICMPV6_ROUTER_PREF_LOW: "low", + ICMPV6_ROUTER_PREF_MEDIUM: "medium", + ICMPV6_ROUTER_PREF_HIGH: "high", + } + + return self._output(f' pref {pref.get(data, data)}') + + def _format_rta_multipath(self, mcast_cloned: bool, family: int, data: List[nh]) -> str: + """ + Formats multipath route attributes. + """ + result = '' + first = True + for rec in data: + if mcast_cloned: + if first: + result += ' Oifs: ' + first = False + else: + result += ' ' + else: + result += ' nexthop ' + + if rec.get_attr('RTA_ENCAP'): + result += self._format_rta_encap(rec.get_attr('RTA_ENCAP_TYPE'), + rec.get_attr('RTA_ENCAP')) + + if rec.get_attr('RTA_NEWDST'): + result += self._format_rta_newdest(rec.get_attr('RTA_NEWDST')) + + if rec.get_attr('RTA_GATEWAY'): + result += self._format_rta_gateway(rec.get_attr('RTA_GATEWAY')) + + if rec.get_attr('RTA_VIA'): + result += self._format_rta_via(rec.get_attr('RTA_VIA')) + + if rec.get_attr('RTA_FLOW'): + result += self._format_flow_field(rec.get_attr('RTA_FLOW')) + + result += f' dev {self._get_if_name_by_index(rec.get("oif"))}' + if mcast_cloned: + if rec.get("hops") != 1: + result += f' (ttl>{rec.get("hops")})' + else: + if family != AF_MPLS: + result += f' weight {rec.get("hops") + 1}' + + result += self._format_rt_flags(rec.get("flags")) + + return self._output(result) + + def format(self, msg: rtmsg.rtmsg) -> str: + """ + Formats a network route message into a human-readable string representation. + """ + message = self._prepare_start_message(msg.get('event')) + + message += RT_TYPES.get(msg.get('type')) + + if msg.get_attr('RTA_DST'): + host_len = self.af_bit_len(msg.get('family')) + if msg.get('dst_len') != host_len: + message += f' {msg.get_attr("RTA_DST")}/{msg.get("dst_len")}' + else: + message += f' {msg.get_attr("RTA_DST")}' + elif msg.get('dst_len'): + message += f' 0/{msg.get("dst_len")}' + else: + message += ' default' + + if msg.get_attr('RTA_SRC'): + message += f' from {msg.get_attr("RTA_SRC")}' + elif msg.get('src_len'): + message += f' from 0/{msg.get("src_len")}' + + message += self._format_simple_field(msg.get_attr('RTA_NH_ID'), prefix='nhid') + + if msg.get_attr('RTA_NEWDST'): + message += self._format_rta_newdest(msg.get_attr('RTA_NEWDST')) + + if msg.get_attr('RTA_ENCAP'): + message += self._format_rta_encap(msg.get_attr('RTA_ENCAP_TYPE'), + msg.get_attr('RTA_ENCAP')) + + message += self._format_simple_field(msg.get('tos'), prefix='tos') + + if msg.get_attr('RTA_GATEWAY'): + message += self._format_rta_gateway(msg.get_attr('RTA_GATEWAY')) + + if msg.get_attr('RTA_VIA'): + message += self._format_rta_via(msg.get_attr('RTA_VIA')) + + if msg.get_attr('RTA_OIF') is not None: + message += f' dev {self._get_if_name_by_index(msg.get_attr("RTA_OIF"))}' + + if msg.get_attr("RTA_TABLE"): + message += f' table {self._format_rttable(msg.get_attr("RTA_TABLE"))}' + + if not msg.get('flags') & RTM_F_CLONED: + message += f' proto {RT_PROTO.get(msg.get("proto"))}' + + if not msg.get('scope') == rtmsg.RT_SCOPE_UNIVERSE: + message += f' scope {RT_SCOPE_TO_NAME.get(msg.get("scope"))}' + + message += self._format_simple_field(msg.get_attr('RTA_PREFSRC'), prefix='src') + message += self._format_simple_field(msg.get_attr('RTA_PRIORITY'), prefix='metric') + + message += self._format_rt_flags(msg.get("flags")) + + if msg.get_attr('RTA_MARK'): + mark = msg.get_attr("RTA_MARK") + if mark >= 16: + message += f' mark 0x{mark:x}' + else: + message += f' mark {mark}' + + if msg.get_attr('RTA_FLOW'): + message += self._format_flow_field(msg.get_attr('RTA_FLOW')) + + message += self._format_simple_field(msg.get_attr('RTA_UID'), prefix='uid') + + if msg.get_attr('RTA_METRICS'): + message += self._format_rta_metrics(msg.get_attr("RTA_METRICS")) + + if msg.get_attr('RTA_IIF') is not None: + message += f' iif {self._get_if_name_by_index(msg.get_attr("RTA_IIF"))}' + + if msg.get_attr('RTA_PREF') is not None: + message += self._format_rta_pref(msg.get_attr("RTA_PREF")) + + if msg.get_attr('RTA_TTL_PROPAGATE') is not None: + message += f' ttl-propogate {"enabled" if msg.get_attr("RTA_TTL_PROPAGATE") else "disabled"}' + + if msg.get_attr('RTA_MULTIPATH') is not None: + _tmp = self._format_rta_multipath( + mcast_cloned=msg.get('flags') & RTM_F_CLONED and msg.get('type') == RTYPES['RTN_MULTICAST'], + family=msg.get('family'), + data=msg.get_attr("RTA_MULTIPATH")) + message += f' {_tmp}' + + return self._output(message) + + +class AddrFormatter(BaseMSGFormatter): + """ + A formatter class for handling address-related network messages + `RTM_NEWADDR` and `RTM_DELADDR`. + """ + INFINITY_LIFE_TIME = _INFINITY + + def _format_ifa_flags(self, flags: int, family: int) -> str: + """ + Formats address flags into a human-readable string. + """ + result = list() + if flags: + if not flags & IFA_FLAGS['permanent']: + result.append('dynamic') + flags &= ~IFA_FLAGS['permanent'] + + if flags & IFA_FLAGS['temporary'] and family == socket.AF_INET6: + result.append('temporary') + flags &= ~IFA_FLAGS['temporary'] + + result.extend(self._parse_flag(flags, IFA_FLAGS)) + + return self._output(",".join(result)) + + def _format_ifa_addr(self, local: str, addr: str, preflen: int, priority: int) -> str: + """ + Formats address information into a shuman-readable string. + """ + result = '' + local = local or addr + addr = addr or local + + if local: + result += f'{local}' + if addr and addr != local: + result += f' peer {addr}' + result += f'/{preflen}' + + if priority: + result += f' {priority}' + + return self._output(result) + + def _format_ifa_cacheinfo(self, data: ifaddrmsg.ifaddrmsg.cacheinfo) -> str: + """ + Formats cache information for an address. + """ + result = '' + _map = { + 'ifa_valid': 'valid_lft', + 'ifa_preferred': 'preferred_lft', + } + + for key in ['ifa_valid', 'ifa_preferred']: + val = data.get(key) + if val == self.INFINITY_LIFE_TIME: + result += f'{_map.get(key)} forever ' + else: + result += f'{_map.get(key)} {val}sec ' + + return self._output(result) + + def format(self, msg: ifaddrmsg.ifaddrmsg) -> str: + """ + Formats a full network address message. + Combine attributes such as index, family, address, flags, and cache + information into a structured output string. + """ + message = self._prepare_start_message(msg.get('event')) + + message += f'{msg.get("index")}: {self._get_if_name_by_index(msg.get("index"))} ' + message += f'{FAMILY_TO_NAME.get(msg.get("family"), msg.get("family"))} ' + + message += self._format_ifa_addr( + msg.get_attr('IFA_LOCAL'), + msg.get_attr('IFA_ADDRESS'), + msg.get('prefixlen'), + msg.get_attr('IFA_RT_PRIORITY') + ) + message += self._format_simple_field(msg.get_attr('IFA_BROADCAST'), prefix='brd') + message += self._format_simple_field(msg.get_attr('IFA_ANYCAST'), prefix='any') + + if msg.get('scope') is not None: + message += f' scope {RT_SCOPE_TO_NAME.get(msg.get("scope"))}' + + message += self._format_ifa_flags(msg.get_attr("IFA_FLAGS"), msg.get("family")) + message += self._format_simple_field(msg.get_attr('IFA_LABEL'), prefix='label:') + + if msg.get_attr('IFA_CACHEINFO'): + message += self._format_ifa_cacheinfo(msg.get_attr('IFA_CACHEINFO')) + + return self._output(message) + + +class NeighFormatter(BaseMSGFormatter): + """ + A formatter class for handling neighbor-related network messages + `RTM_NEWNEIGH`, `RTM_DELNEIGH` and `RTM_GETNEIGH` + """ + def _format_ntf_flags(self, flags: int) -> str: + """ + Formats neighbor table entry flags into a human-readable string. + """ + result = list() + result.extend(self._parse_flag(flags, NTF_FlAGS)) + + return self._output(",".join(result)) + + def _format_neigh_state(self, data: int) -> str: + """ + Formats the state of a neighbor entry. + """ + result = list() + result.extend(self._parse_flag(data, NEIGH_STATE_FLAGS)) + + return self._output(",".join(result)) + + def format(self, msg: ndmsg.ndmsg) -> str: + """ + Formats a full neighbor-related network message. + Combine attributes such as destination, device, link-layer address, + flags, state, and protocol into a structured output string. + """ + message = self._prepare_start_message(msg.get('event')) + message += self._format_simple_field(msg.get_attr('NDA_DST'), prefix='') + + if msg.get("ifindex") is not None: + message += f' dev {self._get_if_name_by_index(msg.get("ifindex"))}' + + message += self._format_simple_field(msg.get_attr('NDA_LLADDR'), prefix='lladdr') + message += f' {self._format_ntf_flags(msg.get("flags"))}' + message += f' {self._format_neigh_state(msg.get("state"))}' + + if msg.get_attr('NDA_PROTOCOL'): + message += f' proto {RT_PROTO.get(msg.get_attr("NDA_PROTOCOL"), msg.get_attr("NDA_PROTOCOL"))}' + + return self._output(message) + + +class RuleFormatter(BaseMSGFormatter): + """ + A formatter class for handling ruting tule network messages + `RTM_NEWRULE` and `RTM_DELRULE` + """ + def _format_direction(self, data: str, length: int, host_len: int): + """ + Formats the direction of traffic based on source or destination and prefix length. + """ + result = '' + if data: + result += f' {data}' + if length != host_len: + result += f'/{length}' + elif length: + result += f' 0/{length}' + + return self._output(result) + + def _format_fra_interface(self, data: str, flags: int, prefix: str): + """ + Formats interface-related attributes. + """ + result = f'{prefix} {data}' + if flags & FIB_RULE_IIF_DETACHED: + result += '[detached]' + + return self._output(result) + + def _format_fra_range(self, data: [str, dict], prefix: str): + """ + Formats a range of values (e.g., UID, sport, or dport). + """ + result = '' + if data: + if isinstance(data, str): + result += f' {prefix} {data}' + else: + result += f' {prefix} {data.get("start")}:{data.get("end")}' + return self._output(result) + + def _format_fra_table(self, msg: fibmsg): + """ + Formats the lookup table and associated attributes in the message. + """ + def __format_field(data: int, prefix: str): + if data and data not in [-1, _INFINITY]: + return f' {prefix} {data}' + return '' + + result = '' + table = msg.get_attr('FRA_TABLE') or msg.get('table') + if table: + result += f' lookup {self._format_rttable(table)}' + result += __format_field(msg.get_attr('FRA_SUPPRESS_PREFIXLEN'), 'suppress_prefixlength') + result += __format_field(msg.get_attr('FRA_SUPPRESS_IFGROUP'), 'suppress_ifgroup') + + return self._output(result) + + def _format_fra_action(self, msg: fibmsg): + """ + Formats the action associated with the rule. + """ + result = '' + if msg.get('action') == RTYPES.get('RTN_NAT'): + if msg.get_attr('RTA_GATEWAY'): # looks like deprecated but still use in iproute2 + result += f' map-to {msg.get_attr("RTA_GATEWAY")}' + else: + result += ' masquerade' + + elif msg.get('action') == FR_ACT_GOTO: + result += f' goto {msg.get_attr("FRA_GOTO") or "none"}' + if msg.get('flags') & FIB_RULE_UNRESOLVED: + result += ' [unresolved]' + + elif msg.get('action') == FR_ACT_NOP: + result += ' nop' + + elif msg.get('action') != FR_ACT_TO_TBL: + result += f' {RTYPES.get(msg.get("action"))}' + + return self._output(result) + + def format(self, msg: fibmsg): + """ + Formats a complete routing rule message. + Combines information about source, destination, interfaces, actions, + and other attributes into a single formatted string. + """ + message = self._prepare_start_message(msg.get('event')) + host_len = self.af_bit_len(msg.get('family')) + message += self._format_simple_field(msg.get_attr('FRA_PRIORITY'), prefix='') + + if msg.get('flags') & FIB_RULE_INVERT: + message += ' not' + + tmp = self._format_direction(msg.get_attr('FRA_SRC'), msg.get('src_len'), host_len) + message += ' from' + (tmp if tmp else ' all ') + + if msg.get_attr('FRA_DST'): + tmp = self._format_direction(msg.get_attr('FRA_DST'), msg.get('dst_len'), host_len) + message += ' to' + tmp + + if msg.get('tos'): + message += f' tos {hex(msg.get("tos"))}' + + if msg.get_attr('FRA_FWMARK') or msg.get_attr('FRA_FWMASK'): + mark = msg.get_attr('FRA_FWMARK') or 0 + mask = msg.get_attr('FRA_FWMASK') or 0 + if mask != 0xFFFFFFFF: + message += f' fwmark {mark}/{mask}' + else: + message += f' fwmark {mark}' + + if msg.get_attr('FRA_IIFNAME'): + message += self._format_fra_interface( + msg.get_attr('FRA_IIFNAME'), + msg.get('flags'), + 'iif' + ) + + if msg.get_attr('FRA_OIFNAME'): + message += self._format_fra_interface( + msg.get_attr('FRA_OIFNAME'), + msg.get('flags'), + 'oif' + ) + + if msg.get_attr('FRA_L3MDEV'): + message += f' lookup [l3mdev-table]' + + if msg.get_attr('FRA_UID_RANGE'): + message += self._format_fra_range(msg.get_attr('FRA_UID_RANGE'), 'uidrange') + + message += self._format_simple_field(msg.get_attr('FRA_IP_PROTO'), prefix='ipproto') + + if msg.get_attr('FRA_SPORT_RANGE'): + message += self._format_fra_range(msg.get_attr('FRA_SPORT_RANGE'), 'sport') + + if msg.get_attr('FRA_DPORT_RANGE'): + message += self._format_fra_range(msg.get_attr('FRA_DPORT_RANGE'), 'dport') + + message += self._format_simple_field(msg.get_attr('FRA_TUN_ID'), prefix='tun_id') + + message += self._format_fra_table(msg) + + if msg.get_attr('FRA_FLOW'): + message += self._format_flow_field(msg.get_attr('FRA_FLOW')) + + message += self._format_fra_action(msg) + + if msg.get_attr('FRA_PROTOCOL'): + message += f' proto {RT_PROTO.get(msg.get_attr("FRA_PROTOCOL"), msg.get_attr("FRA_PROTOCOL"))}' + + return self._output(message) + + +class AddrlabelFormatter(BaseMSGFormatter): + # Not implemented decoder on pytroute2 but ip monitor use it message + pass + + +class PrefixFormatter(BaseMSGFormatter): + # Not implemented decoder on pytroute2 but ip monitor use it message + pass + + +class NetconfFormatter(BaseMSGFormatter): + # Not implemented decoder on pytroute2 but ip monitor use it message + pass + + +EVENT_MAP = { + rtnl.RTM_NEWROUTE: {'parser': RouteFormatter, 'event': 'route'}, + rtnl.RTM_DELROUTE: {'parser': RouteFormatter, 'event': 'route'}, + rtnl.RTM_NEWLINK: {'parser': LinkFormatter, 'event': 'link'}, + rtnl.RTM_DELLINK: {'parser': LinkFormatter, 'event': 'link'}, + rtnl.RTM_NEWADDR: {'parser': AddrFormatter, 'event': 'addr'}, + rtnl.RTM_DELADDR: {'parser': AddrFormatter, 'event': 'addr'}, + # rtnl.RTM_NEWADDRLABEL: {'parser': AddrlabelFormatter, 'event': 'addrlabel'}, + # rtnl.RTM_DELADDRLABEL: {'parser': AddrlabelFormatter, 'event': 'addrlabel'}, + rtnl.RTM_NEWNEIGH: {'parser': NeighFormatter, 'event': 'neigh'}, + rtnl.RTM_DELNEIGH: {'parser': NeighFormatter, 'event': 'neigh'}, + rtnl.RTM_GETNEIGH: {'parser': NeighFormatter, 'event': 'neigh'}, + # rtnl.RTM_NEWPREFIX: {'parser': PrefixFormatter, 'event': 'prefix'}, + rtnl.RTM_NEWRULE: {'parser': RuleFormatter, 'event': 'rule'}, + rtnl.RTM_DELRULE: {'parser': RuleFormatter, 'event': 'rule'}, + # rtnl.RTM_NEWNETCONF: {'parser': NetconfFormatter, 'event': 'netconf'}, + # rtnl.RTM_DELNETCONF: {'parser': NetconfFormatter, 'event': 'netconf'}, +} + + +def sig_handler(signum, frame): + process_name = multiprocessing.current_process().name + logger.debug( + f'[{process_name}]: {"Shutdown" if signum == signal.SIGTERM else "Reload"} signal received...' + ) + shutdown_event.set() + + +def parse_event_type(header: Dict) -> tuple: + """ + Extract event type and parser. + """ + event_type = EVENT_MAP.get(header['type'], {}).get('event', 'unknown') + _parser = EVENT_MAP.get(header['type'], {}).get('parser') + + if _parser is None: + raise UnsupportedMessageType(f'Unsupported message type: {header["type"]}') + + return event_type, _parser + + +def is_need_to_log(event_type: AnyStr, conf_event: Dict): + """ + Filter message by event type and protocols + """ + conf = conf_event.get(event_type) + if conf == {}: + return True + return False + + +def parse_event(msg: nfct_msg, conf_event: Dict) -> str: + """ + Convert nfct_msg to internal data dict. + """ + data = '' + event_type, parser = parse_event_type(msg['header']) + if event_type == 'link': + remember_if_index(idx=msg.get('index'), event_type=msg['header'].get('type')) + + if not is_need_to_log(event_type, conf_event): + return data + + message = parser().format(msg) + if message: + data = f'{f"[{event_type}]".upper():<{7}} {message}' + + return data + + +def worker(ct: IPRoute, shutdown_event: multiprocessing.Event, conf_event: Dict) -> None: + """ + Main function of parser worker process + """ + process_name = multiprocessing.current_process().name + logger.debug(f'[{process_name}] started') + timeout = 0.1 + while not shutdown_event.is_set(): + if not ct.buffer_queue.empty(): + msg = None + try: + for msg in ct.get(): + message = parse_event(msg, conf_event) + if message: + if logger.level == logging.DEBUG: + logger.debug(f'[{process_name}]: {message} raw: {msg}') + else: + logger.info(message) + except queue.Full: + logger.error('IPRoute message queue if full.') + except UnsupportedMessageType as e: + logger.debug(f'{e} =====> raw msg: {msg}') + except Exception as e: + logger.error(f'Unexpected error: {e.__class__} {e} [{msg}]') + else: + sleep(timeout) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '-c', + '--config', + action='store', + help='Path to vyos-network-event-logger configuration', + required=True, + type=Path, + ) + + args = parser.parse_args() + try: + config = read_json(args.config) + except Exception as err: + logger.error(f'Configuration file "{args.config}" does not exist or malformed: {err}') + exit(1) + + set_log_level(config.get('log_level', 'info')) + + signal.signal(signal.SIGHUP, sig_handler) + signal.signal(signal.SIGTERM, sig_handler) + + if 'event' in config: + event_groups = list(config.get('event').keys()) + else: + logger.error(f'Configuration is wrong. Event filter is empty.') + exit(1) + + conf_event = config['event'] + qsize = config.get('queue_size') + ct = IPRoute(async_qsize=int(qsize) if qsize else None) + ct.buffer_queue = multiprocessing.Queue(ct.async_qsize) + ct.bind(async_cache=True) + + processes = list() + try: + for _ in range(multiprocessing.cpu_count()): + p = multiprocessing.Process(target=worker, args=(ct, shutdown_event, conf_event)) + processes.append(p) + p.start() + logger.info('IPRoute socket bound and listening for messages.') + + while not shutdown_event.is_set(): + if not ct.pthread.is_alive(): + if ct.buffer_queue.qsize() / ct.async_qsize < 0.9: + if not shutdown_event.is_set(): + logger.debug('Restart listener thread') + # restart listener thread after queue overloaded when queue size low than 90% + ct.pthread = threading.Thread(name='Netlink async cache', target=ct.async_recv) + ct.pthread.daemon = True + ct.pthread.start() + else: + sleep(0.1) + finally: + for p in processes: + p.join() + if not p.is_alive(): + logger.debug(f'[{p.name}]: finished') + ct.close() + logging.info('IPRoute socket closed.') + exit() |