From 494729145397b42fb10c5be472df5d757005a573 Mon Sep 17 00:00:00 2001
From: zsdc <taras@vyos.io>
Date: Tue, 25 Jul 2023 12:47:31 +0300
Subject: remote: T4412: Improved error handling for uploads/downloads

- added ability to set a timeout, with default value 10s
- added exceptions handling to show nicer messages for users
- denied to use untrusted SSH hosts in non-interactive mode
---
 python/vyos/remote.py | 72 ++++++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 57 insertions(+), 15 deletions(-)

(limited to 'python')

diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 16fe2b2c2..cf731c881 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -25,7 +25,7 @@ import urllib.parse
 from ftplib import FTP
 from ftplib import FTP_TLS
 
-from paramiko import SSHClient
+from paramiko import SSHClient, SSHException
 from paramiko import MissingHostKeyPolicy
 
 from requests import Session
@@ -50,7 +50,7 @@ class InteractivePolicy(MissingHostKeyPolicy):
     def missing_host_key(self, client, hostname, key):
         print_error(f"Host '{hostname}' not found in known hosts.")
         print_error('Fingerprint: ' + key.get_fingerprint().hex())
-        if ask_yes_no('Do you wish to continue?'):
+        if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'):
             if client._host_keys_filename\
                and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
                 client._host_keys.add(hostname, key.get_name(), key)
@@ -96,7 +96,13 @@ def check_storage(path, size):
 
 
 class FtpC:
-    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+    def __init__(self,
+                 url,
+                 progressbar=False,
+                 check_space=False,
+                 source_host='',
+                 source_port=0,
+                 timeout=10):
         self.secure = url.scheme == 'ftps'
         self.hostname = url.hostname
         self.path = url.path
@@ -106,12 +112,15 @@ class FtpC:
         self.source = (source_host, source_port)
         self.progressbar = progressbar
         self.check_space = check_space
+        self.timeout = timeout
 
     def _establish(self):
         if self.secure:
-            return FTP_TLS(source_address=self.source, context=ssl.create_default_context())
+            return FTP_TLS(source_address=self.source,
+                           context=ssl.create_default_context(),
+                           timeout=self.timeout)
         else:
-            return FTP(source_address=self.source)
+            return FTP(source_address=self.source, timeout=self.timeout)
 
     def download(self, location: str):
         # Open the file upfront before establishing connection.
@@ -150,7 +159,13 @@ class FtpC:
 
 class SshC:
     known_hosts = os.path.expanduser('~/.ssh/known_hosts')
-    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+    def __init__(self,
+                 url,
+                 progressbar=False,
+                 check_space=False,
+                 source_host='',
+                 source_port=0,
+                 timeout=10.0):
         self.hostname = url.hostname
         self.path = url.path
         self.username = url.username or os.getenv('REMOTE_USERNAME')
@@ -159,6 +174,7 @@ class SshC:
         self.source = (source_host, source_port)
         self.progressbar = progressbar
         self.check_space = check_space
+        self.timeout = timeout
 
     def _establish(self):
         ssh = SSHClient()
@@ -169,7 +185,7 @@ class SshC:
         ssh.set_missing_host_key_policy(InteractivePolicy())
         # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family
         #  for us on dual-stack systems.
-        sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source)
+        sock = socket.create_connection((self.hostname, self.port), self.timeout, self.source)
         ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)
         return ssh
 
@@ -198,13 +214,20 @@ class SshC:
 
 
 class HttpC:
-    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+    def __init__(self,
+                 url,
+                 progressbar=False,
+                 check_space=False,
+                 source_host='',
+                 source_port=0,
+                 timeout=10.0):
         self.urlstring = urllib.parse.urlunsplit(url)
         self.progressbar = progressbar
         self.check_space = check_space
         self.source_pair = (source_host, source_port)
         self.username = url.username or os.getenv('REMOTE_USERNAME')
         self.password = url.password or os.getenv('REMOTE_PASSWORD')
+        self.timeout = timeout
 
     def _establish(self):
         session = Session()
@@ -220,8 +243,11 @@ class HttpC:
             # Not only would it potentially mess up with the progress bar but
             # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding.
             s.headers.update({'Accept-Encoding': 'identity'})
-            with s.head(self.urlstring, allow_redirects=True) as r:
+            with s.head(self.urlstring,
+                        allow_redirects=True,
+                        timeout=self.timeout) as r:
                 # Abort early if the destination is inaccessible.
+                print('pre-3')
                 r.raise_for_status()
                 # If the request got redirected, keep the last URL we ended up with.
                 final_urlstring = r.url
@@ -235,7 +261,8 @@ class HttpC:
                     size = None
             if self.check_space:
                 check_storage(location, size)
-            with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f:
+            with s.get(final_urlstring, stream=True,
+                       timeout=self.timeout) as r, open(location, 'wb') as f:
                 if self.progressbar and size:
                     progress = make_incremental_progressbar(CHUNK_SIZE / size)
                     next(progress)
@@ -249,7 +276,10 @@ class HttpC:
     def upload(self, location: str):
         # Does not yet support progressbars.
         with self._establish() as s, open(location, 'rb') as f:
-            s.post(self.urlstring, data=f, allow_redirects=True)
+            s.post(self.urlstring,
+                   data=f,
+                   allow_redirects=True,
+                   timeout=self.timeout)
 
 
 class TftpC:
@@ -258,10 +288,16 @@ class TftpC:
     # 2. Since there's no concept authentication, we don't need to deal with keys/passwords.
     # 3. It would be a waste to import, audit and maintain a third-party library for TFTP.
     # 4. I'd rather not implement the entire protocol here, no matter how simple it is.
-    def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0):
+    def __init__(self,
+                 url,
+                 progressbar=False,
+                 check_space=False,
+                 source_host=None,
+                 source_port=0,
+                 timeout=10):
         source_option = f'--interface {source_host} --local-port {source_port}' if source_host else ''
         progress_flag = '--progress-bar' if progressbar else '-s'
-        self.command = f'curl {source_option} {progress_flag}'
+        self.command = f'curl {source_option} {progress_flag} --connect-timeout {timeout}'
         self.urlstring = urllib.parse.urlunsplit(url)
 
     def download(self, location: str):
@@ -286,10 +322,16 @@ def urlc(urlstring, *args, **kwargs):
         raise ValueError(f'Unsupported URL scheme: "{url.scheme}"')
 
 def download(local_path, urlstring, *args, **kwargs):
-    urlc(urlstring, *args, **kwargs).download(local_path)
+    try:
+        urlc(urlstring, *args, **kwargs).download(local_path)
+    except Exception as err:
+        print_error(f'Unable to download "{urlstring}": {err}')
 
 def upload(local_path, urlstring, *args, **kwargs):
-    urlc(urlstring, *args, **kwargs).upload(local_path)
+    try:
+        urlc(urlstring, *args, **kwargs).upload(local_path)
+    except Exception as err:
+        print_error(f'Unable to upload "{urlstring}": {err}')
 
 def get_remote_config(urlstring, source_host='', source_port=0):
     """
-- 
cgit v1.2.3