diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/vyos/airbag.py | 24 | ||||
-rw-r--r-- | python/vyos/certbot_util.py | 2 | ||||
-rw-r--r-- | python/vyos/configquery.py | 44 | ||||
-rw-r--r-- | python/vyos/configsession.py | 7 | ||||
-rw-r--r-- | python/vyos/configverify.py | 68 | ||||
-rw-r--r-- | python/vyos/defaults.py | 5 | ||||
-rw-r--r-- | python/vyos/ethtool.py | 6 | ||||
-rw-r--r-- | python/vyos/ifconfig/__init__.py | 1 | ||||
-rw-r--r-- | python/vyos/ifconfig/bond.py | 16 | ||||
-rw-r--r-- | python/vyos/ifconfig/ethernet.py | 20 | ||||
-rwxr-xr-x[-rw-r--r--] | python/vyos/ifconfig/interface.py | 74 | ||||
-rw-r--r-- | python/vyos/ifconfig/l2tpv3.py | 24 | ||||
-rw-r--r-- | python/vyos/ifconfig/tunnel.py | 1 | ||||
-rw-r--r-- | python/vyos/ifconfig/vrrp.py | 9 | ||||
-rw-r--r-- | python/vyos/ifconfig/vti.py | 25 | ||||
-rw-r--r-- | python/vyos/ifconfig/wireguard.py | 12 | ||||
-rw-r--r-- | python/vyos/ifconfig/wwan.py | 28 | ||||
-rw-r--r-- | python/vyos/pki.py | 333 | ||||
-rw-r--r-- | python/vyos/remote.py | 375 | ||||
-rw-r--r-- | python/vyos/template.py | 108 | ||||
-rw-r--r-- | python/vyos/util.py | 174 | ||||
-rw-r--r-- | python/vyos/validate.py | 6 | ||||
-rw-r--r-- | python/vyos/xml/load.py | 3 | ||||
-rw-r--r-- | python/vyos/xml/test_xml.py | 8 |
24 files changed, 1192 insertions, 181 deletions
diff --git a/python/vyos/airbag.py b/python/vyos/airbag.py index 510ab7f46..a20f44207 100644 --- a/python/vyos/airbag.py +++ b/python/vyos/airbag.py @@ -18,7 +18,6 @@ from datetime import datetime from vyos import debug from vyos.logger import syslog -from vyos.version import get_version from vyos.version import get_full_version_data @@ -78,7 +77,7 @@ def bug_report(dtype, value, trace): information.update({ 'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'trace': trace, - 'instructions': COMMUNITY if 'rolling' in get_version() else SUPPORTED, + 'instructions': INSTRUCTIONS, 'note': note, }) @@ -162,20 +161,13 @@ When reporting problems, please include as much information as possible: """ -COMMUNITY = """\ -- Make sure you are running the latest version of the code available at - https://downloads.vyos.io/rolling/current/amd64/vyos-rolling-latest.iso -- Consult the forum to see how to handle this issue - https://forum.vyos.io -- Join our community on slack where our users exchange help and advice - https://vyos.slack.com -""".strip() - -SUPPORTED = """\ -- Make sure you are running the latest stable version of VyOS - the code is available at https://downloads.vyos.io/?dir=release/current -- Contact us using the online help desk +INSTRUCTIONS = """\ +- Contact us using the online help desk if you have a subscription: https://support.vyos.io/ -- Join our community on slack where our users exchange help and advice +- Make sure you are running the latest version of VyOS available at: + https://vyos.net/get/ +- Consult the community forum to see how to handle this issue: + https://forum.vyos.io +- Join us on Slack where our users exchange help and advice: https://vyos.slack.com """.strip() diff --git a/python/vyos/certbot_util.py b/python/vyos/certbot_util.py index df42d4780..bcb78381f 100644 --- a/python/vyos/certbot_util.py +++ b/python/vyos/certbot_util.py @@ -1,7 +1,7 @@ # certbot_util -- adaptation of certbot_nginx name matching functions for VyOS # https://github.com/certbot/certbot/blob/master/LICENSE.txt -from certbot_nginx import parser +from certbot_nginx._internal import parser NAME_RANK = 0 START_WILDCARD_RANK = 1 diff --git a/python/vyos/configquery.py b/python/vyos/configquery.py index ed7346f1f..1cdcbcf39 100644 --- a/python/vyos/configquery.py +++ b/python/vyos/configquery.py @@ -18,9 +18,16 @@ A small library that allows querying existence or value(s) of config settings from op mode, and execution of arbitrary op mode commands. ''' +import re +import json +from copy import deepcopy from subprocess import STDOUT -from vyos.util import popen +import vyos.util +import vyos.xml +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configsource import ConfigSourceSession class ConfigQueryError(Exception): pass @@ -51,32 +58,59 @@ class CliShellApiConfigQuery(GenericConfigQuery): def exists(self, path: list): cmd = ' '.join(path) - (_, err) = popen(f'cli-shell-api existsActive {cmd}') + (_, err) = vyos.util.popen(f'cli-shell-api existsActive {cmd}') if err: return False return True def value(self, path: list): cmd = ' '.join(path) - (out, err) = popen(f'cli-shell-api returnActiveValue {cmd}') + (out, err) = vyos.util.popen(f'cli-shell-api returnActiveValue {cmd}') if err: raise ConfigQueryError('No value for given path') return out def values(self, path: list): cmd = ' '.join(path) - (out, err) = popen(f'cli-shell-api returnActiveValues {cmd}') + (out, err) = vyos.util.popen(f'cli-shell-api returnActiveValues {cmd}') if err: raise ConfigQueryError('No values for given path') return out +class ConfigTreeQuery(GenericConfigQuery): + def __init__(self): + super().__init__() + + config_source = ConfigSourceSession() + self.configtree = Config(config_source=config_source) + + def exists(self, path: list): + return self.configtree.exists(path) + + def value(self, path: list): + return self.configtree.return_value(path) + + def values(self, path: list): + return self.configtree.return_values(path) + + def list_nodes(self, path: list): + return self.configtree.list_nodes(path) + + def get_config_dict(self, path=[], effective=False, key_mangling=None, + get_first_key=False, no_multi_convert=False, + no_tag_node_value_mangle=False): + return self.configtree.get_config_dict(path, effective=effective, + key_mangling=key_mangling, get_first_key=get_first_key, + no_multi_convert=no_multi_convert, + no_tag_node_value_mangle=no_tag_node_value_mangle) + class VbashOpRun(GenericOpRun): def __init__(self): super().__init__() def run(self, path: list, **kwargs): cmd = ' '.join(path) - (out, err) = popen(f'. /opt/vyatta/share/vyatta-op/functions/interpreter/vyatta-op-run; _vyatta_op_run {cmd}', stderr=STDOUT, **kwargs) + (out, err) = vyos.util.popen(f'. /opt/vyatta/share/vyatta-op/functions/interpreter/vyatta-op-run; _vyatta_op_run {cmd}', stderr=STDOUT, **kwargs) if err: raise ConfigQueryError(out) return out diff --git a/python/vyos/configsession.py b/python/vyos/configsession.py index 670e6c7fc..f28ad09c5 100644 --- a/python/vyos/configsession.py +++ b/python/vyos/configsession.py @@ -10,14 +10,14 @@ # See the GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License along with this library; -# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA import os import re import sys import subprocess -from vyos.util import call +from vyos.util import is_systemd_service_running CLI_SHELL_API = '/bin/cli-shell-api' SET = '/opt/vyatta/sbin/my_set' @@ -73,8 +73,7 @@ def inject_vyos_env(env): env['vyos_validators_dir'] = '/usr/libexec/vyos/validators' # if running the vyos-configd daemon, inject the vyshim env var - ret = call('systemctl is-active --quiet vyos-configd.service') - if not ret: + if is_systemd_service_running('vyos-configd.service'): env['vyshim'] = '/usr/sbin/vyshim' return env diff --git a/python/vyos/configverify.py b/python/vyos/configverify.py index 99c472582..4279e6982 100644 --- a/python/vyos/configverify.py +++ b/python/vyos/configverify.py @@ -1,4 +1,4 @@ -# Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2020-2021 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -45,6 +45,16 @@ def verify_mtu(config): raise ConfigError(f'Interface MTU too high, ' \ f'maximum supported MTU is {max_mtu}!') +def verify_mtu_parent(config, parent): + if 'mtu' not in config or 'mtu' not in parent: + return + + mtu = int(config['mtu']) + parent_mtu = int(parent['mtu']) + if mtu > parent_mtu: + raise ConfigError(f'Interface MTU ({mtu}) too high, ' \ + f'parent interface MTU is {parent_mtu}!') + def verify_mtu_ipv6(config): """ Common helper function used by interface implementations to perform @@ -139,9 +149,38 @@ def verify_eapol(config): recurring validation of EAPoL configuration. """ if 'eapol' in config: - if not {'cert_file', 'key_file'} <= set(config['eapol']): - raise ConfigError('Both cert and key-file must be specified '\ - 'when using EAPoL!') + if 'certificate' not in config['eapol']: + raise ConfigError('Certificate must be specified when using EAPoL!') + + if 'certificate' not in config['pki']: + raise ConfigError('Invalid certificate specified for EAPoL') + + cert_name = config['eapol']['certificate'] + + if cert_name not in config['pki']['certificate']: + raise ConfigError('Invalid certificate specified for EAPoL') + + cert = config['pki']['certificate'][cert_name] + + if 'certificate' not in cert or 'private' not in cert or 'key' not in cert['private']: + raise ConfigError('Invalid certificate/private key specified for EAPoL') + + if 'password_protected' in cert['private']: + raise ConfigError('Encrypted private key cannot be used for EAPoL') + + if 'ca_certificate' in config['eapol']: + if 'ca' not in config['pki']: + raise ConfigError('Invalid CA certificate specified for EAPoL') + + ca_cert_name = config['eapol']['ca_certificate'] + + if ca_cert_name not in config['pki']['ca']: + raise ConfigError('Invalid CA certificate specified for EAPoL') + + ca_cert = config['pki']['ca'][cert_name] + + if 'certificate' not in ca_cert: + raise ConfigError('Invalid CA certificate specified for EAPoL') def verify_mirror(config): """ @@ -156,6 +195,19 @@ def verify_mirror(config): raise ConfigError(f'Can not mirror "{direction}" traffic back ' \ 'the originating interface!') +def verify_authentication(config): + """ + Common helper function used by interface implementations to perform + recurring validation of authentication for either PPPoE or WWAN interfaces. + + If authentication CLI option is defined, both username and password must + be set! + """ + if 'authentication' not in config: + return + if not {'user', 'password'} <= set(config['authentication']): + raise ConfigError('Authentication requires both username and ' \ + 'password to be set!') def verify_address(config): """ @@ -266,6 +318,7 @@ def verify_vlan_config(config): verify_dhcpv6(vlan) verify_address(vlan) verify_vrf(vlan) + verify_mtu_parent(vlan, config) # 802.1ad (Q-in-Q) VLANs for s_vlan in config.get('vif_s', {}): @@ -273,12 +326,15 @@ def verify_vlan_config(config): verify_dhcpv6(s_vlan) verify_address(s_vlan) verify_vrf(s_vlan) + verify_mtu_parent(s_vlan, config) for c_vlan in s_vlan.get('vif_c', {}): c_vlan = s_vlan['vif_c'][c_vlan] verify_dhcpv6(c_vlan) verify_address(c_vlan) verify_vrf(c_vlan) + verify_mtu_parent(c_vlan, config) + verify_mtu_parent(c_vlan, s_vlan) def verify_accel_ppp_base_service(config): """ @@ -288,7 +344,7 @@ def verify_accel_ppp_base_service(config): # vertify auth settings if dict_search('authentication.mode', config) == 'local': if not dict_search('authentication.local_users', config): - raise ConfigError('PPPoE local auth mode requires local users to be configured!') + raise ConfigError('Authentication mode local requires local users to be configured!') for user in dict_search('authentication.local_users.username', config): user_config = config['authentication']['local_users']['username'][user] @@ -312,7 +368,7 @@ def verify_accel_ppp_base_service(config): raise ConfigError(f'Missing RADIUS secret key for server "{server}"') if 'gateway_address' not in config: - raise ConfigError('PPPoE server requires gateway-address to be configured!') + raise ConfigError('Server requires gateway-address to be configured!') if 'name_server_ipv4' in config: if len(config['name_server_ipv4']) > 2: diff --git a/python/vyos/defaults.py b/python/vyos/defaults.py index 9921e3b5f..03006c383 100644 --- a/python/vyos/defaults.py +++ b/python/vyos/defaults.py @@ -22,7 +22,10 @@ directories = { "migrate": "/opt/vyatta/etc/config-migrate/migrate", "log": "/var/log/vyatta", "templates": "/usr/share/vyos/templates/", - "certbot": "/config/auth/letsencrypt" + "certbot": "/config/auth/letsencrypt", + "api_schema": "/usr/libexec/vyos/services/api/graphql/graphql/schema/", + "api_templates": "/usr/libexec/vyos/services/api/graphql/recipes/templates/" + } cfg_group = 'vyattacfg' diff --git a/python/vyos/ethtool.py b/python/vyos/ethtool.py index 136feae8d..bc103959a 100644 --- a/python/vyos/ethtool.py +++ b/python/vyos/ethtool.py @@ -57,7 +57,11 @@ class Ethtool: if ':' in line: key, value = [s.strip() for s in line.strip().split(":", 1)] key = key.lower().replace(' ', '_') - self.ring_buffers[key] = int(value) + # T3645: ethtool version used on Debian Bullseye changed the + # output format from 0 -> n/a. As we are only interested in the + # tx/rx keys we do not care about RX Mini/Jumbo. + if value.isdigit(): + self.ring_buffers[key] = int(value) def is_fixed_lro(self): diff --git a/python/vyos/ifconfig/__init__.py b/python/vyos/ifconfig/__init__.py index e9da1e9f5..2d3e406ac 100644 --- a/python/vyos/ifconfig/__init__.py +++ b/python/vyos/ifconfig/__init__.py @@ -35,3 +35,4 @@ from vyos.ifconfig.tunnel import TunnelIf from vyos.ifconfig.wireless import WiFiIf from vyos.ifconfig.l2tpv3 import L2TPv3If from vyos.ifconfig.macsec import MACsecIf +from vyos.ifconfig.wwan import WWANIf diff --git a/python/vyos/ifconfig/bond.py b/python/vyos/ifconfig/bond.py index 233d53688..2b9afe109 100644 --- a/python/vyos/ifconfig/bond.py +++ b/python/vyos/ifconfig/bond.py @@ -86,6 +86,9 @@ class BondIf(Interface): _sysfs_get = {**Interface._sysfs_get, **{ 'bond_arp_ip_target': { 'location': '/sys/class/net/{ifname}/bonding/arp_ip_target', + }, + 'bond_mode': { + 'location': '/sys/class/net/{ifname}/bonding/mode', } }} @@ -317,6 +320,19 @@ class BondIf(Interface): return enslaved_ifs + def get_mode(self): + """ + Return bond operation mode. + + Example: + >>> from vyos.ifconfig import BondIf + >>> BondIf('bond0').get_mode() + '802.3ad' + """ + mode = self.get_interface('bond_mode') + # mode is now "802.3ad 4", we are only interested in "802.3ad" + return mode.split()[0] + def set_primary(self, interface): """ A string (eth0, eth2, etc) specifying which slave is the primary diff --git a/python/vyos/ifconfig/ethernet.py b/python/vyos/ifconfig/ethernet.py index b89ca5a5c..07b31a12a 100644 --- a/python/vyos/ifconfig/ethernet.py +++ b/python/vyos/ifconfig/ethernet.py @@ -55,6 +55,11 @@ class EthernetIf(Interface): 'possible': lambda i, v: EthernetIf.feature(i, 'gso', v), # 'shellcmd': 'ethtool -K {ifname} gso {value}', }, + 'lro': { + 'validate': lambda v: assert_list(v, ['on', 'off']), + 'possible': lambda i, v: EthernetIf.feature(i, 'lro', v), + # 'shellcmd': 'ethtool -K {ifname} lro {value}', + }, 'sg': { 'validate': lambda v: assert_list(v, ['on', 'off']), 'possible': lambda i, v: EthernetIf.feature(i, 'sg', v), @@ -238,6 +243,18 @@ class EthernetIf(Interface): raise ValueError("Value out of range") return self.set_interface('gso', 'on' if state else 'off') + def set_lro(self, state): + """ + Enable Large Receive offload. State can be either True or False. + Example: + >>> from vyos.ifconfig import EthernetIf + >>> i = EthernetIf('eth0') + >>> i.set_lro(True) + """ + if not isinstance(state, bool): + raise ValueError("Value out of range") + return self.set_interface('lro', 'on' if state else 'off') + def set_rps(self, state): if not isinstance(state, bool): raise ValueError("Value out of range") @@ -328,6 +345,9 @@ class EthernetIf(Interface): # GSO (generic segmentation offload) self.set_gso(dict_search('offload.gso', config) != None) + # LRO (large receive offload) + self.set_lro(dict_search('offload.lro', config) != None) + # RPS - Receive Packet Steering self.set_rps(dict_search('offload.rps', config) != None) diff --git a/python/vyos/ifconfig/interface.py b/python/vyos/ifconfig/interface.py index 048a2cd19..a1928ba51 100644..100755 --- a/python/vyos/ifconfig/interface.py +++ b/python/vyos/ifconfig/interface.py @@ -311,6 +311,28 @@ class Interface(Control): cmd = 'ip link del dev {ifname}'.format(**self.config) return self._cmd(cmd) + def _set_vrf_ct_zone(self, vrf): + """ + Add/Remove rules in nftables to associate traffic in VRF to an + individual conntack zone + """ + if vrf: + # Get routing table ID for VRF + vrf_table_id = get_interface_config(vrf).get('linkinfo', {}).get( + 'info_data', {}).get('table') + # Add map element with interface and zone ID + if vrf_table_id: + self._cmd( + f'nft add element inet vrf_zones ct_iface_map {{ "{self.ifname}" : {vrf_table_id} }}' + ) + else: + nft_del_element = f'delete element inet vrf_zones ct_iface_map {{ "{self.ifname}" }}' + # Check if deleting is possible first to avoid raising errors + _, err = self._popen(f'nft -c {nft_del_element}') + if not err: + # Remove map element + self._cmd(f'nft {nft_del_element}') + def get_min_mtu(self): """ Get hardware minimum supported MTU @@ -401,6 +423,7 @@ class Interface(Control): >>> Interface('eth0').set_vrf() """ self.set_interface('vrf', vrf) + self._set_vrf_ct_zone(vrf) def set_arp_cache_tmo(self, tmo): """ @@ -779,9 +802,7 @@ class Interface(Control): # Note that currently expanded netmasks are not supported. That means # 2001:db00::0/24 is a valid argument while 2001:db00::0/ffff:ff00:: not. # see https://docs.python.org/3/library/ipaddress.html - bits = bin( - int(v6_addr['netmask'].replace(':', ''), 16)).count('1') - prefix = '/' + str(bits) + prefix = '/' + v6_addr['netmask'].split('/')[-1] # we alsoneed to remove the interface suffix on link local # addresses @@ -1345,12 +1366,55 @@ class Interface(Control): # create/update 802.1q VLAN interfaces for vif_id, vif_config in config.get('vif', {}).items(): + + vif_ifname = f'{ifname}.{vif_id}' + vif_config['ifname'] = vif_ifname + tmp = deepcopy(VLANIf.get_config()) tmp['source_interface'] = ifname tmp['vlan_id'] = vif_id - vif_ifname = f'{ifname}.{vif_id}' - vif_config['ifname'] = vif_ifname + # We need to ensure that the string format is consistent, and we need to exclude redundant spaces. + sep = ' ' + if 'egress_qos' in vif_config: + # Unwrap strings into arrays + egress_qos_array = vif_config['egress_qos'].split() + # The split array is spliced according to the fixed format + tmp['egress_qos'] = sep.join(egress_qos_array) + + if 'ingress_qos' in vif_config: + # Unwrap strings into arrays + ingress_qos_array = vif_config['ingress_qos'].split() + # The split array is spliced according to the fixed format + tmp['ingress_qos'] = sep.join(ingress_qos_array) + + # Since setting the QoS control parameters in the later stage will + # not completely delete the old settings, + # we still need to delete the VLAN encapsulation interface in order to + # ensure that the changed settings are effective. + cur_cfg = get_interface_config(vif_ifname) + qos_str = '' + tmp2 = dict_search('linkinfo.info_data.ingress_qos', cur_cfg) + if 'ingress_qos' in tmp and tmp2: + for item in tmp2: + from_key = item['from'] + to_key = item['to'] + qos_str += f'{from_key}:{to_key} ' + if qos_str != tmp['ingress_qos']: + if self.exists(vif_ifname): + VLANIf(vif_ifname).remove() + + qos_str = '' + tmp2 = dict_search('linkinfo.info_data.egress_qos', cur_cfg) + if 'egress_qos' in tmp and tmp2: + for item in tmp2: + from_key = item['from'] + to_key = item['to'] + qos_str += f'{from_key}:{to_key} ' + if qos_str != tmp['egress_qos']: + if self.exists(vif_ifname): + VLANIf(vif_ifname).remove() + vlan = VLANIf(vif_ifname, **tmp) vlan.update(vif_config) diff --git a/python/vyos/ifconfig/l2tpv3.py b/python/vyos/ifconfig/l2tpv3.py index 7ff0fdd0e..fcd1fbf81 100644 --- a/python/vyos/ifconfig/l2tpv3.py +++ b/python/vyos/ifconfig/l2tpv3.py @@ -13,8 +13,28 @@ # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. +from time import sleep +from time import time +from vyos.util import run from vyos.ifconfig.interface import Interface +def wait_for_add_l2tpv3(timeout=10, sleep_interval=1, cmd=None): + ''' + In some cases, we need to wait until local address is assigned. + And only then can the l2tpv3 tunnel be configured. + For example when ipv6 address in tentative state + or we wait for some routing daemon for remote address. + ''' + start_time = time() + test_command = cmd + while True: + if (start_time + timeout) < time(): + return None + result = run(test_command) + if result == 0: + return True + sleep(sleep_interval) + @Interface.register class L2TPv3If(Interface): """ @@ -43,7 +63,9 @@ class L2TPv3If(Interface): cmd += ' encap {encapsulation}' cmd += ' local {source_address}' cmd += ' remote {remote}' - self._cmd(cmd.format(**self.config)) + c = cmd.format(**self.config) + # wait until the local/remote address is available, but no more 10 sec. + wait_for_add_l2tpv3(cmd=c) # setup session cmd = 'ip l2tp add session name {ifname}' diff --git a/python/vyos/ifconfig/tunnel.py b/python/vyos/ifconfig/tunnel.py index 2a266fc9f..64c735824 100644 --- a/python/vyos/ifconfig/tunnel.py +++ b/python/vyos/ifconfig/tunnel.py @@ -62,6 +62,7 @@ class TunnelIf(Interface): mapping_ipv4 = { 'parameters.ip.key' : 'key', 'parameters.ip.no_pmtu_discovery' : 'nopmtudisc', + 'parameters.ip.ignore_df' : 'ignore-df', 'parameters.ip.tos' : 'tos', 'parameters.ip.ttl' : 'ttl', 'parameters.erspan.direction' : 'erspan_dir', diff --git a/python/vyos/ifconfig/vrrp.py b/python/vyos/ifconfig/vrrp.py index d3e9d5df2..b522cc1ab 100644 --- a/python/vyos/ifconfig/vrrp.py +++ b/python/vyos/ifconfig/vrrp.py @@ -92,11 +92,14 @@ class VRRP(object): try: # send signal to generate the configuration file pid = util.read_file(cls.location['pid']) - os.kill(int(pid), cls._signal[what]) + util.wait_for_file_write_complete(fname, + pre_hook=(lambda: os.kill(int(pid), cls._signal[what])), + timeout=30) - # should look for file size change? - sleep(0.2) return util.read_file(fname) + except OSError: + # raised by vyos.util.read_file + raise VRRPNoData("VRRP data is not available (wait time exceeded)") except FileNotFoundError: raise VRRPNoData("VRRP data is not available (process not running or no active groups)") except Exception: diff --git a/python/vyos/ifconfig/vti.py b/python/vyos/ifconfig/vti.py index e2090c889..470ebbff3 100644 --- a/python/vyos/ifconfig/vti.py +++ b/python/vyos/ifconfig/vti.py @@ -14,6 +14,7 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. from vyos.ifconfig.interface import Interface +from vyos.util import dict_search @Interface.register class VTIIf(Interface): @@ -25,3 +26,27 @@ class VTIIf(Interface): 'prefixes': ['vti', ], }, } + + def _create(self): + # This table represents a mapping from VyOS internal config dict to + # arguments used by iproute2. For more information please refer to: + # - https://man7.org/linux/man-pages/man8/ip-link.8.html + # - https://man7.org/linux/man-pages/man8/ip-tunnel.8.html + mapping = { + 'source_interface' : 'dev', + } + + if_id = self.ifname.lstrip('vti') + cmd = f'ip link add {self.ifname} type xfrm if_id {if_id}' + for vyos_key, iproute2_key in mapping.items(): + # dict_search will return an empty dict "{}" for valueless nodes like + # "parameters.nolearning" - thus we need to test the nodes existence + # by using isinstance() + tmp = dict_search(vyos_key, self.config) + if isinstance(tmp, dict): + cmd += f' {iproute2_key}' + elif tmp != None: + cmd += f' {iproute2_key} {tmp}' + + self._cmd(cmd.format(**self.config)) + self.set_interface('admin_state', 'down') diff --git a/python/vyos/ifconfig/wireguard.py b/python/vyos/ifconfig/wireguard.py index e5b9c4408..c4cf2fbbf 100644 --- a/python/vyos/ifconfig/wireguard.py +++ b/python/vyos/ifconfig/wireguard.py @@ -95,7 +95,7 @@ class WireGuardOperational(Operational): for peer in c.list_effective_nodes(["peer"]): if wgdump['peers']: - pubkey = c.return_effective_value(["peer", peer, "pubkey"]) + pubkey = c.return_effective_value(["peer", peer, "public_key"]) if pubkey in wgdump['peers']: wgpeer = wgdump['peers'][pubkey] @@ -194,11 +194,15 @@ class WireGuardIf(Interface): peer = config['peer_remove'][tmp] peer['ifname'] = config['ifname'] - cmd = 'wg set {ifname} peer {pubkey} remove' + cmd = 'wg set {ifname} peer {public_key} remove' self._cmd(cmd.format(**peer)) + config['private_key_file'] = '/tmp/tmp.wireguard.key' + with open(config['private_key_file'], 'w') as f: + f.write(config['private_key']) + # Wireguard base command is identical for every peer - base_cmd = 'wg set {ifname} private-key {private_key}' + base_cmd = 'wg set {ifname} private-key {private_key_file}' if 'port' in config: base_cmd += ' listen-port {port}' if 'fwmark' in config: @@ -210,7 +214,7 @@ class WireGuardIf(Interface): peer = config['peer'][tmp] # start of with a fresh 'wg' command - cmd = base_cmd + ' peer {pubkey}' + cmd = base_cmd + ' peer {public_key}' # If no PSK is given remove it by using /dev/null - passing keys via # the shell (usually bash) is considered insecure, thus we use a file diff --git a/python/vyos/ifconfig/wwan.py b/python/vyos/ifconfig/wwan.py new file mode 100644 index 000000000..f18959a60 --- /dev/null +++ b/python/vyos/ifconfig/wwan.py @@ -0,0 +1,28 @@ +# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library. If not, see <http://www.gnu.org/licenses/>. + +from vyos.ifconfig.interface import Interface + +@Interface.register +class WWANIf(Interface): + iftype = 'wwan' + definition = { + **Interface.definition, + **{ + 'section': 'wwan', + 'prefixes': ['wwan', ], + 'eternal': 'wwan[0-9]+$', + }, + } diff --git a/python/vyos/pki.py b/python/vyos/pki.py new file mode 100644 index 000000000..68ad73bf2 --- /dev/null +++ b/python/vyos/pki.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2021 VyOS maintainers and contributors +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 or later as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import datetime +import ipaddress + +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.x509.extensions import ExtensionNotFound +from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID, ExtensionOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import dh +from cryptography.hazmat.primitives.asymmetric import dsa +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric import rsa + +CERT_BEGIN='-----BEGIN CERTIFICATE-----\n' +CERT_END='\n-----END CERTIFICATE-----' +KEY_BEGIN='-----BEGIN PRIVATE KEY-----\n' +KEY_END='\n-----END PRIVATE KEY-----' +KEY_ENC_BEGIN='-----BEGIN ENCRYPTED PRIVATE KEY-----\n' +KEY_ENC_END='\n-----END ENCRYPTED PRIVATE KEY-----' +KEY_PUB_BEGIN='-----BEGIN PUBLIC KEY-----\n' +KEY_PUB_END='\n-----END PUBLIC KEY-----' +CRL_BEGIN='-----BEGIN X509 CRL-----\n' +CRL_END='\n-----END X509 CRL-----' +CSR_BEGIN='-----BEGIN CERTIFICATE REQUEST-----\n' +CSR_END='\n-----END CERTIFICATE REQUEST-----' +DH_BEGIN='-----BEGIN DH PARAMETERS-----\n' +DH_END='\n-----END DH PARAMETERS-----' +OVPN_BEGIN = '-----BEGIN OpenVPN Static key V{0}-----\n' +OVPN_END = '\n-----END OpenVPN Static key V{0}-----' + +# Print functions + +encoding_map = { + 'PEM': serialization.Encoding.PEM, + 'OpenSSH': serialization.Encoding.OpenSSH +} + +public_format_map = { + 'SubjectPublicKeyInfo': serialization.PublicFormat.SubjectPublicKeyInfo, + 'OpenSSH': serialization.PublicFormat.OpenSSH +} + +private_format_map = { + 'PKCS8': serialization.PrivateFormat.PKCS8, + 'OpenSSH': serialization.PrivateFormat.OpenSSH +} + +def encode_certificate(cert): + return cert.public_bytes(encoding=serialization.Encoding.PEM).decode('utf-8') + +def encode_public_key(cert, encoding='PEM', key_format='SubjectPublicKeyInfo'): + if encoding not in encoding_map: + encoding = 'PEM' + if key_format not in public_format_map: + key_format = 'SubjectPublicKeyInfo' + return cert.public_bytes( + encoding=encoding_map[encoding], + format=public_format_map[key_format]).decode('utf-8') + +def encode_private_key(private_key, encoding='PEM', key_format='PKCS8', passphrase=None): + if encoding not in encoding_map: + encoding = 'PEM' + if key_format not in private_format_map: + key_format = 'PKCS8' + encryption = serialization.NoEncryption() if not passphrase else serialization.BestAvailableEncryption(bytes(passphrase, 'utf-8')) + return private_key.private_bytes( + encoding=encoding_map[encoding], + format=private_format_map[key_format], + encryption_algorithm=encryption).decode('utf-8') + +def encode_dh_parameters(dh_parameters): + return dh_parameters.parameter_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.ParameterFormat.PKCS3).decode('utf-8') + +# EC Helper + +def get_elliptic_curve(size): + curve_func = None + name = f'SECP{size}R1' + if hasattr(ec, name): + curve_func = getattr(ec, name) + else: + curve_func = ec.SECP256R1() # Default to SECP256R1 + return curve_func() + +# Creation functions + +def create_private_key(key_type, key_size=None): + private_key = None + if key_type == 'rsa': + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) + elif key_type == 'dsa': + private_key = dsa.generate_private_key(key_size=key_size) + elif key_type == 'ec': + curve = get_elliptic_curve(key_size) + private_key = ec.generate_private_key(curve) + return private_key + +def create_certificate_request(subject, private_key, subject_alt_names=[]): + subject_obj = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, subject['country']), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, subject['state']), + x509.NameAttribute(NameOID.LOCALITY_NAME, subject['locality']), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject['organization']), + x509.NameAttribute(NameOID.COMMON_NAME, subject['common_name'])]) + + builder = x509.CertificateSigningRequestBuilder() \ + .subject_name(subject_obj) + + if subject_alt_names: + alt_names = [] + for obj in subject_alt_names: + if isinstance(obj, ipaddress.IPv4Address) or isinstance(obj, ipaddress.IPv6Address): + alt_names.append(x509.IPAddress(obj)) + elif isinstance(obj, str): + alt_names.append(x509.DNSName(obj)) + if alt_names: + builder = builder.add_extension(x509.SubjectAlternativeName(alt_names), critical=False) + + return builder.sign(private_key, hashes.SHA256()) + +def add_key_identifier(ca_cert): + try: + ski_ext = ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier) + return x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ski_ext.value) + except: + return x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key()) + +def create_certificate(cert_req, ca_cert, ca_private_key, valid_days=365, cert_type='server', is_ca=False, is_sub_ca=False): + ext_key_usage = [] + if is_ca: + ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH] + elif cert_type == 'client': + ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH] + elif cert_type == 'server': + ext_key_usage = [ExtendedKeyUsageOID.SERVER_AUTH] + + builder = x509.CertificateBuilder() \ + .subject_name(cert_req.subject) \ + .issuer_name(ca_cert.subject) \ + .public_key(cert_req.public_key()) \ + .serial_number(x509.random_serial_number()) \ + .not_valid_before(datetime.datetime.utcnow()) \ + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=int(valid_days))) + + builder = builder.add_extension(x509.BasicConstraints(ca=is_ca, path_length=0 if is_sub_ca else None), critical=True) + builder = builder.add_extension(x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=is_ca, + crl_sign=is_ca, + encipher_only=False, + decipher_only=False), critical=True) + builder = builder.add_extension(x509.ExtendedKeyUsage(ext_key_usage), critical=False) + builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(cert_req.public_key()), critical=False) + + if not is_ca or is_sub_ca: + builder = builder.add_extension(add_key_identifier(ca_cert), critical=False) + + for ext in cert_req.extensions: + builder = builder.add_extension(ext.value, critical=False) + + return builder.sign(ca_private_key, hashes.SHA256()) + +def create_certificate_revocation_list(ca_cert, ca_private_key, serial_numbers=[]): + if not serial_numbers: + return False + + builder = x509.CertificateRevocationListBuilder() \ + .issuer_name(ca_cert.subject) \ + .last_update(datetime.datetime.today()) \ + .next_update(datetime.datetime.today() + datetime.timedelta(1, 0, 0)) + + for serial_number in serial_numbers: + revoked_cert = x509.RevokedCertificateBuilder() \ + .serial_number(serial_number) \ + .revocation_date(datetime.datetime.today()) \ + .build() + builder = builder.add_revoked_certificate(revoked_cert) + + return builder.sign(private_key=ca_private_key, algorithm=hashes.SHA256()) + +def create_dh_parameters(bits=2048): + if not bits or bits < 512: + print("Invalid DH parameter key size") + return False + + return dh.generate_parameters(generator=2, key_size=int(bits)) + +# Wrap functions + +def wrap_public_key(raw_data): + return KEY_PUB_BEGIN + raw_data + KEY_PUB_END + +def wrap_private_key(raw_data, passphrase=None): + return (KEY_ENC_BEGIN if passphrase else KEY_BEGIN) + raw_data + (KEY_ENC_END if passphrase else KEY_END) + +def wrap_certificate_request(raw_data): + return CSR_BEGIN + raw_data + CSR_END + +def wrap_certificate(raw_data): + return CERT_BEGIN + raw_data + CERT_END + +def wrap_crl(raw_data): + return CRL_BEGIN + raw_data + CRL_END + +def wrap_dh_parameters(raw_data): + return DH_BEGIN + raw_data + DH_END + +def wrap_openvpn_key(raw_data, version='1'): + return OVPN_BEGIN.format(version) + raw_data + OVPN_END.format(version) + +# Load functions + +def load_public_key(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_public_key(raw_data) + + try: + return serialization.load_pem_public_key(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_private_key(raw_data, passphrase=None, wrap_tags=True): + if wrap_tags: + raw_data = wrap_private_key(raw_data, passphrase) + + if passphrase: + passphrase = bytes(passphrase, 'utf-8') + + try: + return serialization.load_pem_private_key(bytes(raw_data, 'utf-8'), password=passphrase) + except ValueError: + return False + +def load_certificate_request(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_certificate_request(raw_data) + + try: + return x509.load_pem_x509_csr(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_certificate(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_certificate(raw_data) + + try: + return x509.load_pem_x509_certificate(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_crl(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_crl(raw_data) + + try: + return x509.load_pem_x509_crl(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +def load_dh_parameters(raw_data, wrap_tags=True): + if wrap_tags: + raw_data = wrap_dh_parameters(raw_data) + + try: + return serialization.load_pem_parameters(bytes(raw_data, 'utf-8')) + except ValueError: + return False + +# Verify + +def is_ca_certificate(cert): + if not cert: + return False + + try: + ext = cert.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS) + return ext.value.ca + except ExtensionNotFound: + return False + +def verify_certificate(cert, ca_cert): + # Verify certificate was signed by specified CA + if ca_cert.subject != cert.issuer: + return False + + ca_public_key = ca_cert.public_key() + try: + if isinstance(ca_public_key, rsa.RSAPublicKeyWithSerialization): + ca_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + padding=padding.PKCS1v15(), + algorithm=cert.signature_hash_algorithm) + elif isinstance(ca_public_key, dsa.DSAPublicKeyWithSerialization): + ca_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + algorithm=cert.signature_hash_algorithm) + elif isinstance(ca_public_key, ec.EllipticCurvePublicKeyWithSerialization): + ca_public_key.verify( + cert.signature, + cert.tbs_certificate_bytes, + signature_algorithm=ec.ECDSA(cert.signature_hash_algorithm)) + else: + return False # We cannot verify it + return True + except InvalidSignature: + return False diff --git a/python/vyos/remote.py b/python/vyos/remote.py index f683a6d5a..e972050b7 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -15,90 +15,176 @@ from ftplib import FTP import os +import shutil import socket +import stat import sys import tempfile import urllib.parse -import urllib.request +import urllib.request as urlreq -from vyos.util import cmd, ask_yes_no +from vyos.template import get_ip +from vyos.template import ip_from_cidr +from vyos.template import is_interface +from vyos.template import is_ipv6 +from vyos.util import cmd +from vyos.util import ask_yes_no +from vyos.util import print_error +from vyos.util import make_progressbar +from vyos.util import make_incremental_progressbar from vyos.version import get_version -from paramiko import SSHClient, SSHException, MissingHostKeyPolicy +from paramiko import SSHClient +from paramiko import SSHException +from paramiko import MissingHostKeyPolicy - -known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') +# This is a hardcoded path and no environment variable can change it. +KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts') +CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): """ Policy for interactively querying the user on whether to proceed with - SSH connections to unknown hosts. + SSH connections to unknown hosts. """ def missing_host_key(self, client, hostname, key): - print(f"Host '{hostname}' not found in known hosts.") - print('Fingerprint: ' + key.get_fingerprint().hex()) + print_error(f"Host '{hostname}' not found in known hosts.") + print_error('Fingerprint: ' + key.get_fingerprint().hex()) if ask_yes_no('Do you wish to continue?'): - if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): + if client._host_keys_filename\ + and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): client._host_keys.add(hostname, key.get_name(), key) client.save_host_keys(client._host_keys_filename) else: raise SSHException(f"Cannot connect to unknown host '{hostname}'.") + +## Helper routines +def get_authentication_variables(default_username=None, default_password=None): + """ + Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and + return the defaults provided if environment variables are empty or nonexistent. + """ + username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') + # Fall back to defaults if the username variable doesn't exist or is an empty string. + # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, + # as we want the username and the password to have the same behaviour. + if not username: + return default_username, default_password + else: + return username, password + +def get_source_address(source): + """ + Take a string vaguely indicating an origin source (interface, hostname or IP address), + return a tuple in the format `(source_pair, address_family)` where + `source_pair` is `(source_address, source_port)`. + """ + # TODO: Properly distinguish between IPv4 and IPv6. + port = 0 + if is_interface(source): + source = ip_from_cidr(get_ip(source)[0]) + if is_ipv6(source): + return (source, port), socket.AF_INET6 + else: + return (socket.gethostbyname(source), port), socket.AF_INET + +def get_port_from_url(url): + """ + Return the port number from the given `url` named tuple, fall back to + the default if there isn't one. + """ + defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\ + "ssh": 22, "scp": 22, "sftp": 22} + if url.port: + return url.port + else: + return defaults[url.scheme] + + ## FTP routines -def transfer_ftp(mode, local_path, hostname, remote_path,\ - username='anonymous', password='', port=21, source=None): - with FTP(source_address=source) as conn: +def upload_ftp(local_path, hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source_pair=None, progressbar=False): + size = os.path.getsize(local_path) + with FTP(source_address=source_pair) as conn: conn.connect(hostname, port) conn.login(username, password) - if mode == 'upload': - with open(local_path, 'rb') as file: - conn.storbinary(f'STOR {remote_path}', file) - elif mode == 'download': - with open(local_path, 'wb') as file: - conn.retrbinary(f'RETR {remote_path}', file.write) - elif mode == 'size': - size = conn.size(remote_path) - if size: - return size + with open(local_path, 'rb') as file: + if progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / size) + next(progress) + callback = lambda block: next(progress) else: - # SIZE is an extension to the FTP specification, although it's extremely common. - raise ValueError('Failed to receive file size from FTP server. \ - Perhaps the server does not implement the SIZE command?') + callback = None + conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback) -def upload_ftp(*args, **kwargs): - transfer_ftp('upload', *args, **kwargs) +def download_ftp(local_path, hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source_pair=None, progressbar=False): + with FTP(source_address=source_pair) as conn: + conn.connect(hostname, port) + conn.login(username, password) + size = conn.size(remote_path) + with open(local_path, 'wb') as file: + # No progressbar if we can't determine the size. + if progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / size) + next(progress) + callback = lambda block: (file.write(block), next(progress)) + else: + callback = file.write + conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE) -def download_ftp(*args, **kwargs): - transfer_ftp('download', *args, **kwargs) +def get_ftp_file_size(hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source_pair=None): + with FTP(source_address=source) as conn: + conn.connect(hostname, port) + conn.login(username, password) + size = conn.size(remote_path) + if size: + return size + else: + # SIZE is an extension to the FTP specification, although it's extremely common. + raise ValueError('Failed to receive file size from FTP server. \ + Perhaps the server does not implement the SIZE command?') -def get_ftp_file_size(*args, **kwargs): - return transfer_ftp('size', None, *args, **kwargs) ## SFTP/SCP routines def transfer_sftp(mode, local_path, hostname, remote_path,\ - username=None, password=None, port=22, source=None): + username=None, password=None, port=22,\ + source_tuple=None, progressbar=False): sock = None - if source: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((source, 0)) + if source_tuple: + (source_address, source_port), address_family = source_tuple + sock = socket.socket(address_family, socket.SOCK_STREAM) + sock.bind((source_address, source_port)) sock.connect((hostname, port)) - try: - with SSHClient() as ssh: - ssh.load_system_host_keys() - if os.path.exists(known_hosts_file): - ssh.load_host_keys(known_hosts_file) - ssh.set_missing_host_key_policy(InteractivePolicy()) - ssh.connect(hostname, port, username, password, sock=sock) - with ssh.open_sftp() as sftp: - if mode == 'upload': - sftp.put(local_path, remote_path) - elif mode == 'download': - sftp.get(remote_path, local_path) - elif mode == 'size': - return sftp.stat(remote_path).st_size - finally: - if sock: - sock.shutdown() - sock.close() + callback = make_progressbar() if progressbar else None + with SSHClient() as ssh: + ssh.load_system_host_keys() + if os.path.exists(KNOWN_HOSTS_FILE): + ssh.load_host_keys(KNOWN_HOSTS_FILE) + ssh.set_missing_host_key_policy(InteractivePolicy()) + ssh.connect(hostname, port, username, password, sock=sock) + with ssh.open_sftp() as sftp: + if mode == 'upload': + try: + # If the remote path is a directory, use the original filename. + if stat.S_ISDIR(sftp.stat(remote_path).st_mode): + path = os.path.join(remote_path, os.path.basename(local_path)) + # A file exists at this destination. We're simply going to clobber it. + else: + path = remote_path + # This path doesn't point at any existing file. We can freely use this filename. + except IOError: + path = remote_path + finally: + sftp.put(local_path, path, callback=callback) + elif mode == 'download': + sftp.get(remote_path, local_path, callback=callback) + elif mode == 'size': + return sftp.stat(remote_path).st_size def upload_sftp(*args, **kwargs): transfer_sftp('upload', *args, **kwargs) @@ -109,32 +195,70 @@ def download_sftp(*args, **kwargs): def get_sftp_file_size(*args, **kwargs): return transfer_sftp('size', None, *args, **kwargs) + ## TFTP routines -def upload_tftp(local_path, hostname, remote_path, port=69, source=None): +def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'rb') as file: - cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\ + cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\ stderr=None, input=file.read()).encode() -def download_tftp(local_path, hostname, remote_path, port=69, source=None): +def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + # Not really applicable but we pass it for the sake of uniformity. + progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'wb') as file: - file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\ + file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ stderr=None).encode()) # get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, -# as TFTP does not specify a SIZE command. +# as TFTP does not specify a SIZE command. + ## HTTP(S) routines -def download_http(urlstring, local_path): - request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) - with open(local_path, 'wb') as file: - with urllib.request.urlopen(request) as response: - file.write(response.read()) +def install_request_opener(urlstring, username, password): + """ + Take `username` and `password` strings and install the appropriate + password manager to `urllib.request.urlopen()` for the given `urlstring`. + """ + manager = urlreq.HTTPPasswordMgrWithDefaultRealm() + manager.add_password(None, urlstring, username, password) + urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager))) + +# upload_http() is unimplemented. -def get_http_file_size(urlstring): - request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) - with urllib.request.urlopen(request) as response: +def download_http(local_path, urlstring, username=None, password=None, progressbar=False): + """ + Download the file from from `urlstring` to `local_path`. + Optionally takes `username` and `password` for authentication. + """ + request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) + if username: + install_request_opener(urlstring, username, password) + with open(local_path, 'wb') as file, urlreq.urlopen(request) as response: + size = response.getheader('Content-Length') + if progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / int(size)) + next(progress) + for chunk in iter(lambda: response.read(CHUNK_SIZE), b''): + file.write(chunk) + next(progress) + next(progress) + # If we can't determine the size or if a progress bar wasn't requested, + # we can let `shutil` take care of the copying. + else: + shutil.copyfileobj(response, file) + +def get_http_file_size(urlstring, username=None, password=None): + """ + Return the size of the file from `urlstring` in terms of number of bytes. + Optionally takes `username` and `password` for authentication. + """ + request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) + if username: + install_request_opener(urlstring, username, password) + with urlreq.urlopen(request) as response: size = response.getheader('Content-Length') if size: return int(size) @@ -142,69 +266,96 @@ def get_http_file_size(urlstring): else: raise ValueError('Failed to receive file size from HTTP server.') -# Dynamic dispatchers -def download(local_path, urlstring, source=None): + +## Dynamic dispatchers +def download(local_path, urlstring, source=None, progressbar=False): """ - Dispatch the appropriate download function for the given URL and save to local path. + Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. + Optionally takes a `source` address or interface (not valid for HTTP(S)). + Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. + Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) + username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) + if url.scheme == 'http' or url.scheme == 'https': if source: - print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr) - download_http(urlstring, local_path) + print_error('Warning: Custom source address not supported for HTTP connections.') + download_http(local_path, urlstring, username, password, progressbar) elif url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - download_ftp(local_path, url.hostname, url.path, username, url.password, source=source) + source = get_source_address(source)[0] if source else None + username = username if username else 'anonymous' + download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': - download_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source) + source = get_source_address(source) if source else None + download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': - download_tftp(local_path, url.hostname, url.path, source=source) + download_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def upload(local_path, urlstring, source=None): +def upload(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate upload function for the given URL and upload from local path. + Optionally takes a `source` address. + Supports FTP, SFTP, SCP (through SFTP) and TFTP. + Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) + username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) + if url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - upload_ftp(local_path, url.hostname, url.path, username, url.password, source=source) + username = username if username else 'anonymous' + source = get_source_address(source)[0] if source else None + upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': - upload_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source) + source = get_source_address(source) if source else None + upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': - upload_tftp(local_path, url.hostname, url.path, source=source) + upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') def get_remote_file_size(urlstring, source=None): """ - Return the size of the remote file in bytes. + Dispatch the appropriate function to return the size of the remote file from `urlstring` + in terms of number of bytes. + Optionally takes a `source` address (not valid for HTTP(S)). + Supports HTTP, HTTPS, FTP and SFTP (through SFTP). + Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) + username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) + if url.scheme == 'http' or url.scheme == 'https': - return get_http_file_size(urlstring) + if source: + print_error('Warning: Custom source address not supported for HTTP connections.') + return get_http_file_size(urlstring, username, password) elif url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - return get_ftp_file_size(url.hostname, url.path, username, url.password, source=source) + source = get_source_address(source)[0] if source else None + username = username if username else 'anonymous' + return get_ftp_file_size(url.hostname, url.path, username, password, port, source) elif url.scheme == 'sftp' or url.scheme == 'scp': - return get_sftp_file_size(url.hostname, url.path, url.username, url.password, source=source) + source = get_source_address(source) if source else None + return get_sftp_file_size(url.hostname, url.path, username, password, port, source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') def get_remote_config(urlstring, source=None): """ - Download remote (config) file and return the contents. + Download remote (config) file from `urlstring` and return the contents as a string. Args: remote file URI: - scp://<user>[:<passwd>]@<host>/<file> - sftp://<user>[:<passwd>]@<host>/<file> - http://<host>/<file> - https://<host>/<file> - ftp://[<user>[:<passwd>]@]<host>/<file> - tftp://<host>/<file> + tftp://<host>[:<port>]/<file> + http[s]://<host>[:<port>]/<file> + [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file> + source address (optional): + <interface> + <IP address> """ - url = urllib.parse.urlparse(urlstring) temp = tempfile.NamedTemporaryFile(delete=False).name try: download(temp, urlstring, source) @@ -212,3 +363,41 @@ def get_remote_config(urlstring, source=None): return file.read() finally: os.remove(temp) + +def friendly_download(local_path, urlstring, source=None): + """ + Download from `urlstring` to `local_path` in an informative way. + Checks the storage space before attempting download. + Intended to be called from interactive, user-facing scripts. + """ + destination_directory = os.path.dirname(local_path) + try: + free_space = shutil.disk_usage(destination_directory).free + try: + file_size = get_remote_file_size(urlstring, source) + if file_size < 1024 * 1024: + print_error(f'The file is {file_size / 1024.0:.3f} KiB.') + else: + print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') + if file_size > free_space: + raise OSError(f'Not enough disk space available in "{destination_directory}".') + except ValueError: + # Can't do a storage check in this case, so we bravely continue. + file_size = 0 + print_error('Could not determine the file size in advance.') + else: + print_error('Downloading...') + download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024) + except KeyboardInterrupt: + print_error('Download aborted by user.') + sys.exit(1) + except: + import traceback + # There are a myriad different reasons a download could fail. + # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) + # We omit the scary stack trace but print the error nevertheless. + print_error(f'Failed to download {urlstring}.') + traceback.print_exception(*sys.exc_info()[:2], None) + sys.exit(1) + else: + print_error('Download complete.') diff --git a/python/vyos/template.py b/python/vyos/template.py index e1986b1e4..08a5712af 100644 --- a/python/vyos/template.py +++ b/python/vyos/template.py @@ -29,13 +29,17 @@ _FILTERS = {} # reuse Environments with identical settings to improve performance @functools.lru_cache(maxsize=2) -def _get_environment(): +def _get_environment(location=None): + if location is None: + loc_loader=FileSystemLoader(directories["templates"]) + else: + loc_loader=FileSystemLoader(location) env = Environment( # Don't check if template files were modified upon re-rendering auto_reload=False, # Cache up to this number of templates for quick re-rendering cache_size=100, - loader=FileSystemLoader(directories["templates"]), + loader=loc_loader, trim_blocks=True, ) env.filters.update(_FILTERS) @@ -63,7 +67,7 @@ def register_filter(name, func=None): return func -def render_to_string(template, content, formater=None): +def render_to_string(template, content, formater=None, location=None): """Render a template from the template directory, raise on any errors. :param template: the path to the template relative to the template folder @@ -78,7 +82,7 @@ def render_to_string(template, content, formater=None): package is build (recovering the load time and overhead caused by having the file out of the code). """ - template = _get_environment().get_template(template) + template = _get_environment(location).get_template(template) rendered = template.render(content) if formater is not None: rendered = formater(rendered) @@ -93,6 +97,7 @@ def render( permission=None, user=None, group=None, + location=None, ): """Render a template from the template directory to a file, raise on any errors. @@ -109,7 +114,7 @@ def render( # As we are opening the file with 'w', we are performing the rendering before # calling open() to not accidentally erase the file if rendering fails - rendered = render_to_string(template, content, formater) + rendered = render_to_string(template, content, formater, location) # Write to file with open(destination, "w") as file: @@ -375,3 +380,96 @@ def get_ipv4(interface): """ Get interface IPv4 addresses""" from vyos.ifconfig import Interface return Interface(interface).get_addr_v4() + +@register_filter('get_ipv6') +def get_ipv6(interface): + """ Get interface IPv6 addresses""" + from vyos.ifconfig import Interface + return Interface(interface).get_addr_v6() + +@register_filter('get_ip') +def get_ip(interface): + """ Get interface IP addresses""" + from vyos.ifconfig import Interface + return Interface(interface).get_addr() + +@register_filter('get_esp_ike_cipher') +def get_esp_ike_cipher(group_config): + pfs_lut = { + 'dh-group1' : 'modp768', + 'dh-group2' : 'modp1024', + 'dh-group5' : 'modp1536', + 'dh-group14' : 'modp2048', + 'dh-group15' : 'modp3072', + 'dh-group16' : 'modp4096', + 'dh-group17' : 'modp6144', + 'dh-group18' : 'modp8192', + 'dh-group19' : 'ecp256', + 'dh-group20' : 'ecp384', + 'dh-group21' : 'ecp512', + 'dh-group22' : 'modp1024s160', + 'dh-group23' : 'modp2048s224', + 'dh-group24' : 'modp2048s256', + 'dh-group25' : 'ecp192', + 'dh-group26' : 'ecp224', + 'dh-group27' : 'ecp224bp', + 'dh-group28' : 'ecp256bp', + 'dh-group29' : 'ecp384bp', + 'dh-group30' : 'ecp512bp', + 'dh-group31' : 'curve25519', + 'dh-group32' : 'curve448' + } + + ciphers = [] + if 'proposal' in group_config: + for priority, proposal in group_config['proposal'].items(): + # both encryption and hash need to be specified for a proposal + if not {'encryption', 'hash'} <= set(proposal): + continue + + tmp = '{encryption}-{hash}'.format(**proposal) + if 'dh_group' in proposal: + tmp += '-' + pfs_lut[ 'dh-group' + proposal['dh_group'] ] + elif 'pfs' in group_config and group_config['pfs'] != 'disable': + group = group_config['pfs'] + if group_config['pfs'] == 'enable': + group = 'dh-group2' + tmp += '-' + pfs_lut[group] + + ciphers.append(tmp) + return ciphers + +@register_filter('get_uuid') +def get_uuid(interface): + """ Get interface IP addresses""" + from uuid import uuid1 + return uuid1() + +openvpn_translate = { + 'des': 'des-cbc', + '3des': 'des-ede3-cbc', + 'bf128': 'bf-cbc', + 'bf256': 'bf-cbc', + 'aes128gcm': 'aes-128-gcm', + 'aes128': 'aes-128-cbc', + 'aes192gcm': 'aes-192-gcm', + 'aes192': 'aes-192-cbc', + 'aes256gcm': 'aes-256-gcm', + 'aes256': 'aes-256-cbc' +} + +@register_filter('openvpn_cipher') +def get_openvpn_cipher(cipher): + if cipher in openvpn_translate: + return openvpn_translate[cipher].upper() + return cipher.upper() + +@register_filter('openvpn_ncp_ciphers') +def get_openvpn_ncp_ciphers(ciphers): + out = [] + for cipher in ciphers: + if cipher in openvpn_translate: + out.append(openvpn_translate[cipher]) + else: + out.append(cipher) + return ':'.join(out).upper() diff --git a/python/vyos/util.py b/python/vyos/util.py index 2a3f6a228..59f9f1c44 100644 --- a/python/vyos/util.py +++ b/python/vyos/util.py @@ -1,4 +1,4 @@ -# Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2020-2021 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -22,25 +22,13 @@ import sys # where it is used so it is as local as possible to the execution # - -def _need_sudo(command): - return os.path.basename(command.split()[0]) in ('systemctl', ) - - -def _add_sudo(command): - if _need_sudo(command): - return 'sudo ' + command - return command - - from subprocess import Popen from subprocess import PIPE from subprocess import STDOUT from subprocess import DEVNULL - def popen(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=PIPE, decode='utf-8', autosudo=True): + stdout=PIPE, stderr=PIPE, decode='utf-8'): """ popen is a wrapper helper aound subprocess.Popen with it default setting it will return a tuple (out, err) @@ -79,9 +67,6 @@ def popen(command, flag='', shell=None, input=None, timeout=None, env=None, if not debug.enabled(flag): flag = 'command' - if autosudo: - command = _add_sudo(command) - cmd_msg = f"cmd '{command}'" debug.message(cmd_msg, flag) @@ -98,11 +83,8 @@ def popen(command, flag='', shell=None, input=None, timeout=None, env=None, stdin = PIPE input = input.encode() if type(input) is str else input - p = Popen( - command, - stdin=stdin, stdout=stdout, stderr=stderr, - env=env, shell=use_shell, - ) + p = Popen(command, stdin=stdin, stdout=stdout, stderr=stderr, + env=env, shell=use_shell) pipe = p.communicate(input, timeout) @@ -135,7 +117,7 @@ def popen(command, flag='', shell=None, input=None, timeout=None, env=None, def run(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=DEVNULL, stderr=PIPE, decode='utf-8', autosudo=True): + stdout=DEVNULL, stderr=PIPE, decode='utf-8'): """ A wrapper around popen, which discard the stdout and will return the error code of a command @@ -151,8 +133,8 @@ def run(command, flag='', shell=None, input=None, timeout=None, env=None, def cmd(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=PIPE, decode='utf-8', autosudo=True, - raising=None, message='', expect=[0]): + stdout=PIPE, stderr=PIPE, decode='utf-8', raising=None, message='', + expect=[0]): """ A wrapper around popen, which returns the stdout and will raise the error code of a command @@ -183,7 +165,7 @@ def cmd(command, flag='', shell=None, input=None, timeout=None, env=None, def call(command, flag='', shell=None, input=None, timeout=None, env=None, - stdout=PIPE, stderr=PIPE, decode='utf-8', autosudo=True): + stdout=PIPE, stderr=PIPE, decode='utf-8'): """ A wrapper around popen, which print the stdout and will return the error code of a command @@ -239,7 +221,6 @@ def write_file(fname, data, defaultonfailure=None, user=None, group=None): return defaultonfailure raise e - def read_json(fname, defaultonfailure=None): """ read and json decode the content of a file @@ -459,7 +440,6 @@ def process_running(pid_file): pid = f.read().strip() return pid_exists(int(pid)) - def process_named_running(name): """ Checks if process with given name is running and returns its PID. If Process is not running, return None @@ -470,7 +450,6 @@ def process_named_running(name): return p.pid return None - def seconds_to_human(s, separator=""): """ Converts number of seconds passed to a human-readable interval such as 1w4d18h35m59s @@ -525,6 +504,46 @@ def file_is_persistent(path): absolute = os.path.abspath(os.path.dirname(path)) return re.match(location,absolute) +def wait_for_inotify(file_path, pre_hook=None, event_type=None, timeout=None, sleep_interval=0.1): + """ Waits for an inotify event to occur """ + if not os.path.dirname(file_path): + raise ValueError( + "File path {} does not have a directory part (required for inotify watching)".format(file_path)) + if not os.path.basename(file_path): + raise ValueError( + "File path {} does not have a file part, do not know what to watch for".format(file_path)) + + from inotify.adapters import Inotify + from time import time + from time import sleep + + time_start = time() + + i = Inotify() + i.add_watch(os.path.dirname(file_path)) + + if pre_hook: + pre_hook() + + for event in i.event_gen(yield_nones=True): + if (timeout is not None) and ((time() - time_start) > timeout): + # If the function didn't return until this point, + # the file failed to have been written to and closed within the timeout + raise OSError("Waiting for file {} to be written has failed".format(file_path)) + + # Most such events don't take much time, so it's better to check right away + # and sleep later. + if event is not None: + (_, type_names, path, filename) = event + if filename == os.path.basename(file_path): + if event_type in type_names: + return + sleep(sleep_interval) + +def wait_for_file_write_complete(file_path, pre_hook=None, timeout=None, sleep_interval=0.1): + """ Waits for a process to close a file after opening it in write mode. """ + wait_for_inotify(file_path, + event_type='IN_CLOSE_WRITE', pre_hook=pre_hook, timeout=timeout, sleep_interval=sleep_interval) def commit_in_progress(): """ Not to be used in normal op mode scripts! """ @@ -571,6 +590,25 @@ def wait_for_commit_lock(): while commit_in_progress(): sleep(1) +def ask_input(question, default='', numeric_only=False, valid_responses=[]): + question_out = question + if default: + question_out += f' (Default: {default})' + response = '' + while True: + response = input(question_out + ' ').strip() + if not response and default: + return default + if numeric_only: + if not response.isnumeric(): + print("Invalid value, try again.") + continue + response = int(response) + if valid_responses and response not in valid_responses: + print("Invalid value, try again.") + continue + break + return response def ask_yes_no(question, default=False) -> bool: """Ask a yes/no question via input() and return their answer.""" @@ -672,6 +710,19 @@ def dict_search(path, my_dict): c = c.get(p, {}) return c.get(parts[-1], None) +def dict_search_args(dict_object, *path): + # Traverse dictionary using variable arguments + # Added due to above function not allowing for '.' in the key names + # Example: dict_search_args(some_dict, 'key', 'subkey', 'subsubkey', ...) + if not isinstance(dict_object, dict) or not path: + return None + + for item in path: + if item not in dict_object: + return None + dict_object = dict_object[item] + return dict_object + def get_interface_config(interface): """ Returns the used encapsulation protocol for given interface. If interface does not exist, None is returned. @@ -682,6 +733,16 @@ def get_interface_config(interface): tmp = loads(cmd(f'ip -d -j link show {interface}'))[0] return tmp +def get_interface_address(interface): + """ Returns the used encapsulation protocol for given interface. + If interface does not exist, None is returned. + """ + if not os.path.exists(f'/sys/class/net/{interface}'): + return None + from json import loads + tmp = loads(cmd(f'ip -d -j addr show {interface}'))[0] + return tmp + def get_all_vrfs(): """ Return a dictionary of all system wide known VRF instances """ from json import loads @@ -694,3 +755,58 @@ def get_all_vrfs(): name = entry.pop('name') data[name] = entry return data + +def print_error(str='', end='\n'): + """ + Print `str` to stderr, terminated with `end`. + Used for warnings and out-of-band messages to avoid mangling precious + stdout output. + """ + sys.stderr.write(str) + sys.stderr.write(end) + sys.stderr.flush() + +def make_progressbar(): + """ + Make a procedure that takes two arguments `done` and `total` and prints a + progressbar based on the ratio thereof, whose length is determined by the + width of the terminal. + """ + import shutil, math + col, _ = shutil.get_terminal_size() + col = max(col - 15, 20) + def print_progressbar(done, total): + if done <= total: + increment = total / col + length = math.ceil(done / increment) + percentage = str(math.ceil(100 * done / total)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + # Print a newline so that the subsequent prints don't overwrite the full bar. + if done == total: + print_error() + return print_progressbar + +def make_incremental_progressbar(increment: float): + """ + Make a generator that displays a progressbar that grows monotonically with + every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0. + Intended for FTP and HTTP transfers with stateless callbacks. + """ + print_progressbar = make_progressbar() + total = 0.0 + while total < 1.0: + print_progressbar(total, 1.0) + yield + total += increment + print_progressbar(1, 1) + # Ignore further calls. + while True: + yield + +def is_systemd_service_running(service): + """ Test is a specified systemd service is actually running. + Returns True if service is running, false otherwise. """ + tmp = run(f'systemctl is-active --quiet {service}') + return bool((tmp == 0)) diff --git a/python/vyos/validate.py b/python/vyos/validate.py index 23e88b5ac..0dad2a6cb 100644 --- a/python/vyos/validate.py +++ b/python/vyos/validate.py @@ -49,7 +49,7 @@ def is_intf_addr_assigned(intf, addr): return _is_intf_addr_assigned(intf, ip, mask) return _is_intf_addr_assigned(intf, addr) -def _is_intf_addr_assigned(intf, address, netmask=''): +def _is_intf_addr_assigned(intf, address, netmask=None): """ Verify if the given IPv4/IPv6 address is assigned to specific interface. It can check both a single IP address (e.g. 192.0.2.1 or a assigned CIDR @@ -85,14 +85,14 @@ def _is_intf_addr_assigned(intf, address, netmask=''): continue # we do not have a netmask to compare against, they are the same - if netmask == '': + if not netmask: return True prefixlen = '' if is_ipv4(ip_addr): prefixlen = sum([bin(int(_)).count('1') for _ in ip['netmask'].split('.')]) else: - prefixlen = sum([bin(int(_,16)).count('1') for _ in ip['netmask'].split(':') if _]) + prefixlen = sum([bin(int(_,16)).count('1') for _ in ip['netmask'].split('/')[0].split(':') if _]) if str(prefixlen) == netmask: return True diff --git a/python/vyos/xml/load.py b/python/vyos/xml/load.py index 0965d4220..37479c6e1 100644 --- a/python/vyos/xml/load.py +++ b/python/vyos/xml/load.py @@ -225,6 +225,9 @@ def _format_node(inside, conf, xml): else: _fatal(constraint) + elif 'constraintGroup' in properties: + properties.pop('constraintGroup') + elif 'constraintErrorMessage' in properties: r[kw.error] = properties.pop('constraintErrorMessage') diff --git a/python/vyos/xml/test_xml.py b/python/vyos/xml/test_xml.py index ff55151d2..3a6f0132d 100644 --- a/python/vyos/xml/test_xml.py +++ b/python/vyos/xml/test_xml.py @@ -59,7 +59,7 @@ class TestSearch(TestCase): last = self.xml.traverse("interfaces") self.assertEqual(last, '') self.assertEqual(self.xml.inside, ['interfaces']) - self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wirelessmodem']) + self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wwan']) self.assertEqual(self.xml.filling, False) self.assertEqual(self.xml.word, '') self.assertEqual(self.xml.check, False) @@ -72,7 +72,7 @@ class TestSearch(TestCase): last = self.xml.traverse("interfaces ") self.assertEqual(last, '') self.assertEqual(self.xml.inside, ['interfaces']) - self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wirelessmodem']) + self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wwan']) self.assertEqual(self.xml.filling, False) self.assertEqual(self.xml.word, last) self.assertEqual(self.xml.check, False) @@ -85,7 +85,7 @@ class TestSearch(TestCase): last = self.xml.traverse("interfaces w") self.assertEqual(last, 'w') self.assertEqual(self.xml.inside, ['interfaces']) - self.assertEqual(self.xml.options, ['wireguard', 'wireless', 'wirelessmodem']) + self.assertEqual(self.xml.options, ['wireguard', 'wireless', 'wwan']) self.assertEqual(self.xml.filling, True) self.assertEqual(self.xml.word, last) self.assertEqual(self.xml.check, True) @@ -276,4 +276,4 @@ class TestSearch(TestCase): self.assertEqual(self.xml.filled, True) self.assertEqual(self.xml.plain, False) - # Need to add a check for a valuless leafNode
\ No newline at end of file + # Need to add a check for a valuless leafNode |