summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/vyos/remote.py138
1 files changed, 93 insertions, 45 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 3f24d4b33..ebbded67a 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -22,57 +22,74 @@ import urllib.parse
import urllib.request
from vyos.util import cmd
+from vyos.version import get_version
from paramiko import SSHClient
-def upload_ftp(local_path, hostname, remote_path,\
- username='anonymous', password='', port=21, source=None):
- with open(local_path, 'rb') as file:
- with FTP(source_address=source) as conn:
- conn.connect(hostname, port)
- conn.login(username, password)
- conn.storbinary(f'STOR {remote_path}', file)
-
-def download_ftp(local_path, hostname, remote_path,\
+## FTP routines
+def transfer_ftp(mode, local_path, hostname, remote_path,\
username='anonymous', password='', port=21, source=None):
- with open(local_path, 'wb') as file:
- with FTP(source_address=source) as conn:
- conn.connect(hostname, port)
- conn.login(username, password)
- conn.retrbinary(f'RETR {remote_path}', file.write)
+ with FTP(source_address=source) as conn:
+ conn.connect(hostname, port)
+ conn.login(username, password)
+ if mode == 'upload':
+ with open(local_path, 'rb') as file:
+ conn.storbinary(f'STOR {remote_path}', file)
+ elif mode == 'download':
+ with open(local_path, 'wb') as file:
+ conn.retrbinary(f'RETR {remote_path}', file.write)
+ elif mode == 'size':
+ size = conn.size(remote_path)
+ if size:
+ return size
+ else:
+ # SIZE is an extension to the FTP specification, although it's extremely common.
+ raise ValueError('Failed to receive file size from FTP server. \
+ Perhaps the server does not implement the SIZE command?')
-def upload_sftp(local_path, hostname, remote_path,\
- username=None, password=None, port=22, source=None):
- sock = None
- if source:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind((source, 0))
- sock.connect((hostname, port))
- with SSHClient() as ssh:
- ssh.load_system_host_keys()
- ssh.connect(hostname, port, username, password, sock=sock)
- with ssh.open_sftp() as sftp:
- sftp.put(local_path, remote_path)
- if sock:
- sock.shutdown()
- sock.close()
-
-def download_sftp(local_path, hostname, remote_path,\
+def upload_ftp(*args, **kwargs):
+ transfer_ftp('upload', *args, **kwargs)
+
+def download_ftp(*args, **kwargs):
+ transfer_ftp('download', *args, **kwargs)
+
+def get_ftp_file_size(*args, **kwargs):
+ return transfer_ftp('size', None, *args, **kwargs)
+
+## SFTP/SCP routines
+def transfer_sftp(mode, local_path, hostname, remote_path,\
username=None, password=None, port=22, source=None):
sock = None
if source:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((source, 0))
sock.connect((hostname, port))
- with SSHClient() as ssh:
- ssh.load_system_host_keys()
- ssh.connect(hostname, port, username, password, sock=sock)
- with ssh.open_sftp() as sftp:
- sftp.get(remote_path, local_path)
- if sock:
- sock.shutdown()
- sock.close()
+ try:
+ with SSHClient() as ssh:
+ ssh.load_system_host_keys()
+ ssh.connect(hostname, port, username, password, sock=sock)
+ with ssh.open_sftp() as sftp:
+ if mode == 'upload':
+ sftp.put(local_path, remote_path)
+ elif mode == 'download':
+ sftp.get(remote_path, local_path)
+ elif mode == 'size':
+ return sftp.stat(remote_path).st_size
+ finally:
+ if sock:
+ sock.shutdown()
+ sock.close()
+
+def upload_sftp(*args, **kwargs):
+ transfer_sftp('upload', *args, **kwargs)
+def download_sftp(*args, **kwargs):
+ transfer_sftp('download', *args, **kwargs)
+
+def get_sftp_file_size(*args, **kwargs):
+ return transfer_sftp('size', None, *args, **kwargs)
+
+## TFTP routines
def upload_tftp(local_path, hostname, remote_path, port=69, source=None):
source_option = f'--interface {source}' if source else ''
with open(local_path, 'rb') as file:
@@ -85,11 +102,27 @@ def download_tftp(local_path, hostname, remote_path, port=69, source=None):
file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\
stderr=None).encode())
+# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP,
+# as TFTP does not specify a SIZE command.
+
+## HTTP(S) routines
def download_http(urlstring, local_path):
+ request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
with open(local_path, 'wb') as file:
- with urllib.request.urlopen(urlstring) as response:
+ with urllib.request.urlopen(request) as response:
file.write(response.read())
+def get_http_file_size(urlstring):
+ request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
+ with urllib.request.urlopen(request) as response:
+ size = response.getheader('Content-Length')
+ if size:
+ return int(size)
+ # The server didn't send 'Content-Length' in the response headers.
+ else:
+ raise ValueError('Failed to receive file size from HTTP server.')
+
+# Dynamic dispatchers
def download(local_path, urlstring, source=None):
"""
Dispatch the appropriate download function for the given URL and save to local path.
@@ -97,7 +130,7 @@ def download(local_path, urlstring, source=None):
url = urllib.parse.urlparse(urlstring)
if url.scheme == 'http' or url.scheme == 'https':
if source:
- print("Warning: Custom source address not supported for HTTP connections.", file=sys.stderr)
+ print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr)
download_http(urlstring, local_path)
elif url.scheme == 'ftp':
username = url.username if url.username else 'anonymous'
@@ -107,7 +140,7 @@ def download(local_path, urlstring, source=None):
elif url.scheme == 'tftp':
download_tftp(local_path, url.hostname, url.path, source=source)
else:
- ValueError(f'Unsupported URL scheme: {url.scheme}')
+ raise ValueError(f'Unsupported URL scheme: {url.scheme}')
def upload(local_path, urlstring, source=None):
"""
@@ -122,9 +155,24 @@ def upload(local_path, urlstring, source=None):
elif url.scheme == 'tftp':
upload_tftp(local_path, url.hostname, url.path, source=source)
else:
- ValueError(f'Unsupported URL scheme: {url.scheme}')
+ raise ValueError(f'Unsupported URL scheme: {url.scheme}')
+
+def get_remote_file_size(urlstring, source=None):
+ """
+ Return the size of the remote file in bytes.
+ """
+ url = urllib.parse.urlparse(urlstring)
+ if url.scheme == 'http' or url.scheme == 'https':
+ return get_http_file_size(urlstring)
+ elif url.scheme == 'ftp':
+ username = url.username if url.username else 'anonymous'
+ return get_ftp_file_size(url.hostname, url.path, username, url.password, source=source)
+ elif url.scheme == 'sftp' or url.scheme == 'scp':
+ return get_sftp_file_size(url.hostname, url.path, url.username, url.password, source=source)
+ else:
+ raise ValueError(f'Unsupported URL scheme: {url.scheme}')
-def get_remote_config(urlstring):
+def get_remote_config(urlstring, source=None):
"""
Download remote (config) file and return the contents.
Args:
@@ -139,7 +187,7 @@ def get_remote_config(urlstring):
url = urllib.parse.urlparse(urlstring)
temp = tempfile.NamedTemporaryFile(delete=False).name
try:
- download(temp, urlstring)
+ download(temp, urlstring, source)
with open(temp, 'r') as file:
return file.read()
finally: