#!/usr/share/vyos-http-api-tools/bin/python3 # # Copyright (C) 2019-2021 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # # import os import sys import grp import copy import json import logging import traceback import threading from typing import List, Union, Callable, Dict import uvicorn from fastapi import FastAPI, Depends, Request, Response, HTTPException from fastapi.responses import HTMLResponse from fastapi.exceptions import RequestValidationError from fastapi.routing import APIRoute from pydantic import BaseModel, StrictStr, validator from starlette.middleware.cors import CORSMiddleware from starlette.datastructures import FormData from starlette.formparsers import FormParser, MultiPartParser from multipart.multipart import parse_options_header from ariadne.asgi import GraphQL import vyos.config from vyos.configsession import ConfigSession, ConfigSessionError import api.graphql.state DEFAULT_CONFIG_FILE = '/etc/vyos/http-api.conf' CFG_GROUP = 'vyattacfg' debug = True logger = logging.getLogger(__name__) logs_handler = logging.StreamHandler() logger.addHandler(logs_handler) if debug: logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.INFO) # Giant lock! lock = threading.Lock() def load_server_config(): with open(DEFAULT_CONFIG_FILE) as f: config = json.load(f) return config def check_auth(key_list, key): key_id = None for k in key_list: if k['key'] == key: key_id = k['id'] return key_id def error(code, msg): resp = {"success": False, "error": msg, "data": None} resp = json.dumps(resp) return HTMLResponse(resp, status_code=code) def success(data): resp = {"success": True, "data": data, "error": None} resp = json.dumps(resp) return HTMLResponse(resp) # Pydantic models for validation # Pydantic will cast when possible, so use StrictStr # validators added as needed for additional constraints # schema_extra adds anotations to OpenAPI, to add examples class ApiModel(BaseModel): key: StrictStr class BaseConfigureModel(BaseModel): op: StrictStr path: List[StrictStr] value: StrictStr = None @validator("path", pre=True, always=True) def check_non_empty(cls, path): assert len(path) > 0 return path class ConfigureModel(ApiModel): op: StrictStr path: List[StrictStr] value: StrictStr = None @validator("path", pre=True, always=True) def check_non_empty(cls, path): assert len(path) > 0 return path class Config: schema_extra = { "example": { "key": "id_key", "op": "set | delete | comment", "path": ['config', 'mode', 'path'], } } class ConfigureListModel(ApiModel): commands: List[BaseConfigureModel] class Config: schema_extra = { "example": { "key": "id_key", "commands": "list of commands", } } class RetrieveModel(ApiModel): op: StrictStr path: List[StrictStr] configFormat: StrictStr = None class Config: schema_extra = { "example": { "key": "id_key", "op": "returnValue | returnValues | exists | showConfig", "path": ['config', 'mode', 'path'], "configFormat": "json (default) | json_ast | raw", } } class ConfigFileModel(ApiModel): op: StrictStr file: StrictStr = None class Config: schema_extra = { "example": { "key": "id_key", "op": "save | load", "file": "filename", } } class ImageModel(ApiModel): op: StrictStr url: StrictStr = None name: StrictStr = None class Config: schema_extra = { "example": { "key": "id_key", "op": "add | delete", "url": "imagelocation", "name": "imagename", } } class GenerateModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: schema_extra = { "example": { "key": "id_key", "op": "generate", "path": ["op", "mode", "path"], } } class ShowModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: schema_extra = { "example": { "key": "id_key", "op": "show", "path": ["op", "mode", "path"], } } class Success(BaseModel): success: bool data: Union[str, bool, Dict] error: str class Error(BaseModel): success: bool = False data: Union[str, bool, Dict] error: str responses = { 200: {'model': Success}, 400: {'model': Error}, 422: {'model': Error, 'description': 'Validation Error'}, 500: {'model': Error} } def auth_required(data: ApiModel): key = data.key api_keys = app.state.vyos_keys key_id = check_auth(api_keys, key) if not key_id: raise HTTPException(status_code=401, detail="Valid API key is required") app.state.vyos_id = key_id # override Request and APIRoute classes in order to convert form request to json; # do all explicit validation here, for backwards compatability of error messages; # the explicit validation may be dropped, if desired, in favor of native # validation by FastAPI/Pydantic, as is used for application/json requests class MultipartRequest(Request): ERR_MISSING_KEY = False ERR_MISSING_DATA = False ERR_NOT_JSON = False ERR_NOT_DICT = False ERR_NO_OP = False ERR_NO_PATH = False ERR_EMPTY_PATH = False ERR_PATH_NOT_LIST = False ERR_VALUE_NOT_STRING = False ERR_PATH_NOT_LIST_OF_STR = False offending_command = {} exception = None @property def orig_headers(self): self._orig_headers = super().headers return self._orig_headers @property def headers(self): self._headers = super().headers.mutablecopy() self._headers['content-type'] = 'application/json' return self._headers async def form(self) -> FormData: if not hasattr(self, "_form"): assert ( parse_options_header is not None ), "The `python-multipart` library must be installed to use form parsing." content_type_header = self.orig_headers.get("Content-Type") content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": multipart_parser = MultiPartParser(self.orig_headers, self.stream()) self._form = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": form_parser = FormParser(self.orig_headers, self.stream()) self._form = await form_parser.parse() else: self._form = FormData() return self._form async def body(self) -> bytes: if not hasattr(self, "_body"): forms = {} merge = {} body = await super().body() self._body = body form_data = await self.form() if form_data: logger.debug("processing form data") for k, v in form_data.multi_items(): forms[k] = v if 'data' not in forms: self.ERR_MISSING_DATA = True else: try: tmp = json.loads(forms['data']) except json.JSONDecodeError as e: self.ERR_NOT_JSON = True self.exception = e tmp = {} if isinstance(tmp, list): merge['commands'] = tmp else: merge = tmp if 'commands' in merge: cmds = merge['commands'] else: cmds = copy.deepcopy(merge) cmds = [cmds] for c in cmds: if not isinstance(c, dict): self.ERR_NOT_DICT = True self.offending_command = c elif 'op' not in c: self.ERR_NO_OP = True self.offending_command = c elif 'path' not in c: self.ERR_NO_PATH = True self.offending_command = c elif not c['path']: self.ERR_EMPTY_PATH = True self.offending_command = c elif not isinstance(c['path'], list): self.ERR_PATH_NOT_LIST = True self.offending_command = c elif not all(isinstance(el, str) for el in c['path']): self.ERR_PATH_NOT_LIST_OF_STR = True self.offending_command = c elif 'value' in c and not isinstance(c['value'], str): self.ERR_VALUE_NOT_STRING = True self.offending_command = c if 'key' not in forms and 'key' not in merge: self.ERR_MISSING_KEY = True if 'key' in forms and 'key' not in merge: merge['key'] = forms['key'] new_body = json.dumps(merge) new_body = new_body.encode() self._body = new_body return self._body class MultipartRoute(APIRoute): def get_route_handler(self) -> Callable: original_route_handler = super().get_route_handler() async def custom_route_handler(request: Request) -> Response: request = MultipartRequest(request.scope, request.receive) endpoint = request.url.path try: response: Response = await original_route_handler(request) except HTTPException as e: return error(e.status_code, e.detail) except Exception as e: if request.ERR_MISSING_KEY: return error(401, "Valid API key is required") if request.ERR_MISSING_DATA: return error(422, "Non-empty data field is required") if request.ERR_NOT_JSON: return error(400, "Failed to parse JSON: {0}".format(request.exception)) if endpoint == '/configure': if request.ERR_NOT_DICT: return error(400, "Malformed command \"{0}\": any command must be a dict".format(json.dumps(request.offending_command))) if request.ERR_NO_OP: return error(400, "Malformed command \"{0}\": missing \"op\" field".format(json.dumps(request.offending_command))) if request.ERR_NO_PATH: return error(400, "Malformed command \"{0}\": missing \"path\" field".format(json.dumps(request.offending_command))) if request.ERR_EMPTY_PATH: return error(400, "Malformed command \"{0}\": empty path".format(json.dumps(request.offending_command))) if request.ERR_PATH_NOT_LIST: return error(400, "Malformed command \"{0}\": \"path\" field must be a list".format(json.dumps(request.offending_command))) if request.ERR_VALUE_NOT_STRING: return error(400, "Malformed command \"{0}\": \"value\" field must be a string".format(json.dumps(request.offending_command))) if request.ERR_PATH_NOT_LIST_OF_STR: return error(400, "Malformed command \"{0}\": \"path\" field must be a list of strings".format(json.dumps(request.offending_command))) if endpoint in ('/retrieve','/generate','/show'): if request.ERR_NO_OP or request.ERR_NO_PATH: return error(400, "Missing required field. \"op\" and \"path\" fields are required") if endpoint in ('/config-file', '/image'): if request.ERR_NO_OP: return error(400, "Missing required field \"op\"") raise e return response return custom_route_handler app = FastAPI(debug=True, title="VyOS API", version="0.1.0", responses={**responses}, dependencies=[Depends(auth_required)]) app.router.route_class = MultipartRoute @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): return error(400, str(exc.errors()[0])) @app.post('/configure') def configure_op(data: Union[ConfigureModel, ConfigureListModel]): session = app.state.vyos_session env = session.get_session_env() config = vyos.config.Config(session_env=env) # Allow users to pass just one command if not isinstance(data, ConfigureListModel): data = [data] else: data = data.commands # We don't want multiple people/apps to be able to commit at once, # or modify the shared session while someone else is doing the same, # so the lock is really global lock.acquire() status = 200 error_msg = None try: for c in data: op = c.op path = c.path if c.value: value = c.value else: value = "" # For vyos.configsession calls that have no separate value arguments, # and for type checking too cfg_path = " ".join(path + [value]).strip() if op == 'set': # XXX: it would be nice to do a strict check for "path already exists", # but there's probably no way to do that session.set(path, value=value) elif op == 'delete': if app.state.vyos_strict and not config.exists(cfg_path): raise ConfigSessionError("Cannot delete [{0}]: path/value does not exist".format(cfg_path)) session.delete(path, value=value) elif op == 'comment': session.comment(path, value=value) else: raise ConfigSessionError("\"{0}\" is not a valid operation".format(op)) # end for session.commit() logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'") except ConfigSessionError as e: session.discard() status = 400 if app.state.vyos_debug: logger.critical(f"ConfigSessionError:\n {traceback.format_exc()}") error_msg = str(e) except Exception as e: session.discard() logger.critical(traceback.format_exc()) status = 500 # Don't give the details away to the outer world error_msg = "An internal error occured. Check the logs for details." finally: lock.release() if status != 200: return error(status, error_msg) return success(None) @app.post("/retrieve") def retrieve_op(data: RetrieveModel): session = app.state.vyos_session env = session.get_session_env() config = vyos.config.Config(session_env=env) op = data.op path = " ".join(data.path) try: if op == 'returnValue': res = config.return_value(path) elif op == 'returnValues': res = config.return_values(path) elif op == 'exists': res = config.exists(path) elif op == 'showConfig': config_format = 'json' if data.configFormat: config_format = data.configFormat res = session.show_config(path=data.path) if config_format == 'json': config_tree = vyos.configtree.ConfigTree(res) res = json.loads(config_tree.to_json()) elif config_format == 'json_ast': config_tree = vyos.configtree.ConfigTree(res) res = json.loads(config_tree.to_json_ast()) elif config_format == 'raw': pass else: return error(400, "\"{0}\" is not a valid config format".format(config_format)) else: return error(400, "\"{0}\" is not a valid operation".format(op)) except ConfigSessionError as e: return error(400, str(e)) except Exception as e: logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") return success(res) @app.post('/config-file') def config_file_op(data: ConfigFileModel): session = app.state.vyos_session op = data.op try: if op == 'save': if data.file: path = data.file else: path = '/config/config.boot' res = session.save_config(path) elif op == 'load': if data.file: path = data.file else: return error(400, "Missing required field \"file\"") res = session.migrate_and_load_config(path) res = session.commit() else: return error(400, "\"{0}\" is not a valid operation".format(op)) except ConfigSessionError as e: return error(400, str(e)) except Exception as e: logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") return success(res) @app.post('/image') def image_op(data: ImageModel): session = app.state.vyos_session op = data.op try: if op == 'add': if data.url: url = data.url else: return error(400, "Missing required field \"url\"") res = session.install_image(url) elif op == 'delete': if data.name: name = data.name else: return error(400, "Missing required field \"name\"") res = session.remove_image(name) else: return error(400, "\"{0}\" is not a valid operation".format(op)) except ConfigSessionError as e: return error(400, str(e)) except Exception as e: logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") return success(res) @app.post('/generate') def generate_op(data: GenerateModel): session = app.state.vyos_session op = data.op path = data.path try: if op == 'generate': res = session.generate(path) else: return error(400, "\"{0}\" is not a valid operation".format(op)) except ConfigSessionError as e: return error(400, str(e)) except Exception as e: logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") return success(res) @app.post('/show') def show_op(data: ShowModel): session = app.state.vyos_session op = data.op path = data.path try: if op == 'show': res = session.show(path) else: return error(400, "\"{0}\" is not a valid operation".format(op)) except ConfigSessionError as e: return error(400, str(e)) except Exception as e: logger.critical(traceback.format_exc()) return error(500, "An internal error occured. Check the logs for details.") return success(res) ### # GraphQL integration ### def graphql_init(fast_api_app): from api.graphql.bindings import generate_schema api.graphql.state.init() api.graphql.state.settings['app'] = app schema = generate_schema() if app.state.vyos_origins: origins = app.state.vyos_origins app.add_route('/graphql', CORSMiddleware(GraphQL(schema, debug=True), allow_origins=origins, allow_methods=("GET", "POST", "OPTIONS"))) else: app.add_route('/graphql', GraphQL(schema, debug=True)) ### if __name__ == '__main__': # systemd's user and group options don't work, do it by hand here, # else no one else will be able to commit cfg_group = grp.getgrnam(CFG_GROUP) os.setgid(cfg_group.gr_gid) # Need to set file permissions to 775 too so that every vyattacfg group member # has write access to the running config os.umask(0o002) try: server_config = load_server_config() except Exception as err: logger.critical(f"Failed to load the HTTP API server config: {err}") config_session = ConfigSession(os.getpid()) app.state.vyos_session = config_session app.state.vyos_keys = server_config['api_keys'] app.state.vyos_debug = server_config['debug'] app.state.vyos_gql = server_config['gql'] app.state.vyos_strict = server_config['strict'] app.state.vyos_origins = server_config.get('cors', {}).get('origins', []) if app.state.vyos_gql: graphql_init(app) try: if not server_config['socket']: uvicorn.run(app, host=server_config["listen_address"], port=int(server_config["port"]), proxy_headers=True) else: uvicorn.run(app, uds="/run/api.sock", proxy_headers=True) except OSError as err: logger.critical(f"OSError {err}") sys.exit(1)