summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/vyos/qos/base.py142
-rw-r--r--python/vyos/utils/process.py4
2 files changed, 79 insertions, 67 deletions
diff --git a/python/vyos/qos/base.py b/python/vyos/qos/base.py
index 3da9afe04..66df5d107 100644
--- a/python/vyos/qos/base.py
+++ b/python/vyos/qos/base.py
@@ -245,8 +245,6 @@ class QoSBase:
prio = cls_config['priority']
filter_cmd_base += f' prio {prio}'
- filter_cmd_base += ' protocol all'
-
if 'match' in cls_config:
has_filter = False
has_action_policy = any(tmp in ['exceed', 'bandwidth', 'burst'] for tmp in cls_config)
@@ -254,13 +252,17 @@ class QoSBase:
for index, (match, match_config) in enumerate(cls_config['match'].items(), start=1):
filter_cmd = filter_cmd_base
if not has_filter:
- for key in ['mark', 'vif', 'ip', 'ipv6', 'interface']:
+ for key in ['mark', 'vif', 'ip', 'ipv6', 'interface', 'ether']:
if key in match_config:
has_filter = True
break
+ tmp = dict_search(f'ether.protocol', match_config) or 'all'
+ filter_cmd += f' protocol {tmp}'
+
if self.qostype in ['shaper', 'shaper_hfsc'] and 'prio ' not in filter_cmd:
filter_cmd += f' prio {index}'
+
if 'mark' in match_config:
mark = match_config['mark']
filter_cmd += f' handle {mark} fw'
@@ -273,7 +275,7 @@ class QoSBase:
iif = Interface(iif_name).get_ifindex()
filter_cmd += f' basic match "meta(rt_iif eq {iif})"'
- for af in ['ip', 'ipv6']:
+ for af in ['ip', 'ipv6', 'ether']:
tc_af = af
if af == 'ipv6':
tc_af = 'ip6'
@@ -281,67 +283,77 @@ class QoSBase:
if af in match_config:
filter_cmd += ' u32'
- tmp = dict_search(f'{af}.source.address', match_config)
- if tmp: filter_cmd += f' match {tc_af} src {tmp}'
-
- tmp = dict_search(f'{af}.source.port', match_config)
- if tmp: filter_cmd += f' match {tc_af} sport {tmp} 0xffff'
-
- tmp = dict_search(f'{af}.destination.address', match_config)
- if tmp: filter_cmd += f' match {tc_af} dst {tmp}'
-
- tmp = dict_search(f'{af}.destination.port', match_config)
- if tmp: filter_cmd += f' match {tc_af} dport {tmp} 0xffff'
-
- tmp = dict_search(f'{af}.protocol', match_config)
- if tmp:
- tmp = get_protocol_by_name(tmp)
- filter_cmd += f' match {tc_af} protocol {tmp} 0xff'
-
- tmp = dict_search(f'{af}.dscp', match_config)
- if tmp:
- tmp = self._get_dsfield(tmp)
- if af == 'ip':
- filter_cmd += f' match {tc_af} dsfield {tmp} 0xff'
- elif af == 'ipv6':
- filter_cmd += f' match u16 {tmp} 0x0ff0 at 0'
-
- # Will match against total length of an IPv4 packet and
- # payload length of an IPv6 packet.
- #
- # IPv4 : match u16 0x0000 ~MAXLEN at 2
- # IPv6 : match u16 0x0000 ~MAXLEN at 4
- tmp = dict_search(f'{af}.max_length', match_config)
- if tmp:
- # We need the 16 bit two's complement of the maximum
- # packet length
- tmp = hex(0xffff & ~int(tmp))
-
- if af == 'ip':
- filter_cmd += f' match u16 0x0000 {tmp} at 2'
- elif af == 'ipv6':
- filter_cmd += f' match u16 0x0000 {tmp} at 4'
-
- # We match against specific TCP flags - we assume the IPv4
- # header length is 20 bytes and assume the IPv6 packet is
- # not using extension headers (hence a ip header length of 40 bytes)
- # TCP Flags are set on byte 13 of the TCP header.
- # IPv4 : match u8 X X at 33
- # IPv6 : match u8 X X at 53
- # with X = 0x02 for SYN and X = 0x10 for ACK
- tmp = dict_search(f'{af}.tcp', match_config)
- if tmp:
- mask = 0
- if 'ack' in tmp:
- mask |= 0x10
- if 'syn' in tmp:
- mask |= 0x02
- mask = hex(mask)
-
- if af == 'ip':
- filter_cmd += f' match u8 {mask} {mask} at 33'
- elif af == 'ipv6':
- filter_cmd += f' match u8 {mask} {mask} at 53'
+ if af == 'ether':
+ src = dict_search(f'{af}.source', match_config)
+ if src: filter_cmd += f' match {tc_af} src {src}'
+
+ dst = dict_search(f'{af}.destination', match_config)
+ if dst: filter_cmd += f' match {tc_af} dst {dst}'
+
+ if not src and not dst:
+ filter_cmd += f' match u32 0 0'
+ else:
+ tmp = dict_search(f'{af}.source.address', match_config)
+ if tmp: filter_cmd += f' match {tc_af} src {tmp}'
+
+ tmp = dict_search(f'{af}.source.port', match_config)
+ if tmp: filter_cmd += f' match {tc_af} sport {tmp} 0xffff'
+
+ tmp = dict_search(f'{af}.destination.address', match_config)
+ if tmp: filter_cmd += f' match {tc_af} dst {tmp}'
+
+ tmp = dict_search(f'{af}.destination.port', match_config)
+ if tmp: filter_cmd += f' match {tc_af} dport {tmp} 0xffff'
+ ###
+ tmp = dict_search(f'{af}.protocol', match_config)
+ if tmp:
+ tmp = get_protocol_by_name(tmp)
+ filter_cmd += f' match {tc_af} protocol {tmp} 0xff'
+
+ tmp = dict_search(f'{af}.dscp', match_config)
+ if tmp:
+ tmp = self._get_dsfield(tmp)
+ if af == 'ip':
+ filter_cmd += f' match {tc_af} dsfield {tmp} 0xff'
+ elif af == 'ipv6':
+ filter_cmd += f' match u16 {tmp} 0x0ff0 at 0'
+
+ # Will match against total length of an IPv4 packet and
+ # payload length of an IPv6 packet.
+ #
+ # IPv4 : match u16 0x0000 ~MAXLEN at 2
+ # IPv6 : match u16 0x0000 ~MAXLEN at 4
+ tmp = dict_search(f'{af}.max_length', match_config)
+ if tmp:
+ # We need the 16 bit two's complement of the maximum
+ # packet length
+ tmp = hex(0xffff & ~int(tmp))
+
+ if af == 'ip':
+ filter_cmd += f' match u16 0x0000 {tmp} at 2'
+ elif af == 'ipv6':
+ filter_cmd += f' match u16 0x0000 {tmp} at 4'
+
+ # We match against specific TCP flags - we assume the IPv4
+ # header length is 20 bytes and assume the IPv6 packet is
+ # not using extension headers (hence a ip header length of 40 bytes)
+ # TCP Flags are set on byte 13 of the TCP header.
+ # IPv4 : match u8 X X at 33
+ # IPv6 : match u8 X X at 53
+ # with X = 0x02 for SYN and X = 0x10 for ACK
+ tmp = dict_search(f'{af}.tcp', match_config)
+ if tmp:
+ mask = 0
+ if 'ack' in tmp:
+ mask |= 0x10
+ if 'syn' in tmp:
+ mask |= 0x02
+ mask = hex(mask)
+
+ if af == 'ip':
+ filter_cmd += f' match u8 {mask} {mask} at 33'
+ elif af == 'ipv6':
+ filter_cmd += f' match u8 {mask} {mask} at 53'
if index != max_index or not has_action_policy:
# avoid duplicate last match rule
diff --git a/python/vyos/utils/process.py b/python/vyos/utils/process.py
index ce880f4a4..d8aabb822 100644
--- a/python/vyos/utils/process.py
+++ b/python/vyos/utils/process.py
@@ -128,7 +128,7 @@ def run(command, flag='', shell=None, input=None, timeout=None, env=None,
def cmd(command, flag='', shell=None, input=None, timeout=None, env=None,
stdout=PIPE, stderr=PIPE, decode='utf-8', raising=None, message='',
- expect=[0]):
+ expect=[0], auth=''):
"""
A wrapper around popen, which returns the stdout and
will raise the error code of a command
@@ -139,7 +139,7 @@ def cmd(command, flag='', shell=None, input=None, timeout=None, env=None,
expect: a list of error codes to consider as normal
"""
decoded, code = popen(
- command, flag,
+ f'{auth} {command}'.strip(), flag,
stdout=stdout, stderr=stderr,
input=input, timeout=timeout,
env=env, shell=shell,