summaryrefslogtreecommitdiff
path: root/python/vyos/remote.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/vyos/remote.py')
-rw-r--r--python/vyos/remote.py29
1 files changed, 24 insertions, 5 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 66044fa52..d1db68f1c 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -21,11 +21,13 @@ import stat
import sys
import tempfile
import urllib.parse
+import warnings
+from cryptography.utils import CryptographyDeprecationWarning
from ftplib import FTP
from ftplib import FTP_TLS
-from paramiko import SSHClient
+from paramiko import SSHClient, SSHException
from paramiko import MissingHostKeyPolicy
from requests import Session
@@ -43,6 +45,10 @@ from vyos.version import get_version
CHUNK_SIZE = 8192
+# suppress warnings for deprecated APIs used by paramiko
+warnings.simplefilter('ignore', category=CryptographyDeprecationWarning)
+
+
class InteractivePolicy(MissingHostKeyPolicy):
"""
Paramiko policy for interactively querying the user on whether to proceed
@@ -51,7 +57,7 @@ class InteractivePolicy(MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
print_error(f"Host '{hostname}' not found in known hosts.")
print_error('Fingerprint: ' + key.get_fingerprint().hex())
- if ask_yes_no('Do you wish to continue?'):
+ if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'):
if client._host_keys_filename\
and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
client._host_keys.add(hostname, key.get_name(), key)
@@ -151,7 +157,14 @@ class FtpC:
class SshC:
known_hosts = os.path.expanduser('~/.ssh/known_hosts')
- def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0):
+
+ def __init__(self,
+ url,
+ progressbar=False,
+ check_space=False,
+ source_host='',
+ source_port=0,
+ timeout=10.0):
self.hostname = url.hostname
self.path = url.path
self.username = url.username or os.getenv('REMOTE_USERNAME')
@@ -160,6 +173,7 @@ class SshC:
self.source = (source_host, source_port)
self.progressbar = progressbar
self.check_space = check_space
+ self.timeout = timeout
def _establish(self):
ssh = SSHClient()
@@ -170,8 +184,13 @@ class SshC:
ssh.set_missing_host_key_policy(InteractivePolicy())
# `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family
# for us on dual-stack systems.
- sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source)
- ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)
+ try:
+ sock = socket.create_connection((self.hostname, self.port),
+ self.timeout, self.source)
+ ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)
+ except Exception as err:
+ print(f'Cannot connect to "{self.hostname}:{self.port}": {err}')
+ sys.exit(1)
return ssh
def download(self, location: str):