From 701df0b70a8979249232d5ef60e86601c295098d Mon Sep 17 00:00:00 2001
From: John Estabrook <jestabro@vyos.io>
Date: Sat, 10 Jun 2023 16:45:30 -0500
Subject: http-api: T5263: simplify form errors

---
 src/services/vyos-http-api-server | 108 +++++++++++++++-----------------------
 1 file changed, 42 insertions(+), 66 deletions(-)

(limited to 'src')

diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server
index 89c685f32..dda137943 100755
--- a/src/services/vyos-http-api-server
+++ b/src/services/vyos-http-api-server
@@ -260,18 +260,15 @@ def auth_required(data: ApiModel):
 # the explicit validation may be dropped, if desired, in favor of native
 # validation by FastAPI/Pydantic, as is used for application/json requests
 class MultipartRequest(Request):
-    ERR_MISSING_KEY = False
-    ERR_MISSING_DATA = False
-    ERR_NOT_JSON = False
-    ERR_NOT_DICT = False
-    ERR_NO_OP = False
-    ERR_NO_PATH = False
-    ERR_EMPTY_PATH = False
-    ERR_PATH_NOT_LIST = False
-    ERR_VALUE_NOT_STRING = False
-    ERR_PATH_NOT_LIST_OF_STR = False
-    offending_command = {}
-    exception = None
+    _form_err = ()
+    @property
+    def form_err(self):
+        return self._form_err
+
+    @form_err.setter
+    def form_err(self, val):
+        if not self._form_err:
+            self._form_err = val
 
     @property
     def orig_headers(self):
@@ -310,19 +307,20 @@ class MultipartRequest(Request):
 
             form_data = await self.form()
             if form_data:
+                endpoint = self.url.path
                 logger.debug("processing form data")
                 for k, v in form_data.multi_items():
                     forms[k] = v
 
                 if 'data' not in forms:
-                    self.ERR_MISSING_DATA = True
+                    self.form_err = (422, "Non-empty data field is required")
+                    return self._body
                 else:
                     try:
                         tmp = json.loads(forms['data'])
                     except json.JSONDecodeError as e:
-                        self.ERR_NOT_JSON = True
-                        self.exception = e
-                        tmp = {}
+                        self.form_err = (400, f'Failed to parse JSON: {e}')
+                        return self._body
                     if isinstance(tmp, list):
                         merge['commands'] = tmp
                     else:
@@ -336,29 +334,33 @@ class MultipartRequest(Request):
 
                 for c in cmds:
                     if not isinstance(c, dict):
-                        self.ERR_NOT_DICT = True
-                        self.offending_command = c
-                    elif 'op' not in c:
-                        self.ERR_NO_OP = True
-                        self.offending_command = c
-                    elif 'path' not in c:
-                        self.ERR_NO_PATH = True
-                        self.offending_command = c
-                    elif not c['path']:
-                        self.ERR_EMPTY_PATH = True
-                        self.offending_command = c
-                    elif not isinstance(c['path'], list):
-                        self.ERR_PATH_NOT_LIST = True
-                        self.offending_command = c
-                    elif not all(isinstance(el, str) for el in c['path']):
-                        self.ERR_PATH_NOT_LIST_OF_STR = True
-                        self.offending_command = c
-                    elif 'value' in c and not isinstance(c['value'], str):
-                        self.ERR_VALUE_NOT_STRING = True
-                        self.offending_command = c
+                        self.form_err = (400,
+                        f"Malformed command '{c}': any command must be JSON of dict")
+                        return self._body
+                    if 'op' not in c:
+                        self.form_err = (400,
+                        f"Malformed command '{c}': missing 'op' field")
+                    if endpoint not in ('/config-file', '/container-image',
+                                        '/image'):
+                        if 'path' not in c:
+                            self.form_err = (400,
+                            f"Malformed command '{c}': missing 'path' field")
+                        elif not isinstance(c['path'], list):
+                            self.form_err = (400,
+                            f"Malformed command '{c}': 'path' field must be a list")
+                        elif not all(isinstance(el, str) for el in c['path']):
+                            self.form_err = (400,
+                            f"Malformed command '{0}': 'path' field must be a list of strings")
+                    if endpoint in ('/configure'):
+                        if not c['path']:
+                            self.form_err = (400,
+                            f"Malformed command '{c}': 'path' list must be non-empty")
+                        if 'value' in c and not isinstance(c['value'], str):
+                            self.form_err = (400,
+                            f"Malformed command '{c}': 'value' field must be a string")
 
                 if 'key' not in forms and 'key' not in merge:
-                    self.ERR_MISSING_KEY = True
+                    self.form_err = (401, "Valid API key is required")
                 if 'key' in forms and 'key' not in merge:
                     merge['key'] = forms['key']
 
@@ -374,40 +376,14 @@ class MultipartRoute(APIRoute):
 
         async def custom_route_handler(request: Request) -> Response:
             request = MultipartRequest(request.scope, request.receive)
-            endpoint = request.url.path
             try:
                 response: Response = await original_route_handler(request)
             except HTTPException as e:
                 return error(e.status_code, e.detail)
             except Exception as e:
-                if request.ERR_MISSING_KEY:
-                    return error(401, "Valid API key is required")
-                if request.ERR_MISSING_DATA:
-                    return error(422, "Non-empty data field is required")
-                if request.ERR_NOT_JSON:
-                    return error(400, "Failed to parse JSON: {0}".format(request.exception))
-                if endpoint == '/configure':
-                    if request.ERR_NOT_DICT:
-                        return error(400, "Malformed command \"{0}\": any command must be a dict".format(json.dumps(request.offending_command)))
-                    if request.ERR_NO_OP:
-                        return error(400, "Malformed command \"{0}\": missing \"op\" field".format(json.dumps(request.offending_command)))
-                    if request.ERR_NO_PATH:
-                        return error(400, "Malformed command \"{0}\": missing \"path\" field".format(json.dumps(request.offending_command)))
-                    if request.ERR_EMPTY_PATH:
-                        return error(400, "Malformed command \"{0}\": empty path".format(json.dumps(request.offending_command)))
-                    if request.ERR_PATH_NOT_LIST:
-                        return error(400, "Malformed command \"{0}\": \"path\" field must be a list".format(json.dumps(request.offending_command)))
-                    if request.ERR_VALUE_NOT_STRING:
-                        return error(400, "Malformed command \"{0}\": \"value\" field must be a string".format(json.dumps(request.offending_command)))
-                    if request.ERR_PATH_NOT_LIST_OF_STR:
-                        return error(400, "Malformed command \"{0}\": \"path\" field must be a list of strings".format(json.dumps(request.offending_command)))
-                if endpoint in ('/retrieve','/generate','/show','/reset'):
-                    if request.ERR_NO_OP or request.ERR_NO_PATH:
-                        return error(400, "Missing required field. \"op\" and \"path\" fields are required")
-                if endpoint in ('/config-file', '/image', '/container-image'):
-                    if request.ERR_NO_OP:
-                        return error(400, "Missing required field \"op\"")
-
+                form_err = request.form_err
+                if form_err:
+                    return error(*form_err)
                 raise e
 
             return response
-- 
cgit v1.2.3