summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/vyos/remote.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 2419f8873..732ef76b7 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -89,6 +89,7 @@ class WrappedFile:
def __getattr__(self, attr):
return getattr(self._obj, attr)
+
def check_storage(path, size):
"""
Check whether `path` has enough storage space for a transfer of `size` bytes.
@@ -236,9 +237,15 @@ class HttpC:
# Not only would it potentially mess up with the progress bar but
# `shutil.copyfileobj(request.raw, file)` does not handle automatic decoding.
s.headers.update({'Accept-Encoding': 'identity'})
- with s.head(self.urlstring) as r:
+ with s.head(self.urlstring, allow_redirects=True) as r:
# Abort early if the destination is inaccessible.
r.raise_for_status()
+ # If the request got redirected, keep the last URL we ended up with.
+ if r.history:
+ final_urlstring = r.history[-1].url
+ print_error('Redirecting to ' + final_urlstring)
+ else:
+ final_urlstring = self.urlstring
# Check for the prospective file size.
try:
size = int(r.headers['Content-Length'])
@@ -247,7 +254,7 @@ class HttpC:
size = None
if self.check_space:
check_storage(location, size)
- with s.get(self.urlstring, stream=True) as r, open(location, 'wb') as f:
+ with s.get(final_urlstring, stream=True) as r, open(location, 'wb') as f:
if self.progressbar and size:
progress = make_incremental_progressbar(CHUNK_SIZE / size)
next(progress)
@@ -262,7 +269,7 @@ class HttpC:
size = os.path.getsize(location) if self.progressbar else None
# Keep in mind that `data` can be a file-like or iterable object.
with self._establish() as s, file(location, 'rb') as f:
- s.post(self.urlstring, data=WrappedFile(f, size))
+ s.post(self.urlstring, data=WrappedFile(f, size), allow_redirects=True)
class TftpC:
@@ -324,7 +331,7 @@ def friendly_download(local_path, urlstring, source_host='', source_port=0):
print_error('Downloading...')
download(local_path, urlstring, True, True, source_host, source_port)
except KeyboardInterrupt:
- print_error('Download aborted by user.')
+ print_error('\nDownload aborted by user.')
sys.exit(1)
except:
import traceback