diff options
Diffstat (limited to 'python/vyos/remote.py')
| -rw-r--r-- | python/vyos/remote.py | 61 | 
1 files changed, 48 insertions, 13 deletions
| diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 52d234d4a..aef4dda7e 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -15,6 +15,7 @@  from ftplib import FTP  import os +import shutil  import socket  import sys  import tempfile @@ -28,14 +29,23 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy  known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') +def print_error(str): +    """ +    Used for warnings and out-of-band messages to avoid mangling precious +    stdout output. +    """ +    sys.stderr.write(str) +    sys.stderr.write('\n') +    sys.stderr.flush() +  class InteractivePolicy(MissingHostKeyPolicy):      """      Policy for interactively querying the user on whether to proceed with      SSH connections to unknown hosts.      """      def missing_host_key(self, client, hostname, key): -        print(f"Host '{hostname}' not found in known hosts.") -        print('Fingerprint: ' + key.get_fingerprint().hex()) +        print_error(f"Host '{hostname}' not found in known hosts.") +        print_error('Fingerprint: ' + key.get_fingerprint().hex())          if ask_yes_no('Do you wish to continue?'):              if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):                  client._host_keys.add(hostname, key.get_name(), key) @@ -186,7 +196,7 @@ def download(local_path, urlstring, authentication=None, source=None):      if url.scheme == 'http' or url.scheme == 'https':          if source: -            print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr) +            print_error('Warning: Custom source address not supported for HTTP connections.')          download_http(urlstring, local_path, username, password)      elif url.scheme == 'ftp':          username = username if username else 'anonymous' @@ -247,16 +257,6 @@ def get_remote_file_size(urlstring, authentication=None, source=None):      else:          raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_file_size_maybe(urlstring, authentication=None, source=None): -    """ -    Passes arguments to `get_remote_file_size()` but returns 0 if it fails. -    Intended to be used in shell scripts only. -    """ -    try: -        return get_remote_file_size(urlstring, authentication, source) -    except ValueError: -        return 0 -  def get_remote_config(urlstring, authentication=None, source=None):      """      Download remote (config) file from `urlstring` and return the contents as a string. @@ -282,3 +282,38 @@ def get_remote_config(urlstring, authentication=None, source=None):              return file.read()      finally:          os.remove(temp) + +def friendly_download(local_path, urlstring, authentication=None, source=None): +    """ +    Intended to be called from interactive, user-facing scripts. +    """ +    destination_directory = os.path.dirname(local_path) +    free_space = shutil.disk_usage(destination_directory).free +    try: +        try: +            file_size = get_remote_file_size(urlstring, authentication, source) +            if file_size < 1024 * 1024: +                print_error(f'The file is {file_size / 1024.0:.3f} KiB.') +            else: +                print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') +            if file_size > free_space: +                raise OSError(f'Not enough disk space available in "{destination_directory}".') +        except ValueError: +            print_error('Could not determine the file size in advance.') +        else: +            # TODO: Progress bar +            print_error('Downloading...') +            download(local_path, urlstring, authentication, source) +    except KeyboardInterrupt: +        print_error('Download aborted by user.') +        sys.exit(1) +    except: +        import traceback +        # There are a myriad different reasons a download could fail. +        # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) +        # We omit the scary stack trace but print the error nevertheless. +        print_error(f'Failed to download {urlstring}.') +        traceback.print_exception(*sys.exc_info()[:2], None) +        sys.exit(1) +    else: +        print_error('Download complete.') | 
