From 64668771d5f14fc4b68fff382d166238c164bdde Mon Sep 17 00:00:00 2001
From: sarthurdev <965089+sarthurdev@users.noreply.github.com>
Date: Sat, 15 Jan 2022 12:48:48 +0100
Subject: firewall: policy: T4178: Migrate and refactor tcp flags

* Add support for ECN and CWR flags
---
 src/conf_mode/firewall.py             | 12 ++++++++++--
 src/conf_mode/policy-route.py         | 14 +++++++++++---
 src/migration-scripts/firewall/6-to-7 | 21 +++++++++++++++++++++
 src/migration-scripts/policy/1-to-2   | 19 +++++++++++++++++++
 src/validators/tcp-flag               | 14 ++++++--------
 5 files changed, 67 insertions(+), 13 deletions(-)

(limited to 'src')

diff --git a/src/conf_mode/firewall.py b/src/conf_mode/firewall.py
index 853470fd8..906d477b0 100755
--- a/src/conf_mode/firewall.py
+++ b/src/conf_mode/firewall.py
@@ -142,8 +142,16 @@ def verify_rule(firewall, rule_conf, ipv6):
         if not {'count', 'time'} <= set(rule_conf['recent']):
             raise ConfigError('Recent "count" and "time" values must be defined')
 
-    if dict_search_args(rule_conf, 'tcp', 'flags') and dict_search_args(rule_conf, 'protocol') != 'tcp':
-        raise ConfigError('Protocol must be tcp when specifying tcp flags')
+    tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags')
+    if tcp_flags:
+        if dict_search_args(rule_conf, 'protocol') != 'tcp':
+            raise ConfigError('Protocol must be tcp when specifying tcp flags')
+
+        not_flags = dict_search_args(rule_conf, 'tcp', 'flags', 'not')
+        if not_flags:
+            duplicates = [flag for flag in tcp_flags if flag in not_flags]
+            if duplicates:
+                raise ConfigError(f'Cannot match a tcp flag as set and not set')
 
     for side in ['destination', 'source']:
         if side in rule_conf:
diff --git a/src/conf_mode/policy-route.py b/src/conf_mode/policy-route.py
index 30597ef4e..eb13788dd 100755
--- a/src/conf_mode/policy-route.py
+++ b/src/conf_mode/policy-route.py
@@ -97,11 +97,19 @@ def verify_rule(policy, name, rule_conf, ipv6):
     if 'set' in rule_conf:
         if 'tcp_mss' in rule_conf['set']:
             tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags')
-            if not tcp_flags or 'SYN' not in tcp_flags.split(","):
+            if not tcp_flags or 'syn' not in tcp_flags:
                 raise ConfigError(f'{name} rule {rule_id}: TCP SYN flag must be set to modify TCP-MSS')
 
-    if dict_search_args(rule_conf, 'tcp', 'flags') and dict_search_args(rule_conf, 'protocol') != 'tcp':
-                raise ConfigError(f'{name} rule {rule_id}: TCP flags can only be set if protocol is set to TCP')
+    tcp_flags = dict_search_args(rule_conf, 'tcp', 'flags')
+    if tcp_flags:
+        if dict_search_args(rule_conf, 'protocol') != 'tcp':
+            raise ConfigError('Protocol must be tcp when specifying tcp flags')
+
+        not_flags = dict_search_args(rule_conf, 'tcp', 'flags', 'not')
+        if not_flags:
+            duplicates = [flag for flag in tcp_flags if flag in not_flags]
+            if duplicates:
+                raise ConfigError(f'Cannot match a tcp flag as set and not set')
 
     for side in ['destination', 'source']:
         if side in rule_conf:
diff --git a/src/migration-scripts/firewall/6-to-7 b/src/migration-scripts/firewall/6-to-7
index 4a4097d56..bc0b19325 100755
--- a/src/migration-scripts/firewall/6-to-7
+++ b/src/migration-scripts/firewall/6-to-7
@@ -17,6 +17,7 @@
 # T2199: Remove unavailable nodes due to XML/Python implementation using nftables
 #        monthdays: nftables does not have a monthdays equivalent
 #        utc: nftables userspace uses localtime and calculates the UTC offset automatically
+# T4178: Update tcp flags to use multi value node
 
 from sys import argv
 from sys import exit
@@ -45,6 +46,7 @@ if config.exists(base + ['name']):
         if config.exists(base + ['name', name, 'rule']):
             for rule in config.list_nodes(base + ['name', name, 'rule']):
                 rule_time = base + ['name', name, 'rule', rule, 'time']
