summaryrefslogtreecommitdiff
path: root/src/services
diff options
context:
space:
mode:
Diffstat (limited to 'src/services')
-rw-r--r--src/services/api/__init__.py0
-rw-r--r--src/services/api/graphql/bindings.py38
-rw-r--r--src/services/api/graphql/graphql/auth_token_mutation.py37
-rw-r--r--src/services/api/graphql/graphql/mutations.py78
-rw-r--r--src/services/api/graphql/graphql/queries.py78
-rw-r--r--src/services/api/graphql/libs/__init__.py0
-rw-r--r--src/services/api/graphql/libs/key_auth.py24
-rw-r--r--src/services/api/graphql/libs/token_auth.py49
-rw-r--r--src/services/api/graphql/routers.py77
-rw-r--r--src/services/api/graphql/session/session.py39
-rw-r--r--src/services/api/graphql/state.py4
-rw-r--r--src/services/api/rest/__init__.py0
-rw-r--r--src/services/api/rest/models.py320
-rw-r--r--src/services/api/rest/routers.py778
-rw-r--r--src/services/api/session.py41
-rwxr-xr-xsrc/services/vyos-commitd457
-rwxr-xr-xsrc/services/vyos-configd212
-rwxr-xr-xsrc/services/vyos-conntrack-logger2
-rwxr-xr-xsrc/services/vyos-domain-resolver313
-rwxr-xr-xsrc/services/vyos-hostsd4
-rwxr-xr-xsrc/services/vyos-http-api-server1012
-rw-r--r--src/services/vyos-network-event-logger1218
22 files changed, 3638 insertions, 1143 deletions
diff --git a/src/services/api/__init__.py b/src/services/api/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/services/api/__init__.py
diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py
index ef4966466..ebf745f32 100644
--- a/src/services/api/graphql/bindings.py
+++ b/src/services/api/graphql/bindings.py
@@ -1,4 +1,4 @@
-# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io>
+# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
@@ -13,24 +13,40 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
import vyos.defaults
-from . graphql.queries import query
-from . graphql.mutations import mutation
-from . graphql.directives import directives_dict
-from . graphql.errors import op_mode_error
-from . graphql.auth_token_mutation import auth_token_mutation
-from . libs.token_auth import init_secret
-from . import state
-from ariadne import make_executable_schema, load_schema_from_path, snake_case_fallback_resolvers
+
+from ariadne import make_executable_schema
+from ariadne import load_schema_from_path
+from ariadne import snake_case_fallback_resolvers
+
+from .graphql.queries import query
+from .graphql.mutations import mutation
+from .graphql.directives import directives_dict
+from .graphql.errors import op_mode_error
+from .graphql.auth_token_mutation import auth_token_mutation
+from .libs.token_auth import init_secret
+
+from ..session import SessionState
+
def generate_schema():
+ state = SessionState()
api_schema_dir = vyos.defaults.directories['api_schema']
- if state.settings['app'].state.vyos_auth_type == 'token':
+ if state.auth_type == 'token':
init_secret()
type_defs = load_schema_from_path(api_schema_dir)
- schema = make_executable_schema(type_defs, query, op_mode_error, mutation, auth_token_mutation, snake_case_fallback_resolvers, directives=directives_dict)
+ schema = make_executable_schema(
+ type_defs,
+ query,
+ op_mode_error,
+ mutation,
+ auth_token_mutation,
+ snake_case_fallback_resolvers,
+ directives=directives_dict,
+ )
return schema
diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py
index a53fa4d60..c74364603 100644
--- a/src/services/api/graphql/graphql/auth_token_mutation.py
+++ b/src/services/api/graphql/graphql/auth_token_mutation.py
@@ -19,11 +19,12 @@ from typing import Dict
from ariadne import ObjectType
from graphql import GraphQLResolveInfo
-from .. libs.token_auth import generate_token
-from .. session.session import get_user_info
-from .. import state
+from ..libs.token_auth import generate_token
+from ..session.session import get_user_info
+from ...session import SessionState
+
+auth_token_mutation = ObjectType('Mutation')
-auth_token_mutation = ObjectType("Mutation")
@auth_token_mutation.field('AuthToken')
def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict):
@@ -31,10 +32,13 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict):
user = data['username']
passwd = data['password']
- secret = state.settings['secret']
- exp_interval = int(state.settings['app'].state.vyos_token_exp)
- expiration = (datetime.datetime.now(tz=datetime.timezone.utc) +
- datetime.timedelta(seconds=exp_interval))
+ state = SessionState()
+
+ secret = getattr(state, 'secret', '')
+ exp_interval = int(state.token_exp)
+ expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(
+ seconds=exp_interval
+ )
res = generate_token(user, passwd, secret, expiration)
try:
@@ -44,18 +48,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict):
pass
if 'token' in res:
data['result'] = res
- return {
- "success": True,
- "data": data
- }
+ return {'success': True, 'data': data}
if 'errors' in res:
- return {
- "success": False,
- "errors": res['errors']
- }
-
- return {
- "success": False,
- "errors": ['token generation failed']
- }
+ return {'success': False, 'errors': res['errors']}
+
+ return {'success': False, 'errors': ['token generation failed']}
diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py
index d115a8e94..0b391c070 100644
--- a/src/services/api/graphql/graphql/mutations.py
+++ b/src/services/api/graphql/graphql/mutations.py
@@ -14,20 +14,23 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module
-from ariadne import ObjectType, convert_camel_case_to_snake
-from makefun import with_signature
# used below by func_sig
-from typing import Any, Dict, Optional # pylint: disable=W0611
-from graphql import GraphQLResolveInfo # pylint: disable=W0611
+from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401
+from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401
+
+from ariadne import ObjectType, convert_camel_case_to_snake
+from makefun import with_signature
-from .. import state
-from .. libs import key_auth
-from api.graphql.session.session import Session
-from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
from vyos.opmode import Error as OpModeError
-mutation = ObjectType("Mutation")
+from ...session import SessionState
+from ..libs import key_auth
+from ..session.session import Session
+from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
+
+mutation = ObjectType('Mutation')
+
def make_mutation_resolver(mutation_name, class_name, session_func):
"""Dynamically generate a resolver for the mutation named in the
@@ -45,12 +48,13 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
func_base_name = convert_camel_case_to_snake(class_name)
resolver_name = f'resolve_{func_base_name}'
func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)'
+ state = SessionState()
@mutation.field(mutation_name)
@with_signature(func_sig, func_name=resolver_name)
async def func_impl(*args, **kwargs):
try:
- auth_type = state.settings['app'].state.vyos_auth_type
+ auth_type = state.auth_type
if auth_type == 'key':
data = kwargs['data']
@@ -58,10 +62,7 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
auth = key_auth.auth_required(key)
if auth is None:
- return {
- "success": False,
- "errors": ['invalid API key']
- }
+ return {'success': False, 'errors': ['invalid API key']}
# We are finished with the 'key' entry, and may remove so as to
# pass the rest of data (if any) to function.
@@ -76,21 +77,15 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
if user is None:
error = info.context.get('error')
if error is not None:
- return {
- "success": False,
- "errors": [error]
- }
- return {
- "success": False,
- "errors": ['not authenticated']
- }
+ return {'success': False, 'errors': [error]}
+ return {'success': False, 'errors': ['not authenticated']}
else:
# AtrributeError will have already been raised if no
- # vyos_auth_type; validation and defaultValue ensure it is
+ # auth_type; validation and defaultValue ensure it is
# one of the previous cases, so this is never reached.
pass
- session = state.settings['app'].state.vyos_session
+ session = state.session
# one may override the session functions with a local subclass
try:
@@ -105,35 +100,36 @@ def make_mutation_resolver(mutation_name, class_name, session_func):
result = method()
data['result'] = result
- return {
- "success": True,
- "data": data
- }
+ return {'success': True, 'data': data}
except OpModeError as e:
typename = type(e).__name__
msg = str(e)
return {
- "success": False,
- "errore": ['op_mode_error'],
- "op_mode_error": {"name": f"{typename}",
- "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"),
- "vyos_code": op_mode_err_code.get(typename, 9999)}
+ 'success': False,
+ 'errore': ['op_mode_error'],
+ 'op_mode_error': {
+ 'name': f'{typename}',
+ 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'),
+ 'vyos_code': op_mode_err_code.get(typename, 9999),
+ },
}
except Exception as error:
- return {
- "success": False,
- "errors": [repr(error)]
- }
+ return {'success': False, 'errors': [repr(error)]}
return func_impl
+
def make_config_session_mutation_resolver(mutation_name):
- return make_mutation_resolver(mutation_name, mutation_name,
- convert_camel_case_to_snake(mutation_name))
+ return make_mutation_resolver(
+ mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name)
+ )
+
def make_gen_op_mutation_resolver(mutation_name):
return make_mutation_resolver(mutation_name, mutation_name, 'gen_op_mutation')
+
def make_composite_mutation_resolver(mutation_name):
- return make_mutation_resolver(mutation_name, mutation_name,
- convert_camel_case_to_snake(mutation_name))
+ return make_mutation_resolver(
+ mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name)
+ )
diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py
index 717098259..9303fe909 100644
--- a/src/services/api/graphql/graphql/queries.py
+++ b/src/services/api/graphql/graphql/queries.py
@@ -14,20 +14,23 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module
-from ariadne import ObjectType, convert_camel_case_to_snake
-from makefun import with_signature
# used below by func_sig
-from typing import Any, Dict, Optional # pylint: disable=W0611
-from graphql import GraphQLResolveInfo # pylint: disable=W0611
+from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401
+from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401
+
+from ariadne import ObjectType, convert_camel_case_to_snake
+from makefun import with_signature
-from .. import state
-from .. libs import key_auth
-from api.graphql.session.session import Session
-from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
from vyos.opmode import Error as OpModeError
-query = ObjectType("Query")
+from ...session import SessionState
+from ..libs import key_auth
+from ..session.session import Session
+from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code
+
+query = ObjectType('Query')
+
def make_query_resolver(query_name, class_name, session_func):
"""Dynamically generate a resolver for the query named in the
@@ -45,12 +48,13 @@ def make_query_resolver(query_name, class_name, session_func):
func_base_name = convert_camel_case_to_snake(class_name)
resolver_name = f'resolve_{func_base_name}'
func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)'
+ state = SessionState()
@query.field(query_name)
@with_signature(func_sig, func_name=resolver_name)
async def func_impl(*args, **kwargs):
try:
- auth_type = state.settings['app'].state.vyos_auth_type
+ auth_type = state.auth_type
if auth_type == 'key':
data = kwargs['data']
@@ -58,10 +62,7 @@ def make_query_resolver(query_name, class_name, session_func):
auth = key_auth.auth_required(key)
if auth is None:
- return {
- "success": False,
- "errors": ['invalid API key']
- }
+ return {'success': False, 'errors': ['invalid API key']}
# We are finished with the 'key' entry, and may remove so as to
# pass the rest of data (if any) to function.
@@ -76,21 +77,15 @@ def make_query_resolver(query_name, class_name, session_func):
if user is None:
error = info.context.get('error')
if error is not None:
- return {
- "success": False,
- "errors": [error]
- }
- return {
- "success": False,
- "errors": ['not authenticated']
- }
+ return {'success': False, 'errors': [error]}
+ return {'success': False, 'errors': ['not authenticated']}
else:
# AtrributeError will have already been raised if no
- # vyos_auth_type; validation and defaultValue ensure it is
+ # auth_type; validation and defaultValue ensure it is
# one of the previous cases, so this is never reached.
pass
- session = state.settings['app'].state.vyos_session
+ session = state.session
# one may override the session functions with a local subclass
try:
@@ -105,35 +100,36 @@ def make_query_resolver(query_name, class_name, session_func):
result = method()
data['result'] = result
- return {
- "success": True,
- "data": data
- }
+ return {'success': True, 'data': data}
except OpModeError as e:
typename = type(e).__name__
msg = str(e)
return {
- "success": False,
- "errors": ['op_mode_error'],
- "op_mode_error": {"name": f"{typename}",
- "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"),
- "vyos_code": op_mode_err_code.get(typename, 9999)}
+ 'success': False,
+ 'errors': ['op_mode_error'],
+ 'op_mode_error': {
+ 'name': f'{typename}',
+ 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'),
+ 'vyos_code': op_mode_err_code.get(typename, 9999),
+ },
}
except Exception as error:
- return {
- "success": False,
- "errors": [repr(error)]
- }
+ return {'success': False, 'errors': [repr(error)]}
return func_impl
+
def make_config_session_query_resolver(query_name):
- return make_query_resolver(query_name, query_name,
- convert_camel_case_to_snake(query_name))
+ return make_query_resolver(
+ query_name, query_name, convert_camel_case_to_snake(query_name)
+ )
+
def make_gen_op_query_resolver(query_name):
return make_query_resolver(query_name, query_name, 'gen_op_query')
+
def make_composite_query_resolver(query_name):
- return make_query_resolver(query_name, query_name,
- convert_camel_case_to_snake(query_name))
+ return make_query_resolver(
+ query_name, query_name, convert_camel_case_to_snake(query_name)
+ )
diff --git a/src/services/api/graphql/libs/__init__.py b/src/services/api/graphql/libs/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/services/api/graphql/libs/__init__.py
diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py
index 2db0f7d48..ffd7f32b2 100644
--- a/src/services/api/graphql/libs/key_auth.py
+++ b/src/services/api/graphql/libs/key_auth.py
@@ -1,5 +1,21 @@
+# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+
+from ...session import SessionState
-from .. import state
def check_auth(key_list, key):
if not key_list:
@@ -10,9 +26,11 @@ def check_auth(key_list, key):
key_id = k['id']
return key_id
+
def auth_required(key):
+ state = SessionState()
api_keys = None
- api_keys = state.settings['app'].state.vyos_keys
+ api_keys = state.keys
key_id = check_auth(api_keys, key)
- state.settings['app'].state.vyos_id = key_id
+ state.id = key_id
return key_id
diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py
index 8585485c9..4f743a096 100644
--- a/src/services/api/graphql/libs/token_auth.py
+++ b/src/services/api/graphql/libs/token_auth.py
@@ -1,46 +1,67 @@
+# Copyright 2021-2024 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+
import jwt
import uuid
import pam
from secrets import token_hex
-from .. import state
+from ...session import SessionState
+
def _check_passwd_pam(username: str, passwd: str) -> bool:
if pam.authenticate(username, passwd):
return True
return False
+
def init_secret():
- length = int(state.settings['app'].state.vyos_secret_len)
+ state = SessionState()
+ length = int(state.secret_len)
secret = token_hex(length)
- state.settings['secret'] = secret
+ state.secret = secret
+
def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict:
if user is None or passwd is None:
return {}
+ state = SessionState()
if _check_passwd_pam(user, passwd):
- app = state.settings['app']
try:
- users = app.state.vyos_token_users
+ users = state.token_users
except AttributeError:
- app.state.vyos_token_users = {}
- users = app.state.vyos_token_users
+ users = state.token_users = {}
user_id = uuid.uuid1().hex
payload_data = {'iss': user, 'sub': user_id, 'exp': exp}
- secret = state.settings.get('secret')
+ secret = getattr(state, 'secret', None)
if secret is None:
- return {"errors": ['missing secret']}
- token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256")
+ return {'errors': ['missing secret']}
+ token = jwt.encode(payload=payload_data, key=secret, algorithm='HS256')
users |= {user_id: user}
return {'token': token}
else:
- return {"errors": ['failed pam authentication']}
+ return {'errors': ['failed pam authentication']}
+
def get_user_context(request):
context = {}
context['request'] = request
context['user'] = None
+ state = SessionState()
if 'Authorization' in request.headers:
auth = request.headers['Authorization']
scheme, token = auth.split()
@@ -48,8 +69,8 @@ def get_user_context(request):
return context
try:
- secret = state.settings.get('secret')
- payload = jwt.decode(token, secret, algorithms=["HS256"])
+ secret = getattr(state, 'secret', None)
+ payload = jwt.decode(token, secret, algorithms=['HS256'])
user_id: str = payload.get('sub')
if user_id is None:
return context
@@ -59,7 +80,7 @@ def get_user_context(request):
except jwt.PyJWTError:
return context
try:
- users = state.settings['app'].state.vyos_token_users
+ users = state.token_users
except AttributeError:
return context
diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py
new file mode 100644
index 000000000..ed3ee1e8c
--- /dev/null
+++ b/src/services/api/graphql/routers.py
@@ -0,0 +1,77 @@
+# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+# pylint: disable=import-outside-toplevel
+
+
+import typing
+
+from ariadne.asgi import GraphQL
+from starlette.middleware.cors import CORSMiddleware
+
+
+if typing.TYPE_CHECKING:
+ from fastapi import FastAPI
+
+
+def graphql_init(app: 'FastAPI'):
+ from ..session import SessionState
+ from .libs.token_auth import get_user_context
+
+ state = SessionState()
+
+ # import after initializaion of state
+ from .bindings import generate_schema
+
+ schema = generate_schema()
+
+ in_spec = state.introspection
+
+ # remove route and reinstall below, for any changes; alternatively, test
+ # for config_diff before proceeding
+ graphql_clear(app)
+
+ if state.origins:
+ origins = state.origins
+ app.add_route(
+ '/graphql',
+ CORSMiddleware(
+ GraphQL(
+ schema,
+ context_value=get_user_context,
+ debug=True,
+ introspection=in_spec,
+ ),
+ allow_origins=origins,
+ allow_methods=('GET', 'POST', 'OPTIONS'),
+ allow_headers=('Authorization',),
+ ),
+ )
+ else:
+ app.add_route(
+ '/graphql',
+ GraphQL(
+ schema,
+ context_value=get_user_context,
+ debug=True,
+ introspection=in_spec,
+ ),
+ )
+
+
+def graphql_clear(app: 'FastAPI'):
+ for r in app.routes:
+ if r.path == '/graphql':
+ app.routes.remove(r)
diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py
index 6ae44b9bf..619534f43 100644
--- a/src/services/api/graphql/session/session.py
+++ b/src/services/api/graphql/session/session.py
@@ -28,34 +28,45 @@ from api.graphql.libs.op_mode import normalize_output
op_mode_include_file = os.path.join(directories['data'], 'op-mode-standardized.json')
-def get_config_dict(path=[], effective=False, key_mangling=None,
- get_first_key=False, no_multi_convert=False,
- no_tag_node_value_mangle=False):
+
+def get_config_dict(
+ path=[],
+ effective=False,
+ key_mangling=None,
+ get_first_key=False,
+ no_multi_convert=False,
+ no_tag_node_value_mangle=False,
+):
config = Config()
- return config.get_config_dict(path=path, effective=effective,
- key_mangling=key_mangling,
- get_first_key=get_first_key,
- no_multi_convert=no_multi_convert,
- no_tag_node_value_mangle=no_tag_node_value_mangle)
+ return config.get_config_dict(
+ path=path,
+ effective=effective,
+ key_mangling=key_mangling,
+ get_first_key=get_first_key,
+ no_multi_convert=no_multi_convert,
+ no_tag_node_value_mangle=no_tag_node_value_mangle,
+ )
+
def get_user_info(user):
user_info = {}
- info = get_config_dict(['system', 'login', 'user', user],
- get_first_key=True)
+ info = get_config_dict(['system', 'login', 'user', user], get_first_key=True)
if not info:
- raise ValueError("No such user")
+ raise ValueError('No such user')
user_info['user'] = user
user_info['full_name'] = info.get('full-name', '')
return user_info
+
class Session:
"""
Wrapper for calling configsession functions based on GraphQL requests.
Non-nullable fields in the respective schema allow avoiding a key check
in 'data'.
"""
+
def __init__(self, session, data):
self._session = session
self._data = data
@@ -138,7 +149,6 @@ class Session:
return res
def show_user_info(self):
- session = self._session
data = self._data
user_info = {}
@@ -151,10 +161,9 @@ class Session:
return user_info
def system_status(self):
- import api.graphql.session.composite.system_status as system_status
+ from api.graphql.session.composite import system_status
session = self._session
- data = self._data
status = {}
status['host_name'] = session.show(['host', 'name']).strip()
@@ -165,7 +174,6 @@ class Session:
return status
def gen_op_query(self):
- session = self._session
data = self._data
name = self._name
op_mode_list = self._op_mode_list
@@ -189,7 +197,6 @@ class Session:
return res
def gen_op_mutation(self):
- session = self._session
data = self._data
name = self._name
op_mode_list = self._op_mode_list
diff --git a/src/services/api/graphql/state.py b/src/services/api/graphql/state.py
deleted file mode 100644
index 63db9f4ef..000000000
--- a/src/services/api/graphql/state.py
+++ /dev/null
@@ -1,4 +0,0 @@
-
-def init():
- global settings
- settings = {}
diff --git a/src/services/api/rest/__init__.py b/src/services/api/rest/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/services/api/rest/__init__.py
diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py
new file mode 100644
index 000000000..dda50010f
--- /dev/null
+++ b/src/services/api/rest/models.py
@@ -0,0 +1,320 @@
+# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+
+# pylint: disable=too-few-public-methods
+
+import json
+from html import escape
+from enum import Enum
+from typing import List
+from typing import Union
+from typing import Dict
+from typing import Self
+
+from pydantic import BaseModel
+from pydantic import StrictStr
+from pydantic import field_validator
+from pydantic import model_validator
+from fastapi.responses import HTMLResponse
+
+
+def error(code, msg):
+ msg = escape(msg, quote=False)
+ resp = {'success': False, 'error': msg, 'data': None}
+ resp = json.dumps(resp)
+ return HTMLResponse(resp, status_code=code)
+
+
+def success(data):
+ resp = {'success': True, 'data': data, 'error': None}
+ resp = json.dumps(resp)
+ return HTMLResponse(resp)
+
+
+# Pydantic models for validation
+# Pydantic will cast when possible, so use StrictStr validators added as
+# needed for additional constraints
+# json_schema_extra adds anotations to OpenAPI to add examples
+
+
+class ApiModel(BaseModel):
+ key: StrictStr
+
+
+class BasePathModel(BaseModel):
+ op: StrictStr
+ path: List[StrictStr]
+
+ @field_validator('path')
+ @classmethod
+ def check_non_empty(cls, path: str) -> str:
+ if not len(path) > 0:
+ raise ValueError('path must be non-empty')
+ return path
+
+
+class BaseConfigureModel(BasePathModel):
+ value: StrictStr = None
+
+
+class ConfigureModel(ApiModel, BaseConfigureModel):
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'set | delete | comment',
+ 'path': ['config', 'mode', 'path'],
+ }
+ }
+
+
+class ConfigureListModel(ApiModel):
+ commands: List[BaseConfigureModel]
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'commands': 'list of commands',
+ }
+ }
+
+
+class BaseConfigSectionModel(BasePathModel):
+ section: Dict
+
+
+class ConfigSectionModel(ApiModel, BaseConfigSectionModel):
+ pass
+
+
+class ConfigSectionListModel(ApiModel):
+ commands: List[BaseConfigSectionModel]
+
+
+class BaseConfigSectionTreeModel(BaseModel):
+ op: StrictStr
+ mask: Dict
+ config: Dict
+
+
+class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel):
+ pass
+
+
+class RetrieveModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+ configFormat: StrictStr = None
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'returnValue | returnValues | exists | showConfig',
+ 'path': ['config', 'mode', 'path'],
+ 'configFormat': 'json (default) | json_ast | raw',
+ }
+ }
+
+
+class ConfigFileModel(ApiModel):
+ op: StrictStr
+ file: StrictStr = None
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'save | load',
+ 'file': 'filename',
+ }
+ }
+
+
+class ImageOp(str, Enum):
+ add = 'add'
+ delete = 'delete'
+ show = 'show'
+ set_default = 'set_default'
+
+
+class ImageModel(ApiModel):
+ op: ImageOp
+ url: StrictStr = None
+ name: StrictStr = None
+
+ @model_validator(mode='after')
+ def check_data(self) -> Self:
+ if self.op == 'add':
+ if not self.url:
+ raise ValueError('Missing required field "url"')
+ elif self.op in ['delete', 'set_default']:
+ if not self.name:
+ raise ValueError('Missing required field "name"')
+
+ return self
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'add | delete | show | set_default',
+ 'url': 'imagelocation',
+ 'name': 'imagename',
+ }
+ }
+
+
+class ImportPkiModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+ passphrase: StrictStr = None
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'import_pki',
+ 'path': ['op', 'mode', 'path'],
+ 'passphrase': 'passphrase',
+ }
+ }
+
+
+class ContainerImageModel(ApiModel):
+ op: StrictStr
+ name: StrictStr = None
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'add | delete | show',
+ 'name': 'imagename',
+ }
+ }
+
+
+class GenerateModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'generate',
+ 'path': ['op', 'mode', 'path'],
+ }
+ }
+
+
+class ShowModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'show',
+ 'path': ['op', 'mode', 'path'],
+ }
+ }
+
+
+class RebootModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'reboot',
+ 'path': ['op', 'mode', 'path'],
+ }
+ }
+
+
+class ResetModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'reset',
+ 'path': ['op', 'mode', 'path'],
+ }
+ }
+
+
+class PoweroffModel(ApiModel):
+ op: StrictStr
+ path: List[StrictStr]
+
+ class Config:
+ json_schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'poweroff',
+ 'path': ['op', 'mode', 'path'],
+ }
+ }
+
+
+class TracerouteModel(ApiModel):
+ op: StrictStr
+ host: StrictStr
+
+ class Config:
+ schema_extra = {
+ 'example': {
+ 'key': 'id_key',
+ 'op': 'traceroute',
+ 'host': 'host',
+ }
+ }
+
+
+class InfoQueryParams(BaseModel):
+ model_config = {"extra": "forbid"}
+
+ version: bool = True
+ hostname: bool = True
+
+
+class Success(BaseModel):
+ success: bool
+ data: Union[str, bool, Dict]
+ error: str
+
+
+class Error(BaseModel):
+ success: bool = False
+ data: Union[str, bool, Dict]
+ error: str
+
+
+responses = {
+ 200: {'model': Success},
+ 400: {'model': Error},
+ 422: {'model': Error, 'description': 'Validation Error'},
+ 500: {'model': Error},
+}
diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py
new file mode 100644
index 000000000..e52c77fda
--- /dev/null
+++ b/src/services/api/rest/routers.py
@@ -0,0 +1,778 @@
+# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+
+# pylint: disable=line-too-long,raise-missing-from,invalid-name
+# pylint: disable=wildcard-import,unused-wildcard-import
+# pylint: disable=broad-exception-caught
+
+import json
+import copy
+import logging
+import traceback
+from threading import Lock
+from typing import Union
+from typing import Callable
+from typing import TYPE_CHECKING
+
+from fastapi import Depends
+from fastapi import Request
+from fastapi import Response
+from fastapi import HTTPException
+from fastapi import APIRouter
+from fastapi import BackgroundTasks
+from fastapi.routing import APIRoute
+from starlette.datastructures import FormData
+from starlette.formparsers import FormParser
+from starlette.formparsers import MultiPartParser
+from starlette.formparsers import MultiPartException
+from multipart.multipart import parse_options_header
+
+from vyos.config import Config
+from vyos.configtree import ConfigTree
+from vyos.configdiff import get_config_diff
+from vyos.configsession import ConfigSessionError
+
+from ..session import SessionState
+from .models import success
+from .models import error
+from .models import responses
+from .models import ApiModel
+from .models import ConfigureModel
+from .models import ConfigureListModel
+from .models import ConfigSectionModel
+from .models import ConfigSectionListModel
+from .models import ConfigSectionTreeModel
+from .models import BaseConfigSectionTreeModel
+from .models import BaseConfigureModel
+from .models import BaseConfigSectionModel
+from .models import RetrieveModel
+from .models import ConfigFileModel
+from .models import ImageModel
+from .models import ContainerImageModel
+from .models import GenerateModel
+from .models import ShowModel
+from .models import RebootModel
+from .models import ResetModel
+from .models import ImportPkiModel
+from .models import PoweroffModel
+from .models import TracerouteModel
+
+
+if TYPE_CHECKING:
+ from fastapi import FastAPI
+
+
+LOG = logging.getLogger('http_api.routers')
+
+lock = Lock()
+
+
+def check_auth(key_list, key):
+ key_id = None
+ for k in key_list:
+ if k['key'] == key:
+ key_id = k['id']
+ return key_id
+
+
+def auth_required(data: ApiModel):
+ session = SessionState()
+ key = data.key
+ api_keys = session.keys
+ key_id = check_auth(api_keys, key)
+ if not key_id:
+ raise HTTPException(status_code=401, detail='Valid API key is required')
+ session.id = key_id
+
+
+# override Request and APIRoute classes in order to convert form request to json;
+# do all explicit validation here, for backwards compatability of error messages;
+# the explicit validation may be dropped, if desired, in favor of native
+# validation by FastAPI/Pydantic, as is used for application/json requests
+class MultipartRequest(Request):
+ """Override Request class to convert form request to json"""
+
+ # pylint: disable=attribute-defined-outside-init
+ # pylint: disable=too-many-branches,too-many-statements
+
+ _form_err = ()
+
+ @property
+ def form_err(self):
+ return self._form_err
+
+ @form_err.setter
+ def form_err(self, val):
+ if not self._form_err:
+ self._form_err = val
+
+ @property
+ def orig_headers(self):
+ self._orig_headers = super().headers
+ return self._orig_headers
+
+ @property
+ def headers(self):
+ self._headers = super().headers.mutablecopy()
+ self._headers['content-type'] = 'application/json'
+ return self._headers
+
+ async def _get_form(
+ self, *, max_files: int | float = 1000, max_fields: int | float = 1000
+ ) -> FormData:
+ if self._form is None:
+ assert (
+ parse_options_header is not None
+ ), 'The `python-multipart` library must be installed to use form parsing.'
+ content_type_header = self.orig_headers.get('Content-Type')
+ content_type: bytes
+ content_type, _ = parse_options_header(content_type_header)
+ if content_type == b'multipart/form-data':
+ try:
+ multipart_parser = MultiPartParser(
+ self.orig_headers,
+ self.stream(),
+ max_files=max_files,
+ max_fields=max_fields,
+ )
+ self._form = await multipart_parser.parse()
+ except MultiPartException as exc:
+ if 'app' in self.scope:
+ raise HTTPException(status_code=400, detail=exc.message)
+ raise exc
+ elif content_type == b'application/x-www-form-urlencoded':
+ form_parser = FormParser(self.orig_headers, self.stream())
+ self._form = await form_parser.parse()
+ else:
+ self._form = FormData()
+ return self._form
+
+ async def body(self) -> bytes:
+ if not hasattr(self, '_body'):
+ forms = {}
+ merge = {}
+ body = await super().body()
+ self._body = body
+
+ form_data = await self.form()
+ if form_data:
+ endpoint = self.url.path
+ LOG.debug('processing form data')
+ for k, v in form_data.multi_items():
+ forms[k] = v
+
+ if 'data' not in forms:
+ self.form_err = (422, 'Non-empty data field is required')
+ return self._body
+ try:
+ tmp = json.loads(forms['data'])
+ except json.JSONDecodeError as e:
+ self.form_err = (400, f'Failed to parse JSON: {e}')
+ return self._body
+ if isinstance(tmp, list):
+ merge['commands'] = tmp
+ else:
+ merge = tmp
+
+ if 'commands' in merge:
+ cmds = merge['commands']
+ else:
+ cmds = copy.deepcopy(merge)
+ cmds = [cmds]
+
+ for c in cmds:
+ if not isinstance(c, dict):
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': any command must be JSON of dict",
+ )
+ return self._body
+ if 'op' not in c:
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'op' field",
+ )
+ if endpoint not in (
+ '/config-file',
+ '/container-image',
+ '/image',
+ '/configure-section',
+ '/traceroute',
+ ):
+ if 'path' not in c:
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'path' field",
+ )
+ elif not isinstance(c['path'], list):
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'path' field must be a list",
+ )
+ elif not all(isinstance(el, str) for el in c['path']):
+ self.form_err = (
+ 400,
+ f"Malformed command '{0}': 'path' field must be a list of strings",
+ )
+ if endpoint in ('/configure'):
+ if not c['path']:
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'path' list must be non-empty",
+ )
+ if 'value' in c and not isinstance(c['value'], str):
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': 'value' field must be a string",
+ )
+ if endpoint in ('/configure-section'):
+ if 'section' not in c and 'config' not in c:
+ self.form_err = (
+ 400,
+ f"Malformed command '{c}': missing 'section' or 'config' field",
+ )
+
+ if 'key' not in forms and 'key' not in merge:
+ self.form_err = (401, 'Valid API key is required')
+ if 'key' in forms and 'key' not in merge:
+ merge['key'] = forms['key']
+
+ new_body = json.dumps(merge)
+ new_body = new_body.encode()
+ self._body = new_body
+
+ return self._body
+
+
+class MultipartRoute(APIRoute):
+ """Override APIRoute class to convert form request to json"""
+
+ def get_route_handler(self) -> Callable:
+ original_route_handler = super().get_route_handler()
+
+ async def custom_route_handler(request: Request) -> Response:
+ request = MultipartRequest(request.scope, request.receive)
+ try:
+ response: Response = await original_route_handler(request)
+ except HTTPException as e:
+ return error(e.status_code, e.detail)
+ except Exception as e:
+ form_err = request.form_err
+ if form_err:
+ return error(*form_err)
+ raise e
+
+ return response
+
+ return custom_route_handler
+
+
+router = APIRouter(
+ route_class=MultipartRoute,
+ responses={**responses},
+ dependencies=[Depends(auth_required)],
+)
+
+
+self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background'
+
+
+def call_commit(s: SessionState):
+ try:
+ s.session.commit()
+ except ConfigSessionError as e:
+ s.session.discard()
+ if s.debug:
+ LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}')
+ else:
+ LOG.warning(f'ConfigSessionError: {e}')
+
+
+def _configure_op(
+ data: Union[
+ ConfigureModel,
+ ConfigureListModel,
+ ConfigSectionModel,
+ ConfigSectionListModel,
+ ConfigSectionTreeModel,
+ ],
+ _request: Request,
+ background_tasks: BackgroundTasks,
+):
+ # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
+ # pylint: disable=consider-using-with
+
+ state = SessionState()
+ session = state.session
+ env = session.get_session_env()
+
+ # Allow users to pass just one command
+ if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)):
+ data = [data]
+ else:
+ data = data.commands
+
+ # We don't want multiple people/apps to be able to commit at once,
+ # or modify the shared session while someone else is doing the same,
+ # so the lock is really global
+ lock.acquire()
+
+ config = Config(session_env=env)
+
+ status = 200
+ msg = None
+ error_msg = None
+ try:
+ for c in data:
+ op = c.op
+ if not isinstance(c, BaseConfigSectionTreeModel):
+ path = c.path
+
+ if isinstance(c, BaseConfigureModel):
+ if c.value:
+ value = c.value
+ else:
+ value = ''
+ # For vyos.configsession calls that have no separate value arguments,
+ # and for type checking too
+ cfg_path = ' '.join(path + [value]).strip()
+
+ elif isinstance(c, BaseConfigSectionModel):
+ section = c.section
+
+ elif isinstance(c, BaseConfigSectionTreeModel):
+ mask = c.mask
+ config = c.config
+
+ if isinstance(c, BaseConfigureModel):
+ if op == 'set':
+ session.set(path, value=value)
+ elif op == 'delete':
+ if state.strict and not config.exists(cfg_path):
+ raise ConfigSessionError(
+ f'Cannot delete [{cfg_path}]: path/value does not exist'
+ )
+ session.delete(path, value=value)
+ elif op == 'comment':
+ session.comment(path, value=value)
+ else:
+ raise ConfigSessionError(f"'{op}' is not a valid operation")
+
+ elif isinstance(c, BaseConfigSectionModel):
+ if op == 'set':
+ session.set_section(path, section)
+ elif op == 'load':
+ session.load_section(path, section)
+ else:
+ raise ConfigSessionError(f"'{op}' is not a valid operation")
+
+ elif isinstance(c, BaseConfigSectionTreeModel):
+ if op == 'set':
+ session.set_section_tree(config)
+ elif op == 'load':
+ session.load_section_tree(mask, config)
+ else:
+ raise ConfigSessionError(f"'{op}' is not a valid operation")
+ # end for
+ config = Config(session_env=env)
+ d = get_config_diff(config)
+
+ if d.is_node_changed(['service', 'https']):
+ background_tasks.add_task(call_commit, state)
+ msg = self_ref_msg
+ else:
+ # capture non-fatal warnings
+ out = session.commit()
+ msg = out if out else msg
+
+ LOG.info(f"Configuration modified via HTTP API using key '{state.id}'")
+ except ConfigSessionError as e:
+ session.discard()
+ status = 400
+ if state.debug:
+ LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}')
+ error_msg = str(e)
+ except Exception:
+ session.discard()
+ LOG.critical(traceback.format_exc())
+ status = 500
+
+ # Don't give the details away to the outer world
+ error_msg = 'An internal error occured. Check the logs for details.'
+ finally:
+ lock.release()
+
+ if status != 200:
+ return error(status, error_msg)
+
+ return success(msg)
+
+
+def create_path_import_pki_no_prompt(path):
+ correct_paths = ['ca', 'certificate', 'key-pair']
+ if path[1] not in correct_paths:
+ return False
+ path[3] = '--key-filename'
+ path.insert(2, '--name')
+ return ['--pki-type'] + path[1:]
+
+
+@router.post('/configure')
+def configure_op(
+ data: Union[ConfigureModel, ConfigureListModel],
+ request: Request,
+ background_tasks: BackgroundTasks,
+):
+ return _configure_op(data, request, background_tasks)
+
+
+@router.post('/configure-section')
+def configure_section_op(
+ data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel],
+ request: Request,
+ background_tasks: BackgroundTasks,
+):
+ return _configure_op(data, request, background_tasks)
+
+
+@router.post('/retrieve')
+async def retrieve_op(data: RetrieveModel):
+ state = SessionState()
+ session = state.session
+ env = session.get_session_env()
+ config = Config(session_env=env)
+
+ op = data.op
+ path = ' '.join(data.path)
+
+ try:
+ if op == 'returnValue':
+ res = config.return_value(path)
+ elif op == 'returnValues':
+ res = config.return_values(path)
+ elif op == 'exists':
+ res = config.exists(path)
+ elif op == 'showConfig':
+ config_format = 'json'
+ if data.configFormat:
+ config_format = data.configFormat
+
+ res = session.show_config(path=data.path)
+ if config_format == 'json':
+ config_tree = ConfigTree(res)
+ res = json.loads(config_tree.to_json())
+ elif config_format == 'json_ast':
+ config_tree = ConfigTree(res)
+ res = json.loads(config_tree.to_json_ast())
+ elif config_format == 'raw':
+ pass
+ else:
+ return error(400, f"'{config_format}' is not a valid config format")
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/config-file')
+def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):
+ state = SessionState()
+ session = state.session
+ env = session.get_session_env()
+ op = data.op
+ msg = None
+
+ try:
+ if op == 'save':
+ if data.file:
+ path = data.file
+ else:
+ path = '/config/config.boot'
+ msg = session.save_config(path)
+ elif op == 'load':
+ if data.file:
+ path = data.file
+ else:
+ return error(400, 'Missing required field "file"')
+
+ session.migrate_and_load_config(path)
+
+ config = Config(session_env=env)
+ d = get_config_diff(config)
+
+ if d.is_node_changed(['service', 'https']):
+ background_tasks.add_task(call_commit, state)
+ msg = self_ref_msg
+ else:
+ session.commit()
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(msg)
+
+
+@router.post('/image')
+def image_op(data: ImageModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+
+ try:
+ if op == 'add':
+ res = session.install_image(data.url)
+ elif op == 'delete':
+ res = session.remove_image(data.name)
+ elif op == 'show':
+ res = session.show(['system', 'image'])
+ elif op == 'set_default':
+ res = session.set_default_image(data.name)
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/container-image')
+def container_image_op(data: ContainerImageModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+
+ try:
+ if op == 'add':
+ if data.name:
+ name = data.name
+ else:
+ return error(400, 'Missing required field "name"')
+ res = session.add_container_image(name)
+ elif op == 'delete':
+ if data.name:
+ name = data.name
+ else:
+ return error(400, 'Missing required field "name"')
+ res = session.delete_container_image(name)
+ elif op == 'show':
+ res = session.show_container_image()
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/generate')
+def generate_op(data: GenerateModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ path = data.path
+
+ try:
+ if op == 'generate':
+ res = session.generate(path)
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/show')
+def show_op(data: ShowModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ path = data.path
+
+ try:
+ if op == 'show':
+ res = session.show(path)
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/reboot')
+def reboot_op(data: RebootModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ path = data.path
+
+ try:
+ if op == 'reboot':
+ res = session.reboot(path)
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/reset')
+def reset_op(data: ResetModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ path = data.path
+
+ try:
+ if op == 'reset':
+ res = session.reset(path)
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/import-pki')
+def import_pki(data: ImportPkiModel):
+ # pylint: disable=consider-using-with
+
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ path = data.path
+
+ lock.acquire()
+
+ try:
+ if op == 'import-pki':
+ # need to get rid or interactive mode for private key
+ if len(path) == 5 and path[3] in ['key-file', 'private-key']:
+ path_no_prompt = create_path_import_pki_no_prompt(path)
+ if not path_no_prompt:
+ return error(400, f"Invalid command: {' '.join(path)}")
+ if data.passphrase:
+ path_no_prompt += ['--passphrase', data.passphrase]
+ res = session.import_pki_no_prompt(path_no_prompt)
+ else:
+ res = session.import_pki(path)
+ if not res[0].isdigit():
+ return error(400, res)
+ # commit changes
+ session.commit()
+ res = res.split('. ')[0]
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+ finally:
+ lock.release()
+
+ return success(res)
+
+
+@router.post('/poweroff')
+def poweroff_op(data: PoweroffModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ path = data.path
+
+ try:
+ if op == 'poweroff':
+ res = session.poweroff(path)
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
+
+ return success(res)
+
+
+@router.post('/traceroute')
+def traceroute_op(data: TracerouteModel):
+ state = SessionState()
+ session = state.session
+
+ op = data.op
+ host = data.host
+
+ try:
+ if op == 'traceroute':
+ res = session.traceroute(host)
+ else:
+ return error(400, f"'{op}' is not a valid operation")
+ except ConfigSessionError as e:
+ return error(400, str(e))
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occurred. Check the logs for details.')
+
+ return success(res)
+
+
+def rest_init(app: 'FastAPI'):
+ if all(r in app.routes for r in router.routes):
+ return
+ app.include_router(router)
+
+
+def rest_clear(app: 'FastAPI'):
+ for r in router.routes:
+ if r in app.routes:
+ app.routes.remove(r)
diff --git a/src/services/api/session.py b/src/services/api/session.py
new file mode 100644
index 000000000..ad3ef660c
--- /dev/null
+++ b/src/services/api/session.py
@@ -0,0 +1,41 @@
+# Copyright 2024 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+
+class SessionState:
+ # pylint: disable=attribute-defined-outside-init
+ # pylint: disable=too-many-instance-attributes,too-few-public-methods
+
+ _instance = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super(SessionState, cls).__new__(cls)
+ cls._instance._initialize()
+ return cls._instance
+
+ def _initialize(self):
+ self.session = None
+ self.keys = []
+ self.id = None
+ self.rest = False
+ self.debug = False
+ self.strict = False
+ self.graphql = False
+ self.origins = []
+ self.introspection = False
+ self.auth_type = None
+ self.token_exp = None
+ self.secret_len = None
diff --git a/src/services/vyos-commitd b/src/services/vyos-commitd
new file mode 100755
index 000000000..e7f2d82c7
--- /dev/null
+++ b/src/services/vyos-commitd
@@ -0,0 +1,457 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2025 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+#
+import os
+import sys
+import grp
+import json
+import signal
+import socket
+import typing
+import logging
+import traceback
+import importlib.util
+import io
+from contextlib import redirect_stdout
+from dataclasses import dataclass
+from dataclasses import fields
+from dataclasses import field
+from dataclasses import asdict
+from pathlib import Path
+
+import tomli
+
+from google.protobuf.json_format import MessageToDict
+from google.protobuf.json_format import ParseDict
+
+from vyos.defaults import directories
+from vyos.utils.boot import boot_configuration_complete
+from vyos.configsource import ConfigSourceCache
+from vyos.configsource import ConfigSourceError
+from vyos.config import Config
+from vyos.frrender import FRRender
+from vyos.frrender import get_frrender_dict
+from vyos import ConfigError
+
+from vyos.proto import vycall_pb2
+
+
+@dataclass
+class Status:
+ success: bool = False
+ out: str = ''
+
+
+@dataclass
+class Call:
+ script_name: str = ''
+ tag_value: str = None
+ arg_value: str = None
+ reply: Status = None
+
+ def set_reply(self, success: bool, out: str):
+ self.reply = Status(success=success, out=out)
+
+
+@dataclass
+class Session:
+ # pylint: disable=too-many-instance-attributes
+
+ session_id: str = ''
+ dry_run: bool = False
+ atomic: bool = False
+ background: bool = False
+ config: Config = None
+ init: Status = None
+ calls: list[Call] = field(default_factory=list)
+
+ def set_init(self, success: bool, out: str):
+ self.init = Status(success=success, out=out)
+
+
+@dataclass
+class ServerConf:
+ commitd_socket: str = ''
+ session_dir: str = ''
+ running_cache: str = ''
+ session_cache: str = ''
+
+
+server_conf = None
+SOCKET_PATH = None
+conf_mode_scripts = None
+frr = None
+
+CFG_GROUP = 'vyattacfg'
+
+script_stdout_log = '/tmp/vyos-commitd-script-stdout'
+
+debug = True
+
+logger = logging.getLogger(__name__)
+logs_handler = logging.StreamHandler()
+logger.addHandler(logs_handler)
+
+if debug:
+ logger.setLevel(logging.DEBUG)
+else:
+ logger.setLevel(logging.INFO)
+
+
+vyos_conf_scripts_dir = directories['conf_mode']
+commitd_include_file = os.path.join(directories['data'], 'configd-include.json')
+
+
+def key_name_from_file_name(f):
+ return os.path.splitext(f)[0]
+
+
+def module_name_from_key(k):
+ return k.replace('-', '_')
+
+
+def path_from_file_name(f):
+ return os.path.join(vyos_conf_scripts_dir, f)
+
+
+def load_conf_mode_scripts():
+ with open(commitd_include_file) as f:
+ try:
+ include = json.load(f)
+ except OSError as e:
+ logger.critical(f'configd include file error: {e}')
+ sys.exit(1)
+ except json.JSONDecodeError as e:
+ logger.critical(f'JSON load error: {e}')
+ sys.exit(1)
+
+ # import conf_mode scripts
+ (_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir)))
+ filenames.sort()
+
+ # this is redundant, as all scripts are currently in the include file;
+ # leave it as an inexpensive check for future changes
+ load_filenames = [f for f in filenames if f in include]
+ imports = [key_name_from_file_name(f) for f in load_filenames]
+ module_names = [module_name_from_key(k) for k in imports]
+ paths = [path_from_file_name(f) for f in load_filenames]
+ to_load = list(zip(module_names, paths))
+
+ modules = []
+
+ for x in to_load:
+ spec = importlib.util.spec_from_file_location(x[0], x[1])
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ modules.append(module)
+
+ scripts = dict(zip(imports, modules))
+
+ return scripts
+
+
+def get_session_out(session: Session) -> str:
+ out = ''
+ if session.init and session.init.out:
+ out = f'{out} + init: {session.init.out} + \n'
+ for call in session.calls:
+ reply = call.reply
+ if reply and reply.out:
+ out = f'{out} + {call.script_name}: {reply.out} + \n'
+ return out
+
+
+def write_stdout_log(file_name, session):
+ if boot_configuration_complete():
+ return
+ with open(file_name, 'a') as f:
+ f.write(get_session_out(session))
+
+
+def msg_to_commit_data(msg: vycall_pb2.Commit) -> Session:
+ # pylint: disable=no-member
+
+ d = MessageToDict(msg, preserving_proto_field_name=True)
+
+ # wrap in dataclasses
+ session = Session(**d)
+ session.init = Status(**session.init) if session.init else None
+ session.calls = list(map(lambda x: Call(**x), session.calls))
+ for call in session.calls:
+ call.reply = Status(**call.reply) if call.reply else None
+
+ return session
+
+
+def commit_data_to_msg(obj: Session) -> vycall_pb2.Commit:
+ # pylint: disable=no-member
+
+ # avoid asdict attempt of deepcopy on Config obj
+ obj.config = None
+
+ msg = vycall_pb2.Commit()
+ msg = ParseDict(asdict(obj), msg, ignore_unknown_fields=True)
+
+ return msg
+
+
+def initialization(session: Session) -> Session:
+ running_cache = os.path.join(server_conf.session_dir, server_conf.running_cache)
+ session_cache = os.path.join(server_conf.session_dir, server_conf.session_cache)
+ try:
+ configsource = ConfigSourceCache(
+ running_config_cache=running_cache,
+ session_config_cache=session_cache,
+ )
+ except ConfigSourceError as e:
+ fail_msg = f'Failed to read config caches: {e}'
+ logger.critical(fail_msg)
+ session.set_init(False, fail_msg)
+ return session
+
+ session.set_init(True, '')
+
+ config = Config(config_source=configsource)
+
+ dependent_func: dict[str, list[typing.Callable]] = {}
+ setattr(config, 'dependent_func', dependent_func)
+
+ scripts_called = []
+ setattr(config, 'scripts_called', scripts_called)
+
+ dry_run = session.dry_run
+ config.set_bool_attr('dry_run', dry_run)
+ logger.debug(f'commit dry_run is {dry_run}')
+
+ session.config = config
+
+ return session
+
+
+def run_script(script_name: str, config: Config, args: list) -> tuple[bool, str]:
+ # pylint: disable=broad-exception-caught
+
+ script = conf_mode_scripts[script_name]
+ script.argv = args
+ config.set_level([])
+ dry_run = config.get_bool_attr('dry_run')
+ try:
+ c = script.get_config(config)
+ script.verify(c)
+ if not dry_run:
+ script.generate(c)
+ script.apply(c)
+ else:
+ if hasattr(script, 'call_dependents'):
+ script.call_dependents()
+ except ConfigError as e:
+ logger.error(e)
+ return False, str(e)
+ except Exception:
+ tb = traceback.format_exc()
+ logger.error(tb)
+ return False, tb
+
+ return True, ''
+
+
+def process_call_data(call: Call, config: Config, last: bool = False) -> None:
+ # pylint: disable=too-many-locals
+
+ script_name = key_name_from_file_name(call.script_name)
+
+ if script_name not in conf_mode_scripts:
+ fail_msg = f'No such script: {call.script_name}'
+ logger.critical(fail_msg)
+ call.set_reply(False, fail_msg)
+ return
+
+ config.dependency_list.clear()
+
+ tag_value = call.tag_value if call.tag_value is not None else ''
+ os.environ['VYOS_TAGNODE_VALUE'] = tag_value
+
+ args = call.arg_value.split() if call.arg_value else []
+ args.insert(0, f'{script_name}.py')
+
+ tag_ext = f'_{tag_value}' if tag_value else ''
+ script_record = f'{script_name}{tag_ext}'
+ scripts_called = getattr(config, 'scripts_called', [])
+ scripts_called.append(script_record)
+
+ with redirect_stdout(io.StringIO()) as o:
+ success, err_out = run_script(script_name, config, args)
+ amb_out = o.getvalue()
+ o.close()
+
+ out = amb_out + err_out
+
+ call.set_reply(success, out)
+
+ logger.info(f'[{script_name}] {out}')
+
+ if last:
+ scripts_called = getattr(config, 'scripts_called', [])
+ logger.debug(f'scripts_called: {scripts_called}')
+
+ if last and success:
+ tmp = get_frrender_dict(config)
+ if frr.generate(tmp):
+ # only apply a new FRR configuration if anything changed
+ # in comparison to the previous applied configuration
+ frr.apply()
+
+
+def process_session_data(session: Session) -> Session:
+ if session.init is None or not session.init.success:
+ return session
+
+ config = session.config
+ len_calls = len(session.calls)
+ for index, call in enumerate(session.calls):
+ process_call_data(call, config, last=len_calls == index + 1)
+
+ return session
+
+
+def read_message(msg: bytes) -> Session:
+ """Read message into Session instance"""
+
+ message = vycall_pb2.Commit() # pylint: disable=no-member
+ message.ParseFromString(msg)
+ session = msg_to_commit_data(message)
+
+ session = initialization(session)
+ session = process_session_data(session)
+
+ write_stdout_log(script_stdout_log, session)
+
+ return session
+
+
+def write_reply(session: Session) -> bytearray:
+ """Serialize modified object to bytearray, prepending data length
+ header"""
+
+ reply = commit_data_to_msg(session)
+ encoded_data = reply.SerializeToString()
+ byte_size = reply.ByteSize()
+ length_bytes = byte_size.to_bytes(4)
+ arr = bytearray(length_bytes)
+ arr.extend(encoded_data)
+
+ return arr
+
+
+def load_server_conf() -> ServerConf:
+ # pylint: disable=import-outside-toplevel
+ # pylint: disable=broad-exception-caught
+ from vyos.defaults import vyconfd_conf
+
+ try:
+ with open(vyconfd_conf, 'rb') as f:
+ vyconfd_conf_d = tomli.load(f)
+
+ except Exception as e:
+ logger.critical(f'Failed to open the vyconfd.conf file {vyconfd_conf}: {e}')
+ sys.exit(1)
+
+ app = vyconfd_conf_d.get('appliance', {})
+
+ conf_data = {
+ k: v for k, v in app.items() if k in [_.name for _ in fields(ServerConf)]
+ }
+
+ conf = ServerConf(**conf_data)
+
+ return conf
+
+
+def remove_if_exists(f: str):
+ try:
+ os.unlink(f)
+ except FileNotFoundError:
+ pass
+
+
+def sig_handler(_signum, _frame):
+ logger.info('stopping server')
+ raise KeyboardInterrupt
+
+
+def run_server():
+ # pylint: disable=global-statement
+
+ global server_conf
+ global SOCKET_PATH
+ global conf_mode_scripts
+ global frr
+
+ signal.signal(signal.SIGTERM, sig_handler)
+ signal.signal(signal.SIGINT, sig_handler)
+
+ logger.info('starting server')
+
+ server_conf = load_server_conf()
+ SOCKET_PATH = server_conf.commitd_socket
+ conf_mode_scripts = load_conf_mode_scripts()
+
+ cfg_group = grp.getgrnam(CFG_GROUP)
+ os.setgid(cfg_group.gr_gid)
+
+ server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+ remove_if_exists(SOCKET_PATH)
+ server_socket.bind(SOCKET_PATH)
+ Path(SOCKET_PATH).chmod(0o775)
+
+ # We only need one long-lived instance of FRRender
+ frr = FRRender()
+
+ server_socket.listen(2)
+ while True:
+ try:
+ conn, _ = server_socket.accept()
+ logger.debug('connection accepted')
+ while True:
+ # receive size of data
+ data_length = conn.recv(4)
+ if not data_length:
+ logger.debug('no data')
+ # if no data break
+ break
+
+ length = int.from_bytes(data_length)
+ # receive data
+ data = conn.recv(length)
+
+ session = read_message(data)
+ reply = write_reply(session)
+ conn.sendall(reply)
+
+ conn.close()
+ logger.debug('connection closed')
+
+ except KeyboardInterrupt:
+ break
+
+ server_socket.close()
+ sys.exit(0)
+
+
+if __name__ == '__main__':
+ run_server()
diff --git a/src/services/vyos-configd b/src/services/vyos-configd
index 3674d9627..28acccd2c 100755
--- a/src/services/vyos-configd
+++ b/src/services/vyos-configd
@@ -14,6 +14,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+# pylint: disable=redefined-outer-name
+
import os
import sys
import grp
@@ -22,9 +24,13 @@ import json
import typing
import logging
import signal
+import traceback
import importlib.util
+import io
+from contextlib import redirect_stdout
+from enum import Enum
+
import zmq
-from contextlib import contextmanager
from vyos.defaults import directories
from vyos.utils.boot import boot_configuration_complete
@@ -32,6 +38,8 @@ from vyos.configsource import ConfigSourceString
from vyos.configsource import ConfigSourceError
from vyos.configdiff import get_commit_scripts
from vyos.config import Config
+from vyos.frrender import FRRender
+from vyos.frrender import get_frrender_dict
from vyos import ConfigError
CFG_GROUP = 'vyattacfg'
@@ -49,13 +57,18 @@ if debug:
else:
logger.setLevel(logging.INFO)
-SOCKET_PATH = "ipc:///run/vyos-configd.sock"
+SOCKET_PATH = 'ipc:///run/vyos-configd.sock'
+MAX_MSG_SIZE = 65535
+PAD_MSG_SIZE = 6
+
# Response error codes
-R_SUCCESS = 1
-R_ERROR_COMMIT = 2
-R_ERROR_DAEMON = 4
-R_PASS = 8
+class Response(Enum):
+ SUCCESS = 1
+ ERROR_COMMIT = 2
+ ERROR_DAEMON = 4
+ PASS = 8
+
vyos_conf_scripts_dir = directories['conf_mode']
configd_include_file = os.path.join(directories['data'], 'configd-include.json')
@@ -64,29 +77,31 @@ configd_env_unset_file = os.path.join(directories['data'], 'vyos-configd-env-uns
# sourced on entering config session
configd_env_file = '/etc/default/vyos-configd-env'
-session_out = None
-session_mode = None
def key_name_from_file_name(f):
return os.path.splitext(f)[0]
+
def module_name_from_key(k):
return k.replace('-', '_')
+
def path_from_file_name(f):
return os.path.join(vyos_conf_scripts_dir, f)
+
# opt-in to be run by daemon
with open(configd_include_file) as f:
try:
include = json.load(f)
except OSError as e:
- logger.critical(f"configd include file error: {e}")
+ logger.critical(f'configd include file error: {e}')
sys.exit(1)
except json.JSONDecodeError as e:
- logger.critical(f"JSON load error: {e}")
+ logger.critical(f'JSON load error: {e}')
sys.exit(1)
+
# import conf_mode scripts
(_, _, filenames) = next(iter(os.walk(vyos_conf_scripts_dir)))
filenames.sort()
@@ -110,31 +125,17 @@ conf_mode_scripts = dict(zip(imports, modules))
exclude_set = {key_name_from_file_name(f) for f in filenames if f not in include}
include_set = {key_name_from_file_name(f) for f in filenames if f in include}
-@contextmanager
-def stdout_redirected(filename, mode):
- saved_stdout_fd = None
- destination_file = None
- try:
- sys.stdout.flush()
- saved_stdout_fd = os.dup(sys.stdout.fileno())
- destination_file = open(filename, mode)
- os.dup2(destination_file.fileno(), sys.stdout.fileno())
- yield
- finally:
- if saved_stdout_fd is not None:
- os.dup2(saved_stdout_fd, sys.stdout.fileno())
- os.close(saved_stdout_fd)
- if destination_file is not None:
- destination_file.close()
-
-def explicit_print(path, mode, msg):
- try:
- with open(path, mode) as f:
- f.write(f"\n{msg}\n\n")
- except OSError:
- logger.critical("error explicit_print")
-def run_script(script_name, config, args) -> int:
+def write_stdout_log(file_name, msg):
+ if boot_configuration_complete():
+ return
+ with open(file_name, 'a') as f:
+ f.write(msg)
+
+
+def run_script(script_name, config, args) -> tuple[Response, str]:
+ # pylint: disable=broad-exception-caught
+
script = conf_mode_scripts[script_name]
script.argv = args
config.set_level([])
@@ -145,64 +146,54 @@ def run_script(script_name, config, args) -> int:
script.apply(c)
except ConfigError as e:
logger.error(e)
- explicit_print(session_out, session_mode, str(e))
- return R_ERROR_COMMIT
- except Exception as e:
- logger.critical(e)
- return R_ERROR_DAEMON
+ return Response.ERROR_COMMIT, str(e)
+ except Exception:
+ tb = traceback.format_exc()
+ logger.error(tb)
+ return Response.ERROR_COMMIT, tb
+
+ return Response.SUCCESS, ''
- return R_SUCCESS
def initialization(socket):
- global session_out
- global session_mode
+ # pylint: disable=broad-exception-caught,too-many-locals
+
# Reset config strings:
active_string = ''
session_string = ''
# check first for resent init msg, in case of client timeout
while True:
- msg = socket.recv().decode("utf-8", "ignore")
+ msg = socket.recv().decode('utf-8', 'ignore')
try:
message = json.loads(msg)
- if message["type"] == "init":
- resp = "init"
+ if message['type'] == 'init':
+ resp = 'init'
socket.send(resp.encode())
- except:
+ except Exception:
break
# zmq synchronous for ipc from single client:
active_string = msg
- resp = "active"
+ resp = 'active'
socket.send(resp.encode())
- session_string = socket.recv().decode("utf-8", "ignore")
- resp = "session"
+ session_string = socket.recv().decode('utf-8', 'ignore')
+ resp = 'session'
socket.send(resp.encode())
- pid_string = socket.recv().decode("utf-8", "ignore")
- resp = "pid"
+ pid_string = socket.recv().decode('utf-8', 'ignore')
+ resp = 'pid'
socket.send(resp.encode())
- sudo_user_string = socket.recv().decode("utf-8", "ignore")
- resp = "sudo_user"
+ sudo_user_string = socket.recv().decode('utf-8', 'ignore')
+ resp = 'sudo_user'
socket.send(resp.encode())
- temp_config_dir_string = socket.recv().decode("utf-8", "ignore")
- resp = "temp_config_dir"
+ temp_config_dir_string = socket.recv().decode('utf-8', 'ignore')
+ resp = 'temp_config_dir'
socket.send(resp.encode())
- changes_only_dir_string = socket.recv().decode("utf-8", "ignore")
- resp = "changes_only_dir"
+ changes_only_dir_string = socket.recv().decode('utf-8', 'ignore')
+ resp = 'changes_only_dir'
socket.send(resp.encode())
- logger.debug(f"config session pid is {pid_string}")
- logger.debug(f"config session sudo_user is {sudo_user_string}")
-
- try:
- session_out = os.readlink(f"/proc/{pid_string}/fd/1")
- session_mode = 'w'
- except FileNotFoundError:
- session_out = None
-
- # if not a 'live' session, for example on boot, write to file
- if not session_out or not boot_configuration_complete():
- session_out = script_stdout_log
- session_mode = 'a'
+ logger.debug(f'config session pid is {pid_string}')
+ logger.debug(f'config session sudo_user is {sudo_user_string}')
os.environ['SUDO_USER'] = sudo_user_string
if temp_config_dir_string:
@@ -211,8 +202,9 @@ def initialization(socket):
os.environ['VYATTA_CHANGES_ONLY_DIR'] = changes_only_dir_string
try:
- configsource = ConfigSourceString(running_config_text=active_string,
- session_config_text=session_string)
+ configsource = ConfigSourceString(
+ running_config_text=active_string, session_config_text=session_string
+ )
except ConfigSourceError as e:
logger.debug(e)
return None
@@ -229,10 +221,12 @@ def initialization(socket):
return config
-def process_node_data(config, data, last: bool = False) -> int:
+
+def process_node_data(config, data, _last: bool = False) -> tuple[Response, str]:
if not config:
- logger.critical(f"Empty config")
- return R_ERROR_DAEMON
+ out = 'Empty config'
+ logger.critical(out)
+ return Response.ERROR_DAEMON, out
script_name = None
os.environ['VYOS_TAGNODE_VALUE'] = ''
@@ -246,8 +240,9 @@ def process_node_data(config, data, last: bool = False) -> int:
if res.group(2):
script_name = res.group(2)
if not script_name:
- logger.critical(f"Missing script_name")
- return R_ERROR_DAEMON
+ out = 'Missing script_name'
+ logger.critical(out)
+ return Response.ERROR_DAEMON, out
if res.group(3):
args = res.group(3).split()
args.insert(0, f'{script_name}.py')
@@ -259,26 +254,46 @@ def process_node_data(config, data, last: bool = False) -> int:
scripts_called.append(script_record)
if script_name not in include_set:
- return R_PASS
+ return Response.PASS, ''
+
+ with redirect_stdout(io.StringIO()) as o:
+ result, err_out = run_script(script_name, config, args)
+ amb_out = o.getvalue()
+ o.close()
+
+ out = amb_out + err_out
+
+ return result, out
+
+
+def send_result(sock, err, msg):
+ err_no = err.value
+ err_name = err.name
+ msg = msg if msg else ''
+ msg_size = min(MAX_MSG_SIZE, len(msg))
- with stdout_redirected(session_out, session_mode):
- result = run_script(script_name, config, args)
+ err_rep = err_no.to_bytes(1)
+ msg_size_rep = f'{msg_size:#0{PAD_MSG_SIZE}x}'
+
+ logger.debug(f'Sending reply: {err_name} with output')
+ sock.send_multipart([err_rep, msg_size_rep.encode(), msg.encode()])
+
+ write_stdout_log(script_stdout_log, msg)
- return result
def remove_if_file(f: str):
try:
os.remove(f)
except FileNotFoundError:
pass
- except OSError:
- raise
+
def shutdown():
remove_if_file(configd_env_file)
os.symlink(configd_env_unset_file, configd_env_file)
sys.exit(0)
+
if __name__ == '__main__':
context = zmq.Context()
socket = context.socket(zmq.REP)
@@ -294,6 +309,7 @@ if __name__ == '__main__':
os.environ['VYOS_CONFIGD'] = 't'
def sig_handler(signum, frame):
+ # pylint: disable=unused-argument
shutdown()
signal.signal(signal.SIGTERM, sig_handler)
@@ -303,25 +319,33 @@ if __name__ == '__main__':
remove_if_file(configd_env_file)
os.symlink(configd_env_set_file, configd_env_file)
- config = None
+ # We only need one long-lived instance of FRRender
+ frr = FRRender()
+ config = None
while True:
# Wait for next request from client
msg = socket.recv().decode()
- logger.debug(f"Received message: {msg}")
+ logger.debug(f'Received message: {msg}')
message = json.loads(msg)
- if message["type"] == "init":
- resp = "init"
+ if message['type'] == 'init':
+ resp = 'init'
socket.send(resp.encode())
config = initialization(socket)
- elif message["type"] == "node":
- res = process_node_data(config, message["data"], message["last"])
- response = res.to_bytes(1, byteorder=sys.byteorder)
- logger.debug(f"Sending response {res}")
- socket.send(response)
- if message["last"] and config:
+ elif message['type'] == 'node':
+ res, out = process_node_data(config, message['data'], message['last'])
+ send_result(socket, res, out)
+
+ if message['last'] and config:
scripts_called = getattr(config, 'scripts_called', [])
logger.debug(f'scripts_called: {scripts_called}')
+
+ if res == Response.SUCCESS:
+ tmp = get_frrender_dict(config)
+ if frr.generate(tmp):
+ # only apply a new FRR configuration if anything changed
+ # in comparison to the previous applied configuration
+ frr.apply()
else:
- logger.critical(f"Unexpected message: {message}")
+ logger.critical(f'Unexpected message: {message}')
diff --git a/src/services/vyos-conntrack-logger b/src/services/vyos-conntrack-logger
index 9c31b465f..ec0e1f717 100755
--- a/src/services/vyos-conntrack-logger
+++ b/src/services/vyos-conntrack-logger
@@ -15,10 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import argparse
-import grp
import logging
import multiprocessing
-import os
import queue
import signal
import socket
diff --git a/src/services/vyos-domain-resolver b/src/services/vyos-domain-resolver
new file mode 100755
index 000000000..fb18724af
--- /dev/null
+++ b/src/services/vyos-domain-resolver
@@ -0,0 +1,313 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2022-2024 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+import json
+import time
+import logging
+import os
+
+from vyos.configdict import dict_merge
+from vyos.configquery import ConfigTreeQuery
+from vyos.firewall import fqdn_config_parse
+from vyos.firewall import fqdn_resolve
+from vyos.ifconfig import WireGuardIf
+from vyos.remote import download
+from vyos.utils.commit import commit_in_progress
+from vyos.utils.dict import dict_search_args
+from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME
+from vyos.utils.file import makedir, chmod_775, write_file, read_file
+from vyos.utils.network import is_valid_ipv4_address_or_range, is_valid_ipv6_address_or_range
+from vyos.utils.process import cmd
+from vyos.utils.process import run
+from vyos.xml_ref import get_defaults
+
+base = ['firewall']
+timeout = 300
+cache = False
+base_firewall = ['firewall']
+base_nat = ['nat']
+base_interfaces = ['interfaces']
+
+firewall_config_dir = "/config/firewall"
+
+domain_state = {}
+
+ipv4_tables = {
+ 'ip vyos_mangle',
+ 'ip vyos_filter',
+ 'ip vyos_nat',
+ 'ip raw'
+}
+
+ipv6_tables = {
+ 'ip6 vyos_mangle',
+ 'ip6 vyos_filter',
+ 'ip6 raw'
+}
+
+logger = logging.getLogger(__name__)
+logs_handler = logging.StreamHandler()
+logger.addHandler(logs_handler)
+logger.setLevel(logging.INFO)
+
+def get_config(conf, node):
+ node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True,
+ no_tag_node_value_mangle=True)
+
+ default_values = get_defaults(node, get_first_key=True)
+
+ node_config = dict_merge(default_values, node_config)
+
+ if node == base_firewall and 'global_options' in node_config:
+ global_config = node_config['global_options']
+ global timeout, cache
+
+ if 'resolver_interval' in global_config:
+ timeout = int(global_config['resolver_interval'])
+
+ if 'resolver_cache' in global_config:
+ cache = True
+
+ fqdn_config_parse(node_config, node[0])
+
+ return node_config
+
+def resolve(domains, ipv6=False):
+ global domain_state
+
+ ip_list = set()
+
+ for domain in domains:
+ resolved = fqdn_resolve(domain, ipv6=ipv6)
+
+ cache_key = f'{domain}_ipv6' if ipv6 else domain
+
+ if resolved and cache:
+ domain_state[cache_key] = resolved
+ elif not resolved:
+ if cache_key not in domain_state:
+ continue
+ resolved = domain_state[cache_key]
+
+ ip_list = ip_list | resolved
+ return ip_list
+
+def nft_output(table, set_name, ip_list):
+ output = [f'flush set {table} {set_name}']
+ if ip_list:
+ ip_str = ','.join(ip_list)
+ output.append(f'add element {table} {set_name} {{ {ip_str} }}')
+ return output
+
+def nft_valid_sets():
+ try:
+ valid_sets = []
+ sets_json = cmd('nft --json list sets')
+ sets_obj = json.loads(sets_json)
+
+ for obj in sets_obj['nftables']:
+ if 'set' in obj:
+ family = obj['set']['family']
+ table = obj['set']['table']
+ name = obj['set']['name']
+ valid_sets.append((f'{family} {table}', name))
+
+ return valid_sets
+ except:
+ return []
+
+def update_remote_group(config):
+ conf_lines = []
+ count = 0
+ valid_sets = nft_valid_sets()
+
+ remote_groups = dict_search_args(config, 'group', 'remote_group')
+ if remote_groups:
+ # Create directory for list files if necessary
+ if not os.path.isdir(firewall_config_dir):
+ makedir(firewall_config_dir, group='vyattacfg')
+ chmod_775(firewall_config_dir)
+
+ for set_name, remote_config in remote_groups.items():
+ if 'url' not in remote_config:
+ continue
+ nft_ip_set_name = f'R_{set_name}'
+ nft_ip6_set_name = f'R6_{set_name}'
+
+ # Create list file if necessary
+ list_file = os.path.join(firewall_config_dir, f"{nft_ip_set_name}.txt")
+ if not os.path.exists(list_file):
+ write_file(list_file, '', user="root", group="vyattacfg", mode=0o644)
+
+ # Attempt to download file, use cached version if download fails
+ try:
+ download(list_file, remote_config['url'], raise_error=True)
+ except:
+ logger.error(f'Failed to download list-file for {set_name} remote group')
+ logger.info(f'Using cached list-file for {set_name} remote group')
+
+ # Read list file
+ ip_list = []
+ ip6_list = []
+ invalid_list = []
+ for line in read_file(list_file).splitlines():
+ line_first_word = line.strip().partition(' ')[0]
+
+ if is_valid_ipv4_address_or_range(line_first_word):
+ ip_list.append(line_first_word)
+ elif is_valid_ipv6_address_or_range(line_first_word):
+ ip6_list.append(line_first_word)
+ else:
+ if line_first_word[0].isalnum():
+ invalid_list.append(line_first_word)
+
+ # Load ip tables
+ for table in ipv4_tables:
+ if (table, nft_ip_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_ip_set_name, ip_list)
+
+ # Load ip6 tables
+ for table in ipv6_tables:
+ if (table, nft_ip6_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_ip6_set_name, ip6_list)
+
+ invalid_str = ", ".join(invalid_list)
+ if invalid_str:
+ logger.info(f'Invalid address for set {set_name}: {invalid_str}')
+
+ count += 1
+
+ nft_conf_str = "\n".join(conf_lines) + "\n"
+ code = run(f'nft --file -', input=nft_conf_str)
+
+ logger.info(f'Updated {count} remote-groups in firewall - result: {code}')
+
+
+def update_fqdn(config, node):
+ conf_lines = []
+ count = 0
+ valid_sets = nft_valid_sets()
+
+ if node == 'firewall':
+ domain_groups = dict_search_args(config, 'group', 'domain_group')
+ if domain_groups:
+ for set_name, domain_config in domain_groups.items():
+ if 'address' not in domain_config:
+ continue
+ nft_set_name = f'D_{set_name}'
+ domains = domain_config['address']
+
+ ip_list = resolve(domains, ipv6=False)
+ for table in ipv4_tables:
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ ip6_list = resolve(domains, ipv6=True)
+ for table in ipv6_tables:
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip6_list)
+ count += 1
+
+ for set_name, domain in config['ip_fqdn'].items():
+ table = 'ip vyos_filter'
+ nft_set_name = f'FQDN_{set_name}'
+ ip_list = resolve([domain], ipv6=False)
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ count += 1
+
+ for set_name, domain in config['ip6_fqdn'].items():
+ table = 'ip6 vyos_filter'
+ nft_set_name = f'FQDN_{set_name}'
+ ip_list = resolve([domain], ipv6=True)
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ count += 1
+
+ else:
+ # It's NAT
+ for set_name, domain in config['ip_fqdn'].items():
+ table = 'ip vyos_nat'
+ nft_set_name = f'FQDN_nat_{set_name}'
+ ip_list = resolve([domain], ipv6=False)
+ if (table, nft_set_name) in valid_sets:
+ conf_lines += nft_output(table, nft_set_name, ip_list)
+ count += 1
+
+ nft_conf_str = "\n".join(conf_lines) + "\n"
+ code = run(f'nft --file -', input=nft_conf_str)
+
+ logger.info(f'Updated {count} sets in {node} - result: {code}')
+
+def update_interfaces(config, node):
+ if node == 'interfaces':
+ wg_interfaces = dict_search_args(config, 'wireguard')
+ if wg_interfaces:
+
+ peer_public_keys = {}
+ # for each wireguard interfaces
+ for interface, wireguard in wg_interfaces.items():
+ peer_public_keys[interface] = []
+ for peer, peer_config in wireguard['peer'].items():
+ # check peer if peer host-name or address is set
+ if 'host_name' in peer_config or 'address' in peer_config:
+ # check latest handshake
+ peer_public_keys[interface].append(
+ peer_config['public_key']
+ )
+
+ now_time = time.time()
+ for (interface, check_peer_public_keys) in peer_public_keys.items():
+ if len(check_peer_public_keys) == 0:
+ continue
+
+ intf = WireGuardIf(interface, create=False, debug=False)
+ handshakes = intf.operational.get_latest_handshakes()
+
+ # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME
+ # if data is being transmitted between the peers. If no data is
+ # transmitted, the handshake will not be initiated unless new
+ # data begins to flow. Each handshake generates a new session
+ # key, and the key is rotated at least every 120 seconds or
+ # upon data transmission after a prolonged silence.
+ for public_key, handshake_time in handshakes.items():
+ if public_key in check_peer_public_keys and (
+ handshake_time == 0
+ or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME)
+ ):
+ intf.operational.reset_peer(public_key=public_key)
+
+if __name__ == '__main__':
+ logger.info('VyOS domain resolver')
+
+ count = 1
+ while commit_in_progress():
+ if ( count % 60 == 0 ):
+ logger.info(f'Commit still in progress after {count}s - waiting')
+ count += 1
+ time.sleep(1)
+
+ conf = ConfigTreeQuery()
+ firewall = get_config(conf, base_firewall)
+ nat = get_config(conf, base_nat)
+ interfaces = get_config(conf, base_interfaces)
+
+ logger.info(f'interval: {timeout}s - cache: {cache}')
+
+ while True:
+ update_fqdn(firewall, 'firewall')
+ update_fqdn(nat, 'nat')
+ update_remote_group(firewall)
+ update_interfaces(interfaces, 'interfaces')
+ time.sleep(timeout)
diff --git a/src/services/vyos-hostsd b/src/services/vyos-hostsd
index 1ba90471e..44f03586c 100755
--- a/src/services/vyos-hostsd
+++ b/src/services/vyos-hostsd
@@ -233,10 +233,7 @@
# }
import os
-import sys
-import time
import json
-import signal
import traceback
import re
import logging
@@ -245,7 +242,6 @@ import zmq
from voluptuous import Schema, MultipleInvalid, Required, Any
from collections import OrderedDict
from vyos.utils.file import makedir
-from vyos.utils.permission import chown
from vyos.utils.permission import chmod_755
from vyos.utils.process import popen
from vyos.utils.process import process_named_running
diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server
index 97633577d..be3dd5051 100755
--- a/src/services/vyos-http-api-server
+++ b/src/services/vyos-http-api-server
@@ -17,946 +17,135 @@
import os
import sys
import grp
-import copy
import json
import logging
import signal
import traceback
-import threading
-from enum import Enum
-
from time import sleep
-from typing import List, Union, Callable, Dict, Self
+from typing import Annotated
-from fastapi import FastAPI, Depends, Request, Response, HTTPException
-from fastapi import BackgroundTasks
-from fastapi.responses import HTMLResponse
+from fastapi import FastAPI, Query
from fastapi.exceptions import RequestValidationError
-from fastapi.routing import APIRoute
-from pydantic import BaseModel, StrictStr, validator, model_validator
-from starlette.middleware.cors import CORSMiddleware
-from starlette.datastructures import FormData
-from starlette.formparsers import FormParser, MultiPartParser
-from multipart.multipart import parse_options_header
from uvicorn import Config as UvicornConfig
from uvicorn import Server as UvicornServer
-from ariadne.asgi import GraphQL
-
-from vyos.config import Config
-from vyos.configtree import ConfigTree
-from vyos.configdiff import get_config_diff
from vyos.configsession import ConfigSession
-from vyos.configsession import ConfigSessionError
from vyos.defaults import api_config_state
+from vyos.utils.file import read_file
+from vyos.version import get_version
-import api.graphql.state
+from api.session import SessionState
+from api.rest.models import error, InfoQueryParams, success
CFG_GROUP = 'vyattacfg'
debug = True
-logger = logging.getLogger(__name__)
+LOG = logging.getLogger('http_api')
logs_handler = logging.StreamHandler()
-logger.addHandler(logs_handler)
+LOG.addHandler(logs_handler)
if debug:
- logger.setLevel(logging.DEBUG)
+ LOG.setLevel(logging.DEBUG)
else:
- logger.setLevel(logging.INFO)
+ LOG.setLevel(logging.INFO)
-# Giant lock!
-lock = threading.Lock()
def load_server_config():
with open(api_config_state) as f:
config = json.load(f)
return config
-def check_auth(key_list, key):
- key_id = None
- for k in key_list:
- if k['key'] == key:
- key_id = k['id']
- return key_id
-
-def error(code, msg):
- resp = {"success": False, "error": msg, "data": None}
- resp = json.dumps(resp)
- return HTMLResponse(resp, status_code=code)
-
-def success(data):
- resp = {"success": True, "data": data, "error": None}
- resp = json.dumps(resp)
- return HTMLResponse(resp)
-
-# Pydantic models for validation
-# Pydantic will cast when possible, so use StrictStr
-# validators added as needed for additional constraints
-# schema_extra adds anotations to OpenAPI, to add examples
-
-class ApiModel(BaseModel):
- key: StrictStr
-
-class BasePathModel(BaseModel):
- op: StrictStr
- path: List[StrictStr]
-
- @validator("path")
- def check_non_empty(cls, path):
- if not len(path) > 0:
- raise ValueError('path must be non-empty')
- return path
-
-class BaseConfigureModel(BasePathModel):
- value: StrictStr = None
-
-class ConfigureModel(ApiModel, BaseConfigureModel):
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "set | delete | comment",
- "path": ['config', 'mode', 'path'],
- }
- }
-
-class ConfigureListModel(ApiModel):
- commands: List[BaseConfigureModel]
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "commands": "list of commands",
- }
- }
-
-class BaseConfigSectionModel(BasePathModel):
- section: Dict
-
-class ConfigSectionModel(ApiModel, BaseConfigSectionModel):
- pass
-
-class ConfigSectionListModel(ApiModel):
- commands: List[BaseConfigSectionModel]
-
-class BaseConfigSectionTreeModel(BaseModel):
- op: StrictStr
- mask: Dict
- config: Dict
-
-class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel):
- pass
-
-class RetrieveModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
- configFormat: StrictStr = None
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "returnValue | returnValues | exists | showConfig",
- "path": ['config', 'mode', 'path'],
- "configFormat": "json (default) | json_ast | raw",
-
- }
- }
-
-class ConfigFileModel(ApiModel):
- op: StrictStr
- file: StrictStr = None
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "save | load",
- "file": "filename",
- }
- }
-
-
-class ImageOp(str, Enum):
- add = "add"
- delete = "delete"
- show = "show"
- set_default = "set_default"
-
-
-class ImageModel(ApiModel):
- op: ImageOp
- url: StrictStr = None
- name: StrictStr = None
-
- @model_validator(mode='after')
- def check_data(self) -> Self:
- if self.op == 'add':
- if not self.url:
- raise ValueError("Missing required field \"url\"")
- elif self.op in ['delete', 'set_default']:
- if not self.name:
- raise ValueError("Missing required field \"name\"")
-
- return self
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "add | delete | show | set_default",
- "url": "imagelocation",
- "name": "imagename",
- }
- }
-
-class ImportPkiModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
- passphrase: StrictStr = None
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "import_pki",
- "path": ["op", "mode", "path"],
- "passphrase": "passphrase",
- }
- }
-
-
-class ContainerImageModel(ApiModel):
- op: StrictStr
- name: StrictStr = None
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "add | delete | show",
- "name": "imagename",
- }
- }
-
-class GenerateModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "generate",
- "path": ["op", "mode", "path"],
- }
- }
-
-class ShowModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "show",
- "path": ["op", "mode", "path"],
- }
- }
-
-class RebootModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "reboot",
- "path": ["op", "mode", "path"],
- }
- }
-
-class ResetModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "reset",
- "path": ["op", "mode", "path"],
- }
- }
-
-class PoweroffModel(ApiModel):
- op: StrictStr
- path: List[StrictStr]
-
- class Config:
- schema_extra = {
- "example": {
- "key": "id_key",
- "op": "poweroff",
- "path": ["op", "mode", "path"],
- }
- }
-
-
-class Success(BaseModel):
- success: bool
- data: Union[str, bool, Dict]
- error: str
-
-class Error(BaseModel):
- success: bool = False
- data: Union[str, bool, Dict]
- error: str
-
-responses = {
- 200: {'model': Success},
- 400: {'model': Error},
- 422: {'model': Error, 'description': 'Validation Error'},
- 500: {'model': Error}
-}
-
-def auth_required(data: ApiModel):
- key = data.key
- api_keys = app.state.vyos_keys
- key_id = check_auth(api_keys, key)
- if not key_id:
- raise HTTPException(status_code=401, detail="Valid API key is required")
- app.state.vyos_id = key_id
-
-# override Request and APIRoute classes in order to convert form request to json;
-# do all explicit validation here, for backwards compatability of error messages;
-# the explicit validation may be dropped, if desired, in favor of native
-# validation by FastAPI/Pydantic, as is used for application/json requests
-class MultipartRequest(Request):
- _form_err = ()
- @property
- def form_err(self):
- return self._form_err
-
- @form_err.setter
- def form_err(self, val):
- if not self._form_err:
- self._form_err = val
-
- @property
- def orig_headers(self):
- self._orig_headers = super().headers
- return self._orig_headers
-
- @property
- def headers(self):
- self._headers = super().headers.mutablecopy()
- self._headers['content-type'] = 'application/json'
- return self._headers
-
- async def form(self) -> FormData:
- if self._form is None:
- assert (
- parse_options_header is not None
- ), "The `python-multipart` library must be installed to use form parsing."
- content_type_header = self.orig_headers.get("Content-Type")
- content_type, options = parse_options_header(content_type_header)
- if content_type == b"multipart/form-data":
- multipart_parser = MultiPartParser(self.orig_headers, self.stream())
- self._form = await multipart_parser.parse()
- elif content_type == b"application/x-www-form-urlencoded":
- form_parser = FormParser(self.orig_headers, self.stream())
- self._form = await form_parser.parse()
- else:
- self._form = FormData()
- return self._form
-
- async def body(self) -> bytes:
- if not hasattr(self, "_body"):
- forms = {}
- merge = {}
- body = await super().body()
- self._body = body
-
- form_data = await self.form()
- if form_data:
- endpoint = self.url.path
- logger.debug("processing form data")
- for k, v in form_data.multi_items():
- forms[k] = v
-
- if 'data' not in forms:
- self.form_err = (422, "Non-empty data field is required")
- return self._body
- else:
- try:
- tmp = json.loads(forms['data'])
- except json.JSONDecodeError as e:
- self.form_err = (400, f'Failed to parse JSON: {e}')
- return self._body
- if isinstance(tmp, list):
- merge['commands'] = tmp
- else:
- merge = tmp
-
- if 'commands' in merge:
- cmds = merge['commands']
- else:
- cmds = copy.deepcopy(merge)
- cmds = [cmds]
-
- for c in cmds:
- if not isinstance(c, dict):
- self.form_err = (400,
- f"Malformed command '{c}': any command must be JSON of dict")
- return self._body
- if 'op' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'op' field")
- if endpoint not in ('/config-file', '/container-image',
- '/image', '/configure-section'):
- if 'path' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'path' field")
- elif not isinstance(c['path'], list):
- self.form_err = (400,
- f"Malformed command '{c}': 'path' field must be a list")
- elif not all(isinstance(el, str) for el in c['path']):
- self.form_err = (400,
- f"Malformed command '{0}': 'path' field must be a list of strings")
- if endpoint in ('/configure'):
- if not c['path']:
- self.form_err = (400,
- f"Malformed command '{c}': 'path' list must be non-empty")
- if 'value' in c and not isinstance(c['value'], str):
- self.form_err = (400,
- f"Malformed command '{c}': 'value' field must be a string")
- if endpoint in ('/configure-section'):
- if 'section' not in c and 'config' not in c:
- self.form_err = (400,
- f"Malformed command '{c}': missing 'section' or 'config' field")
-
- if 'key' not in forms and 'key' not in merge:
- self.form_err = (401, "Valid API key is required")
- if 'key' in forms and 'key' not in merge:
- merge['key'] = forms['key']
-
- new_body = json.dumps(merge)
- new_body = new_body.encode()
- self._body = new_body
-
- return self._body
-
-class MultipartRoute(APIRoute):
- def get_route_handler(self) -> Callable:
- original_route_handler = super().get_route_handler()
-
- async def custom_route_handler(request: Request) -> Response:
- request = MultipartRequest(request.scope, request.receive)
- try:
- response: Response = await original_route_handler(request)
- except HTTPException as e:
- return error(e.status_code, e.detail)
- except Exception as e:
- form_err = request.form_err
- if form_err:
- return error(*form_err)
- raise e
-
- return response
-
- return custom_route_handler
app = FastAPI(debug=True,
title="VyOS API",
- version="0.1.0",
- responses={**responses},
- dependencies=[Depends(auth_required)])
+ version="0.1.0")
-app.router.route_class = MultipartRoute
@app.exception_handler(RequestValidationError)
-async def validation_exception_handler(request, exc):
+async def validation_exception_handler(_request, exc):
return error(400, str(exc.errors()[0]))
-self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background"
-
-def call_commit(s: ConfigSession):
- try:
- s.commit()
- except ConfigSessionError as e:
- s.discard()
- if app.state.vyos_debug:
- logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}")
- else:
- logger.warning(f"ConfigSessionError: {e}")
-
-def _configure_op(data: Union[ConfigureModel, ConfigureListModel,
- ConfigSectionModel, ConfigSectionListModel,
- ConfigSectionTreeModel],
- request: Request, background_tasks: BackgroundTasks):
- session = app.state.vyos_session
- env = session.get_session_env()
-
- endpoint = request.url.path
-
- # Allow users to pass just one command
- if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)):
- data = [data]
- else:
- data = data.commands
-
- # We don't want multiple people/apps to be able to commit at once,
- # or modify the shared session while someone else is doing the same,
- # so the lock is really global
- lock.acquire()
-
- config = Config(session_env=env)
-
- status = 200
- msg = None
- error_msg = None
- try:
- for c in data:
- op = c.op
- if not isinstance(c, BaseConfigSectionTreeModel):
- path = c.path
-
- if isinstance(c, BaseConfigureModel):
- if c.value:
- value = c.value
- else:
- value = ""
- # For vyos.configsession calls that have no separate value arguments,
- # and for type checking too
- cfg_path = " ".join(path + [value]).strip()
-
- elif isinstance(c, BaseConfigSectionModel):
- section = c.section
-
- elif isinstance(c, BaseConfigSectionTreeModel):
- mask = c.mask
- config = c.config
-
- if isinstance(c, BaseConfigureModel):
- if op == 'set':
- session.set(path, value=value)
- elif op == 'delete':
- if app.state.vyos_strict and not config.exists(cfg_path):
- raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist")
- session.delete(path, value=value)
- elif op == 'comment':
- session.comment(path, value=value)
- else:
- raise ConfigSessionError(f"'{op}' is not a valid operation")
-
- elif isinstance(c, BaseConfigSectionModel):
- if op == 'set':
- session.set_section(path, section)
- elif op == 'load':
- session.load_section(path, section)
- else:
- raise ConfigSessionError(f"'{op}' is not a valid operation")
-
- elif isinstance(c, BaseConfigSectionTreeModel):
- if op == 'set':
- session.set_section_tree(config)
- elif op == 'load':
- session.load_section_tree(mask, config)
- else:
- raise ConfigSessionError(f"'{op}' is not a valid operation")
- # end for
- config = Config(session_env=env)
- d = get_config_diff(config)
-
- if d.is_node_changed(['service', 'https']):
- background_tasks.add_task(call_commit, session)
- msg = self_ref_msg
- else:
- session.commit()
-
- logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'")
- except ConfigSessionError as e:
- session.discard()
- status = 400
- if app.state.vyos_debug:
- logger.critical(f"ConfigSessionError:\n {traceback.format_exc()}")
- error_msg = str(e)
- except Exception as e:
- session.discard()
- logger.critical(traceback.format_exc())
- status = 500
-
- # Don't give the details away to the outer world
- error_msg = "An internal error occured. Check the logs for details."
- finally:
- lock.release()
-
- if status != 200:
- return error(status, error_msg)
-
- return success(msg)
-
-def create_path_import_pki_no_prompt(path):
- correct_paths = ['ca', 'certificate', 'key-pair']
- if path[1] not in correct_paths:
- return False
- path[1] = '--' + path[1].replace('-', '')
- path[3] = '--key-filename'
- return path[1:]
-
-@app.post('/configure')
-def configure_op(data: Union[ConfigureModel,
- ConfigureListModel],
- request: Request, background_tasks: BackgroundTasks):
- return _configure_op(data, request, background_tasks)
-
-@app.post('/configure-section')
-def configure_section_op(data: Union[ConfigSectionModel,
- ConfigSectionListModel,
- ConfigSectionTreeModel],
- request: Request, background_tasks: BackgroundTasks):
- return _configure_op(data, request, background_tasks)
-@app.post("/retrieve")
-async def retrieve_op(data: RetrieveModel):
- session = app.state.vyos_session
- env = session.get_session_env()
- config = Config(session_env=env)
+@app.get('/info')
+def info(q: Annotated[InfoQueryParams, Query()]):
+ show_version = q.version
+ show_hostname = q.hostname
- op = data.op
- path = " ".join(data.path)
+ prelogin_file = r'/etc/issue'
+ hostname_file = r'/etc/hostname'
+ default = 'Welcome to VyOS'
try:
- if op == 'returnValue':
- res = config.return_value(path)
- elif op == 'returnValues':
- res = config.return_values(path)
- elif op == 'exists':
- res = config.exists(path)
- elif op == 'showConfig':
- config_format = 'json'
- if data.configFormat:
- config_format = data.configFormat
-
- res = session.show_config(path=data.path)
- if config_format == 'json':
- config_tree = ConfigTree(res)
- res = json.loads(config_tree.to_json())
- elif config_format == 'json_ast':
- config_tree = ConfigTree(res)
- res = json.loads(config_tree.to_json_ast())
- elif config_format == 'raw':
- pass
- else:
- return error(400, f"'{config_format}' is not a valid config format")
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/config-file')
-def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):
- session = app.state.vyos_session
- env = session.get_session_env()
- op = data.op
- msg = None
-
- try:
- if op == 'save':
- if data.file:
- path = data.file
- else:
- path = '/config/config.boot'
- msg = session.save_config(path)
- elif op == 'load':
- if data.file:
- path = data.file
- else:
- return error(400, "Missing required field \"file\"")
-
- session.migrate_and_load_config(path)
-
- config = Config(session_env=env)
- d = get_config_diff(config)
-
- if d.is_node_changed(['service', 'https']):
- background_tasks.add_task(call_commit, session)
- msg = self_ref_msg
- else:
- session.commit()
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(msg)
-
-@app.post('/image')
-def image_op(data: ImageModel):
- session = app.state.vyos_session
-
- op = data.op
-
- try:
- if op == 'add':
- res = session.install_image(data.url)
- elif op == 'delete':
- res = session.remove_image(data.name)
- elif op == 'show':
- res = session.show(["system", "image"])
- elif op == 'set_default':
- res = session.set_default_image(data.name)
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/container-image')
-def container_image_op(data: ContainerImageModel):
- session = app.state.vyos_session
-
- op = data.op
-
- try:
- if op == 'add':
- if data.name:
- name = data.name
- else:
- return error(400, "Missing required field \"name\"")
- res = session.add_container_image(name)
- elif op == 'delete':
- if data.name:
- name = data.name
- else:
- return error(400, "Missing required field \"name\"")
- res = session.delete_container_image(name)
- elif op == 'show':
- res = session.show_container_image()
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/generate')
-def generate_op(data: GenerateModel):
- session = app.state.vyos_session
-
- op = data.op
- path = data.path
-
- try:
- if op == 'generate':
- res = session.generate(path)
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/show')
-def show_op(data: ShowModel):
- session = app.state.vyos_session
-
- op = data.op
- path = data.path
-
- try:
- if op == 'show':
- res = session.show(path)
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/reboot')
-def reboot_op(data: RebootModel):
- session = app.state.vyos_session
-
- op = data.op
- path = data.path
-
- try:
- if op == 'reboot':
- res = session.reboot(path)
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/reset')
-def reset_op(data: ResetModel):
- session = app.state.vyos_session
-
- op = data.op
- path = data.path
-
- try:
- if op == 'reset':
- res = session.reset(path)
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
-
- return success(res)
-
-@app.post('/import-pki')
-def import_pki(data: ImportPkiModel):
- session = app.state.vyos_session
-
- op = data.op
- path = data.path
-
- lock.acquire()
-
- try:
- if op == 'import-pki':
- # need to get rid or interactive mode for private key
- if len(path) == 5 and path[3] in ['key-file', 'private-key']:
- path_no_prompt = create_path_import_pki_no_prompt(path)
- if not path_no_prompt:
- return error(400, f"Invalid command: {' '.join(path)}")
- if data.passphrase:
- path_no_prompt += ['--passphrase', data.passphrase]
- res = session.import_pki_no_prompt(path_no_prompt)
- else:
- res = session.import_pki(path)
- if not res[0].isdigit():
- return error(400, res)
- # commit changes
- session.commit()
- res = res.split('. ')[0]
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
- finally:
- lock.release()
+ res = {
+ 'banner': '',
+ 'hostname': '',
+ 'version': ''
+ }
+ if show_version:
+ res.update(version=get_version())
- return success(res)
+ if show_hostname:
+ try:
+ hostname = read_file(hostname_file)
+ except Exception:
+ hostname = 'vyos'
+ res.update(hostname=hostname)
-@app.post('/poweroff')
-def poweroff_op(data: PoweroffModel):
- session = app.state.vyos_session
+ banner = read_file(prelogin_file, defaultonfailure=default)
+ if banner == f'{default} - \\n \\l':
+ banner = banner.partition(default)[1]
- op = data.op
- path = data.path
-
- try:
- if op == 'poweroff':
- res = session.poweroff(path)
- else:
- return error(400, f"'{op}' is not a valid operation")
- except ConfigSessionError as e:
- return error(400, str(e))
- except Exception as e:
- logger.critical(traceback.format_exc())
- return error(500, "An internal error occured. Check the logs for details.")
+ res.update(banner=banner)
+ except Exception:
+ LOG.critical(traceback.format_exc())
+ return error(500, 'An internal error occured. Check the logs for details.')
return success(res)
###
-# GraphQL integration
-###
-
-def graphql_init(app: FastAPI = app):
- from api.graphql.libs.token_auth import get_user_context
- api.graphql.state.init()
- api.graphql.state.settings['app'] = app
-
- # import after initializaion of state
- from api.graphql.bindings import generate_schema
- schema = generate_schema()
-
- in_spec = app.state.vyos_introspection
-
- if app.state.vyos_origins:
- origins = app.state.vyos_origins
- app.add_route('/graphql', CORSMiddleware(GraphQL(schema,
- context_value=get_user_context,
- debug=True,
- introspection=in_spec),
- allow_origins=origins,
- allow_methods=("GET", "POST", "OPTIONS"),
- allow_headers=("Authorization",)))
- else:
- app.add_route('/graphql', GraphQL(schema,
- context_value=get_user_context,
- debug=True,
- introspection=in_spec))
-###
# Modify uvicorn to allow reloading server within the configsession
###
server = None
shutdown = False
+
class ApiServerConfig(UvicornConfig):
pass
+
class ApiServer(UvicornServer):
def install_signal_handlers(self):
pass
+
def reload_handler(signum, frame):
+ # pylint: disable=global-statement
+
global server
- logger.debug('Reload signal received...')
+ LOG.debug('Reload signal received...')
if server is not None:
server.handle_exit(signum, frame)
server = None
- logger.info('Server stopping for reload...')
+ LOG.info('Server stopping for reload...')
else:
- logger.warning('Reload called for non-running server...')
+ LOG.warning('Reload called for non-running server...')
+
def shutdown_handler(signum, frame):
+ # pylint: disable=global-statement
+
global shutdown
- logger.debug('Shutdown signal received...')
+ LOG.debug('Shutdown signal received...')
server.handle_exit(signum, frame)
- logger.info('Server shutdown...')
+ LOG.info('Server shutdown...')
shutdown = True
+# end modify uvicorn
+
+
def flatten_keys(d: dict) -> list[dict]:
keys_list = []
for el in list(d['keys'].get('id', {})):
@@ -965,49 +154,87 @@ def flatten_keys(d: dict) -> list[dict]:
keys_list.append({'id': el, 'key': key})
return keys_list
-def initialization(session: ConfigSession, app: FastAPI = app):
+
+def regenerate_docs(app: FastAPI) -> None:
+ docs = ('/openapi.json', '/docs', '/docs/oauth2-redirect', '/redoc')
+ remove = []
+ for r in app.routes:
+ if r.path in docs:
+ remove.append(r)
+ for r in remove:
+ app.routes.remove(r)
+
+ app.openapi_schema = None
+ app.setup()
+
+
+def initialization(session: SessionState, app: FastAPI = app):
+ # pylint: disable=global-statement,broad-exception-caught,import-outside-toplevel
+
global server
try:
server_config = load_server_config()
except Exception as e:
- logger.critical(f'Failed to load the HTTP API server config: {e}')
+ LOG.critical(f'Failed to load the HTTP API server config: {e}')
sys.exit(1)
- app.state.vyos_session = session
- app.state.vyos_keys = []
-
if 'keys' in server_config:
- app.state.vyos_keys = flatten_keys(server_config)
+ session.keys = flatten_keys(server_config)
+
+ rest_config = server_config.get('rest', {})
+ session.debug = bool('debug' in rest_config)
+ session.strict = bool('strict' in rest_config)
+
+ graphql_config = server_config.get('graphql', {})
+ session.origins = graphql_config.get('cors', {}).get('allow_origin', [])
+
+ if 'rest' in server_config:
+ session.rest = True
+ else:
+ session.rest = False
- app.state.vyos_debug = bool('debug' in server_config)
- app.state.vyos_strict = bool('strict' in server_config)
- app.state.vyos_origins = server_config.get('cors', {}).get('allow_origin', [])
if 'graphql' in server_config:
- app.state.vyos_graphql = True
+ session.graphql = True
if isinstance(server_config['graphql'], dict):
if 'introspection' in server_config['graphql']:
- app.state.vyos_introspection = True
+ session.introspection = True
else:
- app.state.vyos_introspection = False
+ session.introspection = False
# default values if not set explicitly
- app.state.vyos_auth_type = server_config['graphql']['authentication']['type']
- app.state.vyos_token_exp = server_config['graphql']['authentication']['expiration']
- app.state.vyos_secret_len = server_config['graphql']['authentication']['secret_length']
+ session.auth_type = server_config['graphql']['authentication']['type']
+ session.token_exp = server_config['graphql']['authentication']['expiration']
+ session.secret_len = server_config['graphql']['authentication']['secret_length']
else:
- app.state.vyos_graphql = False
+ session.graphql = False
+
+ # pass session state
+ app.state = session
- if app.state.vyos_graphql:
+ # add REST routes
+ if session.rest:
+ from api.rest.routers import rest_init
+ rest_init(app)
+ else:
+ from api.rest.routers import rest_clear
+ rest_clear(app)
+
+ # add GraphQL route
+ if session.graphql:
+ from api.graphql.routers import graphql_init
graphql_init(app)
+ else:
+ from api.graphql.routers import graphql_clear
+ graphql_clear(app)
+
+ regenerate_docs(app)
+
+ LOG.debug('Active routes are:')
+ for r in app.routes:
+ LOG.debug(f'{r.path}')
config = ApiServerConfig(app, uds="/run/api.sock", proxy_headers=True)
server = ApiServer(config)
-def run_server():
- try:
- server.run()
- except OSError as e:
- logger.critical(e)
- sys.exit(1)
if __name__ == '__main__':
# systemd's user and group options don't work, do it by hand here,
@@ -1022,13 +249,14 @@ if __name__ == '__main__':
signal.signal(signal.SIGHUP, reload_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
- config_session = ConfigSession(os.getpid())
+ session_state = SessionState()
+ session_state.session = ConfigSession(os.getpid())
while True:
- logger.debug('Enter main loop...')
+ LOG.debug('Enter main loop...')
if shutdown:
break
if server is None:
- initialization(config_session)
+ initialization(session_state)
server.run()
sleep(1)
diff --git a/src/services/vyos-network-event-logger b/src/services/vyos-network-event-logger
new file mode 100644
index 000000000..840ff3cda
--- /dev/null
+++ b/src/services/vyos-network-event-logger
@@ -0,0 +1,1218 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2025 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+import argparse
+import logging
+import multiprocessing
+import queue
+import signal
+import socket
+import threading
+from pathlib import Path
+from time import sleep
+from typing import Dict, AnyStr, List, Union
+
+from pyroute2.common import AF_MPLS
+from pyroute2.iproute import IPRoute
+from pyroute2.netlink import rtnl, nlmsg
+from pyroute2.netlink.nfnetlink.nfctsocket import nfct_msg
+from pyroute2.netlink.rtnl import (rt_proto as RT_PROTO, rt_type as RT_TYPES,
+ rtypes as RTYPES
+ )
+from pyroute2.netlink.rtnl.fibmsg import FR_ACT_GOTO, FR_ACT_NOP, FR_ACT_TO_TBL, \
+ fibmsg
+from pyroute2.netlink.rtnl import ifaddrmsg
+from pyroute2.netlink.rtnl import ifinfmsg
+from pyroute2.netlink.rtnl import ndmsg
+from pyroute2.netlink.rtnl import rtmsg
+from pyroute2.netlink.rtnl.rtmsg import nh, rtmsg_base
+
+from vyos.include.uapi.linux.fib_rules import *
+from vyos.include.uapi.linux.icmpv6 import *
+from vyos.include.uapi.linux.if_arp import *
+from vyos.include.uapi.linux.lwtunnel import *
+from vyos.include.uapi.linux.neighbour import *
+from vyos.include.uapi.linux.rtnetlink import *
+
+from vyos.utils.file import read_json
+
+
+manager = multiprocessing.Manager()
+cache = manager.dict()
+
+
+class UnsupportedMessageType(Exception):
+ pass
+
+shutdown_event = multiprocessing.Event()
+
+logging.basicConfig(level=logging.INFO, format='%(message)s')
+logger = logging.getLogger(__name__)
+
+
+class DebugFormatter(logging.Formatter):
+ def format(self, record):
+ self._style._fmt = '[%(asctime)s] %(levelname)s: %(message)s'
+ return super().format(record)
+
+
+def set_log_level(level: str) -> None:
+ if level == 'debug':
+ logger.setLevel(logging.DEBUG)
+ logger.parent.handlers[0].setFormatter(DebugFormatter())
+ else:
+ logger.setLevel(logging.INFO)
+
+IFF_FLAGS = {
+ 'RUNNING': ifinfmsg.IFF_RUNNING,
+ 'LOOPBACK': ifinfmsg.IFF_LOOPBACK,
+ 'BROADCAST': ifinfmsg.IFF_BROADCAST,
+ 'POINTOPOINT': ifinfmsg.IFF_POINTOPOINT,
+ 'MULTICAST': ifinfmsg.IFF_MULTICAST,
+ 'NOARP': ifinfmsg.IFF_NOARP,
+ 'ALLMULTI': ifinfmsg.IFF_ALLMULTI,
+ 'PROMISC': ifinfmsg.IFF_PROMISC,
+ 'MASTER': ifinfmsg.IFF_MASTER,
+ 'SLAVE': ifinfmsg.IFF_SLAVE,
+ 'DEBUG': ifinfmsg.IFF_DEBUG,
+ 'DYNAMIC': ifinfmsg.IFF_DYNAMIC,
+ 'AUTOMEDIA': ifinfmsg.IFF_AUTOMEDIA,
+ 'PORTSEL': ifinfmsg.IFF_PORTSEL,
+ 'NOTRAILERS': ifinfmsg.IFF_NOTRAILERS,
+ 'UP': ifinfmsg.IFF_UP,
+ 'LOWER_UP': ifinfmsg.IFF_LOWER_UP,
+ 'DORMANT': ifinfmsg.IFF_DORMANT,
+ 'ECHO': ifinfmsg.IFF_ECHO,
+}
+
+NEIGH_STATE_FLAGS = {
+ 'INCOMPLETE': ndmsg.NUD_INCOMPLETE,
+ 'REACHABLE': ndmsg.NUD_REACHABLE,
+ 'STALE': ndmsg.NUD_STALE,
+ 'DELAY': ndmsg.NUD_DELAY,
+ 'PROBE': ndmsg.NUD_PROBE,
+ 'FAILED': ndmsg.NUD_FAILED,
+ 'NOARP': ndmsg.NUD_NOARP,
+ 'PERMANENT': ndmsg.NUD_PERMANENT,
+}
+
+IFA_FLAGS = {
+ 'secondary': ifaddrmsg.IFA_F_SECONDARY,
+ 'temporary': ifaddrmsg.IFA_F_SECONDARY,
+ 'nodad': ifaddrmsg.IFA_F_NODAD,
+ 'optimistic': ifaddrmsg.IFA_F_OPTIMISTIC,
+ 'dadfailed': ifaddrmsg.IFA_F_DADFAILED,
+ 'home': ifaddrmsg.IFA_F_HOMEADDRESS,
+ 'deprecated': ifaddrmsg.IFA_F_DEPRECATED,
+ 'tentative': ifaddrmsg.IFA_F_TENTATIVE,
+ 'permanent': ifaddrmsg.IFA_F_PERMANENT,
+ 'mngtmpaddr': ifaddrmsg.IFA_F_MANAGETEMPADDR,
+ 'noprefixroute': ifaddrmsg.IFA_F_NOPREFIXROUTE,
+ 'autojoin': ifaddrmsg.IFA_F_MCAUTOJOIN,
+ 'stable-privacy': ifaddrmsg.IFA_F_STABLE_PRIVACY,
+}
+
+RT_SCOPE_TO_NAME = {
+ rtmsg.RT_SCOPE_UNIVERSE: 'global',
+ rtmsg.RT_SCOPE_SITE: 'site',
+ rtmsg.RT_SCOPE_LINK: 'link',
+ rtmsg.RT_SCOPE_HOST: 'host',
+ rtmsg.RT_SCOPE_NOWHERE: 'nowhere',
+}
+
+FAMILY_TO_NAME = {
+ socket.AF_INET: 'inet',
+ socket.AF_INET6: 'inet6',
+ socket.AF_PACKET: 'link',
+ AF_MPLS: 'mpls',
+ socket.AF_BRIDGE: 'bridge',
+}
+
+_INFINITY = 4294967295
+
+
+def _get_iif_name(idx: int) -> str:
+ """
+ Retrieves the interface name associated with a given index.
+ """
+ try:
+ if_info = IPRoute().link("get", index=idx)
+ if if_info:
+ return if_info[0].get_attr('IFLA_IFNAME')
+ except Exception as e:
+ pass
+
+ return ''
+
+
+def remember_if_index(idx: int, event_type: int) -> None:
+ """
+ Manages the caching of network interface names based on their index and event type.
+
+ - For RTM_DELLINK event, the interface name is removed from the cache if exists.
+ - For RTM_NEWLINK event, the interface name is retrieved and updated in the cache.
+ """
+ name = cache.get(idx)
+ if name:
+ if event_type == rtnl.RTM_DELLINK:
+ del cache[idx]
+ else:
+ name = _get_iif_name(idx)
+ if name:
+ cache[idx] = name
+ else:
+ cache[idx] = _get_iif_name(idx)
+
+
+class BaseFormatter:
+ """
+ A base class providing utility methods for formatting network message data.
+ """
+ def _get_if_name_by_index(self, idx: int) -> str:
+ """
+ Retrieves the name of a network interface based on its index.
+
+ Uses a cached lookup for efficiency. If the name is not found in the cache,
+ it queries the system and updates the cache.
+ """
+ if_name = cache.get(idx)
+ if not if_name:
+ if_name = _get_iif_name(idx)
+ cache[idx] = if_name
+
+ return if_name
+
+ def _format_rttable(self, idx: int) -> str:
+ """
+ Formats a route table identifier into a readable name.
+ """
+ return f'{RT_TABLE_TO_NAME.get(idx, idx)}'
+
+ def _parse_flag(self, data: int, flags: dict) -> list:
+ """
+ Extracts and returns flag names equal the bits set in a numeric value.
+ """
+ result = list()
+ if data:
+ for key, val in flags.items():
+ if data & val:
+ result.append(key)
+ data &= ~val
+
+ if data:
+ result.append(f"{data:#x}")
+
+ return result
+
+ def af_bit_len(self, af: int) -> int:
+ """
+ Gets the bit length of a given address family.
+ Supports common address families like IPv4, IPv6, and MPLS.
+ """
+ _map = {
+ socket.AF_INET6: 128,
+ socket.AF_INET: 32,
+ AF_MPLS: 20,
+ }
+
+ return _map.get(af)
+
+ def _format_simple_field(self, data: str, prefix: str='') -> str:
+ """
+ Formats a simple field with an optional prefix.
+
+ A simple field represents a value that does not require additional
+ parsing and is used as is.
+ """
+ return self._output(f'{prefix} {data}') if data is not None else ''
+
+ def _output(self, data: str) -> str:
+ """
+ Standardizes the output format.
+
+ Ensures that the output is enclosed with single spaces and has no leading
+ or trailing whitespace.
+ """
+ return f' {data.strip()} ' if data else ''
+
+
+class BaseMSGFormatter(BaseFormatter):
+ """
+ A base formatter class for network messages.
+ This class provides common methods for formatting network-related messages,
+ """
+
+ def _prepare_start_message(self, event: str) -> str:
+ """
+ Prepares a starting message string based on the event type.
+ """
+ if event in ['RTM_DELROUTE', 'RTM_DELLINK', 'RTM_DELNEIGH',
+ 'RTM_DELADDR', 'RTM_DELADDRLABEL', 'RTM_DELRULE',
+ 'RTM_DELNETCONF']:
+ return 'Deleted '
+ if event == 'RTM_GETNEIGH':
+ return 'Miss '
+ return ''
+
+ def _format_flow_field(self, data: int) -> str:
+ """
+ Formats a flow field to represent traffic realms.
+ """
+ to = data & 0xFFFF
+ from_ = data >> 16
+ result = f"realm{'s' if from_ else ''} "
+ if from_:
+ result += f'{from_}/'
+ result += f'{to}'
+
+ return self._output(result)
+
+ def format(self, msg: nlmsg) -> str:
+ """
+ Abstract method to format a complete message.
+
+ This method must be implemented by subclasses to provide specific formatting
+ logic for different types of messages.
+ """
+ raise NotImplementedError(f'{msg.get("event")}: {msg}')
+
+
+class LinkFormatter(BaseMSGFormatter):
+ """
+ A formatter class for handling link-related network messages
+ `RTM_NEWLINK` and `RTM_DELLINK`.
+ """
+ def _format_iff_flags(self, flags: int) -> str:
+ """
+ Formats interface flags into a human-readable string.
+ """
+ result = list()
+ if flags:
+ if flags & IFF_FLAGS['UP'] and not flags & IFF_FLAGS['RUNNING']:
+ result.append('NO-CARRIER')
+
+ flags &= ~IFF_FLAGS['RUNNING']
+
+ result.extend(self._parse_flag(flags, IFF_FLAGS))
+
+ return self._output(f'<{(",").join(result)}>')
+
+ def _format_if_props(self, data: ifinfmsg.ifinfbase.proplist) -> str:
+ """
+ Formats interface alternative name properties.
+ """
+ result = ''
+ for rec in data.altnames():
+ result += f'[altname {rec}] '
+ return self._output(result)
+
+ def _format_link(self, msg: ifinfmsg.ifinfmsg) -> str:
+ """
+ Formats the link attribute of a network interface message.
+ """
+ if msg.get_attr("IFLA_LINK") is not None:
+ iflink = msg.get_attr("IFLA_LINK")
+ if iflink:
+ if msg.get_attr("IFLA_LINK_NETNSID"):
+ return f'if{iflink}'
+ else:
+ return self._get_if_name_by_index(iflink)
+ return 'NONE'
+
+ def _format_link_info(self, msg: ifinfmsg.ifinfmsg) -> str:
+ """
+ Formats detailed information about the link, including type, address,
+ broadcast address, and permanent address.
+ """
+ result = f'link/{ARPHRD_TO_NAME.get(msg.get("ifi_type"), msg.get("ifi_type"))}'
+ result += self._format_simple_field(msg.get_attr('IFLA_ADDRESS'))
+
+ if msg.get_attr("IFLA_BROADCAST"):
+ if msg.get('flags') & ifinfmsg.IFF_POINTOPOINT:
+ result += f' peer'
+ else:
+ result += f' brd'
+ result += f' {msg.get_attr("IFLA_BROADCAST")}'
+
+ if msg.get_attr("IFLA_PERM_ADDRESS"):
+ if not msg.get_attr("IFLA_ADDRESS") or \
+ msg.get_attr("IFLA_ADDRESS") != msg.get_attr("IFLA_PERM_ADDRESS"):
+ result += f' permaddr {msg.get_attr("IFLA_PERM_ADDRESS")}'
+
+ return self._output(result)
+
+ def format(self, msg: ifinfmsg.ifinfmsg):
+ """
+ Formats a network link message into a structured output string.
+ """
+ if msg.get("family") not in [socket.AF_UNSPEC, socket.AF_BRIDGE]:
+ return None
+
+ message = self._prepare_start_message(msg.get('event'))
+
+ link = self._format_link(msg)
+
+ message += f'{msg.get("index")}: {msg.get_attr("IFLA_IFNAME")}'
+ message += f'@{link}' if link else ''
+ message += f': {self._format_iff_flags(msg.get("flags"))}'
+
+ message += self._format_simple_field(msg.get_attr('IFLA_MTU'), prefix='mtu')
+ message += self._format_simple_field(msg.get_attr('IFLA_QDISC'), prefix='qdisc')
+ message += self._format_simple_field(msg.get_attr('IFLA_OPERSTATE'), prefix='state')
+ message += self._format_simple_field(msg.get_attr('IFLA_GROUP'), prefix='group')
+ message += self._format_simple_field(msg.get_attr('IFLA_MASTER'), prefix='master')
+
+ message += self._format_link_info(msg)
+
+ if msg.get_attr('IFLA_PROP_LIST'):
+ message += self._format_if_props(msg.get_attr('IFLA_PROP_LIST'))
+
+ return self._output(message)
+
+
+class EncapFormatter(BaseFormatter):
+ """
+ A formatter class for handling encapsulation attributes in routing messages.
+ """
+ # TODO: implement other lwtunnel decoder in pyroute2
+ # https://github.com/svinota/pyroute2/blob/78cfe838bec8d96324811a3962bda15fb028e0ce/pyroute2/netlink/rtnl/rtmsg.py#L657
+ def __init__(self):
+ """
+ Initializes the EncapFormatter with supported encapsulation types.
+ """
+ self.formatters = {
+ rtmsg.LWTUNNEL_ENCAP_MPLS: self.mpls_format,
+ rtmsg.LWTUNNEL_ENCAP_SEG6: self.seg6_format,
+ rtmsg.LWTUNNEL_ENCAP_BPF: self.bpf_format,
+ rtmsg.LWTUNNEL_ENCAP_SEG6_LOCAL: self.seg6local_format,
+ }
+
+ def _format_srh(self, data: rtmsg_base.seg6_encap_info.ipv6_sr_hdr):
+ """
+ Formats Segment Routing Header (SRH) attributes.
+ """
+ result = ''
+ # pyroute2 decode mode only as inline or encap (encap, l2encap, encap.red, l2encap.red")
+ # https://github.com/svinota/pyroute2/blob/78cfe838bec8d96324811a3962bda15fb028e0ce/pyroute2/netlink/rtnl/rtmsg.py#L220
+ for key in ['mode', 'segs']:
+
+ val = data.get(key)
+
+ if val:
+ if key == 'segs':
+ result += f'{key} {len(val)} {val} '
+ else:
+ result += f'{key} {val} '
+
+ return self._output(result)
+
+ def _format_bpf_object(self, data: rtmsg_base.bpf_encap_info, attr_name: str, attr_key: str):
+ """
+ Formats eBPF program attributes.
+ """
+ attr = data.get_attr(attr_name)
+ if not attr:
+ return ''
+ result = ''
+ if attr.get_attr("LWT_BPF_PROG_NAME"):
+ result += f'{attr.get_attr("LWT_BPF_PROG_NAME")} '
+ if attr.get_attr("LWT_BPF_PROG_FD"):
+ result += f'{attr.get_attr("LWT_BPF_PROG_FD")} '
+
+ return self._output(f'{attr_key} {result.strip()}')
+
+ def mpls_format(self, data: rtmsg_base.mpls_encap_info):
+ """
+ Formats MPLS encapsulation attributes.
+ """
+ result = ''
+ if data.get_attr("MPLS_IPTUNNEL_DST"):
+ for rec in data.get_attr("MPLS_IPTUNNEL_DST"):
+ for key, val in rec.items():
+ if val:
+ result += f'{key} {val} '
+
+ if data.get_attr("MPLS_IPTUNNEL_TTL"):
+ result += f' ttl {data.get_attr("MPLS_IPTUNNEL_TTL")}'
+
+ return self._output(result)
+
+ def bpf_format(self, data: rtmsg_base.bpf_encap_info):
+ """
+ Formats eBPF encapsulation attributes.
+ """
+ result = ''
+ result += self._format_bpf_object(data, 'LWT_BPF_IN', 'in')
+ result += self._format_bpf_object(data, 'LWT_BPF_OUT', 'out')
+ result += self._format_bpf_object(data, 'LWT_BPF_XMIT', 'xmit')
+
+ if data.get_attr('LWT_BPF_XMIT_HEADROOM'):
+ result += f'headroom {data.get_attr("LWT_BPF_XMIT_HEADROOM")} '
+
+ return self._output(result)
+
+ def seg6_format(self, data: rtmsg_base.seg6_encap_info):
+ """
+ Formats Segment Routing (SEG6) encapsulation attributes.
+ """
+ result = ''
+ if data.get_attr("SEG6_IPTUNNEL_SRH"):
+ result += self._format_srh(data.get_attr("SEG6_IPTUNNEL_SRH"))
+
+ return self._output(result)
+
+ def seg6local_format(self, data: rtmsg_base.seg6local_encap_info):
+ """
+ Formats SEG6 local encapsulation attributes.
+ """
+ result = ''
+ formatters = {
+ 'SEG6_LOCAL_ACTION': lambda val: f' action {next((k for k, v in data.action.actions.items() if v == val), "unknown")}',
+ 'SEG6_LOCAL_SRH': lambda val: f' {self._format_srh(val)}',
+ 'SEG6_LOCAL_TABLE': lambda val: f' table {self._format_rttable(val)}',
+ 'SEG6_LOCAL_NH4': lambda val: f' nh4 {val}',
+ 'SEG6_LOCAL_NH6': lambda val: f' nh6 {val}',
+ 'SEG6_LOCAL_IIF': lambda val: f' iif {self._get_if_name_by_index(val)}',
+ 'SEG6_LOCAL_OIF': lambda val: f' oif {self._get_if_name_by_index(val)}',
+ 'SEG6_LOCAL_BPF': lambda val: f' endpoint {val.get("LWT_BPF_PROG_NAME")}',
+ 'SEG6_LOCAL_VRFTABLE': lambda val: f' vrftable {self._format_rttable(val)}',
+ }
+
+ for rec in data.get('attrs'):
+ if rec[0] in formatters:
+ result += formatters[rec[0]](rec[1])
+
+ return self._output(result)
+
+ def format(self, type: int, data: Union[rtmsg_base.mpls_encap_info,
+ rtmsg_base.bpf_encap_info,
+ rtmsg_base.seg6_encap_info,
+ rtmsg_base.seg6local_encap_info]):
+ """
+ Formats encapsulation attributes based on their type.
+ """
+ result = ''
+ formatter = self.formatters.get(type)
+
+ result += f'encap {ENCAP_TO_NAME.get(type, "unknown")}'
+
+ if formatter:
+ result += f' {formatter(data)}'
+
+ return self._output(result)
+
+
+class RouteFormatter(BaseMSGFormatter):
+ """
+ A formatter class for handling network routing messages
+ `RTM_NEWROUTE` and `RTM_DELROUTE`.
+ """
+
+ def _format_rt_flags(self, flags: int) -> str:
+ """
+ Formats route flags into a comma-separated string.
+ """
+ result = list()
+ result.extend(self._parse_flag(flags, RT_FlAGS))
+
+ return self._output(",".join(result))
+
+ def _format_rta_encap(self, type: int, data: Union[rtmsg_base.mpls_encap_info,
+ rtmsg_base.bpf_encap_info,
+ rtmsg_base.seg6_encap_info,
+ rtmsg_base.seg6local_encap_info]) -> str:
+ """
+ Formats encapsulation attributes.
+ """
+ return EncapFormatter().format(type, data)
+
+ def _format_rta_newdest(self, data: str) -> str:
+ """
+ Formats a new destination attribute.
+ """
+ return self._output(f'as to {data}')
+
+ def _format_rta_gateway(self, data: str) -> str:
+ """
+ Formats a gateway attribute.
+ """
+ return self._output(f'via {data}')
+
+ def _format_rta_via(self, data: str) -> str:
+ """
+ Formats a 'via' route attribute.
+ """
+ return self._output(f'{data}')
+
+ def _format_rta_metrics(self, data: rtmsg_base.metrics):
+ """
+ Formats routing metrics.
+ """
+ result = ''
+
+ def __format_metric_time(_val: int) -> str:
+ """Formats metric time values into seconds or milliseconds."""
+ return f"{_val / 1000}s" if _val >= 1000 else f"{_val}ms"
+
+ def __format_reatures(_val: int) -> str:
+ """Parse and formats routing feature flags."""
+ result = self._parse_flag(_val, {'ecn': RTAX_FEATURE_ECN,
+ 'tcp_usec_ts': RTAX_FEATURE_TCP_USEC_TS})
+ return ",".join(result)
+
+ formatters = {
+ 'RTAX_MTU': lambda val: f' mtu {val}',
+ 'RTAX_WINDOW': lambda val: f' window {val}',
+ 'RTAX_RTT': lambda val: f' rtt {__format_metric_time(val / 8)}',
+ 'RTAX_RTTVAR': lambda val: f' rttvar {__format_metric_time(val / 4)}',
+ 'RTAX_SSTHRESH': lambda val: f' ssthresh {val}',
+ 'RTAX_CWND': lambda val: f' cwnd {val}',
+ 'RTAX_ADVMSS': lambda val: f' advmss {val}',
+ 'RTAX_REORDERING': lambda val: f' reordering {val}',
+ 'RTAX_HOPLIMIT': lambda val: f' hoplimit {val}',
+ 'RTAX_INITCWND': lambda val: f' initcwnd {val}',
+ 'RTAX_FEATURES': lambda val: f' features {__format_reatures(val)}',
+ 'RTAX_RTO_MIN': lambda val: f' rto_min {__format_metric_time(val)}',
+ 'RTAX_INITRWND': lambda val: f' initrwnd {val}',
+ 'RTAX_QUICKACK': lambda val: f' quickack {val}',
+ }
+
+ for rec in data.get('attrs'):
+ if rec[0] in formatters:
+ result += formatters[rec[0]](rec[1])
+
+ return self._output(result)
+
+ def _format_rta_pref(self, data: int) -> str:
+ """
+ Formats a pref attribute.
+ """
+ pref = {
+ ICMPV6_ROUTER_PREF_LOW: "low",
+ ICMPV6_ROUTER_PREF_MEDIUM: "medium",
+ ICMPV6_ROUTER_PREF_HIGH: "high",
+ }
+
+ return self._output(f' pref {pref.get(data, data)}')
+
+ def _format_rta_multipath(self, mcast_cloned: bool, family: int, data: List[nh]) -> str:
+ """
+ Formats multipath route attributes.
+ """
+ result = ''
+ first = True
+ for rec in data:
+ if mcast_cloned:
+ if first:
+ result += ' Oifs: '
+ first = False
+ else:
+ result += ' '
+ else:
+ result += ' nexthop '
+
+ if rec.get_attr('RTA_ENCAP'):
+ result += self._format_rta_encap(rec.get_attr('RTA_ENCAP_TYPE'),
+ rec.get_attr('RTA_ENCAP'))
+
+ if rec.get_attr('RTA_NEWDST'):
+ result += self._format_rta_newdest(rec.get_attr('RTA_NEWDST'))
+
+ if rec.get_attr('RTA_GATEWAY'):
+ result += self._format_rta_gateway(rec.get_attr('RTA_GATEWAY'))
+
+ if rec.get_attr('RTA_VIA'):
+ result += self._format_rta_via(rec.get_attr('RTA_VIA'))
+
+ if rec.get_attr('RTA_FLOW'):
+ result += self._format_flow_field(rec.get_attr('RTA_FLOW'))
+
+ result += f' dev {self._get_if_name_by_index(rec.get("oif"))}'
+ if mcast_cloned:
+ if rec.get("hops") != 1:
+ result += f' (ttl>{rec.get("hops")})'
+ else:
+ if family != AF_MPLS:
+ result += f' weight {rec.get("hops") + 1}'
+
+ result += self._format_rt_flags(rec.get("flags"))
+
+ return self._output(result)
+
+ def format(self, msg: rtmsg.rtmsg) -> str:
+ """
+ Formats a network route message into a human-readable string representation.
+ """
+ message = self._prepare_start_message(msg.get('event'))
+
+ message += RT_TYPES.get(msg.get('type'))
+
+ if msg.get_attr('RTA_DST'):
+ host_len = self.af_bit_len(msg.get('family'))
+ if msg.get('dst_len') != host_len:
+ message += f' {msg.get_attr("RTA_DST")}/{msg.get("dst_len")}'
+ else:
+ message += f' {msg.get_attr("RTA_DST")}'
+ elif msg.get('dst_len'):
+ message += f' 0/{msg.get("dst_len")}'
+ else:
+ message += ' default'
+
+ if msg.get_attr('RTA_SRC'):
+ message += f' from {msg.get_attr("RTA_SRC")}'
+ elif msg.get('src_len'):
+ message += f' from 0/{msg.get("src_len")}'
+
+ message += self._format_simple_field(msg.get_attr('RTA_NH_ID'), prefix='nhid')
+
+ if msg.get_attr('RTA_NEWDST'):
+ message += self._format_rta_newdest(msg.get_attr('RTA_NEWDST'))
+
+ if msg.get_attr('RTA_ENCAP'):
+ message += self._format_rta_encap(msg.get_attr('RTA_ENCAP_TYPE'),
+ msg.get_attr('RTA_ENCAP'))
+
+ message += self._format_simple_field(msg.get('tos'), prefix='tos')
+
+ if msg.get_attr('RTA_GATEWAY'):
+ message += self._format_rta_gateway(msg.get_attr('RTA_GATEWAY'))
+
+ if msg.get_attr('RTA_VIA'):
+ message += self._format_rta_via(msg.get_attr('RTA_VIA'))
+
+ if msg.get_attr('RTA_OIF') is not None:
+ message += f' dev {self._get_if_name_by_index(msg.get_attr("RTA_OIF"))}'
+
+ if msg.get_attr("RTA_TABLE"):
+ message += f' table {self._format_rttable(msg.get_attr("RTA_TABLE"))}'
+
+ if not msg.get('flags') & RTM_F_CLONED:
+ message += f' proto {RT_PROTO.get(msg.get("proto"))}'
+
+ if not msg.get('scope') == rtmsg.RT_SCOPE_UNIVERSE:
+ message += f' scope {RT_SCOPE_TO_NAME.get(msg.get("scope"))}'
+
+ message += self._format_simple_field(msg.get_attr('RTA_PREFSRC'), prefix='src')
+ message += self._format_simple_field(msg.get_attr('RTA_PRIORITY'), prefix='metric')
+
+ message += self._format_rt_flags(msg.get("flags"))
+
+ if msg.get_attr('RTA_MARK'):
+ mark = msg.get_attr("RTA_MARK")
+ if mark >= 16:
+ message += f' mark 0x{mark:x}'
+ else:
+ message += f' mark {mark}'
+
+ if msg.get_attr('RTA_FLOW'):
+ message += self._format_flow_field(msg.get_attr('RTA_FLOW'))
+
+ message += self._format_simple_field(msg.get_attr('RTA_UID'), prefix='uid')
+
+ if msg.get_attr('RTA_METRICS'):
+ message += self._format_rta_metrics(msg.get_attr("RTA_METRICS"))
+
+ if msg.get_attr('RTA_IIF') is not None:
+ message += f' iif {self._get_if_name_by_index(msg.get_attr("RTA_IIF"))}'
+
+ if msg.get_attr('RTA_PREF') is not None:
+ message += self._format_rta_pref(msg.get_attr("RTA_PREF"))
+
+ if msg.get_attr('RTA_TTL_PROPAGATE') is not None:
+ message += f' ttl-propogate {"enabled" if msg.get_attr("RTA_TTL_PROPAGATE") else "disabled"}'
+
+ if msg.get_attr('RTA_MULTIPATH') is not None:
+ _tmp = self._format_rta_multipath(
+ mcast_cloned=msg.get('flags') & RTM_F_CLONED and msg.get('type') == RTYPES['RTN_MULTICAST'],
+ family=msg.get('family'),
+ data=msg.get_attr("RTA_MULTIPATH"))
+ message += f' {_tmp}'
+
+ return self._output(message)
+
+
+class AddrFormatter(BaseMSGFormatter):
+ """
+ A formatter class for handling address-related network messages
+ `RTM_NEWADDR` and `RTM_DELADDR`.
+ """
+ INFINITY_LIFE_TIME = _INFINITY
+
+ def _format_ifa_flags(self, flags: int, family: int) -> str:
+ """
+ Formats address flags into a human-readable string.
+ """
+ result = list()
+ if flags:
+ if not flags & IFA_FLAGS['permanent']:
+ result.append('dynamic')
+ flags &= ~IFA_FLAGS['permanent']
+
+ if flags & IFA_FLAGS['temporary'] and family == socket.AF_INET6:
+ result.append('temporary')
+ flags &= ~IFA_FLAGS['temporary']
+
+ result.extend(self._parse_flag(flags, IFA_FLAGS))
+
+ return self._output(",".join(result))
+
+ def _format_ifa_addr(self, local: str, addr: str, preflen: int, priority: int) -> str:
+ """
+ Formats address information into a shuman-readable string.
+ """
+ result = ''
+ local = local or addr
+ addr = addr or local
+
+ if local:
+ result += f'{local}'
+ if addr and addr != local:
+ result += f' peer {addr}'
+ result += f'/{preflen}'
+
+ if priority:
+ result += f' {priority}'
+
+ return self._output(result)
+
+ def _format_ifa_cacheinfo(self, data: ifaddrmsg.ifaddrmsg.cacheinfo) -> str:
+ """
+ Formats cache information for an address.
+ """
+ result = ''
+ _map = {
+ 'ifa_valid': 'valid_lft',
+ 'ifa_preferred': 'preferred_lft',
+ }
+
+ for key in ['ifa_valid', 'ifa_preferred']:
+ val = data.get(key)
+ if val == self.INFINITY_LIFE_TIME:
+ result += f'{_map.get(key)} forever '
+ else:
+ result += f'{_map.get(key)} {val}sec '
+
+ return self._output(result)
+
+ def format(self, msg: ifaddrmsg.ifaddrmsg) -> str:
+ """
+ Formats a full network address message.
+ Combine attributes such as index, family, address, flags, and cache
+ information into a structured output string.
+ """
+ message = self._prepare_start_message(msg.get('event'))
+
+ message += f'{msg.get("index")}: {self._get_if_name_by_index(msg.get("index"))} '
+ message += f'{FAMILY_TO_NAME.get(msg.get("family"), msg.get("family"))} '
+
+ message += self._format_ifa_addr(
+ msg.get_attr('IFA_LOCAL'),
+ msg.get_attr('IFA_ADDRESS'),
+ msg.get('prefixlen'),
+ msg.get_attr('IFA_RT_PRIORITY')
+ )
+ message += self._format_simple_field(msg.get_attr('IFA_BROADCAST'), prefix='brd')
+ message += self._format_simple_field(msg.get_attr('IFA_ANYCAST'), prefix='any')
+
+ if msg.get('scope') is not None:
+ message += f' scope {RT_SCOPE_TO_NAME.get(msg.get("scope"))}'
+
+ message += self._format_ifa_flags(msg.get_attr("IFA_FLAGS"), msg.get("family"))
+ message += self._format_simple_field(msg.get_attr('IFA_LABEL'), prefix='label:')
+
+ if msg.get_attr('IFA_CACHEINFO'):
+ message += self._format_ifa_cacheinfo(msg.get_attr('IFA_CACHEINFO'))
+
+ return self._output(message)
+
+
+class NeighFormatter(BaseMSGFormatter):
+ """
+ A formatter class for handling neighbor-related network messages
+ `RTM_NEWNEIGH`, `RTM_DELNEIGH` and `RTM_GETNEIGH`
+ """
+ def _format_ntf_flags(self, flags: int) -> str:
+ """
+ Formats neighbor table entry flags into a human-readable string.
+ """
+ result = list()
+ result.extend(self._parse_flag(flags, NTF_FlAGS))
+
+ return self._output(",".join(result))
+
+ def _format_neigh_state(self, data: int) -> str:
+ """
+ Formats the state of a neighbor entry.
+ """
+ result = list()
+ result.extend(self._parse_flag(data, NEIGH_STATE_FLAGS))
+
+ return self._output(",".join(result))
+
+ def format(self, msg: ndmsg.ndmsg) -> str:
+ """
+ Formats a full neighbor-related network message.
+ Combine attributes such as destination, device, link-layer address,
+ flags, state, and protocol into a structured output string.
+ """
+ message = self._prepare_start_message(msg.get('event'))
+ message += self._format_simple_field(msg.get_attr('NDA_DST'), prefix='')
+
+ if msg.get("ifindex") is not None:
+ message += f' dev {self._get_if_name_by_index(msg.get("ifindex"))}'
+
+ message += self._format_simple_field(msg.get_attr('NDA_LLADDR'), prefix='lladdr')
+ message += f' {self._format_ntf_flags(msg.get("flags"))}'
+ message += f' {self._format_neigh_state(msg.get("state"))}'
+
+ if msg.get_attr('NDA_PROTOCOL'):
+ message += f' proto {RT_PROTO.get(msg.get_attr("NDA_PROTOCOL"), msg.get_attr("NDA_PROTOCOL"))}'
+
+ return self._output(message)
+
+
+class RuleFormatter(BaseMSGFormatter):
+ """
+ A formatter class for handling ruting tule network messages
+ `RTM_NEWRULE` and `RTM_DELRULE`
+ """
+ def _format_direction(self, data: str, length: int, host_len: int):
+ """
+ Formats the direction of traffic based on source or destination and prefix length.
+ """
+ result = ''
+ if data:
+ result += f' {data}'
+ if length != host_len:
+ result += f'/{length}'
+ elif length:
+ result += f' 0/{length}'
+
+ return self._output(result)
+
+ def _format_fra_interface(self, data: str, flags: int, prefix: str):
+ """
+ Formats interface-related attributes.
+ """
+ result = f'{prefix} {data}'
+ if flags & FIB_RULE_IIF_DETACHED:
+ result += '[detached]'
+
+ return self._output(result)
+
+ def _format_fra_range(self, data: [str, dict], prefix: str):
+ """
+ Formats a range of values (e.g., UID, sport, or dport).
+ """
+ result = ''
+ if data:
+ if isinstance(data, str):
+ result += f' {prefix} {data}'
+ else:
+ result += f' {prefix} {data.get("start")}:{data.get("end")}'
+ return self._output(result)
+
+ def _format_fra_table(self, msg: fibmsg):
+ """
+ Formats the lookup table and associated attributes in the message.
+ """
+ def __format_field(data: int, prefix: str):
+ if data and data not in [-1, _INFINITY]:
+ return f' {prefix} {data}'
+ return ''
+
+ result = ''
+ table = msg.get_attr('FRA_TABLE') or msg.get('table')
+ if table:
+ result += f' lookup {self._format_rttable(table)}'
+ result += __format_field(msg.get_attr('FRA_SUPPRESS_PREFIXLEN'), 'suppress_prefixlength')
+ result += __format_field(msg.get_attr('FRA_SUPPRESS_IFGROUP'), 'suppress_ifgroup')
+
+ return self._output(result)
+
+ def _format_fra_action(self, msg: fibmsg):
+ """
+ Formats the action associated with the rule.
+ """
+ result = ''
+ if msg.get('action') == RTYPES.get('RTN_NAT'):
+ if msg.get_attr('RTA_GATEWAY'): # looks like deprecated but still use in iproute2
+ result += f' map-to {msg.get_attr("RTA_GATEWAY")}'
+ else:
+ result += ' masquerade'
+
+ elif msg.get('action') == FR_ACT_GOTO:
+ result += f' goto {msg.get_attr("FRA_GOTO") or "none"}'
+ if msg.get('flags') & FIB_RULE_UNRESOLVED:
+ result += ' [unresolved]'
+
+ elif msg.get('action') == FR_ACT_NOP:
+ result += ' nop'
+
+ elif msg.get('action') != FR_ACT_TO_TBL:
+ result += f' {RTYPES.get(msg.get("action"))}'
+
+ return self._output(result)
+
+ def format(self, msg: fibmsg):
+ """
+ Formats a complete routing rule message.
+ Combines information about source, destination, interfaces, actions,
+ and other attributes into a single formatted string.
+ """
+ message = self._prepare_start_message(msg.get('event'))
+ host_len = self.af_bit_len(msg.get('family'))
+ message += self._format_simple_field(msg.get_attr('FRA_PRIORITY'), prefix='')
+
+ if msg.get('flags') & FIB_RULE_INVERT:
+ message += ' not'
+
+ tmp = self._format_direction(msg.get_attr('FRA_SRC'), msg.get('src_len'), host_len)
+ message += ' from' + (tmp if tmp else ' all ')
+
+ if msg.get_attr('FRA_DST'):
+ tmp = self._format_direction(msg.get_attr('FRA_DST'), msg.get('dst_len'), host_len)
+ message += ' to' + tmp
+
+ if msg.get('tos'):
+ message += f' tos {hex(msg.get("tos"))}'
+
+ if msg.get_attr('FRA_FWMARK') or msg.get_attr('FRA_FWMASK'):
+ mark = msg.get_attr('FRA_FWMARK') or 0
+ mask = msg.get_attr('FRA_FWMASK') or 0
+ if mask != 0xFFFFFFFF:
+ message += f' fwmark {mark}/{mask}'
+ else:
+ message += f' fwmark {mark}'
+
+ if msg.get_attr('FRA_IIFNAME'):
+ message += self._format_fra_interface(
+ msg.get_attr('FRA_IIFNAME'),
+ msg.get('flags'),
+ 'iif'
+ )
+
+ if msg.get_attr('FRA_OIFNAME'):
+ message += self._format_fra_interface(
+ msg.get_attr('FRA_OIFNAME'),
+ msg.get('flags'),
+ 'oif'
+ )
+
+ if msg.get_attr('FRA_L3MDEV'):
+ message += f' lookup [l3mdev-table]'
+
+ if msg.get_attr('FRA_UID_RANGE'):
+ message += self._format_fra_range(msg.get_attr('FRA_UID_RANGE'), 'uidrange')
+
+ message += self._format_simple_field(msg.get_attr('FRA_IP_PROTO'), prefix='ipproto')
+
+ if msg.get_attr('FRA_SPORT_RANGE'):
+ message += self._format_fra_range(msg.get_attr('FRA_SPORT_RANGE'), 'sport')
+
+ if msg.get_attr('FRA_DPORT_RANGE'):
+ message += self._format_fra_range(msg.get_attr('FRA_DPORT_RANGE'), 'dport')
+
+ message += self._format_simple_field(msg.get_attr('FRA_TUN_ID'), prefix='tun_id')
+
+ message += self._format_fra_table(msg)
+
+ if msg.get_attr('FRA_FLOW'):
+ message += self._format_flow_field(msg.get_attr('FRA_FLOW'))
+
+ message += self._format_fra_action(msg)
+
+ if msg.get_attr('FRA_PROTOCOL'):
+ message += f' proto {RT_PROTO.get(msg.get_attr("FRA_PROTOCOL"), msg.get_attr("FRA_PROTOCOL"))}'
+
+ return self._output(message)
+
+
+class AddrlabelFormatter(BaseMSGFormatter):
+ # Not implemented decoder on pytroute2 but ip monitor use it message
+ pass
+
+
+class PrefixFormatter(BaseMSGFormatter):
+ # Not implemented decoder on pytroute2 but ip monitor use it message
+ pass
+
+
+class NetconfFormatter(BaseMSGFormatter):
+ # Not implemented decoder on pytroute2 but ip monitor use it message
+ pass
+
+
+EVENT_MAP = {
+ rtnl.RTM_NEWROUTE: {'parser': RouteFormatter, 'event': 'route'},
+ rtnl.RTM_DELROUTE: {'parser': RouteFormatter, 'event': 'route'},
+ rtnl.RTM_NEWLINK: {'parser': LinkFormatter, 'event': 'link'},
+ rtnl.RTM_DELLINK: {'parser': LinkFormatter, 'event': 'link'},
+ rtnl.RTM_NEWADDR: {'parser': AddrFormatter, 'event': 'addr'},
+ rtnl.RTM_DELADDR: {'parser': AddrFormatter, 'event': 'addr'},
+ # rtnl.RTM_NEWADDRLABEL: {'parser': AddrlabelFormatter, 'event': 'addrlabel'},
+ # rtnl.RTM_DELADDRLABEL: {'parser': AddrlabelFormatter, 'event': 'addrlabel'},
+ rtnl.RTM_NEWNEIGH: {'parser': NeighFormatter, 'event': 'neigh'},
+ rtnl.RTM_DELNEIGH: {'parser': NeighFormatter, 'event': 'neigh'},
+ rtnl.RTM_GETNEIGH: {'parser': NeighFormatter, 'event': 'neigh'},
+ # rtnl.RTM_NEWPREFIX: {'parser': PrefixFormatter, 'event': 'prefix'},
+ rtnl.RTM_NEWRULE: {'parser': RuleFormatter, 'event': 'rule'},
+ rtnl.RTM_DELRULE: {'parser': RuleFormatter, 'event': 'rule'},
+ # rtnl.RTM_NEWNETCONF: {'parser': NetconfFormatter, 'event': 'netconf'},
+ # rtnl.RTM_DELNETCONF: {'parser': NetconfFormatter, 'event': 'netconf'},
+}
+
+
+def sig_handler(signum, frame):
+ process_name = multiprocessing.current_process().name
+ logger.debug(
+ f'[{process_name}]: {"Shutdown" if signum == signal.SIGTERM else "Reload"} signal received...'
+ )
+ shutdown_event.set()
+
+
+def parse_event_type(header: Dict) -> tuple:
+ """
+ Extract event type and parser.
+ """
+ event_type = EVENT_MAP.get(header['type'], {}).get('event', 'unknown')
+ _parser = EVENT_MAP.get(header['type'], {}).get('parser')
+
+ if _parser is None:
+ raise UnsupportedMessageType(f'Unsupported message type: {header["type"]}')
+
+ return event_type, _parser
+
+
+def is_need_to_log(event_type: AnyStr, conf_event: Dict):
+ """
+ Filter message by event type and protocols
+ """
+ conf = conf_event.get(event_type)
+ if conf == {}:
+ return True
+ return False
+
+
+def parse_event(msg: nfct_msg, conf_event: Dict) -> str:
+ """
+ Convert nfct_msg to internal data dict.
+ """
+ data = ''
+ event_type, parser = parse_event_type(msg['header'])
+ if event_type == 'link':
+ remember_if_index(idx=msg.get('index'), event_type=msg['header'].get('type'))
+
+ if not is_need_to_log(event_type, conf_event):
+ return data
+
+ message = parser().format(msg)
+ if message:
+ data = f'{f"[{event_type}]".upper():<{7}} {message}'
+
+ return data
+
+
+def worker(ct: IPRoute, shutdown_event: multiprocessing.Event, conf_event: Dict) -> None:
+ """
+ Main function of parser worker process
+ """
+ process_name = multiprocessing.current_process().name
+ logger.debug(f'[{process_name}] started')
+ timeout = 0.1
+ while not shutdown_event.is_set():
+ if not ct.buffer_queue.empty():
+ msg = None
+ try:
+ for msg in ct.get():
+ message = parse_event(msg, conf_event)
+ if message:
+ if logger.level == logging.DEBUG:
+ logger.debug(f'[{process_name}]: {message} raw: {msg}')
+ else:
+ logger.info(message)
+ except queue.Full:
+ logger.error('IPRoute message queue if full.')
+ except UnsupportedMessageType as e:
+ logger.debug(f'{e} =====> raw msg: {msg}')
+ except Exception as e:
+ logger.error(f'Unexpected error: {e.__class__} {e} [{msg}]')
+ else:
+ sleep(timeout)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-c',
+ '--config',
+ action='store',
+ help='Path to vyos-network-event-logger configuration',
+ required=True,
+ type=Path,
+ )
+
+ args = parser.parse_args()
+ try:
+ config = read_json(args.config)
+ except Exception as err:
+ logger.error(f'Configuration file "{args.config}" does not exist or malformed: {err}')
+ exit(1)
+
+ set_log_level(config.get('log_level', 'info'))
+
+ signal.signal(signal.SIGHUP, sig_handler)
+ signal.signal(signal.SIGTERM, sig_handler)
+
+ if 'event' in config:
+ event_groups = list(config.get('event').keys())
+ else:
+ logger.error(f'Configuration is wrong. Event filter is empty.')
+ exit(1)
+
+ conf_event = config['event']
+ qsize = config.get('queue_size')
+ ct = IPRoute(async_qsize=int(qsize) if qsize else None)
+ ct.buffer_queue = multiprocessing.Queue(ct.async_qsize)
+ ct.bind(async_cache=True)
+
+ processes = list()
+ try:
+ for _ in range(multiprocessing.cpu_count()):
+ p = multiprocessing.Process(target=worker, args=(ct, shutdown_event, conf_event))
+ processes.append(p)
+ p.start()
+ logger.info('IPRoute socket bound and listening for messages.')
+
+ while not shutdown_event.is_set():
+ if not ct.pthread.is_alive():
+ if ct.buffer_queue.qsize() / ct.async_qsize < 0.9:
+ if not shutdown_event.is_set():
+ logger.debug('Restart listener thread')
+ # restart listener thread after queue overloaded when queue size low than 90%
+ ct.pthread = threading.Thread(name='Netlink async cache', target=ct.async_recv)
+ ct.pthread.daemon = True
+ ct.pthread.start()
+ else:
+ sleep(0.1)
+ finally:
+ for p in processes:
+ p.join()
+ if not p.is_alive():
+ logger.debug(f'[{p.name}]: finished')
+ ct.close()
+ logging.info('IPRoute socket closed.')
+ exit()