From 1b3350788ceeace52e2d693a18d92d82464220c0 Mon Sep 17 00:00:00 2001
From: Christian Breunig <christian@breunig.cc>
Date: Sat, 20 Jul 2024 10:35:44 +0200
Subject: interfaces: T6592: moving an interface between VRF instances failed

To reproduce:

    set vrf name mgmt table '150'
    set vrf name no-mgmt table '151'
    set interfaces ethernet eth2 vrf 'mgmt'
    commit

    set interfaces ethernet eth2 vrf no-mgmt
    commit

This resulted in an error while interacting with nftables:
[Errno 1] failed to run command: nft add element inet vrf_zones ct_iface_map { "eth2" : 151 }

The reason is that the old mapping entry still exists and was not removed.

This commit adds a new utility function get_vrf_tableid() and compares the
current and new VRF table IDs assigned to an interface. If the IDs do not
match, the nftables ct_iface_map entry is removed before the new entry is added.

(cherry picked from commit 452068ce78581bb6fba2df4dba197e95b9aeb33d)
---
 python/vyos/ifconfig/interface.py             | 28 ++++++++++------
 python/vyos/utils/network.py                  | 23 +++++++++++---
 smoketest/scripts/cli/base_interfaces_test.py | 46 +++++++++++++++++++++++++++
 3 files changed, 83 insertions(+), 14 deletions(-)

diff --git a/python/vyos/ifconfig/interface.py b/python/vyos/ifconfig/interface.py
index 117479ade..748830004 100644
--- a/python/vyos/ifconfig/interface.py
+++ b/python/vyos/ifconfig/interface.py
@@ -37,6 +37,7 @@ from vyos.utils.network import mac2eui64
 from vyos.utils.dict import dict_search
 from vyos.utils.network import get_interface_config
 from vyos.utils.network import get_interface_namespace
+from vyos.utils.network import get_vrf_tableid
 from vyos.utils.network import is_netns_interface
 from vyos.utils.process import is_systemd_service_active
 from vyos.utils.process import run
@@ -402,7 +403,7 @@ class Interface(Control):
         if netns: cmd = f'ip netns exec {netns} {cmd}'
         return self._cmd(cmd)
 
-    def _set_vrf_ct_zone(self, vrf):
+    def _set_vrf_ct_zone(self, vrf, old_vrf_tableid=None):
         """
         Add/Remove rules in nftables to associate traffic in VRF to an
         individual conntack zone
@@ -411,20 +412,27 @@ class Interface(Control):
         if 'netns' in self.config:
             return None
 
+        def nft_check_and_run(nft_command):
+            # Check if deleting is possible first to avoid raising errors
+            _, err = self._popen(f'nft --check {nft_command}')
+            if not err:
+                # Remove map element
+                self._cmd(f'nft {nft_command}')
+
         if vrf:
             # Get routing table ID for VRF
-            vrf_table_id = get_interface_config(vrf).get('linkinfo', {}).get(
-                'info_data', {}).get('table')
+            vrf_table_id = get_vrf_tableid(vrf)
             # Add map element with interface and zone ID
             if vrf_table_id:
+                # delete old table ID from nftables if it has changed, e.g. interface moved to a different VRF
+                if old_vrf_tableid and old_vrf_tableid != int(vrf_table_id):
+                    nft_del_element = f'delete element inet vrf_zones ct_iface_map {{ "{self.ifname}" }}'
+                    nft_check_and_run(nft_del_element)
+
                 self._cmd(f'nft add element inet vrf_zones ct_iface_map {{ "{self.ifname}" : {vrf_table_id} }}')
         else:
             nft_del_element = f'delete element inet vrf_zones ct_iface_map {{ "{self.ifname}" }}'
-            # Check if deleting is possible first to avoid raising errors
-            _, err = self._popen(f'nft --check {nft_del_element}')
-            if not err:
-                # Remove map element
-                self._cmd(f'nft {nft_del_element}')
+            nft_check_and_run(nft_del_element)
 
     def get_min_mtu(self):
         """
@@ -601,8 +609,10 @@ class Interface(Control):
         if tmp == vrf:
             return False
 
+        # Get current VRF table ID
+        old_vrf_tableid = get_vrf_tableid(self.ifname)
         self.set_interface('vrf', vrf)
-        self._set_vrf_ct_zone(vrf)
+        self._set_vrf_ct_zone(vrf, old_vrf_tableid)
         return True
 
     def set_arp_cache_tmo(self, tmo):
