#!/usr/bin/env python3
#
# Copyright (C) 2024 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
from shutil import rmtree

from sys import exit

from netifaces import AF_INET
from psutil import net_if_addrs

from vyos.config import Config
from vyos.configverify import verify_pki_ca_certificate
from vyos.configverify import verify_pki_certificate
from vyos.pki import encode_certificate
from vyos.pki import encode_private_key
from vyos.pki import find_chain
from vyos.pki import load_certificate
from vyos.pki import load_private_key
from vyos.utils.dict import dict_search
from vyos.utils.file import makedir
from vyos.utils.file import write_file
from vyos.utils.network import check_port_availability
from vyos.utils.network import is_listen_port_bind_service
from vyos.utils.process import call
from vyos.template import render
from vyos import ConfigError
from vyos import airbag
airbag.enable()

stunnel_dir = '/run/stunnel'
config_file = f'{stunnel_dir}/stunnel.conf'
stunnel_ca_dir = f'{stunnel_dir}/ca'
stunnel_psk_dir = f'{stunnel_dir}/psk'

# config based on
# http://man.he.net/man8/stunnel4


def get_config(config=None):
    if config:
        conf = config
    else:
        conf = Config()
    base = ['service', 'stunnel']
    if not conf.exists(base):
        return None

    stunnel = conf.get_config_dict(base,
                              get_first_key=True,
                              key_mangling=('-', '_'),
                              no_tag_node_value_mangle=True,
                              with_recursive_defaults=True,
                              with_pki=True)
    stunnel['config_file'] = config_file
    return stunnel


def verify(stunnel):
    if not stunnel:
        return None

    stunnel_listen_addresses = list()
    for mode, conf in stunnel.items():
        if mode not in ['server', 'client']:
            continue

        for app, app_conf in conf.items():
            # connect, listen, exec and some protocols e.g. socks on server mode are endpoints.
            endpoints = 0
            if 'socks' == app_conf.get('protocol') and mode == 'server':
                if 'connect' in app_conf:
                    raise ConfigError("The 'connect' option cannot be used with the 'socks' protocol in server mode.")
                endpoints += 1

            for item in ['connect', 'listen']:
                if item in app_conf:
                    endpoints += 1
                    if 'port' not in app_conf[item]:
                        raise ConfigError(f'{mode} [{app}]: {item} port number is required!')
                elif item == 'listen':
                    raise ConfigError(f'{mode} [{app}]: {item} port number is required!')

            if endpoints != 2:
                raise ConfigError(f'{mode} [{app}]: connect port number is required!')

            if 'address' in app_conf['listen']:
                laddresses = [dict_search('listen.address', app_conf)]
            else:
                laddresses = list()
                ifaces = net_if_addrs()
                for iface_name, iface_addresses in ifaces.items():
                    for iface_addr in iface_addresses:
                        if iface_addr.family == AF_INET:
                            laddresses.append(iface_addr.address)

            lport = int(dict_search('listen.port', app_conf))

            for address in laddresses:
                if f'{address}:{lport}' in stunnel_listen_addresses:
                    raise ConfigError(
                        f'{mode} [{app}]: Address {address}:{lport} already '
                        f'in use by other stunnel service')

                stunnel_listen_addresses.append(f'{address}:{lport}')
                if not check_port_availability(address, lport, 'tcp') and \
                not is_listen_port_bind_service(lport, 'stunnel'):
                    raise ConfigError(
                        f'{mode} [{app}]: Address {address}:{lport} already in use')

            if 'options' in app_conf:
                protocol = app_conf.get('protocol')
                if protocol not in ['connect', 'smtp']:
                    raise ConfigError("Additional option is only supported in the 'connect' and 'smtp' protocols.")
                if protocol == 'smtp' and ('domain' in app_conf['options'] or 'host' in app_conf['options']):
                    raise ConfigError("Protocol 'smtp' does not support options 'domain' and 'host'.")

                # set default authentication option
                if 'authentication' not in app_conf['options']:
                    app_conf['options']['authentication'] = 'basic' if protocol == 'connect' else 'plain'

                for option, option_config in app_conf['options'].items():
                    if option == 'authentication':
                        if protocol == 'connect' and option_config not in ['basic', 'ntlm']:
                            raise ConfigError("Supported authentication types for the 'connect' protocol are 'basic' or 'ntlm'")
                        elif protocol == 'smtp' and option_config not in ['plain', 'login']:
                            raise ConfigError("Supported authentication types for the 'smtp' protocol are 'plain' or 'login'")
                    if option == 'host':
                        if 'address' not in option_config:
                            raise ConfigError('Address is required for option host.')
                        if 'port' not in option_config:
                            raise ConfigError('Port is required for option host.')

            # check pki certs
            for key in ['ca_certificate', 'certificate']:
                tmp = dict_search(f'ssl.{key}', app_conf)
                if mode == 'server' and key != 'ca_certificate' and not tmp and 'psk' not in app_conf:
                    raise ConfigError(f'{mode} [{app}]: TLS server needs a certificate or PSK')
                if tmp:
                    if key == 'ca_certificate':
                        for ca_cert in tmp:
                            verify_pki_ca_certificate(stunnel, ca_cert)
                    else:
                        verify_pki_certificate(stunnel, tmp)

            #check psk
            if 'psk' in app_conf:
                for psk, psk_conf in app_conf['psk'].items():
                    if 'id' not in psk_conf or 'secret' not in psk_conf:
                        raise ConfigError(
                            f'Authentication psk "{psk}" missing "id" or "secret"')


