diff options
-rw-r--r-- | data/templates/firewall/nftables-nat.tmpl | 7 | ||||
-rwxr-xr-x | smoketest/scripts/cli/test_nat.py | 109 |
2 files changed, 68 insertions, 48 deletions
diff --git a/data/templates/firewall/nftables-nat.tmpl b/data/templates/firewall/nftables-nat.tmpl index 922f3dcb4..7a925b264 100644 --- a/data/templates/firewall/nftables-nat.tmpl +++ b/data/templates/firewall/nftables-nat.tmpl @@ -138,8 +138,9 @@ {% endif %} {% endmacro %} -# Start with clean NAT table -flush table ip nat +# Start with clean SNAT and DNAT chains +flush chain ip nat PREROUTING +flush chain ip nat POSTROUTING {% if helper_functions is vyos_defined('remove') %} {# NAT if going to be disabled - remove rules and targets from nftables #} {% set base_command = 'delete rule ip raw' %} @@ -164,6 +165,7 @@ add rule ip raw NAT_CONNTRACK counter accept # # Destination NAT rules build up here # +add rule ip nat PREROUTING counter jump VYOS_PRE_DNAT_HOOK {% if destination.rule is vyos_defined %} {% for rule, config in destination.rule.items() if config.disable is not vyos_defined %} {{ nat_rule(rule, config, 'PREROUTING') }} @@ -172,6 +174,7 @@ add rule ip raw NAT_CONNTRACK counter accept # # Source NAT rules build up here # +add rule ip nat POSTROUTING counter jump VYOS_PRE_SNAT_HOOK {% if source.rule is vyos_defined %} {% for rule, config in source.rule.items() if config.disable is not vyos_defined %} {{ nat_rule(rule, config, 'POSTROUTING') }} diff --git a/smoketest/scripts/cli/test_nat.py b/smoketest/scripts/cli/test_nat.py index 2e1b8d431..1f8e4c307 100755 --- a/smoketest/scripts/cli/test_nat.py +++ b/smoketest/scripts/cli/test_nat.py @@ -59,36 +59,44 @@ class TestNAT(VyOSUnitTestSHIM.TestCase): self.cli_commit() - tmp = cmd('sudo nft -j list table nat') + tmp = cmd('sudo nft -j list chain ip nat POSTROUTING') data_json = jmespath.search('nftables[?rule].rule[?chain]', json.loads(tmp)) for idx in range(0, len(data_json)): - rule = str(rules[idx]) data = data_json[idx] - network = f'192.168.{rule}.0/24' - - self.assertEqual(data['chain'], 'POSTROUTING') - self.assertEqual(data['comment'], f'SRC-NAT-{rule}') - self.assertEqual(data['family'], 'ip') - self.assertEqual(data['table'], 'nat') + if idx == 0: + self.assertEqual(data['chain'], 'POSTROUTING') + self.assertEqual(data['family'], 'ip') + self.assertEqual(data['table'], 'nat') - iface = dict_search('match.right', data['expr'][0]) - direction = dict_search('match.left.payload.field', data['expr'][1]) - address = dict_search('match.right.prefix.addr', data['expr'][1]) - mask = dict_search('match.right.prefix.len', data['expr'][1]) - - if int(rule) < 200: - self.assertEqual(direction, 'saddr') - self.assertEqual(iface, outbound_iface_100) - # check for masquerade keyword - self.assertIn('masquerade', data['expr'][3]) + jump_target = dict_search('jump.target', data['expr'][1]) + self.assertEqual(jump_target,'VYOS_PRE_SNAT_HOOK') else: - self.assertEqual(direction, 'daddr') - self.assertEqual(iface, outbound_iface_200) - # check for return keyword due to 'exclude' - self.assertIn('return', data['expr'][3]) - - self.assertEqual(f'{address}/{mask}', network) + rule = str(rules[idx - 1]) + network = f'192.168.{rule}.0/24' + + self.assertEqual(data['chain'], 'POSTROUTING') + self.assertEqual(data['comment'], f'SRC-NAT-{rule}') + self.assertEqual(data['family'], 'ip') + self.assertEqual(data['table'], 'nat') + + iface = dict_search('match.right', data['expr'][0]) + direction = dict_search('match.left.payload.field', data['expr'][1]) + address = dict_search('match.right.prefix.addr', data['expr'][1]) + mask = dict_search('match.right.prefix.len', data['expr'][1]) + + if int(rule) < 200: + self.assertEqual(direction, 'saddr') + self.assertEqual(iface, outbound_iface_100) + # check for masquerade keyword + self.assertIn('masquerade', data['expr'][3]) + else: + self.assertEqual(direction, 'daddr') + self.assertEqual(iface, outbound_iface_200) + # check for return keyword due to 'exclude' + self.assertIn('return', data['expr'][3]) + + self.assertEqual(f'{address}/{mask}', network) def test_dnat(self): rules = ['100', '110', '120', '130', '200', '210', '220', '230'] @@ -111,33 +119,42 @@ class TestNAT(VyOSUnitTestSHIM.TestCase): self.cli_commit() - tmp = cmd('sudo nft -j list table nat') + tmp = cmd('sudo nft -j list chain ip nat PREROUTING') data_json = jmespath.search('nftables[?rule].rule[?chain]', json.loads(tmp)) for idx in range(0, len(data_json)): - rule = str(rules[idx]) data = data_json[idx] - port = int(f'10{rule}') - - self.assertEqual(data['chain'], 'PREROUTING') - self.assertEqual(data['comment'].split()[0], f'DST-NAT-{rule}') - self.assertEqual(data['family'], 'ip') - self.assertEqual(data['table'], 'nat') - - iface = dict_search('match.right', data['expr'][0]) - direction = dict_search('match.left.payload.field', data['expr'][1]) - protocol = dict_search('match.left.payload.protocol', data['expr'][1]) - dnat_addr = dict_search('dnat.addr', data['expr'][3]) - dnat_port = dict_search('dnat.port', data['expr'][3]) - - self.assertEqual(direction, 'sport') - self.assertEqual(dnat_addr, '192.0.2.1') - self.assertEqual(dnat_port, port) - if int(rule) < 200: - self.assertEqual(iface, inbound_iface_100) - self.assertEqual(protocol, inbound_proto_100) + if idx == 0: + self.assertEqual(data['chain'], 'PREROUTING') + self.assertEqual(data['family'], 'ip') + self.assertEqual(data['table'], 'nat') + + jump_target = dict_search('jump.target', data['expr'][1]) + self.assertEqual(jump_target,'VYOS_PRE_DNAT_HOOK') else: - self.assertEqual(iface, inbound_iface_200) + + rule = str(rules[idx - 1]) + port = int(f'10{rule}') + + self.assertEqual(data['chain'], 'PREROUTING') + self.assertEqual(data['comment'].split()[0], f'DST-NAT-{rule}') + self.assertEqual(data['family'], 'ip') + self.assertEqual(data['table'], 'nat') + + iface = dict_search('match.right', data['expr'][0]) + direction = dict_search('match.left.payload.field', data['expr'][1]) + protocol = dict_search('match.left.payload.protocol', data['expr'][1]) + dnat_addr = dict_search('dnat.addr', data['expr'][3]) + dnat_port = dict_search('dnat.port', data['expr'][3]) + + self.assertEqual(direction, 'sport') + self.assertEqual(dnat_addr, '192.0.2.1') + self.assertEqual(dnat_port, port) + if int(rule) < 200: + self.assertEqual(iface, inbound_iface_100) + self.assertEqual(protocol, inbound_proto_100) + else: + self.assertEqual(iface, inbound_iface_200) def test_snat_required_translation_address(self): # T2813: Ensure translation address is specified |