From e86edf18c18065b61760743cc4311ac0c75c79ab Mon Sep 17 00:00:00 2001 From: erkin Date: Thu, 2 Dec 2021 17:23:45 +0300 Subject: remote: T4037: Follow HTTP redirects --- python/vyos/remote.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'python') diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 2419f8873..732ef76b7 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -89,6 +89,7 @@ class WrappedFile: def __getattr__(self, attr): return getattr(self._obj, attr) + def check_storage(path, size): """ Check whether `path` has enough storage space for a transfer of `size` bytes. @@ -236,9 +237,15 @@ class HttpC: # Not only would it potentially mess up with the progress bar but # `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding. s.headers.update({'Accept-Encoding': 'identity'}) - with s.head(self.urlstring) as r: + with s.head(self.urlstring, allow_redirects=True) as r: # Abort early if the destination is inaccessible. r.raise_for_status() + # If the request got redirected, keep the last URL we ended up with. + if r.history: + final_urlstring = r.history[-1].url + print_error('Redirecting to ' + final_urlstring) + else: + final_urlstring = self.urlstring # Check for the prospective file size. try: size = int(r.headers['Content-Length']) @@ -247,7 +254,7 @@ class HttpC: size = None if self.check_space: check_storage(location, size) - with s.get(self.urlstring, stream=True) as r, open(location, 'wb') as f: + with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f: if self.progressbar and size: progress = make_incremental_progressbar(CHUNK_SIZE / size) next(progress) @@ -262,7 +269,7 @@ class HttpC: size = os.path.getsize(location) if self.progressbar else None # Keep in mind that `data` can be a file-like or iterable object. with self._establish() as s, file(location, 'rb') as f: - s.post(self.urlstring, data=WrappedFile(f, size)) + s.post(self.urlstring, data=WrappedFile(f, size), allow_redirects=True) class TftpC: @@ -324,7 +331,7 @@ def friendly_download(local_path, urlstring, source_host='', source_port=0): print_error('Downloading...') download(local_path, urlstring, True, True, source_host, source_port) except KeyboardInterrupt: - print_error('Download aborted by user.') + print_error('\nDownload aborted by user.') sys.exit(1) except: import traceback -- cgit v1.2.3