diff options
Diffstat (limited to 'python/vyos/remote.py')
-rw-r--r-- | python/vyos/remote.py | 572 |
1 files changed, 254 insertions, 318 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index e972050b7..2419f8873 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -13,38 +13,40 @@ # You should have received a copy of the GNU Lesser General Public # License along with this library. If not, see <http://www.gnu.org/licenses/>. -from ftplib import FTP import os import shutil import socket +import ssl import stat import sys import tempfile import urllib.parse -import urllib.request as urlreq -from vyos.template import get_ip -from vyos.template import ip_from_cidr -from vyos.template import is_interface -from vyos.template import is_ipv6 -from vyos.util import cmd +from ftplib import FTP +from ftplib import FTP_TLS + +from paramiko import SSHClient +from paramiko import MissingHostKeyPolicy + +from requests import Session +from requests.adapters import HTTPAdapter +from requests.packages.urllib3 import PoolManager + from vyos.util import ask_yes_no -from vyos.util import print_error -from vyos.util import make_progressbar +from vyos.util import begin +from vyos.util import cmd from vyos.util import make_incremental_progressbar +from vyos.util import make_progressbar +from vyos.util import print_error from vyos.version import get_version -from paramiko import SSHClient -from paramiko import SSHException -from paramiko import MissingHostKeyPolicy -# This is a hardcoded path and no environment variable can change it. -KNOWN_HOSTS_FILE = os.path.expanduser('~/.ssh/known_hosts') + CHUNK_SIZE = 8192 class InteractivePolicy(MissingHostKeyPolicy): """ - Policy for interactively querying the user on whether to proceed with - SSH connections to unknown hosts. + Paramiko policy for interactively querying the user on whether to proceed + with SSH connections to unknown hosts. """ def missing_host_key(self, client, hostname, key): print_error(f"Host '{hostname}' not found in known hosts.") @@ -57,337 +59,270 @@ class InteractivePolicy(MissingHostKeyPolicy): else: raise SSHException(f"Cannot connect to unknown host '{hostname}'.") - -## Helper routines -def get_authentication_variables(default_username=None, default_password=None): +class SourceAdapter(HTTPAdapter): """ - Return the environment variables `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` and - return the defaults provided if environment variables are empty or nonexistent. + urllib3 transport adapter for setting source addresses per session. """ - username, password = os.getenv('REMOTE_USERNAME'), os.getenv('REMOTE_PASSWORD') - # Fall back to defaults if the username variable doesn't exist or is an empty string. - # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, - # as we want the username and the password to have the same behaviour. - if not username: - return default_username, default_password - else: - return username, password - -def get_source_address(source): + def __init__(self, source_pair, *args, **kwargs): + # A source pair is a tuple of a source host string and source port respectively. + # Supply '' and 0 respectively for default values. + self._source_pair = source_pair + super(SourceAdapter, self).__init__(*args, **kwargs) + + def init_poolmanager(self, connections, maxsize, block=False): + self.poolmanager = PoolManager( + num_pools=connections, maxsize=maxsize, + block=block, source_address=self._source_pair) + +class WrappedFile: + def __init__(self, obj, size=None, chunk_size=CHUNK_SIZE): + self._obj = obj + self._progress = size and make_incremental_progressbar(chunk_size / size) + def read(self, size=-1): + if self._progress: + next(self._progress) + self._obj.read(size) + def write(self, size=-1): + if self._progress: + next(self._progress) + self._obj.write(size) + def __getattr__(self, attr): + return getattr(self._obj, attr) + +def check_storage(path, size): """ - Take a string vaguely indicating an origin source (interface, hostname or IP address), - return a tuple in the format `(source_pair, address_family)` where - `source_pair` is `(source_address, source_port)`. + Check whether `path` has enough storage space for a transfer of `size` bytes. """ - # TODO: Properly distinguish between IPv4 and IPv6. - port = 0 - if is_interface(source): - source = ip_from_cidr(get_ip(source)[0]) - if is_ipv6(source): - return (source, port), socket.AF_INET6 + path = os.path.abspath(os.path.expanduser(path)) + directory = path if os.path.isdir(path) else (os.path.dirname(os.path.expanduser(path)) or os.getcwd()) + # `size` can be None or 0 to indicate unknown size. + if not size: + print_error('Warning: Cannot determine size of remote file.') + print_error('Bravely continuing regardless.') + return + + if size < 1024 * 1024: + print_error(f'The file is {size / 1024.0:.3f} KiB.') else: - return (socket.gethostbyname(source), port), socket.AF_INET - -def get_port_from_url(url): - """ - Return the port number from the given `url` named tuple, fall back to - the default if there isn't one. - """ - defaults = {"http": 80, "https": 443, "ftp": 21, "tftp": 69,\ - "ssh": 22, "scp": 22, "sftp": 22} - if url.port: - return url.port - else: - return defaults[url.scheme] - - -## FTP routines -def upload_ftp(local_path, hostname, remote_path,\ - username='anonymous', password='', port=21,\ - source_pair=None, progressbar=False): - size = os.path.getsize(local_path) - with FTP(source_address=source_pair) as conn: - conn.connect(hostname, port) - conn.login(username, password) - with open(local_path, 'rb') as file: - if progressbar and size: + print_error(f'The file is {size / (1024.0 * 1024.0):.3f} MiB.') + + # Will throw `FileNotFoundError' if `directory' is absent. + if size > shutil.disk_usage(directory).free: + raise OSError(f'Not enough disk space available in "{directory}".') + + +class FtpC: + def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + self.secure = url.scheme == 'ftps' + self.hostname = url.hostname + self.path = url.path + self.username = url.username or os.getenv('REMOTE_USERNAME', 'anonymous') + self.password = url.password or os.getenv('REMOTE_PASSWORD', '') + self.port = url.port or 21 + self.source = (source_host, source_port) + self.progressbar = progressbar + self.check_space = check_space + + def _establish(self): + if self.secure: + return FTP_TLS(source_address=self.source, context=ssl.create_default_context()) + else: + return FTP(source_address=self.source) + + def download(self, location: str): + # Open the file upfront before establishing connection. + with open(location, 'wb') as f, self._establish() as conn: + conn.connect(self.hostname, self.port) + conn.login(self.username, self.password) + # Set secure connection over TLS. + if self.secure: + conn.prot_p() + # Almost all FTP servers support the `SIZE' command. + if self.check_space: + check_storage(path, conn.size(self.path)) + # No progressbar if we can't determine the size or if the file is too small. + if self.progressbar and size and size > CHUNK_SIZE: progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) - callback = lambda block: next(progress) + callback = lambda block: begin(f.write(block), next(progress)) else: - callback = None - conn.storbinary(f'STOR {remote_path}', file, CHUNK_SIZE, callback) - -def download_ftp(local_path, hostname, remote_path,\ - username='anonymous', password='', port=21,\ - source_pair=None, progressbar=False): - with FTP(source_address=source_pair) as conn: - conn.connect(hostname, port) - conn.login(username, password) - size = conn.size(remote_path) - with open(local_path, 'wb') as file: - # No progressbar if we can't determine the size. - if progressbar and size: + callback = f.write + conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) + + def upload(self, location: str): + size = os.path.getsize(location) + with open(location, 'rb') as f, self._establish() as conn: + conn.connect(self.hostname, self.port) + conn.login(self.username, self.password) + if self.secure: + conn.prot_p() + if self.progressbar and size and size > CHUNK_SIZE: progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) - callback = lambda block: (file.write(block), next(progress)) + callback = lambda block: next(progress) else: - callback = file.write - conn.retrbinary(f'RETR {remote_path}', callback, CHUNK_SIZE) - -def get_ftp_file_size(hostname, remote_path,\ - username='anonymous', password='', port=21,\ - source_pair=None): - with FTP(source_address=source) as conn: - conn.connect(hostname, port) - conn.login(username, password) - size = conn.size(remote_path) - if size: - return size - else: - # SIZE is an extension to the FTP specification, although it's extremely common. - raise ValueError('Failed to receive file size from FTP server. \ - Perhaps the server does not implement the SIZE command?') - - -## SFTP/SCP routines -def transfer_sftp(mode, local_path, hostname, remote_path,\ - username=None, password=None, port=22,\ - source_tuple=None, progressbar=False): - sock = None - if source_tuple: - (source_address, source_port), address_family = source_tuple - sock = socket.socket(address_family, socket.SOCK_STREAM) - sock.bind((source_address, source_port)) - sock.connect((hostname, port)) - callback = make_progressbar() if progressbar else None - with SSHClient() as ssh: + callback = None + conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, callback) + +class SshC: + known_hosts = os.path.expanduser('~/.ssh/known_hosts') + def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + self.hostname = url.hostname + self.path = url.path + self.username = url.username or os.getenv('REMOTE_USERNAME') + self.password = url.password or os.getenv('REMOTE_PASSWORD') + self.port = url.port or 22 + self.source = (source_host, source_port) + self.progressbar = progressbar + self.check_space = check_space + + def _establish(self): + ssh = SSHClient() ssh.load_system_host_keys() - if os.path.exists(KNOWN_HOSTS_FILE): - ssh.load_host_keys(KNOWN_HOSTS_FILE) + # Try to load from a user-local known hosts file if one exists. + if os.path.exists(self.known_hosts): + ssh.load_host_keys(self.known_hosts) ssh.set_missing_host_key_policy(InteractivePolicy()) - ssh.connect(hostname, port, username, password, sock=sock) - with ssh.open_sftp() as sftp: - if mode == 'upload': + # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family + # for us on dual-stack systems. + sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source) + ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock) + return ssh + + def download(self, location: str): + callback = make_progressbar() if self.progressbar else None + with self._establish() as ssh, ssh.open_sftp() as sftp: + if self.check_space: + check_storage(location, sftp.stat(self.path).st_size) + sftp.get(self.path, location, callback=callback) + + def upload(self, location: str): + callback = make_progressbar() if self.progressbar else None + with self._establish() as ssh, ssh.open_sftp() as sftp: + try: + # If the remote path is a directory, use the original filename. + if stat.S_ISDIR(sftp.stat(self.path).st_mode): + path = os.path.join(self.path, os.path.basename(location)) + # A file exists at this destination. We're simply going to clobber it. + else: + path = self.path + # This path doesn't point at any existing file. We can freely use this filename. + except IOError: + path = self.path + finally: + sftp.put(location, path, callback=callback) + + +class HttpC: + def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + self.urlstring = urllib.parse.urlunsplit(url) + self.progressbar = progressbar + self.check_space = check_space + self.source_pair = (source_host, source_port) + self.username = url.username or os.getenv('REMOTE_USERNAME') + self.password = url.password or os.getenv('REMOTE_PASSWORD') + + def _establish(self): + session = Session() + session.mount(self.urlstring, SourceAdapter(self.source_pair)) + session.headers.update({'User-Agent': 'VyOS/' + get_version()}) + if self.username: + session.auth = self.username, self.password + return session + + def download(self, location: str): + with self._establish() as s: + # We ask for uncompressed downloads so that we don't have to deal with decoding. + # Not only would it potentially mess up with the progress bar but + # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding. + s.headers.update({'Accept-Encoding': 'identity'}) + with s.head(self.urlstring) as r: + # Abort early if the destination is inaccessible. + r.raise_for_status() + # Check for the prospective file size. try: - # If the remote path is a directory, use the original filename. - if stat.S_ISDIR(sftp.stat(remote_path).st_mode): - path = os.path.join(remote_path, os.path.basename(local_path)) - # A file exists at this destination. We're simply going to clobber it. - else: - path = remote_path - # This path doesn't point at any existing file. We can freely use this filename. - except IOError: - path = remote_path - finally: - sftp.put(local_path, path, callback=callback) - elif mode == 'download': - sftp.get(remote_path, local_path, callback=callback) - elif mode == 'size': - return sftp.stat(remote_path).st_size - -def upload_sftp(*args, **kwargs): - transfer_sftp('upload', *args, **kwargs) - -def download_sftp(*args, **kwargs): - transfer_sftp('download', *args, **kwargs) - -def get_sftp_file_size(*args, **kwargs): - return transfer_sftp('size', None, *args, **kwargs) - - -## TFTP routines -def upload_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): - source_option = f'--interface {source}' if source else '' - progress_flag = '--progress-bar' if progressbar else '-s' - with open(local_path, 'rb') as file: - cmd(f'curl {source_option} {progress_flag} -T - tftp://{hostname}:{port}/{remote_path}',\ - stderr=None, input=file.read()).encode() - -def download_tftp(local_path, hostname, remote_path, port=69, source=None, progressbar=False): - source_option = f'--interface {source}' if source else '' - # Not really applicable but we pass it for the sake of uniformity. - progress_flag = '--progress-bar' if progressbar else '-s' - with open(local_path, 'wb') as file: - file.write(cmd(f'curl {source_option} {progress_flag} tftp://{hostname}:{port}/{remote_path}',\ - stderr=None).encode()) - -# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP, -# as TFTP does not specify a SIZE command. - - -## HTTP(S) routines -def install_request_opener(urlstring, username, password): - """ - Take `username` and `password` strings and install the appropriate - password manager to `urllib.request.urlopen()` for the given `urlstring`. - """ - manager = urlreq.HTTPPasswordMgrWithDefaultRealm() - manager.add_password(None, urlstring, username, password) - urlreq.install_opener(urlreq.build_opener(urlreq.HTTPBasicAuthHandler(manager))) - -# upload_http() is unimplemented. - -def download_http(local_path, urlstring, username=None, password=None, progressbar=False): - """ - Download the file from from `urlstring` to `local_path`. - Optionally takes `username` and `password` for authentication. - """ - request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) - if username: - install_request_opener(urlstring, username, password) - with open(local_path, 'wb') as file, urlreq.urlopen(request) as response: - size = response.getheader('Content-Length') - if progressbar and size: - progress = make_incremental_progressbar(CHUNK_SIZE / int(size)) - next(progress) - for chunk in iter(lambda: response.read(CHUNK_SIZE), b''): - file.write(chunk) - next(progress) - next(progress) - # If we can't determine the size or if a progress bar wasn't requested, - # we can let `shutil` take care of the copying. - else: - shutil.copyfileobj(response, file) - -def get_http_file_size(urlstring, username=None, password=None): - """ - Return the size of the file from `urlstring` in terms of number of bytes. - Optionally takes `username` and `password` for authentication. - """ - request = urlreq.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()}) - if username: - install_request_opener(urlstring, username, password) - with urlreq.urlopen(request) as response: - size = response.getheader('Content-Length') - if size: - return int(size) - # The server didn't send 'Content-Length' in the response headers. - else: - raise ValueError('Failed to receive file size from HTTP server.') - - -## Dynamic dispatchers -def download(local_path, urlstring, source=None, progressbar=False): + size = int(r.headers['Content-Length']) + # In case the server does not supply the header. + except KeyError: + size = None + if self.check_space: + check_storage(location, size) + with s.get(self.urlstring, stream=True) as r, open(location, 'wb') as f: + if self.progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / size) + next(progress) + for chunk in iter(lambda: begin(next(progress), r.raw.read(CHUNK_SIZE)), b''): + f.write(chunk) + else: + # We'll try to stream the download directly with `copyfileobj()` so that large + # files (like entire VyOS images) don't occupy much memory. + shutil.copyfileobj(r.raw, f) + + def upload(self, location: str): + size = os.path.getsize(location) if self.progressbar else None + # Keep in mind that `data` can be a file-like or iterable object. + with self._establish() as s, file(location, 'rb') as f: + s.post(self.urlstring, data=WrappedFile(f, size)) + + +class TftpC: + # We simply allow `curl` to take over because + # 1. TFTP is rather simple. + # 2. Since there's no concept authentication, we don't need to deal with keys/passwords. + # 3. It would be a waste to import, audit and maintain a third-party library for TFTP. + # 4. I'd rather not implement the entire protocol here, no matter how simple it is. + def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0): + source_option = f'--interface {source_host} --local-port {source_port}' if source_host else '' + progress_flag = '--progress-bar' if progressbar else '-s' + self.command = f'curl {source_option} {progress_flag}' + self.urlstring = urllib.parse.urlunsplit(url) + + def download(self, location: str): + with open(location, 'wb') as f: + f.write(cmd(f'{self.command} "{self.urlstring}"').encode()) + + def upload(self, location: str): + with open(location, 'rb') as f: + cmd(f'{self.command} -T - "{self.urlstring}"', input=f.read()) + + +def urlc(urlstring, *args, **kwargs): """ - Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. - Optionally takes a `source` address or interface (not valid for HTTP(S)). - Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. - Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. + Dynamically dispatch the appropriate protocol class. """ - url = urllib.parse.urlparse(urlstring) - username, password = get_authentication_variables(url.username, url.password) - port = get_port_from_url(url) - - if url.scheme == 'http' or url.scheme == 'https': - if source: - print_error('Warning: Custom source address not supported for HTTP connections.') - download_http(local_path, urlstring, username, password, progressbar) - elif url.scheme == 'ftp': - source = get_source_address(source)[0] if source else None - username = username if username else 'anonymous' - download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) - elif url.scheme == 'sftp' or url.scheme == 'scp': - source = get_source_address(source) if source else None - download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) - elif url.scheme == 'tftp': - download_tftp(local_path, url.hostname, url.path, port, source, progressbar) - else: - raise ValueError(f'Unsupported URL scheme: {url.scheme}') + url_classes = {'http': HttpC, 'https': HttpC, 'ftp': FtpC, 'ftps': FtpC, \ + 'sftp': SshC, 'ssh': SshC, 'scp': SshC, 'tftp': TftpC} + url = urllib.parse.urlsplit(urlstring) + try: + return url_classes[url.scheme](url, *args, **kwargs) + except KeyError: + raise ValueError(f'Unsupported URL scheme: "{url.scheme}"') -def upload(local_path, urlstring, source=None, progressbar=False): - """ - Dispatch the appropriate upload function for the given URL and upload from local path. - Optionally takes a `source` address. - Supports FTP, SFTP, SCP (through SFTP) and TFTP. - Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. - """ - url = urllib.parse.urlparse(urlstring) - username, password = get_authentication_variables(url.username, url.password) - port = get_port_from_url(url) - - if url.scheme == 'ftp': - username = username if username else 'anonymous' - source = get_source_address(source)[0] if source else None - upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) - elif url.scheme == 'sftp' or url.scheme == 'scp': - source = get_source_address(source) if source else None - upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) - elif url.scheme == 'tftp': - upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) - else: - raise ValueError(f'Unsupported URL scheme: {url.scheme}') +def download(local_path, urlstring, *args, **kwargs): + urlc(urlstring, *args, **kwargs).download(local_path) -def get_remote_file_size(urlstring, source=None): - """ - Dispatch the appropriate function to return the size of the remote file from `urlstring` - in terms of number of bytes. - Optionally takes a `source` address (not valid for HTTP(S)). - Supports HTTP, HTTPS, FTP and SFTP (through SFTP). - Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. - """ - url = urllib.parse.urlparse(urlstring) - username, password = get_authentication_variables(url.username, url.password) - port = get_port_from_url(url) - - if url.scheme == 'http' or url.scheme == 'https': - if source: - print_error('Warning: Custom source address not supported for HTTP connections.') - return get_http_file_size(urlstring, username, password) - elif url.scheme == 'ftp': - source = get_source_address(source)[0] if source else None - username = username if username else 'anonymous' - return get_ftp_file_size(url.hostname, url.path, username, password, port, source) - elif url.scheme == 'sftp' or url.scheme == 'scp': - source = get_source_address(source) if source else None - return get_sftp_file_size(url.hostname, url.path, username, password, port, source) - else: - raise ValueError(f'Unsupported URL scheme: {url.scheme}') +def upload(local_path, urlstring, *args, **kwargs): + urlc(urlstring, *args, **kwargs).upload(local_path) -def get_remote_config(urlstring, source=None): +def get_remote_config(urlstring, source_host='', source_port=0): """ - Download remote (config) file from `urlstring` and return the contents as a string. - Args: - remote file URI: - tftp://<host>[:<port>]/<file> - http[s]://<host>[:<port>]/<file> - [scp|sftp|ftp]://[<user>[:<passwd>]@]<host>[:port]/<file> - source address (optional): - <interface> - <IP address> + Quietly download a file and return it as a string. """ temp = tempfile.NamedTemporaryFile(delete=False).name try: - download(temp, urlstring, source) - with open(temp, 'r') as file: - return file.read() + download(temp, urlstring, False, False, source_host, source_port) + with open(temp, 'r') as f: + return f.read() finally: os.remove(temp) -def friendly_download(local_path, urlstring, source=None): +def friendly_download(local_path, urlstring, source_host='', source_port=0): """ - Download from `urlstring` to `local_path` in an informative way. - Checks the storage space before attempting download. - Intended to be called from interactive, user-facing scripts. + Download with a progress bar, reassuring messages and free space checks. """ - destination_directory = os.path.dirname(local_path) try: - free_space = shutil.disk_usage(destination_directory).free - try: - file_size = get_remote_file_size(urlstring, source) - if file_size < 1024 * 1024: - print_error(f'The file is {file_size / 1024.0:.3f} KiB.') - else: - print_error(f'The file is {file_size / (1024.0 * 1024.0):.3f} MiB.') - if file_size > free_space: - raise OSError(f'Not enough disk space available in "{destination_directory}".') - except ValueError: - # Can't do a storage check in this case, so we bravely continue. - file_size = 0 - print_error('Could not determine the file size in advance.') - else: - print_error('Downloading...') - download(local_path, urlstring, source, progressbar=file_size > 1024 * 1024) + print_error('Downloading...') + download(local_path, urlstring, True, True, source_host, source_port) except KeyboardInterrupt: print_error('Download aborted by user.') sys.exit(1) @@ -401,3 +336,4 @@ def friendly_download(local_path, urlstring, source=None): sys.exit(1) else: print_error('Download complete.') + sys.exit(0) |