From ff43733074675b94ce4ead83fe63870b6cf953c5 Mon Sep 17 00:00:00 2001
From: Viacheslav Hletenko <v.gletenko@vyos.io>
Date: Fri, 6 Oct 2023 09:18:35 +0000
Subject: T5165: Implement policy local-route source and destination port

Add `policy local-route` source and destination port

set policy local-route rule 23 destination port '222'
set policy local-route rule 23 protocol 'tcp'
set policy local-route rule 23 set table '123'
set policy local-route rule 23 source port '8888'

% ip rule show prio 23
23:	from all ipproto tcp sport 8888 dport 222 lookup 123
---
 interface-definitions/policy-local-route.xml.in |  5 ++
 smoketest/scripts/cli/test_policy.py            | 50 ++++++++++++++++
 src/conf_mode/policy-local-route.py             | 79 +++++++++++++++++++++++--
 3 files changed, 129 insertions(+), 5 deletions(-)

diff --git a/interface-definitions/policy-local-route.xml.in b/interface-definitions/policy-local-route.xml.in
index 6827bd64e..15be099c9 100644
--- a/interface-definitions/policy-local-route.xml.in
+++ b/interface-definitions/policy-local-route.xml.in
@@ -60,6 +60,7 @@
                 </properties>
                 <children>
                   #include <include/policy/local-route_rule_ipv4_address.xml.i>
+                  #include <include/port-number.xml.i>
                 </children>
               </node>
               <node name="destination">
@@ -68,6 +69,7 @@
                 </properties>
                 <children>
                   #include <include/policy/local-route_rule_ipv4_address.xml.i>
+                  #include <include/port-number.xml.i>
                 </children>
               </node>
               #include <include/interface/inbound-interface.xml.i>
@@ -125,12 +127,14 @@
                   </constraint>
                 </properties>
               </leafNode>
+              #include <include/policy/local-route_rule_protocol.xml.i>
               <node name="source">
                 <properties>
                   <help>Source parameters</help>
                 </properties>
                 <children>
                   #include <include/policy/local-route_rule_ipv6_address.xml.i>
+                  #include <include/port-number.xml.i>
                 </children>
               </node>
               <node name="destination">
@@ -139,6 +143,7 @@
                 </properties>
                 <children>
                   #include <include/policy/local-route_rule_ipv6_address.xml.i>
+                  #include <include/port-number.xml.i>
                 </children>
               </node>
               #include <include/interface/inbound-interface.xml.i>
diff --git a/smoketest/scripts/cli/test_policy.py b/smoketest/scripts/cli/test_policy.py
index 4ac422d5f..51a33f978 100755
--- a/smoketest/scripts/cli/test_policy.py
+++ b/smoketest/scripts/cli/test_policy.py
@@ -1541,6 +1541,56 @@ class TestPolicy(VyOSUnitTestSHIM.TestCase):
 
         self.assertEqual(sort_ip(tmp), sort_ip(original))
 
+    # Test set table for destination, source, protocol, fwmark and port
+    def test_protocol_port_address_fwmark_table_id(self):
+        path = base_path + ['local-route']
+
+        dst = '203.0.113.5'
+        src_list = ['203.0.113.1', '203.0.113.2']
+        rule = '23'
+        fwmark = '123456'
+        table = '123'
+        new_table = '111'
+        proto = 'udp'
+        new_proto = 'tcp'
+        src_port = '5555'
+        dst_port = '8888'
+
+        self.cli_set(path + ['rule', rule, 'set', 'table', table])
+        self.cli_set(path + ['rule', rule, 'destination', 'address', dst])
+        self.cli_set(path + ['rule', rule, 'source', 'port', src_port])
+        self.cli_set(path + ['rule', rule, 'protocol', proto])
+        self.cli_set(path + ['rule', rule, 'fwmark', fwmark])
+        self.cli_set(path + ['rule', rule, 'destination', 'port', dst_port])
+        for src in src_list:
+            self.cli_set(path + ['rule', rule, 'source', 'address', src])
+
+        self.cli_commit()
+
+        original = """
+        23:	from 203.0.113.1 to 203.0.113.5 fwmark 0x1e240 ipproto udp sport 5555 dport 8888 lookup 123
+        23:	from 203.0.113.2 to 203.0.113.5 fwmark 0x1e240 ipproto udp sport 5555 dport 8888 lookup 123
+        """
+        tmp = cmd(f'ip rule show prio {rule}')
+
+        self.assertEqual(sort_ip(tmp), sort_ip(original))
+
+        # Change table and protocol, delete fwmark and source port
+        self.cli_delete(path + ['rule', rule, 'fwmark'])
+        self.cli_delete(path + ['rule', rule, 'source', 'port'])
+        self.cli_set(path + ['rule', rule, 'set', 'table', new_table])
+        self.cli_set(path + ['rule', rule, 'protocol', new_proto])
+
+        self.cli_commit()
+
+        original = """
+        23:	from 203.0.113.1 to 203.0.113.5 ipproto tcp dport 8888 lookup 111
+        23:	from 203.0.113.2 to 203.0.113.5 ipproto tcp dport 8888 lookup 111
+        """
+        tmp = cmd(f'ip rule show prio {rule}')
+
+        self.assertEqual(sort_ip(tmp), sort_ip(original))
+
     # Test set table for sources with fwmark
     def test_fwmark_sources_table_id(self):
         path = base_path + ['local-route']
