diff options
-rw-r--r-- | cloudinit/distros/helpers.py | 179 | ||||
-rw-r--r-- | cloudinit/distros/rhel.py | 32 | ||||
-rw-r--r-- | cloudinit/util.py | 10 | ||||
-rw-r--r-- | tests/unittests/test_distros/test_resolv.py | 61 |
4 files changed, 273 insertions, 9 deletions
diff --git a/cloudinit/distros/helpers.py b/cloudinit/distros/helpers.py new file mode 100644 index 00000000..e1db74dc --- /dev/null +++ b/cloudinit/distros/helpers.py @@ -0,0 +1,179 @@ +# vi: ts=4 expandtab +# +# Copyright (C) 2012 Canonical Ltd. +# Copyright (C) 2012 Yahoo! Inc. +# +# Author: Scott Moser <scott.moser@canonical.com> +# Author: Joshua Harlow <harlowja@yahoo-inc.com> +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3, as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from StringIO import StringIO + +from cloudinit import util + + +# See: man resolv.conf +class ResolvConf(object): + def __init__(self, text): + self._text = text + self._contents = None + + def parse(self): + if self._contents is None: + self._contents = self._parse(self._text) + + @property + def nameservers(self): + self.parse() + return self._retr_option('nameserver') + + @property + def local_domain(self): + self.parse() + dm = self._retr_option('domain') + if dm: + return dm[0] + return None + + @property + def search_domains(self): + self.parse() + current_sds = self._retr_option('search') + flat_sds = [] + for sdlist in current_sds: + for sd in sdlist.split(None): + if sd: + flat_sds.append(sd) + return flat_sds + + def __str__(self): + self.parse() + contents = StringIO() + for (line_type, components) in self._contents: + if line_type == 'blank': + contents.write("\n") + elif line_type == 'all_comment': + contents.write("%s\n" % (components[0])) + elif line_type == 'option': + (cfg_opt, cfg_value, comment_tail) = components + line = "%s %s" % (cfg_opt, cfg_value) + if len(comment_tail): + line += comment_tail + contents.write("%s\n" % (line)) + return contents.getvalue() + + def _retr_option(self, opt_name): + found = [] + for (line_type, components) in self._contents: + if line_type == 'option': + (cfg_opt, cfg_value, comment_tail) = components + if cfg_opt == opt_name: + found.append(cfg_value) + return found + + def add_nameserver(self, ns): + self.parse() + current_ns = self._retr_option('nameserver') + new_ns = list(current_ns) + new_ns.append(str(ns)) + new_ns = util.uniq_list(new_ns) + if len(new_ns) == len(current_ns): + return current_ns + if len(current_ns) >= 3: + # Hard restriction on only 3 name servers + raise ValueError(("Adding %r would go beyond the " + "'3' maximum name servers") % (ns)) + self._remove_option('nameserver') + for n in new_ns: + self._contents.append(('option', ['nameserver', n, ''])) + return new_ns + + def _remove_option(self, opt_name): + + def remove_opt(item): + line_type, components = item + if line_type != 'option': + return True + (cfg_opt, cfg_value, comment_tail) = components + if cfg_opt != opt_name: + return True + return False + + new_contents = filter(remove_opt, self._contents) + self._contents = new_contents + + def add_search_domain(self, search_domain): + flat_sds = self.search_domains + new_sds = list(flat_sds) + new_sds.append(str(search_domain)) + new_sds = util.uniq_list(new_sds) + if len(flat_sds) == len(new_sds): + return new_sds + if len(flat_sds) >= 6: + # Hard restriction on only 6 search domains + raise ValueError(("Adding %r would go beyond the " + "'6' maximum search domains") % (search_domain)) + s_list = " ".join(new_sds) + if len(s_list) > 256: + # Some hard limit on 256 chars total + raise ValueError(("Adding %r would go beyond the " + "256 maximum search list character limit") + % (search_domain)) + self._remove_option('search') + self._contents.append(('option', ['search', s_list, ''])) + return flat_sds + + @local_domain.setter + def local_domain(self, domain): + self.parse() + self._remove_option('domain') + self._contents.append(('option', ['domain', str(domain), ''])) + return domain + + def _parse(self, contents): + entries = [] + for (i, line) in enumerate(contents.splitlines()): + sline = line.strip() + if not sline: + entries.append(('blank', [line])) + continue + comment_s_loc = sline.find(";") + comment_h_loc = sline.find("#") + comment_loc = -1 + if comment_s_loc != -1 and comment_h_loc != -1: + comment_loc = min(comment_h_loc, comment_s_loc) + elif comment_s_loc != -1: + comment_loc = comment_s_loc + elif comment_h_loc != -1: + comment_loc = comment_h_loc + head = line + tail = None + if comment_loc != -1: + head = line[:comment_loc] + tail = line[comment_loc:] + if not len(head.strip()): + entries.append(('all_comment', [line])) + continue + if not tail: + tail = '' + try: + (cfg_opt, cfg_values) = head.split(None, 1) + except (IndexError, ValueError): + raise IOError("Incorrectly formatted resolv.conf line %s" % (i + 1)) + if cfg_opt not in ('nameserver', 'domain', 'search', 'sortlist', 'options'): + raise IOError("Unexpected resolv.conf option %s" % (cfg_opt)) + entries.append(("option", [cfg_opt, cfg_values, tail])) + return entries + + diff --git a/cloudinit/distros/rhel.py b/cloudinit/distros/rhel.py index ec4dc2cc..1c9d493d 100644 --- a/cloudinit/distros/rhel.py +++ b/cloudinit/distros/rhel.py @@ -23,6 +23,8 @@ import os from cloudinit import distros +from cloudinit.distros import helpers as d_helpers + from cloudinit import helpers from cloudinit import log as logging from cloudinit import util @@ -68,17 +70,29 @@ class Distro(distros.Distro): def install_packages(self, pkglist): self.package_command('install', pkglist) - def _write_resolve(self, dns_servers, search_servers): - contents = [] + def _adjust_resolve(self, dns_servers, search_servers): + r_conf = d_helpers.ResolvConf(util.load_file("/etc/resolv.conf")) + try: + r_conf.parse() + except IOError: + util.logexc(LOG, + "Failed at parsing %s reverting to an empty instance", + "/etc/resolv.conf") + r_conf = d_helpers.ResolvConf('') + r_conf.parse() if dns_servers: for s in dns_servers: - contents.append("nameserver %s" % (s)) + try: + r_conf.add_nameserver(s) + except ValueError: + util.logexc(LOG, "Failed at adding nameserver %s", s) if search_servers: - contents.append("search %s" % (" ".join(search_servers))) - if contents: - resolve_rw_fn = self._paths.join(False, "/etc/resolv.conf") - contents.insert(0, '# Created by cloud-init') - util.write_file(resolve_rw_fn, "\n".join(contents), 0644) + for s in search_servers: + try: + r_conf.add_search_domain(s) + except ValueError: + util.logexc(LOG, "Failed at adding search domain %s", s) + util.write_file("/etc/resolv.conf", str(r_conf), 0644) def _write_network(self, settings): # TODO(harlowja) fix this... since this is the ubuntu format @@ -126,7 +140,7 @@ class Distro(distros.Distro): net_rw_fn = self._paths.join(False, net_fn) util.write_file(net_rw_fn, w_contents, 0644) if nameservers or searchservers: - self._write_resolve(nameservers, searchservers) + self._adjust_resolve(nameservers, searchservers) def set_hostname(self, hostname): out_fn = self._paths.join(False, '/etc/sysconfig/network') diff --git a/cloudinit/util.py b/cloudinit/util.py index 33da73eb..e5fa42c6 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -952,6 +952,16 @@ def find_devs_with(criteria=None, oformat='device', return entries +def uniq_list(in_list): + out_list = [] + for i in in_list: + if i in out_list: + continue + else: + out_list.append(i) + return out_list + + def load_file(fname, read_cb=None, quiet=False): LOG.debug("Reading from %s (quiet=%s)", fname, quiet) ofh = StringIO() diff --git a/tests/unittests/test_distros/test_resolv.py b/tests/unittests/test_distros/test_resolv.py new file mode 100644 index 00000000..5f122833 --- /dev/null +++ b/tests/unittests/test_distros/test_resolv.py @@ -0,0 +1,61 @@ +from mocker import MockerTestCase + +from cloudinit.distros import helpers + + +BASE_RESOLVE = ''' +; generated by /sbin/dhclient-script +search blah.yahoo.com yahoo.com +nameserver 10.15.44.14 +nameserver 10.15.30.92 +''' +BASE_RESOLVE = BASE_RESOLVE.strip() + + +class TestResolvHelper(MockerTestCase): + def test_parse_same(self): + rp = helpers.ResolvConf(BASE_RESOLVE) + rp_r = str(rp).strip() + self.assertEquals(BASE_RESOLVE, rp_r) + + def test_local_domain(self): + rp = helpers.ResolvConf(BASE_RESOLVE) + self.assertEquals(None, rp.local_domain) + + rp.local_domain = "bob" + self.assertEquals('bob', rp.local_domain) + self.assertIn('domain bob', str(rp)) + + def test_nameservers(self): + rp = helpers.ResolvConf(BASE_RESOLVE) + self.assertIn('10.15.44.14', rp.nameservers) + self.assertIn('10.15.30.92', rp.nameservers) + rp.add_nameserver('10.2') + self.assertIn('10.2', rp.nameservers) + self.assertIn('nameserver 10.2', str(rp)) + self.assertNotIn('10.3', rp.nameservers) + self.assertEquals(len(rp.nameservers), 3) + rp.add_nameserver('10.2') + with self.assertRaises(ValueError): + rp.add_nameserver('10.3') + self.assertNotIn('10.3', rp.nameservers) + + def test_search_domains(self): + rp = helpers.ResolvConf(BASE_RESOLVE) + self.assertIn('yahoo.com', rp.search_domains) + self.assertIn('blah.yahoo.com', rp.search_domains) + rp.add_search_domain('bbb.y.com') + self.assertIn('bbb.y.com', rp.search_domains) + self.assertRegexpMatches(str(rp), r'search(.*)bbb.y.com(.*)') + self.assertIn('bbb.y.com', rp.search_domains) + rp.add_search_domain('bbb.y.com') + self.assertEquals(len(rp.search_domains), 3) + rp.add_search_domain('bbb2.y.com') + self.assertEquals(len(rp.search_domains), 4) + rp.add_search_domain('bbb3.y.com') + self.assertEquals(len(rp.search_domains), 5) + rp.add_search_domain('bbb4.y.com') + self.assertEquals(len(rp.search_domains), 6) + with self.assertRaises(ValueError): + rp.add_search_domain('bbb5.y.com') + self.assertEquals(len(rp.search_domains), 6) |