From f2f6b963b755ca5da3321e84738bfec1d08fb1ea Mon Sep 17 00:00:00 2001
From: John Estabrook <jestabro@vyos.io>
Date: Fri, 2 Jun 2023 23:29:47 -0500
Subject: xml: T5218: fix error and simplify logic in recursive option

---
 python/vyos/xml_ref/__init__.py   |   7 ++-
 python/vyos/xml_ref/definition.py | 103 ++++++++++++++++++++------------------
 2 files changed, 60 insertions(+), 50 deletions(-)

diff --git a/python/vyos/xml_ref/__init__.py b/python/vyos/xml_ref/__init__.py
index ae5184746..d3fb4ab07 100644
--- a/python/vyos/xml_ref/__init__.py
+++ b/python/vyos/xml_ref/__init__.py
@@ -62,5 +62,8 @@ def get_config_defaults(rpath: list, conf: dict, get_first_key=False,
                                               get_first_key=get_first_key,
                                               recursive=recursive)
 
-def merge_defaults(path: list, conf: dict) -> dict:
-    return load_reference().merge_defaults(path, conf)
+def merge_defaults(path: list, conf: dict, get_first_key=False,
+                   recursive=False) -> dict:
+    return load_reference().merge_defaults(path, conf,
+                                           get_first_key=get_first_key,
+                                           recursive=recursive)
diff --git a/python/vyos/xml_ref/definition.py b/python/vyos/xml_ref/definition.py
index 429331577..970dd915f 100644
--- a/python/vyos/xml_ref/definition.py
+++ b/python/vyos/xml_ref/definition.py
@@ -13,7 +13,7 @@
 # You should have received a copy of the GNU Lesser General Public License
 # along with this library.  If not, see <http://www.gnu.org/licenses/>.
 
-from typing import Union, Any
+from typing import Optional, Union, Any
 from vyos.configdict import dict_merge
 
 class Xml:
@@ -116,9 +116,17 @@ class Xml:
 
         return res
 
-    def _get_default_value(self, node: dict):
+    def _get_default_value(self, node: dict) -> Optional[str]:
         return self._get_ref_node_data(node, "default_value")
 
+    def _get_default(self, node: dict) -> Optional[Union[str, list]]:
+        default = self._get_default_value(node)
+        if default is None:
+            return None
+        if self._is_multi_node(node) and not isinstance(default, list):
+            return [default]
+        return default
+
     def get_defaults(self, path: list, get_first_key=False, recursive=False) -> dict:
         """Return dict containing default values below path
 
@@ -128,18 +136,23 @@ class Xml:
         'relative_defaults'
         """
         res: dict = {}
+        if self.is_tag(path):
+            return res
+
         d = self._get_ref_path(path)
+
+        if self._is_leaf_node(d):
+            default_value = self._get_default(d)
+            if default_value is not None:
+                return {path[-1]: default_value} if path else {}
+
         for k in list(d):
             if k in ('node_data', 'component_version') :
                 continue
-            d_k = d[k]
-            if self._is_leaf_node(d_k):
-                default_value = self._get_default_value(d_k)
+            if self._is_leaf_node(d[k]):
+                default_value = self._get_default(d[k])
                 if default_value is not None:
-                    pos = default_value
-                    if self._is_multi_node(d_k) and not isinstance(pos, list):
-                        pos = [pos]
-                    res |= {k: pos}
+                    res |= {k: default_value}
             elif self.is_tag(path + [k]):
                 # tag node defaults are used as suggestion, not default value;
                 # should this change, append to path and continue if recursive
@@ -150,8 +163,6 @@ class Xml:
                     res |= pos
         if res:
             if get_first_key or not path:
-                if not isinstance(res, dict):
-                    raise TypeError("Cannot get_first_key as data under node is not of type dict")
                 return res
             return {path[-1]: res}
 
@@ -163,7 +174,7 @@ class Xml:
             return [next(iter(c.keys()))] if c else []
         try:
             tmp = step(conf)
-            if self.is_tag_value(path + tmp):
+            if tmp and self.is_tag_value(path + tmp):
                 c = conf[tmp[0]]
                 if not isinstance(c, dict):
                     raise ValueError
@@ -175,57 +186,53 @@ class Xml:
             return False
         return True
 
-    def relative_defaults(self, rpath: list, conf: dict, get_first_key=False,
+    def _relative_defaults(self, rpath: list, conf: dict, recursive=False) -> dict:
+        res: dict = {}
+        res = self.get_defaults(rpath, recursive=recursive,
+                                get_first_key=True)
+        for k in list(conf):
+            if isinstance(conf[k], dict):
+                step = self._relative_defaults(rpath + [k], conf=conf[k],
+                                               recursive=recursive)
+                res |= step
+
+        if res:
+            return {rpath[-1]: res} if rpath else res
+
+        return {}
+
+    def relative_defaults(self, path: list, conf: dict, get_first_key=False,
                           recursive=False) -> dict:
         """Return dict containing defaults along paths of a config dict
         """
         if not conf:
-            return self.get_defaults(rpath, get_first_key=get_first_key,
+            return self.get_defaults(path, get_first_key=get_first_key,
                                      recursive=recursive)
-        if rpath and rpath[-1] in list(conf):
-            conf = conf[rpath[-1]]
-            if not isinstance(conf, dict):
-                raise TypeError('conf at path is not of type dict')
+        if path and path[-1] in list(conf):
+            conf = conf[path[-1]]
+            conf = {} if not isinstance(conf, dict) else conf
 
-        if not self._well_defined(rpath, conf):
+        if not self._well_defined(path, conf):
             print('path to config dict does not define full config paths')
             return {}
 
-        res: dict = {}
-        for k in list(conf):
-            pos = self.get_defaults(rpath + [k], recursive=recursive)
-            res |= pos
-
-            if isinstance(conf[k], dict):
-                step = self.relative_defaults(rpath + [k], conf=conf[k],
-                                              recursive=recursive)
-                res |= step
+        res = self._relative_defaults(path, conf, recursive=recursive)
 
-        if res:
-            if get_first_key:
-                return res
-            return {rpath[-1]: res} if rpath else res
+        if get_first_key and path:
+            if res.values():
+                res = next(iter(res.values()))
+            else:
+                res = {}
 
-        return {}
+        return res
 
-    def merge_defaults(self, path: list, conf: dict) -> dict:
+    def merge_defaults(self, path: list, conf: dict, get_first_key=False,
+                       recursive=False) -> dict:
         """Return config dict with defaults non-destructively merged
 
         This merges non-recursive defaults relative to the config dict.
         """
-        if path[-1] in list(conf):
-            config = conf[path[-1]]
-            if not isinstance(config, dict):
-                raise TypeError('conf at path is not of type dict')
-            shift = False
-        else:
-            config = conf
-            shift = True
-
-        if not self._well_defined(path, config):
-            print('path to config dict does not define config paths; conf returned unchanged')
-            return conf
-
-        d = self.relative_defaults(path, conf=config, get_first_key=shift)
+        d = self.relative_defaults(path, conf, get_first_key=get_first_key,
+                                   recursive=recursive)
         d = dict_merge(d, conf)
         return d
-- 
cgit v1.2.3