summaryrefslogtreecommitdiff
path: root/src/conf_mode/policy-route.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/conf_mode/policy-route.py')
-rwxr-xr-xsrc/conf_mode/policy-route.py72
1 files changed, 52 insertions, 20 deletions
diff --git a/src/conf_mode/policy-route.py b/src/conf_mode/policy-route.py
index 5de341beb..9fddbd2c6 100755
--- a/src/conf_mode/policy-route.py
+++ b/src/conf_mode/policy-route.py
@@ -25,6 +25,7 @@ from vyos.config import Config
from vyos.template import render
from vyos.util import cmd
from vyos.util import dict_search_args
+from vyos.util import dict_search_recursive
from vyos.util import run
from vyos import ConfigError
from vyos import airbag
@@ -33,6 +34,9 @@ airbag.enable()
mark_offset = 0x7FFFFFFF
nftables_conf = '/run/nftables_policy.conf'
+ROUTE_PREFIX = 'VYOS_PBR_'
+ROUTE6_PREFIX = 'VYOS_PBR6_'
+
preserve_chains = [
'VYOS_PBR_PREROUTING',
'VYOS_PBR_POSTROUTING',
@@ -46,6 +50,16 @@ valid_groups = [
'port_group'
]
+group_set_prefix = {
+ 'A_': 'address_group',
+ 'A6_': 'ipv6_address_group',
+# 'D_': 'domain_group',
+ 'M_': 'mac_group',
+ 'N_': 'network_group',
+ 'N6_': 'ipv6_network_group',
+ 'P_': 'port_group'
+}
+
def get_policy_interfaces(conf):
out = {}
interfaces = conf.get_config_dict(['interfaces'], key_mangling=('-', '_'), get_first_key=True,
@@ -166,37 +180,55 @@ def verify(policy):
return None
-def cleanup_rule(table, jump_chain):
- commands = []
- results = cmd(f'nft -a list table {table}').split("\n")
- for line in results:
- if f'jump {jump_chain}' in line:
- handle_search = re.search('handle (\d+)', line)
- if handle_search:
- commands.append(f'delete rule {table} {chain} handle {handle_search[1]}')
- return commands
-
def cleanup_commands(policy):
commands = []
+ commands_chains = []
+ commands_sets = []
for table in ['ip mangle', 'ip6 mangle']:
- json_str = cmd(f'nft -j list table {table}')
+ route_node = 'route' if table == 'ip mangle' else 'route6'
+ chain_prefix = ROUTE_PREFIX if table == 'ip mangle' else ROUTE6_PREFIX
+
+ json_str = cmd(f'nft -t -j list table {table}')
obj = loads(json_str)
if 'nftables' not in obj:
continue
for item in obj['nftables']:
if 'chain' in item:
chain = item['chain']['name']
- if not chain.startswith("VYOS_PBR"):
+ if chain in preserve_chains or not chain.startswith("VYOS_PBR"):
continue
+
+ if dict_search_args(policy, route_node, chain.replace(chain_prefix, "", 1)) != None:
+ commands.append(f'flush chain {table} {chain}')
+ else:
+ commands_chains.append(f'delete chain {table} {chain}')
+
+ if 'rule' in item:
+ rule = item['rule']
+ chain = rule['chain']
+ handle = rule['handle']
+
if chain not in preserve_chains:
- if table == 'ip mangle' and dict_search_args(policy, 'route', chain.replace("VYOS_PBR_", "", 1)):
- commands.append(f'flush chain {table} {chain}')
- elif table == 'ip6 mangle' and dict_search_args(policy, 'route6', chain.replace("VYOS_PBR6_", "", 1)):
- commands.append(f'flush chain {table} {chain}')
- else:
- commands += cleanup_rule(table, chain)
- commands.append(f'delete chain {table} {chain}')
- return commands
+ continue
+
+ target, _ = next(dict_search_recursive(rule['expr'], 'target'))
+
+ if target.startswith(chain_prefix):
+ if dict_search_args(policy, route_node, target.replace(chain_prefix, "", 1)) == None:
+ commands.append(f'delete rule {table} {chain} handle {handle}')
+
+ if 'set' in item:
+ set_name = item['set']['name']
+
+ for prefix, group_type in group_set_prefix.items():
+ if set_name.startswith(prefix):
+ group_name = set_name.replace(prefix, "", 1)
+ if dict_search_args(policy, 'firewall_group', group_type, group_name) != None:
+ commands_sets.append(f'flush set {table} {set_name}')
+ else:
+ commands_sets.append(f'delete set {table} {set_name}')
+
+ return commands + commands_chains + commands_sets
def generate(policy):
if not os.path.exists(nftables_conf):