summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/templates/firewall/nftables-geoip-update.j233
-rw-r--r--data/templates/firewall/nftables-policy.j217
-rw-r--r--interface-definitions/policy_route.xml.in4
-rwxr-xr-xpython/vyos/firewall.py67
-rwxr-xr-xsmoketest/scripts/cli/test_policy_route.py34
-rwxr-xr-xsrc/conf_mode/firewall.py2
-rwxr-xr-xsrc/conf_mode/policy_route.py47
-rwxr-xr-xsrc/helpers/geoip-update.py17
8 files changed, 193 insertions, 28 deletions
diff --git a/data/templates/firewall/nftables-geoip-update.j2 b/data/templates/firewall/nftables-geoip-update.j2
index 832ccc3e9..d8f80d1f5 100644
--- a/data/templates/firewall/nftables-geoip-update.j2
+++ b/data/templates/firewall/nftables-geoip-update.j2
@@ -31,3 +31,36 @@ table ip6 vyos_filter {
{% endfor %}
}
{% endif %}
+
+
+{% if ipv4_sets_policy is vyos_defined %}
+{% for setname, ip_list in ipv4_sets_policy.items() %}
+flush set ip vyos_mangle {{ setname }}
+{% endfor %}
+
+table ip vyos_mangle {
+{% for setname, ip_list in ipv4_sets_policy.items() %}
+ set {{ setname }} {
+ type ipv4_addr
+ flags interval
+ elements = { {{ ','.join(ip_list) }} }
+ }
+{% endfor %}
+}
+{% endif %}
+
+{% if ipv6_sets_policy is vyos_defined %}
+{% for setname, ip_list in ipv6_sets_policy.items() %}
+flush set ip6 vyos_mangle {{ setname }}
+{% endfor %}
+
+table ip6 vyos_mangle {
+{% for setname, ip_list in ipv6_sets_policy.items() %}
+ set {{ setname }} {
+ type ipv6_addr
+ flags interval
+ elements = { {{ ','.join(ip_list) }} }
+ }
+{% endfor %}
+}
+{% endif %}
diff --git a/data/templates/firewall/nftables-policy.j2 b/data/templates/firewall/nftables-policy.j2
index 9e28899b0..00d0e8a62 100644
--- a/data/templates/firewall/nftables-policy.j2
+++ b/data/templates/firewall/nftables-policy.j2
@@ -33,6 +33,15 @@ table ip vyos_mangle {
{% endif %}
}
{% endfor %}
+
+{% if geoip_updated.name is vyos_defined %}
+{% for setname in geoip_updated.name %}
+ set {{ setname }} {
+ type ipv4_addr
+ flags interval
+ }
+{% endfor %}
+{% endif %}
{% endif %}
{{ group_tmpl.groups(firewall_group, False, True) }}
@@ -65,6 +74,14 @@ table ip6 vyos_mangle {
{% endif %}
}
{% endfor %}
+{% if geoip_updated.ipv6_name is vyos_defined %}
+{% for setname in geoip_updated.ipv6_name %}
+ set {{ setname }} {
+ type ipv6_addr
+ flags interval
+ }
+{% endfor %}
+{% endif %}
{% endif %}
{{ group_tmpl.groups(firewall_group, True, True) }}
diff --git a/interface-definitions/policy_route.xml.in b/interface-definitions/policy_route.xml.in
index 9cc22540b..48f728923 100644
--- a/interface-definitions/policy_route.xml.in
+++ b/interface-definitions/policy_route.xml.in
@@ -35,6 +35,7 @@
#include <include/firewall/address-ipv6.xml.i>
#include <include/firewall/source-destination-group-ipv6.xml.i>
#include <include/firewall/port.xml.i>
+ #include <include/firewall/geoip.xml.i>
</children>
</node>
<node name="source">
@@ -45,6 +46,7 @@
#include <include/firewall/address-ipv6.xml.i>
#include <include/firewall/source-destination-group-ipv6.xml.i>
#include <include/firewall/port.xml.i>
+ #include <include/firewall/geoip.xml.i>
</children>
</node>
#include <include/policy/route-common.xml.i>
@@ -90,6 +92,7 @@
#include <include/firewall/address.xml.i>
#include <include/firewall/source-destination-group.xml.i>
#include <include/firewall/port.xml.i>
+ #include <include/firewall/geoip.xml.i>
</children>
</node>
<node name="source">
@@ -100,6 +103,7 @@
#include <include/firewall/address.xml.i>
#include <include/firewall/source-destination-group.xml.i>
#include <include/firewall/port.xml.i>
+ #include <include/firewall/geoip.xml.i>
</children>
</node>
#include <include/policy/route-common.xml.i>
diff --git a/python/vyos/firewall.py b/python/vyos/firewall.py
index 9f01f8be1..9c320c82d 100755
--- a/python/vyos/firewall.py
+++ b/python/vyos/firewall.py
@@ -233,6 +233,9 @@ def parse_rule(rule_conf, hook, fw_name, rule_id, ip_name):
hook_name = 'prerouting'
if hook == 'NAM':
hook_name = f'name'
+ # for policy
+ if hook == 'route' or hook == 'route6':
+ hook_name = hook
output.append(f'{ip_name} {prefix}addr {operator} @GEOIP_CC{def_suffix}_{hook_name}_{fw_name}_{rule_id}')
if 'mac_address' in side_conf:
@@ -738,14 +741,14 @@ class GeoIPLock(object):
def __exit__(self, exc_type, exc_value, tb):
os.unlink(self.file)
-def geoip_update(firewall, force=False):
+def geoip_update(firewall=None, policy=None, force=False):
with GeoIPLock(geoip_lock_file) as lock:
if not lock:
print("Script is already running")
return False
- if not firewall:
- print("Firewall is not configured")
+ if not firewall and not policy:
+ print("Firewall and policy are not configured")
return True
if not os.path.exists(geoip_database):
@@ -760,23 +763,41 @@ def geoip_update(firewall, force=False):
ipv4_sets = {}
ipv6_sets = {}
+ ipv4_codes_policy = {}
+ ipv6_codes_policy = {}
+
+ ipv4_sets_policy = {}
+ ipv6_sets_policy = {}
+
# Map country codes to set names
- for codes, path in dict_search_recursive(firewall, 'country_code'):
- set_name = f'GEOIP_CC_{path[1]}_{path[2]}_{path[4]}'
- if ( path[0] == 'ipv4'):
- for code in codes:
- ipv4_codes.setdefault(code, []).append(set_name)
- elif ( path[0] == 'ipv6' ):
- set_name = f'GEOIP_CC6_{path[1]}_{path[2]}_{path[4]}'
- for code in codes:
- ipv6_codes.setdefault(code, []).append(set_name)
-
- if not ipv4_codes and not ipv6_codes:
+ if firewall:
+ for codes, path in dict_search_recursive(firewall, 'country_code'):
+ set_name = f'GEOIP_CC_{path[1]}_{path[2]}_{path[4]}'
+ if ( path[0] == 'ipv4'):
+ for code in codes:
+ ipv4_codes.setdefault(code, []).append(set_name)
+ elif ( path[0] == 'ipv6' ):
+ set_name = f'GEOIP_CC6_{path[1]}_{path[2]}_{path[4]}'
+ for code in codes:
+ ipv6_codes.setdefault(code, []).append(set_name)
+
+ if policy:
+ for codes, path in dict_search_recursive(policy, 'country_code'):
+ set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}'
+ if ( path[0] == 'route'):
+ for code in codes:
+ ipv4_codes_policy.setdefault(code, []).append(set_name)
+ elif ( path[0] == 'route6' ):
+ set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}'
+ for code in codes:
+ ipv6_codes_policy.setdefault(code, []).append(set_name)
+
+ if not ipv4_codes and not ipv6_codes and not ipv4_codes_policy and not ipv6_codes_policy:
if force:
- print("GeoIP not in use by firewall")
+ print("GeoIP not in use by firewall and policy")
return True
- geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes])
+ geoip_data = geoip_load_data([*ipv4_codes, *ipv6_codes, *ipv4_codes_policy, *ipv6_codes_policy])
# Iterate IP blocks to assign to sets
for start, end, code in geoip_data:
@@ -785,19 +806,29 @@ def geoip_update(firewall, force=False):
ip_range = f'{start}-{end}' if start != end else start
for setname in ipv4_codes[code]:
ipv4_sets.setdefault(setname, []).append(ip_range)
+ if code in ipv4_codes_policy and ipv4:
+ ip_range = f'{start}-{end}' if start != end else start
+ for setname in ipv4_codes_policy[code]:
+ ipv4_sets_policy.setdefault(setname, []).append(ip_range)
if code in ipv6_codes and not ipv4:
ip_range = f'{start}-{end}' if start != end else start
for setname in ipv6_codes[code]:
ipv6_sets.setdefault(setname, []).append(ip_range)
+ if code in ipv6_codes_policy and not ipv4:
+ ip_range = f'{start}-{end}' if start != end else start
+ for setname in ipv6_codes_policy[code]:
+ ipv6_sets_policy.setdefault(setname, []).append(ip_range)
render(nftables_geoip_conf, 'firewall/nftables-geoip-update.j2', {
'ipv4_sets': ipv4_sets,
- 'ipv6_sets': ipv6_sets
+ 'ipv6_sets': ipv6_sets,
+ 'ipv4_sets_policy': ipv4_sets_policy,
+ 'ipv6_sets_policy': ipv6_sets_policy,
})
result = run(f'nft --file {nftables_geoip_conf}')
if result != 0:
- print('Error: GeoIP failed to update firewall')
+ print('Error: GeoIP failed to update firewall/policy')
return False
return True
diff --git a/smoketest/scripts/cli/test_policy_route.py b/smoketest/scripts/cli/test_policy_route.py
index 53761b7d6..15ddd857e 100755
--- a/smoketest/scripts/cli/test_policy_route.py
+++ b/smoketest/scripts/cli/test_policy_route.py
@@ -307,5 +307,39 @@ class TestPolicyRoute(VyOSUnitTestSHIM.TestCase):
self.verify_nftables(nftables6_search, 'ip6 vyos_mangle')
+ def test_geoip(self):
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'action', 'drop'])
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'source', 'geoip', 'country-code', 'se'])
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '1', 'source', 'geoip', 'country-code', 'gb'])
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'action', 'accept'])
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'source', 'geoip', 'country-code', 'de'])
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'source', 'geoip', 'country-code', 'fr'])
+ self.cli_set(['policy', 'route', 'smoketest', 'rule', '2', 'source', 'geoip', 'inverse-match'])
+
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'action', 'drop'])
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'source', 'geoip', 'country-code', 'se'])
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '1', 'source', 'geoip', 'country-code', 'gb'])
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'action', 'accept'])
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'source', 'geoip', 'country-code', 'de'])
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'source', 'geoip', 'country-code', 'fr'])
+ self.cli_set(['policy', 'route6', 'smoketest6', 'rule', '2', 'source', 'geoip', 'inverse-match'])
+
+ self.cli_commit()
+
+ nftables_search = [
+ ['ip saddr @GEOIP_CC_route_smoketest_1', 'drop'],
+ ['ip saddr != @GEOIP_CC_route_smoketest_2', 'accept'],
+ ]
+
+ # -t prevents 1000+ GeoIP elements being returned
+ self.verify_nftables(nftables_search, 'ip vyos_mangle', args='-t')
+
+ nftables_search = [
+ ['ip6 saddr @GEOIP_CC6_route6_smoketest6_1', 'drop'],
+ ['ip6 saddr != @GEOIP_CC6_route6_smoketest6_2', 'accept'],
+ ]
+
+ self.verify_nftables(nftables_search, 'ip6 vyos_mangle', args='-t')
+
if __name__ == '__main__':
unittest.main(verbosity=2)
diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py
index cebe57092..72f2d39f4 100755
--- a/src/conf_mode/firewall.py
+++ b/src/conf_mode/firewall.py
@@ -627,7 +627,7 @@ def apply(firewall):
# Call helper script to Update set contents
if 'name' in firewall['geoip_updated'] or 'ipv6_name' in firewall['geoip_updated']:
print('Updating GeoIP. Please wait...')
- geoip_update(firewall)
+ geoip_update(firewall=firewall)
return None
diff --git a/src/conf_mode/policy_route.py b/src/conf_mode/policy_route.py
index 223175b8a..521764896 100755
--- a/src/conf_mode/policy_route.py
+++ b/src/conf_mode/policy_route.py
@@ -21,13 +21,16 @@ from sys import exit
from vyos.base import Warning
from vyos.config import Config
+from vyos.configdiff import get_config_diff, Diff
from vyos.template import render
from vyos.utils.dict import dict_search_args
+from vyos.utils.dict import dict_search_recursive
from vyos.utils.process import cmd
from vyos.utils.process import run
from vyos.utils.network import get_vrf_tableid
from vyos.defaults import rt_global_table
from vyos.defaults import rt_global_vrf
+from vyos.firewall import geoip_update
from vyos import ConfigError
from vyos import airbag
airbag.enable()
@@ -43,6 +46,43 @@ valid_groups = [
'interface_group'
]
+def geoip_updated(conf, policy):
+ diff = get_config_diff(conf)
+ node_diff = diff.get_child_nodes_diff(['policy'], expand_nodes=Diff.DELETE, recursive=True)
+
+ out = {
+ 'name': [],
+ 'ipv6_name': [],
+ 'deleted_name': [],
+ 'deleted_ipv6_name': []
+ }
+ updated = False
+
+ for key, path in dict_search_recursive(policy, 'geoip'):
+ set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}'
+ if (path[0] == 'route'):
+ out['name'].append(set_name)
+ elif (path[0] == 'route6'):
+ set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}'
+ out['ipv6_name'].append(set_name)
+
+ updated = True
+
+ if 'delete' in node_diff:
+ for key, path in dict_search_recursive(node_diff['delete'], 'geoip'):
+ set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}'
+ if (path[0] == 'route'):
+ out['deleted_name'].append(set_name)
+ elif (path[0] == 'route6'):
+ set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}'
+ out['deleted_ipv6_name'].append(set_name)
+ updated = True
+
+ if updated:
+ return out
+
+ return False
+
def get_config(config=None):
if config:
conf = config
@@ -60,6 +100,7 @@ def get_config(config=None):
if 'dynamic_group' in policy['firewall_group']:
del policy['firewall_group']['dynamic_group']
+ policy['geoip_updated'] = geoip_updated(conf, policy)
return policy
def verify_rule(policy, name, rule_conf, ipv6, rule_id):
@@ -203,6 +244,12 @@ def apply(policy):
apply_table_marks(policy)
+ if policy['geoip_updated']:
+ # Call helper script to Update set contents
+ if 'name' in policy['geoip_updated'] or 'ipv6_name' in policy['geoip_updated']:
+ print('Updating GeoIP. Please wait...')
+ geoip_update(policy=policy)
+
return None
if __name__ == '__main__':
diff --git a/src/helpers/geoip-update.py b/src/helpers/geoip-update.py
index 34accf2cc..061c95401 100755
--- a/src/helpers/geoip-update.py
+++ b/src/helpers/geoip-update.py
@@ -25,20 +25,19 @@ def get_config(config=None):
conf = config
else:
conf = ConfigTreeQuery()
- base = ['firewall']
- if not conf.exists(base):
- return None
-
- return conf.get_config_dict(base, key_mangling=('-', '_'), get_first_key=True,
- no_tag_node_value_mangle=True)
+ return (
+ conf.get_config_dict(['firewall'], key_mangling=('-', '_'), get_first_key=True,
+ no_tag_node_value_mangle=True) if conf.exists(['firewall']) else None,
+ conf.get_config_dict(['policy'], key_mangling=('-', '_'), get_first_key=True,
+ no_tag_node_value_mangle=True) if conf.exists(['policy']) else None,
+ )
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--force", help="Force update", action="store_true")
args = parser.parse_args()
- firewall = get_config()
-
- if not geoip_update(firewall, force=args.force):
+ firewall, policy = get_config()
+ if not geoip_update(firewall=firewall, policy=policy, force=args.force):
sys.exit(1)