summaryrefslogtreecommitdiff
path: root/src/services/vyos-hostsd
blob: 8f70eb4e96d54d7dc8830468e342ef280b407d9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
#!/usr/bin/env python3
#
# Copyright (C) 2019 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#

import os
import sys
import time
import json
import signal
import traceback

import zmq

import jinja2

debug = True

DATA_DIR = "/var/lib/vyos/"
STATE_FILE = os.path.join(DATA_DIR, "hostsd.state")

SOCKET_PATH = "ipc:///run/vyos-hostsd.sock"

RESOLV_CONF_FILE = '/etc/resolv.conf'
HOSTS_FILE = '/etc/hosts'

hosts_tmpl_source = """
### Autogenerated by VyOS ###
### Do not edit, your changes will get overwritten ###

# Local host
127.0.0.1       localhost
127.0.1.1       {{ host_name }}{% if domain_name %}.{{ domain_name }}{% endif %}

# The following lines are desirable for IPv6 capable hosts
::1             localhost ip6-localhost ip6-loopback
fe00::0         ip6-localnet
ff00::0         ip6-mcastprefix
ff02::1         ip6-allnodes
ff02::2         ip6-allrouters

# From DHCP and "system static host-mapping"
{%- if hosts %}
{% for h in hosts -%}
{{hosts[h]['address']}}\t{{h}}\t{% for a in hosts[h]['aliases'] %} {{a}} {% endfor %}
{% endfor %}
{%- endif %}
"""

hosts_tmpl = jinja2.Template(hosts_tmpl_source)

resolv_tmpl_source = """
### Autogenerated by VyOS ###
### Do not edit, your changes will get overwritten ###

{% for ns in name_servers -%}
nameserver {{ns}}
{% endfor -%}

{%- if domain_name %}
domain {{ domain_name }}
{%- endif %}

{%- if search_domains %}
search {{ search_domains | join(" ") }}
{%- endif %}

"""

resolv_tmpl = jinja2.Template(resolv_tmpl_source)

# The state data includes a list of name servers
# and a list of hosts entries.
#
# Name servers have the following structure:
# {"server": {"tag": <str>}}
#
# Hosts entries are similar:
# {"host": {"tag": <str>, "address": <str>, "aliases": <str list>}}
#
# The tag is either "static" or "dhcp-<intf>"
# It's used to distinguish entries created
# by different scripts so that they can be removed
# and re-created without having to track what needs
# to be changed
STATE = {
    "name_servers": {},
    "hosts": {},
    "host_name": "vyos",
    "domain_name": "",
    "search_domains": []}


def make_resolv_conf(data):
    resolv_conf = resolv_tmpl.render(data)
    print("Writing /etc/resolv.conf")
    with open(RESOLV_CONF_FILE, 'w') as f:
        f.write(resolv_conf)

def make_hosts_file(state):
    print("Writing /etc/hosts")
    hosts = hosts_tmpl.render(state)
    with open(HOSTS_FILE, 'w') as f:
        f.write(hosts)

def add_hosts(data, entries, tag):
    hosts = data['hosts']

    if not entries:
        return

    for e in entries:
        host = e['host']
        hosts[host] = {}
        hosts[host]['tag'] = tag
        hosts[host]['address'] = e['address']
        hosts[host]['aliases'] = e['aliases']

def delete_hosts(data, tag):
    hosts = data['hosts']
    keys_for_deletion = []

    # You can't delete items from a dict while iterating over it,
    # so we build a list of doomed items first
    for h in hosts:
        if hosts[h]['tag'] == tag:
            keys_for_deletion.append(h)

    for k in keys_for_deletion:
        del hosts[k]

def add_name_servers(data, entries, tag):
    name_servers = data['name_servers']

    if not entries:
        return

    for e in entries:
        name_servers[e] = {}
        name_servers[e]['tag'] = tag

def delete_name_servers(data, tag):
    name_servers = data['name_servers']
    keys_for_deletion = []

    for ns in name_servers:
      if name_servers[ns]['tag'] == tag:
          keys_for_deletion.append(ns)

    for k in keys_for_deletion:
      del name_servers[k]

def set_host_name(state, data):
    if data['host_name']:
        state['host_name'] = data['host_name']
    if data['domain_name']:
        state['domain_name'] = data['domain_name']
    if data['search_domains']:
        state['search_domains'] = data['search_domains']

def get_name_servers(state, tag):
    ns = []
    data = state['name_servers']
    for n in data:
        if data[n]['tag'] == tag:
            ns.append(n)
    return ns

def get_option(msg, key):
    if key in msg:
        return msg[key]
    else:
        raise ValueError("Missing required option \"{0}\"".format(key))

def handle_message(msg_json):
    msg = json.loads(msg_json)

    op = get_option(msg, 'op')
    _type = get_option(msg, 'type')

    if op == 'delete':
        tag = get_option(msg, 'tag')

        if _type == 'name_servers':
            delete_name_servers(STATE, tag)
        elif _type == 'hosts':
            delete_hosts(STATE, tag)
        else:
            raise ValueError("Unknown message type {0}".format(_type))
    elif op == 'add':
        tag = get_option(msg, 'tag')
        entries = get_option(msg, 'data')
        if _type == 'name_servers':
            add_name_servers(STATE, entries, tag)
        elif _type == 'hosts':
            add_hosts(STATE, entries, tag)
        else:
            raise ValueError("Unknown message type {0}".format(_type))
    elif op == 'set':
        # Host name/domain name/search domain are set without a tag,
        # there can be only one anyway
        data = get_option(msg, 'data')
        if _type == 'host_name':
            set_host_name(STATE, data)
        else:
            raise ValueError("Unknown message type {0}".format(_type))
    elif op == 'get':
        tag = get_option(msg, 'tag')
        if _type == 'name_servers':
            result = get_name_servers(STATE, tag)
        else:
            raise ValueError("Unimplemented")
        return result
    else:
        raise ValueError("Unknown operation {0}".format(op))

    make_resolv_conf(STATE)
    make_hosts_file(STATE)

    print("Saving state to {0}".format(STATE_FILE))
    with open(STATE_FILE, 'w') as f:
        json.dump(STATE, f)

def exit_handler(sig, frame):
    """ Clean up the state when shutdown correctly """
    print("Cleaning up state")
    os.unlink(STATE_FILE)
    sys.exit(0)


if __name__ == '__main__':
    signal.signal(signal.SIGTERM, exit_handler)

    # Create a directory for state checkpoints
    os.makedirs(DATA_DIR, exist_ok=True)
    if os.path.exists(STATE_FILE):
        with open(STATE_FILE, 'r') as f:
            try:
                data = json.load(f)
                STATE = data
            except:
                print(traceback.format_exc())
                print("Failed to load the state file, using default")

    context = zmq.Context()
    socket = context.socket(zmq.REP)
    socket.bind(SOCKET_PATH)

    while True:
        #  Wait for next request from client
        message = socket.recv().decode()
        print("Received a configuration change request")
        if debug:
            print("Request data: {0}".format(message))

        resp = {}

        try:
            result = handle_message(message)
            resp['data'] = result
        except ValueError as e:
            resp['error'] = str(e)
        except:
            print(traceback.format_exc())
            resp['error'] = "Internal error"

        if debug:
            print("Sent response: {0}".format(resp))

        #  Send reply back to client
        socket.send(json.dumps(resp).encode())