# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io> # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. from ftplib import FTP import os import shutil import socket import stat import sys import tempfile import urllib.parse import urllib.request as urlreq from vyos.template import get_ip from vyos.template import ip_from_cidr from vyos.template import is_interface from vyos.template import is_ipv6 from vyos.util import cmd from vyos.util import ask_yes_no from vyos.util import print_error from vyos.util import make_progressbar from vyos.util import make_incremental_progressbar from vyos.version import get_version from paramiko import SSHClient from paramiko import SSHException from paramiko import MissingHostKeyPolicy # This is a hardcoded path and no environment variable can change it. KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts') CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): """ Policy for interactively querying the user on whether to proceed with SSH connections to unknown hosts. """ def missing_host_key(self, client, hostname, key): print_error(f"Host '{hostname}' not found in known hosts.") print_error('Fingerprint: ' + key.get_fingerprint().hex()) if ask_yes_no('Do you wish to continue?'): if client._host_keys_filename\ and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): client._host_keys.add(hostname, key.get_name(), key) client.save_host_keys(client._host_keys_filename) else: raise SSHException(f"Cannot connect to unknown host '{hostname}'.") ## Helper routines def get_authentication_variables(default_username=None, default_password=None): """ Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and return the defaults provided if environment variables are empty or nonexistent. """ username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') # Fall back to defaults if the username variable doesn't exist or is an empty string. # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, # as we want the username and the password to have the same behaviour. if not username: return default_username, default_password else: return username, password def get_source_address(source): """ Take a string vaguely indicating an origin source (interface, hostname or IP address), return a tuple in the format `(source_pair, address_family)` where `source_pair` is `(source_address, source_port)`. """ # TODO: Properly distinguish between IPv4 and IPv6. port = 0 if is_interface(source): source = ip_from_cidr(get_ip(source)[0]) if is_ipv6(source): return (source, port), socket.AF_INET6 else: return (socket.gethostbyname(source), port), socket.AF_INET def get_port_from_url(url): """ Return the port number from the given `url` named tuple, fall back to the default if there isn't one. """ defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\ "ssh": 22, "scp": 22, "sftp": 22} if url.port: return url.port else: return defaults[url.scheme] ## FTP routines def upload_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ source_pair=None, progressbar=False): size = os.path.getsize(local_path) with FTP(source_address=source_pair) as conn: conn.connect(hostname, port) conn.login(username, password) with open(local_path, 'rb') as file: if progressbar and size: progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) callback = lambda block: next(progress) else: callback = None conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback) def download_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ source_pair=None, progressbar=False): with FTP(source_address=source_pair) as conn: conn.connect(hostname, port) conn.login(username, password) size = conn.size(remote_path) with open(local_path, 'wb') as file: # No progressbar if we can't determine the size. if progressbar and size: progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) callback = lambda block: (file.write(block), next(progress)) else: callback = file.write conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE) def get_ftp_file_size(hostname, remote_path,\ username='anonymous', password='', port=21,\ source_pair=None): with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) size = conn.size(remote_path) if size: return size else: # SIZE is an extension to the FTP specification, although it's extremely common. raise ValueError('Failed to receive file size from FTP server. \ Perhaps the server does not implement the SIZE command?') ## SFTP/SCP routines def transfer_sftp(mode, local_path, hostname, remote_path,\ username=None, password=None, port=22,\ source_tuple=None, progressbar=False): sock = None if source_tuple: (source_address, source_port), address_family = source_tuple sock = socket.socket(address_family, socket.SOCK_STREAM) sock.bind((source_address, source_port)) sock.connect((hostname, port)) callback = make_progressbar() if progressbar else None with SSHClient() as ssh: ssh.load_system_host_keys() if os.path.exists(KNOWN_HOSTS_FILE): ssh.load_host_keys(KNOWN_HOSTS_FILE) ssh.set_missing_host_key_policy(InteractivePolicy()) ssh.connect(hostname, port, username, password, sock=sock) with ssh.open_sftp() as sftp: if mode == 'upload': try: # If the remote path is a directory, use the original filename. if stat.S_ISDIR(sftp.stat(remote_path).st_mode): path = os.path.join(remote_path, os.path.basename(local_path)) # A file exists at this destination. We're simply going to clobber it. else: path = remote_path # This path doesn't point at any existing file. We can freely use this filename. except IOError: path = remote_path finally: sftp.put(local_path, path, callback=callback) elif mode == 'download': sftp.get(remote_path, local_path, callback=callback) elif mode == 'size': return sftp.stat(remote_path).st_size def upload_sftp(*args, **kwargs): transfer_sftp('upload', *args, **kwargs) def download_sftp(*args, **kwargs): transfer_sftp('download', *args, **kwargs) def get_sftp_file_size(*args, **kwargs): return transfer_sftp('size', None, *args, **kwargs) ## TFTP routines def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'rb') as file: cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\ stderr=None, input=file.read()).encode() def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): source_option = f'--interface {source}' if source else '' # Not really applicable but we pass it for the sake of uniformity. progress_flag = '--progress-bar' if progressbar else '-s' with open(local_path, 'wb') as file: file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ stderr=None).encode()) # get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, # as TFTP does not specify a SIZE command. ## HTTP(S) routines def install_request_opener(urlstring, username, password): """ Take `username` and `password` strings and install the appropriate password manager to `urllib.request.urlopen()` for the given `urlstring`. """ manager = urlreq.HTTPPasswordMgrWithDefaultRealm() manager.add_password(None, urlstring, username, password) urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager))) # upload_http() is unimplemented. def download_http(local_path, urlstring, username=None, password=None, progressbar=False): """ Download the file from from `urlstring` to `local_path`. Optionally takes `username` and `password` for authentication. """ request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) if username: install_request_opener(urlstring, username, password) with open(local_path, 'wb') as file, urlreq.urlopen(request) as response: size = response.getheader('Content-Length') if progressbar and size: progress = make_incremental_progressbar(CHUNK_SIZE / int(size)) next(progress) for chunk in iter(lambda: response.read(CHUNK_SIZE), b''): file.write(chunk) next(progress) next(progress) # If we can't determine the size or if a progress bar wasn't requested, # we can let `shutil` take care of the copying. else: shutil.copyfileobj(response, file) def get_http_file_size(urlstring, username=None, password=None): """ Return the size of the file from `urlstring` in terms of number of bytes. Optionally takes `username` and `password` for authentication. """ request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) if username: install_request_opener(urlstring, username, password) with urlreq.urlopen(request) as response: size = response.getheader('Content-Length') if size: return int(size) # The server didn't send 'Content-Length' in the response headers. else: raise ValueError('Failed to receive file size from HTTP server.') ## Dynamic dispatchers def download(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. Optionally takes a `source` address or interface (not valid for HTTP(S)). Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': if source: print_error('Warning: Custom source address not supported for HTTP connections.') download_http(local_path, urlstring, username, password, progressbar) elif url.scheme == 'ftp': source = get_source_address(source)[0] if source else None username = username if username else 'anonymous' download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': source = get_source_address(source) if source else None download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': download_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') def upload(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate upload function for the given URL and upload from local path. Optionally takes a `source` address. Supports FTP, SFTP, SCP (through SFTP) and TFTP. Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) port = get_port_from_url(url) if url.scheme == 'ftp': username = username if username else 'anonymous' source = get_source_address(source)[0] if source else None upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': source = get_source_address(source) if source else None upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') def get_remote_file_size(urlstring, source=None): """ Dispatch the appropriate function to return the size of the remote file from `urlstring` in terms of number of bytes. Optionally takes a `source` address (not valid for HTTP(S)). Supports HTTP, HTTPS, FTP and SFTP (through SFTP). Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ url = urllib.parse.urlparse(urlstring) username, password = get_authentication_variables(url.username, url.password) port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': if source: print_error('Warning: Custom source address not supported for HTTP connections.') return get_http_file_size(urlstring, username, password) elif url.scheme == 'ftp': source = get_source_address(source)[0] if source else None username = username if username else 'anonymous' return get_ftp_file_size(url.hostname, url.path, username, password, port, source) elif url.scheme == 'sftp' or url.scheme == 'scp': source = get_source_address(source) if source else None return get_sftp_file_size(url.hostname, url.path, username, password, port, source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') def get_remote_config(urlstring, source=None): """ Download remote (config) file from `urlstring` and return the contents as a string. Args: remote file URI: tftp://<host>[:<port>]/<file> http[s]://<host>[:<port>]/<file> [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file> source address (optional): <interface> <IP address> """ temp = tempfile.NamedTemporaryFile(delete=False).name try: download(temp, urlstring, source) with open(temp, 'r') as file: return file.read() finally: os.remove(temp) def friendly_download(local_path, urlstring, source=None): """ Download from `urlstring` to `local_path` in an informative way. Checks the storage space before attempting download. Intended to be called from interactive, user-facing scripts. """ destination_directory = os.path.dirname(local_path) try: free_space = shutil.disk_usage(destination_directory).free try: file_size = get_remote_file_size(urlstring, source) if file_size < 1024 * 1024: print_error(f'The file is {file_size / 1024.0:.3f} KiB.') else: print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') if file_size > free_space: raise OSError(f'Not enough disk space available in "{destination_directory}".') except ValueError: # Can't do a storage check in this case, so we bravely continue. file_size = 0 print_error('Could not determine the file size in advance.') else: print_error('Downloading...') download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024) except KeyboardInterrupt: print_error('Download aborted by user.') sys.exit(1) except: import traceback # There are a myriad different reasons a download could fail. # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) # We omit the scary stack trace but print the error nevertheless. print_error(f'Failed to download {urlstring}.') traceback.print_exception(*sys.exc_info()[:2], None) sys.exit(1) else: print_error('Download complete.')