diff options
Diffstat (limited to 'python/vyos/remote.py')
-rw-r--r-- | python/vyos/remote.py | 61 |
1 files changed, 48 insertions, 13 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 52d234d4a..aef4dda7e 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -15,6 +15,7 @@ from ftplib import FTP import os +import shutil import socket import sys import tempfile @@ -28,14 +29,23 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') +def print_error(str): + """ + Used for warnings and out-of-band messages to avoid mangling precious + stdout output. + """ + sys.stderr.write(str) + sys.stderr.write('\n') + sys.stderr.flush() + class InteractivePolicy(MissingHostKeyPolicy): """ Policy for interactively querying the user on whether to proceed with SSH connections to unknown hosts. """ def missing_host_key(self, client, hostname, key): - print(f"Host '{hostname}' not found in known hosts.") - print('Fingerprint: ' + key.get_fingerprint().hex()) + print_error(f"Host '{hostname}' not found in known hosts.") + print_error('Fingerprint: ' + key.get_fingerprint().hex()) if ask_yes_no('Do you wish to continue?'): if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): client._host_keys.add(hostname, key.get_name(), key) @@ -186,7 +196,7 @@ def download(local_path, urlstring, authentication=None, source=None): if url.scheme == 'http' or url.scheme == 'https': if source: - print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr) + print_error('Warning: Custom source address not supported for HTTP connections.') download_http(urlstring, local_path, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' @@ -247,16 +257,6 @@ def get_remote_file_size(urlstring, authentication=None, source=None): else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_file_size_maybe(urlstring, authentication=None, source=None): - """ - Passes arguments to `get_remote_file_size()` but returns 0 if it fails. - Intended to be used in shell scripts only. - """ - try: - return get_remote_file_size(urlstring, authentication, source) - except ValueError: - return 0 - def get_remote_config(urlstring, authentication=None, source=None): """ Download remote (config) file from `urlstring` and return the contents as a string. @@ -282,3 +282,38 @@ def get_remote_config(urlstring, authentication=None, source=None): return file.read() finally: os.remove(temp) + +def friendly_download(local_path, urlstring, authentication=None, source=None): + """ + Intended to be called from interactive, user-facing scripts. + """ + destination_directory = os.path.dirname(local_path) + free_space = shutil.disk_usage(destination_directory).free + try: + try: + file_size = get_remote_file_size(urlstring, authentication, source) + if file_size < 1024 * 1024: + print_error(f'The file is {file_size / 1024.0:.3f} KiB.') + else: + print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') + if file_size > free_space: + raise OSError(f'Not enough disk space available in "{destination_directory}".') + except ValueError: + print_error('Could not determine the file size in advance.') + else: + # TODO: Progress bar + print_error('Downloading...') + download(local_path, urlstring, authentication, source) + except KeyboardInterrupt: + print_error('Download aborted by user.') + sys.exit(1) + except: + import traceback + # There are a myriad different reasons a download could fail. + # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) + # We omit the scary stack trace but print the error nevertheless. + print_error(f'Failed to download {urlstring}.') + traceback.print_exception(*sys.exc_info()[:2], None) + sys.exit(1) + else: + print_error('Download complete.') |