summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--interface-definitions/https.xml.in13
-rwxr-xr-xsrc/conf_mode/http-api.py6
-rwxr-xr-xsrc/services/vyos-http-api-server18
3 files changed, 32 insertions, 5 deletions
diff --git a/interface-definitions/https.xml.in b/interface-definitions/https.xml.in
index 33e43a432..6fea2f1f6 100644
--- a/interface-definitions/https.xml.in
+++ b/interface-definitions/https.xml.in
@@ -107,6 +107,19 @@
<valueless/>
</properties>
</leafNode>
+ <node name="cors">
+ <properties>
+ <help>Set CORS options</help>
+ </properties>
+ <children>
+ <leafNode name="allow-origin">
+ <properties>
+ <help>Allow resource request from origin</help>
+ <multi/>
+ </properties>
+ </leafNode>
+ </children>
+ </node>
</children>
</node>
<node name="api-restrict">
diff --git a/src/conf_mode/http-api.py b/src/conf_mode/http-api.py
index cd0191599..ea0743cd5 100755
--- a/src/conf_mode/http-api.py
+++ b/src/conf_mode/http-api.py
@@ -67,6 +67,12 @@ def get_config(config=None):
port = conf.return_value('port')
http_api['port'] = port
+ if conf.exists('cors'):
+ http_api['cors'] = {}
+ if conf.exists('cors allow-origin'):
+ origins = conf.return_values('cors allow-origin')
+ http_api['cors']['origins'] = origins[:]
+
if conf.exists('keys'):
for name in conf.list_nodes('keys id'):
if conf.exists('keys id {0} key'.format(name)):
diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server
index f79058683..06871f1d6 100755
--- a/src/services/vyos-http-api-server
+++ b/src/services/vyos-http-api-server
@@ -32,6 +32,7 @@ from fastapi.responses import HTMLResponse
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
from pydantic import BaseModel, StrictStr, validator
+from starlette.middleware.cors import CORSMiddleware
from starlette.datastructures import FormData
from starlette.formparsers import FormParser, MultiPartParser
from multipart.multipart import parse_options_header
@@ -610,13 +611,19 @@ def show_op(data: ShowModel):
# GraphQL integration
###
-from api.graphql.bindings import generate_schema
+def graphql_init(fast_api_app):
+ from api.graphql.bindings import generate_schema
-api.graphql.state.init()
+ api.graphql.state.init()
+ api.graphql.state.settings['app'] = app
-schema = generate_schema()
+ schema = generate_schema()
-app.add_route('/graphql', GraphQL(schema, debug=True))
+ if app.state.vyos_origins:
+ origins = app.state.vyos_origins
+ app.add_route('/graphql', CORSMiddleware(GraphQL(schema, debug=True), allow_origins=origins, allow_methods=("GET", "POST", "OPTIONS")))
+ else:
+ app.add_route('/graphql', GraphQL(schema, debug=True))
###
@@ -642,8 +649,9 @@ if __name__ == '__main__':
app.state.vyos_debug = server_config['debug']
app.state.vyos_strict = server_config['strict']
+ app.state.vyos_origins = server_config.get('cors', {}).get('origins', [])
- api.graphql.state.settings['app'] = app
+ graphql_init(app)
try:
if not server_config['socket']: