#!/usr/bin/env python3
#
# Copyright (C) 2019 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#

import os
import sys
import time
import json
import signal
import traceback
import re
import logging
import zmq

import jinja2

debug = True

# Configure logging
logger = logging.getLogger(__name__)
# set stream as output
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)

if debug:
    logger.setLevel(logging.DEBUG)
else:
    logger.setLevel(logging.INFO)


DATA_DIR = "/var/lib/vyos/"
STATE_FILE = os.path.join(DATA_DIR, "hostsd.state")

SOCKET_PATH = "ipc:///run/vyos-hostsd.sock"

RESOLV_CONF_FILE = '/etc/resolv.conf'
HOSTS_FILE = '/etc/hosts'

hosts_tmpl_source = """
### Autogenerated by VyOS ###
### Do not edit, your changes will get overwritten ###

# Local host
127.0.0.1       localhost
127.0.1.1       {{ host_name }}{% if domain_name %}.{{ domain_name }} {{ host_name }}{% endif %}

# The following lines are desirable for IPv6 capable hosts
::1             localhost ip6-localhost ip6-loopback
fe00::0         ip6-localnet
ff00::0         ip6-mcastprefix
ff02::1         ip6-allnodes
ff02::2         ip6-allrouters

# From DHCP and "system static host-mapping"
{%- if hosts %}
{% for h in hosts -%}
{{hosts[h]['address']}}\t{{h}}\t{% for a in hosts[h]['aliases'] %} {{a}} {% endfor %}
{% endfor %}
{%- endif %}
"""

hosts_tmpl = jinja2.Template(hosts_tmpl_source)

resolv_tmpl_source = """
### Autogenerated by VyOS ###
### Do not edit, your changes will get overwritten ###

{% for ns in name_servers -%}
nameserver {{ns}}
{% endfor -%}

{%- if domain_name %}
domain {{ domain_name }}
{%- endif %}

{%- if search_domains %}
search {{ search_domains | join(" ") }}
{%- endif %}

"""

resolv_tmpl = jinja2.Template(resolv_tmpl_source)

# The state data includes a list of name servers
# and a list of hosts entries.
#
# Name servers have the following structure:
# {"server": {"tag": <str>}}
#
# Hosts entries are similar:
# {"host": {"tag": <str>, "address": <str>, "aliases": <str list>}}
#
# The tag is either "static" or "dhcp-<intf>"
# It's used to distinguish entries created
# by different scripts so that they can be removed
# and re-created without having to track what needs
# to be changed
STATE = {
    "name_servers": {},
    "hosts": {},
    "host_name": "vyos",
    "domain_name": "",
    "search_domains": []}


def make_resolv_conf(data):
    resolv_conf = resolv_tmpl.render(data)
    logger.info("Writing /etc/resolv.conf")
    with open(RESOLV_CONF_FILE, 'w') as f:
        f.write(resolv_conf)

def make_hosts_file(state):
    logger.info("Writing /etc/hosts")
    hosts = hosts_tmpl.render(state)
    with open(HOSTS_FILE, 'w') as f:
        f.write(hosts)

def add_hosts(data, entries, tag):
    hosts = data['hosts']

    if not entries:
        return

    for e in entries:
        host = e['host']
        hosts[host] = {}
        hosts[host]['tag'] = tag
        hosts[host]['address'] = e['address']
        hosts[host]['aliases'] = e['aliases']

def delete_hosts(data, tag):
    hosts = data['hosts']
    keys_for_deletion = []

    # You can't delete items from a dict while iterating over it,
    # so we build a list of doomed items first
    for h in hosts:
        if hosts[h]['tag'] == tag:
            keys_for_deletion.append(h)

    for k in keys_for_deletion:
        del hosts[k]

def add_name_servers(data, entries, tag):
    name_servers = data['name_servers']

    if not entries:
        return

    for e in entries:
        name_servers[e] = {}
        name_servers[e]['tag'] = tag

def delete_name_servers(data, tag):
    name_servers = data['name_servers']
    regex_filter = re.compile(tag)
    for ns in list(name_servers.keys()):
        if regex_filter.match(name_servers[ns]['tag']):
            del name_servers[ns]

def set_host_name(state, data):
    if data['host_name']:
        state['host_name'] = data['host_name']
    if 'domain_name' in data:
        state['domain_name'] = data['domain_name']
    if 'search_domains' in data:
        state['search_domains'] = data['search_domains']

def get_name_servers(state, tag):
    ns = []
    data = state['name_servers']
    regex_filter = re.compile(tag)
    for n in data:
        if regex_filter.match(data[n]['tag']):
            ns.append(n)
    return ns

def get_option(msg, key):
    if key in msg:
        return msg[key]
    else:
        raise ValueError("Missing required option \"{0}\"".format(key))

def handle_message(msg_json):
    msg = json.loads(msg_json)

    op = get_option(msg, 'op')
    _type = get_option(msg, 'type')

    if op == 'delete':
        tag = get_option(msg, 'tag')

        if _type == 'name_servers':
            delete_name_servers(STATE, tag)
        elif _type == 'hosts':
            delete_hosts(STATE, tag)
        else:
            raise ValueError("Unknown message type {0}".format(_type))
    elif op == 'add':
        tag = get_option(msg, 'tag')
        entries = get_option(msg, 'data')
        if _type == 'name_servers':
            add_name_servers(STATE, entries, tag)
        elif _type == 'hosts':
            add_hosts(STATE, entries, tag)
        else:
            raise ValueError("Unknown message type {0}".format(_type))
    elif op == 'set':
        # Host name/domain name/search domain are set without a tag,
        # there can be only one anyway
        data = get_option(msg, 'data')
        if _type == 'host_name':
            set_host_name(STATE, data)
        else:
            raise ValueError("Unknown message type {0}".format(_type))
    elif op == 'get':
        tag = get_option(msg, 'tag')
        if _type == 'name_servers':
            result = get_name_servers(STATE, tag)
        else:
            raise ValueError("Unimplemented")
        return result
    else:
        raise ValueError("Unknown operation {0}".format(op))

    make_resolv_conf(STATE)
    make_hosts_file(STATE)

    logger.info("Saving state to {0}".format(STATE_FILE))
    with open(STATE_FILE, 'w') as f:
        json.dump(STATE, f)

def exit_handler(sig, frame):
    """ Clean up the state when shutdown correctly """
    logger.info("Cleaning up state")
    os.unlink(STATE_FILE)
    sys.exit(0)


if __name__ == '__main__':
    signal.signal(signal.SIGTERM, exit_handler)

    # Create a directory for state checkpoints
    os.makedirs(DATA_DIR, exist_ok=True)
    if os.path.exists(STATE_FILE):
        with open(STATE_FILE, 'r') as f:
            try:
                data = json.load(f)
                STATE = data
            except:
                logger.exception(traceback.format_exc())
                logger.exception("Failed to load the state file, using default")

    context = zmq.Context()
    socket = context.socket(zmq.REP)
    socket.bind(SOCKET_PATH)

    while True:
        #  Wait for next request from client
        message = socket.recv().decode()
        logger.info("Received a configuration change request")
        logger.debug("Request data: {0}".format(message))

        resp = {}

        try:
            result = handle_message(message)
            resp['data'] = result
        except ValueError as e:
            resp['error'] = str(e)
        except:
            logger.exception(traceback.format_exc())
            resp['error'] = "Internal error"

        logger.debug("Sent response: {0}".format(resp))

        #  Send reply back to client
        socket.send(json.dumps(resp).encode())