#!/usr/bin/env python3
#
# Copyright (C) 2019 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
#
import os
import sys
import time
import json
import signal
import traceback
import re
import logging
import zmq
import collections
import jinja2
from vyos.util import popen, process_named_running
debug = True
# Configure logging
logger = logging.getLogger(__name__)
# set stream as output
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)
if debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
DATA_DIR = "/var/lib/vyos/"
STATE_FILE = os.path.join(DATA_DIR, "hostsd.state")
SOCKET_PATH = "ipc:///run/vyos-hostsd.sock"
RESOLV_CONF_FILE = '/etc/resolv.conf'
HOSTS_FILE = '/etc/hosts'
hosts_tmpl_source = """
### Autogenerated by VyOS ###
### Do not edit, your changes will get overwritten ###
# Local host
127.0.0.1 localhost
127.0.1.1 {{ host_name }}{% if domain_name %}.{{ domain_name }} {{ host_name }}{% endif %}
# The following lines are desirable for IPv6 capable hosts
::1 localhost ip6-localhost ip6-loopback
fe00::0 ip6-localnet
ff00::0 ip6-mcastprefix
ff02::1 ip6-allnodes
ff02::2 ip6-allrouters
# From DHCP and "system static host-mapping"
{%- if hosts %}
{% for h in hosts -%}
{{hosts[h]['address']}}\t{{h}}\t{% for a in hosts[h]['aliases'] %} {{a}} {% endfor %}
{% endfor %}
{%- endif %}
"""
hosts_tmpl = jinja2.Template(hosts_tmpl_source)
resolv_tmpl_source = """
### Autogenerated by VyOS ###
### Do not edit, your changes will get overwritten ###
# name server from static configuration
{% for ns in name_servers -%}
{%- if name_servers[ns]['tag'] == "static" %}
nameserver {{ns}}
{%- endif %}
{% endfor -%}
{% for ns in name_servers -%}
{%- if name_servers[ns]['tag'] != "static" %}
# name server from {{name_servers[ns]['tag']}}
nameserver {{ns}}
{%- endif %}
{% endfor -%}
{%- if domain_name %}
domain {{ domain_name }}
{%- endif %}
{%- if search_domains %}
search {{ search_domains | join(" ") }}
{%- endif %}
"""
resolv_tmpl = jinja2.Template(resolv_tmpl_source)
# The state data includes a list of name servers
# and a list of hosts entries.
#
# Name servers have the following structure:
# {"server": {"tag": }}
#
# Hosts entries are similar:
# {"host": {"tag": , "address": , "aliases": }}
#
# The tag is either "static" or "dhcp-"
# It's used to distinguish entries created
# by different scripts so that they can be removed
# and re-created without having to track what needs
# to be changed
STATE = {
"name_servers": collections.OrderedDict({}),
"hosts": {},
"host_name": "vyos",
"domain_name": "",
"search_domains": []}
def make_resolv_conf(data):
resolv_conf = resolv_tmpl.render(data)
logger.info("Writing /etc/resolv.conf")
with open(RESOLV_CONF_FILE, 'w') as f:
f.write(resolv_conf)
def make_hosts_file(state):
logger.info("Writing /etc/hosts")
hosts = hosts_tmpl.render(state)
with open(HOSTS_FILE, 'w') as f:
f.write(hosts)
def add_hosts(data, entries, tag):
hosts = data['hosts']
if not entries:
return
for e in entries:
host = e['host']
hosts[host] = {}
hosts[host]['tag'] = tag
hosts[host]['address'] = e['address']
hosts[host]['aliases'] = e['aliases']
def delete_hosts(data, tag):
hosts = data['hosts']
keys_for_deletion = []
# You can't delete items from a dict while iterating over it,
# so we build a list of doomed items first
for h in hosts:
if hosts[h]['tag'] == tag:
keys_for_deletion.append(h)
for k in keys_for_deletion:
del hosts[k]
def add_name_servers(data, entries, tag):
name_servers = data['name_servers']
if not entries:
return
for e in entries:
name_servers[e] = {}
name_servers[e]['tag'] = tag
def delete_name_servers(data, tag):
name_servers = data['name_servers']
regex_filter = re.compile(tag)
for ns in list(name_servers.keys()):
if regex_filter.match(name_servers[ns]['tag']):
del name_servers[ns]
def set_host_name(state, data):
if data['host_name']:
state['host_name'] = data['host_name']
if 'domain_name' in data:
state['domain_name'] = data['domain_name']
if 'search_domains' in data:
state['search_domains'] = data['search_domains']
def get_name_servers(state, tag):
ns = []
data = state['name_servers']
regex_filter = re.compile(tag)
for n in data:
if regex_filter.match(data[n]['tag']):
ns.append(n)
return ns
def get_option(msg, key):
if key in msg:
return msg[key]
else:
raise ValueError("Missing required option \"{0}\"".format(key))
def handle_message(msg_json):
msg = json.loads(msg_json)
op = get_option(msg, 'op')
_type = get_option(msg, 'type')
changes = 0
if op == 'delete':
tag = get_option(msg, 'tag')
if _type == 'name_servers':
delete_name_servers(STATE, tag)
changes += 1
elif _type == 'hosts':
delete_hosts(STATE, tag)
changes += 1
else:
raise ValueError("Unknown message type {0}".format(_type))
elif op == 'add':
tag = get_option(msg, 'tag')
entries = get_option(msg, 'data')
if _type == 'name_servers':
add_name_servers(STATE, entries, tag)
changes += 1
elif _type == 'hosts':
add_hosts(STATE, entries, tag)
changes += 1
else:
raise ValueError("Unknown message type {0}".format(_type))
elif op == 'set':
# Host name/domain name/search domain are set without a tag,
# there can be only one anyway
data = get_option(msg, 'data')
if _type == 'host_name':
set_host_name(STATE, data)
changes += 1
else:
raise ValueError("Unknown message type {0}".format(_type))
elif op == 'get':
tag = get_option(msg, 'tag')
if _type == 'name_servers':
result = get_name_servers(STATE, tag)
else:
raise ValueError("Unimplemented")
return result
else:
raise ValueError("Unknown operation {0}".format(op))
make_resolv_conf(STATE)
make_hosts_file(STATE)
logger.info("Saving state to {0}".format(STATE_FILE))
with open(STATE_FILE, 'w') as f:
json.dump(STATE, f)
if changes > 0:
if process_named_running("pdns_recursor"):
(ret,return_code) = popen("sudo rec_control --socket-dir=/run/powerdns reload-zones")
if return_code > 0:
logger.exception("PowerDNS rec_control failed to reload")
def exit_handler(sig, frame):
""" Clean up the state when shutdown correctly """
logger.info("Cleaning up state")
os.unlink(STATE_FILE)
sys.exit(0)
if __name__ == '__main__':
signal.signal(signal.SIGTERM, exit_handler)
# Create a directory for state checkpoints
os.makedirs(DATA_DIR, exist_ok=True)
if os.path.exists(STATE_FILE):
with open(STATE_FILE, 'r') as f:
try:
data = json.load(f)
STATE = data
except:
logger.exception(traceback.format_exc())
logger.exception("Failed to load the state file, using default")
context = zmq.Context()
socket = context.socket(zmq.REP)
# Set the right permissions on the socket, then change it back
o_mask = os.umask(0)
socket.bind(SOCKET_PATH)
os.umask(o_mask)
while True:
# Wait for next request from client
message = socket.recv().decode()
logger.info("Received a configuration change request")
logger.debug("Request data: {0}".format(message))
resp = {}
try:
result = handle_message(message)
resp['data'] = result
except ValueError as e:
resp['error'] = str(e)
except:
logger.exception(traceback.format_exc())
resp['error'] = "Internal error"
logger.debug("Sent response: {0}".format(resp))
# Send reply back to client
socket.send(json.dumps(resp).encode())