From aab730c57542fdd64665ea5c4158a5e7d2e8ac76 Mon Sep 17 00:00:00 2001
From: zsdc <taras@vyos.io>
Date: Fri, 21 Jul 2023 16:38:41 +0300
Subject: remote: T4412: fixed upload via SSH

- added timeout to socket creating
- added skipping SSH fingerprint check with a negative result if a
console is not interactive
- replaced tracebacks with human-readable error messages
- suppressed warnings from `cryptography` used by `paramiko`
---
 python/vyos/remote.py | 29 ++++++++++++++++++++++++-----
 1 file changed, 24 insertions(+), 5 deletions(-)

(limited to 'python')

diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 66044fa52..d1db68f1c 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -21,11 +21,13 @@ import stat
 import sys
 import tempfile
 import urllib.parse
+import warnings
 
+from cryptography.utils import CryptographyDeprecationWarning
 from ftplib import FTP
 from ftplib import FTP_TLS
 
-from paramiko import SSHClient
+from paramiko import SSHClient, SSHException
 from paramiko import MissingHostKeyPolicy
 
 from requests import Session
@@ -43,6 +45,10 @@ from vyos.version import get_version
 
 CHUNK_SIZE = 8192
 
+# suppress warnings for deprecated APIs used by paramiko
+warnings.simplefilter('ignore', category=CryptographyDeprecationWarning)
+
+
 class InteractivePolicy(MissingHostKeyPolicy):
     """
     Paramiko policy for interactively querying the user on whether to proceed
@@ -51,7 +57,7 @@ class InteractivePolicy(MissingHostKeyPolicy):
     def missing_host_key(self, client, hostname, key):
         print_error(f"Host '{hostname}' not found in known hosts.")
         print_error('Fingerprint: ' + key.get_fingerprint().hex())
-        if ask_yes_no('Do you wish to continue?'):
+        if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'):
             if client._host_keys_filename\
                and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
                 client._host_keys.add(hostname, key.get_name(), key)
@@ -151,7 +157,14 @@ class FtpC:
 
 class SshC:
     known_hosts = os.path.expanduser('~/.ssh/known_hosts')
-    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+
+    def __init__(self,
+                 url,
+                 progressbar=False,
+                 check_space=False,
+                 source_host='',
+                 source_port=0,
+                 timeout=10.0):
         self.hostname = url.hostname
         self.path = url.path
         self.username = url.username or os.getenv('REMOTE_USERNAME')
@@ -160,6 +173,7 @@ class SshC:
         self.source = (source_host, source_port)
         self.progressbar = progressbar
         self.check_space = check_space
+        self.timeout = timeout
 
     def _establish(self):
         ssh = SSHClient()
@@ -170,8 +184,13 @@ class SshC:
         ssh.set_missing_host_key_policy(InteractivePolicy())
         # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family
         #  for us on dual-stack systems.
-        sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source)
-        ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)
+        try:
+            sock = socket.create_connection((self.hostname, self.port),
+                                            self.timeout, self.source)
+            ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)
+        except Exception as err:
+            print(f'Cannot connect to "{self.hostname}:{self.port}": {err}')
+            sys.exit(1)
         return ssh
 
     def download(self, location: str):
-- 
cgit v1.2.3