summaryrefslogtreecommitdiff
path: root/python/vyos/remote.py
blob: f683a6d5a29820f77b8f10285d40e819e099ab28 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright 2021 VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library.  If not, see <http://www.gnu.org/licenses/>.

from ftplib import FTP
import os
import socket
import sys
import tempfile
import urllib.parse
import urllib.request

from vyos.util import cmd, ask_yes_no
from vyos.version import get_version
from paramiko import SSHClient, SSHException, MissingHostKeyPolicy


known_hosts_file = os.path.expanduser('~/.ssh/known_hosts')

class InteractivePolicy(MissingHostKeyPolicy):
    """
    Policy for interactively querying the user on whether to proceed with
    SSH connections to unknown hosts.
    """
    def missing_host_key(self, client, hostname, key):
        print(f"Host '{hostname}' not found in known hosts.")
        print('Fingerprint: ' + key.get_fingerprint().hex())
        if ask_yes_no('Do you wish to continue?'):
            if client._host_keys_filename and ask_yes_no('Do you wish to permanently add this host/key pair to known hosts?'):
                client._host_keys.add(hostname, key.get_name(), key)
                client.save_host_keys(client._host_keys_filename)
        else:
            raise SSHException(f"Cannot connect to unknown host '{hostname}'.")

## FTP routines
def transfer_ftp(mode, local_path, hostname, remote_path,\
                 username='anonymous', password='', port=21, source=None):
    with FTP(source_address=source) as conn:
        conn.connect(hostname, port)
        conn.login(username, password)
        if mode == 'upload':
            with open(local_path, 'rb') as file:
                conn.storbinary(f'STOR {remote_path}', file)
        elif mode == 'download':
            with open(local_path, 'wb') as file:
                conn.retrbinary(f'RETR {remote_path}', file.write)
        elif mode == 'size':
            size = conn.size(remote_path)
            if size:
                return size
            else:
                # SIZE is an extension to the FTP specification, although it's extremely common.
                raise ValueError('Failed to receive file size from FTP server. \
                Perhaps the server does not implement the SIZE command?')

def upload_ftp(*args, **kwargs):
    transfer_ftp('upload', *args, **kwargs)

def download_ftp(*args, **kwargs):
    transfer_ftp('download', *args, **kwargs)

def get_ftp_file_size(*args, **kwargs):
    return transfer_ftp('size', None, *args, **kwargs)

## SFTP/SCP routines
def transfer_sftp(mode, local_path, hostname, remote_path,\
                  username=None, password=None, port=22, source=None):
    sock = None
    if source:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.bind((source, 0))
        sock.connect((hostname, port))
    try:
        with SSHClient() as ssh:
            ssh.load_system_host_keys()
            if os.path.exists(known_hosts_file):
                ssh.load_host_keys(known_hosts_file)
            ssh.set_missing_host_key_policy(InteractivePolicy())
            ssh.connect(hostname, port, username, password, sock=sock)
            with ssh.open_sftp() as sftp:
                if mode == 'upload':
                    sftp.put(local_path, remote_path)
                elif mode == 'download':
                    sftp.get(remote_path, local_path)
                elif mode == 'size':
                    return sftp.stat(remote_path).st_size
    finally:
        if sock:
            sock.shutdown()
            sock.close()

def upload_sftp(*args, **kwargs):
    transfer_sftp('upload', *args, **kwargs)

def download_sftp(*args, **kwargs):
    transfer_sftp('download', *args, **kwargs)

def get_sftp_file_size(*args, **kwargs):
    return transfer_sftp('size', None, *args, **kwargs)

## TFTP routines
def upload_tftp(local_path, hostname, remote_path, port=69, source=None):
    source_option = f'--interface {source}' if source else ''
    with open(local_path, 'rb') as file:
        cmd(f'curl {source_option} -s -T - tftp://{hostname}:{port}/{remote_path}',\
            stderr=None, input=file.read()).encode()

