summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/vyos/remote.py98
1 files changed, 60 insertions, 38 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 6c98f3219..0bc2ee7f8 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -28,7 +28,10 @@ from vyos.version import get_version
from paramiko import SSHClient, SSHException, MissingHostKeyPolicy
-known_hosts_file = os.path.expanduser('~/.ssh/known_hosts')
+
+# This is a hardcoded path and no environment variable can change it.
+KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts')
+CHUNK_SIZE = 8192
class InteractivePolicy(MissingHostKeyPolicy):
"""
@@ -57,28 +60,40 @@ def print_error(str='', end='\n'):
sys.stderr.write(end)
sys.stderr.flush()
-def make_progressbar(increment: float):
+def make_progressbar():
"""
- Return a generator that displays progressbar whose length is determined
- by the width of the terminal with every iteration.
- First call displays it at 0% and every subsequent iteration displays it
- at `increment` increments where 0.0 < `increment` < 1.0
+ Make a procedure that takes two arguments `done` and `total` and prints a
+ progressbar based on the ratio thereof, whose length is determined by the
+ width of the terminal.
"""
col, _ = shutil.get_terminal_size()
- # Try for 20 columns if the terminal is too narrow. Let it overflow.
col = max(col - 15, 20)
- total = 0.0
- while True:
- length = min(round(total * col), col)
- percentage = str(math.floor(total * 100)).rjust(3)
- print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r')
- if total >= 1.0:
- # Print a newline so that the subsequent prints don't overwrite the bar.
+ def print_progressbar(done, total):
+ if done <= total:
+ increment = total / col
+ length = math.ceil(done / increment)
+ percentage = str(math.ceil(100 * done / total)).rjust(3)
+ print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r')
+ # Print a newline so that the subsequent prints don't overwrite the full bar.
+ if done == total:
print_error()
- break
- # Add a new increment with each iteration.
+ return print_progressbar
+
+def make_incremental_progressbar(increment: float):
+ """
+ Make a generator that displays a progressbar that grows monotonically with
+ every iteration.
+ First call displays it at 0% and every subsequent iteration displays it
+ at `increment` increments where 0.0 < `increment` < 1.0.
+ Intended for FTP and HTTP transfers with stateless callbacks.
+ """
+ print_progressbar = make_progressbar()
+ total = 0.0
+ while total < 1.0:
+ print_progressbar(total, 1.0)
yield
- total = min(total + increment, 1.0)
+ total += increment
+ print_progressbar(1, 1)
# Ignore further calls.
while True:
yield
@@ -115,23 +130,21 @@ def upload_ftp(local_path, hostname, remote_path,\
username='anonymous', password='', port=21,\
source=None, progressbar=False):
size = os.path.getsize(local_path)
- blocksize = 8192
with FTP(source_address=source) as conn:
conn.connect(hostname, port)
conn.login(username, password)
with open(local_path, 'rb') as file:
if progressbar and size:
- progress = make_progressbar(blocksize / size)
+ progress = make_incremental_progressbar(CHUNK_SIZE / size)
next(progress)
callback = lambda block: next(progress)
else:
callback = None
- conn.storbinary(f'STOR {remote_path}', file, blocksize, callback)
+ conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback)
def download_ftp(local_path, hostname, remote_path,\
username='anonymous', password='', port=21,\
source=None, progressbar=False):
- blocksize = 8192
with FTP(source_address=source) as conn:
conn.connect(hostname, port)
conn.login(username, password)
@@ -139,12 +152,12 @@ def download_ftp(local_path, hostname, remote_path,\
with open(local_path, 'wb') as file:
# No progressbar if we can't determine the size.
if progressbar and size:
- progress = make_progressbar(blocksize / size)
+ progress = make_incremental_progressbar(CHUNK_SIZE / size)
next(progress)
callback = lambda block: (file.write(block), next(progress))
else:
callback = file.write
- conn.retrbinary(f'RETR {remote_path}', callback, blocksize)
+ conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE)
def get_ftp_file_size(hostname, remote_path,\
username='anonymous', password='', port=21,\
@@ -170,18 +183,19 @@ def transfer_sftp(mode, local_path, hostname, remote_path,\
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((source, 0))
sock.connect((hostname, port))
+ callback = make_progressbar() if progressbar else None
try:
with SSHClient() as ssh:
ssh.load_system_host_keys()
- if os.path.exists(known_hosts_file):
- ssh.load_host_keys(known_hosts_file)
+ if os.path.exists(KNOWN_HOSTS_FILE):
+ ssh.load_host_keys(KNOWN_HOSTS_FILE)
ssh.set_missing_host_key_policy(InteractivePolicy())
ssh.connect(hostname, port, username, password, sock=sock)
with ssh.open_sftp() as sftp:
if mode == 'upload':
- sftp.put(local_path, remote_path)
+ sftp.put(local_path, remote_path, callback=callback)
elif mode == 'download':
- sftp.get(remote_path, local_path)
+ sftp.get(remote_path, local_path, callback=callback)
elif mode == 'size':
return sftp.stat(remote_path).st_size
finally:
@@ -209,6 +223,7 @@ def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progres
def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ # Not really applicable but we pass it for the sake of uniformity.
progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'wb') as file:
file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\
@@ -230,7 +245,7 @@ def install_request_opener(urlstring, username, password):
# upload_http() is unimplemented.
-def download_http(urlstring, local_path, username=None, password=None, progressbar=False):
+def download_http(local_path, urlstring, username=None, password=None, progressbar=False):
"""
Download the file from from `urlstring` to `local_path`.
Optionally takes `username` and `password` for authentication.
@@ -238,9 +253,19 @@ def download_http(urlstring, local_path, username=None, password=None, progressb
request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
if username:
install_request_opener(urlstring, username, password)
- with open(local_path, 'wb') as file:
- with urlreq.urlopen(request) as response:
- file.write(response.read())
+ with open(local_path, 'wb') as file, urlreq.urlopen(request) as response:
+ size = response.getheader('Content-Length')
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / int(size))
+ next(progress)
+ for chunk in iter(lambda: response.read(CHUNK_SIZE), b''):
+ file.write(chunk)
+ next(progress)
+ next(progress)
+ # If we can't determine the size or if a progress bar wasn't requested,
+ # we can let `shutil` take care of the copying.
+ else:
+ shutil.copyfileobj(response, file)
def get_http_file_size(urlstring, username=None, password=None):
"""
@@ -274,7 +299,7 @@ def download(local_path, urlstring, source=None, progressbar=False):
if url.scheme == 'http' or url.scheme == 'https':
if source:
print_error('Warning: Custom source address not supported for HTTP connections.')
- download_http(urlstring, local_path, username, password)
+ download_http(local_path, urlstring, username, password, progressbar)
elif url.scheme == 'ftp':
username = username if username else 'anonymous'
download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
@@ -333,12 +358,9 @@ def get_remote_config(urlstring, source=None):
Download remote (config) file from `urlstring` and return the contents as a string.
Args:
remote file URI:
- scp://<user>[:<passwd>]@<host>/<file>
- sftp://<user>[:<passwd>]@<host>/<file>
- http://<host>/<file>
- https://<host>/<file>
- ftp://[<user>[:<passwd>]@]<host>/<file>
- tftp://<host>/<file>
+ tftp://<host>[:<port>]/<file>
+ http[s]://<host>[:<port>]/<file>
+ [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file>
source address (optional):
<interface>
<IP address>