diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/vyos/remote.py | 400 | ||||
-rw-r--r-- | python/vyos/util.py | 63 |
2 files changed, 354 insertions, 109 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index db20b173d..aa62ac60d 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -14,133 +14,315 @@ # License along with this library. If not, see <http://www.gnu.org/licenses/>. import os +import shutil import socket +import ssl +import stat import sys import tempfile -from ftplib import FTP import urllib.parse -import urllib.request -from vyos.util import cmd +from ftplib import FTP +from ftplib import FTP_TLS + from paramiko import SSHClient +from paramiko import MissingHostKeyPolicy +from requests import Session +from requests.adapters import HTTPAdapter +from requests.packages.urllib3 import PoolManager -def upload_ftp(local_path, hostname, remote_path,\ - username='anonymous', password='', port=21, source=None): - with open(local_path, 'rb') as file: - with FTP(source_address=source) as conn: - conn.connect(hostname, port) - conn.login(username, password) - conn.storbinary(f'STOR {remote_path}', file) - -def download_ftp(local_path, hostname, remote_path,\ - username='anonymous', password='', port=21, source=None): - with open(local_path, 'wb') as file: - with FTP(source_address=source) as conn: - conn.connect(hostname, port) - conn.login(username, password) - conn.retrbinary(f'RETR {remote_path}', file.write) - -def upload_sftp(local_path, hostname, remote_path,\ - username=None, password=None, port=22, source=None): - sock = None - if source: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((source, 0)) - sock.connect((hostname, port)) - with SSHClient() as ssh: - ssh.load_system_host_keys() - ssh.connect(hostname, port, username, password, sock=sock) - with ssh.open_sftp() as sftp: - sftp.put(local_path, remote_path) - if sock: - sock.shutdown() - sock.close() - -def download_sftp(local_path, hostname, remote_path,\ - username=None, password=None, port=22, source=None): - sock = None - if source: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind((source, 0)) - sock.connect((hostname, port)) - with SSHClient() as ssh: - ssh.load_system_host_keys() - ssh.connect(hostname, port, username, password, sock=sock) - with ssh.open_sftp() as sftp: - sftp.get(remote_path, local_path) - if sock: - sock.shutdown() - sock.close() - -def upload_tftp(local_path, hostname, remote_path, port=69, source=None): - source_option = f'--interface {source}' if source else '' - with open(local_path, 'rb') as file: - cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\ - stderr=None, input=file.read()).encode() - -def download_tftp(local_path, hostname, remote_path, port=69, source=None): - source_option = f'--interface {source}' if source else '' - with open(local_path, 'wb') as file: - file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\ - stderr=None).encode()) - -def download_http(urlstring, local_path): - with open(local_path, 'wb') as file: - with urllib.request.urlopen(urlstring) as response: - file.write(response.read()) - -def download(local_path, urlstring, source=None): +from vyos.util import ask_yes_no +from vyos.util import begin +from vyos.util import cmd +from vyos.util import make_incremental_progressbar +from vyos.util import make_progressbar +from vyos.util import print_error +from vyos.version import get_version + + +CHUNK_SIZE = 8192 + +class InteractivePolicy(MissingHostKeyPolicy): """ - Dispatch the appropriate download function for the given URL and save to local path. + Paramiko policy for interactively querying the user on whether to proceed + with SSH connections to unknown hosts. """ - url = urllib.parse.urlparse(urlstring) - if url.scheme == 'http' or url.scheme == 'https': - if source: - print("Warning: Custom source address not supported for HTTP connections.", file=sys.stderr) - download_http(urlstring, local_path) - elif url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - download_ftp(local_path, url.hostname, url.path, username, url.password, source=source) - elif url.scheme == 'sftp' or url.scheme == 'scp': - download_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source) - elif url.scheme == 'tftp': - download_tftp(local_path, url.hostname, url.path, source=source) - else: - raise ValueError(f'Unsupported URL scheme: {url.scheme}') + def missing_host_key(self, client, hostname, key): + print_error(f"Host '{hostname}' not found in known hosts.") + print_error('Fingerprint: ' + key.get_fingerprint().hex()) + if ask_yes_no('Do you wish to continue?'): + if client._host_keys_filename\ + and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'): + client._host_keys.add(hostname, key.get_name(), key) + client.save_host_keys(client._host_keys_filename) + else: + raise SSHException(f"Cannot connect to unknown host '{hostname}'.") -def upload(local_path, urlstring, source=None): +class SourceAdapter(HTTPAdapter): """ - Dispatch the appropriate upload function for the given URL and upload from local path. + urllib3 transport adapter for setting source addresses per session. """ - url = urllib.parse.urlparse(urlstring) - if url.scheme == 'ftp': - username = url.username if url.username else 'anonymous' - upload_ftp(local_path, url.hostname, url.path, username, url.password, source=source) - elif url.scheme == 'sftp' or url.scheme == 'scp': - upload_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source) - elif url.scheme == 'tftp': - upload_tftp(local_path, url.hostname, url.path, source=source) + def __init__(self, source_pair, *args, **kwargs): + # A source pair is a tuple of a source host string and source port respectively. + # Supply '' and 0 respectively for default values. + self._source_pair = source_pair + super(SourceAdapter, self).__init__(*args, **kwargs) + + def init_poolmanager(self, connections, maxsize, block=False): + self.poolmanager = PoolManager( + num_pools=connections, maxsize=maxsize, + block=block, source_address=self._source_pair) + + +def check_storage(path, size): + """ + Check whether `path` has enough storage space for a transfer of `size` bytes. + """ + path = os.path.abspath(os.path.expanduser(path)) + directory = path if os.path.isdir(path) else (os.path.dirname(os.path.expanduser(path)) or os.getcwd()) + # `size` can be None or 0 to indicate unknown size. + if not size: + print_error('Warning: Cannot determine size of remote file.') + print_error('Bravely continuing regardless.') + return + + if size < 1024 * 1024: + print_error(f'The file is {size / 1024.0:.3f} KiB.') else: - raise ValueError(f'Unsupported URL scheme: {url.scheme}') + print_error(f'The file is {size / (1024.0 * 1024.0):.3f} MiB.') + + # Will throw `FileNotFoundError' if `directory' is absent. + if size > shutil.disk_usage(directory).free: + raise OSError(f'Not enough disk space available in "{directory}".') + -def get_remote_config(urlstring): +class FtpC: + def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + self.secure = url.scheme == 'ftps' + self.hostname = url.hostname + self.path = url.path + self.username = url.username or os.getenv('REMOTE_USERNAME', 'anonymous') + self.password = url.password or os.getenv('REMOTE_PASSWORD', '') + self.port = url.port or 21 + self.source = (source_host, source_port) + self.progressbar = progressbar + self.check_space = check_space + + def _establish(self): + if self.secure: + return FTP_TLS(source_address=self.source, context=ssl.create_default_context()) + else: + return FTP(source_address=self.source) + + def download(self, location: str): + # Open the file upfront before establishing connection. + with open(location, 'wb') as f, self._establish() as conn: + conn.connect(self.hostname, self.port) + conn.login(self.username, self.password) + # Set secure connection over TLS. + if self.secure: + conn.prot_p() + # Almost all FTP servers support the `SIZE' command. + if self.check_space: + check_storage(path, conn.size(self.path)) + # No progressbar if we can't determine the size or if the file is too small. + if self.progressbar and size and size > CHUNK_SIZE: + progress = make_incremental_progressbar(CHUNK_SIZE / size) + next(progress) + callback = lambda block: begin(f.write(block), next(progress)) + else: + callback = f.write + conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE) + + def upload(self, location: str): + size = os.path.getsize(location) + with open(location, 'rb') as f, self._establish() as conn: + conn.connect(self.hostname, self.port) + conn.login(self.username, self.password) + if self.secure: + conn.prot_p() + if self.progressbar and size and size > CHUNK_SIZE: + progress = make_incremental_progressbar(CHUNK_SIZE / size) + next(progress) + callback = lambda block: next(progress) + else: + callback = None + conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, callback) + +class SshC: + known_hosts = os.path.expanduser('~/.ssh/known_hosts') + def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + self.hostname = url.hostname + self.path = url.path + self.username = url.username or os.getenv('REMOTE_USERNAME') + self.password = url.password or os.getenv('REMOTE_PASSWORD') + self.port = url.port or 22 + self.source = (source_host, source_port) + self.progressbar = progressbar + self.check_space = check_space + + def _establish(self): + ssh = SSHClient() + ssh.load_system_host_keys() + # Try to load from a user-local known hosts file if one exists. + if os.path.exists(self.known_hosts): + ssh.load_host_keys(self.known_hosts) + ssh.set_missing_host_key_policy(InteractivePolicy()) + # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family + # for us on dual-stack systems. + sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source) + ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock) + return ssh + + def download(self, location: str): + callback = make_progressbar() if self.progressbar else None + with self._establish() as ssh, ssh.open_sftp() as sftp: + if self.check_space: + check_storage(location, sftp.stat(self.path).st_size) + sftp.get(self.path, location, callback=callback) + + def upload(self, location: str): + callback = make_progressbar() if self.progressbar else None + with self._establish() as ssh, ssh.open_sftp() as sftp: + try: + # If the remote path is a directory, use the original filename. + if stat.S_ISDIR(sftp.stat(self.path).st_mode): + path = os.path.join(self.path, os.path.basename(location)) + # A file exists at this destination. We're simply going to clobber it. + else: + path = self.path + # This path doesn't point at any existing file. We can freely use this filename. + except IOError: + path = self.path + finally: + sftp.put(location, path, callback=callback) + + +class HttpC: + def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): + self.urlstring = urllib.parse.urlunsplit(url) + self.progressbar = progressbar + self.check_space = check_space + self.source_pair = (source_host, source_port) + self.username = url.username or os.getenv('REMOTE_USERNAME') + self.password = url.password or os.getenv('REMOTE_PASSWORD') + + def _establish(self): + session = Session() + session.mount(self.urlstring, SourceAdapter(self.source_pair)) + session.headers.update({'User-Agent': 'VyOS/' + get_version()}) + if self.username: + session.auth = self.username, self.password + return session + + def download(self, location: str): + with self._establish() as s: + # We ask for uncompressed downloads so that we don't have to deal with decoding. + # Not only would it potentially mess up with the progress bar but + # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding. + s.headers.update({'Accept-Encoding': 'identity'}) + with s.head(self.urlstring, allow_redirects=True) as r: + # Abort early if the destination is inaccessible. + r.raise_for_status() + # If the request got redirected, keep the last URL we ended up with. + final_urlstring = r.url + if r.history: + print_error('Redirecting to ' + final_urlstring) + # Check for the prospective file size. + try: + size = int(r.headers['Content-Length']) + # In case the server does not supply the header. + except KeyError: + size = None + if self.check_space: + check_storage(location, size) + with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f: + if self.progressbar and size: + progress = make_incremental_progressbar(CHUNK_SIZE / size) + next(progress) + for chunk in iter(lambda: begin(next(progress), r.raw.read(CHUNK_SIZE)), b''): + f.write(chunk) + else: + # We'll try to stream the download directly with `copyfileobj()` so that large + # files (like entire VyOS images) don't occupy much memory. + shutil.copyfileobj(r.raw, f) + + def upload(self, location: str): + # Does not yet support progressbars. + with self._establish() as s, open(location, 'rb') as f: + s.post(self.urlstring, data=f, allow_redirects=True) + + +class TftpC: + # We simply allow `curl` to take over because + # 1. TFTP is rather simple. + # 2. Since there's no concept authentication, we don't need to deal with keys/passwords. + # 3. It would be a waste to import, audit and maintain a third-party library for TFTP. + # 4. I'd rather not implement the entire protocol here, no matter how simple it is. + def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0): + source_option = f'--interface {source_host} --local-port {source_port}' if source_host else '' + progress_flag = '--progress-bar' if progressbar else '-s' + self.command = f'curl {source_option} {progress_flag}' + self.urlstring = urllib.parse.urlunsplit(url) + + def download(self, location: str): + with open(location, 'wb') as f: + f.write(cmd(f'{self.command} "{self.urlstring}"').encode()) + + def upload(self, location: str): + with open(location, 'rb') as f: + cmd(f'{self.command} -T - "{self.urlstring}"', input=f.read()) + + +def urlc(urlstring, *args, **kwargs): + """ + Dynamically dispatch the appropriate protocol class. + """ + url_classes = {'http': HttpC, 'https': HttpC, 'ftp': FtpC, 'ftps': FtpC, \ + 'sftp': SshC, 'ssh': SshC, 'scp': SshC, 'tftp': TftpC} + url = urllib.parse.urlsplit(urlstring) + try: + return url_classes[url.scheme](url, *args, **kwargs) + except KeyError: + raise ValueError(f'Unsupported URL scheme: "{url.scheme}"') + +def download(local_path, urlstring, *args, **kwargs): + urlc(urlstring, *args, **kwargs).download(local_path) + +def upload(local_path, urlstring, *args, **kwargs): + urlc(urlstring, *args, **kwargs).upload(local_path) + +def get_remote_config(urlstring, source_host='', source_port=0): """ - Download remote (config) file and return the contents. - Args: - remote file URI: - scp://<user>[:<passwd>]@<host>/<file> - sftp://<user>[:<passwd>]@<host>/<file> - http://<host>/<file> - https://<host>/<file> - ftp://[<user>[:<passwd>]@]<host>/<file> - tftp://<host>/<file> + Quietly download a file and return it as a string. """ - url = urllib.parse.urlparse(urlstring) temp = tempfile.NamedTemporaryFile(delete=False).name try: - download(temp, urlstring) - with open(temp, 'r') as file: - return file.read() + download(temp, urlstring, False, False, source_host, source_port) + with open(temp, 'r') as f: + return f.read() finally: os.remove(temp) + +def friendly_download(local_path, urlstring, source_host='', source_port=0): + """ + Download with a progress bar, reassuring messages and free space checks. + """ + try: + print_error('Downloading...') + download(local_path, urlstring, True, True, source_host, source_port) + except KeyboardInterrupt: + print_error('\nDownload aborted by user.') + sys.exit(1) + except: + import traceback + # There are a myriad different reasons a download could fail. + # SSH errors, FTP errors, I/O errors, HTTP errors (403, 404...) + # We omit the scary stack trace but print the error nevertheless. + print_error(f'Failed to download {urlstring}.') + traceback.print_exception(*sys.exc_info()[:2], None) + sys.exit(1) + else: + print_error('Download complete.') + sys.exit(0) diff --git a/python/vyos/util.py b/python/vyos/util.py index dcb332e5c..f2f302559 100644 --- a/python/vyos/util.py +++ b/python/vyos/util.py @@ -704,6 +704,69 @@ def get_interface_config(interface): tmp = loads(cmd(f'ip -d -j link show {interface}'))[0] return tmp +def print_error(str='', end='\n'): + """ + Print `str` to stderr, terminated with `end`. + Used for warnings and out-of-band messages to avoid mangling precious + stdout output. + """ + sys.stderr.write(str) + sys.stderr.write(end) + sys.stderr.flush() + +def make_progressbar(): + """ + Make a procedure that takes two arguments `done` and `total` and prints a + progressbar based on the ratio thereof, whose length is determined by the + width of the terminal. + """ + import shutil, math + col, _ = shutil.get_terminal_size() + col = max(col - 15, 20) + def print_progressbar(done, total): + if done <= total: + increment = total / col + length = math.ceil(done / increment) + percentage = str(math.ceil(100 * done / total)).rjust(3) + print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r') + # Print a newline so that the subsequent prints don't overwrite the full bar. + if done == total: + print_error() + return print_progressbar + +def make_incremental_progressbar(increment: float): + """ + Make a generator that displays a progressbar that grows monotonically with + every iteration. + First call displays it at 0% and every subsequent iteration displays it + at `increment` increments where 0.0 < `increment` < 1.0. + Intended for FTP and HTTP transfers with stateless callbacks. + """ + print_progressbar = make_progressbar() + total = 0.0 + while total < 1.0: + print_progressbar(total, 1.0) + yield + total += increment + print_progressbar(1, 1) + # Ignore further calls. + while True: + yield + +def begin(*args): + """ + Evaluate arguments in order and return the result of the *last* argument. + For combining multiple expressions in one statement. Useful for lambdas. + """ + return args[-1] + +def begin0(*args): + """ + Evaluate arguments in order and return the result of the *first* argument. + For combining multiple expressions in one statement. Useful for lambdas. + """ + return args[0] + def is_systemd_service_active(service): """ Test is a specified systemd service is activated. Returns True if service is active, false otherwise. |