diff --git a/src/conf_mode/policy-local-route.py b/src/conf_mode/policy-local-route.py
index 2e8aabb80..91e4fce2c 100755
--- a/src/conf_mode/policy-local-route.py
+++ b/src/conf_mode/policy-local-route.py
@@ -52,19 +52,28 @@ def get_config(config=None):
         if tmp:
             for rule in (tmp or []):
                 src = leaf_node_changed(conf, base_rule + [rule, 'source', 'address'])
+                src_port = leaf_node_changed(conf, base_rule + [rule, 'source', 'port'])
                 fwmk = leaf_node_changed(conf, base_rule + [rule, 'fwmark'])
                 iif = leaf_node_changed(conf, base_rule + [rule, 'inbound-interface'])
                 dst = leaf_node_changed(conf, base_rule + [rule, 'destination', 'address'])
+                dst_port = leaf_node_changed(conf, base_rule + [rule, 'destination', 'port'])
+                table = leaf_node_changed(conf, base_rule + [rule, 'set', 'table'])
                 proto = leaf_node_changed(conf, base_rule + [rule, 'protocol'])
                 rule_def = {}
                 if src:
                     rule_def = dict_merge({'source': {'address': src}}, rule_def)
+                if src_port:
+                    rule_def = dict_merge({'source': {'port': src_port}}, rule_def)
                 if fwmk:
                     rule_def = dict_merge({'fwmark' : fwmk}, rule_def)
                 if iif:
                     rule_def = dict_merge({'inbound_interface' : iif}, rule_def)
                 if dst:
                     rule_def = dict_merge({'destination': {'address': dst}}, rule_def)
+                if dst_port:
+                    rule_def = dict_merge({'destination': {'port': dst_port}}, rule_def)
+                if table:
+                    rule_def = dict_merge({'table' : table}, rule_def)
                 if proto:
                     rule_def = dict_merge({'protocol' : proto}, rule_def)
                 dict = dict_merge({dict_id : {rule : rule_def}}, dict)
@@ -79,9 +88,12 @@ def get_config(config=None):
         if 'rule' in pbr[route]:
             for rule, rule_config in pbr[route]['rule'].items():
                 src = leaf_node_changed(conf, base_rule + [rule, 'source', 'address'])
+                src_port = leaf_node_changed(conf, base_rule + [rule, 'source', 'port'])
                 fwmk = leaf_node_changed(conf, base_rule + [rule, 'fwmark'])
                 iif = leaf_node_changed(conf, base_rule + [rule, 'inbound-interface'])
                 dst = leaf_node_changed(conf, base_rule + [rule, 'destination', 'address'])
+                dst_port = leaf_node_changed(conf, base_rule + [rule, 'destination', 'port'])
+                table = leaf_node_changed(conf, base_rule + [rule, 'set', 'table'])
                 proto = leaf_node_changed(conf, base_rule + [rule, 'protocol'])
                 # keep track of changes in configuration
                 # otherwise we might remove an existing node although nothing else has changed
@@ -105,14 +117,32 @@ def get_config(config=None):
                     if len(src) > 0:
                         rule_def = dict_merge({'source': {'address': src}}, rule_def)
 
+                # source port
+                if src_port is None:
+                    if 'source' in rule_config:
+                        if 'port' in rule_config['source']:
+                            tmp = rule_config['source']['port']
+                            if isinstance(tmp, str):
+                                tmp = [tmp]
+                            rule_def = dict_merge({'source': {'port': tmp}}, rule_def)
+                else:
+                    changed = True
+                    if len(src_port) > 0:
+                        rule_def = dict_merge({'source': {'port': src_port}}, rule_def)
+
+                # fwmark
                 if fwmk is None:
                     if 'fwmark' in rule_config:
-                        rule_def = dict_merge({'fwmark': rule_config['fwmark']}, rule_def)
+                        tmp = rule_config['fwmark']
+                        if isinstance(tmp, str):
+                            tmp = [tmp]
+                        rule_def = dict_merge({'fwmark': tmp}, rule_def)
                 else:
                     changed = True
                     if len(fwmk) > 0:
                         rule_def = dict_merge({'fwmark' : fwmk}, rule_def)
 
+                # inbound-interface
                 if iif is None:
                     if 'inbound_interface' in rule_config:
                         rule_def = dict_merge({'inbound_interface': rule_config['inbound_interface']}, rule_def)
@@ -121,6 +151,7 @@ def get_config(config=None):
                     if len(iif) > 0:
                         rule_def = dict_merge({'inbound_interface' : iif}, rule_def)
 
