diff options
Diffstat (limited to 'python/vyos/remote.py')
-rw-r--r-- | python/vyos/remote.py | 53 |
1 files changed, 39 insertions, 14 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py index 80a4e7528..e972050b7 100644 --- a/python/vyos/remote.py +++ b/python/vyos/remote.py @@ -23,12 +23,15 @@ import tempfile import urllib.parse import urllib.request as urlreq +from vyos.template import get_ip +from vyos.template import ip_from_cidr +from vyos.template import is_interface +from vyos.template import is_ipv6 from vyos.util import cmd from vyos.util import ask_yes_no from vyos.util import print_error from vyos.util import make_progressbar from vyos.util import make_incremental_progressbar -from vyos.template import is_ipv6 from vyos.version import get_version from paramiko import SSHClient from paramiko import SSHException @@ -66,9 +69,24 @@ def get_authentication_variables(default_username=None, default_password=None): # Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`, # as we want the username and the password to have the same behaviour. if not username: - return (default_username, default_password) + return default_username, default_password + else: + return username, password + +def get_source_address(source): + """ + Take a string vaguely indicating an origin source (interface, hostname or IP address), + return a tuple in the format `(source_pair, address_family)` where + `source_pair` is `(source_address, source_port)`. + """ + # TODO: Properly distinguish between IPv4 and IPv6. + port = 0 + if is_interface(source): + source = ip_from_cidr(get_ip(source)[0]) + if is_ipv6(source): + return (source, port), socket.AF_INET6 else: - return (username, password) + return (socket.gethostbyname(source), port), socket.AF_INET def get_port_from_url(url): """ @@ -86,9 +104,9 @@ def get_port_from_url(url): ## FTP routines def upload_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ - source=None, progressbar=False): + source_pair=None, progressbar=False): size = os.path.getsize(local_path) - with FTP(source_address=source) as conn: + with FTP(source_address=source_pair) as conn: conn.connect(hostname, port) conn.login(username, password) with open(local_path, 'rb') as file: @@ -102,8 +120,8 @@ def upload_ftp(local_path, hostname, remote_path,\ def download_ftp(local_path, hostname, remote_path,\ username='anonymous', password='', port=21,\ - source=None, progressbar=False): - with FTP(source_address=source) as conn: + source_pair=None, progressbar=False): + with FTP(source_address=source_pair) as conn: conn.connect(hostname, port) conn.login(username, password) size = conn.size(remote_path) @@ -119,7 +137,7 @@ def download_ftp(local_path, hostname, remote_path,\ def get_ftp_file_size(hostname, remote_path,\ username='anonymous', password='', port=21,\ - source=None): + source_pair=None): with FTP(source_address=source) as conn: conn.connect(hostname, port) conn.login(username, password) @@ -135,13 +153,12 @@ def get_ftp_file_size(hostname, remote_path,\ ## SFTP/SCP routines def transfer_sftp(mode, local_path, hostname, remote_path,\ username=None, password=None, port=22,\ - source=None, progressbar=False): + source_tuple=None, progressbar=False): sock = None - if source: - # Check if the given string is an IPv6 address. - address_family = socket.AF_INET6 if is_ipv6(source) else socket.AF_INET + if source_tuple: + (source_address, source_port), address_family = source_tuple sock = socket.socket(address_family, socket.SOCK_STREAM) - sock.bind((source, 0)) + sock.bind((source_address, source_port)) sock.connect((hostname, port)) callback = make_progressbar() if progressbar else None with SSHClient() as ssh: @@ -254,7 +271,7 @@ def get_http_file_size(urlstring, username=None, password=None): def download(local_path, urlstring, source=None, progressbar=False): """ Dispatch the appropriate download function for the given `urlstring` and save to `local_path`. - Optionally takes a `source` address (not valid for HTTP(S)). + Optionally takes a `source` address or interface (not valid for HTTP(S)). Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP. Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables. """ @@ -267,9 +284,11 @@ def download(local_path, urlstring, source=None, progressbar=False): print_error('Warning: Custom source address not supported for HTTP connections.') download_http(local_path, urlstring, username, password, progressbar) elif url.scheme == 'ftp': + source = get_source_address(source)[0] if source else None username = username if username else 'anonymous' download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': + source = get_source_address(source) if source else None download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': download_tftp(local_path, url.hostname, url.path, port, source, progressbar) @@ -289,8 +308,10 @@ def upload(local_path, urlstring, source=None, progressbar=False): if url.scheme == 'ftp': username = username if username else 'anonymous' + source = get_source_address(source)[0] if source else None upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'sftp' or url.scheme == 'scp': + source = get_source_address(source) if source else None upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar) elif url.scheme == 'tftp': upload_tftp(local_path, url.hostname, url.path, port, source, progressbar) @@ -310,11 +331,15 @@ def get_remote_file_size(urlstring, source=None): port = get_port_from_url(url) if url.scheme == 'http' or url.scheme == 'https': + if source: + print_error('Warning: Custom source address not supported for HTTP connections.') return get_http_file_size(urlstring, username, password) elif url.scheme == 'ftp': + source = get_source_address(source)[0] if source else None username = username if username else 'anonymous' return get_ftp_file_size(url.hostname, url.path, username, password, port, source) elif url.scheme == 'sftp' or url.scheme == 'scp': + source = get_source_address(source) if source else None return get_sftp_file_size(url.hostname, url.path, username, password, port, source) else: raise ValueError(f'Unsupported URL scheme: {url.scheme}') |