diff options
| author | Christian Poessinger <christian@poessinger.com> | 2021-11-26 07:39:45 +0100 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-11-26 07:39:45 +0100 | 
| commit | 1df8ba611f04c46d49f4cf70d6fa22cfef089392 (patch) | |
| tree | 4d6ce5a6e4fbe0bba93336a75ac2d16dfe98fe5a /python | |
| parent | fed88cda4722024fdfcb5c2951f074fa6caf4d7c (diff) | |
| parent | 6577c9725e74c564f23a87e1f91050900f258fe9 (diff) | |
| download | vyos-1x-1df8ba611f04c46d49f4cf70d6fa22cfef089392.tar.gz vyos-1x-1df8ba611f04c46d49f4cf70d6fa22cfef089392.zip  | |
Merge pull request #1080 from erkin/current
remote: T3356: Rewrite remote.py
Diffstat (limited to 'python')
| -rw-r--r-- | python/vyos/remote.py | 572 | ||||
| -rw-r--r-- | python/vyos/util.py | 14 | 
2 files changed, 268 insertions, 318 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index e972050b7..2419f8873 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -13,38 +13,40 @@  # You should have received a copy of the GNU Lesser General Public  # License along with this library.  If not, see <http://www.gnu.org/licenses/>. -from ftplib import FTP  import os  import shutil  import socket +import ssl  import stat  import sys  import tempfile  import urllib.parse -import urllib.request as urlreq -from vyos.template import get_ip -from vyos.template import ip_from_cidr -from vyos.template import is_interface -from vyos.template import is_ipv6 -from vyos.util import cmd +from ftplib import FTP +from ftplib import FTP_TLS + +from paramiko import SSHClient +from paramiko import MissingHostKeyPolicy + +from requests import Session +from requests.adapters import HTTPAdapter +from requests.packages.urllib3 import PoolManager +  from vyos.util import ask_yes_no -from vyos.util import print_error -from vyos.util import make_progressbar +from vyos.util import begin +from vyos.util import cmd  from vyos.util import make_incremental_progressbar +from vyos.util import make_progressbar +from vyos.util import print_error  from vyos.version import get_version -from paramiko import SSHClient -from paramiko import SSHException -from paramiko import MissingHostKeyPolicy -# This is a hardcoded path and no environment variable can change it. -KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts') +  CHUNK_SIZE = 8192  class InteractivePolicy(MissingHostKeyPolicy):      """ -    Policy for interactively querying the user on whether to proceed with -     SSH connections to unknown hosts. +    Paramiko policy for interactively querying the user on whether to proceed +     with SSH connections to unknown hosts.      """      def missing_host_key(self, client, hostname, key):          print_error(f"Host '{hostname}' not found in known hosts.") @@ -57,337 +59,270 @@ class InteractivePolicy(MissingHostKeyPolicy):          else:              raise SSHException(f"Cannot connect to unknown host '{hostname}'.") - -## Helper routines -def get_authentication_variables(default_username=None, default_password=None): +class SourceAdapter(HTTPAdapter):      """ -    Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and -     return the defaults provided if environment variables are empty or nonexistent. +    urllib3 transport adapter for setting source addresses per session.      """ -    username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') -    # Fall back to defaults if the username variable doesn't exist or is an empty string. -    # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, -    #  as we want the username and the password to have the same behaviour. -    if not username: -        return default_username, default_password -    else: -        return username, password - -def get_source_address(source): +    def __init__(self, source_pair, *args, **kwargs): +        # A source pair is a tuple of a source host string and source port respectively. +        # Supply '' and 0 respectively for default values. +        self._source_pair = source_pair +        super(SourceAdapter, self).__init__(*args, **kwargs) + +    def init_poolmanager(self, connections, maxsize, block=False): +        self.poolmanager = PoolManager( +            num_pools=connections, maxsize=maxsize, +            block=block, source_address=self._source_pair) + +class WrappedFile: +    def __init__(self, obj, size=None, chunk_size=CHUNK_SIZE): +        self._obj = obj +        self._progress = size and make_incremental_progressbar(chunk_size / size) +    def read(self, size=-1): +        if self._progress: +            next(self._progress) +        self._obj.read(size) +    def write(self, size=-1): +        if self._progress: +            next(self._progress) +        self._obj.write(size) +    def __getattr__(self, attr): +         return getattr(self._obj, attr) + +def check_storage(path, size):      """ -    Take a string vaguely indicating an origin source (interface, hostname or IP address), -     return a tuple in the format `(source_pair, address_family)` where -     `source_pair` is `(source_address, source_port)`. +    Check whether `path` has enough storage space for a transfer of `size` bytes.      """ -    # TODO: Properly distinguish between IPv4 and IPv6. -    port = 0 -    if is_interface(source): -        source = ip_from_cidr(get_ip(source)[0]) -    if is_ipv6(source): -        return (source, port), socket.AF_INET6 +    path = os.path.abspath(os.path.expanduser(path)) +    directory = path if os.path.isdir(path) else (os.path.dirname(os.path.expanduser(path)) or os.getcwd()) +    # `size` can be None or 0 to indicate unknown size. +    if not size: +        print_error('Warning: Cannot determine size of remote file.') +        print_error('Bravely continuing regardless.') +        return + +    if size < 1024 * 1024: +        print_error(f'The file is {size / 1024.0:.3f} KiB.')      else: -        return (socket.gethostbyname(source), port), socket.AF_INET - -def get_port_from_url(url): -    """ -    Return the port number from the given `url` named tuple, fall back to -     the default if there isn't one. -    """ -    defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\ -                "ssh": 22, "scp": 22, "sftp": 22} -    if url.port: -        return url.port -    else: -        return defaults[url.scheme] - - -## FTP routines -def upload_ftp(local_path, hostname, remote_path,\ -               username='anonymous', password='', port=21,\ -               source_pair=None, progressbar=False): -    size = os.path.getsize(local_path) -    with FTP(source_address=source_pair) as conn: -        conn.connect(hostname, port) -        conn.login(username, password) -        with open(local_path, 'rb') as file: -            if progressbar and size: +        print_error(f'The file is {size / (1024.0 * 1024.0):.3f} MiB.') + +    # Will throw `FileNotFoundError' if `directory' is absent. +    if size > shutil.disk_usage(directory).free: +        raise OSError(f'Not enough disk space available in "{directory}".') + + +class FtpC: +    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): +        self.secure = url.scheme == 'ftps' +        self.hostname = url.hostname +        self.path = url.path +        self.username = url.username or os.getenv('REMOTE_USERNAME', 'anonymous') +        self.password = url.password or os.getenv('REMOTE_PASSWORD', '') +        self.port = url.port or 21 +        self.source = (source_host, source_port) +        self.progressbar = progressbar +        self.check_space = check_space + +    def _establish(self): +        if self.secure: +            return FTP_TLS(source_address=self.source, context=ssl.create_default_context()) +        else: +            return FTP(source_address=self.source) + +    def download(self, location: str): +        # Open the file upfront before establishing connection. +        with open(location, 'wb') as f, self._establish() as conn: +            conn.connect(self.hostname, self.port) +            conn.login(self.username, self.password) +            # Set secure connection over TLS. +            if self.secure: +                conn.prot_p() +            # Almost all FTP servers support the `SIZE' command. +            if self.check_space: +                check_storage(path, conn.size(self.path)) +            # No progressbar if we can't determine the size or if the file is too small. +            if self.progressbar and size and size > CHUNK_SIZE:                  progress = make_incremental_progressbar(CHUNK_SIZE / size)                  next(progress) -                callback = lambda block: next(progress) +                callback = lambda block: begin(f.write(block), next(progress))              else: -                callback = None -            conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback) - -def download_ftp(local_path, hostname, remote_path,\ -                 username='anonymous', password='', port=21,\ -                 source_pair=None, progressbar=False): -    with FTP(source_address=source_pair) as conn: -        conn.connect(hostname, port) -        conn.login(username, password) -        size = conn.size(remote_path) -        with open(local_path, 'wb') as file: -            # No progressbar if we can't determine the size. -            if progressbar and size: +                callback = f.write +            conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) + +    def upload(self, location: str): +        size = os.path.getsize(location) +        with open(location, 'rb') as f, self._establish() as conn: +            conn.connect(self.hostname, self.port) +            conn.login(self.username, self.password) +            if self.secure: +                conn.prot_p() +            if self.progressbar and size and size > CHUNK_SIZE:                  progress = make_incremental_progressbar(CHUNK_SIZE / size)                  next(progress) -                callback = lambda block: (file.write(block), next(progress)) +                callback = lambda block: next(progress)              else: -                callback = file.write -            conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE) - -def get_ftp_file_size(hostname, remote_path,\ -                      username='anonymous', password='', port=21,\ -                      source_pair=None): -    with FTP(source_address=source) as conn: -        conn.connect(hostname, port) -        conn.login(username, password) -        size = conn.size(remote_path) -        if size: -            return size -        else: -            # SIZE is an extension to the FTP specification, although it's extremely common. -            raise ValueError('Failed to receive file size from FTP server. \ -            Perhaps the server does not implement the SIZE command?') - - -## SFTP/SCP routines -def transfer_sftp(mode, local_path, hostname, remote_path,\ -                  username=None, password=None, port=22,\ -                  source_tuple=None, progressbar=False): -    sock = None -    if source_tuple: -        (source_address, source_port), address_family = source_tuple -        sock = socket.socket(address_family, socket.SOCK_STREAM) -        sock.bind((source_address, source_port)) -        sock.connect((hostname, port)) -    callback = make_progressbar() if progressbar else None -    with SSHClient() as ssh: +                callback = None +            conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, callback) + +class SshC: +    known_hosts = os.path.expanduser('~/.ssh/known_hosts') +    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): +        self.hostname = url.hostname +        self.path = url.path +        self.username = url.username or os.getenv('REMOTE_USERNAME') +        self.password = url.password or os.getenv('REMOTE_PASSWORD') +        self.port = url.port or 22 +        self.source = (source_host, source_port) +        self.progressbar = progressbar +        self.check_space = check_space + +    def _establish(self): +        ssh = SSHClient()          ssh.load_system_host_keys() -        if os.path.exists(KNOWN_HOSTS_FILE): -            ssh.load_host_keys(KNOWN_HOSTS_FILE) +        # Try to load from a user-local known hosts file if one exists. +        if os.path.exists(self.known_hosts): +            ssh.load_host_keys(self.known_hosts)          ssh.set_missing_host_key_policy(InteractivePolicy()) -        ssh.connect(hostname, port, username, password, sock=sock) -        with ssh.open_sftp() as sftp: -            if mode == 'upload': +        # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family +        #  for us on dual-stack systems. +        sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source) +        ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock) +        return ssh + +    def download(self, location: str): +        callback = make_progressbar() if self.progressbar else None +        with self._establish() as ssh, ssh.open_sftp() as sftp: +            if self.check_space: +                check_storage(location, sftp.stat(self.path).st_size) +            sftp.get(self.path, location, callback=callback) + +    def upload(self, location: str): +        callback = make_progressbar() if self.progressbar else None +        with self._establish() as ssh, ssh.open_sftp() as sftp: +            try: +                # If the remote path is a directory, use the original filename. +                if stat.S_ISDIR(sftp.stat(self.path).st_mode): +                    path = os.path.join(self.path, os.path.basename(location)) +                # A file exists at this destination. We're simply going to clobber it. +                else: +                    path = self.path +            # This path doesn't point at any existing file. We can freely use this filename. +            except IOError: +                path = self.path +            finally: +                sftp.put(location, path, callback=callback) + + +class HttpC: +    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): +        self.urlstring = urllib.parse.urlunsplit(url) +        self.progressbar = progressbar +        self.check_space = check_space +        self.source_pair = (source_host, source_port) +        self.username = url.username or os.getenv('REMOTE_USERNAME') +        self.password = url.password or os.getenv('REMOTE_PASSWORD') + +    def _establish(self): +        session = Session() +        session.mount(self.urlstring, SourceAdapter(self.source_pair)) +        session.headers.update({'User-Agent': 'VyOS/' + get_version()}) +        if self.username: +            session.auth = self.username, self.password +        return session + +    def download(self, location: str): +        with self._establish() as s: +            # We ask for uncompressed downloads so that we don't have to deal with decoding. +            # Not only would it potentially mess up with the progress bar but +            # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding. +            s.headers.update({'Accept-Encoding': 'identity'}) +            with s.head(self.urlstring) as r: +                # Abort early if the destination is inaccessible. +                r.raise_for_status() +                # Check for the prospective file size.                  try: -                    # If the remote path is a directory, use the original filename. -                    if stat.S_ISDIR(sftp.stat(remote_path).st_mode): -                        path = os.path.join(remote_path, os.path.basename(local_path)) -                    # A file exists at this destination. We're simply going to clobber it. -                    else: -                        path = remote_path -                # This path doesn't point at any existing file. We can freely use this filename. -                except IOError: -                    path = remote_path -                finally: -                    sftp.put(local_path, path, callback=callback) -            elif mode == 'download': -                sftp.get(remote_path, local_path, callback=callback) -            elif mode == 'size': -                return sftp.stat(remote_path).st_size - -def upload_sftp(*args, **kwargs): -    transfer_sftp('upload', *args, **kwargs) - -def download_sftp(*args, **kwargs): -    transfer_sftp('download', *args, **kwargs) - -def get_sftp_file_size(*args, **kwargs): -    return transfer_sftp('size', None, *args, **kwargs) - - -## TFTP routines -def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): -    source_option = f'--interface {source}' if source else '' -    progress_flag = '--progress-bar' if progressbar else '-s' -    with open(local_path, 'rb') as file: -        cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\ -            stderr=None, input=file.read()).encode() - -def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): -    source_option = f'--interface {source}' if source else '' -    # Not really applicable but we pass it for the sake of uniformity. -    progress_flag = '--progress-bar' if progressbar else '-s' -    with open(local_path, 'wb') as file: -        file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ -                       stderr=None).encode()) - -# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, -#  as TFTP does not specify a SIZE command. - - -## HTTP(S) routines -def install_request_opener(urlstring, username, password): -    """ -    Take `username` and `password` strings and install the appropriate -     password manager to `urllib.request.urlopen()` for the given `urlstring`. -    """ -    manager = urlreq.HTTPPasswordMgrWithDefaultRealm() -    manager.add_password(None, urlstring, username, password) -    urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager))) - -# upload_http() is unimplemented. - -def download_http(local_path, urlstring, username=None, password=None, progressbar=False): -    """ -    Download the file from from `urlstring` to `local_path`. -    Optionally takes `username` and `password` for authentication. -    """ -    request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) -    if username: -        install_request_opener(urlstring, username, password) -    with open(local_path, 'wb') as file, urlreq.urlopen(request) as response: -        size = response.getheader('Content-Length') -        if progressbar and size: -            progress = make_incremental_progressbar(CHUNK_SIZE / int(size)) -            next(progress) -            for chunk in iter(lambda: response.read(CHUNK_SIZE), b''): -                file.write(chunk) -                next(progress) -            next(progress) -        # If we can't determine the size or if a progress bar wasn't requested, -        #  we can let `shutil` take care of the copying. -        else: -            shutil.copyfileobj(response, file) - -def get_http_file_size(urlstring, username=None, password=None): -    """ -    Return the size of the file from `urlstring` in terms of number of bytes. -    Optionally takes `username` and `password` for authentication. -    """ -    request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) -    if username: -        install_request_opener(urlstring, username, password) -    with urlreq.urlopen(request) as response: -        size = response.getheader('Content-Length') -        if size: -            return int(size) -        # The server didn't send 'Content-Length' in the response headers. -        else: -            raise ValueError('Failed to receive file size from HTTP server.') - - -## Dynamic dispatchers -def download(local_path, urlstring, source=None, progressbar=False): +                    size = int(r.headers['Content-Length']) +                # In case the server does not supply the header. +                except KeyError: +                    size = None +            if self.check_space: +                check_storage(location, size) +            with s.get(self.urlstring, stream=True) as r, open(location, 'wb') as f: +                if self.progressbar and size: +                    progress = make_incremental_progressbar(CHUNK_SIZE / size) +                    next(progress) +                    for chunk in iter(lambda: begin(next(progress), r.raw.read(CHUNK_SIZE)), b''): +                        f.write(chunk) +                else: +                    # We'll try to stream the download directly with `copyfileobj()` so that large +                    #  files (like entire VyOS images) don't occupy much memory. +                    shutil.copyfileobj(r.raw, f) + +    def upload(self, location: str): +        size = os.path.getsize(location) if self.progressbar else None +        # Keep in mind that `data` can be a file-like or iterable object. +        with self._establish() as s, file(location, 'rb') as f: +            s.post(self.urlstring, data=WrappedFile(f, size)) + + +class TftpC: +    # We simply allow `curl` to take over because +    # 1. TFTP is rather simple. +    # 2. Since there's no concept authentication, we don't need to deal with keys/passwords. +    # 3. It would be a waste to import, audit and maintain a third-party library for TFTP. +    # 4. I'd rather not implement the entire protocol here, no matter how simple it is. +    def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0): +        source_option = f'--interface {source_host} --local-port {source_port}' if source_host else '' +        progress_flag = '--progress-bar' if progressbar else '-s' +        self.command = f'curl {source_option} {progress_flag}' +        self.urlstring = urllib.parse.urlunsplit(url) + +    def download(self, location: str): +        with open(location, 'wb') as f: +            f.write(cmd(f'{self.command} "{self.urlstring}"').encode()) + +    def upload(self, location: str): +        with open(location, 'rb') as f: +            cmd(f'{self.command} -T - "{self.urlstring}"', input=f.read()) + + +def urlc(urlstring, *args, **kwargs):      """ -    Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. -    Optionally takes a `source` address or interface (not valid for HTTP(S)). -    Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. -    Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. +    Dynamically dispatch the appropriate protocol class.      """ -    url = urllib.parse.urlparse(urlstring) -    username, password = get_authentication_variables(url.username, url.password) -    port = get_port_from_url(url) - -    if url.scheme == 'http' or url.scheme == 'https': -        if source: -            print_error('Warning: Custom source address not supported for HTTP connections.') -        download_http(local_path, urlstring, username, password, progressbar) -    elif url.scheme == 'ftp': -        source = get_source_address(source)[0] if source else None -        username = username if username else 'anonymous' -        download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) -    elif url.scheme == 'sftp' or url.scheme == 'scp': -        source = get_source_address(source) if source else None -        download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) -    elif url.scheme == 'tftp': -        download_tftp(local_path, url.hostname, url.path, port, source, progressbar) -    else: -        raise ValueError(f'Unsupported URL scheme: {url.scheme}') +    url_classes = {'http': HttpC, 'https': HttpC, 'ftp': FtpC, 'ftps': FtpC, \ +                   'sftp': SshC, 'ssh': SshC, 'scp': SshC, 'tftp': TftpC} +    url = urllib.parse.urlsplit(urlstring) +    try: +        return url_classes[url.scheme](url, *args, **kwargs) +    except KeyError: +        raise ValueError(f'Unsupported URL scheme: "{url.scheme}"') -def upload(local_path, urlstring, source=None, progressbar=False): -    """ -    Dispatch the appropriate upload function for the given URL and upload from local path. -    Optionally takes a `source` address. -    Supports FTP, SFTP, SCP (through SFTP) and TFTP. -    Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. -    """ -    url = urllib.parse.urlparse(urlstring) -    username, password = get_authentication_variables(url.username, url.password) -    port = get_port_from_url(url) - -    if url.scheme == 'ftp': -        username = username if username else 'anonymous' -        source = get_source_address(source)[0] if source else None -        upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) -    elif url.scheme == 'sftp' or url.scheme == 'scp': -        source = get_source_address(source) if source else None -        upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) -    elif url.scheme == 'tftp': -        upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) -    else: -        raise ValueError(f'Unsupported URL scheme: {url.scheme}') +def download(local_path, urlstring, *args, **kwargs): +    urlc(urlstring, *args, **kwargs).download(local_path) -def get_remote_file_size(urlstring, source=None): -    """ -    Dispatch the appropriate function to return the size of the remote file from `urlstring` -     in terms of number of bytes. -    Optionally takes a `source` address (not valid for HTTP(S)). -    Supports HTTP, HTTPS, FTP and SFTP (through SFTP). -    Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. -    """ -    url = urllib.parse.urlparse(urlstring) -    username, password = get_authentication_variables(url.username, url.password) -    port = get_port_from_url(url) - -    if url.scheme == 'http' or url.scheme == 'https': -        if source: -            print_error('Warning: Custom source address not supported for HTTP connections.') -        return get_http_file_size(urlstring, username, password) -    elif url.scheme == 'ftp': -        source = get_source_address(source)[0] if source else None -        username = username if username else 'anonymous' -        return get_ftp_file_size(url.hostname, url.path, username, password, port, source) -    elif url.scheme == 'sftp' or url.scheme == 'scp': -        source = get_source_address(source) if source else None -        return get_sftp_file_size(url.hostname, url.path, username, password, port, source) -    else: -        raise ValueError(f'Unsupported URL scheme: {url.scheme}') +def upload(local_path, urlstring, *args, **kwargs): +    urlc(urlstring, *args, **kwargs).upload(local_path) -def get_remote_config(urlstring, source=None): +def get_remote_config(urlstring, source_host='', source_port=0):      """ -    Download remote (config) file from `urlstring` and return the contents as a string. -        Args: -            remote file URI: -                tftp://<host>[:<port>]/<file> -                http[s]://<host>[:<port>]/<file> -                [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file> -            source address (optional): -                <interface> -                <IP address> +    Quietly download a file and return it as a string.      """      temp = tempfile.NamedTemporaryFile(delete=False).name      try: -        download(temp, urlstring, source) -        with open(temp, 'r') as file: -            return file.read() +        download(temp, urlstring, False, False, source_host, source_port) +        with open(temp, 'r') as f: +            return f.read()      finally:          os.remove(temp) -def friendly_download(local_path, urlstring, source=None): +def friendly_download(local_path, urlstring, source_host='', source_port=0):      """ -    Download from `urlstring` to `local_path` in an informative way. -    Checks the storage space before attempting download. -    Intended to be called from interactive, user-facing scripts. +    Download with a progress bar, reassuring messages and free space checks.      """ -    destination_directory = os.path.dirname(local_path)      try: -        free_space = shutil.disk_usage(destination_directory).free -        try: -            file_size = get_remote_file_size(urlstring, source) -            if file_size < 1024 * 1024: -                print_error(f'The file is {file_size / 1024.0:.3f} KiB.') -            else: -                print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') -            if file_size > free_space: -                raise OSError(f'Not enough disk space available in "{destination_directory}".') -        except ValueError: -            # Can't do a storage check in this case, so we bravely continue. -            file_size = 0 -            print_error('Could not determine the file size in advance.') -        else: -            print_error('Downloading...') -            download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024) +        print_error('Downloading...') +        download(local_path, urlstring, True, True, source_host, source_port)      except KeyboardInterrupt:          print_error('Download aborted by user.')          sys.exit(1) @@ -401,3 +336,4 @@ def friendly_download(local_path, urlstring, source=None):          sys.exit(1)      else:          print_error('Download complete.') +        sys.exit(0) diff --git a/python/vyos/util.py b/python/vyos/util.py index 9aa1f98d2..d8e83ab8d 100644 --- a/python/vyos/util.py +++ b/python/vyos/util.py @@ -856,6 +856,20 @@ def make_incremental_progressbar(increment: float):      while True:          yield +def begin(*args): +    """ +    Evaluate arguments in order and return the result of the *last* argument. +    For combining multiple expressions in one statement. Useful for lambdas. +    """ +    return args[-1] + +def begin0(*args): +    """ +    Evaluate arguments in order and return the result of the *first* argument. +    For combining multiple expressions in one statement. Useful for lambdas. +    """ +    return args[0] +  def is_systemd_service_active(service):      """ Test is a specified systemd service is activated.      Returns True if service is active, false otherwise.  | 
