From 2f368dd235ef100eb3731871cbe88d11290122ec Mon Sep 17 00:00:00 2001
From: Giggum <152240782+Giggum@users.noreply.github.com>
Date: Wed, 3 Jul 2024 21:44:17 -0400
Subject: op-mode: T6371: fix output of NAT rules with single port range

---
 src/op_mode/nat.py | 56 ++++++++++++++++++++++++++++++++++++------------------
 1 file changed, 37 insertions(+), 19 deletions(-)

(limited to 'src')

diff --git a/src/op_mode/nat.py b/src/op_mode/nat.py
index 16a545cda..c6cf4770a 100755
--- a/src/op_mode/nat.py
+++ b/src/op_mode/nat.py
@@ -31,6 +31,7 @@ from vyos.utils.dict import dict_search
 ArgDirection = typing.Literal['source', 'destination']
 ArgFamily = typing.Literal['inet', 'inet6']
 
+
 def _get_xml_translation(direction, family, address=None):
     """
     Get conntrack XML output --src-nat|--dst-nat
@@ -99,22 +100,35 @@ def _get_raw_translation(direction, family, address=None):
 
 
 def _get_formatted_output_rules(data, direction, family):
-    def _get_ports_for_output(my_dict):
-        # Get and insert all configured ports or port ranges into output string
-        for index, port in enumerate(my_dict['set']):
-            if 'range' in str(my_dict['set'][index]):
-                output = my_dict['set'][index]['range']
-                output = '-'.join(map(str, output))
-            else:
-                output = str(port)
-            if index == 0:
-                output = str(output)
-            else:
-                output = ','.join([output,output])
-        # Handle case where configured ports are a negated list
-        if my_dict['op'] == '!=':
-            output = '!' + output
-        return(output)
+
+
+    def _get_ports_for_output(rules):
+        """
+        Return: string of configured ports
+        """
+        ports = []
+        if 'set' in rules:
+            for index, port in enumerate(rules['set']):
+                if 'range' in str(rules['set'][index]):
+                    output = rules['set'][index]['range']
+                    output = '-'.join(map(str, output))
+                else:
+                    output = str(port)
+                ports.append(output)
+        # When NAT rule contains port range or single port
+        # JSON will not contain keyword 'set'
+        elif 'range' in rules:
+            output = rules['range']
+            output = '-'.join(map(str, output))
+            ports.append(output)
+        else:
+            output = rules['right']
+            ports.append(str(output))
+        result = ','.join(ports)
+        # Handle case where ports in NAT rule are negated
+        if rules['op'] == '!=':
+            result = '!' + result
+        return(result)
 
     # Add default values before loop
     sport, dport, proto = 'any', 'any', 'any'
@@ -132,7 +146,10 @@ def _get_formatted_output_rules(data, direction, family):
                 if jmespath.search('rule.expr[*].match.left.meta', rule) else 'any'
         for index, match in enumerate(jmespath.search('rule.expr[*].match', rule)):
             if 'payload' in match['left']:
-                if isinstance(match['right'], dict) and ('prefix' in match['right'] or 'set' in match['right']):
+                # Handle NAT rule containing comma-seperated list of ports
+                if (isinstance(match['right'], dict) and
+                    ('prefix' in match['right'] or 'set' in match['right'] or
+                     'range' in match['right'])):
                     # Merge dict src/dst l3_l4 parameters
                     my_dict = {**match['left']['payload'], **match['right']}
                     my_dict['op'] = match['op']
@@ -146,6 +163,7 @@ def _get_formatted_output_rules(data, direction, family):
                         sport = _get_ports_for_output(my_dict)
                     elif my_dict['field'] == 'dport':
                         dport = _get_ports_for_output(my_dict)
+                # Handle NAT rule containing a single port
                 else:
                     field = jmespath.search('left.payload.field', match)
                     if field == 'saddr':
@@ -153,9 +171,9 @@ def _get_formatted_output_rules(data, direction, family):
                     elif field == 'daddr':
                         daddr = match.get('right')
                     elif field == 'sport':
-                        sport = match.get('right')
+                        sport = _get_ports_for_output(match)
                     elif field == 'dport':
-                        dport = match.get('right')
+                        dport = _get_ports_for_output(match)
             else:
                 saddr = '::/0' if family == 'inet6' else '0.0.0.0/0'
                 daddr = '::/0' if family == 'inet6' else '0.0.0.0/0'
-- 
cgit v1.2.3