diff options
-rw-r--r-- | python/vyos/remote.py | 98 |
1 files changed, 60 insertions, 38 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 6c98f3219..0bc2ee7f8 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -28,7 +28,10 @@ from vyos.version import get_version from paramiko import SSHClient, SSHException, MissingHostKeyPolicy -known_hosts_file = os.path.expanduser('~/.ssh/known_hosts') + +# This is a hardcoded path and no environment variable can change it. +KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts') +CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): """ @@ -57,28 +60,40 @@ def print_error(str='', end='\n'): sys.stderr.write(end) sys.stderr.flush() -def make_progressbar(increment: float): +def make_progressbar(): """ - Return a generator that displays progressbar whose length is determined - by the width of the terminal with every iteration. - First call displays it at 0% and every subsequent iteration displays it - at `increment` increments where 0.0 < `increment` < 1.0 + Make a procedure that takes two arguments `done` and `total` and prints a + progressbar based on the ratio thereof, whose length is determined by the + width of the terminal. """ col, _ = shutil.get_terminal_size() - # Try for 20 columns if the terminal is too narrow. Let it overflow. col = max(col - 15, 20) - total = 0.0 - while True: - length = min(round(total * col), col) - percentage = str(math.floor(total * 100)).rjust(3) - print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') - if total >= 1.0: - # Print a newline so that the subsequent prints don't overwrite the bar. + def print_progressbar(done, total): + if done <= total: + increment = total / col + length = math.ceil(done / increment) + percentage = str(math.ceil(100 * done / total)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + # Print a newline so that the subsequent prints don't overwrite the full bar. + if done == total: print_error() - break - # Add a new increment with each iteration. + return print_progressbar + +def make_incremental_progressbar(increment: float): + """ + Make a generator that displays a progressbar that grows monotonically with + every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0. + Intended for FTP and HTTP transfers with stateless callbacks. + """ + print_progressbar = make_progressbar() + total = 0.0 + while total < 1.0: + print_progressbar(total, 1.0) yield - total = min(total + increment, 1.0) + total += increment + print_progressbar(1, 1) # Ignore further calls. while True: yield @@ -115,23 +130,21 @@ def upload_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ source=None, progressbar=False): size = os.path.getsize(local_path) - blocksize = 8192 with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) with open(local_path, 'rb') as file: if progressbar and size: - progress = make_progressbar(blocksize / size) + progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) callback = lambda block: next(progress) else: callback = None - conn.storbinary(f'STOR {remote_path}', file, blocksize, callback) + conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback) def download_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ source=None, progressbar=False): - blocksize = 8192 with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) @@ -139,12 +152,12 @@ def download_ftp(local_path, hostname, remote_path,\ with open(local_path, 'wb') as file: # No progressbar if we can't determine the size. if progressbar and size: - progress = make_progressbar(blocksize / size) + progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) callback = lambda block: (file.write(block), next(progress)) else: callback = file.write - conn.retrbinary(f'RETR {remote_path}', callback, blocksize) + conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE) def get_ftp_file_size(hostname, remote_path,\ username='anonymous', password='', port=21,\ @@ -170,18 +183,19 @@ def transfer_sftp(mode, local_path, hostname, remote_path,\ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind((source, 0)) sock.connect((hostname, port)) + callback = make_progressbar() if progressbar else None try: with SSHClient() as ssh: ssh.load_system_host_keys() - if os.path.exists(known_hosts_file): - ssh.load_host_keys(known_hosts_file) + if os.path.exists(KNOWN_HOSTS_FILE): + ssh.load_host_keys(KNOWN_HOSTS_FILE) ssh.set_missing_host_key_policy(InteractivePolicy()) ssh.connect(hostname, port, username, password, sock=sock) with ssh.open_sftp() as sftp: if mode == 'upload': - sftp.put(local_path, remote_path) + sftp.put(local_path, remote_path, callback=callback) elif mode == 'download': - sftp.get(remote_path, local_path) + sftp.get(remote_path, local_path, callback=callback) elif mode == 'size': return sftp.stat(remote_path).st_size finally: @@ -209,6 +223,7 @@ def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progres def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' + # Not really applicable but we pass it for the sake of uniformity. progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'wb') as file: file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ @@ -230,7 +245,7 @@ def install_request_opener(urlstring, username, password): # upload_http() is unimplemented. -def download_http(urlstring, local_path, username=None, password=None, progressbar=False): +def download_http(local_path, urlstring, username=None, password=None, progressbar=False): """ Download the file from from `urlstring` to `local_path`. Optionally takes `username` and `password` for authentication. @@ -238,9 +253,19 @@ def download_http(urlstring, local_path, username=None, password=None, progressb request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) if username: install_request_opener(urlstring, username, password) - with open(local_path, 'wb') as file: - with urlreq.urlopen(request) as response: - file.write(response.read()) + with open(local_path, 'wb') as file, urlreq.urlopen(request) as response: + size = response.getheader('Content-Length') + if progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / int(size)) + next(progress) + for chunk in iter(lambda: response.read(CHUNK_SIZE), b''): + file.write(chunk) + next(progress) + next(progress) + # If we can't determine the size or if a progress bar wasn't requested, + # we can let `shutil` take care of the copying. + else: + shutil.copyfileobj(response, file) def get_http_file_size(urlstring, username=None, password=None): """ @@ -274,7 +299,7 @@ def download(local_path, urlstring, source=None, progressbar=False): if url.scheme == 'http' or url.scheme == 'https': if source: print_error('Warning: Custom source address not supported for HTTP connections.') - download_http(urlstring, local_path, username, password) + download_http(local_path, urlstring, username, password, progressbar) elif url.scheme == 'ftp': username = username if username else 'anonymous' download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) @@ -333,12 +358,9 @@ def get_remote_config(urlstring, source=None): Download remote (config) file from `urlstring` and return the contents as a string. Args: remote file URI: - scp://<user>[:<passwd>]@<host>/<file> - sftp://<user>[:<passwd>]@<host>/<file> - http://<host>/<file> - https://<host>/<file> - ftp://[<user>[:<passwd>]@]<host>/<file> - tftp://<host>/<file> + tftp://<host>[:<port>]/<file> + http[s]://<host>[:<port>]/<file> + [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file> source address (optional): <interface> <IP address> |