diff options
Diffstat (limited to 'src/services/vyos-http-api-server')
| -rwxr-xr-x | src/services/vyos-http-api-server | 173 | 
1 files changed, 129 insertions, 44 deletions
| diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 66e80ced5..3a9efb73e 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -22,12 +22,14 @@ import grp  import copy  import json  import logging +import signal  import traceback  import threading +from time import sleep  from typing import List, Union, Callable, Dict -import uvicorn  from fastapi import FastAPI, Depends, Request, Response, HTTPException +from fastapi import BackgroundTasks  from fastapi.responses import HTMLResponse  from fastapi.exceptions import RequestValidationError  from fastapi.routing import APIRoute @@ -36,10 +38,14 @@ from starlette.middleware.cors import CORSMiddleware  from starlette.datastructures import FormData  from starlette.formparsers import FormParser, MultiPartParser  from multipart.multipart import parse_options_header +from uvicorn import Config as UvicornConfig +from uvicorn import Server as UvicornServer  from ariadne.asgi import GraphQL -import vyos.config +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configdiff import get_config_diff  from vyos.configsession import ConfigSession, ConfigSessionError  import api.graphql.state @@ -410,12 +416,24 @@ app.router.route_class = MultipartRoute  async def validation_exception_handler(request, exc):      return error(400, str(exc.errors()[0])) +self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" + +def call_commit(s: ConfigSession): +    try: +        s.commit() +    except ConfigSessionError as e: +        s.discard() +        if app.state.vyos_debug: +            logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}") +        else: +            logger.warning(f"ConfigSessionError: {e}") +  def _configure_op(data: Union[ConfigureModel, ConfigureListModel,                                ConfigSectionModel, ConfigSectionListModel], -                  request: Request): +                  request: Request, background_tasks: BackgroundTasks):      session = app.state.vyos_session      env = session.get_session_env() -    config = vyos.config.Config(session_env=env) +    config = Config(session_env=env)      endpoint = request.url.path @@ -470,7 +488,15 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,                  else:                      raise ConfigSessionError(f"'{op}' is not a valid operation")          # end for -        session.commit() +        config = Config(session_env=env) +        d = get_config_diff(config) + +        if d.is_node_changed(['service', 'https']): +            background_tasks.add_task(call_commit, session) +            msg = self_ref_msg +        else: +            session.commit() +          logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'")      except ConfigSessionError as e:          session.discard() @@ -495,21 +521,21 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel,  @app.post('/configure')  def configure_op(data: Union[ConfigureModel, -                                   ConfigureListModel], -                       request: Request): -    return _configure_op(data, request) +                             ConfigureListModel], +                       request: Request, background_tasks: BackgroundTasks): +    return _configure_op(data, request, background_tasks)  @app.post('/configure-section')  def configure_section_op(data: Union[ConfigSectionModel, -                                           ConfigSectionListModel], -                               request: Request): -    return _configure_op(data, request) +                                     ConfigSectionListModel], +                               request: Request, background_tasks: BackgroundTasks): +    return _configure_op(data, request, background_tasks)  @app.post("/retrieve")  async def retrieve_op(data: RetrieveModel):      session = app.state.vyos_session      env = session.get_session_env() -    config = vyos.config.Config(session_env=env) +    config = Config(session_env=env)      op = data.op      path = " ".join(data.path) @@ -528,10 +554,10 @@ async def retrieve_op(data: RetrieveModel):              res = session.show_config(path=data.path)              if config_format == 'json': -                config_tree = vyos.configtree.ConfigTree(res) +                config_tree = ConfigTree(res)                  res = json.loads(config_tree.to_json())              elif config_format == 'json_ast': -                config_tree = vyos.configtree.ConfigTree(res) +                config_tree = ConfigTree(res)                  res = json.loads(config_tree.to_json_ast())              elif config_format == 'raw':                  pass @@ -548,10 +574,11 @@ async def retrieve_op(data: RetrieveModel):      return success(res)  @app.post('/config-file') -def config_file_op(data: ConfigFileModel): +def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks):      session = app.state.vyos_session - +    env = session.get_session_env()      op = data.op +    msg = None      try:          if op == 'save': @@ -559,14 +586,23 @@ def config_file_op(data: ConfigFileModel):                  path = data.file              else:                  path = '/config/config.boot' -            res = session.save_config(path) +            msg = session.save_config(path)          elif op == 'load':              if data.file:                  path = data.file              else:                  return error(400, "Missing required field \"file\"") -            res = session.migrate_and_load_config(path) -            res = session.commit() + +            session.migrate_and_load_config(path) + +            config = Config(session_env=env) +            d = get_config_diff(config) + +            if d.is_node_changed(['service', 'https']): +                background_tasks.add_task(call_commit, session) +                msg = self_ref_msg +            else: +                session.commit()          else:              return error(400, f"'{op}' is not a valid operation")      except ConfigSessionError as e: @@ -575,7 +611,7 @@ def config_file_op(data: ConfigFileModel):          logger.critical(traceback.format_exc())          return error(500, "An internal error occured. Check the logs for details.") -    return success(res) +    return success(msg)  @app.post('/image')  def image_op(data: ImageModel): @@ -607,7 +643,7 @@ def image_op(data: ImageModel):      return success(res)  @app.post('/container-image') -def image_op(data: ContainerImageModel): +def container_image_op(data: ContainerImageModel):      session = app.state.vyos_session      op = data.op @@ -702,7 +738,7 @@ def reset_op(data: ResetModel):  # GraphQL integration  ### -def graphql_init(fast_api_app): +def graphql_init(app: FastAPI = app):      from api.graphql.libs.token_auth import get_user_context      api.graphql.state.init()      api.graphql.state.settings['app'] = app @@ -728,26 +764,45 @@ def graphql_init(fast_api_app):                                            debug=True,                                            introspection=in_spec))  ### +# Modify uvicorn to allow reloading server within the configsession +### -if __name__ == '__main__': -    # systemd's user and group options don't work, do it by hand here, -    # else no one else will be able to commit -    cfg_group = grp.getgrnam(CFG_GROUP) -    os.setgid(cfg_group.gr_gid) +server = None +shutdown = False -    # Need to set file permissions to 775 too so that every vyattacfg group member -    # has write access to the running config -    os.umask(0o002) +class ApiServerConfig(UvicornConfig): +    pass + +class ApiServer(UvicornServer): +    def install_signal_handlers(self): +        pass + +def reload_handler(signum, frame): +    global server +    logger.debug('Reload signal received...') +    if server is not None: +        server.handle_exit(signum, frame) +        server = None +        logger.info('Server stopping for reload...') +    else: +        logger.warning('Reload called for non-running server...') +def shutdown_handler(signum, frame): +    global shutdown +    logger.debug('Shutdown signal received...') +    server.handle_exit(signum, frame) +    logger.info('Server shutdown...') +    shutdown = True + +def initialization(session: ConfigSession, app: FastAPI = app): +    global server      try:          server_config = load_server_config() -    except Exception as err: -        logger.critical(f"Failed to load the HTTP API server config: {err}") +    except Exception as e: +        logger.critical(f'Failed to load the HTTP API server config: {e}')          sys.exit(1) -    config_session = ConfigSession(os.getpid()) - -    app.state.vyos_session = config_session +    app.state.vyos_session = session      app.state.vyos_keys = server_config['api_keys']      app.state.vyos_debug = server_config['debug'] @@ -770,14 +825,44 @@ if __name__ == '__main__':      if app.state.vyos_graphql:          graphql_init(app) +    if not server_config['socket']: +        config = ApiServerConfig(app, +                                 host=server_config["listen_address"], +                                 port=int(server_config["port"]), +                                 proxy_headers=True) +    else: +        config = ApiServerConfig(app, +                                 uds="/run/api.sock", +                                 proxy_headers=True) +    server = ApiServer(config) + +def run_server():      try: -        if not server_config['socket']: -            uvicorn.run(app, host=server_config["listen_address"], -                             port=int(server_config["port"]), -                             proxy_headers=True) -        else: -            uvicorn.run(app, uds="/run/api.sock", -                             proxy_headers=True) -    except OSError as err: -        logger.critical(f"OSError {err}") +        server.run() +    except OSError as e: +        logger.critical(e)          sys.exit(1) + +if __name__ == '__main__': +    # systemd's user and group options don't work, do it by hand here, +    # else no one else will be able to commit +    cfg_group = grp.getgrnam(CFG_GROUP) +    os.setgid(cfg_group.gr_gid) + +    # Need to set file permissions to 775 too so that every vyattacfg group member +    # has write access to the running config +    os.umask(0o002) + +    signal.signal(signal.SIGHUP, reload_handler) +    signal.signal(signal.SIGTERM, shutdown_handler) + +    config_session = ConfigSession(os.getpid()) + +    while True: +        logger.debug('Enter main loop...') +        if shutdown: +            break +        if server is None: +            initialization(config_session) +            server.run() +        sleep(1) | 