def generate(stunnel):
    if not stunnel or ('client' not in stunnel and 'server' not in stunnel):
        if os.path.isdir(stunnel_dir):
            rmtree(stunnel_dir, ignore_errors=True)

        return None
    makedir(stunnel_dir)

    exist_files = list()
    current_files = [config_file, config_file.replace('.conf', 'pid')]
    for root, dirs, files in os.walk(stunnel_dir):
        for file in files:
            exist_files.append(os.path.join(root, file))

    loaded_ca_certs = {load_certificate(c['certificate'])
        for c in stunnel['pki']['ca'].values()} if 'pki' in stunnel and 'ca' in stunnel['pki'] else {}

    for mode, conf in stunnel.items():
        if mode not in ['server', 'client']:
            continue

        for app, app_conf in conf.items():
            if 'ssl' in app_conf:
                if 'certificate' in app_conf['ssl']:
                    cert_name = app_conf['ssl']['certificate']

                    pki_cert = stunnel['pki']['certificate'][cert_name]
                    cert_file_path = os.path.join(stunnel_dir,
                                                  f'{mode}-{app}-{cert_name}.pem')
                    cert_key_path = os.path.join(stunnel_dir,
                                                 f'{mode}-{app}-{cert_name}.pem.key')
                    app_conf['ssl']['cert'] = cert_file_path

                    loaded_pki_cert = load_certificate(pki_cert['certificate'])
                    cert_full_chain = find_chain(loaded_pki_cert, loaded_ca_certs)

                    write_file(cert_file_path,
                       '\n'.join(encode_certificate(c) for c in cert_full_chain))
                    current_files.append(cert_file_path)

                    if 'private' in pki_cert and 'key' in pki_cert['private']:
                        app_conf['ssl']['cert_key'] = cert_key_path
                        loaded_key = load_private_key(pki_cert['private']['key'],
                                                      passphrase=None, wrap_tags=True)
                        key_pem = encode_private_key(loaded_key, passphrase=None)
                        write_file(cert_key_path, key_pem, mode=0o600)
                        current_files.append(cert_key_path)

                if 'ca_certificate' in app_conf['ssl']:
                    app_conf['ssl']['ca_path'] = stunnel_ca_dir
                    app_conf['ssl']['ca_file'] = f'{mode}-{app}-ca.pem'
                    ca_cert_file_path = os.path.join(stunnel_ca_dir, app_conf['ssl']['ca_file'])
                    ca_chains = []

                    for ca_name in app_conf['ssl']['ca_certificate']:
                        pki_ca_cert = stunnel['pki']['ca'][ca_name]
                        loaded_ca_cert = load_certificate(pki_ca_cert['certificate'])
                        ca_full_chain = find_chain(loaded_ca_cert, loaded_ca_certs)
                        ca_chains.append(
                            '\n'.join(encode_certificate(c) for c in ca_full_chain))

                    write_file(ca_cert_file_path, '\n'.join(ca_chains))
                    current_files.append(ca_cert_file_path)

            if 'psk' in app_conf:
                psk_data = list()
                psk_file_path = os.path.join(stunnel_psk_dir, f'{mode}_{app}.txt')

                for _, psk_conf in app_conf['psk'].items():
                    psk_data.append(f'{psk_conf["id"]}:{psk_conf["secret"]}')

                write_file(psk_file_path, '\n'.join(psk_data))
                app_conf['psk']['file'] = psk_file_path
                current_files.append(psk_file_path)

    for file in exist_files:
        if file not in current_files:
            os.unlink(file)

    render(config_file, 'stunnel/stunnel_config.j2', stunnel)


def apply(stunnel):
    if not stunnel or ('client' not in stunnel and 'server' not in stunnel):
        call('systemctl stop stunnel.service')
    else:
        call('systemctl restart stunnel.service')


if __name__ == '__main__':
    try:
        c = get_config()
        verify(c)
        generate(c)
        apply(c)
    except ConfigError as e:
        print(e)
        exit(1)