#!/usr/bin/env python3 # # Copyright (C) 2019 VyOS maintainers and contributors # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 or later as # published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # import os import sys import time import json import signal import traceback import re import logging import zmq import collections import jinja2 debug = True # Configure logging logger = logging.getLogger(__name__) # set stream as output logs_handler = logging.StreamHandler() logger.addHandler(logs_handler) if debug: logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.INFO) DATA_DIR = "/var/lib/vyos/" STATE_FILE = os.path.join(DATA_DIR, "hostsd.state") SOCKET_PATH = "ipc:///run/vyos-hostsd.sock" RESOLV_CONF_FILE = '/etc/resolv.conf' HOSTS_FILE = '/etc/hosts' hosts_tmpl_source = """ ### Autogenerated by VyOS ### ### Do not edit, your changes will get overwritten ### # Local host 127.0.0.1 localhost 127.0.1.1 {{ host_name }}{% if domain_name %}.{{ domain_name }} {{ host_name }}{% endif %} # The following lines are desirable for IPv6 capable hosts ::1 localhost ip6-localhost ip6-loopback fe00::0 ip6-localnet ff00::0 ip6-mcastprefix ff02::1 ip6-allnodes ff02::2 ip6-allrouters # From DHCP and "system static host-mapping" {%- if hosts %} {% for h in hosts -%} {{hosts[h]['address']}}\t{{h}}\t{% for a in hosts[h]['aliases'] %} {{a}} {% endfor %} {% endfor %} {%- endif %} """ hosts_tmpl = jinja2.Template(hosts_tmpl_source) resolv_tmpl_source = """ ### Autogenerated by VyOS ### ### Do not edit, your changes will get overwritten ### # name server from static configuration {% for ns in name_servers -%} {%- if name_servers[ns]['tag'] == "static" %} nameserver {{ns}} {%- endif %} {% endfor -%} {% for ns in name_servers -%} {%- if name_servers[ns]['tag'] != "static" %} # name server from {{name_servers[ns]['tag']}} nameserver {{ns}} {%- endif %} {% endfor -%} {%- if domain_name %} domain {{ domain_name }} {%- endif %} {%- if search_domains %} search {{ search_domains | join(" ") }} {%- endif %} """ resolv_tmpl = jinja2.Template(resolv_tmpl_source) # The state data includes a list of name servers # and a list of hosts entries. # # Name servers have the following structure: # {"server": {"tag": }} # # Hosts entries are similar: # {"host": {"tag": , "address": , "aliases": }} # # The tag is either "static" or "dhcp-" # It's used to distinguish entries created # by different scripts so that they can be removed # and re-created without having to track what needs # to be changed STATE = { "name_servers": collections.OrderedDict({}), "hosts": {}, "host_name": "vyos", "domain_name": "", "search_domains": []} def make_resolv_conf(data): resolv_conf = resolv_tmpl.render(data) logger.info("Writing /etc/resolv.conf") with open(RESOLV_CONF_FILE, 'w') as f: f.write(resolv_conf) def make_hosts_file(state): logger.info("Writing /etc/hosts") hosts = hosts_tmpl.render(state) with open(HOSTS_FILE, 'w') as f: f.write(hosts) def add_hosts(data, entries, tag): hosts = data['hosts'] if not entries: return for e in entries: host = e['host'] hosts[host] = {} hosts[host]['tag'] = tag hosts[host]['address'] = e['address'] hosts[host]['aliases'] = e['aliases'] def delete_hosts(data, tag): hosts = data['hosts'] keys_for_deletion = [] # You can't delete items from a dict while iterating over it, # so we build a list of doomed items first for h in hosts: if hosts[h]['tag'] == tag: keys_for_deletion.append(h) for k in keys_for_deletion: del hosts[k] def add_name_servers(data, entries, tag): name_servers = data['name_servers'] if not entries: return for e in entries: name_servers[e] = {} name_servers[e]['tag'] = tag def delete_name_servers(data, tag): name_servers = data['name_servers'] regex_filter = re.compile(tag) for ns in list(name_servers.keys()): if regex_filter.match(name_servers[ns]['tag']): del name_servers[ns] def set_host_name(state, data): if data['host_name']: state['host_name'] = data['host_name'] if 'domain_name' in data: state['domain_name'] = data['domain_name'] if 'search_domains' in data: state['search_domains'] = data['search_domains'] def get_name_servers(state, tag): ns = [] data = state['name_servers'] regex_filter = re.compile(tag) for n in data: if regex_filter.match(data[n]['tag']): ns.append(n) return ns def get_option(msg, key): if key in msg: return msg[key] else: raise ValueError("Missing required option \"{0}\"".format(key)) def handle_message(msg_json): msg = json.loads(msg_json) op = get_option(msg, 'op') _type = get_option(msg, 'type') if op == 'delete': tag = get_option(msg, 'tag') if _type == 'name_servers': delete_name_servers(STATE, tag) elif _type == 'hosts': delete_hosts(STATE, tag) else: raise ValueError("Unknown message type {0}".format(_type)) elif op == 'add': tag = get_option(msg, 'tag') entries = get_option(msg, 'data') if _type == 'name_servers': add_name_servers(STATE, entries, tag) elif _type == 'hosts': add_hosts(STATE, entries, tag) else: raise ValueError("Unknown message type {0}".format(_type)) elif op == 'set': # Host name/domain name/search domain are set without a tag, # there can be only one anyway data = get_option(msg, 'data') if _type == 'host_name': set_host_name(STATE, data) else: raise ValueError("Unknown message type {0}".format(_type)) elif op == 'get': tag = get_option(msg, 'tag') if _type == 'name_servers': result = get_name_servers(STATE, tag) else: raise ValueError("Unimplemented") return result else: raise ValueError("Unknown operation {0}".format(op)) make_resolv_conf(STATE) make_hosts_file(STATE) logger.info("Saving state to {0}".format(STATE_FILE)) with open(STATE_FILE, 'w') as f: json.dump(STATE, f) def exit_handler(sig, frame): """ Clean up the state when shutdown correctly """ logger.info("Cleaning up state") os.unlink(STATE_FILE) sys.exit(0) if __name__ == '__main__': signal.signal(signal.SIGTERM, exit_handler) # Create a directory for state checkpoints os.makedirs(DATA_DIR, exist_ok=True) if os.path.exists(STATE_FILE): with open(STATE_FILE, 'r') as f: try: data = json.load(f) STATE = data except: logger.exception(traceback.format_exc()) logger.exception("Failed to load the state file, using default") context = zmq.Context() socket = context.socket(zmq.REP) # Set the right permissions on the socket, then change it back o_mask = os.umask(0) socket.bind(SOCKET_PATH) os.umask(o_mask) while True: # Wait for next request from client message = socket.recv().decode() logger.info("Received a configuration change request") logger.debug("Request data: {0}".format(message)) resp = {} try: result = handle_message(message) resp['data'] = result except ValueError as e: resp['error'] = str(e) except: logger.exception(traceback.format_exc()) resp['error'] = "Internal error" logger.debug("Sent response: {0}".format(resp)) # Send reply back to client socket.send(json.dumps(resp).encode())