summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/vyos/airbag.py24
-rw-r--r--python/vyos/certbot_util.py2
-rw-r--r--python/vyos/configquery.py44
-rw-r--r--python/vyos/configsession.py7
-rw-r--r--python/vyos/configverify.py68
-rw-r--r--python/vyos/defaults.py5
-rw-r--r--python/vyos/ethtool.py6
-rw-r--r--python/vyos/ifconfig/__init__.py1
-rw-r--r--python/vyos/ifconfig/bond.py16
-rw-r--r--python/vyos/ifconfig/ethernet.py20
-rwxr-xr-x[-rw-r--r--]python/vyos/ifconfig/interface.py74
-rw-r--r--python/vyos/ifconfig/l2tpv3.py24
-rw-r--r--python/vyos/ifconfig/tunnel.py1
-rw-r--r--python/vyos/ifconfig/vrrp.py9
-rw-r--r--python/vyos/ifconfig/vti.py25
-rw-r--r--python/vyos/ifconfig/wireguard.py12
-rw-r--r--python/vyos/ifconfig/wwan.py28
-rw-r--r--python/vyos/pki.py333
-rw-r--r--python/vyos/remote.py375
-rw-r--r--python/vyos/template.py108
-rw-r--r--python/vyos/util.py174
-rw-r--r--python/vyos/validate.py6
-rw-r--r--python/vyos/xml/load.py3
-rw-r--r--python/vyos/xml/test_xml.py8
24 files changed, 1192 insertions, 181 deletions
diff --git a/python/vyos/airbag.py b/python/vyos/airbag.py
index 510ab7f46..a20f44207 100644
--- a/python/vyos/airbag.py
+++ b/python/vyos/airbag.py
@@ -18,7 +18,6 @@ from datetime import datetime
from vyos import debug
from vyos.logger import syslog
-from vyos.version import get_version
from vyos.version import get_full_version_data
@@ -78,7 +77,7 @@ def bug_report(dtype, value, trace):
information.update({
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'trace': trace,
- 'instructions': COMMUNITY if 'rolling' in get_version() else SUPPORTED,
+ 'instructions': INSTRUCTIONS,
'note': note,
})
@@ -162,20 +161,13 @@ When reporting problems, please include as much information as possible:
"""
-COMMUNITY = """\
-- Make sure you are running the latest version of the code available at
- https://downloads.vyos.io/rolling/current/amd64/vyos-rolling-latest.iso
-- Consult the forum to see how to handle this issue
- https://forum.vyos.io
-- Join our community on slack where our users exchange help and advice
- https://vyos.slack.com
-""".strip()
-
-SUPPORTED = """\
-- Make sure you are running the latest stable version of VyOS
- the code is available at https://downloads.vyos.io/?dir=release/current
-- Contact us using the online help desk
+INSTRUCTIONS = """\
+- Contact us using the online help desk if you have a subscription:
https://support.vyos.io/
-- Join our community on slack where our users exchange help and advice
+- Make sure you are running the latest version of VyOS available at:
+ https://vyos.net/get/
+- Consult the community forum to see how to handle this issue:
+ https://forum.vyos.io
+- Join us on Slack where our users exchange help and advice:
https://vyos.slack.com
""".strip()
diff --git a/python/vyos/certbot_util.py b/python/vyos/certbot_util.py
index df42d4780..bcb78381f 100644
--- a/python/vyos/certbot_util.py
+++ b/python/vyos/certbot_util.py
@@ -1,7 +1,7 @@
# certbot_util -- adaptation of certbot_nginx name matching functions for VyOS
# https://github.com/certbot/certbot/blob/master/LICENSE.txt
-from certbot_nginx import parser
+from certbot_nginx._internal import parser
NAME_RANK = 0
START_WILDCARD_RANK = 1
diff --git a/python/vyos/configquery.py b/python/vyos/configquery.py
index ed7346f1f..1cdcbcf39 100644
--- a/python/vyos/configquery.py
+++ b/python/vyos/configquery.py
@@ -18,9 +18,16 @@ A small library that allows querying existence or value(s) of config
settings from op mode, and execution of arbitrary op mode commands.
'''
+import re
+import json
+from copy import deepcopy
from subprocess import STDOUT
-from vyos.util import popen
+import vyos.util
+import vyos.xml
+from vyos.config import Config
+from vyos.configtree import ConfigTree
+from vyos.configsource import ConfigSourceSession
class ConfigQueryError(Exception):
pass
@@ -51,32 +58,59 @@ class CliShellApiConfigQuery(GenericConfigQuery):
def exists(self, path: list):
cmd = ' '.join(path)
- (_, err) = popen(f'cli-shell-api existsActive {cmd}')
+ (_, err) = vyos.util.popen(f'cli-shell-api existsActive {cmd}')
if err:
return False
return True
def value(self, path: list):
cmd = ' '.join(path)
- (out, err) = popen(f'cli-shell-api returnActiveValue {cmd}')
+ (out, err) = vyos.util.popen(f'cli-shell-api returnActiveValue {cmd}')
if err:
raise ConfigQueryError('No value for given path')
return out
def values(self, path: list):
cmd = ' '.join(path)
- (out, err) = popen(f'cli-shell-api returnActiveValues {cmd}')
+ (out, err) = vyos.util.popen(f'cli-shell-api returnActiveValues {cmd}')
if err:
raise ConfigQueryError('No values for given path')
return out
+class ConfigTreeQuery(GenericConfigQuery):
+ def __init__(self):
+ super().__init__()
+
+ config_source = ConfigSourceSession()
+ self.configtree = Config(config_source=config_source)
+
+ def exists(self, path: list):
+ return self.configtree.exists(path)
+
+ def value(self, path: list):
+ return self.configtree.return_value(path)
+
+ def values(self, path: list):
+ return self.configtree.return_values(path)
+
+ def list_nodes(self, path: list):
+ return self.configtree.list_nodes(path)
+
+ def get_config_dict(self, path=[], effective=False, key_mangling=None,
+ get_first_key=False, no_multi_convert=False,
+ no_tag_node_value_mangle=False):
+ return self.configtree.get_config_dict(path, effective=effective,
+ key_mangling=key_mangling, get_first_key=get_first_key,
+ no_multi_convert=no_multi_convert,
+ no_tag_node_value_mangle=no_tag_node_value_mangle)
+
class VbashOpRun(GenericOpRun):
def __init__(self):
super().__init__()
def run(self, path: list, **kwargs):
cmd = ' '.join(path)
- (out, err) = popen(f'. /opt/vyatta/share/vyatta-op/functions/interpreter/vyatta-op-run; _vyatta_op_run {cmd}', stderr=STDOUT, **kwargs)
+ (out, err) = vyos.util.popen(f'. /opt/vyatta/share/vyatta-op/functions/interpreter/vyatta-op-run; _vyatta_op_run {cmd}', stderr=STDOUT, **kwargs)
if err:
raise ConfigQueryError(out)
return out
diff --git a/python/vyos/configsession.py b/python/vyos/configsession.py
index 670e6c7fc..f28ad09c5 100644
--- a/python/vyos/configsession.py
+++ b/python/vyos/configsession.py
@@ -10,14 +10,14 @@
# See the GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License along with this library;
-# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
import os
import re
import sys
import subprocess
-from vyos.util import call
+from vyos.util import is_systemd_service_running
CLI_SHELL_API = '/bin/cli-shell-api'
SET = '/opt/vyatta/sbin/my_set'
@@ -73,8 +73,7 @@ def inject_vyos_env(env):
env['vyos_validators_dir'] = '/usr/libexec/vyos/validators'
# if running the vyos-configd daemon, inject the vyshim env var
- ret = call('systemctl is-active --quiet vyos-configd.service')
- if not ret:
+ if is_systemd_service_running('vyos-configd.service'):
env['vyshim'] = '/usr/sbin/vyshim'
return env
diff --git a/python/vyos/configverify.py b/python/vyos/configverify.py
index 99c472582..4279e6982 100644
--- a/python/vyos/configverify.py
+++ b/python/vyos/configverify.py
@@ -1,4 +1,4 @@
-# Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io>
+# Copyright 2020-2021 VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
@@ -45,6 +45,16 @@ def verify_mtu(config):
raise ConfigError(f'Interface MTU too high, ' \
f'maximum supported MTU is {max_mtu}!')
+def verify_mtu_parent(config, parent):
+ if 'mtu' not in config or 'mtu' not in parent:
+ return
+
+ mtu = int(config['mtu'])
+ parent_mtu = int(parent['mtu'])
+ if mtu > parent_mtu:
+ raise ConfigError(f'Interface MTU ({mtu}) too high, ' \
+ f'parent interface MTU is {parent_mtu}!')
+
def verify_mtu_ipv6(config):
"""
Common helper function used by interface implementations to perform
@@ -139,9 +149,38 @@ def verify_eapol(config):
recurring validation of EAPoL configuration.
"""
if 'eapol' in config:
- if not {'cert_file', 'key_file'} <= set(config['eapol']):
- raise ConfigError('Both cert and key-file must be specified '\
- 'when using EAPoL!')
+ if 'certificate' not in config['eapol']:
+ raise ConfigError('Certificate must be specified when using EAPoL!')
+
+ if 'certificate' not in config['pki']:
+ raise ConfigError('Invalid certificate specified for EAPoL')
+
+ cert_name = config['eapol']['certificate']
+
+ if cert_name not in config['pki']['certificate']:
+ raise ConfigError('Invalid certificate specified for EAPoL')
+
+ cert = config['pki']['certificate'][cert_name]
+
+ if 'certificate' not in cert or 'private' not in cert or 'key' not in cert['private']:
+ raise ConfigError('Invalid certificate/private key specified for EAPoL')
+
+ if 'password_protected' in cert['private']:
+ raise ConfigError('Encrypted private key cannot be used for EAPoL')
+
+ if 'ca_certificate' in config['eapol']:
+ if 'ca' not in config['pki']:
+ raise ConfigError('Invalid CA certificate specified for EAPoL')
+
+ ca_cert_name = config['eapol']['ca_certificate']
+
+ if ca_cert_name not in config['pki']['ca']:
+ raise ConfigError('Invalid CA certificate specified for EAPoL')
+
+ ca_cert = config['pki']['ca'][cert_name]
+
+ if 'certificate' not in ca_cert:
+ raise ConfigError('Invalid CA certificate specified for EAPoL')
def verify_mirror(config):
"""
@@ -156,6 +195,19 @@ def verify_mirror(config):
raise ConfigError(f'Can not mirror "{direction}" traffic back ' \
'the originating interface!')
+def verify_authentication(config):
+ """
+ Common helper function used by interface implementations to perform
+ recurring validation of authentication for either PPPoE or WWAN interfaces.
+
+ If authentication CLI option is defined, both username and password must
+ be set!
+ """
+ if 'authentication' not in config:
+ return
+ if not {'user', 'password'} <= set(config['authentication']):
+ raise ConfigError('Authentication requires both username and ' \
+ 'password to be set!')
def verify_address(config):
"""
@@ -266,6 +318,7 @@ def verify_vlan_config(config):
verify_dhcpv6(vlan)
verify_address(vlan)
verify_vrf(vlan)
+ verify_mtu_parent(vlan, config)
# 802.1ad (Q-in-Q) VLANs
for s_vlan in config.get('vif_s', {}):
@@ -273,12 +326,15 @@ def verify_vlan_config(config):
verify_dhcpv6(s_vlan)
verify_address(s_vlan)
verify_vrf(s_vlan)
+ verify_mtu_parent(s_vlan, config)
for c_vlan in s_vlan.get('vif_c', {}):
c_vlan = s_vlan['vif_c'][c_vlan]
verify_dhcpv6(c_vlan)
verify_address(c_vlan)
verify_vrf(c_vlan)
+ verify_mtu_parent(c_vlan, config)
+ verify_mtu_parent(c_vlan, s_vlan)
def verify_accel_ppp_base_service(config):
"""
@@ -288,7 +344,7 @@ def verify_accel_ppp_base_service(config):
# vertify auth settings
if dict_search('authentication.mode', config) == 'local':
if not dict_search('authentication.local_users', config):
- raise ConfigError('PPPoE local auth mode requires local users to be configured!')
+ raise ConfigError('Authentication mode local requires local users to be configured!')
for user in dict_search('authentication.local_users.username', config):
user_config = config['authentication']['local_users']['username'][user]
@@ -312,7 +368,7 @@ def verify_accel_ppp_base_service(config):
raise ConfigError(f'Missing RADIUS secret key for server "{server}"')
if 'gateway_address' not in config:
- raise ConfigError('PPPoE server requires gateway-address to be configured!')
+ raise ConfigError('Server requires gateway-address to be configured!')
if 'name_server_ipv4' in config:
if len(config['name_server_ipv4']) > 2:
diff --git a/python/vyos/defaults.py b/python/vyos/defaults.py
index 9921e3b5f..03006c383 100644
--- a/python/vyos/defaults.py
+++ b/python/vyos/defaults.py
@@ -22,7 +22,10 @@ directories = {
"migrate": "/opt/vyatta/etc/config-migrate/migrate",
"log": "/var/log/vyatta",
"templates": "/usr/share/vyos/templates/",
- "certbot": "/config/auth/letsencrypt"
+ "certbot": "/config/auth/letsencrypt",
+ "api_schema": "/usr/libexec/vyos/services/api/graphql/graphql/schema/",
+ "api_templates": "/usr/libexec/vyos/services/api/graphql/recipes/templates/"
+
}
cfg_group = 'vyattacfg'
diff --git a/python/vyos/ethtool.py b/python/vyos/ethtool.py
index 136feae8d..bc103959a 100644
--- a/python/vyos/ethtool.py
+++ b/python/vyos/ethtool.py
@@ -57,7 +57,11 @@ class Ethtool:
if ':' in line:
key, value = [s.strip() for s in line.strip().split(":", 1)]
key = key.lower().replace(' ', '_')
- self.ring_buffers[key] = int(value)
+ # T3645: ethtool version used on Debian Bullseye changed the
+ # output format from 0 -> n/a. As we are only interested in the
+ # tx/rx keys we do not care about RX Mini/Jumbo.
+ if value.isdigit():
+ self.ring_buffers[key] = int(value)
def is_fixed_lro(self):
diff --git a/python/vyos/ifconfig/__init__.py b/python/vyos/ifconfig/__init__.py
index e9da1e9f5..2d3e406ac 100644
--- a/python/vyos/ifconfig/__init__.py
+++ b/python/vyos/ifconfig/__init__.py
@@ -35,3 +35,4 @@ from vyos.ifconfig.tunnel import TunnelIf
from vyos.ifconfig.wireless import WiFiIf
from vyos.ifconfig.l2tpv3 import L2TPv3If
from vyos.ifconfig.macsec import MACsecIf
+from vyos.ifconfig.wwan import WWANIf
diff --git a/python/vyos/ifconfig/bond.py b/python/vyos/ifconfig/bond.py
index 233d53688..2b9afe109 100644
--- a/python/vyos/ifconfig/bond.py
+++ b/python/vyos/ifconfig/bond.py
@@ -86,6 +86,9 @@ class BondIf(Interface):
_sysfs_get = {**Interface._sysfs_get, **{
'bond_arp_ip_target': {
'location': '/sys/class/net/{ifname}/bonding/arp_ip_target',
+ },
+ 'bond_mode': {
+ 'location': '/sys/class/net/{ifname}/bonding/mode',
}
}}
@@ -317,6 +320,19 @@ class BondIf(Interface):
return enslaved_ifs
+ def get_mode(self):
+ """
+ Return bond operation mode.
+
+ Example:
+ >>> from vyos.ifconfig import BondIf
+ >>> BondIf('bond0').get_mode()
+ '802.3ad'
+ """
+ mode = self.get_interface('bond_mode')
+ # mode is now "802.3ad 4", we are only interested in "802.3ad"
+ return mode.split()[0]
+
def set_primary(self, interface):
"""
A string (eth0, eth2, etc) specifying which slave is the primary
diff --git a/python/vyos/ifconfig/ethernet.py b/python/vyos/ifconfig/ethernet.py
index b89ca5a5c..07b31a12a 100644
--- a/python/vyos/ifconfig/ethernet.py
+++ b/python/vyos/ifconfig/ethernet.py
@@ -55,6 +55,11 @@ class EthernetIf(Interface):
'possible': lambda i, v: EthernetIf.feature(i, 'gso', v),
# 'shellcmd': 'ethtool -K {ifname} gso {value}',
},
+ 'lro': {
+ 'validate': lambda v: assert_list(v, ['on', 'off']),
+ 'possible': lambda i, v: EthernetIf.feature(i, 'lro', v),
+ # 'shellcmd': 'ethtool -K {ifname} lro {value}',
+ },
'sg': {
'validate': lambda v: assert_list(v, ['on', 'off']),
'possible': lambda i, v: EthernetIf.feature(i, 'sg', v),
@@ -238,6 +243,18 @@ class EthernetIf(Interface):
raise ValueError("Value out of range")
return self.set_interface('gso', 'on' if state else 'off')
+ def set_lro(self, state):
+ """
+ Enable Large Receive offload. State can be either True or False.
+ Example:
+ >>> from vyos.ifconfig import EthernetIf
+ >>> i = EthernetIf('eth0')
+ >>> i.set_lro(True)
+ """
+ if not isinstance(state, bool):
+ raise ValueError("Value out of range")
+ return self.set_interface('lro', 'on' if state else 'off')
+
def set_rps(self, state):
if not isinstance(state, bool):
raise ValueError("Value out of range")
@@ -328,6 +345,9 @@ class EthernetIf(Interface):
# GSO (generic segmentation offload)
self.set_gso(dict_search('offload.gso', config) != None)
+ # LRO (large receive offload)
+ self.set_lro(dict_search('offload.lro', config) != None)
+
# RPS - Receive Packet Steering
self.set_rps(dict_search('offload.rps', config) != None)
diff --git a/python/vyos/ifconfig/interface.py b/python/vyos/ifconfig/interface.py
index 048a2cd19..a1928ba51 100644..100755
--- a/python/vyos/ifconfig/interface.py
+++ b/python/vyos/ifconfig/interface.py
@@ -311,6 +311,28 @@ class Interface(Control):
cmd = 'ip link del dev {ifname}'.format(**self.config)
return self._cmd(cmd)
+ def _set_vrf_ct_zone(self, vrf):
+ """
+ Add/Remove rules in nftables to associate traffic in VRF to an
+ individual conntack zone
+ """
+ if vrf:
+ # Get routing table ID for VRF
+ vrf_table_id = get_interface_config(vrf).get('linkinfo', {}).get(
+ 'info_data', {}).get('table')
+ # Add map element with interface and zone ID
+ if vrf_table_id:
+ self._cmd(
+ f'nft add element inet vrf_zones ct_iface_map {{ "{self.ifname}" : {vrf_table_id} }}'
+ )
+ else:
+ nft_del_element = f'delete element inet vrf_zones ct_iface_map {{ "{self.ifname}" }}'
+ # Check if deleting is possible first to avoid raising errors
+ _, err = self._popen(f'nft -c {nft_del_element}')
+ if not err:
+ # Remove map element
+ self._cmd(f'nft {nft_del_element}')
+
def get_min_mtu(self):
"""
Get hardware minimum supported MTU
@@ -401,6 +423,7 @@ class Interface(Control):
>>> Interface('eth0').set_vrf()
"""
self.set_interface('vrf', vrf)
+ self._set_vrf_ct_zone(vrf)
def set_arp_cache_tmo(self, tmo):
"""
@@ -779,9 +802,7 @@ class Interface(Control):
# Note that currently expanded netmasks are not supported. That means
# 2001:db00::0/24 is a valid argument while 2001:db00::0/ffff:ff00:: not.
# see https://docs.python.org/3/library/ipaddress.html
- bits = bin(
- int(v6_addr['netmask'].replace(':', ''), 16)).count('1')
- prefix = '/' + str(bits)
+ prefix = '/' + v6_addr['netmask'].split('/')[-1]
# we alsoneed to remove the interface suffix on link local
# addresses
@@ -1345,12 +1366,55 @@ class Interface(Control):
# create/update 802.1q VLAN interfaces
for vif_id, vif_config in config.get('vif', {}).items():
+
+ vif_ifname = f'{ifname}.{vif_id}'
+ vif_config['ifname'] = vif_ifname
+
tmp = deepcopy(VLANIf.get_config())
tmp['source_interface'] = ifname
tmp['vlan_id'] = vif_id
- vif_ifname = f'{ifname}.{vif_id}'
- vif_config['ifname'] = vif_ifname
+ # We need to ensure that the string format is consistent, and we need to exclude redundant spaces.
+ sep = ' '
+ if 'egress_qos' in vif_config:
+ # Unwrap strings into arrays
+ egress_qos_array = vif_config['egress_qos'].split()
+ # The split array is spliced according to the fixed format
+ tmp['egress_qos'] = sep.join(egress_qos_array)
+
+ if 'ingress_qos' in vif_config:
+ # Unwrap strings into arrays
+ ingress_qos_array = vif_config['ingress_qos'].split()
+ # The split array is spliced according to the fixed format
+ tmp['ingress_qos'] = sep.join(ingress_qos_array)
+
+ # Since setting the QoS control parameters in the later stage will
+ # not completely delete the old settings,
+ # we still need to delete the VLAN encapsulation interface in order to
+ # ensure that the changed settings are effective.
+ cur_cfg = get_interface_config(vif_ifname)
+ qos_str = ''
+ tmp2 = dict_search('linkinfo.info_data.ingress_qos', cur_cfg)
+ if 'ingress_qos' in tmp and tmp2:
+ for item in tmp2:
+ from_key = item['from']
+ to_key = item['to']
+ qos_str += f'{from_key}:{to_key} '
+ if qos_str != tmp['ingress_qos']:
+ if self.exists(vif_ifname):
+ VLANIf(vif_ifname).remove()
+
+ qos_str = ''
+ tmp2 = dict_search('linkinfo.info_data.egress_qos', cur_cfg)
+ if 'egress_qos' in tmp and tmp2:
+ for item in tmp2:
+ from_key = item['from']
+ to_key = item['to']
+ qos_str += f'{from_key}:{to_key} '
+ if qos_str != tmp['egress_qos']:
+ if self.exists(vif_ifname):
+ VLANIf(vif_ifname).remove()
+
vlan = VLANIf(vif_ifname, **tmp)
vlan.update(vif_config)
diff --git a/python/vyos/ifconfig/l2tpv3.py b/python/vyos/ifconfig/l2tpv3.py
index 7ff0fdd0e..fcd1fbf81 100644
--- a/python/vyos/ifconfig/l2tpv3.py
+++ b/python/vyos/ifconfig/l2tpv3.py
@@ -13,8 +13,28 @@
# You should have received a copy of the GNU Lesser General Public
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
+from time import sleep
+from time import time
+from vyos.util import run
from vyos.ifconfig.interface import Interface
+def wait_for_add_l2tpv3(timeout=10, sleep_interval=1, cmd=None):
+ '''
+ In some cases, we need to wait until local address is assigned.
+ And only then can the l2tpv3 tunnel be configured.
+ For example when ipv6 address in tentative state
+ or we wait for some routing daemon for remote address.
+ '''
+ start_time = time()
+ test_command = cmd
+ while True:
+ if (start_time + timeout) < time():
+ return None
+ result = run(test_command)
+ if result == 0:
+ return True
+ sleep(sleep_interval)
+
@Interface.register
class L2TPv3If(Interface):
"""
@@ -43,7 +63,9 @@ class L2TPv3If(Interface):
cmd += ' encap {encapsulation}'
cmd += ' local {source_address}'
cmd += ' remote {remote}'
- self._cmd(cmd.format(**self.config))
+ c = cmd.format(**self.config)
+ # wait until the local/remote address is available, but no more 10 sec.
+ wait_for_add_l2tpv3(cmd=c)
# setup session
cmd = 'ip l2tp add session name {ifname}'
diff --git a/python/vyos/ifconfig/tunnel.py b/python/vyos/ifconfig/tunnel.py
index 2a266fc9f..64c735824 100644
--- a/python/vyos/ifconfig/tunnel.py
+++ b/python/vyos/ifconfig/tunnel.py
@@ -62,6 +62,7 @@ class TunnelIf(Interface):
mapping_ipv4 = {
'parameters.ip.key' : 'key',
'parameters.ip.no_pmtu_discovery' : 'nopmtudisc',
+ 'parameters.ip.ignore_df' : 'ignore-df',
'parameters.ip.tos' : 'tos',
'parameters.ip.ttl' : 'ttl',
'parameters.erspan.direction' : 'erspan_dir',
diff --git a/python/vyos/ifconfig/vrrp.py b/python/vyos/ifconfig/vrrp.py
index d3e9d5df2..b522cc1ab 100644
--- a/python/vyos/ifconfig/vrrp.py
+++ b/python/vyos/ifconfig/vrrp.py
@@ -92,11 +92,14 @@ class VRRP(object):
try:
# send signal to generate the configuration file
pid = util.read_file(cls.location['pid'])
- os.kill(int(pid), cls._signal[what])
+ util.wait_for_file_write_complete(fname,
+ pre_hook=(lambda: os.kill(int(pid), cls._signal[what])),
+ timeout=30)
- # should look for file size change?
- sleep(0.2)
return util.read_file(fname)
+ except OSError:
+ # raised by vyos.util.read_file
+ raise VRRPNoData("VRRP data is not available (wait time exceeded)")
except FileNotFoundError:
raise VRRPNoData("VRRP data is not available (process not running or no active groups)")
except Exception:
diff --git a/python/vyos/ifconfig/vti.py b/python/vyos/ifconfig/vti.py
index e2090c889..470ebbff3 100644
--- a/python/vyos/ifconfig/vti.py
+++ b/python/vyos/ifconfig/vti.py
@@ -14,6 +14,7 @@
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
from vyos.ifconfig.interface import Interface
+from vyos.util import dict_search
@Interface.register
class VTIIf(Interface):
@@ -25,3 +26,27 @@ class VTIIf(Interface):
'prefixes': ['vti', ],
},
}
+
+ def _create(self):
+ # This table represents a mapping from VyOS internal config dict to
+ # arguments used by iproute2. For more information please refer to:
+ # - https://man7.org/linux/man-pages/man8/ip-link.8.html
+ # - https://man7.org/linux/man-pages/man8/ip-tunnel.8.html
+ mapping = {
+ 'source_interface' : 'dev',
+ }
+
+ if_id = self.ifname.lstrip('vti')
+ cmd = f'ip link add {self.ifname} type xfrm if_id {if_id}'
+ for vyos_key, iproute2_key in mapping.items():
+ # dict_search will return an empty dict "{}" for valueless nodes like
+ # "parameters.nolearning" - thus we need to test the nodes existence
+ # by using isinstance()
+ tmp = dict_search(vyos_key, self.config)
+ if isinstance(tmp, dict):
+ cmd += f' {iproute2_key}'
+ elif tmp != None:
+ cmd += f' {iproute2_key} {tmp}'
+
+ self._cmd(cmd.format(**self.config))
+ self.set_interface('admin_state', 'down')
diff --git a/python/vyos/ifconfig/wireguard.py b/python/vyos/ifconfig/wireguard.py
index e5b9c4408..c4cf2fbbf 100644
--- a/python/vyos/ifconfig/wireguard.py
+++ b/python/vyos/ifconfig/wireguard.py
@@ -95,7 +95,7 @@ class WireGuardOperational(Operational):
for peer in c.list_effective_nodes(["peer"]):
if wgdump['peers']:
- pubkey = c.return_effective_value(["peer", peer, "pubkey"])
+ pubkey = c.return_effective_value(["peer", peer, "public_key"])
if pubkey in wgdump['peers']:
wgpeer = wgdump['peers'][pubkey]
@@ -194,11 +194,15 @@ class WireGuardIf(Interface):
peer = config['peer_remove'][tmp]
peer['ifname'] = config['ifname']
- cmd = 'wg set {ifname} peer {pubkey} remove'
+ cmd = 'wg set {ifname} peer {public_key} remove'
self._cmd(cmd.format(**peer))
+ config['private_key_file'] = '/tmp/tmp.wireguard.key'
+ with open(config['private_key_file'], 'w') as f:
+ f.write(config['private_key'])
+
# Wireguard base command is identical for every peer
- base_cmd = 'wg set {ifname} private-key {private_key}'
+ base_cmd = 'wg set {ifname} private-key {private_key_file}'
if 'port' in config:
base_cmd += ' listen-port {port}'
if 'fwmark' in config:
@@ -210,7 +214,7 @@ class WireGuardIf(Interface):
peer = config['peer'][tmp]
# start of with a fresh 'wg' command
- cmd = base_cmd + ' peer {pubkey}'
+ cmd = base_cmd + ' peer {public_key}'
# If no PSK is given remove it by using /dev/null - passing keys via
# the shell (usually bash) is considered insecure, thus we use a file
diff --git a/python/vyos/ifconfig/wwan.py b/python/vyos/ifconfig/wwan.py
new file mode 100644
index 000000000..f18959a60
--- /dev/null
+++ b/python/vyos/ifconfig/wwan.py
@@ -0,0 +1,28 @@
+# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+from vyos.ifconfig.interface import Interface
+
+@Interface.register
+class WWANIf(Interface):
+ iftype = 'wwan'
+ definition = {
+ **Interface.definition,
+ **{
+ 'section': 'wwan',
+ 'prefixes': ['wwan', ],
+ 'eternal': 'wwan[0-9]+$',
+ },
+ }
diff --git a/python/vyos/pki.py b/python/vyos/pki.py
new file mode 100644
index 000000000..68ad73bf2
--- /dev/null
+++ b/python/vyos/pki.py
@@ -0,0 +1,333 @@
+#!/usr/bin/env python3
+#
+# Copyright (C) 2021 VyOS maintainers and contributors
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 or later as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+import datetime
+import ipaddress
+
+from cryptography import x509
+from cryptography.exceptions import InvalidSignature
+from cryptography.x509.extensions import ExtensionNotFound
+from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID, ExtensionOID
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import dh
+from cryptography.hazmat.primitives.asymmetric import dsa
+from cryptography.hazmat.primitives.asymmetric import ec
+from cryptography.hazmat.primitives.asymmetric import padding
+from cryptography.hazmat.primitives.asymmetric import rsa
+
+CERT_BEGIN='-----BEGIN CERTIFICATE-----\n'
+CERT_END='\n-----END CERTIFICATE-----'
+KEY_BEGIN='-----BEGIN PRIVATE KEY-----\n'
+KEY_END='\n-----END PRIVATE KEY-----'
+KEY_ENC_BEGIN='-----BEGIN ENCRYPTED PRIVATE KEY-----\n'
+KEY_ENC_END='\n-----END ENCRYPTED PRIVATE KEY-----'
+KEY_PUB_BEGIN='-----BEGIN PUBLIC KEY-----\n'
+KEY_PUB_END='\n-----END PUBLIC KEY-----'
+CRL_BEGIN='-----BEGIN X509 CRL-----\n'
+CRL_END='\n-----END X509 CRL-----'
+CSR_BEGIN='-----BEGIN CERTIFICATE REQUEST-----\n'
+CSR_END='\n-----END CERTIFICATE REQUEST-----'
+DH_BEGIN='-----BEGIN DH PARAMETERS-----\n'
+DH_END='\n-----END DH PARAMETERS-----'
+OVPN_BEGIN = '-----BEGIN OpenVPN Static key V{0}-----\n'
+OVPN_END = '\n-----END OpenVPN Static key V{0}-----'
+
+# Print functions
+
+encoding_map = {
+ 'PEM': serialization.Encoding.PEM,
+ 'OpenSSH': serialization.Encoding.OpenSSH
+}
+
+public_format_map = {
+ 'SubjectPublicKeyInfo': serialization.PublicFormat.SubjectPublicKeyInfo,
+ 'OpenSSH': serialization.PublicFormat.OpenSSH
+}
+
+private_format_map = {
+ 'PKCS8': serialization.PrivateFormat.PKCS8,
+ 'OpenSSH': serialization.PrivateFormat.OpenSSH
+}
+
+def encode_certificate(cert):
+ return cert.public_bytes(encoding=serialization.Encoding.PEM).decode('utf-8')
+
+def encode_public_key(cert, encoding='PEM', key_format='SubjectPublicKeyInfo'):
+ if encoding not in encoding_map:
+ encoding = 'PEM'
+ if key_format not in public_format_map:
+ key_format = 'SubjectPublicKeyInfo'
+ return cert.public_bytes(
+ encoding=encoding_map[encoding],
+ format=public_format_map[key_format]).decode('utf-8')
+
+def encode_private_key(private_key, encoding='PEM', key_format='PKCS8', passphrase=None):
+ if encoding not in encoding_map:
+ encoding = 'PEM'
+ if key_format not in private_format_map:
+ key_format = 'PKCS8'
+ encryption = serialization.NoEncryption() if not passphrase else serialization.BestAvailableEncryption(bytes(passphrase, 'utf-8'))
+ return private_key.private_bytes(
+ encoding=encoding_map[encoding],
+ format=private_format_map[key_format],
+ encryption_algorithm=encryption).decode('utf-8')
+
+def encode_dh_parameters(dh_parameters):
+ return dh_parameters.parameter_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.ParameterFormat.PKCS3).decode('utf-8')
+
+# EC Helper
+
+def get_elliptic_curve(size):
+ curve_func = None
+ name = f'SECP{size}R1'
+ if hasattr(ec, name):
+ curve_func = getattr(ec, name)
+ else:
+ curve_func = ec.SECP256R1() # Default to SECP256R1
+ return curve_func()
+
+# Creation functions
+
+def create_private_key(key_type, key_size=None):
+ private_key = None
+ if key_type == 'rsa':
+ private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size)
+ elif key_type == 'dsa':
+ private_key = dsa.generate_private_key(key_size=key_size)
+ elif key_type == 'ec':
+ curve = get_elliptic_curve(key_size)
+ private_key = ec.generate_private_key(curve)
+ return private_key
+
+def create_certificate_request(subject, private_key, subject_alt_names=[]):
+ subject_obj = x509.Name([
+ x509.NameAttribute(NameOID.COUNTRY_NAME, subject['country']),
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, subject['state']),
+ x509.NameAttribute(NameOID.LOCALITY_NAME, subject['locality']),
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject['organization']),
+ x509.NameAttribute(NameOID.COMMON_NAME, subject['common_name'])])
+
+ builder = x509.CertificateSigningRequestBuilder() \
+ .subject_name(subject_obj)
+
+ if subject_alt_names:
+ alt_names = []
+ for obj in subject_alt_names:
+ if isinstance(obj, ipaddress.IPv4Address) or isinstance(obj, ipaddress.IPv6Address):
+ alt_names.append(x509.IPAddress(obj))
+ elif isinstance(obj, str):
+ alt_names.append(x509.DNSName(obj))
+ if alt_names:
+ builder = builder.add_extension(x509.SubjectAlternativeName(alt_names), critical=False)
+
+ return builder.sign(private_key, hashes.SHA256())
+
+def add_key_identifier(ca_cert):
+ try:
+ ski_ext = ca_cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier)
+ return x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(ski_ext.value)
+ except:
+ return x509.AuthorityKeyIdentifier.from_issuer_public_key(ca_cert.public_key())
+
+def create_certificate(cert_req, ca_cert, ca_private_key, valid_days=365, cert_type='server', is_ca=False, is_sub_ca=False):
+ ext_key_usage = []
+ if is_ca:
+ ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH]
+ elif cert_type == 'client':
+ ext_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH]
+ elif cert_type == 'server':
+ ext_key_usage = [ExtendedKeyUsageOID.SERVER_AUTH]
+
+ builder = x509.CertificateBuilder() \
+ .subject_name(cert_req.subject) \
+ .issuer_name(ca_cert.subject) \
+ .public_key(cert_req.public_key()) \
+ .serial_number(x509.random_serial_number()) \
+ .not_valid_before(datetime.datetime.utcnow()) \
+ .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=int(valid_days)))
+
+ builder = builder.add_extension(x509.BasicConstraints(ca=is_ca, path_length=0 if is_sub_ca else None), critical=True)
+ builder = builder.add_extension(x509.KeyUsage(
+ digital_signature=True,
+ content_commitment=False,
+ key_encipherment=False,
+ data_encipherment=False,
+ key_agreement=False,
+ key_cert_sign=is_ca,
+ crl_sign=is_ca,
+ encipher_only=False,
+ decipher_only=False), critical=True)
+ builder = builder.add_extension(x509.ExtendedKeyUsage(ext_key_usage), critical=False)
+ builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(cert_req.public_key()), critical=False)
+
+ if not is_ca or is_sub_ca:
+ builder = builder.add_extension(add_key_identifier(ca_cert), critical=False)
+
+ for ext in cert_req.extensions:
+ builder = builder.add_extension(ext.value, critical=False)
+
+ return builder.sign(ca_private_key, hashes.SHA256())
+
+def create_certificate_revocation_list(ca_cert, ca_private_key, serial_numbers=[]):
+ if not serial_numbers:
+ return False
+
+ builder = x509.CertificateRevocationListBuilder() \
+ .issuer_name(ca_cert.subject) \
+ .last_update(datetime.datetime.today()) \
+ .next_update(datetime.datetime.today() + datetime.timedelta(1, 0, 0))
+
+ for serial_number in serial_numbers:
+ revoked_cert = x509.RevokedCertificateBuilder() \
+ .serial_number(serial_number) \
+ .revocation_date(datetime.datetime.today()) \
+ .build()
+ builder = builder.add_revoked_certificate(revoked_cert)
+
+ return builder.sign(private_key=ca_private_key, algorithm=hashes.SHA256())
+
+def create_dh_parameters(bits=2048):
+ if not bits or bits < 512:
+ print("Invalid DH parameter key size")
+ return False
+
+ return dh.generate_parameters(generator=2, key_size=int(bits))
+
+# Wrap functions
+
+def wrap_public_key(raw_data):
+ return KEY_PUB_BEGIN + raw_data + KEY_PUB_END
+
+def wrap_private_key(raw_data, passphrase=None):
+ return (KEY_ENC_BEGIN if passphrase else KEY_BEGIN) + raw_data + (KEY_ENC_END if passphrase else KEY_END)
+
+def wrap_certificate_request(raw_data):
+ return CSR_BEGIN + raw_data + CSR_END
+
+def wrap_certificate(raw_data):
+ return CERT_BEGIN + raw_data + CERT_END
+
+def wrap_crl(raw_data):
+ return CRL_BEGIN + raw_data + CRL_END
+
+def wrap_dh_parameters(raw_data):
+ return DH_BEGIN + raw_data + DH_END
+
+def wrap_openvpn_key(raw_data, version='1'):
+ return OVPN_BEGIN.format(version) + raw_data + OVPN_END.format(version)
+
+# Load functions
+
+def load_public_key(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_public_key(raw_data)
+
+ try:
+ return serialization.load_pem_public_key(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_private_key(raw_data, passphrase=None, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_private_key(raw_data, passphrase)
+
+ if passphrase:
+ passphrase = bytes(passphrase, 'utf-8')
+
+ try:
+ return serialization.load_pem_private_key(bytes(raw_data, 'utf-8'), password=passphrase)
+ except ValueError:
+ return False
+
+def load_certificate_request(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_certificate_request(raw_data)
+
+ try:
+ return x509.load_pem_x509_csr(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_certificate(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_certificate(raw_data)
+
+ try:
+ return x509.load_pem_x509_certificate(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_crl(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_crl(raw_data)
+
+ try:
+ return x509.load_pem_x509_crl(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+def load_dh_parameters(raw_data, wrap_tags=True):
+ if wrap_tags:
+ raw_data = wrap_dh_parameters(raw_data)
+
+ try:
+ return serialization.load_pem_parameters(bytes(raw_data, 'utf-8'))
+ except ValueError:
+ return False
+
+# Verify
+
+def is_ca_certificate(cert):
+ if not cert:
+ return False
+
+ try:
+ ext = cert.extensions.get_extension_for_oid(ExtensionOID.BASIC_CONSTRAINTS)
+ return ext.value.ca
+ except ExtensionNotFound:
+ return False
+
+def verify_certificate(cert, ca_cert):
+ # Verify certificate was signed by specified CA
+ if ca_cert.subject != cert.issuer:
+ return False
+
+ ca_public_key = ca_cert.public_key()
+ try:
+ if isinstance(ca_public_key, rsa.RSAPublicKeyWithSerialization):
+ ca_public_key.verify(
+ cert.signature,
+ cert.tbs_certificate_bytes,
+ padding=padding.PKCS1v15(),
+ algorithm=cert.signature_hash_algorithm)
+ elif isinstance(ca_public_key, dsa.DSAPublicKeyWithSerialization):
+ ca_public_key.verify(
+ cert.signature,
+ cert.tbs_certificate_bytes,
+ algorithm=cert.signature_hash_algorithm)
+ elif isinstance(ca_public_key, ec.EllipticCurvePublicKeyWithSerialization):
+ ca_public_key.verify(
+ cert.signature,
+ cert.tbs_certificate_bytes,
+ signature_algorithm=ec.ECDSA(cert.signature_hash_algorithm))
+ else:
+ return False # We cannot verify it
+ return True
+ except InvalidSignature:
+ return False
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index f683a6d5a..e972050b7 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -15,90 +15,176 @@
from ftplib import FTP
import os
+import shutil
import socket
+import stat
import sys
import tempfile
import urllib.parse
-import urllib.request
+import urllib.request as urlreq
-from vyos.util import cmd, ask_yes_no
+from vyos.template import get_ip
+from vyos.template import ip_from_cidr
+from vyos.template import is_interface
+from vyos.template import is_ipv6
+from vyos.util import cmd
+from vyos.util import ask_yes_no
+from vyos.util import print_error
+from vyos.util import make_progressbar
+from vyos.util import make_incremental_progressbar
from vyos.version import get_version
-from paramiko import SSHClient, SSHException, MissingHostKeyPolicy
+from paramiko import SSHClient
+from paramiko import SSHException
+from paramiko import MissingHostKeyPolicy
-
-known_hosts_file = os.path.expanduser('~/.ssh/known_hosts')
+# This is a hardcoded path and no environment variable can change it.
+KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts')
+CHUNK_SIZE = 8192
class InteractivePolicy(MissingHostKeyPolicy):
"""
Policy for interactively querying the user on whether to proceed with
- SSH connections to unknown hosts.
+ SSH connections to unknown hosts.
"""
def missing_host_key(self, client, hostname, key):
- print(f"Host '{hostname}' not found in known hosts.")
- print('Fingerprint: ' + key.get_fingerprint().hex())
+ print_error(f"Host '{hostname}' not found in known hosts.")
+ print_error('Fingerprint: ' + key.get_fingerprint().hex())
if ask_yes_no('Do you wish to continue?'):
- if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
+ if client._host_keys_filename\
+ and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
client._host_keys.add(hostname, key.get_name(), key)
client.save_host_keys(client._host_keys_filename)
else:
raise SSHException(f"Cannot connect to unknown host '{hostname}'.")
+
+## Helper routines
+def get_authentication_variables(default_username=None, default_password=None):
+ """
+ Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and
+ return the defaults provided if environment variables are empty or nonexistent.
+ """
+ username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD')
+ # Fall back to defaults if the username variable doesn't exist or is an empty string.
+ # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`,
+ # as we want the username and the password to have the same behaviour.
+ if not username:
+ return default_username, default_password
+ else:
+ return username, password
+
+def get_source_address(source):
+ """
+ Take a string vaguely indicating an origin source (interface, hostname or IP address),
+ return a tuple in the format `(source_pair, address_family)` where
+ `source_pair` is `(source_address, source_port)`.
+ """
+ # TODO: Properly distinguish between IPv4 and IPv6.
+ port = 0
+ if is_interface(source):
+ source = ip_from_cidr(get_ip(source)[0])
+ if is_ipv6(source):
+ return (source, port), socket.AF_INET6
+ else:
+ return (socket.gethostbyname(source), port), socket.AF_INET
+
+def get_port_from_url(url):
+ """
+ Return the port number from the given `url` named tuple, fall back to
+ the default if there isn't one.
+ """
+ defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\
+ "ssh": 22, "scp": 22, "sftp": 22}
+ if url.port:
+ return url.port
+ else:
+ return defaults[url.scheme]
+
+
## FTP routines
-def transfer_ftp(mode, local_path, hostname, remote_path,\
- username='anonymous', password='', port=21, source=None):
- with FTP(source_address=source) as conn:
+def upload_ftp(local_path, hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source_pair=None, progressbar=False):
+ size = os.path.getsize(local_path)
+ with FTP(source_address=source_pair) as conn:
conn.connect(hostname, port)
conn.login(username, password)
- if mode == 'upload':
- with open(local_path, 'rb') as file:
- conn.storbinary(f'STOR {remote_path}', file)
- elif mode == 'download':
- with open(local_path, 'wb') as file:
- conn.retrbinary(f'RETR {remote_path}', file.write)
- elif mode == 'size':
- size = conn.size(remote_path)
- if size:
- return size
+ with open(local_path, 'rb') as file:
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / size)
+ next(progress)
+ callback = lambda block: next(progress)
else:
- # SIZE is an extension to the FTP specification, although it's extremely common.
- raise ValueError('Failed to receive file size from FTP server. \
- Perhaps the server does not implement the SIZE command?')
+ callback = None
+ conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback)
-def upload_ftp(*args, **kwargs):
- transfer_ftp('upload', *args, **kwargs)
+def download_ftp(local_path, hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source_pair=None, progressbar=False):
+ with FTP(source_address=source_pair) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ size = conn.size(remote_path)
+ with open(local_path, 'wb') as file:
+ # No progressbar if we can't determine the size.
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / size)
+ next(progress)
+ callback = lambda block: (file.write(block), next(progress))
+ else:
+ callback = file.write
+ conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE)
-def download_ftp(*args, **kwargs):
- transfer_ftp('download', *args, **kwargs)
+def get_ftp_file_size(hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source_pair=None):
+ with FTP(source_address=source) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ size = conn.size(remote_path)
+ if size:
+ return size
+ else:
+ # SIZE is an extension to the FTP specification, although it's extremely common.
+ raise ValueError('Failed to receive file size from FTP server. \
+ Perhaps the server does not implement the SIZE command?')
-def get_ftp_file_size(*args, **kwargs):
- return transfer_ftp('size', None, *args, **kwargs)
## SFTP/SCP routines
def transfer_sftp(mode, local_path, hostname, remote_path,\
- username=None, password=None, port=22, source=None):
+ username=None, password=None, port=22,\
+ source_tuple=None, progressbar=False):
sock = None
- if source:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind((source, 0))
+ if source_tuple:
+ (source_address, source_port), address_family = source_tuple
+ sock = socket.socket(address_family, socket.SOCK_STREAM)
+ sock.bind((source_address, source_port))
sock.connect((hostname, port))
- try:
- with SSHClient() as ssh:
- ssh.load_system_host_keys()
- if os.path.exists(known_hosts_file):
- ssh.load_host_keys(known_hosts_file)
- ssh.set_missing_host_key_policy(InteractivePolicy())
- ssh.connect(hostname, port, username, password, sock=sock)
- with ssh.open_sftp() as sftp:
- if mode == 'upload':
- sftp.put(local_path, remote_path)
- elif mode == 'download':
- sftp.get(remote_path, local_path)
- elif mode == 'size':
- return sftp.stat(remote_path).st_size
- finally:
- if sock:
- sock.shutdown()
- sock.close()
+ callback = make_progressbar() if progressbar else None
+ with SSHClient() as ssh:
+ ssh.load_system_host_keys()
+ if os.path.exists(KNOWN_HOSTS_FILE):
+ ssh.load_host_keys(KNOWN_HOSTS_FILE)
+ ssh.set_missing_host_key_policy(InteractivePolicy())
+ ssh.connect(hostname, port, username, password, sock=sock)
+ with ssh.open_sftp() as sftp:
+ if mode == 'upload':
+ try:
+ # If the remote path is a directory, use the original filename.
+ if stat.S_ISDIR(sftp.stat(remote_path).st_mode):
+ path = os.path.join(remote_path, os.path.basename(local_path))
+ # A file exists at this destination. We're simply going to clobber it.
+ else:
+ path = remote_path
+ # This path doesn't point at any existing file. We can freely use this filename.
+ except IOError:
+ path = remote_path
+ finally:
+ sftp.put(local_path, path, callback=callback)
+ elif mode == 'download':
+ sftp.get(remote_path, local_path, callback=callback)
+ elif mode == 'size':
+ return sftp.stat(remote_path).st_size
def upload_sftp(*args, **kwargs):
transfer_sftp('upload', *args, **kwargs)
@@ -109,32 +195,70 @@ def download_sftp(*args, **kwargs):
def get_sftp_file_size(*args, **kwargs):
return transfer_sftp('size', None, *args, **kwargs)
+
## TFTP routines
-def upload_tftp(local_path, hostname, remote_path, port=69, source=None):
+def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'rb') as file:
- cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\
+ cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\
stderr=None, input=file.read()).encode()
-def download_tftp(local_path, hostname, remote_path, port=69, source=None):
+def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ # Not really applicable but we pass it for the sake of uniformity.
+ progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'wb') as file:
- file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\
+ file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\
stderr=None).encode())
# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP,
-# as TFTP does not specify a SIZE command.
+# as TFTP does not specify a SIZE command.
+
## HTTP(S) routines
-def download_http(urlstring, local_path):
- request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
- with open(local_path, 'wb') as file:
- with urllib.request.urlopen(request) as response:
- file.write(response.read())
+def install_request_opener(urlstring, username, password):
+ """
+ Take `username` and `password` strings and install the appropriate
+ password manager to `urllib.request.urlopen()` for the given `urlstring`.
+ """
+ manager = urlreq.HTTPPasswordMgrWithDefaultRealm()
+ manager.add_password(None, urlstring, username, password)
+ urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager)))
+
+# upload_http() is unimplemented.
-def get_http_file_size(urlstring):
- request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
- with urllib.request.urlopen(request) as response:
+def download_http(local_path, urlstring, username=None, password=None, progressbar=False):
+ """
+ Download the file from from `urlstring` to `local_path`.
+ Optionally takes `username` and `password` for authentication.
+ """
+ request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
+ if username:
+ install_request_opener(urlstring, username, password)
+ with open(local_path, 'wb') as file, urlreq.urlopen(request) as response:
+ size = response.getheader('Content-Length')
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / int(size))
+ next(progress)
+ for chunk in iter(lambda: response.read(CHUNK_SIZE), b''):
+ file.write(chunk)
+ next(progress)
+ next(progress)
+ # If we can't determine the size or if a progress bar wasn't requested,
+ # we can let `shutil` take care of the copying.
+ else:
+ shutil.copyfileobj(response, file)
+
+def get_http_file_size(urlstring, username=None, password=None):
+ """
+ Return the size of the file from `urlstring` in terms of number of bytes.
+ Optionally takes `username` and `password` for authentication.
+ """
+ request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
+ if username:
+ install_request_opener(urlstring, username, password)
+ with urlreq.urlopen(request) as response:
size = response.getheader('Content-Length')
if size:
return int(size)
@@ -142,69 +266,96 @@ def get_http_file_size(urlstring):
else:
raise ValueError('Failed to receive file size from HTTP server.')
-# Dynamic dispatchers
-def download(local_path, urlstring, source=None):
+
+## Dynamic dispatchers
+def download(local_path, urlstring, source=None, progressbar=False):
"""
- Dispatch the appropriate download function for the given URL and save to local path.
+ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`.
+ Optionally takes a `source` address or interface (not valid for HTTP(S)).
+ Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP.
+ Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
url = urllib.parse.urlparse(urlstring)
+ username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
+
if url.scheme == 'http' or url.scheme == 'https':
if source:
- print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr)
- download_http(urlstring, local_path)
+ print_error('Warning: Custom source address not supported for HTTP connections.')
+ download_http(local_path, urlstring, username, password, progressbar)
elif url.scheme == 'ftp':
- username = url.username if url.username else 'anonymous'
- download_ftp(local_path, url.hostname, url.path, username, url.password, source=source)
+ source = get_source_address(source)[0] if source else None
+ username = username if username else 'anonymous'
+ download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- download_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source)
+ source = get_source_address(source) if source else None
+ download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
- download_tftp(local_path, url.hostname, url.path, source=source)
+ download_tftp(local_path, url.hostname, url.path, port, source, progressbar)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
-def upload(local_path, urlstring, source=None):
+def upload(local_path, urlstring, source=None, progressbar=False):
"""
Dispatch the appropriate upload function for the given URL and upload from local path.
+ Optionally takes a `source` address.
+ Supports FTP, SFTP, SCP (through SFTP) and TFTP.
+ Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
url = urllib.parse.urlparse(urlstring)
+ username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
+
if url.scheme == 'ftp':
- username = url.username if url.username else 'anonymous'
- upload_ftp(local_path, url.hostname, url.path, username, url.password, source=source)
+ username = username if username else 'anonymous'
+ source = get_source_address(source)[0] if source else None
+ upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- upload_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source)
+ source = get_source_address(source) if source else None
+ upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
- upload_tftp(local_path, url.hostname, url.path, source=source)
+ upload_tftp(local_path, url.hostname, url.path, port, source, progressbar)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
def get_remote_file_size(urlstring, source=None):
"""
- Return the size of the remote file in bytes.
+ Dispatch the appropriate function to return the size of the remote file from `urlstring`
+ in terms of number of bytes.
+ Optionally takes a `source` address (not valid for HTTP(S)).
+ Supports HTTP, HTTPS, FTP and SFTP (through SFTP).
+ Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
url = urllib.parse.urlparse(urlstring)
+ username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
+
if url.scheme == 'http' or url.scheme == 'https':
- return get_http_file_size(urlstring)
+ if source:
+ print_error('Warning: Custom source address not supported for HTTP connections.')
+ return get_http_file_size(urlstring, username, password)
elif url.scheme == 'ftp':
- username = url.username if url.username else 'anonymous'
- return get_ftp_file_size(url.hostname, url.path, username, url.password, source=source)
+ source = get_source_address(source)[0] if source else None
+ username = username if username else 'anonymous'
+ return get_ftp_file_size(url.hostname, url.path, username, password, port, source)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- return get_sftp_file_size(url.hostname, url.path, url.username, url.password, source=source)
+ source = get_source_address(source) if source else None
+ return get_sftp_file_size(url.hostname, url.path, username, password, port, source)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
def get_remote_config(urlstring, source=None):
"""
- Download remote (config) file and return the contents.
+ Download remote (config) file from `urlstring` and return the contents as a string.
Args:
remote file URI:
- scp://<user>[:<passwd>]@<host>/<file>
- sftp://<user>[:<passwd>]@<host>/<file>
- http://<host>/<file>
- https://<host>/<file>
- ftp://[<user>[:<passwd>]@]<host>/<file>
- tftp://<host>/<file>
+ tftp://<host>[:<port>]/<file>
+ http[s]://<host>[:<port>]/<file>
+ [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file>
+ source address (optional):
+ <interface>
+ <IP address>
"""
- url = urllib.parse.urlparse(urlstring)
temp = tempfile.NamedTemporaryFile(delete=False).name
try:
download(temp, urlstring, source)
@@ -212,3 +363,41 @@ def get_remote_config(urlstring, source=None):
return file.read()
finally:
os.remove(temp)
+
+def friendly_download(local_path, urlstring, source=None):
+ """
+ Download from `urlstring` to `local_path` in an informative way.
+ Checks the storage space before attempting download.
+ Intended to be called from interactive, user-facing scripts.
+ """
+ destination_directory = os.path.dirname(local_path)
+ try:
+ free_space = shutil.disk_usage(destination_directory).free
+ try:
+ file_size = get_remote_file_size(urlstring, source)
+ if file_size < 1024 * 1024:
+ print_error(f'The file is {file_size / 1024.0:.3f} KiB.')
+ else:
+ print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.')
+ if file_size > free_space:
+ raise OSError(f'Not enough disk space available in "{destination_directory}".')
+ except ValueError:
+ # Can't do a storage check in this case, so we bravely continue.
+ file_size = 0
+ print_error('Could not determine the file size in advance.')
+ else:
+ print_error('Downloading...')
+ download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024)
+ except KeyboardInterrupt:
+ print_error('Download aborted by user.')
+ sys.exit(1)
+ except:
+ import traceback
+ # There are a myriad different reasons a download could fail.
+ # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...)
+ # We omit the scary stack trace but print the error nevertheless.
+ print_error(f'Failed to download {urlstring}.')
+ traceback.print_exception(*sys.exc_info()[:2], None)
+ sys.exit(1)
+ else:
+ print_error('Download complete.')
diff --git a/python/vyos/template.py b/python/vyos/template.py
index e1986b1e4..08a5712af 100644
--- a/python/vyos/template.py
+++ b/python/vyos/template.py
@@ -29,13 +29,17 @@ _FILTERS = {}
# reuse Environments with identical settings to improve performance
@functools.lru_cache(maxsize=2)
-def _get_environment():
+def _get_environment(location=None):
+ if location is None:
+ loc_loader=FileSystemLoader(directories["templates"])
+ else:
+ loc_loader=FileSystemLoader(location)
env = Environment(
# Don't check if template files were modified upon re-rendering
auto_reload=False,
# Cache up to this number of templates for quick re-rendering
cache_size=100,
- loader=FileSystemLoader(directories["templates"]),
+ loader=loc_loader,
trim_blocks=True,
)
env.filters.update(_FILTERS)
@@ -63,7 +67,7 @@ def register_filter(name, func=None):
return func
-def render_to_string(template, content, formater=None):
+def render_to_string(template, content, formater=None, location=None):
"""Render a template from the template directory, raise on any errors.
:param template: the path to the template relative to the template folder
@@ -78,7 +82,7 @@ def render_to_string(template, content, formater=None):
package is build (recovering the load time and overhead caused by having the
file out of the code).
"""
- template = _get_environment().get_template(template)
+ template = _get_environment(location).get_template(template)
rendered = template.render(content)
if formater is not None:
rendered = formater(rendered)
@@ -93,6 +97,7 @@ def render(
permission=None,
user=None,
group=None,
+ location=None,
):
"""Render a template from the template directory to a file, raise on any errors.
@@ -109,7 +114,7 @@ def render(
# As we are opening the file with 'w', we are performing the rendering before
# calling open() to not accidentally erase the file if rendering fails
- rendered = render_to_string(template, content, formater)
+ rendered = render_to_string(template, content, formater, location)
# Write to file
with open(destination, "w") as file:
@@ -375,3 +380,96 @@ def get_ipv4(interface):
""" Get interface IPv4 addresses"""
from vyos.ifconfig import Interface
return Interface(interface).get_addr_v4()
+
+@register_filter('get_ipv6')
+def get_ipv6(interface):
+ """ Get interface IPv6 addresses"""
+ from vyos.ifconfig import Interface
+ return Interface(interface).get_addr_v6()
+
+@register_filter('get_ip')
+def get_ip(interface):
+ """ Get interface IP addresses"""
+ from vyos.ifconfig import Interface
+ return Interface(interface).get_addr()
+
+@register_filter('get_esp_ike_cipher')
+def get_esp_ike_cipher(group_config):
+ pfs_lut = {
+ 'dh-group1' : 'modp768',
+ 'dh-group2' : 'modp1024',
+ 'dh-group5' : 'modp1536',
+ 'dh-group14' : 'modp2048',
+ 'dh-group15' : 'modp3072',
+ 'dh-group16' : 'modp4096',
+ 'dh-group17' : 'modp6144',
+ 'dh-group18' : 'modp8192',
+ 'dh-group19' : 'ecp256',
+ 'dh-group20' : 'ecp384',
+ 'dh-group21' : 'ecp512',
+ 'dh-group22' : 'modp1024s160',
+ 'dh-group23' : 'modp2048s224',
+ 'dh-group24' : 'modp2048s256',
+ 'dh-group25' : 'ecp192',
+ 'dh-group26' : 'ecp224',
+ 'dh-group27' : 'ecp224bp',
+ 'dh-group28' : 'ecp256bp',
+ 'dh-group29' : 'ecp384bp',
+ 'dh-group30' : 'ecp512bp',
+ 'dh-group31' : 'curve25519',
+ 'dh-group32' : 'curve448'
+ }
+
+ ciphers = []
+ if 'proposal' in group_config:
+ for priority, proposal in group_config['proposal'].items():
+ # both encryption and hash need to be specified for a proposal
+ if not {'encryption', 'hash'} <= set(proposal):
+ continue
+
+ tmp = '{encryption}-{hash}'.format(**proposal)
+ if 'dh_group' in proposal:
+ tmp += '-' + pfs_lut[ 'dh-group' + proposal['dh_group'] ]
+ elif 'pfs' in group_config and group_config['pfs'] != 'disable':
+ group = group_config['pfs']
+ if group_config['pfs'] == 'enable':
+ group = 'dh-group2'
+ tmp += '-' + pfs_lut[group]
+
+ ciphers.append(tmp)
+ return ciphers
+
+@register_filter('get_uuid')
+def get_uuid(interface):
+ """ Get interface IP addresses"""
+ from uuid import uuid1
+ return uuid1()
+
+openvpn_translate = {
+ 'des': 'des-cbc',
+ '3des': 'des-ede3-cbc',
+ 'bf128': 'bf-cbc',
+ 'bf256': 'bf-cbc',
+ 'aes128gcm': 'aes-128-gcm',
+ 'aes128': 'aes-128-cbc',
+ 'aes192gcm': 'aes-192-gcm',
+ 'aes192': 'aes-192-cbc',
+ 'aes256gcm': 'aes-256-gcm',
+ 'aes256': 'aes-256-cbc'
+}
+
+@register_filter('openvpn_cipher')
+def get_openvpn_cipher(cipher):
+ if cipher in openvpn_translate:
+ return openvpn_translate[cipher].upper()
+ return cipher.upper()
+
+@register_filter('openvpn_ncp_ciphers')
+def get_openvpn_ncp_ciphers(ciphers):
+ out = []
+ for cipher in ciphers:
+ if cipher in openvpn_translate:
+ out.append(openvpn_translate[cipher])
+ else:
+ out.append(cipher)
+ return ':'.join(out).upper()
diff --git a/python/vyos/util.py b/python/vyos/util.py
index 2a3f6a228..59f9f1c44 100644
--- a/python/vyos/util.py
+++ b/python/vyos/util.py
@@ -1,4 +1,4 @@
-# Copyright 2020 VyOS maintainers and contributors <maintainers@vyos.io>
+# Copyright 2020-2021 VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
@@ -22,25 +22,13 @@ import sys
# where it is used so it is as local as possible to the execution
#
-
-def _need_sudo(command):
- return os.path.basename(command.split()[0]) in ('systemctl', )
-
-
-def _add_sudo(command):
- if _need_sudo(command):
- return 'sudo ' + command
- return command
-
-
from subprocess import Popen
from subprocess import PIPE
from subprocess import STDOUT
from subprocess import DEVNULL
-
def popen(command, flag='', shell=None, input=None, timeout=None, env=None,
- stdout=PIPE, stderr=PIPE, decode='utf-8', autosudo=True):
+ stdout=PIPE, stderr=PIPE, decode='utf-8'):
"""
popen is a wrapper helper aound subprocess.Popen
with it default setting it will return a tuple (out, err)
@@ -79,9 +67,6 @@ def popen(command, flag='', shell=None, input=None, timeout=None, env=None,
if not debug.enabled(flag):
flag = 'command'
- if autosudo:
- command = _add_sudo(command)
-
cmd_msg = f"cmd '{command}'"
debug.message(cmd_msg, flag)
@@ -98,11 +83,8 @@ def popen(command, flag='', shell=None, input=None, timeout=None, env=None,
stdin = PIPE
input = input.encode() if type(input) is str else input
- p = Popen(
- command,
- stdin=stdin, stdout=stdout, stderr=stderr,
- env=env, shell=use_shell,
- )
+ p = Popen(command, stdin=stdin, stdout=stdout, stderr=stderr,
+ env=env, shell=use_shell)
pipe = p.communicate(input, timeout)
@@ -135,7 +117,7 @@ def popen(command, flag='', shell=None, input=None, timeout=None, env=None,
def run(command, flag='', shell=None, input=None, timeout=None, env=None,
- stdout=DEVNULL, stderr=PIPE, decode='utf-8', autosudo=True):
+ stdout=DEVNULL, stderr=PIPE, decode='utf-8'):
"""
A wrapper around popen, which discard the stdout and
will return the error code of a command
@@ -151,8 +133,8 @@ def run(command, flag='', shell=None, input=None, timeout=None, env=None,
def cmd(command, flag='', shell=None, input=None, timeout=None, env=None,
- stdout=PIPE, stderr=PIPE, decode='utf-8', autosudo=True,
- raising=None, message='', expect=[0]):
+ stdout=PIPE, stderr=PIPE, decode='utf-8', raising=None, message='',
+ expect=[0]):
"""
A wrapper around popen, which returns the stdout and
will raise the error code of a command
@@ -183,7 +165,7 @@ def cmd(command, flag='', shell=None, input=None, timeout=None, env=None,
def call(command, flag='', shell=None, input=None, timeout=None, env=None,
- stdout=PIPE, stderr=PIPE, decode='utf-8', autosudo=True):
+ stdout=PIPE, stderr=PIPE, decode='utf-8'):
"""
A wrapper around popen, which print the stdout and
will return the error code of a command
@@ -239,7 +221,6 @@ def write_file(fname, data, defaultonfailure=None, user=None, group=None):
return defaultonfailure
raise e
-
def read_json(fname, defaultonfailure=None):
"""
read and json decode the content of a file
@@ -459,7 +440,6 @@ def process_running(pid_file):
pid = f.read().strip()
return pid_exists(int(pid))
-
def process_named_running(name):
""" Checks if process with given name is running and returns its PID.
If Process is not running, return None
@@ -470,7 +450,6 @@ def process_named_running(name):
return p.pid
return None
-
def seconds_to_human(s, separator=""):
""" Converts number of seconds passed to a human-readable
interval such as 1w4d18h35m59s
@@ -525,6 +504,46 @@ def file_is_persistent(path):
absolute = os.path.abspath(os.path.dirname(path))
return re.match(location,absolute)
+def wait_for_inotify(file_path, pre_hook=None, event_type=None, timeout=None, sleep_interval=0.1):
+ """ Waits for an inotify event to occur """
+ if not os.path.dirname(file_path):
+ raise ValueError(
+ "File path {} does not have a directory part (required for inotify watching)".format(file_path))
+ if not os.path.basename(file_path):
+ raise ValueError(
+ "File path {} does not have a file part, do not know what to watch for".format(file_path))
+
+ from inotify.adapters import Inotify
+ from time import time
+ from time import sleep
+
+ time_start = time()
+
+ i = Inotify()
+ i.add_watch(os.path.dirname(file_path))
+
+ if pre_hook:
+ pre_hook()
+
+ for event in i.event_gen(yield_nones=True):
+ if (timeout is not None) and ((time() - time_start) > timeout):
+ # If the function didn't return until this point,
+ # the file failed to have been written to and closed within the timeout
+ raise OSError("Waiting for file {} to be written has failed".format(file_path))
+
+ # Most such events don't take much time, so it's better to check right away
+ # and sleep later.
+ if event is not None:
+ (_, type_names, path, filename) = event
+ if filename == os.path.basename(file_path):
+ if event_type in type_names:
+ return
+ sleep(sleep_interval)
+
+def wait_for_file_write_complete(file_path, pre_hook=None, timeout=None, sleep_interval=0.1):
+ """ Waits for a process to close a file after opening it in write mode. """
+ wait_for_inotify(file_path,
+ event_type='IN_CLOSE_WRITE', pre_hook=pre_hook, timeout=timeout, sleep_interval=sleep_interval)
def commit_in_progress():
""" Not to be used in normal op mode scripts! """
@@ -571,6 +590,25 @@ def wait_for_commit_lock():
while commit_in_progress():
sleep(1)
+def ask_input(question, default='', numeric_only=False, valid_responses=[]):
+ question_out = question
+ if default:
+ question_out += f' (Default: {default})'
+ response = ''
+ while True:
+ response = input(question_out + ' ').strip()
+ if not response and default:
+ return default
+ if numeric_only:
+ if not response.isnumeric():
+ print("Invalid value, try again.")
+ continue
+ response = int(response)
+ if valid_responses and response not in valid_responses:
+ print("Invalid value, try again.")
+ continue
+ break
+ return response
def ask_yes_no(question, default=False) -> bool:
"""Ask a yes/no question via input() and return their answer."""
@@ -672,6 +710,19 @@ def dict_search(path, my_dict):
c = c.get(p, {})
return c.get(parts[-1], None)
+def dict_search_args(dict_object, *path):
+ # Traverse dictionary using variable arguments
+ # Added due to above function not allowing for '.' in the key names
+ # Example: dict_search_args(some_dict, 'key', 'subkey', 'subsubkey', ...)
+ if not isinstance(dict_object, dict) or not path:
+ return None
+
+ for item in path:
+ if item not in dict_object:
+ return None
+ dict_object = dict_object[item]
+ return dict_object
+
def get_interface_config(interface):
""" Returns the used encapsulation protocol for given interface.
If interface does not exist, None is returned.
@@ -682,6 +733,16 @@ def get_interface_config(interface):
tmp = loads(cmd(f'ip -d -j link show {interface}'))[0]
return tmp
+def get_interface_address(interface):
+ """ Returns the used encapsulation protocol for given interface.
+ If interface does not exist, None is returned.
+ """
+ if not os.path.exists(f'/sys/class/net/{interface}'):
+ return None
+ from json import loads
+ tmp = loads(cmd(f'ip -d -j addr show {interface}'))[0]
+ return tmp
+
def get_all_vrfs():
""" Return a dictionary of all system wide known VRF instances """
from json import loads
@@ -694,3 +755,58 @@ def get_all_vrfs():
name = entry.pop('name')
data[name] = entry
return data
+
+def print_error(str='', end='\n'):
+ """
+ Print `str` to stderr, terminated with `end`.
+ Used for warnings and out-of-band messages to avoid mangling precious
+ stdout output.
+ """
+ sys.stderr.write(str)
+ sys.stderr.write(end)
+ sys.stderr.flush()
+
+def make_progressbar():
+ """
+ Make a procedure that takes two arguments `done` and `total` and prints a
+ progressbar based on the ratio thereof, whose length is determined by the
+ width of the terminal.
+ """
+ import shutil, math
+ col, _ = shutil.get_terminal_size()
+ col = max(col - 15, 20)
+ def print_progressbar(done, total):
+ if done <= total:
+ increment = total / col
+ length = math.ceil(done / increment)
+ percentage = str(math.ceil(100 * done / total)).rjust(3)
+ print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r')
+ # Print a newline so that the subsequent prints don't overwrite the full bar.
+ if done == total:
+ print_error()
+ return print_progressbar
+
+def make_incremental_progressbar(increment: float):
+ """
+ Make a generator that displays a progressbar that grows monotonically with
+ every iteration.
+ First call displays it at 0% and every subsequent iteration displays it
+ at `increment` increments where 0.0 < `increment` < 1.0.
+ Intended for FTP and HTTP transfers with stateless callbacks.
+ """
+ print_progressbar = make_progressbar()
+ total = 0.0
+ while total < 1.0:
+ print_progressbar(total, 1.0)
+ yield
+ total += increment
+ print_progressbar(1, 1)
+ # Ignore further calls.
+ while True:
+ yield
+
+def is_systemd_service_running(service):
+ """ Test is a specified systemd service is actually running.
+ Returns True if service is running, false otherwise. """
+ tmp = run(f'systemctl is-active --quiet {service}')
+ return bool((tmp == 0))
diff --git a/python/vyos/validate.py b/python/vyos/validate.py
index 23e88b5ac..0dad2a6cb 100644
--- a/python/vyos/validate.py
+++ b/python/vyos/validate.py
@@ -49,7 +49,7 @@ def is_intf_addr_assigned(intf, addr):
return _is_intf_addr_assigned(intf, ip, mask)
return _is_intf_addr_assigned(intf, addr)
-def _is_intf_addr_assigned(intf, address, netmask=''):
+def _is_intf_addr_assigned(intf, address, netmask=None):
"""
Verify if the given IPv4/IPv6 address is assigned to specific interface.
It can check both a single IP address (e.g. 192.0.2.1 or a assigned CIDR
@@ -85,14 +85,14 @@ def _is_intf_addr_assigned(intf, address, netmask=''):
continue
# we do not have a netmask to compare against, they are the same
- if netmask == '':
+ if not netmask:
return True
prefixlen = ''
if is_ipv4(ip_addr):
prefixlen = sum([bin(int(_)).count('1') for _ in ip['netmask'].split('.')])
else:
- prefixlen = sum([bin(int(_,16)).count('1') for _ in ip['netmask'].split(':') if _])
+ prefixlen = sum([bin(int(_,16)).count('1') for _ in ip['netmask'].split('/')[0].split(':') if _])
if str(prefixlen) == netmask:
return True
diff --git a/python/vyos/xml/load.py b/python/vyos/xml/load.py
index 0965d4220..37479c6e1 100644
--- a/python/vyos/xml/load.py
+++ b/python/vyos/xml/load.py
@@ -225,6 +225,9 @@ def _format_node(inside, conf, xml):
else:
_fatal(constraint)
+ elif 'constraintGroup' in properties:
+ properties.pop('constraintGroup')
+
elif 'constraintErrorMessage' in properties:
r[kw.error] = properties.pop('constraintErrorMessage')
diff --git a/python/vyos/xml/test_xml.py b/python/vyos/xml/test_xml.py
index ff55151d2..3a6f0132d 100644
--- a/python/vyos/xml/test_xml.py
+++ b/python/vyos/xml/test_xml.py
@@ -59,7 +59,7 @@ class TestSearch(TestCase):
last = self.xml.traverse("interfaces")
self.assertEqual(last, '')
self.assertEqual(self.xml.inside, ['interfaces'])
- self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wirelessmodem'])
+ self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wwan'])
self.assertEqual(self.xml.filling, False)
self.assertEqual(self.xml.word, '')
self.assertEqual(self.xml.check, False)
@@ -72,7 +72,7 @@ class TestSearch(TestCase):
last = self.xml.traverse("interfaces ")
self.assertEqual(last, '')
self.assertEqual(self.xml.inside, ['interfaces'])
- self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wirelessmodem'])
+ self.assertEqual(self.xml.options, ['bonding', 'bridge', 'dummy', 'ethernet', 'geneve', 'l2tpv3', 'loopback', 'macsec', 'openvpn', 'pppoe', 'pseudo-ethernet', 'tunnel', 'vxlan', 'wireguard', 'wireless', 'wwan'])
self.assertEqual(self.xml.filling, False)
self.assertEqual(self.xml.word, last)
self.assertEqual(self.xml.check, False)
@@ -85,7 +85,7 @@ class TestSearch(TestCase):
last = self.xml.traverse("interfaces w")
self.assertEqual(last, 'w')
self.assertEqual(self.xml.inside, ['interfaces'])
- self.assertEqual(self.xml.options, ['wireguard', 'wireless', 'wirelessmodem'])
+ self.assertEqual(self.xml.options, ['wireguard', 'wireless', 'wwan'])
self.assertEqual(self.xml.filling, True)
self.assertEqual(self.xml.word, last)
self.assertEqual(self.xml.check, True)
@@ -276,4 +276,4 @@ class TestSearch(TestCase):
self.assertEqual(self.xml.filled, True)
self.assertEqual(self.xml.plain, False)
- # Need to add a check for a valuless leafNode \ No newline at end of file
+ # Need to add a check for a valuless leafNode