summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rwxr-xr-xsrc/conf_mode/http-api.py6
-rwxr-xr-xsrc/services/vyos-http-api-server18
2 files changed, 19 insertions, 5 deletions
diff --git a/src/conf_mode/http-api.py b/src/conf_mode/http-api.py
index cd0191599..ea0743cd5 100755
--- a/src/conf_mode/http-api.py
+++ b/src/conf_mode/http-api.py
@@ -67,6 +67,12 @@ def get_config(config=None):
port = conf.return_value('port')
http_api['port'] = port
+ if conf.exists('cors'):
+ http_api['cors'] = {}
+ if conf.exists('cors allow-origin'):
+ origins = conf.return_values('cors allow-origin')
+ http_api['cors']['origins'] = origins[:]
+
if conf.exists('keys'):
for name in conf.list_nodes('keys id'):
if conf.exists('keys id {0} key'.format(name)):
diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server
index f79058683..06871f1d6 100755
--- a/src/services/vyos-http-api-server
+++ b/src/services/vyos-http-api-server
@@ -32,6 +32,7 @@ from fastapi.responses import HTMLResponse
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
from pydantic import BaseModel, StrictStr, validator
+from starlette.middleware.cors import CORSMiddleware
from starlette.datastructures import FormData
from starlette.formparsers import FormParser, MultiPartParser
from multipart.multipart import parse_options_header
@@ -610,13 +611,19 @@ def show_op(data: ShowModel):
# GraphQL integration
###
-from api.graphql.bindings import generate_schema
+def graphql_init(fast_api_app):
+ from api.graphql.bindings import generate_schema
-api.graphql.state.init()
+ api.graphql.state.init()
+ api.graphql.state.settings['app'] = app
-schema = generate_schema()
+ schema = generate_schema()
-app.add_route('/graphql', GraphQL(schema, debug=True))
+ if app.state.vyos_origins:
+ origins = app.state.vyos_origins
+ app.add_route('/graphql', CORSMiddleware(GraphQL(schema, debug=True), allow_origins=origins, allow_methods=("GET", "POST", "OPTIONS")))
+ else:
+ app.add_route('/graphql', GraphQL(schema, debug=True))
###
@@ -642,8 +649,9 @@ if __name__ == '__main__':
app.state.vyos_debug = server_config['debug']
app.state.vyos_strict = server_config['strict']
+ app.state.vyos_origins = server_config.get('cors', {}).get('origins', [])
- api.graphql.state.settings['app'] = app
+ graphql_init(app)
try:
if not server_config['socket']: