#!/usr/bin/env python3
#
# Copyright (C) 2018 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import os
import sys
import subprocess
import ipaddress

import jinja2

import vyos.config
import vyos.keepalived

from vyos import ConfigError


config_file = "/etc/keepalived/keepalived.conf"

config_tmpl = """
# Autogenerated by VyOS
# Do not edit this file, all your changes will be lost
# on next commit or reboot

{% for group in groups -%}

{% if group.health_check_script -%}
vrrp_script healthcheck_{{ group.name }} {
	script {{ group.health_check_script }}
	interval {{ group.health_check_interval }}
	fall {{ group.health_check_count }}
	rise 1

}
{% endif %}

vrrp_instance {{ group.name }} {
    {% if group.description -%}
    # {{ group.description }}
    {% endif -%}

    state BACKUP
    interface {{ group.interface }}
    virtual_router_id {{ group.vrid }}
    priority {{ group.priority }}
    advert_int {{ group.advertise_interval }}

    {% if group.preempt -%}
    preempt_delay {{ group.preempt_delay }}
    {% else -%}
    nopreempt
    {% endif -%}

    {% if group.peer_address -%}
    unicast_peer { {{ group.peer_address }} }
    {% endif -%}

    {% if group.hello_source -%}
      {%- if group.peer_address -%}
      unicast_src_ip {{ group.hello_source }}
      {%- else -%}
      mcast_src_ip {{ group.hello_source }}
      {%- endif %}
    {% endif -%}

    {% if group.use_vmac -%}
      use_vmac {{group.interface}}v{{group.vrid}}
    {% endif -%}

    {% if group.auth_password -%}
      authentication {
        auth_pass {{ group.auth_password }}
        auth_type {{ group.auth_type }}
      }
    {% endif -%}

    virtual_ipaddress {
    {% for addr in group.virtual_addresses -%}
        {{ addr }}
    {% endfor -%}
    }

    {% if group.health_check_script -%}
    track_script {
        healthcheck_{{ group.name }}
    }
    {% endif -%}

    {% if group.master_script -%}
        notify_master "/usr/libexec/vyos/system/vrrp-script-wrapper.py --state master --group {{ group.name }} --interface {{ group.interface }} {{ group.master_script }}"
    {% endif -%}

    {% if group.backup_script -%}
        notify_backup "/usr/libexec/vyos/system/vrrp-script-wrapper.py --state backup --group {{ group.name }} --interface {{ group.interface }} {{ group.backup_script }}"
    {% endif -%}

    {% if group.fault_script -%}
        notify_fault "/usr/libexec/vyos/system/vrrp-script-wrapper.py --state fault --group {{ group.name }} --interface {{ group.interface }} {{ group.fault_script }}"
    {% endif -%}
}

{% endfor -%}

{% for sync_group in sync_groups -%}
vrrp_sync_group {{ sync_group.name }} {
       group {
            {% for member in sync_group.members -%}
                {{ member }}
            {% endfor -%}
        }

        {% if sync_group.conntrack_sync -%}
            notify_master "/opt/vyatta/sbin/vyatta-vrrp-conntracksync.sh master {{ sync_group.name }}"
            notify_backup "/opt/vyatta/sbin/vyatta-vrrp-conntracksync.sh backup {{ sync_group.name }}"
            notify_fault "/opt/vyatta/sbin/vyatta-vrrp-conntracksync.sh fault {{ sync_group.name }}"
        {% endif -%}
}

{% endfor -%}

"""

def get_config():
    vrrp_groups = []
    sync_groups = []

    config = vyos.config.Config()

    # Get the VRRP groups
    for group_name in config.list_nodes("high-availability vrrp group"):
        config.set_level("high-availability vrrp group {0}".format(group_name))

        # Retrieve the values
        group = {"preempt": True, "use_vmac": False, "disable": False}

        if config.exists("disable"):
            group["disable"] = True

        group["name"] = group_name
        group["vrid"] = config.return_value("vrid")
        group["interface"] = config.return_value("interface")
        group["description"] = config.return_value("description")
        group["advertise_interval"] = config.return_value("advertise-interval")
        group["priority"] = config.return_value("priority")
        group["hello_source"] = config.return_value("hello-source-address")
        group["peer_address"] = config.return_value("peer-address")
        group["sync_group"] = config.return_value("sync-group")
        group["preempt_delay"] = config.return_value("preempt-delay")
        group["virtual_addresses"] = config.return_values("virtual-address")

        group["auth_password"] = config.return_value("authentication password")
        group["auth_type"] = config.return_value("authentication type")

        group["health_check_script"] = config.return_value("health-check script")
        group["health_check_interval"] = config.return_value("health-check interval")
        group["health_check_count"] = config.return_value("health-check failure-count")

        group["master_script"] = config.return_value("transition-script master")
        group["backup_script"] = config.return_value("transition-script backup")
        group["fault_script"] = config.return_value("transition-script fault")

        if config.exists("no-preempt"):
            group["preempt"] = False
        if config.exists("rfc3768-compatibility"):
            group["use_vmac"] = True

        # Substitute defaults where applicable
        if not group["advertise_interval"]:
            group["advertise_interval"] = 1
        if not group["priority"]:
            group["priority"] = 100
        if not group["preempt_delay"]:
            group["preempt_delay"] = 5 * 60
        if not group["health_check_interval"]:
            group["health_check_interval"] = 60
        if not group["health_check_count"]:
            group["health_check_count"] = 3

        # FIXUP: translate our option for auth type to keepalived's syntax
        # for simplicity
        if group["auth_type"]:
            if group["auth_type"] == "plaintext-password":
                group["auth_type"] = "PASS"
            else:
                group["auth_type"] = "AH"

        vrrp_groups.append(group)

    config.set_level("") 

    # Get the sync group used for conntrack-sync
    conntrack_sync_group = None
    if config.exists("service conntrack-sync failover-mechanism vrrp"):
        conntrack_sync_group = config.return_value("service conntrack-sync failover-mechanism vrrp sync-group")

    # Get the sync groups
    for sync_group_name in config.list_nodes("high-availability vrrp sync-group"):
        config.set_level("high-availability vrrp sync-group {0}".format(sync_group_name))

        sync_group = {"conntrack_sync": False}
        sync_group["name"] = sync_group_name
        sync_group["members"] = config.return_values("member")
        if conntrack_sync_group:
            if conntrack_sync_group == sync_group_name:
                sync_group["conntrack_sync"] = True

        sync_groups.append(sync_group)

    return (vrrp_groups, sync_groups)

