summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/distros/helpers.py179
-rw-r--r--cloudinit/distros/rhel.py32
-rw-r--r--cloudinit/util.py10
-rw-r--r--tests/unittests/test_distros/test_resolv.py61
4 files changed, 273 insertions, 9 deletions
diff --git a/cloudinit/distros/helpers.py b/cloudinit/distros/helpers.py
new file mode 100644
index 00000000..e1db74dc
--- /dev/null
+++ b/cloudinit/distros/helpers.py
@@ -0,0 +1,179 @@
+# vi: ts=4 expandtab
+#
+# Copyright (C) 2012 Canonical Ltd.
+# Copyright (C) 2012 Yahoo! Inc.
+#
+# Author: Scott Moser <scott.moser@canonical.com>
+# Author: Joshua Harlow <harlowja@yahoo-inc.com>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 3, as
+# published by the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+from StringIO import StringIO
+
+from cloudinit import util
+
+
+# See: man resolv.conf
+class ResolvConf(object):
+ def __init__(self, text):
+ self._text = text
+ self._contents = None
+
+ def parse(self):
+ if self._contents is None:
+ self._contents = self._parse(self._text)
+
+ @property
+ def nameservers(self):
+ self.parse()
+ return self._retr_option('nameserver')
+
+ @property
+ def local_domain(self):
+ self.parse()
+ dm = self._retr_option('domain')
+ if dm:
+ return dm[0]
+ return None
+
+ @property
+ def search_domains(self):
+ self.parse()
+ current_sds = self._retr_option('search')
+ flat_sds = []
+ for sdlist in current_sds:
+ for sd in sdlist.split(None):
+ if sd:
+ flat_sds.append(sd)
+ return flat_sds
+
+ def __str__(self):
+ self.parse()
+ contents = StringIO()
+ for (line_type, components) in self._contents:
+ if line_type == 'blank':
+ contents.write("\n")
+ elif line_type == 'all_comment':
+ contents.write("%s\n" % (components[0]))
+ elif line_type == 'option':
+ (cfg_opt, cfg_value, comment_tail) = components
+ line = "%s %s" % (cfg_opt, cfg_value)
+ if len(comment_tail):
+ line += comment_tail
+ contents.write("%s\n" % (line))
+ return contents.getvalue()
+
+ def _retr_option(self, opt_name):
+ found = []
+ for (line_type, components) in self._contents:
+ if line_type == 'option':
+ (cfg_opt, cfg_value, comment_tail) = components
+ if cfg_opt == opt_name:
+ found.append(cfg_value)
+ return found
+
+ def add_nameserver(self, ns):
+ self.parse()
+ current_ns = self._retr_option('nameserver')
+ new_ns = list(current_ns)
+ new_ns.append(str(ns))
+ new_ns = util.uniq_list(new_ns)
+ if len(new_ns) == len(current_ns):
+ return current_ns
+ if len(current_ns) >= 3:
+ # Hard restriction on only 3 name servers
+ raise ValueError(("Adding %r would go beyond the "
+ "'3' maximum name servers") % (ns))
+ self._remove_option('nameserver')
+ for n in new_ns:
+ self._contents.append(('option', ['nameserver', n, '']))
+ return new_ns
+
+ def _remove_option(self, opt_name):
+
+ def remove_opt(item):
+ line_type, components = item
+ if line_type != 'option':
+ return True
+ (cfg_opt, cfg_value, comment_tail) = components
+ if cfg_opt != opt_name:
+ return True
+ return False
+
+ new_contents = filter(remove_opt, self._contents)
+ self._contents = new_contents
+
+ def add_search_domain(self, search_domain):
+ flat_sds = self.search_domains
+ new_sds = list(flat_sds)
+ new_sds.append(str(search_domain))
+ new_sds = util.uniq_list(new_sds)
+ if len(flat_sds) == len(new_sds):
+ return new_sds
+ if len(flat_sds) >= 6:
+ # Hard restriction on only 6 search domains
+ raise ValueError(("Adding %r would go beyond the "
+ "'6' maximum search domains") % (search_domain))
+ s_list = " ".join(new_sds)
+ if len(s_list) > 256:
+ # Some hard limit on 256 chars total
+ raise ValueError(("Adding %r would go beyond the "
+ "256 maximum search list character limit")
+ % (search_domain))
+ self._remove_option('search')
+ self._contents.append(('option', ['search', s_list, '']))
+ return flat_sds
+
+ @local_domain.setter
+ def local_domain(self, domain):
+ self.parse()
+ self._remove_option('domain')
+ self._contents.append(('option', ['domain', str(domain), '']))
+ return domain
+
+ def _parse(self, contents):
+ entries = []
+ for (i, line) in enumerate(contents.splitlines()):
+ sline = line.strip()
+ if not sline:
+ entries.append(('blank', [line]))
+ continue
+ comment_s_loc = sline.find(";")
+ comment_h_loc = sline.find("#")
+ comment_loc = -1
+ if comment_s_loc != -1 and comment_h_loc != -1:
+ comment_loc = min(comment_h_loc, comment_s_loc)
+ elif comment_s_loc != -1:
+ comment_loc = comment_s_loc
+ elif comment_h_loc != -1:
+ comment_loc = comment_h_loc
+ head = line
+ tail = None
+ if comment_loc != -1:
+ head = line[:comment_loc]
+ tail = line[comment_loc:]
+ if not len(head.strip()):
+ entries.append(('all_comment', [line]))
+ continue
+ if not tail:
+ tail = ''
+ try:
+ (cfg_opt, cfg_values) = head.split(None, 1)
+ except (IndexError, ValueError):
+ raise IOError("Incorrectly formatted resolv.conf line %s" % (i + 1))
+ if cfg_opt not in ('nameserver', 'domain', 'search', 'sortlist', 'options'):
+ raise IOError("Unexpected resolv.conf option %s" % (cfg_opt))
+ entries.append(("option", [cfg_opt, cfg_values, tail]))
+ return entries
+
+
diff --git a/cloudinit/distros/rhel.py b/cloudinit/distros/rhel.py
index ec4dc2cc..1c9d493d 100644
--- a/cloudinit/distros/rhel.py
+++ b/cloudinit/distros/rhel.py
@@ -23,6 +23,8 @@
import os
from cloudinit import distros
+from cloudinit.distros import helpers as d_helpers
+
from cloudinit import helpers
from cloudinit import log as logging
from cloudinit import util
@@ -68,17 +70,29 @@ class Distro(distros.Distro):
def install_packages(self, pkglist):
self.package_command('install', pkglist)
- def _write_resolve(self, dns_servers, search_servers):
- contents = []
+ def _adjust_resolve(self, dns_servers, search_servers):
+ r_conf = d_helpers.ResolvConf(util.load_file("/etc/resolv.conf"))
+ try:
+ r_conf.parse()
+ except IOError:
+ util.logexc(LOG,
+ "Failed at parsing %s reverting to an empty instance",
+ "/etc/resolv.conf")
+ r_conf = d_helpers.ResolvConf('')
+ r_conf.parse()
if dns_servers:
for s in dns_servers:
- contents.append("nameserver %s" % (s))
+ try:
+ r_conf.add_nameserver(s)
+ except ValueError:
+ util.logexc(LOG, "Failed at adding nameserver %s", s)
if search_servers:
- contents.append("search %s" % (" ".join(search_servers)))
- if contents:
- resolve_rw_fn = self._paths.join(False, "/etc/resolv.conf")
- contents.insert(0, '# Created by cloud-init')
- util.write_file(resolve_rw_fn, "\n".join(contents), 0644)
+ for s in search_servers:
+ try:
+ r_conf.add_search_domain(s)
+ except ValueError:
+ util.logexc(LOG, "Failed at adding search domain %s", s)
+ util.write_file("/etc/resolv.conf", str(r_conf), 0644)
def _write_network(self, settings):
# TODO(harlowja) fix this... since this is the ubuntu format
@@ -126,7 +140,7 @@ class Distro(distros.Distro):
net_rw_fn = self._paths.join(False, net_fn)
util.write_file(net_rw_fn, w_contents, 0644)
if nameservers or searchservers:
- self._write_resolve(nameservers, searchservers)
+ self._adjust_resolve(nameservers, searchservers)
def set_hostname(self, hostname):
out_fn = self._paths.join(False, '/etc/sysconfig/network')
diff --git a/cloudinit/util.py b/cloudinit/util.py
index 33da73eb..e5fa42c6 100644
--- a/cloudinit/util.py
+++ b/cloudinit/util.py
@@ -952,6 +952,16 @@ def find_devs_with(criteria=None, oformat='device',
return entries
+def uniq_list(in_list):
+ out_list = []
+ for i in in_list:
+ if i in out_list:
+ continue
+ else:
+ out_list.append(i)
+ return out_list
+
+
def load_file(fname, read_cb=None, quiet=False):
LOG.debug("Reading from %s (quiet=%s)", fname, quiet)
ofh = StringIO()
diff --git a/tests/unittests/test_distros/test_resolv.py b/tests/unittests/test_distros/test_resolv.py
new file mode 100644
index 00000000..5f122833
--- /dev/null
+++ b/tests/unittests/test_distros/test_resolv.py
@@ -0,0 +1,61 @@
+from mocker import MockerTestCase
+
+from cloudinit.distros import helpers
+
+
+BASE_RESOLVE = '''
+; generated by /sbin/dhclient-script
+search blah.yahoo.com yahoo.com
+nameserver 10.15.44.14
+nameserver 10.15.30.92
+'''
+BASE_RESOLVE = BASE_RESOLVE.strip()
+
+
+class TestResolvHelper(MockerTestCase):
+ def test_parse_same(self):
+ rp = helpers.ResolvConf(BASE_RESOLVE)
+ rp_r = str(rp).strip()
+ self.assertEquals(BASE_RESOLVE, rp_r)
+
+ def test_local_domain(self):
+ rp = helpers.ResolvConf(BASE_RESOLVE)
+ self.assertEquals(None, rp.local_domain)
+
+ rp.local_domain = "bob"
+ self.assertEquals('bob', rp.local_domain)
+ self.assertIn('domain bob', str(rp))
+
+ def test_nameservers(self):
+ rp = helpers.ResolvConf(BASE_RESOLVE)
+ self.assertIn('10.15.44.14', rp.nameservers)
+ self.assertIn('10.15.30.92', rp.nameservers)
+ rp.add_nameserver('10.2')
+ self.assertIn('10.2', rp.nameservers)
+ self.assertIn('nameserver 10.2', str(rp))
+ self.assertNotIn('10.3', rp.nameservers)
+ self.assertEquals(len(rp.nameservers), 3)
+ rp.add_nameserver('10.2')
+ with self.assertRaises(ValueError):
+ rp.add_nameserver('10.3')
+ self.assertNotIn('10.3', rp.nameservers)
+
+ def test_search_domains(self):
+ rp = helpers.ResolvConf(BASE_RESOLVE)
+ self.assertIn('yahoo.com', rp.search_domains)
+ self.assertIn('blah.yahoo.com', rp.search_domains)
+ rp.add_search_domain('bbb.y.com')
+ self.assertIn('bbb.y.com', rp.search_domains)
+ self.assertRegexpMatches(str(rp), r'search(.*)bbb.y.com(.*)')
+ self.assertIn('bbb.y.com', rp.search_domains)
+ rp.add_search_domain('bbb.y.com')
+ self.assertEquals(len(rp.search_domains), 3)
+ rp.add_search_domain('bbb2.y.com')
+ self.assertEquals(len(rp.search_domains), 4)
+ rp.add_search_domain('bbb3.y.com')
+ self.assertEquals(len(rp.search_domains), 5)
+ rp.add_search_domain('bbb4.y.com')
+ self.assertEquals(len(rp.search_domains), 6)
+ with self.assertRaises(ValueError):
+ rp.add_search_domain('bbb5.y.com')
+ self.assertEquals(len(rp.search_domains), 6)