summaryrefslogtreecommitdiff
path: root/python/vyos/remote.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/vyos/remote.py')
-rw-r--r--python/vyos/remote.py337
1 files changed, 268 insertions, 69 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index f683a6d5a..0bc2ee7f8 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -14,28 +14,33 @@
# License along with this library. If not, see <http://www.gnu.org/licenses/>.
from ftplib import FTP
+import math
import os
+import shutil
import socket
import sys
import tempfile
import urllib.parse
-import urllib.request
+import urllib.request as urlreq
from vyos.util import cmd, ask_yes_no
from vyos.version import get_version
from paramiko import SSHClient, SSHException, MissingHostKeyPolicy
-known_hosts_file = os.path.expanduser('~/.ssh/known_hosts')
+
+# This is a hardcoded path and no environment variable can change it.
+KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts')
+CHUNK_SIZE = 8192
class InteractivePolicy(MissingHostKeyPolicy):
"""
Policy for interactively querying the user on whether to proceed with
- SSH connections to unknown hosts.
+ SSH connections to unknown hosts.
"""
def missing_host_key(self, client, hostname, key):
- print(f"Host '{hostname}' not found in known hosts.")
- print('Fingerprint: ' + key.get_fingerprint().hex())
+ print_error(f"Host '{hostname}' not found in known hosts.")
+ print_error('Fingerprint: ' + key.get_fingerprint().hex())
if ask_yes_no('Do you wish to continue?'):
if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
client._host_keys.add(hostname, key.get_name(), key)
@@ -43,56 +48,154 @@ class InteractivePolicy(MissingHostKeyPolicy):
else:
raise SSHException(f"Cannot connect to unknown host '{hostname}'.")
+
+## Helper routines
+def print_error(str='', end='\n'):
+ """
+ Print `str` to stderr, terminated with `end`.
+ Used for warnings and out-of-band messages to avoid mangling precious
+ stdout output.
+ """
+ sys.stderr.write(str)
+ sys.stderr.write(end)
+ sys.stderr.flush()
+
+def make_progressbar():
+ """
+ Make a procedure that takes two arguments `done` and `total` and prints a
+ progressbar based on the ratio thereof, whose length is determined by the
+ width of the terminal.
+ """
+ col, _ = shutil.get_terminal_size()
+ col = max(col - 15, 20)
+ def print_progressbar(done, total):
+ if done <= total:
+ increment = total / col
+ length = math.ceil(done / increment)
+ percentage = str(math.ceil(100 * done / total)).rjust(3)
+ print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r')
+ # Print a newline so that the subsequent prints don't overwrite the full bar.
+ if done == total:
+ print_error()
+ return print_progressbar
+
+def make_incremental_progressbar(increment: float):
+ """
+ Make a generator that displays a progressbar that grows monotonically with
+ every iteration.
+ First call displays it at 0% and every subsequent iteration displays it
+ at `increment` increments where 0.0 < `increment` < 1.0.
+ Intended for FTP and HTTP transfers with stateless callbacks.
+ """
+ print_progressbar = make_progressbar()
+ total = 0.0
+ while total < 1.0:
+ print_progressbar(total, 1.0)
+ yield
+ total += increment
+ print_progressbar(1, 1)
+ # Ignore further calls.
+ while True:
+ yield
+
+def get_authentication_variables(default_username=None, default_password=None):
+ """
+ Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and
+ return the defaults provided if environment variables are empty or nonexistent.
+ """
+ username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD')
+ # Fall back to defaults if the username variable doesn't exist or is an empty string.
+ # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`,
+ # as we want the username and the password to have the same behaviour.
+ if not username:
+ return (default_username, default_password)
+ else:
+ return (username, password)
+
+def get_port_from_url(url):
+ """
+ Return the port number from the given `url` named tuple, fall back to
+ the default if there isn't one.
+ """
+ defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\
+ "ssh": 22, "scp": 22, "sftp": 22}
+ if url.port:
+ return url.port
+ else:
+ return defaults[url.scheme]
+
+
## FTP routines
-def transfer_ftp(mode, local_path, hostname, remote_path,\
- username='anonymous', password='', port=21, source=None):
+def upload_ftp(local_path, hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source=None, progressbar=False):
+ size = os.path.getsize(local_path)
with FTP(source_address=source) as conn:
conn.connect(hostname, port)
conn.login(username, password)
- if mode == 'upload':
- with open(local_path, 'rb') as file:
- conn.storbinary(f'STOR {remote_path}', file)
- elif mode == 'download':
- with open(local_path, 'wb') as file:
- conn.retrbinary(f'RETR {remote_path}', file.write)
- elif mode == 'size':
- size = conn.size(remote_path)
- if size:
- return size
+ with open(local_path, 'rb') as file:
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / size)
+ next(progress)
+ callback = lambda block: next(progress)
else:
- # SIZE is an extension to the FTP specification, although it's extremely common.
- raise ValueError('Failed to receive file size from FTP server. \
- Perhaps the server does not implement the SIZE command?')
+ callback = None
+ conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback)
-def upload_ftp(*args, **kwargs):
- transfer_ftp('upload', *args, **kwargs)
+def download_ftp(local_path, hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source=None, progressbar=False):
+ with FTP(source_address=source) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ size = conn.size(remote_path)
+ with open(local_path, 'wb') as file:
+ # No progressbar if we can't determine the size.
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / size)
+ next(progress)
+ callback = lambda block: (file.write(block), next(progress))
+ else:
+ callback = file.write
+ conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE)
-def download_ftp(*args, **kwargs):
- transfer_ftp('download', *args, **kwargs)
+def get_ftp_file_size(hostname, remote_path,\
+ username='anonymous', password='', port=21,\
+ source=None):
+ with FTP(source_address=source) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ size = conn.size(remote_path)
+ if size:
+ return size
+ else:
+ # SIZE is an extension to the FTP specification, although it's extremely common.
+ raise ValueError('Failed to receive file size from FTP server. \
+ Perhaps the server does not implement the SIZE command?')
-def get_ftp_file_size(*args, **kwargs):
- return transfer_ftp('size', None, *args, **kwargs)
## SFTP/SCP routines
def transfer_sftp(mode, local_path, hostname, remote_path,\
- username=None, password=None, port=22, source=None):
+ username=None, password=None, port=22,\
+ source=None, progressbar=False):
sock = None
if source:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((source, 0))
sock.connect((hostname, port))
+ callback = make_progressbar() if progressbar else None
try:
with SSHClient() as ssh:
ssh.load_system_host_keys()
- if os.path.exists(known_hosts_file):
- ssh.load_host_keys(known_hosts_file)
+ if os.path.exists(KNOWN_HOSTS_FILE):
+ ssh.load_host_keys(KNOWN_HOSTS_FILE)
ssh.set_missing_host_key_policy(InteractivePolicy())
ssh.connect(hostname, port, username, password, sock=sock)
with ssh.open_sftp() as sftp:
if mode == 'upload':
- sftp.put(local_path, remote_path)
+ sftp.put(local_path, remote_path, callback=callback)
elif mode == 'download':
- sftp.get(remote_path, local_path)
+ sftp.get(remote_path, local_path, callback=callback)
elif mode == 'size':
return sftp.stat(remote_path).st_size
finally:
@@ -109,32 +212,70 @@ def download_sftp(*args, **kwargs):
def get_sftp_file_size(*args, **kwargs):
return transfer_sftp('size', None, *args, **kwargs)
+
## TFTP routines
-def upload_tftp(local_path, hostname, remote_path, port=69, source=None):
+def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'rb') as file:
- cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\
+ cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\
stderr=None, input=file.read()).encode()
-def download_tftp(local_path, hostname, remote_path, port=69, source=None):
+def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False):
source_option = f'--interface {source}' if source else ''
+ # Not really applicable but we pass it for the sake of uniformity.
+ progress_flag = '--progress-bar' if progressbar else '-s'
with open(local_path, 'wb') as file:
- file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\
+ file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\
stderr=None).encode())
# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP,
-# as TFTP does not specify a SIZE command.
+# as TFTP does not specify a SIZE command.
+
## HTTP(S) routines
-def download_http(urlstring, local_path):
- request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
- with open(local_path, 'wb') as file:
- with urllib.request.urlopen(request) as response:
- file.write(response.read())
+def install_request_opener(urlstring, username, password):
+ """
+ Take `username` and `password` strings and install the appropriate
+ password manager to `urllib.request.urlopen()` for the given `urlstring`.
+ """
+ manager = urlreq.HTTPPasswordMgrWithDefaultRealm()
+ manager.add_password(None, urlstring, username, password)
+ urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager)))
+
+# upload_http() is unimplemented.
+
+def download_http(local_path, urlstring, username=None, password=None, progressbar=False):
+ """
+ Download the file from from `urlstring` to `local_path`.
+ Optionally takes `username` and `password` for authentication.
+ """
+ request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
+ if username:
+ install_request_opener(urlstring, username, password)
+ with open(local_path, 'wb') as file, urlreq.urlopen(request) as response:
+ size = response.getheader('Content-Length')
+ if progressbar and size:
+ progress = make_incremental_progressbar(CHUNK_SIZE / int(size))
+ next(progress)
+ for chunk in iter(lambda: response.read(CHUNK_SIZE), b''):
+ file.write(chunk)
+ next(progress)
+ next(progress)
+ # If we can't determine the size or if a progress bar wasn't requested,
+ # we can let `shutil` take care of the copying.
+ else:
+ shutil.copyfileobj(response, file)
-def get_http_file_size(urlstring):
- request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
- with urllib.request.urlopen(request) as response:
+def get_http_file_size(urlstring, username=None, password=None):
+ """
+ Return the size of the file from `urlstring` in terms of number of bytes.
+ Optionally takes `username` and `password` for authentication.
+ """
+ request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
+ if username:
+ install_request_opener(urlstring, username, password)
+ with urlreq.urlopen(request) as response:
size = response.getheader('Content-Length')
if size:
return int(size)
@@ -142,67 +283,87 @@ def get_http_file_size(urlstring):
else:
raise ValueError('Failed to receive file size from HTTP server.')
+
# Dynamic dispatchers
-def download(local_path, urlstring, source=None):
+def download(local_path, urlstring, source=None, progressbar=False):
"""
- Dispatch the appropriate download function for the given URL and save to local path.
+ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`.
+ Optionally takes a `source` address (not valid for HTTP(S)).
+ Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP.
+ Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
url = urllib.parse.urlparse(urlstring)
+ username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
+
if url.scheme == 'http' or url.scheme == 'https':
if source:
- print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr)
- download_http(urlstring, local_path)
+ print_error('Warning: Custom source address not supported for HTTP connections.')
+ download_http(local_path, urlstring, username, password, progressbar)
elif url.scheme == 'ftp':
- username = url.username if url.username else 'anonymous'
- download_ftp(local_path, url.hostname, url.path, username, url.password, source=source)
+ username = username if username else 'anonymous'
+ download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- download_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source)
+ download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
- download_tftp(local_path, url.hostname, url.path, source=source)
+ download_tftp(local_path, url.hostname, url.path, port, source, progressbar)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
-def upload(local_path, urlstring, source=None):
+def upload(local_path, urlstring, source=None, progressbar=False):
"""
Dispatch the appropriate upload function for the given URL and upload from local path.
+ Optionally takes a `source` address.
+ Supports FTP, SFTP, SCP (through SFTP) and TFTP.
+ Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
url = urllib.parse.urlparse(urlstring)
+ username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
+
if url.scheme == 'ftp':
- username = url.username if url.username else 'anonymous'
- upload_ftp(local_path, url.hostname, url.path, username, url.password, source=source)
+ username = username if username else 'anonymous'
+ upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- upload_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source)
+ upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
- upload_tftp(local_path, url.hostname, url.path, source=source)
+ upload_tftp(local_path, url.hostname, url.path, port, source, progressbar)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
def get_remote_file_size(urlstring, source=None):
"""
- Return the size of the remote file in bytes.
+ Dispatch the appropriate function to return the size of the remote file from `urlstring`
+ in terms of number of bytes.
+ Optionally takes a `source` address (not valid for HTTP(S)).
+ Supports HTTP, HTTPS, FTP and SFTP (through SFTP).
+ Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
url = urllib.parse.urlparse(urlstring)
+ username, password = get_authentication_variables(url.username, url.password)
+ port = get_port_from_url(url)
+
if url.scheme == 'http' or url.scheme == 'https':
- return get_http_file_size(urlstring)
+ return get_http_file_size(urlstring, username, password)
elif url.scheme == 'ftp':
- username = url.username if url.username else 'anonymous'
- return get_ftp_file_size(url.hostname, url.path, username, url.password, source=source)
+ username = username if username else 'anonymous'
+ return get_ftp_file_size(url.hostname, url.path, username, password, port, source)
elif url.scheme == 'sftp' or url.scheme == 'scp':
- return get_sftp_file_size(url.hostname, url.path, url.username, url.password, source=source)
+ return get_sftp_file_size(url.hostname, url.path, username, password, port, source)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')
def get_remote_config(urlstring, source=None):
"""
- Download remote (config) file and return the contents.
+ Download remote (config) file from `urlstring` and return the contents as a string.
Args:
remote file URI:
- scp://<user>[:<passwd>]@<host>/<file>
- sftp://<user>[:<passwd>]@<host>/<file>
- http://<host>/<file>
- https://<host>/<file>
- ftp://[<user>[:<passwd>]@]<host>/<file>
- tftp://<host>/<file>
+ tftp://<host>[:<port>]/<file>
+ http[s]://<host>[:<port>]/<file>
+ [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file>
+ source address (optional):
+ <interface>
+ <IP address>
"""
url = urllib.parse.urlparse(urlstring)
temp = tempfile.NamedTemporaryFile(delete=False).name
@@ -212,3 +373,41 @@ def get_remote_config(urlstring, source=None):
return file.read()
finally:
os.remove(temp)
+
+def friendly_download(local_path, urlstring, source=None):
+ """
+ Download from `urlstring` to `local_path` in an informative way.
+ Checks the storage space before attempting download.
+ Intended to be called from interactive, user-facing scripts.
+ """
+ destination_directory = os.path.dirname(local_path)
+ try:
+ free_space = shutil.disk_usage(destination_directory).free
+ try:
+ file_size = get_remote_file_size(urlstring, source)
+ if file_size < 1024 * 1024:
+ print_error(f'The file is {file_size / 1024.0:.3f} KiB.')
+ else:
+ print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.')
+ if file_size > free_space:
+ raise OSError(f'Not enough disk space available in "{destination_directory}".')
+ except ValueError:
+ # Can't do a storage check in this case, so we bravely continue.
+ file_size = 0
+ print_error('Could not determine the file size in advance.')
+ else:
+ print_error('Downloading...')
+ download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024)
+ except KeyboardInterrupt:
+ print_error('Download aborted by user.')
+ sys.exit(1)
+ except:
+ import traceback
+ # There are a myriad different reasons a download could fail.
+ # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...)
+ # We omit the scary stack trace but print the error nevertheless.
+ print_error(f'Failed to download {urlstring}.')
+ traceback.print_exception(*sys.exc_info()[:2], None)
+ sys.exit(1)
+ else:
+ print_error('Download complete.')