diff options
| -rw-r--r-- | ChangeLog | 9 | ||||
| -rw-r--r-- | cloudinit/distros/__init__.py | 187 | ||||
| -rw-r--r-- | cloudinit/distros/debian.py | 115 | ||||
| -rw-r--r-- | cloudinit/distros/parsers/__init__.py | 28 | ||||
| -rw-r--r-- | cloudinit/distros/parsers/hostname.py | 88 | ||||
| -rw-r--r-- | cloudinit/distros/parsers/hosts.py | 93 | ||||
| -rw-r--r-- | cloudinit/distros/parsers/resolv_conf.py | 169 | ||||
| -rw-r--r-- | cloudinit/distros/parsers/sys_conf.py | 113 | ||||
| -rw-r--r-- | cloudinit/distros/rhel.py | 186 | ||||
| -rw-r--r-- | cloudinit/util.py | 25 | ||||
| -rw-r--r-- | tests/unittests/test_distros/test_hostname.py | 38 | ||||
| -rw-r--r-- | tests/unittests/test_distros/test_hosts.py | 41 | ||||
| -rw-r--r-- | tests/unittests/test_distros/test_netconfig.py | 7 | ||||
| -rw-r--r-- | tests/unittests/test_distros/test_resolv.py | 63 | ||||
| -rw-r--r-- | tests/unittests/test_distros/test_sysconfig.py | 83 | 
15 files changed, 1006 insertions, 239 deletions
| @@ -41,6 +41,15 @@     metadata port the lazy loaded dictionary will continue working properly      instead of trying to make additional url calls which will fail (LP: #1068801)      - Added dependency on distribute's python-pkg-resources + - use a set of helper/parsing classes to perform system configuration +   file modification in a manner that provides a nice object oriented +   interface to those objects as well as makes it possible to test +   those parsing entities without having to invoke distro class code. +    - Created parsers for: +     - /etc/sysconfig +     - /etc/hostname  +     - resolv.conf +     - /etc/hosts  0.7.0:   - add a 'exception_cb' argument to 'wait_for_url'.  If provided, this     method will be called back with the exception received and the message. diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 8a98e334..464ae550 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -24,6 +24,7 @@  from StringIO import StringIO  import abc +import collections  import itertools  import os  import re @@ -33,11 +34,18 @@ from cloudinit import log as logging  from cloudinit import ssh_util  from cloudinit import util +from cloudinit.distros.parsers import hosts +  LOG = logging.getLogger(__name__)  class Distro(object):      __metaclass__ = abc.ABCMeta +    default_user = None +    default_user_groups = None +    hosts_fn = "/etc/hosts" +    ci_sudoers_fn = "/etc/sudoers.d/90-cloud-init-users" +    hostname_conf_fn = "/etc/hostname"      def __init__(self, name, cfg, paths):          self._paths = paths @@ -57,13 +65,10 @@ class Distro(object):      def get_option(self, opt_name, default=None):          return self._cfg.get(opt_name, default) -    @abc.abstractmethod      def set_hostname(self, hostname, fqdn=None): -        raise NotImplementedError() - -    @abc.abstractmethod -    def update_hostname(self, hostname, fqdn, prev_hostname_fn): -        raise NotImplementedError() +        writeable_hostname = self._select_hostname(hostname, fqdn) +        self._write_hostname(writeable_hostname, self.hostname_conf_fn) +        self._apply_hostname(hostname)      @abc.abstractmethod      def package_command(self, cmd, args=None): @@ -87,7 +92,7 @@ class Distro(object):      def get_package_mirror_info(self, arch=None,                                  availability_zone=None): -        # this resolves the package_mirrors config option +        # This resolves the package_mirrors config option          # down to a single dict of {mirror_name: mirror_url}          arch_info = self._get_arch_package_mirror_info(arch)          return _get_package_mirror_info(availability_zone=availability_zone, @@ -112,41 +117,110 @@ class Distro(object):      def _get_localhost_ip(self):          return "127.0.0.1" +    @abc.abstractmethod +    def _read_hostname(self, filename, default=None): +        raise NotImplementedError() + +    @abc.abstractmethod +    def _write_hostname(self, hostname, filename): +        raise NotImplementedError() + +    @abc.abstractmethod +    def _read_system_hostname(self): +        raise NotImplementedError() + +    def _apply_hostname(self, hostname): +        # This really only sets the hostname +        # temporarily (until reboot so it should +        # not be depended on). Use the write +        # hostname functions for 'permanent' adjustments. +        LOG.debug("Non-persistently setting the system hostname to %s", +                  hostname) +        try: +            util.subp(['hostname', hostname]) +        except util.ProcessExecutionError: +            util.logexc(LOG, ("Failed to non-persistently adjust" +                              " the system hostname to %s"), hostname) + +    @abc.abstractmethod +    def _select_hostname(self, hostname, fqdn): +        raise NotImplementedError() + +    def update_hostname(self, hostname, fqdn, +                        previous_hostname_filename): +        applying_hostname = hostname +        hostname = self._select_hostname(hostname, fqdn) +        prev_hostname = self._read_hostname(prev_hostname_fn) +        (sys_fn, sys_hostname) = self._read_system_hostname() +        update_files = [] +        if not prev_hostname or prev_hostname != hostname: +            update_files.append(prev_hostname_fn) + +        if (not sys_hostname) or (sys_hostname == prev_hostname +                                  and sys_hostname != hostname): +            update_files.append(sys_fn) + +        update_files = set([f for f in update_files if f]) +        LOG.debug("Attempting to update hostname to %s in %s files", +                  hostname, len(update_files)) + +        for fn in update_files: +            try: +                self._write_hostname(hostname, fn) +            except IOError: +                util.logexc(LOG, "Failed to write hostname %s to %s", +                            hostname, fn) + +        if (sys_hostname and prev_hostname and +            sys_hostname != prev_hostname): +            LOG.debug("%s differs from %s, assuming user maintained hostname.", +                       prev_hostname_fn, sys_fn) + +        if sys_fn in update_files: +            self._apply_hostname(applying_hostname) +      def update_etc_hosts(self, hostname, fqdn): -        # Format defined at -        # http://unixhelp.ed.ac.uk/CGI/man-cgi?hosts -        header = "# Added by cloud-init" -        real_header = "%s on %s" % (header, util.time_rfc2822()) +        header = '' +        if os.path.exists(self.hosts_fn): +            eh = hosts.HostsConf(util.load_file(self.hosts_fn)) +        else: +            eh = hosts.HostsConf('') +            header = util.make_header(base="added")          local_ip = self._get_localhost_ip() -        hosts_line = "%s\t%s %s" % (local_ip, fqdn, hostname) -        new_etchosts = StringIO() -        need_write = False -        need_change = True -        for line in util.load_file("/etc/hosts").splitlines(): -            if line.strip().startswith(header): -                continue -            if not line.strip() or line.strip().startswith("#"): -                new_etchosts.write("%s\n" % (line)) -                continue -            split_line = [s.strip() for s in line.split()] -            if len(split_line) < 2: -                new_etchosts.write("%s\n" % (line)) -                continue -            (ip, hosts) = split_line[0], split_line[1:] -            if ip == local_ip: -                if sorted([hostname, fqdn]) == sorted(hosts): -                    need_change = False -                if need_change: -                    line = "%s\n%s" % (real_header, hosts_line) -                    need_change = False -                    need_write = True -            new_etchosts.write("%s\n" % (line)) +        prev_info = eh.get_entry(local_ip) +        need_change = False +        if not prev_info: +            eh.add_entry(local_ip, fqdn, hostname) +            need_change = True +        else: +            need_change = True +            for entry in prev_info: +                entry_fqdn = None +                entry_aliases = [] +                if len(entry) >= 1: +                    entry_fqdn = entry[0] +                if len(entry) >= 2: +                    entry_aliases = entry[1:] +                if entry_fqdn is not None and entry_fqdn == fqdn: +                    if hostname in entry_aliases: +                        # Exists already, leave it be +                        need_change = False +            if need_change: +                # Doesn't exist, add that entry in... +                new_entries = list(prev_info) +                new_entries.append([fqdn, hostname]) +                eh.del_entries(local_ip) +                for entry in new_entries: +                    if len(entry) == 1: +                        eh.add_entry(local_ip, entry[0]) +                    elif len(entry) >= 2: +                        eh.add_entry(local_ip, *entry)          if need_change: -            new_etchosts.write("%s\n%s\n" % (real_header, hosts_line)) -            need_write = True -        if need_write: -            contents = new_etchosts.getvalue() -            util.write_file("/etc/hosts", contents, mode=0644) +            contents = StringIO() +            if header: +                contents.write("%s\n" % (header)) +            contents.write("%s\n" % (eh)) +            util.write_file(self.hosts_fn, contents.getvalue(), mode=0644)      def _bring_up_interface(self, device_name):          cmd = ['ifup', device_name] @@ -305,12 +379,12 @@ class Distro(object):                  if not base_exists:                      lines = [('# See sudoers(5) for more information'                                ' on "#include" directives:'), '', -                             '# Added by cloud-init', +                             util.make_header(base="added"),                               "#includedir %s" % (path), '']                      sudoers_contents = "\n".join(lines)                      util.write_file(sudo_base, sudoers_contents, 0440)                  else: -                    lines = ['', '# Added by cloud-init', +                    lines = ['', util.make_header(base="added"),                               "#includedir %s" % (path), '']                      sudoers_contents = "\n".join(lines)                      util.append_file(sudo_base, sudoers_contents) @@ -322,26 +396,35 @@ class Distro(object):      def write_sudo_rules(self, user, rules, sudo_file=None):          if not sudo_file: -            sudo_file = "/etc/sudoers.d/90-cloud-init-users" +            sudo_file = self.ci_sudoers_fn -        content_header = "# user rules for %s" % user -        content = "%s\n%s %s\n\n" % (content_header, user, rules) - -        if isinstance(rules, list): -            content = "%s\n" % content_header +        lines = [ +            '', +            "# User rules for %s" % user, +        ] +        if isinstance(rules, collections.Iterable):              for rule in rules: -                content += "%s %s\n" % (user, rule) -            content += "\n" +                lines.append("%s %s" % (user, rule)) +        else: +            lines.append("%s %s" % (user, rules)) +        content = "\n".join(lines)          self.ensure_sudo_dir(os.path.dirname(sudo_file)) -          if not os.path.exists(sudo_file): -            util.write_file(sudo_file, content, 0440) +            contents = [ +                util.make_header(), +                content, +            ] +            try: +                util.write_file(sudo_file, "\n".join(contents), 0440) +            except IOError as e: +                util.logexc(LOG, "Failed to write sudoers file %s", sudo_file) +                raise e          else:              try:                  util.append_file(sudo_file, content)              except IOError as e: -                util.logexc(LOG, "Failed to write %s" % sudo_file, e) +                util.logexc(LOG, "Failed to append sudoers file %s", sudo_file)                  raise e      def create_group(self, name, members): diff --git a/cloudinit/distros/debian.py b/cloudinit/distros/debian.py index ed4070b4..b6e7654f 100644 --- a/cloudinit/distros/debian.py +++ b/cloudinit/distros/debian.py @@ -27,12 +27,20 @@ from cloudinit import helpers  from cloudinit import log as logging  from cloudinit import util +from cloudinit.distros.parsers.hostname import HostnameConf +  from cloudinit.settings import PER_INSTANCE  LOG = logging.getLogger(__name__)  class Distro(distros.Distro): +    hostname_conf_fn = "/etc/hostname" +    locale_conf_fn = "/etc/default/locale" +    network_conf_fn = "/etc/network/interfaces" +    tz_conf_fn = "/etc/timezone" +    tz_local_fn = "/etc/localtime" +    tz_zone_dir = "/usr/share/zoneinfo"      def __init__(self, name, cfg, paths):          distros.Distro.__init__(self, name, cfg, paths) @@ -43,10 +51,15 @@ class Distro(distros.Distro):      def apply_locale(self, locale, out_fn=None):          if not out_fn: -            out_fn = '/etc/default/locale' +            out_fn = self.locale_conf_fn          util.subp(['locale-gen', locale], capture=False)          util.subp(['update-locale', locale], capture=False) -        lines = ["# Created by cloud-init", 'LANG="%s"' % (locale), ""] +        # "" provides trailing newline during join +        lines = [ +            util.make_header(), +            'LANG="%s"' % (locale), +            "", +        ]          util.write_file(out_fn, "\n".join(lines))      def install_packages(self, pkglist): @@ -54,7 +67,7 @@ class Distro(distros.Distro):          self.package_command('install', pkglist)      def _write_network(self, settings): -        util.write_file("/etc/network/interfaces", settings) +        util.write_file(self.network_conf_fn, settings)          return ['all']      def _bring_up_interfaces(self, device_names): @@ -67,64 +80,66 @@ class Distro(distros.Distro):          else:              return distros.Distro._bring_up_interfaces(self, device_names) -    def set_hostname(self, hostname, fqdn=None): -        self._write_hostname(hostname, "/etc/hostname") -        LOG.debug("Setting hostname to %s", hostname) -        util.subp(['hostname', hostname]) - -    def _write_hostname(self, hostname, out_fn): -        # "" gives trailing newline. -        util.write_file(out_fn, "%s\n" % str(hostname), 0644) - -    def update_hostname(self, hostname, fqdn, prev_fn): -        hostname_prev = self._read_hostname(prev_fn) -        hostname_in_etc = self._read_hostname("/etc/hostname") -        update_files = [] -        if not hostname_prev or hostname_prev != hostname: -            update_files.append(prev_fn) -        if (not hostname_in_etc or -            (hostname_in_etc == hostname_prev and -             hostname_in_etc != hostname)): -            update_files.append("/etc/hostname") -        for fn in update_files: -            try: -                self._write_hostname(hostname, fn) -            except: -                util.logexc(LOG, "Failed to write hostname %s to %s", -                            hostname, fn) -        if (hostname_in_etc and hostname_prev and -            hostname_in_etc != hostname_prev): -            LOG.debug(("%s differs from /etc/hostname." -                        " Assuming user maintained hostname."), prev_fn) -        if "/etc/hostname" in update_files: -            LOG.debug("Setting hostname to %s", hostname) -            util.subp(['hostname', hostname]) +    def _select_hostname(self, hostname, fqdn): +        # Prefer the short hostname over the long +        # fully qualified domain name +        if not hostname: +            return fqdn +        return hostname + +    def _write_hostname(self, your_hostname, out_fn): +        conf = self._read_hostname_conf(out_fn) +        if not conf: +            conf = HostnameConf('') +            conf.parse() +        conf.set_hostname(your_hostname) +        util.write_file(out_fn, str(conf), 0644) + +    def _read_system_hostname(self): +        conf = self._read_hostname_conf(self.hostname_conf_fn) +        if conf: +            sys_hostname = conf.hostname +        else: +            sys_hostname = None +        return (self.hostname_conf_fn, sys_hostname) + +    def _read_hostname_conf(self, filename): +        try: +            conf = HostnameConf(util.load_file(filename)) +            conf.parse() +            return conf +        except IOError: +            util.logexc(LOG, "Error reading hostname from %s", filename) +            return None      def _read_hostname(self, filename, default=None): -        contents = util.load_file(filename, quiet=True) -        for line in contents.splitlines(): -            c_pos = line.find("#") -            # Handle inline comments -            if c_pos != -1: -                line = line[0:c_pos] -            line_c = line.strip() -            if line_c: -                return line_c -        return default +        conf = self._read_hostname_conf(filename) +        if not conf: +            return default +        if not conf.hostname: +            return default +        return conf.hostname      def _get_localhost_ip(self):          # Note: http://www.leonardoborda.com/blog/127-0-1-1-ubuntu-debian/          return "127.0.1.1"      def set_timezone(self, tz): -        tz_file = os.path.join("/usr/share/zoneinfo", tz) +        # TODO(harlowja): move this code into +        # the parent distro... +        tz_file = os.path.join(self.tz_zone_dir, str(tz))          if not os.path.isfile(tz_file):              raise RuntimeError(("Invalid timezone %s,"                                  " no file found at %s") % (tz, tz_file)) -        # "" provides trailing newline during join -        tz_lines = ["# Created by cloud-init", str(tz), ""] -        util.write_file("/etc/timezone", "\n".join(tz_lines)) -        util.copy(tz_file, "/etc/localtime") +        # Note: "" provides trailing newline during join +        tz_lines = [ +            util.make_header(), +            str(tz), +            "", +        ] +        util.write_file(self.tz_conf_fn, "\n".join(tz_lines)) +        # This ensures that the correct tz will be used for the system +        util.copy(tz_file, self.tz_local_fn)      def package_command(self, command, args=None):          e = os.environ.copy() diff --git a/cloudinit/distros/parsers/__init__.py b/cloudinit/distros/parsers/__init__.py new file mode 100644 index 00000000..1c413eaa --- /dev/null +++ b/cloudinit/distros/parsers/__init__.py @@ -0,0 +1,28 @@ +# vi: ts=4 expandtab +# +#    Copyright (C) 2012 Yahoo! Inc. +# +#    Author: Joshua Harlow <harlowja@yahoo-inc.com> +# +#    This program is free software: you can redistribute it and/or modify +#    it under the terms of the GNU General Public License version 3, as +#    published by the Free Software Foundation. +# +#    This program is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    GNU General Public License for more details. +# +#    You should have received a copy of the GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. + + +def chop_comment(text, comment_chars): +    comment_locations = [text.find(c) for c in comment_chars] +    comment_locations = [c for c in comment_locations if c != -1] +    if not comment_locations: +        return (text, '') +    min_comment = min(comment_locations) +    before_comment = text[0:min_comment] +    comment = text[min_comment:] +    return (before_comment, comment) diff --git a/cloudinit/distros/parsers/hostname.py b/cloudinit/distros/parsers/hostname.py new file mode 100644 index 00000000..617b3c36 --- /dev/null +++ b/cloudinit/distros/parsers/hostname.py @@ -0,0 +1,88 @@ +# vi: ts=4 expandtab +# +#    Copyright (C) 2012 Yahoo! Inc. +# +#    Author: Joshua Harlow <harlowja@yahoo-inc.com> +# +#    This program is free software: you can redistribute it and/or modify +#    it under the terms of the GNU General Public License version 3, as +#    published by the Free Software Foundation. +# +#    This program is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    GNU General Public License for more details. +# +#    You should have received a copy of the GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. + +from StringIO import StringIO + +from cloudinit.distros.parsers import chop_comment + + +# Parser that knows how to work with /etc/hostname format +class HostnameConf(object): +    def __init__(self, text): +        self._text = text +        self._contents = None + +    def parse(self): +        if self._contents is None: +            self._contents = self._parse(self._text) + +    def __str__(self): +        self.parse() +        contents = StringIO() +        for (line_type, components) in self._contents: +            if line_type == 'blank': +                contents.write("%s\n" % (components[0])) +            elif line_type == 'all_comment': +                contents.write("%s\n" % (components[0])) +            elif line_type == 'hostname': +                (hostname, tail) = components +                contents.write("%s%s\n" % (hostname, tail)) +        # Ensure trailing newline +        contents = contents.getvalue() +        if not contents.endswith("\n"): +            contents += "\n" +        return contents + +    @property +    def hostname(self): +        self.parse() +        for (line_type, components) in self._contents: +            if line_type == 'hostname': +                return components[0] +        return None + +    def set_hostname(self, your_hostname): +        your_hostname = your_hostname.strip() +        if not your_hostname: +            return +        self.parse() +        replaced = False +        for (line_type, components) in self._contents: +            if line_type == 'hostname': +                components[0] = str(your_hostname) +                replaced = True +        if not replaced: +            self._contents.append(('hostname', [str(your_hostname), ''])) + +    def _parse(self, contents): +        entries = [] +        hostnames_found = set() +        for line in contents.splitlines(): +            if not len(line.strip()): +                entries.append(('blank', [line])) +                continue +            (head, tail) = chop_comment(line.strip(), '#') +            if not len(head): +                entries.append(('all_comment', [line])) +                continue +            entries.append(('hostname', [head, tail])) +            hostnames_found.add(head) +        if len(hostnames_found) > 1: +            raise IOError("Multiple hostnames (%s) found!" +                           % (hostnames_found)) +        return entries diff --git a/cloudinit/distros/parsers/hosts.py b/cloudinit/distros/parsers/hosts.py new file mode 100644 index 00000000..958a7c31 --- /dev/null +++ b/cloudinit/distros/parsers/hosts.py @@ -0,0 +1,93 @@ +# vi: ts=4 expandtab +# +#    Copyright (C) 2012 Yahoo! Inc. +# +#    Author: Joshua Harlow <harlowja@yahoo-inc.com> +# +#    This program is free software: you can redistribute it and/or modify +#    it under the terms of the GNU General Public License version 3, as +#    published by the Free Software Foundation. +# +#    This program is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    GNU General Public License for more details. +# +#    You should have received a copy of the GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. + +from StringIO import StringIO + +from cloudinit.distros.parsers import chop_comment + + +# See: man hosts +# or http://unixhelp.ed.ac.uk/CGI/man-cgi?hosts +# or http://tinyurl.com/6lmox3 +class HostsConf(object): +    def __init__(self, text): +        self._text = text +        self._contents = None + +    def parse(self): +        if self._contents is None: +            self._contents = self._parse(self._text) + +    def get_entry(self, ip): +        self.parse() +        options = [] +        for (line_type, components) in self._contents: +            if line_type == 'option': +                (pieces, _tail) = components +                if len(pieces) and pieces[0] == ip: +                    options.append(pieces[1:]) +        return options + +    def del_entries(self, ip): +        self.parse() +        n_entries = [] +        for (line_type, components) in self._contents: +            if line_type != 'option': +                n_entries.append((line_type, components)) +                continue +            else: +                (pieces, _tail) = components +                if len(pieces) and pieces[0] == ip: +                    pass +                elif len(pieces): +                    n_entries.append((line_type, list(components))) +        self._contents = n_entries + +    def add_entry(self, ip, canonical_hostname, *aliases): +        self.parse() +        self._contents.append(('option', +                              ([ip, canonical_hostname] + list(aliases), ''))) + +    def _parse(self, contents): +        entries = [] +        for line in contents.splitlines(): +            if not len(line.strip()): +                entries.append(('blank', [line])) +                continue +            (head, tail) = chop_comment(line.strip(), '#') +            if not len(head): +                entries.append(('all_comment', [line])) +                continue +            entries.append(('option', [head.split(None), tail])) +        return entries + +    def __str__(self): +        self.parse() +        contents = StringIO() +        for (line_type, components) in self._contents: +            if line_type == 'blank': +                contents.write("%s\n" % (components[0])) +            elif line_type == 'all_comment': +                contents.write("%s\n" % (components[0])) +            elif line_type == 'option': +                (pieces, tail) = components +                pieces = [str(p) for p in pieces] +                pieces = "\t".join(pieces) +                contents.write("%s%s\n" % (pieces, tail)) +        return contents.getvalue() + diff --git a/cloudinit/distros/parsers/resolv_conf.py b/cloudinit/distros/parsers/resolv_conf.py new file mode 100644 index 00000000..5733c25a --- /dev/null +++ b/cloudinit/distros/parsers/resolv_conf.py @@ -0,0 +1,169 @@ +# vi: ts=4 expandtab +# +#    Copyright (C) 2012 Yahoo! Inc. +# +#    Author: Joshua Harlow <harlowja@yahoo-inc.com> +# +#    This program is free software: you can redistribute it and/or modify +#    it under the terms of the GNU General Public License version 3, as +#    published by the Free Software Foundation. +# +#    This program is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    GNU General Public License for more details. +# +#    You should have received a copy of the GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. + +from StringIO import StringIO + +from cloudinit import util + +from cloudinit.distros.parsers import chop_comment + + +# See: man resolv.conf +class ResolvConf(object): +    def __init__(self, text): +        self._text = text +        self._contents = None + +    def parse(self): +        if self._contents is None: +            self._contents = self._parse(self._text) + +    @property +    def nameservers(self): +        self.parse() +        return self._retr_option('nameserver') + +    @property +    def local_domain(self): +        self.parse() +        dm = self._retr_option('domain') +        if dm: +            return dm[0] +        return None + +    @property +    def search_domains(self): +        self.parse() +        current_sds = self._retr_option('search') +        flat_sds = [] +        for sdlist in current_sds: +            for sd in sdlist.split(None): +                if sd: +                    flat_sds.append(sd) +        return flat_sds + +    def __str__(self): +        self.parse() +        contents = StringIO() +        for (line_type, components) in self._contents: +            if line_type == 'blank': +                contents.write("\n") +            elif line_type == 'all_comment': +                contents.write("%s\n" % (components[0])) +            elif line_type == 'option': +                (cfg_opt, cfg_value, comment_tail) = components +                line = "%s %s" % (cfg_opt, cfg_value) +                if len(comment_tail): +                    line += comment_tail +                contents.write("%s\n" % (line)) +        return contents.getvalue() + +    def _retr_option(self, opt_name): +        found = [] +        for (line_type, components) in self._contents: +            if line_type == 'option': +                (cfg_opt, cfg_value, _comment_tail) = components +                if cfg_opt == opt_name: +                    found.append(cfg_value) +        return found + +    def add_nameserver(self, ns): +        self.parse() +        current_ns = self._retr_option('nameserver') +        new_ns = list(current_ns) +        new_ns.append(str(ns)) +        new_ns = util.uniq_list(new_ns) +        if len(new_ns) == len(current_ns): +            return current_ns +        if len(current_ns) >= 3: +            # Hard restriction on only 3 name servers +            raise ValueError(("Adding %r would go beyond the " +                              "'3' maximum name servers") % (ns)) +        self._remove_option('nameserver') +        for n in new_ns: +            self._contents.append(('option', ['nameserver', n, ''])) +        return new_ns + +    def _remove_option(self, opt_name): + +        def remove_opt(item): +            line_type, components = item +            if line_type != 'option': +                return False +            (cfg_opt, _cfg_value, _comment_tail) = components +            if cfg_opt != opt_name: +                return False +            return True + +        new_contents = [] +        for c in self._contents: +            if not remove_opt(c): +                new_contents.append(c) +        self._contents = new_contents + +    def add_search_domain(self, search_domain): +        flat_sds = self.search_domains +        new_sds = list(flat_sds) +        new_sds.append(str(search_domain)) +        new_sds = util.uniq_list(new_sds) +        if len(flat_sds) == len(new_sds): +            return new_sds +        if len(flat_sds) >= 6: +            # Hard restriction on only 6 search domains +            raise ValueError(("Adding %r would go beyond the " +                              "'6' maximum search domains") % (search_domain)) +        s_list = " ".join(new_sds) +        if len(s_list) > 256: +            # Some hard limit on 256 chars total +            raise ValueError(("Adding %r would go beyond the " +                              "256 maximum search list character limit") +                              % (search_domain)) +        self._remove_option('search') +        self._contents.append(('option', ['search', s_list, ''])) +        return flat_sds + +    @local_domain.setter +    def local_domain(self, domain): +        self.parse() +        self._remove_option('domain') +        self._contents.append(('option', ['domain', str(domain), ''])) +        return domain + +    def _parse(self, contents): +        entries = [] +        for (i, line) in enumerate(contents.splitlines()): +            sline = line.strip() +            if not sline: +                entries.append(('blank', [line])) +                continue +            (head, tail) = chop_comment(line, ';#') +            if not len(head.strip()): +                entries.append(('all_comment', [line])) +                continue +            if not tail: +                tail = '' +            try: +                (cfg_opt, cfg_values) = head.split(None, 1) +            except (IndexError, ValueError): +                raise IOError("Incorrectly formatted resolv.conf line %s" +                              % (i + 1)) +            if cfg_opt not in ['nameserver', 'domain', +                               'search', 'sortlist', 'options']: +                raise IOError("Unexpected resolv.conf option %s" % (cfg_opt)) +            entries.append(("option", [cfg_opt, cfg_values, tail])) +        return entries diff --git a/cloudinit/distros/parsers/sys_conf.py b/cloudinit/distros/parsers/sys_conf.py new file mode 100644 index 00000000..20ca1871 --- /dev/null +++ b/cloudinit/distros/parsers/sys_conf.py @@ -0,0 +1,113 @@ +# vi: ts=4 expandtab +# +#    Copyright (C) 2012 Yahoo! Inc. +# +#    Author: Joshua Harlow <harlowja@yahoo-inc.com> +# +#    This program is free software: you can redistribute it and/or modify +#    it under the terms of the GNU General Public License version 3, as +#    published by the Free Software Foundation. +# +#    This program is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    GNU General Public License for more details. +# +#    You should have received a copy of the GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. + +from StringIO import StringIO + +import pipes +import re + +# This library is used to parse/write +# out the various sysconfig files edited (best attempt effort) +# +# It has to be slightly modified though +# to ensure that all values are quoted/unquoted correctly +# since these configs are usually sourced into +# bash scripts... +import configobj + +# See: http://pubs.opengroup.org/onlinepubs/000095399/basedefs/xbd_chap08.html +# or look at the 'param_expand()' function in the subst.c file in the bash +# source tarball... +SHELL_VAR_RULE = r'[a-zA-Z_]+[a-zA-Z0-9_]*' +SHELL_VAR_REGEXES = [ +    # Basic variables +    re.compile(r"\$" + SHELL_VAR_RULE), +    # Things like $?, $0, $-, $@ +    re.compile(r"\$[0-9#\?\-@\*]"), +    # Things like ${blah:1} - but this one +    # gets very complex so just try the +    # simple path +    re.compile(r"\$\{.+\}"), +] + + +def _contains_shell_variable(text): +    for r in SHELL_VAR_REGEXES: +        if r.search(text): +            return True +    return False + + +class SysConf(configobj.ConfigObj): +    def __init__(self, contents): +        configobj.ConfigObj.__init__(self, contents, +                                     interpolation=False, +                                     write_empty_values=True) + +    def __str__(self): +        contents = self.write() +        out_contents = StringIO() +        if isinstance(contents, (list, tuple)): +            out_contents.write("\n".join(contents)) +        else: +            out_contents.write(str(contents)) +        return out_contents.getvalue() + +    def _quote(self, value, multiline=False): +        if not isinstance(value, (str, basestring)): +            raise ValueError('Value "%s" is not a string' % (value)) +        if len(value) == 0: +            return '' +        quot_func = None +        if value[0] in ['"', "'"] and value[-1] in ['"', "'"]: +            if len(value) == 1: +                quot_func = (lambda x: +                                self._get_single_quote(x) % x) +        else: +            # Quote whitespace if it isn't the start + end of a shell command +            if value.strip().startswith("$(") and value.strip().endswith(")"): +                pass +            else: +                if re.search(r"[\t\r\n ]", value): +                    if _contains_shell_variable(value): +                        # If it contains shell variables then we likely want to +                        # leave it alone since the pipes.quote function likes +                        # to use single quotes which won't get expanded... +                        if re.search(r"[\n\"']", value): +                            quot_func = (lambda x: +                                            self._get_triple_quote(x) % x) +                        else: +                            quot_func = (lambda x: +                                            self._get_single_quote(x) % x) +                    else: +                        quot_func = pipes.quote +        if not quot_func: +            return value +        return quot_func(value) + +    def _write_line(self, indent_string, entry, this_entry, comment): +        # Ensure it is formatted fine for +        # how these sysconfig scripts are used +        val = self._decode_element(self._quote(this_entry)) +        key = self._decode_element(self._quote(entry)) +        cmnt = self._decode_element(comment) +        return '%s%s%s%s%s' % (indent_string, +                               key, +                               self._a_to_u('='), +                               val, +                               cmnt) diff --git a/cloudinit/distros/rhel.py b/cloudinit/distros/rhel.py index e4c27216..7df01c62 100644 --- a/cloudinit/distros/rhel.py +++ b/cloudinit/distros/rhel.py @@ -23,39 +23,18 @@  import os  from cloudinit import distros + +from cloudinit.distros.parsers.resolv_conf import ResolvConf +from cloudinit.distros.parsers.sys_conf import SysConf +  from cloudinit import helpers  from cloudinit import log as logging  from cloudinit import util -from cloudinit import version  from cloudinit.settings import PER_INSTANCE  LOG = logging.getLogger(__name__) -NETWORK_FN_TPL = '/etc/sysconfig/network-scripts/ifcfg-%s' - -# See: http://tiny.cc/6r99fw -# For what alot of these files that are being written -# are and the format of them - -# This library is used to parse/write -# out the various sysconfig files edited -# -# It has to be slightly modified though -# to ensure that all values are quoted -# since these configs are usually sourced into -# bash scripts... -from configobj import ConfigObj - -# See: http://tiny.cc/oezbgw -D_QUOTE_CHARS = { -    "\"": "\\\"", -    "(": "\\(", -    ")": "\\)", -    "$": '\$', -    '`': '\`', -} -  def _make_sysconfig_bool(val):      if val: @@ -64,12 +43,16 @@ def _make_sysconfig_bool(val):          return 'no' -def _make_header(): -    ci_ver = version.version_string() -    return '# Created by cloud-init v. %s' % (ci_ver) - -  class Distro(distros.Distro): +    # See: http://tiny.cc/6r99fw +    clock_conf_fn = "/etc/sysconfig/clock" +    locale_conf_fn = '/etc/sysconfig/i18n' +    network_conf_fn = "/etc/sysconfig/network" +    hostname_conf_fn = "/etc/sysconfig/network"  +    network_script_tpl = '/etc/sysconfig/network-scripts/ifcfg-%s' +    resolve_conf_fn = "/etc/resolv.conf" +    tz_local_fn = "/etc/localtime" +    tz_zone_dir = "/usr/share/zoneinfo"      def __init__(self, name, cfg, paths):          distros.Distro.__init__(self, name, cfg, paths) @@ -81,16 +64,29 @@ class Distro(distros.Distro):      def install_packages(self, pkglist):          self.package_command('install', pkglist) -    def _write_resolve(self, dns_servers, search_servers): -        contents = [] +    def _adjust_resolve(self, dns_servers, search_servers): +        r_conf = ResolvConf(util.load_file(self.resolve_conf_fn)) +        try: +            r_conf.parse() +        except IOError: +            util.logexc(LOG,  +                        "Failed at parsing %s reverting to an empty instance", +                        self.resolve_conf_fn) +            r_conf = ResolvConf('') +            r_conf.parse()          if dns_servers:              for s in dns_servers: -                contents.append("nameserver %s" % (s)) +                try: +                    r_conf.add_nameserver(s) +                except ValueError: +                    util.logexc(LOG, "Failed at adding nameserver %s", s)          if search_servers: -            contents.append("search %s" % (" ".join(search_servers))) -        if contents: -            contents.insert(0, _make_header()) -            util.write_file("/etc/resolv.conf", "\n".join(contents), 0644) +            for s in search_servers: +                try: +                    r_conf.add_search_domain(s) +                except ValueError: +                    util.logexc(LOG, "Failed at adding search domain %s", s) +        util.write_file(self.resolve_conf_fn, str(r_conf), 0644)      def _write_network(self, settings):          # TODO(harlowja) fix this... since this is the ubuntu format @@ -102,7 +98,7 @@ class Distro(distros.Distro):          searchservers = []          dev_names = entries.keys()          for (dev, info) in entries.iteritems(): -            net_fn = NETWORK_FN_TPL % (dev) +            net_fn = self.network_script_tpl % (dev)              net_cfg = {                  'DEVICE': dev,                  'NETMASK': info.get('netmask'), @@ -119,12 +115,12 @@ class Distro(distros.Distro):              if 'dns-search' in info:                  searchservers.extend(info['dns-search'])          if nameservers or searchservers: -            self._write_resolve(nameservers, searchservers) +            self._adjust_resolve(nameservers, searchservers)          if dev_names:              net_cfg = {                  'NETWORKING': _make_sysconfig_bool(True),              } -            self._update_sysconfig_file("/etc/sysconfig/network", net_cfg) +            self._update_sysconfig_file(self.network_conf_fn, net_cfg)          return dev_names      def _update_sysconfig_file(self, fn, adjustments, allow_empty=False): @@ -141,24 +137,16 @@ class Distro(distros.Distro):              contents[k] = v              updated_am += 1          if updated_am: -            lines = contents.write() +            lines = [ +                str(contents), +            ]              if not exists: -                lines.insert(0, _make_header()) +                lines.insert(0, util.make_header())              util.write_file(fn, "\n".join(lines), 0644) -    def set_hostname(self, hostname, fqdn=None): -        # See: http://bit.ly/TwitgL -        # Should be fqdn if we can use it -        sysconfig_hostname = fqdn -        if not sysconfig_hostname: -            sysconfig_hostname = hostname -        self._write_hostname(sysconfig_hostname, '/etc/sysconfig/network') -        LOG.debug("Setting hostname to %s", hostname) -        util.subp(['hostname', hostname]) -      def apply_locale(self, locale, out_fn=None):          if not out_fn: -            out_fn = '/etc/sysconfig/i18n' +            out_fn = self.locale_conf_fn          locale_cfg = {              'LANG': locale,          } @@ -170,34 +158,16 @@ class Distro(distros.Distro):          }          self._update_sysconfig_file(out_fn, host_cfg) -    def update_hostname(self, hostname, fqdn, prev_file): +    def _select_hostname(self, hostname, fqdn):          # See: http://bit.ly/TwitgL          # Should be fqdn if we can use it -        sysconfig_hostname = fqdn -        if not sysconfig_hostname: -            sysconfig_hostname = hostname -        hostname_prev = self._read_hostname(prev_file) -        hostname_in_sys = self._read_hostname("/etc/sysconfig/network") -        update_files = [] -        if not hostname_prev or hostname_prev != sysconfig_hostname: -            update_files.append(prev_file) -        if (not hostname_in_sys or -            (hostname_in_sys == hostname_prev -             and hostname_in_sys != sysconfig_hostname)): -            update_files.append("/etc/sysconfig/network") -        for fn in update_files: -            try: -                self._write_hostname(sysconfig_hostname, fn) -            except: -                util.logexc(LOG, "Failed to write hostname %s to %s", -                            sysconfig_hostname, fn) -        if (hostname_in_sys and hostname_prev and -            hostname_in_sys != hostname_prev): -            LOG.debug(("%s differs from /etc/sysconfig/network." -                        " Assuming user maintained hostname."), prev_file) -        if "/etc/sysconfig/network" in update_files: -            LOG.debug("Setting hostname to %s", hostname) -            util.subp(['hostname', hostname]) +        if fqdn: +            return fqdn +        return hostname + +    def _read_system_hostname(self): +        return (self.network_conf_fn, +                self._read_hostname(self.network_conf_fn))      def _read_hostname(self, filename, default=None):          (_exists, contents) = self._read_conf(filename) @@ -213,7 +183,8 @@ class Distro(distros.Distro):              exists = True          else:              contents = [] -        return (exists, QuotingConfigObj(contents)) +        return (exists, +                SysConf(contents))      def _bring_up_interfaces(self, device_names):          if device_names and 'all' in device_names: @@ -222,17 +193,19 @@ class Distro(distros.Distro):          return distros.Distro._bring_up_interfaces(self, device_names)      def set_timezone(self, tz): -        tz_file = os.path.join("/usr/share/zoneinfo", tz) +        # TODO(harlowja): move this code into +        # the parent distro... +        tz_file = os.path.join(self.tz_zone_dir, str(tz))          if not os.path.isfile(tz_file):              raise RuntimeError(("Invalid timezone %s,"                                  " no file found at %s") % (tz, tz_file))          # Adjust the sysconfig clock zone setting          clock_cfg = { -            'ZONE': tz, +            'ZONE': str(tz),          } -        self._update_sysconfig_file("/etc/sysconfig/clock", clock_cfg) +        self._update_sysconfig_file(self.clock_conf_fn, clock_cfg)          # This ensures that the correct tz will be used for the system -        util.copy(tz_file, "/etc/localtime") +        util.copy(tz_file, self.tz_local_fn)      def package_command(self, command, args=None):          cmd = ['yum'] @@ -256,51 +229,6 @@ class Distro(distros.Distro):                           ["makecache"], freq=PER_INSTANCE) -# This class helps adjust the configobj -# writing to ensure that when writing a k/v -# on a line, that they are properly quoted -# and have no spaces between the '=' sign. -# - This is mainly due to the fact that -# the sysconfig scripts are often sourced -# directly into bash/shell scripts so ensure -# that it works for those types of use cases. -class QuotingConfigObj(ConfigObj): -    def __init__(self, lines): -        ConfigObj.__init__(self, lines, -                           interpolation=False, -                           write_empty_values=True) - -    def _quote_posix(self, text): -        if not text: -            return '' -        for (k, v) in D_QUOTE_CHARS.iteritems(): -            text = text.replace(k, v) -        return '"%s"' % (text) - -    def _quote_special(self, text): -        if text.lower() in ['yes', 'no', 'true', 'false']: -            return text -        else: -            return self._quote_posix(text) - -    def _write_line(self, indent_string, entry, this_entry, comment): -        # Ensure it is formatted fine for -        # how these sysconfig scripts are used -        val = self._decode_element(self._quote(this_entry)) -        # Single quoted strings should -        # always work. -        if not val.startswith("'"): -            # Perform any special quoting -            val = self._quote_special(val) -        key = self._decode_element(self._quote(entry, multiline=False)) -        cmnt = self._decode_element(comment) -        return '%s%s%s%s%s' % (indent_string, -                               key, -                               "=", -                               val, -                               cmnt) - -  # This is a util function to translate a ubuntu /etc/network/interfaces 'blob'  # to a rhel equiv. that can then be written to /etc/sysconfig/network-scripts/  # TODO(harlowja) remove when we have python-netcf active... diff --git a/cloudinit/util.py b/cloudinit/util.py index 4f5b15ee..ab918433 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -52,6 +52,7 @@ from cloudinit import importer  from cloudinit import log as logging  from cloudinit import safeyaml  from cloudinit import url_helper as uhelp +from cloudinit import version  from cloudinit.settings import (CFG_BUILTIN) @@ -272,11 +273,7 @@ def uniq_merge(*lists):              # Kickout the empty ones              a_list = [a for a in a_list if len(a)]          combined_list.extend(a_list) -    uniq_list = [] -    for i in combined_list: -        if i not in uniq_list: -            uniq_list.append(i) -    return uniq_list +    return uniq_list(combined_list)  def clean_filename(fn): @@ -989,6 +986,16 @@ def peek_file(fname, max_bytes):          return ifh.read(max_bytes) +def uniq_list(in_list): +    out_list = [] +    for i in in_list: +        if i in out_list: +            continue +        else: +            out_list.append(i) +    return out_list + +  def load_file(fname, read_cb=None, quiet=False):      LOG.debug("Reading from %s (quiet=%s)", fname, quiet)      ofh = StringIO() @@ -1428,6 +1435,14 @@ def subp(args, data=None, rcs=None, env=None, capture=True, shell=False,      return (out, err) +def make_header(comment_char="#", base='created'): +    ci_ver = version.version_string() +    header = str(comment_char) +    header += " %s by cloud-init v. %s" % (base.title(), ci_ver) +    header += " on %s" % time_rfc2822() +    return header + +  def abs_join(*paths):      return os.path.abspath(os.path.join(*paths)) diff --git a/tests/unittests/test_distros/test_hostname.py b/tests/unittests/test_distros/test_hostname.py new file mode 100644 index 00000000..8e644f4d --- /dev/null +++ b/tests/unittests/test_distros/test_hostname.py @@ -0,0 +1,38 @@ +from mocker import MockerTestCase + +from cloudinit.distros.parsers import hostname + + +BASE_HOSTNAME = ''' +# My super-duper-hostname + +blahblah + +''' +BASE_HOSTNAME = BASE_HOSTNAME.strip() + + +class TestHostnameHelper(MockerTestCase): +    def test_parse_same(self): +        hn = hostname.HostnameConf(BASE_HOSTNAME) +        self.assertEquals(str(hn).strip(), BASE_HOSTNAME) +        self.assertEquals(hn.hostname, 'blahblah') + +    def test_no_adjust_hostname(self): +        hn = hostname.HostnameConf(BASE_HOSTNAME) +        prev_name = hn.hostname +        hn.set_hostname("") +        self.assertEquals(hn.hostname, prev_name) + +    def test_adjust_hostname(self): +        hn = hostname.HostnameConf(BASE_HOSTNAME) +        prev_name = hn.hostname +        self.assertEquals(prev_name, 'blahblah') +        hn.set_hostname("bbbbd") +        self.assertEquals(hn.hostname, 'bbbbd') +        expected_out = ''' +# My super-duper-hostname + +bbbbd +''' +        self.assertEquals(str(hn).strip(), expected_out.strip()) diff --git a/tests/unittests/test_distros/test_hosts.py b/tests/unittests/test_distros/test_hosts.py new file mode 100644 index 00000000..687a0dab --- /dev/null +++ b/tests/unittests/test_distros/test_hosts.py @@ -0,0 +1,41 @@ +from mocker import MockerTestCase + +from cloudinit.distros.parsers import hosts + + +BASE_ETC = ''' +# Example +127.0.0.1	localhost +192.168.1.10	foo.mydomain.org  foo +192.168.1.10 	bar.mydomain.org  bar +146.82.138.7	master.debian.org      master +209.237.226.90	www.opensource.org +''' +BASE_ETC = BASE_ETC.strip() + + +class TestHostsHelper(MockerTestCase): +    def test_parse(self): +        eh = hosts.HostsConf(BASE_ETC) +        self.assertEquals(eh.get_entry('127.0.0.1'), [['localhost']]) +        self.assertEquals(eh.get_entry('192.168.1.10'), +                          [['foo.mydomain.org', 'foo'], +                           ['bar.mydomain.org', 'bar']]) +        eh = str(eh) +        self.assertTrue(eh.startswith('# Example')) + +    def test_add(self): +        eh = hosts.HostsConf(BASE_ETC) +        eh.add_entry('127.0.0.0', 'blah') +        self.assertEquals(eh.get_entry('127.0.0.0'), [['blah']]) +        eh.add_entry('127.0.0.3', 'blah', 'blah2', 'blah3') +        self.assertEquals(eh.get_entry('127.0.0.3'), +                          [['blah', 'blah2', 'blah3']]) + +    def test_del(self): +        eh = hosts.HostsConf(BASE_ETC) +        eh.add_entry('127.0.0.0', 'blah') +        self.assertEquals(eh.get_entry('127.0.0.0'), [['blah']]) + +        eh.del_entries('127.0.0.0') +        self.assertEquals(eh.get_entry('127.0.0.0'), []) diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 55765f0c..9763b14b 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -9,6 +9,8 @@ from cloudinit import helpers  from cloudinit import settings  from cloudinit import util +from cloudinit.distros.parsers.sys_conf import SysConf +  from StringIO import StringIO @@ -83,9 +85,8 @@ class TestNetCfgDistro(MockerTestCase):          self.assertEquals(write_buf.mode, 0644)      def assertCfgEquals(self, blob1, blob2): -        cfg_tester = distros.rhel.QuotingConfigObj -        b1 = dict(cfg_tester(blob1.strip().splitlines())) -        b2 = dict(cfg_tester(blob2.strip().splitlines())) +        b1 = dict(SysConf(blob1.strip().splitlines())) +        b2 = dict(SysConf(blob2.strip().splitlines()))          self.assertEquals(b1, b2)          for (k, v) in b1.items():              self.assertIn(k, b2) diff --git a/tests/unittests/test_distros/test_resolv.py b/tests/unittests/test_distros/test_resolv.py new file mode 100644 index 00000000..d947dda0 --- /dev/null +++ b/tests/unittests/test_distros/test_resolv.py @@ -0,0 +1,63 @@ +from mocker import MockerTestCase + +from cloudinit.distros.parsers import resolv_conf + +import re + + +BASE_RESOLVE = ''' +; generated by /sbin/dhclient-script +search blah.yahoo.com yahoo.com +nameserver 10.15.44.14 +nameserver 10.15.30.92 +''' +BASE_RESOLVE = BASE_RESOLVE.strip() + + +class TestResolvHelper(MockerTestCase): +    def test_parse_same(self): +        rp = resolv_conf.ResolvConf(BASE_RESOLVE) +        rp_r = str(rp).strip() +        self.assertEquals(BASE_RESOLVE, rp_r) + +    def test_local_domain(self): +        rp = resolv_conf.ResolvConf(BASE_RESOLVE) +        self.assertEquals(None, rp.local_domain) + +        rp.local_domain = "bob" +        self.assertEquals('bob', rp.local_domain) +        self.assertIn('domain bob', str(rp)) + +    def test_nameservers(self): +        rp = resolv_conf.ResolvConf(BASE_RESOLVE) +        self.assertIn('10.15.44.14', rp.nameservers) +        self.assertIn('10.15.30.92', rp.nameservers) +        rp.add_nameserver('10.2') +        self.assertIn('10.2', rp.nameservers) +        self.assertIn('nameserver 10.2', str(rp)) +        self.assertNotIn('10.3', rp.nameservers) +        self.assertEquals(len(rp.nameservers), 3) +        rp.add_nameserver('10.2') +        with self.assertRaises(ValueError): +            rp.add_nameserver('10.3') +        self.assertNotIn('10.3', rp.nameservers) + +    def test_search_domains(self): +        rp = resolv_conf.ResolvConf(BASE_RESOLVE) +        self.assertIn('yahoo.com', rp.search_domains) +        self.assertIn('blah.yahoo.com', rp.search_domains) +        rp.add_search_domain('bbb.y.com') +        self.assertIn('bbb.y.com', rp.search_domains) +        self.assertTrue(re.search(r'search(.*)bbb.y.com(.*)', str(rp))) +        self.assertIn('bbb.y.com', rp.search_domains) +        rp.add_search_domain('bbb.y.com') +        self.assertEquals(len(rp.search_domains), 3) +        rp.add_search_domain('bbb2.y.com') +        self.assertEquals(len(rp.search_domains), 4) +        rp.add_search_domain('bbb3.y.com') +        self.assertEquals(len(rp.search_domains), 5) +        rp.add_search_domain('bbb4.y.com') +        self.assertEquals(len(rp.search_domains), 6) +        with self.assertRaises(ValueError): +            rp.add_search_domain('bbb5.y.com') +        self.assertEquals(len(rp.search_domains), 6) diff --git a/tests/unittests/test_distros/test_sysconfig.py b/tests/unittests/test_distros/test_sysconfig.py new file mode 100644 index 00000000..1e34909d --- /dev/null +++ b/tests/unittests/test_distros/test_sysconfig.py @@ -0,0 +1,83 @@ +from mocker import MockerTestCase + +import re + +from cloudinit.distros.parsers.sys_conf import SysConf + + +# Lots of good examples @ +# http://content.hccfl.edu/pollock/AUnix1/SysconfigFilesDesc.txt + +class TestSysConfHelper(MockerTestCase): +    # This function was added in 2.7, make it work for 2.6 +    def assertRegMatches(self, text, regexp): +        regexp = re.compile(regexp) +        self.assertTrue(regexp.search(text), +                        msg="%s must match %s!" % (text, regexp.pattern)) + +    def test_parse_no_change(self): +        contents = '''# A comment +USESMBAUTH=no +KEYTABLE=/usr/lib/kbd/keytables/us.map +SHORTDATE=$(date +%y:%m:%d:%H:%M) +HOSTNAME=blahblah +NETMASK0=255.255.255.0 +# Inline comment +LIST=$LOGROOT/incremental-list +IPV6TO4_ROUTING='eth0-:0004::1/64 eth1-:0005::1/64' +ETHTOOL_OPTS="-K ${DEVICE} tso on; -G ${DEVICE} rx 256 tx 256" +USEMD5=no''' +        conf = SysConf(contents.splitlines()) +        self.assertEquals(conf['HOSTNAME'], 'blahblah') +        self.assertEquals(conf['SHORTDATE'], '$(date +%y:%m:%d:%H:%M)') +        # Should be unquoted +        self.assertEquals(conf['ETHTOOL_OPTS'], ('-K ${DEVICE} tso on; ' +                                                 '-G ${DEVICE} rx 256 tx 256')) +        self.assertEquals(contents, str(conf)) + +    def test_parse_shell_vars(self): +        contents = 'USESMBAUTH=$XYZ' +        conf = SysConf(contents.splitlines()) +        self.assertEquals(contents, str(conf)) +        conf = SysConf('') +        conf['B'] = '${ZZ}d apples' +        # Should be quoted +        self.assertEquals('B="${ZZ}d apples"', str(conf)) +        conf = SysConf('') +        conf['B'] = '$? d apples' +        self.assertEquals('B="$? d apples"', str(conf)) +        contents = 'IPMI_WATCHDOG_OPTIONS="timeout=60"' +        conf = SysConf(contents.splitlines()) +        self.assertEquals('IPMI_WATCHDOG_OPTIONS=timeout=60', str(conf)) + +    def test_parse_adjust(self): +        contents = 'IPV6TO4_ROUTING="eth0-:0004::1/64 eth1-:0005::1/64"' +        conf = SysConf(contents.splitlines()) +        # Should be unquoted +        self.assertEquals('eth0-:0004::1/64 eth1-:0005::1/64', +                          conf['IPV6TO4_ROUTING']) +        conf['IPV6TO4_ROUTING'] = "blah \tblah" +        contents2 = str(conf).strip() +        # Should be requoted due to whitespace +        self.assertRegMatches(contents2, +                              r'IPV6TO4_ROUTING=[\']blah\s+blah[\']') + +    def test_parse_no_adjust_shell(self): +        conf = SysConf(''.splitlines()) +        conf['B'] = ' $(time)' +        contents = str(conf) +        self.assertEquals('B= $(time)', contents) + +    def test_parse_empty(self): +        contents = '' +        conf = SysConf(contents.splitlines()) +        self.assertEquals('', str(conf).strip()) + +    def test_parse_add_new(self): +        contents = 'BLAH=b' +        conf = SysConf(contents.splitlines()) +        conf['Z'] = 'd' +        contents = str(conf) +        self.assertIn("Z=d", contents) +        self.assertIn("BLAH=b", contents) +         | 
