From dd94e6f1cf767d2ee85df8b42e16bc92537c82ef Mon Sep 17 00:00:00 2001 From: erkin Date: Tue, 11 May 2021 14:47:12 +0300 Subject: T3356: remote: Add authentication support Add docstrings --- python/vyos/remote.py | 112 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 91 insertions(+), 21 deletions(-) (limited to 'python/vyos') diff --git a/python/vyos/remote.py b/python/vyos/remote.py index f683a6d5a..52d234d4a 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -43,6 +43,7 @@ class InteractivePolicy(MissingHostKeyPolicy): else: raise SSHException(f"Cannot connect to unknown host '{hostname}'.") + ## FTP routines def transfer_ftp(mode, local_path, hostname, remote_path,\ username='anonymous', password='', port=21, source=None): @@ -73,6 +74,7 @@ def download_ftp(*args, **kwargs): def get_ftp_file_size(*args, **kwargs): return transfer_ftp('size', None, *args, **kwargs) + ## SFTP/SCP routines def transfer_sftp(mode, local_path, hostname, remote_path,\ username=None, password=None, port=22, source=None): @@ -109,6 +111,7 @@ def download_sftp(*args, **kwargs): def get_sftp_file_size(*args, **kwargs): return transfer_sftp('size', None, *args, **kwargs) + ## TFTP routines def upload_tftp(local_path, hostname, remote_path, port=69, source=None): source_option = f'--interface {source}' if source else '' @@ -125,15 +128,39 @@ def download_tftp(local_path, hostname, remote_path, port=69, source=None): # get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, # as TFTP does not specify a SIZE command. + ## HTTP(S) routines -def download_http(urlstring, local_path): +def install_request_opener(urlstring, username, password): + """ + Take`username` and `password` strings and install the appropriate + password manager to `urllib.request.urlopen()` for the given `urlstring`. + """ + manager = urllib.request.HTTPPasswordMgrWithDefaultRealm() + manager.add_password(None, urlstring, username, password) + urllib.request.install_opener(urllib.request.build_opener(manager)) + +# upload_http() is unimplemented. + +def download_http(urlstring, local_path, username=None, password=None): + """ + Download the file from from `urlstring` to `local_path`. + Optionally takes `username` and `password` for authentication. + """ request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) + if username: + install_request_opener(urlstring, username, password) with open(local_path, 'wb') as file: with urllib.request.urlopen(request) as response: file.write(response.read()) -def get_http_file_size(urlstring): +def get_http_file_size(urlstring, username=None, password=None): + """ + Return the size of the file from `urlstring` in terms of number of bytes. + Optionally takes `username` and `password` for authentication. + """ request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) + if username: + install_request_opener(urlstring, username, password) with urllib.request.urlopen(request) as response: size = response.getheader('Content-Length') if size: @@ -142,59 +169,97 @@ def get_http_file_size(urlstring): else: raise ValueError('Failed to receive file size from HTTP server.') + # Dynamic dispatchers -def download(local_path, urlstring, source=None): +def download(local_path, urlstring, authentication=None, source=None): """ - Dispatch the appropriate download function for the given URL and save to local path. + Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. + Optionally takes a `source` address (not valid for HTTP(S)) and an `authentication` tuple + in the form of `(username, password)`. + Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. """ url = urllib.parse.urlparse(urlstring) + if authentication: + username, password = authentication + else: + username, password = url.username, url.password + if url.scheme == 'http' or url.scheme == 'https': if source: print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr) - download_http(urlstring, local_path) + download_http(urlstring, local_path, username, password) elif url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - download_ftp(local_path, url.hostname, url.path, username, url.password, source=source) + username = username if username else 'anonymous' + download_ftp(local_path, url.hostname, url.path, username, password, source=source) elif url.scheme == 'sftp' or url.scheme == 'scp': - download_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source) + download_sftp(local_path, url.hostname, url.path, username, password, source=source) elif url.scheme == 'tftp': download_tftp(local_path, url.hostname, url.path, source=source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def upload(local_path, urlstring, source=None): +def upload(local_path, urlstring, authentication=None, source=None): """ Dispatch the appropriate upload function for the given URL and upload from local path. + Optionally takes a `source` address and an `authentication` tuple + in the form of `(username, password)`. + `authentication` takes precedence over credentials in `urlstring`. + Supports FTP, SFTP, SCP (through SFTP) and TFTP. """ url = urllib.parse.urlparse(urlstring) + if authentication: + username, password = authentication + else: + username, password = url.username, url.password + if url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - upload_ftp(local_path, url.hostname, url.path, username, url.password, source=source) + username = username if username else 'anonymous' + upload_ftp(local_path, url.hostname, url.path, username, password, source=source) elif url.scheme == 'sftp' or url.scheme == 'scp': - upload_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source) + upload_sftp(local_path, url.hostname, url.path, username, password, source=source) elif url.scheme == 'tftp': upload_tftp(local_path, url.hostname, url.path, source=source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_file_size(urlstring, source=None): +def get_remote_file_size(urlstring, authentication=None, source=None): """ - Return the size of the remote file in bytes. + Dispatch the appropriate function to return the size of the remote file from `urlstring` + in terms of number of bytes. + Optionally takes a `source` address (not valid for HTTP(S)) and an `authentication` tuple + in the form of `(username, password)`. + `authentication` takes precedence over credentials in `urlstring`. + Supports HTTP, HTTPS, FTP and SFTP (through SFTP). """ url = urllib.parse.urlparse(urlstring) + if authentication: + username, password = authentication + else: + username, password = url.username, url.password + if url.scheme == 'http' or url.scheme == 'https': - return get_http_file_size(urlstring) + return get_http_file_size(urlstring, authentication) elif url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - return get_ftp_file_size(url.hostname, url.path, username, url.password, source=source) + username = username if username else 'anonymous' + return get_ftp_file_size(url.hostname, url.path, username, password, source=source) elif url.scheme == 'sftp' or url.scheme == 'scp': - return get_sftp_file_size(url.hostname, url.path, url.username, url.password, source=source) + return get_sftp_file_size(url.hostname, url.path, username, password, source=source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_config(urlstring, source=None): +def get_remote_file_size_maybe(urlstring, authentication=None, source=None): + """ + Passes arguments to `get_remote_file_size()` but returns 0 if it fails. + Intended to be used in shell scripts only. + """ + try: + return get_remote_file_size(urlstring, authentication, source) + except ValueError: + return 0 + +def get_remote_config(urlstring, authentication=None, source=None): """ - Download remote (config) file and return the contents. + Download remote (config) file from `urlstring` and return the contents as a string. Args: remote file URI: scp://[:]@/ @@ -203,11 +268,16 @@ def get_remote_config(urlstring, source=None): https:/// ftp://[[:]@]/ tftp:/// + authentication tuple (optional): + (, ) + source address (optional): + + """ url = urllib.parse.urlparse(urlstring) temp = tempfile.NamedTemporaryFile(delete=False).name try: - download(temp, urlstring, source) + download(temp, urlstring, authentication, source) with open(temp, 'r') as file: return file.read() finally: -- cgit v1.2.3 From ace6cc3b51659f62d581ff785dee9a9888cc33e4 Mon Sep 17 00:00:00 2001 From: erkin Date: Tue, 18 May 2021 14:02:08 +0300 Subject: T3356: remote: Add friendly download procedure for user-facing scripts --- python/vyos/remote.py | 61 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 13 deletions(-) (limited to 'python/vyos') diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 52d234d4a..aef4dda7e 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -15,6 +15,7 @@ from ftplib import FTP import os +import shutil import socket import sys import tempfile @@ -28,14 +29,23 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') +def print_error(str): + """ + Used for warnings and out-of-band messages to avoid mangling precious + stdout output. + """ + sys.stderr.write(str) + sys.stderr.write('\n') + sys.stderr.flush() + class InteractivePolicy(MissingHostKeyPolicy): """ Policy for interactively querying the user on whether to proceed with SSH connections to unknown hosts. """ def missing_host_key(self, client, hostname, key): - print(f"Host '{hostname}' not found in known hosts.") - print('Fingerprint: ' + key.get_fingerprint().hex()) + print_error(f"Host '{hostname}' not found in known hosts.") + print_error('Fingerprint: ' + key.get_fingerprint().hex()) if ask_yes_no('Do you wish to continue?'): if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): client._host_keys.add(hostname, key.get_name(), key) @@ -186,7 +196,7 @@ def download(local_path, urlstring, authentication=None, source=None): if url.scheme == 'http' or url.scheme == 'https': if source: - print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr) + print_error('Warning: Custom source address not supported for HTTP connections.') download_http(urlstring, local_path, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' @@ -247,16 +257,6 @@ def get_remote_file_size(urlstring, authentication=None, source=None): else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_file_size_maybe(urlstring, authentication=None, source=None): - """ - Passes arguments to `get_remote_file_size()` but returns 0 if it fails. - Intended to be used in shell scripts only. - """ - try: - return get_remote_file_size(urlstring, authentication, source) - except ValueError: - return 0 - def get_remote_config(urlstring, authentication=None, source=None): """ Download remote (config) file from `urlstring` and return the contents as a string. @@ -282,3 +282,38 @@ def get_remote_config(urlstring, authentication=None, source=None): return file.read() finally: os.remove(temp) + +def friendly_download(local_path, urlstring, authentication=None, source=None): + """ + Intended to be called from interactive, user-facing scripts. + """ + destination_directory = os.path.dirname(local_path) + free_space = shutil.disk_usage(destination_directory).free + try: + try: + file_size = get_remote_file_size(urlstring, authentication, source) + if file_size < 1024 * 1024: + print_error(f'The file is {file_size / 1024.0:.3f} KiB.') + else: + print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') + if file_size > free_space: + raise OSError(f'Not enough disk space available in "{destination_directory}".') + except ValueError: + print_error('Could not determine the file size in advance.') + else: + # TODO: Progress bar + print_error('Downloading...') + download(local_path, urlstring, authentication, source) + except KeyboardInterrupt: + print_error('Download aborted by user.') + sys.exit(1) + except: + import traceback + # There are a myriad different reasons a download could fail. + # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) + # We omit the scary stack trace but print the error nevertheless. + print_error(f'Failed to download {urlstring}.') + traceback.print_exception(*sys.exc_info()[:2], None) + sys.exit(1) + else: + print_error('Download complete.') -- cgit v1.2.3 From 0aa64e0c260d6da6cf6050a5679695c63672ee74 Mon Sep 17 00:00:00 2001 From: erkin Date: Tue, 18 May 2021 15:51:16 +0300 Subject: T3356: remote: Read username and password from environment variables --- python/vyos/remote.py | 77 +++++++++++++++++++++++++-------------------------- 1 file changed, 38 insertions(+), 39 deletions(-) (limited to 'python/vyos') diff --git a/python/vyos/remote.py b/python/vyos/remote.py index aef4dda7e..7d371b3c0 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -20,7 +20,7 @@ import socket import sys import tempfile import urllib.parse -import urllib.request +import urllib.request as urlreq from vyos.util import cmd, ask_yes_no from vyos.version import get_version @@ -38,6 +38,18 @@ def print_error(str): sys.stderr.write('\n') sys.stderr.flush() +def get_authentication_variables(default_username=None, default_password=None): + """ + Returns the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and + returns the defaults provided if environment variables are empty or nonexistent. + """ + username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') + # Fall back to defaults if the username variable doesn't exist or is an empty string. + if not username: + return (default_username, default_password) + else: + return (username, password) + class InteractivePolicy(MissingHostKeyPolicy): """ Policy for interactively querying the user on whether to proceed with @@ -145,9 +157,9 @@ def install_request_opener(urlstring, username, password): Take`username` and `password` strings and install the appropriate password manager to `urllib.request.urlopen()` for the given `urlstring`. """ - manager = urllib.request.HTTPPasswordMgrWithDefaultRealm() + manager = urlreq.HTTPPasswordMgrWithDefaultRealm() manager.add_password(None, urlstring, username, password) - urllib.request.install_opener(urllib.request.build_opener(manager)) + urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager))) # upload_http() is unimplemented. @@ -156,11 +168,11 @@ def download_http(urlstring, local_path, username=None, password=None): Download the file from from `urlstring` to `local_path`. Optionally takes `username` and `password` for authentication. """ - request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) + request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) if username: install_request_opener(urlstring, username, password) with open(local_path, 'wb') as file: - with urllib.request.urlopen(request) as response: + with urlreq.urlopen(request) as response: file.write(response.read()) def get_http_file_size(urlstring, username=None, password=None): @@ -168,10 +180,10 @@ def get_http_file_size(urlstring, username=None, password=None): Return the size of the file from `urlstring` in terms of number of bytes. Optionally takes `username` and `password` for authentication. """ - request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) + request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) if username: install_request_opener(urlstring, username, password) - with urllib.request.urlopen(request) as response: + with urlreq.urlopen(request) as response: size = response.getheader('Content-Length') if size: return int(size) @@ -181,18 +193,15 @@ def get_http_file_size(urlstring, username=None, password=None): # Dynamic dispatchers -def download(local_path, urlstring, authentication=None, source=None): +def download(local_path, urlstring, source=None): """ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. - Optionally takes a `source` address (not valid for HTTP(S)) and an `authentication` tuple - in the form of `(username, password)`. + Optionally takes a `source` address (not valid for HTTP(S)). Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. + Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) - if authentication: - username, password = authentication - else: - username, password = url.username, url.password + username, password = get_authentication_variables(url.username, url.password) if url.scheme == 'http' or url.scheme == 'https': if source: @@ -208,19 +217,15 @@ def download(local_path, urlstring, authentication=None, source=None): else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def upload(local_path, urlstring, authentication=None, source=None): +def upload(local_path, urlstring, source=None): """ Dispatch the appropriate upload function for the given URL and upload from local path. - Optionally takes a `source` address and an `authentication` tuple - in the form of `(username, password)`. - `authentication` takes precedence over credentials in `urlstring`. + Optionally takes a `source` address. Supports FTP, SFTP, SCP (through SFTP) and TFTP. + Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) - if authentication: - username, password = authentication - else: - username, password = url.username, url.password + username, password = get_authentication_variables(url.username, url.password) if url.scheme == 'ftp': username = username if username else 'anonymous' @@ -232,23 +237,19 @@ def upload(local_path, urlstring, authentication=None, source=None): else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_file_size(urlstring, authentication=None, source=None): +def get_remote_file_size(urlstring, source=None): """ Dispatch the appropriate function to return the size of the remote file from `urlstring` in terms of number of bytes. - Optionally takes a `source` address (not valid for HTTP(S)) and an `authentication` tuple - in the form of `(username, password)`. - `authentication` takes precedence over credentials in `urlstring`. + Optionally takes a `source` address (not valid for HTTP(S)). Supports HTTP, HTTPS, FTP and SFTP (through SFTP). + Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) - if authentication: - username, password = authentication - else: - username, password = url.username, url.password + username, password = get_authentication_variables(url.username, url.password) if url.scheme == 'http' or url.scheme == 'https': - return get_http_file_size(urlstring, authentication) + return get_http_file_size(urlstring, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' return get_ftp_file_size(url.hostname, url.path, username, password, source=source) @@ -257,7 +258,7 @@ def get_remote_file_size(urlstring, authentication=None, source=None): else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def get_remote_config(urlstring, authentication=None, source=None): +def get_remote_config(urlstring, source=None): """ Download remote (config) file from `urlstring` and return the contents as a string. Args: @@ -268,8 +269,6 @@ def get_remote_config(urlstring, authentication=None, source=None): https:/// ftp://[[:]@]/ tftp:/// - authentication tuple (optional): - (, ) source address (optional): @@ -277,21 +276,21 @@ def get_remote_config(urlstring, authentication=None, source=None): url = urllib.parse.urlparse(urlstring) temp = tempfile.NamedTemporaryFile(delete=False).name try: - download(temp, urlstring, authentication, source) + download(temp, urlstring, source) with open(temp, 'r') as file: return file.read() finally: os.remove(temp) -def friendly_download(local_path, urlstring, authentication=None, source=None): +def friendly_download(local_path, urlstring, source=None): """ Intended to be called from interactive, user-facing scripts. """ destination_directory = os.path.dirname(local_path) - free_space = shutil.disk_usage(destination_directory).free try: + free_space = shutil.disk_usage(destination_directory).free try: - file_size = get_remote_file_size(urlstring, authentication, source) + file_size = get_remote_file_size(urlstring, source) if file_size < 1024 * 1024: print_error(f'The file is {file_size / 1024.0:.3f} KiB.') else: @@ -303,7 +302,7 @@ def friendly_download(local_path, urlstring, authentication=None, source=None): else: # TODO: Progress bar print_error('Downloading...') - download(local_path, urlstring, authentication, source) + download(local_path, urlstring, source) except KeyboardInterrupt: print_error('Download aborted by user.') sys.exit(1) -- cgit v1.2.3 From f3072a64a8075f4f5a730983b0c7752396b5c0d7 Mon Sep 17 00:00:00 2001 From: erkin Date: Sat, 29 May 2021 13:24:54 +0300 Subject: T3356: Add progressbars to FTP transfers Allow ports to be specified in URL strings --- python/vyos/remote.py | 193 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 133 insertions(+), 60 deletions(-) (limited to 'python/vyos') diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 7d371b3c0..6c98f3219 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -14,6 +14,7 @@ # License along with this library. If not, see . from ftplib import FTP +import math import os import shutil import socket @@ -29,77 +30,141 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') -def print_error(str): +class InteractivePolicy(MissingHostKeyPolicy): + """ + Policy for interactively querying the user on whether to proceed with + SSH connections to unknown hosts. + """ + def missing_host_key(self, client, hostname, key): + print_error(f"Host '{hostname}' not found in known hosts.") + print_error('Fingerprint: ' + key.get_fingerprint().hex()) + if ask_yes_no('Do you wish to continue?'): + if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): + client._host_keys.add(hostname, key.get_name(), key) + client.save_host_keys(client._host_keys_filename) + else: + raise SSHException(f"Cannot connect to unknown host '{hostname}'.") + + +## Helper routines +def print_error(str='', end='\n'): """ + Print `str` to stderr, terminated with `end`. Used for warnings and out-of-band messages to avoid mangling precious - stdout output. + stdout output. """ sys.stderr.write(str) - sys.stderr.write('\n') + sys.stderr.write(end) sys.stderr.flush() +def make_progressbar(increment: float): + """ + Return a generator that displays progressbar whose length is determined + by the width of the terminal with every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0 + """ + col, _ = shutil.get_terminal_size() + # Try for 20 columns if the terminal is too narrow. Let it overflow. + col = max(col - 15, 20) + total = 0.0 + while True: + length = min(round(total * col), col) + percentage = str(math.floor(total * 100)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + if total >= 1.0: + # Print a newline so that the subsequent prints don't overwrite the bar. + print_error() + break + # Add a new increment with each iteration. + yield + total = min(total + increment, 1.0) + # Ignore further calls. + while True: + yield + def get_authentication_variables(default_username=None, default_password=None): """ - Returns the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and - returns the defaults provided if environment variables are empty or nonexistent. + Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and + return the defaults provided if environment variables are empty or nonexistent. """ username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') # Fall back to defaults if the username variable doesn't exist or is an empty string. + # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, + # as we want the username and the password to have the same behaviour. if not username: return (default_username, default_password) else: return (username, password) -class InteractivePolicy(MissingHostKeyPolicy): +def get_port_from_url(url): """ - Policy for interactively querying the user on whether to proceed with - SSH connections to unknown hosts. + Return the port number from the given `url` named tuple, fall back to + the default if there isn't one. """ - def missing_host_key(self, client, hostname, key): - print_error(f"Host '{hostname}' not found in known hosts.") - print_error('Fingerprint: ' + key.get_fingerprint().hex()) - if ask_yes_no('Do you wish to continue?'): - if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): - client._host_keys.add(hostname, key.get_name(), key) - client.save_host_keys(client._host_keys_filename) - else: - raise SSHException(f"Cannot connect to unknown host '{hostname}'.") + defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\ + "ssh": 22, "scp": 22, "sftp": 22} + if url.port: + return url.port + else: + return defaults[url.scheme] ## FTP routines -def transfer_ftp(mode, local_path, hostname, remote_path,\ - username='anonymous', password='', port=21, source=None): +def upload_ftp(local_path, hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source=None, progressbar=False): + size = os.path.getsize(local_path) + blocksize = 8192 with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) - if mode == 'upload': - with open(local_path, 'rb') as file: - conn.storbinary(f'STOR {remote_path}', file) - elif mode == 'download': - with open(local_path, 'wb') as file: - conn.retrbinary(f'RETR {remote_path}', file.write) - elif mode == 'size': - size = conn.size(remote_path) - if size: - return size + with open(local_path, 'rb') as file: + if progressbar and size: + progress = make_progressbar(blocksize / size) + next(progress) + callback = lambda block: next(progress) else: - # SIZE is an extension to the FTP specification, although it's extremely common. - raise ValueError('Failed to receive file size from FTP server. \ - Perhaps the server does not implement the SIZE command?') + callback = None + conn.storbinary(f'STOR {remote_path}', file, blocksize, callback) -def upload_ftp(*args, **kwargs): - transfer_ftp('upload', *args, **kwargs) - -def download_ftp(*args, **kwargs): - transfer_ftp('download', *args, **kwargs) +def download_ftp(local_path, hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source=None, progressbar=False): + blocksize = 8192 + with FTP(source_address=source) as conn: + conn.connect(hostname, port) + conn.login(username, password) + size = conn.size(remote_path) + with open(local_path, 'wb') as file: + # No progressbar if we can't determine the size. + if progressbar and size: + progress = make_progressbar(blocksize / size) + next(progress) + callback = lambda block: (file.write(block), next(progress)) + else: + callback = file.write + conn.retrbinary(f'RETR {remote_path}', callback, blocksize) -def get_ftp_file_size(*args, **kwargs): - return transfer_ftp('size', None, *args, **kwargs) +def get_ftp_file_size(hostname, remote_path,\ + username='anonymous', password='', port=21,\ + source=None): + with FTP(source_address=source) as conn: + conn.connect(hostname, port) + conn.login(username, password) + size = conn.size(remote_path) + if size: + return size + else: + # SIZE is an extension to the FTP specification, although it's extremely common. + raise ValueError('Failed to receive file size from FTP server. \ + Perhaps the server does not implement the SIZE command?') ## SFTP/SCP routines def transfer_sftp(mode, local_path, hostname, remote_path,\ - username=None, password=None, port=22, source=None): + username=None, password=None, port=22,\ + source=None, progressbar=False): sock = None if source: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -135,27 +200,29 @@ def get_sftp_file_size(*args, **kwargs): ## TFTP routines -def upload_tftp(local_path, hostname, remote_path, port=69, source=None): +def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'rb') as file: - cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\ + cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\ stderr=None, input=file.read()).encode() -def download_tftp(local_path, hostname, remote_path, port=69, source=None): +def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'wb') as file: - file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\ + file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ stderr=None).encode()) # get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, -# as TFTP does not specify a SIZE command. +# as TFTP does not specify a SIZE command. ## HTTP(S) routines def install_request_opener(urlstring, username, password): """ - Take`username` and `password` strings and install the appropriate - password manager to `urllib.request.urlopen()` for the given `urlstring`. + Take `username` and `password` strings and install the appropriate + password manager to `urllib.request.urlopen()` for the given `urlstring`. """ manager = urlreq.HTTPPasswordMgrWithDefaultRealm() manager.add_password(None, urlstring, username, password) @@ -163,7 +230,7 @@ def install_request_opener(urlstring, username, password): # upload_http() is unimplemented. -def download_http(urlstring, local_path, username=None, password=None): +def download_http(urlstring, local_path, username=None, password=None, progressbar=False): """ Download the file from from `urlstring` to `local_path`. Optionally takes `username` and `password` for authentication. @@ -193,7 +260,7 @@ def get_http_file_size(urlstring, username=None, password=None): # Dynamic dispatchers -def download(local_path, urlstring, source=None): +def download(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. Optionally takes a `source` address (not valid for HTTP(S)). @@ -202,6 +269,7 @@ def download(local_path, urlstring, source=None): """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': if source: @@ -209,15 +277,15 @@ def download(local_path, urlstring, source=None): download_http(urlstring, local_path, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' - download_ftp(local_path, url.hostname, url.path, username, password, source=source) + download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': - download_sftp(local_path, url.hostname, url.path, username, password, source=source) + download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': - download_tftp(local_path, url.hostname, url.path, source=source) + download_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') -def upload(local_path, urlstring, source=None): +def upload(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate upload function for the given URL and upload from local path. Optionally takes a `source` address. @@ -226,14 +294,15 @@ def upload(local_path, urlstring, source=None): """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) if url.scheme == 'ftp': username = username if username else 'anonymous' - upload_ftp(local_path, url.hostname, url.path, username, password, source=source) + upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': - upload_sftp(local_path, url.hostname, url.path, username, password, source=source) + upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': - upload_tftp(local_path, url.hostname, url.path, source=source) + upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') @@ -247,14 +316,15 @@ def get_remote_file_size(urlstring, source=None): """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) + port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': return get_http_file_size(urlstring, username, password) elif url.scheme == 'ftp': username = username if username else 'anonymous' - return get_ftp_file_size(url.hostname, url.path, username, password, source=source) + return get_ftp_file_size(url.hostname, url.path, username, password, port, source) elif url.scheme == 'sftp' or url.scheme == 'scp': - return get_sftp_file_size(url.hostname, url.path, username, password, source=source) + return get_sftp_file_size(url.hostname, url.path, username, password, port, source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') @@ -284,6 +354,8 @@ def get_remote_config(urlstring, source=None): def friendly_download(local_path, urlstring, source=None): """ + Download from `urlstring` to `local_path` in an informative way. + Checks the storage space before attempting download. Intended to be called from interactive, user-facing scripts. """ destination_directory = os.path.dirname(local_path) @@ -298,11 +370,12 @@ def friendly_download(local_path, urlstring, source=None): if file_size > free_space: raise OSError(f'Not enough disk space available in "{destination_directory}".') except ValueError: + # Can't do a storage check in this case, so we bravely continue. + file_size = 0 print_error('Could not determine the file size in advance.') else: - # TODO: Progress bar print_error('Downloading...') - download(local_path, urlstring, source) + download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024) except KeyboardInterrupt: print_error('Download aborted by user.') sys.exit(1) -- cgit v1.2.3 From 0856dd2f2584d2c41b8ddf70a5a2751ee446be5a Mon Sep 17 00:00:00 2001 From: erkin Date: Sun, 30 May 2021 10:41:16 +0300 Subject: T3356: Add progressbars to SFTP and HTTP transfers --- python/vyos/remote.py | 98 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 38 deletions(-) (limited to 'python/vyos') diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 6c98f3219..0bc2ee7f8 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -28,7 +28,10 @@ from vyos.version import get_version from paramiko import SSHClient, SSHException, MissingHostKeyPolicy -known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') + +# This is a hardcoded path and no environment variable can change it. +KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts') +CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): """ @@ -57,28 +60,40 @@ def print_error(str='', end='\n'): sys.stderr.write(end) sys.stderr.flush() -def make_progressbar(increment: float): +def make_progressbar(): """ - Return a generator that displays progressbar whose length is determined - by the width of the terminal with every iteration. - First call displays it at 0% and every subsequent iteration displays it - at `increment` increments where 0.0 < `increment` < 1.0 + Make a procedure that takes two arguments `done` and `total` and prints a + progressbar based on the ratio thereof, whose length is determined by the + width of the terminal. """ col, _ = shutil.get_terminal_size() - # Try for 20 columns if the terminal is too narrow. Let it overflow. col = max(col - 15, 20) - total = 0.0 - while True: - length = min(round(total * col), col) - percentage = str(math.floor(total * 100)).rjust(3) - print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') - if total >= 1.0: - # Print a newline so that the subsequent prints don't overwrite the bar. + def print_progressbar(done, total): + if done <= total: + increment = total / col + length = math.ceil(done / increment) + percentage = str(math.ceil(100 * done / total)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + # Print a newline so that the subsequent prints don't overwrite the full bar. + if done == total: print_error() - break - # Add a new increment with each iteration. + return print_progressbar + +def make_incremental_progressbar(increment: float): + """ + Make a generator that displays a progressbar that grows monotonically with + every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0. + Intended for FTP and HTTP transfers with stateless callbacks. + """ + print_progressbar = make_progressbar() + total = 0.0 + while total < 1.0: + print_progressbar(total, 1.0) yield - total = min(total + increment, 1.0) + total += increment + print_progressbar(1, 1) # Ignore further calls. while True: yield @@ -115,23 +130,21 @@ def upload_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ source=None, progressbar=False): size = os.path.getsize(local_path) - blocksize = 8192 with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) with open(local_path, 'rb') as file: if progressbar and size: - progress = make_progressbar(blocksize / size) + progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) callback = lambda block: next(progress) else: callback = None - conn.storbinary(f'STOR {remote_path}', file, blocksize, callback) + conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback) def download_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ source=None, progressbar=False): - blocksize = 8192 with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) @@ -139,12 +152,12 @@ def download_ftp(local_path, hostname, remote_path,\ with open(local_path, 'wb') as file: # No progressbar if we can't determine the size. if progressbar and size: - progress = make_progressbar(blocksize / size) + progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) callback = lambda block: (file.write(block), next(progress)) else: callback = file.write - conn.retrbinary(f'RETR {remote_path}', callback, blocksize) + conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE) def get_ftp_file_size(hostname, remote_path,\ username='anonymous', password='', port=21,\ @@ -170,18 +183,19 @@ def transfer_sftp(mode, local_path, hostname, remote_path,\ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind((source, 0)) sock.connect((hostname, port)) + callback = make_progressbar() if progressbar else None try: with SSHClient() as ssh: ssh.load_system_host_keys() - if os.path.exists(known_hosts_file): - ssh.load_host_keys(known_hosts_file) + if os.path.exists(KNOWN_HOSTS_FILE): + ssh.load_host_keys(KNOWN_HOSTS_FILE) ssh.set_missing_host_key_policy(InteractivePolicy()) ssh.connect(hostname, port, username, password, sock=sock) with ssh.open_sftp() as sftp: if mode == 'upload': - sftp.put(local_path, remote_path) + sftp.put(local_path, remote_path, callback=callback) elif mode == 'download': - sftp.get(remote_path, local_path) + sftp.get(remote_path, local_path, callback=callback) elif mode == 'size': return sftp.stat(remote_path).st_size finally: @@ -209,6 +223,7 @@ def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progres def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + # Not really applicable but we pass it for the sake of uniformity. progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'wb') as file: file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ @@ -230,7 +245,7 @@ def install_request_opener(urlstring, username, password): # upload_http() is unimplemented. -def download_http(urlstring, local_path, username=None, password=None, progressbar=False): +def download_http(local_path, urlstring, username=None, password=None, progressbar=False): """ Download the file from from `urlstring` to `local_path`. Optionally takes `username` and `password` for authentication. @@ -238,9 +253,19 @@ def download_http(urlstring, local_path, username=None, password=None, progressb request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) if username: install_request_opener(urlstring, username, password) - with open(local_path, 'wb') as file: - with urlreq.urlopen(request) as response: - file.write(response.read()) + with open(local_path, 'wb') as file, urlreq.urlopen(request) as response: + size = response.getheader('Content-Length') + if progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / int(size)) + next(progress) + for chunk in iter(lambda: response.read(CHUNK_SIZE), b''): + file.write(chunk) + next(progress) + next(progress) + # If we can't determine the size or if a progress bar wasn't requested, + # we can let `shutil` take care of the copying. + else: + shutil.copyfileobj(response, file) def get_http_file_size(urlstring, username=None, password=None): """ @@ -274,7 +299,7 @@ def download(local_path, urlstring, source=None, progressbar=False): if url.scheme == 'http' or url.scheme == 'https': if source: print_error('Warning: Custom source address not supported for HTTP connections.') - download_http(urlstring, local_path, username, password) + download_http(local_path, urlstring, username, password, progressbar) elif url.scheme == 'ftp': username = username if username else 'anonymous' download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) @@ -333,12 +358,9 @@ def get_remote_config(urlstring, source=None): Download remote (config) file from `urlstring` and return the contents as a string. Args: remote file URI: - scp://[:]@/ - sftp://[:]@/ - http:/// - https:/// - ftp://[[:]@]/ - tftp:/// + tftp://[:]/ + http[s]://[:]/ + [scp|sftp|ftp]://[[:]@][:port]/ source address (optional): -- cgit v1.2.3