diff --git a/python/vyos/utils/network.py b/python/vyos/utils/network.py
index 829124b57..8406a5638 100644
--- a/python/vyos/utils/network.py
+++ b/python/vyos/utils/network.py
@@ -83,6 +83,19 @@ def get_interface_vrf(interface):
         return tmp['master']
     return 'default'
 
+def get_vrf_tableid(interface: str):
+    """ Return VRF table ID for given interface name or None """
+    from vyos.utils.dict import dict_search
+    table = None
+    tmp = get_interface_config(interface)
+    # Check if we are "the" VRF interface
+    if dict_search('linkinfo.info_kind', tmp) == 'vrf':
+        table = tmp['linkinfo']['info_data']['table']
+    # or an interface bound to a VRF
+    elif dict_search('linkinfo.info_slave_kind', tmp) == 'vrf':
+        table = tmp['linkinfo']['info_slave_data']['table']
+    return table
+
 def get_interface_config(interface):
     """ Returns the used encapsulation protocol for given interface.
         If interface does not exist, None is returned.
@@ -537,21 +550,21 @@ def ipv6_prefix_length(low, high):
         return None
 
     xor = bytearray(a ^ b for a, b in zip(lo, hi))
-        
+
     plen = 0
     while plen < 128 and xor[plen // 8] == 0:
         plen += 8
-        
+
     if plen == 128:
         return plen
-    
+
     for i in range((plen // 8) + 1, 16):
         if xor[i] != 0:
             return None
-    
+
     for i in range(8):
         msk = ~xor[plen // 8] & 0xff
-        
+
         if msk == bytemasks[i]:
             return plen + i + 1
 
diff --git a/smoketest/scripts/cli/base_interfaces_test.py b/smoketest/scripts/cli/base_interfaces_test.py
index 9be2c2f1a..4072fd5c2 100644
--- a/smoketest/scripts/cli/base_interfaces_test.py
+++ b/smoketest/scripts/cli/base_interfaces_test.py
@@ -28,6 +28,7 @@ from vyos.utils.dict import dict_search
 from vyos.utils.process import process_named_running
 from vyos.utils.network import get_interface_config
 from vyos.utils.network import get_interface_vrf
+from vyos.utils.network import get_vrf_tableid
 from vyos.utils.process import cmd
 from vyos.utils.network import is_intf_addr_assigned
 from vyos.utils.network import is_ipv6_link_local
@@ -257,6 +258,51 @@ class BasicInterfaceTest:
 
             self.cli_delete(['vrf', 'name', vrf_name])
 
+        def test_move_interface_between_vrf_instances(self):
+            if not self._test_vrf:
+                self.skipTest('not supported')
+
+            vrf1_name = 'smoketest_mgmt1'
+            vrf1_table = '5424'
+            vrf2_name = 'smoketest_mgmt2'
+            vrf2_table = '7412'
+
+            self.cli_set(['vrf', 'name', vrf1_name, 'table', vrf1_table])
+            self.cli_set(['vrf', 'name', vrf2_name, 'table', vrf2_table])
+
+            # move interface into first VRF
+            for interface in self._interfaces:
+                for option in self._options.get(interface, []):
+                    self.cli_set(self._base_path + [interface] + option.split())
+                self.cli_set(self._base_path + [interface, 'vrf', vrf1_name])
+
+            self.cli_commit()
+
+            # check that interface belongs to proper VRF
+            for interface in self._interfaces:
+                tmp = get_interface_vrf(interface)
+                self.assertEqual(tmp, vrf1_name)
+
+                tmp = get_interface_config(vrf1_name)
+                self.assertEqual(int(vrf1_table), get_vrf_tableid(interface))
+
+            # move interface into second VRF
+            for interface in self._interfaces:
+                self.cli_set(self._base_path + [interface, 'vrf', vrf2_name])
+
+            self.cli_commit()
+
+            # check that interface belongs to proper VRF
+            for interface in self._interfaces:
+                tmp = get_interface_vrf(interface)
+                self.assertEqual(tmp, vrf2_name)
+
+                tmp = get_interface_config(vrf2_name)
+                self.assertEqual(int(vrf2_table), get_vrf_tableid(interface))
+
+            self.cli_delete(['vrf', 'name', vrf1_name])
+            self.cli_delete(['vrf', 'name', vrf2_name])
+
         def test_span_mirror(self):
             if not self._mirror_interfaces:
                 self.skipTest('not supported')
-- 
cgit v1.2.3