summaryrefslogtreecommitdiff
path: root/src/conf_mode/policy_route.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/conf_mode/policy_route.py')
-rwxr-xr-xsrc/conf_mode/policy_route.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/src/conf_mode/policy_route.py b/src/conf_mode/policy_route.py
index 223175b8a..521764896 100755
--- a/src/conf_mode/policy_route.py
+++ b/src/conf_mode/policy_route.py
@@ -21,13 +21,16 @@ from sys import exit
from vyos.base import Warning
from vyos.config import Config
+from vyos.configdiff import get_config_diff, Diff
from vyos.template import render
from vyos.utils.dict import dict_search_args
+from vyos.utils.dict import dict_search_recursive
from vyos.utils.process import cmd
from vyos.utils.process import run
from vyos.utils.network import get_vrf_tableid
from vyos.defaults import rt_global_table
from vyos.defaults import rt_global_vrf
+from vyos.firewall import geoip_update
from vyos import ConfigError
from vyos import airbag
airbag.enable()
@@ -43,6 +46,43 @@ valid_groups = [
'interface_group'
]
+def geoip_updated(conf, policy):
+ diff = get_config_diff(conf)
+ node_diff = diff.get_child_nodes_diff(['policy'], expand_nodes=Diff.DELETE, recursive=True)
+
+ out = {
+ 'name': [],
+ 'ipv6_name': [],
+ 'deleted_name': [],
+ 'deleted_ipv6_name': []
+ }
+ updated = False
+
+ for key, path in dict_search_recursive(policy, 'geoip'):
+ set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}'
+ if (path[0] == 'route'):
+ out['name'].append(set_name)
+ elif (path[0] == 'route6'):
+ set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}'
+ out['ipv6_name'].append(set_name)
+
+ updated = True
+
+ if 'delete' in node_diff:
+ for key, path in dict_search_recursive(node_diff['delete'], 'geoip'):
+ set_name = f'GEOIP_CC_{path[0]}_{path[1]}_{path[3]}'
+ if (path[0] == 'route'):
+ out['deleted_name'].append(set_name)
+ elif (path[0] == 'route6'):
+ set_name = f'GEOIP_CC6_{path[0]}_{path[1]}_{path[3]}'
+ out['deleted_ipv6_name'].append(set_name)
+ updated = True
+
+ if updated:
+ return out
+
+ return False
+
def get_config(config=None):
if config:
conf = config
@@ -60,6 +100,7 @@ def get_config(config=None):
if 'dynamic_group' in policy['firewall_group']:
del policy['firewall_group']['dynamic_group']
+ policy['geoip_updated'] = geoip_updated(conf, policy)
return policy
def verify_rule(policy, name, rule_conf, ipv6, rule_id):
@@ -203,6 +244,12 @@ def apply(policy):
apply_table_marks(policy)
+ if policy['geoip_updated']:
+ # Call helper script to Update set contents
+ if 'name' in policy['geoip_updated'] or 'ipv6_name' in policy['geoip_updated']:
+ print('Updating GeoIP. Please wait...')
+ geoip_update(policy=policy)
+
return None
if __name__ == '__main__':