summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/vyos/progressbar.py70
-rw-r--r--python/vyos/remote.py44
-rw-r--r--python/vyos/utils/io.py39
3 files changed, 93 insertions, 60 deletions
diff --git a/python/vyos/progressbar.py b/python/vyos/progressbar.py
new file mode 100644
index 000000000..1793c445b
--- /dev/null
+++ b/python/vyos/progressbar.py
@@ -0,0 +1,70 @@
+# Copyright 2023 VyOS maintainers and contributors <maintainers@vyos.io>
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library. If not, see <http://www.gnu.org/licenses/>.
+
+import math
+import os
+import signal
+import subprocess
+import sys
+
+from vyos.utils.io import print_error
+
+class Progressbar:
+ def __init__(self, step=None):
+ self.total = 0.0
+ self.step = step
+ def __enter__(self):
+ # Recalculate terminal width with every window resize.
+ signal.signal(signal.SIGWINCH, lambda signum, frame: self._update_cols())
+ # Disable line wrapping to prevent the staircase effect.
+ subprocess.run(['tput', 'rmam'], check=False)
+ self._update_cols()
+ # Print an empty progressbar with entry.
+ self.progress(0, 1)
+ return self
+ def __exit__(self, exc_type, kexc_val, exc_tb):
+ # Revert to the default SIGWINCH handler (ie nothing).
+ signal.signal(signal.SIGWINCH, signal.SIG_DFL)
+ # Reenable line wrapping.
+ subprocess.run(['tput', 'smam'], check=False)
+ def _update_cols(self):
+ # `os.get_terminal_size()' is fast enough for our purposes.
+ self.col = max(os.get_terminal_size().columns - 15, 20)
+ def increment(self):
+ """
+ Stateful progressbar taking the step fraction at init and no input at
+ callback (for FTP)
+ """
+ if self.step:
+ if self.total < 1.0:
+ self.total += self.step
+ if self.total >= 1.0:
+ self.total = 1.0
+ # Ignore superfluous calls caused by fuzzy FTP size calculations.
+ self.step = None
+ self.progress(self.total, 1.0)
+ def progress(self, done, total):
+ """
+ Stateless progressbar taking no input at init and current progress with
+ final size at callback (for SSH)
+ """
+ if done <= total:
+ length = math.ceil(self.col * done / total)
+ percentage = str(math.ceil(100 * done / total)).rjust(3)
+ # Carriage return at the end will make sure the line will get overwritten.
+ print_error(f'[{length * "#"}{(self.col - length) * "_"}] {percentage}%', end='\r')
+ # Print a newline to make sure the full progressbar doesn't get overwritten by the next line.
+ if done == total:
+ print_error()
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index cf731c881..1ca8a9530 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -32,9 +32,8 @@ from requests import Session
from requests.adapters import HTTPAdapter
from requests.packages.urllib3 import PoolManager
+from vyos.progressbar import Progressbar
from vyos.utils.io import ask_yes_no
-from vyos.utils.io import make_incremental_progressbar
-from vyos.utils.io import make_progressbar
from vyos.utils.io import print_error
from vyos.utils.misc import begin
from vyos.utils.process import cmd
@@ -131,16 +130,16 @@ class FtpC:
if self.secure:
conn.prot_p()
# Almost all FTP servers support the `SIZE' command.
+ size = conn.size(self.path)
if self.check_space:
- check_storage(path, conn.size(self.path))
+ check_storage(path, size)
# No progressbar if we can't determine the size or if the file is too small.
if self.progressbar and size and size > CHUNK_SIZE:
- progress = make_incremental_progressbar(CHUNK_SIZE / size)
- next(progress)
- callback = lambda block: begin(f.write(block), next(progress))
+ with Progressbar(CHUNK_SIZE / size) as p:
+ callback = lambda block: begin(f.write(block), p.increment())
+ conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE)
else:
- callback = f.write
- conn.retrbinary('RETR ' + self.path, callback, CHUNK_SIZE)
+ conn.retrbinary('RETR ' + self.path, f.write, CHUNK_SIZE)
def upload(self, location: str):
size = os.path.getsize(location)
@@ -150,12 +149,10 @@ class FtpC:
if self.secure:
conn.prot_p()
if self.progressbar and size and size > CHUNK_SIZE:
- progress = make_incremental_progressbar(CHUNK_SIZE / size)
- next(progress)
- callback = lambda block: next(progress)
+ with Progressbar(CHUNK_SIZE / size) as p:
+ conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, lambda block: p.increment())
else:
- callback = None
- conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE, callback)
+ conn.storbinary('STOR ' + self.path, f, CHUNK_SIZE)
class SshC:
known_hosts = os.path.expanduser('~/.ssh/known_hosts')
@@ -190,14 +187,16 @@ class SshC:
return ssh
def download(self, location: str):
- callback = make_progressbar() if self.progressbar else None
with self._establish() as ssh, ssh.open_sftp() as sftp:
if self.check_space:
check_storage(location, sftp.stat(self.path).st_size)
- sftp.get(self.path, location, callback=callback)
+ if self.progressbar:
+ with Progressbar() as p:
+ sftp.get(self.path, location, callback=p.progress)
+ else:
+ sftp.get(self.path, location)
def upload(self, location: str):
- callback = make_progressbar() if self.progressbar else None
with self._establish() as ssh, ssh.open_sftp() as sftp:
try:
# If the remote path is a directory, use the original filename.
@@ -210,7 +209,11 @@ class SshC:
except IOError:
path = self.path
finally:
- sftp.put(location, path, callback=callback)
+ if self.progressbar:
+ with Progressbar() as p:
+ sftp.put(location, path, callback=p.progress)
+ else:
+ sftp.put(location, path)
class HttpC:
@@ -264,10 +267,9 @@ class HttpC:
with s.get(final_urlstring, stream=True,
timeout=self.timeout) as r, open(location, 'wb') as f:
if self.progressbar and size:
- progress = make_incremental_progressbar(CHUNK_SIZE / size)
- next(progress)
- for chunk in iter(lambda: begin(next(progress), r.raw.read(CHUNK_SIZE)), b''):
- f.write(chunk)
+ with Progressbar(CHUNK_SIZE / size) as p:
+ for chunk in iter(lambda: begin(p.increment(), r.raw.read(CHUNK_SIZE)), b''):
+ f.write(chunk)
else:
# We'll try to stream the download directly with `copyfileobj()` so that large
# files (like entire VyOS images) don't occupy much memory.
diff --git a/python/vyos/utils/io.py b/python/vyos/utils/io.py
index 843494855..5fffa62f8 100644
--- a/python/vyos/utils/io.py
+++ b/python/vyos/utils/io.py
@@ -24,45 +24,6 @@ def print_error(str='', end='\n'):
sys.stderr.write(end)
sys.stderr.flush()
-def make_progressbar():
- """
- Make a procedure that takes two arguments `done` and `total` and prints a
- progressbar based on the ratio thereof, whose length is determined by the
- width of the terminal.
- """
- import shutil, math
- col, _ = shutil.get_terminal_size()
- col = max(col - 15, 20)
- def print_progressbar(done, total):
- if done <= total:
- increment = total / col
- length = math.ceil(done / increment)
- percentage = str(math.ceil(100 * done / total)).rjust(3)
- print_error(f'[{length * "#"}{(col - length) * "_"}] {percentage}%', '\r')
- # Print a newline so that the subsequent prints don't overwrite the full bar.
- if done == total:
- print_error()
- return print_progressbar
-
-def make_incremental_progressbar(increment: float):
- """
- Make a generator that displays a progressbar that grows monotonically with
- every iteration.
- First call displays it at 0% and every subsequent iteration displays it
- at `increment` increments where 0.0 < `increment` < 1.0.
- Intended for FTP and HTTP transfers with stateless callbacks.
- """
- print_progressbar = make_progressbar()
- total = 0.0
- while total < 1.0:
- print_progressbar(total, 1.0)
- yield
- total += increment
- print_progressbar(1, 1)
- # Ignore further calls.
- while True:
- yield
-
def ask_input(question, default='', numeric_only=False, valid_responses=[]):
question_out = question
if default: