diff options
Diffstat (limited to 'python')
| -rw-r--r-- | python/vyos/base.py | 3 | ||||
| -rw-r--r-- | python/vyos/ethtool.py | 15 | ||||
| -rw-r--r-- | python/vyos/ifconfig/ethernet.py | 14 | ||||
| -rw-r--r-- | python/vyos/ifconfig/interface.py | 10 | ||||
| -rw-r--r-- | python/vyos/opmode.py | 5 | ||||
| -rw-r--r-- | python/vyos/template.py | 1 | ||||
| -rw-r--r-- | python/vyos/utils/__init__.py | 0 | ||||
| -rw-r--r-- | python/vyos/utils/convert.py | 145 | ||||
| -rw-r--r-- | python/vyos/utils/dict.py | 256 | ||||
| -rw-r--r-- | python/vyos/utils/file.py | 171 | ||||
| -rw-r--r-- | python/vyos/utils/io.py | 103 | ||||
| -rw-r--r-- | python/vyos/xml/load.py | 18 | 
12 files changed, 721 insertions, 20 deletions
diff --git a/python/vyos/base.py b/python/vyos/base.py index 9b93cb2f2..c1acfd060 100644 --- a/python/vyos/base.py +++ b/python/vyos/base.py @@ -1,4 +1,4 @@ -# Copyright 2018-2022 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2018-2023 VyOS maintainers and contributors <maintainers@vyos.io>  #  # This library is free software; you can redistribute it and/or  # modify it under the terms of the GNU Lesser General Public @@ -41,7 +41,6 @@ class BaseWarning:                  isfirstmessage = False                  initial_indent = self.standardindent              print(f'{mes}') -        print('')  class Warning(): diff --git a/python/vyos/ethtool.py b/python/vyos/ethtool.py index bc3402059..1b1e54dfb 100644 --- a/python/vyos/ethtool.py +++ b/python/vyos/ethtool.py @@ -51,6 +51,7 @@ class Ethtool:      _ring_buffers_max = { }      _driver_name = None      _auto_negotiation = False +    _auto_negotiation_supported = None      _flow_control = False      _flow_control_enabled = None @@ -80,7 +81,13 @@ class Ethtool:                              self._speed_duplex.update({ speed : {}})                          if duplex not in self._speed_duplex[speed]:                              self._speed_duplex[speed].update({ duplex : ''}) -            if 'Auto-negotiation:' in line: +            if 'Supports auto-negotiation:' in line: +                # Split the following string: Auto-negotiation: off +                # we are only interested in off or on +                tmp = line.split()[-1] +                self._auto_negotiation_supported = bool(tmp == 'Yes') +            # Only read in if Auto-negotiation is supported +            if self._auto_negotiation_supported and 'Auto-negotiation:' in line:                  # Split the following string: Auto-negotiation: off                  # we are only interested in off or on                  tmp = line.split()[-1] @@ -132,8 +139,12 @@ class Ethtool:              # ['Autonegotiate:', 'on']              self._flow_control_enabled = out.splitlines()[1].split()[-1] +    def check_auto_negotiation_supported(self): +        """ Check if the NIC supports changing auto-negotiation """ +        return self._auto_negotiation_supported +      def get_auto_negotiation(self): -        return self._auto_negotiation +        return self._auto_negotiation_supported and self._auto_negotiation      def get_driver_name(self):          return self._driver_name diff --git a/python/vyos/ifconfig/ethernet.py b/python/vyos/ifconfig/ethernet.py index 5080144ff..6a49c022a 100644 --- a/python/vyos/ifconfig/ethernet.py +++ b/python/vyos/ifconfig/ethernet.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 VyOS maintainers and contributors <maintainers@vyos.io> +# Copyright 2019-2023 VyOS maintainers and contributors <maintainers@vyos.io>  #  # This library is free software; you can redistribute it and/or  # modify it under the terms of the GNU Lesser General Public @@ -14,9 +14,10 @@  # License along with this library.  If not, see <http://www.gnu.org/licenses/>.  import os -import re  from glob import glob + +from vyos.base import Warning  from vyos.ethtool import Ethtool  from vyos.ifconfig.interface import Interface  from vyos.util import run @@ -118,7 +119,7 @@ class EthernetIf(Interface):              cmd = f'ethtool --pause {ifname} autoneg {enable} tx {enable} rx {enable}'              output, code = self._popen(cmd)              if code: -                print(f'Could not set flowcontrol for {ifname}') +                Warning(f'could not change "{ifname}" flow control setting!')              return output          return None @@ -134,6 +135,7 @@ class EthernetIf(Interface):          >>> i = EthernetIf('eth0')          >>> i.set_speed_duplex('auto', 'auto')          """ +        ifname = self.config['ifname']          if speed not in ['auto', '10', '100', '1000', '2500', '5000', '10000',                           '25000', '40000', '50000', '100000', '400000']: @@ -143,7 +145,11 @@ class EthernetIf(Interface):              raise ValueError("Value out of range (duplex)")          if not self.ethtool.check_speed_duplex(speed, duplex): -            self._debug_msg(f'NIC driver does not support changing speed/duplex settings!') +            Warning(f'changing speed/duplex setting on "{ifname}" is unsupported!') +            return + +        if not self.ethtool.check_auto_negotiation_supported(): +            Warning(f'changing auto-negotiation setting on "{ifname}" is unsupported!')              return          # Get current speed and duplex settings: diff --git a/python/vyos/ifconfig/interface.py b/python/vyos/ifconfig/interface.py index fc33430eb..2f1d5eb96 100644 --- a/python/vyos/ifconfig/interface.py +++ b/python/vyos/ifconfig/interface.py @@ -532,7 +532,7 @@ class Interface(Control):              return None          # As a PoC we only allow 'dummy' interfaces -        if 'dum' not in self.ifname: +        if not ('dum' in self.ifname or 'veth' in self.ifname):              return None          # Check if interface realy exists in namespace @@ -1709,6 +1709,14 @@ class VLANIf(Interface):          if self.exists(f'{self.ifname}'):              return +        # If source_interface or vlan_id was not explicitly defined (e.g. when +        # calling  VLANIf('eth0.1').remove() we can define source_interface and +        # vlan_id here, as it's quiet obvious that it would be eth0 in that case. +        if 'source_interface' not in self.config: +            self.config['source_interface'] = '.'.join(self.ifname.split('.')[:-1]) +        if 'vlan_id' not in self.config: +            self.config['vlan_id'] = self.ifname.split('.')[-1] +          cmd = 'ip link add link {source_interface} name {ifname} type vlan id {vlan_id}'          if 'protocol' in self.config:              cmd += ' protocol {protocol}' diff --git a/python/vyos/opmode.py b/python/vyos/opmode.py index d7172a0b5..230a85541 100644 --- a/python/vyos/opmode.py +++ b/python/vyos/opmode.py @@ -209,6 +209,11 @@ def run(module):          for opt in type_hints:              th = type_hints[opt] +            # Function argument names use underscores as separators +            # but command-line options should use hyphens +            # Without this, we'd get options like "--foo_bar" +            opt = re.sub(r'_', '-', opt) +              if _get_arg_type(th) == bool:                  subparser.add_argument(f"--{opt}", action='store_true')              else: diff --git a/python/vyos/template.py b/python/vyos/template.py index 06a292706..254a15e3a 100644 --- a/python/vyos/template.py +++ b/python/vyos/template.py @@ -44,6 +44,7 @@ def _get_environment(location=None):          loader=loc_loader,          trim_blocks=True,          undefined=ChainableUndefined, +        extensions=['jinja2.ext.loopcontrols']      )      env.filters.update(_FILTERS)      env.tests.update(_TESTS) diff --git a/python/vyos/utils/__init__.py b/python/vyos/utils/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/python/vyos/utils/__init__.py diff --git a/python/vyos/utils/convert.py b/python/vyos/utils/convert.py new file mode 100644 index 000000000..975c67e0a --- /dev/null +++ b/python/vyos/utils/convert.py @@ -0,0 +1,145 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library.  If not, see <http://www.gnu.org/licenses/>. + +def seconds_to_human(s, separator=""): +    """ Converts number of seconds passed to a human-readable +    interval such as 1w4d18h35m59s +    """ +    s = int(s) + +    week = 60 * 60 * 24 * 7 +    day = 60 * 60 * 24 +    hour = 60 * 60 + +    remainder = 0 +    result = "" + +    weeks = s // week +    if weeks > 0: +        result = "{0}w".format(weeks) +        s = s % week + +    days = s // day +    if days > 0: +        result = "{0}{1}{2}d".format(result, separator, days) +        s = s % day + +    hours = s // hour +    if hours > 0: +        result = "{0}{1}{2}h".format(result, separator, hours) +        s = s % hour + +    minutes = s // 60 +    if minutes > 0: +        result = "{0}{1}{2}m".format(result, separator, minutes) +        s = s % 60 + +    seconds = s +    if seconds > 0: +        result = "{0}{1}{2}s".format(result, separator, seconds) + +    return result + +def bytes_to_human(bytes, initial_exponent=0, precision=2): +    """ Converts a value in bytes to a human-readable size string like 640 KB + +    The initial_exponent parameter is the exponent of 2, +    e.g. 10 (1024) for kilobytes, 20 (1024 * 1024) for megabytes. +    """ + +    if bytes == 0: +        return "0 B" + +    from math import log2 + +    bytes = bytes * (2**initial_exponent) + +    # log2 is a float, while range checking requires an int +    exponent = int(log2(bytes)) + +    if exponent < 10: +        value = bytes +        suffix = "B" +    elif exponent in range(10, 20): +        value = bytes / 1024 +        suffix = "KB" +    elif exponent in range(20, 30): +        value = bytes / 1024**2 +        suffix = "MB" +    elif exponent in range(30, 40): +        value = bytes / 1024**3 +        suffix = "GB" +    else: +        value = bytes / 1024**4 +        suffix = "TB" +    # Add a new case when the first machine with petabyte RAM +    # hits the market. + +    size_string = "{0:.{1}f} {2}".format(value, precision, suffix) +    return size_string + +def human_to_bytes(value): +    """ Converts a data amount with a unit suffix to bytes, like 2K to 2048 """ + +    from re import match as re_match + +    res = re_match(r'^\s*(\d+(?:\.\d+)?)\s*([a-zA-Z]+)\s*$', value) + +    if not res: +        raise ValueError(f"'{value}' is not a valid data amount") +    else: +        amount = float(res.group(1)) +        unit = res.group(2).lower() + +        if unit == 'b': +            res = amount +        elif (unit == 'k') or (unit == 'kb'): +            res = amount * 1024 +        elif (unit == 'm') or (unit == 'mb'): +            res = amount * 1024**2 +        elif (unit == 'g') or (unit == 'gb'): +            res = amount * 1024**3 +        elif (unit == 't') or (unit == 'tb'): +            res = amount * 1024**4 +        else: +            raise ValueError(f"Unsupported data unit '{unit}'") + +    # There cannot be fractional bytes, so we convert them to integer. +    # However, truncating causes problems with conversion back to human unit, +    # so we round instead -- that seems to work well enough. +    return round(res) + +def mac_to_eui64(mac, prefix=None): +    """ +    Convert a MAC address to a EUI64 address or, with prefix provided, a full +    IPv6 address. +    Thankfully copied from https://gist.github.com/wido/f5e32576bb57b5cc6f934e177a37a0d3 +    """ +    import re +    from ipaddress import ip_network +    # http://tools.ietf.org/html/rfc4291#section-2.5.1 +    eui64 = re.sub(r'[.:-]', '', mac).lower() +    eui64 = eui64[0:6] + 'fffe' + eui64[6:] +    eui64 = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:] + +    if prefix is None: +        return ':'.join(re.findall(r'.{4}', eui64)) +    else: +        try: +            net = ip_network(prefix, strict=False) +            euil = int('0x{0}'.format(eui64), 16) +            return str(net[euil]) +        except:  # pylint: disable=bare-except +            return diff --git a/python/vyos/utils/dict.py b/python/vyos/utils/dict.py new file mode 100644 index 000000000..4afc9f54e --- /dev/null +++ b/python/vyos/utils/dict.py @@ -0,0 +1,256 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library.  If not, see <http://www.gnu.org/licenses/>. + + +def colon_separated_to_dict(data_string, uniquekeys=False): +    """ Converts a string containing newline-separated entries +        of colon-separated key-value pairs into a dict. + +        Such files are common in Linux /proc filesystem + +    Args: +        data_string (str): data string +        uniquekeys (bool): whether to insist that keys are unique or not + +    Returns: dict + +    Raises: +        ValueError: if uniquekeys=True and the data string has +            duplicate keys. + +    Note: +        If uniquekeys=True, then dict entries are always strings, +        otherwise they are always lists of strings. +    """ +    import re +    key_value_re = re.compile('([^:]+)\s*\:\s*(.*)') + +    data_raw = re.split('\n', data_string) + +    data = {} + +    for l in data_raw: +        l = l.strip() +        if l: +            match = re.match(key_value_re, l) +            if match and (len(match.groups()) == 2): +                key = match.groups()[0].strip() +                value = match.groups()[1].strip() +            else: +                raise ValueError(f"""Line "{l}" could not be parsed a colon-separated pair """, l) +            if key in data.keys(): +                if uniquekeys: +                    raise ValueError("Data string has duplicate keys: {0}".format(key)) +                else: +                    data[key].append(value) +            else: +                if uniquekeys: +                    data[key] = value +                else: +                    data[key] = [value] +        else: +            pass + +    return data + +def _mangle_dict_keys(data, regex, replacement, abs_path=[], no_tag_node_value_mangle=False, mod=0): +    """ Mangles dict keys according to a regex and replacement character. +    Some libraries like Jinja2 do not like certain characters in dict keys. +    This function can be used for replacing all offending characters +    with something acceptable. + +    Args: +        data (dict): Original dict to mangle + +    Returns: dict +    """ +    from vyos.xml import is_tag + +    new_dict = {} + +    for key in data.keys(): +        save_mod = mod +        save_path = abs_path[:] + +        abs_path.append(key) + +        if not is_tag(abs_path): +            new_key = re.sub(regex, replacement, key) +        else: +            if mod%2: +                new_key = key +            else: +                new_key = re.sub(regex, replacement, key) +            if no_tag_node_value_mangle: +                mod += 1 + +        value = data[key] + +        if isinstance(value, dict): +            new_dict[new_key] = _mangle_dict_keys(value, regex, replacement, abs_path=abs_path, mod=mod, no_tag_node_value_mangle=no_tag_node_value_mangle) +        else: +            new_dict[new_key] = value + +        mod = save_mod +        abs_path = save_path[:] + +    return new_dict + +def mangle_dict_keys(data, regex, replacement, abs_path=[], no_tag_node_value_mangle=False): +    return _mangle_dict_keys(data, regex, replacement, abs_path=abs_path, no_tag_node_value_mangle=no_tag_node_value_mangle, mod=0) + +def _get_sub_dict(d, lpath): +    k = lpath[0] +    if k not in d.keys(): +        return {} +    c = {k: d[k]} +    lpath = lpath[1:] +    if not lpath: +        return c +    elif not isinstance(c[k], dict): +        return {} +    return _get_sub_dict(c[k], lpath) + +def get_sub_dict(source, lpath, get_first_key=False): +    """ Returns the sub-dict of a nested dict, defined by path of keys. + +    Args: +        source (dict): Source dict to extract from +        lpath (list[str]): sequence of keys + +    Returns: source, if lpath is empty, else +             {key : source[..]..[key]} for key the last element of lpath, if exists +             {} otherwise +    """ +    if not isinstance(source, dict): +        raise TypeError("source must be of type dict") +    if not isinstance(lpath, list): +        raise TypeError("path must be of type list") +    if not lpath: +        return source + +    ret =  _get_sub_dict(source, lpath) + +    if get_first_key and lpath and ret: +        tmp = next(iter(ret.values())) +        if not isinstance(tmp, dict): +            raise TypeError("Data under node is not of type dict") +        ret = tmp + +    return ret + +def dict_search(path, dict_object): +    """ Traverse Python dictionary (dict_object) delimited by dot (.). +    Return value of key if found, None otherwise. + +    This is faster implementation then jmespath.search('foo.bar', dict_object)""" +    if not isinstance(dict_object, dict) or not path: +        return None + +    parts = path.split('.') +    inside = parts[:-1] +    if not inside: +        if path not in dict_object: +            return None +        return dict_object[path] +    c = dict_object +    for p in parts[:-1]: +        c = c.get(p, {}) +    return c.get(parts[-1], None) + +def dict_search_args(dict_object, *path): +    # Traverse dictionary using variable arguments +    # Added due to above function not allowing for '.' in the key names +    # Example: dict_search_args(some_dict, 'key', 'subkey', 'subsubkey', ...) +    if not isinstance(dict_object, dict) or not path: +        return None + +    for item in path: +        if item not in dict_object: +            return None +        dict_object = dict_object[item] +    return dict_object + +def dict_search_recursive(dict_object, key, path=[]): +    """ Traverse a dictionary recurisvely and return the value of the key +    we are looking for. + +    Thankfully copied from https://stackoverflow.com/a/19871956 + +    Modified to yield optional path to found keys +    """ +    if isinstance(dict_object, list): +        for i in dict_object: +            new_path = path + [i] +            for x in dict_search_recursive(i, key, new_path): +                yield x +    elif isinstance(dict_object, dict): +        if key in dict_object: +            new_path = path + [key] +            yield dict_object[key], new_path +        for k, j in dict_object.items(): +            new_path = path + [k] +            for x in dict_search_recursive(j, key, new_path): +                yield x + +def dict_to_list(d, save_key_to=None): +    """ Convert a dict to a list of dicts. + +    Optionally, save the original key of the dict inside +    dicts stores in that list. +    """ +    def save_key(i, k): +        if isinstance(i, dict): +            i[save_key_to] = k +            return +        elif isinstance(i, list): +            for _i in i: +                save_key(_i, k) +        else: +            raise ValueError(f"Cannot save the key: the item is {type(i)}, not a dict") + +    collect = [] + +    for k,_ in d.items(): +        item = d[k] +        if save_key_to is not None: +            save_key(item, k) +        if isinstance(item, list): +            collect += item +        else: +            collect.append(item) + +    return collect + +def check_mutually_exclusive_options(d, keys, required=False): +    """ Checks if a dict has at most one or only one of +    mutually exclusive keys. +    """ +    present_keys = [] + +    for k in d: +        if k in keys: +            present_keys.append(k) + +    # Un-mangle the keys to make them match CLI option syntax +    from re import sub +    orig_keys = list(map(lambda s: sub(r'_', '-', s), keys)) +    orig_present_keys = list(map(lambda s: sub(r'_', '-', s), present_keys)) + +    if len(present_keys) > 1: +        raise ValueError(f"Options {orig_keys} are mutually-exclusive but more than one of them is present: {orig_present_keys}") + +    if required and (len(present_keys) < 1): +        raise ValueError(f"At least one of the following options is required: {orig_present_keys}") diff --git a/python/vyos/utils/file.py b/python/vyos/utils/file.py new file mode 100644 index 000000000..2560a35be --- /dev/null +++ b/python/vyos/utils/file.py @@ -0,0 +1,171 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library.  If not, see <http://www.gnu.org/licenses/>. + +import os + + +def read_file(fname, defaultonfailure=None): +    """ +    read the content of a file, stripping any end characters (space, newlines) +    should defaultonfailure be not None, it is returned on failure to read +    """ +    try: +        """ Read a file to string """ +        with open(fname, 'r') as f: +            data = f.read().strip() +        return data +    except Exception as e: +        if defaultonfailure is not None: +            return defaultonfailure +        raise e + +def write_file(fname, data, defaultonfailure=None, user=None, group=None, mode=None, append=False): +    """ +    Write content of data to given fname, should defaultonfailure be not None, +    it is returned on failure to read. + +    If directory of file is not present, it is auto-created. +    """ +    dirname = os.path.dirname(fname) +    if not os.path.isdir(dirname): +        os.makedirs(dirname, mode=0o755, exist_ok=False) +        chown(dirname, user, group) + +    try: +        """ Write a file to string """ +        bytes = 0 +        with open(fname, 'w' if not append else 'a') as f: +            bytes = f.write(data) +        chown(fname, user, group) +        chmod(fname, mode) +        return bytes +    except Exception as e: +        if defaultonfailure is not None: +            return defaultonfailure +        raise e + +def read_json(fname, defaultonfailure=None): +    """ +    read and json decode the content of a file +    should defaultonfailure be not None, it is returned on failure to read +    """ +    import json +    try: +        with open(fname, 'r') as f: +            data = json.load(f) +        return data +    except Exception as e: +        if defaultonfailure is not None: +            return defaultonfailure +        raise e + +def chown(path, user, group): +    """ change file/directory owner """ +    from pwd import getpwnam +    from grp import getgrnam + +    if user is None or group is None: +        return False + +    # path may also be an open file descriptor +    if not isinstance(path, int) and not os.path.exists(path): +        return False + +    uid = getpwnam(user).pw_uid +    gid = getgrnam(group).gr_gid +    os.chown(path, uid, gid) +    return True + + +def chmod(path, bitmask): +    # path may also be an open file descriptor +    if not isinstance(path, int) and not os.path.exists(path): +        return +    if bitmask is None: +        return +    os.chmod(path, bitmask) + + +def chmod_600(path): +    """ Make file only read/writable by owner """ +    from stat import S_IRUSR, S_IWUSR + +    bitmask = S_IRUSR | S_IWUSR +    chmod(path, bitmask) + + +def chmod_750(path): +    """ Make file/directory only executable to user and group """ +    from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP + +    bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP +    chmod(path, bitmask) + + +def chmod_755(path): +    """ Make file executable by all """ +    from stat import S_IRUSR, S_IWUSR, S_IXUSR, S_IRGRP, S_IXGRP, S_IROTH, S_IXOTH + +    bitmask = S_IRUSR | S_IWUSR | S_IXUSR | S_IRGRP | S_IXGRP | \ +              S_IROTH | S_IXOTH +    chmod(path, bitmask) + + +def makedir(path, user=None, group=None): +    if os.path.exists(path): +        return +    os.makedirs(path, mode=0o755) +    chown(path, user, group) + +def wait_for_inotify(file_path, pre_hook=None, event_type=None, timeout=None, sleep_interval=0.1): +    """ Waits for an inotify event to occur """ +    if not os.path.dirname(file_path): +        raise ValueError( +          "File path {} does not have a directory part (required for inotify watching)".format(file_path)) +    if not os.path.basename(file_path): +        raise ValueError( +          "File path {} does not have a file part, do not know what to watch for".format(file_path)) + +    from inotify.adapters import Inotify +    from time import time +    from time import sleep + +    time_start = time() + +    i = Inotify() +    i.add_watch(os.path.dirname(file_path)) + +    if pre_hook: +        pre_hook() + +    for event in i.event_gen(yield_nones=True): +        if (timeout is not None) and ((time() - time_start) > timeout): +            # If the function didn't return until this point, +            # the file failed to have been written to and closed within the timeout +            raise OSError("Waiting for file {} to be written has failed".format(file_path)) + +        # Most such events don't take much time, so it's better to check right away +        # and sleep later. +        if event is not None: +            (_, type_names, path, filename) = event +            if filename == os.path.basename(file_path): +                if event_type in type_names: +                    return +        sleep(sleep_interval) + +def wait_for_file_write_complete(file_path, pre_hook=None, timeout=None, sleep_interval=0.1): +    """ Waits for a process to close a file after opening it in write mode. """ +    wait_for_inotify(file_path, +      event_type='IN_CLOSE_WRITE', pre_hook=pre_hook, timeout=timeout, sleep_interval=sleep_interval) diff --git a/python/vyos/utils/io.py b/python/vyos/utils/io.py new file mode 100644 index 000000000..843494855 --- /dev/null +++ b/python/vyos/utils/io.py @@ -0,0 +1,103 @@ +# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io> +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library.  If not, see <http://www.gnu.org/licenses/>. + +def print_error(str='', end='\n'): +    """ +    Print `str` to stderr, terminated with `end`. +    Used for warnings and out-of-band messages to avoid mangling precious +     stdout output. +    """ +    import sys +    sys.stderr.write(str) +    sys.stderr.write(end) +    sys.stderr.flush() + +def make_progressbar(): +    """ +    Make a procedure that takes two arguments `done` and `total` and prints a +     progressbar based on the ratio thereof, whose length is determined by the +     width of the terminal. +    """ +    import shutil, math +    col, _ = shutil.get_terminal_size() +    col = max(col - 15, 20) +    def print_progressbar(done, total): +        if done <= total: +            increment = total / col +            length = math.ceil(done / increment) +            percentage = str(math.ceil(100 * done / total)).rjust(3) +            print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') +            # Print a newline so that the subsequent prints don't overwrite the full bar. +        if done == total: +            print_error() +    return print_progressbar + +def make_incremental_progressbar(increment: float): +    """ +    Make a generator that displays a progressbar that grows monotonically with +     every iteration. +    First call displays it at 0% and every subsequent iteration displays it +     at `increment` increments where 0.0 < `increment` < 1.0. +    Intended for FTP and HTTP transfers with stateless callbacks. +    """ +    print_progressbar = make_progressbar() +    total = 0.0 +    while total < 1.0: +        print_progressbar(total, 1.0) +        yield +        total += increment +    print_progressbar(1, 1) +    # Ignore further calls. +    while True: +        yield + +def ask_input(question, default='', numeric_only=False, valid_responses=[]): +    question_out = question +    if default: +        question_out += f' (Default: {default})' +    response = '' +    while True: +        response = input(question_out + ' ').strip() +        if not response and default: +            return default +        if numeric_only: +            if not response.isnumeric(): +                print("Invalid value, try again.") +                continue +            response = int(response) +        if valid_responses and response not in valid_responses: +            print("Invalid value, try again.") +            continue +        break +    return response + +def ask_yes_no(question, default=False) -> bool: +    """Ask a yes/no question via input() and return their answer.""" +    from sys import stdout +    default_msg = "[Y/n]" if default else "[y/N]" +    while True: +        try: +            stdout.write("%s %s " % (question, default_msg)) +            c = input().lower() +            if c == '': +                return default +            elif c in ("y", "ye", "yes"): +                return True +            elif c in ("n", "no"): +                return False +            else: +                stdout.write("Please respond with yes/y or no/n\n") +        except EOFError: +            stdout.write("\nPlease respond with yes/y or no/n\n") diff --git a/python/vyos/xml/load.py b/python/vyos/xml/load.py index c3022f3d6..f842ff9ce 100644 --- a/python/vyos/xml/load.py +++ b/python/vyos/xml/load.py @@ -71,16 +71,12 @@ def _merge(dict1, dict2):              continue          if isinstance(dict1[k], dict) and isinstance(dict2[k], dict):              dict1[k] = _merge(dict1[k], dict2[k]) -        elif isinstance(dict1[k], dict) and isinstance(dict2[k], dict): +        elif isinstance(dict1[k], list) and isinstance(dict2[k], list):              dict1[k].extend(dict2[k])          elif dict1[k] == dict2[k]: -            # A definition shared between multiple files -            if k in (kw.valueless, kw.multi, kw.hidden, kw.node, kw.summary, kw.owner, kw.priority): -                continue -            _fatal() -            raise RuntimeError('parsing issue - undefined leaf?') +            continue          else: -            raise RuntimeError('parsing issue - we messed up?') +            dict1[k] = dict2[k]      return dict1 @@ -131,7 +127,7 @@ def _format_nodes(inside, conf, xml):                  name = node.pop('@name')                  into = inside + [name]                  if name in r: -                    r[name].update(_format_node(into, node, xml)) +                    _merge(r[name], _format_node(into, node, xml))                  else:                      r[name] = _format_node(into, node, xml)                  r[name][kw.node] = nodename @@ -141,7 +137,7 @@ def _format_nodes(inside, conf, xml):              name = node.pop('@name')              into = inside + [name]              if name in r: -                r[name].update(_format_node(inside + [name], node, xml)) +                _merge(r[name], _format_node(inside + [name], node, xml))              else:                  r[name] = _format_node(inside + [name], node, xml)              r[name][kw.node] = nodename @@ -180,10 +176,10 @@ def _format_node(inside, conf, xml):              if isinstance(conf, list):                  for child in children: -                    r = _safe_update(r, _format_nodes(inside, child, xml)) +                    _merge(r, _format_nodes(inside, child, xml))              else:                  child = children -                r = _safe_update(r, _format_nodes(inside, child, xml)) +                _merge(r, _format_nodes(inside, child, xml))          elif 'properties' in keys:              properties = conf.pop('properties')  | 
