summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/vyos/remote.py61
1 files changed, 48 insertions, 13 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 52d234d4a..aef4dda7e 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -15,6 +15,7 @@
from ftplib import FTP
import os
+import shutil
import socket
import sys
import tempfile
@@ -28,14 +29,23 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy
known_hosts_file = os.path.expanduser('~/.ssh/known_hosts')
+def print_error(str):
+ """
+ Used for warnings and out-of-band messages to avoid mangling precious
+ stdout output.
+ """
+ sys.stderr.write(str)
+ sys.stderr.write('\n')
+ sys.stderr.flush()
+
class InteractivePolicy(MissingHostKeyPolicy):
"""
Policy for interactively querying the user on whether to proceed with
SSH connections to unknown hosts.
"""
def missing_host_key(self, client, hostname, key):
- print(f"Host '{hostname}' not found in known hosts.")
- print('Fingerprint: ' + key.get_fingerprint().hex())
+ print_error(f"Host '{hostname}' not found in known hosts.")
+ print_error('Fingerprint: ' + key.get_fingerprint().hex())
if ask_yes_no('Do you wish to continue?'):
if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
client._host_keys.add(hostname, key.get_name(), key)
@@ -186,7 +196,7 @@ def download(local_path, urlstring, authentication=None, source=None):
if url.scheme == 'http' or url.scheme == 'https':
if source:
- print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr)
+ print_error('Warning: Custom source address not supported for HTTP connections.')
download_http(urlstring, local_path, username, password)
elif url.scheme == 'ftp':
username = username if username else 'anonymous'
@@ -247,16 +257,6 @@ def get_remote_file_size(urlstring, authentication=None, source=None):
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
-def get_remote_file_size_maybe(urlstring, authentication=None, source=None):
- """
- Passes arguments to `get_remote_file_size()` but returns 0 if it fails.
- Intended to be used in shell scripts only.
- """
- try:
- return get_remote_file_size(urlstring, authentication, source)
- except ValueError:
- return 0
-
def get_remote_config(urlstring, authentication=None, source=None):
"""
Download remote (config) file from `urlstring` and return the contents as a string.
@@ -282,3 +282,38 @@ def get_remote_config(urlstring, authentication=None, source=None):
return file.read()
finally:
os.remove(temp)
+
+def friendly_download(local_path, urlstring, authentication=None, source=None):
+ """
+ Intended to be called from interactive, user-facing scripts.
+ """
+ destination_directory = os.path.dirname(local_path)
+ free_space = shutil.disk_usage(destination_directory).free
+ try:
+ try:
+ file_size = get_remote_file_size(urlstring, authentication, source)
+ if file_size < 1024 * 1024:
+ print_error(f'The file is {file_size / 1024.0:.3f} KiB.')
+ else:
+ print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.')
+ if file_size > free_space:
+ raise OSError(f'Not enough disk space available in "{destination_directory}".')
+ except ValueError:
+ print_error('Could not determine the file size in advance.')
+ else:
+ # TODO: Progress bar
+ print_error('Downloading...')
+ download(local_path, urlstring, authentication, source)
+ except KeyboardInterrupt:
+ print_error('Download aborted by user.')
+ sys.exit(1)
+ except:
+ import traceback
+ # There are a myriad different reasons a download could fail.
+ # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...)
+ # We omit the scary stack trace but print the error nevertheless.
+ print_error(f'Failed to download {urlstring}.')
+ traceback.print_exception(*sys.exc_info()[:2], None)
+ sys.exit(1)
+ else:
+ print_error('Download complete.')