diff options
Diffstat (limited to 'python')
| -rw-r--r-- | python/vyos/remote.py | 72 | ||||
| -rw-r--r-- | python/vyos/utils/kernel.py | 11 | 
2 files changed, 68 insertions, 15 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 16fe2b2c2..cf731c881 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -25,7 +25,7 @@ import urllib.parse  from ftplib import FTP  from ftplib import FTP_TLS -from paramiko import SSHClient +from paramiko import SSHClient, SSHException  from paramiko import MissingHostKeyPolicy  from requests import Session @@ -50,7 +50,7 @@ class InteractivePolicy(MissingHostKeyPolicy):      def missing_host_key(self, client, hostname, key):          print_error(f"Host '{hostname}' not found in known hosts.")          print_error('Fingerprint: ' + key.get_fingerprint().hex()) -        if ask_yes_no('Do you wish to continue?'): +        if sys.stdout.isatty() and ask_yes_no('Do you wish to continue?'):              if client._host_keys_filename\                 and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):                  client._host_keys.add(hostname, key.get_name(), key) @@ -96,7 +96,13 @@ def check_storage(path, size):  class FtpC: -    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): +    def __init__(self, +                 url, +                 progressbar=False, +                 check_space=False, +                 source_host='', +                 source_port=0, +                 timeout=10):          self.secure = url.scheme == 'ftps'          self.hostname = url.hostname          self.path = url.path @@ -106,12 +112,15 @@ class FtpC:          self.source = (source_host, source_port)          self.progressbar = progressbar          self.check_space = check_space +        self.timeout = timeout      def _establish(self):          if self.secure: -            return FTP_TLS(source_address=self.source, context=ssl.create_default_context()) +            return FTP_TLS(source_address=self.source, +                           context=ssl.create_default_context(), +                           timeout=self.timeout)          else: -            return FTP(source_address=self.source) +            return FTP(source_address=self.source, timeout=self.timeout)      def download(self, location: str):          # Open the file upfront before establishing connection. @@ -150,7 +159,13 @@ class FtpC:  class SshC:      known_hosts = os.path.expanduser('~/.ssh/known_hosts') -    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): +    def __init__(self, +                 url, +                 progressbar=False, +                 check_space=False, +                 source_host='', +                 source_port=0, +                 timeout=10.0):          self.hostname = url.hostname          self.path = url.path          self.username = url.username or os.getenv('REMOTE_USERNAME') @@ -159,6 +174,7 @@ class SshC:          self.source = (source_host, source_port)          self.progressbar = progressbar          self.check_space = check_space +        self.timeout = timeout      def _establish(self):          ssh = SSHClient() @@ -169,7 +185,7 @@ class SshC:          ssh.set_missing_host_key_policy(InteractivePolicy())          # `socket.create_connection()` automatically picks a NIC and an IPv4/IPv6 address family          #  for us on dual-stack systems. -        sock = socket.create_connection((self.hostname, self.port), socket.getdefaulttimeout(), self.source) +        sock = socket.create_connection((self.hostname, self.port), self.timeout, self.source)          ssh.connect(self.hostname, self.port, self.username, self.password, sock=sock)          return ssh @@ -198,13 +214,20 @@ class SshC:  class HttpC: -    def __init__(self, url, progressbar=False, check_space=False, source_host='', source_port=0): +    def __init__(self, +                 url, +                 progressbar=False, +                 check_space=False, +                 source_host='', +                 source_port=0, +                 timeout=10.0):          self.urlstring = urllib.parse.urlunsplit(url)          self.progressbar = progressbar          self.check_space = check_space          self.source_pair = (source_host, source_port)          self.username = url.username or os.getenv('REMOTE_USERNAME')          self.password = url.password or os.getenv('REMOTE_PASSWORD') +        self.timeout = timeout      def _establish(self):          session = Session() @@ -220,8 +243,11 @@ class HttpC:              # Not only would it potentially mess up with the progress bar but              # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding.              s.headers.update({'Accept-Encoding': 'identity'}) -            with s.head(self.urlstring, allow_redirects=True) as r: +            with s.head(self.urlstring, +                        allow_redirects=True, +                        timeout=self.timeout) as r:                  # Abort early if the destination is inaccessible. +                print('pre-3')                  r.raise_for_status()                  # If the request got redirected, keep the last URL we ended up with.                  final_urlstring = r.url @@ -235,7 +261,8 @@ class HttpC:                      size = None              if self.check_space:                  check_storage(location, size) -            with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f: +            with s.get(final_urlstring, stream=True, +                       timeout=self.timeout) as r, open(location, 'wb') as f:                  if self.progressbar and size:                      progress = make_incremental_progressbar(CHUNK_SIZE / size)                      next(progress) @@ -249,7 +276,10 @@ class HttpC:      def upload(self, location: str):          # Does not yet support progressbars.          with self._establish() as s, open(location, 'rb') as f: -            s.post(self.urlstring, data=f, allow_redirects=True) +            s.post(self.urlstring, +                   data=f, +                   allow_redirects=True, +                   timeout=self.timeout)  class TftpC: @@ -258,10 +288,16 @@ class TftpC:      # 2. Since there's no concept authentication, we don't need to deal with keys/passwords.      # 3. It would be a waste to import, audit and maintain a third-party library for TFTP.      # 4. I'd rather not implement the entire protocol here, no matter how simple it is. -    def __init__(self, url, progressbar=False, check_space=False, source_host=None, source_port=0): +    def __init__(self, +                 url, +                 progressbar=False, +                 check_space=False, +                 source_host=None, +                 source_port=0, +                 timeout=10):          source_option = f'--interface {source_host} --local-port {source_port}' if source_host else ''          progress_flag = '--progress-bar' if progressbar else '-s' -        self.command = f'curl {source_option} {progress_flag}' +        self.command = f'curl {source_option} {progress_flag} --connect-timeout {timeout}'          self.urlstring = urllib.parse.urlunsplit(url)      def download(self, location: str): @@ -286,10 +322,16 @@ def urlc(urlstring, *args, **kwargs):          raise ValueError(f'Unsupported URL scheme: "{url.scheme}"')  def download(local_path, urlstring, *args, **kwargs): -    urlc(urlstring, *args, **kwargs).download(local_path) +    try: +        urlc(urlstring, *args, **kwargs).download(local_path) +    except Exception as err: +        print_error(f'Unable to download "{urlstring}": {err}')  def upload(local_path, urlstring, *args, **kwargs): -    urlc(urlstring, *args, **kwargs).upload(local_path) +    try: +        urlc(urlstring, *args, **kwargs).upload(local_path) +    except Exception as err: +        print_error(f'Unable to upload "{urlstring}": {err}')  def get_remote_config(urlstring, source_host='', source_port=0):      """ diff --git a/python/vyos/utils/kernel.py b/python/vyos/utils/kernel.py index 0eb113174..1f3bbdffe 100644 --- a/python/vyos/utils/kernel.py +++ b/python/vyos/utils/kernel.py @@ -25,3 +25,14 @@ def check_kmod(k_mod):          if not os.path.exists(f'/sys/module/{module}'):              if call(f'modprobe {module}') != 0:                  raise ConfigError(f'Loading Kernel module {module} failed') + +def unload_kmod(k_mod): +    """ Common utility function to unload required kernel modules on demand """ +    from vyos import ConfigError +    from vyos.utils.process import call +    if isinstance(k_mod, str): +        k_mod = k_mod.split() +    for module in k_mod: +        if os.path.exists(f'/sys/module/{module}'): +            if call(f'rmmod {module}') != 0: +                raise ConfigError(f'Unloading Kernel module {module} failed')  | 