+                # destination address
                 if dst is None:
                     if 'destination' in rule_config:
                         if 'address' in rule_config['destination']:
@@ -130,9 +161,35 @@ def get_config(config=None):
                     if len(dst) > 0:
                         rule_def = dict_merge({'destination': {'address': dst}}, rule_def)
 
+                # destination port
+                if dst_port is None:
+                    if 'destination' in rule_config:
+                        if 'port' in rule_config['destination']:
+                            tmp = rule_config['destination']['port']
+                            if isinstance(tmp, str):
+                                tmp = [tmp]
+                            rule_def = dict_merge({'destination': {'port': tmp}}, rule_def)
+                else:
+                    changed = True
+                    if len(dst_port) > 0:
+                        rule_def = dict_merge({'destination': {'port': dst_port}}, rule_def)
+
+                # table
+                if table is None:
+                    if 'set' in rule_config and 'table' in rule_config['set']:
+                        rule_def = dict_merge({'table': [rule_config['set']['table']]}, rule_def)
+                else:
+                    changed = True
+                    if len(table) > 0:
+                        rule_def = dict_merge({'table' : table}, rule_def)
+
+                # protocol
                 if proto is None:
                     if 'protocol' in rule_config:
-                        rule_def = dict_merge({'protocol': rule_config['protocol']}, rule_def)
+                        tmp = rule_config['protocol']
+                        if isinstance(tmp, str):
+                            tmp = [tmp]
+                        rule_def = dict_merge({'protocol': tmp}, rule_def)
                 else:
                     changed = True
                     if len(proto) > 0:
@@ -192,19 +249,27 @@ def apply(pbr):
 
             for rule, rule_config in pbr[rule_rm].items():
                 source = rule_config.get('source', {}).get('address', [''])
+                source_port = rule_config.get('source', {}).get('port', [''])
                 destination = rule_config.get('destination', {}).get('address', [''])
+                destination_port = rule_config.get('destination', {}).get('port', [''])
                 fwmark = rule_config.get('fwmark', [''])
                 inbound_interface = rule_config.get('inbound_interface', [''])
                 protocol = rule_config.get('protocol', [''])
+                table = rule_config.get('table', [''])
 
-                for src, dst, fwmk, iif, proto in product(source, destination, fwmark, inbound_interface, protocol):
+                for src, dst, src_port, dst_port, fwmk, iif, proto, table in product(
+                        source, destination, source_port, destination_port,
+                        fwmark, inbound_interface, protocol, table):
                     f_src = '' if src == '' else f' from {src} '
+                    f_src_port = '' if src_port == '' else f' sport {src_port} '
                     f_dst = '' if dst == '' else f' to {dst} '
+                    f_dst_port = '' if dst_port == '' else f' dport {dst_port} '
                     f_fwmk = '' if fwmk == '' else f' fwmark {fwmk} '
                     f_iif = '' if iif == '' else f' iif {iif} '
                     f_proto = '' if proto == '' else f' ipproto {proto} '
+                    f_table = '' if table == '' else f' lookup {table} '
 
-                    call(f'ip{v6} rule del prio {rule} {f_src}{f_dst}{f_fwmk}{f_iif}')
+                    call(f'ip{v6} rule del prio {rule} {f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif}{f_table}')
 
     # Generate new config
     for route in ['local_route', 'local_route6']:
@@ -218,7 +283,9 @@ def apply(pbr):
             for rule, rule_config in pbr_route['rule'].items():
                 table = rule_config['set'].get('table', '')
                 source = rule_config.get('source', {}).get('address', ['all'])
+                source_port = rule_config.get('source', {}).get('port', '')
                 destination = rule_config.get('destination', {}).get('address', ['all'])
+                destination_port = rule_config.get('destination', {}).get('port', '')
                 fwmark = rule_config.get('fwmark', '')
                 inbound_interface = rule_config.get('inbound_interface', '')
                 protocol = rule_config.get('protocol', '')
@@ -227,11 +294,13 @@ def apply(pbr):
                     f_src = f' from {src} ' if src else ''
                     for dst in destination:
                         f_dst = f' to {dst} ' if dst else ''
+                        f_src_port = f' sport {source_port} ' if source_port else ''
+                        f_dst_port = f' dport {destination_port} ' if destination_port else ''
                         f_fwmk = f' fwmark {fwmark} ' if fwmark else ''
                         f_iif = f' iif {inbound_interface} ' if inbound_interface else ''
                         f_proto = f' ipproto {protocol} ' if protocol else ''
 
-                        call(f'ip{v6} rule add prio {rule}{f_src}{f_dst}{f_proto}{f_fwmk}{f_iif} lookup {table}')
+                        call(f'ip{v6} rule add prio {rule}{f_src}{f_dst}{f_proto}{f_src_port}{f_dst_port}{f_fwmk}{f_iif} lookup {table}')
 
     return None
 
-- 
cgit v1.2.3