diff options
| -rwxr-xr-x | src/op_mode/nat.py | 56 | 
1 files changed, 37 insertions, 19 deletions
| diff --git a/src/op_mode/nat.py b/src/op_mode/nat.py index 16a545cda..c6cf4770a 100755 --- a/src/op_mode/nat.py +++ b/src/op_mode/nat.py @@ -31,6 +31,7 @@ from vyos.utils.dict import dict_search  ArgDirection = typing.Literal['source', 'destination']  ArgFamily = typing.Literal['inet', 'inet6'] +  def _get_xml_translation(direction, family, address=None):      """      Get conntrack XML output --src-nat|--dst-nat @@ -99,22 +100,35 @@ def _get_raw_translation(direction, family, address=None):  def _get_formatted_output_rules(data, direction, family): -    def _get_ports_for_output(my_dict): -        # Get and insert all configured ports or port ranges into output string -        for index, port in enumerate(my_dict['set']): -            if 'range' in str(my_dict['set'][index]): -                output = my_dict['set'][index]['range'] -                output = '-'.join(map(str, output)) -            else: -                output = str(port) -            if index == 0: -                output = str(output) -            else: -                output = ','.join([output,output]) -        # Handle case where configured ports are a negated list -        if my_dict['op'] == '!=': -            output = '!' + output -        return(output) + + +    def _get_ports_for_output(rules): +        """ +        Return: string of configured ports +        """ +        ports = [] +        if 'set' in rules: +            for index, port in enumerate(rules['set']): +                if 'range' in str(rules['set'][index]): +                    output = rules['set'][index]['range'] +                    output = '-'.join(map(str, output)) +                else: +                    output = str(port) +                ports.append(output) +        # When NAT rule contains port range or single port +        # JSON will not contain keyword 'set' +        elif 'range' in rules: +            output = rules['range'] +            output = '-'.join(map(str, output)) +            ports.append(output) +        else: +            output = rules['right'] +            ports.append(str(output)) +        result = ','.join(ports) +        # Handle case where ports in NAT rule are negated +        if rules['op'] == '!=': +            result = '!' + result +        return(result)      # Add default values before loop      sport, dport, proto = 'any', 'any', 'any' @@ -132,7 +146,10 @@ def _get_formatted_output_rules(data, direction, family):                  if jmespath.search('rule.expr[*].match.left.meta', rule) else 'any'          for index, match in enumerate(jmespath.search('rule.expr[*].match', rule)):              if 'payload' in match['left']: -                if isinstance(match['right'], dict) and ('prefix' in match['right'] or 'set' in match['right']): +                # Handle NAT rule containing comma-seperated list of ports +                if (isinstance(match['right'], dict) and +                    ('prefix' in match['right'] or 'set' in match['right'] or +                     'range' in match['right'])):                      # Merge dict src/dst l3_l4 parameters                      my_dict = {**match['left']['payload'], **match['right']}                      my_dict['op'] = match['op'] @@ -146,6 +163,7 @@ def _get_formatted_output_rules(data, direction, family):                          sport = _get_ports_for_output(my_dict)                      elif my_dict['field'] == 'dport':                          dport = _get_ports_for_output(my_dict) +                # Handle NAT rule containing a single port                  else:                      field = jmespath.search('left.payload.field', match)                      if field == 'saddr': @@ -153,9 +171,9 @@ def _get_formatted_output_rules(data, direction, family):                      elif field == 'daddr':                          daddr = match.get('right')                      elif field == 'sport': -                        sport = match.get('right') +                        sport = _get_ports_for_output(match)                      elif field == 'dport': -                        dport = match.get('right') +                        dport = _get_ports_for_output(match)              else:                  saddr = '::/0' if family == 'inet6' else '0.0.0.0/0'                  daddr = '::/0' if family == 'inet6' else '0.0.0.0/0' | 
