summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzsdc <taras@vyos.io>2023-07-25 12:47:31 +0300
committerzsdc <taras@vyos.io>2023-07-25 12:47:31 +0300
commit494729145397b42fb10c5be472df5d757005a573 (patch)
treeb0dbfc65f90284ceb8e80c14e3af26f8e0b44cf4
parent20b7155f4140f54cf7669256160b6fedd8c1ab7a (diff)
downloadvyos-1x-494729145397b42fb10c5be472df5d757005a573.tar.gz
vyos-1x-494729145397b42fb10c5be472df5d757005a573.zip
remote: T4412: Improved error handling for uploads/downloads
- added ability to set a timeout, with default value 10s - added exceptions handling to show nicer messages for users - denied to use untrusted SSH hosts in non-interactive mode
-rw-r--r--python/vyos/remote.py72
1 files changed, 57 insertions, 15 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 16fe2b2c2..cf731c881 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -25,7 +25,7 @@ import urllib.parse
from ftplib import FTP
from ftplib import FTP_TLS
-from paramiko import SSHClient
+from paramiko import SSHClient, SSHException
from paramiko import MissingHostKeyPolicy
from requests import Session
@@ -50,7 +50,7 @@ class InteractivePolicy(MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
print_error(f"Host '{hostname}' not found in known hosts.")
print_error('Fingerprint: ' + key.get_fingerprint().hex())
- if ask_yes_no('Do you wish to continue?'):
+ if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'):
if client._host_keys_filename\
and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
client._host_keys.add(hostname, key.get_name(), key)
@@ -96,7 +96,13 @@ def check_storage(path, size):
class FtpC:
- def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+ def __init__(self,
+ url,
+ progressbar=False,
+ check_space=False,
+ source_host='',
+ source_port=0,
+ timeout=10):
self.secure = url.scheme == 'ftps'
self.hostname = url.hostname
self.path = url.path
@@ -106,12 +112,15 @@ class FtpC:
self.source = (source_host, source_port)
self.progressbar = progressbar
self.check_space = check_space
+ self.timeout = timeout
def _establish(self):
if self.secure:
- return FTP_TLS(source_address=self.source, context=ssl.create_default_context())
+ return FTP_TLS(source_address=self.source,
+ context=ssl.create_default_context(),
+ timeout=self.timeout)
else:
- return FTP(source_address=self.source)
+ return FTP(source_address=self.source, timeout=self.timeout)
def download(self, location: str):
# Open the file upfront before establishing connection.
@@ -150,7 +159,13 @@ class FtpC:
class SshC:
known_hosts = os.path.expanduser('~/.ssh/known_hosts')
- def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+ def __init__(self,
+ url,
+ progressbar=False,
+ check_space=False,
+ source_host='',
+ source_port=0,
+ timeout=10.0):
self.hostname = url.hostname
self.path = url.path
self.username = url.username or os.getenv('REMOTE_USERNAME')
@@ -159,6 +174,7 @@ class SshC:
self.source = (source_host, source_port)
self.progressbar = progressbar
self.check_space = check_space
+ self.timeout = timeout
def _establish(self):
ssh = SSHClient()
@@ -169,7 +185,7 @@ class SshC:
ssh.set_missing_host_key_policy(InteractivePolicy())
# `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family
# for us on dual-stack systems.
- sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source)
+ sock = socket.create_connection((self.hostname, self.port), self.timeout, self.source)
ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)
return ssh
@@ -198,13 +214,20 @@ class SshC:
class HttpC:
- def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+ def __init__(self,
+ url,
+ progressbar=False,
+ check_space=False,
+ source_host='',
+ source_port=0,
+ timeout=10.0):
self.urlstring = urllib.parse.urlunsplit(url)
self.progressbar = progressbar
self.check_space = check_space
self.source_pair = (source_host, source_port)
self.username = url.username or os.getenv('REMOTE_USERNAME')
self.password = url.password or os.getenv('REMOTE_PASSWORD')
+ self.timeout = timeout
def _establish(self):
session = Session()
@@ -220,8 +243,11 @@ class HttpC:
# Not only would it potentially mess up with the progress bar but
# `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding.
s.headers.update({'Accept-Encoding': 'identity'})
- with s.head(self.urlstring, allow_redirects=True) as r:
+ with s.head(self.urlstring,
+ allow_redirects=True,
+ timeout=self.timeout) as r:
# Abort early if the destination is inaccessible.
+ print('pre-3')
r.raise_for_status()
# If the request got redirected, keep the last URL we ended up with.
final_urlstring = r.url
@@ -235,7 +261,8 @@ class HttpC:
size = None
if self.check_space:
check_storage(location, size)
- with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f:
+ with s.get(final_urlstring, stream=True,
+ timeout=self.timeout) as r, open(location, 'wb') as f:
if self.progressbar and size:
progress = make_incremental_progressbar(CHUNK_SIZE / size)
next(progress)
@@ -249,7 +276,10 @@ class HttpC:
def upload(self, location: str):
# Does not yet support progressbars.
with self._establish() as s, open(location, 'rb') as f:
- s.post(self.urlstring, data=f, allow_redirects=True)
+ s.post(self.urlstring,
+ data=f,
+ allow_redirects=True,
+ timeout=self.timeout)
class TftpC:
@@ -258,10 +288,16 @@ class TftpC:
# 2. Since there's no concept authentication, we don't need to deal with keys/passwords.
# 3. It would be a waste to import, audit and maintain a third-party library for TFTP.
# 4. I'd rather not implement the entire protocol here, no matter how simple it is.
- def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0):
+ def __init__(self,
+ url,
+ progressbar=False,
+ check_space=False,
+ source_host=None,
+ source_port=0,
+ timeout=10):
source_option = f'--interface {source_host} --local-port {source_port}' if source_host else ''
progress_flag = '--progress-bar' if progressbar else '-s'
- self.command = f'curl {source_option} {progress_flag}'
+ self.command = f'curl {source_option} {progress_flag} --connect-timeout {timeout}'
self.urlstring = urllib.parse.urlunsplit(url)
def download(self, location: str):
@@ -286,10 +322,16 @@ def urlc(urlstring, *args, **kwargs):
raise ValueError(f'Unsupported URL scheme: "{url.scheme}"')
def download(local_path, urlstring, *args, **kwargs):
- urlc(urlstring, *args, **kwargs).download(local_path)
+ try:
+ urlc(urlstring, *args, **kwargs).download(local_path)
+ except Exception as err:
+ print_error(f'Unable to download "{urlstring}": {err}')
def upload(local_path, urlstring, *args, **kwargs):
- urlc(urlstring, *args, **kwargs).upload(local_path)
+ try:
+ urlc(urlstring, *args, **kwargs).upload(local_path)
+ except Exception as err:
+ print_error(f'Unable to upload "{urlstring}": {err}')
def get_remote_config(urlstring, source_host='', source_port=0):
"""