def download_tftp(local_path, hostname, remote_path, port=69, source=None):
    source_option = f'--interface {source}' if source else ''
    with open(local_path, 'wb') as file:
        file.write(cmd(f'curl {source_option} -s tftp://{hostname}:{port}/{remote_path}',\
                       stderr=None).encode())

# get_tftp_file_size() is unimplemented because there is no way to obtain a file's size through TFTP,
# as TFTP does not specify a SIZE command.

## HTTP(S) routines
def download_http(urlstring, local_path):
    request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
    with open(local_path, 'wb') as file:
        with urllib.request.urlopen(request) as response:
            file.write(response.read())

def get_http_file_size(urlstring):
    request = urllib.request.Request(urlstring, headers={'User-Agent': 'VyOS/' + get_version()})
    with urllib.request.urlopen(request) as response:
        size = response.getheader('Content-Length')
        if size:
            return int(size)
        # The server didn't send 'Content-Length' in the response headers.
        else:
            raise ValueError('Failed to receive file size from HTTP server.')

# Dynamic dispatchers
def download(local_path, urlstring, source=None):
    """
    Dispatch the appropriate download function for the given URL and save to local path.
    """
    url = urllib.parse.urlparse(urlstring)
    if url.scheme == 'http' or url.scheme == 'https':
        if source:
            print('Warning: Custom source address not supported for HTTP connections.', file=sys.stderr)
        download_http(urlstring, local_path)
    elif url.scheme == 'ftp':
        username = url.username if url.username else 'anonymous'
        download_ftp(local_path, url.hostname, url.path, username, url.password, source=source)
    elif url.scheme == 'sftp' or url.scheme == 'scp':
        download_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source)
    elif url.scheme == 'tftp':
        download_tftp(local_path, url.hostname, url.path, source=source)
    else:
        raise ValueError(f'Unsupported URL scheme: {url.scheme}')

def upload(local_path, urlstring, source=None):
    """
    Dispatch the appropriate upload function for the given URL and upload from local path.
    """
    url = urllib.parse.urlparse(urlstring)
    if url.scheme == 'ftp':
        username = url.username if url.username else 'anonymous'
        upload_ftp(local_path, url.hostname, url.path, username, url.password, source=source)
    elif url.scheme == 'sftp' or url.scheme == 'scp':
        upload_sftp(local_path, url.hostname, url.path, url.username, url.password, source=source)
    elif url.scheme == 'tftp':
        upload_tftp(local_path, url.hostname, url.path, source=source)
    else:
        raise ValueError(f'Unsupported URL scheme: {url.scheme}')

def get_remote_file_size(urlstring, source=None):
    """
    Return the size of the remote file in bytes.
    """
    url = urllib.parse.urlparse(urlstring)
    if url.scheme == 'http' or url.scheme == 'https':
        return get_http_file_size(urlstring)
    elif url.scheme == 'ftp':
        username = url.username if url.username else 'anonymous'
        return get_ftp_file_size(url.hostname, url.path, username, url.password, source=source)
    elif url.scheme == 'sftp' or url.scheme == 'scp':
        return get_sftp_file_size(url.hostname, url.path, url.username, url.password, source=source)
    else:
        raise ValueError(f'Unsupported URL scheme: {url.scheme}')

def get_remote_config(urlstring, source=None):
    """
    Download remote (config) file and return the contents.
        Args:
            remote file URI:
                scp://<user>[:<passwd>]@<host>/<file>
                sftp://<user>[:<passwd>]@<host>/<file>
                http://<host>/<file>
                https://<host>/<file>
                ftp://[<user>[:<passwd>]@]<host>/<file>
                tftp://<host>/<file>
    """
    url = urllib.parse.urlparse(urlstring)
    temp = tempfile.NamedTemporaryFile(delete=False).name
    try:
        download(temp, urlstring, source)
        with open(temp, 'r') as file:
            return file.read()
    finally:
        os.remove(temp)