diff options
| -rw-r--r-- | data/templates/firewall/nftables.tmpl | 6 | ||||
| -rw-r--r-- | python/vyos/template.py | 8 | ||||
| -rwxr-xr-x | smoketest/scripts/cli/test_firewall.py | 17 | ||||
| -rwxr-xr-x | src/conf_mode/firewall-interface.py | 27 | ||||
| -rwxr-xr-x | src/conf_mode/firewall.py | 28 | 
5 files changed, 76 insertions, 10 deletions
| diff --git a/data/templates/firewall/nftables.tmpl b/data/templates/firewall/nftables.tmpl index 34bd9b71e..bbb111b1f 100644 --- a/data/templates/firewall/nftables.tmpl +++ b/data/templates/firewall/nftables.tmpl @@ -147,13 +147,13 @@ table ip6 filter {  {% if state_policy is defined %}      chain VYOS_STATE_POLICY6 {  {%   if state_policy.established is defined %} -        {{ state_policy.established | nft_state_policy('established') }} +        {{ state_policy.established | nft_state_policy('established', ipv6=True) }}  {%   endif %}  {%   if state_policy.invalid is defined %} -        {{ state_policy.invalid | nft_state_policy('invalid') }} +        {{ state_policy.invalid | nft_state_policy('invalid', ipv6=True) }}  {%   endif %}  {%   if state_policy.related is defined %} -        {{ state_policy.related | nft_state_policy('related') }} +        {{ state_policy.related | nft_state_policy('related', ipv6=True) }}  {%   endif %}          return      } diff --git a/python/vyos/template.py b/python/vyos/template.py index 2987fcd0e..7671bf377 100644 --- a/python/vyos/template.py +++ b/python/vyos/template.py @@ -517,7 +517,7 @@ def nft_rule(rule_conf, fw_name, rule_id, ip_name='ip'):      return parse_rule(rule_conf, fw_name, rule_id, ip_name)  @register_filter('nft_state_policy') -def nft_state_policy(conf, state): +def nft_state_policy(conf, state, ipv6=False):      out = [f'ct state {state}']      if 'log' in conf and 'enable' in conf['log']: @@ -526,7 +526,11 @@ def nft_state_policy(conf, state):      out.append('counter')      if 'action' in conf: -        out.append(conf['action']) +        if conf['action'] == 'accept': +            jump_target = 'VYOS_POST_FW6' if ipv6 else 'VYOS_POST_FW' +            out.append(f'jump {jump_target}') +        else: +            out.append(conf['action'])      return " ".join(out) diff --git a/smoketest/scripts/cli/test_firewall.py b/smoketest/scripts/cli/test_firewall.py index 1520020fd..5f728f0cd 100755 --- a/smoketest/scripts/cli/test_firewall.py +++ b/smoketest/scripts/cli/test_firewall.py @@ -134,6 +134,23 @@ class TestFirewall(VyOSUnitTestSHIM.TestCase):                      break              self.assertTrue(matched) +    def test_state_policy(self): +        self.cli_set(['firewall', 'state-policy', 'established', 'action', 'accept']) +        self.cli_set(['firewall', 'state-policy', 'related', 'action', 'accept']) +        self.cli_set(['firewall', 'state-policy', 'invalid', 'action', 'drop']) + +        self.cli_commit() + +        chains = { +            'ip filter': ['VYOS_FW_IN', 'VYOS_FW_OUTPUT', 'VYOS_FW_LOCAL'], +            'ip6 filter': ['VYOS_FW6_IN', 'VYOS_FW6_OUTPUT', 'VYOS_FW6_LOCAL'] +        } + +        for table in ['ip filter', 'ip6 filter']: +            for chain in chains[table]: +                nftables_output = cmd(f'sudo nft list chain {table} {chain}') +                self.assertTrue('jump VYOS_STATE_POLICY' in nftables_output) +      def test_sysfs(self):          for name, conf in sysfs_config.items():              paths = glob(conf['sysfs']) diff --git a/src/conf_mode/firewall-interface.py b/src/conf_mode/firewall-interface.py index 3a17dc5a4..516fa6c48 100755 --- a/src/conf_mode/firewall-interface.py +++ b/src/conf_mode/firewall-interface.py @@ -107,6 +107,15 @@ def cleanup_rule(table, chain, ifname, new_name=None):                  run(f'nft delete rule {table} {chain} handle {handle_search[1]}')      return retval +def state_policy_handle(table, chain): +    results = cmd(f'nft -a list chain {table} {chain}').split("\n") +    for line in results: +        if 'jump VYOS_STATE_POLICY' in line: +            handle_search = re.search('handle (\d+)', line) +            if handle_search: +                return handle_search[1] +    return None +  def apply(if_firewall):      ifname = if_firewall['ifname'] @@ -118,18 +127,32 @@ def apply(if_firewall):          name = dict_search_args(if_firewall, direction, 'name')          if name:              rule_exists = cleanup_rule('ip filter', chain, ifname, name) +            rule_action = 'insert' +            rule_prefix = ''              if not rule_exists: -                run(f'nft insert rule ip filter {chain} {if_prefix}ifname {ifname} counter jump {name}') +                handle = state_policy_handle('ip filter', chain) +                if handle: +                    rule_action = 'add' +                    rule_prefix = f'position {handle}' + +                run(f'nft {rule_action} rule ip filter {chain} {rule_prefix} {if_prefix}ifname {ifname} counter jump {name}')          else:              cleanup_rule('ip filter', chain, ifname)          ipv6_name = dict_search_args(if_firewall, direction, 'ipv6_name')          if ipv6_name:              rule_exists = cleanup_rule('ip6 filter', ipv6_chain, ifname, ipv6_name) +            rule_action = 'insert' +            rule_prefix = ''              if not rule_exists: -                run(f'nft insert rule ip6 filter {ipv6_chain} {if_prefix}ifname {ifname} counter jump {ipv6_name}') +                handle = state_policy_handle('ip filter', chain) +                if handle: +                    rule_action = 'add' +                    rule_prefix = f'position {handle}' + +                run(f'nft {rule_action} rule ip6 filter {ipv6_chain} {rule_prefix} {if_prefix}ifname {ifname} counter jump {ipv6_name}')          else:              cleanup_rule('ip6 filter', ipv6_chain, ifname) diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py index 5ac48c9ba..8e037c679 100755 --- a/src/conf_mode/firewall.py +++ b/src/conf_mode/firewall.py @@ -205,13 +205,20 @@ def verify(firewall):  def cleanup_commands(firewall):      commands = []      for table in ['ip filter', 'ip6 filter']: +        state_chain = 'VYOS_STATE_POLICY' if table == 'ip filter' else 'VYOS_STATE_POLICY6'          json_str = cmd(f'nft -j list table {table}')          obj = loads(json_str)          if 'nftables' not in obj:              continue          for item in obj['nftables']:              if 'chain' in item: -                if item['chain']['name'] not in preserve_chains: +                if item['chain']['name'] in ['VYOS_STATE_POLICY', 'VYOS_STATE_POLICY6']: +                    chain = item['chain']['name'] +                    if 'state_policy' not in firewall: +                        commands.append(f'delete chain {table} {chain}') +                    else: +                        commands.append(f'flush chain {table} {chain}') +                elif item['chain']['name'] not in preserve_chains:                      chain = item['chain']['name']                      if table == 'ip filter' and dict_search_args(firewall, 'name', chain):                          commands.append(f'flush chain {table} {chain}') @@ -219,6 +226,14 @@ def cleanup_commands(firewall):                          commands.append(f'flush chain {table} {chain}')                      else:                          commands.append(f'delete chain {table} {chain}') +            elif 'rule' in item: +                rule = item['rule'] +                if rule['chain'] in ['VYOS_FW_IN', 'VYOS_FW_OUTPUT', 'VYOS_FW_LOCAL', 'VYOS_FW6_IN', 'VYOS_FW6_OUTPUT', 'VYOS_FW6_LOCAL']: +                    if 'expr' in rule and any([True for expr in rule['expr'] if dict_search_args(expr, 'jump', 'target') == state_chain]): +                        if 'state_policy' not in firewall: +                            chain = rule['chain'] +                            handle = rule['handle'] +                            commands.append(f'delete rule {table} {chain} handle {handle}')      return commands  def generate(firewall): @@ -286,6 +301,11 @@ def post_apply_trap(firewall):                  cmd(base_cmd + ' '.join(objects)) +def state_policy_rule_exists(): +    # Determine if state policy rules already exist in nft +    search_str = cmd(f'nft list chain ip filter VYOS_FW_IN') +    return 'VYOS_STATE_POLICY' in search_str +  def apply(firewall):      if 'first_install' in firewall:          run('nfct helper add rpc inet tcp') @@ -296,9 +316,11 @@ def apply(firewall):      if install_result == 1:          raise ConfigError('Failed to apply firewall') -    if 'state_policy' in firewall: -        for chain in ['INPUT', 'OUTPUT', 'FORWARD']: +    if 'state_policy' in firewall and not state_policy_rule_exists(): +        for chain in ['VYOS_FW_IN', 'VYOS_FW_OUTPUT', 'VYOS_FW_LOCAL']:              cmd(f'nft insert rule ip filter {chain} jump VYOS_STATE_POLICY') + +        for chain in ['VYOS_FW6_IN', 'VYOS_FW6_OUTPUT', 'VYOS_FW6_LOCAL']:              cmd(f'nft insert rule ip6 filter {chain} jump VYOS_STATE_POLICY6')      apply_sysfs(firewall) | 