+                rule_tcp_flags = base + ['name', name, 'rule', rule, 'tcp', 'flags']
 
                 if config.exists(rule_time + ['monthdays']):
                     config.delete(rule_time + ['monthdays'])
@@ -52,11 +54,21 @@ if config.exists(base + ['name']):
                 if config.exists(rule_time + ['utc']):
                     config.delete(rule_time + ['utc'])
 
+                if config.exists(rule_tcp_flags):
+                    tmp = config.return_value(rule_tcp_flags)
+                    config.delete(rule_tcp_flags)
+                    for flag in tmp.split(","):
+                        if flag[0] == '!':
+                            config.set(rule_tcp_flags + ['not', flag[1:].lower()])
+                        else:
+                            config.set(rule_tcp_flags + [flag.lower()])
+
 if config.exists(base + ['ipv6-name']):
     for name in config.list_nodes(base + ['ipv6-name']):
         if config.exists(base + ['ipv6-name', name, 'rule']):
             for rule in config.list_nodes(base + ['ipv6-name', name, 'rule']):
                 rule_time = base + ['ipv6-name', name, 'rule', rule, 'time']
+                rule_tcp_flags = base + ['ipv6-name', name, 'rule', rule, 'tcp', 'flags']
 
                 if config.exists(rule_time + ['monthdays']):
                     config.delete(rule_time + ['monthdays'])
@@ -64,6 +76,15 @@ if config.exists(base + ['ipv6-name']):
                 if config.exists(rule_time + ['utc']):
                     config.delete(rule_time + ['utc'])
 
+                if config.exists(rule_tcp_flags):
+                    tmp = config.return_value(rule_tcp_flags)
+                    config.delete(rule_tcp_flags)
+                    for flag in tmp.split(","):
+                        if flag[0] == '!':
+                            config.set(rule_tcp_flags + ['not', flag[1:].lower()])
+                        else:
+                            config.set(rule_tcp_flags + [flag.lower()])
+
 try:
     with open(file_name, 'w') as f:
         f.write(config.to_string())
diff --git a/src/migration-scripts/policy/1-to-2 b/src/migration-scripts/policy/1-to-2
index 7ffceef22..eebbf9d41 100755
--- a/src/migration-scripts/policy/1-to-2
+++ b/src/migration-scripts/policy/1-to-2
@@ -16,6 +16,7 @@
 
 # T4170: rename "policy ipv6-route" to "policy route6" to match common
 #        IPv4/IPv6 schema
+# T4178: Update tcp flags to use multi value node
 
 from sys import argv
 from sys import exit
@@ -41,6 +42,24 @@ if not config.exists(base):
 config.rename(base, 'route6')
 config.set_tag(['policy', 'route6'])
 
+for route in ['route', 'route6']:
+    route_path = ['policy', route]
+    if config.exists(route_path):
+        for name in config.list_nodes(route_path):
+            if config.exists(route_path + [name, 'rule']):
+                for rule in config.list_nodes(route_path + [name, 'rule']):
+                    rule_tcp_flags = route_path + [name, 'rule', rule, 'tcp', 'flags']
+
+                    if config.exists(rule_tcp_flags):
+                        tmp = config.return_value(rule_tcp_flags)
+                        config.delete(rule_tcp_flags)
+                        for flag in tmp.split(","):
+                            for flag in tmp.split(","):
+                                if flag[0] == '!':
+                                    config.set(rule_tcp_flags + ['not', flag[1:].lower()])
+                                else:
+                                    config.set(rule_tcp_flags + [flag.lower()])
+
 if config.exists(['interfaces']):
     def if_policy_rename(config, path):
         if config.exists(path + ['policy', 'ipv6-route']):
diff --git a/src/validators/tcp-flag b/src/validators/tcp-flag
index 86ebec189..1496b904a 100755
--- a/src/validators/tcp-flag
+++ b/src/validators/tcp-flag
@@ -5,14 +5,12 @@ import re
 
 if __name__ == '__main__':
     if len(sys.argv)>1:
-        flags = sys.argv[1].split(",")
-
-        for flag in flags:
-            if flag and flag[0] == '!':
-                flag = flag[1:]
-            if flag.lower() not in ['syn', 'ack', 'rst', 'fin', 'urg', 'psh']:
-                print(f'Error: {flag} is not a valid TCP flag')
-                sys.exit(1)
+        flag = sys.argv[1]
+        if flag and flag[0] == '!':
+            flag = flag[1:]
+        if flag not in ['syn', 'ack', 'rst', 'fin', 'urg', 'psh', 'ecn', 'cwr']:
+            print(f'Error: {flag} is not a valid TCP flag')
+            sys.exit(1)
     else:
         sys.exit(2)
 
-- 
cgit v1.2.3