diff options
Diffstat (limited to 'python/vyos')
-rw-r--r-- | python/vyos/remote.py | 193 |
1 files changed, 133 insertions, 60 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 7d371b3c0..6c98f3219 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -14,6 +14,7 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. from ftplib import FTP +import math import os import shutil import socket @@ -29,77 +30,141 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') -def print_error(str): +class InteractivePolicy(MissingHostKeyPolicy): + """ + Policy for interactively querying the user on whether to proceed with + SSH connections to unknown hosts. + """ + def missing_host_key(self, client, hostname, key): + print_error(f"Host '{hostname}' not found in known hosts.") + print_error('Fingerprint: ' + key.get_fingerprint().hex()) + if ask_yes_no('Do you wish to continue?'): + if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): + client._host_keys.add(hostname, key.get_name(), key) + client.save_host_keys(client._host_keys_filename) + else: + raise SSHException(f"Cannot connect to unknown host '{hostname}'.") + + +## Helper routines +def print_error(str='', end='\n'): """ + Print `str` to stderr, terminated with `end`. Used for warnings and out-of-band messages to avoid mangling precious - stdout output. + stdout output. """ sys.stderr.write(str) - sys.stderr.write('\n') + sys.stderr.write(end) sys.stderr.flush() +def make_progressbar(increment: float): + """ + Return a generator that displays progressbar whose length is determined + by the width of the terminal with every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0 + """ + col, _ = shutil.get_terminal_size() + # Try for 20 columns if the terminal is too narrow. Let it overflow. + col = max(col - 15, 20) + total = 0.0 + while True: + length = min(round(total * col), col) + percentage = str(math.floor(total * 100)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + if total >= 1.0: + # Print a newline so that the subsequent prints don't overwrite the bar. + print_error() + break + # Add a new increment with each iteration. + yield + total = min(total + increment, 1.0) + # Ignore further calls. + while True: + yield + def get_authentication_variables(default_username=None, default_password=None): """ - Returns the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and - returns the defaults provided if environment variables are empty or nonexistent. + Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and + return the defaults provided if environment variables are empty or nonexistent. """ username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') # Fall back to defaults if the username variable doesn't exist or is an empty string. + # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, + # as we want the username and the password to have the same behaviour. if not username: return (default_username, default_password) else: return (username, password) -class InteractivePolicy(MissingHostKeyPolicy): +def get_port_from_url(url): """ - Policy for interactively querying the user on whether to proceed with - SSH connections to unknown hosts. + Return the port number from the given `url` named tuple, fall back to + the default if there isn't one. """ - def missing_host_key(self, client, hostname, key): - print_error(f"Host '{hostname}' not found in known hosts.") - print_error('Fingerprint: ' + key.get_fingerprint().hex()) - if ask_yes_no('Do you wish to continue?'): - if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): - client._host_keys.add(hostname, key.get_name(), key) - client.save_host_keys(client._host_keys_filename) - else: - raise SSHException(f"Cannot connect to unknown host '{hostname}'.") + defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\ + "ssh": 22, "scp": 22, "sftp": 22} + if url.port: + return url.port + else: + return defaults[url.scheme] ## FTP routines -def transfer_ftp(mode, local_path, hostname, remote_path,\ - username='anonymous', password='', port=21, source=None): +def upload_ftp(local_path, hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source=None, progressbar=False): + size = os.path.getsize(local_path) + blocksize = 8192 with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) - if mode == 'upload': - with open(local_path, 'rb') as file: - conn.storbinary(f'STOR {remote_path}', file) - elif mode == 'download': - with open(local_path, 'wb') as file: - conn.retrbinary(f'RETR {remote_path}', file.write) - elif mode == 'size': - size = conn.size(remote_path) - if size: - return size + with open(local_path, 'rb') as file: + if progressbar and size: + progress = make_progressbar(blocksize / size) + next(progress) + callback = lambda block: next(progress) else: - # SIZE is an extension to the FTP specification, although it's extremely common. - raise ValueError('Failed to receive file size from FTP server. \ - Perhaps the server does not implement the SIZE command?') + callback = None + conn.storbinary(f'STOR {remote_path}', file, blocksize, callback) -def upload_ftp(*args, **kwargs): - transfer_ftp('upload', *args, **kwargs) - -def download_ftp(*args, **kwargs): - transfer_ftp('download', *args, **kwargs) +def download_ftp(local_path, hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source=None, progressbar=False): + blocksize = 8192 + with FTP(source_address=source) as conn: + conn.connect(hostname, port) + conn.login(username, password) + size = conn.size(remote_path) + with open(local_path, 'wb') as file: + # No progressbar if we can't determine the size. + if progressbar and size: + progress = make_progressbar(blocksize / size) + next(progress) + callback = lambda block: (file.write(block), next(progress)) + else: + callback = file.write + conn.retrbinary(f'RETR {remote_path}', callback, blocksize) -def get_ftp_file_size(*args, **kwargs): - return transfer_ftp('size', None, *args, **kwargs) +def get_ftp_file_size(hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source=None): + with FTP(source_address=source) as conn: + conn.connect(hostname, port) + conn.login(username, password) + size = conn.size(remote_path) + if size: + return size + else: + # SIZE is an extension to the FTP specification, although it's extremely common. + raise ValueError('Failed to receive file size from FTP server. \ + Perhaps the server does not implement the SIZE command?') ## SFTP/SCP routines def transfer_sftp(mode, local_path, hostname, remote_path,\ - username=None, password=None, port=22, source=None): + username=None, password=None, port=22,\ + source=None, progressbar=False): sock = None if source: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -135,27 +200,29 @@ def get_sftp_file_size(*args, **kwargs): ## TFTP routines -def upload_tftp(local_path, hostname, remote_path, port=69, source=None): +def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'rb') as file: - cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\ + cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\ stderr=None, input=file.read()).encode() -def download_tftp(local_path, hostname, remote_path, port=69, source=None): +def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'wb') as file: - file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\ + file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ stderr=None).encode()) # get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, -# as TFTP does not specify a SIZE command. +# as TFTP does not specify a SIZE command. ## HTTP(S) routines def install_request_opener(urlstring, username, password): """ - Take`username` and `password` strings and install the appropriate - password manager to `urllib.request.urlopen()` for the given `urlstring`. + Take `username` and `password` strings and install the appropriate + password manager to `urllib.request.urlopen()` for the given `urlstring`. """ manager = urlreq.HTTPPasswordMgrWithDefaultRealm() manager.add_password(None, urlstring, username, password) @@ -163,7 +230,7 @@ def install_request_opener(urlstring, username, password): # upload_http() is unimplemented. -def download_http(urlstring, local_path, username=None, password=None): +def download_http(urlstring, local_path, username=None, password=None, progressbar=False): """ Download the file from from `urlstring` to `local_path`. Optionally takes `username` and `password` for authentication. @@ -193,7 +260,7 @@ def get_http_file_size(urlstring, username=None, password=None): # Dynamic dispatchers -def download(local_path, urlstring, source=None): +def download(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. Optionally takes a `source` address (not valid for HTTP(S)). @@ -202,6 +269,7 @@ def download(local_path, urlstring, source=None): """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': if source: @@ -209,15 +277,15 @@ def download(local_path, urlstring, source=None): download_http(urlstring, local_path, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' - download_ftp(local_path, url.hostname, url.path, username, password, source=source) + download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': - download_sftp(local_path, url.hostname, url.path, username, password, source=source) + download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': - download_tftp(local_path, url.hostname, url.path, source=source) + download_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def upload(local_path, urlstring, source=None): +def upload(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate upload function for the given URL and upload from local path. Optionally takes a `source` address. @@ -226,14 +294,15 @@ def upload(local_path, urlstring, source=None): """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) if url.scheme == 'ftp': username = username if username else 'anonymous' - upload_ftp(local_path, url.hostname, url.path, username, password, source=source) + upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': - upload_sftp(local_path, url.hostname, url.path, username, password, source=source) + upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': - upload_tftp(local_path, url.hostname, url.path, source=source) + upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') @@ -247,14 +316,15 @@ def get_remote_file_size(urlstring, source=None): """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': return get_http_file_size(urlstring, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' - return get_ftp_file_size(url.hostname, url.path, username, password, source=source) + return get_ftp_file_size(url.hostname, url.path, username, password, port, source) elif url.scheme == 'sftp' or url.scheme == 'scp': - return get_sftp_file_size(url.hostname, url.path, username, password, source=source) + return get_sftp_file_size(url.hostname, url.path, username, password, port, source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') @@ -284,6 +354,8 @@ def get_remote_config(urlstring, source=None): def friendly_download(local_path, urlstring, source=None): """ + Download from `urlstring` to `local_path` in an informative way. + Checks the storage space before attempting download. Intended to be called from interactive, user-facing scripts. """ destination_directory = os.path.dirname(local_path) @@ -298,11 +370,12 @@ def friendly_download(local_path, urlstring, source=None): if file_size > free_space: raise OSError(f'Not enough disk space available in "{destination_directory}".') except ValueError: + # Can't do a storage check in this case, so we bravely continue. + file_size = 0 print_error('Could not determine the file size in advance.') else: - # TODO: Progress bar print_error('Downloading...') - download(local_path, urlstring, source) + download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024) except KeyboardInterrupt: print_error('Download aborted by user.') sys.exit(1) |