summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/vyos/remote.py193
1 files changed, 133 insertions, 60 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 7d371b3c0..6c98f3219 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -14,6 +14,7 @@
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
from ftplib import FTP
+import math
import os
import shutil
import socket
@@ -29,77 +30,141 @@ from paramiko import SSHClient, SSHException, MissingHostKeyPolicy
known_hosts_file = os.path.expanduser('~/.ssh/known_hosts')
-def print_error(str):
+class InteractivePolicy(MissingHostKeyPolicy):
+ """
+ Policy for interactively querying the user on whether to proceed with
+ SSH connections to unknown hosts.
+ """
+ def missing_host_key(self, client, hostname, key):
+ print_error(f"Host '{hostname}' not found in known hosts.")
+ print_error('Fingerprint: ' + key.get_fingerprint().hex())
+ if ask_yes_no('Do you wish to continue?'):
+ if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
+ client._host_keys.add(hostname, key.get_name(), key)
+ client.save_host_keys(client._host_keys_filename)
+ else:
+ raise SSHException(f"Cannot connect to unknown host '{hostname}'.")
+
+
+## Helper routines
+def print_error(str='', end='\n'):
"""
+ Print `str` to stderr, terminated with `end`.
Used for warnings and out-of-band messages to avoid mangling precious
- stdout output.
+ stdout output.
"""
sys.stderr.write(str)
- sys.stderr.write('\n')
+ sys.stderr.write(end)
sys.stderr.flush()
+def make_progressbar(increment: float):
+ """
+ Return a generator that displays progressbar whose length is determined
+ by the width of the terminal with every iteration.
+ First call displays it at 0% and every subsequent iteration displays it
+ at `increment` increments where 0.0 < `increment` < 1.0
+ """
+ col, _ = shutil.get_terminal_size()
+ # Try for 20 columns if the terminal is too narrow. Let it overflow.
+ col = max(col - 15, 20)
+ total = 0.0
+ while True:
+ length = min(round(total * col), col)
+ percentage = str(math.floor(total * 100)).rjust(3)
+ print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r')
+ if total >= 1.0:
+ # Print a newline so that the subsequent prints don't overwrite the bar.
+ print_error()
+ break
+ # Add a new increment with each iteration.
+ yield
+ total = min(total + increment, 1.0)
+ # Ignore further calls.
+ while True:
+ yield
+
def get_authentication_variables(default_username=None, default_password=None):
"""
- Returns the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and
- returns the defaults provided if environment variables are empty or nonexistent.
+ Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and
+ return the defaults provided if environment variables are empty or nonexistent.
"""
username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD')
# Fall back to defaults if the username variable doesn't exist or is an empty string.
+ # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`,
+ # as we want the username and the password to have the same behaviour.
if not username:
return (default_username, default_password)
else:
return (username, password)
-class InteractivePolicy(MissingHostKeyPolicy):
+def get_port_from_url(url):
"""
- Policy for interactively querying the user on whether to proceed with
- SSH connections to unknown hosts.
+ Return the port number from the given `url` named tuple, fall back to
+ the default if there isn't one.
"""
- def missing_host_key(self, client, hostname, key):
- print_error(f"Host '{hostname}' not found in known hosts.")
- print_error('Fingerprint: ' + key.get_fingerprint().hex())
- if ask_yes_no('Do you wish to continue?'):
- if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
- client._host_keys.add(hostname, key.get_name(), key)
- client.save_host_keys(client._host_keys_filename)
- else:
- raise SSHException(f"Cannot connect to unknown host '{hostname}'.")
+ defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\
+ "ssh": 22, "scp": 22, "sftp": 22}
+ if url.port:
+ return url.port
+ else:
+ return defaults[url.scheme]
## FTP routines
-def transfer_ftp(mode, local_path, hostname, remote_path,\
- username='anonymous', password='', port=21, source=None):
+def upload_ftp(local_path, hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source=None, progressbar=False):
+ size = os.path.getsize(local_path)
+ blocksize = 8192
with FTP(source_address=source) as conn:
conn.connect(hostname, port)
conn.login(username, password)
- if mode == 'upload':
- with open(local_path, 'rb') as file:
- conn.storbinary(f'STOR {remote_path}', file)
- elif mode == 'download':
- with open(local_path, 'wb') as file:
- conn.retrbinary(f'RETR {remote_path}', file.write)
- elif mode == 'size':
- size = conn.size(remote_path)
- if size:
- return size
+ with open(local_path, 'rb') as file:
+ if progressbar and size:
+ progress = make_progressbar(blocksize / size)
+ next(progress)
+ callback = lambda block: next(progress)
else:
- # SIZE is an extension to the FTP specification, although it's extremely common.
- raise ValueError('Failed to receive file size from FTP server. \
- Perhaps the server does not implement the SIZE command?')
+ callback = None
+ conn.storbinary(f'STOR {remote_path}', file, blocksize, callback)
-def upload_ftp(*args, **kwargs):
- transfer_ftp('upload', *args, **kwargs)
-
-def download_ftp(*args, **kwargs):
- transfer_ftp('download', *args, **kwargs)
+def download_ftp(local_path, hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source=None, progressbar=False):
+ blocksize = 8192
+ with FTP(source_address=source) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ size = conn.size(remote_path)
+ with open(local_path, 'wb') as file:
+ # No progressbar if we can't determine the size.
+ if progressbar and size:
+ progress = make_progressbar(blocksize / size)
+ next(progress)
+ callback = lambda block: (file.write(block), next(progress))
+ else:
+ callback = file.write
+ conn.retrbinary(f'RETR {remote_path}', callback, blocksize)
-def get_ftp_file_size(*args, **kwargs):
- return transfer_ftp('size', None, *args, **kwargs)
+def get_ftp_file_size(hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source=None):
+ with FTP(source_address=source) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ size = conn.size(remote_path)
+ if size:
+ return size
+ else:
+ # SIZE is an extension to the FTP specification, although it's extremely common.
+ raise ValueError('Failed to receive file size from FTP server. \
+ Perhaps the server does not implement the SIZE command?')
## SFTP/SCP routines
def transfer_sftp(mode, local_path, hostname, remote_path,\
- username=None, password=None, port=22, source=None):
+ username=None, password=None, port=22,\
+ source=None, progressbar=False):
sock = None
if source:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -135,27 +200,29 @@ def get_sftp_file_size(*args, **kwargs):
## TFTP routines
-def upload_tftp(local_path, hostname, remote_path, port=69, source=None):
+def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'rb') as file:
- cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\
+ cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\
stderr=None, input=file.read()).encode()
-def download_tftp(local_path, hostname, remote_path, port=69, source=None):
+def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'wb') as file:
- file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\
+ file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\
stderr=None).encode())
# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP,
-# as TFTP does not specify a SIZE command.
+# as TFTP does not specify a SIZE command.
## HTTP(S) routines
def install_request_opener(urlstring, username, password):
"""
- Take`username` and `password` strings and install the appropriate
- password manager to `urllib.request.urlopen()` for the given `urlstring`.
+ Take `username` and `password` strings and install the appropriate
+ password manager to `urllib.request.urlopen()` for the given `urlstring`.
"""
manager = urlreq.HTTPPasswordMgrWithDefaultRealm()
manager.add_password(None, urlstring, username, password)
@@ -163,7 +230,7 @@ def install_request_opener(urlstring, username, password):
# upload_http() is unimplemented.
-def download_http(urlstring, local_path, username=None, password=None):
+def download_http(urlstring, local_path, username=None, password=None, progressbar=False):
"""
Download the file from from `urlstring` to `local_path`.
Optionally takes `username` and `password` for authentication.
@@ -193,7 +260,7 @@ def get_http_file_size(urlstring, username=None, password=None):
# Dynamic dispatchers
-def download(local_path, urlstring, source=None):
+def download(local_path, urlstring, source=None, progressbar=False):
"""
Dispatch the appropriate download function for the given `urlstring` and save to `local_path`.
Optionally takes a `source` address (not valid for HTTP(S)).
@@ -202,6 +269,7 @@ def download(local_path, urlstring, source=None):
"""
url = urllib.parse.urlparse(urlstring)
username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
if url.scheme == 'http' or url.scheme == 'https':
if source:
@@ -209,15 +277,15 @@ def download(local_path, urlstring, source=None):
download_http(urlstring, local_path, username, password)
elif url.scheme == 'ftp':
username = username if username else 'anonymous'
- download_ftp(local_path, url.hostname, url.path, username, password, source=source)
+ download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- download_sftp(local_path, url.hostname, url.path, username, password, source=source)
+ download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
- download_tftp(local_path, url.hostname, url.path, source=source)
+ download_tftp(local_path, url.hostname, url.path, port, source, progressbar)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
-def upload(local_path, urlstring, source=None):
+def upload(local_path, urlstring, source=None, progressbar=False):
"""
Dispatch the appropriate upload function for the given URL and upload from local path.
Optionally takes a `source` address.
@@ -226,14 +294,15 @@ def upload(local_path, urlstring, source=None):
"""
url = urllib.parse.urlparse(urlstring)
username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
if url.scheme == 'ftp':
username = username if username else 'anonymous'
- upload_ftp(local_path, url.hostname, url.path, username, password, source=source)
+ upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- upload_sftp(local_path, url.hostname, url.path, username, password, source=source)
+ upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
- upload_tftp(local_path, url.hostname, url.path, source=source)
+ upload_tftp(local_path, url.hostname, url.path, port, source, progressbar)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
@@ -247,14 +316,15 @@ def get_remote_file_size(urlstring, source=None):
"""
url = urllib.parse.urlparse(urlstring)
username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
if url.scheme == 'http' or url.scheme == 'https':
return get_http_file_size(urlstring, username, password)
elif url.scheme == 'ftp':
username = username if username else 'anonymous'
- return get_ftp_file_size(url.hostname, url.path, username, password, source=source)
+ return get_ftp_file_size(url.hostname, url.path, username, password, port, source)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- return get_sftp_file_size(url.hostname, url.path, username, password, source=source)
+ return get_sftp_file_size(url.hostname, url.path, username, password, port, source)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
@@ -284,6 +354,8 @@ def get_remote_config(urlstring, source=None):
def friendly_download(local_path, urlstring, source=None):
"""
+ Download from `urlstring` to `local_path` in an informative way.
+ Checks the storage space before attempting download.
Intended to be called from interactive, user-facing scripts.
"""
destination_directory = os.path.dirname(local_path)
@@ -298,11 +370,12 @@ def friendly_download(local_path, urlstring, source=None):
if file_size > free_space:
raise OSError(f'Not enough disk space available in "{destination_directory}".')
except ValueError:
+ # Can't do a storage check in this case, so we bravely continue.
+ file_size = 0
print_error('Could not determine the file size in advance.')
else:
- # TODO: Progress bar
print_error('Downloading...')
- download(local_path, urlstring, source)
+ download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024)
except KeyboardInterrupt:
print_error('Download aborted by user.')
sys.exit(1)