def verify(data):
    vrrp_groups, sync_groups = data

    for group in vrrp_groups:
        # Check required fields
        if not group["vrid"]:
            raise ConfigError("vrid is required but not set in VRRP group {0}".format(group["name"]))
        if not group["interface"]:
            raise ConfigError("interface is required but not set in VRRP group {0}".format(group["name"]))
        if not group["virtual_addresses"]:
            raise ConfigError("virtual-address is required but not set in VRRP group {0}".format(group["name"]))

        if group["auth_password"] and (not group["auth_type"]):
            raise ConfigError("authentication type is required but not set in VRRP group {0}".format(group["name"]))

        # Keepalived doesn't allow mixing IPv4 and IPv6 in one group, so we mirror that restriction

        # XXX: filter on map object is destructive, so we force it to list.
        # Additionally, filter objects always evaluate to True, empty or not,
        # so we force them to lists as well.
        vaddrs = list(map(lambda i: ipaddress.ip_interface(i), group["virtual_addresses"]))
        vaddrs4 = list(filter(lambda x: isinstance(x, ipaddress.IPv4Interface), vaddrs))
        vaddrs6 = list(filter(lambda x: isinstance(x, ipaddress.IPv6Interface), vaddrs))

        if vaddrs4 and vaddrs6:
            raise ConfigError("VRRP group {0} mixes IPv4 and IPv6 virtual addresses, this is not allowed. Create separate groups for IPv4 and IPv6".format(group["name"]))

        if vaddrs4:
            if group["hello_source"]:
                hsa = ipaddress.ip_address(group["hello_source"])
                if isinstance(hsa, ipaddress.IPv6Address):
                    raise ConfigError("VRRP group {0} uses IPv4 but its hello-source-address is IPv6".format(group["name"]))
            if group["peer_address"]:
                pa = ipaddress.ip_address(group["peer_address"])
                if isinstance(pa, ipaddress.IPv6Address):
                    raise ConfigError("VRRP group {0} uses IPv4 but its peer-address is IPv6".format(group["name"]))

        if vaddrs6:
            if group["hello_source"]:
                hsa = ipaddress.ip_address(group["hello_source"])
                if isinstance(hsa, ipaddress.IPv4Address):
                    raise ConfigError("VRRP group {0} uses IPv6 but its hello-source-address is IPv4".format(group["name"]))
            if group["peer_address"]:
                pa = ipaddress.ip_address(group["peer_address"])
                if isinstance(pa, ipaddress.IPv4Address):
                    raise ConfigError("VRRP group {0} uses IPv6 but its peer-address is IPv4".format(group["name"]))

    # Disallow same VRID on multiple interfaces
    _groups = sorted(vrrp_groups, key=(lambda x: x["interface"]))
    count = len(_groups) - 1
    index = 0
    while (index < count):
        if _groups[index]["vrid"] == _groups[index + 1]["vrid"]:
            raise ConfigError("VRID {0} is used in groups {1} and {2} that both use interface {3}. Groups on the same interface must use different VRIDs".format(
              _groups[index]["vrid"], _groups[index]["name"], _groups[index + 1]["name"], _groups[index]["interface"]))
        else:
            index += 1

    # Check sync groups
    vrrp_group_names = list(map(lambda x: x["name"], vrrp_groups))

    for sync_group in sync_groups:
        for m in sync_group["members"]:
            if not (m in vrrp_group_names):
                raise ConfigError("VRRP sync-group {0} refers to VRRP group {1}, but group {1} does not exist".format(sync_group["name"], m))

def generate(data):
    vrrp_groups, sync_groups = data

    # Remove disabled groups from the sync group member lists
    for sync_group in sync_groups:
        for member in sync_group["members"]:
            g = list(filter(lambda x: x["name"] == member, vrrp_groups))[0]
            if g["disable"]:
                print("Warning: ignoring disabled VRRP group {0} in sync-group {1}".format(g["name"], sync_group["name"]))
    # Filter out disabled groups
    vrrp_groups = list(filter(lambda x: x["disable"] != True, vrrp_groups))

    tmpl = jinja2.Template(config_tmpl)
    config_text = tmpl.render({"groups": vrrp_groups, "sync_groups": sync_groups})
    
    with open(config_file, 'w') as f:
        f.write(config_text)
    return None

def apply(data):
    vrrp_groups, sync_groups = data
    if vrrp_groups:
        if not vyos.keepalived.vrrp_running():
            print("Starting the VRRP process")
            ret = subprocess.call("sudo systemctl restart keepalived.service", shell=True)
        else:
            print("Reloading the VRRP process")
            ret = subprocess.call("sudo systemctl reload keepalived.service", shell=True)

        if ret != 0:
            raise ConfigError("keepalived failed to start")
    else:
        # VRRP is removed in the commit
        print("Stopping the VRRP process")
        subprocess.call("sudo systemctl stop keepalived.service", shell=True)
        os.unlink(config_file)

    return None


if __name__ == '__main__':
    try:
        c = get_config()
        verify(c)
        generate(c)
        apply(c)
    except ConfigError as e:
        print("VRRP error: {0}".format(str(e)))
        sys.exit(1)