summaryrefslogtreecommitdiff
path: root/python/vyos/remote.py
diff options
context:
space:
mode:
authorerkin <e.altunbas@vyos.io>2021-06-20 19:45:45 +0300
committererkin <e.altunbas@vyos.io>2021-06-20 19:45:45 +0300
commit58300075021e65fcf3e61c9ab4dee2c33454e10f (patch)
tree3acf086c509e7eabfb062b2d1cba6dea53ae3f7d /python/vyos/remote.py
parentbe167b110dabb1f7f6db7351d828bba1e54e358a (diff)
downloadvyos-1x-58300075021e65fcf3e61c9ab4dee2c33454e10f.tar.gz
vyos-1x-58300075021e65fcf3e61c9ab4dee2c33454e10f.zip
T3268: remote: Determine source address from given network interface
Diffstat (limited to 'python/vyos/remote.py')
-rw-r--r--python/vyos/remote.py53
1 files changed, 39 insertions, 14 deletions
diff --git a/python/vyos/remote.py b/python/vyos/remote.py
index 80a4e7528..e972050b7 100644
--- a/python/vyos/remote.py
+++ b/python/vyos/remote.py
@@ -23,12 +23,15 @@ import tempfile
import urllib.parse
import urllib.request as urlreq
+from vyos.template import get_ip
+from vyos.template import ip_from_cidr
+from vyos.template import is_interface
+from vyos.template import is_ipv6
from vyos.util import cmd
from vyos.util import ask_yes_no
from vyos.util import print_error
from vyos.util import make_progressbar
from vyos.util import make_incremental_progressbar
-from vyos.template import is_ipv6
from vyos.version import get_version
from paramiko import SSHClient
from paramiko import SSHException
@@ -66,9 +69,24 @@ def get_authentication_variables(default_username=None, default_password=None):
# Note that this is different from `os.getenv('REMOTE_USERNAME', default=default_username)`,
# as we want the username and the password to have the same behaviour.
if not username:
- return (default_username, default_password)
+ return default_username, default_password
+ else:
+ return username, password
+
+def get_source_address(source):
+ """
+ Take a string vaguely indicating an origin source (interface, hostname or IP address),
+ return a tuple in the format `(source_pair, address_family)` where
+ `source_pair` is `(source_address, source_port)`.
+ """
+ # TODO: Properly distinguish between IPv4 and IPv6.
+ port = 0
+ if is_interface(source):
+ source = ip_from_cidr(get_ip(source)[0])
+ if is_ipv6(source):
+ return (source, port), socket.AF_INET6
else:
- return (username, password)
+ return (socket.gethostbyname(source), port), socket.AF_INET
def get_port_from_url(url):
"""
@@ -86,9 +104,9 @@ def get_port_from_url(url):
## FTP routines
def upload_ftp(local_path, hostname, remote_path,\
username='anonymous', password='', port=21,\
- source=None, progressbar=False):
+ source_pair=None, progressbar=False):
size = os.path.getsize(local_path)
- with FTP(source_address=source) as conn:
+ with FTP(source_address=source_pair) as conn:
conn.connect(hostname, port)
conn.login(username, password)
with open(local_path, 'rb') as file:
@@ -102,8 +120,8 @@ def upload_ftp(local_path, hostname, remote_path,\
def download_ftp(local_path, hostname, remote_path,\
username='anonymous', password='', port=21,\
- source=None, progressbar=False):
- with FTP(source_address=source) as conn:
+ source_pair=None, progressbar=False):
+ with FTP(source_address=source_pair) as conn:
conn.connect(hostname, port)
conn.login(username, password)
size = conn.size(remote_path)
@@ -119,7 +137,7 @@ def download_ftp(local_path, hostname, remote_path,\
def get_ftp_file_size(hostname, remote_path,\
username='anonymous', password='', port=21,\
- source=None):
+ source_pair=None):
with FTP(source_address=source) as conn:
conn.connect(hostname, port)
conn.login(username, password)
@@ -135,13 +153,12 @@ def get_ftp_file_size(hostname, remote_path,\
## SFTP/SCP routines
def transfer_sftp(mode, local_path, hostname, remote_path,\
username=None, password=None, port=22,\
- source=None, progressbar=False):
+ source_tuple=None, progressbar=False):
sock = None
- if source:
- # Check if the given string is an IPv6 address.
- address_family = socket.AF_INET6 if is_ipv6(source) else socket.AF_INET
+ if source_tuple:
+ (source_address, source_port), address_family = source_tuple
sock = socket.socket(address_family, socket.SOCK_STREAM)
- sock.bind((source, 0))
+ sock.bind((source_address, source_port))
sock.connect((hostname, port))
callback = make_progressbar() if progressbar else None
with SSHClient() as ssh:
@@ -254,7 +271,7 @@ def get_http_file_size(urlstring, username=None, password=None):
def download(local_path, urlstring, source=None, progressbar=False):
"""
Dispatch the appropriate download function for the given `urlstring` and save to `local_path`.
- Optionally takes a `source` address (not valid for HTTP(S)).
+ Optionally takes a `source` address or interface (not valid for HTTP(S)).
Supports HTTP, HTTPS, FTP, SFTP, SCP (through SFTP) and TFTP.
Reads `$REMOTE_USERNAME` and `$REMOTE_PASSWORD` environment variables.
"""
@@ -267,9 +284,11 @@ def download(local_path, urlstring, source=None, progressbar=False):
print_error('Warning: Custom source address not supported for HTTP connections.')
download_http(local_path, urlstring, username, password, progressbar)
elif url.scheme == 'ftp':
+ source = get_source_address(source)[0] if source else None
username = username if username else 'anonymous'
download_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
+ source = get_source_address(source) if source else None
download_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
download_tftp(local_path, url.hostname, url.path, port, source, progressbar)
@@ -289,8 +308,10 @@ def upload(local_path, urlstring, source=None, progressbar=False):
if url.scheme == 'ftp':
username = username if username else 'anonymous'
+ source = get_source_address(source)[0] if source else None
upload_ftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'sftp' or url.scheme == 'scp':
+ source = get_source_address(source) if source else None
upload_sftp(local_path, url.hostname, url.path, username, password, port, source, progressbar)
elif url.scheme == 'tftp':
upload_tftp(local_path, url.hostname, url.path, port, source, progressbar)
@@ -310,11 +331,15 @@ def get_remote_file_size(urlstring, source=None):
port = get_port_from_url(url)
if url.scheme == 'http' or url.scheme == 'https':
+ if source:
+ print_error('Warning: Custom source address not supported for HTTP connections.')
return get_http_file_size(urlstring, username, password)
elif url.scheme == 'ftp':
+ source = get_source_address(source)[0] if source else None
username = username if username else 'anonymous'
return get_ftp_file_size(url.hostname, url.path, username, password, port, source)
elif url.scheme == 'sftp' or url.scheme == 'scp':
+ source = get_source_address(source) if source else None
return get_sftp_file_size(url.hostname, url.path, username, password, port, source)
else:
raise ValueError(f'Unsupported URL scheme: {url.scheme}')