diff options
-rw-r--r-- | interface-definitions/include/firewall/geoip.xml.i | 6 | ||||
-rw-r--r-- | python/vyos/firewall.py | 24 | ||||
-rwxr-xr-x | smoketest/scripts/cli/test_firewall.py | 22 |
3 files changed, 35 insertions, 17 deletions
diff --git a/interface-definitions/include/firewall/geoip.xml.i b/interface-definitions/include/firewall/geoip.xml.i index f6208f718..9fb37a574 100644 --- a/interface-definitions/include/firewall/geoip.xml.i +++ b/interface-definitions/include/firewall/geoip.xml.i @@ -17,6 +17,12 @@ <multi /> </properties> </leafNode> + <leafNode name="inverse-match"> + <properties> + <help>Inverse match of country-codes</help> + <valueless/> + </properties> + </leafNode> </children> </node> <!-- include end --> diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 7d1278d0e..3e2de4c3f 100644 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -152,7 +152,10 @@ def parse_rule(rule_conf, fw_name, rule_id, ip_name): output.append(f'{ip_name} {prefix}addr {suffix}') if dict_search_args(side_conf, 'geoip', 'country_code'): - output.append(f'{ip_name} {prefix}addr @GEOIP_CC_{fw_name}_{rule_id}') + operator = '' + if dict_search_args(side_conf, 'geoip', 'inverse_match') != None: + operator = '!=' + output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC_{fw_name}_{rule_id}') if 'mac_address' in side_conf: suffix = side_conf["mac_address"] @@ -429,22 +432,13 @@ def geoip_update(firewall, force=False): # Map country codes to set names for codes, path in dict_search_recursive(firewall, 'country_code'): + set_name = f'GEOIP_CC_{path[1]}_{path[3]}' if path[0] == 'name': - set_name = f'GEOIP_CC_{path[1]}_{path[3]}' - ipv4_sets[set_name] = [] for code in codes: - if code not in ipv4_codes: - ipv4_codes[code] = [set_name] - else: - ipv4_codes[code].append(set_n) + ipv4_codes.setdefault(code, []).append(set_name) elif path[0] == 'ipv6_name': - set_name = f'GEOIP_CC_{path[1]}_{path[3]}' - ipv6_sets[set_name] = [] for code in codes: - if code not in ipv6_codes: - ipv6_codes[code] = [set_name] - else: - ipv6_codes[code].append(set_name) + ipv6_codes.setdefault(code, []).append(set_name) if not ipv4_codes and not ipv6_codes: if force: @@ -459,11 +453,11 @@ def geoip_update(firewall, force=False): if code in ipv4_codes and ipv4: ip_range = f'{start}-{end}' if start != end else start for setname in ipv4_codes[code]: - ipv4_sets[setname].append(ip_range) + ipv4_sets.setdefault(setname, []).append(ip_range) if code in ipv6_codes and not ipv4: ip_range = f'{start}-{end}' if start != end else start for setname in ipv6_codes[code]: - ipv6_sets[setname].append(ip_range) + ipv6_sets.setdefault(setname, []).append(ip_range) render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', { 'ipv4_sets': ipv4_sets, diff --git a/smoketest/scripts/cli/test_firewall.py b/smoketest/scripts/cli/test_firewall.py index ce06b9074..4de90e1ec 100755 --- a/smoketest/scripts/cli/test_firewall.py +++ b/smoketest/scripts/cli/test_firewall.py @@ -69,8 +69,8 @@ class TestFirewall(VyOSUnitTestSHIM.TestCase): self.verify_nftables(nftables_search, 'ip filter', inverse=True) - def verify_nftables(self, nftables_search, table, inverse=False): - nftables_output = cmd(f'sudo nft list table {table}') + def verify_nftables(self, nftables_search, table, inverse=False, args=''): + nftables_output = cmd(f'sudo nft {args} list table {table}') for search in nftables_search: matched = False @@ -80,6 +80,24 @@ class TestFirewall(VyOSUnitTestSHIM.TestCase): break self.assertTrue(not matched if inverse else matched, msg=search) + def test_geoip(self): + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '1', 'action', 'drop']) + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '1', 'source', 'geoip', 'country-code', 'se']) + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '1', 'source', 'geoip', 'country-code', 'gb']) + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '2', 'action', 'accept']) + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '2', 'source', 'geoip', 'country-code', 'de']) + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '2', 'source', 'geoip', 'country-code', 'fr']) + self.cli_set(['firewall', 'name', 'smoketest', 'rule', '2', 'source', 'geoip', 'inverse-match']) + + self.cli_commit() + + nftables_search = [ + ['ip saddr @GEOIP_CC_smoketest_1', 'drop'], + ['ip saddr != @GEOIP_CC_smoketest_2', 'return'] + ] + # -t prevents 1000+ GeoIP elements being returned + self.verify_nftables(nftables_search, 'ip filter', args='-t') + def test_groups(self): hostmap_path = ['system', 'static-host-mapping', 'host-name'] example_org = ['192.0.2.8', '192.0.2.10', '192.0.2.11'] |