diff options
-rw-r--r-- | data/templates/firewall/nftables-geoip-update.j2 | 33 | ||||
-rw-r--r-- | data/templates/firewall/nftables-policy.j2 | 17 | ||||
-rw-r--r-- | interface-definitions/policy_route.xml.in | 4 | ||||
-rwxr-xr-x | python/vyos/firewall.py | 67 | ||||
-rwxr-xr-x | smoketest/scripts/cli/test_policy_route.py | 34 | ||||
-rwxr-xr-x | src/conf_mode/firewall.py | 2 | ||||
-rwxr-xr-x | src/conf_mode/policy_route.py | 47 | ||||
-rwxr-xr-x | src/helpers/geoip-update.py | 17 |
8 files changed, 193 insertions, 28 deletions
diff --git a/data/templates/firewall/nftables-geoip-update.j2 b/data/templates/firewall/nftables-geoip-update.j2 index 832ccc3e9..d8f80d1f5 100644 --- a/data/templates/firewall/nftables-geoip-update.j2 +++ b/data/templates/firewall/nftables-geoip-update.j2 @@ -31,3 +31,36 @@ table ip6 vyos_filter { {% endfor %} } {% endif %} + + +{% if ipv4_sets_policy is vyos_defined %} +{% for setname, ip_list in ipv4_sets_policy.items() %} +flush set ip vyos_mangle {{ setname }} +{% endfor %} + +table ip vyos_mangle { +{% for setname, ip_list in ipv4_sets_policy.items() %} + set {{ setname }} { + type ipv4_addr + flags interval + elements = { {{ ','.join(ip_list) }} } + } +{% endfor %} +} +{% endif %} + +{% if ipv6_sets_policy is vyos_defined %} +{% for setname, ip_list in ipv6_sets_policy.items() %} +flush set ip6 vyos_mangle {{ setname }} +{% endfor %} + +table ip6 vyos_mangle { +{% for setname, ip_list in ipv6_sets_policy.items() %} + set {{ setname }} { + type ipv6_addr + flags interval + elements = { {{ ','.join(ip_list) }} } + } +{% endfor %} +} +{% endif %} diff --git a/data/templates/firewall/nftables-policy.j2 b/data/templates/firewall/nftables-policy.j2 index 9e28899b0..00d0e8a62 100644 --- a/data/templates/firewall/nftables-policy.j2 +++ b/data/templates/firewall/nftables-policy.j2 @@ -33,6 +33,15 @@ table ip vyos_mangle { {% endif %} } {% endfor %} + +{% if geoip_updated.name is vyos_defined %} +{% for setname in geoip_updated.name %} + set {{ setname }} { + type ipv4_addr + flags interval + } +{% endfor %} +{% endif %} {% endif %} {{ group_tmpl.groups(firewall_group, False, True) }} @@ -65,6 +74,14 @@ table ip6 vyos_mangle { {% endif %} } {% endfor %} +{% if geoip_updated.ipv6_name is vyos_defined %} +{% for setname in geoip_updated.ipv6_name %} + set {{ setname }} { + type ipv6_addr + flags interval + } +{% endfor %} +{% endif %} {% endif %} {{ group_tmpl.groups(firewall_group, True, True) }} diff --git a/interface-definitions/policy_route.xml.in b/interface-definitions/policy_route.xml.in index 9cc22540b..48f728923 100644 --- a/interface-definitions/policy_route.xml.in +++ b/interface-definitions/policy_route.xml.in @@ -35,6 +35,7 @@ #include <include/firewall/address-ipv6.xml.i> #include <include/firewall/source-destination-group-ipv6.xml.i> #include <include/firewall/port.xml.i> + #include <include/firewall/geoip.xml.i> </children> </node> <node name="source"> @@ -45,6 +46,7 @@ #include <include/firewall/address-ipv6.xml.i> #include <include/firewall/source-destination-group-ipv6.xml.i> #include <include/firewall/port.xml.i> + #include <include/firewall/geoip.xml.i> </children> </node> #include <include/policy/route-common.xml.i> @@ -90,6 +92,7 @@ #include <include/firewall/address.xml.i> #include <include/firewall/source-destination-group.xml.i> #include <include/firewall/port.xml.i> + #include <include/firewall/geoip.xml.i> </children> </node> <node name="source"> @@ -100,6 +103,7 @@ #include <include/firewall/address.xml.i> #include <include/firewall/source-destination-group.xml.i> #include <include/firewall/port.xml.i> + #include <include/firewall/geoip.xml.i> </children> </node> #include <include/policy/route-common.xml.i> diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py index 9f01f8be1..9c320c82d 100755 --- a/python/vyos/firewall.py +++ b/python/vyos/firewall.py @@ -233,6 +233,9 @@ def parse_rule(rule_conf, hook, fw_name, rule_id, ip_name): hook_name = 'prerouting' if hook == 'NAM': hook_name = f'name' + # for policy + if hook == 'route' or hook == 'route6': + hook_name = hook output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC{def_suffix}_{hook_name}_{fw_name}_{rule_id}') if 'mac_address' in side_conf: @@ -738,14 +741,14 @@ class GeoIPLock(object): def __exit__(self, exc_type, exc_value, tb): os.unlink(self.file) -def geoip_update(firewall, force=False): +def geoip_update(firewall=None, policy=None, force=False): with GeoIPLock(geoip_lock_file) as lock: if not lock: print("Script is already running") return False - if not firewall: - print("Firewall is not configured") + if not firewall and not policy: + print("Firewall and policy are not configured") return True if not os.path.exists(geoip_database): @@ -760,23 +763,41 @@ def geoip_update(firewall, force=False): ipv4_sets = {} ipv6_sets = {} + ipv4_codes_policy = {} + ipv6_codes_policy = {} + + ipv4_sets_policy = {} + ipv6_sets_policy = {} + # Map country codes to set names - for codes, path in dict_search_recursive(firewall, 'country_code'): - set_name = f'GEOIP_CC_{path[1]}_{path[2]}_{path[4]}' - if ( path[0] == 'ipv4'): - for code in codes: - ipv4_codes.setdefault(code, []).append(set_name) - elif ( path[0] == 'ipv6' ): - set_name = f'GEOIP_CC6_{path[1]}_{path[2]}_{path[4]}' - for code in codes: - ipv6_codes.setdefault(code, []).append(set_name) - - if not ipv4_codes and not ipv6_codes: + if firewall: + for codes, path in dict_search_recursive(firewall, 'country_code'): + set_name = f'GEOIP_CC_{path[1]}_{path[2]}_{path[4]}' + if ( path[0] == 'ipv4'): + for code in codes: + ipv4_codes.setdefault(code, []).append(set_name) + elif ( path[0] == 'ipv6' ): + set_name = f'GEOIP_CC6_{path[1]}_{path[2]}_{path[4]}' + for code in codes: + ipv6_codes.setdefault(code, []).append(set_name) + + if policy: + for codes, path in dict_search_recursive(policy, 'country_code'): + set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}' + if ( path[0] == 'route'): + for code in codes: + ipv4_codes_policy.setdefault(code, []).append(set_name) + elif ( path[0] == 'route6' ): + set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}' + for code in codes: + ipv6_codes_policy.setdefault(code, []).append(set_name) + + if not ipv4_codes and not ipv6_codes and not ipv4_codes_policy and not ipv6_codes_policy: if force: - print("GeoIP not in use by firewall") + print("GeoIP not in use by firewall and policy") return True - geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes]) + geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes, *ipv4_codes_policy, *ipv6_codes_policy]) # Iterate IP blocks to assign to sets for start, end, code in geoip_data: @@ -785,19 +806,29 @@ def geoip_update(firewall, force=False): ip_range = f'{start}-{end}' if start != end else start for setname in ipv4_codes[code]: ipv4_sets.setdefault(setname, []).append(ip_range) + if code in ipv4_codes_policy and ipv4: + ip_range = f'{start}-{end}' if start != end else start + for setname in ipv4_codes_policy[code]: + ipv4_sets_policy.setdefault(setname, []).append(ip_range) if code in ipv6_codes and not ipv4: ip_range = f'{start}-{end}' if start != end else start for setname in ipv6_codes[code]: ipv6_sets.setdefault(setname, []).append(ip_range) + if code in ipv6_codes_policy and not ipv4: + ip_range = f'{start}-{end}' if start != end else start + for setname in ipv6_codes_policy[code]: + ipv6_sets_policy.setdefault(setname, []).append(ip_range) render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', { 'ipv4_sets': ipv4_sets, - 'ipv6_sets': ipv6_sets + 'ipv6_sets': ipv6_sets, + 'ipv4_sets_policy': ipv4_sets_policy, + 'ipv6_sets_policy': ipv6_sets_policy, }) result = run(f'nft --file {nftables_geoip_conf}') if result != 0: - print('Error: GeoIP failed to update firewall') + print('Error: GeoIP failed to update firewall/policy') return False return True diff --git a/smoketest/scripts/cli/test_policy_route.py b/smoketest/scripts/cli/test_policy_route.py index 53761b7d6..15ddd857e 100755 --- a/smoketest/scripts/cli/test_policy_route.py +++ b/smoketest/scripts/cli/test_policy_route.py @@ -307,5 +307,39 @@ class TestPolicyRoute(VyOSUnitTestSHIM.TestCase): self.verify_nftables(nftables6_search, 'ip6 vyos_mangle') + def test_geoip(self): + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'action', 'drop']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'source', 'geoip', 'country-code', 'se']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'source', 'geoip', 'country-code', 'gb']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'action', 'accept']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'source', 'geoip', 'country-code', 'de']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'source', 'geoip', 'country-code', 'fr']) + self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'source', 'geoip', 'inverse-match']) + + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'action', 'drop']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'source', 'geoip', 'country-code', 'se']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'source', 'geoip', 'country-code', 'gb']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'action', 'accept']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'source', 'geoip', 'country-code', 'de']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'source', 'geoip', 'country-code', 'fr']) + self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'source', 'geoip', 'inverse-match']) + + self.cli_commit() + + nftables_search = [ + ['ip saddr @GEOIP_CC_route_smoketest_1', 'drop'], + ['ip saddr != @GEOIP_CC_route_smoketest_2', 'accept'], + ] + + # -t prevents 1000+ GeoIP elements being returned + self.verify_nftables(nftables_search, 'ip vyos_mangle', args='-t') + + nftables_search = [ + ['ip6 saddr @GEOIP_CC6_route6_smoketest6_1', 'drop'], + ['ip6 saddr != @GEOIP_CC6_route6_smoketest6_2', 'accept'], + ] + + self.verify_nftables(nftables_search, 'ip6 vyos_mangle', args='-t') + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py index cebe57092..72f2d39f4 100755 --- a/src/conf_mode/firewall.py +++ b/src/conf_mode/firewall.py @@ -627,7 +627,7 @@ def apply(firewall): # Call helper script to Update set contents if 'name' in firewall['geoip_updated'] or 'ipv6_name' in firewall['geoip_updated']: print('Updating GeoIP. Please wait...') - geoip_update(firewall) + geoip_update(firewall=firewall) return None diff --git a/src/conf_mode/policy_route.py b/src/conf_mode/policy_route.py index 223175b8a..521764896 100755 --- a/src/conf_mode/policy_route.py +++ b/src/conf_mode/policy_route.py @@ -21,13 +21,16 @@ from sys import exit from vyos.base import Warning from vyos.config import Config +from vyos.configdiff import get_config_diff, Diff from vyos.template import render from vyos.utils.dict import dict_search_args +from vyos.utils.dict import dict_search_recursive from vyos.utils.process import cmd from vyos.utils.process import run from vyos.utils.network import get_vrf_tableid from vyos.defaults import rt_global_table from vyos.defaults import rt_global_vrf +from vyos.firewall import geoip_update from vyos import ConfigError from vyos import airbag airbag.enable() @@ -43,6 +46,43 @@ valid_groups = [ 'interface_group' ] +def geoip_updated(conf, policy): + diff = get_config_diff(conf) + node_diff = diff.get_child_nodes_diff(['policy'], expand_nodes=Diff.DELETE, recursive=True) + + out = { + 'name': [], + 'ipv6_name': [], + 'deleted_name': [], + 'deleted_ipv6_name': [] + } + updated = False + + for key, path in dict_search_recursive(policy, 'geoip'): + set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}' + if (path[0] == 'route'): + out['name'].append(set_name) + elif (path[0] == 'route6'): + set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}' + out['ipv6_name'].append(set_name) + + updated = True + + if 'delete' in node_diff: + for key, path in dict_search_recursive(node_diff['delete'], 'geoip'): + set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}' + if (path[0] == 'route'): + out['deleted_name'].append(set_name) + elif (path[0] == 'route6'): + set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}' + out['deleted_ipv6_name'].append(set_name) + updated = True + + if updated: + return out + + return False + def get_config(config=None): if config: conf = config @@ -60,6 +100,7 @@ def get_config(config=None): if 'dynamic_group' in policy['firewall_group']: del policy['firewall_group']['dynamic_group'] + policy['geoip_updated'] = geoip_updated(conf, policy) return policy def verify_rule(policy, name, rule_conf, ipv6, rule_id): @@ -203,6 +244,12 @@ def apply(policy): apply_table_marks(policy) + if policy['geoip_updated']: + # Call helper script to Update set contents + if 'name' in policy['geoip_updated'] or 'ipv6_name' in policy['geoip_updated']: + print('Updating GeoIP. Please wait...') + geoip_update(policy=policy) + return None if __name__ == '__main__': diff --git a/src/helpers/geoip-update.py b/src/helpers/geoip-update.py index 34accf2cc..061c95401 100755 --- a/src/helpers/geoip-update.py +++ b/src/helpers/geoip-update.py @@ -25,20 +25,19 @@ def get_config(config=None): conf = config else: conf = ConfigTreeQuery() - base = ['firewall'] - if not conf.exists(base): - return None - - return conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True, - no_tag_node_value_mangle=True) + return ( + conf.get_config_dict(['firewall'], key_mangling=('-', '_'), get_first_key=True, + no_tag_node_value_mangle=True) if conf.exists(['firewall']) else None, + conf.get_config_dict(['policy'], key_mangling=('-', '_'), get_first_key=True, + no_tag_node_value_mangle=True) if conf.exists(['policy']) else None, + ) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--force", help="Force update", action="store_true") args = parser.parse_args() - firewall = get_config() - - if not geoip_update(firewall, force=args.force): + firewall, policy = get_config() + if not geoip_update(firewall=firewall, policy=policy, force=args.force): sys.exit(1) |