From 28c8aa7270a04adea69065477b13cfc0dd244acc Mon Sep 17 00:00:00 2001 From: Ben Howard Date: Wed, 14 Jan 2015 12:24:09 -0700 Subject: Drop reliance on dmidecode executable. --- tests/unittests/test_datasource/test_altcloud.py | 66 ++++++++++-------------- tests/unittests/test_util.py | 28 ++++++++++ 2 files changed, 54 insertions(+), 40 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_altcloud.py b/tests/unittests/test_datasource/test_altcloud.py index eaaa90e6..1a48ee5f 100644 --- a/tests/unittests/test_datasource/test_altcloud.py +++ b/tests/unittests/test_datasource/test_altcloud.py @@ -26,6 +26,7 @@ import shutil import tempfile from cloudinit import helpers +from cloudinit import util from unittest import TestCase # Get the cloudinit.sources.DataSourceAltCloud import items needed. @@ -98,6 +99,16 @@ def _remove_user_data_files(mount_dir, pass +def _dmi_data(expected): + ''' + Spoof the data received over DMI + ''' + def _data(key): + return expected + + return _data + + class TestGetCloudType(TestCase): ''' Test to exercise method: DataSourceAltCloud.get_cloud_type() @@ -106,24 +117,22 @@ class TestGetCloudType(TestCase): def setUp(self): '''Set up.''' self.paths = helpers.Paths({'cloud_dir': '/tmp'}) + self.dmi_data = util.read_dmi_data # We have a different code path for arm to deal with LP1243287 # We have to switch arch to x86_64 to avoid test failure force_arch('x86_64') def tearDown(self): # Reset - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['dmidecode', '--string', 'system-product-name'] - # Return back to original arch + util.read_dmi_data = self.dmi_data force_arch() def test_rhev(self): ''' Test method get_cloud_type() for RHEVm systems. - Forcing dmidecode return to match a RHEVm system: RHEV Hypervisor + Forcing read_dmi_data return to match a RHEVm system: RHEV Hypervisor ''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['echo', 'RHEV Hypervisor'] + util.read_dmi_data = _dmi_data('RHEV') dsrc = DataSourceAltCloud({}, None, self.paths) self.assertEquals('RHEV', \ dsrc.get_cloud_type()) @@ -131,10 +140,9 @@ class TestGetCloudType(TestCase): def test_vsphere(self): ''' Test method get_cloud_type() for vSphere systems. - Forcing dmidecode return to match a vSphere system: RHEV Hypervisor + Forcing read_dmi_data return to match a vSphere system: RHEV Hypervisor ''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['echo', 'VMware Virtual Platform'] + util.read_dmi_data = _dmi_data('VMware Virtual Platform') dsrc = DataSourceAltCloud({}, None, self.paths) self.assertEquals('VSPHERE', \ dsrc.get_cloud_type()) @@ -142,30 +150,9 @@ class TestGetCloudType(TestCase): def test_unknown(self): ''' Test method get_cloud_type() for unknown systems. - Forcing dmidecode return to match an unrecognized return. - ''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['echo', 'Unrecognized Platform'] - dsrc = DataSourceAltCloud({}, None, self.paths) - self.assertEquals('UNKNOWN', \ - dsrc.get_cloud_type()) - - def test_exception1(self): - ''' - Test method get_cloud_type() where command dmidecode fails. - ''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['ls', 'bad command'] - dsrc = DataSourceAltCloud({}, None, self.paths) - self.assertEquals('UNKNOWN', \ - dsrc.get_cloud_type()) - - def test_exception2(self): - ''' - Test method get_cloud_type() where command dmidecode is not available. + Forcing read_dmi_data return to match an unrecognized return. ''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['bad command'] + util.read_dmi_data = _dmi_data('Unrecognized Platform') dsrc = DataSourceAltCloud({}, None, self.paths) self.assertEquals('UNKNOWN', \ dsrc.get_cloud_type()) @@ -180,6 +167,7 @@ class TestGetDataCloudInfoFile(TestCase): '''Set up.''' self.paths = helpers.Paths({'cloud_dir': '/tmp'}) self.cloud_info_file = tempfile.mkstemp()[1] + self.dmi_data = util.read_dmi_data cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE = \ self.cloud_info_file @@ -192,6 +180,7 @@ class TestGetDataCloudInfoFile(TestCase): except OSError: pass + util.read_dmi_data = self.dmi_data cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE = \ '/etc/sysconfig/cloud-info' @@ -243,6 +232,7 @@ class TestGetDataNoCloudInfoFile(TestCase): def setUp(self): '''Set up.''' self.paths = helpers.Paths({'cloud_dir': '/tmp'}) + self.dmi_data = util.read_dmi_data cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE = \ 'no such file' # We have a different code path for arm to deal with LP1243287 @@ -253,16 +243,14 @@ class TestGetDataNoCloudInfoFile(TestCase): # Reset cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE = \ '/etc/sysconfig/cloud-info' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['dmidecode', '--string', 'system-product-name'] + util.read_dmi_data = self.dmi_data # Return back to original arch force_arch() def test_rhev_no_cloud_file(self): '''Test No cloud info file module get_data() forcing RHEV.''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['echo', 'RHEV Hypervisor'] + util.read_dmi_data = _dmi_data('RHEV Hypervisor') dsrc = DataSourceAltCloud({}, None, self.paths) dsrc.user_data_rhevm = lambda: True self.assertEquals(True, dsrc.get_data()) @@ -270,8 +258,7 @@ class TestGetDataNoCloudInfoFile(TestCase): def test_vsphere_no_cloud_file(self): '''Test No cloud info file module get_data() forcing VSPHERE.''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['echo', 'VMware Virtual Platform'] + util.read_dmi_data = _dmi_data('VMware Virtual Platform') dsrc = DataSourceAltCloud({}, None, self.paths) dsrc.user_data_vsphere = lambda: True self.assertEquals(True, dsrc.get_data()) @@ -279,8 +266,7 @@ class TestGetDataNoCloudInfoFile(TestCase): def test_failure_no_cloud_file(self): '''Test No cloud info file module get_data() forcing unrecognized.''' - cloudinit.sources.DataSourceAltCloud.CMD_DMI_SYSTEM = \ - ['echo', 'Unrecognized Platform'] + util.read_dmi_data = _dmi_data('Unrecognized Platform') dsrc = DataSourceAltCloud({}, None, self.paths) self.assertEquals(False, dsrc.get_data()) diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 35e92445..6ae41bd6 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -310,4 +310,32 @@ class TestMountinfoParsing(helpers.ResourceUsingTestCase): expected = ('none', 'tmpfs', '/run/lock') self.assertEqual(expected, util.parse_mount_info('/run/lock', lines)) + +class TestReadDMIData(helpers.FilesystemMockingTestCase): + + def _patchIn(self, root): + self.restore() + self.patchOS(root) + self.patchUtils(root) + + def _write_key(self, key, content): + new_root = self.makeDir() + self._patchIn(new_root) + util.ensure_dir(os.path.join('sys', 'class', 'dmi', 'id')) + + dmi_key = "/sys/class/dmi/id/{}".format(key) + util.write_file(dmi_key, content) + + def test_key(self): + key_content = "TEST-KEY-DATA" + self._write_key("key", key_content) + self.assertEquals(key_content, util.read_dmi_data("key")) + + def test_key_mismatch(self): + self._write_key("test", "ABC") + self.assertNotEqual("123", util.read_dmi_data("test")) + + def test_no_key(self): + self.assertFalse(util.read_dmi_data("key")) + # vi: ts=4 expandtab -- cgit v1.2.3 From 463d626ba53e54160f350d84831e1877b24f4024 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Wed, 21 Jan 2015 15:28:32 -0500 Subject: Only install cheetah (and only run the cheetah templating test) when in Python 2. Cheetah is not compatible with Python 3. --- requirements.txt | 4 +++- setup.py | 7 ++++++- tests/unittests/test_templating.py | 4 ++++ 3 files changed, 13 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/requirements.txt b/requirements.txt index 943dbef7..2a12ca3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ # Pypi requirements for cloud-init to work # Used for untemplating any files or strings with parameters. -cheetah jinja2 # This is used for any pretty printing of tabular data. @@ -32,3 +31,6 @@ requests # For patching pieces of cloud-config together jsonpatch + +# For Python 2/3 compatibility +six diff --git a/setup.py b/setup.py index 507d5b25..e88d9e88 100755 --- a/setup.py +++ b/setup.py @@ -175,6 +175,11 @@ else: } +requirements = read_requires() +if sys.version_info < (3,): + requirements.append('cheetah') + + setuptools.setup(name='cloud-init', version=get_version(), description='EC2 initialisation magic', @@ -187,6 +192,6 @@ setuptools.setup(name='cloud-init', ], license='GPLv3', data_files=data_files, - install_requires=read_requires(), + install_requires=requirements, cmdclass=cmdclass, ) diff --git a/tests/unittests/test_templating.py b/tests/unittests/test_templating.py index 3ba4ed8a..957467f6 100644 --- a/tests/unittests/test_templating.py +++ b/tests/unittests/test_templating.py @@ -16,6 +16,9 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import six +import unittest + from . import helpers as test_helpers import textwrap @@ -38,6 +41,7 @@ class TestTemplates(test_helpers.TestCase): out_data = templater.basic_render(in_data, {'b': 2}) self.assertEqual(expected_data.strip(), out_data) + @unittest.skipIf(six.PY3, 'Cheetah is not compatible with Python 3') def test_detection(self): blob = "## template:cheetah" -- cgit v1.2.3 From cccc0ff012d2e7b5c238609b22cc064b519e54a5 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Wed, 21 Jan 2015 15:35:56 -0500 Subject: Fix file modes to be Python 2/3 compatible. --- .bzrignore | 1 + cloudinit/distros/__init__.py | 2 +- cloudinit/util.py | 8 ++++---- tests/unittests/test__init__.py | 2 +- tests/unittests/test_data.py | 2 +- tests/unittests/test_datasource/test_altcloud.py | 2 +- tests/unittests/test_datasource/test_azure.py | 2 +- tests/unittests/test_distros/test_netconfig.py | 2 +- tests/unittests/test_handler/test_handler_ca_certs.py | 2 +- tests/unittests/test_handler/test_handler_growpart.py | 2 +- tests/unittests/test_runs/test_simple_run.py | 2 +- tests/unittests/test_util.py | 6 +++--- 12 files changed, 17 insertions(+), 16 deletions(-) (limited to 'tests/unittests') diff --git a/.bzrignore b/.bzrignore index 32f5a949..926e4581 100644 --- a/.bzrignore +++ b/.bzrignore @@ -1,3 +1,4 @@ .tox dist cloud_init.egg-info +__pycache__ diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 5eab780b..a913e15a 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -272,7 +272,7 @@ class Distro(object): if header: contents.write("%s\n" % (header)) contents.write("%s\n" % (eh)) - util.write_file(self.hosts_fn, contents.getvalue(), mode=0644) + util.write_file(self.hosts_fn, contents.getvalue(), mode=0o644) def _bring_up_interface(self, device_name): cmd = ['ifup', device_name] diff --git a/cloudinit/util.py b/cloudinit/util.py index bf8e7d80..9efc704a 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -1250,7 +1250,7 @@ def rename(src, dest): os.rename(src, dest) -def ensure_dirs(dirlist, mode=0755): +def ensure_dirs(dirlist, mode=0o755): for d in dirlist: ensure_dir(d, mode) @@ -1264,7 +1264,7 @@ def read_write_cmdline_url(target_fn): return try: if key and content: - write_file(target_fn, content, mode=0600) + write_file(target_fn, content, mode=0o600) LOG.debug(("Wrote to %s with contents of command line" " url %s (len=%s)"), target_fn, url, len(content)) elif key and not content: @@ -1489,7 +1489,7 @@ def append_file(path, content): write_file(path, content, omode="ab", mode=None) -def ensure_file(path, mode=0644): +def ensure_file(path, mode=0o644): write_file(path, content='', omode="ab", mode=mode) @@ -1507,7 +1507,7 @@ def chmod(path, mode): os.chmod(path, real_mode) -def write_file(filename, content, mode=0644, omode="wb"): +def write_file(filename, content, mode=0o644, omode="wb"): """ Writes a file with the given content and sets the file mode as specified. Resotres the SELinux context if possible. diff --git a/tests/unittests/test__init__.py b/tests/unittests/test__init__.py index 17965488..48db1a5e 100644 --- a/tests/unittests/test__init__.py +++ b/tests/unittests/test__init__.py @@ -48,7 +48,7 @@ class TestWalkerHandleHandler(MockerTestCase): # Mock the write_file function write_file_mock = self.mocker.replace(util.write_file, passthrough=False) - write_file_mock(expected_file_fullname, self.payload, 0600) + write_file_mock(expected_file_fullname, self.payload, 0o600) def test_no_errors(self): """Payload gets written to file and added to C{pdata}.""" diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index fd6bd8a1..5517f0b4 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -337,7 +337,7 @@ p: 1 mock_write = self.mocker.replace("cloudinit.util.write_file", passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0600) + mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) self.mocker.replay() log_file = self.capture_log(logging.WARNING) diff --git a/tests/unittests/test_datasource/test_altcloud.py b/tests/unittests/test_datasource/test_altcloud.py index eaaa90e6..9d8a4a20 100644 --- a/tests/unittests/test_datasource/test_altcloud.py +++ b/tests/unittests/test_datasource/test_altcloud.py @@ -45,7 +45,7 @@ def _write_cloud_info_file(value): cifile = open(cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE, 'w') cifile.write(value) cifile.close() - os.chmod(cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE, 0664) + os.chmod(cloudinit.sources.DataSourceAltCloud.CLOUD_INFO_FILE, 0o664) def _remove_cloud_info_file(): diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index e992a006..6e007a95 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -153,7 +153,7 @@ class TestAzureDataSource(MockerTestCase): ret = dsrc.get_data() self.assertTrue(ret) self.assertTrue(os.path.isdir(self.waagent_d)) - self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0700) + self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0o700) def test_user_cfg_set_agent_command_plain(self): # set dscfg in via plaintext diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 193338e8..47de034b 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -96,7 +96,7 @@ class TestNetCfgDistro(MockerTestCase): write_bufs = {} - def replace_write(filename, content, mode=0644, omode="wb"): + def replace_write(filename, content, mode=0o644, omode="wb"): buf = WriteBuffer() buf.mode = mode buf.omode = omode diff --git a/tests/unittests/test_handler/test_handler_ca_certs.py b/tests/unittests/test_handler/test_handler_ca_certs.py index 0558023a..75df807e 100644 --- a/tests/unittests/test_handler/test_handler_ca_certs.py +++ b/tests/unittests/test_handler/test_handler_ca_certs.py @@ -150,7 +150,7 @@ class TestAddCaCerts(MockerTestCase): mock_load = self.mocker.replace(util.load_file, passthrough=False) mock_write("/usr/share/ca-certificates/cloud-init-ca-certs.crt", - cert, mode=0644) + cert, mode=0o644) mock_load("/etc/ca-certificates.conf") self.mocker.result(ca_certs_content) diff --git a/tests/unittests/test_handler/test_handler_growpart.py b/tests/unittests/test_handler/test_handler_growpart.py index 5d0636d1..3056320d 100644 --- a/tests/unittests/test_handler/test_handler_growpart.py +++ b/tests/unittests/test_handler/test_handler_growpart.py @@ -145,7 +145,7 @@ class TestResize(MockerTestCase): # this patches out devent2dev, os.stat, and device_part_info # so in the end, doesn't test a lot devs = ["/dev/XXda1", "/dev/YYda2"] - devstat_ret = Bunch(st_mode=25008, st_ino=6078, st_dev=5L, + devstat_ret = Bunch(st_mode=25008, st_ino=6078, st_dev=5, st_nlink=1, st_uid=0, st_gid=6, st_size=0, st_atime=0, st_mtime=0, st_ctime=0) enoent = ["/dev/NOENT"] diff --git a/tests/unittests/test_runs/test_simple_run.py b/tests/unittests/test_runs/test_simple_run.py index c9ba949e..2d51a337 100644 --- a/tests/unittests/test_runs/test_simple_run.py +++ b/tests/unittests/test_runs/test_simple_run.py @@ -41,7 +41,7 @@ class TestSimpleRun(helpers.FilesystemMockingTestCase): { 'path': '/etc/blah.ini', 'content': 'blah', - 'permissions': 0755, + 'permissions': 0o755, }, ], 'cloud_init_modules': ['write-files'], diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 35e92445..203445b7 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -79,7 +79,7 @@ class TestWriteFile(MockerTestCase): create_contents = f.read() self.assertEqual(contents, create_contents) file_stat = os.stat(path) - self.assertEqual(0644, stat.S_IMODE(file_stat.st_mode)) + self.assertEqual(0o644, stat.S_IMODE(file_stat.st_mode)) def test_dir_is_created_if_required(self): """Verifiy that directories are created is required.""" @@ -97,12 +97,12 @@ class TestWriteFile(MockerTestCase): path = os.path.join(self.tmp, "NewFile.txt") contents = "Hey there" - util.write_file(path, contents, mode=0666) + util.write_file(path, contents, mode=0o666) self.assertTrue(os.path.exists(path)) self.assertTrue(os.path.isfile(path)) file_stat = os.stat(path) - self.assertEqual(0666, stat.S_IMODE(file_stat.st_mode)) + self.assertEqual(0o666, stat.S_IMODE(file_stat.st_mode)) def test_custom_omode(self): """Verify custom omode works properly.""" -- cgit v1.2.3 From a64bb4febc79fcf641f6471d8cc00c74ca915f3d Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Wed, 21 Jan 2015 15:42:59 -0500 Subject: More octal literal fixes. --- cloudinit/distros/__init__.py | 6 ++--- tests/unittests/test_data.py | 14 ++++++------ tests/unittests/test_datasource/test_altcloud.py | 4 ++-- tests/unittests/test_distros/test_netconfig.py | 26 +++++++++++----------- .../test_handler/test_handler_ca_certs.py | 6 ++--- 5 files changed, 28 insertions(+), 28 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index a913e15a..49a0b652 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -468,7 +468,7 @@ class Distro(object): util.make_header(base="added"), "#includedir %s" % (path), ''] sudoers_contents = "\n".join(lines) - util.write_file(sudo_base, sudoers_contents, 0440) + util.write_file(sudo_base, sudoers_contents, 0o440) else: lines = ['', util.make_header(base="added"), "#includedir %s" % (path), ''] @@ -478,7 +478,7 @@ class Distro(object): except IOError as e: util.logexc(LOG, "Failed to write %s", sudo_base) raise e - util.ensure_dir(path, 0750) + util.ensure_dir(path, 0o750) def write_sudo_rules(self, user, rules, sudo_file=None): if not sudo_file: @@ -506,7 +506,7 @@ class Distro(object): content, ] try: - util.write_file(sudo_file, "\n".join(contents), 0440) + util.write_file(sudo_file, "\n".join(contents), 0o440) except IOError as e: util.logexc(LOG, "Failed to write sudoers file %s", sudo_file) raise e diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 5517f0b4..03296e62 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -396,7 +396,7 @@ c: 4 mock_write = self.mocker.replace("cloudinit.util.write_file", passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0600) + mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) self.mocker.replay() log_file = self.capture_log(logging.WARNING) @@ -415,8 +415,8 @@ c: 4 outpath = os.path.join(ci.paths.get_ipath_cur("scripts"), "part-001") mock_write = self.mocker.replace("cloudinit.util.write_file", passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0600) - mock_write(outpath, script, 0700) + mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) + mock_write(outpath, script, 0o700) self.mocker.replay() log_file = self.capture_log(logging.WARNING) @@ -435,8 +435,8 @@ c: 4 outpath = os.path.join(ci.paths.get_ipath_cur("scripts"), "part-001") mock_write = self.mocker.replace("cloudinit.util.write_file", passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0600) - mock_write(outpath, script, 0700) + mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) + mock_write(outpath, script, 0o700) self.mocker.replay() log_file = self.capture_log(logging.WARNING) @@ -455,8 +455,8 @@ c: 4 outpath = os.path.join(ci.paths.get_ipath_cur("scripts"), "part-001") mock_write = self.mocker.replace("cloudinit.util.write_file", passthrough=False) - mock_write(outpath, script, 0700) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0600) + mock_write(outpath, script, 0o700) + mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) self.mocker.replay() log_file = self.capture_log(logging.WARNING) diff --git a/tests/unittests/test_datasource/test_altcloud.py b/tests/unittests/test_datasource/test_altcloud.py index 9d8a4a20..c74562d7 100644 --- a/tests/unittests/test_datasource/test_altcloud.py +++ b/tests/unittests/test_datasource/test_altcloud.py @@ -66,12 +66,12 @@ def _write_user_data_files(mount_dir, value): udfile = open(deltacloud_user_data_file, 'w') udfile.write(value) udfile.close() - os.chmod(deltacloud_user_data_file, 0664) + os.chmod(deltacloud_user_data_file, 0o664) udfile = open(user_data_file, 'w') udfile.write(value) udfile.close() - os.chmod(user_data_file, 0664) + os.chmod(user_data_file, 0o664) def _remove_user_data_files(mount_dir, diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 47de034b..33a1d6e1 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -112,7 +112,7 @@ class TestNetCfgDistro(MockerTestCase): self.assertIn('/etc/network/interfaces', write_bufs) write_buf = write_bufs['/etc/network/interfaces'] self.assertEquals(str(write_buf).strip(), BASE_NET_CFG.strip()) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) def assertCfgEquals(self, blob1, blob2): b1 = dict(SysConf(blob1.strip().splitlines())) @@ -136,7 +136,7 @@ class TestNetCfgDistro(MockerTestCase): write_bufs = {} - def replace_write(filename, content, mode=0644, omode="wb"): + def replace_write(filename, content, mode=0o644, omode="wb"): buf = WriteBuffer() buf.mode = mode buf.omode = omode @@ -169,7 +169,7 @@ DEVICE="lo" ONBOOT=yes ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0', write_bufs) write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0'] @@ -183,7 +183,7 @@ GATEWAY="192.168.1.254" BROADCAST="192.168.1.0" ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1', write_bufs) write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1'] @@ -193,7 +193,7 @@ BOOTPROTO="dhcp" ONBOOT=yes ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) self.assertIn('/etc/sysconfig/network', write_bufs) write_buf = write_bufs['/etc/sysconfig/network'] @@ -202,7 +202,7 @@ ONBOOT=yes NETWORKING=yes ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) def test_write_ipv6_rhel(self): rh_distro = self._get_distro('rhel') @@ -215,7 +215,7 @@ NETWORKING=yes write_bufs = {} - def replace_write(filename, content, mode=0644, omode="wb"): + def replace_write(filename, content, mode=0o644, omode="wb"): buf = WriteBuffer() buf.mode = mode buf.omode = omode @@ -248,7 +248,7 @@ DEVICE="lo" ONBOOT=yes ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0', write_bufs) write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0'] @@ -265,7 +265,7 @@ IPV6ADDR="2607:f0d0:1002:0011::2" IPV6_DEFAULTGW="2607:f0d0:1002:0011::1" ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1', write_bufs) write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1'] expected_buf = ''' @@ -281,7 +281,7 @@ IPV6ADDR="2607:f0d0:1002:0011::3" IPV6_DEFAULTGW="2607:f0d0:1002:0011::1" ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) self.assertIn('/etc/sysconfig/network', write_bufs) write_buf = write_bufs['/etc/sysconfig/network'] @@ -292,7 +292,7 @@ NETWORKING_IPV6=yes IPV6_AUTOCONF=no ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) def test_simple_write_freebsd(self): fbsd_distro = self._get_distro('freebsd') @@ -319,7 +319,7 @@ IPV6_AUTOCONF=no '/etc/resolv.conf': '', } - def replace_write(filename, content, mode=0644, omode="wb"): + def replace_write(filename, content, mode=0o644, omode="wb"): buf = WriteBuffer() buf.mode = mode buf.omode = omode @@ -355,4 +355,4 @@ ifconfig_vtnet1="DHCP" defaultrouter="192.168.1.254" ''' self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0644) + self.assertEquals(write_buf.mode, 0o644) diff --git a/tests/unittests/test_handler/test_handler_ca_certs.py b/tests/unittests/test_handler/test_handler_ca_certs.py index 75df807e..7fe47b74 100644 --- a/tests/unittests/test_handler/test_handler_ca_certs.py +++ b/tests/unittests/test_handler/test_handler_ca_certs.py @@ -171,7 +171,7 @@ class TestAddCaCerts(MockerTestCase): mock_load = self.mocker.replace(util.load_file, passthrough=False) mock_write("/usr/share/ca-certificates/cloud-init-ca-certs.crt", - cert, mode=0644) + cert, mode=0o644) mock_load("/etc/ca-certificates.conf") self.mocker.result(ca_certs_content) @@ -192,7 +192,7 @@ class TestAddCaCerts(MockerTestCase): mock_load = self.mocker.replace(util.load_file, passthrough=False) mock_write("/usr/share/ca-certificates/cloud-init-ca-certs.crt", - expected_cert_file, mode=0644) + expected_cert_file, mode=0o644) ca_certs_content = "line1\nline2\nline3" mock_load("/etc/ca-certificates.conf") @@ -233,7 +233,7 @@ class TestRemoveDefaultCaCerts(MockerTestCase): mock_delete_dir_contents("/usr/share/ca-certificates/") mock_delete_dir_contents("/etc/ssl/certs/") - mock_write("/etc/ca-certificates.conf", "", mode=0644) + mock_write("/etc/ca-certificates.conf", "", mode=0o644) mock_subp(('debconf-set-selections', '-'), "ca-certificates ca-certificates/trust_new_crts select no") self.mocker.replay() -- cgit v1.2.3 From c80892c9c326716724c3ff06d9a82516a4152c74 Mon Sep 17 00:00:00 2001 From: Ben Howard Date: Wed, 21 Jan 2015 15:42:55 -0700 Subject: Use either syspath or dmidecode based on the availability. --- cloudinit/util.py | 35 ++++++++++++++++++++++++++++++++++- tests/unittests/test_util.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index f7498b01..26456aa6 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -2016,7 +2016,7 @@ def human2bytes(size): return int(num * mpliers[mplier]) -def read_dmi_data(key): +def _read_dmi_syspath(key): """ Reads dmi data with from /sys/class/dmi/id """ @@ -2039,3 +2039,36 @@ def read_dmi_data(key): except Exception as e: logexc(LOG, "failed read of {}".format(dmi_key), e) return None + + +def _call_dmidecode(key, dmidecode_path): + """ + Calls out to dmidecode to get the data out. This is mostly for supporting + OS's without /sys/class/dmi/id support. + """ + try: + cmd = [dmidecode_path, "--string", key] + (result, _err) = subp(cmd) + LOG.debug("dmidecode returned '{}' for '{}'".format(result, key)) + return result + except OSError, _err: + LOG.debug('failed dmidecode cmd: {}\n{}'.format(cmd, _err.message)) + return None + + +def read_dmi_data(key): + """ + Wrapper for reading DMI data. This tries to determine whether the DMI + Data can be read directly, otherwise it will fallback to using dmidecode. + """ + if os.path.exists(DMI_SYS_PATH): + return _read_dmi_syspath(key) + + dmidecode_path = which('dmidecode') + if dmidecode_path: + return _call_dmidecode(key, dmidecode_path) + + LOG.warn("did not find either path {} or dmidecode command".format( + DMI_SYS_PATH)) + + return None diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 6ae41bd6..3e079131 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -319,6 +319,7 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): self.patchUtils(root) def _write_key(self, key, content): + """Mocks the sys path found on Linux systems.""" new_root = self.makeDir() self._patchIn(new_root) util.ensure_dir(os.path.join('sys', 'class', 'dmi', 'id')) @@ -326,6 +327,24 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): dmi_key = "/sys/class/dmi/id/{}".format(key) util.write_file(dmi_key, content) + def _no_syspath(self, key, content): + """ + In order to test a missing sys path and call outs to dmidecode, this + function fakes the results of dmidecode to test the results. + """ + new_root = self.makeDir() + self._patchIn(new_root) + self.real_which = util.which + self.real_subp = util.subp + + def _which(key): + return True + util.which = _which + + def _cdd(_key, error=None): + return (content, error) + util.subp = _cdd + def test_key(self): key_content = "TEST-KEY-DATA" self._write_key("key", key_content) @@ -333,9 +352,18 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): def test_key_mismatch(self): self._write_key("test", "ABC") - self.assertNotEqual("123", util.read_dmi_data("test")) + self.assertNotEqual("123", util.read_dmi_data("test")) def test_no_key(self): + self._no_syspath(None, None) self.assertFalse(util.read_dmi_data("key")) + def test_callout_dmidecode(self): + """test to make sure that dmidecode is used when no syspath""" + self._no_syspath("key", "stuff") + self.assertEquals("stuff", util.read_dmi_data("key")) + self._no_syspath("key", None) + self.assertFalse(None, util.read_dmi_data("key")) + + # vi: ts=4 expandtab -- cgit v1.2.3 From f895cb12141281702b34da18f2384deb64c881e7 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Wed, 21 Jan 2015 17:56:53 -0500 Subject: Largely merge lp:~harlowja/cloud-init/py2-3 albeit manually because it seemed to be behind trunk. `tox -e py27` passes full test suite. Now to work on replacing mocker. --- cloudinit/config/cc_apt_configure.py | 2 +- cloudinit/config/cc_debug.py | 7 +- cloudinit/config/cc_landscape.py | 2 +- cloudinit/config/cc_mcollective.py | 15 +-- cloudinit/config/cc_phone_home.py | 4 +- cloudinit/config/cc_puppet.py | 8 +- cloudinit/config/cc_resolv_conf.py | 4 +- cloudinit/config/cc_seed_random.py | 3 +- cloudinit/config/cc_ssh.py | 16 +-- cloudinit/config/cc_yum_add_repo.py | 7 +- cloudinit/distros/__init__.py | 55 ++++++----- cloudinit/distros/arch.py | 2 +- cloudinit/distros/freebsd.py | 12 ++- cloudinit/distros/net_util.py | 2 +- cloudinit/distros/parsers/hostname.py | 2 +- cloudinit/distros/parsers/hosts.py | 2 +- cloudinit/distros/parsers/resolv_conf.py | 2 +- cloudinit/distros/parsers/sys_conf.py | 5 +- cloudinit/distros/rhel.py | 2 +- cloudinit/distros/sles.py | 2 +- cloudinit/ec2_utils.py | 9 +- cloudinit/handlers/__init__.py | 2 +- cloudinit/handlers/boot_hook.py | 2 +- cloudinit/handlers/cloud_config.py | 2 +- cloudinit/handlers/shell_script.py | 2 +- cloudinit/handlers/upstart_job.py | 2 +- cloudinit/helpers.py | 13 +-- cloudinit/log.py | 7 +- cloudinit/mergers/__init__.py | 4 +- cloudinit/mergers/m_dict.py | 4 +- cloudinit/mergers/m_list.py | 6 +- cloudinit/mergers/m_str.py | 10 +- cloudinit/netinfo.py | 4 +- cloudinit/signal_handler.py | 2 +- cloudinit/sources/DataSourceConfigDrive.py | 4 +- cloudinit/sources/DataSourceDigitalOcean.py | 9 +- cloudinit/sources/DataSourceEc2.py | 4 +- cloudinit/sources/DataSourceMAAS.py | 2 +- cloudinit/sources/DataSourceOVF.py | 6 +- cloudinit/sources/DataSourceSmartOS.py | 15 +-- cloudinit/sources/__init__.py | 10 +- cloudinit/sources/helpers/openstack.py | 10 +- cloudinit/ssh_util.py | 6 +- cloudinit/stages.py | 23 ++--- cloudinit/type_utils.py | 32 ++++-- cloudinit/url_helper.py | 22 +++-- cloudinit/user_data.py | 8 +- cloudinit/util.py | 109 +++++++++++++-------- packages/bddeb | 1 + packages/brpm | 2 + tests/unittests/test_data.py | 12 +-- tests/unittests/test_datasource/test_nocloud.py | 2 +- tests/unittests/test_datasource/test_openstack.py | 7 +- tests/unittests/test_distros/test_netconfig.py | 4 +- .../test_handler/test_handler_apt_configure.py | 10 +- .../unittests/test_handler/test_handler_locale.py | 6 +- .../test_handler/test_handler_seed_random.py | 2 +- .../test_handler/test_handler_set_hostname.py | 6 +- .../test_handler/test_handler_timezone.py | 6 +- .../test_handler/test_handler_yum_add_repo.py | 7 +- 60 files changed, 315 insertions(+), 233 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_apt_configure.py b/cloudinit/config/cc_apt_configure.py index f10b76a3..de72903f 100644 --- a/cloudinit/config/cc_apt_configure.py +++ b/cloudinit/config/cc_apt_configure.py @@ -126,7 +126,7 @@ def mirror2lists_fileprefix(mirror): def rename_apt_lists(old_mirrors, new_mirrors, lists_d="/var/lib/apt/lists"): - for (name, omirror) in old_mirrors.iteritems(): + for (name, omirror) in old_mirrors.items(): nmirror = new_mirrors.get(name) if not nmirror: continue diff --git a/cloudinit/config/cc_debug.py b/cloudinit/config/cc_debug.py index 8c489426..bdc32fe6 100644 --- a/cloudinit/config/cc_debug.py +++ b/cloudinit/config/cc_debug.py @@ -34,7 +34,8 @@ It can be configured with the following option structure:: """ import copy -from StringIO import StringIO + +from six import StringIO from cloudinit import type_utils from cloudinit import util @@ -77,7 +78,7 @@ def handle(name, cfg, cloud, log, args): dump_cfg = copy.deepcopy(cfg) for k in SKIP_KEYS: dump_cfg.pop(k, None) - all_keys = list(dump_cfg.keys()) + all_keys = list(dump_cfg) for k in all_keys: if k.startswith("_"): dump_cfg.pop(k, None) @@ -103,6 +104,6 @@ def handle(name, cfg, cloud, log, args): line = "ci-info: %s\n" % (line) content_to_file.append(line) if out_file: - util.write_file(out_file, "".join(content_to_file), 0644, "w") + util.write_file(out_file, "".join(content_to_file), 0o644, "w") else: util.multi_log("".join(content_to_file), console=True, stderr=False) diff --git a/cloudinit/config/cc_landscape.py b/cloudinit/config/cc_landscape.py index 8a709677..0b9d846e 100644 --- a/cloudinit/config/cc_landscape.py +++ b/cloudinit/config/cc_landscape.py @@ -20,7 +20,7 @@ import os -from StringIO import StringIO +from six import StringIO from configobj import ConfigObj diff --git a/cloudinit/config/cc_mcollective.py b/cloudinit/config/cc_mcollective.py index b670390d..425420ae 100644 --- a/cloudinit/config/cc_mcollective.py +++ b/cloudinit/config/cc_mcollective.py @@ -19,7 +19,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +import six +from six import StringIO # Used since this can maintain comments # and doesn't need a top level section @@ -51,17 +52,17 @@ def handle(name, cfg, cloud, log, _args): # original file in order to be able to mix the rest up mcollective_config = ConfigObj(SERVER_CFG) # See: http://tiny.cc/jh9agw - for (cfg_name, cfg) in mcollective_cfg['conf'].iteritems(): + for (cfg_name, cfg) in mcollective_cfg['conf'].items(): if cfg_name == 'public-cert': - util.write_file(PUBCERT_FILE, cfg, mode=0644) + util.write_file(PUBCERT_FILE, cfg, mode=0o644) mcollective_config['plugin.ssl_server_public'] = PUBCERT_FILE mcollective_config['securityprovider'] = 'ssl' elif cfg_name == 'private-cert': - util.write_file(PRICERT_FILE, cfg, mode=0600) + util.write_file(PRICERT_FILE, cfg, mode=0o600) mcollective_config['plugin.ssl_server_private'] = PRICERT_FILE mcollective_config['securityprovider'] = 'ssl' else: - if isinstance(cfg, (basestring, str)): + if isinstance(cfg, six.string_types): # Just set it in the 'main' section mcollective_config[cfg_name] = cfg elif isinstance(cfg, (dict)): @@ -69,7 +70,7 @@ def handle(name, cfg, cloud, log, _args): # if it is needed and then add/or create items as needed if cfg_name not in mcollective_config.sections: mcollective_config[cfg_name] = {} - for (o, v) in cfg.iteritems(): + for (o, v) in cfg.items(): mcollective_config[cfg_name][o] = v else: # Otherwise just try to convert it to a string @@ -81,7 +82,7 @@ def handle(name, cfg, cloud, log, _args): contents = StringIO() mcollective_config.write(contents) contents = contents.getvalue() - util.write_file(SERVER_CFG, contents, mode=0644) + util.write_file(SERVER_CFG, contents, mode=0o644) # Start mcollective util.subp(['service', 'mcollective', 'start'], capture=False) diff --git a/cloudinit/config/cc_phone_home.py b/cloudinit/config/cc_phone_home.py index 5bc68b83..18a7ddad 100644 --- a/cloudinit/config/cc_phone_home.py +++ b/cloudinit/config/cc_phone_home.py @@ -81,7 +81,7 @@ def handle(name, cfg, cloud, log, args): 'pub_key_ecdsa': '/etc/ssh/ssh_host_ecdsa_key.pub', } - for (n, path) in pubkeys.iteritems(): + for (n, path) in pubkeys.items(): try: all_keys[n] = util.load_file(path) except: @@ -99,7 +99,7 @@ def handle(name, cfg, cloud, log, args): # Get them read to be posted real_submit_keys = {} - for (k, v) in submit_keys.iteritems(): + for (k, v) in submit_keys.items(): if v is None: real_submit_keys[k] = 'N/A' else: diff --git a/cloudinit/config/cc_puppet.py b/cloudinit/config/cc_puppet.py index 471a1a8a..6f1b3c57 100644 --- a/cloudinit/config/cc_puppet.py +++ b/cloudinit/config/cc_puppet.py @@ -18,7 +18,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +from six import StringIO import os import socket @@ -81,13 +81,13 @@ def handle(name, cfg, cloud, log, _args): cleaned_contents = '\n'.join(cleaned_lines) puppet_config.readfp(StringIO(cleaned_contents), filename=PUPPET_CONF_PATH) - for (cfg_name, cfg) in puppet_cfg['conf'].iteritems(): + for (cfg_name, cfg) in puppet_cfg['conf'].items(): # Cert configuration is a special case # Dump the puppet master ca certificate in the correct place if cfg_name == 'ca_cert': # Puppet ssl sub-directory isn't created yet # Create it with the proper permissions and ownership - util.ensure_dir(PUPPET_SSL_DIR, 0771) + util.ensure_dir(PUPPET_SSL_DIR, 0o771) util.chownbyname(PUPPET_SSL_DIR, 'puppet', 'root') util.ensure_dir(PUPPET_SSL_CERT_DIR) util.chownbyname(PUPPET_SSL_CERT_DIR, 'puppet', 'root') @@ -96,7 +96,7 @@ def handle(name, cfg, cloud, log, _args): else: # Iterate throug the config items, we'll use ConfigParser.set # to overwrite or create new items as needed - for (o, v) in cfg.iteritems(): + for (o, v) in cfg.items(): if o == 'certname': # Expand %f as the fqdn # TODO(harlowja) should this use the cloud fqdn?? diff --git a/cloudinit/config/cc_resolv_conf.py b/cloudinit/config/cc_resolv_conf.py index bbaa6c63..71d9e3a7 100644 --- a/cloudinit/config/cc_resolv_conf.py +++ b/cloudinit/config/cc_resolv_conf.py @@ -66,8 +66,8 @@ def generate_resolv_conf(template_fn, params, target_fname="/etc/resolv.conf"): false_flags = [] if 'options' in params: - for key, val in params['options'].iteritems(): - if type(val) == bool: + for key, val in params['options'].items(): + if isinstance(val, bool): if val: flags.append(key) else: diff --git a/cloudinit/config/cc_seed_random.py b/cloudinit/config/cc_seed_random.py index 49a6b3e8..3b7235bf 100644 --- a/cloudinit/config/cc_seed_random.py +++ b/cloudinit/config/cc_seed_random.py @@ -21,7 +21,8 @@ import base64 import os -from StringIO import StringIO + +from six import StringIO from cloudinit.settings import PER_INSTANCE from cloudinit import log as logging diff --git a/cloudinit/config/cc_ssh.py b/cloudinit/config/cc_ssh.py index 4c76581c..ab6940fa 100644 --- a/cloudinit/config/cc_ssh.py +++ b/cloudinit/config/cc_ssh.py @@ -34,12 +34,12 @@ DISABLE_ROOT_OPTS = ("no-port-forwarding,no-agent-forwarding," "rather than the user \\\"root\\\".\';echo;sleep 10\"") KEY_2_FILE = { - "rsa_private": ("/etc/ssh/ssh_host_rsa_key", 0600), - "rsa_public": ("/etc/ssh/ssh_host_rsa_key.pub", 0644), - "dsa_private": ("/etc/ssh/ssh_host_dsa_key", 0600), - "dsa_public": ("/etc/ssh/ssh_host_dsa_key.pub", 0644), - "ecdsa_private": ("/etc/ssh/ssh_host_ecdsa_key", 0600), - "ecdsa_public": ("/etc/ssh/ssh_host_ecdsa_key.pub", 0644), + "rsa_private": ("/etc/ssh/ssh_host_rsa_key", 0o600), + "rsa_public": ("/etc/ssh/ssh_host_rsa_key.pub", 0o644), + "dsa_private": ("/etc/ssh/ssh_host_dsa_key", 0o600), + "dsa_public": ("/etc/ssh/ssh_host_dsa_key.pub", 0o644), + "ecdsa_private": ("/etc/ssh/ssh_host_ecdsa_key", 0o600), + "ecdsa_public": ("/etc/ssh/ssh_host_ecdsa_key.pub", 0o644), } PRIV_2_PUB = { @@ -68,13 +68,13 @@ def handle(_name, cfg, cloud, log, _args): if "ssh_keys" in cfg: # if there are keys in cloud-config, use them - for (key, val) in cfg["ssh_keys"].iteritems(): + for (key, val) in cfg["ssh_keys"].items(): if key in KEY_2_FILE: tgt_fn = KEY_2_FILE[key][0] tgt_perms = KEY_2_FILE[key][1] util.write_file(tgt_fn, val, tgt_perms) - for (priv, pub) in PRIV_2_PUB.iteritems(): + for (priv, pub) in PRIV_2_PUB.items(): if pub in cfg['ssh_keys'] or priv not in cfg['ssh_keys']: continue pair = (KEY_2_FILE[priv][0], KEY_2_FILE[pub][0]) diff --git a/cloudinit/config/cc_yum_add_repo.py b/cloudinit/config/cc_yum_add_repo.py index 0d836f28..3b821af9 100644 --- a/cloudinit/config/cc_yum_add_repo.py +++ b/cloudinit/config/cc_yum_add_repo.py @@ -18,9 +18,10 @@ import os -from cloudinit import util - import configobj +import six + +from cloudinit import util def _canonicalize_id(repo_id): @@ -37,7 +38,7 @@ def _format_repo_value(val): # Can handle 'lists' in certain cases # See: http://bit.ly/Qqrf1t return "\n ".join([_format_repo_value(v) for v in val]) - if not isinstance(val, (basestring, str)): + if not isinstance(val, six.string_types): return str(val) return val diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 49a0b652..4ebccdda 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -21,7 +21,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +import six +from six import StringIO import abc import itertools @@ -334,7 +335,7 @@ class Distro(object): redact_opts = ['passwd'] # Check the values and create the command - for key, val in kwargs.iteritems(): + for key, val in kwargs.items(): if key in adduser_opts and val and isinstance(val, str): adduser_cmd.extend([adduser_opts[key], val]) @@ -393,7 +394,7 @@ class Distro(object): if 'ssh_authorized_keys' in kwargs: # Try to handle this in a smart manner. keys = kwargs['ssh_authorized_keys'] - if isinstance(keys, (basestring, str)): + if isinstance(keys, six.string_types): keys = [keys] if isinstance(keys, dict): keys = list(keys.values()) @@ -491,7 +492,7 @@ class Distro(object): if isinstance(rules, (list, tuple)): for rule in rules: lines.append("%s %s" % (user, rule)) - elif isinstance(rules, (basestring, str)): + elif isinstance(rules, six.string_types): lines.append("%s %s" % (user, rules)) else: msg = "Can not create sudoers rule addition with type %r" @@ -561,10 +562,10 @@ def _get_package_mirror_info(mirror_info, availability_zone=None, subst['ec2_region'] = "%s" % availability_zone[0:-1] results = {} - for (name, mirror) in mirror_info.get('failsafe', {}).iteritems(): + for (name, mirror) in mirror_info.get('failsafe', {}).items(): results[name] = mirror - for (name, searchlist) in mirror_info.get('search', {}).iteritems(): + for (name, searchlist) in mirror_info.get('search', {}).items(): mirrors = [] for tmpl in searchlist: try: @@ -604,30 +605,30 @@ def _get_arch_package_mirror_info(package_mirrors, arch): # is the standard form used in the rest # of cloud-init def _normalize_groups(grp_cfg): - if isinstance(grp_cfg, (str, basestring)): + if isinstance(grp_cfg, six.string_types): grp_cfg = grp_cfg.strip().split(",") - if isinstance(grp_cfg, (list)): + if isinstance(grp_cfg, list): c_grp_cfg = {} for i in grp_cfg: - if isinstance(i, (dict)): + if isinstance(i, dict): for k, v in i.items(): if k not in c_grp_cfg: - if isinstance(v, (list)): + if isinstance(v, list): c_grp_cfg[k] = list(v) - elif isinstance(v, (basestring, str)): + elif isinstance(v, six.string_types): c_grp_cfg[k] = [v] else: raise TypeError("Bad group member type %s" % type_utils.obj_name(v)) else: - if isinstance(v, (list)): + if isinstance(v, list): c_grp_cfg[k].extend(v) - elif isinstance(v, (basestring, str)): + elif isinstance(v, six.string_types): c_grp_cfg[k].append(v) else: raise TypeError("Bad group member type %s" % type_utils.obj_name(v)) - elif isinstance(i, (str, basestring)): + elif isinstance(i, six.string_types): if i not in c_grp_cfg: c_grp_cfg[i] = [] else: @@ -635,7 +636,7 @@ def _normalize_groups(grp_cfg): type_utils.obj_name(i)) grp_cfg = c_grp_cfg groups = {} - if isinstance(grp_cfg, (dict)): + if isinstance(grp_cfg, dict): for (grp_name, grp_members) in grp_cfg.items(): groups[grp_name] = util.uniq_merge_sorted(grp_members) else: @@ -661,29 +662,29 @@ def _normalize_groups(grp_cfg): # entry 'default' which will be marked as true # all other users will be marked as false. def _normalize_users(u_cfg, def_user_cfg=None): - if isinstance(u_cfg, (dict)): + if isinstance(u_cfg, dict): ad_ucfg = [] for (k, v) in u_cfg.items(): - if isinstance(v, (bool, int, basestring, str, float)): + if isinstance(v, (bool, int, float) + six.string_types): if util.is_true(v): ad_ucfg.append(str(k)) - elif isinstance(v, (dict)): + elif isinstance(v, dict): v['name'] = k ad_ucfg.append(v) else: raise TypeError(("Unmappable user value type %s" " for key %s") % (type_utils.obj_name(v), k)) u_cfg = ad_ucfg - elif isinstance(u_cfg, (str, basestring)): + elif isinstance(u_cfg, six.string_types): u_cfg = util.uniq_merge_sorted(u_cfg) users = {} for user_config in u_cfg: - if isinstance(user_config, (str, basestring, list)): + if isinstance(user_config, (list,) + six.string_types): for u in util.uniq_merge(user_config): if u and u not in users: users[u] = {} - elif isinstance(user_config, (dict)): + elif isinstance(user_config, dict): if 'name' in user_config: n = user_config.pop('name') prev_config = users.get(n) or {} @@ -784,11 +785,11 @@ def normalize_users_groups(cfg, distro): old_user = cfg['user'] # Translate it into the format that is more useful # going forward - if isinstance(old_user, (basestring, str)): + if isinstance(old_user, six.string_types): old_user = { 'name': old_user, } - if not isinstance(old_user, (dict)): + if not isinstance(old_user, dict): LOG.warn(("Format for 'user' key must be a string or " "dictionary and not %s"), type_utils.obj_name(old_user)) old_user = {} @@ -813,7 +814,7 @@ def normalize_users_groups(cfg, distro): default_user_config = util.mergemanydict([old_user, distro_user_config]) base_users = cfg.get('users', []) - if not isinstance(base_users, (list, dict, str, basestring)): + if not isinstance(base_users, (list, dict) + six.string_types): LOG.warn(("Format for 'users' key must be a comma separated string" " or a dictionary or a list and not %s"), type_utils.obj_name(base_users)) @@ -822,12 +823,12 @@ def normalize_users_groups(cfg, distro): if old_user: # Ensure that when user: is provided that this user # always gets added (as the default user) - if isinstance(base_users, (list)): + if isinstance(base_users, list): # Just add it on at the end... base_users.append({'name': 'default'}) - elif isinstance(base_users, (dict)): + elif isinstance(base_users, dict): base_users['default'] = dict(base_users).get('default', True) - elif isinstance(base_users, (str, basestring)): + elif isinstance(base_users, six.string_types): # Just append it on to be re-parsed later base_users += ",default" diff --git a/cloudinit/distros/arch.py b/cloudinit/distros/arch.py index 68bf1aab..e540e0bc 100644 --- a/cloudinit/distros/arch.py +++ b/cloudinit/distros/arch.py @@ -66,7 +66,7 @@ class Distro(distros.Distro): settings, entries) dev_names = entries.keys() # Format for netctl - for (dev, info) in entries.iteritems(): + for (dev, info) in entries.items(): nameservers = [] net_fn = self.network_conf_dir + dev net_cfg = { diff --git a/cloudinit/distros/freebsd.py b/cloudinit/distros/freebsd.py index f1b4a256..4c484639 100644 --- a/cloudinit/distros/freebsd.py +++ b/cloudinit/distros/freebsd.py @@ -16,7 +16,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +import six +from six import StringIO import re @@ -203,8 +204,9 @@ class Distro(distros.Distro): redact_opts = ['passwd'] - for key, val in kwargs.iteritems(): - if key in adduser_opts and val and isinstance(val, basestring): + for key, val in kwargs.items(): + if (key in adduser_opts and val + and isinstance(val, six.string_types)): adduser_cmd.extend([adduser_opts[key], val]) # Redact certain fields from the logs @@ -271,7 +273,7 @@ class Distro(distros.Distro): nameservers = [] searchdomains = [] dev_names = entries.keys() - for (device, info) in entries.iteritems(): + for (device, info) in entries.items(): # Skip the loopback interface. if device.startswith('lo'): continue @@ -323,7 +325,7 @@ class Distro(distros.Distro): resolvconf.add_search_domain(domain) except ValueError: util.logexc(LOG, "Failed to add search domain %s", domain) - util.write_file(self.resolv_conf_fn, str(resolvconf), 0644) + util.write_file(self.resolv_conf_fn, str(resolvconf), 0o644) return dev_names diff --git a/cloudinit/distros/net_util.py b/cloudinit/distros/net_util.py index 8b28e2d1..cadfa6b6 100644 --- a/cloudinit/distros/net_util.py +++ b/cloudinit/distros/net_util.py @@ -103,7 +103,7 @@ def translate_network(settings): consume[cmd] = args # Check if anything left over to consume absorb = False - for (cmd, args) in consume.iteritems(): + for (cmd, args) in consume.items(): if cmd == 'iface': absorb = True if absorb: diff --git a/cloudinit/distros/parsers/hostname.py b/cloudinit/distros/parsers/hostname.py index 617b3c36..84a1de42 100644 --- a/cloudinit/distros/parsers/hostname.py +++ b/cloudinit/distros/parsers/hostname.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +from six import StringIO from cloudinit.distros.parsers import chop_comment diff --git a/cloudinit/distros/parsers/hosts.py b/cloudinit/distros/parsers/hosts.py index 94c97051..3c5498ee 100644 --- a/cloudinit/distros/parsers/hosts.py +++ b/cloudinit/distros/parsers/hosts.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +from six import StringIO from cloudinit.distros.parsers import chop_comment diff --git a/cloudinit/distros/parsers/resolv_conf.py b/cloudinit/distros/parsers/resolv_conf.py index 5733c25a..8aee03a4 100644 --- a/cloudinit/distros/parsers/resolv_conf.py +++ b/cloudinit/distros/parsers/resolv_conf.py @@ -16,7 +16,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +from six import StringIO from cloudinit import util diff --git a/cloudinit/distros/parsers/sys_conf.py b/cloudinit/distros/parsers/sys_conf.py index 20ca1871..d795e12f 100644 --- a/cloudinit/distros/parsers/sys_conf.py +++ b/cloudinit/distros/parsers/sys_conf.py @@ -16,7 +16,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO +import six +from six import StringIO import pipes import re @@ -69,7 +70,7 @@ class SysConf(configobj.ConfigObj): return out_contents.getvalue() def _quote(self, value, multiline=False): - if not isinstance(value, (str, basestring)): + if not isinstance(value, six.string_types): raise ValueError('Value "%s" is not a string' % (value)) if len(value) == 0: return '' diff --git a/cloudinit/distros/rhel.py b/cloudinit/distros/rhel.py index d9588632..7408989c 100644 --- a/cloudinit/distros/rhel.py +++ b/cloudinit/distros/rhel.py @@ -73,7 +73,7 @@ class Distro(distros.Distro): searchservers = [] dev_names = entries.keys() use_ipv6 = False - for (dev, info) in entries.iteritems(): + for (dev, info) in entries.items(): net_fn = self.network_script_tpl % (dev) net_cfg = { 'DEVICE': dev, diff --git a/cloudinit/distros/sles.py b/cloudinit/distros/sles.py index 43682a12..0c6d1203 100644 --- a/cloudinit/distros/sles.py +++ b/cloudinit/distros/sles.py @@ -62,7 +62,7 @@ class Distro(distros.Distro): nameservers = [] searchservers = [] dev_names = entries.keys() - for (dev, info) in entries.iteritems(): + for (dev, info) in entries.items(): net_fn = self.network_script_tpl % (dev) mode = info.get('auto') if mode and mode.lower() == 'true': diff --git a/cloudinit/ec2_utils.py b/cloudinit/ec2_utils.py index e69d06ff..e1ed4091 100644 --- a/cloudinit/ec2_utils.py +++ b/cloudinit/ec2_utils.py @@ -17,7 +17,6 @@ # along with this program. If not, see . import functools -import httplib import json from cloudinit import log as logging @@ -25,7 +24,7 @@ from cloudinit import url_helper from cloudinit import util LOG = logging.getLogger(__name__) -SKIP_USERDATA_CODES = frozenset([httplib.NOT_FOUND]) +SKIP_USERDATA_CODES = frozenset([url_helper.NOT_FOUND]) class MetadataLeafDecoder(object): @@ -123,7 +122,7 @@ class MetadataMaterializer(object): leaf_contents = {} for (field, resource) in leaves.items(): leaf_url = url_helper.combine_url(base_url, resource) - leaf_blob = str(self._caller(leaf_url)) + leaf_blob = self._caller(leaf_url).contents leaf_contents[field] = self._leaf_decoder(field, leaf_blob) joined = {} joined.update(child_contents) @@ -160,7 +159,7 @@ def get_instance_userdata(api_version='latest', timeout=timeout, retries=retries, exception_cb=exception_cb) - user_data = str(response) + user_data = response.contents except url_helper.UrlError as e: if e.code not in SKIP_USERDATA_CODES: util.logexc(LOG, "Failed fetching userdata from url %s", ud_url) @@ -183,7 +182,7 @@ def get_instance_metadata(api_version='latest', try: response = caller(md_url) - materializer = MetadataMaterializer(str(response), + materializer = MetadataMaterializer(response.contents, md_url, caller, leaf_decoder=leaf_decoder) md = materializer.materialize() diff --git a/cloudinit/handlers/__init__.py b/cloudinit/handlers/__init__.py index 059d7495..d67a70ea 100644 --- a/cloudinit/handlers/__init__.py +++ b/cloudinit/handlers/__init__.py @@ -147,7 +147,7 @@ def walker_handle_handler(pdata, _ctype, _filename, payload): if not modfname.endswith(".py"): modfname = "%s.py" % (modfname) # TODO(harlowja): Check if path exists?? - util.write_file(modfname, payload, 0600) + util.write_file(modfname, payload, 0o600) handlers = pdata['handlers'] try: mod = fixup_handler(importer.import_module(modname)) diff --git a/cloudinit/handlers/boot_hook.py b/cloudinit/handlers/boot_hook.py index 3a50cf87..a4ea47ac 100644 --- a/cloudinit/handlers/boot_hook.py +++ b/cloudinit/handlers/boot_hook.py @@ -50,7 +50,7 @@ class BootHookPartHandler(handlers.Handler): filepath = os.path.join(self.boothook_dir, filename) contents = util.strip_prefix_suffix(util.dos2unix(payload), prefix=BOOTHOOK_PREFIX) - util.write_file(filepath, contents.lstrip(), 0700) + util.write_file(filepath, contents.lstrip(), 0o700) return filepath def handle_part(self, data, ctype, filename, payload, frequency): diff --git a/cloudinit/handlers/cloud_config.py b/cloudinit/handlers/cloud_config.py index bf994e33..07b6d0e0 100644 --- a/cloudinit/handlers/cloud_config.py +++ b/cloudinit/handlers/cloud_config.py @@ -95,7 +95,7 @@ class CloudConfigPartHandler(handlers.Handler): lines.append(util.yaml_dumps(self.cloud_buf)) else: lines = [] - util.write_file(self.cloud_fn, "\n".join(lines), 0600) + util.write_file(self.cloud_fn, "\n".join(lines), 0o600) def _extract_mergers(self, payload, headers): merge_header_headers = '' diff --git a/cloudinit/handlers/shell_script.py b/cloudinit/handlers/shell_script.py index 9755ab05..b5087693 100644 --- a/cloudinit/handlers/shell_script.py +++ b/cloudinit/handlers/shell_script.py @@ -52,4 +52,4 @@ class ShellScriptPartHandler(handlers.Handler): filename = util.clean_filename(filename) payload = util.dos2unix(payload) path = os.path.join(self.script_dir, filename) - util.write_file(path, payload, 0700) + util.write_file(path, payload, 0o700) diff --git a/cloudinit/handlers/upstart_job.py b/cloudinit/handlers/upstart_job.py index 50d193c4..c5bea711 100644 --- a/cloudinit/handlers/upstart_job.py +++ b/cloudinit/handlers/upstart_job.py @@ -65,7 +65,7 @@ class UpstartJobPartHandler(handlers.Handler): payload = util.dos2unix(payload) path = os.path.join(self.upstart_dir, filename) - util.write_file(path, payload, 0644) + util.write_file(path, payload, 0o644) if SUITABLE_UPSTART: util.subp(["initctl", "reload-configuration"], capture=False) diff --git a/cloudinit/helpers.py b/cloudinit/helpers.py index e701126e..ed396b5a 100644 --- a/cloudinit/helpers.py +++ b/cloudinit/helpers.py @@ -23,10 +23,11 @@ from time import time import contextlib -import io import os -from ConfigParser import (NoSectionError, NoOptionError, RawConfigParser) +import six +from six.moves.configparser import ( + NoSectionError, NoOptionError, RawConfigParser) from cloudinit.settings import (PER_INSTANCE, PER_ALWAYS, PER_ONCE, CFG_ENV_NAME) @@ -318,10 +319,10 @@ class ContentHandlers(object): return self.registered[content_type] def items(self): - return self.registered.items() + return list(self.registered.items()) - def iteritems(self): - return self.registered.iteritems() + # XXX This should really go away. + iteritems = items class Paths(object): @@ -449,7 +450,7 @@ class DefaultingConfigParser(RawConfigParser): def stringify(self, header=None): contents = '' - with io.BytesIO() as outputstream: + with six.StringIO() as outputstream: self.write(outputstream) outputstream.flush() contents = outputstream.getvalue() diff --git a/cloudinit/log.py b/cloudinit/log.py index 622c946c..3c79b9c9 100644 --- a/cloudinit/log.py +++ b/cloudinit/log.py @@ -28,7 +28,8 @@ import collections import os import sys -from StringIO import StringIO +import six +from six import StringIO # Logging levels for easy access CRITICAL = logging.CRITICAL @@ -72,13 +73,13 @@ def setupLogging(cfg=None): log_cfgs = [] log_cfg = cfg.get('logcfg') - if log_cfg and isinstance(log_cfg, (str, basestring)): + if log_cfg and isinstance(log_cfg, six.string_types): # If there is a 'logcfg' entry in the config, # respect it, it is the old keyname log_cfgs.append(str(log_cfg)) elif "log_cfgs" in cfg: for a_cfg in cfg['log_cfgs']: - if isinstance(a_cfg, (basestring, str)): + if isinstance(a_cfg, six.string_types): log_cfgs.append(a_cfg) elif isinstance(a_cfg, (collections.Iterable)): cfg_str = [str(c) for c in a_cfg] diff --git a/cloudinit/mergers/__init__.py b/cloudinit/mergers/__init__.py index 03aa1ee1..e13f55ac 100644 --- a/cloudinit/mergers/__init__.py +++ b/cloudinit/mergers/__init__.py @@ -18,6 +18,8 @@ import re +import six + from cloudinit import importer from cloudinit import log as logging from cloudinit import type_utils @@ -95,7 +97,7 @@ def dict_extract_mergers(config): raw_mergers = config.pop('merge_type', None) if raw_mergers is None: return parsed_mergers - if isinstance(raw_mergers, (str, basestring)): + if isinstance(raw_mergers, six.string_types): return string_extract_mergers(raw_mergers) for m in raw_mergers: if isinstance(m, (dict)): diff --git a/cloudinit/mergers/m_dict.py b/cloudinit/mergers/m_dict.py index a16141fa..87cf1a72 100644 --- a/cloudinit/mergers/m_dict.py +++ b/cloudinit/mergers/m_dict.py @@ -16,6 +16,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import six + DEF_MERGE_TYPE = 'no_replace' MERGE_TYPES = ('replace', DEF_MERGE_TYPE,) @@ -57,7 +59,7 @@ class Merger(object): return new_v if isinstance(new_v, (list, tuple)) and self._recurse_array: return self._merger.merge(old_v, new_v) - if isinstance(new_v, (basestring)) and self._recurse_str: + if isinstance(new_v, six.string_types) and self._recurse_str: return self._merger.merge(old_v, new_v) if isinstance(new_v, (dict)) and self._recurse_dict: return self._merger.merge(old_v, new_v) diff --git a/cloudinit/mergers/m_list.py b/cloudinit/mergers/m_list.py index 3b87b0fc..81e5c580 100644 --- a/cloudinit/mergers/m_list.py +++ b/cloudinit/mergers/m_list.py @@ -16,6 +16,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import six + DEF_MERGE_TYPE = 'replace' MERGE_TYPES = ('append', 'prepend', DEF_MERGE_TYPE, 'no_replace') @@ -73,7 +75,7 @@ class Merger(object): return old_v if isinstance(new_v, (list, tuple)) and self._recurse_array: return self._merger.merge(old_v, new_v) - if isinstance(new_v, (str, basestring)) and self._recurse_str: + if isinstance(new_v, six.string_types) and self._recurse_str: return self._merger.merge(old_v, new_v) if isinstance(new_v, (dict)) and self._recurse_dict: return self._merger.merge(old_v, new_v) @@ -82,6 +84,6 @@ class Merger(object): # Ok now we are replacing same indexes merged_list.extend(value) common_len = min(len(merged_list), len(merge_with)) - for i in xrange(0, common_len): + for i in range(0, common_len): merged_list[i] = merge_same_index(merged_list[i], merge_with[i]) return merged_list diff --git a/cloudinit/mergers/m_str.py b/cloudinit/mergers/m_str.py index e22ce28a..b00c4bf3 100644 --- a/cloudinit/mergers/m_str.py +++ b/cloudinit/mergers/m_str.py @@ -17,6 +17,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import six + class Merger(object): def __init__(self, _merger, opts): @@ -34,11 +36,11 @@ class Merger(object): # perform the following action, if appending we will # merge them together, otherwise we will just return value. def _on_str(self, value, merge_with): - if not isinstance(value, (basestring)): + if not isinstance(value, six.string_types): return merge_with if not self._append: return merge_with - if isinstance(value, unicode): - return value + unicode(merge_with) + if isinstance(value, six.text_type): + return value + six.text_type(merge_with) else: - return value + str(merge_with) + return value + six.binary_type(merge_with) diff --git a/cloudinit/netinfo.py b/cloudinit/netinfo.py index fb40cc0d..e30d6fb5 100644 --- a/cloudinit/netinfo.py +++ b/cloudinit/netinfo.py @@ -87,7 +87,7 @@ def netdev_info(empty=""): devs[curdev][target] = toks[i][len(field) + 1:] if empty != "": - for (_devname, dev) in devs.iteritems(): + for (_devname, dev) in devs.items(): for field in dev: if dev[field] == "": dev[field] = empty @@ -181,7 +181,7 @@ def netdev_pformat(): else: fields = ['Device', 'Up', 'Address', 'Mask', 'Scope', 'Hw-Address'] tbl = PrettyTable(fields) - for (dev, d) in netdev.iteritems(): + for (dev, d) in netdev.items(): tbl.add_row([dev, d["up"], d["addr"], d["mask"], ".", d["hwaddr"]]) if d.get('addr6'): tbl.add_row([dev, d["up"], diff --git a/cloudinit/signal_handler.py b/cloudinit/signal_handler.py index 40b0c94c..0d95f506 100644 --- a/cloudinit/signal_handler.py +++ b/cloudinit/signal_handler.py @@ -22,7 +22,7 @@ import inspect import signal import sys -from StringIO import StringIO +from six import StringIO from cloudinit import log as logging from cloudinit import util diff --git a/cloudinit/sources/DataSourceConfigDrive.py b/cloudinit/sources/DataSourceConfigDrive.py index 15244a0d..eb474079 100644 --- a/cloudinit/sources/DataSourceConfigDrive.py +++ b/cloudinit/sources/DataSourceConfigDrive.py @@ -216,11 +216,11 @@ def on_first_boot(data, distro=None): files = data.get('files', {}) if files: LOG.debug("Writing %s injected files", len(files)) - for (filename, content) in files.iteritems(): + for (filename, content) in files.items(): if not filename.startswith(os.sep): filename = os.sep + filename try: - util.write_file(filename, content, mode=0660) + util.write_file(filename, content, mode=0o660) except IOError: util.logexc(LOG, "Failed writing file: %s", filename) diff --git a/cloudinit/sources/DataSourceDigitalOcean.py b/cloudinit/sources/DataSourceDigitalOcean.py index 8f27ee89..b20ce2a1 100644 --- a/cloudinit/sources/DataSourceDigitalOcean.py +++ b/cloudinit/sources/DataSourceDigitalOcean.py @@ -18,7 +18,7 @@ from cloudinit import log as logging from cloudinit import util from cloudinit import sources from cloudinit import ec2_utils -from types import StringType + import functools @@ -72,10 +72,11 @@ class DataSourceDigitalOcean(sources.DataSource): return "\n".join(self.metadata['vendor-data']) def get_public_ssh_keys(self): - if type(self.metadata['public-keys']) is StringType: - return [self.metadata['public-keys']] + public_keys = self.metadata['public-keys'] + if isinstance(public_keys, list): + return public_keys else: - return self.metadata['public-keys'] + return [public_keys] @property def availability_zone(self): diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index 1b20ecf3..798869b7 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -156,8 +156,8 @@ class DataSourceEc2(sources.DataSource): # 'ephemeral0': '/dev/sdb', # 'root': '/dev/sda1'} found = None - bdm_items = self.metadata['block-device-mapping'].iteritems() - for (entname, device) in bdm_items: + bdm = self.metadata['block-device-mapping'] + for (entname, device) in bdm.items(): if entname == name: found = device break diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index dfe90bc6..9a3e30c5 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -262,7 +262,7 @@ def check_seed_contents(content, seed): userdata = content.get('user-data', "") md = {} - for (key, val) in content.iteritems(): + for (key, val) in content.items(): if key == 'user-data': continue md[key] = val diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index 7ba60735..58a4b2a2 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -66,7 +66,7 @@ class DataSourceOVF(sources.DataSource): np = {'iso': transport_iso9660, 'vmware-guestd': transport_vmware_guestd, } name = None - for (name, transfunc) in np.iteritems(): + for (name, transfunc) in np.items(): (contents, _dev, _fname) = transfunc() if contents: break @@ -138,7 +138,7 @@ def read_ovf_environment(contents): ud = "" cfg_props = ['password'] md_props = ['seedfrom', 'local-hostname', 'public-keys', 'instance-id'] - for (prop, val) in props.iteritems(): + for (prop, val) in props.items(): if prop == 'hostname': prop = "local-hostname" if prop in md_props: @@ -183,7 +183,7 @@ def transport_iso9660(require_iso=True): # Go through mounts to see if it was already mounted mounts = util.mounts() - for (dev, info) in mounts.iteritems(): + for (dev, info) in mounts.items(): fstype = info['fstype'] if fstype != "iso9660" and require_iso: continue diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 2733a2f6..7a975d78 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -30,12 +30,12 @@ # Comments with "@datadictionary" are snippets of the definition import base64 +import os +import serial + from cloudinit import log as logging from cloudinit import sources from cloudinit import util -import os -import os.path -import serial LOG = logging.getLogger(__name__) @@ -201,7 +201,7 @@ class DataSourceSmartOS(sources.DataSource): if b64_all is not None: self.b64_all = util.is_true(b64_all) - for ci_noun, attribute in SMARTOS_ATTRIB_MAP.iteritems(): + for ci_noun, attribute in SMARTOS_ATTRIB_MAP.items(): smartos_noun, strip = attribute md[ci_noun] = self.query(smartos_noun, strip=strip) @@ -218,11 +218,12 @@ class DataSourceSmartOS(sources.DataSource): user_script = os.path.join(data_d, 'user-script') u_script_l = "%s/user-script" % LEGACY_USER_D write_boot_content(md.get('user-script'), content_f=user_script, - link=u_script_l, shebang=True, mode=0700) + link=u_script_l, shebang=True, mode=0o700) operator_script = os.path.join(data_d, 'operator-script') write_boot_content(md.get('operator-script'), - content_f=operator_script, shebang=False, mode=0700) + content_f=operator_script, shebang=False, + mode=0o700) # @datadictionary: This key has no defined format, but its value # is written to the file /var/db/mdata-user-data on each boot prior @@ -381,7 +382,7 @@ def dmi_data(): def write_boot_content(content, content_f, link=None, shebang=False, - mode=0400): + mode=0o400): """ Write the content to content_f. Under the following rules: 1. If no content, remove the file diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 7c7ef9ab..39eab51b 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -23,6 +23,8 @@ import abc import os +import six + from cloudinit import importer from cloudinit import log as logging from cloudinit import type_utils @@ -130,7 +132,7 @@ class DataSource(object): # we want to return the correct value for what will actually # exist in this instance mappings = {"sd": ("vd", "xvd", "vtb")} - for (nfrom, tlist) in mappings.iteritems(): + for (nfrom, tlist) in mappings.items(): if not short_name.startswith(nfrom): continue for nto in tlist: @@ -218,18 +220,18 @@ def normalize_pubkey_data(pubkey_data): if not pubkey_data: return keys - if isinstance(pubkey_data, (basestring, str)): + if isinstance(pubkey_data, six.string_types): return str(pubkey_data).splitlines() if isinstance(pubkey_data, (list, set)): return list(pubkey_data) if isinstance(pubkey_data, (dict)): - for (_keyname, klist) in pubkey_data.iteritems(): + for (_keyname, klist) in pubkey_data.items(): # lp:506332 uec metadata service responds with # data that makes boto populate a string for 'klist' rather # than a list. - if isinstance(klist, (str, basestring)): + if isinstance(klist, six.string_types): klist = [klist] if isinstance(klist, (list, set)): for pkey in klist: diff --git a/cloudinit/sources/helpers/openstack.py b/cloudinit/sources/helpers/openstack.py index b7e19314..88c7a198 100644 --- a/cloudinit/sources/helpers/openstack.py +++ b/cloudinit/sources/helpers/openstack.py @@ -24,6 +24,8 @@ import copy import functools import os +import six + from cloudinit import ec2_utils from cloudinit import log as logging from cloudinit import sources @@ -205,7 +207,7 @@ class BaseReader(object): """ load_json_anytype = functools.partial( - util.load_json, root_types=(dict, basestring, list)) + util.load_json, root_types=(dict, list) + six.string_types) def datafiles(version): files = {} @@ -234,7 +236,7 @@ class BaseReader(object): 'version': 2, } data = datafiles(self._find_working_version()) - for (name, (path, required, translator)) in data.iteritems(): + for (name, (path, required, translator)) in data.items(): path = self._path_join(self.base_path, path) data = None found = False @@ -364,7 +366,7 @@ class ConfigDriveReader(BaseReader): raise NonReadable("%s: no files found" % (self.base_path)) md = {} - for (name, (key, translator, default)) in FILES_V1.iteritems(): + for (name, (key, translator, default)) in FILES_V1.items(): if name in found: path = found[name] try: @@ -478,7 +480,7 @@ def convert_vendordata_json(data, recurse=True): """ if not data: return None - if isinstance(data, (str, unicode, basestring)): + if isinstance(data, six.string_types): return data if isinstance(data, list): return copy.deepcopy(data) diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py index 14d0cb0f..9b2f5ed5 100644 --- a/cloudinit/ssh_util.py +++ b/cloudinit/ssh_util.py @@ -239,7 +239,7 @@ def setup_user_keys(keys, username, options=None): # Make sure the users .ssh dir is setup accordingly (ssh_dir, pwent) = users_ssh_info(username) if not os.path.isdir(ssh_dir): - util.ensure_dir(ssh_dir, mode=0700) + util.ensure_dir(ssh_dir, mode=0o700) util.chownbyid(ssh_dir, pwent.pw_uid, pwent.pw_gid) # Turn the 'update' keys given into actual entries @@ -252,8 +252,8 @@ def setup_user_keys(keys, username, options=None): (auth_key_fn, auth_key_entries) = extract_authorized_keys(username) with util.SeLinuxGuard(ssh_dir, recursive=True): content = update_authorized_keys(auth_key_entries, key_entries) - util.ensure_dir(os.path.dirname(auth_key_fn), mode=0700) - util.write_file(auth_key_fn, content, mode=0600) + util.ensure_dir(os.path.dirname(auth_key_fn), mode=0o700) + util.write_file(auth_key_fn, content, mode=0o600) util.chownbyid(auth_key_fn, pwent.pw_uid, pwent.pw_gid) diff --git a/cloudinit/stages.py b/cloudinit/stages.py index 67f467f7..f4f4591d 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -20,12 +20,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import cPickle as pickle - import copy import os import sys +import six +from six.moves import cPickle as pickle + from cloudinit.settings import (PER_INSTANCE, FREQUENCIES, CLOUD_CONFIG) from cloudinit import handlers @@ -202,7 +203,7 @@ class Init(object): util.logexc(LOG, "Failed pickling datasource %s", self.datasource) return False try: - util.write_file(pickled_fn, pk_contents, mode=0400) + util.write_file(pickled_fn, pk_contents, mode=0o400) except Exception: util.logexc(LOG, "Failed pickling datasource to %s", pickled_fn) return False @@ -324,15 +325,15 @@ class Init(object): def _store_userdata(self): raw_ud = "%s" % (self.datasource.get_userdata_raw()) - util.write_file(self._get_ipath('userdata_raw'), raw_ud, 0600) + util.write_file(self._get_ipath('userdata_raw'), raw_ud, 0o600) processed_ud = "%s" % (self.datasource.get_userdata()) - util.write_file(self._get_ipath('userdata'), processed_ud, 0600) + util.write_file(self._get_ipath('userdata'), processed_ud, 0o600) def _store_vendordata(self): raw_vd = "%s" % (self.datasource.get_vendordata_raw()) - util.write_file(self._get_ipath('vendordata_raw'), raw_vd, 0600) + util.write_file(self._get_ipath('vendordata_raw'), raw_vd, 0o600) processed_vd = "%s" % (self.datasource.get_vendordata()) - util.write_file(self._get_ipath('vendordata'), processed_vd, 0600) + util.write_file(self._get_ipath('vendordata'), processed_vd, 0o600) def _default_handlers(self, opts=None): if opts is None: @@ -384,7 +385,7 @@ class Init(object): if not path or not os.path.isdir(path): return potential_handlers = util.find_modules(path) - for (fname, mod_name) in potential_handlers.iteritems(): + for (fname, mod_name) in potential_handlers.items(): try: mod_locs, looked_locs = importer.find_module( mod_name, [''], ['list_types', 'handle_part']) @@ -422,7 +423,7 @@ class Init(object): def init_handlers(): # Init the handlers first - for (_ctype, mod) in c_handlers.iteritems(): + for (_ctype, mod) in c_handlers.items(): if mod in c_handlers.initialized: # Avoid initing the same module twice (if said module # is registered to more than one content-type). @@ -449,7 +450,7 @@ class Init(object): def finalize_handlers(): # Give callbacks opportunity to finalize - for (_ctype, mod) in c_handlers.iteritems(): + for (_ctype, mod) in c_handlers.items(): if mod not in c_handlers.initialized: # Said module was never inited in the first place, so lets # not attempt to finalize those that never got called. @@ -574,7 +575,7 @@ class Modules(object): for item in cfg_mods: if not item: continue - if isinstance(item, (str, basestring)): + if isinstance(item, six.string_types): module_list.append({ 'mod': item.strip(), }) diff --git a/cloudinit/type_utils.py b/cloudinit/type_utils.py index cc3d9495..b93efd6a 100644 --- a/cloudinit/type_utils.py +++ b/cloudinit/type_utils.py @@ -22,11 +22,31 @@ import types +import six + + +if six.PY3: + _NAME_TYPES = ( + types.ModuleType, + types.FunctionType, + types.LambdaType, + type, + ) +else: + _NAME_TYPES = ( + types.TypeType, + types.ModuleType, + types.FunctionType, + types.LambdaType, + types.ClassType, + ) + def obj_name(obj): - if isinstance(obj, (types.TypeType, - types.ModuleType, - types.FunctionType, - types.LambdaType)): - return str(obj.__name__) - return obj_name(obj.__class__) + if isinstance(obj, _NAME_TYPES): + return six.text_type(obj.__name__) + else: + if not hasattr(obj, '__class__'): + return repr(obj) + else: + return obj_name(obj.__class__) diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 3074dd08..62001dff 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -20,21 +20,29 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import httplib import time -import urllib + +import six import requests from requests import exceptions -from urlparse import (urlparse, urlunparse) +from six.moves.urllib.parse import ( + urlparse, urlunparse, + quote as urlquote) from cloudinit import log as logging from cloudinit import version LOG = logging.getLogger(__name__) -NOT_FOUND = httplib.NOT_FOUND +if six.PY2: + import httplib + NOT_FOUND = httplib.NOT_FOUND +else: + import http.client + NOT_FOUND = http.client.NOT_FOUND + # Check if requests has ssl support (added in requests >= 0.8.8) SSL_ENABLED = False @@ -70,7 +78,7 @@ def combine_url(base, *add_ons): path = url_parsed[2] if path and not path.endswith("/"): path += "/" - path += urllib.quote(str(add_on), safe="/:") + path += urlquote(str(add_on), safe="/:") url_parsed[2] = path return urlunparse(url_parsed) @@ -111,7 +119,7 @@ class UrlResponse(object): @property def contents(self): - return self._response.content + return self._response.text @property def url(self): @@ -135,7 +143,7 @@ class UrlResponse(object): return self._response.status_code def __str__(self): - return self.contents + return self._response.text class UrlError(IOError): diff --git a/cloudinit/user_data.py b/cloudinit/user_data.py index de6487d8..9111bd39 100644 --- a/cloudinit/user_data.py +++ b/cloudinit/user_data.py @@ -29,6 +29,8 @@ from email.mime.multipart import MIMEMultipart from email.mime.nonmultipart import MIMENonMultipart from email.mime.text import MIMEText +import six + from cloudinit import handlers from cloudinit import log as logging from cloudinit import util @@ -235,7 +237,7 @@ class UserDataProcessor(object): resp = util.read_file_or_url(include_url, ssl_details=self.ssl_details) if include_once_on and resp.ok(): - util.write_file(include_once_fn, str(resp), mode=0600) + util.write_file(include_once_fn, str(resp), mode=0o600) if resp.ok(): content = str(resp) else: @@ -256,7 +258,7 @@ class UserDataProcessor(object): # filename and type not be present # or # scalar(payload) - if isinstance(ent, (str, basestring)): + if isinstance(ent, six.string_types): ent = {'content': ent} if not isinstance(ent, (dict)): # TODO(harlowja) raise? @@ -337,7 +339,7 @@ def convert_string(raw_data, headers=None): data = util.decomp_gzip(raw_data) if "mime-version:" in data[0:4096].lower(): msg = email.message_from_string(data) - for (key, val) in headers.iteritems(): + for (key, val) in headers.items(): _replace_header(msg, key, val) else: mtype = headers.get(CONTENT_TYPE, NOT_MULTIPART_TYPE) diff --git a/cloudinit/util.py b/cloudinit/util.py index 9efc704a..434ba7fb 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -20,8 +20,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from StringIO import StringIO - import contextlib import copy as obj_copy import ctypes @@ -45,8 +43,10 @@ import subprocess import sys import tempfile import time -import urlparse +from six.moves.urllib import parse as urlparse + +import six import yaml from cloudinit import importer @@ -69,8 +69,26 @@ FN_REPLACEMENTS = { } FN_ALLOWED = ('_-.()' + string.digits + string.ascii_letters) +TRUE_STRINGS = ('true', '1', 'on', 'yes') +FALSE_STRINGS = ('off', '0', 'no', 'false') + + # Helper utils to see if running in a container -CONTAINER_TESTS = ['running-in-container', 'lxc-is-container'] +CONTAINER_TESTS = ('running-in-container', 'lxc-is-container') + + +def decode_binary(blob, encoding='utf-8'): + # Converts a binary type into a text type using given encoding. + if isinstance(blob, six.text_type): + return blob + return blob.decode(encoding) + + +def encode_text(text, encoding='utf-8'): + # Converts a text string into a binary type using given encoding. + if isinstance(text, six.binary_type): + return text + return text.encode(encoding) class ProcessExecutionError(IOError): @@ -95,7 +113,7 @@ class ProcessExecutionError(IOError): else: self.description = description - if not isinstance(exit_code, (long, int)): + if not isinstance(exit_code, six.integer_types): self.exit_code = '-' else: self.exit_code = exit_code @@ -151,7 +169,8 @@ class SeLinuxGuard(object): path = os.path.realpath(self.path) # path should be a string, not unicode - path = str(path) + if six.PY2: + path = str(path) try: stats = os.lstat(path) self.selinux.matchpathcon(path, stats[stat.ST_MODE]) @@ -209,10 +228,10 @@ def fork_cb(child_cb, *args, **kwargs): def is_true(val, addons=None): if isinstance(val, (bool)): return val is True - check_set = ['true', '1', 'on', 'yes'] + check_set = TRUE_STRINGS if addons: - check_set = check_set + addons - if str(val).lower().strip() in check_set: + check_set = list(check_set) + addons + if six.text_type(val).lower().strip() in check_set: return True return False @@ -220,10 +239,10 @@ def is_true(val, addons=None): def is_false(val, addons=None): if isinstance(val, (bool)): return val is False - check_set = ['off', '0', 'no', 'false'] + check_set = FALSE_STRINGS if addons: - check_set = check_set + addons - if str(val).lower().strip() in check_set: + check_set = list(check_set) + addons + if six.text_type(val).lower().strip() in check_set: return True return False @@ -273,7 +292,7 @@ def uniq_merge_sorted(*lists): def uniq_merge(*lists): combined_list = [] for a_list in lists: - if isinstance(a_list, (str, basestring)): + if isinstance(a_list, six.string_types): a_list = a_list.strip().split(",") # Kickout the empty ones a_list = [a for a in a_list if len(a)] @@ -282,7 +301,7 @@ def uniq_merge(*lists): def clean_filename(fn): - for (k, v) in FN_REPLACEMENTS.iteritems(): + for (k, v) in FN_REPLACEMENTS.items(): fn = fn.replace(k, v) removals = [] for k in fn: @@ -296,14 +315,14 @@ def clean_filename(fn): def decomp_gzip(data, quiet=True): try: - buf = StringIO(str(data)) + buf = six.BytesIO(encode_text(data)) with contextlib.closing(gzip.GzipFile(None, "rb", 1, buf)) as gh: - return gh.read() + return decode_binary(gh.read()) except Exception as e: if quiet: return data else: - raise DecompressionError(str(e)) + raise DecompressionError(six.text_type(e)) def extract_usergroup(ug_pair): @@ -362,7 +381,7 @@ def multi_log(text, console=True, stderr=True, def load_json(text, root_types=(dict,)): - decoded = json.loads(text) + decoded = json.loads(decode_binary(text)) if not isinstance(decoded, tuple(root_types)): expected_types = ", ".join([str(t) for t in root_types]) raise TypeError("(%s) root types expected, got %s instead" @@ -394,7 +413,7 @@ def get_cfg_option_str(yobj, key, default=None): if key not in yobj: return default val = yobj[key] - if not isinstance(val, (str, basestring)): + if not isinstance(val, six.string_types): val = str(val) return val @@ -433,7 +452,7 @@ def get_cfg_option_list(yobj, key, default=None): if isinstance(val, (list)): cval = [v for v in val] return cval - if not isinstance(val, (basestring)): + if not isinstance(val, six.string_types): val = str(val) return [val] @@ -708,10 +727,10 @@ def read_file_or_url(url, timeout=5, retries=10, def load_yaml(blob, default=None, allowed=(dict,)): loaded = default + blob = decode_binary(blob) try: - blob = str(blob) - LOG.debug(("Attempting to load yaml from string " - "of length %s with allowed root types %s"), + LOG.debug("Attempting to load yaml from string " + "of length %s with allowed root types %s", len(blob), allowed) converted = safeyaml.load(blob) if not isinstance(converted, allowed): @@ -746,14 +765,12 @@ def read_seeded(base="", ext="", timeout=5, retries=10, file_retries=0): md_resp = read_file_or_url(md_url, timeout, retries, file_retries) md = None if md_resp.ok(): - md_str = str(md_resp) - md = load_yaml(md_str, default={}) + md = load_yaml(md_resp.contents, default={}) ud_resp = read_file_or_url(ud_url, timeout, retries, file_retries) ud = None if ud_resp.ok(): - ud_str = str(ud_resp) - ud = ud_str + ud = ud_resp.contents return (md, ud) @@ -784,7 +801,7 @@ def read_conf_with_confd(cfgfile): if "conf_d" in cfg: confd = cfg['conf_d'] if confd: - if not isinstance(confd, (str, basestring)): + if not isinstance(confd, six.string_types): raise TypeError(("Config file %s contains 'conf_d' " "with non-string type %s") % (cfgfile, type_utils.obj_name(confd))) @@ -921,8 +938,8 @@ def get_cmdline_url(names=('cloud-config-url', 'url'), return (None, None, None) resp = read_file_or_url(url) - if resp.contents.startswith(starts) and resp.ok(): - return (key, url, str(resp)) + if resp.ok() and resp.contents.startswith(starts): + return (key, url, resp.contents) return (key, url, None) @@ -1076,9 +1093,9 @@ def uniq_list(in_list): return out_list -def load_file(fname, read_cb=None, quiet=False): +def load_file(fname, read_cb=None, quiet=False, decode=True): LOG.debug("Reading from %s (quiet=%s)", fname, quiet) - ofh = StringIO() + ofh = six.BytesIO() try: with open(fname, 'rb') as ifh: pipe_in_out(ifh, ofh, chunk_cb=read_cb) @@ -1089,7 +1106,10 @@ def load_file(fname, read_cb=None, quiet=False): raise contents = ofh.getvalue() LOG.debug("Read %s bytes from %s", len(contents), fname) - return contents + if decode: + return decode_binary(contents) + else: + return contents def get_cmdline(): @@ -1219,7 +1239,7 @@ def logexc(log, msg, *args): def hash_blob(blob, routine, mlen=None): hasher = hashlib.new(routine) - hasher.update(blob) + hasher.update(encode_text(blob)) digest = hasher.hexdigest() # Don't get to long now if mlen is not None: @@ -1280,8 +1300,7 @@ def yaml_dumps(obj, explicit_start=True, explicit_end=True): indent=4, explicit_start=explicit_start, explicit_end=explicit_end, - default_flow_style=False, - allow_unicode=True) + default_flow_style=False) def ensure_dir(path, mode=None): @@ -1515,11 +1534,17 @@ def write_file(filename, content, mode=0o644, omode="wb"): @param filename: The full path of the file to write. @param content: The content to write to the file. @param mode: The filesystem mode to set on the file. - @param omode: The open mode used when opening the file (r, rb, a, etc.) + @param omode: The open mode used when opening the file (w, wb, a, etc.) """ ensure_dir(os.path.dirname(filename)) - LOG.debug("Writing to %s - %s: [%s] %s bytes", - filename, omode, mode, len(content)) + if 'b' in omode.lower(): + content = encode_text(content) + write_type = 'bytes' + else: + content = decode_binary(content) + write_type = 'characters' + LOG.debug("Writing to %s - %s: [%s] %s %s", + filename, omode, mode, len(content), write_type) with SeLinuxGuard(path=filename): with open(filename, omode) as fh: fh.write(content) @@ -1608,10 +1633,10 @@ def shellify(cmdlist, add_header=True): if isinstance(args, list): fixed = [] for f in args: - fixed.append("'%s'" % (str(f).replace("'", escaped))) + fixed.append("'%s'" % (six.text_type(f).replace("'", escaped))) content = "%s%s\n" % (content, ' '.join(fixed)) cmds_made += 1 - elif isinstance(args, (str, basestring)): + elif isinstance(args, six.string_types): content = "%s%s\n" % (content, args) cmds_made += 1 else: @@ -1722,7 +1747,7 @@ def expand_package_list(version_fmt, pkgs): pkglist = [] for pkg in pkgs: - if isinstance(pkg, basestring): + if isinstance(pkg, six.string_types): pkglist.append(pkg) continue diff --git a/packages/bddeb b/packages/bddeb index 9d264f92..83ca68bb 100755 --- a/packages/bddeb +++ b/packages/bddeb @@ -38,6 +38,7 @@ PKG_MP = { 'pyserial': 'python-serial', 'pyyaml': 'python-yaml', 'requests': 'python-requests', + 'six': 'python-six', } DEBUILD_ARGS = ["-S", "-d"] diff --git a/packages/brpm b/packages/brpm index 9657b1dd..72bfca08 100755 --- a/packages/brpm +++ b/packages/brpm @@ -45,6 +45,7 @@ PKG_MP = { 'pyserial': 'pyserial', 'pyyaml': 'PyYAML', 'requests': 'python-requests', + 'six': 'python-six', }, 'suse': { 'argparse': 'python-argparse', @@ -56,6 +57,7 @@ PKG_MP = { 'pyserial': 'python-pyserial', 'pyyaml': 'python-yaml', 'requests': 'python-requests', + 'six': 'python-six', } } diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 03296e62..a35afc27 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -1,11 +1,11 @@ """Tests for handling of userdata within cloud init.""" -import StringIO - import gzip import logging import os +from six import BytesIO, StringIO + from email.mime.application import MIMEApplication from email.mime.base import MIMEBase from email.mime.multipart import MIMEMultipart @@ -53,7 +53,7 @@ class TestConsumeUserData(helpers.FilesystemMockingTestCase): self.patchUtils(root) def capture_log(self, lvl=logging.DEBUG): - log_file = StringIO.StringIO() + log_file = StringIO() self._log_handler = logging.StreamHandler(log_file) self._log_handler.setLevel(lvl) self._log = log.getLogger() @@ -351,9 +351,9 @@ p: 1 """Tests that individual message gzip encoding works.""" def gzip_part(text): - contents = StringIO.StringIO() - f = gzip.GzipFile(fileobj=contents, mode='w') - f.write(str(text)) + contents = BytesIO() + f = gzip.GzipFile(fileobj=contents, mode='wb') + f.write(util.encode_text(text)) f.flush() f.close() return MIMEApplication(contents.getvalue(), 'gzip') diff --git a/tests/unittests/test_datasource/test_nocloud.py b/tests/unittests/test_datasource/test_nocloud.py index e9235951..ae9e6c22 100644 --- a/tests/unittests/test_datasource/test_nocloud.py +++ b/tests/unittests/test_datasource/test_nocloud.py @@ -85,7 +85,7 @@ class TestNoCloudDataSource(MockerTestCase): data = { 'fs_label': None, - 'meta-data': {'instance-id': 'IID'}, + 'meta-data': yaml.safe_dump({'instance-id': 'IID'}), 'user-data': "USER_DATA_RAW", } diff --git a/tests/unittests/test_datasource/test_openstack.py b/tests/unittests/test_datasource/test_openstack.py index 49894e51..81ef1546 100644 --- a/tests/unittests/test_datasource/test_openstack.py +++ b/tests/unittests/test_datasource/test_openstack.py @@ -20,12 +20,11 @@ import copy import json import re -from StringIO import StringIO - -from urlparse import urlparse - from .. import helpers as test_helpers +from six import StringIO +from six.moves.urllib.parse import urlparse + from cloudinit import helpers from cloudinit import settings from cloudinit.sources import DataSourceOpenStack as ds diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 33a1d6e1..6e1a0b69 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -4,6 +4,8 @@ import mocker import os +from six import StringIO + from cloudinit import distros from cloudinit import helpers from cloudinit import settings @@ -11,8 +13,6 @@ from cloudinit import util from cloudinit.distros.parsers.sys_conf import SysConf -from StringIO import StringIO - BASE_NET_CFG = ''' auto lo diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index 203dd2aa..f5832365 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -16,12 +16,12 @@ class TestAptProxyConfig(MockerTestCase): self.cfile = os.path.join(self.tmp, "config.cfg") def _search_apt_config(self, contents, ptype, value): - print( + ## print( + ## r"acquire::%s::proxy\s+[\"']%s[\"'];\n" % (ptype, value), + ## contents, "flags=re.IGNORECASE") + return re.search( r"acquire::%s::proxy\s+[\"']%s[\"'];\n" % (ptype, value), - contents, "flags=re.IGNORECASE") - return(re.search( - r"acquire::%s::proxy\s+[\"']%s[\"'];\n" % (ptype, value), - contents, flags=re.IGNORECASE)) + contents, flags=re.IGNORECASE) def test_apt_proxy_written(self): cfg = {'apt_proxy': 'myproxy'} diff --git a/tests/unittests/test_handler/test_handler_locale.py b/tests/unittests/test_handler/test_handler_locale.py index eb251636..690ef86f 100644 --- a/tests/unittests/test_handler/test_handler_locale.py +++ b/tests/unittests/test_handler/test_handler_locale.py @@ -29,7 +29,7 @@ from .. import helpers as t_help from configobj import ConfigObj -from StringIO import StringIO +from six import BytesIO import logging @@ -59,6 +59,6 @@ class TestLocale(t_help.FilesystemMockingTestCase): cc = self._get_cloud('sles') cc_locale.handle('cc_locale', cfg, cc, LOG, []) - contents = util.load_file('/etc/sysconfig/language') - n_cfg = ConfigObj(StringIO(contents)) + contents = util.load_file('/etc/sysconfig/language', decode=False) + n_cfg = ConfigObj(BytesIO(contents)) self.assertEquals({'RC_LANG': cfg['locale']}, dict(n_cfg)) diff --git a/tests/unittests/test_handler/test_handler_seed_random.py b/tests/unittests/test_handler/test_handler_seed_random.py index 40481f16..579377fb 100644 --- a/tests/unittests/test_handler/test_handler_seed_random.py +++ b/tests/unittests/test_handler/test_handler_seed_random.py @@ -22,7 +22,7 @@ import base64 import gzip import tempfile -from StringIO import StringIO +from six import StringIO from cloudinit import cloud from cloudinit import distros diff --git a/tests/unittests/test_handler/test_handler_set_hostname.py b/tests/unittests/test_handler/test_handler_set_hostname.py index e1530e30..a9f7829b 100644 --- a/tests/unittests/test_handler/test_handler_set_hostname.py +++ b/tests/unittests/test_handler/test_handler_set_hostname.py @@ -9,7 +9,7 @@ from .. import helpers as t_help import logging -from StringIO import StringIO +from six import BytesIO from configobj import ConfigObj @@ -38,8 +38,8 @@ class TestHostname(t_help.FilesystemMockingTestCase): cc_set_hostname.handle('cc_set_hostname', cfg, cc, LOG, []) if not distro.uses_systemd(): - contents = util.load_file("/etc/sysconfig/network") - n_cfg = ConfigObj(StringIO(contents)) + contents = util.load_file("/etc/sysconfig/network", decode=False) + n_cfg = ConfigObj(BytesIO(contents)) self.assertEquals({'HOSTNAME': 'blah.blah.blah.yahoo.com'}, dict(n_cfg)) diff --git a/tests/unittests/test_handler/test_handler_timezone.py b/tests/unittests/test_handler/test_handler_timezone.py index 874db340..10ea2040 100644 --- a/tests/unittests/test_handler/test_handler_timezone.py +++ b/tests/unittests/test_handler/test_handler_timezone.py @@ -29,7 +29,7 @@ from .. import helpers as t_help from configobj import ConfigObj -from StringIO import StringIO +from six import BytesIO import logging @@ -67,8 +67,8 @@ class TestTimezone(t_help.FilesystemMockingTestCase): cc_timezone.handle('cc_timezone', cfg, cc, LOG, []) - contents = util.load_file('/etc/sysconfig/clock') - n_cfg = ConfigObj(StringIO(contents)) + contents = util.load_file('/etc/sysconfig/clock', decode=False) + n_cfg = ConfigObj(BytesIO(contents)) self.assertEquals({'TIMEZONE': cfg['timezone']}, dict(n_cfg)) contents = util.load_file('/etc/localtime') diff --git a/tests/unittests/test_handler/test_handler_yum_add_repo.py b/tests/unittests/test_handler/test_handler_yum_add_repo.py index 435c9787..81806ad1 100644 --- a/tests/unittests/test_handler/test_handler_yum_add_repo.py +++ b/tests/unittests/test_handler/test_handler_yum_add_repo.py @@ -6,7 +6,7 @@ from .. import helpers import logging -from StringIO import StringIO +from six import BytesIO import configobj @@ -52,8 +52,9 @@ class TestConfig(helpers.FilesystemMockingTestCase): } self.patchUtils(self.tmp) cc_yum_add_repo.handle('yum_add_repo', cfg, None, LOG, []) - contents = util.load_file("/etc/yum.repos.d/epel_testing.repo") - contents = configobj.ConfigObj(StringIO(contents)) + contents = util.load_file("/etc/yum.repos.d/epel_testing.repo", + decode=False) + contents = configobj.ConfigObj(BytesIO(contents)) expected = { 'epel_testing': { 'name': 'Extra Packages for Enterprise Linux 5 - Testing', -- cgit v1.2.3 From 6363b546c8bacb24b9350d108165a09bb65828a1 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 11:38:51 -0500 Subject: Port test__init__.py to unittest.mock. --- tests/unittests/helpers.py | 19 +--- tests/unittests/test__init__.py | 239 +++++++++++++++++++--------------------- tests/unittests/test_util.py | 3 +- tox.ini | 2 +- 4 files changed, 118 insertions(+), 145 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 52305397..4ca460f1 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -2,11 +2,6 @@ import os import sys import unittest -from contextlib import contextmanager - -from mocker import Mocker -from mocker import MockerTestCase - from cloudinit import helpers as ch from cloudinit import util @@ -86,17 +81,6 @@ else: pass -@contextmanager -def mocker(verify_calls=True): - m = Mocker() - try: - yield m - finally: - m.restore() - if verify_calls: - m.verify() - - # Makes the old path start # with new base instead of whatever # it previously had @@ -126,9 +110,8 @@ def retarget_many_wrapper(new_base, am, old_func): return wrapper -class ResourceUsingTestCase(MockerTestCase): +class ResourceUsingTestCase(unittest.TestCase): def __init__(self, methodName="runTest"): - MockerTestCase.__init__(self, methodName) self.resource_path = None def resourceLocation(self, subname=None): diff --git a/tests/unittests/test__init__.py b/tests/unittests/test__init__.py index 48db1a5e..ce4704d8 100644 --- a/tests/unittests/test__init__.py +++ b/tests/unittests/test__init__.py @@ -1,10 +1,19 @@ import os - -from mocker import MockerTestCase, ARGS, KWARGS +import shutil +import tempfile +import unittest + +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack from cloudinit import handlers from cloudinit import helpers -from cloudinit import importer from cloudinit import settings from cloudinit import url_helper from cloudinit import util @@ -22,76 +31,76 @@ class FakeModule(handlers.Handler): pass -class TestWalkerHandleHandler(MockerTestCase): +class TestWalkerHandleHandler(unittest.TestCase): def setUp(self): - - MockerTestCase.setUp(self) + unittest.TestCase.setUp(self) + tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmpdir) self.data = { "handlercount": 0, "frequency": "", - "handlerdir": self.makeDir(), + "handlerdir": tmpdir, "handlers": helpers.ContentHandlers(), "data": None} self.expected_module_name = "part-handler-%03d" % ( self.data["handlercount"],) expected_file_name = "%s.py" % self.expected_module_name - expected_file_fullname = os.path.join(self.data["handlerdir"], - expected_file_name) + self.expected_file_fullname = os.path.join( + self.data["handlerdir"], expected_file_name) self.module_fake = FakeModule() self.ctype = None self.filename = None self.payload = "dummy payload" - # Mock the write_file function - write_file_mock = self.mocker.replace(util.write_file, - passthrough=False) - write_file_mock(expected_file_fullname, self.payload, 0o600) + # Mock the write_file() function. We'll assert that it got called as + # expected in each of the individual tests. + self.resources = ExitStack() + self.write_file_mock = self.resources.enter_context( + mock.patch('cloudinit.util.write_file')) + + def tearDown(self): + self.resources.close() + unittest.TestCase.tearDown(self) def test_no_errors(self): """Payload gets written to file and added to C{pdata}.""" - import_mock = self.mocker.replace(importer.import_module, - passthrough=False) - import_mock(self.expected_module_name) - self.mocker.result(self.module_fake) - self.mocker.replay() - - handlers.walker_handle_handler(self.data, self.ctype, self.filename, - self.payload) - - self.assertEqual(1, self.data["handlercount"]) + with mock.patch('cloudinit.importer.import_module', + return_value=self.module_fake) as mockobj: + handlers.walker_handle_handler(self.data, self.ctype, + self.filename, self.payload) + mockobj.assert_called_with_once(self.expected_module_name) + self.write_file_mock.assert_called_with_once( + self.expected_file_fullname, self.payload, 0o600) + self.assertEqual(self.data['handlercount'], 1) def test_import_error(self): """Module import errors are logged. No handler added to C{pdata}.""" - import_mock = self.mocker.replace(importer.import_module, - passthrough=False) - import_mock(self.expected_module_name) - self.mocker.throw(ImportError()) - self.mocker.replay() - - handlers.walker_handle_handler(self.data, self.ctype, self.filename, - self.payload) - - self.assertEqual(0, self.data["handlercount"]) + with mock.patch('cloudinit.importer.import_module', + side_effect=ImportError) as mockobj: + handlers.walker_handle_handler(self.data, self.ctype, + self.filename, self.payload) + mockobj.assert_called_with_once(self.expected_module_name) + self.write_file_mock.assert_called_with_once( + self.expected_file_fullname, self.payload, 0o600) + self.assertEqual(self.data['handlercount'], 0) def test_attribute_error(self): """Attribute errors are logged. No handler added to C{pdata}.""" - import_mock = self.mocker.replace(importer.import_module, - passthrough=False) - import_mock(self.expected_module_name) - self.mocker.result(self.module_fake) - self.mocker.throw(AttributeError()) - self.mocker.replay() - - handlers.walker_handle_handler(self.data, self.ctype, self.filename, - self.payload) + with mock.patch('cloudinit.importer.import_module', + side_effect=AttributeError, + return_value=self.module_fake) as mockobj: + handlers.walker_handle_handler(self.data, self.ctype, + self.filename, self.payload) + mockobj.assert_called_with_once(self.expected_module_name) + self.write_file_mock.assert_called_with_once( + self.expected_file_fullname, self.payload, 0o600) + self.assertEqual(self.data['handlercount'], 0) - self.assertEqual(0, self.data["handlercount"]) - -class TestHandlerHandlePart(MockerTestCase): +class TestHandlerHandlePart(unittest.TestCase): def setUp(self): self.data = "fake data" @@ -108,95 +117,80 @@ class TestHandlerHandlePart(MockerTestCase): C{handle_part} is called without C{frequency} for C{handler_version} == 1. """ - mod_mock = self.mocker.mock() - getattr(mod_mock, "frequency") - self.mocker.result(settings.PER_INSTANCE) - getattr(mod_mock, "handler_version") - self.mocker.result(1) - mod_mock.handle_part(self.data, self.ctype, self.filename, - self.payload) - self.mocker.replay() - - handlers.run_part(mod_mock, self.data, self.filename, - self.payload, self.frequency, self.headers) + mod_mock = mock.Mock(frequency=settings.PER_INSTANCE, + handler_version=1) + handlers.run_part(mod_mock, self.data, self.filename, self.payload, + self.frequency, self.headers) + # Assert that the handle_part() method of the mock object got + # called with the expected arguments. + mod_mock.handle_part.assert_called_with_once( + self.data, self.ctype, self.filename, self.payload) def test_normal_version_2(self): """ C{handle_part} is called with C{frequency} for C{handler_version} == 2. """ - mod_mock = self.mocker.mock() - getattr(mod_mock, "frequency") - self.mocker.result(settings.PER_INSTANCE) - getattr(mod_mock, "handler_version") - self.mocker.result(2) - mod_mock.handle_part(self.data, self.ctype, self.filename, - self.payload, self.frequency) - self.mocker.replay() - - handlers.run_part(mod_mock, self.data, self.filename, - self.payload, self.frequency, self.headers) + mod_mock = mock.Mock(frequency=settings.PER_INSTANCE, + handler_version=2) + handlers.run_part(mod_mock, self.data, self.filename, self.payload, + self.frequency, self.headers) + # Assert that the handle_part() method of the mock object got + # called with the expected arguments. + mod_mock.handle_part.assert_called_with_once( + self.data, self.ctype, self.filename, self.payload) def test_modfreq_per_always(self): """ C{handle_part} is called regardless of frequency if nofreq is always. """ self.frequency = "once" - mod_mock = self.mocker.mock() - getattr(mod_mock, "frequency") - self.mocker.result(settings.PER_ALWAYS) - getattr(mod_mock, "handler_version") - self.mocker.result(1) - mod_mock.handle_part(self.data, self.ctype, self.filename, - self.payload) - self.mocker.replay() - - handlers.run_part(mod_mock, self.data, self.filename, - self.payload, self.frequency, self.headers) + mod_mock = mock.Mock(frequency=settings.PER_ALWAYS, + handler_version=1) + handlers.run_part(mod_mock, self.data, self.filename, self.payload, + self.frequency, self.headers) + # Assert that the handle_part() method of the mock object got + # called with the expected arguments. + mod_mock.handle_part.assert_called_with_once( + self.data, self.ctype, self.filename, self.payload) def test_no_handle_when_modfreq_once(self): """C{handle_part} is not called if frequency is once.""" self.frequency = "once" - mod_mock = self.mocker.mock() - getattr(mod_mock, "frequency") - self.mocker.result(settings.PER_ONCE) - self.mocker.replay() - - handlers.run_part(mod_mock, self.data, self.filename, - self.payload, self.frequency, self.headers) + mod_mock = mock.Mock(frequency=settings.PER_ONCE) + handlers.run_part(mod_mock, self.data, self.filename, self.payload, + self.frequency, self.headers) + # Assert that the handle_part() method of the mock object got + # called with the expected arguments. + mod_mock.handle_part.assert_called_with_once( + self.data, self.ctype, self.filename, self.payload) def test_exception_is_caught(self): """Exceptions within C{handle_part} are caught and logged.""" - mod_mock = self.mocker.mock() - getattr(mod_mock, "frequency") - self.mocker.result(settings.PER_INSTANCE) - getattr(mod_mock, "handler_version") - self.mocker.result(1) - mod_mock.handle_part(self.data, self.ctype, self.filename, - self.payload) - self.mocker.throw(Exception()) - self.mocker.replay() - - handlers.run_part(mod_mock, self.data, self.filename, - self.payload, self.frequency, self.headers) - - -class TestCmdlineUrl(MockerTestCase): + mod_mock = mock.Mock(frequency=settings.PER_INSTANCE, + handler_version=1) + handlers.run_part(mod_mock, self.data, self.filename, self.payload, + self.frequency, self.headers) + mod_mock.handle_part.side_effect = Exception + handlers.run_part(mod_mock, self.data, self.filename, self.payload, + self.frequency, self.headers) + mod_mock.handle_part.assert_called_with_once( + self.data, self.ctype, self.filename, self.payload) + + +class TestCmdlineUrl(unittest.TestCase): def test_invalid_content(self): url = "http://example.com/foo" key = "mykey" payload = "0" cmdline = "ro %s=%s bar=1" % (key, url) - mock_readurl = self.mocker.replace(url_helper.readurl, - passthrough=False) - mock_readurl(url, ARGS, KWARGS) - self.mocker.result(url_helper.StringResponse(payload)) - self.mocker.replay() - - self.assertEqual((key, url, None), - util.get_cmdline_url(names=[key], starts="xxxxxx", - cmdline=cmdline)) + with mock.patch('cloudinit.url_helper.readurl', + return_value=url_helper.StringResponse(payload)): + self.assertEqual( + util.get_cmdline_url(names=[key], starts="xxxxxx", + cmdline=cmdline), + (key, url, None)) def test_valid_content(self): url = "http://example.com/foo" @@ -204,27 +198,24 @@ class TestCmdlineUrl(MockerTestCase): payload = "xcloud-config\nmydata: foo\nbar: wark\n" cmdline = "ro %s=%s bar=1" % (key, url) - mock_readurl = self.mocker.replace(url_helper.readurl, - passthrough=False) - mock_readurl(url, ARGS, KWARGS) - self.mocker.result(url_helper.StringResponse(payload)) - self.mocker.replay() - - self.assertEqual((key, url, payload), - util.get_cmdline_url(names=[key], starts="xcloud-config", - cmdline=cmdline)) + with mock.patch('cloudinit.url_helper.readurl', + return_value=url_helper.StringResponse(payload)): + self.assertEqual( + util.get_cmdline_url(names=[key], starts="xcloud-config", + cmdline=cmdline), + (key, url, payload)) def test_no_key_found(self): url = "http://example.com/foo" key = "mykey" cmdline = "ro %s=%s bar=1" % (key, url) - self.mocker.replace(url_helper.readurl, passthrough=False) - self.mocker.result(url_helper.StringResponse("")) - self.mocker.replay() + with mock.patch('cloudinit.url_helper.readurl', + return_value=url_helper.StringResponse('')): + self.assertEqual( + util.get_cmdline_url(names=["does-not-appear"], + starts="#cloud-config", cmdline=cmdline), + (None, None, None)) - self.assertEqual((None, None, None), - util.get_cmdline_url(names=["does-not-appear"], - starts="#cloud-config", cmdline=cmdline)) # vi: ts=4 expandtab diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 203445b7..ee938969 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -2,7 +2,6 @@ import os import stat import yaml -from mocker import MockerTestCase from . import helpers import unittest @@ -61,7 +60,7 @@ class TestGetCfgOptionListOrStr(unittest.TestCase): self.assertEqual([], result) -class TestWriteFile(MockerTestCase): +class TestWriteFile(unittest.TestCase): def setUp(self): super(TestWriteFile, self).setUp() self.tmp = self.makeDir(prefix="unittest_") diff --git a/tox.ini b/tox.ini index 883f6196..e547c693 100644 --- a/tox.ini +++ b/tox.ini @@ -5,9 +5,9 @@ recreate = True [testenv] commands = python -m nose tests deps = + contextlib2 httpretty>=0.7.1 mock - mocker nose pep8==1.5.7 pyflakes -- cgit v1.2.3 From 9a16a6a2e9929098474e0db23db130cf172aeba4 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 14:46:21 -0500 Subject: Use .addCleanup() instead of a .tearDown() where appropriate, although we might have to rewrite this for Python 2.6. Disable Cepko tests (test_cs_util.py) since they are essentially worthless. Convert test_azure to unittest.mock. --- tests/unittests/test__init__.py | 9 ++-- tests/unittests/test_cs_util.py | 25 ++++++----- tests/unittests/test_datasource/test_azure.py | 64 +++++++++++++-------------- tox.ini | 2 +- 4 files changed, 49 insertions(+), 51 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test__init__.py b/tests/unittests/test__init__.py index ce4704d8..f5dc3435 100644 --- a/tests/unittests/test__init__.py +++ b/tests/unittests/test__init__.py @@ -57,14 +57,11 @@ class TestWalkerHandleHandler(unittest.TestCase): # Mock the write_file() function. We'll assert that it got called as # expected in each of the individual tests. - self.resources = ExitStack() - self.write_file_mock = self.resources.enter_context( + resources = ExitStack() + self.addCleanup(resources.close) + self.write_file_mock = resources.enter_context( mock.patch('cloudinit.util.write_file')) - def tearDown(self): - self.resources.close() - unittest.TestCase.tearDown(self) - def test_no_errors(self): """Payload gets written to file and added to C{pdata}.""" with mock.patch('cloudinit.importer.import_module', diff --git a/tests/unittests/test_cs_util.py b/tests/unittests/test_cs_util.py index 7d59222b..99fac84d 100644 --- a/tests/unittests/test_cs_util.py +++ b/tests/unittests/test_cs_util.py @@ -1,4 +1,4 @@ -from mocker import MockerTestCase +import unittest from cloudinit.cs_utils import Cepko @@ -26,16 +26,21 @@ class CepkoMock(Cepko): return SERVER_CONTEXT['tags'] -class CepkoResultTests(MockerTestCase): +# 2015-01-22 BAW: This test is completely useless because it only ever tests +# the CepkoMock object. Even in its original form, I don't think it ever +# touched the underlying Cepko class methods. +@unittest.skip('This test is completely useless') +class CepkoResultTests(unittest.TestCase): def setUp(self): - self.mocked = self.mocker.replace("cloudinit.cs_utils.Cepko", - spec=CepkoMock, - count=False, - passthrough=False) - self.mocked() - self.mocker.result(CepkoMock()) - self.mocker.replay() - self.c = Cepko() + pass + ## self.mocked = self.mocker.replace("cloudinit.cs_utils.Cepko", + ## spec=CepkoMock, + ## count=False, + ## passthrough=False) + ## self.mocked() + ## self.mocker.result(CepkoMock()) + ## self.mocker.replay() + ## self.c = Cepko() def test_getitem(self): result = self.c.all() diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 6e007a95..2dbcd389 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -3,12 +3,23 @@ from cloudinit.util import load_file from cloudinit.sources import DataSourceAzure from ..helpers import populate_dir +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + import base64 import crypt -from mocker import MockerTestCase import os import stat import yaml +import shutil +import tempfile +import unittest def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): @@ -66,26 +77,24 @@ def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): return content -class TestAzureDataSource(MockerTestCase): +class TestAzureDataSource(unittest.TestCase): def setUp(self): - # makeDir comes from MockerTestCase - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) # patch cloud_dir, so our 'seed_dir' is guaranteed empty self.paths = helpers.Paths({'cloud_dir': self.tmp}) self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent') - self.unapply = [] - super(TestAzureDataSource, self).setUp() + self.patches = ExitStack() + self.addCleanup(self.patches.close) - def tearDown(self): - apply_patches([i for i in reversed(self.unapply)]) - super(TestAzureDataSource, self).tearDown() + super(TestAzureDataSource, self).setUp() def apply_patches(self, patches): - ret = apply_patches(patches) - self.unapply += ret + for module, name, new in patches: + self.patches.enter_context(mock.patch.object(module, name, new)) def _get_ds(self, data): @@ -117,16 +126,14 @@ class TestAzureDataSource(MockerTestCase): mod = DataSourceAzure mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d - self.apply_patches([(mod, 'list_possible_azure_ds_devs', dsdevs)]) - - self.apply_patches([(mod, 'invoke_agent', _invoke_agent), - (mod, 'wait_for_files', _wait_for_files), - (mod, 'pubkeys_from_crt_files', - _pubkeys_from_crt_files), - (mod, 'iid_from_shared_config', - _iid_from_shared_config), - (mod, 'apply_hostname_bounce', - _apply_hostname_bounce), ]) + self.apply_patches([ + (mod, 'list_possible_azure_ds_devs', dsdevs), + (mod, 'invoke_agent', _invoke_agent), + (mod, 'wait_for_files', _wait_for_files), + (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files), + (mod, 'iid_from_shared_config', _iid_from_shared_config), + (mod, 'apply_hostname_bounce', _apply_hostname_bounce), + ]) dsrc = mod.DataSourceAzureNet( data.get('sys_cfg', {}), distro=None, paths=self.paths) @@ -402,7 +409,7 @@ class TestAzureDataSource(MockerTestCase): load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) -class TestReadAzureOvf(MockerTestCase): +class TestReadAzureOvf(unittest.TestCase): def test_invalid_xml_raises_non_azure_ds(self): invalid_xml = "" + construct_valid_ovf_env(data={}) self.assertRaises(DataSourceAzure.BrokenAzureDataSource, @@ -417,7 +424,7 @@ class TestReadAzureOvf(MockerTestCase): self.assertIn(mypk, cfg['_pubkeys']) -class TestReadAzureSharedConfig(MockerTestCase): +class TestReadAzureSharedConfig(unittest.TestCase): def test_valid_content(self): xml = """ @@ -429,14 +436,3 @@ class TestReadAzureSharedConfig(MockerTestCase): """ ret = DataSourceAzure.iid_from_shared_config_content(xml) self.assertEqual("MY_INSTANCE_ID", ret) - - -def apply_patches(patches): - ret = [] - for (ref, name, replace) in patches: - if replace is None: - continue - orig = getattr(ref, name) - setattr(ref, name, replace) - ret.append((ref, name, orig)) - return ret diff --git a/tox.ini b/tox.ini index e547c693..d050d7c8 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ envlist = py26,py27,py34 recreate = True [testenv] -commands = python -m nose tests +commands = python -m nose -v tests deps = contextlib2 httpretty>=0.7.1 -- cgit v1.2.3 From 6f3fadcfeec94711de32574f7738f98aa0a697ac Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 15:41:52 -0500 Subject: Convert helpers.py and test_data.py from mocker to mock. --- tests/unittests/helpers.py | 56 +++++++++++++--------- tests/unittests/test_data.py | 110 +++++++++++++++++++++++++------------------ tox.ini | 2 +- 3 files changed, 97 insertions(+), 71 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 4ca460f1..43dd5aa0 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -1,7 +1,17 @@ import os import sys +import tempfile import unittest +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + from cloudinit import helpers as ch from cloudinit import util @@ -111,7 +121,11 @@ def retarget_many_wrapper(new_base, am, old_func): class ResourceUsingTestCase(unittest.TestCase): - def __init__(self, methodName="runTest"): + ## def __init__(self, methodName="runTest"): + ## self.resource_path = None + + def setUp(self): + unittest.TestCase.setUp(self) self.resource_path = None def resourceLocation(self, subname=None): @@ -139,17 +153,23 @@ class ResourceUsingTestCase(unittest.TestCase): return fh.read() def getCloudPaths(self): + tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmpdir) cp = ch.Paths({ - 'cloud_dir': self.makeDir(), + 'cloud_dir': tmpdir, 'templates_dir': self.resourceLocation(), }) return cp class FilesystemMockingTestCase(ResourceUsingTestCase): - def __init__(self, methodName="runTest"): - ResourceUsingTestCase.__init__(self, methodName) - self.patched_funcs = [] + ## def __init__(self, methodName="runTest"): + ## ResourceUsingTestCase.__init__(self, methodName) + + def setUp(self): + ResourceUsingTestCase.setUp(self) + self.patched_funcs = ExitStack() + self.addCleanup(self.patched_funcs.close) def replicateTestRoot(self, example_root, target_root): real_root = self.resourceLocation() @@ -163,15 +183,6 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): make_path = util.abs_join(make_path, f) shutil.copy(real_path, make_path) - def tearDown(self): - self.restore() - ResourceUsingTestCase.tearDown(self) - - def restore(self): - for (mod, f, func) in self.patched_funcs: - setattr(mod, f, func) - self.patched_funcs = [] - def patchUtils(self, new_root): patch_funcs = { util: [('write_file', 1), @@ -188,8 +199,8 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): for (f, am) in funcs: func = getattr(mod, f) trap_func = retarget_many_wrapper(new_root, am, func) - setattr(mod, f, trap_func) - self.patched_funcs.append((mod, f, func)) + self.patched_funcs.enter_context( + mock.patch.object(mod, f, trap_func)) # Handle subprocess calls func = getattr(util, 'subp') @@ -197,16 +208,15 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): def nsubp(*_args, **_kwargs): return ('', '') - setattr(util, 'subp', nsubp) - self.patched_funcs.append((util, 'subp', func)) + self.patched_funcs.enter_context( + mock.patch.object(util, 'subp', nsubp)) def null_func(*_args, **_kwargs): return None for f in ['chownbyid', 'chownbyname']: - func = getattr(util, f) - setattr(util, f, null_func) - self.patched_funcs.append((util, f, func)) + self.patched_funcs.enter_context( + mock.patch.object(util, f, null_func)) def patchOS(self, new_root): patch_funcs = { @@ -217,8 +227,8 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): for f in funcs: func = getattr(mod, f) trap_func = retarget_many_wrapper(new_root, 1, func) - setattr(mod, f, trap_func) - self.patched_funcs.append((mod, f, func)) + self.patched_funcs.enter_context( + mock.patch.object(mod, f, trap_func)) class HttprettyTestCase(TestCase): diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index a35afc27..71e02e5d 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -3,6 +3,13 @@ import gzip import logging import os +import shutil +import tempfile + +try: + from unittest import mock +except ImportError: + import mock from six import BytesIO, StringIO @@ -43,9 +50,9 @@ class TestConsumeUserData(helpers.FilesystemMockingTestCase): self._log_handler = None def tearDown(self): - helpers.FilesystemMockingTestCase.tearDown(self) if self._log_handler and self._log: self._log.removeHandler(self._log_handler) + helpers.FilesystemMockingTestCase.tearDown(self) def _patchIn(self, root): self.restore() @@ -71,7 +78,8 @@ class TestConsumeUserData(helpers.FilesystemMockingTestCase): ci = stages.Init() ci.datasource = FakeDataSource(blob) - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self.patchUtils(new_root) self.patchOS(new_root) ci.fetch() @@ -99,7 +107,8 @@ class TestConsumeUserData(helpers.FilesystemMockingTestCase): { "op": "add", "path": "/foo", "value": "quxC" } ] ''' - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self._patchIn(new_root) initer = stages.Init() initer.datasource = FakeDataSource(user_blob, vendordata=vendor_blob) @@ -138,7 +147,8 @@ class TestConsumeUserData(helpers.FilesystemMockingTestCase): { "op": "add", "path": "/foo", "value": "quxC" } ] ''' - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self._patchIn(new_root) initer = stages.Init() initer.datasource = FakeDataSource(user_blob, vendordata=vendor_blob) @@ -184,7 +194,8 @@ c: d ci = stages.Init() ci.datasource = FakeDataSource(str(message)) - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self.patchUtils(new_root) self.patchOS(new_root) ci.fetch() @@ -214,7 +225,8 @@ name: user run: - z ''' - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self._patchIn(new_root) initer = stages.Init() initer.datasource = FakeDataSource(user_blob, vendordata=vendor_blob) @@ -249,7 +261,8 @@ vendor_data: enabled: True prefix: /bin/true ''' - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self._patchIn(new_root) initer = stages.Init() initer.datasource = FakeDataSource(user_blob, vendordata=vendor_blob) @@ -309,7 +322,8 @@ p: 1 paths = c_helpers.Paths({}, ds=FakeDataSource('')) cloud_cfg = handlers.cloud_config.CloudConfigPartHandler(paths) - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self.patchUtils(new_root) self.patchOS(new_root) cloud_cfg.handle_part(None, handlers.CONTENT_START, None, None, None, @@ -374,7 +388,8 @@ c: 4 message.attach(gzip_part(base_content2)) ci = stages.Init() ci.datasource = FakeDataSource(str(message)) - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self.patchUtils(new_root) self.patchOS(new_root) ci.fetch() @@ -394,17 +409,15 @@ c: 4 message.set_payload("Just text") ci.datasource = FakeDataSource(message.as_string()) - mock_write = self.mocker.replace("cloudinit.util.write_file", - passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) - self.mocker.replay() - - log_file = self.capture_log(logging.WARNING) - ci.fetch() - ci.consume_data() - self.assertIn( - "Unhandled unknown content-type (text/plain)", - log_file.getvalue()) + with mock.patch('cloudinit.util.write_file') as mockobj: + log_file = self.capture_log(logging.WARNING) + ci.fetch() + ci.consume_data() + self.assertIn( + "Unhandled unknown content-type (text/plain)", + log_file.getvalue()) + mockobj.assert_called_once_with( + ci.paths.get_ipath("cloud_config"), "", 0o600) def test_shellscript(self): """Raw text starting #!/bin/sh is treated as script.""" @@ -413,16 +426,17 @@ c: 4 ci.datasource = FakeDataSource(script) outpath = os.path.join(ci.paths.get_ipath_cur("scripts"), "part-001") - mock_write = self.mocker.replace("cloudinit.util.write_file", - passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) - mock_write(outpath, script, 0o700) - self.mocker.replay() - log_file = self.capture_log(logging.WARNING) - ci.fetch() - ci.consume_data() - self.assertEqual("", log_file.getvalue()) + with mock.patch('cloudinit.util.write_file') as mockobj: + log_file = self.capture_log(logging.WARNING) + ci.fetch() + ci.consume_data() + self.assertEqual("", log_file.getvalue()) + + mockobj.assert_has_calls([ + mock.call(outpath, script, 0o700), + mock.call(ci.paths.get_ipath("cloud_config"), "", 0o600), + ]) def test_mime_text_x_shellscript(self): """Mime message of type text/x-shellscript is treated as script.""" @@ -433,16 +447,17 @@ c: 4 ci.datasource = FakeDataSource(message.as_string()) outpath = os.path.join(ci.paths.get_ipath_cur("scripts"), "part-001") - mock_write = self.mocker.replace("cloudinit.util.write_file", - passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) - mock_write(outpath, script, 0o700) - self.mocker.replay() - log_file = self.capture_log(logging.WARNING) - ci.fetch() - ci.consume_data() - self.assertEqual("", log_file.getvalue()) + with mock.patch('cloudinit.util.write_file') as mockobj: + log_file = self.capture_log(logging.WARNING) + ci.fetch() + ci.consume_data() + self.assertEqual("", log_file.getvalue()) + + mockobj.assert_has_calls([ + mock.call(outpath, script, 0o700), + mock.call(ci.paths.get_ipath("cloud_config"), "", 0o600), + ]) def test_mime_text_plain_shell(self): """Mime type text/plain starting #!/bin/sh is treated as script.""" @@ -453,13 +468,14 @@ c: 4 ci.datasource = FakeDataSource(message.as_string()) outpath = os.path.join(ci.paths.get_ipath_cur("scripts"), "part-001") - mock_write = self.mocker.replace("cloudinit.util.write_file", - passthrough=False) - mock_write(outpath, script, 0o700) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) - self.mocker.replay() - log_file = self.capture_log(logging.WARNING) - ci.fetch() - ci.consume_data() - self.assertEqual("", log_file.getvalue()) + with mock.patch('cloudinit.util.write_file') as mockobj: + log_file = self.capture_log(logging.WARNING) + ci.fetch() + ci.consume_data() + self.assertEqual("", log_file.getvalue()) + + mockobj.assert_has_calls([ + mock.call(outpath, script, 0o700), + mock.call(ci.paths.get_ipath("cloud_config"), "", 0o600), + ]) diff --git a/tox.ini b/tox.ini index d050d7c8..e547c693 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ envlist = py26,py27,py34 recreate = True [testenv] -commands = python -m nose -v tests +commands = python -m nose tests deps = contextlib2 httpretty>=0.7.1 -- cgit v1.2.3 From b3bbd3985a75d00115e1623dc1ebd4248a7024a6 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 15:42:30 -0500 Subject: Clean up. --- tests/unittests/helpers.py | 6 ------ 1 file changed, 6 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 43dd5aa0..38a2176d 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -121,9 +121,6 @@ def retarget_many_wrapper(new_base, am, old_func): class ResourceUsingTestCase(unittest.TestCase): - ## def __init__(self, methodName="runTest"): - ## self.resource_path = None - def setUp(self): unittest.TestCase.setUp(self) self.resource_path = None @@ -163,9 +160,6 @@ class ResourceUsingTestCase(unittest.TestCase): class FilesystemMockingTestCase(ResourceUsingTestCase): - ## def __init__(self, methodName="runTest"): - ## ResourceUsingTestCase.__init__(self, methodName) - def setUp(self): ResourceUsingTestCase.setUp(self) self.patched_funcs = ExitStack() -- cgit v1.2.3 From 508670bdaf9545b6bcc8e2009e8bd3f08d6f8796 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 18:38:30 -0500 Subject: More test ports from mocker to mock. --- tests/unittests/test_builtin_handlers.py | 38 +++++--- tests/unittests/test_data.py | 21 ++--- .../unittests/test_datasource/test_configdrive.py | 104 ++++++++++++--------- tests/unittests/test_datasource/test_maas.py | 75 +++++++++------ tests/unittests/test_datasource/test_smartos.py | 10 +- tests/unittests/test_distros/test_generic.py | 5 +- tests/unittests/test_handler/test_handler_chef.py | 5 +- tests/unittests/test_handler/test_handler_debug.py | 5 +- .../unittests/test_handler/test_handler_locale.py | 5 +- .../test_handler/test_handler_set_hostname.py | 5 +- .../test_handler/test_handler_timezone.py | 5 +- .../test_handler/test_handler_yum_add_repo.py | 5 +- tests/unittests/test_runs/test_merge_run.py | 5 +- tests/unittests/test_runs/test_simple_run.py | 5 +- tests/unittests/test_util.py | 5 +- 15 files changed, 186 insertions(+), 112 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_builtin_handlers.py b/tests/unittests/test_builtin_handlers.py index af7f442e..47ff6318 100644 --- a/tests/unittests/test_builtin_handlers.py +++ b/tests/unittests/test_builtin_handlers.py @@ -1,6 +1,13 @@ """Tests of the built-in user data handlers.""" import os +import shutil +import tempfile + +try: + from unittest import mock +except ImportError: + import mock from . import helpers as test_helpers @@ -16,8 +23,10 @@ from cloudinit.settings import (PER_ALWAYS, PER_INSTANCE) class TestBuiltins(test_helpers.FilesystemMockingTestCase): def test_upstart_frequency_no_out(self): - c_root = self.makeDir() - up_root = self.makeDir() + c_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, c_root) + up_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, up_root) paths = helpers.Paths({ 'cloud_dir': c_root, 'upstart_dir': up_root, @@ -36,7 +45,8 @@ class TestBuiltins(test_helpers.FilesystemMockingTestCase): def test_upstart_frequency_single(self): # files should be written out when frequency is ! per-instance - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) freq = PER_INSTANCE self.patchOS(new_root) @@ -49,16 +59,16 @@ class TestBuiltins(test_helpers.FilesystemMockingTestCase): util.ensure_dir("/run") util.ensure_dir("/etc/upstart") - mock_subp = self.mocker.replace(util.subp, passthrough=False) - mock_subp(["initctl", "reload-configuration"], capture=False) - self.mocker.replay() + with mock.patch.object(util, 'subp') as mockobj: + h = upstart_job.UpstartJobPartHandler(paths) + h.handle_part('', handlers.CONTENT_START, + None, None, None) + h.handle_part('blah', 'text/upstart-job', + 'test.conf', 'blah', freq) + h.handle_part('', handlers.CONTENT_END, + None, None, None) - h = upstart_job.UpstartJobPartHandler(paths) - h.handle_part('', handlers.CONTENT_START, - None, None, None) - h.handle_part('blah', 'text/upstart-job', - 'test.conf', 'blah', freq) - h.handle_part('', handlers.CONTENT_END, - None, None, None) + self.assertEquals(len(os.listdir('/etc/upstart')), 1) - self.assertEquals(1, len(os.listdir('/etc/upstart'))) + mockobj.assert_called_once_with( + ['initctl', 'reload-configuration'], capture=False) diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 71e02e5d..7598837e 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -55,7 +55,6 @@ class TestConsumeUserData(helpers.FilesystemMockingTestCase): helpers.FilesystemMockingTestCase.tearDown(self) def _patchIn(self, root): - self.restore() self.patchOS(root) self.patchUtils(root) @@ -349,17 +348,17 @@ p: 1 data = "arbitrary text\n" ci.datasource = FakeDataSource(data) - mock_write = self.mocker.replace("cloudinit.util.write_file", - passthrough=False) - mock_write(ci.paths.get_ipath("cloud_config"), "", 0o600) - self.mocker.replay() + with mock.patch('cloudinit.util.write_file') as mockobj: + log_file = self.capture_log(logging.WARNING) + ci.fetch() + ci.consume_data() + self.assertIn( + "Unhandled non-multipart (text/x-not-multipart) userdata:", + log_file.getvalue()) + + mockobj.assert_called_once_with( + ci.paths.get_ipath("cloud_config"), "", 0o600) - log_file = self.capture_log(logging.WARNING) - ci.fetch() - ci.consume_data() - self.assertIn( - "Unhandled non-multipart (text/x-not-multipart) userdata:", - log_file.getvalue()) def test_mime_gzip_compressed(self): """Tests that individual message gzip encoding works.""" diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index d88066e5..800c5fd8 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -1,10 +1,18 @@ from copy import copy import json import os -import os.path - -import mocker -from mocker import MockerTestCase +import shutil +import tempfile +import unittest + +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack from cloudinit import helpers from cloudinit import settings @@ -12,8 +20,6 @@ from cloudinit.sources import DataSourceConfigDrive as ds from cloudinit.sources.helpers import openstack from cloudinit import util -from .. import helpers as unit_helpers - PUBKEY = u'ssh-rsa AAAAB3NzaC1....sIkJhq8wdX+4I3A4cYbYP ubuntu@server-460\n' EC2_META = { 'ami-id': 'ami-00000001', @@ -64,11 +70,12 @@ CFG_DRIVE_FILES_V2 = { 'openstack/latest/user_data': USER_DATA} -class TestConfigDriveDataSource(MockerTestCase): +class TestConfigDriveDataSource(unittest.TestCase): def setUp(self): super(TestConfigDriveDataSource, self).setUp() - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def test_ec2_metadata(self): populate_dir(self.tmp, CFG_DRIVE_FILES_V2) @@ -91,23 +98,28 @@ class TestConfigDriveDataSource(MockerTestCase): 'swap': '/dev/vda3', } for name, dev_name in name_tests.items(): - with unit_helpers.mocker() as my_mock: - find_mock = my_mock.replace(util.find_devs_with, - spec=False, passthrough=False) + with ExitStack() as mocks: provided_name = dev_name[len('/dev/'):] provided_name = "s" + provided_name[1:] - find_mock(mocker.ARGS) - my_mock.result([provided_name]) - exists_mock = my_mock.replace(os.path.exists, - spec=False, passthrough=False) - exists_mock(mocker.ARGS) - my_mock.result(False) - exists_mock(mocker.ARGS) - my_mock.result(True) - my_mock.replay() + find_mock = mocks.enter_context( + mock.patch.object(util, 'find_devs_with', + return_value=[provided_name])) + # We want os.path.exists() to return False on its first call, + # and True on its second call. We use a handy generator as + # the mock side effect for this. The mocked function returns + # what the side effect returns. + def exists_side_effect(): + yield False + yield True + exists_mock = mocks.enter_context( + mock.patch.object(os.path, 'exists', + side_effect=exists_side_effect())) device = cfg_ds.device_name_to_device(name) self.assertEquals(dev_name, device) + find_mock.assert_called_once_with(mock.ANY) + self.assertEqual(exists_mock.call_count, 2) + def test_dev_os_map(self): populate_dir(self.tmp, CFG_DRIVE_FILES_V2) cfg_ds = ds.DataSourceConfigDrive(settings.CFG_BUILTIN, @@ -123,19 +135,19 @@ class TestConfigDriveDataSource(MockerTestCase): 'swap': '/dev/vda3', } for name, dev_name in name_tests.items(): - with unit_helpers.mocker() as my_mock: - find_mock = my_mock.replace(util.find_devs_with, - spec=False, passthrough=False) - find_mock(mocker.ARGS) - my_mock.result([dev_name]) - exists_mock = my_mock.replace(os.path.exists, - spec=False, passthrough=False) - exists_mock(mocker.ARGS) - my_mock.result(True) - my_mock.replay() + with ExitStack() as mocks: + find_mock = mocks.enter_context( + mock.patch.object(util, 'find_devs_with', + return_value=[dev_name])) + exists_mock = mocks.enter_context( + mock.patch.object(os.path, 'exists', + return_value=True)) device = cfg_ds.device_name_to_device(name) self.assertEquals(dev_name, device) + find_mock.assert_called_once_with(mock.ANY) + exists_mock.assert_called_once_with(mock.ANY) + def test_dev_ec2_remap(self): populate_dir(self.tmp, CFG_DRIVE_FILES_V2) cfg_ds = ds.DataSourceConfigDrive(settings.CFG_BUILTIN, @@ -156,16 +168,21 @@ class TestConfigDriveDataSource(MockerTestCase): 'root2k': None, } for name, dev_name in name_tests.items(): - with unit_helpers.mocker(verify_calls=False) as my_mock: - exists_mock = my_mock.replace(os.path.exists, - spec=False, passthrough=False) - exists_mock(mocker.ARGS) - my_mock.result(False) - exists_mock(mocker.ARGS) - my_mock.result(True) - my_mock.replay() + # We want os.path.exists() to return False on its first call, + # and True on its second call. We use a handy generator as + # the mock side effect for this. The mocked function returns + # what the side effect returns. + def exists_side_effect(): + yield False + yield True + with mock.patch.object(os.path, 'exists', + side_effect=exists_side_effect()): device = cfg_ds.device_name_to_device(name) self.assertEquals(dev_name, device) + # We don't assert the call count for os.path.exists() because + # not all of the entries in name_tests results in two calls to + # that function. Specifically, 'root2k' doesn't seem to call + # it at all. def test_dev_ec2_map(self): populate_dir(self.tmp, CFG_DRIVE_FILES_V2) @@ -173,12 +190,6 @@ class TestConfigDriveDataSource(MockerTestCase): None, helpers.Paths({})) found = ds.read_config_drive(self.tmp) - exists_mock = self.mocker.replace(os.path.exists, - spec=False, passthrough=False) - exists_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result(True) - self.mocker.replay() ec2_md = found['ec2-metadata'] os_md = found['metadata'] cfg_ds.ec2_metadata = ec2_md @@ -193,8 +204,9 @@ class TestConfigDriveDataSource(MockerTestCase): 'root2k': None, } for name, dev_name in name_tests.items(): - device = cfg_ds.device_name_to_device(name) - self.assertEquals(dev_name, device) + with mock.patch.object(os.path, 'exists', return_value=True): + device = cfg_ds.device_name_to_device(name) + self.assertEquals(dev_name, device) def test_dir_valid(self): """Verify a dir is read as such.""" diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index c157beb8..6af0cd82 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -1,19 +1,26 @@ from copy import copy import os +import shutil +import tempfile +import unittest from cloudinit.sources import DataSourceMAAS from cloudinit import url_helper from ..helpers import populate_dir -import mocker +try: + from unittest import mock +except ImportError: + import mock -class TestMAASDataSource(mocker.MockerTestCase): +class TestMAASDataSource(unittest.TestCase): def setUp(self): super(TestMAASDataSource, self).setUp() # Make a temp directoy for tests to use. - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def test_seed_dir_valid(self): """Verify a valid seeddir is read as such.""" @@ -93,16 +100,18 @@ class TestMAASDataSource(mocker.MockerTestCase): def test_seed_url_valid(self): """Verify that valid seed_url is read as such.""" - valid = {'meta-data/instance-id': 'i-instanceid', + valid = { + 'meta-data/instance-id': 'i-instanceid', 'meta-data/local-hostname': 'test-hostname', 'meta-data/public-keys': 'test-hostname', - 'user-data': 'foodata'} + 'user-data': 'foodata', + } valid_order = [ 'meta-data/local-hostname', 'meta-data/instance-id', 'meta-data/public-keys', 'user-data', - ] + ] my_seed = "http://example.com/xmeta" my_ver = "1999-99-99" my_headers = {'header1': 'value1', 'header2': 'value2'} @@ -110,28 +119,38 @@ class TestMAASDataSource(mocker.MockerTestCase): def my_headers_cb(url): return my_headers - mock_request = self.mocker.replace(url_helper.readurl, - passthrough=False) - - for key in valid_order: - url = "%s/%s/%s" % (my_seed, my_ver, key) - mock_request(url, headers=None, timeout=mocker.ANY, - data=mocker.ANY, sec_between=mocker.ANY, - ssl_details=mocker.ANY, retries=mocker.ANY, - headers_cb=my_headers_cb, - exception_cb=mocker.ANY) - resp = valid.get(key) - self.mocker.result(url_helper.StringResponse(resp)) - self.mocker.replay() - - (userdata, metadata) = DataSourceMAAS.read_maas_seed_url(my_seed, - header_cb=my_headers_cb, version=my_ver) - - self.assertEqual("foodata", userdata) - self.assertEqual(metadata['instance-id'], - valid['meta-data/instance-id']) - self.assertEqual(metadata['local-hostname'], - valid['meta-data/local-hostname']) + # Each time url_helper.readurl() is called, something different is + # returned based on the canned data above. We need to build up a list + # of side effect return values, which the mock will return. At the + # same time, we'll build up a list of expected call arguments for + # asserting after the code under test is run. + calls = [] + + def side_effect(): + for key in valid_order: + resp = valid.get(key) + url = "%s/%s/%s" % (my_seed, my_ver, key) + calls.append( + mock.call(url, headers=None, timeout=mock.ANY, + data=mock.ANY, sec_between=mock.ANY, + ssl_details=mock.ANY, retries=mock.ANY, + headers_cb=my_headers_cb, + exception_cb=mock.ANY)) + yield url_helper.StringResponse(resp) + + # Now do the actual call of the code under test. + with mock.patch.object(url_helper, 'readurl', + side_effect=side_effect()) as mockobj: + userdata, metadata = DataSourceMAAS.read_maas_seed_url( + my_seed, header_cb=my_headers_cb, version=my_ver) + + self.assertEqual("foodata", userdata) + self.assertEqual(metadata['instance-id'], + valid['meta-data/instance-id']) + self.assertEqual(metadata['local-hostname'], + valid['meta-data/local-hostname']) + + mockobj.has_calls(calls) def test_seed_url_invalid(self): """Verify that invalid seed_url raises MAASSeedDirMalformed.""" diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 65675106..35d7ef5e 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -29,9 +29,12 @@ from .. import helpers import os import os.path import re +import shutil +import tempfile import stat import uuid + MOCK_RETURNS = { 'hostname': 'test-host', 'root_authorized_keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname', @@ -109,9 +112,10 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): def setUp(self): helpers.FilesystemMockingTestCase.setUp(self) - # makeDir comes from MockerTestCase - self.tmp = self.makeDir() - self.legacy_user_d = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + self.legacy_user_d = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.legacy_user_d) # If you should want to watch the logs... self._log = None diff --git a/tests/unittests/test_distros/test_generic.py b/tests/unittests/test_distros/test_generic.py index db6aa0e8..2c85cbdb 100644 --- a/tests/unittests/test_distros/test_generic.py +++ b/tests/unittests/test_distros/test_generic.py @@ -4,6 +4,8 @@ from cloudinit import util from .. import helpers import os +import shutil +import tempfile unknown_arch_info = { 'arches': ['default'], @@ -53,7 +55,8 @@ class TestGenericDistro(helpers.FilesystemMockingTestCase): def setUp(self): super(TestGenericDistro, self).setUp() # Make a temp directoy for tests to use. - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def _write_load_sudoers(self, _user, rules): cls = distros.fetch("ubuntu") diff --git a/tests/unittests/test_handler/test_handler_chef.py b/tests/unittests/test_handler/test_handler_chef.py index ef1aa208..b06a160c 100644 --- a/tests/unittests/test_handler/test_handler_chef.py +++ b/tests/unittests/test_handler/test_handler_chef.py @@ -12,6 +12,8 @@ from cloudinit.sources import DataSourceNone from .. import helpers as t_help import logging +import shutil +import tempfile LOG = logging.getLogger(__name__) @@ -19,7 +21,8 @@ LOG = logging.getLogger(__name__) class TestChef(t_help.FilesystemMockingTestCase): def setUp(self): super(TestChef, self).setUp() - self.tmp = self.makeDir(prefix="unittest_") + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def fetch_cloud(self, distro_kind): cls = distros.fetch(distro_kind) diff --git a/tests/unittests/test_handler/test_handler_debug.py b/tests/unittests/test_handler/test_handler_debug.py index 8891ca04..80708d7b 100644 --- a/tests/unittests/test_handler/test_handler_debug.py +++ b/tests/unittests/test_handler/test_handler_debug.py @@ -26,6 +26,8 @@ from cloudinit.sources import DataSourceNone from .. import helpers as t_help import logging +import shutil +import tempfile LOG = logging.getLogger(__name__) @@ -33,7 +35,8 @@ LOG = logging.getLogger(__name__) class TestDebug(t_help.FilesystemMockingTestCase): def setUp(self): super(TestDebug, self).setUp() - self.new_root = self.makeDir(prefix="unittest_") + self.new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.new_root) def _get_cloud(self, distro, metadata=None): self.patchUtils(self.new_root) diff --git a/tests/unittests/test_handler/test_handler_locale.py b/tests/unittests/test_handler/test_handler_locale.py index 690ef86f..de85eff6 100644 --- a/tests/unittests/test_handler/test_handler_locale.py +++ b/tests/unittests/test_handler/test_handler_locale.py @@ -32,6 +32,8 @@ from configobj import ConfigObj from six import BytesIO import logging +import shutil +import tempfile LOG = logging.getLogger(__name__) @@ -39,7 +41,8 @@ LOG = logging.getLogger(__name__) class TestLocale(t_help.FilesystemMockingTestCase): def setUp(self): super(TestLocale, self).setUp() - self.new_root = self.makeDir(prefix="unittest_") + self.new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.new_root) def _get_cloud(self, distro): self.patchUtils(self.new_root) diff --git a/tests/unittests/test_handler/test_handler_set_hostname.py b/tests/unittests/test_handler/test_handler_set_hostname.py index a9f7829b..d358b069 100644 --- a/tests/unittests/test_handler/test_handler_set_hostname.py +++ b/tests/unittests/test_handler/test_handler_set_hostname.py @@ -7,6 +7,8 @@ from cloudinit import util from .. import helpers as t_help +import shutil +import tempfile import logging from six import BytesIO @@ -19,7 +21,8 @@ LOG = logging.getLogger(__name__) class TestHostname(t_help.FilesystemMockingTestCase): def setUp(self): super(TestHostname, self).setUp() - self.tmp = self.makeDir(prefix="unittest_") + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def _fetch_distro(self, kind): cls = distros.fetch(kind) diff --git a/tests/unittests/test_handler/test_handler_timezone.py b/tests/unittests/test_handler/test_handler_timezone.py index 10ea2040..e3df8759 100644 --- a/tests/unittests/test_handler/test_handler_timezone.py +++ b/tests/unittests/test_handler/test_handler_timezone.py @@ -31,6 +31,8 @@ from configobj import ConfigObj from six import BytesIO +import shutil +import tempfile import logging LOG = logging.getLogger(__name__) @@ -39,7 +41,8 @@ LOG = logging.getLogger(__name__) class TestTimezone(t_help.FilesystemMockingTestCase): def setUp(self): super(TestTimezone, self).setUp() - self.new_root = self.makeDir(prefix="unittest_") + self.new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.new_root) def _get_cloud(self, distro): self.patchUtils(self.new_root) diff --git a/tests/unittests/test_handler/test_handler_yum_add_repo.py b/tests/unittests/test_handler/test_handler_yum_add_repo.py index 81806ad1..3a8aa7c1 100644 --- a/tests/unittests/test_handler/test_handler_yum_add_repo.py +++ b/tests/unittests/test_handler/test_handler_yum_add_repo.py @@ -4,6 +4,8 @@ from cloudinit.config import cc_yum_add_repo from .. import helpers +import shutil +import tempfile import logging from six import BytesIO @@ -16,7 +18,8 @@ LOG = logging.getLogger(__name__) class TestConfig(helpers.FilesystemMockingTestCase): def setUp(self): super(TestConfig, self).setUp() - self.tmp = self.makeDir(prefix="unittest_") + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def test_bad_config(self): cfg = { diff --git a/tests/unittests/test_runs/test_merge_run.py b/tests/unittests/test_runs/test_merge_run.py index 977adb34..2d920eb8 100644 --- a/tests/unittests/test_runs/test_merge_run.py +++ b/tests/unittests/test_runs/test_merge_run.py @@ -1,4 +1,6 @@ import os +import shutil +import tempfile from .. import helpers @@ -14,7 +16,8 @@ class TestMergeRun(helpers.FilesystemMockingTestCase): self.patchUtils(root) def test_none_ds(self): - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self.replicateTestRoot('simple_ubuntu', new_root) cfg = { 'datasource_list': ['None'], diff --git a/tests/unittests/test_runs/test_simple_run.py b/tests/unittests/test_runs/test_simple_run.py index 2d51a337..0279b8b0 100644 --- a/tests/unittests/test_runs/test_simple_run.py +++ b/tests/unittests/test_runs/test_simple_run.py @@ -1,4 +1,6 @@ import os +import shutil +import tempfile from .. import helpers @@ -33,7 +35,8 @@ class TestSimpleRun(helpers.FilesystemMockingTestCase): self._patchIn(root) def test_none_ds(self): - new_root = self.makeDir() + new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, new_root) self.replicateTestRoot('simple_ubuntu', new_root) cfg = { 'datasource_list': ['None'], diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index ee938969..5ac47b80 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -1,6 +1,8 @@ import os import stat import yaml +import shutil +import tempfile from . import helpers import unittest @@ -63,7 +65,8 @@ class TestGetCfgOptionListOrStr(unittest.TestCase): class TestWriteFile(unittest.TestCase): def setUp(self): super(TestWriteFile, self).setUp() - self.tmp = self.makeDir(prefix="unittest_") + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def test_basic_usage(self): """Verify basic usage with default args.""" -- cgit v1.2.3 From 8329ad5bad5d6b0fcb8f4736fb308914ca3600c7 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 19:53:04 -0500 Subject: More conversions from mocker to mock. --- tests/unittests/test_datasource/test_nocloud.py | 52 ++-- tests/unittests/test_datasource/test_opennebula.py | 14 +- tests/unittests/test_distros/test_generic.py | 1 - tests/unittests/test_distros/test_hostname.py | 4 +- tests/unittests/test_distros/test_hosts.py | 4 +- tests/unittests/test_distros/test_netconfig.py | 265 +++++++++------------ 6 files changed, 151 insertions(+), 189 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_nocloud.py b/tests/unittests/test_datasource/test_nocloud.py index ae9e6c22..480a4012 100644 --- a/tests/unittests/test_datasource/test_nocloud.py +++ b/tests/unittests/test_datasource/test_nocloud.py @@ -3,33 +3,38 @@ from cloudinit.sources import DataSourceNoCloud from cloudinit import util from ..helpers import populate_dir -from mocker import MockerTestCase import os import yaml +import shutil +import tempfile +import unittest +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack -class TestNoCloudDataSource(MockerTestCase): + +class TestNoCloudDataSource(unittest.TestCase): def setUp(self): - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) self.paths = helpers.Paths({'cloud_dir': self.tmp}) self.cmdline = "root=TESTCMDLINE" - self.unapply = [] - self.apply_patches([(util, 'get_cmdline', self._getcmdline)]) - super(TestNoCloudDataSource, self).setUp() - - def tearDown(self): - apply_patches([i for i in reversed(self.unapply)]) - super(TestNoCloudDataSource, self).tearDown() + self.mocks = ExitStack() + self.addCleanup(self.mocks.close) - def apply_patches(self, patches): - ret = apply_patches(patches) - self.unapply += ret + self.mocks.enter_context( + mock.patch.object(util, 'get_cmdline', return_value=self.cmdline)) - def _getcmdline(self): - return self.cmdline + super(TestNoCloudDataSource, self).setUp() def test_nocloud_seed_dir(self): md = {'instance-id': 'IID', 'dsmode': 'local'} @@ -59,7 +64,9 @@ class TestNoCloudDataSource(MockerTestCase): def my_find_devs_with(*args, **kwargs): raise PsuedoException - self.apply_patches([(util, 'find_devs_with', my_find_devs_with)]) + self.mocks.enter_context( + mock.patch.object(util, 'find_devs_with', + side_effect=PsuedoException)) # by default, NoCloud should search for filesystems by label sys_cfg = {'datasource': {'NoCloud': {}}} @@ -133,7 +140,7 @@ class TestNoCloudDataSource(MockerTestCase): self.assertTrue(ret) -class TestParseCommandLineData(MockerTestCase): +class TestParseCommandLineData(unittest.TestCase): def test_parse_cmdline_data_valid(self): ds_id = "ds=nocloud" @@ -178,15 +185,4 @@ class TestParseCommandLineData(MockerTestCase): self.assertFalse(ret) -def apply_patches(patches): - ret = [] - for (ref, name, replace) in patches: - if replace is None: - continue - orig = getattr(ref, name) - setattr(ref, name, replace) - ret.append((ref, name, orig)) - return ret - - # vi: ts=4 expandtab diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index b4fd1f4d..ddf77265 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -1,12 +1,15 @@ from cloudinit import helpers from cloudinit.sources import DataSourceOpenNebula as ds from cloudinit import util -from mocker import MockerTestCase from ..helpers import populate_dir from base64 import b64encode import os import pwd +import shutil +import tempfile +import unittest + TEST_VARS = { 'VAR1': 'single', @@ -37,12 +40,13 @@ CMD_IP_OUT = '''\ ''' -class TestOpenNebulaDataSource(MockerTestCase): +class TestOpenNebulaDataSource(unittest.TestCase): parsed_user = None def setUp(self): super(TestOpenNebulaDataSource, self).setUp() - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) self.paths = helpers.Paths({'cloud_dir': self.tmp}) # defaults for few tests @@ -228,7 +232,7 @@ class TestOpenNebulaDataSource(MockerTestCase): util.find_devs_with = orig_find_devs_with -class TestOpenNebulaNetwork(MockerTestCase): +class TestOpenNebulaNetwork(unittest.TestCase): def setUp(self): super(TestOpenNebulaNetwork, self).setUp() @@ -280,7 +284,7 @@ iface eth0 inet static ''') -class TestParseShellConfig(MockerTestCase): +class TestParseShellConfig(unittest.TestCase): def test_no_seconds(self): cfg = '\n'.join(["foo=bar", "SECONDS=2", "xx=foo"]) # we could test 'sleep 2', but that would make the test run slower. diff --git a/tests/unittests/test_distros/test_generic.py b/tests/unittests/test_distros/test_generic.py index 2c85cbdb..35153f0d 100644 --- a/tests/unittests/test_distros/test_generic.py +++ b/tests/unittests/test_distros/test_generic.py @@ -67,7 +67,6 @@ class TestGenericDistro(helpers.FilesystemMockingTestCase): self.patchUtils(self.tmp) d.write_sudo_rules("harlowja", rules) contents = util.load_file(d.ci_sudoers_fn) - self.restore() return contents def _count_in(self, lines_look_for, text_content): diff --git a/tests/unittests/test_distros/test_hostname.py b/tests/unittests/test_distros/test_hostname.py index 8e644f4d..143e29a9 100644 --- a/tests/unittests/test_distros/test_hostname.py +++ b/tests/unittests/test_distros/test_hostname.py @@ -1,4 +1,4 @@ -from mocker import MockerTestCase +import unittest from cloudinit.distros.parsers import hostname @@ -12,7 +12,7 @@ blahblah BASE_HOSTNAME = BASE_HOSTNAME.strip() -class TestHostnameHelper(MockerTestCase): +class TestHostnameHelper(unittest.TestCase): def test_parse_same(self): hn = hostname.HostnameConf(BASE_HOSTNAME) self.assertEquals(str(hn).strip(), BASE_HOSTNAME) diff --git a/tests/unittests/test_distros/test_hosts.py b/tests/unittests/test_distros/test_hosts.py index 687a0dab..fc701eaa 100644 --- a/tests/unittests/test_distros/test_hosts.py +++ b/tests/unittests/test_distros/test_hosts.py @@ -1,4 +1,4 @@ -from mocker import MockerTestCase +import unittest from cloudinit.distros.parsers import hosts @@ -14,7 +14,7 @@ BASE_ETC = ''' BASE_ETC = BASE_ETC.strip() -class TestHostsHelper(MockerTestCase): +class TestHostsHelper(unittest.TestCase): def test_parse(self): eh = hosts.HostsConf(BASE_ETC) self.assertEquals(eh.get_entry('127.0.0.1'), [['localhost']]) diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 6e1a0b69..91e630ae 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -1,8 +1,14 @@ -from mocker import MockerTestCase - -import mocker - import os +import unittest + +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack from six import StringIO @@ -74,7 +80,7 @@ class WriteBuffer(object): return self.buffer.getvalue() -class TestNetCfgDistro(MockerTestCase): +class TestNetCfgDistro(unittest.TestCase): def _get_distro(self, dname): cls = distros.fetch(dname) @@ -85,34 +91,28 @@ class TestNetCfgDistro(MockerTestCase): def test_simple_write_ub(self): ub_distro = self._get_distro('ubuntu') - util_mock = self.mocker.replace(util.write_file, - spec=False, passthrough=False) - exists_mock = self.mocker.replace(os.path.isfile, - spec=False, passthrough=False) + with ExitStack() as mocks: + write_bufs = {} - exists_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result(False) + def replace_write(filename, content, mode=0o644, omode="wb"): + buf = WriteBuffer() + buf.mode = mode + buf.omode = omode + buf.write(content) + write_bufs[filename] = buf - write_bufs = {} - - def replace_write(filename, content, mode=0o644, omode="wb"): - buf = WriteBuffer() - buf.mode = mode - buf.omode = omode - buf.write(content) - write_bufs[filename] = buf + mocks.enter_context( + mock.patch.object(util, 'write_file', replace_write)) + mocks.enter_context( + mock.patch.object(os.path, 'isfile', return_value=False)) - util_mock(mocker.ARGS) - self.mocker.call(replace_write) - self.mocker.replay() - ub_distro.apply_network(BASE_NET_CFG, False) + ub_distro.apply_network(BASE_NET_CFG, False) - self.assertEquals(len(write_bufs), 1) - self.assertIn('/etc/network/interfaces', write_bufs) - write_buf = write_bufs['/etc/network/interfaces'] - self.assertEquals(str(write_buf).strip(), BASE_NET_CFG.strip()) - self.assertEquals(write_buf.mode, 0o644) + self.assertEquals(len(write_bufs), 1) + self.assertIn('/etc/network/interfaces', write_bufs) + write_buf = write_bufs['/etc/network/interfaces'] + self.assertEquals(str(write_buf).strip(), BASE_NET_CFG.strip()) + self.assertEquals(write_buf.mode, 0o644) def assertCfgEquals(self, blob1, blob2): b1 = dict(SysConf(blob1.strip().splitlines())) @@ -127,12 +127,6 @@ class TestNetCfgDistro(MockerTestCase): def test_simple_write_rh(self): rh_distro = self._get_distro('rhel') - write_mock = self.mocker.replace(util.write_file, - spec=False, passthrough=False) - load_mock = self.mocker.replace(util.load_file, - spec=False, passthrough=False) - exists_mock = self.mocker.replace(os.path.isfile, - spec=False, passthrough=False) write_bufs = {} @@ -143,37 +137,31 @@ class TestNetCfgDistro(MockerTestCase): buf.write(content) write_bufs[filename] = buf - exists_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result(False) - - load_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result('') - - for _i in range(0, 3): - write_mock(mocker.ARGS) - self.mocker.call(replace_write) - - write_mock(mocker.ARGS) - self.mocker.call(replace_write) - - self.mocker.replay() - rh_distro.apply_network(BASE_NET_CFG, False) - - self.assertEquals(len(write_bufs), 4) - self.assertIn('/etc/sysconfig/network-scripts/ifcfg-lo', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-lo'] - expected_buf = ''' + with ExitStack() as mocks: + mocks.enter_context( + mock.patch.object(util, 'write_file', replace_write)) + mocks.enter_context( + mock.patch.object(util, 'load_file', return_value='')) + mocks.enter_context( + mock.patch.object(os.path, 'isfile', return_value=False)) + + rh_distro.apply_network(BASE_NET_CFG, False) + + self.assertEquals(len(write_bufs), 4) + self.assertIn('/etc/sysconfig/network-scripts/ifcfg-lo', + write_bufs) + write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-lo'] + expected_buf = ''' DEVICE="lo" ONBOOT=yes ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) - self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0'] - expected_buf = ''' + self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0', + write_bufs) + write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0'] + expected_buf = ''' DEVICE="eth0" BOOTPROTO="static" NETMASK="255.255.255.0" @@ -182,36 +170,31 @@ ONBOOT=yes GATEWAY="192.168.1.254" BROADCAST="192.168.1.0" ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) - self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1'] - expected_buf = ''' + self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1', + write_bufs) + write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1'] + expected_buf = ''' DEVICE="eth1" BOOTPROTO="dhcp" ONBOOT=yes ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) - self.assertIn('/etc/sysconfig/network', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network'] - expected_buf = ''' + self.assertIn('/etc/sysconfig/network', write_bufs) + write_buf = write_bufs['/etc/sysconfig/network'] + expected_buf = ''' # Created by cloud-init v. 0.7 NETWORKING=yes ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) def test_write_ipv6_rhel(self): rh_distro = self._get_distro('rhel') - write_mock = self.mocker.replace(util.write_file, - spec=False, passthrough=False) - load_mock = self.mocker.replace(util.load_file, - spec=False, passthrough=False) - exists_mock = self.mocker.replace(os.path.isfile, - spec=False, passthrough=False) write_bufs = {} @@ -222,37 +205,31 @@ NETWORKING=yes buf.write(content) write_bufs[filename] = buf - exists_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result(False) - - load_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result('') - - for _i in range(0, 3): - write_mock(mocker.ARGS) - self.mocker.call(replace_write) - - write_mock(mocker.ARGS) - self.mocker.call(replace_write) - - self.mocker.replay() - rh_distro.apply_network(BASE_NET_CFG_IPV6, False) - - self.assertEquals(len(write_bufs), 4) - self.assertIn('/etc/sysconfig/network-scripts/ifcfg-lo', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-lo'] - expected_buf = ''' + with ExitStack() as mocks: + mocks.enter_context( + mock.patch.object(util, 'write_file', replace_write)) + mocks.enter_context( + mock.patch.object(util, 'load_file', return_value='')) + mocks.enter_context( + mock.patch.object(os.path, 'isfile', return_value=False)) + + rh_distro.apply_network(BASE_NET_CFG_IPV6, False) + + self.assertEquals(len(write_bufs), 4) + self.assertIn('/etc/sysconfig/network-scripts/ifcfg-lo', + write_bufs) + write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-lo'] + expected_buf = ''' DEVICE="lo" ONBOOT=yes ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) - self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0'] - expected_buf = ''' + self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0', + write_bufs) + write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0'] + expected_buf = ''' DEVICE="eth0" BOOTPROTO="static" NETMASK="255.255.255.0" @@ -264,11 +241,12 @@ IPV6INIT=yes IPV6ADDR="2607:f0d0:1002:0011::2" IPV6_DEFAULTGW="2607:f0d0:1002:0011::1" ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) - self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1'] - expected_buf = ''' + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) + self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1', + write_bufs) + write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1'] + expected_buf = ''' DEVICE="eth1" BOOTPROTO="static" NETMASK="255.255.255.0" @@ -280,38 +258,22 @@ IPV6INIT=yes IPV6ADDR="2607:f0d0:1002:0011::3" IPV6_DEFAULTGW="2607:f0d0:1002:0011::1" ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) - self.assertIn('/etc/sysconfig/network', write_bufs) - write_buf = write_bufs['/etc/sysconfig/network'] - expected_buf = ''' + self.assertIn('/etc/sysconfig/network', write_bufs) + write_buf = write_bufs['/etc/sysconfig/network'] + expected_buf = ''' # Created by cloud-init v. 0.7 NETWORKING=yes NETWORKING_IPV6=yes IPV6_AUTOCONF=no ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) def test_simple_write_freebsd(self): fbsd_distro = self._get_distro('freebsd') - util_mock = self.mocker.replace(util.write_file, - spec=False, passthrough=False) - exists_mock = self.mocker.replace(os.path.isfile, - spec=False, passthrough=False) - load_mock = self.mocker.replace(util.load_file, - spec=False, passthrough=False) - subp_mock = self.mocker.replace(util.subp, - spec=False, passthrough=False) - - subp_mock(['ifconfig', '-a']) - self.mocker.count(0, None) - self.mocker.result(('vtnet0', '')) - - exists_mock(mocker.ARGS) - self.mocker.count(0, None) - self.mocker.result(False) write_bufs = {} read_bufs = { @@ -336,23 +298,24 @@ IPV6_AUTOCONF=no return str(write_bufs[fname]) return read_bufs[fname] - util_mock(mocker.ARGS) - self.mocker.call(replace_write) - self.mocker.count(0, None) - - load_mock(mocker.ARGS) - self.mocker.call(replace_read) - self.mocker.count(0, None) - - self.mocker.replay() - fbsd_distro.apply_network(BASE_NET_CFG, False) - - self.assertIn('/etc/rc.conf', write_bufs) - write_buf = write_bufs['/etc/rc.conf'] - expected_buf = ''' + with ExitStack() as mocks: + mocks.enter_context( + mock.patch.object(util, 'subp', return_value=('vtnet0', ''))) + mocks.enter_context( + mock.patch.object(os.path, 'exists', return_value=False)) + mocks.enter_context( + mock.patch.object(util, 'write_file', replace_write)) + mocks.enter_context( + mock.patch.object(util, 'load_file', replace_read)) + + fbsd_distro.apply_network(BASE_NET_CFG, False) + + self.assertIn('/etc/rc.conf', write_bufs) + write_buf = write_bufs['/etc/rc.conf'] + expected_buf = ''' ifconfig_vtnet0="192.168.1.5 netmask 255.255.255.0" ifconfig_vtnet1="DHCP" defaultrouter="192.168.1.254" ''' - self.assertCfgEquals(expected_buf, str(write_buf)) - self.assertEquals(write_buf.mode, 0o644) + self.assertCfgEquals(expected_buf, str(write_buf)) + self.assertEquals(write_buf.mode, 0o644) -- cgit v1.2.3 From c0aae445119252134e8c89b3b73999ed213135f1 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 20:30:10 -0500 Subject: More conversions from mocker to mock. --- tests/unittests/test_distros/test_resolv.py | 5 +- tests/unittests/test_distros/test_sysconfig.py | 5 +- .../test_distros/test_user_data_normalize.py | 4 +- .../test_handler/test_handler_apt_configure.py | 10 +- .../test_handler/test_handler_ca_certs.py | 253 ++++++++++++--------- 5 files changed, 155 insertions(+), 122 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_distros/test_resolv.py b/tests/unittests/test_distros/test_resolv.py index 6b6ff6aa..779b83e3 100644 --- a/tests/unittests/test_distros/test_resolv.py +++ b/tests/unittests/test_distros/test_resolv.py @@ -1,8 +1,7 @@ -from mocker import MockerTestCase - from cloudinit.distros.parsers import resolv_conf import re +import unittest BASE_RESOLVE = ''' @@ -14,7 +13,7 @@ nameserver 10.15.30.92 BASE_RESOLVE = BASE_RESOLVE.strip() -class TestResolvHelper(MockerTestCase): +class TestResolvHelper(unittest.TestCase): def test_parse_same(self): rp = resolv_conf.ResolvConf(BASE_RESOLVE) rp_r = str(rp).strip() diff --git a/tests/unittests/test_distros/test_sysconfig.py b/tests/unittests/test_distros/test_sysconfig.py index 0c651407..f66201b3 100644 --- a/tests/unittests/test_distros/test_sysconfig.py +++ b/tests/unittests/test_distros/test_sysconfig.py @@ -1,6 +1,5 @@ -from mocker import MockerTestCase - import re +import unittest from cloudinit.distros.parsers.sys_conf import SysConf @@ -8,7 +7,7 @@ from cloudinit.distros.parsers.sys_conf import SysConf # Lots of good examples @ # http://content.hccfl.edu/pollock/AUnix1/SysconfigFilesDesc.txt -class TestSysConfHelper(MockerTestCase): +class TestSysConfHelper(unittest.TestCase): # This function was added in 2.7, make it work for 2.6 def assertRegMatches(self, text, regexp): regexp = re.compile(regexp) diff --git a/tests/unittests/test_distros/test_user_data_normalize.py b/tests/unittests/test_distros/test_user_data_normalize.py index 50398c74..b90d6185 100644 --- a/tests/unittests/test_distros/test_user_data_normalize.py +++ b/tests/unittests/test_distros/test_user_data_normalize.py @@ -1,4 +1,4 @@ -from mocker import MockerTestCase +import unittest from cloudinit import distros from cloudinit import helpers @@ -15,7 +15,7 @@ bcfg = { } -class TestUGNormalize(MockerTestCase): +class TestUGNormalize(unittest.TestCase): def _make_distro(self, dtype, def_user=None): cfg = dict(settings.CFG_BUILTIN) diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index f5832365..2c3dad72 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -1,17 +1,19 @@ -from mocker import MockerTestCase - from cloudinit import util from cloudinit.config import cc_apt_configure import os import re +import shutil +import tempfile +import unittest -class TestAptProxyConfig(MockerTestCase): +class TestAptProxyConfig(unittest.TestCase): def setUp(self): super(TestAptProxyConfig, self).setUp() - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) self.pfile = os.path.join(self.tmp, "proxy.cfg") self.cfile = os.path.join(self.tmp, "config.cfg") diff --git a/tests/unittests/test_handler/test_handler_ca_certs.py b/tests/unittests/test_handler/test_handler_ca_certs.py index 7fe47b74..97213a0c 100644 --- a/tests/unittests/test_handler/test_handler_ca_certs.py +++ b/tests/unittests/test_handler/test_handler_ca_certs.py @@ -1,5 +1,3 @@ -from mocker import MockerTestCase - from cloudinit import cloud from cloudinit import helpers from cloudinit import util @@ -7,9 +5,21 @@ from cloudinit import util from cloudinit.config import cc_ca_certs import logging +import shutil +import tempfile +import unittest + +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack -class TestNoConfig(MockerTestCase): +class TestNoConfig(unittest.TestCase): def setUp(self): super(TestNoConfig, self).setUp() self.name = "ca-certs" @@ -22,15 +32,20 @@ class TestNoConfig(MockerTestCase): Test that nothing is done if no ca-certs configuration is provided. """ config = util.get_builtin_cfg() - self.mocker.replace(util.write_file, passthrough=False) - self.mocker.replace(cc_ca_certs.update_ca_certs, passthrough=False) - self.mocker.replay() + with ExitStack() as mocks: + util_mock = mocks.enter_context( + mock.patch.object(util, 'write_file')) + certs_mock = mocks.enter_context( + mock.patch.object(cc_ca_certs, 'update_ca_certs')) - cc_ca_certs.handle(self.name, config, self.cloud_init, self.log, - self.args) + cc_ca_certs.handle(self.name, config, self.cloud_init, self.log, + self.args) + self.assertEqual(util_mock.call_count, 0) + self.assertEqual(certs_mock.call_count, 0) -class TestConfig(MockerTestCase): + +class TestConfig(unittest.TestCase): def setUp(self): super(TestConfig, self).setUp() self.name = "ca-certs" @@ -39,16 +54,16 @@ class TestConfig(MockerTestCase): self.log = logging.getLogger("TestNoConfig") self.args = [] - # Mock out the functions that actually modify the system - self.mock_add = self.mocker.replace(cc_ca_certs.add_ca_certs, - passthrough=False) - self.mock_update = self.mocker.replace(cc_ca_certs.update_ca_certs, - passthrough=False) - self.mock_remove = self.mocker.replace( - cc_ca_certs.remove_default_ca_certs, passthrough=False) + self.mocks = ExitStack() + self.addCleanup(self.mocks.close) - # Order must be correct - self.mocker.order() + # Mock out the functions that actually modify the system + self.mock_add = self.mocks.enter_context( + mock.patch.object(cc_ca_certs, 'add_ca_certs')) + self.mock_update = self.mocks.enter_context( + mock.patch.object(cc_ca_certs, 'update_ca_certs')) + self.mock_remove = self.mocks.enter_context( + mock.patch.object(cc_ca_certs, 'remove_default_ca_certs')) def test_no_trusted_list(self): """ @@ -57,86 +72,88 @@ class TestConfig(MockerTestCase): """ config = {"ca-certs": {}} - # No functions should be called - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.assertEqual(self.mock_add.call_count, 0) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 0) + def test_empty_trusted_list(self): """Test that no certificate are written if 'trusted' list is empty.""" config = {"ca-certs": {"trusted": []}} - # No functions should be called - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.assertEqual(self.mock_add.call_count, 0) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 0) + def test_single_trusted(self): """Test that a single cert gets passed to add_ca_certs.""" config = {"ca-certs": {"trusted": ["CERT1"]}} - self.mock_add(["CERT1"]) - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.mock_add.assert_called_once_with(['CERT1']) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 0) + def test_multiple_trusted(self): """Test that multiple certs get passed to add_ca_certs.""" config = {"ca-certs": {"trusted": ["CERT1", "CERT2"]}} - self.mock_add(["CERT1", "CERT2"]) - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.mock_add.assert_called_once_with(['CERT1', 'CERT2']) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 0) + def test_remove_default_ca_certs(self): """Test remove_defaults works as expected.""" config = {"ca-certs": {"remove-defaults": True}} - self.mock_remove() - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.assertEqual(self.mock_add.call_count, 0) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 1) + def test_no_remove_defaults_if_false(self): """Test remove_defaults is not called when config value is False.""" config = {"ca-certs": {"remove-defaults": False}} - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.assertEqual(self.mock_add.call_count, 0) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 0) + def test_correct_order_for_remove_then_add(self): """Test remove_defaults is not called when config value is False.""" config = {"ca-certs": {"remove-defaults": True, "trusted": ["CERT1"]}} - self.mock_remove() - self.mock_add(["CERT1"]) - self.mock_update() - self.mocker.replay() - cc_ca_certs.handle(self.name, config, self.cloud, self.log, self.args) + self.mock_add.assert_called_once_with(['CERT1']) + self.assertEqual(self.mock_update.call_count, 1) + self.assertEqual(self.mock_remove.call_count, 1) -class TestAddCaCerts(MockerTestCase): + +class TestAddCaCerts(unittest.TestCase): def setUp(self): super(TestAddCaCerts, self).setUp() + tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmpdir) self.paths = helpers.Paths({ - 'cloud_dir': self.makeDir() + 'cloud_dir': tmpdir, }) def test_no_certs_in_list(self): """Test that no certificate are written if not provided.""" - self.mocker.replace(util.write_file, passthrough=False) - self.mocker.replay() - cc_ca_certs.add_ca_certs([]) + with mock.patch.object(util, 'write_file') as mockobj: + cc_ca_certs.add_ca_certs([]) + self.assertEqual(mockobj.call_count, 0) def test_single_cert_trailing_cr(self): """Test adding a single certificate to the trusted CAs @@ -146,19 +163,21 @@ class TestAddCaCerts(MockerTestCase): ca_certs_content = "line1\nline2\ncloud-init-ca-certs.crt\nline3\n" expected = "line1\nline2\nline3\ncloud-init-ca-certs.crt\n" - mock_write = self.mocker.replace(util.write_file, passthrough=False) - mock_load = self.mocker.replace(util.load_file, passthrough=False) - - mock_write("/usr/share/ca-certificates/cloud-init-ca-certs.crt", - cert, mode=0o644) - - mock_load("/etc/ca-certificates.conf") - self.mocker.result(ca_certs_content) + with ExitStack() as mocks: + mock_write = mocks.enter_context( + mock.patch.object(util, 'write_file')) + mock_load = mocks.enter_context( + mock.patch.object(util, 'load_file', + return_value=ca_certs_content)) - mock_write("/etc/ca-certificates.conf", expected, omode="wb") - self.mocker.replay() + cc_ca_certs.add_ca_certs([cert]) - cc_ca_certs.add_ca_certs([cert]) + mock_write.assert_has_calls([ + mock.call("/usr/share/ca-certificates/cloud-init-ca-certs.crt", + cert, mode=0o644), + mock.call("/etc/ca-certificates.conf", expected, omode="wb"), + ]) + mock_load.assert_called_once_with("/etc/ca-certificates.conf") def test_single_cert_no_trailing_cr(self): """Test adding a single certificate to the trusted CAs @@ -167,75 +186,89 @@ class TestAddCaCerts(MockerTestCase): ca_certs_content = "line1\nline2\nline3" - mock_write = self.mocker.replace(util.write_file, passthrough=False) - mock_load = self.mocker.replace(util.load_file, passthrough=False) + with ExitStack() as mocks: + mock_write = mocks.enter_context( + mock.patch.object(util, 'write_file')) + mock_load = mocks.enter_context( + mock.patch.object(util, 'load_file', + return_value=ca_certs_content)) - mock_write("/usr/share/ca-certificates/cloud-init-ca-certs.crt", - cert, mode=0o644) + cc_ca_certs.add_ca_certs([cert]) - mock_load("/etc/ca-certificates.conf") - self.mocker.result(ca_certs_content) + mock_write.assert_has_calls([ + mock.call("/usr/share/ca-certificates/cloud-init-ca-certs.crt", + cert, mode=0o644), + mock.call("/etc/ca-certificates.conf", + "%s\n%s\n" % (ca_certs_content, + "cloud-init-ca-certs.crt"), + omode="wb"), + ]) - mock_write("/etc/ca-certificates.conf", - "%s\n%s\n" % (ca_certs_content, "cloud-init-ca-certs.crt"), - omode="wb") - self.mocker.replay() - - cc_ca_certs.add_ca_certs([cert]) + mock_load.assert_called_once_with("/etc/ca-certificates.conf") def test_multiple_certs(self): """Test adding multiple certificates to the trusted CAs.""" certs = ["CERT1\nLINE2\nLINE3", "CERT2\nLINE2\nLINE3"] expected_cert_file = "\n".join(certs) - - mock_write = self.mocker.replace(util.write_file, passthrough=False) - mock_load = self.mocker.replace(util.load_file, passthrough=False) - - mock_write("/usr/share/ca-certificates/cloud-init-ca-certs.crt", - expected_cert_file, mode=0o644) - ca_certs_content = "line1\nline2\nline3" - mock_load("/etc/ca-certificates.conf") - self.mocker.result(ca_certs_content) - out = "%s\n%s\n" % (ca_certs_content, "cloud-init-ca-certs.crt") - mock_write("/etc/ca-certificates.conf", out, omode="wb") + with ExitStack() as mocks: + mock_write = mocks.enter_context( + mock.patch.object(util, 'write_file')) + mock_load = mocks.enter_context( + mock.patch.object(util, 'load_file', + return_value=ca_certs_content)) - self.mocker.replay() + cc_ca_certs.add_ca_certs(certs) - cc_ca_certs.add_ca_certs(certs) + mock_write.assert_has_calls([ + mock.call("/usr/share/ca-certificates/cloud-init-ca-certs.crt", + expected_cert_file, mode=0o644), + mock.call("/etc/ca-certificates.conf", + "%s\n%s\n" % (ca_certs_content, + "cloud-init-ca-certs.crt"), + omode='wb'), + ]) + mock_load.assert_called_once_with("/etc/ca-certificates.conf") -class TestUpdateCaCerts(MockerTestCase): - def test_commands(self): - mock_check_call = self.mocker.replace(util.subp, - passthrough=False) - mock_check_call(["update-ca-certificates"], capture=False) - self.mocker.replay() - cc_ca_certs.update_ca_certs() +class TestUpdateCaCerts(unittest.TestCase): + def test_commands(self): + with mock.patch.object(util, 'subp') as mockobj: + cc_ca_certs.update_ca_certs() + mockobj.assert_called_once_with( + ["update-ca-certificates"], capture=False) -class TestRemoveDefaultCaCerts(MockerTestCase): +class TestRemoveDefaultCaCerts(unittest.TestCase): def setUp(self): super(TestRemoveDefaultCaCerts, self).setUp() + tmpdir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, tmpdir) self.paths = helpers.Paths({ - 'cloud_dir': self.makeDir() + 'cloud_dir': tmpdir, }) def test_commands(self): - mock_delete_dir_contents = self.mocker.replace( - util.delete_dir_contents, passthrough=False) - mock_write = self.mocker.replace(util.write_file, passthrough=False) - mock_subp = self.mocker.replace(util.subp, - passthrough=False) - - mock_delete_dir_contents("/usr/share/ca-certificates/") - mock_delete_dir_contents("/etc/ssl/certs/") - mock_write("/etc/ca-certificates.conf", "", mode=0o644) - mock_subp(('debconf-set-selections', '-'), - "ca-certificates ca-certificates/trust_new_crts select no") - self.mocker.replay() - - cc_ca_certs.remove_default_ca_certs() + with ExitStack() as mocks: + mock_delete = mocks.enter_context( + mock.patch.object(util, 'delete_dir_contents')) + mock_write = mocks.enter_context( + mock.patch.object(util, 'write_file')) + mock_subp = mocks.enter_context(mock.patch.object(util, 'subp')) + + cc_ca_certs.remove_default_ca_certs() + + mock_delete.assert_has_calls([ + mock.call("/usr/share/ca-certificates/"), + mock.call("/etc/ssl/certs/"), + ]) + + mock_write.assert_called_once_with( + "/etc/ca-certificates.conf", "", mode=0o644) + + mock_subp.assert_called_once_with( + ('debconf-set-selections', '-'), + "ca-certificates ca-certificates/trust_new_crts select no") -- cgit v1.2.3 From 6f2a62c2fde85839ed437549597498a707f5da68 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 20:52:01 -0500 Subject: Conversion from mocker to mock completed. --- .../test_handler/test_handler_growpart.py | 107 +++++++++++---------- tests/unittests/test_pathprefix2dict.py | 10 +- tests/unittests/test_runs/test_merge_run.py | 3 +- tests/unittests/test_runs/test_simple_run.py | 4 +- tests/unittests/test_util.py | 31 +++--- 5 files changed, 86 insertions(+), 69 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_handler/test_handler_growpart.py b/tests/unittests/test_handler/test_handler_growpart.py index 3056320d..89727863 100644 --- a/tests/unittests/test_handler/test_handler_growpart.py +++ b/tests/unittests/test_handler/test_handler_growpart.py @@ -1,5 +1,3 @@ -from mocker import MockerTestCase - from cloudinit import cloud from cloudinit import util @@ -9,6 +7,16 @@ import errno import logging import os import re +import unittest + +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack # growpart: # mode: auto # off, on, auto, 'growpart' @@ -42,7 +50,7 @@ growpart disk partition """ -class TestDisabled(MockerTestCase): +class TestDisabled(unittest.TestCase): def setUp(self): super(TestDisabled, self).setUp() self.name = "growpart" @@ -57,14 +65,14 @@ class TestDisabled(MockerTestCase): # this really only verifies that resizer_factory isn't called config = {'growpart': {'mode': 'off'}} - self.mocker.replace(cc_growpart.resizer_factory, - passthrough=False) - self.mocker.replay() - self.handle(self.name, config, self.cloud_init, self.log, self.args) + with mock.patch.object(cc_growpart, 'resizer_factory') as mockobj: + self.handle(self.name, config, self.cloud_init, self.log, + self.args) + self.assertEqual(mockobj.call_count, 0) -class TestConfig(MockerTestCase): +class TestConfig(unittest.TestCase): def setUp(self): super(TestConfig, self).setUp() self.name = "growpart" @@ -77,69 +85,70 @@ class TestConfig(MockerTestCase): self.cloud_init = None self.handle = cc_growpart.handle - # Order must be correct - self.mocker.order() - def test_no_resizers_auto_is_fine(self): - subp = self.mocker.replace(util.subp, passthrough=False) - subp(['growpart', '--help'], env={'LANG': 'C'}) - self.mocker.result((HELP_GROWPART_NO_RESIZE, "")) - self.mocker.replay() + with mock.patch.object( + util, 'subp', + return_value=(HELP_GROWPART_NO_RESIZE, "")) as mockobj: - config = {'growpart': {'mode': 'auto'}} - self.handle(self.name, config, self.cloud_init, self.log, self.args) + config = {'growpart': {'mode': 'auto'}} + self.handle(self.name, config, self.cloud_init, self.log, + self.args) + + mockobj.assert_called_once_with( + ['growpart', '--help'], env={'LANG': 'C'}) def test_no_resizers_mode_growpart_is_exception(self): - subp = self.mocker.replace(util.subp, passthrough=False) - subp(['growpart', '--help'], env={'LANG': 'C'}) - self.mocker.result((HELP_GROWPART_NO_RESIZE, "")) - self.mocker.replay() + with mock.patch.object( + util, 'subp', + return_value=(HELP_GROWPART_NO_RESIZE, "")) as mockobj: + config = {'growpart': {'mode': "growpart"}} + self.assertRaises( + ValueError, self.handle, self.name, config, + self.cloud_init, self.log, self.args) - config = {'growpart': {'mode': "growpart"}} - self.assertRaises(ValueError, self.handle, self.name, config, - self.cloud_init, self.log, self.args) + mockobj.assert_called_once_with( + ['growpart', '--help'], env={'LANG': 'C'}) def test_mode_auto_prefers_growpart(self): - subp = self.mocker.replace(util.subp, passthrough=False) - subp(['growpart', '--help'], env={'LANG': 'C'}) - self.mocker.result((HELP_GROWPART_RESIZE, "")) - self.mocker.replay() + with mock.patch.object( + util, 'subp', + return_value=(HELP_GROWPART_RESIZE, "")) as mockobj: + ret = cc_growpart.resizer_factory(mode="auto") + self.assertIsInstance(ret, cc_growpart.ResizeGrowPart) - ret = cc_growpart.resizer_factory(mode="auto") - self.assertTrue(isinstance(ret, cc_growpart.ResizeGrowPart)) + mockobj.assert_called_once_with( + ['growpart', '--help'], env={'LANG': 'C'}) def test_handle_with_no_growpart_entry(self): # if no 'growpart' entry in config, then mode=auto should be used myresizer = object() + retval = (("/", cc_growpart.RESIZE.CHANGED, "my-message",),) + + with ExitStack() as mocks: + factory = mocks.enter_context( + mock.patch.object(cc_growpart, 'resizer_factory', + return_value=myresizer)) + rsdevs = mocks.enter_context( + mock.patch.object(cc_growpart, 'resize_devices', + return_value=retval)) + mocks.enter_context( + mock.patch.object(cc_growpart, 'RESIZERS', + (('mysizer', object),) + )) - factory = self.mocker.replace(cc_growpart.resizer_factory, - passthrough=False) - rsdevs = self.mocker.replace(cc_growpart.resize_devices, - passthrough=False) - factory("auto") - self.mocker.result(myresizer) - rsdevs(myresizer, ["/"]) - self.mocker.result((("/", cc_growpart.RESIZE.CHANGED, "my-message",),)) - self.mocker.replay() - - try: - orig_resizers = cc_growpart.RESIZERS - cc_growpart.RESIZERS = (('mysizer', object),) self.handle(self.name, {}, self.cloud_init, self.log, self.args) - finally: - cc_growpart.RESIZERS = orig_resizers + factory.assert_called_once_with('auto') + rsdevs.assert_called_once_with(myresizer, ['/']) -class TestResize(MockerTestCase): + +class TestResize(unittest.TestCase): def setUp(self): super(TestResize, self).setUp() self.name = "growpart" self.log = logging.getLogger("TestResize") - # Order must be correct - self.mocker.order() - def test_simple_devices(self): # test simple device list # this patches out devent2dev, os.stat, and device_part_info diff --git a/tests/unittests/test_pathprefix2dict.py b/tests/unittests/test_pathprefix2dict.py index 590c4b82..38a56dc2 100644 --- a/tests/unittests/test_pathprefix2dict.py +++ b/tests/unittests/test_pathprefix2dict.py @@ -1,13 +1,17 @@ from cloudinit import util -from mocker import MockerTestCase from .helpers import populate_dir +import shutil +import tempfile +import unittest -class TestPathPrefix2Dict(MockerTestCase): + +class TestPathPrefix2Dict(unittest.TestCase): def setUp(self): - self.tmp = self.makeDir() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def test_required_only(self): dirdata = {'f1': 'f1content', 'f2': 'f2content'} diff --git a/tests/unittests/test_runs/test_merge_run.py b/tests/unittests/test_runs/test_merge_run.py index 2d920eb8..d0ec36a9 100644 --- a/tests/unittests/test_runs/test_merge_run.py +++ b/tests/unittests/test_runs/test_merge_run.py @@ -4,14 +4,13 @@ import tempfile from .. import helpers -from cloudinit.settings import (PER_INSTANCE) +from cloudinit.settings import PER_INSTANCE from cloudinit import stages from cloudinit import util class TestMergeRun(helpers.FilesystemMockingTestCase): def _patchIn(self, root): - self.restore() self.patchOS(root) self.patchUtils(root) diff --git a/tests/unittests/test_runs/test_simple_run.py b/tests/unittests/test_runs/test_simple_run.py index 0279b8b0..e19e65cd 100644 --- a/tests/unittests/test_runs/test_simple_run.py +++ b/tests/unittests/test_runs/test_simple_run.py @@ -4,19 +4,17 @@ import tempfile from .. import helpers -from cloudinit.settings import (PER_INSTANCE) +from cloudinit.settings import PER_INSTANCE from cloudinit import stages from cloudinit import util class TestSimpleRun(helpers.FilesystemMockingTestCase): def _patchIn(self, root): - self.restore() self.patchOS(root) self.patchUtils(root) def _pp_root(self, root, repatch=True): - self.restore() for (dirpath, dirnames, filenames) in os.walk(root): print(dirpath) for f in filenames: diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 5ac47b80..b1f5d62c 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -6,6 +6,12 @@ import tempfile from . import helpers import unittest +import six + +try: + from unittest import mock +except ImportError: + import mock from cloudinit import importer from cloudinit import util @@ -128,23 +134,24 @@ class TestWriteFile(unittest.TestCase): with open(my_file, "w") as fp: fp.write("My Content") - import_mock = self.mocker.replace(importer.import_module, - passthrough=False) - import_mock('selinux') - fake_se = FakeSelinux(my_file) - self.mocker.result(fake_se) - self.mocker.replay() - with util.SeLinuxGuard(my_file) as is_on: - self.assertTrue(is_on) + + with mock.patch.object(importer, 'import_module', + return_value=fake_se) as mockobj: + with util.SeLinuxGuard(my_file) as is_on: + self.assertTrue(is_on) + self.assertEqual(1, len(fake_se.restored)) self.assertEqual(my_file, fake_se.restored[0]) + mockobj.assert_called_once_with('selinux') -class TestDeleteDirContents(MockerTestCase): + +class TestDeleteDirContents(unittest.TestCase): def setUp(self): super(TestDeleteDirContents, self).setUp() - self.tmp = self.makeDir(prefix="unittest_") + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) def assertDirEmpty(self, dirname): self.assertEqual([], os.listdir(dirname)) @@ -248,8 +255,8 @@ class TestLoadYaml(unittest.TestCase): self.mydefault) def test_python_unicode(self): - # complex type of python/unicde is explicitly allowed - myobj = {'1': unicode("FOOBAR")} + # complex type of python/unicode is explicitly allowed + myobj = {'1': six.text_type("FOOBAR")} safe_yaml = yaml.dump(myobj) self.assertEqual(util.load_yaml(blob=safe_yaml, default=self.mydefault), -- cgit v1.2.3 From 3b798b5d5c3caa5d0e8e534855e29010ca932aaa Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Thu, 22 Jan 2015 21:21:04 -0500 Subject: Low hanging Python 3 fruit. --- cloudinit/config/cc_ca_certs.py | 4 ++-- cloudinit/config/cc_chef.py | 6 ++++-- cloudinit/distros/__init__.py | 12 ++++++++++-- cloudinit/distros/debian.py | 2 +- cloudinit/distros/rhel_util.py | 4 ++-- cloudinit/distros/sles.py | 2 +- cloudinit/sources/DataSourceAltCloud.py | 12 ++++++------ cloudinit/sources/DataSourceAzure.py | 4 ++-- cloudinit/sources/DataSourceMAAS.py | 10 ++++++---- cloudinit/sources/DataSourceOpenNebula.py | 2 +- cloudinit/templater.py | 2 +- cloudinit/util.py | 7 +++++-- templates/resolv.conf.tmpl | 2 +- tests/unittests/helpers.py | 4 ++-- tests/unittests/test_datasource/test_configdrive.py | 2 +- tests/unittests/test_datasource/test_digitalocean.py | 7 +++---- tests/unittests/test_datasource/test_gce.py | 2 +- tests/unittests/test_datasource/test_opennebula.py | 2 +- tests/unittests/test_datasource/test_smartos.py | 4 +++- .../unittests/test_handler/test_handler_apt_configure.py | 2 +- tests/unittests/test_merging.py | 16 +++++++++------- tools/ccfg-merge-debug | 4 ++-- 22 files changed, 65 insertions(+), 47 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_ca_certs.py b/cloudinit/config/cc_ca_certs.py index 4f2a46a1..8248b020 100644 --- a/cloudinit/config/cc_ca_certs.py +++ b/cloudinit/config/cc_ca_certs.py @@ -44,7 +44,7 @@ def add_ca_certs(certs): if certs: # First ensure they are strings... cert_file_contents = "\n".join([str(c) for c in certs]) - util.write_file(CA_CERT_FULL_PATH, cert_file_contents, mode=0644) + util.write_file(CA_CERT_FULL_PATH, cert_file_contents, mode=0o644) # Append cert filename to CA_CERT_CONFIG file. # We have to strip the content because blank lines in the file @@ -63,7 +63,7 @@ def remove_default_ca_certs(): """ util.delete_dir_contents(CA_CERT_PATH) util.delete_dir_contents(CA_CERT_SYSTEM_PATH) - util.write_file(CA_CERT_CONFIG, "", mode=0644) + util.write_file(CA_CERT_CONFIG, "", mode=0o644) debconf_sel = "ca-certificates ca-certificates/trust_new_crts select no" util.subp(('debconf-set-selections', '-'), debconf_sel) diff --git a/cloudinit/config/cc_chef.py b/cloudinit/config/cc_chef.py index fc837363..584199e5 100644 --- a/cloudinit/config/cc_chef.py +++ b/cloudinit/config/cc_chef.py @@ -76,6 +76,8 @@ from cloudinit import templater from cloudinit import url_helper from cloudinit import util +import six + RUBY_VERSION_DEFAULT = "1.8" CHEF_DIRS = tuple([ @@ -261,7 +263,7 @@ def run_chef(chef_cfg, log): cmd_args = chef_cfg['exec_arguments'] if isinstance(cmd_args, (list, tuple)): cmd.extend(cmd_args) - elif isinstance(cmd_args, (str, basestring)): + elif isinstance(cmd_args, six.string_types): cmd.append(cmd_args) else: log.warn("Unknown type %s provided for chef" @@ -300,7 +302,7 @@ def install_chef(cloud, chef_cfg, log): with util.tempdir() as tmpd: # Use tmpdir over tmpfile to avoid 'text file busy' on execute tmpf = "%s/chef-omnibus-install" % tmpd - util.write_file(tmpf, str(content), mode=0700) + util.write_file(tmpf, str(content), mode=0o700) util.subp([tmpf], capture=False) else: log.warn("Unknown chef install type '%s'", install_type) diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 4ebccdda..6b96d58c 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -25,7 +25,6 @@ import six from six import StringIO import abc -import itertools import os import re @@ -37,6 +36,15 @@ from cloudinit import util from cloudinit.distros.parsers import hosts +try: + # Python 3 + from six import filter +except ImportError: + # Python 2 + from itertools import ifilter as filter + + + OSFAMILIES = { 'debian': ['debian', 'ubuntu'], 'redhat': ['fedora', 'rhel'], @@ -853,7 +861,7 @@ def extract_default(users, default_name=None, default_config=None): return config['default'] tmp_users = users.items() - tmp_users = dict(itertools.ifilter(safe_find, tmp_users)) + tmp_users = dict(filter(safe_find, tmp_users)) if not tmp_users: return (default_name, default_config) else: diff --git a/cloudinit/distros/debian.py b/cloudinit/distros/debian.py index b09eb094..6d3a82bf 100644 --- a/cloudinit/distros/debian.py +++ b/cloudinit/distros/debian.py @@ -97,7 +97,7 @@ class Distro(distros.Distro): if not conf: conf = HostnameConf('') conf.set_hostname(your_hostname) - util.write_file(out_fn, str(conf), 0644) + util.write_file(out_fn, str(conf), 0o644) def _read_system_hostname(self): sys_hostname = self._read_hostname(self.hostname_conf_fn) diff --git a/cloudinit/distros/rhel_util.py b/cloudinit/distros/rhel_util.py index 063d536e..903d7793 100644 --- a/cloudinit/distros/rhel_util.py +++ b/cloudinit/distros/rhel_util.py @@ -50,7 +50,7 @@ def update_sysconfig_file(fn, adjustments, allow_empty=False): ] if not exists: lines.insert(0, util.make_header()) - util.write_file(fn, "\n".join(lines) + "\n", 0644) + util.write_file(fn, "\n".join(lines) + "\n", 0o644) # Helper function to read a RHEL/SUSE /etc/sysconfig/* file @@ -86,4 +86,4 @@ def update_resolve_conf_file(fn, dns_servers, search_servers): r_conf.add_search_domain(s) except ValueError: util.logexc(LOG, "Failed at adding search domain %s", s) - util.write_file(fn, str(r_conf), 0644) + util.write_file(fn, str(r_conf), 0o644) diff --git a/cloudinit/distros/sles.py b/cloudinit/distros/sles.py index 0c6d1203..620c974c 100644 --- a/cloudinit/distros/sles.py +++ b/cloudinit/distros/sles.py @@ -113,7 +113,7 @@ class Distro(distros.Distro): if not conf: conf = HostnameConf('') conf.set_hostname(hostname) - util.write_file(out_fn, str(conf), 0644) + util.write_file(out_fn, str(conf), 0o644) def _read_system_hostname(self): host_fn = self.hostname_conf_fn diff --git a/cloudinit/sources/DataSourceAltCloud.py b/cloudinit/sources/DataSourceAltCloud.py index 1e913a6e..69053d0b 100644 --- a/cloudinit/sources/DataSourceAltCloud.py +++ b/cloudinit/sources/DataSourceAltCloud.py @@ -124,11 +124,11 @@ class DataSourceAltCloud(sources.DataSource): cmd = CMD_DMI_SYSTEM try: (cmd_out, _err) = util.subp(cmd) - except ProcessExecutionError, _err: + except ProcessExecutionError as _err: LOG.debug(('Failed command: %s\n%s') % \ (' '.join(cmd), _err.message)) return 'UNKNOWN' - except OSError, _err: + except OSError as _err: LOG.debug(('Failed command: %s\n%s') % \ (' '.join(cmd), _err.message)) return 'UNKNOWN' @@ -211,11 +211,11 @@ class DataSourceAltCloud(sources.DataSource): cmd = CMD_PROBE_FLOPPY (cmd_out, _err) = util.subp(cmd) LOG.debug(('Command: %s\nOutput%s') % (' '.join(cmd), cmd_out)) - except ProcessExecutionError, _err: + except ProcessExecutionError as _err: util.logexc(LOG, 'Failed command: %s\n%s', ' '.join(cmd), _err.message) return False - except OSError, _err: + except OSError as _err: util.logexc(LOG, 'Failed command: %s\n%s', ' '.join(cmd), _err.message) return False @@ -228,11 +228,11 @@ class DataSourceAltCloud(sources.DataSource): cmd.append('--exit-if-exists=' + floppy_dev) (cmd_out, _err) = util.subp(cmd) LOG.debug(('Command: %s\nOutput%s') % (' '.join(cmd), cmd_out)) - except ProcessExecutionError, _err: + except ProcessExecutionError as _err: util.logexc(LOG, 'Failed command: %s\n%s', ' '.join(cmd), _err.message) return False - except OSError, _err: + except OSError as _err: util.logexc(LOG, 'Failed command: %s\n%s', ' '.join(cmd), _err.message) return False diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 09bc196d..29ae2c22 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -151,7 +151,7 @@ class DataSourceAzureNet(sources.DataSource): # walinux agent writes files world readable, but expects # the directory to be protected. - write_files(ddir, files, dirmode=0700) + write_files(ddir, files, dirmode=0o700) # handle the hostname 'publishing' try: @@ -390,7 +390,7 @@ def write_files(datadir, files, dirmode=None): util.ensure_dir(datadir, dirmode) for (name, content) in files.items(): util.write_file(filename=os.path.join(datadir, name), - content=content, mode=0600) + content=content, mode=0o600) def invoke_agent(cmd): diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index 9a3e30c5..8f9c81de 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -18,6 +18,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import print_function + from email.utils import parsedate import errno import oauth.oauth as oauth @@ -361,7 +363,7 @@ if __name__ == "__main__": return (urllib2.urlopen(req).read()) def printurl(url, headers_cb): - print "== %s ==\n%s\n" % (url, geturl(url, headers_cb)) + print("== %s ==\n%s\n" % (url, geturl(url, headers_cb))) def crawl(url, headers_cb=None): if url.endswith("/"): @@ -386,9 +388,9 @@ if __name__ == "__main__": version=args.apiver) else: (userdata, metadata) = read_maas_seed_url(args.url) - print "=== userdata ===" - print userdata - print "=== metadata ===" + print("=== userdata ===") + print(userdata) + print("=== metadata ===") pprint.pprint(metadata) elif args.subcmd == "get": diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index e2469f6e..f9dac29e 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -280,7 +280,7 @@ def parse_shell_config(content, keylist=None, bash=None, asuser=None, # allvars expands to all existing variables by using '${!x*}' notation # where x is lower or upper case letters or '_' - allvars = ["${!%s*}" % x for x in string.letters + "_"] + allvars = ["${!%s*}" % x for x in string.ascii_letters + "_"] keylist_in = keylist if keylist is None: diff --git a/cloudinit/templater.py b/cloudinit/templater.py index 4cd3f13d..a9231482 100644 --- a/cloudinit/templater.py +++ b/cloudinit/templater.py @@ -137,7 +137,7 @@ def render_from_file(fn, params): return renderer(content, params) -def render_to_file(fn, outfn, params, mode=0644): +def render_to_file(fn, outfn, params, mode=0o644): contents = render_from_file(fn, params) util.write_file(outfn, contents, mode=mode) diff --git a/cloudinit/util.py b/cloudinit/util.py index 434ba7fb..94fd5c70 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -142,6 +142,9 @@ class ProcessExecutionError(IOError): 'reason': self.reason, } IOError.__init__(self, message) + # For backward compatibility with Python 2. + if not hasattr(self, 'message'): + self.message = message class SeLinuxGuard(object): @@ -260,7 +263,7 @@ def translate_bool(val, addons=None): def rand_str(strlen=32, select_from=None): if not select_from: - select_from = string.letters + string.digits + select_from = string.ascii_letters + string.digits return "".join([random.choice(select_from) for _x in range(0, strlen)]) @@ -1127,7 +1130,7 @@ def pipe_in_out(in_fh, out_fh, chunk_size=1024, chunk_cb=None): bytes_piped = 0 while True: data = in_fh.read(chunk_size) - if data == '': + if len(data) == 0: break else: out_fh.write(data) diff --git a/templates/resolv.conf.tmpl b/templates/resolv.conf.tmpl index 1300156c..bfae80db 100644 --- a/templates/resolv.conf.tmpl +++ b/templates/resolv.conf.tmpl @@ -24,7 +24,7 @@ sortlist {% for sort in sortlist %}{{sort}} {% endfor %} {% if options or flags %} options {% for flag in flags %}{{flag}} {% endfor %} -{% for key, value in options.iteritems() -%} +{% for key, value in options.items() -%} {{key}}:{{value}} {% endfor %} {% endif %} diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 38a2176d..70b8116f 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -65,7 +65,7 @@ if PY26: def assertDictContainsSubset(self, expected, actual, msg=None): missing = [] mismatched = [] - for k, v in expected.iteritems(): + for k, v in expected.items(): if k not in actual: missing.append(k) elif actual[k] != v: @@ -243,7 +243,7 @@ class HttprettyTestCase(TestCase): def populate_dir(path, files): if not os.path.exists(path): os.makedirs(path) - for (name, content) in files.iteritems(): + for (name, content) in files.items(): with open(os.path.join(path, name), "w") as fp: fp.write(content) fp.close() diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index 800c5fd8..258c68e2 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -338,7 +338,7 @@ def populate_ds_from_read_config(cfg_ds, source, results): def populate_dir(seed_dir, files): - for (name, content) in files.iteritems(): + for (name, content) in files.items(): path = os.path.join(seed_dir, name) dirname = os.path.dirname(path) if not os.path.isdir(dirname): diff --git a/tests/unittests/test_datasource/test_digitalocean.py b/tests/unittests/test_datasource/test_digitalocean.py index d1270fc2..98f9cfac 100644 --- a/tests/unittests/test_datasource/test_digitalocean.py +++ b/tests/unittests/test_datasource/test_digitalocean.py @@ -18,8 +18,7 @@ import httpretty import re -from types import ListType -from urlparse import urlparse +from six.moves.urllib_parse import urlparse from cloudinit import settings from cloudinit import helpers @@ -110,7 +109,7 @@ class TestDataSourceDigitalOcean(test_helpers.HttprettyTestCase): self.assertEqual([DO_META.get('public-keys')], self.ds.get_public_ssh_keys()) - self.assertIs(type(self.ds.get_public_ssh_keys()), ListType) + self.assertIsInstance(self.ds.get_public_ssh_keys(), list) @httpretty.activate def test_multiple_ssh_keys(self): @@ -124,4 +123,4 @@ class TestDataSourceDigitalOcean(test_helpers.HttprettyTestCase): self.assertEqual(DO_META.get('public-keys').splitlines(), self.ds.get_public_ssh_keys()) - self.assertIs(type(self.ds.get_public_ssh_keys()), ListType) + self.assertIsInstance(self.ds.get_public_ssh_keys(), list) diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 06050bb1..aa60eb33 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -19,7 +19,7 @@ import httpretty import re from base64 import b64encode, b64decode -from urlparse import urlparse +from six.moves.urllib_parse import urlparse from cloudinit import settings from cloudinit import helpers diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index ddf77265..b79237f0 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -294,7 +294,7 @@ class TestParseShellConfig(unittest.TestCase): def populate_context_dir(path, variables): data = "# Context variables generated by OpenNebula\n" - for (k, v) in variables.iteritems(): + for (k, v) in variables.items(): data += ("%s='%s'\n" % (k.upper(), v.replace(r"'", r"'\''"))) populate_dir(path, {'context.sh': data}) diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 35d7ef5e..01b9b73e 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -22,6 +22,8 @@ # return responses. # +from __future__ import print_function + import base64 from cloudinit import helpers as c_helpers from cloudinit.sources import DataSourceSmartOS @@ -369,7 +371,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): permissions = oct(os.stat(name_f)[stat.ST_MODE])[-3:] if re.match(r'.*\/mdata-user-data$', name_f): found_new = True - print name_f + print(name_f) self.assertEquals(permissions, '400') self.assertFalse(found_new) diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index 2c3dad72..d72fa8c7 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -62,7 +62,7 @@ class TestAptProxyConfig(unittest.TestCase): contents = str(util.read_file_or_url(self.pfile)) - for ptype, pval in values.iteritems(): + for ptype, pval in values.items(): self.assertTrue(self._search_apt_config(contents, ptype, pval)) def test_proxy_deleted(self): diff --git a/tests/unittests/test_merging.py b/tests/unittests/test_merging.py index 07b610f7..976d8283 100644 --- a/tests/unittests/test_merging.py +++ b/tests/unittests/test_merging.py @@ -11,11 +11,13 @@ import glob import os import random import re +import six import string SOURCE_PAT = "source*.*yaml" EXPECTED_PAT = "expected%s.yaml" -TYPES = [long, int, dict, str, list, tuple, None] +TYPES = [dict, str, list, tuple, None] +TYPES.extend(six.integer_types) def _old_mergedict(src, cand): @@ -25,7 +27,7 @@ def _old_mergedict(src, cand): Nested dictionaries are merged recursively. """ if isinstance(src, dict) and isinstance(cand, dict): - for (k, v) in cand.iteritems(): + for (k, v) in cand.items(): if k not in src: src[k] = v else: @@ -42,8 +44,8 @@ def _old_mergemanydict(*args): def _random_str(rand): base = '' - for _i in xrange(rand.randint(1, 2 ** 8)): - base += rand.choice(string.letters + string.digits) + for _i in range(rand.randint(1, 2 ** 8)): + base += rand.choice(string.ascii_letters + string.digits) return base @@ -64,7 +66,7 @@ def _make_dict(current_depth, max_depth, rand): if t in [dict, list, tuple]: if t in [dict]: amount = rand.randint(0, 5) - keys = [_random_str(rand) for _i in xrange(0, amount)] + keys = [_random_str(rand) for _i in range(0, amount)] base = {} for k in keys: try: @@ -74,14 +76,14 @@ def _make_dict(current_depth, max_depth, rand): elif t in [list, tuple]: base = [] amount = rand.randint(0, 5) - for _i in xrange(0, amount): + for _i in range(0, amount): try: base.append(_make_dict(current_depth + 1, max_depth, rand)) except _NoMoreException: pass if t in [tuple]: base = tuple(base) - elif t in [long, int]: + elif t in six.integer_types: base = rand.randint(0, 2 ** 8) elif t in [str]: base = _random_str(rand) diff --git a/tools/ccfg-merge-debug b/tools/ccfg-merge-debug index 85227da7..1f08e0cb 100755 --- a/tools/ccfg-merge-debug +++ b/tools/ccfg-merge-debug @@ -51,7 +51,7 @@ def main(): c_handlers.register(ccph) called = [] - for (_ctype, mod) in c_handlers.iteritems(): + for (_ctype, mod) in c_handlers.items(): if mod in called: continue handlers.call_begin(mod, data, frequency) @@ -76,7 +76,7 @@ def main(): # Give callbacks opportunity to finalize called = [] - for (_ctype, mod) in c_handlers.iteritems(): + for (_ctype, mod) in c_handlers.items(): if mod in called: continue handlers.call_end(mod, data, frequency) -- cgit v1.2.3 From 841db73600e3f203243c773109d71ab88d3334bc Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 11:14:06 -0500 Subject: More test repairs. --- cloudinit/distros/__init__.py | 2 +- cloudinit/user_data.py | 9 +++++++ tests/unittests/helpers.py | 12 ++++++--- tests/unittests/test_builtin_handlers.py | 1 - tests/unittests/test_datasource/test_azure.py | 31 +++++++++++++--------- tests/unittests/test_datasource/test_gce.py | 2 +- tests/unittests/test_datasource/test_opennebula.py | 10 +++++-- tests/unittests/test_datasource/test_smartos.py | 12 ++++++--- tests/unittests/test_filters/test_launch_index.py | 8 +++--- tests/unittests/test_handler/test_handler_chef.py | 3 ++- .../test_handler/test_handler_seed_random.py | 11 ++++++-- tests/unittests/test_util.py | 6 ++--- 12 files changed, 73 insertions(+), 34 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 00fb95fb..ab874b45 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -857,7 +857,7 @@ def extract_default(users, default_name=None, default_config=None): if not tmp_users: return (default_name, default_config) else: - name = tmp_users.keys()[0] + name = list(tmp_users)[0] config = tmp_users[name] config.pop('default', None) return (name, config) diff --git a/cloudinit/user_data.py b/cloudinit/user_data.py index 9111bd39..ff21259c 100644 --- a/cloudinit/user_data.py +++ b/cloudinit/user_data.py @@ -109,6 +109,15 @@ class UserDataProcessor(object): ctype = None ctype_orig = part.get_content_type() payload = part.get_payload(decode=True) + # In Python 3, decoding the payload will ironically hand us a + # bytes object. 'decode' means to decode according to + # Content-Transfer-Encoding, not according to any charset in the + # Content-Type. So, if we end up with bytes, first try to decode + # to str via CT charset, and failing that, try utf-8 using + # surrogate escapes. + if six.PY3 and isinstance(payload, bytes): + charset = part.get_charset() or 'utf-8' + payload = payload.decode(charset, errors='surrogateescape') was_compressed = False # When the message states it is of a gzipped content type ensure diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 70b8116f..4b8dcc5c 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -1,8 +1,11 @@ import os import sys +import shutil import tempfile import unittest +import six + try: from unittest import mock except ImportError: @@ -15,8 +18,6 @@ except ImportError: from cloudinit import helpers as ch from cloudinit import util -import shutil - # Used for detecting different python versions PY2 = False PY26 = False @@ -115,7 +116,12 @@ def retarget_many_wrapper(new_base, am, old_func): nam = len(n_args) for i in range(0, nam): path = args[i] - n_args[i] = rebase_path(path, new_base) + # patchOS() wraps various os and os.path functions, however in + # Python 3 some of these now accept file-descriptors (integers). + # That breaks rebase_path() so in lieu of a better solution, just + # don't rebase if we get a fd. + if isinstance(path, six.string_types): + n_args[i] = rebase_path(path, new_base) return old_func(*n_args, **kwds) return wrapper diff --git a/tests/unittests/test_builtin_handlers.py b/tests/unittests/test_builtin_handlers.py index 47ff6318..ad32d0b2 100644 --- a/tests/unittests/test_builtin_handlers.py +++ b/tests/unittests/test_builtin_handlers.py @@ -21,7 +21,6 @@ from cloudinit.settings import (PER_ALWAYS, PER_INSTANCE) class TestBuiltins(test_helpers.FilesystemMockingTestCase): - def test_upstart_frequency_no_out(self): c_root = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, c_root) diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 2dbcd389..1f0330b3 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -22,6 +22,13 @@ import tempfile import unittest +def b64(source): + # In Python 3, b64encode only accepts bytes and returns bytes. + if not isinstance(source, bytes): + source = source.encode('utf-8') + return base64.b64encode(source).decode('us-ascii') + + def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): if data is None: data = {'HostName': 'FOOHOST'} @@ -51,7 +58,7 @@ def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): content += "<%s%s>%s\n" % (key, attrs, val, key) if userdata: - content += "%s\n" % (base64.b64encode(userdata)) + content += "%s\n" % (b64(userdata)) if pubkeys: content += "\n" @@ -181,7 +188,7 @@ class TestAzureDataSource(unittest.TestCase): # set dscfg in via base64 encoded yaml cfg = {'agent_command': "my_command"} odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': base64.b64encode(yaml.dump(cfg)), + 'dscfg': {'text': b64(yaml.dump(cfg)), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} @@ -233,13 +240,13 @@ class TestAzureDataSource(unittest.TestCase): def test_userdata_found(self): mydata = "FOOBAR" - odata = {'UserData': base64.b64encode(mydata)} + odata = {'UserData': b64(mydata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertEqual(dsrc.userdata_raw, mydata) + self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8')) def test_no_datasource_expected(self): # no source should be found if no seed_dir and no devs @@ -281,7 +288,7 @@ class TestAzureDataSource(unittest.TestCase): 'command': 'my-bounce-command', 'hostname_command': 'my-hostname-command'}} odata = {'HostName': "xhost", - 'dscfg': {'text': base64.b64encode(yaml.dump(cfg)), + 'dscfg': {'text': b64(yaml.dump(cfg)), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} self._get_ds(data).get_data() @@ -296,7 +303,7 @@ class TestAzureDataSource(unittest.TestCase): # config specifying set_hostname off should not bounce cfg = {'set_hostname': False} odata = {'HostName': "xhost", - 'dscfg': {'text': base64.b64encode(yaml.dump(cfg)), + 'dscfg': {'text': b64(yaml.dump(cfg)), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} self._get_ds(data).get_data() @@ -325,7 +332,7 @@ class TestAzureDataSource(unittest.TestCase): # Make sure that user can affect disk aliases dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}} odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': base64.b64encode(yaml.dump(dscfg)), + 'dscfg': {'text': b64(yaml.dump(dscfg)), 'encoding': 'base64'}} usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'}, 'ephemeral0': False}} @@ -347,7 +354,7 @@ class TestAzureDataSource(unittest.TestCase): dsrc = self._get_ds(data) dsrc.get_data() - self.assertEqual(userdata, dsrc.userdata_raw) + self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw) def test_ovf_env_arrives_in_waagent_dir(self): xml = construct_valid_ovf_env(data={}, userdata="FOODATA") @@ -362,7 +369,7 @@ class TestAzureDataSource(unittest.TestCase): def test_existing_ovf_same(self): # waagent/SharedConfig left alone if found ovf-env.xml same as cached - odata = {'UserData': base64.b64encode("SOMEUSERDATA")} + odata = {'UserData': b64("SOMEUSERDATA")} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} populate_dir(self.waagent_d, @@ -386,9 +393,9 @@ class TestAzureDataSource(unittest.TestCase): # 'get_data' should remove SharedConfig.xml in /var/lib/waagent # if ovf-env.xml differs. cached_ovfenv = construct_valid_ovf_env( - {'userdata': base64.b64encode("FOO_USERDATA")}) + {'userdata': b64("FOO_USERDATA")}) new_ovfenv = construct_valid_ovf_env( - {'userdata': base64.b64encode("NEW_USERDATA")}) + {'userdata': b64("NEW_USERDATA")}) populate_dir(self.waagent_d, {'ovf-env.xml': cached_ovfenv, @@ -398,7 +405,7 @@ class TestAzureDataSource(unittest.TestCase): dsrc = self._get_ds({'ovfcontent': new_ovfenv}) ret = dsrc.get_data() self.assertTrue(ret) - self.assertEqual(dsrc.userdata_raw, "NEW_USERDATA") + self.assertEqual(dsrc.userdata_raw, b"NEW_USERDATA") self.assertTrue(os.path.exists( os.path.join(self.waagent_d, 'otherfile'))) self.assertFalse( diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index aa60eb33..6dd4b5ed 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -45,7 +45,7 @@ GCE_META_ENCODING = { 'instance/id': '12345', 'instance/hostname': 'server.project-baz.local', 'instance/zone': 'baz/bang', - 'instance/attributes/user-data': b64encode('/bin/echo baz\n'), + 'instance/attributes/user-data': b64encode(b'/bin/echo baz\n'), 'instance/attributes/user-data-encoding': 'base64', } diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index b79237f0..1a8d2122 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -10,6 +10,12 @@ import shutil import tempfile import unittest +def b64(source): + # In Python 3, b64encode only accepts bytes and returns bytes. + if not isinstance(source, bytes): + source = source.encode('utf-8') + return b64encode(source).decode('us-ascii') + TEST_VARS = { 'VAR1': 'single', @@ -180,7 +186,7 @@ class TestOpenNebulaDataSource(unittest.TestCase): self.assertEqual(USER_DATA, results['userdata']) def test_user_data_encoding_required_for_decode(self): - b64userdata = b64encode(USER_DATA) + b64userdata = b64(USER_DATA) for k in ('USER_DATA', 'USERDATA'): my_d = os.path.join(self.tmp, k) populate_context_dir(my_d, {k: b64userdata}) @@ -192,7 +198,7 @@ class TestOpenNebulaDataSource(unittest.TestCase): def test_user_data_base64_encoding(self): for k in ('USER_DATA', 'USERDATA'): my_d = os.path.join(self.tmp, k) - populate_context_dir(my_d, {k: b64encode(USER_DATA), + populate_context_dir(my_d, {k: b64(USER_DATA), 'USERDATA_ENCODING': 'base64'}) results = ds.read_context_disk_dir(my_d) diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 01b9b73e..2fb9e1b6 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -36,6 +36,12 @@ import tempfile import stat import uuid +def b64(source): + # In Python 3, b64encode only accepts bytes and returns bytes. + if not isinstance(source, bytes): + source = source.encode('utf-8') + return base64.b64encode(source).decode('us-ascii') + MOCK_RETURNS = { 'hostname': 'test-host', @@ -233,7 +239,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): my_returns = MOCK_RETURNS.copy() my_returns['base64_all'] = "true" for k in ('hostname', 'cloud-init:user-data'): - my_returns[k] = base64.b64encode(my_returns[k]) + my_returns[k] = b64(my_returns[k]) dsrc = self._get_ds(mockdata=my_returns) ret = dsrc.get_data() @@ -254,7 +260,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): my_returns['b64-cloud-init:user-data'] = "true" my_returns['b64-hostname'] = "true" for k in ('hostname', 'cloud-init:user-data'): - my_returns[k] = base64.b64encode(my_returns[k]) + my_returns[k] = b64(my_returns[k]) dsrc = self._get_ds(mockdata=my_returns) ret = dsrc.get_data() @@ -270,7 +276,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): my_returns = MOCK_RETURNS.copy() my_returns['base64_keys'] = 'hostname,ignored' for k in ('hostname',): - my_returns[k] = base64.b64encode(my_returns[k]) + my_returns[k] = b64(my_returns[k]) dsrc = self._get_ds(mockdata=my_returns) ret = dsrc.get_data() diff --git a/tests/unittests/test_filters/test_launch_index.py b/tests/unittests/test_filters/test_launch_index.py index 2f4c2fda..95d24b9b 100644 --- a/tests/unittests/test_filters/test_launch_index.py +++ b/tests/unittests/test_filters/test_launch_index.py @@ -2,7 +2,7 @@ import copy from .. import helpers -import itertools +from six.moves import filterfalse from cloudinit.filters import launch_index from cloudinit import user_data as ud @@ -36,11 +36,9 @@ class TestLaunchFilter(helpers.ResourceUsingTestCase): return False # Do some basic payload checking msg1_msgs = [m for m in msg1.walk()] - msg1_msgs = [m for m in - itertools.ifilterfalse(ud.is_skippable, msg1_msgs)] + msg1_msgs = [m for m in filterfalse(ud.is_skippable, msg1_msgs)] msg2_msgs = [m for m in msg2.walk()] - msg2_msgs = [m for m in - itertools.ifilterfalse(ud.is_skippable, msg2_msgs)] + msg2_msgs = [m for m in filterfalse(ud.is_skippable, msg2_msgs)] for i in range(0, len(msg2_msgs)): m1_msg = msg1_msgs[i] m2_msg = msg2_msgs[i] diff --git a/tests/unittests/test_handler/test_handler_chef.py b/tests/unittests/test_handler/test_handler_chef.py index b06a160c..8ab27911 100644 --- a/tests/unittests/test_handler/test_handler_chef.py +++ b/tests/unittests/test_handler/test_handler_chef.py @@ -11,6 +11,7 @@ from cloudinit.sources import DataSourceNone from .. import helpers as t_help +import six import logging import shutil import tempfile @@ -77,7 +78,7 @@ class TestChef(t_help.FilesystemMockingTestCase): for k, v in cfg['chef'].items(): self.assertIn(v, c) for k, v in cc_chef.CHEF_RB_TPL_DEFAULTS.items(): - if isinstance(v, basestring): + if isinstance(v, six.string_types): self.assertIn(v, c) c = util.load_file(cc_chef.CHEF_FB_PATH) self.assertEqual({}, json.loads(c)) diff --git a/tests/unittests/test_handler/test_handler_seed_random.py b/tests/unittests/test_handler/test_handler_seed_random.py index 579377fb..c2da5ced 100644 --- a/tests/unittests/test_handler/test_handler_seed_random.py +++ b/tests/unittests/test_handler/test_handler_seed_random.py @@ -38,6 +38,13 @@ import logging LOG = logging.getLogger(__name__) +def b64(source): + # In Python 3, b64encode only accepts bytes and returns bytes. + if not isinstance(source, bytes): + source = source.encode('utf-8') + return base64.b64encode(source).decode('us-ascii') + + class TestRandomSeed(t_help.TestCase): def setUp(self): super(TestRandomSeed, self).setUp() @@ -134,7 +141,7 @@ class TestRandomSeed(t_help.TestCase): self.assertEquals("big-toe", contents) def test_append_random_base64(self): - data = base64.b64encode('bubbles') + data = b64('bubbles') cfg = { 'random_seed': { 'file': self._seed_file, @@ -147,7 +154,7 @@ class TestRandomSeed(t_help.TestCase): self.assertEquals("bubbles", contents) def test_append_random_b64(self): - data = base64.b64encode('kit-kat') + data = b64('kit-kat') cfg = { 'random_seed': { 'file': self._seed_file, diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index b1f5d62c..b0207ace 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -119,7 +119,7 @@ class TestWriteFile(unittest.TestCase): # Create file first with basic content with open(path, "wb") as f: - f.write("LINE1\n") + f.write(b"LINE1\n") util.write_file(path, contents, omode="a") self.assertTrue(os.path.exists(path)) @@ -194,7 +194,7 @@ class TestDeleteDirContents(unittest.TestCase): os.mkdir(os.path.join(self.tmp, "new_dir")) f_name = os.path.join(self.tmp, "new_dir", "new_file.txt") with open(f_name, "wb") as f: - f.write("DELETE ME") + f.write(b"DELETE ME") util.delete_dir_contents(self.tmp) @@ -205,7 +205,7 @@ class TestDeleteDirContents(unittest.TestCase): file_name = os.path.join(self.tmp, "new_file.txt") link_name = os.path.join(self.tmp, "new_file_link.txt") with open(file_name, "wb") as f: - f.write("DELETE ME") + f.write(b"DELETE ME") os.symlink(file_name, link_name) util.delete_dir_contents(self.tmp) -- cgit v1.2.3 From 926a3df79a10ede61967c60f48ff0670a36e689a Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 12:41:04 -0500 Subject: More Python 3 test fixes. --- cloudinit/config/cc_write_files.py | 5 +++-- tests/unittests/test_datasource/test_opennebula.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_write_files.py b/cloudinit/config/cc_write_files.py index a73d6f4e..4b03ea91 100644 --- a/cloudinit/config/cc_write_files.py +++ b/cloudinit/config/cc_write_files.py @@ -18,6 +18,7 @@ import base64 import os +import six from cloudinit.settings import PER_INSTANCE from cloudinit import util @@ -25,7 +26,7 @@ from cloudinit import util frequency = PER_INSTANCE DEFAULT_OWNER = "root:root" -DEFAULT_PERMS = 0644 +DEFAULT_PERMS = 0o644 UNKNOWN_ENC = 'text/plain' @@ -79,7 +80,7 @@ def write_files(name, files, log): def decode_perms(perm, default, log): try: - if isinstance(perm, (int, long, float)): + if isinstance(perm, six.integer_types + (float,)): # Just 'downcast' it (if a float) return int(perm) else: diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index 1a8d2122..31c6232f 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -33,7 +33,7 @@ TEST_VARS = { } INVALID_CONTEXT = ';' -USER_DATA = '#cloud-config\napt_upgrade: true' +USER_DATA = b'#cloud-config\napt_upgrade: true' SSH_KEY = 'ssh-rsa AAAAB3NzaC1....sIkJhq8wdX+4I3A4cYbYP ubuntu@server-460-%i' HOSTNAME = 'foo.example.com' PUBLIC_IP = '10.0.0.3' -- cgit v1.2.3 From de5974fe93dd717e0c7ba6de17db3192cc258cff Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 14:31:09 -0500 Subject: * More str/bytes fixes. * Temporarily skip the MAAS tests in py3 since they need to be ported to oauthlib. --- cloudinit/sources/DataSourceOpenNebula.py | 12 +++++++++--- cloudinit/sources/DataSourceSmartOS.py | 15 +++++++++++++-- tests/unittests/test_datasource/test_maas.py | 7 ++++++- tests/unittests/test_datasource/test_opennebula.py | 4 ++-- 4 files changed, 30 insertions(+), 8 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index f9dac29e..691b39f8 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -379,7 +379,8 @@ def read_context_disk_dir(source_dir, asuser=None): raise BrokenContextDiskDir("configured user '%s' " "does not exist", asuser) try: - with open(os.path.join(source_dir, 'context.sh'), 'r') as f: + path = os.path.join(source_dir, 'context.sh') + with open(path, 'r', encoding='utf-8') as f: content = f.read().strip() context = parse_shell_config(content, asuser=asuser) @@ -426,14 +427,19 @@ def read_context_disk_dir(source_dir, asuser=None): context.get('USER_DATA_ENCODING')) if encoding == "base64": try: - results['userdata'] = base64.b64decode(results['userdata']) + userdata = base64.b64decode(results['userdata']) + # In Python 3 we still expect a str, but b64decode will return + # bytes. Convert to str. + if isinstance(userdata, bytes): + userdata = userdata.decode('utf-8') + results['userdata'] = userdata except TypeError: LOG.warn("Failed base64 decoding of userdata") # generate static /etc/network/interfaces # only if there are any required context variables # http://opennebula.org/documentation:rel3.8:cong#network_configuration - for k in context.keys(): + for k in context: if re.match(r'^ETH\d+_IP$', k): (out, _) = util.subp(['/sbin/ip', 'link']) net = OpenNebulaNetwork(out, context) diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 7a975d78..d3ed40c5 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -30,6 +30,7 @@ # Comments with "@datadictionary" are snippets of the definition import base64 +import binascii import os import serial @@ -350,8 +351,18 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, if b64: try: - return base64.b64decode(resp) - except TypeError: + # Generally, we want native strings in the values. Python 3's + # b64decode will return bytes though, so decode them to utf-8 if + # possible. If that fails, return the bytes. + decoded = base64.b64decode(resp) + try: + if isinstance(decoded, bytes): + return decoded.decode('utf-8') + except UnicodeDecodeError: + pass + return decoded + # Bogus input produces different errors in Python 2 and 3; catch both. + except (TypeError, binascii.Error): LOG.warn("Failed base64 decoding key '%s'", noun) return resp diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index 6af0cd82..66fe22ae 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -4,7 +4,11 @@ import shutil import tempfile import unittest -from cloudinit.sources import DataSourceMAAS +# XXX DataSourceMAAS must be ported to oauthlib for Python 3 +import six +if not six.PY3: + from cloudinit.sources import DataSourceMAAS + from cloudinit import url_helper from ..helpers import populate_dir @@ -14,6 +18,7 @@ except ImportError: import mock +@unittest.skipIf(six.PY3, 'DataSourceMAAS must be ported to oauthlib') class TestMAASDataSource(unittest.TestCase): def setUp(self): diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index 31c6232f..ef534bab 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -33,7 +33,7 @@ TEST_VARS = { } INVALID_CONTEXT = ';' -USER_DATA = b'#cloud-config\napt_upgrade: true' +USER_DATA = '#cloud-config\napt_upgrade: true' SSH_KEY = 'ssh-rsa AAAAB3NzaC1....sIkJhq8wdX+4I3A4cYbYP ubuntu@server-460-%i' HOSTNAME = 'foo.example.com' PUBLIC_IP = '10.0.0.3' @@ -300,7 +300,7 @@ class TestParseShellConfig(unittest.TestCase): def populate_context_dir(path, variables): data = "# Context variables generated by OpenNebula\n" - for (k, v) in variables.items(): + for k, v in variables.items(): data += ("%s='%s'\n" % (k.upper(), v.replace(r"'", r"'\''"))) populate_dir(path, {'context.sh': data}) -- cgit v1.2.3 From f5d6d0e6433f12d05676bea03f78d57966c35b0a Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 15:09:48 -0500 Subject: Down to it. --- tests/unittests/test_handler/test_handler_seed_random.py | 10 +++++----- tests/unittests/test_util.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_handler/test_handler_seed_random.py b/tests/unittests/test_handler/test_handler_seed_random.py index c2da5ced..d3f18fa0 100644 --- a/tests/unittests/test_handler/test_handler_seed_random.py +++ b/tests/unittests/test_handler/test_handler_seed_random.py @@ -22,7 +22,7 @@ import base64 import gzip import tempfile -from six import StringIO +from six import BytesIO from cloudinit import cloud from cloudinit import distros @@ -76,7 +76,7 @@ class TestRandomSeed(t_help.TestCase): return def _compress(self, text): - contents = StringIO() + contents = BytesIO() gz_fh = gzip.GzipFile(mode='wb', fileobj=contents) gz_fh.write(text) gz_fh.close() @@ -103,7 +103,7 @@ class TestRandomSeed(t_help.TestCase): self.assertEquals("tiny-tim-was-here", contents) def test_append_random_unknown_encoding(self): - data = self._compress("tiny-toe") + data = self._compress(b"tiny-toe") cfg = { 'random_seed': { 'file': self._seed_file, @@ -115,7 +115,7 @@ class TestRandomSeed(t_help.TestCase): self._get_cloud('ubuntu'), LOG, []) def test_append_random_gzip(self): - data = self._compress("tiny-toe") + data = self._compress(b"tiny-toe") cfg = { 'random_seed': { 'file': self._seed_file, @@ -128,7 +128,7 @@ class TestRandomSeed(t_help.TestCase): self.assertEquals("tiny-toe", contents) def test_append_random_gz(self): - data = self._compress("big-toe") + data = self._compress(b"big-toe") cfg = { 'random_seed': { 'file': self._seed_file, diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index b0207ace..f537d332 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -166,7 +166,7 @@ class TestDeleteDirContents(unittest.TestCase): def test_deletes_files(self): """Single file should be deleted.""" with open(os.path.join(self.tmp, "new_file.txt"), "wb") as f: - f.write("DELETE ME") + f.write(b"DELETE ME") util.delete_dir_contents(self.tmp) -- cgit v1.2.3 From fabff4aec884467729fc372bb67f240752c15511 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 16:37:29 -0500 Subject: Port the MAAS code to oauthlib. --- cloudinit/sources/DataSourceMAAS.py | 56 ++++++++++++++++------------ requirements.txt | 2 +- tests/unittests/test_datasource/test_maas.py | 7 +--- 3 files changed, 35 insertions(+), 30 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index 8f9c81de..39296f08 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -22,10 +22,11 @@ from __future__ import print_function from email.utils import parsedate import errno -import oauth.oauth as oauth +import oauthlib import os import time -import urllib2 + +from six.moves.urllib_request import Request, urlopen from cloudinit import log as logging from cloudinit import sources @@ -274,25 +275,34 @@ def check_seed_contents(content, seed): def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, timestamp=None): - consumer = oauth.OAuthConsumer(consumer_key, consumer_secret) - token = oauth.OAuthToken(token_key, token_secret) - - if timestamp is None: - ts = int(time.time()) - else: - ts = timestamp - - params = { - 'oauth_version': "1.0", - 'oauth_nonce': oauth.generate_nonce(), - 'oauth_timestamp': ts, - 'oauth_token': token.key, - 'oauth_consumer_key': consumer.key, - } - req = oauth.OAuthRequest(http_url=url, parameters=params) - req.sign_request(oauth.OAuthSignatureMethod_PLAINTEXT(), - consumer, token) - return req.to_header() + client = oauthlib.oauth1.Client( + consumer_key, + client_secret=consumer_secret, + resource_owner_key=token_key, + resource_owner_secret=token_secret, + signature_method=oauthlib.SIGNATURE_PLAINTEXT) + uri, signed_headers, body = client.sign(url) + return signed_headers + + ## consumer = oauth.OAuthConsumer(consumer_key, consumer_secret) + ## token = oauth.OAuthToken(token_key, token_secret) + + ## if timestamp is None: + ## ts = int(time.time()) + ## else: + ## ts = timestamp + + ## params = { + ## 'oauth_version': "1.0", + ## 'oauth_nonce': oauth.generate_nonce(), + ## 'oauth_timestamp': ts, + ## 'oauth_token': token.key, + ## 'oauth_consumer_key': consumer.key, + ## } + ## req = oauth.OAuthRequest(http_url=url, parameters=params) + ## req.sign_request(oauth.OAuthSignatureMethod_PLAINTEXT(), + ## consumer, token) + ## return req.to_header() class MAASSeedDirNone(Exception): @@ -359,8 +369,8 @@ if __name__ == "__main__": creds[key] = cfg[key] def geturl(url, headers_cb): - req = urllib2.Request(url, data=None, headers=headers_cb(url)) - return (urllib2.urlopen(req).read()) + req = Request(url, data=None, headers=headers_cb(url)) + return urlopen(req).read() def printurl(url, headers_cb): print("== %s ==\n%s\n" % (url, geturl(url, headers_cb))) diff --git a/requirements.txt b/requirements.txt index 2a12ca3e..19c88857 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ PrettyTable # This one is currently only used by the MAAS datasource. If that # datasource is removed, this is no longer needed -oauth +oauthlib # This one is currently used only by the CloudSigma and SmartOS datasources. # If these datasources are removed, this is no longer needed diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index 66fe22ae..6af0cd82 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -4,11 +4,7 @@ import shutil import tempfile import unittest -# XXX DataSourceMAAS must be ported to oauthlib for Python 3 -import six -if not six.PY3: - from cloudinit.sources import DataSourceMAAS - +from cloudinit.sources import DataSourceMAAS from cloudinit import url_helper from ..helpers import populate_dir @@ -18,7 +14,6 @@ except ImportError: import mock -@unittest.skipIf(six.PY3, 'DataSourceMAAS must be ported to oauthlib') class TestMAASDataSource(unittest.TestCase): def setUp(self): -- cgit v1.2.3 From 5e2b8ef0703eb4582a5a8ba50ae7c83a8294d65a Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 20:02:31 -0500 Subject: Repair the Python 2.6 tests. --- cloudinit/util.py | 18 ++++++++-------- tests/unittests/helpers.py | 25 +++++++++++++++++++--- tests/unittests/test__init__.py | 6 ++++-- tests/unittests/test_cs_util.py | 16 +++++++++++++- tests/unittests/test_datasource/test_azure.py | 7 +++--- tests/unittests/test_datasource/test_cloudsigma.py | 1 + .../unittests/test_datasource/test_configdrive.py | 6 ++++-- tests/unittests/test_datasource/test_maas.py | 5 ++--- tests/unittests/test_datasource/test_nocloud.py | 7 +++--- tests/unittests/test_datasource/test_opennebula.py | 4 ++-- tests/unittests/test_distros/test_netconfig.py | 4 ++-- tests/unittests/test_distros/test_resolv.py | 4 ++-- tests/unittests/test_distros/test_sysconfig.py | 4 ++-- .../test_distros/test_user_data_normalize.py | 7 +++--- .../test_handler/test_handler_apt_configure.py | 3 ++- .../test_handler/test_handler_ca_certs.py | 7 +++--- .../test_handler/test_handler_growpart.py | 3 ++- tests/unittests/test_pathprefix2dict.py | 6 +++--- tests/unittests/test_templating.py | 19 +++++++++++++++- tests/unittests/test_util.py | 15 ++++++------- tox.ini | 10 +++++++++ 21 files changed, 122 insertions(+), 55 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index 32c19ba2..766f8e32 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -2059,23 +2059,23 @@ def _read_dmi_syspath(key): Reads dmi data with from /sys/class/dmi/id """ - dmi_key = "{}/{}".format(DMI_SYS_PATH, key) - LOG.debug("querying dmi data {}".format(dmi_key)) + dmi_key = "{0}/{1}".format(DMI_SYS_PATH, key) + LOG.debug("querying dmi data {0}".format(dmi_key)) try: if not os.path.exists(dmi_key): - LOG.debug("did not find {}".format(dmi_key)) + LOG.debug("did not find {0}".format(dmi_key)) return None key_data = load_file(dmi_key) if not key_data: - LOG.debug("{} did not return any data".format(key)) + LOG.debug("{0} did not return any data".format(key)) return None - LOG.debug("dmi data {} returned {}".format(dmi_key, key_data)) + LOG.debug("dmi data {0} returned {0}".format(dmi_key, key_data)) return key_data.strip() except Exception as e: - logexc(LOG, "failed read of {}".format(dmi_key), e) + logexc(LOG, "failed read of {0}".format(dmi_key), e) return None @@ -2087,10 +2087,10 @@ def _call_dmidecode(key, dmidecode_path): try: cmd = [dmidecode_path, "--string", key] (result, _err) = subp(cmd) - LOG.debug("dmidecode returned '{}' for '{}'".format(result, key)) + LOG.debug("dmidecode returned '{0}' for '{0}'".format(result, key)) return result except OSError as _err: - LOG.debug('failed dmidecode cmd: {}\n{}'.format(cmd, _err.message)) + LOG.debug('failed dmidecode cmd: {0}\n{0}'.format(cmd, _err.message)) return None @@ -2106,7 +2106,7 @@ def read_dmi_data(key): if dmidecode_path: return _call_dmidecode(key, dmidecode_path) - LOG.warn("did not find either path {} or dmidecode command".format( + LOG.warn("did not find either path {0} or dmidecode command".format( DMI_SYS_PATH)) return None diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 828579e8..424d0626 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -39,8 +39,20 @@ else: PY3 = True if PY26: - # For now add these on, taken from python 2.7 + slightly adjusted + # For now add these on, taken from python 2.7 + slightly adjusted. Drop + # all this once Python 2.6 is dropped as a minimum requirement. class TestCase(unittest.TestCase): + def setUp(self): + unittest.TestCase.setUp(self) + self.__all_cleanups = ExitStack() + + def tearDown(self): + self.__all_cleanups.close() + unittest.TestCase.tearDown(self) + + def addCleanup(self, function, *args, **kws): + self.__all_cleanups.callback(function, *args, **kws) + def assertIs(self, expr1, expr2, msg=None): if expr1 is not expr2: standardMsg = '%r is not %r' % (expr1, expr2) @@ -63,6 +75,13 @@ if PY26: standardMsg = standardMsg % (value) self.fail(self._formatMessage(msg, standardMsg)) + def assertIsInstance(self, obj, cls, msg=None): + """Same as self.assertTrue(isinstance(obj, cls)), with a nicer + default message.""" + if not isinstance(obj, cls): + standardMsg = '%s is not an instance of %r' % (repr(obj), cls) + self.fail(self._formatMessage(msg, standardMsg)) + def assertDictContainsSubset(self, expected, actual, msg=None): missing = [] mismatched = [] @@ -126,9 +145,9 @@ def retarget_many_wrapper(new_base, am, old_func): return wrapper -class ResourceUsingTestCase(unittest.TestCase): +class ResourceUsingTestCase(TestCase): def setUp(self): - unittest.TestCase.setUp(self) + TestCase.setUp(self) self.resource_path = None def resourceLocation(self, subname=None): diff --git a/tests/unittests/test__init__.py b/tests/unittests/test__init__.py index f5dc3435..1a307e56 100644 --- a/tests/unittests/test__init__.py +++ b/tests/unittests/test__init__.py @@ -18,6 +18,8 @@ from cloudinit import settings from cloudinit import url_helper from cloudinit import util +from .helpers import TestCase + class FakeModule(handlers.Handler): def __init__(self): @@ -31,10 +33,10 @@ class FakeModule(handlers.Handler): pass -class TestWalkerHandleHandler(unittest.TestCase): +class TestWalkerHandleHandler(TestCase): def setUp(self): - unittest.TestCase.setUp(self) + super(TestWalkerHandleHandler, self).setUp() tmpdir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, tmpdir) diff --git a/tests/unittests/test_cs_util.py b/tests/unittests/test_cs_util.py index 99fac84d..337ac9a0 100644 --- a/tests/unittests/test_cs_util.py +++ b/tests/unittests/test_cs_util.py @@ -1,7 +1,21 @@ +from __future__ import print_function + +import sys import unittest from cloudinit.cs_utils import Cepko +try: + skip = unittest.skip +except AttributeError: + # Python 2.6. Doesn't have to be high fidelity. + def skip(reason): + def decorator(func): + def wrapper(*args, **kws): + print(reason, file=sys.stderr) + return wrapper + return decorator + SERVER_CONTEXT = { "cpu": 1000, @@ -29,7 +43,7 @@ class CepkoMock(Cepko): # 2015-01-22 BAW: This test is completely useless because it only ever tests # the CepkoMock object. Even in its original form, I don't think it ever # touched the underlying Cepko class methods. -@unittest.skip('This test is completely useless') +@skip('This test is completely useless') class CepkoResultTests(unittest.TestCase): def setUp(self): pass diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 1f0330b3..97a53bee 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -1,7 +1,7 @@ from cloudinit import helpers from cloudinit.util import load_file from cloudinit.sources import DataSourceAzure -from ..helpers import populate_dir +from ..helpers import TestCase, populate_dir try: from unittest import mock @@ -84,9 +84,10 @@ def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): return content -class TestAzureDataSource(unittest.TestCase): +class TestAzureDataSource(TestCase): def setUp(self): + super(TestAzureDataSource, self).setUp() self.tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp) @@ -416,7 +417,7 @@ class TestAzureDataSource(unittest.TestCase): load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) -class TestReadAzureOvf(unittest.TestCase): +class TestReadAzureOvf(TestCase): def test_invalid_xml_raises_non_azure_ds(self): invalid_xml = "" + construct_valid_ovf_env(data={}) self.assertRaises(DataSourceAzure.BrokenAzureDataSource, diff --git a/tests/unittests/test_datasource/test_cloudsigma.py b/tests/unittests/test_datasource/test_cloudsigma.py index 306ac7d8..772d189a 100644 --- a/tests/unittests/test_datasource/test_cloudsigma.py +++ b/tests/unittests/test_datasource/test_cloudsigma.py @@ -39,6 +39,7 @@ class CepkoMock(Cepko): class DataSourceCloudSigmaTest(test_helpers.TestCase): def setUp(self): + super(DataSourceCloudSigmaTest, self).setUp() self.datasource = DataSourceCloudSigma.DataSourceCloudSigma("", "", "") self.datasource.is_running_in_cloudsigma = lambda: True self.datasource.cepko = CepkoMock(SERVER_CONTEXT) diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index 258c68e2..fd930877 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -3,7 +3,6 @@ import json import os import shutil import tempfile -import unittest try: from unittest import mock @@ -20,6 +19,9 @@ from cloudinit.sources import DataSourceConfigDrive as ds from cloudinit.sources.helpers import openstack from cloudinit import util +from ..helpers import TestCase + + PUBKEY = u'ssh-rsa AAAAB3NzaC1....sIkJhq8wdX+4I3A4cYbYP ubuntu@server-460\n' EC2_META = { 'ami-id': 'ami-00000001', @@ -70,7 +72,7 @@ CFG_DRIVE_FILES_V2 = { 'openstack/latest/user_data': USER_DATA} -class TestConfigDriveDataSource(unittest.TestCase): +class TestConfigDriveDataSource(TestCase): def setUp(self): super(TestConfigDriveDataSource, self).setUp() diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index 6af0cd82..d25e1adc 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -2,11 +2,10 @@ from copy import copy import os import shutil import tempfile -import unittest from cloudinit.sources import DataSourceMAAS from cloudinit import url_helper -from ..helpers import populate_dir +from ..helpers import TestCase, populate_dir try: from unittest import mock @@ -14,7 +13,7 @@ except ImportError: import mock -class TestMAASDataSource(unittest.TestCase): +class TestMAASDataSource(TestCase): def setUp(self): super(TestMAASDataSource, self).setUp() diff --git a/tests/unittests/test_datasource/test_nocloud.py b/tests/unittests/test_datasource/test_nocloud.py index 480a4012..4f967f58 100644 --- a/tests/unittests/test_datasource/test_nocloud.py +++ b/tests/unittests/test_datasource/test_nocloud.py @@ -1,7 +1,7 @@ from cloudinit import helpers from cloudinit.sources import DataSourceNoCloud from cloudinit import util -from ..helpers import populate_dir +from ..helpers import TestCase, populate_dir import os import yaml @@ -19,9 +19,10 @@ except ImportError: from contextlib2 import ExitStack -class TestNoCloudDataSource(unittest.TestCase): +class TestNoCloudDataSource(TestCase): def setUp(self): + super(TestNoCloudDataSource, self).setUp() self.tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp) self.paths = helpers.Paths({'cloud_dir': self.tmp}) @@ -34,8 +35,6 @@ class TestNoCloudDataSource(unittest.TestCase): self.mocks.enter_context( mock.patch.object(util, 'get_cmdline', return_value=self.cmdline)) - super(TestNoCloudDataSource, self).setUp() - def test_nocloud_seed_dir(self): md = {'instance-id': 'IID', 'dsmode': 'local'} ud = "USER_DATA_HERE" diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index ef534bab..e5a4bd18 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -1,7 +1,7 @@ from cloudinit import helpers from cloudinit.sources import DataSourceOpenNebula as ds from cloudinit import util -from ..helpers import populate_dir +from ..helpers import TestCase, populate_dir from base64 import b64encode import os @@ -46,7 +46,7 @@ CMD_IP_OUT = '''\ ''' -class TestOpenNebulaDataSource(unittest.TestCase): +class TestOpenNebulaDataSource(TestCase): parsed_user = None def setUp(self): diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 91e630ae..6d30c5b8 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -1,5 +1,4 @@ import os -import unittest try: from unittest import mock @@ -11,6 +10,7 @@ except ImportError: from contextlib2 import ExitStack from six import StringIO +from ..helpers import TestCase from cloudinit import distros from cloudinit import helpers @@ -80,7 +80,7 @@ class WriteBuffer(object): return self.buffer.getvalue() -class TestNetCfgDistro(unittest.TestCase): +class TestNetCfgDistro(TestCase): def _get_distro(self, dname): cls = distros.fetch(dname) diff --git a/tests/unittests/test_distros/test_resolv.py b/tests/unittests/test_distros/test_resolv.py index 779b83e3..faaf5b7f 100644 --- a/tests/unittests/test_distros/test_resolv.py +++ b/tests/unittests/test_distros/test_resolv.py @@ -1,7 +1,7 @@ from cloudinit.distros.parsers import resolv_conf import re -import unittest +from ..helpers import TestCase BASE_RESOLVE = ''' @@ -13,7 +13,7 @@ nameserver 10.15.30.92 BASE_RESOLVE = BASE_RESOLVE.strip() -class TestResolvHelper(unittest.TestCase): +class TestResolvHelper(TestCase): def test_parse_same(self): rp = resolv_conf.ResolvConf(BASE_RESOLVE) rp_r = str(rp).strip() diff --git a/tests/unittests/test_distros/test_sysconfig.py b/tests/unittests/test_distros/test_sysconfig.py index f66201b3..03d89a10 100644 --- a/tests/unittests/test_distros/test_sysconfig.py +++ b/tests/unittests/test_distros/test_sysconfig.py @@ -1,13 +1,13 @@ import re -import unittest from cloudinit.distros.parsers.sys_conf import SysConf +from ..helpers import TestCase # Lots of good examples @ # http://content.hccfl.edu/pollock/AUnix1/SysconfigFilesDesc.txt -class TestSysConfHelper(unittest.TestCase): +class TestSysConfHelper(TestCase): # This function was added in 2.7, make it work for 2.6 def assertRegMatches(self, text, regexp): regexp = re.compile(regexp) diff --git a/tests/unittests/test_distros/test_user_data_normalize.py b/tests/unittests/test_distros/test_user_data_normalize.py index b90d6185..e4488e2a 100644 --- a/tests/unittests/test_distros/test_user_data_normalize.py +++ b/tests/unittests/test_distros/test_user_data_normalize.py @@ -1,9 +1,10 @@ -import unittest - from cloudinit import distros from cloudinit import helpers from cloudinit import settings +from ..helpers import TestCase + + bcfg = { 'name': 'bob', 'plain_text_passwd': 'ubuntu', @@ -15,7 +16,7 @@ bcfg = { } -class TestUGNormalize(unittest.TestCase): +class TestUGNormalize(TestCase): def _make_distro(self, dtype, def_user=None): cfg = dict(settings.CFG_BUILTIN) diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index d72fa8c7..6bccff11 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -1,6 +1,7 @@ from cloudinit import util from cloudinit.config import cc_apt_configure +from ..helpers import TestCase import os import re @@ -9,7 +10,7 @@ import tempfile import unittest -class TestAptProxyConfig(unittest.TestCase): +class TestAptProxyConfig(TestCase): def setUp(self): super(TestAptProxyConfig, self).setUp() self.tmp = tempfile.mkdtemp() diff --git a/tests/unittests/test_handler/test_handler_ca_certs.py b/tests/unittests/test_handler/test_handler_ca_certs.py index 97213a0c..a6b9c0fd 100644 --- a/tests/unittests/test_handler/test_handler_ca_certs.py +++ b/tests/unittests/test_handler/test_handler_ca_certs.py @@ -3,6 +3,7 @@ from cloudinit import helpers from cloudinit import util from cloudinit.config import cc_ca_certs +from ..helpers import TestCase import logging import shutil @@ -45,7 +46,7 @@ class TestNoConfig(unittest.TestCase): self.assertEqual(certs_mock.call_count, 0) -class TestConfig(unittest.TestCase): +class TestConfig(TestCase): def setUp(self): super(TestConfig, self).setUp() self.name = "ca-certs" @@ -139,7 +140,7 @@ class TestConfig(unittest.TestCase): self.assertEqual(self.mock_remove.call_count, 1) -class TestAddCaCerts(unittest.TestCase): +class TestAddCaCerts(TestCase): def setUp(self): super(TestAddCaCerts, self).setUp() @@ -241,7 +242,7 @@ class TestUpdateCaCerts(unittest.TestCase): ["update-ca-certificates"], capture=False) -class TestRemoveDefaultCaCerts(unittest.TestCase): +class TestRemoveDefaultCaCerts(TestCase): def setUp(self): super(TestRemoveDefaultCaCerts, self).setUp() diff --git a/tests/unittests/test_handler/test_handler_growpart.py b/tests/unittests/test_handler/test_handler_growpart.py index 89727863..bef0d80d 100644 --- a/tests/unittests/test_handler/test_handler_growpart.py +++ b/tests/unittests/test_handler/test_handler_growpart.py @@ -2,6 +2,7 @@ from cloudinit import cloud from cloudinit import util from cloudinit.config import cc_growpart +from ..helpers import TestCase import errno import logging @@ -72,7 +73,7 @@ class TestDisabled(unittest.TestCase): self.assertEqual(mockobj.call_count, 0) -class TestConfig(unittest.TestCase): +class TestConfig(TestCase): def setUp(self): super(TestConfig, self).setUp() self.name = "growpart" diff --git a/tests/unittests/test_pathprefix2dict.py b/tests/unittests/test_pathprefix2dict.py index 38a56dc2..d38260e6 100644 --- a/tests/unittests/test_pathprefix2dict.py +++ b/tests/unittests/test_pathprefix2dict.py @@ -1,15 +1,15 @@ from cloudinit import util -from .helpers import populate_dir +from .helpers import TestCase, populate_dir import shutil import tempfile -import unittest -class TestPathPrefix2Dict(unittest.TestCase): +class TestPathPrefix2Dict(TestCase): def setUp(self): + TestCase.setUp(self) self.tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp) diff --git a/tests/unittests/test_templating.py b/tests/unittests/test_templating.py index 957467f6..fbad405f 100644 --- a/tests/unittests/test_templating.py +++ b/tests/unittests/test_templating.py @@ -16,6 +16,9 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import print_function + +import sys import six import unittest @@ -24,6 +27,20 @@ import textwrap from cloudinit import templater +try: + skipIf = unittest.skipIf +except AttributeError: + # Python 2.6. Doesn't have to be high fidelity. + def skipIf(condition, reason): + def decorator(func): + def wrapper(*args, **kws): + if condition: + return func(*args, **kws) + else: + print(reason, file=sys.stderr) + return wrapper + return decorator + class TestTemplates(test_helpers.TestCase): def test_render_basic(self): @@ -41,7 +58,7 @@ class TestTemplates(test_helpers.TestCase): out_data = templater.basic_render(in_data, {'b': 2}) self.assertEqual(expected_data.strip(), out_data) - @unittest.skipIf(six.PY3, 'Cheetah is not compatible with Python 3') + @skipIf(six.PY3, 'Cheetah is not compatible with Python 3') def test_detection(self): blob = "## template:cheetah" diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 7a224230..a1bd2c46 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -7,7 +7,6 @@ import shutil import tempfile from . import helpers -import unittest import six try: @@ -38,7 +37,7 @@ class FakeSelinux(object): self.restored.append(path) -class TestGetCfgOptionListOrStr(unittest.TestCase): +class TestGetCfgOptionListOrStr(helpers.TestCase): def test_not_found_no_default(self): """None is returned if key is not found and no default given.""" config = {} @@ -70,7 +69,7 @@ class TestGetCfgOptionListOrStr(unittest.TestCase): self.assertEqual([], result) -class TestWriteFile(unittest.TestCase): +class TestWriteFile(helpers.TestCase): def setUp(self): super(TestWriteFile, self).setUp() self.tmp = tempfile.mkdtemp() @@ -149,7 +148,7 @@ class TestWriteFile(unittest.TestCase): mockobj.assert_called_once_with('selinux') -class TestDeleteDirContents(unittest.TestCase): +class TestDeleteDirContents(helpers.TestCase): def setUp(self): super(TestDeleteDirContents, self).setUp() self.tmp = tempfile.mkdtemp() @@ -215,20 +214,20 @@ class TestDeleteDirContents(unittest.TestCase): self.assertDirEmpty(self.tmp) -class TestKeyValStrings(unittest.TestCase): +class TestKeyValStrings(helpers.TestCase): def test_keyval_str_to_dict(self): expected = {'1': 'one', '2': 'one+one', 'ro': True} cmdline = "1=one ro 2=one+one" self.assertEqual(expected, util.keyval_str_to_dict(cmdline)) -class TestGetCmdline(unittest.TestCase): +class TestGetCmdline(helpers.TestCase): def test_cmdline_reads_debug_env(self): os.environ['DEBUG_PROC_CMDLINE'] = 'abcd 123' self.assertEqual(os.environ['DEBUG_PROC_CMDLINE'], util.get_cmdline()) -class TestLoadYaml(unittest.TestCase): +class TestLoadYaml(helpers.TestCase): mydefault = "7b03a8ebace993d806255121073fed52" def test_simple(self): @@ -335,7 +334,7 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): self._patchIn(new_root) util.ensure_dir(os.path.join('sys', 'class', 'dmi', 'id')) - dmi_key = "/sys/class/dmi/id/{}".format(key) + dmi_key = "/sys/class/dmi/id/{0}".format(key) util.write_file(dmi_key, content) def _no_syspath(self, key, content): diff --git a/tox.ini b/tox.ini index e547c693..d04cd47c 100644 --- a/tox.ini +++ b/tox.ini @@ -11,3 +11,13 @@ deps = nose pep8==1.5.7 pyflakes + +[testenv:py26] +commands = nosetests tests +deps = + contextlib2 + httpretty>=0.7.1 + mock + nose + pep8==1.5.7 + pyflakes -- cgit v1.2.3 From e2cfd12dc79a6350d9f1f3111a15bdb6666ccad7 Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Mon, 26 Jan 2015 20:05:48 -0500 Subject: super() works in all of Python 2.6, 2.7, and 3.4. --- tests/unittests/helpers.py | 6 +++--- tests/unittests/test_data.py | 2 +- tests/unittests/test_datasource/test_smartos.py | 2 +- tests/unittests/test_pathprefix2dict.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 424d0626..f92e7ac2 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -43,7 +43,7 @@ if PY26: # all this once Python 2.6 is dropped as a minimum requirement. class TestCase(unittest.TestCase): def setUp(self): - unittest.TestCase.setUp(self) + super(TestCase, self).setUp() self.__all_cleanups = ExitStack() def tearDown(self): @@ -147,7 +147,7 @@ def retarget_many_wrapper(new_base, am, old_func): class ResourceUsingTestCase(TestCase): def setUp(self): - TestCase.setUp(self) + super(ResourceUsingTestCase, self).setUp() self.resource_path = None def resourceLocation(self, subname=None): @@ -186,7 +186,7 @@ class ResourceUsingTestCase(TestCase): class FilesystemMockingTestCase(ResourceUsingTestCase): def setUp(self): - ResourceUsingTestCase.setUp(self) + super(FilesystemMockingTestCase, self).setUp() self.patched_funcs = ExitStack() def tearDown(self): diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 7598837e..e900faa8 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -44,7 +44,7 @@ class FakeDataSource(sources.DataSource): class TestConsumeUserData(helpers.FilesystemMockingTestCase): def setUp(self): - helpers.FilesystemMockingTestCase.setUp(self) + super(TestConsumeUserData, self).setUp() self._log = None self._log_file = None self._log_handler = None diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 2fb9e1b6..b5ebf94d 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -118,7 +118,7 @@ class MockSerial(object): class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): def setUp(self): - helpers.FilesystemMockingTestCase.setUp(self) + super(TestSmartOSDataSource, self).setUp() self.tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp) diff --git a/tests/unittests/test_pathprefix2dict.py b/tests/unittests/test_pathprefix2dict.py index d38260e6..7089bde6 100644 --- a/tests/unittests/test_pathprefix2dict.py +++ b/tests/unittests/test_pathprefix2dict.py @@ -9,7 +9,7 @@ import tempfile class TestPathPrefix2Dict(TestCase): def setUp(self): - TestCase.setUp(self) + super(TestPathPrefix2Dict, self).setUp() self.tmp = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp) -- cgit v1.2.3 From 6e742d20e9ed56498925c7c850cd5da65d063b4b Mon Sep 17 00:00:00 2001 From: Barry Warsaw Date: Tue, 27 Jan 2015 15:03:52 -0500 Subject: Respond to review: - Refactor both the base64 encoding and decoding into utility functions. Also: - Mechanically fix some other broken untested code. --- cloudinit/config/cc_seed_random.py | 8 +------ cloudinit/config/cc_ssh_authkey_fingerprints.py | 2 +- cloudinit/sources/DataSourceOpenNebula.py | 7 +----- cloudinit/sources/DataSourceSmartOS.py | 11 +-------- cloudinit/util.py | 20 ++++++++++++++++ tests/unittests/test_datasource/test_azure.py | 28 ++++++++-------------- tests/unittests/test_datasource/test_opennebula.py | 11 ++------- tests/unittests/test_datasource/test_smartos.py | 14 ++++------- .../test_handler/test_handler_seed_random.py | 12 ++-------- 9 files changed, 42 insertions(+), 71 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_seed_random.py b/cloudinit/config/cc_seed_random.py index 981e1b08..bb64b0f5 100644 --- a/cloudinit/config/cc_seed_random.py +++ b/cloudinit/config/cc_seed_random.py @@ -38,13 +38,7 @@ def _decode(data, encoding=None): if not encoding or encoding.lower() in ['raw']: return data elif encoding.lower() in ['base64', 'b64']: - # Try to give us a native string in both Python 2 and 3, and remember - # that b64decode() returns bytes in Python 3. - decoded = base64.b64decode(data) - try: - return decoded.decode('utf-8') - except UnicodeDecodeError: - return decoded + return util.b64d(data) elif encoding.lower() in ['gzip', 'gz']: return util.decomp_gzip(data, quiet=False) else: diff --git a/cloudinit/config/cc_ssh_authkey_fingerprints.py b/cloudinit/config/cc_ssh_authkey_fingerprints.py index 51580633..6ce831bc 100644 --- a/cloudinit/config/cc_ssh_authkey_fingerprints.py +++ b/cloudinit/config/cc_ssh_authkey_fingerprints.py @@ -32,7 +32,7 @@ from cloudinit import util def _split_hash(bin_hash): split_up = [] - for i in xrange(0, len(bin_hash), 2): + for i in range(0, len(bin_hash), 2): split_up.append(bin_hash[i:i + 2]) return split_up diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index a0275cda..61709c1b 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -426,12 +426,7 @@ def read_context_disk_dir(source_dir, asuser=None): context.get('USER_DATA_ENCODING')) if encoding == "base64": try: - userdata = base64.b64decode(results['userdata']) - # In Python 3 we still expect a str, but b64decode will return - # bytes. Convert to str. - if isinstance(userdata, bytes): - userdata = userdata.decode('utf-8') - results['userdata'] = userdata + results['userdata'] = util.b64d(results['userdata']) except TypeError: LOG.warn("Failed base64 decoding of userdata") diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index f59ad3d6..9d48beab 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -351,16 +351,7 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, if b64: try: - # Generally, we want native strings in the values. Python 3's - # b64decode will return bytes though, so decode them to utf-8 if - # possible. If that fails, return the bytes. - decoded = base64.b64decode(resp) - try: - if isinstance(decoded, bytes): - return decoded.decode('utf-8') - except UnicodeDecodeError: - pass - return decoded + return util.b64d(resp) # Bogus input produces different errors in Python 2 and 3; catch both. except (TypeError, binascii.Error): LOG.warn("Failed base64 decoding key '%s'", noun) diff --git a/cloudinit/util.py b/cloudinit/util.py index 766f8e32..8916cc11 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -44,6 +44,7 @@ import sys import tempfile import time +from base64 import b64decode, b64encode from six.moves.urllib import parse as urlparse import six @@ -90,6 +91,25 @@ def encode_text(text, encoding='utf-8'): return text return text.encode(encoding) + +def b64d(source): + # Base64 decode some data, accepting bytes or unicode/str, and returning + # str/unicode if the result is utf-8 compatible, otherwise returning bytes. + decoded = b64decode(source) + if isinstance(decoded, bytes): + try: + return decoded.decode('utf-8') + except UnicodeDecodeError: + return decoded + +def b64e(source): + # Base64 encode some data, accepting bytes or unicode/str, and returning + # str/unicode if the result is utf-8 compatible, otherwise returning bytes. + if not isinstance(source, bytes): + source = source.encode('utf-8') + return b64encode(source).decode('utf-8') + + # Path for DMI Data DMI_SYS_PATH = "/sys/class/dmi/id" diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 97a53bee..965bce4b 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -1,5 +1,5 @@ from cloudinit import helpers -from cloudinit.util import load_file +from cloudinit.util import b64e, load_file from cloudinit.sources import DataSourceAzure from ..helpers import TestCase, populate_dir @@ -12,7 +12,6 @@ try: except ImportError: from contextlib2 import ExitStack -import base64 import crypt import os import stat @@ -22,13 +21,6 @@ import tempfile import unittest -def b64(source): - # In Python 3, b64encode only accepts bytes and returns bytes. - if not isinstance(source, bytes): - source = source.encode('utf-8') - return base64.b64encode(source).decode('us-ascii') - - def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): if data is None: data = {'HostName': 'FOOHOST'} @@ -58,7 +50,7 @@ def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): content += "<%s%s>%s\n" % (key, attrs, val, key) if userdata: - content += "%s\n" % (b64(userdata)) + content += "%s\n" % (b64e(userdata)) if pubkeys: content += "\n" @@ -189,7 +181,7 @@ class TestAzureDataSource(TestCase): # set dscfg in via base64 encoded yaml cfg = {'agent_command': "my_command"} odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': b64(yaml.dump(cfg)), + 'dscfg': {'text': b64e(yaml.dump(cfg)), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} @@ -241,7 +233,7 @@ class TestAzureDataSource(TestCase): def test_userdata_found(self): mydata = "FOOBAR" - odata = {'UserData': b64(mydata)} + odata = {'UserData': b64e(mydata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} dsrc = self._get_ds(data) @@ -289,7 +281,7 @@ class TestAzureDataSource(TestCase): 'command': 'my-bounce-command', 'hostname_command': 'my-hostname-command'}} odata = {'HostName': "xhost", - 'dscfg': {'text': b64(yaml.dump(cfg)), + 'dscfg': {'text': b64e(yaml.dump(cfg)), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} self._get_ds(data).get_data() @@ -304,7 +296,7 @@ class TestAzureDataSource(TestCase): # config specifying set_hostname off should not bounce cfg = {'set_hostname': False} odata = {'HostName': "xhost", - 'dscfg': {'text': b64(yaml.dump(cfg)), + 'dscfg': {'text': b64e(yaml.dump(cfg)), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} self._get_ds(data).get_data() @@ -333,7 +325,7 @@ class TestAzureDataSource(TestCase): # Make sure that user can affect disk aliases dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}} odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': b64(yaml.dump(dscfg)), + 'dscfg': {'text': b64e(yaml.dump(dscfg)), 'encoding': 'base64'}} usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'}, 'ephemeral0': False}} @@ -370,7 +362,7 @@ class TestAzureDataSource(TestCase): def test_existing_ovf_same(self): # waagent/SharedConfig left alone if found ovf-env.xml same as cached - odata = {'UserData': b64("SOMEUSERDATA")} + odata = {'UserData': b64e("SOMEUSERDATA")} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} populate_dir(self.waagent_d, @@ -394,9 +386,9 @@ class TestAzureDataSource(TestCase): # 'get_data' should remove SharedConfig.xml in /var/lib/waagent # if ovf-env.xml differs. cached_ovfenv = construct_valid_ovf_env( - {'userdata': b64("FOO_USERDATA")}) + {'userdata': b64e("FOO_USERDATA")}) new_ovfenv = construct_valid_ovf_env( - {'userdata': b64("NEW_USERDATA")}) + {'userdata': b64e("NEW_USERDATA")}) populate_dir(self.waagent_d, {'ovf-env.xml': cached_ovfenv, diff --git a/tests/unittests/test_datasource/test_opennebula.py b/tests/unittests/test_datasource/test_opennebula.py index e5a4bd18..27adf21b 100644 --- a/tests/unittests/test_datasource/test_opennebula.py +++ b/tests/unittests/test_datasource/test_opennebula.py @@ -3,19 +3,12 @@ from cloudinit.sources import DataSourceOpenNebula as ds from cloudinit import util from ..helpers import TestCase, populate_dir -from base64 import b64encode import os import pwd import shutil import tempfile import unittest -def b64(source): - # In Python 3, b64encode only accepts bytes and returns bytes. - if not isinstance(source, bytes): - source = source.encode('utf-8') - return b64encode(source).decode('us-ascii') - TEST_VARS = { 'VAR1': 'single', @@ -186,7 +179,7 @@ class TestOpenNebulaDataSource(TestCase): self.assertEqual(USER_DATA, results['userdata']) def test_user_data_encoding_required_for_decode(self): - b64userdata = b64(USER_DATA) + b64userdata = util.b64e(USER_DATA) for k in ('USER_DATA', 'USERDATA'): my_d = os.path.join(self.tmp, k) populate_context_dir(my_d, {k: b64userdata}) @@ -198,7 +191,7 @@ class TestOpenNebulaDataSource(TestCase): def test_user_data_base64_encoding(self): for k in ('USER_DATA', 'USERDATA'): my_d = os.path.join(self.tmp, k) - populate_context_dir(my_d, {k: b64(USER_DATA), + populate_context_dir(my_d, {k: util.b64e(USER_DATA), 'USERDATA_ENCODING': 'base64'}) results = ds.read_context_disk_dir(my_d) diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index b5ebf94d..8b62b1b1 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -24,9 +24,9 @@ from __future__ import print_function -import base64 from cloudinit import helpers as c_helpers from cloudinit.sources import DataSourceSmartOS +from cloudinit.util import b64e from .. import helpers import os import os.path @@ -36,12 +36,6 @@ import tempfile import stat import uuid -def b64(source): - # In Python 3, b64encode only accepts bytes and returns bytes. - if not isinstance(source, bytes): - source = source.encode('utf-8') - return base64.b64encode(source).decode('us-ascii') - MOCK_RETURNS = { 'hostname': 'test-host', @@ -239,7 +233,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): my_returns = MOCK_RETURNS.copy() my_returns['base64_all'] = "true" for k in ('hostname', 'cloud-init:user-data'): - my_returns[k] = b64(my_returns[k]) + my_returns[k] = b64e(my_returns[k]) dsrc = self._get_ds(mockdata=my_returns) ret = dsrc.get_data() @@ -260,7 +254,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): my_returns['b64-cloud-init:user-data'] = "true" my_returns['b64-hostname'] = "true" for k in ('hostname', 'cloud-init:user-data'): - my_returns[k] = b64(my_returns[k]) + my_returns[k] = b64e(my_returns[k]) dsrc = self._get_ds(mockdata=my_returns) ret = dsrc.get_data() @@ -276,7 +270,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): my_returns = MOCK_RETURNS.copy() my_returns['base64_keys'] = 'hostname,ignored' for k in ('hostname',): - my_returns[k] = b64(my_returns[k]) + my_returns[k] = b64e(my_returns[k]) dsrc = self._get_ds(mockdata=my_returns) ret = dsrc.get_data() diff --git a/tests/unittests/test_handler/test_handler_seed_random.py b/tests/unittests/test_handler/test_handler_seed_random.py index d3f18fa0..0bcdcb31 100644 --- a/tests/unittests/test_handler/test_handler_seed_random.py +++ b/tests/unittests/test_handler/test_handler_seed_random.py @@ -18,7 +18,6 @@ from cloudinit.config import cc_seed_random -import base64 import gzip import tempfile @@ -38,13 +37,6 @@ import logging LOG = logging.getLogger(__name__) -def b64(source): - # In Python 3, b64encode only accepts bytes and returns bytes. - if not isinstance(source, bytes): - source = source.encode('utf-8') - return base64.b64encode(source).decode('us-ascii') - - class TestRandomSeed(t_help.TestCase): def setUp(self): super(TestRandomSeed, self).setUp() @@ -141,7 +133,7 @@ class TestRandomSeed(t_help.TestCase): self.assertEquals("big-toe", contents) def test_append_random_base64(self): - data = b64('bubbles') + data = util.b64e('bubbles') cfg = { 'random_seed': { 'file': self._seed_file, @@ -154,7 +146,7 @@ class TestRandomSeed(t_help.TestCase): self.assertEquals("bubbles", contents) def test_append_random_b64(self): - data = b64('kit-kat') + data = util.b64e('kit-kat') cfg = { 'random_seed': { 'file': self._seed_file, -- cgit v1.2.3 From a4a6702758cf60ecb8742d78e576733dbbdbb9a0 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 10 Feb 2015 20:32:32 +0000 Subject: make bddeb work with python3 or python2 painful, and not perfect, but at this point the output builds on a vivid system python2 (bddeb --python2) or python3. * remove use of cheetah by bddeb in favor of builtin renderer * add '--python2' flag to bddeb and knowledge of python 2 and python3 package names. * read-dependencies can now read test-requirements also. * differenciate from build-requirements and runtime requirements. --- packages/bddeb | 102 ++++++++++++++++------ packages/debian/changelog.in | 2 +- packages/debian/control.in | 29 +++--- packages/debian/rules | 17 ---- packages/debian/rules.in | 19 ++++ test-requirements.txt | 3 +- tests/unittests/helpers.py | 16 ++++ tests/unittests/test_handler/test_handler_chef.py | 10 ++- tests/unittests/test_templating.py | 16 +--- tools/read-dependencies | 10 ++- 10 files changed, 140 insertions(+), 84 deletions(-) delete mode 100755 packages/debian/rules create mode 100755 packages/debian/rules.in (limited to 'tests/unittests') diff --git a/packages/bddeb b/packages/bddeb index 83ca68bb..c4efe264 100755 --- a/packages/bddeb +++ b/packages/bddeb @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python3 import os import shutil @@ -27,23 +27,35 @@ import argparse # Package names that will showup in requires to what we can actually # use in our debian 'control' file, this is a translation of the 'requires' # file pypi package name to a debian/ubuntu package name. -PKG_MP = { - 'argparse': 'python-argparse', - 'cheetah': 'python-cheetah', - 'configobj': 'python-configobj', - 'jinja2': 'python-jinja2', - 'jsonpatch': 'python-jsonpatch | python-json-patch', - 'oauth': 'python-oauth', - 'prettytable': 'python-prettytable', - 'pyserial': 'python-serial', - 'pyyaml': 'python-yaml', - 'requests': 'python-requests', - 'six': 'python-six', +STD_NAMED_PACKAGES = [ + 'configobj', + 'jinja2', + 'jsonpatch', + 'oauthlib', + 'prettytable', + 'requests', + 'six', + 'httpretty', + 'mock', + 'nose', + 'setuptools', +] +NONSTD_NAMED_PACKAGES = { + 'argparse': ('python-argparse', None), + 'contextlib2': ('python-contextlib2', None), + 'cheetah': ('python-cheetah', None), + 'pyserial': ('python-serial', 'python3-serial'), + 'pyyaml': ('python-yaml', 'python3-yaml'), + 'six': ('python-six', 'python3-six'), + 'pep8': ('pep8', 'python3-pep8'), + 'pyflakes': ('pyflakes', 'pyflakes'), } + DEBUILD_ARGS = ["-S", "-d"] -def write_debian_folder(root, version, revno, append_requires=[]): +def write_debian_folder(root, version, revno, pkgmap, + pyver="3", append_requires=[]): deb_dir = util.abs_join(root, 'debian') os.makedirs(deb_dir) @@ -59,25 +71,42 @@ def write_debian_folder(root, version, revno, append_requires=[]): # Write out the control file template cmd = [util.abs_join(find_root(), 'tools', 'read-dependencies')] (stdout, _stderr) = util.subp(cmd) - pkgs = [p.lower().strip() for p in stdout.splitlines()] + pypi_pkgs = [p.lower().strip() for p in stdout.splitlines()] + + (stdout, _stderr) = util.subp(cmd + ['test-requirements.txt']) + pypi_test_pkgs = [p.lower().strip() for p in stdout.splitlines()] # Map to known packages requires = append_requires - for p in pkgs: - tgt_pkg = PKG_MP.get(p) - if not tgt_pkg: - raise RuntimeError(("Do not know how to translate pypi dependency" - " %r to a known package") % (p)) - else: - requires.append(tgt_pkg) + test_requires = [] + lists = ((pypi_pkgs, requires), (pypi_test_pkgs, test_requires)) + for pypilist, target in lists: + for p in pypilist: + if p not in pkgmap: + raise RuntimeError(("Do not know how to translate pypi " + "dependency %r to a known package") % (p)) + elif pkgmap[p]: + target.append(pkgmap[p]) + + if pyver == "3": + python = "python3" + else: + python = "python" templater.render_to_file(util.abs_join(find_root(), 'packages', 'debian', 'control.in'), util.abs_join(deb_dir, 'control'), - params={'requires': requires}) + params={'requires': ','.join(requires), + 'test_requires': ','.join(test_requires), + 'python': python}) + + templater.render_to_file(util.abs_join(find_root(), + 'packages', 'debian', 'rules.in'), + util.abs_join(deb_dir, 'rules'), + params={'python': python, 'pyver': pyver}) # Just copy the following directly - for base_fn in ['dirs', 'copyright', 'compat', 'rules']: + for base_fn in ['dirs', 'copyright', 'compat']: shutil.copy(util.abs_join(find_root(), 'packages', 'debian', base_fn), util.abs_join(deb_dir, base_fn)) @@ -91,12 +120,16 @@ def main(): " (default: %(default)s)"), default=False, action='store_true') - parser.add_argument("--no-cloud-utils", dest="no_cloud_utils", - help=("don't depend on cloud-utils package" + parser.add_argument("--cloud-utils", dest="cloud_utils", + help=("depend on cloud-utils package" " (default: %(default)s)"), default=False, action='store_true') + parser.add_argument("--python2", dest="python2", + help=("build debs for python2 rather than python3"), + default=False, action='store_true') + parser.add_argument("--init-system", dest="init_system", help=("build deb with INIT_SYSTEM=xxx" " (default: %(default)s"), @@ -123,6 +156,18 @@ def main(): if args.verbose: capture = False + pkgmap = {} + for p in NONSTD_NAMED_PACKAGES: + pkgmap[p] = NONSTD_NAMED_PACKAGES[p][int(not args.python2)] + + for p in STD_NAMED_PACKAGES: + if args.python2: + pkgmap[p] = "python-" + p + pyver = "2" + else: + pkgmap[p] = "python3-" + p + pyver = "3" + with util.tempdir() as tdir: cmd = [util.abs_join(find_root(), 'tools', 'read-version')] @@ -153,11 +198,12 @@ def main(): shutil.move(extracted_name, xdir) print("Creating a debian/ folder in %r" % (xdir)) - if not args.no_cloud_utils: + if args.cloud_utils: append_requires=['cloud-utils | cloud-guest-utils'] else: append_requires=[] - write_debian_folder(xdir, version, revno, append_requires) + write_debian_folder(xdir, version, revno, pkgmap, + pyver=pyver, append_requires=append_requires) # The naming here seems to follow some debian standard # so it will whine if it is changed... diff --git a/packages/debian/changelog.in b/packages/debian/changelog.in index e3e94f54..c9affe47 100644 --- a/packages/debian/changelog.in +++ b/packages/debian/changelog.in @@ -1,4 +1,4 @@ -## This is a cheetah template +## template:basic cloud-init (${version}~bzr${revision}-1) UNRELEASED; urgency=low * build diff --git a/packages/debian/control.in b/packages/debian/control.in index 9207e5f4..bd6e3867 100644 --- a/packages/debian/control.in +++ b/packages/debian/control.in @@ -1,4 +1,4 @@ -## This is a cheetah template +## template:basic Source: cloud-init Section: admin Priority: optional @@ -6,31 +6,22 @@ Maintainer: Scott Moser Build-Depends: debhelper (>= 9), dh-python, dh-systemd, - python (>= 2.6.6-3~), - python-nose, pyflakes, - python-setuptools, - python-selinux, - python-cheetah, - python-mocker, - python-httpretty, -#for $r in $requires - ${r}, -#end for + ${python}, + ${test_requires}, + ${requires} XS-Python-Version: all -Standards-Version: 3.9.3 +Standards-Version: 3.9.6 Package: cloud-init Architecture: all Depends: procps, - python, -#for $r in $requires - ${r}, -#end for - python-software-properties | software-properties-common, - \${misc:Depends}, + ${python}, + ${requires}, + software-properties-common, + ${misc:Depends}, Recommends: sudo -XB-Python-Version: \${python:Versions} +XB-Python-Version: ${python:Versions} Description: Init scripts for cloud instances Cloud instances need special scripts to run during initialisation to retrieve and install ssh keys and to let the user run various scripts. diff --git a/packages/debian/rules b/packages/debian/rules deleted file mode 100755 index 9e0c5ddb..00000000 --- a/packages/debian/rules +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/make -f - -INIT_SYSTEM ?= upstart,systemd -export PYBUILD_INSTALL_ARGS=--init-system=$(INIT_SYSTEM) - -%: - dh $@ --with python2,systemd --buildsystem pybuild - -override_dh_install: - dh_install - install -d debian/cloud-init/etc/rsyslog.d - cp tools/21-cloudinit.conf debian/cloud-init/etc/rsyslog.d/21-cloudinit.conf - -override_dh_auto_test: - # Becuase setup tools didn't copy data... - cp -r tests/data .pybuild/pythonX.Y_2.7/build/tests - http_proxy= dh_auto_test -- --test-nose diff --git a/packages/debian/rules.in b/packages/debian/rules.in new file mode 100755 index 00000000..bb2e1d5c --- /dev/null +++ b/packages/debian/rules.in @@ -0,0 +1,19 @@ +## template:basic +#!/usr/bin/make -f + +INIT_SYSTEM ?= upstart,systemd +PYVER ?= python${pyver} +export PYBUILD_INSTALL_ARGS=--init-system=$(INIT_SYSTEM) + +%: + dh $@ --with $(PYVER),systemd --buildsystem pybuild + +override_dh_install: + dh_install + install -d debian/cloud-init/etc/rsyslog.d + cp tools/21-cloudinit.conf debian/cloud-init/etc/rsyslog.d/21-cloudinit.conf + +override_dh_auto_test: + # Because setup tools didn't copy data... + [ ! -d .pybuild/pythonX.Y_?.?/build/tests ] || cp -r tests/data .pybuild/pythonX.Y_?.?/build/tests + http_proxy= dh_auto_test -- --test-nose diff --git a/test-requirements.txt b/test-requirements.txt index 230f0404..9b3d07c5 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,6 +1,7 @@ httpretty>=0.7.1 mock -mocker nose pep8==1.5.7 pyflakes +contextlib2 +setuptools diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index f92e7ac2..6b9394b3 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -1,3 +1,5 @@ +from __future__ import print_function + import os import sys import shutil @@ -275,3 +277,17 @@ def populate_dir(path, files): with open(os.path.join(path, name), "w") as fp: fp.write(content) fp.close() + +try: + skipIf = unittest.skipIf +except AttributeError: + # Python 2.6. Doesn't have to be high fidelity. + def skipIf(condition, reason): + def decorator(func): + def wrapper(*args, **kws): + if condition: + return func(*args, **kws) + else: + print(reason, file=sys.stderr) + return wrapper + return decorator diff --git a/tests/unittests/test_handler/test_handler_chef.py b/tests/unittests/test_handler/test_handler_chef.py index 8ab27911..edad88cb 100644 --- a/tests/unittests/test_handler/test_handler_chef.py +++ b/tests/unittests/test_handler/test_handler_chef.py @@ -18,6 +18,8 @@ import tempfile LOG = logging.getLogger(__name__) +CLIENT_TEMPL = os.path.sep.join(["templates", "chef_client.rb.tmpl"]) + class TestChef(t_help.FilesystemMockingTestCase): def setUp(self): @@ -41,9 +43,13 @@ class TestChef(t_help.FilesystemMockingTestCase): for d in cc_chef.CHEF_DIRS: self.assertFalse(os.path.isdir(d)) + @t_help.skipIf(not os.path.isfile(CLIENT_TEMPL), + CLIENT_TEMPL + " is not available") def test_basic_config(self): - # This should create a file of the format... """ + test basic config looks sane + + # This should create a file of the format... # Created by cloud-init v. 0.7.6 on Sat, 11 Oct 2014 23:57:21 +0000 log_level :info ssl_verify_mode :verify_none @@ -105,6 +111,8 @@ class TestChef(t_help.FilesystemMockingTestCase): 'c': 'd', }, json.loads(c)) + @t_help.skipIf(not os.path.isfile(CLIENT_TEMPL), + CLIENT_TEMPL + " is not available") def test_template_deletes(self): tpl_file = util.load_file('templates/chef_client.rb.tmpl') self.patchUtils(self.tmp) diff --git a/tests/unittests/test_templating.py b/tests/unittests/test_templating.py index fbad405f..2b821150 100644 --- a/tests/unittests/test_templating.py +++ b/tests/unittests/test_templating.py @@ -27,20 +27,6 @@ import textwrap from cloudinit import templater -try: - skipIf = unittest.skipIf -except AttributeError: - # Python 2.6. Doesn't have to be high fidelity. - def skipIf(condition, reason): - def decorator(func): - def wrapper(*args, **kws): - if condition: - return func(*args, **kws) - else: - print(reason, file=sys.stderr) - return wrapper - return decorator - class TestTemplates(test_helpers.TestCase): def test_render_basic(self): @@ -58,7 +44,7 @@ class TestTemplates(test_helpers.TestCase): out_data = templater.basic_render(in_data, {'b': 2}) self.assertEqual(expected_data.strip(), out_data) - @skipIf(six.PY3, 'Cheetah is not compatible with Python 3') + @test_helpers.skipIf(six.PY3, 'Cheetah is not compatible with Python 3') def test_detection(self): blob = "## template:cheetah" diff --git a/tools/read-dependencies b/tools/read-dependencies index fee3efcf..6a6f3e12 100755 --- a/tools/read-dependencies +++ b/tools/read-dependencies @@ -1,6 +1,7 @@ #!/usr/bin/env python import os +import re import sys if 'CLOUD_INIT_TOP_D' in os.environ: @@ -14,10 +15,15 @@ for fname in ("setup.py", "requirements.txt"): "exist in cloud-init root directory." % fname) sys.exit(1) -with open(os.path.join(topd, "requirements.txt"), "r") as fp: +if len(sys.argv) > 1: + reqfile = sys.argv[1] +else: + reqfile = "requirements.txt" + +with open(os.path.join(topd, reqfile), "r") as fp: for line in fp: if not line.strip() or line.startswith("#"): continue - sys.stdout.write(line) + sys.stdout.write(re.split("[>=.<]*", line)[0].strip() + "\n") sys.exit(0) -- cgit v1.2.3 From 4b578b2d1fbc6af141081c457ffc74c811e788f3 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 10 Feb 2015 15:46:48 -0500 Subject: skip cheetah dependent test --- tests/unittests/test_templating.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_templating.py b/tests/unittests/test_templating.py index 2b821150..cf7c03b0 100644 --- a/tests/unittests/test_templating.py +++ b/tests/unittests/test_templating.py @@ -27,6 +27,12 @@ import textwrap from cloudinit import templater +try: + import Cheetah + HAS_CHEETAH = True +except ImportError: + HAS_CHEETAH = False + class TestTemplates(test_helpers.TestCase): def test_render_basic(self): @@ -44,7 +50,7 @@ class TestTemplates(test_helpers.TestCase): out_data = templater.basic_render(in_data, {'b': 2}) self.assertEqual(expected_data.strip(), out_data) - @test_helpers.skipIf(six.PY3, 'Cheetah is not compatible with Python 3') + @test_helpers.skipIf(not HAS_CHEETAH, 'cheetah renderer not available') def test_detection(self): blob = "## template:cheetah" -- cgit v1.2.3 From db66b4a42e24dc83d9316df14485c8413ac94abe Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 10 Feb 2015 20:50:23 -0500 Subject: pep8 --- cloudinit/util.py | 3 ++- tests/unittests/test_cs_util.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index c998154a..b845adfd 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -101,6 +101,7 @@ def b64d(source): except UnicodeDecodeError: return decoded + def b64e(source): # Base64 encode some data, accepting bytes or unicode/str, and returning # str/unicode if the result is utf-8 compatible, otherwise returning bytes. @@ -116,7 +117,7 @@ def fully_decoded_payload(part): # bytes, first try to decode to str via CT charset, and failing that, try # utf-8 using surrogate escapes. cte_payload = part.get_payload(decode=True) - if ( six.PY3 and + if (six.PY3 and part.get_content_maintype() == 'text' and isinstance(cte_payload, bytes)): charset = part.get_charset() or 'utf-8' diff --git a/tests/unittests/test_cs_util.py b/tests/unittests/test_cs_util.py index 337ac9a0..d7273035 100644 --- a/tests/unittests/test_cs_util.py +++ b/tests/unittests/test_cs_util.py @@ -47,14 +47,14 @@ class CepkoMock(Cepko): class CepkoResultTests(unittest.TestCase): def setUp(self): pass - ## self.mocked = self.mocker.replace("cloudinit.cs_utils.Cepko", - ## spec=CepkoMock, - ## count=False, - ## passthrough=False) - ## self.mocked() - ## self.mocker.result(CepkoMock()) - ## self.mocker.replay() - ## self.c = Cepko() + # self.mocked = self.mocker.replace("cloudinit.cs_utils.Cepko", + # spec=CepkoMock, + # count=False, + # passthrough=False) + # self.mocked() + # self.mocker.result(CepkoMock()) + # self.mocker.replay() + # self.c = Cepko() def test_getitem(self): result = self.c.all() -- cgit v1.2.3 From f5f280cae778bd214b91664f28d9eed997fbcda5 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 10 Feb 2015 20:51:59 -0500 Subject: pep8 --- tests/unittests/test_data.py | 1 - tests/unittests/test_datasource/test_configdrive.py | 1 + tests/unittests/test_handler/test_handler_apt_configure.py | 3 --- 3 files changed, 1 insertion(+), 4 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index e900faa8..e5b227f8 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -359,7 +359,6 @@ p: 1 mockobj.assert_called_once_with( ci.paths.get_ipath("cloud_config"), "", 0o600) - def test_mime_gzip_compressed(self): """Tests that individual message gzip encoding works.""" diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index fd930877..e28bdd84 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -110,6 +110,7 @@ class TestConfigDriveDataSource(TestCase): # and True on its second call. We use a handy generator as # the mock side effect for this. The mocked function returns # what the side effect returns. + def exists_side_effect(): yield False yield True diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index 6bccff11..d8fe9a4f 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -19,9 +19,6 @@ class TestAptProxyConfig(TestCase): self.cfile = os.path.join(self.tmp, "config.cfg") def _search_apt_config(self, contents, ptype, value): - ## print( - ## r"acquire::%s::proxy\s+[\"']%s[\"'];\n" % (ptype, value), - ## contents, "flags=re.IGNORECASE") return re.search( r"acquire::%s::proxy\s+[\"']%s[\"'];\n" % (ptype, value), contents, flags=re.IGNORECASE) -- cgit v1.2.3 From 4cfdde8be624f5dc9a9ec214ea60f9d1f43ee424 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 11 Feb 2015 11:54:20 +0000 Subject: Fix import ordering in test_util.py. --- tests/unittests/test_util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index a1bd2c46..23821521 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -1,21 +1,21 @@ from __future__ import print_function import os -import stat -import yaml import shutil +import stat import tempfile -from . import helpers import six +import yaml + +from cloudinit import importer, util +from . import helpers try: from unittest import mock except ImportError: import mock -from cloudinit import importer -from cloudinit import util class FakeSelinux(object): -- cgit v1.2.3 From abccec1150d6fada29eae9819968e3d4419440ab Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 11 Feb 2015 11:54:48 +0000 Subject: Add helpers for patching open and stdout/stderr. --- tests/unittests/helpers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 6b9394b3..ce77af93 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -254,6 +254,19 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): self.patched_funcs.enter_context( mock.patch.object(mod, f, trap_func)) + def patchOpen(self, new_root): + trap_func = retarget_many_wrapper(new_root, 1, open) + name = 'builtins.open' if PY3 else '__builtin__.open' + self.patched_funcs.enter_context(mock.patch(name, trap_func)) + + def patchStdoutAndStderr(self, stdout=None, stderr=None): + if stdout is not None: + self.patched_funcs.enter_context( + mock.patch.object(sys, 'stdout', stdout)) + if stderr is not None: + self.patched_funcs.enter_context( + mock.patch.object(sys, 'stderr', stderr)) + class HttprettyTestCase(TestCase): # necessary as http_proxy gets in the way of httpretty -- cgit v1.2.3 From 62f5f10c2572585b89f5c837b84015a86e6af357 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 11 Feb 2015 11:55:06 +0000 Subject: Add unittests for util.multi_log. --- tests/unittests/test_util.py | 68 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) (limited to 'tests/unittests') diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 23821521..2772ce58 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -1,5 +1,6 @@ from __future__ import print_function +import logging import os import shutil import stat @@ -377,4 +378,71 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): self.assertFalse(None, util.read_dmi_data("key")) +class TestMultiLog(helpers.FilesystemMockingTestCase): + + def _createConsole(self, root): + os.mkdir(os.path.join(root, 'dev')) + open(os.path.join(root, 'dev', 'console'), 'a').close() + + def setUp(self): + super(TestMultiLog, self).setUp() + self.root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.root) + self.patchOS(self.root) + self.patchUtils(self.root) + self.patchOpen(self.root) + self.stdout = six.StringIO() + self.stderr = six.StringIO() + self.patchStdoutAndStderr(self.stdout, self.stderr) + + def test_stderr_used_by_default(self): + logged_string = 'test stderr output' + util.multi_log(logged_string) + self.assertEqual(logged_string, self.stderr.getvalue()) + + def test_stderr_not_used_if_false(self): + util.multi_log('should not see this', stderr=False) + self.assertEqual('', self.stderr.getvalue()) + + def test_logs_go_to_console_by_default(self): + self._createConsole(self.root) + logged_string = 'something very important' + util.multi_log(logged_string) + self.assertEqual(logged_string, open('/dev/console').read()) + + def test_logs_dont_go_to_stdout_if_console_exists(self): + self._createConsole(self.root) + util.multi_log('something') + self.assertEqual('', self.stdout.getvalue()) + + def test_logs_go_to_stdout_if_console_does_not_exist(self): + logged_string = 'something very important' + util.multi_log(logged_string) + self.assertEqual(logged_string, self.stdout.getvalue()) + + def test_logs_go_to_log_if_given(self): + log = mock.MagicMock() + logged_string = 'something very important' + util.multi_log(logged_string, log=log) + self.assertEqual([((mock.ANY, logged_string), {})], + log.log.call_args_list) + + def test_newlines_stripped_from_log_call(self): + log = mock.MagicMock() + expected_string = 'something very important' + util.multi_log('{0}\n'.format(expected_string), log=log) + self.assertEqual((mock.ANY, expected_string), log.log.call_args[0]) + + def test_log_level_defaults_to_debug(self): + log = mock.MagicMock() + util.multi_log('message', log=log) + self.assertEqual((logging.DEBUG, mock.ANY), log.log.call_args[0]) + + def test_given_log_level_used(self): + log = mock.MagicMock() + log_level = mock.Mock() + util.multi_log('message', log=log, log_level=log_level) + self.assertEqual((log_level, mock.ANY), log.log.call_args[0]) + + # vi: ts=4 expandtab -- cgit v1.2.3 From d5c93d393ed41b33522c2d6f12465b1f12d28e95 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Wed, 11 Feb 2015 12:10:34 -0500 Subject: pep8 --- tests/unittests/test_util.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 2772ce58..b96da663 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -444,5 +444,4 @@ class TestMultiLog(helpers.FilesystemMockingTestCase): util.multi_log('message', log=log, log_level=log_level) self.assertEqual((log_level, mock.ANY), log.log.call_args[0]) - # vi: ts=4 expandtab -- cgit v1.2.3 From 9dde8e6cc29d789e341921ec46b01046f51258f8 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 13 Feb 2015 15:49:40 -0500 Subject: tests/unittests/test_util.py: pep8 --- tests/unittests/test_util.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index b96da663..33c191a9 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -18,7 +18,6 @@ except ImportError: import mock - class FakeSelinux(object): def __init__(self, match_what): -- cgit v1.2.3 From 589ced475c9e200d4645f0b06f7846dae412b194 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 18 Feb 2015 13:30:51 +0000 Subject: Read ovf-env.xml as bytes. This should fix the Azure data source on Python 3, and is appropriate as XML shouldn't really be read as a string. --- cloudinit/sources/DataSourceAzure.py | 4 ++-- tests/unittests/helpers.py | 5 +++-- tests/unittests/test_datasource/test_azure.py | 6 ++++++ 3 files changed, 11 insertions(+), 4 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 444070bb..6e030217 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -137,7 +137,7 @@ class DataSourceAzureNet(sources.DataSource): if found != ddir: cached_ovfenv = util.load_file( - os.path.join(ddir, 'ovf-env.xml'), quiet=True) + os.path.join(ddir, 'ovf-env.xml'), quiet=True, decode=False) if cached_ovfenv != files['ovf-env.xml']: # source was not walinux-agent's datadir, so we have to clean # up so 'wait_for_files' doesn't return early due to stale data @@ -593,7 +593,7 @@ def load_azure_ds_dir(source_dir): if not os.path.isfile(ovf_file): raise NonAzureDataSource("No ovf-env file found") - with open(ovf_file, "r") as fp: + with open(ovf_file, "rb") as fp: contents = fp.read() md, ud, cfg = read_azure_ovf(contents) diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index ce77af93..7516bd02 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -287,10 +287,11 @@ def populate_dir(path, files): if not os.path.exists(path): os.makedirs(path) for (name, content) in files.items(): - with open(os.path.join(path, name), "w") as fp: - fp.write(content) + with open(os.path.join(path, name), "wb") as fp: + fp.write(content.encode('utf-8')) fp.close() + try: skipIf = unittest.skipIf except AttributeError: diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 965bce4b..38d70fcd 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -360,6 +360,12 @@ class TestAzureDataSource(TestCase): self.assertTrue(os.path.exists(ovf_env_path)) self.assertEqual(xml, load_file(ovf_env_path)) + def test_ovf_can_include_unicode(self): + xml = construct_valid_ovf_env(data={}) + xml = u'\ufeff{0}'.format(xml) + dsrc = self._get_ds({'ovfcontent': xml}) + dsrc.get_data() + def test_existing_ovf_same(self): # waagent/SharedConfig left alone if found ovf-env.xml same as cached odata = {'UserData': b64e("SOMEUSERDATA")} -- cgit v1.2.3 From 385b48ae2b82b9c267710d41b9f470104ea248b7 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 20 Feb 2015 10:30:36 +0000 Subject: Add automated tests for CloudStack passwords. --- tests/unittests/test_datasource/test_cloudstack.py | 86 ++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 tests/unittests/test_datasource/test_cloudstack.py (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_cloudstack.py b/tests/unittests/test_datasource/test_cloudstack.py new file mode 100644 index 00000000..959d78ae --- /dev/null +++ b/tests/unittests/test_datasource/test_cloudstack.py @@ -0,0 +1,86 @@ +from cloudinit import helpers +from cloudinit.sources.DataSourceCloudStack import DataSourceCloudStack +from ..helpers import TestCase + +try: + from unittest import mock +except ImportError: + import mock +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + + +class TestCloudStackPasswordFetching(TestCase): + + def setUp(self): + super(TestCloudStackPasswordFetching, self).setUp() + self.patches = ExitStack() + self.addCleanup(self.patches.close) + mod_name = 'cloudinit.sources.DataSourceCloudStack' + self.patches.enter_context(mock.patch('{0}.ec2'.format(mod_name))) + self.patches.enter_context(mock.patch('{0}.uhelp'.format(mod_name))) + + def _set_password_server_response(self, response_string): + http_client = mock.MagicMock() + http_client.HTTPConnection.return_value.sock.recv.return_value = \ + response_string.encode('utf-8') + self.patches.enter_context( + mock.patch('cloudinit.sources.DataSourceCloudStack.http_client', + http_client)) + return http_client + + def test_empty_password_doesnt_create_config(self): + self._set_password_server_response('') + ds = DataSourceCloudStack({}, None, helpers.Paths({})) + ds.get_data() + self.assertEqual({}, ds.get_config_obj()) + + def test_saved_password_doesnt_create_config(self): + self._set_password_server_response('saved_password') + ds = DataSourceCloudStack({}, None, helpers.Paths({})) + ds.get_data() + self.assertEqual({}, ds.get_config_obj()) + + def test_password_sets_password(self): + password = 'SekritSquirrel' + self._set_password_server_response(password) + ds = DataSourceCloudStack({}, None, helpers.Paths({})) + ds.get_data() + self.assertEqual(password, ds.get_config_obj()['password']) + + def test_bad_request_doesnt_stop_ds_from_working(self): + self._set_password_server_response('bad_request') + ds = DataSourceCloudStack({}, None, helpers.Paths({})) + self.assertTrue(ds.get_data()) + + def assertRequestTypesSent(self, http_client, expected_request_types): + request_types = [ + kwargs['headers']['DomU_Request'] + for _, kwargs + in http_client.HTTPConnection.return_value.request.call_args_list] + self.assertEqual(expected_request_types, request_types) + + def test_valid_response_means_password_marked_as_saved(self): + password = 'SekritSquirrel' + http_client = self._set_password_server_response(password) + ds = DataSourceCloudStack({}, None, helpers.Paths({})) + ds.get_data() + self.assertRequestTypesSent(http_client, + ['send_my_password', 'saved_password']) + + def _check_password_not_saved_for(self, response_string): + http_client = self._set_password_server_response(response_string) + ds = DataSourceCloudStack({}, None, helpers.Paths({})) + ds.get_data() + self.assertRequestTypesSent(http_client, ['send_my_password']) + + def test_password_not_saved_if_empty(self): + self._check_password_not_saved_for('') + + def test_password_not_saved_if_already_saved(self): + self._check_password_not_saved_for('saved_password') + + def test_password_not_saved_if_bad_request(self): + self._check_password_not_saved_for('bad_request') -- cgit v1.2.3 From f1ee9275a504c20153b795923b1f51d3005d745c Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 24 Feb 2015 11:58:22 -0500 Subject: use util.decode_binary rather than str, add tests. just seems to make more sense to decode here. Add a test showing the previous failure (testBytesInPayload) And one that should pass (testStringInPayload) Also, add a test for unencoded content in the ovf xml (test_userdata_plain) And explicitly set encoding on another test (test_userdata_found). --- cloudinit/user_data.py | 4 ++-- tests/unittests/test_datasource/test_azure.py | 14 +++++++++++-- tests/unittests/test_udprocess.py | 30 +++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 tests/unittests/test_udprocess.py (limited to 'tests/unittests') diff --git a/cloudinit/user_data.py b/cloudinit/user_data.py index 8fd7fba5..b11894ce 100644 --- a/cloudinit/user_data.py +++ b/cloudinit/user_data.py @@ -336,8 +336,8 @@ def convert_string(raw_data, headers=None): raw_data = '' if not headers: headers = {} - data = util.decomp_gzip(raw_data) - if "mime-version:" in str(data[0:4096]).lower(): + data = util.decode_binary(util.decomp_gzip(raw_data)) + if "mime-version:" in data[0:4096].lower(): msg = email.message_from_string(data) for (key, val) in headers.items(): _replace_header(msg, key, val) diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 38d70fcd..8112c69b 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -1,5 +1,5 @@ from cloudinit import helpers -from cloudinit.util import b64e, load_file +from cloudinit.util import b64e, decode_binary, load_file from cloudinit.sources import DataSourceAzure from ..helpers import TestCase, populate_dir @@ -231,9 +231,19 @@ class TestAzureDataSource(TestCase): self.assertEqual(defuser['passwd'], crypt.crypt(odata['UserPassword'], defuser['passwd'][0:pos])) + def test_userdata_plain(self): + mydata = "FOOBAR" + odata = {'UserData': {'text': mydata, 'encoding': 'plain'}} + data = {'ovfcontent': construct_valid_ovf_env(data=odata)} + + dsrc = self._get_ds(data) + ret = dsrc.get_data() + self.assertTrue(ret) + self.assertEqual(decode_binary(dsrc.userdata_raw), mydata) + def test_userdata_found(self): mydata = "FOOBAR" - odata = {'UserData': b64e(mydata)} + odata = {'UserData': {'text': b64e(mydata), 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} dsrc = self._get_ds(data) diff --git a/tests/unittests/test_udprocess.py b/tests/unittests/test_udprocess.py new file mode 100644 index 00000000..39adbf9d --- /dev/null +++ b/tests/unittests/test_udprocess.py @@ -0,0 +1,30 @@ +from . import helpers + +from six.moves import filterfalse + +from cloudinit import user_data as ud +from cloudinit import util + +def count_messages(root): + am = 0 + for m in root.walk(): + if ud.is_skippable(m): + continue + am += 1 + return am + + +class TestUDProcess(helpers.ResourceUsingTestCase): + + def testBytesInPayload(self): + msg = b'#cloud-config\napt_update: True\n' + ud_proc = ud.UserDataProcessor(self.getCloudPaths()) + message = ud_proc.process(msg) + self.assertTrue(count_messages(message) == 1) + + def testStringInPayload(self): + msg = '#cloud-config\napt_update: True\n' + + ud_proc = ud.UserDataProcessor(self.getCloudPaths()) + message = ud_proc.process(msg) + self.assertTrue(count_messages(message) == 1) -- cgit v1.2.3 From 4770c61cf1a114348c763094c4cfd3d007894b0d Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 24 Feb 2015 16:09:42 -0500 Subject: move recently added test_udprocess tests to test_data, improve a bit explicitly test compressed user-data. --- tests/unittests/test_data.py | 52 +++++++++++++++++++++++++++++++++------ tests/unittests/test_udprocess.py | 30 ---------------------- 2 files changed, 45 insertions(+), 37 deletions(-) delete mode 100644 tests/unittests/test_udprocess.py (limited to 'tests/unittests') diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index e5b227f8..48475515 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -23,6 +23,7 @@ from cloudinit import log from cloudinit.settings import (PER_INSTANCE) from cloudinit import sources from cloudinit import stages +from cloudinit import user_data as ud from cloudinit import util INSTANCE_ID = "i-testing" @@ -39,6 +40,25 @@ class FakeDataSource(sources.DataSource): self.vendordata_raw = vendordata +def count_messages(root): + am = 0 + for m in root.walk(): + if ud.is_skippable(m): + continue + am += 1 + return am + + +def gzip_text(text): + contents = BytesIO() + f = gzip.GzipFile(fileobj=contents, mode='wb') + f.write(util.encode_text(text)) + f.flush() + f.close() + return contents.getvalue() + + + # FIXME: these tests shouldn't be checking log output?? # Weirddddd... class TestConsumeUserData(helpers.FilesystemMockingTestCase): @@ -363,12 +383,7 @@ p: 1 """Tests that individual message gzip encoding works.""" def gzip_part(text): - contents = BytesIO() - f = gzip.GzipFile(fileobj=contents, mode='wb') - f.write(util.encode_text(text)) - f.flush() - f.close() - return MIMEApplication(contents.getvalue(), 'gzip') + return MIMEApplication(gzip_text(text), 'gzip') base_content1 = ''' #cloud-config @@ -405,7 +420,7 @@ c: 4 ci = stages.Init() message = MIMEBase("text", "plain") message.set_payload("Just text") - ci.datasource = FakeDataSource(message.as_string()) + ci.datasource = FakeDataSource(message.as_string().encode()) with mock.patch('cloudinit.util.write_file') as mockobj: log_file = self.capture_log(logging.WARNING) @@ -477,3 +492,26 @@ c: 4 mock.call(outpath, script, 0o700), mock.call(ci.paths.get_ipath("cloud_config"), "", 0o600), ]) + + +class TestUDProcess(helpers.ResourceUsingTestCase): + + def test_bytes_in_userdata(self): + msg = b'#cloud-config\napt_update: True\n' + ud_proc = ud.UserDataProcessor(self.getCloudPaths()) + message = ud_proc.process(msg) + self.assertTrue(count_messages(message) == 1) + + def test_string_in_userdata(self): + msg = '#cloud-config\napt_update: True\n' + + ud_proc = ud.UserDataProcessor(self.getCloudPaths()) + message = ud_proc.process(msg) + self.assertTrue(count_messages(message) == 1) + + def test_compressed_in_userdata(self): + msg = gzip_text('#cloud-config\napt_update: True\n') + + ud_proc = ud.UserDataProcessor(self.getCloudPaths()) + message = ud_proc.process(msg) + self.assertTrue(count_messages(message) == 1) diff --git a/tests/unittests/test_udprocess.py b/tests/unittests/test_udprocess.py deleted file mode 100644 index 39adbf9d..00000000 --- a/tests/unittests/test_udprocess.py +++ /dev/null @@ -1,30 +0,0 @@ -from . import helpers - -from six.moves import filterfalse - -from cloudinit import user_data as ud -from cloudinit import util - -def count_messages(root): - am = 0 - for m in root.walk(): - if ud.is_skippable(m): - continue - am += 1 - return am - - -class TestUDProcess(helpers.ResourceUsingTestCase): - - def testBytesInPayload(self): - msg = b'#cloud-config\napt_update: True\n' - ud_proc = ud.UserDataProcessor(self.getCloudPaths()) - message = ud_proc.process(msg) - self.assertTrue(count_messages(message) == 1) - - def testStringInPayload(self): - msg = '#cloud-config\napt_update: True\n' - - ud_proc = ud.UserDataProcessor(self.getCloudPaths()) - message = ud_proc.process(msg) - self.assertTrue(count_messages(message) == 1) -- cgit v1.2.3 From 8cd5d7b143f882d80d45b1c04bdde1949846d4f1 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Wed, 25 Feb 2015 19:40:33 -0500 Subject: move towards user-data being binary UrlResponse: biggest change... make readurl return bytes, making user know what to do with it. util: add load_tfile_or_url for loading text file or url as read_file_or_url now returns bytes ec2_utils: all meta-data is text, remove non-obvious string translations DigitalOcean: adjust for ec2_utils DataSourceGCE, DataSourceMAAS: user-data is binary other fields are text. openstack.py: read paths without decoding to text. This is ok as paths other than user-data are json, and load_json will handle load_file still returns text, and that is what most things use. --- cloudinit/ec2_utils.py | 14 +++++++++++--- cloudinit/sources/DataSourceDigitalOcean.py | 8 ++++++-- cloudinit/sources/DataSourceGCE.py | 21 ++++++++++++--------- cloudinit/sources/DataSourceMAAS.py | 14 +++++++++++--- cloudinit/sources/helpers/openstack.py | 2 +- cloudinit/url_helper.py | 2 +- cloudinit/util.py | 11 ++++++++--- tests/unittests/helpers.py | 5 ++++- tests/unittests/test_datasource/test_configdrive.py | 15 ++++++++++----- tests/unittests/test_datasource/test_gce.py | 2 +- tests/unittests/test_datasource/test_maas.py | 8 ++++---- tests/unittests/test_datasource/test_nocloud.py | 14 +++++++------- tests/unittests/test_datasource/test_openstack.py | 6 +++--- tests/unittests/test_ec2_util.py | 2 +- .../test_handler/test_handler_apt_configure.py | 12 ++++++------ tests/unittests/test_pathprefix2dict.py | 10 +++++----- 16 files changed, 91 insertions(+), 55 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/ec2_utils.py b/cloudinit/ec2_utils.py index e1ed4091..7cf99186 100644 --- a/cloudinit/ec2_utils.py +++ b/cloudinit/ec2_utils.py @@ -41,6 +41,10 @@ class MetadataLeafDecoder(object): def __call__(self, field, blob): if not blob: return blob + try: + blob = util.decode_binary(blob) + except UnicodeDecodeError: + return blob if self._maybe_json_object(blob): try: # Assume it's json, unless it fails parsing... @@ -69,6 +73,8 @@ class MetadataMaterializer(object): def _parse(self, blob): leaves = {} children = [] + blob = util.decode_binary(blob) + if not blob: return (leaves, children) @@ -117,12 +123,12 @@ class MetadataMaterializer(object): child_url = url_helper.combine_url(base_url, c) if not child_url.endswith("/"): child_url += "/" - child_blob = str(self._caller(child_url)) + child_blob = self._caller(child_url) child_contents[c] = self._materialize(child_blob, child_url) leaf_contents = {} for (field, resource) in leaves.items(): leaf_url = url_helper.combine_url(base_url, resource) - leaf_blob = self._caller(leaf_url).contents + leaf_blob = self._caller(leaf_url) leaf_contents[field] = self._leaf_decoder(field, leaf_blob) joined = {} joined.update(child_contents) @@ -179,11 +185,13 @@ def get_instance_metadata(api_version='latest', caller = functools.partial(util.read_file_or_url, ssl_details=ssl_details, timeout=timeout, retries=retries) + def mcaller(url): + return caller(url).contents try: response = caller(md_url) materializer = MetadataMaterializer(response.contents, - md_url, caller, + md_url, mcaller, leaf_decoder=leaf_decoder) md = materializer.materialize() if not isinstance(md, (dict)): diff --git a/cloudinit/sources/DataSourceDigitalOcean.py b/cloudinit/sources/DataSourceDigitalOcean.py index 76ddaa9d..5d47564d 100644 --- a/cloudinit/sources/DataSourceDigitalOcean.py +++ b/cloudinit/sources/DataSourceDigitalOcean.py @@ -54,9 +54,13 @@ class DataSourceDigitalOcean(sources.DataSource): def get_data(self): caller = functools.partial(util.read_file_or_url, timeout=self.timeout, retries=self.retries) - md = ec2_utils.MetadataMaterializer(str(caller(self.metadata_address)), + + def mcaller(url): + return caller(url).contents + + md = ec2_utils.MetadataMaterializer(mcaller(self.metadata_address), base_url=self.metadata_address, - caller=caller) + caller=mcaller) self.metadata = md.materialize() diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py index 6936c74e..608c07f1 100644 --- a/cloudinit/sources/DataSourceGCE.py +++ b/cloudinit/sources/DataSourceGCE.py @@ -53,15 +53,15 @@ class DataSourceGCE(sources.DataSource): # GCE metadata server requires a custom header since v1 headers = {'X-Google-Metadata-Request': True} - # url_map: (our-key, path, required) + # url_map: (our-key, path, required, is_text) url_map = [ - ('instance-id', 'instance/id', True), - ('availability-zone', 'instance/zone', True), - ('local-hostname', 'instance/hostname', True), - ('public-keys', 'project/attributes/sshKeys', False), - ('user-data', 'instance/attributes/user-data', False), + ('instance-id', 'instance/id', True, True), + ('availability-zone', 'instance/zone', True, True), + ('local-hostname', 'instance/hostname', True, True), + ('public-keys', 'project/attributes/sshKeys', False, True), + ('user-data', 'instance/attributes/user-data', False, False), ('user-data-encoding', 'instance/attributes/user-data-encoding', - False), + False, True), ] # if we cannot resolve the metadata server, then no point in trying @@ -71,13 +71,16 @@ class DataSourceGCE(sources.DataSource): # iterate over url_map keys to get metadata items found = False - for (mkey, path, required) in url_map: + for (mkey, path, required, is_text) in url_map: try: resp = url_helper.readurl(url=self.metadata_address + path, headers=headers) if resp.code == 200: found = True - self.metadata[mkey] = resp.contents + if is_text: + self.metadata[mkey] = util.decode_binary(resp.contents) + else: + self.metadata[mkey] = resp.contents else: if required: msg = "required url %s returned code %s. not GCE" diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index 082cc58f..35c5b5e1 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -36,6 +36,8 @@ from cloudinit import util LOG = logging.getLogger(__name__) MD_VERSION = "2012-03-01" +BINARY_FIELDS = ('user-data',) + class DataSourceMAAS(sources.DataSource): """ @@ -185,7 +187,9 @@ def read_maas_seed_dir(seed_d): md = {} for fname in files: try: - md[fname] = util.load_file(os.path.join(seed_d, fname)) + print("fname: %s / %s" % (fname, fname not in BINARY_FIELDS)) + md[fname] = util.load_file(os.path.join(seed_d, fname), + decode=fname not in BINARY_FIELDS) except IOError as e: if e.errno != errno.ENOENT: raise @@ -218,6 +222,7 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None, 'public-keys': "%s/%s" % (base_url, 'meta-data/public-keys'), 'user-data': "%s/%s" % (base_url, 'user-data'), } + md = {} for name in file_order: url = files.get(name) @@ -238,7 +243,10 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None, timeout=timeout, ssl_details=ssl_details) if resp.ok(): - md[name] = str(resp) + if name in BINARY_FIELDS: + md[name] = resp.contents + else: + md[name] = util.decode_binary(resp.contents) else: LOG.warn(("Fetching from %s resulted in" " an invalid http code %s"), url, resp.code) @@ -263,7 +271,7 @@ def check_seed_contents(content, seed): if len(missing): raise MAASSeedDirMalformed("%s: missing files %s" % (seed, missing)) - userdata = content.get('user-data', "") + userdata = content.get('user-data', b"") md = {} for (key, val) in content.items(): if key == 'user-data': diff --git a/cloudinit/sources/helpers/openstack.py b/cloudinit/sources/helpers/openstack.py index 88c7a198..bd93d22f 100644 --- a/cloudinit/sources/helpers/openstack.py +++ b/cloudinit/sources/helpers/openstack.py @@ -327,7 +327,7 @@ class ConfigDriveReader(BaseReader): return os.path.join(*components) def _path_read(self, path): - return util.load_file(path) + return util.load_file(path, decode=False) def _fetch_available_versions(self): if self._versions is None: diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 62001dff..2d81a062 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -119,7 +119,7 @@ class UrlResponse(object): @property def contents(self): - return self._response.text + return self._response.content @property def url(self): diff --git a/cloudinit/util.py b/cloudinit/util.py index 4fbdf0a9..efbc3c8d 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -739,6 +739,10 @@ def fetch_ssl_details(paths=None): return ssl_details +def load_tfile_or_url(*args, **kwargs): + return(decode_binary(read_file_or_url(*args, **kwargs).contents)) + + def read_file_or_url(url, timeout=5, retries=10, headers=None, data=None, sec_between=1, ssl_details=None, headers_cb=None, exception_cb=None): @@ -750,7 +754,7 @@ def read_file_or_url(url, timeout=5, retries=10, LOG.warn("Unable to post data to file resource %s", url) file_path = url[len("file://"):] try: - contents = load_file(file_path) + contents = load_file(file_path, decode=False) except IOError as e: code = e.errno if e.errno == errno.ENOENT: @@ -806,7 +810,7 @@ def read_seeded(base="", ext="", timeout=5, retries=10, file_retries=0): ud_url = "%s%s%s" % (base, "user-data", ext) md_url = "%s%s%s" % (base, "meta-data", ext) - md_resp = read_file_or_url(md_url, timeout, retries, file_retries) + md_resp = load_tfile_or_url(md_url, timeout, retries, file_retries) md = None if md_resp.ok(): md = load_yaml(md_resp.contents, default={}) @@ -815,6 +819,7 @@ def read_seeded(base="", ext="", timeout=5, retries=10, file_retries=0): ud = None if ud_resp.ok(): ud = ud_resp.contents + print("returning %s (%s)" % (ud_resp.contents.__class__, ud_resp.contents)) return (md, ud) @@ -2030,7 +2035,7 @@ def pathprefix2dict(base, required=None, optional=None, delim=os.path.sep): ret = {} for f in required + optional: try: - ret[f] = load_file(base + delim + f, quiet=False) + ret[f] = load_file(base + delim + f, quiet=False, decode=False) except IOError as e: if e.errno != errno.ENOENT: raise diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 7516bd02..24e1e881 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -288,7 +288,10 @@ def populate_dir(path, files): os.makedirs(path) for (name, content) in files.items(): with open(os.path.join(path, name), "wb") as fp: - fp.write(content.encode('utf-8')) + if isinstance(content, six.binary_type): + fp.write(content) + else: + fp.write(content.encode('utf-8')) fp.close() diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index e28bdd84..83aca505 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -2,6 +2,7 @@ from copy import copy import json import os import shutil +import six import tempfile try: @@ -45,7 +46,7 @@ EC2_META = { 'reservation-id': 'r-iru5qm4m', 'security-groups': ['default'] } -USER_DATA = '#!/bin/sh\necho This is user data\n' +USER_DATA = b'#!/bin/sh\necho This is user data\n' OSTACK_META = { 'availability_zone': 'nova', 'files': [{'content_path': '/content/0000', 'path': '/etc/foo.cfg'}, @@ -56,8 +57,8 @@ OSTACK_META = { 'public_keys': {'mykey': PUBKEY}, 'uuid': 'b0fa911b-69d4-4476-bbe2-1c92bff6535c'} -CONTENT_0 = 'This is contents of /etc/foo.cfg\n' -CONTENT_1 = '# this is /etc/bar/bar.cfg\n' +CONTENT_0 = b'This is contents of /etc/foo.cfg\n' +CONTENT_1 = b'# this is /etc/bar/bar.cfg\n' CFG_DRIVE_FILES_V2 = { 'ec2/2009-04-04/meta-data.json': json.dumps(EC2_META), @@ -346,8 +347,12 @@ def populate_dir(seed_dir, files): dirname = os.path.dirname(path) if not os.path.isdir(dirname): os.makedirs(dirname) - with open(path, "w") as fp: + if isinstance(content, six.text_type): + mode = "w" + else: + mode = "wb" + + with open(path, mode) as fp: fp.write(content) - fp.close() # vi: ts=4 expandtab diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 6dd4b5ed..d28f3b08 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -32,7 +32,7 @@ GCE_META = { 'instance/zone': 'foo/bar', 'project/attributes/sshKeys': 'user:ssh-rsa AA2..+aRD0fyVw== root@server', 'instance/hostname': 'server.project-foo.local', - 'instance/attributes/user-data': '/bin/echo foo\n', + 'instance/attributes/user-data': b'/bin/echo foo\n', } GCE_META_PARTIAL = { diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index d25e1adc..f109bb04 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -26,7 +26,7 @@ class TestMAASDataSource(TestCase): data = {'instance-id': 'i-valid01', 'local-hostname': 'valid01-hostname', - 'user-data': 'valid01-userdata', + 'user-data': b'valid01-userdata', 'public-keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname'} my_d = os.path.join(self.tmp, "valid") @@ -46,7 +46,7 @@ class TestMAASDataSource(TestCase): data = {'instance-id': 'i-valid-extra', 'local-hostname': 'valid-extra-hostname', - 'user-data': 'valid-extra-userdata', 'foo': 'bar'} + 'user-data': b'valid-extra-userdata', 'foo': 'bar'} my_d = os.path.join(self.tmp, "valid_extra") populate_dir(my_d, data) @@ -103,7 +103,7 @@ class TestMAASDataSource(TestCase): 'meta-data/instance-id': 'i-instanceid', 'meta-data/local-hostname': 'test-hostname', 'meta-data/public-keys': 'test-hostname', - 'user-data': 'foodata', + 'user-data': b'foodata', } valid_order = [ 'meta-data/local-hostname', @@ -143,7 +143,7 @@ class TestMAASDataSource(TestCase): userdata, metadata = DataSourceMAAS.read_maas_seed_url( my_seed, header_cb=my_headers_cb, version=my_ver) - self.assertEqual("foodata", userdata) + self.assertEqual(b"foodata", userdata) self.assertEqual(metadata['instance-id'], valid['meta-data/instance-id']) self.assertEqual(metadata['local-hostname'], diff --git a/tests/unittests/test_datasource/test_nocloud.py b/tests/unittests/test_datasource/test_nocloud.py index 4f967f58..85b4c25a 100644 --- a/tests/unittests/test_datasource/test_nocloud.py +++ b/tests/unittests/test_datasource/test_nocloud.py @@ -37,7 +37,7 @@ class TestNoCloudDataSource(TestCase): def test_nocloud_seed_dir(self): md = {'instance-id': 'IID', 'dsmode': 'local'} - ud = "USER_DATA_HERE" + ud = b"USER_DATA_HERE" populate_dir(os.path.join(self.paths.seed_dir, "nocloud"), {'user-data': ud, 'meta-data': yaml.safe_dump(md)}) @@ -92,20 +92,20 @@ class TestNoCloudDataSource(TestCase): data = { 'fs_label': None, 'meta-data': yaml.safe_dump({'instance-id': 'IID'}), - 'user-data': "USER_DATA_RAW", + 'user-data': b"USER_DATA_RAW", } sys_cfg = {'datasource': {'NoCloud': data}} dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths) ret = dsrc.get_data() - self.assertEqual(dsrc.userdata_raw, "USER_DATA_RAW") + self.assertEqual(dsrc.userdata_raw, b"USER_DATA_RAW") self.assertEqual(dsrc.metadata.get('instance-id'), 'IID') self.assertTrue(ret) def test_nocloud_seed_with_vendordata(self): md = {'instance-id': 'IID', 'dsmode': 'local'} - ud = "USER_DATA_HERE" - vd = "THIS IS MY VENDOR_DATA" + ud = b"USER_DATA_HERE" + vd = b"THIS IS MY VENDOR_DATA" populate_dir(os.path.join(self.paths.seed_dir, "nocloud"), {'user-data': ud, 'meta-data': yaml.safe_dump(md), @@ -126,7 +126,7 @@ class TestNoCloudDataSource(TestCase): def test_nocloud_no_vendordata(self): populate_dir(os.path.join(self.paths.seed_dir, "nocloud"), - {'user-data': "ud", 'meta-data': "instance-id: IID\n"}) + {'user-data': b"ud", 'meta-data': "instance-id: IID\n"}) sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}} @@ -134,7 +134,7 @@ class TestNoCloudDataSource(TestCase): dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths) ret = dsrc.get_data() - self.assertEqual(dsrc.userdata_raw, "ud") + self.assertEqual(dsrc.userdata_raw, b"ud") self.assertFalse(dsrc.vendordata) self.assertTrue(ret) diff --git a/tests/unittests/test_datasource/test_openstack.py b/tests/unittests/test_datasource/test_openstack.py index 81ef1546..81411ced 100644 --- a/tests/unittests/test_datasource/test_openstack.py +++ b/tests/unittests/test_datasource/test_openstack.py @@ -49,7 +49,7 @@ EC2_META = { 'public-ipv4': '0.0.0.1', 'reservation-id': 'r-iru5qm4m', } -USER_DATA = '#!/bin/sh\necho This is user data\n' +USER_DATA = b'#!/bin/sh\necho This is user data\n' VENDOR_DATA = { 'magic': '', } @@ -63,8 +63,8 @@ OSTACK_META = { 'public_keys': {'mykey': PUBKEY}, 'uuid': 'b0fa911b-69d4-4476-bbe2-1c92bff6535c', } -CONTENT_0 = 'This is contents of /etc/foo.cfg\n' -CONTENT_1 = '# this is /etc/bar/bar.cfg\n' +CONTENT_0 = b'This is contents of /etc/foo.cfg\n' +CONTENT_1 = b'# this is /etc/bar/bar.cfg\n' OS_FILES = { 'openstack/latest/meta_data.json': json.dumps(OSTACK_META), 'openstack/latest/user_data': USER_DATA, diff --git a/tests/unittests/test_ec2_util.py b/tests/unittests/test_ec2_util.py index 84aa002e..bd43accf 100644 --- a/tests/unittests/test_ec2_util.py +++ b/tests/unittests/test_ec2_util.py @@ -16,7 +16,7 @@ class TestEc2Util(helpers.HttprettyTestCase): body='stuff', status=200) userdata = eu.get_instance_userdata(self.VERSION) - self.assertEquals('stuff', userdata) + self.assertEquals('stuff', userdata.decode('utf-8')) @hp.activate def test_userdata_fetch_fail_not_found(self): diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index d8fe9a4f..02cad8b2 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -30,7 +30,7 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.pfile)) self.assertFalse(os.path.isfile(self.cfile)) - contents = str(util.read_file_or_url(self.pfile)) + contents = util.load_tfile_or_url(self.pfile) self.assertTrue(self._search_apt_config(contents, "http", "myproxy")) def test_apt_http_proxy_written(self): @@ -40,7 +40,7 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.pfile)) self.assertFalse(os.path.isfile(self.cfile)) - contents = str(util.read_file_or_url(self.pfile)) + contents = util.load_tfile_or_url(self.pfile) self.assertTrue(self._search_apt_config(contents, "http", "myproxy")) def test_apt_all_proxy_written(self): @@ -58,7 +58,7 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.pfile)) self.assertFalse(os.path.isfile(self.cfile)) - contents = str(util.read_file_or_url(self.pfile)) + contents = util.load_tfile_or_url(self.pfile) for ptype, pval in values.items(): self.assertTrue(self._search_apt_config(contents, ptype, pval)) @@ -74,7 +74,7 @@ class TestAptProxyConfig(TestCase): cc_apt_configure.apply_apt_config({'apt_proxy': "foo"}, self.pfile, self.cfile) self.assertTrue(os.path.isfile(self.pfile)) - contents = str(util.read_file_or_url(self.pfile)) + contents = util.load_tfile_or_url(self.pfile) self.assertTrue(self._search_apt_config(contents, "http", "foo")) def test_config_written(self): @@ -86,14 +86,14 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.cfile)) self.assertFalse(os.path.isfile(self.pfile)) - self.assertEqual(str(util.read_file_or_url(self.cfile)), payload) + self.assertEqual(util.load_tfile_or_url(self.cfile), payload) def test_config_replaced(self): util.write_file(self.pfile, "content doesnt matter") cc_apt_configure.apply_apt_config({'apt_config': "foo"}, self.pfile, self.cfile) self.assertTrue(os.path.isfile(self.cfile)) - self.assertEqual(str(util.read_file_or_url(self.cfile)), "foo") + self.assertEqual(util.load_tfile_or_url(self.cfile), "foo") def test_config_deleted(self): # if no 'apt_config' is provided, delete any previously written file diff --git a/tests/unittests/test_pathprefix2dict.py b/tests/unittests/test_pathprefix2dict.py index 7089bde6..38fd75b6 100644 --- a/tests/unittests/test_pathprefix2dict.py +++ b/tests/unittests/test_pathprefix2dict.py @@ -14,28 +14,28 @@ class TestPathPrefix2Dict(TestCase): self.addCleanup(shutil.rmtree, self.tmp) def test_required_only(self): - dirdata = {'f1': 'f1content', 'f2': 'f2content'} + dirdata = {'f1': b'f1content', 'f2': b'f2content'} populate_dir(self.tmp, dirdata) ret = util.pathprefix2dict(self.tmp, required=['f1', 'f2']) self.assertEqual(dirdata, ret) def test_required_missing(self): - dirdata = {'f1': 'f1content'} + dirdata = {'f1': b'f1content'} populate_dir(self.tmp, dirdata) kwargs = {'required': ['f1', 'f2']} self.assertRaises(ValueError, util.pathprefix2dict, self.tmp, **kwargs) def test_no_required_and_optional(self): - dirdata = {'f1': 'f1c', 'f2': 'f2c'} + dirdata = {'f1': b'f1c', 'f2': b'f2c'} populate_dir(self.tmp, dirdata) ret = util.pathprefix2dict(self.tmp, required=None, - optional=['f1', 'f2']) + optional=['f1', 'f2']) self.assertEqual(dirdata, ret) def test_required_and_optional(self): - dirdata = {'f1': 'f1c', 'f2': 'f2c'} + dirdata = {'f1': b'f1c', 'f2': b'f2c'} populate_dir(self.tmp, dirdata) ret = util.pathprefix2dict(self.tmp, required=['f1'], optional=['f2']) -- cgit v1.2.3 From e4399a07a798ea83156f300116d33e39f6b6c19f Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Thu, 26 Feb 2015 15:26:15 +0000 Subject: Fix traceback with no arguments on Python 3. Fixes bug 1424277. --- bin/cloud-init | 2 ++ tests/unittests/test_cli.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 tests/unittests/test_cli.py (limited to 'tests/unittests') diff --git a/bin/cloud-init b/bin/cloud-init index 6c83c2e7..e95fea28 100755 --- a/bin/cloud-init +++ b/bin/cloud-init @@ -609,6 +609,8 @@ def main(): # Setup signal handlers before running signal_handler.attach_handlers() + if not hasattr(args, 'action'): + parser.error('too few arguments') (name, functor) = args.action if name in ("modules", "init"): functor = status_wrapper diff --git a/tests/unittests/test_cli.py b/tests/unittests/test_cli.py new file mode 100644 index 00000000..ba447f87 --- /dev/null +++ b/tests/unittests/test_cli.py @@ -0,0 +1,48 @@ +import imp +import sys + +import six + +from . import helpers as test_helpers + +try: + from unittest import mock +except ImportError: + import mock + + +class TestCLI(test_helpers.FilesystemMockingTestCase): + + def setUp(self): + super(TestCLI, self).setUp() + self.stderr = six.StringIO() + self.patchStdoutAndStderr(stderr=self.stderr) + self.sys_exit = mock.MagicMock() + self.patched_funcs.enter_context( + mock.patch.object(sys, 'exit', self.sys_exit)) + + def _call_main(self): + self.patched_funcs.enter_context( + mock.patch.object(sys, 'argv', ['cloud-init'])) + cli = imp.load_module( + 'cli', open('bin/cloud-init'), '', ('', 'r', imp.PY_SOURCE)) + try: + return cli.main() + except: + pass + + def test_no_arguments_shows_usage(self): + self._call_main() + self.assertIn('usage: cloud-init', self.stderr.getvalue()) + + def test_no_arguments_exits_2(self): + exit_code = self._call_main() + if self.sys_exit.call_count: + self.assertEqual(mock.call(2), self.sys_exit.call_args) + else: + self.assertEqual(2, exit_code) + + def test_no_arguments_shows_error_message(self): + self._call_main() + self.assertIn('cloud-init: error: too few arguments', + self.stderr.getvalue()) -- cgit v1.2.3 From f0388cfffadc0596faeda9e11775597444bff25d Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 2 Mar 2015 15:41:16 -0500 Subject: pep8 --- cloudinit/ec2_utils.py | 1 + cloudinit/stages.py | 3 ++- tests/unittests/test_data.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/ec2_utils.py b/cloudinit/ec2_utils.py index 7cf99186..37b92a83 100644 --- a/cloudinit/ec2_utils.py +++ b/cloudinit/ec2_utils.py @@ -185,6 +185,7 @@ def get_instance_metadata(api_version='latest', caller = functools.partial(util.read_file_or_url, ssl_details=ssl_details, timeout=timeout, retries=retries) + def mcaller(url): return caller(url).contents diff --git a/cloudinit/stages.py b/cloudinit/stages.py index 94fcf4cc..45d64823 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -346,7 +346,8 @@ class Init(object): processed_vd = str(self.datasource.get_vendordata()) if processed_vd is None: processed_vd = '' - util.write_file(self._get_ipath('vendordata'), str(processed_vd), 0o600) + util.write_file(self._get_ipath('vendordata'), str(processed_vd), + 0o600) def _default_handlers(self, opts=None): if opts is None: diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 48475515..8fc280e4 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -58,7 +58,6 @@ def gzip_text(text): return contents.getvalue() - # FIXME: these tests shouldn't be checking log output?? # Weirddddd... class TestConsumeUserData(helpers.FilesystemMockingTestCase): -- cgit v1.2.3 From a934ae9543ccc9c13fbdedddcc04fa82853a7ec2 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 2 Mar 2015 15:56:15 -0500 Subject: get_cmdline_url: fix in python3 when calling get_cmdline_url was passing a string to response.contents.startswith() where response.contents is now bytes. this changes it to convert input to text, and also to default to text. --- cloudinit/util.py | 4 +++- tests/unittests/test__init__.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index 039aa3f2..cc20305c 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -970,7 +970,7 @@ def get_fqdn_from_hosts(hostname, filename="/etc/hosts"): def get_cmdline_url(names=('cloud-config-url', 'url'), - starts="#cloud-config", cmdline=None): + starts=b"#cloud-config", cmdline=None): if cmdline is None: cmdline = get_cmdline() @@ -986,6 +986,8 @@ def get_cmdline_url(names=('cloud-config-url', 'url'), return (None, None, None) resp = read_file_or_url(url) + # allow callers to pass starts as text when comparing to bytes contents + starts = encode_text(starts) if resp.ok() and resp.contents.startswith(starts): return (key, url, resp.contents) diff --git a/tests/unittests/test__init__.py b/tests/unittests/test__init__.py index 1a307e56..c32783a6 100644 --- a/tests/unittests/test__init__.py +++ b/tests/unittests/test__init__.py @@ -181,7 +181,7 @@ class TestCmdlineUrl(unittest.TestCase): def test_invalid_content(self): url = "http://example.com/foo" key = "mykey" - payload = "0" + payload = b"0" cmdline = "ro %s=%s bar=1" % (key, url) with mock.patch('cloudinit.url_helper.readurl', @@ -194,13 +194,13 @@ class TestCmdlineUrl(unittest.TestCase): def test_valid_content(self): url = "http://example.com/foo" key = "mykey" - payload = "xcloud-config\nmydata: foo\nbar: wark\n" + payload = b"xcloud-config\nmydata: foo\nbar: wark\n" cmdline = "ro %s=%s bar=1" % (key, url) with mock.patch('cloudinit.url_helper.readurl', return_value=url_helper.StringResponse(payload)): self.assertEqual( - util.get_cmdline_url(names=[key], starts="xcloud-config", + util.get_cmdline_url(names=[key], starts=b"xcloud-config", cmdline=cmdline), (key, url, payload)) @@ -210,7 +210,7 @@ class TestCmdlineUrl(unittest.TestCase): cmdline = "ro %s=%s bar=1" % (key, url) with mock.patch('cloudinit.url_helper.readurl', - return_value=url_helper.StringResponse('')): + return_value=url_helper.StringResponse(b'')): self.assertEqual( util.get_cmdline_url(names=["does-not-appear"], starts="#cloud-config", cmdline=cmdline), -- cgit v1.2.3 From 8663b57ebba7aa4f6916f53e74df4f890bbc8b9a Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 4 Mar 2015 10:46:27 +0000 Subject: Convert dmidecode values to sysfs names before looking for them. dmidecode and /sys/class/dmi/id/* use different names for the same information. This modified the logic in util.read_dmi_data to map from dmidecode names to sysfs names before looking in sysfs. --- cloudinit/util.py | 62 ++++++++++++++++++++++-------- tests/unittests/test_util.py | 91 ++++++++++++++++++++++++-------------------- 2 files changed, 97 insertions(+), 56 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index cc20305c..f95e71c8 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -128,6 +128,28 @@ def fully_decoded_payload(part): # Path for DMI Data DMI_SYS_PATH = "/sys/class/dmi/id" +# dmidecode and /sys/class/dmi/id/* use different names for the same value, +# this allows us to refer to them by one canonical name +DMIDECODE_TO_DMI_SYS_MAPPING = { + 'baseboard-asset-tag': 'board_asset_tag', + 'baseboard-manufacturer': 'board_vendor', + 'baseboard-product-name': 'board_name', + 'baseboard-serial-number': 'board_serial', + 'baseboard-version': 'board_version', + 'bios-release-date': 'bios_date', + 'bios-vendor': 'bios_vendor', + 'bios-version': 'bios_version', + 'chassis-asset-tag': 'chassis_asset_tag', + 'chassis-manufacturer': 'chassis_vendor', + 'chassis-serial-number': 'chassis_serial', + 'chassis-version': 'chassis_version', + 'system-manufacturer': 'sys_vendor', + 'system-product-name': 'product_name', + 'system-serial-number': 'product_serial', + 'system-uuid': 'product_uuid', + 'system-version': 'product_version', +} + class ProcessExecutionError(IOError): @@ -2103,24 +2125,26 @@ def _read_dmi_syspath(key): """ Reads dmi data with from /sys/class/dmi/id """ - - dmi_key = "{0}/{1}".format(DMI_SYS_PATH, key) - LOG.debug("querying dmi data {0}".format(dmi_key)) + if key not in DMIDECODE_TO_DMI_SYS_MAPPING: + return None + mapped_key = DMIDECODE_TO_DMI_SYS_MAPPING[key] + dmi_key_path = "{0}/{1}".format(DMI_SYS_PATH, mapped_key) + LOG.debug("querying dmi data {0}".format(dmi_key_path)) try: - if not os.path.exists(dmi_key): - LOG.debug("did not find {0}".format(dmi_key)) + if not os.path.exists(dmi_key_path): + LOG.debug("did not find {0}".format(dmi_key_path)) return None - key_data = load_file(dmi_key) + key_data = load_file(dmi_key_path) if not key_data: - LOG.debug("{0} did not return any data".format(key)) + LOG.debug("{0} did not return any data".format(dmi_key_path)) return None - LOG.debug("dmi data {0} returned {0}".format(dmi_key, key_data)) + LOG.debug("dmi data {0} returned {1}".format(dmi_key_path, key_data)) return key_data.strip() except Exception as e: - logexc(LOG, "failed read of {0}".format(dmi_key), e) + logexc(LOG, "failed read of {0}".format(dmi_key_path), e) return None @@ -2134,18 +2158,27 @@ def _call_dmidecode(key, dmidecode_path): (result, _err) = subp(cmd) LOG.debug("dmidecode returned '{0}' for '{0}'".format(result, key)) return result - except OSError as _err: + except (IOError, OSError) as _err: LOG.debug('failed dmidecode cmd: {0}\n{0}'.format(cmd, _err.message)) return None def read_dmi_data(key): """ - Wrapper for reading DMI data. This tries to determine whether the DMI - Data can be read directly, otherwise it will fallback to using dmidecode. + Wrapper for reading DMI data. + + This will do the following (returning the first that produces a + result): + 1) Use a mapping to translate `key` from dmidecode naming to + sysfs naming and look in /sys/class/dmi/... for a value. + 2) Use `key` as a sysfs key directly and look in /sys/class/dmi/... + 3) Fall-back to passing `key` to `dmidecode --string`. + + If all of the above fail to find a value, None will be returned. """ - if os.path.exists(DMI_SYS_PATH): - return _read_dmi_syspath(key) + syspath_value = _read_dmi_syspath(key) + if syspath_value is not None: + return syspath_value dmidecode_path = which('dmidecode') if dmidecode_path: @@ -2153,5 +2186,4 @@ def read_dmi_data(key): LOG.warn("did not find either path {0} or dmidecode command".format( DMI_SYS_PATH)) - return None diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 33c191a9..7da1f755 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -323,58 +323,67 @@ class TestMountinfoParsing(helpers.ResourceUsingTestCase): class TestReadDMIData(helpers.FilesystemMockingTestCase): - def _patchIn(self, root): - self.patchOS(root) - self.patchUtils(root) + def setUp(self): + super(TestReadDMIData, self).setUp() + self.new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.new_root) + self.patchOS(self.new_root) + self.patchUtils(self.new_root) - def _write_key(self, key, content): - """Mocks the sys path found on Linux systems.""" - new_root = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, new_root) - self._patchIn(new_root) + def _create_sysfs_parent_directory(self): util.ensure_dir(os.path.join('sys', 'class', 'dmi', 'id')) + def _create_sysfs_file(self, key, content): + """Mocks the sys path found on Linux systems.""" + self._create_sysfs_parent_directory() dmi_key = "/sys/class/dmi/id/{0}".format(key) util.write_file(dmi_key, content) - def _no_syspath(self, key, content): + def _configure_dmidecode_return(self, key, content, error=None): """ In order to test a missing sys path and call outs to dmidecode, this function fakes the results of dmidecode to test the results. """ - new_root = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, new_root) - self._patchIn(new_root) - self.real_which = util.which - self.real_subp = util.subp - - def _which(key): - return True - util.which = _which - - def _cdd(_key, error=None): + def _dmidecode_subp(cmd): + if cmd[-1] != key: + raise util.ProcessExecutionError() return (content, error) - util.subp = _cdd - - def test_key(self): - key_content = "TEST-KEY-DATA" - self._write_key("key", key_content) - self.assertEquals(key_content, util.read_dmi_data("key")) - - def test_key_mismatch(self): - self._write_key("test", "ABC") - self.assertNotEqual("123", util.read_dmi_data("test")) - - def test_no_key(self): - self._no_syspath(None, None) - self.assertFalse(util.read_dmi_data("key")) - - def test_callout_dmidecode(self): - """test to make sure that dmidecode is used when no syspath""" - self._no_syspath("key", "stuff") - self.assertEquals("stuff", util.read_dmi_data("key")) - self._no_syspath("key", None) - self.assertFalse(None, util.read_dmi_data("key")) + + self.patched_funcs.enter_context( + mock.patch.object(util, 'which', lambda _: True)) + self.patched_funcs.enter_context( + mock.patch.object(util, 'subp', _dmidecode_subp)) + + def patch_mapping(self, new_mapping): + self.patched_funcs.enter_context( + mock.patch('cloudinit.util.DMIDECODE_TO_DMI_SYS_MAPPING', + new_mapping)) + + def test_sysfs_used_with_key_in_mapping_and_file_on_disk(self): + self.patch_mapping({'mapped-key': 'mapped-value'}) + expected_dmi_value = 'sys-used-correctly' + self._create_sysfs_file('mapped-value', expected_dmi_value) + self._configure_dmidecode_return('mapped-key', 'wrong-wrong-wrong') + self.assertEqual(expected_dmi_value, util.read_dmi_data('mapped-key')) + + def test_dmidecode_used_if_no_sysfs_file_on_disk(self): + self.patch_mapping({}) + self._create_sysfs_parent_directory() + expected_dmi_value = 'dmidecode-used' + self._configure_dmidecode_return('use-dmidecode', expected_dmi_value) + self.assertEqual(expected_dmi_value, + util.read_dmi_data('use-dmidecode')) + + def test_none_returned_if_neither_source_has_data(self): + self.patch_mapping({}) + self._configure_dmidecode_return('key', 'value') + self.assertEqual(None, util.read_dmi_data('expect-fail')) + + def test_none_returned_if_dmidecode_not_in_path(self): + self.patched_funcs.enter_context( + mock.patch.object(util, 'which', lambda _: False)) + self.patch_mapping({}) + self.assertEqual(None, util.read_dmi_data('expect-fail')) class TestMultiLog(helpers.FilesystemMockingTestCase): -- cgit v1.2.3 From 04a60cf949e085f06fa1f93b39a74b8d288f0e2e Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 4 Mar 2015 12:29:29 +0000 Subject: Fix hang caused by HTTPretty on Python 3.4.2. HTTPretty can causes hangs on Python 3.4.2 (and maybe Python 3.4.1), due to a Python bug (fixed in Python 3.4.3). This works around the problem in the appropriate Python versions. See https://github.com/gabrielfalcao/HTTPretty/pull/193 and https://github.com/gabrielfalcao/HTTPretty/issues/221 for details. --- tests/unittests/helpers.py | 37 +++++++++++++++++++++- .../unittests/test_datasource/test_digitalocean.py | 3 +- tests/unittests/test_datasource/test_gce.py | 3 +- tests/unittests/test_datasource/test_openstack.py | 2 +- tests/unittests/test_ec2_util.py | 2 +- 5 files changed, 42 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 24e1e881..61a1f6ff 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -1,5 +1,6 @@ from __future__ import print_function +import functools import os import sys import shutil @@ -25,9 +26,10 @@ PY2 = False PY26 = False PY27 = False PY3 = False +FIX_HTTPRETTY = False _PY_VER = sys.version_info -_PY_MAJOR, _PY_MINOR = _PY_VER[0:2] +_PY_MAJOR, _PY_MINOR, _PY_MICRO = _PY_VER[0:3] if (_PY_MAJOR, _PY_MINOR) <= (2, 6): if (_PY_MAJOR, _PY_MINOR) == (2, 6): PY26 = True @@ -39,6 +41,8 @@ else: PY2 = True if (_PY_MAJOR, _PY_MINOR) >= (3, 0): PY3 = True + if _PY_MINOR == 4 and _PY_MICRO < 3: + FIX_HTTPRETTY = True if PY26: # For now add these on, taken from python 2.7 + slightly adjusted. Drop @@ -268,6 +272,37 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): mock.patch.object(sys, 'stderr', stderr)) +def import_httpretty(): + """Import HTTPretty and monkey patch Python 3.4 issue. + See https://github.com/gabrielfalcao/HTTPretty/pull/193 and + as well as https://github.com/gabrielfalcao/HTTPretty/issues/221. + + Lifted from + https://github.com/inveniosoftware/datacite/blob/master/tests/helpers.py + """ + if not FIX_HTTPRETTY: + import httpretty + else: + import socket + old_SocketType = socket.SocketType + + import httpretty + from httpretty import core + + def sockettype_patch(f): + @functools.wraps(f) + def inner(*args, **kwargs): + f(*args, **kwargs) + socket.SocketType = old_SocketType + socket.__dict__['SocketType'] = old_SocketType + return inner + + core.httpretty.disable = sockettype_patch( + httpretty.httpretty.disable + ) + return httpretty + + class HttprettyTestCase(TestCase): # necessary as http_proxy gets in the way of httpretty # https://github.com/gabrielfalcao/HTTPretty/issues/122 diff --git a/tests/unittests/test_datasource/test_digitalocean.py b/tests/unittests/test_datasource/test_digitalocean.py index 98f9cfac..679d1b82 100644 --- a/tests/unittests/test_datasource/test_digitalocean.py +++ b/tests/unittests/test_datasource/test_digitalocean.py @@ -15,7 +15,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import httpretty import re from six.moves.urllib_parse import urlparse @@ -26,6 +25,8 @@ from cloudinit.sources import DataSourceDigitalOcean from .. import helpers as test_helpers +httpretty = test_helpers.import_httpretty() + # Abbreviated for the test DO_INDEX = """id hostname diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index d28f3b08..4280abc4 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -15,7 +15,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import httpretty import re from base64 import b64encode, b64decode @@ -27,6 +26,8 @@ from cloudinit.sources import DataSourceGCE from .. import helpers as test_helpers +httpretty = test_helpers.import_httpretty() + GCE_META = { 'instance/id': '123', 'instance/zone': 'foo/bar', diff --git a/tests/unittests/test_datasource/test_openstack.py b/tests/unittests/test_datasource/test_openstack.py index 81411ced..0aa1ba84 100644 --- a/tests/unittests/test_datasource/test_openstack.py +++ b/tests/unittests/test_datasource/test_openstack.py @@ -31,7 +31,7 @@ from cloudinit.sources import DataSourceOpenStack as ds from cloudinit.sources.helpers import openstack from cloudinit import util -import httpretty as hp +hp = test_helpers.import_httpretty() BASE_URL = "http://169.254.169.254" PUBKEY = u'ssh-rsa AAAAB3NzaC1....sIkJhq8wdX+4I3A4cYbYP ubuntu@server-460\n' diff --git a/tests/unittests/test_ec2_util.py b/tests/unittests/test_ec2_util.py index bd43accf..99fc54be 100644 --- a/tests/unittests/test_ec2_util.py +++ b/tests/unittests/test_ec2_util.py @@ -3,7 +3,7 @@ from . import helpers from cloudinit import ec2_utils as eu from cloudinit import url_helper as uh -import httpretty as hp +hp = helpers.import_httpretty() class TestEc2Util(helpers.HttprettyTestCase): -- cgit v1.2.3 From 5eb2aab5d010e7b8d5e4146959e50f2a9f67d504 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 4 Mar 2015 17:20:48 +0000 Subject: Add util.message_from_string to wrap email.message_from_string. This is to work-around the fact that email.message_from_string uses cStringIO in Python 2.6, which can't handle Unicode. --- cloudinit/user_data.py | 4 +--- cloudinit/util.py | 7 +++++++ tests/unittests/test_util.py | 7 +++++++ 3 files changed, 15 insertions(+), 3 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/user_data.py b/cloudinit/user_data.py index 663a9048..eb3c7336 100644 --- a/cloudinit/user_data.py +++ b/cloudinit/user_data.py @@ -22,8 +22,6 @@ import os -import email - from email.mime.base import MIMEBase from email.mime.multipart import MIMEMultipart from email.mime.nonmultipart import MIMENonMultipart @@ -338,7 +336,7 @@ def convert_string(raw_data, headers=None): headers = {} data = util.decode_binary(util.decomp_gzip(raw_data)) if "mime-version:" in data[0:4096].lower(): - msg = email.message_from_string(data) + msg = util.message_from_string(data) for (key, val) in headers.items(): _replace_header(msg, key, val) else: diff --git a/cloudinit/util.py b/cloudinit/util.py index b6065410..971c1c2d 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -23,6 +23,7 @@ import contextlib import copy as obj_copy import ctypes +import email import errno import glob import grp @@ -2187,3 +2188,9 @@ def read_dmi_data(key): LOG.warn("did not find either path %s or dmidecode command", DMI_SYS_PATH) return None + + +def message_from_string(string): + if sys.version_info[:2] < (2, 7): + return email.message_from_file(six.StringIO(string)) + return email.message_from_string(string) diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 7da1f755..1619b5d2 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -452,4 +452,11 @@ class TestMultiLog(helpers.FilesystemMockingTestCase): util.multi_log('message', log=log, log_level=log_level) self.assertEqual((log_level, mock.ANY), log.log.call_args[0]) + +class TestMessageFromString(helpers.TestCase): + + def test_unicode_not_messed_up(self): + roundtripped = util.message_from_string(u'\n').as_string() + self.assertNotIn('\x00', roundtripped) + # vi: ts=4 expandtab -- cgit v1.2.3 From 31a8aab92656279b141a9c29e484c4895bde15d3 Mon Sep 17 00:00:00 2001 From: Oleg Strikov Date: Wed, 11 Mar 2015 20:22:54 +0300 Subject: userdata-handlers: python3-related fixes on do-not-process-this-part path Cloud-init crashed when received multipart userdata object with 'application/octet-stream' part or some other 'application/*' part except archived ones (x-gzip and friends). These parts are not processed by cloud-init and result only in a message in the log. We used some non-python3-friendly techniques while generating this log message which was a reason for the crash. --- cloudinit/handlers/__init__.py | 24 ++++++++++++++++++------ tests/unittests/test_data.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/handlers/__init__.py b/cloudinit/handlers/__init__.py index 6b7abbcd..d62fcd19 100644 --- a/cloudinit/handlers/__init__.py +++ b/cloudinit/handlers/__init__.py @@ -163,12 +163,19 @@ def walker_handle_handler(pdata, _ctype, _filename, payload): def _extract_first_or_bytes(blob, size): - # Extract the first line upto X bytes or X bytes from more than the - # first line if the first line does not contain enough bytes - first_line = blob.split("\n", 1)[0] - if len(first_line) >= size: - start = first_line[:size] - else: + # Extract the first line or upto X symbols for text objects + # Extract first X bytes for binary objects + try: + if isinstance(blob, six.string_types): + start = blob.split("\n", 1)[0] + else: + # We want to avoid decoding the whole blob (it might be huge) + # By taking 4*size bytes we have a guarantee to decode size utf8 chars + start = blob[:4*size].decode(errors='ignore').split("\n", 1)[0] + if len(start) >= size: + start = start[:size] + except UnicodeDecodeError: + # Bytes array doesn't contain a text object -- return chunk of raw bytes start = blob[0:size] return start @@ -183,6 +190,11 @@ def _escape_string(text): except TypeError: # Give up... pass + except AttributeError: + # We're in Python3 and received blob as text + # No escaping is needed because bytes are printed + # as 'b\xAA\xBB' automatically in Python3 + pass return text diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 8fc280e4..4f24e2dd 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -13,6 +13,7 @@ except ImportError: from six import BytesIO, StringIO +from email import encoders from email.mime.application import MIMEApplication from email.mime.base import MIMEBase from email.mime.multipart import MIMEMultipart @@ -492,6 +493,23 @@ c: 4 mock.call(ci.paths.get_ipath("cloud_config"), "", 0o600), ]) + def test_mime_application_octet_stream(self): + """Mime message of type application/octet-stream is ignored but shows warning.""" + ci = stages.Init() + message = MIMEBase("application", "octet-stream") + message.set_payload(b'\xbf\xe6\xb2\xc3\xd3\xba\x13\xa4\xd8\xa1\xcc\xbf') + encoders.encode_base64(message) + ci.datasource = FakeDataSource(message.as_string().encode()) + + with mock.patch('cloudinit.util.write_file') as mockobj: + log_file = self.capture_log(logging.WARNING) + ci.fetch() + ci.consume_data() + self.assertIn( + "Unhandled unknown content-type (application/octet-stream)", + log_file.getvalue()) + mockobj.assert_called_once_with( + ci.paths.get_ipath("cloud_config"), "", 0o600) class TestUDProcess(helpers.ResourceUsingTestCase): -- cgit v1.2.3 From c8a7b446de26c6bc19df1b8bb7d2b39cb9487749 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 13 Mar 2015 10:18:12 +0000 Subject: Write and read bytes to/from the SmartOS serial console. --- cloudinit/sources/DataSourceSmartOS.py | 5 +++-- tests/unittests/test_datasource/test_smartos.py | 15 ++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 9d48beab..896fde3f 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -319,7 +319,8 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, return False ser = get_serial(seed_device, seed_timeout) - ser.write("GET %s\n" % noun.rstrip()) + request_line = "GET %s\n" % noun.rstrip() + ser.write(request_line.encode('ascii')) status = str(ser.readline()).rstrip() response = [] eom_found = False @@ -329,7 +330,7 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, return default while not eom_found: - m = ser.readline() + m = ser.readline().decode('ascii') if m.rstrip() == ".": eom_found = True else: diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 8b62b1b1..cb0ab984 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -36,6 +36,8 @@ import tempfile import stat import uuid +import six + MOCK_RETURNS = { 'hostname': 'test-host', @@ -78,24 +80,27 @@ class MockSerial(object): return True def write(self, line): - line = line.replace('GET ', '') + if not isinstance(line, six.binary_type): + raise TypeError("Should be writing binary lines.") + line = line.decode('ascii').replace('GET ', '') self.last = line.rstrip() def readline(self): if self.new: self.new = False if self.last in self.mockdata: - return 'SUCCESS\n' + line = 'SUCCESS\n' else: - return 'NOTFOUND %s\n' % self.last + line = 'NOTFOUND %s\n' % self.last - if self.last in self.mockdata: + elif self.last in self.mockdata: if not self.mocked_out: self.mocked_out = [x for x in self._format_out()] if len(self.mocked_out) > self.count: self.count += 1 - return self.mocked_out[self.count - 1] + line = self.mocked_out[self.count - 1] + return line.encode('ascii') def _format_out(self): if self.last in self.mockdata: -- cgit v1.2.3 From e298037a04284370d30faa7a1808e985be3a4cc4 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 18 Mar 2015 13:32:54 +0000 Subject: Add tests for cc_disk_setup.is_disk_used. --- .../test_handler/test_handler_disk_setup.py | 30 ++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/unittests/test_handler/test_handler_disk_setup.py (limited to 'tests/unittests') diff --git a/tests/unittests/test_handler/test_handler_disk_setup.py b/tests/unittests/test_handler/test_handler_disk_setup.py new file mode 100644 index 00000000..ddef8d48 --- /dev/null +++ b/tests/unittests/test_handler/test_handler_disk_setup.py @@ -0,0 +1,30 @@ +from cloudinit.config import cc_disk_setup +from ..helpers import ExitStack, mock, TestCase + + +class TestIsDiskUsed(TestCase): + + def setUp(self): + super(TestIsDiskUsed, self).setUp() + self.patches = ExitStack() + mod_name = 'cloudinit.config.cc_disk_setup' + self.enumerate_disk = self.patches.enter_context( + mock.patch('{0}.enumerate_disk'.format(mod_name))) + self.check_fs = self.patches.enter_context( + mock.patch('{0}.check_fs'.format(mod_name))) + + def test_multiple_child_nodes_returns_true(self): + self.enumerate_disk.return_value = (mock.MagicMock() for _ in range(2)) + self.check_fs.return_value = (mock.MagicMock(), None, mock.MagicMock()) + self.assertTrue(cc_disk_setup.is_disk_used(mock.MagicMock())) + + def test_valid_filesystem_returns_true(self): + self.enumerate_disk.return_value = (mock.MagicMock() for _ in range(1)) + self.check_fs.return_value = ( + mock.MagicMock(), 'ext4', mock.MagicMock()) + self.assertTrue(cc_disk_setup.is_disk_used(mock.MagicMock())) + + def test_one_child_nodes_and_no_fs_returns_false(self): + self.enumerate_disk.return_value = (mock.MagicMock() for _ in range(1)) + self.check_fs.return_value = (mock.MagicMock(), None, mock.MagicMock()) + self.assertFalse(cc_disk_setup.is_disk_used(mock.MagicMock())) -- cgit v1.2.3 From 9c9cc1ad42d04fdb375d72d5f2b2cfafba898f50 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 25 Mar 2015 15:54:05 +0000 Subject: Organise imports in test_smartos.py. --- tests/unittests/test_datasource/test_smartos.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index cb0ab984..cdd83bf8 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -24,20 +24,21 @@ from __future__ import print_function -from cloudinit import helpers as c_helpers -from cloudinit.sources import DataSourceSmartOS -from cloudinit.util import b64e -from .. import helpers import os import os.path import re import shutil -import tempfile import stat +import tempfile import uuid import six +from cloudinit import helpers as c_helpers +from cloudinit.sources import DataSourceSmartOS +from cloudinit.util import b64e + +from .. import helpers MOCK_RETURNS = { 'hostname': 'test-host', -- cgit v1.2.3 From 7c63a4096d9b6c9dc10605c289ee048c7b0778c6 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 25 Mar 2015 15:54:07 +0000 Subject: Convert DataSourceSmartOS to use v2 metadata. --- cloudinit/sources/DataSourceSmartOS.py | 75 +++++--- tests/unittests/test_datasource/test_smartos.py | 216 ++++++++++++++++++++---- 2 files changed, 239 insertions(+), 52 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 896fde3f..694a011a 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -29,9 +29,10 @@ # http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html # Comments with "@datadictionary" are snippets of the definition -import base64 import binascii import os +import random +import re import serial from cloudinit import log as logging @@ -301,6 +302,53 @@ def get_serial(seed_device, seed_timeout): return ser +class JoyentMetadataFetchException(Exception): + pass + + +class JoyentMetadataClient(object): + + def __init__(self, serial): + self.serial = serial + + def _checksum(self, body): + return '{0:08x}'.format( + binascii.crc32(body.encode('utf-8')) & 0xffffffff) + + def _get_value_from_frame(self, expected_request_id, frame): + regex = ( + r'V2 (?P\d+) (?P[0-9a-f]+)' + r' (?P(?P[0-9a-f]+) (?PSUCCESS|NOTFOUND)' + r'( (?P.+))?)') + frame_data = re.match(regex, frame).groupdict() + if int(frame_data['length']) != len(frame_data['body']): + raise JoyentMetadataFetchException( + 'Incorrect frame length given ({0} != {1}).'.format( + frame_data['length'], len(frame_data['body']))) + expected_checksum = self._checksum(frame_data['body']) + if frame_data['checksum'] != expected_checksum: + raise JoyentMetadataFetchException( + 'Invalid checksum (expected: {0}; got {1}).'.format( + expected_checksum, frame_data['checksum'])) + if frame_data['request_id'] != expected_request_id: + raise JoyentMetadataFetchException( + 'Request ID mismatch (expected: {0}; got {1}).'.format( + expected_request_id, frame_data['request_id'])) + if not frame_data.get('payload', None): + return None + return util.b64d(frame_data['payload']) + + def get_metadata(self, metadata_key): + request_id = '{0:08x}'.format(random.randint(0, 0xffffffff)) + message_body = '{0} GET {1}'.format(request_id, + util.b64e(metadata_key)) + msg = 'V2 {0} {1} {2}\n'.format( + len(message_body), self._checksum(message_body), message_body) + self.serial.write(msg.encode('ascii')) + response = self.serial.readline().decode('ascii') + return self._get_value_from_frame(request_id, response) + + def query_data(noun, seed_device, seed_timeout, strip=False, default=None, b64=None): """Makes a request to via the serial console via "GET " @@ -314,34 +362,21 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, encoded, so this method relies on being told if the data is base64 or not. """ - if not noun: return False ser = get_serial(seed_device, seed_timeout) - request_line = "GET %s\n" % noun.rstrip() - ser.write(request_line.encode('ascii')) - status = str(ser.readline()).rstrip() - response = [] - eom_found = False - - if 'SUCCESS' not in status: - ser.close() - return default - - while not eom_found: - m = ser.readline().decode('ascii') - if m.rstrip() == ".": - eom_found = True - else: - response.append(m) + client = JoyentMetadataClient(ser) + response = client.get_metadata(noun) ser.close() + if response is None: + return default if b64 is None: b64 = query_data('b64-%s' % noun, seed_device=seed_device, - seed_timeout=seed_timeout, b64=False, - default=False, strip=True) + seed_timeout=seed_timeout, b64=False, + default=False, strip=True) b64 = util.is_true(b64) resp = None diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index cdd83bf8..c79cf3aa 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -31,15 +31,24 @@ import shutil import stat import tempfile import uuid +from binascii import crc32 + +import serial +import six import six from cloudinit import helpers as c_helpers from cloudinit.sources import DataSourceSmartOS -from cloudinit.util import b64e +from cloudinit.util import b64d, b64e from .. import helpers +try: + from unittest import mock +except ImportError: + import mock + MOCK_RETURNS = { 'hostname': 'test-host', 'root_authorized_keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname', @@ -57,6 +66,37 @@ MOCK_RETURNS = { DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc') +def _checksum(body): + return '{0:08x}'.format(crc32(body.encode('utf-8')) & 0xffffffff) + + +def _generate_v2_frame(request_id, command, body=None): + body_parts = [request_id, command] + if body: + body_parts.append(b64e(body)) + message_body = ' '.join(body_parts) + return 'V2 {0} {1} {2}\n'.format( + len(message_body), _checksum(message_body), message_body).encode( + 'ascii') + + +def _parse_v2_frame(line): + line = line.decode('ascii') + if not line.endswith('\n'): + raise Exception('Frames must end with a newline.') + version, length, checksum, body = line.strip().split(' ', 3) + if version != 'V2': + raise Exception('Frames must begin with V2.') + if int(length) != len(body): + raise Exception('Incorrect frame length given ({0} != {1}).'.format( + length, len(body))) + expected_checksum = _checksum(body) + if checksum != expected_checksum: + raise Exception('Invalid checksum.') + request_id, command, payload = body.split() + return request_id, command, b64d(payload) + + class MockSerial(object): """Fake a serial terminal for testing the code that interfaces with the serial""" @@ -81,39 +121,21 @@ class MockSerial(object): return True def write(self, line): - if not isinstance(line, six.binary_type): - raise TypeError("Should be writing binary lines.") - line = line.decode('ascii').replace('GET ', '') - self.last = line.rstrip() + self.last = line def readline(self): - if self.new: - self.new = False - if self.last in self.mockdata: - line = 'SUCCESS\n' - else: - line = 'NOTFOUND %s\n' % self.last - - elif self.last in self.mockdata: - if not self.mocked_out: - self.mocked_out = [x for x in self._format_out()] - - if len(self.mocked_out) > self.count: - self.count += 1 - line = self.mocked_out[self.count - 1] - return line.encode('ascii') - - def _format_out(self): - if self.last in self.mockdata: - _mret = self.mockdata[self.last] - try: - for l in _mret.splitlines(): - yield "%s\n" % l.rstrip() - except: - yield "%s\n" % _mret.rstrip() - - yield '.' - yield '\n' + if self.last == '\n': + return 'invalid command\n' + elif self.last == 'NEGOTIATE V2\n': + return 'V2_OK\n' + request_id, command, request_body = _parse_v2_frame(self.last) + if command != 'GET': + raise Exception('MockSerial only supports GET requests.') + metadata_key = request_body.strip() + if metadata_key in self.mockdata: + return _generate_v2_frame( + request_id, 'SUCCESS', self.mockdata[metadata_key]) + return _generate_v2_frame(request_id, 'NOTFOUND') class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): @@ -459,3 +481,133 @@ def apply_patches(patches): setattr(ref, name, replace) ret.append((ref, name, orig)) return ret + + +class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): + + def setUp(self): + super(TestJoyentMetadataClient, self).setUp() + self.serial = mock.MagicMock(spec=serial.Serial) + self.request_id = 0xabcdef12 + self.metadata_value = 'value' + self.response_parts = { + 'command': 'SUCCESS', + 'crc': 'b5a9ff00', + 'length': 17 + len(b64e(self.metadata_value)), + 'payload': b64e(self.metadata_value), + 'request_id': '{0:08x}'.format(self.request_id), + } + + def make_response(): + payload = '' + if self.response_parts['payload']: + payload = ' {0}'.format(self.response_parts['payload']) + del self.response_parts['payload'] + return ( + 'V2 {length} {crc} {request_id} {command}{payload}\n'.format( + payload=payload, **self.response_parts).encode('ascii')) + self.serial.readline.side_effect = make_response + self.patched_funcs.enter_context( + mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint', + mock.Mock(return_value=self.request_id))) + + def _get_client(self): + return DataSourceSmartOS.JoyentMetadataClient(self.serial) + + def assertEndsWith(self, haystack, prefix): + self.assertTrue(haystack.endswith(prefix), + "{0} does not end with '{1}'".format( + repr(haystack), prefix)) + + def assertStartsWith(self, haystack, prefix): + self.assertTrue(haystack.startswith(prefix), + "{0} does not start with '{1}'".format( + repr(haystack), prefix)) + + def test_get_metadata_writes_a_single_line(self): + client = self._get_client() + client.get_metadata('some_key') + self.assertEqual(1, self.serial.write.call_count) + written_line = self.serial.write.call_args[0][0] + self.assertEndsWith(written_line, b'\n') + self.assertEqual(1, written_line.count(b'\n')) + + def _get_written_line(self, key='some_key'): + client = self._get_client() + client.get_metadata(key) + return self.serial.write.call_args[0][0] + + def test_get_metadata_writes_bytes(self): + self.assertIsInstance(self._get_written_line(), six.binary_type) + + def test_get_metadata_line_starts_with_v2(self): + self.assertStartsWith(self._get_written_line(), b'V2') + + def test_get_metadata_uses_get_command(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + self.assertEqual('GET', parts[4]) + + def test_get_metadata_base64_encodes_argument(self): + key = 'my_key' + parts = self._get_written_line(key).decode('ascii').strip().split(' ') + self.assertEqual(b64e(key), parts[5]) + + def test_get_metadata_calculates_length_correctly(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + expected_length = len(' '.join(parts[3:])) + self.assertEqual(expected_length, int(parts[1])) + + def test_get_metadata_uses_appropriate_request_id(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + request_id = parts[3] + self.assertEqual(8, len(request_id)) + self.assertEqual(request_id, request_id.lower()) + + def test_get_metadata_uses_random_number_for_request_id(self): + line = self._get_written_line() + request_id = line.decode('ascii').strip().split(' ')[3] + self.assertEqual('{0:08x}'.format(self.request_id), request_id) + + def test_get_metadata_checksums_correctly(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + expected_checksum = '{0:08x}'.format( + crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff) + checksum = parts[2] + self.assertEqual(expected_checksum, checksum) + + def test_get_metadata_reads_a_line(self): + client = self._get_client() + client.get_metadata('some_key') + self.assertEqual(1, self.serial.readline.call_count) + + def test_get_metadata_returns_valid_value(self): + client = self._get_client() + value = client.get_metadata('some_key') + self.assertEqual(self.metadata_value, value) + + def test_get_metadata_throws_exception_for_incorrect_length(self): + self.response_parts['length'] = 0 + client = self._get_client() + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client.get_metadata, 'some_key') + + def test_get_metadata_throws_exception_for_incorrect_crc(self): + self.response_parts['crc'] = 'deadbeef' + client = self._get_client() + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client.get_metadata, 'some_key') + + def test_get_metadata_throws_exception_for_request_id_mismatch(self): + self.response_parts['request_id'] = 'deadbeef' + client = self._get_client() + client._checksum = lambda _: self.response_parts['crc'] + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client.get_metadata, 'some_key') + + def test_get_metadata_returns_None_if_value_not_found(self): + self.response_parts['payload'] = '' + self.response_parts['command'] = 'NOTFOUND' + self.response_parts['length'] = 17 + client = self._get_client() + client._checksum = lambda _: self.response_parts['crc'] + self.assertIsNone(client.get_metadata('some_key')) -- cgit v1.2.3 From 7b8b62490303928aa4d68271fdd73b2e9712c504 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 25 Mar 2015 15:54:10 +0000 Subject: Refactor tests to assume JoyentMetadataClient is correct. We are treating JoyentMetadataClient as a unit which the data source depends on, so we mock it out instead of providing a fake implementation of it. --- tests/unittests/test_datasource/test_smartos.py | 88 ++++--------------------- 1 file changed, 13 insertions(+), 75 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index c79cf3aa..39991cc2 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -40,7 +40,7 @@ import six from cloudinit import helpers as c_helpers from cloudinit.sources import DataSourceSmartOS -from cloudinit.util import b64d, b64e +from cloudinit.util import b64e from .. import helpers @@ -66,76 +66,15 @@ MOCK_RETURNS = { DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc') -def _checksum(body): - return '{0:08x}'.format(crc32(body.encode('utf-8')) & 0xffffffff) - - -def _generate_v2_frame(request_id, command, body=None): - body_parts = [request_id, command] - if body: - body_parts.append(b64e(body)) - message_body = ' '.join(body_parts) - return 'V2 {0} {1} {2}\n'.format( - len(message_body), _checksum(message_body), message_body).encode( - 'ascii') - - -def _parse_v2_frame(line): - line = line.decode('ascii') - if not line.endswith('\n'): - raise Exception('Frames must end with a newline.') - version, length, checksum, body = line.strip().split(' ', 3) - if version != 'V2': - raise Exception('Frames must begin with V2.') - if int(length) != len(body): - raise Exception('Incorrect frame length given ({0} != {1}).'.format( - length, len(body))) - expected_checksum = _checksum(body) - if checksum != expected_checksum: - raise Exception('Invalid checksum.') - request_id, command, payload = body.split() - return request_id, command, b64d(payload) - - -class MockSerial(object): - """Fake a serial terminal for testing the code that - interfaces with the serial""" - - port = None - - def __init__(self, mockdata): - self.last = None - self.last = None - self.new = True - self.count = 0 - self.mocked_out = [] - self.mockdata = mockdata - - def open(self): - return True - - def close(self): - return True - - def isOpen(self): - return True - - def write(self, line): - self.last = line - - def readline(self): - if self.last == '\n': - return 'invalid command\n' - elif self.last == 'NEGOTIATE V2\n': - return 'V2_OK\n' - request_id, command, request_body = _parse_v2_frame(self.last) - if command != 'GET': - raise Exception('MockSerial only supports GET requests.') - metadata_key = request_body.strip() - if metadata_key in self.mockdata: - return _generate_v2_frame( - request_id, 'SUCCESS', self.mockdata[metadata_key]) - return _generate_v2_frame(request_id, 'NOTFOUND') +def get_mock_client(mockdata): + class MockMetadataClient(object): + + def __init__(self, serial): + pass + + def get_metadata(self, metadata_key): + return mockdata.get(metadata_key) + return MockMetadataClient class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): @@ -183,9 +122,6 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): if dmi_data is None: dmi_data = DMI_DATA_RETURN - def _get_serial(*_): - return MockSerial(mockdata) - def _dmi_data(): return dmi_data @@ -202,7 +138,9 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): sys_cfg['datasource']['SmartOS'] = ds_cfg self.apply_patches([(mod, 'LEGACY_USER_D', self.legacy_user_d)]) - self.apply_patches([(mod, 'get_serial', _get_serial)]) + self.apply_patches([(mod, 'get_serial', mock.MagicMock())]) + self.apply_patches([ + (mod, 'JoyentMetadataClient', get_mock_client(mockdata))]) self.apply_patches([(mod, 'dmi_data', _dmi_data)]) self.apply_patches([(os, 'uname', _os_uname)]) self.apply_patches([(mod, 'device_exists', lambda d: True)]) -- cgit v1.2.3 From d52feae7ad38670964edebb0eea5db2c8c80f760 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 25 Mar 2015 15:54:19 +0000 Subject: Ensure that the serial console is always closed. --- cloudinit/sources/DataSourceSmartOS.py | 9 +++++---- tests/unittests/test_datasource/test_smartos.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 61dd044f..237fc140 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -30,9 +30,11 @@ # Comments with "@datadictionary" are snippets of the definition import binascii +import contextlib import os import random import re + import serial from cloudinit import log as logging @@ -371,11 +373,10 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, if not noun: return False - ser = get_serial(seed_device, seed_timeout) + with contextlib.closing(get_serial(seed_device, seed_timeout)) as ser: + client = JoyentMetadataClient(ser) + response = client.get_metadata(noun) - client = JoyentMetadataClient(ser) - response = client.get_metadata(noun) - ser.close() if response is None: return default diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 39991cc2..28b41eaf 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -409,6 +409,18 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): self.assertEqual(dsrc.device_name_to_device('FOO'), mydscfg['disk_aliases']['FOO']) + @mock.patch('cloudinit.sources.DataSourceSmartOS.JoyentMetadataClient') + @mock.patch('cloudinit.sources.DataSourceSmartOS.get_serial') + def test_serial_console_closed_on_error(self, get_serial, metadata_client): + class OurException(Exception): + pass + metadata_client.side_effect = OurException + try: + DataSourceSmartOS.query_data('noun', 'device', 0) + except OurException: + pass + self.assertEqual(1, get_serial.return_value.close.call_count) + def apply_patches(patches): ret = [] -- cgit v1.2.3 From a373e1097f6be460914e6cbbc897c6aa8e4aaefe Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 26 Mar 2015 20:39:25 -0400 Subject: commit work in progress. tests pass. --- cloudinit/config/cc_snappy.py | 159 +++++++++++++++----- .../unittests/test_handler/test_handler_snappy.py | 163 +++++++++++++++++++++ 2 files changed, 285 insertions(+), 37 deletions(-) create mode 100644 tests/unittests/test_handler/test_handler_snappy.py (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index 133336d4..bef8c170 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -7,18 +7,21 @@ from cloudinit import util from cloudinit.settings import PER_INSTANCE import glob +import six +import tempfile import os LOG = logging.getLogger(__name__) frequency = PER_INSTANCE -SNAPPY_ENV_PATH = "/writable/system-data/etc/snappy.env" +SNAPPY_CMD = "snappy" BUILTIN_CFG = { 'packages': [], 'packages_dir': '/writable/user-data/cloud-init/click_packages', 'ssh_enabled': False, - 'system_snappy': "auto" + 'system_snappy': "auto", + 'configs': {}, } """ @@ -27,43 +30,111 @@ snappy: ssh_enabled: True packages: - etcd - - {'name': 'pkg1', 'config': "wark"} + - pkg2 + configs: + pkgname: config-blob + pkgname2: config-blob """ -def install_package(pkg_name, config=None): - cmd = ["snappy", "install"] - if config: - if os.path.isfile(config): - cmd.append("--config-file=" + config) +def get_fs_package_ops(fspath): + if not fspath: + return [] + ops = [] + for snapfile in glob.glob(os.path.sep.join([fspath, '*.snap'])): + cfg = snapfile.rpartition(".")[0] + ".config" + name = os.path.basename(snapfile).rpartition(".")[0] + if not os.path.isfile(cfg): + cfg = None + ops.append(makeop('install', name, config=None, + path=snapfile, cfgfile=cfg)) + return ops + + +def makeop(op, name, config=None, path=None, cfgfile=None): + return({'op': op, 'name': name, 'config': config, 'path': path, + 'cfgfile': cfgfile}) + + +def get_package_ops(packages, configs, installed=None, fspath=None): + # get the install an config operations that should be done + if installed is None: + installed = read_installed_packages() + + if not packages: + packages = [] + if not configs: + configs = {} + + ops = [] + ops += get_fs_package_ops(fspath) + + for name in packages: + ops.append(makeop('install', name, configs.get('name'))) + + to_install = [f['name'] for f in ops] + + for name in configs: + if name in installed and name not in to_install: + ops.append(makeop('config', name, config=configs[name])) + + # prefer config entries to filepath entries + for op in ops: + name = op['name'] + if name in configs and op['op'] == 'install' and 'cfgfile' in op: + LOG.debug("preferring configs[%s] over '%s'", name, op['cfgfile']) + op['cfgfile'] = None + op['config'] = configs[op['name']] + + return ops + + +def render_snap_op(op, name, path=None, cfgfile=None, config=None): + if op not in ('install', 'config'): + raise ValueError("cannot render op '%s'" % op) + + try: + cfg_tmpf = None + if config is not None: + if isinstance(config, six.binary_type): + cfg_bytes = config + elif isinstance(config, six.text_type): + cfg_bytes = config_data.encode() + else: + cfg_bytes = yaml.safe_dump(config).encode() + + (fd, cfg_tmpf) = tempfile.mkstemp() + os.write(fd, config_data) + os.close(fd) + cfgfile = cfg_tmpf + + cmd = [SNAPPY_CMD, op] + if op == 'install' and cfgfile: + cmd.append('--config=' + cfgfile) + elif op == 'config': + cmd.append(cfgfile) + + util.subp(cmd) + + finally: + if tmpfile: + os.unlink(tmpfile) + + +def read_installed_packages(): + return [p[0] for p in read_pkg_data()] + + +def read_pkg_data(): + out, err = util.subp([SNAPPY_CMD, "list"]) + for line in out.splitlines()[1:]: + toks = line.split(sep=None, maxsplit=3) + if len(toks) == 3: + (name, date, version) = toks + dev = None else: - cmd.append("--config=" + config) - cmd.append(pkg_name) - util.subp(cmd) - - -def install_packages(package_dir, packages): - local_pkgs = glob.glob(os.path.sep.join([package_dir, '*.click'])) - LOG.debug("installing local packages %s" % local_pkgs) - if local_pkgs: - for pkg in local_pkgs: - cfg = pkg.replace(".click", ".config") - if not os.path.isfile(cfg): - cfg = None - install_package(pkg, config=cfg) - - LOG.debug("installing click packages") - if packages: - for pkg in packages: - if not pkg: - continue - if isinstance(pkg, str): - name = pkg - config = None - elif pkg: - name = pkg.get('name', pkg) - config = pkg.get('config') - install_package(pkg_name=name, config=config) + (name, date, version, dev) = toks + pkgs.append((name, date, version, dev,)) def disable_enable_ssh(enabled): @@ -107,7 +178,21 @@ def handle(name, cfg, cloud, log, args): LOG.debug("%s: 'auto' mode, and system not snappy", name) return - install_packages(mycfg['packages_dir'], - mycfg['packages']) + pkg_ops = get_package_ops(packages=mycfg['packages'], + configs=mycfg['configs'], + fspath=mycfg['packages_dir']) + + fails = [] + for pkg_op in pkg_ops: + try: + render_snap_op(op=pkg_op['op'], name=pkg_op['name'], + cfgfile=pkg_op['cfgfile'], config=pkg_op['config']) + except Exception as e: + fails.append((pkg_op, e,)) + LOG.warn("'%s' failed for '%s': %s", + pkg_op['op'], pkg_op['name'], e) disable_enable_ssh(mycfg.get('ssh_enabled', False)) + + if fails: + raise Exception("failed to install/configure snaps") diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py new file mode 100644 index 00000000..6b6d3584 --- /dev/null +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -0,0 +1,163 @@ +from cloudinit.config.cc_snappy import (makeop, get_package_ops) +from cloudinit import util +from .. import helpers as t_help + +import os +import tempfile + +class TestInstallPackages(t_help.TestCase): + def setUp(self): + super(TestInstallPackages, self).setUp() + self.unapply = [] + + # by default 'which' has nothing in its path + self.apply_patches([(util, 'subp', self._subp)]) + self.subp_called = [] + self.snapcmds = [] + self.tmp = tempfile.mkdtemp() + + def tearDown(self): + apply_patches([i for i in reversed(self.unapply)]) + + def apply_patches(self, patches): + ret = apply_patches(patches) + self.unapply += ret + + def _subp(self, *args, **kwargs): + # supports subp calling with cmd as args or kwargs + if 'args' not in kwargs: + kwargs['args'] = args[0] + self.subp_called.append(kwargs) + snap_cmds = [] + args = kwargs['args'] + if args[0:2] == ['snappy', 'config']: + if args[3] == "-": + config = kwargs.get('data', '') + else: + with open(args[3], "rb") as fp: + config = fp.read() + snap_cmds.append(('config', args[2], config,)) + elif args[0:2] == ['snappy', 'install']: + # basically parse the snappy command and add + # to snap_installs a tuple (pkg, config) + config = None + pkg = None + for arg in args[2:]: + if arg.startswith("--config="): + cfgfile = arg.partition("=")[2] + if cfgfile == "-": + config = kwargs.get('data', '') + elif cfgfile: + with open(cfgfile, "rb") as fp: + config = fp.read() + elif not pkg and not arg.startswith("-"): + pkg = os.path.basename(arg) + self.snap_installs.append(('install', pkg, config,)) + + def test_package_ops_1(self): + ret = get_package_ops( + packages=['pkg1', 'pkg2', 'pkg3'], + configs={'pkg2': b'mycfg2'}, installed=[]) + self.assertEqual(ret, + [makeop('install', 'pkg1', None, None), + makeop('install', 'pkg2', b'mycfg2', None), + makeop('install', 'pkg3', None, None)]) + + def test_package_ops_config_only(self): + ret = get_package_ops( + packages=None, + configs={'pkg2': b'mycfg2'}, installed=['pkg1', 'pkg2']) + self.assertEqual(ret, + [makeop('config', 'pkg2', b'mycfg2')]) + + def test_package_ops_install_and_config(self): + ret = get_package_ops( + packages=['pkg3', 'pkg2'], + configs={'pkg2': b'mycfg2', 'xinstalled': b'xcfg'}, + installed=['xinstalled']) + self.assertEqual(ret, + [makeop('install', 'pkg3'), + makeop('install', 'pkg2', b'mycfg2'), + makeop('config', 'xinstalled', b'xcfg')]) + + def test_package_ops_with_file(self): + t_help.populate_dir(self.tmp, + {"snapf1.snap": b"foo1", "snapf1.config": b"snapf1cfg", + "snapf2.snap": b"foo2", "foo.bar": "ignored"}) + ret = get_package_ops( + packages=['pkg1'], configs={}, installed=[], fspath=self.tmp) + self.assertEqual(ret, + [makeop_tmpd(self.tmp, 'install', 'snapf1', path="snapf1.snap", + cfgfile="snapf1.config"), + makeop_tmpd(self.tmp, 'install', 'snapf2', path="snapf2.snap"), + makeop('install', 'pkg1')]) + + +def makeop_tmpd(tmpd, op, name, config=None, path=None, cfgfile=None): + if cfgfile: + cfgfile = os.path.sep.join([tmpd, cfgfile]) + if path: + path = os.path.sep.join([tmpd, path]) + return(makeop(op=op, name=name, config=config, path=path, cfgfile=cfgfile)) + +# def test_local_snaps_no_config(self): +# t_help.populate_dir(self.tmp, +# {"snap1.snap": b"foo", "snap2.snap": b"foo", "foosnap.txt": b"foo"}) +# cc_snappy.install_packages(self.tmp, None) +# self.assertEqual(self.snap_installs, +# [("snap1.snap", None), ("snap2.snap", None)]) +# +# def test_local_snaps_mixed_config(self): +# t_help.populate_dir(self.tmp, +# {"snap1.snap": b"foo", "snap2.snap": b"snap2", +# "snap1.config": b"snap1config"}) +# cc_snappy.install_packages(self.tmp, None) +# self.assertEqual(self.snap_installs, +# [("snap1.snap", b"snap1config"), ("snap2.snap", None)]) +# +# def test_local_snaps_all_config(self): +# t_help.populate_dir(self.tmp, +# {"snap1.snap": "foo", "snap1.config": b"snap1config", +# "snap2.snap": "snap2", "snap2.config": b"snap2config"}) +# cc_snappy.install_packages(self.tmp, None) +# self.assertEqual(self.snap_installs, +# [("snap1.snap", b"snap1config"), ("snap2.snap", b"snap2config")]) +# +# def test_local_snaps_and_packages(self): +# t_help.populate_dir(self.tmp, +# {"snap1.snap": "foo", "snap1.config": b"snap1config"}) +# cc_snappy.install_packages(self.tmp, ["snap-in-store"]) +# self.assertEqual(self.snap_installs, +# [("snap1.snap", b"snap1config"), ("snap-in-store", None)]) +# +# def test_packages_no_config(self): +# cc_snappy.install_packages(self.tmp, ["snap-in-store"]) +# self.assertEqual(self.snap_installs, +# [("snap-in-store", None)]) +# +# def test_packages_mixed_config(self): +# cc_snappy.install_packages(self.tmp, +# ["snap-in-store", +# {'name': 'snap2-in-store', 'config': b"foo"}]) +# self.assertEqual(self.snap_installs, +# [("snap-in-store", None), ("snap2-in-store", b"foo")]) +# +# def test_packages_all_config(self): +# cc_snappy.install_packages(self.tmp, +# [{'name': 'snap1-in-store', 'config': b"boo"}, +# {'name': 'snap2-in-store', 'config': b"wark"}]) +# self.assertEqual(self.snap_installs, +# [("snap1-in-store", b"boo"), ("snap2-in-store", b"wark")]) +# +# + +def apply_patches(patches): + ret = [] + for (ref, name, replace) in patches: + if replace is None: + continue + orig = getattr(ref, name) + setattr(ref, name, replace) + ret.append((ref, name, orig)) + return ret + -- cgit v1.2.3 From bd7165dd67338f742f999fb2c53ec5f67fc66477 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 26 Mar 2015 21:14:17 -0400 Subject: start of snap_op tests --- cloudinit/config/cc_snappy.py | 10 ++++-- .../unittests/test_handler/test_handler_snappy.py | 40 +++++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index bef8c170..cf441c92 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -114,11 +114,17 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): elif op == 'config': cmd.append(cfgfile) + if op == 'install': + if path: + cmd.append(path) + else: + cmd.append(name) + util.subp(cmd) finally: - if tmpfile: - os.unlink(tmpfile) + if cfg_tmpf: + os.unlink(cfg_tmpf) def read_installed_packages(): diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index 6b6d3584..7dc77970 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -1,4 +1,5 @@ -from cloudinit.config.cc_snappy import (makeop, get_package_ops) +from cloudinit.config.cc_snappy import ( + makeop, get_package_ops, render_snap_op) from cloudinit import util from .. import helpers as t_help @@ -36,7 +37,7 @@ class TestInstallPackages(t_help.TestCase): else: with open(args[3], "rb") as fp: config = fp.read() - snap_cmds.append(('config', args[2], config,)) + self.snapcmds.append(['config', args[2], config]) elif args[0:2] == ['snappy', 'install']: # basically parse the snappy command and add # to snap_installs a tuple (pkg, config) @@ -51,8 +52,8 @@ class TestInstallPackages(t_help.TestCase): with open(cfgfile, "rb") as fp: config = fp.read() elif not pkg and not arg.startswith("-"): - pkg = os.path.basename(arg) - self.snap_installs.append(('install', pkg, config,)) + pkg = arg + self.snapcmds.append(['install', pkg, config]) def test_package_ops_1(self): ret = get_package_ops( @@ -92,6 +93,37 @@ class TestInstallPackages(t_help.TestCase): makeop_tmpd(self.tmp, 'install', 'snapf2', path="snapf2.snap"), makeop('install', 'pkg1')]) + #def render_snap_op(op, name, path=None, cfgfile=None, config=None): + def test_render_op_localsnap(self): + t_help.populate_dir(self.tmp, {"snapf1.snap": b"foo1"}) + op = makeop_tmpd(self.tmp, 'install', 'snapf1', + path='snapf1.snap') + render_snap_op(**op) + self.assertEqual(self.snapcmds, + [['install', op['path'], None]]) + + def test_render_op_localsnap_localconfig(self): + t_help.populate_dir(self.tmp, + {"snapf1.snap": b"foo1", 'snapf1.config': b'snapf1cfg'}) + op = makeop_tmpd(self.tmp, 'install', 'snapf1', + path='snapf1.snap', cfgfile='snapf1.config') + render_snap_op(**op) + self.assertEqual(self.snapcmds, + [['install', op['path'], b'snapf1cfg']]) + + def test_render_op_localsnap_config(self): + pass + + def test_render_op_snap(self): + pass + + def test_render_op_snap_config(self): + pass + + def test_render_op_config(self): + pass + + def makeop_tmpd(tmpd, op, name, config=None, path=None, cfgfile=None): if cfgfile: -- cgit v1.2.3 From df43c6bd3726c9a34b9f8ff4bbf75957aa751011 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 26 Mar 2015 21:55:26 -0400 Subject: pep8, and some more tests --- cloudinit/config/cc_snappy.py | 13 +- .../unittests/test_handler/test_handler_snappy.py | 131 ++++++++------------- 2 files changed, 57 insertions(+), 87 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index cf441c92..c926ae0a 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -99,26 +99,25 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): if isinstance(config, six.binary_type): cfg_bytes = config elif isinstance(config, six.text_type): - cfg_bytes = config_data.encode() + cfg_bytes = config.encode() else: cfg_bytes = yaml.safe_dump(config).encode() (fd, cfg_tmpf) = tempfile.mkstemp() - os.write(fd, config_data) + os.write(fd, cfg_bytes) os.close(fd) cfgfile = cfg_tmpf cmd = [SNAPPY_CMD, op] - if op == 'install' and cfgfile: - cmd.append('--config=' + cfgfile) - elif op == 'config': - cmd.append(cfgfile) - if op == 'install': + if cfgfile: + cmd.append('--config=' + cfgfile) if path: cmd.append(path) else: cmd.append(name) + elif op == 'config': + cmd += [name, cfgfile] util.subp(cmd) diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index 7dc77970..8759a07d 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -4,8 +4,10 @@ from cloudinit import util from .. import helpers as t_help import os +import shutil import tempfile + class TestInstallPackages(t_help.TestCase): def setUp(self): super(TestInstallPackages, self).setUp() @@ -15,15 +17,19 @@ class TestInstallPackages(t_help.TestCase): self.apply_patches([(util, 'subp', self._subp)]) self.subp_called = [] self.snapcmds = [] - self.tmp = tempfile.mkdtemp() + self.tmp = tempfile.mkdtemp(prefix="TestInstallPackages") def tearDown(self): apply_patches([i for i in reversed(self.unapply)]) + shutil.rmtree(self.tmp) def apply_patches(self, patches): ret = apply_patches(patches) self.unapply += ret + def populate_tmp(self, files): + return t_help.populate_dir(self.tmp, files) + def _subp(self, *args, **kwargs): # supports subp calling with cmd as args or kwargs if 'args' not in kwargs: @@ -31,6 +37,8 @@ class TestInstallPackages(t_help.TestCase): self.subp_called.append(kwargs) snap_cmds = [] args = kwargs['args'] + # here we basically parse the snappy command invoked + # and append to snapcmds a list of (mode, pkg, config) if args[0:2] == ['snappy', 'config']: if args[3] == "-": config = kwargs.get('data', '') @@ -39,8 +47,6 @@ class TestInstallPackages(t_help.TestCase): config = fp.read() self.snapcmds.append(['config', args[2], config]) elif args[0:2] == ['snappy', 'install']: - # basically parse the snappy command and add - # to snap_installs a tuple (pkg, config) config = None pkg = None for arg in args[2:]: @@ -59,72 +65,88 @@ class TestInstallPackages(t_help.TestCase): ret = get_package_ops( packages=['pkg1', 'pkg2', 'pkg3'], configs={'pkg2': b'mycfg2'}, installed=[]) - self.assertEqual(ret, - [makeop('install', 'pkg1', None, None), - makeop('install', 'pkg2', b'mycfg2', None), - makeop('install', 'pkg3', None, None)]) + self.assertEqual( + ret, [makeop('install', 'pkg1', None, None), + makeop('install', 'pkg2', b'mycfg2', None), + makeop('install', 'pkg3', None, None)]) def test_package_ops_config_only(self): ret = get_package_ops( packages=None, configs={'pkg2': b'mycfg2'}, installed=['pkg1', 'pkg2']) - self.assertEqual(ret, - [makeop('config', 'pkg2', b'mycfg2')]) + self.assertEqual( + ret, [makeop('config', 'pkg2', b'mycfg2')]) def test_package_ops_install_and_config(self): ret = get_package_ops( packages=['pkg3', 'pkg2'], configs={'pkg2': b'mycfg2', 'xinstalled': b'xcfg'}, installed=['xinstalled']) - self.assertEqual(ret, - [makeop('install', 'pkg3'), - makeop('install', 'pkg2', b'mycfg2'), - makeop('config', 'xinstalled', b'xcfg')]) + self.assertEqual( + ret, [makeop('install', 'pkg3'), + makeop('install', 'pkg2', b'mycfg2'), + makeop('config', 'xinstalled', b'xcfg')]) def test_package_ops_with_file(self): - t_help.populate_dir(self.tmp, + self.populate_tmp( {"snapf1.snap": b"foo1", "snapf1.config": b"snapf1cfg", "snapf2.snap": b"foo2", "foo.bar": "ignored"}) ret = get_package_ops( packages=['pkg1'], configs={}, installed=[], fspath=self.tmp) - self.assertEqual(ret, + self.assertEqual( + ret, [makeop_tmpd(self.tmp, 'install', 'snapf1', path="snapf1.snap", cfgfile="snapf1.config"), makeop_tmpd(self.tmp, 'install', 'snapf2', path="snapf2.snap"), makeop('install', 'pkg1')]) - #def render_snap_op(op, name, path=None, cfgfile=None, config=None): + def test_package_ops_config_overrides_file(self): + # config data overrides local file .config + self.populate_tmp( + {"snapf1.snap": b"foo1", "snapf1.config": b"snapf1cfg"}) + ret = get_package_ops( + packages=[], configs={'snapf1': 'snapf1cfg-config'}, + installed=[], fspath=self.tmp) + self.assertEqual( + ret, [makeop_tmpd(self.tmp, 'install', 'snapf1', + path="snapf1.snap", config="snapf1cfg-config")]) + def test_render_op_localsnap(self): - t_help.populate_dir(self.tmp, {"snapf1.snap": b"foo1"}) + self.populate_tmp({"snapf1.snap": b"foo1"}) op = makeop_tmpd(self.tmp, 'install', 'snapf1', path='snapf1.snap') render_snap_op(**op) - self.assertEqual(self.snapcmds, - [['install', op['path'], None]]) + self.assertEqual( + self.snapcmds, [['install', op['path'], None]]) def test_render_op_localsnap_localconfig(self): - t_help.populate_dir(self.tmp, + self.populate_tmp( {"snapf1.snap": b"foo1", 'snapf1.config': b'snapf1cfg'}) op = makeop_tmpd(self.tmp, 'install', 'snapf1', path='snapf1.snap', cfgfile='snapf1.config') render_snap_op(**op) - self.assertEqual(self.snapcmds, - [['install', op['path'], b'snapf1cfg']]) - - def test_render_op_localsnap_config(self): - pass + self.assertEqual( + self.snapcmds, [['install', op['path'], b'snapf1cfg']]) def test_render_op_snap(self): - pass + op = makeop('install', 'snapf1') + render_snap_op(**op) + self.assertEqual( + self.snapcmds, [['install', 'snapf1', None]]) def test_render_op_snap_config(self): - pass + op = makeop('install', 'snapf1', config=b'myconfig') + render_snap_op(**op) + self.assertEqual( + self.snapcmds, [['install', 'snapf1', b'myconfig']]) def test_render_op_config(self): - pass + op = makeop('config', 'snapf1', config=b'myconfig') + render_snap_op(**op) + self.assertEqual( + self.snapcmds, [['config', 'snapf1', b'myconfig']]) - def makeop_tmpd(tmpd, op, name, config=None, path=None, cfgfile=None): if cfgfile: cfgfile = os.path.sep.join([tmpd, cfgfile]) @@ -132,56 +154,6 @@ def makeop_tmpd(tmpd, op, name, config=None, path=None, cfgfile=None): path = os.path.sep.join([tmpd, path]) return(makeop(op=op, name=name, config=config, path=path, cfgfile=cfgfile)) -# def test_local_snaps_no_config(self): -# t_help.populate_dir(self.tmp, -# {"snap1.snap": b"foo", "snap2.snap": b"foo", "foosnap.txt": b"foo"}) -# cc_snappy.install_packages(self.tmp, None) -# self.assertEqual(self.snap_installs, -# [("snap1.snap", None), ("snap2.snap", None)]) -# -# def test_local_snaps_mixed_config(self): -# t_help.populate_dir(self.tmp, -# {"snap1.snap": b"foo", "snap2.snap": b"snap2", -# "snap1.config": b"snap1config"}) -# cc_snappy.install_packages(self.tmp, None) -# self.assertEqual(self.snap_installs, -# [("snap1.snap", b"snap1config"), ("snap2.snap", None)]) -# -# def test_local_snaps_all_config(self): -# t_help.populate_dir(self.tmp, -# {"snap1.snap": "foo", "snap1.config": b"snap1config", -# "snap2.snap": "snap2", "snap2.config": b"snap2config"}) -# cc_snappy.install_packages(self.tmp, None) -# self.assertEqual(self.snap_installs, -# [("snap1.snap", b"snap1config"), ("snap2.snap", b"snap2config")]) -# -# def test_local_snaps_and_packages(self): -# t_help.populate_dir(self.tmp, -# {"snap1.snap": "foo", "snap1.config": b"snap1config"}) -# cc_snappy.install_packages(self.tmp, ["snap-in-store"]) -# self.assertEqual(self.snap_installs, -# [("snap1.snap", b"snap1config"), ("snap-in-store", None)]) -# -# def test_packages_no_config(self): -# cc_snappy.install_packages(self.tmp, ["snap-in-store"]) -# self.assertEqual(self.snap_installs, -# [("snap-in-store", None)]) -# -# def test_packages_mixed_config(self): -# cc_snappy.install_packages(self.tmp, -# ["snap-in-store", -# {'name': 'snap2-in-store', 'config': b"foo"}]) -# self.assertEqual(self.snap_installs, -# [("snap-in-store", None), ("snap2-in-store", b"foo")]) -# -# def test_packages_all_config(self): -# cc_snappy.install_packages(self.tmp, -# [{'name': 'snap1-in-store', 'config': b"boo"}, -# {'name': 'snap2-in-store', 'config': b"wark"}]) -# self.assertEqual(self.snap_installs, -# [("snap1-in-store", b"boo"), ("snap2-in-store", b"wark")]) -# -# def apply_patches(patches): ret = [] @@ -192,4 +164,3 @@ def apply_patches(patches): setattr(ref, name, replace) ret.append((ref, name, orig)) return ret - -- cgit v1.2.3 From 4c341a87d4b0804565e74e6335a0293dab6c0c7b Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 26 Mar 2015 22:10:01 -0400 Subject: add tests for data types --- cloudinit/config/cc_snappy.py | 2 +- .../unittests/test_handler/test_handler_snappy.py | 35 +++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index d1447fe5..bd928e54 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -101,7 +101,7 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): elif isinstance(config, six.text_type): cfg_bytes = config.encode() else: - cfg_bytes = yaml.safe_dump(config).encode() + cfg_bytes = util.yaml_dumps(config) (fd, cfg_tmpf) = tempfile.mkstemp() os.write(fd, cfg_bytes) diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index 8759a07d..81d891d2 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -140,12 +140,45 @@ class TestInstallPackages(t_help.TestCase): self.assertEqual( self.snapcmds, [['install', 'snapf1', b'myconfig']]) - def test_render_op_config(self): + def test_render_op_config_bytes(self): op = makeop('config', 'snapf1', config=b'myconfig') render_snap_op(**op) self.assertEqual( self.snapcmds, [['config', 'snapf1', b'myconfig']]) + def test_render_op_config_string(self): + mycfg = 'myconfig: foo\nhisconfig: bar\n' + op = makeop('config', 'snapf1', config=mycfg) + render_snap_op(**op) + self.assertEqual( + self.snapcmds, [['config', 'snapf1', mycfg.encode()]]) + + def test_render_op_config_dict(self): + # config entry for package can be a dict, not a string blob + mycfg = {'foo': 'bar'} + op = makeop('config', 'snapf1', config=mycfg) + render_snap_op(**op) + # snapcmds is a list of 3-entry lists. data_found will be the + # blob of data in the file in 'snappy install --config=' + data_found = self.snapcmds[0][2] + self.assertEqual(mycfg, util.load_yaml(data_found)) + + def test_render_op_config_list(self): + # config entry for package can be a list, not a string blob + mycfg = ['foo', 'bar', 'wark', {'f1': 'b1'}] + op = makeop('config', 'snapf1', config=mycfg) + render_snap_op(**op) + data_found = self.snapcmds[0][2] + self.assertEqual(mycfg, util.load_yaml(data_found, allowed=(list,))) + + def test_render_op_config_int(self): + # config entry for package can be a list, not a string blob + mycfg = 1 + op = makeop('config', 'snapf1', config=mycfg) + render_snap_op(**op) + data_found = self.snapcmds[0][2] + self.assertEqual(mycfg, util.load_yaml(data_found, allowed=(int,))) + def makeop_tmpd(tmpd, op, name, config=None, path=None, cfgfile=None): if cfgfile: -- cgit v1.2.3 From b7b8004bc58f1243d023092e67f3b78743086ff2 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 27 Mar 2015 11:11:05 -0400 Subject: change 'configs' to 'config', and namespace input to 'snappy config' the input to 'snappy config ' is expected to have config: : content: So here we pad that input correctly. Note, that a .config file on disk is not modified. Also, we change 'configs' to just be 'config', to be possibly compatible with the a future 'snappy config /' that dumped: config: pkg1: data1 pkg2: data2 --- cloudinit/config/cc_snappy.py | 29 +++++----- .../unittests/test_handler/test_handler_snappy.py | 61 ++++++++++++++++------ 2 files changed, 61 insertions(+), 29 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index e664234a..a3af98a6 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -8,9 +8,11 @@ Example config: system_snappy: auto ssh_enabled: False packages: [etcd, pkg2] - configs: - pkgname: pkgname-config-blob - pkg2: config-blob + config: + pkgname: + key2: value2 + pkg2: + key1: value1 packages_dir: '/writable/user-data/cloud-init/snaps' - ssh_enabled: @@ -31,7 +33,7 @@ Example config: /foo.snap snappy install /bar.snap - Note, that if provided a 'configs' entry for 'ubuntu-core', then + Note, that if provided a 'config' entry for 'ubuntu-core', then cloud-init will invoke: snappy config ubuntu-core Allowing you to configure ubuntu-core in this way. """ @@ -56,7 +58,7 @@ BUILTIN_CFG = { 'packages_dir': '/writable/user-data/cloud-init/snaps', 'ssh_enabled': False, 'system_snappy': "auto", - 'configs': {}, + 'config': {}, } @@ -119,15 +121,14 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): try: cfg_tmpf = None if config is not None: - if isinstance(config, six.binary_type): - cfg_bytes = config - elif isinstance(config, six.text_type): - cfg_bytes = config.encode() - else: - cfg_bytes = util.yaml_dumps(config).encode() - + # input to 'snappy config packagename' must have nested data. odd. + # config: + # packagename: + # config + # Note, however, we do not touch config files on disk. + nested_cfg = {'config': {name: config}} (fd, cfg_tmpf) = tempfile.mkstemp() - os.write(fd, cfg_bytes) + os.write(fd, util.yaml_dumps(nested_cfg).encode()) os.close(fd) cfgfile = cfg_tmpf @@ -218,7 +219,7 @@ def handle(name, cfg, cloud, log, args): return pkg_ops = get_package_ops(packages=mycfg['packages'], - configs=mycfg['configs'], + configs=mycfg['config'], fspath=mycfg['packages_dir']) set_snappy_command() diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index 81d891d2..f56a22f7 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -6,6 +6,9 @@ from .. import helpers as t_help import os import shutil import tempfile +import yaml + +ALLOWED = (dict, list, int, str) class TestInstallPackages(t_help.TestCase): @@ -44,7 +47,7 @@ class TestInstallPackages(t_help.TestCase): config = kwargs.get('data', '') else: with open(args[3], "rb") as fp: - config = fp.read() + config = yaml.safe_load(fp.read()) self.snapcmds.append(['config', args[2], config]) elif args[0:2] == ['snappy', 'install']: config = None @@ -56,7 +59,7 @@ class TestInstallPackages(t_help.TestCase): config = kwargs.get('data', '') elif cfgfile: with open(cfgfile, "rb") as fp: - config = fp.read() + config = yaml.safe_load(fp.read()) elif not pkg and not arg.startswith("-"): pkg = arg self.snapcmds.append(['install', pkg, config]) @@ -126,7 +129,7 @@ class TestInstallPackages(t_help.TestCase): path='snapf1.snap', cfgfile='snapf1.config') render_snap_op(**op) self.assertEqual( - self.snapcmds, [['install', op['path'], b'snapf1cfg']]) + self.snapcmds, [['install', op['path'], 'snapf1cfg']]) def test_render_op_snap(self): op = makeop('install', 'snapf1') @@ -135,49 +138,77 @@ class TestInstallPackages(t_help.TestCase): self.snapcmds, [['install', 'snapf1', None]]) def test_render_op_snap_config(self): - op = makeop('install', 'snapf1', config=b'myconfig') + mycfg = {'key1': 'value1'} + name = "snapf1" + op = makeop('install', name, config=mycfg) render_snap_op(**op) self.assertEqual( - self.snapcmds, [['install', 'snapf1', b'myconfig']]) + self.snapcmds, [['install', name, {'config': {name: mycfg}}]]) def test_render_op_config_bytes(self): - op = makeop('config', 'snapf1', config=b'myconfig') + name = "snapf1" + mycfg = b'myconfig' + op = makeop('config', name, config=mycfg) render_snap_op(**op) self.assertEqual( - self.snapcmds, [['config', 'snapf1', b'myconfig']]) + self.snapcmds, [['config', 'snapf1', {'config': {name: mycfg}}]]) def test_render_op_config_string(self): + name = 'snapf1' mycfg = 'myconfig: foo\nhisconfig: bar\n' - op = makeop('config', 'snapf1', config=mycfg) + op = makeop('config', name, config=mycfg) render_snap_op(**op) self.assertEqual( - self.snapcmds, [['config', 'snapf1', mycfg.encode()]]) + self.snapcmds, [['config', 'snapf1', {'config': {name: mycfg}}]]) def test_render_op_config_dict(self): # config entry for package can be a dict, not a string blob mycfg = {'foo': 'bar'} - op = makeop('config', 'snapf1', config=mycfg) + name = 'snapf1' + op = makeop('config', name, config=mycfg) render_snap_op(**op) # snapcmds is a list of 3-entry lists. data_found will be the # blob of data in the file in 'snappy install --config=' data_found = self.snapcmds[0][2] - self.assertEqual(mycfg, util.load_yaml(data_found)) + self.assertEqual(mycfg, data_found['config'][name]) def test_render_op_config_list(self): # config entry for package can be a list, not a string blob mycfg = ['foo', 'bar', 'wark', {'f1': 'b1'}] - op = makeop('config', 'snapf1', config=mycfg) + name = "snapf1" + op = makeop('config', name, config=mycfg) render_snap_op(**op) data_found = self.snapcmds[0][2] - self.assertEqual(mycfg, util.load_yaml(data_found, allowed=(list,))) + self.assertEqual(mycfg, data_found['config'][name]) def test_render_op_config_int(self): # config entry for package can be a list, not a string blob mycfg = 1 - op = makeop('config', 'snapf1', config=mycfg) + name = 'snapf1' + op = makeop('config', name, config=mycfg) render_snap_op(**op) data_found = self.snapcmds[0][2] - self.assertEqual(mycfg, util.load_yaml(data_found, allowed=(int,))) + self.assertEqual(mycfg, data_found['config'][name]) + + def test_render_does_not_pad_cfgfile(self): + # package_ops with cfgfile should not modify --file= content. + mydata = "foo1: bar1\nk: [l1, l2, l3]\n" + self.populate_tmp( + {"snapf1.snap": b"foo1", "snapf1.config": mydata.encode()}) + ret = get_package_ops( + packages=[], configs={}, installed=[], fspath=self.tmp) + self.assertEqual( + ret, + [makeop_tmpd(self.tmp, 'install', 'snapf1', path="snapf1.snap", + cfgfile="snapf1.config")]) + + # now the op was ok, but test that render didn't mess it up. + render_snap_op(**ret[0]) + data_found = self.snapcmds[0][2] + # the data found gets loaded in the snapcmd interpretation + # so this comparison is a bit lossy, but input to snappy config + # is expected to be yaml loadable, so it should be OK. + self.assertEqual(yaml.safe_load(mydata), data_found) def makeop_tmpd(tmpd, op, name, config=None, path=None, cfgfile=None): -- cgit v1.2.3 From b4989280d7285f214c1016efa36a20ad57821d6b Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 27 Mar 2015 14:19:56 -0400 Subject: address namespacing --- cloudinit/config/cc_snappy.py | 52 +++++++++++++++++----- .../unittests/test_handler/test_handler_snappy.py | 46 +++++++++++++++++++ 2 files changed, 88 insertions(+), 10 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index f237feef..f8f67e1f 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -7,7 +7,7 @@ Example config: snappy: system_snappy: auto ssh_enabled: False - packages: [etcd, pkg2] + packages: [etcd, pkg2.smoser] config: pkgname: key2: value2 @@ -18,11 +18,15 @@ Example config: - ssh_enabled: This defaults to 'False'. Set to a non-false value to enable ssh service - snap installation and config - The above would install 'etcd', and then install 'pkg2' with a + The above would install 'etcd', and then install 'pkg2.smoser' with a '--config=' argument where 'file' as 'config-blob' inside it. If 'pkgname' is installed already, then 'snappy config pkgname ' will be called where 'file' has 'pkgname-config-blob' as its content. + Entries in 'config' can be namespaced or non-namespaced for a package. + In either case, the config provided to snappy command is non-namespaced. + The package name is provided as it appears. + If 'packages_dir' has files in it that end in '.snap', then they are installed. Given 3 files: /foo.snap @@ -52,6 +56,7 @@ LOG = logging.getLogger(__name__) frequency = PER_INSTANCE SNAPPY_CMD = "snappy" +NAMESPACE_DELIM = '.' BUILTIN_CFG = { 'packages': [], @@ -81,10 +86,20 @@ def makeop(op, name, config=None, path=None, cfgfile=None): 'cfgfile': cfgfile}) +def get_package_config(configs, name): + # load the package's config from the configs dict. + # prefer full-name entry (config-example.canonical) + # over short name entry (config-example) + if name in configs: + return configs[name] + return configs.get(name.partition(NAMESPACE_DELIM)[0]) + + def get_package_ops(packages, configs, installed=None, fspath=None): # get the install an config operations that should be done if installed is None: installed = read_installed_packages() + short_installed = [p.partition(NAMESPACE_DELIM)[0] for p in installed] if not packages: packages = [] @@ -95,21 +110,31 @@ def get_package_ops(packages, configs, installed=None, fspath=None): ops += get_fs_package_ops(fspath) for name in packages: - ops.append(makeop('install', name, configs.get('name'))) + ops.append(makeop('install', name, get_package_config(configs, name))) to_install = [f['name'] for f in ops] + short_to_install = [f['name'].partition(NAMESPACE_DELIM)[0] for f in ops] for name in configs: - if name in installed and name not in to_install: - ops.append(makeop('config', name, config=configs[name])) + if name in to_install: + continue + shortname = name.partition(NAMESPACE_DELIM)[0] + if shortname in short_to_install: + continue + if name in installed or shortname in short_installed: + ops.append(makeop('config', name, + config=get_package_config(configs, name))) # prefer config entries to filepath entries for op in ops: + if op['op'] != 'install' or not op['cfgfile']: + continue name = op['name'] - if name in configs and op['op'] == 'install' and 'cfgfile' in op: - LOG.debug("preferring configs[%s] over '%s'", name, op['cfgfile']) + fromcfg = get_package_config(configs, op['name']) + if fromcfg: + LOG.debug("preferring configs[%(name)s] over '%(cfgfile)s'", op) op['cfgfile'] = None - op['config'] = configs[op['name']] + op['config'] = fromcfg return ops @@ -118,6 +143,7 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): if op not in ('install', 'config'): raise ValueError("cannot render op '%s'" % op) + shortname = name.partition(NAMESPACE_DELIM)[0] try: cfg_tmpf = None if config is not None: @@ -126,7 +152,7 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): # packagename: # config # Note, however, we do not touch config files on disk. - nested_cfg = {'config': {name: config}} + nested_cfg = {'config': {shortname: config}} (fd, cfg_tmpf) = tempfile.mkstemp() os.write(fd, util.yaml_dumps(nested_cfg).encode()) os.close(fd) @@ -151,7 +177,13 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): def read_installed_packages(): - return [p[0] for p in read_pkg_data()] + ret = [] + for (name, date, version, dev) in read_pkg_data(): + if dev: + ret.append(NAMESPACE_DELIM.join(name, dev)) + else: + ret.append(name) + return ret def read_pkg_data(): diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index f56a22f7..f0776259 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -90,6 +90,15 @@ class TestInstallPackages(t_help.TestCase): makeop('install', 'pkg2', b'mycfg2'), makeop('config', 'xinstalled', b'xcfg')]) + def test_package_ops_install_long_config_short(self): + # a package can be installed by full name, but have config by short + cfg = {'k1': 'k2'} + ret = get_package_ops( + packages=['config-example.canonical'], + configs={'config-example': cfg}, installed=[]) + self.assertEqual( + ret, [makeop('install', 'config-example.canonical', cfg)]) + def test_package_ops_with_file(self): self.populate_tmp( {"snapf1.snap": b"foo1", "snapf1.config": b"snapf1cfg", @@ -114,6 +123,34 @@ class TestInstallPackages(t_help.TestCase): ret, [makeop_tmpd(self.tmp, 'install', 'snapf1', path="snapf1.snap", config="snapf1cfg-config")]) + def test_package_ops_namespacing(self): + cfgs = { + 'config-example': {'k1': 'v1'}, + 'pkg1': {'p1': 'p2'}, + 'ubuntu-core': {'c1': 'c2'}, + 'notinstalled.smoser': {'s1': 's2'}, + } + cfg = {'config-example-k1': 'config-example-k2'} + ret = get_package_ops( + packages=['config-example.canonical'], configs=cfgs, + installed=['config-example.smoser', 'pkg1.canonical', + 'ubuntu-core']) + + expected_configs = [ + makeop('config', 'pkg1', config=cfgs['pkg1']), + makeop('config', 'ubuntu-core', config=cfgs['ubuntu-core'])] + expected_installs = [ + makeop('install', 'config-example.canonical', + config=cfgs['config-example'])] + + installs = [i for i in ret if i['op'] == 'install'] + configs = [c for c in ret if c['op'] == 'config'] + + self.assertEqual(installs, expected_installs) + # configs are not ordered + self.assertEqual(len(configs), len(expected_configs)) + self.assertTrue(all(found in expected_configs for found in configs)) + def test_render_op_localsnap(self): self.populate_tmp({"snapf1.snap": b"foo1"}) op = makeop_tmpd(self.tmp, 'install', 'snapf1', @@ -190,6 +227,15 @@ class TestInstallPackages(t_help.TestCase): data_found = self.snapcmds[0][2] self.assertEqual(mycfg, data_found['config'][name]) + def test_render_long_configs_short(self): + # install a namespaced package should have un-namespaced config + mycfg = {'k1': 'k2'} + name = 'snapf1' + op = makeop('install', name + ".smoser", config=mycfg) + render_snap_op(**op) + data_found = self.snapcmds[0][2] + self.assertEqual(mycfg, data_found['config'][name]) + def test_render_does_not_pad_cfgfile(self): # package_ops with cfgfile should not modify --file= content. mydata = "foo1: bar1\nk: [l1, l2, l3]\n" -- cgit v1.2.3 From 09cc5909e3d69c03622b7dc2c4cb35fd378743cb Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 27 Mar 2015 16:20:05 -0400 Subject: be more user-friendly when looking for matching .config On fspath installs, look for .config files harder. Given a file named: pkg.namespace_0.version_arch.snap We'll search for config files named: pkg.namespace_0.version_arch.config pkg.namespace.config pkg.config --- cloudinit/config/cc_snappy.py | 21 +++++++++---- .../unittests/test_handler/test_handler_snappy.py | 34 ++++++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index d9dd9771..74ae3ac0 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -67,15 +67,26 @@ BUILTIN_CFG = { } +def parse_filename(fname): + fname = os.path.basename(fname) + fname_noext = fname.rpartition(".")[0] + name = fname_noext.partition("_")[0] + shortname = name.partition(".")[0] + return(name, shortname, fname_noext) + + def get_fs_package_ops(fspath): if not fspath: return [] ops = [] - for snapfile in glob.glob(os.path.sep.join([fspath, '*.snap'])): - cfg = snapfile.rpartition(".")[0] + ".config" - name = os.path.basename(snapfile).rpartition(".")[0] - if not os.path.isfile(cfg): - cfg = None + for snapfile in sorted(glob.glob(os.path.sep.join([fspath, '*.snap']))): + (name, shortname, fname_noext) = parse_filename(snapfile) + cfg = None + for cand in set((fname_noext, name, shortname,)): + fpcand = os.path.sep.join([fspath, cand]) + ".config" + if os.path.isfile(fpcand): + cfg = fpcand + break ops.append(makeop('install', name, config=None, path=snapfile, cfgfile=cfg)) return ops diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index f0776259..84512846 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -112,6 +112,40 @@ class TestInstallPackages(t_help.TestCase): makeop_tmpd(self.tmp, 'install', 'snapf2', path="snapf2.snap"), makeop('install', 'pkg1')]) + def test_package_ops_common_filename(self): + # fish package name from filename + # package names likely look like: pkgname.namespace_version_arch.snap + fname = "xkcd-webserver.canonical_0.3.4_all.snap" + name = "xkcd-webserver.canonical" + shortname = "xkcd-webserver" + + # find filenames + self.populate_tmp( + {"pkg-ws.smoser_0.3.4_all.snap": "pkg-ws-snapdata", + "pkg-ws.config": "pkg-ws-config", + "pkg1.smoser_1.2.3_all.snap": "pkg1.snapdata", + "pkg1.smoser.config": "pkg1.smoser.config-data", + "pkg1.config": "pkg1.config-data", + "pkg2.smoser_0.0_amd64.snap": "pkg2-snapdata", + "pkg2.smoser_0.0_amd64.config": "pkg2.config", + }) + + ret = get_package_ops( + packages=[], configs={}, installed=[], fspath=self.tmp) + raise Exception("ret: %s" % ret) + self.assertEqual( + ret, + [makeop_tmpd(self.tmp, 'install', 'pkg-ws.smoser', + path="pkg-ws.smoser_0.3.4_all.snap", + cfgfile="pkg-ws.config"), + makeop_tmpd(self.tmp, 'install', 'pkg1.smoser', + path="pkg1.smoser_1.2.3_all.snap", + cfgfile="pkg1.smoser.config"), + makeop_tmpd(self.tmp, 'install', 'pkg2.smoser', + path="pkg2.smoser_0.0_amd64.snap", + cfgfile="pkg2.smoser_0.0_amd64.config"), + ]) + def test_package_ops_config_overrides_file(self): # config data overrides local file .config self.populate_tmp( -- cgit v1.2.3 From f3007f9a398f8deaa8247941cfbe023de9d99850 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 27 Mar 2015 16:23:42 -0400 Subject: remove debug --- tests/unittests/test_handler/test_handler_snappy.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index 84512846..8effd99d 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -132,7 +132,6 @@ class TestInstallPackages(t_help.TestCase): ret = get_package_ops( packages=[], configs={}, installed=[], fspath=self.tmp) - raise Exception("ret: %s" % ret) self.assertEqual( ret, [makeop_tmpd(self.tmp, 'install', 'pkg-ws.smoser', -- cgit v1.2.3 From 8165000c3975db07cb5b8b29410635dd6c9345bd Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 31 Mar 2015 14:20:00 -0400 Subject: adjust cc_snappy for snappy install package with config. It was believed that to install a package with config the command was: snappy install --config=config-file Instead, what was implemented in snappy was: snappy install [] This modifies cloud-init to invoke the latter and changes the tests appropriately. LP: #1438836 --- cloudinit/config/cc_snappy.py | 9 ++++----- tests/unittests/test_handler/test_handler_snappy.py | 10 ++++++---- 2 files changed, 10 insertions(+), 9 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index 131ee7ea..6a7ae09b 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -19,7 +19,7 @@ Example config: This defaults to 'False'. Set to a non-false value to enable ssh service - snap installation and config The above would install 'etcd', and then install 'pkg2.smoser' with a - '--config=' argument where 'file' as 'config-blob' inside it. + '' argument where 'config-file' has 'config-blob' inside it. If 'pkgname' is installed already, then 'snappy config pkgname ' will be called where 'file' has 'pkgname-config-blob' as its content. @@ -33,8 +33,7 @@ Example config: /foo.config /bar.snap cloud-init will invoke: - snappy install "--config=/foo.config" \ - /foo.snap + snappy install /foo.snap /foo.config snappy install /bar.snap Note, that if provided a 'config' entry for 'ubuntu-core', then @@ -171,13 +170,13 @@ def render_snap_op(op, name, path=None, cfgfile=None, config=None): cmd = [SNAPPY_CMD, op] if op == 'install': - if cfgfile: - cmd.append('--config=' + cfgfile) if path: cmd.append("--allow-unauthenticated") cmd.append(path) else: cmd.append(name) + if cfgfile: + cmd.append(cfgfile) elif op == 'config': cmd += [name, cfgfile] diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index 8effd99d..f3109bac 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -53,15 +53,17 @@ class TestInstallPackages(t_help.TestCase): config = None pkg = None for arg in args[2:]: - if arg.startswith("--config="): - cfgfile = arg.partition("=")[2] + if arg.startswith("-"): + continue + if not pkg: + pkg = arg + elif not config: + cfgfile = arg if cfgfile == "-": config = kwargs.get('data', '') elif cfgfile: with open(cfgfile, "rb") as fp: config = yaml.safe_load(fp.read()) - elif not pkg and not arg.startswith("-"): - pkg = arg self.snapcmds.append(['install', pkg, config]) def test_package_ops_1(self): -- cgit v1.2.3 From b6060efa4bd1de7f49f6aca3e97cfe77947f3a93 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 15 Apr 2015 12:13:17 +0100 Subject: Add unit tests for Azure hostname bouncing. Including minor refactoring to make mocking considerably easier. --- cloudinit/sources/DataSourceAzure.py | 28 ++-- tests/unittests/test_datasource/test_azure.py | 186 +++++++++++++++++++------- 2 files changed, 161 insertions(+), 53 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 6e030217..d4211fc4 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -66,6 +66,14 @@ DS_CFG_PATH = ['datasource', DS_NAME] DEF_EPHEMERAL_LABEL = 'Temporary Storage' +def get_hostname(hostname_command='hostname'): + return util.subp(hostname_command, capture=True)[0].strip() + + +def set_hostname(hostname, hostname_command='hostname'): + util.subp([hostname_command, hostname]) + + class DataSourceAzureNet(sources.DataSource): def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) @@ -313,13 +321,22 @@ def handle_set_hostname(enabled, hostname, cfg): hostname_command=cfg['hostname_command']) +def perform_hostname_bounce(command, env): + shell = not isinstance(command, (list, tuple)) + # capture=False, see comments in bug 1202758 and bug 1206164. + util.log_time(logfunc=LOG.debug, msg="publishing hostname", + get_uptime=True, func=util.subp, + kwargs={'args': command, 'shell': shell, 'capture': False, + 'env': env}) + + def apply_hostname_bounce(hostname, policy, interface, command, hostname_command="hostname"): # set the hostname to 'hostname' if it is not already set to that. # then, if policy is not off, bounce the interface using command - prev_hostname = util.subp(hostname_command, capture=True)[0].strip() + prev_hostname = get_hostname() - util.subp([hostname_command, hostname]) + set_hostname(hostname, hostname_command) msg = ("phostname=%s hostname=%s policy=%s interface=%s" % (prev_hostname, hostname, policy, interface)) @@ -341,12 +358,7 @@ def apply_hostname_bounce(hostname, policy, interface, command, command = BOUNCE_COMMAND LOG.debug("pubhname: publishing hostname [%s]", msg) - shell = not isinstance(command, (list, tuple)) - # capture=False, see comments in bug 1202758 and bug 1206164. - util.log_time(logfunc=LOG.debug, msg="publishing hostname", - get_uptime=True, func=util.subp, - kwargs={'args': command, 'shell': shell, 'capture': False, - 'env': env}) + perform_hostname_bounce(command, env) def crtfile_to_pubkey(fname): diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 8112c69b..3adf9bdf 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -116,9 +116,6 @@ class TestAzureDataSource(TestCase): data['iid_from_shared_cfg'] = path return 'i-my-azure-id' - def _apply_hostname_bounce(**kwargs): - data['apply_hostname_bounce'] = kwargs - if data.get('ovfcontent') is not None: populate_dir(os.path.join(self.paths.seed_dir, "azure"), {'ovf-env.xml': data['ovfcontent']}) @@ -132,7 +129,9 @@ class TestAzureDataSource(TestCase): (mod, 'wait_for_files', _wait_for_files), (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files), (mod, 'iid_from_shared_config', _iid_from_shared_config), - (mod, 'apply_hostname_bounce', _apply_hostname_bounce), + (mod, 'perform_hostname_bounce', mock.MagicMock()), + (mod, 'get_hostname', mock.MagicMock()), + (mod, 'set_hostname', mock.MagicMock()), ]) dsrc = mod.DataSourceAzureNet( @@ -272,47 +271,6 @@ class TestAzureDataSource(TestCase): for mypk in mypklist: self.assertIn(mypk, dsrc.cfg['_pubkeys']) - def test_disabled_bounce(self): - pass - - def test_apply_bounce_call_1(self): - # hostname needs to get through to apply_hostname_bounce - odata = {'HostName': 'my-random-hostname'} - data = {'ovfcontent': construct_valid_ovf_env(data=odata)} - - self._get_ds(data).get_data() - self.assertIn('hostname', data['apply_hostname_bounce']) - self.assertEqual(data['apply_hostname_bounce']['hostname'], - odata['HostName']) - - def test_apply_bounce_call_configurable(self): - # hostname_bounce should be configurable in datasource cfg - cfg = {'hostname_bounce': {'interface': 'eth1', 'policy': 'off', - 'command': 'my-bounce-command', - 'hostname_command': 'my-hostname-command'}} - odata = {'HostName': "xhost", - 'dscfg': {'text': b64e(yaml.dump(cfg)), - 'encoding': 'base64'}} - data = {'ovfcontent': construct_valid_ovf_env(data=odata)} - self._get_ds(data).get_data() - - for k in cfg['hostname_bounce']: - self.assertIn(k, data['apply_hostname_bounce']) - - for k, v in cfg['hostname_bounce'].items(): - self.assertEqual(data['apply_hostname_bounce'][k], v) - - def test_set_hostname_disabled(self): - # config specifying set_hostname off should not bounce - cfg = {'set_hostname': False} - odata = {'HostName': "xhost", - 'dscfg': {'text': b64e(yaml.dump(cfg)), - 'encoding': 'base64'}} - data = {'ovfcontent': construct_valid_ovf_env(data=odata)} - self._get_ds(data).get_data() - - self.assertEqual(data.get('apply_hostname_bounce', "N/A"), "N/A") - def test_default_ephemeral(self): # make sure the ephemeral device works odata = {} @@ -425,6 +383,144 @@ class TestAzureDataSource(TestCase): load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) +class TestAzureBounce(TestCase): + + def mock_out_azure_moving_parts(self): + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'invoke_agent')) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'wait_for_files')) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'iid_from_shared_config', + mock.MagicMock(return_value='i-my-azure-id'))) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs', + mock.MagicMock(return_value=[]))) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'find_ephemeral_disk', + mock.MagicMock(return_value=None))) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'find_ephemeral_part', + mock.MagicMock(return_value=None))) + + def setUp(self): + super(TestAzureBounce, self).setUp() + self.tmp = tempfile.mkdtemp() + self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent') + self.paths = helpers.Paths({'cloud_dir': self.tmp}) + self.addCleanup(shutil.rmtree, self.tmp) + DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d + self.patches = ExitStack() + self.mock_out_azure_moving_parts() + self.get_hostname = self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'get_hostname')) + self.set_hostname = self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'set_hostname')) + self.subp = self.patches.enter_context( + mock.patch('cloudinit.sources.DataSourceAzure.util.subp')) + + def tearDown(self): + self.patches.close() + + def _get_ds(self, ovfcontent=None): + if ovfcontent is not None: + populate_dir(os.path.join(self.paths.seed_dir, "azure"), + {'ovf-env.xml': ovfcontent}) + return DataSourceAzure.DataSourceAzureNet( + {}, distro=None, paths=self.paths) + + def get_ovf_env_with_dscfg(self, hostname, cfg): + odata = { + 'HostName': hostname, + 'dscfg': { + 'text': b64e(yaml.dump(cfg)), + 'encoding': 'base64' + } + } + return construct_valid_ovf_env(data=odata) + + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') + def test_disabled_bounce_does_not_perform_bounce( + self, perform_hostname_bounce): + cfg = {'hostname_bounce': {'policy': 'off'}} + self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data() + self.assertEqual(0, perform_hostname_bounce.call_count) + + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') + def test_unchanged_hostname_does_not_perform_bounce( + self, perform_hostname_bounce): + host_name = 'unchanged-host-name' + self.get_hostname.return_value = host_name + cfg = {'hostname_bounce': {'policy': 'yes'}} + self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data() + self.assertEqual(0, perform_hostname_bounce.call_count) + + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') + def test_force_performs_bounce_regardless(self, perform_hostname_bounce): + host_name = 'unchanged-host-name' + self.get_hostname.return_value = host_name + cfg = {'hostname_bounce': {'policy': 'force'}} + self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data() + self.assertEqual(1, perform_hostname_bounce.call_count) + + def test_different_hostnames_sets_hostname(self): + expected_hostname = 'azure-expected-host-name' + self.get_hostname.return_value = 'default-host-name' + self._get_ds( + self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data() + self.assertEqual(expected_hostname, + self.set_hostname.call_args_list[0][0][0]) + + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') + def test_different_hostnames_performs_bounce( + self, perform_hostname_bounce): + expected_hostname = 'azure-expected-host-name' + self.get_hostname.return_value = 'default-host-name' + self._get_ds( + self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data() + self.assertEqual(1, perform_hostname_bounce.call_count) + + def test_environment_correct_for_bounce_command(self): + interface = 'int0' + hostname = 'my-new-host' + old_hostname = 'my-old-host' + self.get_hostname.return_value = old_hostname + cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}} + data = self.get_ovf_env_with_dscfg(hostname, cfg) + self._get_ds(data).get_data() + self.assertEqual(1, self.subp.call_count) + bounce_env = self.subp.call_args[1]['env'] + self.assertEqual(interface, bounce_env['interface']) + self.assertEqual(hostname, bounce_env['hostname']) + self.assertEqual(old_hostname, bounce_env['old_hostname']) + + def test_default_bounce_command_used_by_default(self): + cmd = 'default-bounce-command' + DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd + cfg = {'hostname_bounce': {'policy': 'force'}} + data = self.get_ovf_env_with_dscfg('some-hostname', cfg) + self._get_ds(data).get_data() + self.assertEqual(1, self.subp.call_count) + bounce_args = self.subp.call_args[1]['args'] + self.assertEqual(cmd, bounce_args) + + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') + def test_set_hostname_option_can_disable_bounce( + self, perform_hostname_bounce): + cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}} + data = self.get_ovf_env_with_dscfg('some-hostname', cfg) + self._get_ds(data).get_data() + + self.assertEqual(0, perform_hostname_bounce.call_count) + + def test_set_hostname_option_can_disable_hostname_set(self): + cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}} + data = self.get_ovf_env_with_dscfg('some-hostname', cfg) + self._get_ds(data).get_data() + + self.assertEqual(0, self.set_hostname.call_count) + + class TestReadAzureOvf(TestCase): def test_invalid_xml_raises_non_azure_ds(self): invalid_xml = "" + construct_valid_ovf_env(data={}) -- cgit v1.2.3 From b8706d7dc930c5c9dce1f96a000c66e5dda14e02 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 15 Apr 2015 12:13:17 +0100 Subject: Reset host name after bounce has allowed walinuxagent to run successfully. --- cloudinit/sources/DataSourceAzure.py | 134 +++++++++++++------------- tests/unittests/test_datasource/test_azure.py | 31 ++++++ 2 files changed, 99 insertions(+), 66 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index d4211fc4..a19d9ca2 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -17,6 +17,7 @@ # along with this program. If not, see . import base64 +import contextlib import crypt import fnmatch import os @@ -74,6 +75,28 @@ def set_hostname(hostname, hostname_command='hostname'): util.subp([hostname_command, hostname]) +@contextlib.contextmanager +def temporary_hostname(temp_hostname, cfg, hostname_command='hostname'): + """ + Set a temporary hostname, restoring the previous hostname on exit. + + Will have the value of the previous hostname when used as a context + manager, or None if the hostname was not changed. + """ + policy = cfg['hostname_bounce']['policy'] + previous_hostname = get_hostname(hostname_command) + if (not util.is_true(cfg.get('set_hostname')) + or util.is_false(policy) + or (previous_hostname == temp_hostname and policy != 'force')): + yield None + return + set_hostname(temp_hostname, hostname_command) + try: + yield previous_hostname + finally: + set_hostname(previous_hostname, hostname_command) + + class DataSourceAzureNet(sources.DataSource): def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) @@ -162,33 +185,40 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) - # handle the hostname 'publishing' - try: - handle_set_hostname(mycfg.get('set_hostname'), - self.metadata.get('local-hostname'), - mycfg['hostname_bounce']) - except Exception as e: - LOG.warn("Failed publishing hostname: %s", e) - util.logexc(LOG, "handling set_hostname failed") - - try: - invoke_agent(mycfg['agent_command']) - except util.ProcessExecutionError: - # claim the datasource even if the command failed - util.logexc(LOG, "agent command '%s' failed.", - mycfg['agent_command']) - - shcfgxml = os.path.join(ddir, "SharedConfig.xml") - wait_for = [shcfgxml] - - fp_files = [] - for pk in self.cfg.get('_pubkeys', []): - bname = str(pk['fingerprint'] + ".crt") - fp_files += [os.path.join(ddir, bname)] + temp_hostname = self.metadata.get('local-hostname') + hostname_command = mycfg['hostname_bounce']['hostname_command'] + with temporary_hostname(temp_hostname, mycfg, + hostname_command=hostname_command) \ + as previous_hostname: + if (previous_hostname is not None + and util.is_true(mycfg.get('set_hostname'))): + cfg = mycfg['hostname_bounce'] + try: + perform_hostname_bounce(hostname=temp_hostname, + cfg=cfg, + prev_hostname=previous_hostname) + except Exception as e: + LOG.warn("Failed publishing hostname: %s", e) + util.logexc(LOG, "handling set_hostname failed") - missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", - func=wait_for_files, - args=(wait_for + fp_files,)) + try: + invoke_agent(mycfg['agent_command']) + except util.ProcessExecutionError: + # claim the datasource even if the command failed + util.logexc(LOG, "agent command '%s' failed.", + mycfg['agent_command']) + + shcfgxml = os.path.join(ddir, "SharedConfig.xml") + wait_for = [shcfgxml] + + fp_files = [] + for pk in self.cfg.get('_pubkeys', []): + bname = str(pk['fingerprint'] + ".crt") + fp_files += [os.path.join(ddir, bname)] + + missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", + func=wait_for_files, + args=(wait_for + fp_files,)) if len(missing): LOG.warn("Did not find files, but going on: %s", missing) @@ -307,48 +337,15 @@ def support_new_ephemeral(cfg): return mod_list -def handle_set_hostname(enabled, hostname, cfg): - if not util.is_true(enabled): - return - - if not hostname: - LOG.warn("set_hostname was true but no local-hostname") - return - - apply_hostname_bounce(hostname=hostname, policy=cfg['policy'], - interface=cfg['interface'], - command=cfg['command'], - hostname_command=cfg['hostname_command']) - - -def perform_hostname_bounce(command, env): - shell = not isinstance(command, (list, tuple)) - # capture=False, see comments in bug 1202758 and bug 1206164. - util.log_time(logfunc=LOG.debug, msg="publishing hostname", - get_uptime=True, func=util.subp, - kwargs={'args': command, 'shell': shell, 'capture': False, - 'env': env}) - - -def apply_hostname_bounce(hostname, policy, interface, command, - hostname_command="hostname"): +def perform_hostname_bounce(hostname, cfg, prev_hostname): # set the hostname to 'hostname' if it is not already set to that. # then, if policy is not off, bounce the interface using command - prev_hostname = get_hostname() - - set_hostname(hostname, hostname_command) - - msg = ("phostname=%s hostname=%s policy=%s interface=%s" % - (prev_hostname, hostname, policy, interface)) - - if util.is_false(policy): - LOG.debug("pubhname: policy false, skipping [%s]", msg) - return - - if prev_hostname == hostname and policy != "force": - LOG.debug("pubhname: no change, policy != force. skipping. [%s]", msg) - return + command = cfg['command'] + interface = cfg['interface'] + policy = cfg['policy'] + msg = ("hostname=%s policy=%s interface=%s" % + (hostname, policy, interface)) env = os.environ.copy() env['interface'] = interface env['hostname'] = hostname @@ -358,7 +355,12 @@ def apply_hostname_bounce(hostname, policy, interface, command, command = BOUNCE_COMMAND LOG.debug("pubhname: publishing hostname [%s]", msg) - perform_hostname_bounce(command, env) + shell = not isinstance(command, (list, tuple)) + # capture=False, see comments in bug 1202758 and bug 1206164. + util.log_time(logfunc=LOG.debug, msg="publishing hostname", + get_uptime=True, func=util.subp, + kwargs={'args': command, 'shell': shell, 'capture': False, + 'env': env}) def crtfile_to_pubkey(fname): diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 3adf9bdf..7e789853 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -439,6 +439,11 @@ class TestAzureBounce(TestCase): } return construct_valid_ovf_env(data=odata) + def test_disabled_bounce_does_not_change_hostname(self): + cfg = {'hostname_bounce': {'policy': 'off'}} + self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data() + self.assertEqual(0, self.set_hostname.call_count) + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') def test_disabled_bounce_does_not_perform_bounce( self, perform_hostname_bounce): @@ -446,6 +451,13 @@ class TestAzureBounce(TestCase): self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data() self.assertEqual(0, perform_hostname_bounce.call_count) + def test_same_hostname_does_not_change_hostname(self): + host_name = 'unchanged-host-name' + self.get_hostname.return_value = host_name + cfg = {'hostname_bounce': {'policy': 'yes'}} + self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data() + self.assertEqual(0, self.set_hostname.call_count) + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') def test_unchanged_hostname_does_not_perform_bounce( self, perform_hostname_bounce): @@ -480,6 +492,25 @@ class TestAzureBounce(TestCase): self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data() self.assertEqual(1, perform_hostname_bounce.call_count) + def test_different_hostnames_sets_hostname_back(self): + initial_host_name = 'default-host-name' + self.get_hostname.return_value = initial_host_name + self._get_ds( + self.get_ovf_env_with_dscfg('some-host-name', {})).get_data() + self.assertEqual(initial_host_name, + self.set_hostname.call_args_list[-1][0][0]) + + @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') + def test_failure_in_bounce_still_resets_host_name( + self, perform_hostname_bounce): + perform_hostname_bounce.side_effect = Exception + initial_host_name = 'default-host-name' + self.get_hostname.return_value = initial_host_name + self._get_ds( + self.get_ovf_env_with_dscfg('some-host-name', {})).get_data() + self.assertEqual(initial_host_name, + self.set_hostname.call_args_list[-1][0][0]) + def test_environment_correct_for_bounce_command(self): interface = 'int0' hostname = 'my-new-host' -- cgit v1.2.3 From dcd4b2b371059bd6249b4e43af371ee1162273e8 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 16 Apr 2015 16:41:06 -0400 Subject: pep8 fixes --- cloudinit/config/cc_snappy.py | 4 ++-- cloudinit/handlers/__init__.py | 6 +++--- tests/unittests/test_data.py | 5 +++-- 3 files changed, 8 insertions(+), 7 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index 6a7ae09b..bfe76558 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -72,7 +72,7 @@ def parse_filename(fname): name = fname_noext.partition("_")[0] shortname = name.partition(".")[0] return(name, shortname, fname_noext) - + def get_fs_package_ops(fspath): if not fspath: @@ -98,7 +98,7 @@ def makeop(op, name, config=None, path=None, cfgfile=None): def get_package_config(configs, name): # load the package's config from the configs dict. - # prefer full-name entry (config-example.canonical) + # prefer full-name entry (config-example.canonical) # over short name entry (config-example) if name in configs: return configs[name] diff --git a/cloudinit/handlers/__init__.py b/cloudinit/handlers/__init__.py index d62fcd19..52defe66 100644 --- a/cloudinit/handlers/__init__.py +++ b/cloudinit/handlers/__init__.py @@ -170,12 +170,12 @@ def _extract_first_or_bytes(blob, size): start = blob.split("\n", 1)[0] else: # We want to avoid decoding the whole blob (it might be huge) - # By taking 4*size bytes we have a guarantee to decode size utf8 chars - start = blob[:4*size].decode(errors='ignore').split("\n", 1)[0] + # By taking 4*size bytes we guarantee to decode size utf8 chars + start = blob[:4 * size].decode(errors='ignore').split("\n", 1)[0] if len(start) >= size: start = start[:size] except UnicodeDecodeError: - # Bytes array doesn't contain a text object -- return chunk of raw bytes + # Bytes array doesn't contain text so return chunk of raw bytes start = blob[0:size] return start diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 4f24e2dd..b950c9a5 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -494,10 +494,10 @@ c: 4 ]) def test_mime_application_octet_stream(self): - """Mime message of type application/octet-stream is ignored but shows warning.""" + """Mime type application/octet-stream is ignored but shows warning.""" ci = stages.Init() message = MIMEBase("application", "octet-stream") - message.set_payload(b'\xbf\xe6\xb2\xc3\xd3\xba\x13\xa4\xd8\xa1\xcc\xbf') + message.set_payload(b'\xbf\xe6\xb2\xc3\xd3\xba\x13\xa4\xd8\xa1\xcc') encoders.encode_base64(message) ci.datasource = FakeDataSource(message.as_string().encode()) @@ -511,6 +511,7 @@ c: 4 mockobj.assert_called_once_with( ci.paths.get_ipath("cloud_config"), "", 0o600) + class TestUDProcess(helpers.ResourceUsingTestCase): def test_bytes_in_userdata(self): -- cgit v1.2.3 From 341a805fca9a06ce12e9f4bbbe15b3dded9eb6a4 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 16 Apr 2015 17:00:19 -0400 Subject: fix cloud-config-archive handling handling of cloud-config-archive input would fail in fully_decoded_payload. part.get_charset() would return a Charset object, but get_charset.input_codec is a string suitable for passing to decode. This handles that correctly, and is more careful about binary data inside input. The test added verifies that cloud-config inside a cloud-config-archive is handled correctly and also that binary data there is ignored without exceptions raised. LP: #1445143 --- cloudinit/handlers/__init__.py | 5 ++++- cloudinit/user_data.py | 9 +++++++-- cloudinit/util.py | 8 ++++++-- tests/unittests/test_data.py | 27 +++++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/handlers/__init__.py b/cloudinit/handlers/__init__.py index 52defe66..53d5604a 100644 --- a/cloudinit/handlers/__init__.py +++ b/cloudinit/handlers/__init__.py @@ -263,7 +263,10 @@ def fixup_handler(mod, def_freq=PER_INSTANCE): def type_from_starts_with(payload, default=None): - payload_lc = payload.lower() + try: + payload_lc = util.decode_binary(payload).lower() + except UnicodeDecodeError: + return default payload_lc = payload_lc.lstrip() for text in INCLUSION_SRCH: if payload_lc.startswith(text): diff --git a/cloudinit/user_data.py b/cloudinit/user_data.py index eb3c7336..f7c5787c 100644 --- a/cloudinit/user_data.py +++ b/cloudinit/user_data.py @@ -49,6 +49,7 @@ INCLUDE_TYPES = ['text/x-include-url', 'text/x-include-once-url'] ARCHIVE_TYPES = ["text/cloud-config-archive"] UNDEF_TYPE = "text/plain" ARCHIVE_UNDEF_TYPE = "text/cloud-config" +ARCHIVE_UNDEF_BINARY_TYPE = "application/octet-stream" # This seems to hit most of the gzip possible content types. DECOMP_TYPES = [ @@ -265,11 +266,15 @@ class UserDataProcessor(object): content = ent.get('content', '') mtype = ent.get('type') if not mtype: - mtype = handlers.type_from_starts_with(content, - ARCHIVE_UNDEF_TYPE) + default = ARCHIVE_UNDEF_TYPE + if isinstance(content, six.binary_type): + default = ARCHIVE_UNDEF_BINARY_TYPE + mtype = handlers.type_from_starts_with(content, default) maintype, subtype = mtype.split('/', 1) if maintype == "text": + if isinstance(content, six.binary_type): + content = content.decode() msg = MIMEText(content, _subtype=subtype) else: msg = MIMEBase(maintype, subtype) diff --git a/cloudinit/util.py b/cloudinit/util.py index 971c1c2d..cae57770 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -121,8 +121,12 @@ def fully_decoded_payload(part): if (six.PY3 and part.get_content_maintype() == 'text' and isinstance(cte_payload, bytes)): - charset = part.get_charset() or 'utf-8' - return cte_payload.decode(charset, errors='surrogateescape') + charset = part.get_charset() + if charset and charset.input_codec: + encoding = charset.input_codec + else: + encoding = 'utf-8' + return cte_payload.decode(encoding, errors='surrogateescape') return cte_payload diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index b950c9a5..1b15dafa 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -512,6 +512,33 @@ c: 4 ci.paths.get_ipath("cloud_config"), "", 0o600) + def test_cloud_config_archive(self): + non_decodable = b'\x11\xc9\xb4gTH\xee\x12' + data = [{'content': '#cloud-config\npassword: gocubs\n'}, + {'content': '#cloud-config\nlocale: chicago\n'}, + {'content': non_decodable}] + message = b'#cloud-config-archive\n' + util.yaml_dumps(data).encode() + + ci = stages.Init() + ci.datasource = FakeDataSource(message) + + fs = {} + + def fsstore(filename, content, mode=0o0644, omode="wb"): + fs[filename] = content + + # consuming the user-data provided should write 'cloud_config' file + # which will have our yaml in it. + with mock.patch('cloudinit.util.write_file') as mockobj: + mockobj.side_effect = fsstore + ci.fetch() + ci.consume_data() + + cfg = util.load_yaml(fs[ci.paths.get_ipath("cloud_config")]) + self.assertEqual(cfg.get('password'), 'gocubs') + self.assertEqual(cfg.get('locale'), 'chicago') + + class TestUDProcess(helpers.ResourceUsingTestCase): def test_bytes_in_userdata(self): -- cgit v1.2.3 From 3551333795539c2cec1195841d2a641046981ba6 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Mon, 20 Apr 2015 11:52:50 +0100 Subject: Add test that missing GCE required key returns False. --- tests/unittests/test_datasource/test_gce.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 4280abc4..540a55d0 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -141,3 +141,14 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): decoded = b64decode( GCE_META_ENCODING.get('instance/attributes/user-data')) self.assertEqual(decoded, self.ds.get_userdata_raw()) + + @httpretty.activate + def test_missing_required_keys_return_false(self): + for required_key in ['instance/id', 'instance/zone', + 'instance/hostname']: + meta = GCE_META_PARTIAL.copy() + del meta[required_key] + httpretty.register_uri(httpretty.GET, MD_URL_RE, + body=_new_request_callback(meta)) + self.assertEqual(False, self.ds.get_data()) + httpretty.reset() -- cgit v1.2.3 From 4fc65f02ae3fbf1a2062e6169ee39b5c5d5e23bc Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Mon, 20 Apr 2015 15:24:22 +0100 Subject: GCE instance-level SSH keys override project-level keys. (LP: #1403617) --- cloudinit/sources/DataSourceGCE.py | 3 ++- tests/unittests/test_datasource/test_gce.py | 38 ++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py index 1a133c28..f4ed915d 100644 --- a/cloudinit/sources/DataSourceGCE.py +++ b/cloudinit/sources/DataSourceGCE.py @@ -80,7 +80,8 @@ class DataSourceGCE(sources.DataSource): ('instance-id', ('instance/id',), True, True), ('availability-zone', ('instance/zone',), True, True), ('local-hostname', ('instance/hostname',), True, True), - ('public-keys', ('project/attributes/sshKeys',), False, True), + ('public-keys', ('project/attributes/sshKeys', + 'instance/attributes/sshKeys'), False, True), ('user-data', ('instance/attributes/user-data',), False, False), ('user-data-encoding', ('instance/attributes/user-data-encoding',), False, True), diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 540a55d0..1fb100f7 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -113,10 +113,6 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): self.assertEqual(GCE_META.get('instance/attributes/user-data'), self.ds.get_userdata_raw()) - # we expect a list of public ssh keys with user names stripped - self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'], - self.ds.get_public_ssh_keys()) - # test partial metadata (missing user-data in particular) @httpretty.activate def test_metadata_partial(self): @@ -152,3 +148,37 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): body=_new_request_callback(meta)) self.assertEqual(False, self.ds.get_data()) httpretty.reset() + + @httpretty.activate + def test_project_level_ssh_keys_are_used(self): + httpretty.register_uri(httpretty.GET, MD_URL_RE, + body=_new_request_callback()) + self.ds.get_data() + + # we expect a list of public ssh keys with user names stripped + self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'], + self.ds.get_public_ssh_keys()) + + @httpretty.activate + def test_instance_level_ssh_keys_are_used(self): + key_content = 'ssh-rsa JustAUser root@server' + meta = GCE_META.copy() + meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content) + + httpretty.register_uri(httpretty.GET, MD_URL_RE, + body=_new_request_callback(meta)) + self.ds.get_data() + + self.assertIn(key_content, self.ds.get_public_ssh_keys()) + + @httpretty.activate + def test_instance_level_keys_replace_project_level_keys(self): + key_content = 'ssh-rsa JustAUser root@server' + meta = GCE_META.copy() + meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content) + + httpretty.register_uri(httpretty.GET, MD_URL_RE, + body=_new_request_callback(meta)) + self.ds.get_data() + + self.assertEqual([key_content], self.ds.get_public_ssh_keys()) -- cgit v1.2.3 From 96854d720d4bd356181acfa093744599a807ea8e Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 1 May 2015 05:38:56 -0400 Subject: fix 'Make pyflakes' --- Makefile | 2 +- cloudinit/config/cc_apt_pipelining.py | 2 +- cloudinit/config/cc_snappy.py | 2 -- cloudinit/sources/DataSourceOpenNebula.py | 1 - tests/unittests/test_datasource/test_smartos.py | 2 -- tests/unittests/test_handler/test_handler_apt_configure.py | 1 - tests/unittests/test_handler/test_handler_snappy.py | 5 ----- tests/unittests/test_templating.py | 5 +---- tools/hacking.py | 2 +- tools/validate-yaml.py | 3 +-- 10 files changed, 5 insertions(+), 20 deletions(-) (limited to 'tests/unittests') diff --git a/Makefile b/Makefile index 009257ca..bb0c5253 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ pep8: @$(CWD)/tools/run-pep8 $(PY_FILES) pyflakes: - pyflakes $(PY_FILES) + @$(CWD)/tools/tox-venv py34 pyflakes $(PY_FILES) pip-requirements: @echo "Installing cloud-init dependencies..." diff --git a/cloudinit/config/cc_apt_pipelining.py b/cloudinit/config/cc_apt_pipelining.py index e5629175..40c32c84 100644 --- a/cloudinit/config/cc_apt_pipelining.py +++ b/cloudinit/config/cc_apt_pipelining.py @@ -43,7 +43,7 @@ def handle(_name, cfg, _cloud, log, _args): write_apt_snippet("0", log, DEFAULT_FILE) elif apt_pipe_value_s in ("none", "unchanged", "os"): return - elif apt_pipe_value_s in [str(b) for b in xrange(0, 6)]: + elif apt_pipe_value_s in [str(b) for b in range(0, 6)]: write_apt_snippet(apt_pipe_value_s, log, DEFAULT_FILE) else: log.warn("Invalid option for apt_pipeling: %s", apt_pipe_value) diff --git a/cloudinit/config/cc_snappy.py b/cloudinit/config/cc_snappy.py index bfe76558..7aaec94a 100644 --- a/cloudinit/config/cc_snappy.py +++ b/cloudinit/config/cc_snappy.py @@ -42,12 +42,10 @@ Example config: """ from cloudinit import log as logging -from cloudinit import templater from cloudinit import util from cloudinit.settings import PER_INSTANCE import glob -import six import tempfile import os diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index 61709c1b..ac2c3b45 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -24,7 +24,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import base64 import os import pwd import re diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 28b41eaf..adee9019 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -36,8 +36,6 @@ from binascii import crc32 import serial import six -import six - from cloudinit import helpers as c_helpers from cloudinit.sources import DataSourceSmartOS from cloudinit.util import b64e diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index 02cad8b2..895728b3 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -7,7 +7,6 @@ import os import re import shutil import tempfile -import unittest class TestAptProxyConfig(TestCase): diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index f3109bac..eceb14d9 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -38,7 +38,6 @@ class TestInstallPackages(t_help.TestCase): if 'args' not in kwargs: kwargs['args'] = args[0] self.subp_called.append(kwargs) - snap_cmds = [] args = kwargs['args'] # here we basically parse the snappy command invoked # and append to snapcmds a list of (mode, pkg, config) @@ -117,9 +116,6 @@ class TestInstallPackages(t_help.TestCase): def test_package_ops_common_filename(self): # fish package name from filename # package names likely look like: pkgname.namespace_version_arch.snap - fname = "xkcd-webserver.canonical_0.3.4_all.snap" - name = "xkcd-webserver.canonical" - shortname = "xkcd-webserver" # find filenames self.populate_tmp( @@ -165,7 +161,6 @@ class TestInstallPackages(t_help.TestCase): 'ubuntu-core': {'c1': 'c2'}, 'notinstalled.smoser': {'s1': 's2'}, } - cfg = {'config-example-k1': 'config-example-k2'} ret = get_package_ops( packages=['config-example.canonical'], configs=cfgs, installed=['config-example.smoser', 'pkg1.canonical', diff --git a/tests/unittests/test_templating.py b/tests/unittests/test_templating.py index cf7c03b0..0c19a2c2 100644 --- a/tests/unittests/test_templating.py +++ b/tests/unittests/test_templating.py @@ -18,10 +18,6 @@ from __future__ import print_function -import sys -import six -import unittest - from . import helpers as test_helpers import textwrap @@ -30,6 +26,7 @@ from cloudinit import templater try: import Cheetah HAS_CHEETAH = True + Cheetah # make pyflakes happy, as Cheetah is not used here except ImportError: HAS_CHEETAH = False diff --git a/tools/hacking.py b/tools/hacking.py index e7797564..3175df38 100755 --- a/tools/hacking.py +++ b/tools/hacking.py @@ -128,7 +128,7 @@ def cloud_docstring_multiline_end(physical_line): """ pos = max([physical_line.find(i) for i in DOCSTRING_TRIPLE]) # start if (pos != -1 and len(physical_line) == pos): - print physical_line + print(physical_line) if (physical_line[pos + 3] == ' '): return (pos, "N403: multi line docstring end on new line") diff --git a/tools/validate-yaml.py b/tools/validate-yaml.py index eda59cb8..6e164590 100755 --- a/tools/validate-yaml.py +++ b/tools/validate-yaml.py @@ -4,7 +4,6 @@ """ import sys - import yaml @@ -17,7 +16,7 @@ if __name__ == "__main__": yaml.safe_load(fh.read()) fh.close() sys.stdout.write(" - ok\n") - except Exception, e: + except Exception as e: sys.stdout.write(" - bad (%s)\n" % (e)) bads += 1 if bads > 0: -- cgit v1.2.3 From 6ddf7beb112f016be7ebd6fe296de6eaaf3aa9ca Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Thu, 7 May 2015 14:46:47 +0100 Subject: Implement basic replacement for walinuxagent in Azure data source. --- cloudinit/sources/DataSourceAzure.py | 292 +++++++++++++++++++---- tests/unittests/test_datasource/test_azure.py | 331 ++++++++++++++++++++++++++ 2 files changed, 574 insertions(+), 49 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index a19d9ca2..bd3c742b 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -22,8 +22,14 @@ import crypt import fnmatch import os import os.path +import re +import socket +import struct +import tempfile import time +from contextlib import contextmanager from xml.dom import minidom +from xml.etree import ElementTree from cloudinit import log as logging from cloudinit.settings import PER_ALWAYS @@ -34,13 +40,11 @@ LOG = logging.getLogger(__name__) DS_NAME = 'Azure' DEFAULT_METADATA = {"instance-id": "iid-AZURE-NODE"} -AGENT_START = ['service', 'walinuxagent', 'start'] BOUNCE_COMMAND = ['sh', '-xc', "i=$interface; x=0; ifdown $i || x=$?; ifup $i || x=$?; exit $x"] DATA_DIR_CLEAN_LIST = ['SharedConfig.xml'] BUILTIN_DS_CONFIG = { - 'agent_command': AGENT_START, 'data_dir': "/var/lib/waagent", 'set_hostname': True, 'hostname_bounce': { @@ -66,6 +70,231 @@ BUILTIN_CLOUD_CONFIG = { DS_CFG_PATH = ['datasource', DS_NAME] DEF_EPHEMERAL_LABEL = 'Temporary Storage' +REPORT_READY_XML_TEMPLATE = """\ + + + {incarnation} + + {container_id} + + + {instance_id} + + Ready + + + + +""" + + +@contextmanager +def cd(newdir): + prevdir = os.getcwd() + os.chdir(os.path.expanduser(newdir)) + try: + yield + finally: + os.chdir(prevdir) + + +class AzureEndpointHttpClient(object): + + headers = { + 'x-ms-agent-name': 'WALinuxAgent', + 'x-ms-version': '2012-11-30', + } + + def __init__(self, certificate): + self.extra_secure_headers = { + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": certificate, + } + + def get(self, url, secure=False): + headers = self.headers + if secure: + headers = self.headers.copy() + headers.update(self.extra_secure_headers) + return util.read_file_or_url(url, headers=headers) + + def post(self, url, data=None, extra_headers=None): + headers = self.headers + if extra_headers is not None: + headers = self.headers.copy() + headers.update(extra_headers) + return util.read_file_or_url(url, data=data, headers=headers) + + +def find_endpoint(): + content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') + value = None + for line in content.splitlines(): + if 'unknown-245' in line: + value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') + if value is None: + raise Exception('No endpoint found in DHCP config.') + if ':' in value: + hex_string = '' + for hex_pair in value.split(':'): + if len(hex_pair) == 1: + hex_pair = '0' + hex_pair + hex_string += hex_pair + value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) + else: + value = value.encode('utf-8') + return socket.inet_ntoa(value) + + +class GoalState(object): + + def __init__(self, xml, http_client): + self.http_client = http_client + self.root = ElementTree.fromstring(xml) + + def _text_from_xpath(self, xpath): + element = self.root.find(xpath) + if element is not None: + return element.text + return None + + @property + def container_id(self): + return self._text_from_xpath('./Container/ContainerId') + + @property + def incarnation(self): + return self._text_from_xpath('./Incarnation') + + @property + def instance_id(self): + return self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance/InstanceId') + + @property + def shared_config_xml(self): + url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' + '/Configuration/SharedConfig') + return self.http_client.get(url).contents + + @property + def certificates_xml(self): + url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' + '/Configuration/Certificates') + if url is not None: + return self.http_client.get(url, secure=True).contents + return None + + +class OpenSSLManager(object): + + certificate_names = { + 'private_key': 'TransportPrivate.pem', + 'certificate': 'TransportCert.pem', + } + + def __init__(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.certificate = None + self.generate_certificate() + + def generate_certificate(self): + if self.certificate is not None: + return + with cd(self.tmpdir.name): + util.subp([ + 'openssl', 'req', '-x509', '-nodes', '-subj', + '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', + '-keyout', self.certificate_names['private_key'], + '-out', self.certificate_names['certificate'], + ]) + certificate = '' + for line in open(self.certificate_names['certificate']): + if "CERTIFICATE" not in line: + certificate += line.rstrip() + self.certificate = certificate + + def parse_certificates(self, certificates_xml): + tag = ElementTree.fromstring(certificates_xml).find( + './/Data') + certificates_content = tag.text + lines = [ + b'MIME-Version: 1.0', + b'Content-Disposition: attachment; filename="Certificates.p7m"', + b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', + b'Content-Transfer-Encoding: base64', + b'', + certificates_content.encode('utf-8'), + ] + with cd(self.tmpdir.name): + with open('Certificates.p7m', 'wb') as f: + f.write(b'\n'.join(lines)) + out, _ = util.subp( + 'openssl cms -decrypt -in Certificates.p7m -inkey' + ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' + ' -password pass:'.format(**self.certificate_names), + shell=True) + private_keys, certificates = [], [] + current = [] + for line in out.splitlines(): + current.append(line) + if re.match(r'[-]+END .*?KEY[-]+$', line): + private_keys.append('\n'.join(current)) + current = [] + elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): + certificates.append('\n'.join(current)) + current = [] + keys = [] + for certificate in certificates: + with cd(self.tmpdir.name): + public_key, _ = util.subp( + 'openssl x509 -noout -pubkey |' + 'ssh-keygen -i -m PKCS8 -f /dev/stdin', + data=certificate, + shell=True) + keys.append(public_key) + return keys + + +class WALinuxAgentShim(object): + + def __init__(self): + self.endpoint = find_endpoint() + self.goal_state = None + self.openssl_manager = OpenSSLManager() + self.http_client = AzureEndpointHttpClient( + self.openssl_manager.certificate) + self.values = {} + + def register_with_azure_and_fetch_data(self): + LOG.info('Registering with Azure...') + for i in range(10): + try: + response = self.http_client.get( + 'http://{}/machine/?comp=goalstate'.format(self.endpoint)) + except Exception: + time.sleep(i + 1) + else: + break + self.goal_state = GoalState(response.contents, self.http_client) + self.public_keys = [] + if self.goal_state.certificates_xml is not None: + self.public_keys = self.openssl_manager.parse_certificates( + self.goal_state.certificates_xml) + self._report_ready() + + def _report_ready(self): + document = REPORT_READY_XML_TEMPLATE.format( + incarnation=self.goal_state.incarnation, + container_id=self.goal_state.container_id, + instance_id=self.goal_state.instance_id, + ) + self.http_client.post( + "http://{}/machine?comp=health".format(self.endpoint), + data=document, + extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, + ) + def get_hostname(hostname_command='hostname'): return util.subp(hostname_command, capture=True)[0].strip() @@ -185,53 +414,17 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) - temp_hostname = self.metadata.get('local-hostname') - hostname_command = mycfg['hostname_bounce']['hostname_command'] - with temporary_hostname(temp_hostname, mycfg, - hostname_command=hostname_command) \ - as previous_hostname: - if (previous_hostname is not None - and util.is_true(mycfg.get('set_hostname'))): - cfg = mycfg['hostname_bounce'] - try: - perform_hostname_bounce(hostname=temp_hostname, - cfg=cfg, - prev_hostname=previous_hostname) - except Exception as e: - LOG.warn("Failed publishing hostname: %s", e) - util.logexc(LOG, "handling set_hostname failed") + shim = WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() - try: - invoke_agent(mycfg['agent_command']) - except util.ProcessExecutionError: - # claim the datasource even if the command failed - util.logexc(LOG, "agent command '%s' failed.", - mycfg['agent_command']) - - shcfgxml = os.path.join(ddir, "SharedConfig.xml") - wait_for = [shcfgxml] - - fp_files = [] - for pk in self.cfg.get('_pubkeys', []): - bname = str(pk['fingerprint'] + ".crt") - fp_files += [os.path.join(ddir, bname)] - - missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", - func=wait_for_files, - args=(wait_for + fp_files,)) - if len(missing): - LOG.warn("Did not find files, but going on: %s", missing) - - if shcfgxml in missing: - LOG.warn("SharedConfig.xml missing, using static instance-id") - else: - try: - self.metadata['instance-id'] = iid_from_shared_config(shcfgxml) - except ValueError as e: - LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) + try: + self.metadata['instance-id'] = iid_from_shared_config_content( + shim.goal_state.shared_config_xml) + except ValueError as e: + LOG.warn( + "failed to get instance id in %s: %s", shim.shared_config, e) - pubkeys = pubkeys_from_crt_files(fp_files) - self.metadata['public-keys'] = pubkeys + self.metadata['public-keys'] = shim.public_keys found_ephemeral = find_ephemeral_disk() if found_ephemeral: @@ -363,10 +556,11 @@ def perform_hostname_bounce(hostname, cfg, prev_hostname): 'env': env}) -def crtfile_to_pubkey(fname): +def crtfile_to_pubkey(fname, data=None): pipeline = ('openssl x509 -noout -pubkey < "$0" |' 'ssh-keygen -i -m PKCS8 -f /dev/stdin') - (out, _err) = util.subp(['sh', '-c', pipeline, fname], capture=True) + (out, _err) = util.subp(['sh', '-c', pipeline, fname], + capture=True, data=data) return out.rstrip() diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 7e789853..dc7f2663 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -15,11 +15,48 @@ except ImportError: import crypt import os import stat +import struct import yaml import shutil import tempfile import unittest +from cloudinit import url_helper + + +GOAL_STATE_TEMPLATE = """\ + + + 2012-11-30 + {incarnation} + + Started + 300000 + + 16001 + + FALSE + + + {container_id} + + + {instance_id} + Started + + http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1 + {shared_config_url} + http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1 + http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1 + {certificates_url} + 68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml + + + + + +""" + def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): if data is None: @@ -579,3 +616,297 @@ class TestReadAzureSharedConfig(unittest.TestCase): """ ret = DataSourceAzure.iid_from_shared_config_content(xml) self.assertEqual("MY_INSTANCE_ID", ret) + + +class TestFindEndpoint(TestCase): + + def setUp(self): + super(TestFindEndpoint, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.load_file = patches.enter_context( + mock.patch.object(DataSourceAzure.util, 'load_file')) + + def test_missing_file(self): + self.load_file.side_effect = IOError + self.assertRaises(IOError, DataSourceAzure.find_endpoint) + + def test_missing_special_azure_line(self): + self.load_file.return_value = '' + self.assertRaises(Exception, DataSourceAzure.find_endpoint) + + def _build_lease_content(self, ip_address, use_hex=True): + ip_address_repr = ':'.join( + [hex(int(part)).replace('0x', '') + for part in ip_address.split('.')]) + if not use_hex: + ip_address_repr = struct.pack( + '>L', int(ip_address_repr.replace(':', ''), 16)) + ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) + return '\n'.join([ + 'lease {', + ' interface "eth0";', + ' option unknown-245 {0};'.format(ip_address_repr), + '}']) + + def test_hex_string(self): + ip_address = '98.76.54.32' + file_content = self._build_lease_content(ip_address) + self.load_file.return_value = file_content + self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + + def test_hex_string_with_single_character_part(self): + ip_address = '4.3.2.1' + file_content = self._build_lease_content(ip_address) + self.load_file.return_value = file_content + self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + + def test_packed_string(self): + ip_address = '98.76.54.32' + file_content = self._build_lease_content(ip_address, use_hex=False) + self.load_file.return_value = file_content + self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + + def test_latest_lease_used(self): + ip_addresses = ['4.3.2.1', '98.76.54.32'] + file_content = '\n'.join([self._build_lease_content(ip_address) + for ip_address in ip_addresses]) + self.load_file.return_value = file_content + self.assertEqual(ip_addresses[-1], DataSourceAzure.find_endpoint()) + + +class TestGoalStateParsing(TestCase): + + default_parameters = { + 'incarnation': 1, + 'container_id': 'MyContainerId', + 'instance_id': 'MyInstanceId', + 'shared_config_url': 'MySharedConfigUrl', + 'certificates_url': 'MyCertificatesUrl', + } + + def _get_goal_state(self, http_client=None, **kwargs): + if http_client is None: + http_client = mock.MagicMock() + parameters = self.default_parameters.copy() + parameters.update(kwargs) + xml = GOAL_STATE_TEMPLATE.format(**parameters) + if parameters['certificates_url'] is None: + new_xml_lines = [] + for line in xml.splitlines(): + if 'Certificates' in line: + continue + new_xml_lines.append(line) + xml = '\n'.join(new_xml_lines) + return DataSourceAzure.GoalState(xml, http_client) + + def test_incarnation_parsed_correctly(self): + incarnation = '123' + goal_state = self._get_goal_state(incarnation=incarnation) + self.assertEqual(incarnation, goal_state.incarnation) + + def test_container_id_parsed_correctly(self): + container_id = 'TestContainerId' + goal_state = self._get_goal_state(container_id=container_id) + self.assertEqual(container_id, goal_state.container_id) + + def test_instance_id_parsed_correctly(self): + instance_id = 'TestInstanceId' + goal_state = self._get_goal_state(instance_id=instance_id) + self.assertEqual(instance_id, goal_state.instance_id) + + def test_shared_config_xml_parsed_and_fetched_correctly(self): + http_client = mock.MagicMock() + shared_config_url = 'TestSharedConfigUrl' + goal_state = self._get_goal_state( + http_client=http_client, shared_config_url=shared_config_url) + shared_config_xml = goal_state.shared_config_xml + self.assertEqual(1, http_client.get.call_count) + self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) + self.assertEqual(http_client.get.return_value.contents, + shared_config_xml) + + def test_certificates_xml_parsed_and_fetched_correctly(self): + http_client = mock.MagicMock() + certificates_url = 'TestSharedConfigUrl' + goal_state = self._get_goal_state( + http_client=http_client, certificates_url=certificates_url) + certificates_xml = goal_state.certificates_xml + self.assertEqual(1, http_client.get.call_count) + self.assertEqual(certificates_url, http_client.get.call_args[0][0]) + self.assertTrue(http_client.get.call_args[1].get('secure', False)) + self.assertEqual(http_client.get.return_value.contents, + certificates_xml) + + def test_missing_certificates_skips_http_get(self): + http_client = mock.MagicMock() + goal_state = self._get_goal_state( + http_client=http_client, certificates_url=None) + certificates_xml = goal_state.certificates_xml + self.assertEqual(0, http_client.get.call_count) + self.assertIsNone(certificates_xml) + + +class TestAzureEndpointHttpClient(TestCase): + + regular_headers = { + 'x-ms-agent-name': 'WALinuxAgent', + 'x-ms-version': '2012-11-30', + } + + def setUp(self): + super(TestAzureEndpointHttpClient, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.read_file_or_url = patches.enter_context( + mock.patch.object(DataSourceAzure.util, 'read_file_or_url')) + + def test_non_secure_get(self): + client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) + url = 'MyTestUrl' + response = client.get(url, secure=False) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(mock.call(url, headers=self.regular_headers), + self.read_file_or_url.call_args) + + def test_secure_get(self): + url = 'MyTestUrl' + certificate = mock.MagicMock() + expected_headers = self.regular_headers.copy() + expected_headers.update({ + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": certificate, + }) + client = DataSourceAzure.AzureEndpointHttpClient(certificate) + response = client.get(url, secure=True) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(mock.call(url, headers=expected_headers), + self.read_file_or_url.call_args) + + def test_post(self): + data = mock.MagicMock() + url = 'MyTestUrl' + client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) + response = client.post(url, data=data) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual( + mock.call(url, data=data, headers=self.regular_headers), + self.read_file_or_url.call_args) + + def test_post_with_extra_headers(self): + url = 'MyTestUrl' + client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) + extra_headers = {'test': 'header'} + client.post(url, extra_headers=extra_headers) + self.assertEqual(1, self.read_file_or_url.call_count) + expected_headers = self.regular_headers.copy() + expected_headers.update(extra_headers) + self.assertEqual( + mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), + self.read_file_or_url.call_args) + + +class TestOpenSSLManager(TestCase): + + def setUp(self): + super(TestOpenSSLManager, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.subp = patches.enter_context( + mock.patch.object(DataSourceAzure.util, 'subp')) + + @mock.patch.object(DataSourceAzure, 'cd', mock.MagicMock()) + @mock.patch.object(DataSourceAzure.tempfile, 'TemporaryDirectory') + def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): + manager = DataSourceAzure.OpenSSLManager() + self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) + + @mock.patch('builtins.open') + def test_generate_certificate_uses_tmpdir(self, open): + subp_directory = {} + + def capture_directory(*args, **kwargs): + subp_directory['path'] = os.getcwd() + + self.subp.side_effect = capture_directory + manager = DataSourceAzure.OpenSSLManager() + self.assertEqual(manager.tmpdir.name, subp_directory['path']) + + +class TestWALinuxAgentShim(TestCase): + + def setUp(self): + super(TestWALinuxAgentShim, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.AzureEndpointHttpClient = patches.enter_context( + mock.patch.object(DataSourceAzure, 'AzureEndpointHttpClient')) + self.find_endpoint = patches.enter_context( + mock.patch.object(DataSourceAzure, 'find_endpoint')) + self.GoalState = patches.enter_context( + mock.patch.object(DataSourceAzure, 'GoalState')) + self.OpenSSLManager = patches.enter_context( + mock.patch.object(DataSourceAzure, 'OpenSSLManager')) + + def test_http_client_uses_certificate(self): + shim = DataSourceAzure.WALinuxAgentShim() + self.assertEqual( + [mock.call(self.OpenSSLManager.return_value.certificate)], + self.AzureEndpointHttpClient.call_args_list) + self.assertEqual(self.AzureEndpointHttpClient.return_value, + shim.http_client) + + def test_correct_url_used_for_goalstate(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = DataSourceAzure.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + get = self.AzureEndpointHttpClient.return_value.get + self.assertEqual( + [mock.call('http://test_endpoint/machine/?comp=goalstate')], + get.call_args_list) + self.assertEqual( + [mock.call(get.return_value.contents, shim.http_client)], + self.GoalState.call_args_list) + + def test_certificates_used_to_determine_public_keys(self): + shim = DataSourceAzure.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.certificates_xml)], + self.OpenSSLManager.return_value.parse_certificates.call_args_list) + self.assertEqual( + self.OpenSSLManager.return_value.parse_certificates.return_value, + shim.public_keys) + + def test_absent_certificates_produces_empty_public_keys(self): + self.GoalState.return_value.certificates_xml = None + shim = DataSourceAzure.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + self.assertEqual([], shim.public_keys) + + def test_correct_url_used_for_report_ready(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = DataSourceAzure.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + expected_url = 'http://test_endpoint/machine?comp=health' + self.assertEqual( + [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], + shim.http_client.post.call_args_list) + + def test_goal_state_values_used_for_report_ready(self): + self.GoalState.return_value.incarnation = 'TestIncarnation' + self.GoalState.return_value.container_id = 'TestContainerId' + self.GoalState.return_value.instance_id = 'TestInstanceId' + shim = DataSourceAzure.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + posted_document = shim.http_client.post.call_args[1]['data'] + self.assertIn('TestIncarnation', posted_document) + self.assertIn('TestContainerId', posted_document) + self.assertIn('TestInstanceId', posted_document) -- cgit v1.2.3 From 2edfd791b29df3271bdc3aff40d60336ddd636ed Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 12:58:18 +0100 Subject: Return a dict of data from WALinuxAgentShim, rather than accessing attributes. --- cloudinit/sources/DataSourceAzure.py | 46 +++++++++++++++------------ tests/unittests/test_datasource/test_azure.py | 29 ++++++++++++++--- 2 files changed, 49 insertions(+), 26 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index bd3c742b..b93357d5 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -260,7 +260,6 @@ class WALinuxAgentShim(object): def __init__(self): self.endpoint = find_endpoint() - self.goal_state = None self.openssl_manager = OpenSSLManager() self.http_client = AzureEndpointHttpClient( self.openssl_manager.certificate) @@ -276,18 +275,24 @@ class WALinuxAgentShim(object): time.sleep(i + 1) else: break - self.goal_state = GoalState(response.contents, self.http_client) - self.public_keys = [] - if self.goal_state.certificates_xml is not None: - self.public_keys = self.openssl_manager.parse_certificates( - self.goal_state.certificates_xml) - self._report_ready() - - def _report_ready(self): + goal_state = GoalState(response.contents, self.http_client) + public_keys = [] + if goal_state.certificates_xml is not None: + public_keys = self.openssl_manager.parse_certificates( + goal_state.certificates_xml) + data = { + 'instance-id': iid_from_shared_config_content( + goal_state.shared_config_xml), + 'public-keys': public_keys, + } + self._report_ready(goal_state) + return data + + def _report_ready(self, goal_state): document = REPORT_READY_XML_TEMPLATE.format( - incarnation=self.goal_state.incarnation, - container_id=self.goal_state.container_id, - instance_id=self.goal_state.instance_id, + incarnation=goal_state.incarnation, + container_id=goal_state.container_id, + instance_id=goal_state.instance_id, ) self.http_client.post( "http://{}/machine?comp=health".format(self.endpoint), @@ -414,17 +419,16 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) - shim = WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - try: - self.metadata['instance-id'] = iid_from_shared_config_content( - shim.goal_state.shared_config_xml) - except ValueError as e: - LOG.warn( - "failed to get instance id in %s: %s", shim.shared_config, e) + shim = WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + except Exception as exc: + LOG.info("Error communicating with Azure fabric; assume we aren't" + " on Azure.", exc_info=True) + return False - self.metadata['public-keys'] = shim.public_keys + self.metadata['instance-id'] = data['instance-id'] + self.metadata['public-keys'] = data['public-keys'] found_ephemeral = find_ephemeral_disk() if found_ephemeral: diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index dc7f2663..fd5b24f8 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -160,6 +160,12 @@ class TestAzureDataSource(TestCase): mod = DataSourceAzure mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d + fake_shim = mock.MagicMock() + fake_shim().register_with_azure_and_fetch_data.return_value = { + 'instance-id': 'i-my-azure-id', + 'public-keys': [], + } + self.apply_patches([ (mod, 'list_possible_azure_ds_devs', dsdevs), (mod, 'invoke_agent', _invoke_agent), @@ -169,7 +175,8 @@ class TestAzureDataSource(TestCase): (mod, 'perform_hostname_bounce', mock.MagicMock()), (mod, 'get_hostname', mock.MagicMock()), (mod, 'set_hostname', mock.MagicMock()), - ]) + (mod, 'WALinuxAgentShim', fake_shim), + ]) dsrc = mod.DataSourceAzureNet( data.get('sys_cfg', {}), distro=None, paths=self.paths) @@ -852,6 +859,9 @@ class TestWALinuxAgentShim(TestCase): mock.patch.object(DataSourceAzure, 'find_endpoint')) self.GoalState = patches.enter_context( mock.patch.object(DataSourceAzure, 'GoalState')) + self.iid_from_shared_config_content = patches.enter_context( + mock.patch.object(DataSourceAzure, + 'iid_from_shared_config_content')) self.OpenSSLManager = patches.enter_context( mock.patch.object(DataSourceAzure, 'OpenSSLManager')) @@ -877,19 +887,28 @@ class TestWALinuxAgentShim(TestCase): def test_certificates_used_to_determine_public_keys(self): shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() + data = shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(self.GoalState.return_value.certificates_xml)], self.OpenSSLManager.return_value.parse_certificates.call_args_list) self.assertEqual( self.OpenSSLManager.return_value.parse_certificates.return_value, - shim.public_keys) + data['public-keys']) def test_absent_certificates_produces_empty_public_keys(self): self.GoalState.return_value.certificates_xml = None shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - self.assertEqual([], shim.public_keys) + data = shim.register_with_azure_and_fetch_data() + self.assertEqual([], data['public-keys']) + + def test_instance_id_returned_in_data(self): + shim = DataSourceAzure.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.shared_config_xml)], + self.iid_from_shared_config_content.call_args_list) + self.assertEqual(self.iid_from_shared_config_content.return_value, + data['instance-id']) def test_correct_url_used_for_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' -- cgit v1.2.3 From 7ca682408f857fcfd04bfc026ea6c697c1fd4b86 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 12:59:57 +0100 Subject: Make find_endpoint a staticmethod to clean up top-level namespace. --- cloudinit/sources/DataSourceAzure.py | 84 ++++++++++++++------------- tests/unittests/test_datasource/test_azure.py | 21 ++++--- 2 files changed, 57 insertions(+), 48 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index c783732d..ba4afa5f 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -70,22 +70,6 @@ BUILTIN_CLOUD_CONFIG = { DS_CFG_PATH = ['datasource', DS_NAME] DEF_EPHEMERAL_LABEL = 'Temporary Storage' -REPORT_READY_XML_TEMPLATE = """\ - - - {incarnation} - - {container_id} - - - {instance_id} - - Ready - - - - -""" @contextmanager @@ -126,29 +110,6 @@ class AzureEndpointHttpClient(object): return util.read_file_or_url(url, data=data, headers=headers) -def find_endpoint(): - LOG.debug('Finding Azure endpoint...') - content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') - value = None - for line in content.splitlines(): - if 'unknown-245' in line: - value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') - if value is None: - raise Exception('No endpoint found in DHCP config.') - if ':' in value: - hex_string = '' - for hex_pair in value.split(':'): - if len(hex_pair) == 1: - hex_pair = '0' + hex_pair - hex_string += hex_pair - value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) - else: - value = value.encode('utf-8') - endpoint_ip_address = socket.inet_ntoa(value) - LOG.debug('Azure endpoint found at %s', endpoint_ip_address) - return endpoint_ip_address - - class GoalState(object): def __init__(self, xml, http_client): @@ -268,14 +229,55 @@ class OpenSSLManager(object): class WALinuxAgentShim(object): + REPORT_READY_XML_TEMPLATE = '\n'.join([ + '', + '', + ' {incarnation}', + ' ', + ' {container_id}', + ' ', + ' ', + ' {instance_id}', + ' ', + ' Ready', + ' ', + ' ', + ' ', + ' ', + '']) + def __init__(self): LOG.debug('WALinuxAgentShim instantiated...') - self.endpoint = find_endpoint() + self.endpoint = self.find_endpoint() self.openssl_manager = OpenSSLManager() self.http_client = AzureEndpointHttpClient( self.openssl_manager.certificate) self.values = {} + @staticmethod + def find_endpoint(): + LOG.debug('Finding Azure endpoint...') + content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') + value = None + for line in content.splitlines(): + if 'unknown-245' in line: + value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') + if value is None: + raise Exception('No endpoint found in DHCP config.') + if ':' in value: + hex_string = '' + for hex_pair in value.split(':'): + if len(hex_pair) == 1: + hex_pair = '0' + hex_pair + hex_string += hex_pair + value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) + else: + value = value.encode('utf-8') + endpoint_ip_address = socket.inet_ntoa(value) + LOG.debug('Azure endpoint found at %s', endpoint_ip_address) + return endpoint_ip_address + def register_with_azure_and_fetch_data(self): LOG.info('Registering with Azure...') for i in range(10): @@ -303,7 +305,7 @@ class WALinuxAgentShim(object): def _report_ready(self, goal_state): LOG.debug('Reporting ready to Azure fabric.') - document = REPORT_READY_XML_TEMPLATE.format( + document = self.REPORT_READY_XML_TEMPLATE.format( incarnation=goal_state.incarnation, container_id=goal_state.container_id, instance_id=goal_state.instance_id, diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index fd5b24f8..28703029 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -637,11 +637,13 @@ class TestFindEndpoint(TestCase): def test_missing_file(self): self.load_file.side_effect = IOError - self.assertRaises(IOError, DataSourceAzure.find_endpoint) + self.assertRaises(IOError, + DataSourceAzure.WALinuxAgentShim.find_endpoint) def test_missing_special_azure_line(self): self.load_file.return_value = '' - self.assertRaises(Exception, DataSourceAzure.find_endpoint) + self.assertRaises(Exception, + DataSourceAzure.WALinuxAgentShim.find_endpoint) def _build_lease_content(self, ip_address, use_hex=True): ip_address_repr = ':'.join( @@ -661,26 +663,30 @@ class TestFindEndpoint(TestCase): ip_address = '98.76.54.32' file_content = self._build_lease_content(ip_address) self.load_file.return_value = file_content - self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + self.assertEqual(ip_address, + DataSourceAzure.WALinuxAgentShim.find_endpoint()) def test_hex_string_with_single_character_part(self): ip_address = '4.3.2.1' file_content = self._build_lease_content(ip_address) self.load_file.return_value = file_content - self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + self.assertEqual(ip_address, + DataSourceAzure.WALinuxAgentShim.find_endpoint()) def test_packed_string(self): ip_address = '98.76.54.32' file_content = self._build_lease_content(ip_address, use_hex=False) self.load_file.return_value = file_content - self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + self.assertEqual(ip_address, + DataSourceAzure.WALinuxAgentShim.find_endpoint()) def test_latest_lease_used(self): ip_addresses = ['4.3.2.1', '98.76.54.32'] file_content = '\n'.join([self._build_lease_content(ip_address) for ip_address in ip_addresses]) self.load_file.return_value = file_content - self.assertEqual(ip_addresses[-1], DataSourceAzure.find_endpoint()) + self.assertEqual(ip_addresses[-1], + DataSourceAzure.WALinuxAgentShim.find_endpoint()) class TestGoalStateParsing(TestCase): @@ -856,7 +862,8 @@ class TestWALinuxAgentShim(TestCase): self.AzureEndpointHttpClient = patches.enter_context( mock.patch.object(DataSourceAzure, 'AzureEndpointHttpClient')) self.find_endpoint = patches.enter_context( - mock.patch.object(DataSourceAzure, 'find_endpoint')) + mock.patch.object( + DataSourceAzure.WALinuxAgentShim, 'find_endpoint')) self.GoalState = patches.enter_context( mock.patch.object(DataSourceAzure, 'GoalState')) self.iid_from_shared_config_content = patches.enter_context( -- cgit v1.2.3 From b9f26689e8b3bb7a3486771c6362107232a7dcf4 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 13:16:42 +0100 Subject: Split WALinuxAgentShim code out to separate file. --- cloudinit/sources/DataSourceAzure.py | 271 +-------------- cloudinit/sources/helpers/azure.py | 273 +++++++++++++++ tests/unittests/test_datasource/test_azure.py | 364 -------------------- .../unittests/test_datasource/test_azure_helper.py | 377 +++++++++++++++++++++ 4 files changed, 653 insertions(+), 632 deletions(-) create mode 100644 cloudinit/sources/helpers/azure.py create mode 100644 tests/unittests/test_datasource/test_azure_helper.py (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index c2dc6b4c..5e147950 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -17,23 +17,19 @@ # along with this program. If not, see . import base64 +import contextlib import crypt import fnmatch import os import os.path -import re -import socket -import struct -import tempfile -import time -from contextlib import contextmanager from xml.dom import minidom -from xml.etree import ElementTree from cloudinit import log as logging from cloudinit.settings import PER_ALWAYS from cloudinit import sources from cloudinit import util +from cloudinit.sources.helpers.azure import ( + iid_from_shared_config_content, WALinuxAgentShim) LOG = logging.getLogger(__name__) @@ -70,253 +66,6 @@ DS_CFG_PATH = ['datasource', DS_NAME] DEF_EPHEMERAL_LABEL = 'Temporary Storage' - -@contextmanager -def cd(newdir): - prevdir = os.getcwd() - os.chdir(os.path.expanduser(newdir)) - try: - yield - finally: - os.chdir(prevdir) - - -class AzureEndpointHttpClient(object): - - headers = { - 'x-ms-agent-name': 'WALinuxAgent', - 'x-ms-version': '2012-11-30', - } - - def __init__(self, certificate): - self.extra_secure_headers = { - "x-ms-cipher-name": "DES_EDE3_CBC", - "x-ms-guest-agent-public-x509-cert": certificate, - } - - def get(self, url, secure=False): - headers = self.headers - if secure: - headers = self.headers.copy() - headers.update(self.extra_secure_headers) - return util.read_file_or_url(url, headers=headers) - - def post(self, url, data=None, extra_headers=None): - headers = self.headers - if extra_headers is not None: - headers = self.headers.copy() - headers.update(extra_headers) - return util.read_file_or_url(url, data=data, headers=headers) - - -class GoalState(object): - - def __init__(self, xml, http_client): - self.http_client = http_client - self.root = ElementTree.fromstring(xml) - self._certificates_xml = None - - def _text_from_xpath(self, xpath): - element = self.root.find(xpath) - if element is not None: - return element.text - return None - - @property - def container_id(self): - return self._text_from_xpath('./Container/ContainerId') - - @property - def incarnation(self): - return self._text_from_xpath('./Incarnation') - - @property - def instance_id(self): - return self._text_from_xpath( - './Container/RoleInstanceList/RoleInstance/InstanceId') - - @property - def shared_config_xml(self): - url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' - '/Configuration/SharedConfig') - return self.http_client.get(url).contents - - @property - def certificates_xml(self): - if self._certificates_xml is None: - url = self._text_from_xpath( - './Container/RoleInstanceList/RoleInstance' - '/Configuration/Certificates') - if url is not None: - self._certificates_xml = self.http_client.get( - url, secure=True).contents - return self._certificates_xml - - -class OpenSSLManager(object): - - certificate_names = { - 'private_key': 'TransportPrivate.pem', - 'certificate': 'TransportCert.pem', - } - - def __init__(self): - self.tmpdir = tempfile.TemporaryDirectory() - self.certificate = None - self.generate_certificate() - - def generate_certificate(self): - LOG.debug('Generating certificate for communication with fabric...') - if self.certificate is not None: - LOG.debug('Certificate already generated.') - return - with cd(self.tmpdir.name): - util.subp([ - 'openssl', 'req', '-x509', '-nodes', '-subj', - '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', - '-keyout', self.certificate_names['private_key'], - '-out', self.certificate_names['certificate'], - ]) - certificate = '' - for line in open(self.certificate_names['certificate']): - if "CERTIFICATE" not in line: - certificate += line.rstrip() - self.certificate = certificate - LOG.debug('New certificate generated.') - - def parse_certificates(self, certificates_xml): - tag = ElementTree.fromstring(certificates_xml).find( - './/Data') - certificates_content = tag.text - lines = [ - b'MIME-Version: 1.0', - b'Content-Disposition: attachment; filename="Certificates.p7m"', - b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', - b'Content-Transfer-Encoding: base64', - b'', - certificates_content.encode('utf-8'), - ] - with cd(self.tmpdir.name): - with open('Certificates.p7m', 'wb') as f: - f.write(b'\n'.join(lines)) - out, _ = util.subp( - 'openssl cms -decrypt -in Certificates.p7m -inkey' - ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' - ' -password pass:'.format(**self.certificate_names), - shell=True) - private_keys, certificates = [], [] - current = [] - for line in out.splitlines(): - current.append(line) - if re.match(r'[-]+END .*?KEY[-]+$', line): - private_keys.append('\n'.join(current)) - current = [] - elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): - certificates.append('\n'.join(current)) - current = [] - keys = [] - for certificate in certificates: - with cd(self.tmpdir.name): - public_key, _ = util.subp( - 'openssl x509 -noout -pubkey |' - 'ssh-keygen -i -m PKCS8 -f /dev/stdin', - data=certificate, - shell=True) - keys.append(public_key) - return keys - - -class WALinuxAgentShim(object): - - REPORT_READY_XML_TEMPLATE = '\n'.join([ - '', - '', - ' {incarnation}', - ' ', - ' {container_id}', - ' ', - ' ', - ' {instance_id}', - ' ', - ' Ready', - ' ', - ' ', - ' ', - ' ', - '']) - - def __init__(self): - LOG.debug('WALinuxAgentShim instantiated...') - self.endpoint = self.find_endpoint() - self.openssl_manager = OpenSSLManager() - self.http_client = AzureEndpointHttpClient( - self.openssl_manager.certificate) - self.values = {} - - @staticmethod - def find_endpoint(): - LOG.debug('Finding Azure endpoint...') - content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') - value = None - for line in content.splitlines(): - if 'unknown-245' in line: - value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') - if value is None: - raise Exception('No endpoint found in DHCP config.') - if ':' in value: - hex_string = '' - for hex_pair in value.split(':'): - if len(hex_pair) == 1: - hex_pair = '0' + hex_pair - hex_string += hex_pair - value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) - else: - value = value.encode('utf-8') - endpoint_ip_address = socket.inet_ntoa(value) - LOG.debug('Azure endpoint found at %s', endpoint_ip_address) - return endpoint_ip_address - - def register_with_azure_and_fetch_data(self): - LOG.info('Registering with Azure...') - for i in range(10): - try: - response = self.http_client.get( - 'http://{}/machine/?comp=goalstate'.format(self.endpoint)) - except Exception: - time.sleep(i + 1) - else: - break - LOG.debug('Successfully fetched GoalState XML.') - goal_state = GoalState(response.contents, self.http_client) - public_keys = [] - if goal_state.certificates_xml is not None: - LOG.debug('Certificate XML found; parsing out public keys.') - public_keys = self.openssl_manager.parse_certificates( - goal_state.certificates_xml) - data = { - 'instance-id': iid_from_shared_config_content( - goal_state.shared_config_xml), - 'public-keys': public_keys, - } - self._report_ready(goal_state) - return data - - def _report_ready(self, goal_state): - LOG.debug('Reporting ready to Azure fabric.') - document = self.REPORT_READY_XML_TEMPLATE.format( - incarnation=goal_state.incarnation, - container_id=goal_state.container_id, - instance_id=goal_state.instance_id, - ) - self.http_client.post( - "http://{}/machine?comp=health".format(self.endpoint), - data=document, - extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, - ) - LOG.info('Reported ready to Azure fabric.') - - def get_hostname(hostname_command='hostname'): return util.subp(hostname_command, capture=True)[0].strip() @@ -690,20 +439,6 @@ def load_azure_ovf_pubkeys(sshnode): return found -def single_node_at_path(node, pathlist): - curnode = node - for tok in pathlist: - results = find_child(curnode, lambda n: n.localName == tok) - if len(results) == 0: - raise ValueError("missing %s token in %s" % (tok, str(pathlist))) - if len(results) > 1: - raise ValueError("found %s nodes of type %s looking for %s" % - (len(results), tok, str(pathlist))) - curnode = results[0] - - return curnode - - def read_azure_ovf(contents): try: dom = minidom.parseString(contents) diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py new file mode 100644 index 00000000..60f116e0 --- /dev/null +++ b/cloudinit/sources/helpers/azure.py @@ -0,0 +1,273 @@ +import logging +import os +import re +import socket +import struct +import tempfile +import time +from contextlib import contextmanager +from xml.etree import ElementTree + +from cloudinit import util + + +LOG = logging.getLogger(__name__) + + +@contextmanager +def cd(newdir): + prevdir = os.getcwd() + os.chdir(os.path.expanduser(newdir)) + try: + yield + finally: + os.chdir(prevdir) + + +class AzureEndpointHttpClient(object): + + headers = { + 'x-ms-agent-name': 'WALinuxAgent', + 'x-ms-version': '2012-11-30', + } + + def __init__(self, certificate): + self.extra_secure_headers = { + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": certificate, + } + + def get(self, url, secure=False): + headers = self.headers + if secure: + headers = self.headers.copy() + headers.update(self.extra_secure_headers) + return util.read_file_or_url(url, headers=headers) + + def post(self, url, data=None, extra_headers=None): + headers = self.headers + if extra_headers is not None: + headers = self.headers.copy() + headers.update(extra_headers) + return util.read_file_or_url(url, data=data, headers=headers) + + +class GoalState(object): + + def __init__(self, xml, http_client): + self.http_client = http_client + self.root = ElementTree.fromstring(xml) + self._certificates_xml = None + + def _text_from_xpath(self, xpath): + element = self.root.find(xpath) + if element is not None: + return element.text + return None + + @property + def container_id(self): + return self._text_from_xpath('./Container/ContainerId') + + @property + def incarnation(self): + return self._text_from_xpath('./Incarnation') + + @property + def instance_id(self): + return self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance/InstanceId') + + @property + def shared_config_xml(self): + url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' + '/Configuration/SharedConfig') + return self.http_client.get(url).contents + + @property + def certificates_xml(self): + if self._certificates_xml is None: + url = self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance' + '/Configuration/Certificates') + if url is not None: + self._certificates_xml = self.http_client.get( + url, secure=True).contents + return self._certificates_xml + + +class OpenSSLManager(object): + + certificate_names = { + 'private_key': 'TransportPrivate.pem', + 'certificate': 'TransportCert.pem', + } + + def __init__(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.certificate = None + self.generate_certificate() + + def generate_certificate(self): + LOG.debug('Generating certificate for communication with fabric...') + if self.certificate is not None: + LOG.debug('Certificate already generated.') + return + with cd(self.tmpdir.name): + util.subp([ + 'openssl', 'req', '-x509', '-nodes', '-subj', + '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', + '-keyout', self.certificate_names['private_key'], + '-out', self.certificate_names['certificate'], + ]) + certificate = '' + for line in open(self.certificate_names['certificate']): + if "CERTIFICATE" not in line: + certificate += line.rstrip() + self.certificate = certificate + LOG.debug('New certificate generated.') + + def parse_certificates(self, certificates_xml): + tag = ElementTree.fromstring(certificates_xml).find( + './/Data') + certificates_content = tag.text + lines = [ + b'MIME-Version: 1.0', + b'Content-Disposition: attachment; filename="Certificates.p7m"', + b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', + b'Content-Transfer-Encoding: base64', + b'', + certificates_content.encode('utf-8'), + ] + with cd(self.tmpdir.name): + with open('Certificates.p7m', 'wb') as f: + f.write(b'\n'.join(lines)) + out, _ = util.subp( + 'openssl cms -decrypt -in Certificates.p7m -inkey' + ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' + ' -password pass:'.format(**self.certificate_names), + shell=True) + private_keys, certificates = [], [] + current = [] + for line in out.splitlines(): + current.append(line) + if re.match(r'[-]+END .*?KEY[-]+$', line): + private_keys.append('\n'.join(current)) + current = [] + elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): + certificates.append('\n'.join(current)) + current = [] + keys = [] + for certificate in certificates: + with cd(self.tmpdir.name): + public_key, _ = util.subp( + 'openssl x509 -noout -pubkey |' + 'ssh-keygen -i -m PKCS8 -f /dev/stdin', + data=certificate, + shell=True) + keys.append(public_key) + return keys + + +def iid_from_shared_config_content(content): + """ + find INSTANCE_ID in: + + + + + """ + root = ElementTree.fromstring(content) + depnode = root.find('Deployment') + return depnode.get('name') + + +class WALinuxAgentShim(object): + + REPORT_READY_XML_TEMPLATE = '\n'.join([ + '', + '', + ' {incarnation}', + ' ', + ' {container_id}', + ' ', + ' ', + ' {instance_id}', + ' ', + ' Ready', + ' ', + ' ', + ' ', + ' ', + '']) + + def __init__(self): + LOG.debug('WALinuxAgentShim instantiated...') + self.endpoint = self.find_endpoint() + self.openssl_manager = OpenSSLManager() + self.http_client = AzureEndpointHttpClient( + self.openssl_manager.certificate) + self.values = {} + + @staticmethod + def find_endpoint(): + LOG.debug('Finding Azure endpoint...') + content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') + value = None + for line in content.splitlines(): + if 'unknown-245' in line: + value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') + if value is None: + raise Exception('No endpoint found in DHCP config.') + if ':' in value: + hex_string = '' + for hex_pair in value.split(':'): + if len(hex_pair) == 1: + hex_pair = '0' + hex_pair + hex_string += hex_pair + value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) + else: + value = value.encode('utf-8') + endpoint_ip_address = socket.inet_ntoa(value) + LOG.debug('Azure endpoint found at %s', endpoint_ip_address) + return endpoint_ip_address + + def register_with_azure_and_fetch_data(self): + LOG.info('Registering with Azure...') + for i in range(10): + try: + response = self.http_client.get( + 'http://{}/machine/?comp=goalstate'.format(self.endpoint)) + except Exception: + time.sleep(i + 1) + else: + break + LOG.debug('Successfully fetched GoalState XML.') + goal_state = GoalState(response.contents, self.http_client) + public_keys = [] + if goal_state.certificates_xml is not None: + LOG.debug('Certificate XML found; parsing out public keys.') + public_keys = self.openssl_manager.parse_certificates( + goal_state.certificates_xml) + data = { + 'instance-id': iid_from_shared_config_content( + goal_state.shared_config_xml), + 'public-keys': public_keys, + } + self._report_ready(goal_state) + return data + + def _report_ready(self, goal_state): + LOG.debug('Reporting ready to Azure fabric.') + document = self.REPORT_READY_XML_TEMPLATE.format( + incarnation=goal_state.incarnation, + container_id=goal_state.container_id, + instance_id=goal_state.instance_id, + ) + self.http_client.post( + "http://{}/machine?comp=health".format(self.endpoint), + data=document, + extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, + ) + LOG.info('Reported ready to Azure fabric.') diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 28703029..ee7109e1 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -15,47 +15,9 @@ except ImportError: import crypt import os import stat -import struct import yaml import shutil import tempfile -import unittest - -from cloudinit import url_helper - - -GOAL_STATE_TEMPLATE = """\ - - - 2012-11-30 - {incarnation} - - Started - 300000 - - 16001 - - FALSE - - - {container_id} - - - {instance_id} - Started - - http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1 - {shared_config_url} - http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1 - http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1 - {certificates_url} - 68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml - - - - - -""" def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): @@ -610,329 +572,3 @@ class TestReadAzureOvf(TestCase): for mypk in mypklist: self.assertIn(mypk, cfg['_pubkeys']) - -class TestReadAzureSharedConfig(unittest.TestCase): - def test_valid_content(self): - xml = """ - - - - - - - """ - ret = DataSourceAzure.iid_from_shared_config_content(xml) - self.assertEqual("MY_INSTANCE_ID", ret) - - -class TestFindEndpoint(TestCase): - - def setUp(self): - super(TestFindEndpoint, self).setUp() - patches = ExitStack() - self.addCleanup(patches.close) - - self.load_file = patches.enter_context( - mock.patch.object(DataSourceAzure.util, 'load_file')) - - def test_missing_file(self): - self.load_file.side_effect = IOError - self.assertRaises(IOError, - DataSourceAzure.WALinuxAgentShim.find_endpoint) - - def test_missing_special_azure_line(self): - self.load_file.return_value = '' - self.assertRaises(Exception, - DataSourceAzure.WALinuxAgentShim.find_endpoint) - - def _build_lease_content(self, ip_address, use_hex=True): - ip_address_repr = ':'.join( - [hex(int(part)).replace('0x', '') - for part in ip_address.split('.')]) - if not use_hex: - ip_address_repr = struct.pack( - '>L', int(ip_address_repr.replace(':', ''), 16)) - ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) - return '\n'.join([ - 'lease {', - ' interface "eth0";', - ' option unknown-245 {0};'.format(ip_address_repr), - '}']) - - def test_hex_string(self): - ip_address = '98.76.54.32' - file_content = self._build_lease_content(ip_address) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - DataSourceAzure.WALinuxAgentShim.find_endpoint()) - - def test_hex_string_with_single_character_part(self): - ip_address = '4.3.2.1' - file_content = self._build_lease_content(ip_address) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - DataSourceAzure.WALinuxAgentShim.find_endpoint()) - - def test_packed_string(self): - ip_address = '98.76.54.32' - file_content = self._build_lease_content(ip_address, use_hex=False) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - DataSourceAzure.WALinuxAgentShim.find_endpoint()) - - def test_latest_lease_used(self): - ip_addresses = ['4.3.2.1', '98.76.54.32'] - file_content = '\n'.join([self._build_lease_content(ip_address) - for ip_address in ip_addresses]) - self.load_file.return_value = file_content - self.assertEqual(ip_addresses[-1], - DataSourceAzure.WALinuxAgentShim.find_endpoint()) - - -class TestGoalStateParsing(TestCase): - - default_parameters = { - 'incarnation': 1, - 'container_id': 'MyContainerId', - 'instance_id': 'MyInstanceId', - 'shared_config_url': 'MySharedConfigUrl', - 'certificates_url': 'MyCertificatesUrl', - } - - def _get_goal_state(self, http_client=None, **kwargs): - if http_client is None: - http_client = mock.MagicMock() - parameters = self.default_parameters.copy() - parameters.update(kwargs) - xml = GOAL_STATE_TEMPLATE.format(**parameters) - if parameters['certificates_url'] is None: - new_xml_lines = [] - for line in xml.splitlines(): - if 'Certificates' in line: - continue - new_xml_lines.append(line) - xml = '\n'.join(new_xml_lines) - return DataSourceAzure.GoalState(xml, http_client) - - def test_incarnation_parsed_correctly(self): - incarnation = '123' - goal_state = self._get_goal_state(incarnation=incarnation) - self.assertEqual(incarnation, goal_state.incarnation) - - def test_container_id_parsed_correctly(self): - container_id = 'TestContainerId' - goal_state = self._get_goal_state(container_id=container_id) - self.assertEqual(container_id, goal_state.container_id) - - def test_instance_id_parsed_correctly(self): - instance_id = 'TestInstanceId' - goal_state = self._get_goal_state(instance_id=instance_id) - self.assertEqual(instance_id, goal_state.instance_id) - - def test_shared_config_xml_parsed_and_fetched_correctly(self): - http_client = mock.MagicMock() - shared_config_url = 'TestSharedConfigUrl' - goal_state = self._get_goal_state( - http_client=http_client, shared_config_url=shared_config_url) - shared_config_xml = goal_state.shared_config_xml - self.assertEqual(1, http_client.get.call_count) - self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) - self.assertEqual(http_client.get.return_value.contents, - shared_config_xml) - - def test_certificates_xml_parsed_and_fetched_correctly(self): - http_client = mock.MagicMock() - certificates_url = 'TestSharedConfigUrl' - goal_state = self._get_goal_state( - http_client=http_client, certificates_url=certificates_url) - certificates_xml = goal_state.certificates_xml - self.assertEqual(1, http_client.get.call_count) - self.assertEqual(certificates_url, http_client.get.call_args[0][0]) - self.assertTrue(http_client.get.call_args[1].get('secure', False)) - self.assertEqual(http_client.get.return_value.contents, - certificates_xml) - - def test_missing_certificates_skips_http_get(self): - http_client = mock.MagicMock() - goal_state = self._get_goal_state( - http_client=http_client, certificates_url=None) - certificates_xml = goal_state.certificates_xml - self.assertEqual(0, http_client.get.call_count) - self.assertIsNone(certificates_xml) - - -class TestAzureEndpointHttpClient(TestCase): - - regular_headers = { - 'x-ms-agent-name': 'WALinuxAgent', - 'x-ms-version': '2012-11-30', - } - - def setUp(self): - super(TestAzureEndpointHttpClient, self).setUp() - patches = ExitStack() - self.addCleanup(patches.close) - - self.read_file_or_url = patches.enter_context( - mock.patch.object(DataSourceAzure.util, 'read_file_or_url')) - - def test_non_secure_get(self): - client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) - url = 'MyTestUrl' - response = client.get(url, secure=False) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) - self.assertEqual(mock.call(url, headers=self.regular_headers), - self.read_file_or_url.call_args) - - def test_secure_get(self): - url = 'MyTestUrl' - certificate = mock.MagicMock() - expected_headers = self.regular_headers.copy() - expected_headers.update({ - "x-ms-cipher-name": "DES_EDE3_CBC", - "x-ms-guest-agent-public-x509-cert": certificate, - }) - client = DataSourceAzure.AzureEndpointHttpClient(certificate) - response = client.get(url, secure=True) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) - self.assertEqual(mock.call(url, headers=expected_headers), - self.read_file_or_url.call_args) - - def test_post(self): - data = mock.MagicMock() - url = 'MyTestUrl' - client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) - response = client.post(url, data=data) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) - self.assertEqual( - mock.call(url, data=data, headers=self.regular_headers), - self.read_file_or_url.call_args) - - def test_post_with_extra_headers(self): - url = 'MyTestUrl' - client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) - extra_headers = {'test': 'header'} - client.post(url, extra_headers=extra_headers) - self.assertEqual(1, self.read_file_or_url.call_count) - expected_headers = self.regular_headers.copy() - expected_headers.update(extra_headers) - self.assertEqual( - mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), - self.read_file_or_url.call_args) - - -class TestOpenSSLManager(TestCase): - - def setUp(self): - super(TestOpenSSLManager, self).setUp() - patches = ExitStack() - self.addCleanup(patches.close) - - self.subp = patches.enter_context( - mock.patch.object(DataSourceAzure.util, 'subp')) - - @mock.patch.object(DataSourceAzure, 'cd', mock.MagicMock()) - @mock.patch.object(DataSourceAzure.tempfile, 'TemporaryDirectory') - def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): - manager = DataSourceAzure.OpenSSLManager() - self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) - - @mock.patch('builtins.open') - def test_generate_certificate_uses_tmpdir(self, open): - subp_directory = {} - - def capture_directory(*args, **kwargs): - subp_directory['path'] = os.getcwd() - - self.subp.side_effect = capture_directory - manager = DataSourceAzure.OpenSSLManager() - self.assertEqual(manager.tmpdir.name, subp_directory['path']) - - -class TestWALinuxAgentShim(TestCase): - - def setUp(self): - super(TestWALinuxAgentShim, self).setUp() - patches = ExitStack() - self.addCleanup(patches.close) - - self.AzureEndpointHttpClient = patches.enter_context( - mock.patch.object(DataSourceAzure, 'AzureEndpointHttpClient')) - self.find_endpoint = patches.enter_context( - mock.patch.object( - DataSourceAzure.WALinuxAgentShim, 'find_endpoint')) - self.GoalState = patches.enter_context( - mock.patch.object(DataSourceAzure, 'GoalState')) - self.iid_from_shared_config_content = patches.enter_context( - mock.patch.object(DataSourceAzure, - 'iid_from_shared_config_content')) - self.OpenSSLManager = patches.enter_context( - mock.patch.object(DataSourceAzure, 'OpenSSLManager')) - - def test_http_client_uses_certificate(self): - shim = DataSourceAzure.WALinuxAgentShim() - self.assertEqual( - [mock.call(self.OpenSSLManager.return_value.certificate)], - self.AzureEndpointHttpClient.call_args_list) - self.assertEqual(self.AzureEndpointHttpClient.return_value, - shim.http_client) - - def test_correct_url_used_for_goalstate(self): - self.find_endpoint.return_value = 'test_endpoint' - shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - get = self.AzureEndpointHttpClient.return_value.get - self.assertEqual( - [mock.call('http://test_endpoint/machine/?comp=goalstate')], - get.call_args_list) - self.assertEqual( - [mock.call(get.return_value.contents, shim.http_client)], - self.GoalState.call_args_list) - - def test_certificates_used_to_determine_public_keys(self): - shim = DataSourceAzure.WALinuxAgentShim() - data = shim.register_with_azure_and_fetch_data() - self.assertEqual( - [mock.call(self.GoalState.return_value.certificates_xml)], - self.OpenSSLManager.return_value.parse_certificates.call_args_list) - self.assertEqual( - self.OpenSSLManager.return_value.parse_certificates.return_value, - data['public-keys']) - - def test_absent_certificates_produces_empty_public_keys(self): - self.GoalState.return_value.certificates_xml = None - shim = DataSourceAzure.WALinuxAgentShim() - data = shim.register_with_azure_and_fetch_data() - self.assertEqual([], data['public-keys']) - - def test_instance_id_returned_in_data(self): - shim = DataSourceAzure.WALinuxAgentShim() - data = shim.register_with_azure_and_fetch_data() - self.assertEqual( - [mock.call(self.GoalState.return_value.shared_config_xml)], - self.iid_from_shared_config_content.call_args_list) - self.assertEqual(self.iid_from_shared_config_content.return_value, - data['instance-id']) - - def test_correct_url_used_for_report_ready(self): - self.find_endpoint.return_value = 'test_endpoint' - shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - expected_url = 'http://test_endpoint/machine?comp=health' - self.assertEqual( - [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], - shim.http_client.post.call_args_list) - - def test_goal_state_values_used_for_report_ready(self): - self.GoalState.return_value.incarnation = 'TestIncarnation' - self.GoalState.return_value.container_id = 'TestContainerId' - self.GoalState.return_value.instance_id = 'TestInstanceId' - shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - posted_document = shim.http_client.post.call_args[1]['data'] - self.assertIn('TestIncarnation', posted_document) - self.assertIn('TestContainerId', posted_document) - self.assertIn('TestInstanceId', posted_document) diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py new file mode 100644 index 00000000..47b77840 --- /dev/null +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -0,0 +1,377 @@ +import os +import struct +import unittest + +from cloudinit.sources.helpers import azure as azure_helper +from ..helpers import TestCase + +try: + from unittest import mock +except ImportError: + import mock + +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + + +GOAL_STATE_TEMPLATE = """\ + + + 2012-11-30 + {incarnation} + + Started + 300000 + + 16001 + + FALSE + + + {container_id} + + + {instance_id} + Started + + http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1 + {shared_config_url} + http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1 + http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1 + {certificates_url} + 68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml + + + + + +""" + + +class TestReadAzureSharedConfig(unittest.TestCase): + + def test_valid_content(self): + xml = """ + + + + + + + """ + ret = azure_helper.iid_from_shared_config_content(xml) + self.assertEqual("MY_INSTANCE_ID", ret) + + +class TestFindEndpoint(TestCase): + + def setUp(self): + super(TestFindEndpoint, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.load_file = patches.enter_context( + mock.patch.object(azure_helper.util, 'load_file')) + + def test_missing_file(self): + self.load_file.side_effect = IOError + self.assertRaises(IOError, + azure_helper.WALinuxAgentShim.find_endpoint) + + def test_missing_special_azure_line(self): + self.load_file.return_value = '' + self.assertRaises(Exception, + azure_helper.WALinuxAgentShim.find_endpoint) + + def _build_lease_content(self, ip_address, use_hex=True): + ip_address_repr = ':'.join( + [hex(int(part)).replace('0x', '') + for part in ip_address.split('.')]) + if not use_hex: + ip_address_repr = struct.pack( + '>L', int(ip_address_repr.replace(':', ''), 16)) + ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) + return '\n'.join([ + 'lease {', + ' interface "eth0";', + ' option unknown-245 {0};'.format(ip_address_repr), + '}']) + + def test_hex_string(self): + ip_address = '98.76.54.32' + file_content = self._build_lease_content(ip_address) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + + def test_hex_string_with_single_character_part(self): + ip_address = '4.3.2.1' + file_content = self._build_lease_content(ip_address) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + + def test_packed_string(self): + ip_address = '98.76.54.32' + file_content = self._build_lease_content(ip_address, use_hex=False) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + + def test_latest_lease_used(self): + ip_addresses = ['4.3.2.1', '98.76.54.32'] + file_content = '\n'.join([self._build_lease_content(ip_address) + for ip_address in ip_addresses]) + self.load_file.return_value = file_content + self.assertEqual(ip_addresses[-1], + azure_helper.WALinuxAgentShim.find_endpoint()) + + +class TestGoalStateParsing(TestCase): + + default_parameters = { + 'incarnation': 1, + 'container_id': 'MyContainerId', + 'instance_id': 'MyInstanceId', + 'shared_config_url': 'MySharedConfigUrl', + 'certificates_url': 'MyCertificatesUrl', + } + + def _get_goal_state(self, http_client=None, **kwargs): + if http_client is None: + http_client = mock.MagicMock() + parameters = self.default_parameters.copy() + parameters.update(kwargs) + xml = GOAL_STATE_TEMPLATE.format(**parameters) + if parameters['certificates_url'] is None: + new_xml_lines = [] + for line in xml.splitlines(): + if 'Certificates' in line: + continue + new_xml_lines.append(line) + xml = '\n'.join(new_xml_lines) + return azure_helper.GoalState(xml, http_client) + + def test_incarnation_parsed_correctly(self): + incarnation = '123' + goal_state = self._get_goal_state(incarnation=incarnation) + self.assertEqual(incarnation, goal_state.incarnation) + + def test_container_id_parsed_correctly(self): + container_id = 'TestContainerId' + goal_state = self._get_goal_state(container_id=container_id) + self.assertEqual(container_id, goal_state.container_id) + + def test_instance_id_parsed_correctly(self): + instance_id = 'TestInstanceId' + goal_state = self._get_goal_state(instance_id=instance_id) + self.assertEqual(instance_id, goal_state.instance_id) + + def test_shared_config_xml_parsed_and_fetched_correctly(self): + http_client = mock.MagicMock() + shared_config_url = 'TestSharedConfigUrl' + goal_state = self._get_goal_state( + http_client=http_client, shared_config_url=shared_config_url) + shared_config_xml = goal_state.shared_config_xml + self.assertEqual(1, http_client.get.call_count) + self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) + self.assertEqual(http_client.get.return_value.contents, + shared_config_xml) + + def test_certificates_xml_parsed_and_fetched_correctly(self): + http_client = mock.MagicMock() + certificates_url = 'TestSharedConfigUrl' + goal_state = self._get_goal_state( + http_client=http_client, certificates_url=certificates_url) + certificates_xml = goal_state.certificates_xml + self.assertEqual(1, http_client.get.call_count) + self.assertEqual(certificates_url, http_client.get.call_args[0][0]) + self.assertTrue(http_client.get.call_args[1].get('secure', False)) + self.assertEqual(http_client.get.return_value.contents, + certificates_xml) + + def test_missing_certificates_skips_http_get(self): + http_client = mock.MagicMock() + goal_state = self._get_goal_state( + http_client=http_client, certificates_url=None) + certificates_xml = goal_state.certificates_xml + self.assertEqual(0, http_client.get.call_count) + self.assertIsNone(certificates_xml) + + +class TestAzureEndpointHttpClient(TestCase): + + regular_headers = { + 'x-ms-agent-name': 'WALinuxAgent', + 'x-ms-version': '2012-11-30', + } + + def setUp(self): + super(TestAzureEndpointHttpClient, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.read_file_or_url = patches.enter_context( + mock.patch.object(azure_helper.util, 'read_file_or_url')) + + def test_non_secure_get(self): + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + url = 'MyTestUrl' + response = client.get(url, secure=False) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(mock.call(url, headers=self.regular_headers), + self.read_file_or_url.call_args) + + def test_secure_get(self): + url = 'MyTestUrl' + certificate = mock.MagicMock() + expected_headers = self.regular_headers.copy() + expected_headers.update({ + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": certificate, + }) + client = azure_helper.AzureEndpointHttpClient(certificate) + response = client.get(url, secure=True) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(mock.call(url, headers=expected_headers), + self.read_file_or_url.call_args) + + def test_post(self): + data = mock.MagicMock() + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + response = client.post(url, data=data) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual( + mock.call(url, data=data, headers=self.regular_headers), + self.read_file_or_url.call_args) + + def test_post_with_extra_headers(self): + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + extra_headers = {'test': 'header'} + client.post(url, extra_headers=extra_headers) + self.assertEqual(1, self.read_file_or_url.call_count) + expected_headers = self.regular_headers.copy() + expected_headers.update(extra_headers) + self.assertEqual( + mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), + self.read_file_or_url.call_args) + + +class TestOpenSSLManager(TestCase): + + def setUp(self): + super(TestOpenSSLManager, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.subp = patches.enter_context( + mock.patch.object(azure_helper.util, 'subp')) + + @mock.patch.object(azure_helper, 'cd', mock.MagicMock()) + @mock.patch.object(azure_helper.tempfile, 'TemporaryDirectory') + def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): + manager = azure_helper.OpenSSLManager() + self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) + + @mock.patch('builtins.open') + def test_generate_certificate_uses_tmpdir(self, open): + subp_directory = {} + + def capture_directory(*args, **kwargs): + subp_directory['path'] = os.getcwd() + + self.subp.side_effect = capture_directory + manager = azure_helper.OpenSSLManager() + self.assertEqual(manager.tmpdir.name, subp_directory['path']) + + +class TestWALinuxAgentShim(TestCase): + + def setUp(self): + super(TestWALinuxAgentShim, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.AzureEndpointHttpClient = patches.enter_context( + mock.patch.object(azure_helper, 'AzureEndpointHttpClient')) + self.find_endpoint = patches.enter_context( + mock.patch.object( + azure_helper.WALinuxAgentShim, 'find_endpoint')) + self.GoalState = patches.enter_context( + mock.patch.object(azure_helper, 'GoalState')) + self.iid_from_shared_config_content = patches.enter_context( + mock.patch.object(azure_helper, 'iid_from_shared_config_content')) + self.OpenSSLManager = patches.enter_context( + mock.patch.object(azure_helper, 'OpenSSLManager')) + + def test_http_client_uses_certificate(self): + shim = azure_helper.WALinuxAgentShim() + self.assertEqual( + [mock.call(self.OpenSSLManager.return_value.certificate)], + self.AzureEndpointHttpClient.call_args_list) + self.assertEqual(self.AzureEndpointHttpClient.return_value, + shim.http_client) + + def test_correct_url_used_for_goalstate(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + get = self.AzureEndpointHttpClient.return_value.get + self.assertEqual( + [mock.call('http://test_endpoint/machine/?comp=goalstate')], + get.call_args_list) + self.assertEqual( + [mock.call(get.return_value.contents, shim.http_client)], + self.GoalState.call_args_list) + + def test_certificates_used_to_determine_public_keys(self): + shim = azure_helper.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.certificates_xml)], + self.OpenSSLManager.return_value.parse_certificates.call_args_list) + self.assertEqual( + self.OpenSSLManager.return_value.parse_certificates.return_value, + data['public-keys']) + + def test_absent_certificates_produces_empty_public_keys(self): + self.GoalState.return_value.certificates_xml = None + shim = azure_helper.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual([], data['public-keys']) + + def test_instance_id_returned_in_data(self): + shim = azure_helper.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.shared_config_xml)], + self.iid_from_shared_config_content.call_args_list) + self.assertEqual(self.iid_from_shared_config_content.return_value, + data['instance-id']) + + def test_correct_url_used_for_report_ready(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + expected_url = 'http://test_endpoint/machine?comp=health' + self.assertEqual( + [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], + shim.http_client.post.call_args_list) + + def test_goal_state_values_used_for_report_ready(self): + self.GoalState.return_value.incarnation = 'TestIncarnation' + self.GoalState.return_value.container_id = 'TestContainerId' + self.GoalState.return_value.instance_id = 'TestInstanceId' + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + posted_document = shim.http_client.post.call_args[1]['data'] + self.assertIn('TestIncarnation', posted_document) + self.assertIn('TestContainerId', posted_document) + self.assertIn('TestInstanceId', posted_document) -- cgit v1.2.3 From 9c7643c4a0dee7843963709c361b755baf843a4b Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 13:16:44 +0100 Subject: Stop using Python 3 only tempfile.TemporaryDirectory (but lose free cleanup). --- cloudinit/sources/helpers/azure.py | 8 ++++---- tests/unittests/test_datasource/test_azure_helper.py | 17 +++++++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 60f116e0..cb13187f 100644 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -104,7 +104,7 @@ class OpenSSLManager(object): } def __init__(self): - self.tmpdir = tempfile.TemporaryDirectory() + self.tmpdir = tempfile.mkdtemp() self.certificate = None self.generate_certificate() @@ -113,7 +113,7 @@ class OpenSSLManager(object): if self.certificate is not None: LOG.debug('Certificate already generated.') return - with cd(self.tmpdir.name): + with cd(self.tmpdir): util.subp([ 'openssl', 'req', '-x509', '-nodes', '-subj', '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', @@ -139,7 +139,7 @@ class OpenSSLManager(object): b'', certificates_content.encode('utf-8'), ] - with cd(self.tmpdir.name): + with cd(self.tmpdir): with open('Certificates.p7m', 'wb') as f: f.write(b'\n'.join(lines)) out, _ = util.subp( @@ -159,7 +159,7 @@ class OpenSSLManager(object): current = [] keys = [] for certificate in certificates: - with cd(self.tmpdir.name): + with cd(self.tmpdir): public_key, _ = util.subp( 'openssl x509 -noout -pubkey |' 'ssh-keygen -i -m PKCS8 -f /dev/stdin', diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 47b77840..398a9007 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -273,15 +273,20 @@ class TestOpenSSLManager(TestCase): self.subp = patches.enter_context( mock.patch.object(azure_helper.util, 'subp')) + try: + self.open = patches.enter_context( + mock.patch('__builtin__.open')) + except ImportError: + self.open = patches.enter_context( + mock.patch('builtins.open')) @mock.patch.object(azure_helper, 'cd', mock.MagicMock()) - @mock.patch.object(azure_helper.tempfile, 'TemporaryDirectory') - def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): + @mock.patch.object(azure_helper.tempfile, 'mkdtemp') + def test_openssl_manager_creates_a_tmpdir(self, mkdtemp): manager = azure_helper.OpenSSLManager() - self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) + self.assertEqual(mkdtemp.return_value, manager.tmpdir) - @mock.patch('builtins.open') - def test_generate_certificate_uses_tmpdir(self, open): + def test_generate_certificate_uses_tmpdir(self): subp_directory = {} def capture_directory(*args, **kwargs): @@ -289,7 +294,7 @@ class TestOpenSSLManager(TestCase): self.subp.side_effect = capture_directory manager = azure_helper.OpenSSLManager() - self.assertEqual(manager.tmpdir.name, subp_directory['path']) + self.assertEqual(manager.tmpdir, subp_directory['path']) class TestWALinuxAgentShim(TestCase): -- cgit v1.2.3 From 84868622c404cda5efd2a753e2de30c1afca49a2 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 13:18:02 +0100 Subject: Move our walinuxagent implementation to a single function call. --- cloudinit/sources/DataSourceAzure.py | 8 ++-- cloudinit/sources/helpers/azure.py | 31 ++++++++---- tests/unittests/test_datasource/test_azure.py | 19 ++++++-- .../unittests/test_datasource/test_azure_helper.py | 56 ++++++++++++++++++++-- 4 files changed, 92 insertions(+), 22 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 5e147950..4053cfa6 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -29,7 +29,7 @@ from cloudinit.settings import PER_ALWAYS from cloudinit import sources from cloudinit import util from cloudinit.sources.helpers.azure import ( - iid_from_shared_config_content, WALinuxAgentShim) + get_metadata_from_fabric, iid_from_shared_config_content) LOG = logging.getLogger(__name__) @@ -185,15 +185,13 @@ class DataSourceAzureNet(sources.DataSource): write_files(ddir, files, dirmode=0o700) try: - shim = WALinuxAgentShim() - data = shim.register_with_azure_and_fetch_data() + fabric_data = get_metadata_from_fabric() except Exception as exc: LOG.info("Error communicating with Azure fabric; assume we aren't" " on Azure.", exc_info=True) return False - self.metadata['instance-id'] = data['instance-id'] - self.metadata['public-keys'] = data['public-keys'] + self.metadata.update(fabric_data) found_ephemeral = find_ephemeral_disk() if found_ephemeral: diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index cb13187f..dfdfa7c2 100644 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -108,6 +108,9 @@ class OpenSSLManager(object): self.certificate = None self.generate_certificate() + def clean_up(self): + util.del_dir(self.tmpdir) + def generate_certificate(self): LOG.debug('Generating certificate for communication with fabric...') if self.certificate is not None: @@ -205,11 +208,13 @@ class WALinuxAgentShim(object): def __init__(self): LOG.debug('WALinuxAgentShim instantiated...') self.endpoint = self.find_endpoint() - self.openssl_manager = OpenSSLManager() - self.http_client = AzureEndpointHttpClient( - self.openssl_manager.certificate) + self.openssl_manager = None self.values = {} + def clean_up(self): + if self.openssl_manager is not None: + self.openssl_manager.clean_up() + @staticmethod def find_endpoint(): LOG.debug('Finding Azure endpoint...') @@ -234,17 +239,19 @@ class WALinuxAgentShim(object): return endpoint_ip_address def register_with_azure_and_fetch_data(self): + self.openssl_manager = OpenSSLManager() + http_client = AzureEndpointHttpClient(self.openssl_manager.certificate) LOG.info('Registering with Azure...') for i in range(10): try: - response = self.http_client.get( + response = http_client.get( 'http://{}/machine/?comp=goalstate'.format(self.endpoint)) except Exception: time.sleep(i + 1) else: break LOG.debug('Successfully fetched GoalState XML.') - goal_state = GoalState(response.contents, self.http_client) + goal_state = GoalState(response.contents, http_client) public_keys = [] if goal_state.certificates_xml is not None: LOG.debug('Certificate XML found; parsing out public keys.') @@ -255,19 +262,27 @@ class WALinuxAgentShim(object): goal_state.shared_config_xml), 'public-keys': public_keys, } - self._report_ready(goal_state) + self._report_ready(goal_state, http_client) return data - def _report_ready(self, goal_state): + def _report_ready(self, goal_state, http_client): LOG.debug('Reporting ready to Azure fabric.') document = self.REPORT_READY_XML_TEMPLATE.format( incarnation=goal_state.incarnation, container_id=goal_state.container_id, instance_id=goal_state.instance_id, ) - self.http_client.post( + http_client.post( "http://{}/machine?comp=health".format(self.endpoint), data=document, extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, ) LOG.info('Reported ready to Azure fabric.') + + +def get_metadata_from_fabric(): + shim = WALinuxAgentShim() + try: + return shim.register_with_azure_and_fetch_data() + finally: + shim.clean_up() diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index ee7109e1..983be4cd 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -122,11 +122,10 @@ class TestAzureDataSource(TestCase): mod = DataSourceAzure mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d - fake_shim = mock.MagicMock() - fake_shim().register_with_azure_and_fetch_data.return_value = { + self.get_metadata_from_fabric = mock.MagicMock(return_value={ 'instance-id': 'i-my-azure-id', 'public-keys': [], - } + }) self.apply_patches([ (mod, 'list_possible_azure_ds_devs', dsdevs), @@ -137,7 +136,7 @@ class TestAzureDataSource(TestCase): (mod, 'perform_hostname_bounce', mock.MagicMock()), (mod, 'get_hostname', mock.MagicMock()), (mod, 'set_hostname', mock.MagicMock()), - (mod, 'WALinuxAgentShim', fake_shim), + (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric), ]) dsrc = mod.DataSourceAzureNet( @@ -388,6 +387,18 @@ class TestAzureDataSource(TestCase): self.assertEqual(new_ovfenv, load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) + def test_exception_fetching_fabric_data_doesnt_propagate(self): + ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + self.get_metadata_from_fabric.side_effect = Exception + self.assertFalse(ds.get_data()) + + def test_fabric_data_included_in_metadata(self): + ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + self.get_metadata_from_fabric.return_value = {'test': 'value'} + ret = ds.get_data() + self.assertTrue(ret) + self.assertEqual('value', ds.metadata['test']) + class TestAzureBounce(TestCase): diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 398a9007..5fac2ade 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -296,6 +296,14 @@ class TestOpenSSLManager(TestCase): manager = azure_helper.OpenSSLManager() self.assertEqual(manager.tmpdir, subp_directory['path']) + @mock.patch.object(azure_helper, 'cd', mock.MagicMock()) + @mock.patch.object(azure_helper.tempfile, 'mkdtemp', mock.MagicMock()) + @mock.patch.object(azure_helper.util, 'del_dir') + def test_clean_up(self, del_dir): + manager = azure_helper.OpenSSLManager() + manager.clean_up() + self.assertEqual([mock.call(manager.tmpdir)], del_dir.call_args_list) + class TestWALinuxAgentShim(TestCase): @@ -318,11 +326,10 @@ class TestWALinuxAgentShim(TestCase): def test_http_client_uses_certificate(self): shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(self.OpenSSLManager.return_value.certificate)], self.AzureEndpointHttpClient.call_args_list) - self.assertEqual(self.AzureEndpointHttpClient.return_value, - shim.http_client) def test_correct_url_used_for_goalstate(self): self.find_endpoint.return_value = 'test_endpoint' @@ -333,7 +340,8 @@ class TestWALinuxAgentShim(TestCase): [mock.call('http://test_endpoint/machine/?comp=goalstate')], get.call_args_list) self.assertEqual( - [mock.call(get.return_value.contents, shim.http_client)], + [mock.call(get.return_value.contents, + self.AzureEndpointHttpClient.return_value)], self.GoalState.call_args_list) def test_certificates_used_to_determine_public_keys(self): @@ -368,7 +376,7 @@ class TestWALinuxAgentShim(TestCase): expected_url = 'http://test_endpoint/machine?comp=health' self.assertEqual( [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], - shim.http_client.post.call_args_list) + self.AzureEndpointHttpClient.return_value.post.call_args_list) def test_goal_state_values_used_for_report_ready(self): self.GoalState.return_value.incarnation = 'TestIncarnation' @@ -376,7 +384,45 @@ class TestWALinuxAgentShim(TestCase): self.GoalState.return_value.instance_id = 'TestInstanceId' shim = azure_helper.WALinuxAgentShim() shim.register_with_azure_and_fetch_data() - posted_document = shim.http_client.post.call_args[1]['data'] + posted_document = ( + self.AzureEndpointHttpClient.return_value.post.call_args[1]['data'] + ) self.assertIn('TestIncarnation', posted_document) self.assertIn('TestContainerId', posted_document) self.assertIn('TestInstanceId', posted_document) + + def test_clean_up_can_be_called_at_any_time(self): + shim = azure_helper.WALinuxAgentShim() + shim.clean_up() + + def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self): + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + shim.clean_up() + self.assertEqual( + 1, self.OpenSSLManager.return_value.clean_up.call_count) + + +class TestGetMetadataFromFabric(TestCase): + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_data_from_shim_returned(self, shim): + ret = azure_helper.get_metadata_from_fabric() + self.assertEqual( + shim.return_value.register_with_azure_and_fetch_data.return_value, + ret) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_success_calls_clean_up(self, shim): + azure_helper.get_metadata_from_fabric() + self.assertEqual(1, shim.return_value.clean_up.call_count) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_failure_in_registration_calls_clean_up(self, shim): + class SentinelException(Exception): + pass + shim.return_value.register_with_azure_and_fetch_data.side_effect = ( + SentinelException) + self.assertRaises(SentinelException, + azure_helper.get_metadata_from_fabric) + self.assertEqual(1, shim.return_value.clean_up.call_count) -- cgit v1.2.3 From 1185aeae80fc8279946069bb8eec492b3cb81556 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 16:22:36 +0100 Subject: Reintroduce original code path. --- cloudinit/sources/DataSourceAzure.py | 74 +++++++++++++++++++++------ tests/unittests/test_datasource/test_azure.py | 5 ++ 2 files changed, 63 insertions(+), 16 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 4053cfa6..3c7820a6 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -22,6 +22,7 @@ import crypt import fnmatch import os import os.path +import time from xml.dom import minidom from cloudinit import log as logging @@ -35,11 +36,13 @@ LOG = logging.getLogger(__name__) DS_NAME = 'Azure' DEFAULT_METADATA = {"instance-id": "iid-AZURE-NODE"} +AGENT_START = ['service', 'walinuxagent', 'start'] BOUNCE_COMMAND = ['sh', '-xc', "i=$interface; x=0; ifdown $i || x=$?; ifup $i || x=$?; exit $x"] DATA_DIR_CLEAN_LIST = ['SharedConfig.xml'] BUILTIN_DS_CONFIG = { + 'agent_command': '__builtin__', 'data_dir': "/var/lib/waagent", 'set_hostname': True, 'hostname_bounce': { @@ -110,6 +113,56 @@ class DataSourceAzureNet(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [seed=%s]" % (root, self.seed) + def get_metadata_from_agent(self): + temp_hostname = self.metadata.get('local-hostname') + hostname_command = self.ds_cfg['hostname_bounce']['hostname_command'] + with temporary_hostname(temp_hostname, self.ds_cfg, + hostname_command=hostname_command) \ + as previous_hostname: + if (previous_hostname is not None + and util.is_true(self.ds_cfg.get('set_hostname'))): + cfg = self.ds_cfg['hostname_bounce'] + try: + perform_hostname_bounce(hostname=temp_hostname, + cfg=cfg, + prev_hostname=previous_hostname) + except Exception as e: + LOG.warn("Failed publishing hostname: %s", e) + util.logexc(LOG, "handling set_hostname failed") + + try: + invoke_agent(self.ds_cfg['agent_command']) + except util.ProcessExecutionError: + # claim the datasource even if the command failed + util.logexc(LOG, "agent command '%s' failed.", + self.ds_cfg['agent_command']) + + ddir = self.ds_cfg['data_dir'] + shcfgxml = os.path.join(ddir, "SharedConfig.xml") + wait_for = [shcfgxml] + + fp_files = [] + for pk in self.cfg.get('_pubkeys', []): + bname = str(pk['fingerprint'] + ".crt") + fp_files += [os.path.join(ddir, bname)] + + missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", + func=wait_for_files, + args=(wait_for + fp_files,)) + if len(missing): + LOG.warn("Did not find files, but going on: %s", missing) + + metadata = {} + if shcfgxml in missing: + LOG.warn("SharedConfig.xml missing, using static instance-id") + else: + try: + metadata['instance-id'] = iid_from_shared_config(shcfgxml) + except ValueError as e: + LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) + metadata['public-keys'] = pubkeys_from_crt_files(fp_files) + return metadata + def get_data(self): # azure removes/ejects the cdrom containing the ovf-env.xml # file on reboot. So, in order to successfully reboot we @@ -162,8 +215,6 @@ class DataSourceAzureNet(sources.DataSource): # now update ds_cfg to reflect contents pass in config user_ds_cfg = util.get_cfg_by_path(self.cfg, DS_CFG_PATH, {}) self.ds_cfg = util.mergemanydict([user_ds_cfg, self.ds_cfg]) - mycfg = self.ds_cfg - ddir = mycfg['data_dir'] if found != ddir: cached_ovfenv = util.load_file( @@ -184,8 +235,12 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) + if self.ds_cfg['agent_command'] == '__builtin__': + metadata_func = get_metadata_from_fabric + else: + metadata_func = self.get_metadata_from_agent try: - fabric_data = get_metadata_from_fabric() + fabric_data = metadata_func() except Exception as exc: LOG.info("Error communicating with Azure fabric; assume we aren't" " on Azure.", exc_info=True) @@ -567,19 +622,6 @@ def iid_from_shared_config(path): return iid_from_shared_config_content(content) -def iid_from_shared_config_content(content): - """ - find INSTANCE_ID in: - - - - - """ - dom = minidom.parseString(content) - depnode = single_node_at_path(dom, ["SharedConfig", "Deployment"]) - return depnode.attributes.get('name').value - - class BrokenAzureDataSource(Exception): pass diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 983be4cd..c72dc801 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -389,11 +389,13 @@ class TestAzureDataSource(TestCase): def test_exception_fetching_fabric_data_doesnt_propagate(self): ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + ds.ds_cfg['agent_command'] = '__builtin__' self.get_metadata_from_fabric.side_effect = Exception self.assertFalse(ds.get_data()) def test_fabric_data_included_in_metadata(self): ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + ds.ds_cfg['agent_command'] = '__builtin__' self.get_metadata_from_fabric.return_value = {'test': 'value'} ret = ds.get_data() self.assertTrue(ret) @@ -419,6 +421,9 @@ class TestAzureBounce(TestCase): self.patches.enter_context( mock.patch.object(DataSourceAzure, 'find_ephemeral_part', mock.MagicMock(return_value=None))) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric', + mock.MagicMock(return_value={}))) def setUp(self): super(TestAzureBounce, self).setUp() -- cgit v1.2.3 From 512eb552e0ca740e1d285dc1b66a56579bcf68ec Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 8 May 2015 16:52:49 +0100 Subject: Fix retrying. --- cloudinit/sources/helpers/azure.py | 9 +++++++-- tests/unittests/test_datasource/test_azure_helper.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index dfdfa7c2..2ce728f5 100644 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -242,14 +242,19 @@ class WALinuxAgentShim(object): self.openssl_manager = OpenSSLManager() http_client = AzureEndpointHttpClient(self.openssl_manager.certificate) LOG.info('Registering with Azure...') - for i in range(10): + attempts = 0 + while True: try: response = http_client.get( 'http://{}/machine/?comp=goalstate'.format(self.endpoint)) except Exception: - time.sleep(i + 1) + if attempts < 10: + time.sleep(attempts + 1) + else: + raise else: break + attempts += 1 LOG.debug('Successfully fetched GoalState XML.') goal_state = GoalState(response.contents, http_client) public_keys = [] diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 5fac2ade..23bc997c 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -323,6 +323,8 @@ class TestWALinuxAgentShim(TestCase): mock.patch.object(azure_helper, 'iid_from_shared_config_content')) self.OpenSSLManager = patches.enter_context( mock.patch.object(azure_helper, 'OpenSSLManager')) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) def test_http_client_uses_certificate(self): shim = azure_helper.WALinuxAgentShim() @@ -402,6 +404,15 @@ class TestWALinuxAgentShim(TestCase): self.assertEqual( 1, self.OpenSSLManager.return_value.clean_up.call_count) + def test_failure_to_fetch_goalstate_bubbles_up(self): + class SentinelException(Exception): + pass + self.AzureEndpointHttpClient.return_value.get.side_effect = ( + SentinelException) + shim = azure_helper.WALinuxAgentShim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) + class TestGetMetadataFromFabric(TestCase): -- cgit v1.2.3 From 74023961b70a178039ecf10f68745f6927113978 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 14 May 2015 17:06:39 -0400 Subject: read_seeded: fix reed_seeded after regression read_seeded was assuming a Response object back from load_tfile_or_url but load_tfile_or_url was returning string. since the only other user of this was a test, move load_tfile_or_url to a test, and just do the right thing in read_seeded. LP: #1455233 --- cloudinit/util.py | 8 ++------ .../test_handler/test_handler_apt_configure.py | 15 +++++++++------ tests/unittests/test_util.py | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 12 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index cae57770..db4e02b8 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -766,10 +766,6 @@ def fetch_ssl_details(paths=None): return ssl_details -def load_tfile_or_url(*args, **kwargs): - return(decode_binary(read_file_or_url(*args, **kwargs).contents)) - - def read_file_or_url(url, timeout=5, retries=10, headers=None, data=None, sec_between=1, ssl_details=None, headers_cb=None, exception_cb=None): @@ -837,10 +833,10 @@ def read_seeded(base="", ext="", timeout=5, retries=10, file_retries=0): ud_url = "%s%s%s" % (base, "user-data", ext) md_url = "%s%s%s" % (base, "meta-data", ext) - md_resp = load_tfile_or_url(md_url, timeout, retries, file_retries) + md_resp = read_file_or_url(md_url, timeout, retries, file_retries) md = None if md_resp.ok(): - md = load_yaml(md_resp.contents, default={}) + md = load_yaml(decode_binary(md_resp.contents), default={}) ud_resp = read_file_or_url(ud_url, timeout, retries, file_retries) ud = None diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index 895728b3..4a74ea47 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -8,6 +8,9 @@ import re import shutil import tempfile +def load_tfile_or_url(*args, **kwargs): + return(util.decode_binary(util.read_file_or_url(*args, **kwargs).contents)) + class TestAptProxyConfig(TestCase): def setUp(self): @@ -29,7 +32,7 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.pfile)) self.assertFalse(os.path.isfile(self.cfile)) - contents = util.load_tfile_or_url(self.pfile) + contents = load_tfile_or_url(self.pfile) self.assertTrue(self._search_apt_config(contents, "http", "myproxy")) def test_apt_http_proxy_written(self): @@ -39,7 +42,7 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.pfile)) self.assertFalse(os.path.isfile(self.cfile)) - contents = util.load_tfile_or_url(self.pfile) + contents = load_tfile_or_url(self.pfile) self.assertTrue(self._search_apt_config(contents, "http", "myproxy")) def test_apt_all_proxy_written(self): @@ -57,7 +60,7 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.pfile)) self.assertFalse(os.path.isfile(self.cfile)) - contents = util.load_tfile_or_url(self.pfile) + contents = load_tfile_or_url(self.pfile) for ptype, pval in values.items(): self.assertTrue(self._search_apt_config(contents, ptype, pval)) @@ -73,7 +76,7 @@ class TestAptProxyConfig(TestCase): cc_apt_configure.apply_apt_config({'apt_proxy': "foo"}, self.pfile, self.cfile) self.assertTrue(os.path.isfile(self.pfile)) - contents = util.load_tfile_or_url(self.pfile) + contents = load_tfile_or_url(self.pfile) self.assertTrue(self._search_apt_config(contents, "http", "foo")) def test_config_written(self): @@ -85,14 +88,14 @@ class TestAptProxyConfig(TestCase): self.assertTrue(os.path.isfile(self.cfile)) self.assertFalse(os.path.isfile(self.pfile)) - self.assertEqual(util.load_tfile_or_url(self.cfile), payload) + self.assertEqual(load_tfile_or_url(self.cfile), payload) def test_config_replaced(self): util.write_file(self.pfile, "content doesnt matter") cc_apt_configure.apply_apt_config({'apt_config': "foo"}, self.pfile, self.cfile) self.assertTrue(os.path.isfile(self.cfile)) - self.assertEqual(util.load_tfile_or_url(self.cfile), "foo") + self.assertEqual(load_tfile_or_url(self.cfile), "foo") def test_config_deleted(self): # if no 'apt_config' is provided, delete any previously written file diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 1619b5d2..95990165 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -459,4 +459,21 @@ class TestMessageFromString(helpers.TestCase): roundtripped = util.message_from_string(u'\n').as_string() self.assertNotIn('\x00', roundtripped) + +class TestReadSeeded(helpers.TestCase): + def setUp(self): + super(TestReadSeeded, self).setUp() + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + + def test_unicode_not_messed_up(self): + ud = b"userdatablob" + helpers.populate_dir( + self.tmp, {'meta-data': "key1: val1", 'user-data': ud}) + sdir = self.tmp + os.path.sep + (found_md, found_ud) = util.read_seeded(sdir) + + self.assertEqual(found_md, {'key1': 'val1'}) + self.assertEqual(found_ud, ud) + # vi: ts=4 expandtab -- cgit v1.2.3 From 33db529855c0c0746091868c69f5d694bff7b9a8 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 15 May 2015 16:28:24 -0400 Subject: pep8 fixes --- cloudinit/distros/rhel.py | 4 ++-- tests/unittests/test_datasource/test_azure.py | 1 - tests/unittests/test_datasource/test_azure_helper.py | 15 ++++++++++----- .../unittests/test_handler/test_handler_apt_configure.py | 1 + 4 files changed, 13 insertions(+), 8 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/distros/rhel.py b/cloudinit/distros/rhel.py index eec17c61..30c805a6 100644 --- a/cloudinit/distros/rhel.py +++ b/cloudinit/distros/rhel.py @@ -133,7 +133,7 @@ class Distro(distros.Distro): rhel_util.update_sysconfig_file(out_fn, locale_cfg) def _write_hostname(self, hostname, out_fn): - # systemd will never update previous-hostname for us, so + # systemd will never update previous-hostname for us, so # we need to do it ourselves if self.uses_systemd() and out_fn.endswith('/previous-hostname'): util.write_file(out_fn, hostname) @@ -161,7 +161,7 @@ class Distro(distros.Distro): def _read_hostname(self, filename, default=None): if self.uses_systemd() and filename.endswith('/previous-hostname'): - return util.load_file(filename).strip() + return util.load_file(filename).strip() elif self.uses_systemd(): (out, _err) = util.subp(['hostname']) if len(out): diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index c72dc801..4c4b8eec 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -587,4 +587,3 @@ class TestReadAzureOvf(TestCase): (_md, _ud, cfg) = DataSourceAzure.read_azure_ovf(content) for mypk in mypklist: self.assertIn(mypk, cfg['_pubkeys']) - diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 23bc997c..a5228870 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -18,7 +18,8 @@ except ImportError: GOAL_STATE_TEMPLATE = """\ - + 2012-11-30 {incarnation} @@ -36,12 +37,16 @@ GOAL_STATE_TEMPLATE = """\ {instance_id} Started - http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1 + + http://100.86.192.70:80/...hostingEnvironmentConfig... + {shared_config_url} - http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1 - http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1 + + http://100.86.192.70:80/...extensionsConfig... + + http://100.86.192.70:80/...fullConfig... {certificates_url} - 68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml + 68ce47.0.68ce47.0.utl-trusty--292258.1.xml diff --git a/tests/unittests/test_handler/test_handler_apt_configure.py b/tests/unittests/test_handler/test_handler_apt_configure.py index 4a74ea47..1ed185ca 100644 --- a/tests/unittests/test_handler/test_handler_apt_configure.py +++ b/tests/unittests/test_handler/test_handler_apt_configure.py @@ -8,6 +8,7 @@ import re import shutil import tempfile + def load_tfile_or_url(*args, **kwargs): return(util.decode_binary(util.read_file_or_url(*args, **kwargs).contents)) -- cgit v1.2.3 From 2dc69f02e30cf6c89a9f0f036e797c78d8e00b84 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 15 May 2015 16:57:35 -0400 Subject: temporarily disasble test if no bin/cloud-init the tests of bin/cloud-init would fail in a package build environment. so, temporariliy skip them in that environment. --- tests/unittests/test_cli.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_cli.py b/tests/unittests/test_cli.py index ba447f87..ed863399 100644 --- a/tests/unittests/test_cli.py +++ b/tests/unittests/test_cli.py @@ -1,6 +1,6 @@ import imp +import os import sys - import six from . import helpers as test_helpers @@ -11,6 +11,9 @@ except ImportError: import mock +BIN_CLOUDINIT = "bin/cloud-init" + + class TestCLI(test_helpers.FilesystemMockingTestCase): def setUp(self): @@ -25,16 +28,18 @@ class TestCLI(test_helpers.FilesystemMockingTestCase): self.patched_funcs.enter_context( mock.patch.object(sys, 'argv', ['cloud-init'])) cli = imp.load_module( - 'cli', open('bin/cloud-init'), '', ('', 'r', imp.PY_SOURCE)) + 'cli', open(BIN_CLOUDINIT), '', ('', 'r', imp.PY_SOURCE)) try: return cli.main() except: pass + @test_helpers.skipIf(not os.path.isfile(BIN_CLOUDINIT), "no bin/cloudinit") def test_no_arguments_shows_usage(self): self._call_main() self.assertIn('usage: cloud-init', self.stderr.getvalue()) + @test_helpers.skipIf(not os.path.isfile(BIN_CLOUDINIT), "no bin/cloudinit") def test_no_arguments_exits_2(self): exit_code = self._call_main() if self.sys_exit.call_count: @@ -42,6 +47,7 @@ class TestCLI(test_helpers.FilesystemMockingTestCase): else: self.assertEqual(2, exit_code) + @test_helpers.skipIf(not os.path.isfile(BIN_CLOUDINIT), "no bin/cloudinit") def test_no_arguments_shows_error_message(self): self._call_main() self.assertIn('cloud-init: error: too few arguments', -- cgit v1.2.3 From cf2b017c8bf2adb02a2c7a9c9f03754402cb73c4 Mon Sep 17 00:00:00 2001 From: Brent Baude Date: Thu, 21 May 2015 13:32:30 -0500 Subject: This commit consists of three things based on feedback from smosher: cc_rh_subscription: Use of self.log.info limited, uses the util.subp for subprocesses, removed full path for subscription-manager cloud-config-rh_subscription.txt: A heavily commented example file on how to use rh_subscription and its main keys test_rh_subscription.py: a set of unittests for rh_subscription --- cloudinit/config/cc_rh_subscription.py | 181 +++++++++++++------------ tests/unittests/test_rh_subscription.py | 231 ++++++++++++++++++++++++++++++++ 2 files changed, 323 insertions(+), 89 deletions(-) create mode 100644 tests/unittests/test_rh_subscription.py (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rh_subscription.py b/cloudinit/config/cc_rh_subscription.py index b8056dbb..00a88456 100644 --- a/cloudinit/config/cc_rh_subscription.py +++ b/cloudinit/config/cc_rh_subscription.py @@ -16,15 +16,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import os -import subprocess import itertools +from cloudinit import util def handle(_name, cfg, _cloud, log, _args): sm = SubscriptionManager(cfg) sm.log = log - if not sm.is_registered: try: verify, verify_msg = sm._verify_keys() @@ -41,21 +39,22 @@ def handle(_name, cfg, _cloud, log, _args): # Attempt to change the service level if sm.auto_attach and sm.servicelevel is not None: - if not sm._set_service_level(): - raise SubscriptionError("Setting of service-level " - "failed") - else: - sm.log.info("Completed auto-attach with service level") + if not sm._set_service_level(): + raise SubscriptionError("Setting of service-level " + "failed") + else: + sm.log.debug("Completed auto-attach with service level") elif sm.auto_attach: if not sm._set_auto_attach(): raise SubscriptionError("Setting auto-attach failed") else: - sm.log.info("Completed auto-attach") + sm.log.debug("Completed auto-attach") if sm.pools is not None: if type(sm.pools) is not list: - raise SubscriptionError("Pools must in the format of a " - "list.") + pool_fail = "Pools must in the format of a list" + raise SubscriptionError(pool_fail) + return_stat = sm.addPool(sm.pools) if not return_stat: raise SubscriptionError("Unable to attach pools {0}" @@ -66,8 +65,8 @@ def handle(_name, cfg, _cloud, log, _args): raise SubscriptionError("Unable to add or remove repos") sm.log.info("rh_subscription plugin completed successfully") except SubscriptionError as e: - sm.log.warn(e) - sm.log.info("rh_subscription plugin did not complete successfully") + sm.log.warn(str(e)) + sm.log.warn("rh_subscription plugin did not complete successfully") else: sm.log.info("System is already registered") @@ -91,7 +90,7 @@ class SubscriptionManager(object): self.enable_repo = self.rhel_cfg.get('enable-repo') self.disable_repo = self.rhel_cfg.get('disable-repo') self.servicelevel = self.rhel_cfg.get('service-level') - self.subman = ['/bin/subscription-manager'] + self.subman = ['subscription-manager'] self.valid_rh_keys = ['org', 'activation-key', 'username', 'password', 'disable-repo', 'enable-repo', 'add-pool', 'rhsm-baseurl', 'server-hostname', @@ -135,11 +134,22 @@ class SubscriptionManager(object): ''' cmd = list(itertools.chain(self.subman, ['identity'])) - if subprocess.call(cmd, stdout=open(os.devnull, 'wb'), - stderr=open(os.devnull, 'wb')) == 1: + try: + self._sub_man_cli(cmd) + except util.ProcessExecutionError: return False - else: - return True + + return True + + def _sub_man_cli(self, cmd, logstring_val=False): + ''' + Uses the prefered cloud-init subprocess def of util.subp + and runs subscription-manager. Breaking this to a + separate function for later use in mocking and unittests + ''' + return_out, return_err = util.subp(cmd, logstring=logstring_val) + + return return_out, return_err def rhn_register(self): ''' @@ -163,11 +173,13 @@ class SubscriptionManager(object): if self.server_hostname is not None: cmd.append("--serverurl={0}".format(self.server_hostname)) - return_msg, return_code = self._captureRun(cmd) - - if return_code is not 0: - self.log.warn("Registration with {0} and {1} failed.".format( - self.activation_key, self.org)) + try: + return_out, return_err = self._sub_man_cli(cmd, + logstring_val=True) + except util.ProcessExecutionError as e: + if e.stdout == "": + self.log.warn("Registration failed due " + "to: {0}".format(e.stderr)) return False elif (self.userid is not None) and (self.password is not None): @@ -186,14 +198,13 @@ class SubscriptionManager(object): cmd.append("--serverurl={0}".format(self.server_hostname)) # Attempting to register the system only - return_msg, return_code = self._captureRun(cmd) - - if return_code is not 0: - # Return message is in a set - if return_msg[0] == "": - self.log.warn("Registration failed") - if return_msg[1] is not "": - self.log.warn(return_msg[1]) + try: + return_out, return_err = self._sub_man_cli(cmd, + logstring_val=True) + except util.ProcessExecutionError as e: + if e.stdout == "": + self.log.warn("Registration failed due " + "to: {0}".format(e.stderr)) return False else: @@ -203,8 +214,8 @@ class SubscriptionManager(object): "and password") return False - reg_id = return_msg[0].split("ID: ")[1].rstrip() - self.log.info("Registered successfully with ID {0}".format(reg_id)) + reg_id = return_out.split("ID: ")[1].rstrip() + self.log.debug("Registered successfully with ID {0}".format(reg_id)) return True def _set_service_level(self): @@ -212,41 +223,34 @@ class SubscriptionManager(object): ['attach', '--auto', '--servicelevel={0}' .format(self.servicelevel)])) - return_msg, return_code = self._captureRun(cmd) - - if return_code is not 0: - self.log.warn("Setting the service level failed with: " - "{0}".format(return_msg[1].strip())) + try: + return_out, return_err = self._sub_man_cli(cmd) + except util.ProcessExecutionError as e: + if e.stdout.rstrip() != '': + for line in e.stdout.split("\n"): + if line is not '': + self.log.warn(line) + else: + self.log.warn("Setting the service level failed with: " + "{0}".format(e.stderr.strip())) return False - else: - for line in return_msg[0].split("\n"): - if line is not "": - self.log.info(line) - return True + for line in return_out.split("\n"): + if line is not "": + self.log.debug(line) + return True def _set_auto_attach(self): cmd = list(itertools.chain(self.subman, ['attach', '--auto'])) - return_msg, return_code = self._captureRun(cmd) - - if return_code is not 0: + try: + return_out, return_err = self._sub_man_cli(cmd) + except util.ProcessExecutionError: self.log.warn("Auto-attach failed with: " - "{0}]".format(return_msg[1].strip())) + "{0}]".format(return_err.strip())) return False - else: - for line in return_msg[0].split("\n"): - if line is not "": - self.log.info(line) - return True - - def _captureRun(self, cmd): - ''' - Subprocess command that captures and returns the output and - return code. - ''' - - r = subprocess.Popen(cmd, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - return r.communicate(), r.returncode + for line in return_out.split("\n"): + if line is not "": + self.log.debug(line) + return True def _getPools(self): ''' @@ -259,14 +263,15 @@ class SubscriptionManager(object): # Get all available pools cmd = list(itertools.chain(self.subman, ['list', '--available', '--pool-only'])) - results = subprocess.check_output(cmd) + results, errors = self._sub_man_cli(cmd) available = (results.rstrip()).split("\n") - # Get all available pools + # Get all consumed pools cmd = list(itertools.chain(self.subman, ['list', '--consumed', '--pool-only'])) - results = subprocess.check_output(cmd) + results, errors = self._sub_man_cli(cmd) consumed = (results.rstrip()).split("\n") + return available, consumed def _getRepos(self): @@ -276,21 +281,19 @@ class SubscriptionManager(object): ''' cmd = list(itertools.chain(self.subman, ['repos', '--list-enabled'])) - result, return_code = self._captureRun(cmd) - + return_out, return_err = self._sub_man_cli(cmd) active_repos = [] - for repo in result[0].split("\n"): + for repo in return_out.split("\n"): if "Repo ID:" in repo: active_repos.append((repo.split(':')[1]).strip()) cmd = list(itertools.chain(self.subman, ['repos', '--list-disabled'])) - result, return_code = self._captureRun(cmd) + return_out, return_err = self._sub_man_cli(cmd) inactive_repos = [] - for repo in result[0].split("\n"): + for repo in return_out.split("\n"): if "Repo ID:" in repo: inactive_repos.append((repo.split(':')[1]).strip()) - return active_repos, inactive_repos def addPool(self, pools): @@ -301,7 +304,7 @@ class SubscriptionManager(object): # An empty list was passed if len(pools) == 0: - self.log.info("No pools to attach") + self.log.debug("No pools to attach") return True pool_available, pool_consumed = self._getPools() @@ -315,13 +318,14 @@ class SubscriptionManager(object): if len(pool_list) > 0: cmd.extend(pool_list) try: - self._captureRun(cmd) - self.log.info("Attached the following pools to your " - "system: %s" % (", ".join(pool_list)) - .replace('--pool=', '')) + self._sub_man_cli(cmd) + self.log.debug("Attached the following pools to your " + "system: %s" % (", ".join(pool_list)) + .replace('--pool=', '')) return True - except subprocess.CalledProcessError: - self.log.warn("Unable to attach pool {0}".format(pool)) + except util.ProcessExecutionError as e: + self.log.warn("Unable to attach pool {0} " + "due to {1}".format(pool, e)) return False def update_repos(self, erepos, drepos): @@ -341,7 +345,7 @@ class SubscriptionManager(object): # Bail if both lists are not populated if (len(erepos) == 0) and (len(drepos) == 0): - self.log.info("No repo IDs to enable or disable") + self.log.debug("No repo IDs to enable or disable") return True active_repos, inactive_repos = self._getRepos() @@ -368,14 +372,14 @@ class SubscriptionManager(object): for fail in enable_list_fail: # Check if the repo exists or not if fail in active_repos: - self.log.info("Repo {0} is already enabled".format(fail)) + self.log.debug("Repo {0} is already enabled".format(fail)) else: self.log.warn("Repo {0} does not appear to " "exist".format(fail)) if len(disable_list_fail) > 0: for fail in disable_list_fail: - self.log.info("Repo {0} not disabled " - "because it is not enabled".format(fail)) + self.log.debug("Repo {0} not disabled " + "because it is not enabled".format(fail)) cmd = list(itertools.chain(self.subman, ['repos'])) if enable_list > 0: @@ -384,16 +388,15 @@ class SubscriptionManager(object): cmd.extend(disable_list) try: - return_msg, return_code = self._captureRun(cmd) - - except subprocess.CalledProcessError as e: + self._sub_man_cli(cmd) + except util.ProcessExecutionError as e: self.log.warn("Unable to alter repos due to {0}".format(e)) return False if enable_list > 0: - self.log.info("Enabled the following repos: %s" % - (", ".join(enable_list)).replace('--enable=', '')) + self.log.debug("Enabled the following repos: %s" % + (", ".join(enable_list)).replace('--enable=', '')) if disable_list > 0: - self.log.info("Disabled the following repos: %s" % - (", ".join(disable_list)).replace('--disable=', '')) + self.log.debug("Disabled the following repos: %s" % + (", ".join(disable_list)).replace('--disable=', '')) return True diff --git a/tests/unittests/test_rh_subscription.py b/tests/unittests/test_rh_subscription.py new file mode 100644 index 00000000..ba9181ec --- /dev/null +++ b/tests/unittests/test_rh_subscription.py @@ -0,0 +1,231 @@ +from cloudinit import util +from cloudinit.config import cc_rh_subscription +import logging +import mock +import unittest + + +class GoodTests(unittest.TestCase): + def setUp(self): + super(GoodTests, self).setUp() + self.name = "cc_rh_subscription" + self.cloud_init = None + self.log = logging.getLogger("good_tests") + self.args = [] + self.handle = cc_rh_subscription.handle + self.SM = cc_rh_subscription.SubscriptionManager + + self.config = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks' + }} + self.config_full = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'auto-attach': True, + 'service-level': 'self-support', + 'add-pool': ['pool1', 'pool2', 'pool3'], + 'enable-repo': ['repo1', 'repo2', 'repo3'], + 'disable-repo': ['repo4', 'repo5'] + }} + + def test_already_registered(self): + ''' + Emulates a system that is already registered. Ensure it gets + a non-ProcessExecution error from is_registered() + ''' + good_message = 'System is already registered' + with mock.patch.object(cc_rh_subscription.SubscriptionManager, + '_sub_man_cli') as mockobj: + self.log.info = mock.MagicMock(wraps=self.log.info) + self.handle(self.name, self.config, self.cloud_init, + self.log, self.args) + self.assertEqual(mockobj.call_count, 1) + self.log.info.assert_called_with(good_message) + + def test_simple_registration(self): + ''' + Simple registration with username and password + ''' + good_message = 'rh_subscription plugin completed successfully' + self.log.info = mock.MagicMock(wraps=self.log.info) + reg = "The system has been registered with ID:" \ + " 12345678-abde-abcde-1234-1234567890abc" + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError, (reg, 'bar')]) + self.handle(self.name, self.config, self.cloud_init, + self.log, self.args) + self.SM._sub_man_cli.assert_called_with_once(['subscription-manager', + 'identity']) + self.SM._sub_man_cli.assert_called_with_once( + ['subscription-manager', 'register', '--username=scooby@do.com', + '--password=scooby-snacks'], logstring_val=True) + + self.log.info.assert_called_with(good_message) + self.assertEqual(self.SM._sub_man_cli.call_count, 2) + + def test_full_registration(self): + ''' + Registration with auto-attach, service-level, adding pools, + and enabling and disabling yum repos + ''' + pool_message = 'Pool pool2 is not available' + repo_message1 = 'Repo repo1 is already enabled' + repo_message2 = 'Enabled the following repos: repo2, repo3' + good_message = 'rh_subscription plugin completed successfully' + self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.log.debug = mock.MagicMock(wraps=self.log.debug) + reg = "The system has been registered with ID:" \ + " 12345678-abde-abcde-1234-1234567890abc" + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError, (reg, 'bar'), + ('Service level set to: self-support', ''), + ('pool1\npool3\n', ''), ('pool2\n', ''), ('', ''), + ('Repo ID: repo1\nRepo ID: repo5\n', ''), + ('Repo ID: repo2\nRepo ID: repo3\nRepo ID: ' + 'repo4', ''), + ('', '')]) + self.handle(self.name, self.config_full, self.cloud_init, + self.log, self.args) + self.log.warn.assert_any_call(pool_message) + self.log.debug.assert_any_call(repo_message1) + self.log.debug.assert_any_call(repo_message2) + self.log.info.assert_any_call(good_message) + self.SM._sub_man_cli.assert_called_with_once(['subscription-manager', + 'attach', '-pool=pool1', + '--pool=pool33']) + self.assertEqual(self.SM._sub_man_cli.call_count, 9) + + +class BadTests(unittest.TestCase): + def setUp(self): + super(BadTests, self).setUp() + self.name = "cc_rh_subscription" + self.cloud_init = None + self.log = logging.getLogger("bad_tests") + self.orig = self.log + self.args = [] + self.handle = cc_rh_subscription.handle + self.SM = cc_rh_subscription.SubscriptionManager + self.reg = "The system has been registered with ID:" \ + " 12345678-abde-abcde-1234-1234567890abc" + + self.config_no_password = {'rh_subscription': + {'username': 'scooby@do.com' + }} + + self.config_no_key = {'rh_subscription': + {'activation-key': '1234abcde', + }} + + self.config_service = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'service-level': 'self-support' + }} + + self.config_badpool = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'add-pool': 'not_a_list' + }} + self.config_badrepo = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'enable-repo': 'not_a_list' + }} + self.config_badkey = {'rh_subscription': + {'activation_key': 'abcdef1234', + 'org': '123', + }} + + def test_no_password(self): + ''' + Attempt to register without the password key/value + ''' + self.missing_info(self.config_no_password) + + def test_no_org(self): + ''' + Attempt to register without the org key/value + ''' + self.missing_info(self.config_no_key) + + def test_service_level_without_auto(self): + ''' + Attempt to register using service-level without the auto-attach key + ''' + good_message = 'The service-level key must be used in conjunction'\ + ' with the auto-attach key. Please re-run with '\ + 'auto-attach: True' + + self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) + self.handle(self.name, self.config_service, self.cloud_init, + self.log, self.args) + self.log.warn.assert_any_call(good_message) + self.assertRaises(cc_rh_subscription.SubscriptionError) + self.assertEqual(self.SM._sub_man_cli.call_count, 1) + + def test_pool_not_a_list(self): + ''' + Register with pools that are not in the format of a list + ''' + good_message = "Pools must in the format of a list" + self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) + self.handle(self.name, self.config_badpool, self.cloud_init, + self.log, self.args) + self.log.warn.assert_any_call(good_message) + self.assertRaises(cc_rh_subscription.SubscriptionError) + self.assertEqual(self.SM._sub_man_cli.call_count, 2) + + def test_repo_not_a_list(self): + ''' + Register with repos that are not in the format of a list + ''' + good_message = "Repo IDs must in the format of a list." + self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) + self.handle(self.name, self.config_badrepo, self.cloud_init, + self.log, self.args) + self.log.warn.assert_any_call(good_message) + self.assertRaises(cc_rh_subscription.SubscriptionError) + self.assertEqual(self.SM._sub_man_cli.call_count, 2) + + def test_bad_key_value(self): + ''' + Attempt to register with a key that we don't know + ''' + good_message = 'activation_key is not a valid key for rh_subscription.'\ + ' Valid keys are: org, activation-key, username, '\ + 'password, disable-repo, enable-repo, add-pool,'\ + ' rhsm-baseurl, server-hostname, auto-attach, '\ + 'service-level' + self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) + self.handle(self.name, self.config_badkey, self.cloud_init, + self.log, self.args) + self.assertRaises(cc_rh_subscription.SubscriptionError) + self.log.warn.assert_any_call(good_message) + self.assertEqual(self.SM._sub_man_cli.call_count, 1) + + def missing_info(self, config): + ''' + Helper def for tests that having missing information + ''' + good_message = "Unable to register system due to incomplete "\ + "information." + self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM._sub_man_cli = mock.MagicMock( + side_effect=[util.ProcessExecutionError]) + self.handle(self.name, config, self.cloud_init, + self.log, self.args) + self.SM._sub_man_cli.assert_called_with(['subscription-manager', + 'identity']) + self.log.warn.assert_any_call(good_message) + self.assertEqual(self.SM._sub_man_cli.call_count, 1) -- cgit v1.2.3 From 8af1802c9971ec1f2ebac23e9b42d5b42f43afae Mon Sep 17 00:00:00 2001 From: Ben Howard Date: Fri, 22 May 2015 10:28:17 -0600 Subject: AZURE: Redact on-disk user password in /var/lib/ovf-env.xml The fabric provides the user password in plain text via the CDROM, and cloud-init has previously wrote the ovf-env.xml in /var/lib/waagent with the password in plain text. This change redacts the password. --- cloudinit/sources/DataSourceAzure.py | 28 ++++++++-- tests/unittests/test_datasource/test_azure.py | 73 ++++++++++++++++++++++++--- 2 files changed, 91 insertions(+), 10 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index f2388c63..d0a882ca 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -23,6 +23,8 @@ import fnmatch import os import os.path import time +import xml.etree.ElementTree as ET + from xml.dom import minidom from cloudinit import log as logging @@ -68,6 +70,10 @@ BUILTIN_CLOUD_CONFIG = { DS_CFG_PATH = ['datasource', DS_NAME] DEF_EPHEMERAL_LABEL = 'Temporary Storage' +# The redacted password fails to meet password complexity requirements +# so we can safely use this to mask/redact the password in the ovf-env.xml +DEF_PASSWD_REDACTION = 'REDACTED' + def get_hostname(hostname_command='hostname'): return util.subp(hostname_command, capture=True)[0].strip() @@ -414,14 +420,30 @@ def wait_for_files(flist, maxwait=60, naplen=.5): def write_files(datadir, files, dirmode=None): + + def _redact_password(cnt, fname): + """Azure provides the UserPassword in plain text. So we redact it""" + try: + root = ET.fromstring(cnt) + for elem in root.iter(): + if ('UserPassword' in elem.tag and + elem.text != DEF_PASSWD_REDACTION): + elem.text = DEF_PASSWD_REDACTION + return ET.tostring(root) + except Exception as e: + LOG.critical("failed to redact userpassword in {}".format(fname)) + return cnt + if not datadir: return if not files: files = {} util.ensure_dir(datadir, dirmode) for (name, content) in files.items(): - util.write_file(filename=os.path.join(datadir, name), - content=content, mode=0o600) + fname = os.path.join(datadir, name) + if 'ovf-env.xml' in name: + content = _redact_password(content, fname) + util.write_file(filename=fname, content=content, mode=0o600) def invoke_agent(cmd): @@ -576,7 +598,7 @@ def read_azure_ovf(contents): defuser = {} if username: defuser['name'] = username - if password: + if password and DEF_PASSWD_REDACTION != password: defuser['passwd'] = encrypt_pass(password) defuser['lock_passwd'] = False diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 4c4b8eec..33b971f6 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -18,6 +18,7 @@ import stat import yaml import shutil import tempfile +import xml.etree.ElementTree as ET def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): @@ -144,6 +145,39 @@ class TestAzureDataSource(TestCase): return dsrc + def xml_equals(self, oxml, nxml): + """Compare two sets of XML to make sure they are equal""" + + def create_tag_index(xml): + et = ET.fromstring(xml) + ret = {} + for x in et.iter(): + ret[x.tag] = x + return ret + + def tags_exists(x, y): + for tag in x.keys(): + self.assertIn(tag, y) + for tag in y.keys(): + self.assertIn(tag, x) + + def tags_equal(x, y): + for x_tag, x_val in x.items(): + y_val = y.get(x_val.tag) + self.assertEquals(x_val.text, y_val.text) + + old_cnt = create_tag_index(oxml) + new_cnt = create_tag_index(nxml) + tags_exists(old_cnt, new_cnt) + tags_equal(old_cnt, new_cnt) + + def xml_notequals(self, oxml, nxml): + try: + self.xml_equals(oxml, nxml) + except AssertionError as e: + return + raise AssertionError("XML is the same") + def test_basic_seed_dir(self): odata = {'HostName': "myhost", 'UserName': "myuser"} data = {'ovfcontent': construct_valid_ovf_env(data=odata), @@ -322,6 +356,31 @@ class TestAzureDataSource(TestCase): self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw) + def test_password_redacted_in_ovf(self): + odata = {'HostName': "myhost", 'UserName': "myuser", + 'UserPassword': "mypass"} + data = {'ovfcontent': construct_valid_ovf_env(data=odata)} + dsrc = self._get_ds(data) + ret = dsrc.get_data() + + self.assertTrue(ret) + ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml') + + # The XML should not be same since the user password is redacted + on_disk_ovf = load_file(ovf_env_path) + self.xml_notequals(data['ovfcontent'], on_disk_ovf) + + # Make sure that the redacted password on disk is not used by CI + self.assertNotEquals(dsrc.cfg.get('password'), + DataSourceAzure.DEF_PASSWD_REDACTION) + + # Make sure that the password was really encrypted + et = ET.fromstring(on_disk_ovf) + for elem in et.iter(): + if 'UserPassword' in elem.tag: + self.assertEquals(DataSourceAzure.DEF_PASSWD_REDACTION, + elem.text) + def test_ovf_env_arrives_in_waagent_dir(self): xml = construct_valid_ovf_env(data={}, userdata="FOODATA") dsrc = self._get_ds({'ovfcontent': xml}) @@ -331,7 +390,7 @@ class TestAzureDataSource(TestCase): # we expect that the ovf-env.xml file is copied there. ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml') self.assertTrue(os.path.exists(ovf_env_path)) - self.assertEqual(xml, load_file(ovf_env_path)) + self.xml_equals(xml, load_file(ovf_env_path)) def test_ovf_can_include_unicode(self): xml = construct_valid_ovf_env(data={}) @@ -380,12 +439,12 @@ class TestAzureDataSource(TestCase): self.assertEqual(dsrc.userdata_raw, b"NEW_USERDATA") self.assertTrue(os.path.exists( os.path.join(self.waagent_d, 'otherfile'))) - self.assertFalse( - os.path.exists(os.path.join(self.waagent_d, 'SharedConfig.xml'))) - self.assertTrue( - os.path.exists(os.path.join(self.waagent_d, 'ovf-env.xml'))) - self.assertEqual(new_ovfenv, - load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) + self.assertFalse(os.path.exists( + os.path.join(self.waagent_d, 'SharedConfig.xml'))) + self.assertTrue(os.path.exists( + os.path.join(self.waagent_d, 'ovf-env.xml'))) + new_xml = load_file(os.path.join(self.waagent_d, 'ovf-env.xml')) + self.xml_equals(new_ovfenv, new_xml) def test_exception_fetching_fabric_data_doesnt_propagate(self): ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) -- cgit v1.2.3 From d9470a429935d4a5e12a5a3d1f57867362f92c57 Mon Sep 17 00:00:00 2001 From: Brent Baude Date: Wed, 27 May 2015 13:01:35 -0500 Subject: Updated files with upstream review comments thanks to Dan and Scott --- cloudinit/config/cc_rh_subscription.py | 108 ++++++++++---------- tests/unittests/test_rh_subscription.py | 169 ++++++++++++++------------------ 2 files changed, 128 insertions(+), 149 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rh_subscription.py b/cloudinit/config/cc_rh_subscription.py index 00a88456..db3d5525 100644 --- a/cloudinit/config/cc_rh_subscription.py +++ b/cloudinit/config/cc_rh_subscription.py @@ -1,6 +1,6 @@ # vi: ts=4 expandtab # -# Copyright (C) Red Hat, Inc. +# Copyright (C) 2015 Red Hat, Inc. # # Author: Brent Baude # @@ -16,7 +16,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import itertools from cloudinit import util @@ -51,10 +50,10 @@ def handle(_name, cfg, _cloud, log, _args): sm.log.debug("Completed auto-attach") if sm.pools is not None: - if type(sm.pools) is not list: + if not isinstance(sm.pools, (list)): pool_fail = "Pools must in the format of a list" raise SubscriptionError(pool_fail) - + return_stat = sm.addPool(sm.pools) if not return_stat: raise SubscriptionError("Unable to attach pools {0}" @@ -63,12 +62,12 @@ def handle(_name, cfg, _cloud, log, _args): return_stat = sm.update_repos(sm.enable_repo, sm.disable_repo) if not return_stat: raise SubscriptionError("Unable to add or remove repos") - sm.log.info("rh_subscription plugin completed successfully") + sm.log_sucess("rh_subscription plugin completed successfully") except SubscriptionError as e: - sm.log.warn(str(e)) - sm.log.warn("rh_subscription plugin did not complete successfully") + sm.log_warn(str(e)) + sm.log_warn("rh_subscription plugin did not complete successfully") else: - sm.log.info("System is already registered") + sm.log_sucess("System is already registered") class SubscriptionError(Exception): @@ -76,6 +75,11 @@ class SubscriptionError(Exception): class SubscriptionManager(object): + valid_rh_keys = ['org', 'activation-key', 'username', 'password', + 'disable-repo', 'enable-repo', 'add-pool', + 'rhsm-baseurl', 'server-hostname', + 'auto-attach', 'service-level'] + def __init__(self, cfg): self.cfg = cfg self.rhel_cfg = self.cfg.get('rh_subscription', {}) @@ -91,12 +95,16 @@ class SubscriptionManager(object): self.disable_repo = self.rhel_cfg.get('disable-repo') self.servicelevel = self.rhel_cfg.get('service-level') self.subman = ['subscription-manager'] - self.valid_rh_keys = ['org', 'activation-key', 'username', 'password', - 'disable-repo', 'enable-repo', 'add-pool', - 'rhsm-baseurl', 'server-hostname', - 'auto-attach', 'service-level'] self.is_registered = self._is_registered() + def log_sucess(self, msg): + '''Simple wrapper for logging info messages. Useful for unittests''' + self.log.info(msg) + + def log_warn(self, msg): + '''Simple wrapper for logging warning messages. Useful for unittests''' + self.log.warn(msg) + def _verify_keys(self): ''' Checks that the keys in the rh_subscription dict from the user-data @@ -112,14 +120,15 @@ class SubscriptionManager(object): # Check for bad auto-attach value if (self.auto_attach is not None) and \ - (str(self.auto_attach).upper() not in ['TRUE', 'FALSE']): + not (util.is_true(self.auto_attach) or + util.is_false(self.auto_attach)): not_bool = "The key auto-attach must be a value of "\ "either True or False" return False, not_bool if (self.servicelevel is not None) and \ - ((not self.auto_attach) or - (str(self.auto_attach).upper() == "FALSE")): + ((not self.auto_attach) + or (util.is_false(str(self.auto_attach)))): no_auto = "The service-level key must be used in conjunction with "\ "the auto-attach key. Please re-run with auto-attach: "\ @@ -132,7 +141,7 @@ class SubscriptionManager(object): Checks if the system is already registered and returns True if so, else False ''' - cmd = list(itertools.chain(self.subman, ['identity'])) + cmd = ['identity'] try: self._sub_man_cli(cmd) @@ -147,9 +156,8 @@ class SubscriptionManager(object): and runs subscription-manager. Breaking this to a separate function for later use in mocking and unittests ''' - return_out, return_err = util.subp(cmd, logstring=logstring_val) - - return return_out, return_err + cmd = self.subman + cmd + return util.subp(cmd, logstring=logstring_val) def rhn_register(self): ''' @@ -159,10 +167,8 @@ class SubscriptionManager(object): if (self.activation_key is not None) and (self.org is not None): # register by activation key - cmd = list(itertools.chain(self.subman, ['register', - '--activationkey={0}'. - format(self.activation_key), - '--org={0}'.format(self.org)])) + cmd = ['register', '--activationkey={0}'. + format(self.activation_key), '--org={0}'.format(self.org)] # If the baseurl and/or server url are passed in, we register # with them. @@ -178,15 +184,14 @@ class SubscriptionManager(object): logstring_val=True) except util.ProcessExecutionError as e: if e.stdout == "": - self.log.warn("Registration failed due " + self.log_warn("Registration failed due " "to: {0}".format(e.stderr)) return False elif (self.userid is not None) and (self.password is not None): # register by username and password - cmd = list(itertools.chain(self.subman, ['register', - '--username={0}'.format(self.userid), - '--password={0}'.format(self.password)])) + cmd = ['register', '--username={0}'.format(self.userid), + '--password={0}'.format(self.password)] # If the baseurl and/or server url are passed in, we register # with them. @@ -203,14 +208,14 @@ class SubscriptionManager(object): logstring_val=True) except util.ProcessExecutionError as e: if e.stdout == "": - self.log.warn("Registration failed due " + self.log_warn("Registration failed due " "to: {0}".format(e.stderr)) return False else: - self.log.warn("Unable to register system due to incomplete " + self.log_warn("Unable to register system due to incomplete " "information.") - self.log.warn("Use either activationkey and org *or* userid " + self.log_warn("Use either activationkey and org *or* userid " "and password") return False @@ -219,9 +224,8 @@ class SubscriptionManager(object): return True def _set_service_level(self): - cmd = list(itertools.chain(self.subman, - ['attach', '--auto', '--servicelevel={0}' - .format(self.servicelevel)])) + cmd = ['attach', '--auto', '--servicelevel={0}' + .format(self.servicelevel)] try: return_out, return_err = self._sub_man_cli(cmd) @@ -229,9 +233,9 @@ class SubscriptionManager(object): if e.stdout.rstrip() != '': for line in e.stdout.split("\n"): if line is not '': - self.log.warn(line) + self.log_warn(line) else: - self.log.warn("Setting the service level failed with: " + self.log_warn("Setting the service level failed with: " "{0}".format(e.stderr.strip())) return False for line in return_out.split("\n"): @@ -240,11 +244,11 @@ class SubscriptionManager(object): return True def _set_auto_attach(self): - cmd = list(itertools.chain(self.subman, ['attach', '--auto'])) + cmd = ['attach', '--auto'] try: return_out, return_err = self._sub_man_cli(cmd) except util.ProcessExecutionError: - self.log.warn("Auto-attach failed with: " + self.log_warn("Auto-attach failed with: " "{0}]".format(return_err.strip())) return False for line in return_out.split("\n"): @@ -261,14 +265,12 @@ class SubscriptionManager(object): consumed = [] # Get all available pools - cmd = list(itertools.chain(self.subman, ['list', '--available', - '--pool-only'])) + cmd = ['list', '--available', '--pool-only'] results, errors = self._sub_man_cli(cmd) available = (results.rstrip()).split("\n") # Get all consumed pools - cmd = list(itertools.chain(self.subman, ['list', '--consumed', - '--pool-only'])) + cmd = ['list', '--consumed', '--pool-only'] results, errors = self._sub_man_cli(cmd) consumed = (results.rstrip()).split("\n") @@ -280,14 +282,14 @@ class SubscriptionManager(object): them in list form. ''' - cmd = list(itertools.chain(self.subman, ['repos', '--list-enabled'])) + cmd = ['repos', '--list-enabled'] return_out, return_err = self._sub_man_cli(cmd) active_repos = [] for repo in return_out.split("\n"): if "Repo ID:" in repo: active_repos.append((repo.split(':')[1]).strip()) - cmd = list(itertools.chain(self.subman, ['repos', '--list-disabled'])) + cmd = ['repos', '--list-disabled'] return_out, return_err = self._sub_man_cli(cmd) inactive_repos = [] @@ -309,12 +311,12 @@ class SubscriptionManager(object): pool_available, pool_consumed = self._getPools() pool_list = [] - cmd = list(itertools.chain(self.subman, ['attach'])) + cmd = ['attach'] for pool in pools: if (pool not in pool_consumed) and (pool in pool_available): pool_list.append('--pool={0}'.format(pool)) else: - self.log.warn("Pool {0} is not available".format(pool)) + self.log_warn("Pool {0} is not available".format(pool)) if len(pool_list) > 0: cmd.extend(pool_list) try: @@ -324,7 +326,7 @@ class SubscriptionManager(object): .replace('--pool=', '')) return True except util.ProcessExecutionError as e: - self.log.warn("Unable to attach pool {0} " + self.log_warn("Unable to attach pool {0} " "due to {1}".format(pool, e)) return False @@ -335,12 +337,12 @@ class SubscriptionManager(object): executes the action to disable or enable ''' - if (erepos is not None) and (type(erepos) is not list): - self.log.warn("Repo IDs must in the format of a list.") + if (erepos is not None) and (not isinstance(erepos, (list))): + self.log_warn("Repo IDs must in the format of a list.") return False - if (drepos is not None) and (type(drepos) is not list): - self.log.warn("Repo IDs must in the format of a list.") + if (drepos is not None) and (not isinstance(drepos, (list))): + self.log_warn("Repo IDs must in the format of a list.") return False # Bail if both lists are not populated @@ -374,14 +376,14 @@ class SubscriptionManager(object): if fail in active_repos: self.log.debug("Repo {0} is already enabled".format(fail)) else: - self.log.warn("Repo {0} does not appear to " + self.log_warn("Repo {0} does not appear to " "exist".format(fail)) if len(disable_list_fail) > 0: for fail in disable_list_fail: self.log.debug("Repo {0} not disabled " "because it is not enabled".format(fail)) - cmd = list(itertools.chain(self.subman, ['repos'])) + cmd = ['repos'] if enable_list > 0: cmd.extend(enable_list) if disable_list > 0: @@ -390,7 +392,7 @@ class SubscriptionManager(object): try: self._sub_man_cli(cmd) except util.ProcessExecutionError as e: - self.log.warn("Unable to alter repos due to {0}".format(e)) + self.log_warn("Unable to alter repos due to {0}".format(e)) return False if enable_list > 0: diff --git a/tests/unittests/test_rh_subscription.py b/tests/unittests/test_rh_subscription.py index ba9181ec..2f813f41 100644 --- a/tests/unittests/test_rh_subscription.py +++ b/tests/unittests/test_rh_subscription.py @@ -34,34 +34,33 @@ class GoodTests(unittest.TestCase): Emulates a system that is already registered. Ensure it gets a non-ProcessExecution error from is_registered() ''' - good_message = 'System is already registered' with mock.patch.object(cc_rh_subscription.SubscriptionManager, '_sub_man_cli') as mockobj: - self.log.info = mock.MagicMock(wraps=self.log.info) + self.SM.log_sucess = mock.MagicMock() self.handle(self.name, self.config, self.cloud_init, self.log, self.args) + self.assertEqual(self.SM.log_sucess.call_count, 1) self.assertEqual(mockobj.call_count, 1) - self.log.info.assert_called_with(good_message) def test_simple_registration(self): ''' Simple registration with username and password ''' - good_message = 'rh_subscription plugin completed successfully' - self.log.info = mock.MagicMock(wraps=self.log.info) + self.SM.log_sucess = mock.MagicMock() reg = "The system has been registered with ID:" \ " 12345678-abde-abcde-1234-1234567890abc" self.SM._sub_man_cli = mock.MagicMock( side_effect=[util.ProcessExecutionError, (reg, 'bar')]) self.handle(self.name, self.config, self.cloud_init, self.log, self.args) - self.SM._sub_man_cli.assert_called_with_once(['subscription-manager', - 'identity']) - self.SM._sub_man_cli.assert_called_with_once( - ['subscription-manager', 'register', '--username=scooby@do.com', - '--password=scooby-snacks'], logstring_val=True) - - self.log.info.assert_called_with(good_message) + self.assertIn(mock.call(['identity']), + self.SM._sub_man_cli.call_args_list) + self.assertIn(mock.call(['register', '--username=scooby@do.com', + '--password=scooby-snacks'], + logstring_val=True), + self.SM._sub_man_cli.call_args_list) + + self.assertEqual(self.SM.log_sucess.call_count, 1) self.assertEqual(self.SM._sub_man_cli.call_count, 2) def test_full_registration(self): @@ -69,12 +68,12 @@ class GoodTests(unittest.TestCase): Registration with auto-attach, service-level, adding pools, and enabling and disabling yum repos ''' - pool_message = 'Pool pool2 is not available' - repo_message1 = 'Repo repo1 is already enabled' - repo_message2 = 'Enabled the following repos: repo2, repo3' - good_message = 'rh_subscription plugin completed successfully' - self.log.warn = mock.MagicMock(wraps=self.log.warn) - self.log.debug = mock.MagicMock(wraps=self.log.debug) + call_lists = [] + call_lists.append(['attach', '--pool=pool1', '--pool=pool3']) + call_lists.append(['repos', '--enable=repo2', '--enable=repo3', + '--disable=repo5']) + call_lists.append(['attach', '--auto', '--servicelevel=self-support']) + self.SM.log_sucess = mock.MagicMock() reg = "The system has been registered with ID:" \ " 12345678-abde-abcde-1234-1234567890abc" self.SM._sub_man_cli = mock.MagicMock( @@ -87,145 +86,123 @@ class GoodTests(unittest.TestCase): ('', '')]) self.handle(self.name, self.config_full, self.cloud_init, self.log, self.args) - self.log.warn.assert_any_call(pool_message) - self.log.debug.assert_any_call(repo_message1) - self.log.debug.assert_any_call(repo_message2) - self.log.info.assert_any_call(good_message) - self.SM._sub_man_cli.assert_called_with_once(['subscription-manager', - 'attach', '-pool=pool1', - '--pool=pool33']) + for call in call_lists: + self.assertIn(mock.call(call), self.SM._sub_man_cli.call_args_list) + self.assertEqual(self.SM.log_sucess.call_count, 1) self.assertEqual(self.SM._sub_man_cli.call_count, 9) -class BadTests(unittest.TestCase): +class TestBadInput(unittest.TestCase): + name = "cc_rh_subscription" + cloud_init = None + log = logging.getLogger("bad_tests") + args = [] + SM = cc_rh_subscription.SubscriptionManager + reg = "The system has been registered with ID:" \ + " 12345678-abde-abcde-1234-1234567890abc" + + config_no_password = {'rh_subscription': + {'username': 'scooby@do.com' + }} + + config_no_key = {'rh_subscription': + {'activation-key': '1234abcde', + }} + + config_service = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'service-level': 'self-support' + }} + + config_badpool = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'add-pool': 'not_a_list' + }} + config_badrepo = {'rh_subscription': + {'username': 'scooby@do.com', + 'password': 'scooby-snacks', + 'enable-repo': 'not_a_list' + }} + config_badkey = {'rh_subscription': + {'activation_key': 'abcdef1234', + 'org': '123', + }} + def setUp(self): - super(BadTests, self).setUp() - self.name = "cc_rh_subscription" - self.cloud_init = None - self.log = logging.getLogger("bad_tests") - self.orig = self.log - self.args = [] + super(TestBadInput, self).setUp() self.handle = cc_rh_subscription.handle - self.SM = cc_rh_subscription.SubscriptionManager - self.reg = "The system has been registered with ID:" \ - " 12345678-abde-abcde-1234-1234567890abc" - - self.config_no_password = {'rh_subscription': - {'username': 'scooby@do.com' - }} - - self.config_no_key = {'rh_subscription': - {'activation-key': '1234abcde', - }} - - self.config_service = {'rh_subscription': - {'username': 'scooby@do.com', - 'password': 'scooby-snacks', - 'service-level': 'self-support' - }} - - self.config_badpool = {'rh_subscription': - {'username': 'scooby@do.com', - 'password': 'scooby-snacks', - 'add-pool': 'not_a_list' - }} - self.config_badrepo = {'rh_subscription': - {'username': 'scooby@do.com', - 'password': 'scooby-snacks', - 'enable-repo': 'not_a_list' - }} - self.config_badkey = {'rh_subscription': - {'activation_key': 'abcdef1234', - 'org': '123', - }} def test_no_password(self): ''' Attempt to register without the password key/value ''' - self.missing_info(self.config_no_password) + self.input_is_missing_data(self.config_no_password) def test_no_org(self): ''' Attempt to register without the org key/value ''' - self.missing_info(self.config_no_key) + self.input_is_missing_data(self.config_no_key) def test_service_level_without_auto(self): ''' Attempt to register using service-level without the auto-attach key ''' - good_message = 'The service-level key must be used in conjunction'\ - ' with the auto-attach key. Please re-run with '\ - 'auto-attach: True' - - self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM.log_warn = mock.MagicMock() self.SM._sub_man_cli = mock.MagicMock( side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) self.handle(self.name, self.config_service, self.cloud_init, self.log, self.args) - self.log.warn.assert_any_call(good_message) - self.assertRaises(cc_rh_subscription.SubscriptionError) self.assertEqual(self.SM._sub_man_cli.call_count, 1) + self.assertEqual(self.SM.log_warn.call_count, 2) def test_pool_not_a_list(self): ''' Register with pools that are not in the format of a list ''' - good_message = "Pools must in the format of a list" - self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM.log_warn = mock.MagicMock() self.SM._sub_man_cli = mock.MagicMock( side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) self.handle(self.name, self.config_badpool, self.cloud_init, self.log, self.args) - self.log.warn.assert_any_call(good_message) - self.assertRaises(cc_rh_subscription.SubscriptionError) self.assertEqual(self.SM._sub_man_cli.call_count, 2) + self.assertEqual(self.SM.log_warn.call_count, 2) def test_repo_not_a_list(self): ''' Register with repos that are not in the format of a list ''' - good_message = "Repo IDs must in the format of a list." - self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM.log_warn = mock.MagicMock() self.SM._sub_man_cli = mock.MagicMock( side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) self.handle(self.name, self.config_badrepo, self.cloud_init, self.log, self.args) - self.log.warn.assert_any_call(good_message) - self.assertRaises(cc_rh_subscription.SubscriptionError) + self.assertEqual(self.SM.log_warn.call_count, 3) self.assertEqual(self.SM._sub_man_cli.call_count, 2) def test_bad_key_value(self): ''' Attempt to register with a key that we don't know ''' - good_message = 'activation_key is not a valid key for rh_subscription.'\ - ' Valid keys are: org, activation-key, username, '\ - 'password, disable-repo, enable-repo, add-pool,'\ - ' rhsm-baseurl, server-hostname, auto-attach, '\ - 'service-level' - self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM.log_warn = mock.MagicMock() self.SM._sub_man_cli = mock.MagicMock( side_effect=[util.ProcessExecutionError, (self.reg, 'bar')]) self.handle(self.name, self.config_badkey, self.cloud_init, self.log, self.args) - self.assertRaises(cc_rh_subscription.SubscriptionError) - self.log.warn.assert_any_call(good_message) + self.assertEqual(self.SM.log_warn.call_count, 2) self.assertEqual(self.SM._sub_man_cli.call_count, 1) - def missing_info(self, config): + def input_is_missing_data(self, config): ''' Helper def for tests that having missing information ''' - good_message = "Unable to register system due to incomplete "\ - "information." - self.log.warn = mock.MagicMock(wraps=self.log.warn) + self.SM.log_warn = mock.MagicMock() self.SM._sub_man_cli = mock.MagicMock( side_effect=[util.ProcessExecutionError]) self.handle(self.name, config, self.cloud_init, self.log, self.args) - self.SM._sub_man_cli.assert_called_with(['subscription-manager', - 'identity']) - self.log.warn.assert_any_call(good_message) + self.SM._sub_man_cli.assert_called_with(['identity']) + self.assertEqual(self.SM.log_warn.call_count, 4) self.assertEqual(self.SM._sub_man_cli.call_count, 1) -- cgit v1.2.3 From 3c01b8e48400697362f190984ab9c96dee27a369 Mon Sep 17 00:00:00 2001 From: Brent Baude Date: Fri, 29 May 2015 09:18:49 -0500 Subject: Corrected spelling error on variable name --- cloudinit/config/cc_rh_subscription.py | 6 +++--- tests/unittests/test_rh_subscription.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rh_subscription.py b/cloudinit/config/cc_rh_subscription.py index e57e8a07..cabebca4 100644 --- a/cloudinit/config/cc_rh_subscription.py +++ b/cloudinit/config/cc_rh_subscription.py @@ -62,12 +62,12 @@ def handle(_name, cfg, _cloud, log, _args): return_stat = sm.update_repos(sm.enable_repo, sm.disable_repo) if not return_stat: raise SubscriptionError("Unable to add or remove repos") - sm.log_sucess("rh_subscription plugin completed successfully") + sm.log_success("rh_subscription plugin completed successfully") except SubscriptionError as e: sm.log_warn(str(e)) sm.log_warn("rh_subscription plugin did not complete successfully") else: - sm.log_sucess("System is already registered") + sm.log_success("System is already registered") class SubscriptionError(Exception): @@ -97,7 +97,7 @@ class SubscriptionManager(object): self.subman = ['subscription-manager'] self.is_registered = self._is_registered() - def log_sucess(self, msg): + def log_success(self, msg): '''Simple wrapper for logging info messages. Useful for unittests''' self.log.info(msg) diff --git a/tests/unittests/test_rh_subscription.py b/tests/unittests/test_rh_subscription.py index 2f813f41..38d5763a 100644 --- a/tests/unittests/test_rh_subscription.py +++ b/tests/unittests/test_rh_subscription.py @@ -36,17 +36,17 @@ class GoodTests(unittest.TestCase): ''' with mock.patch.object(cc_rh_subscription.SubscriptionManager, '_sub_man_cli') as mockobj: - self.SM.log_sucess = mock.MagicMock() + self.SM.log_success = mock.MagicMock() self.handle(self.name, self.config, self.cloud_init, self.log, self.args) - self.assertEqual(self.SM.log_sucess.call_count, 1) + self.assertEqual(self.SM.log_success.call_count, 1) self.assertEqual(mockobj.call_count, 1) def test_simple_registration(self): ''' Simple registration with username and password ''' - self.SM.log_sucess = mock.MagicMock() + self.SM.log_success = mock.MagicMock() reg = "The system has been registered with ID:" \ " 12345678-abde-abcde-1234-1234567890abc" self.SM._sub_man_cli = mock.MagicMock( @@ -60,7 +60,7 @@ class GoodTests(unittest.TestCase): logstring_val=True), self.SM._sub_man_cli.call_args_list) - self.assertEqual(self.SM.log_sucess.call_count, 1) + self.assertEqual(self.SM.log_success.call_count, 1) self.assertEqual(self.SM._sub_man_cli.call_count, 2) def test_full_registration(self): @@ -73,7 +73,7 @@ class GoodTests(unittest.TestCase): call_lists.append(['repos', '--enable=repo2', '--enable=repo3', '--disable=repo5']) call_lists.append(['attach', '--auto', '--servicelevel=self-support']) - self.SM.log_sucess = mock.MagicMock() + self.SM.log_success = mock.MagicMock() reg = "The system has been registered with ID:" \ " 12345678-abde-abcde-1234-1234567890abc" self.SM._sub_man_cli = mock.MagicMock( @@ -88,7 +88,7 @@ class GoodTests(unittest.TestCase): self.log, self.args) for call in call_lists: self.assertIn(mock.call(call), self.SM._sub_man_cli.call_args_list) - self.assertEqual(self.SM.log_sucess.call_count, 1) + self.assertEqual(self.SM.log_success.call_count, 1) self.assertEqual(self.SM._sub_man_cli.call_count, 9) -- cgit v1.2.3 From 2a955375dbf0d546d8999793923c4e8f23e042c6 Mon Sep 17 00:00:00 2001 From: Lars Kellogg-Stedman Date: Wed, 3 Jun 2015 13:16:23 -0400 Subject: transform paths in functions taking more than a single argument Patch FilesystemMockingTestcase.patchOS to support methods taking more than a single path argument. This is required in order to properly mock `os.symlink`, which takes two path arguments. --- tests/unittests/helpers.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/helpers.py b/tests/unittests/helpers.py index 61a1f6ff..7f4b8784 100644 --- a/tests/unittests/helpers.py +++ b/tests/unittests/helpers.py @@ -248,13 +248,15 @@ class FilesystemMockingTestCase(ResourceUsingTestCase): def patchOS(self, new_root): patch_funcs = { - os.path: ['isfile', 'exists', 'islink', 'isdir'], - os: ['listdir'], + os.path: [('isfile', 1), ('exists', 1), + ('islink', 1), ('isdir', 1)], + os: [('listdir', 1), ('mkdir', 1), + ('lstat', 1), ('symlink', 2)], } for (mod, funcs) in patch_funcs.items(): - for f in funcs: + for f, nargs in funcs: func = getattr(mod, f) - trap_func = retarget_many_wrapper(new_root, 1, func) + trap_func = retarget_many_wrapper(new_root, nargs, func) self.patched_funcs.enter_context( mock.patch.object(mod, f, trap_func)) -- cgit v1.2.3 From 8db399f9149a81de5d65f0759792766ecd509ab3 Mon Sep 17 00:00:00 2001 From: Lars Kellogg-Stedman Date: Wed, 3 Jun 2015 13:18:38 -0400 Subject: add tests for systemd detection This adds the following tests in test_distros.test_generic: - test_systemd_in_use Test the situation in which /run/systemd/system exists. - test_systemd_not_in_use Test the situation in which /run/systemd/system does not exists. - test_systemd_symlink This tests the situation in which /run/systemd/system exists but is a *symlink* to a directory, which according to sd_booted() should return false. --- tests/unittests/test_distros/test_generic.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'tests/unittests') diff --git a/tests/unittests/test_distros/test_generic.py b/tests/unittests/test_distros/test_generic.py index 35153f0d..8e3bd78a 100644 --- a/tests/unittests/test_distros/test_generic.py +++ b/tests/unittests/test_distros/test_generic.py @@ -194,6 +194,29 @@ class TestGenericDistro(helpers.FilesystemMockingTestCase): {'primary': 'http://fs-primary-intel', 'security': 'http://security-mirror2-intel'}) + def test_systemd_in_use(self): + cls = distros.fetch("ubuntu") + d = cls("ubuntu", {}, None) + self.patchOS(self.tmp) + self.patchUtils(self.tmp) + os.makedirs('/run/systemd/system') + self.assertTrue(d.uses_systemd()) + + def test_systemd_not_in_use(self): + cls = distros.fetch("ubuntu") + d = cls("ubuntu", {}, None) + self.patchOS(self.tmp) + self.patchUtils(self.tmp) + self.assertFalse(d.uses_systemd()) + + def test_systemd_symlink(self): + cls = distros.fetch("ubuntu") + d = cls("ubuntu", {}, None) + self.patchOS(self.tmp) + self.patchUtils(self.tmp) + os.makedirs('/run/systemd') + os.symlink('/', '/run/systemd/system') + self.assertFalse(d.uses_systemd()) # def _get_package_mirror_info(mirror_info, availability_zone=None, # mirror_filter=util.search_for_mirror): -- cgit v1.2.3 From ba7fc871f2e73e0adbf883ef8253180f41cdcfe8 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Tue, 16 Jun 2015 17:35:03 +0100 Subject: Use wget to fetch CloudStack passwords. Different versions of the CloudStack password server respond differently; wget handles these nicely for us, so it's easier to just use wget. LP: #1440263, #1464253 --- cloudinit/sources/DataSourceCloudStack.py | 35 +++++++--------------- tests/unittests/test_datasource/test_cloudstack.py | 30 +++++++++---------- 2 files changed, 25 insertions(+), 40 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceCloudStack.py b/cloudinit/sources/DataSourceCloudStack.py index 7b32e1fa..d0cac5bb 100644 --- a/cloudinit/sources/DataSourceCloudStack.py +++ b/cloudinit/sources/DataSourceCloudStack.py @@ -29,8 +29,6 @@ import time from socket import inet_ntoa from struct import pack -from six.moves import http_client - from cloudinit import ec2_utils as ec2 from cloudinit import log as logging from cloudinit import url_helper as uhelp @@ -47,35 +45,22 @@ class CloudStackPasswordServerClient(object): has documentation about the system. This implementation is following that found at https://github.com/shankerbalan/cloudstack-scripts/blob/master/cloud-set-guest-password-debian - - The CloudStack password server is, essentially, a broken HTTP - server. It requires us to provide a valid HTTP request (including a - DomU_Request header, which is the meat of the request), but just - writes the text of its response on to the socket, without a status - line or any HTTP headers. This makes HTTP libraries sad, which - explains the screwiness of the implementation of this class. - - This should be fixed in CloudStack by commit - a72f14ea9cb832faaac946b3cf9f56856b50142a in December 2014. """ def __init__(self, virtual_router_address): self.virtual_router_address = virtual_router_address def _do_request(self, domu_request): - # We have to provide a valid HTTP request, but a valid HTTP - # response is not returned. This means that getresponse() chokes, - # so we use the socket directly to read off the response. - # Because we're reading off the socket directly, we can't re-use the - # connection. - conn = http_client.HTTPConnection(self.virtual_router_address, 8080) - try: - conn.request('GET', '', headers={'DomU_Request': domu_request}) - conn.sock.settimeout(30) - output = conn.sock.recv(1024).decode('utf-8').strip() - finally: - conn.close() - return output + # The password server was in the past, a broken HTTP server, but is now + # fixed. wget handles this seamlessly, so it's easier to shell out to + # that rather than write our own handling code. + output, _ = util.subp([ + 'wget', '--quiet', '--tries', '3', '--timeout', '20', + '--output-document', '-', '--header', + 'DomU_Request: {0}'.format(domu_request), + '{0}:8080'.format(self.virtual_router_address) + ]) + return output.strip() def get_password(self): password = self._do_request('send_my_password') diff --git a/tests/unittests/test_datasource/test_cloudstack.py b/tests/unittests/test_datasource/test_cloudstack.py index 959d78ae..656d80d1 100644 --- a/tests/unittests/test_datasource/test_cloudstack.py +++ b/tests/unittests/test_datasource/test_cloudstack.py @@ -23,13 +23,11 @@ class TestCloudStackPasswordFetching(TestCase): self.patches.enter_context(mock.patch('{0}.uhelp'.format(mod_name))) def _set_password_server_response(self, response_string): - http_client = mock.MagicMock() - http_client.HTTPConnection.return_value.sock.recv.return_value = \ - response_string.encode('utf-8') + subp = mock.MagicMock(return_value=(response_string, '')) self.patches.enter_context( - mock.patch('cloudinit.sources.DataSourceCloudStack.http_client', - http_client)) - return http_client + mock.patch('cloudinit.sources.DataSourceCloudStack.util.subp', + subp)) + return subp def test_empty_password_doesnt_create_config(self): self._set_password_server_response('') @@ -55,26 +53,28 @@ class TestCloudStackPasswordFetching(TestCase): ds = DataSourceCloudStack({}, None, helpers.Paths({})) self.assertTrue(ds.get_data()) - def assertRequestTypesSent(self, http_client, expected_request_types): - request_types = [ - kwargs['headers']['DomU_Request'] - for _, kwargs - in http_client.HTTPConnection.return_value.request.call_args_list] + def assertRequestTypesSent(self, subp, expected_request_types): + request_types = [] + for call in subp.call_args_list: + args = call[0][0] + for arg in args: + if arg.startswith('DomU_Request'): + request_types.append(arg.split()[1]) self.assertEqual(expected_request_types, request_types) def test_valid_response_means_password_marked_as_saved(self): password = 'SekritSquirrel' - http_client = self._set_password_server_response(password) + subp = self._set_password_server_response(password) ds = DataSourceCloudStack({}, None, helpers.Paths({})) ds.get_data() - self.assertRequestTypesSent(http_client, + self.assertRequestTypesSent(subp, ['send_my_password', 'saved_password']) def _check_password_not_saved_for(self, response_string): - http_client = self._set_password_server_response(response_string) + subp = self._set_password_server_response(response_string) ds = DataSourceCloudStack({}, None, helpers.Paths({})) ds.get_data() - self.assertRequestTypesSent(http_client, ['send_my_password']) + self.assertRequestTypesSent(subp, ['send_my_password']) def test_password_not_saved_if_empty(self): self._check_password_not_saved_for('') -- cgit v1.2.3 From 33e1f251aa2ff75475fe0f02f1e344dbe949d89d Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Mon, 6 Jul 2015 15:33:06 +0100 Subject: Reduce repetition in GCE data source tests. --- tests/unittests/test_datasource/test_gce.py | 47 ++++++++--------------------- 1 file changed, 12 insertions(+), 35 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 1fb100f7..98b68f09 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -55,8 +55,8 @@ MD_URL_RE = re.compile( r'http://metadata.google.internal./computeMetadata/v1/.*') -def _new_request_callback(gce_meta=None): - if not gce_meta: +def _set_mock_metadata(gce_meta=None): + if gce_meta is None: gce_meta = GCE_META def _request_callback(method, uri, headers): @@ -70,9 +70,10 @@ def _new_request_callback(gce_meta=None): else: return (404, headers, '') - return _request_callback + httpretty.register_uri(httpretty.GET, MD_URL_RE, body=_request_callback) +@httpretty.activate class TestDataSourceGCE(test_helpers.HttprettyTestCase): def setUp(self): @@ -81,23 +82,16 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): helpers.Paths({})) super(TestDataSourceGCE, self).setUp() - @httpretty.activate def test_connection(self): - httpretty.register_uri( - httpretty.GET, MD_URL_RE, - body=_new_request_callback()) - + _set_mock_metadata() success = self.ds.get_data() self.assertTrue(success) req_header = httpretty.last_request().headers self.assertDictContainsSubset(HEADERS, req_header) - @httpretty.activate def test_metadata(self): - httpretty.register_uri( - httpretty.GET, MD_URL_RE, - body=_new_request_callback()) + _set_mock_metadata() self.ds.get_data() shostname = GCE_META.get('instance/hostname').split('.')[0] @@ -107,18 +101,12 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): self.assertEqual(GCE_META.get('instance/id'), self.ds.get_instance_id()) - self.assertEqual(GCE_META.get('instance/zone'), - self.ds.availability_zone) - self.assertEqual(GCE_META.get('instance/attributes/user-data'), self.ds.get_userdata_raw()) # test partial metadata (missing user-data in particular) - @httpretty.activate def test_metadata_partial(self): - httpretty.register_uri( - httpretty.GET, MD_URL_RE, - body=_new_request_callback(GCE_META_PARTIAL)) + _set_mock_metadata(GCE_META_PARTIAL) self.ds.get_data() self.assertEqual(GCE_META_PARTIAL.get('instance/id'), @@ -127,58 +115,47 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): shostname = GCE_META_PARTIAL.get('instance/hostname').split('.')[0] self.assertEqual(shostname, self.ds.get_hostname()) - @httpretty.activate def test_metadata_encoding(self): - httpretty.register_uri( - httpretty.GET, MD_URL_RE, - body=_new_request_callback(GCE_META_ENCODING)) + _set_mock_metadata(GCE_META_ENCODING) self.ds.get_data() decoded = b64decode( GCE_META_ENCODING.get('instance/attributes/user-data')) self.assertEqual(decoded, self.ds.get_userdata_raw()) - @httpretty.activate def test_missing_required_keys_return_false(self): for required_key in ['instance/id', 'instance/zone', 'instance/hostname']: meta = GCE_META_PARTIAL.copy() del meta[required_key] - httpretty.register_uri(httpretty.GET, MD_URL_RE, - body=_new_request_callback(meta)) + _set_mock_metadata(meta) self.assertEqual(False, self.ds.get_data()) httpretty.reset() - @httpretty.activate def test_project_level_ssh_keys_are_used(self): - httpretty.register_uri(httpretty.GET, MD_URL_RE, - body=_new_request_callback()) + _set_mock_metadata() self.ds.get_data() # we expect a list of public ssh keys with user names stripped self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'], self.ds.get_public_ssh_keys()) - @httpretty.activate def test_instance_level_ssh_keys_are_used(self): key_content = 'ssh-rsa JustAUser root@server' meta = GCE_META.copy() meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content) - httpretty.register_uri(httpretty.GET, MD_URL_RE, - body=_new_request_callback(meta)) + _set_mock_metadata(meta) self.ds.get_data() self.assertIn(key_content, self.ds.get_public_ssh_keys()) - @httpretty.activate def test_instance_level_keys_replace_project_level_keys(self): key_content = 'ssh-rsa JustAUser root@server' meta = GCE_META.copy() meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content) - httpretty.register_uri(httpretty.GET, MD_URL_RE, - body=_new_request_callback(meta)) + _set_mock_metadata(meta) self.ds.get_data() self.assertEqual([key_content], self.ds.get_public_ssh_keys()) -- cgit v1.2.3 From afb5421ee717174b989bfed61333f2073b3f3f50 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Mon, 6 Jul 2015 15:33:33 +0100 Subject: Return a sensible value for DataSourceGCE.availability_zone. --- cloudinit/sources/DataSourceGCE.py | 4 ++++ tests/unittests/test_datasource/test_gce.py | 5 +++++ 2 files changed, 9 insertions(+) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py index f4ed915d..1b28a68c 100644 --- a/cloudinit/sources/DataSourceGCE.py +++ b/cloudinit/sources/DataSourceGCE.py @@ -116,6 +116,10 @@ class DataSourceGCE(sources.DataSource): lines = self.metadata['public-keys'].splitlines() self.metadata['public-keys'] = [self._trim_key(k) for k in lines] + if self.metadata['availability-zone']: + self.metadata['availability-zone'] = self.metadata[ + 'availability-zone'].split('/')[-1] + encoding = self.metadata.get('user-data-encoding') if encoding: if encoding == 'base64': diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 98b68f09..fa714070 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -159,3 +159,8 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): self.ds.get_data() self.assertEqual([key_content], self.ds.get_public_ssh_keys()) + + def test_only_last_part_of_zone_used_for_availability_zone(self): + _set_mock_metadata() + self.ds.get_data() + self.assertEqual('bar', self.ds.availability_zone) -- cgit v1.2.3 From 55487d9eb52343bd72271b072734740b52b25c1d Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Tue, 21 Jul 2015 13:06:11 +0100 Subject: Refactor cc_mounts.sanitize_devname to make it easier to modify. --- cloudinit/config/cc_mounts.py | 104 +++++++--------- .../unittests/test_handler/test_handler_mounts.py | 133 +++++++++++++++++++++ 2 files changed, 177 insertions(+), 60 deletions(-) create mode 100644 tests/unittests/test_handler/test_handler_mounts.py (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_mounts.py b/cloudinit/config/cc_mounts.py index 1cb1e839..f970c2ca 100644 --- a/cloudinit/config/cc_mounts.py +++ b/cloudinit/config/cc_mounts.py @@ -28,15 +28,15 @@ from cloudinit import type_utils from cloudinit import util # Shortname matches 'sda', 'sda1', 'xvda', 'hda', 'sdb', xvdb, vda, vdd1, sr0 -SHORTNAME_FILTER = r"^([x]{0,1}[shv]d[a-z][0-9]*|sr[0-9]+)$" -SHORTNAME = re.compile(SHORTNAME_FILTER) +DEVICE_NAME_FILTER = r"^([x]{0,1}[shv]d[a-z][0-9]*|sr[0-9]+)$" +DEVICE_NAME_RE = re.compile(DEVICE_NAME_FILTER) WS = re.compile("[%s]+" % (whitespace)) FSTAB_PATH = "/etc/fstab" LOG = logging.getLogger(__name__) -def is_mdname(name): +def is_meta_device_name(name): # return true if this is a metadata service name if name in ["ami", "root", "swap"]: return True @@ -48,6 +48,23 @@ def is_mdname(name): return False +def _get_nth_partition_for_device(device_path, partition_number): + potential_suffixes = [str(partition_number), 'p%s' % (partition_number,)] + for suffix in potential_suffixes: + potential_partition_device = '%s%s' % (device_path, suffix) + if os.path.exists(potential_partition_device): + return potential_partition_device + return None + + +def _is_block_device(device_path, partition_path=None): + device_name = device_path.split('/')[-1] + sys_path = os.path.join('/sys/block/', device_name) + if partition_path is not None: + sys_path = os.path.join(sys_path, partition_path.split('/')[-1]) + return os.path.exists(sys_path) + + def sanitize_devname(startname, transformer, log): log.debug("Attempting to determine the real name of %s", startname) @@ -58,21 +75,34 @@ def sanitize_devname(startname, transformer, log): devname = "ephemeral0" log.debug("Adjusted mount option from ephemeral to ephemeral0") - (blockdev, part) = util.expand_dotted_devname(devname) + device_path, partition_number = util.expand_dotted_devname(devname) - if is_mdname(blockdev): - orig = blockdev - blockdev = transformer(blockdev) - if not blockdev: + if is_meta_device_name(device_path): + orig = device_path + device_path = transformer(device_path) + if not device_path: return None - if not blockdev.startswith("/"): - blockdev = "/dev/%s" % blockdev - log.debug("Mapped metadata name %s to %s", orig, blockdev) + if not device_path.startswith("/"): + device_path = "/dev/%s" % (device_path,) + log.debug("Mapped metadata name %s to %s", orig, device_path) + else: + if DEVICE_NAME_RE.match(startname): + device_path = "/dev/%s" % (device_path,) + + partition_path = None + if partition_number is None: + partition_path = _get_nth_partition_for_device(device_path, 1) else: - if SHORTNAME.match(startname): - blockdev = "/dev/%s" % blockdev + partition_path = _get_nth_partition_for_device(device_path, + partition_number) + if partition_path is None: + return None - return devnode_for_dev_part(blockdev, part) + if _is_block_device(device_path, partition_path): + if partition_path is not None: + return partition_path + return device_path + return None def suggested_swapsize(memsize=None, maxsize=None, fsys=None): @@ -366,49 +396,3 @@ def handle(_name, cfg, cloud, log, _args): util.subp(("mount", "-a")) except: util.logexc(log, "Activating mounts via 'mount -a' failed") - - -def devnode_for_dev_part(device, partition): - """ - Find the name of the partition. While this might seem rather - straight forward, its not since some devices are '' - while others are 'p'. For example, /dev/xvda3 on EC2 - will present as /dev/xvda3p1 for the first partition since /dev/xvda3 is - a block device. - """ - if not os.path.exists(device): - return None - - short_name = os.path.basename(device) - sys_path = "/sys/block/%s" % short_name - - if not os.path.exists(sys_path): - LOG.debug("did not find entry for %s in /sys/block", short_name) - return None - - sys_long_path = sys_path + "/" + short_name - - if partition is not None: - partition = str(partition) - - if partition is None: - valid_mappings = [sys_long_path + "1", sys_long_path + "p1"] - elif partition != "0": - valid_mappings = [sys_long_path + "%s" % partition, - sys_long_path + "p%s" % partition] - else: - valid_mappings = [] - - for cdisk in valid_mappings: - if not os.path.exists(cdisk): - continue - - dev_path = "/dev/%s" % os.path.basename(cdisk) - if os.path.exists(dev_path): - return dev_path - - if partition is None or partition == "0": - return device - - LOG.debug("Did not fine partition %s for device %s", partition, device) - return None diff --git a/tests/unittests/test_handler/test_handler_mounts.py b/tests/unittests/test_handler/test_handler_mounts.py new file mode 100644 index 00000000..355674b2 --- /dev/null +++ b/tests/unittests/test_handler/test_handler_mounts.py @@ -0,0 +1,133 @@ +import os.path +import shutil +import tempfile + +from cloudinit.config import cc_mounts + +from .. import helpers as test_helpers + +try: + from unittest import mock +except ImportError: + import mock + + +class TestSanitizeDevname(test_helpers.FilesystemMockingTestCase): + + def setUp(self): + super(TestSanitizeDevname, self).setUp() + self.new_root = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.new_root) + self.patchOS(self.new_root) + + def _touch(self, path): + path = os.path.join(self.new_root, path.lstrip('/')) + basedir = os.path.dirname(path) + if not os.path.exists(basedir): + os.makedirs(basedir) + open(path, 'a').close() + + def _makedirs(self, directory): + directory = os.path.join(self.new_root, directory.lstrip('/')) + if not os.path.exists(directory): + os.makedirs(directory) + + def mock_existence_of_disk(self, disk_path): + self._touch(disk_path) + self._makedirs(os.path.join('/sys/block', disk_path.split('/')[-1])) + + def mock_existence_of_partition(self, disk_path, partition_number): + self.mock_existence_of_disk(disk_path) + self._touch(disk_path + str(partition_number)) + disk_name = disk_path.split('/')[-1] + self._makedirs(os.path.join('/sys/block', + disk_name, + disk_name + str(partition_number))) + + def test_existent_full_disk_path_is_returned(self): + disk_path = '/dev/sda' + self.mock_existence_of_disk(disk_path) + self.assertEqual(disk_path, + cc_mounts.sanitize_devname(disk_path, + lambda x: None, + mock.Mock())) + + def test_existent_disk_name_returns_full_path(self): + disk_name = 'sda' + disk_path = '/dev/' + disk_name + self.mock_existence_of_disk(disk_path) + self.assertEqual(disk_path, + cc_mounts.sanitize_devname(disk_name, + lambda x: None, + mock.Mock())) + + def test_existent_meta_disk_is_returned(self): + actual_disk_path = '/dev/sda' + self.mock_existence_of_disk(actual_disk_path) + self.assertEqual( + actual_disk_path, + cc_mounts.sanitize_devname('ephemeral0', + lambda x: actual_disk_path, + mock.Mock())) + + def test_existent_meta_partition_is_returned(self): + disk_name, partition_part = '/dev/sda', '1' + actual_partition_path = disk_name + partition_part + self.mock_existence_of_partition(disk_name, partition_part) + self.assertEqual( + actual_partition_path, + cc_mounts.sanitize_devname('ephemeral0.1', + lambda x: disk_name, + mock.Mock())) + + def test_existent_meta_partition_with_p_is_returned(self): + disk_name, partition_part = '/dev/sda', 'p1' + actual_partition_path = disk_name + partition_part + self.mock_existence_of_partition(disk_name, partition_part) + self.assertEqual( + actual_partition_path, + cc_mounts.sanitize_devname('ephemeral0.1', + lambda x: disk_name, + mock.Mock())) + + def test_first_partition_returned_if_existent_disk_is_partitioned(self): + disk_name, partition_part = '/dev/sda', '1' + actual_partition_path = disk_name + partition_part + self.mock_existence_of_partition(disk_name, partition_part) + self.assertEqual( + actual_partition_path, + cc_mounts.sanitize_devname('ephemeral0', + lambda x: disk_name, + mock.Mock())) + + def test_nth_partition_returned_if_requested(self): + disk_name, partition_part = '/dev/sda', '3' + actual_partition_path = disk_name + partition_part + self.mock_existence_of_partition(disk_name, partition_part) + self.assertEqual( + actual_partition_path, + cc_mounts.sanitize_devname('ephemeral0.3', + lambda x: disk_name, + mock.Mock())) + + def test_transformer_returning_none_returns_none(self): + self.assertIsNone( + cc_mounts.sanitize_devname( + 'ephemeral0', lambda x: None, mock.Mock())) + + def test_missing_device_returns_none(self): + self.assertIsNone( + cc_mounts.sanitize_devname('/dev/sda', None, mock.Mock())) + + def test_missing_sys_returns_none(self): + disk_path = '/dev/sda' + self._makedirs(disk_path) + self.assertIsNone( + cc_mounts.sanitize_devname(disk_path, None, mock.Mock())) + + def test_existent_disk_but_missing_partition_returns_none(self): + disk_path = '/dev/sda' + self.mock_existence_of_disk(disk_path) + self.assertIsNone( + cc_mounts.sanitize_devname( + 'ephemeral0.1', lambda x: disk_path, mock.Mock())) -- cgit v1.2.3 From 9461b1235f7278440ffb84f1e3d95b3f906e444b Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Tue, 21 Jul 2015 13:06:11 +0100 Subject: Use /dev/disk devices for Azure ephemeral disk. The ephemeral disk will not necessarily be assigned the same name at each boot (LP: #1411582), so we use some udev rules to ensure we always get the right one. --- cloudinit/sources/DataSourceAzure.py | 39 ++++++++++++++------------- tests/unittests/test_datasource/test_azure.py | 6 +++-- 2 files changed, 25 insertions(+), 20 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index d0a882ca..1193d88b 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -254,7 +254,7 @@ class DataSourceAzureNet(sources.DataSource): self.metadata.update(fabric_data) - found_ephemeral = find_ephemeral_disk() + found_ephemeral = find_fabric_formatted_ephemeral_disk() if found_ephemeral: self.ds_cfg['disk_aliases']['ephemeral0'] = found_ephemeral LOG.debug("using detected ephemeral0 of %s", found_ephemeral) @@ -276,30 +276,33 @@ def count_files(mp): return len(fnmatch.filter(os.listdir(mp), '*[!cdrom]*')) -def find_ephemeral_part(): +def find_fabric_formatted_ephemeral_part(): """ - Locate the default ephmeral0.1 device. This will be the first device - that has a LABEL of DEF_EPHEMERAL_LABEL and is a NTFS device. If Azure - gets more ephemeral devices, this logic will only identify the first - such device. + Locate the first fabric formatted ephemeral device. """ - c_label_devs = util.find_devs_with("LABEL=%s" % DEF_EPHEMERAL_LABEL) - c_fstype_devs = util.find_devs_with("TYPE=ntfs") - for dev in c_label_devs: - if dev in c_fstype_devs: - return dev + potential_locations = ['/dev/disk/cloud/azure_resource-part1', + '/dev/disk/azure/resource-part1'] + device_location = None + for potential_location in potential_locations: + if os.path.exists(potential_location): + device_location = potential_location + break + if device_location is None: + return None + ntfs_devices = util.find_devs_with("TYPE=ntfs") + real_device = os.path.realpath(device_location) + if real_device in ntfs_devices: + return device_location return None -def find_ephemeral_disk(): +def find_fabric_formatted_ephemeral_disk(): """ Get the ephemeral disk. """ - part_dev = find_ephemeral_part() - if part_dev and str(part_dev[-1]).isdigit(): - return part_dev[:-1] - elif part_dev: - return part_dev + part_dev = find_fabric_formatted_ephemeral_part() + if part_dev: + return part_dev.split('-')[0] return None @@ -313,7 +316,7 @@ def support_new_ephemeral(cfg): new ephemeral device is detected, cloud-init overrides the default frequency for both disk-setup and mounts for the current boot only. """ - device = find_ephemeral_part() + device = find_fabric_formatted_ephemeral_part() if not device: LOG.debug("no default fabric formated ephemeral0.1 found") return None diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 33b971f6..3b7e3293 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -475,10 +475,12 @@ class TestAzureBounce(TestCase): mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs', mock.MagicMock(return_value=[]))) self.patches.enter_context( - mock.patch.object(DataSourceAzure, 'find_ephemeral_disk', + mock.patch.object(DataSourceAzure, + 'find_fabric_formatted_ephemeral_disk', mock.MagicMock(return_value=None))) self.patches.enter_context( - mock.patch.object(DataSourceAzure, 'find_ephemeral_part', + mock.patch.object(DataSourceAzure, + 'find_fabric_formatted_ephemeral_part', mock.MagicMock(return_value=None))) self.patches.enter_context( mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric', -- cgit v1.2.3 From 0db6078c5555cdddcd195cefdc04154e4a754dd6 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 21 Jul 2015 12:02:44 -0400 Subject: tests: fix TestHandlerHandlePart tests these tests were previously passing, but doing so erroneously. I believe that an update to mock caused them to start failing. I've updated the tests now. The simple change is replacing 'assert_called_with_once' with 'assert_called_once_with'. The second set of changes is seemingly a correction of the following tests expectations: test_normal_version_2 : was not expecting to get frequency passed into handle_part, but should have been. test_no_handle_when_modfreq_once: was expecting to have handle_part called even though the test implies otherwise. test_exception_is_caught: this test just looked broken. Now, we're testing that the part handler is called and that no exception is raised past handle_part --- tests/unittests/test__init__.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test__init__.py b/tests/unittests/test__init__.py index c32783a6..153f1658 100644 --- a/tests/unittests/test__init__.py +++ b/tests/unittests/test__init__.py @@ -70,8 +70,8 @@ class TestWalkerHandleHandler(TestCase): return_value=self.module_fake) as mockobj: handlers.walker_handle_handler(self.data, self.ctype, self.filename, self.payload) - mockobj.assert_called_with_once(self.expected_module_name) - self.write_file_mock.assert_called_with_once( + mockobj.assert_called_once_with(self.expected_module_name) + self.write_file_mock.assert_called_once_with( self.expected_file_fullname, self.payload, 0o600) self.assertEqual(self.data['handlercount'], 1) @@ -81,8 +81,8 @@ class TestWalkerHandleHandler(TestCase): side_effect=ImportError) as mockobj: handlers.walker_handle_handler(self.data, self.ctype, self.filename, self.payload) - mockobj.assert_called_with_once(self.expected_module_name) - self.write_file_mock.assert_called_with_once( + mockobj.assert_called_once_with(self.expected_module_name) + self.write_file_mock.assert_called_once_with( self.expected_file_fullname, self.payload, 0o600) self.assertEqual(self.data['handlercount'], 0) @@ -93,8 +93,8 @@ class TestWalkerHandleHandler(TestCase): return_value=self.module_fake) as mockobj: handlers.walker_handle_handler(self.data, self.ctype, self.filename, self.payload) - mockobj.assert_called_with_once(self.expected_module_name) - self.write_file_mock.assert_called_with_once( + mockobj.assert_called_once_with(self.expected_module_name) + self.write_file_mock.assert_called_once_with( self.expected_file_fullname, self.payload, 0o600) self.assertEqual(self.data['handlercount'], 0) @@ -122,7 +122,7 @@ class TestHandlerHandlePart(unittest.TestCase): self.frequency, self.headers) # Assert that the handle_part() method of the mock object got # called with the expected arguments. - mod_mock.handle_part.assert_called_with_once( + mod_mock.handle_part.assert_called_once_with( self.data, self.ctype, self.filename, self.payload) def test_normal_version_2(self): @@ -136,8 +136,9 @@ class TestHandlerHandlePart(unittest.TestCase): self.frequency, self.headers) # Assert that the handle_part() method of the mock object got # called with the expected arguments. - mod_mock.handle_part.assert_called_with_once( - self.data, self.ctype, self.filename, self.payload) + mod_mock.handle_part.assert_called_once_with( + self.data, self.ctype, self.filename, self.payload, + settings.PER_INSTANCE) def test_modfreq_per_always(self): """ @@ -150,7 +151,7 @@ class TestHandlerHandlePart(unittest.TestCase): self.frequency, self.headers) # Assert that the handle_part() method of the mock object got # called with the expected arguments. - mod_mock.handle_part.assert_called_with_once( + mod_mock.handle_part.assert_called_once_with( self.data, self.ctype, self.filename, self.payload) def test_no_handle_when_modfreq_once(self): @@ -159,21 +160,20 @@ class TestHandlerHandlePart(unittest.TestCase): mod_mock = mock.Mock(frequency=settings.PER_ONCE) handlers.run_part(mod_mock, self.data, self.filename, self.payload, self.frequency, self.headers) - # Assert that the handle_part() method of the mock object got - # called with the expected arguments. - mod_mock.handle_part.assert_called_with_once( - self.data, self.ctype, self.filename, self.payload) + self.assertEqual(0, mod_mock.handle_part.call_count) def test_exception_is_caught(self): """Exceptions within C{handle_part} are caught and logged.""" mod_mock = mock.Mock(frequency=settings.PER_INSTANCE, handler_version=1) - handlers.run_part(mod_mock, self.data, self.filename, self.payload, - self.frequency, self.headers) mod_mock.handle_part.side_effect = Exception - handlers.run_part(mod_mock, self.data, self.filename, self.payload, - self.frequency, self.headers) - mod_mock.handle_part.assert_called_with_once( + try: + handlers.run_part(mod_mock, self.data, self.filename, + self.payload, self.frequency, self.headers) + except Exception: + self.fail("Exception was not caught in handle_part") + + mod_mock.handle_part.assert_called_once_with( self.data, self.ctype, self.filename, self.payload) -- cgit v1.2.3 From b5230bc3e9d65692093cae9d2f4ca628435a382b Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 21 Jul 2015 12:36:53 -0400 Subject: fix 'make pyflakes' --- cloudinit/sources/DataSourceAzure.py | 2 +- tests/unittests/test_datasource/test_azure.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index d0a882ca..2ce85637 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -430,7 +430,7 @@ def write_files(datadir, files, dirmode=None): elem.text != DEF_PASSWD_REDACTION): elem.text = DEF_PASSWD_REDACTION return ET.tostring(root) - except Exception as e: + except Exception: LOG.critical("failed to redact userpassword in {}".format(fname)) return cnt diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 33b971f6..d632bcb9 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -174,7 +174,7 @@ class TestAzureDataSource(TestCase): def xml_notequals(self, oxml, nxml): try: self.xml_equals(oxml, nxml) - except AssertionError as e: + except AssertionError: return raise AssertionError("XML is the same") -- cgit v1.2.3 From 73c5bbfa31b922a0ba403216c0fc1f63b22a9262 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Wed, 22 Jul 2015 13:06:34 +0100 Subject: Make full data source available to code that handles mirror selection. --- cloudinit/distros/__init__.py | 15 +++++++-------- cloudinit/sources/__init__.py | 3 +-- tests/unittests/test_distros/test_generic.py | 22 +++++++++++++++------- 3 files changed, 23 insertions(+), 17 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 8a947867..47b76c68 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -117,12 +117,11 @@ class Distro(object): arch = self.get_primary_arch() return _get_arch_package_mirror_info(mirror_info, arch) - def get_package_mirror_info(self, arch=None, - availability_zone=None): + def get_package_mirror_info(self, arch=None, data_source=None): # This resolves the package_mirrors config option # down to a single dict of {mirror_name: mirror_url} arch_info = self._get_arch_package_mirror_info(arch) - return _get_package_mirror_info(availability_zone=availability_zone, + return _get_package_mirror_info(data_source=data_source, mirror_info=arch_info) def apply_network(self, settings, bring_up=True): @@ -556,7 +555,7 @@ class Distro(object): LOG.info("Added user '%s' to group '%s'" % (member, name)) -def _get_package_mirror_info(mirror_info, availability_zone=None, +def _get_package_mirror_info(mirror_info, data_source=None, mirror_filter=util.search_for_mirror): # given a arch specific 'mirror_info' entry (from package_mirrors) # search through the 'search' entries, and fallback appropriately @@ -572,11 +571,11 @@ def _get_package_mirror_info(mirror_info, availability_zone=None, ec2_az_re = ("^[a-z][a-z]-(%s)-[1-9][0-9]*[a-z]$" % directions_re) subst = {} - if availability_zone: - subst['availability_zone'] = availability_zone + if data_source and data_source.availability_zone: + subst['availability_zone'] = data_source.availability_zone - if availability_zone and re.match(ec2_az_re, availability_zone): - subst['ec2_region'] = "%s" % availability_zone[0:-1] + if re.match(ec2_az_re, data_source.availability_zone): + subst['ec2_region'] = "%s" % data_source.availability_zone[0:-1] results = {} for (name, mirror) in mirror_info.get('failsafe', {}).items(): diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 39eab51b..1a036638 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -210,8 +210,7 @@ class DataSource(object): return hostname def get_package_mirror_info(self): - return self.distro.get_package_mirror_info( - availability_zone=self.availability_zone) + return self.distro.get_package_mirror_info(data_source=self) def normalize_pubkey_data(pubkey_data): diff --git a/tests/unittests/test_distros/test_generic.py b/tests/unittests/test_distros/test_generic.py index 8e3bd78a..6ed1704c 100644 --- a/tests/unittests/test_distros/test_generic.py +++ b/tests/unittests/test_distros/test_generic.py @@ -7,6 +7,11 @@ import os import shutil import tempfile +try: + from unittest import mock +except ImportError: + import mock + unknown_arch_info = { 'arches': ['default'], 'failsafe': {'primary': 'http://fs-primary-default', @@ -144,33 +149,35 @@ class TestGenericDistro(helpers.FilesystemMockingTestCase): def test_get_package_mirror_info_az_ec2(self): arch_mirrors = gapmi(package_mirrors, arch="amd64") + data_source_mock = mock.Mock(availability_zone="us-east-1a") - results = gpmi(arch_mirrors, availability_zone="us-east-1a", + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_first) self.assertEqual(results, {'primary': 'http://us-east-1.ec2/', 'security': 'http://security-mirror1-intel'}) - results = gpmi(arch_mirrors, availability_zone="us-east-1a", + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_second) self.assertEqual(results, {'primary': 'http://us-east-1a.clouds/', 'security': 'http://security-mirror2-intel'}) - results = gpmi(arch_mirrors, availability_zone="us-east-1a", + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_none) self.assertEqual(results, package_mirrors[0]['failsafe']) def test_get_package_mirror_info_az_non_ec2(self): arch_mirrors = gapmi(package_mirrors, arch="amd64") + data_source_mock = mock.Mock(availability_zone="nova.cloudvendor") - results = gpmi(arch_mirrors, availability_zone="nova.cloudvendor", + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_first) self.assertEqual(results, {'primary': 'http://nova.cloudvendor.clouds/', 'security': 'http://security-mirror1-intel'}) - results = gpmi(arch_mirrors, availability_zone="nova.cloudvendor", + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_last) self.assertEqual(results, {'primary': 'http://nova.cloudvendor.clouds/', @@ -178,17 +185,18 @@ class TestGenericDistro(helpers.FilesystemMockingTestCase): def test_get_package_mirror_info_none(self): arch_mirrors = gapmi(package_mirrors, arch="amd64") + data_source_mock = mock.Mock(availability_zone=None) # because both search entries here replacement based on # availability-zone, the filter will be called with an empty list and # failsafe should be taken. - results = gpmi(arch_mirrors, availability_zone=None, + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_first) self.assertEqual(results, {'primary': 'http://fs-primary-intel', 'security': 'http://security-mirror1-intel'}) - results = gpmi(arch_mirrors, availability_zone=None, + results = gpmi(arch_mirrors, data_source=data_source_mock, mirror_filter=self.return_last) self.assertEqual(results, {'primary': 'http://fs-primary-intel', -- cgit v1.2.3 From 6970029c661ab858a55dd467e5c593694ab39512 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 24 Jul 2015 16:58:57 -0400 Subject: commit initial re-work/re-implementation of syslog config --- cloudinit/config/cc_syslog.py | 183 +++++++++++++++++++++ doc/examples/cloud-config-syslog.txt | 30 ++++ .../unittests/test_handler/test_handler_syslog.py | 32 ++++ 3 files changed, 245 insertions(+) create mode 100644 cloudinit/config/cc_syslog.py create mode 100644 doc/examples/cloud-config-syslog.txt create mode 100644 tests/unittests/test_handler/test_handler_syslog.py (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_syslog.py b/cloudinit/config/cc_syslog.py new file mode 100644 index 00000000..21a8e8a9 --- /dev/null +++ b/cloudinit/config/cc_syslog.py @@ -0,0 +1,183 @@ +# vi: ts=4 expandtab +# +# Copyright (C) 2015 Canonical Ltd. +# +# Author: Scott Moser +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3, as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from cloudinit import log as logging +from cloudinit import util +from cloudinit.settings import PER_INSTANCE + +import re + +LOG = logging.getLogger(__name__) + +frequency = PER_INSTANCE + +BUILTIN_CFG = { + 'remotes_file': '/etc/rsyslog.d/20-cloudinit-remotes.conf', + 'remotes': {}, + 'service_name': 'rsyslog', +} + +COMMENT_RE = re.compile(r'[ ]*[#]+[ ]*') +HOST_PORT_RE = re.compile( + r'^(?P[@]{0,2})' + '(([[](?P[^\]]*)[\]])|(?P[^:]*))' + '([:](?P[0-9]+))?$') + + +def parse_remotes_line(line, name=None): + try: + data, comment = COMMENT_RE.split(line) + comment = comment.strip() + except ValueError: + data, comment = (line, None) + + toks = data.strip().split() + match = None + if len(toks) == 1: + host_port = data + elif len(toks) == 2: + match, host_port = toks + else: + raise ValueError("line had multiple spaces: %s" % data) + + toks = HOST_PORT_RE.match(host_port) + + if not toks: + raise ValueError("Invalid host specification '%s'" % host_port) + + proto = toks.group('proto') + addr = toks.group('addr') or toks.group('bracket_addr') + port = toks.group('port') + print("host_port: %s" % addr) + print("port: %s" % port) + + if addr.startswith("[") and not addr.endswith("]"): + raise ValueError("host spec had invalid brackets: %s" % addr) + + if comment and not name: + name = comment + + t = SyslogRemotesLine(name=name, match=match, proto=proto, + addr=addr, port=port) + t.validate() + return t + + +class SyslogRemotesLine(object): + def __init__(self, name=None, match=None, proto=None, addr=None, + port=None): + if not match: + match = "*.*" + self.name = name + self.match = match + if proto == "@": + proto = "udp" + elif proto == "@@": + proto = "tcp" + self.proto = proto + + self.addr = addr + if port: + self.port = int(port) + else: + self.port = None + + def validate(self): + if self.port: + try: + int(self.port) + except ValueError: + raise ValueError("port '%s' is not an integer" % self.port) + + if not self.addr: + raise ValueError("address is required") + + def __repr__(self): + return "[name=%s match=%s proto=%s address=%s port=%s]" % ( + self.name, self.match, self.proto, self.addr, self.port + ) + + def __str__(self): + buf = self.match + " " + if self.proto == "udp": + buf += " @" + elif self.proto == "tcp": + buf += " @@" + + if ":" in self.addr: + buf += "[" + self.addr + "]" + else: + buf += self.addr + + if self.port: + buf += ":%s" % self.port + + if self.name: + buf += " # %s" % self.name + return buf + + +def remotes_to_rsyslog_cfg(remotes, header=None): + if not remotes: + return None + lines = [] + if header is not None: + lines.append(header) + for name, line in remotes.items(): + try: + lines.append(parse_remotes_line(line, name=name)) + except ValueError as e: + LOG.warn("failed loading remote %s: %s [%s]", name, line, e) + return '\n'.join(str(lines)) + '\n' + + +def reload_syslog(systemd, service='rsyslog'): + if systemd: + cmd = ['systemctl', 'reload-or-try-restart', service] + else: + cmd = ['service', service, 'reload'] + try: + util.subp(cmd, capture=True) + except util.ProcessExecutionError as e: + LOG.warn("Failed to reload syslog using '%s': %s", ' '.join(cmd), e) + + +def handle(name, cfg, cloud, log, args): + cfgin = cfg.get('syslog') + if not cfgin: + cfgin = {} + mycfg = util.mergemanydict([cfgin, BUILTIN_CFG]) + + remotes_file = mycfg.get('remotes_file') + if util.is_false(remotes_file): + LOG.debug("syslog/remotes_file empty, doing nothing") + return + + remotes = mycfg.get('remotes_dict', {}) + if remotes and not isinstance(remotes, dict): + LOG.warn("syslog/remotes: content is not a dictionary") + return + + config_data = remotes_to_rsyslog_cfg( + remotes, header="#cloud-init syslog module") + + util.write_file(remotes_file, config_data) + + reload_syslog( + systemd=cloud.distro.uses_systemd(), + service=mycfg.get('service_name')) diff --git a/doc/examples/cloud-config-syslog.txt b/doc/examples/cloud-config-syslog.txt new file mode 100644 index 00000000..9ec5e120 --- /dev/null +++ b/doc/examples/cloud-config-syslog.txt @@ -0,0 +1,30 @@ +## syslog module allows you to configure the systems syslog. +## configuration of syslog is under the top level cloud-config +## entry 'syslog'. +## +## "remotes" +## remotes is a dictionary. items are of 'name: remote_info' +## name is simply a name (example 'maas'). It has no importance other than +## for cloud-init merging configs +## +## remote_info is of the format +## * optional filter for log messages +## default if not present: *.* +## * optional leading '@' or '@@' (indicates udp or tcp). +## default if not present (udp): @ +## This is rsyslog format for that. if not present, is '@' which is udp +## * ipv4 or ipv6 or hostname +## ipv6 addresses must be encoded in [::1] format. example: @[fd00::1]:514 +## * optional port +## port defaults to 514 +## +## Example: +#cloud-config +syslog: + remotes: + # udp to host 'maas.mydomain' port 514 + maashost: maas.mydomain + # udp to ipv4 host on port 514 + maas: "@[10.5.1.56]:514" + # tcp to host ipv6 host on port 555 + maasipv6: "*.* @@[FE80::0202:B3FF:FE1E:8329]:555" diff --git a/tests/unittests/test_handler/test_handler_syslog.py b/tests/unittests/test_handler/test_handler_syslog.py new file mode 100644 index 00000000..bbfd521e --- /dev/null +++ b/tests/unittests/test_handler/test_handler_syslog.py @@ -0,0 +1,32 @@ +from cloudinit.config.cc_syslog import ( + parse_remotes_line, SyslogRemotesLine, remotes_to_rsyslog_cfg) +from cloudinit import util +from .. import helpers as t_help + + +class TestParseRemotesLine(t_help.TestCase): + def test_valid_port(self): + r = parse_remotes_line("foo:9") + self.assertEqual(9, r.port) + + def test_invalid_port(self): + with self.assertRaises(ValueError): + parse_remotes_line("*.* foo:abc") + + def test_valid_ipv6(self): + r = parse_remotes_line("*.* [::1]") + self.assertEqual("*.* [::1]", str(r)) + + def test_valid_ipv6_with_port(self): + r = parse_remotes_line("*.* [::1]:100") + self.assertEqual(r.port, 100) + self.assertEqual(r.addr, "::1") + self.assertEqual("*.* [::1]:100", str(r)) + + def test_invalid_multiple_colon(self): + with self.assertRaises(ValueError): + parse_remotes_line("*.* ::1:100") + + def test_name_in_string(self): + r = parse_remotes_line("syslog.host", name="foobar") + self.assertEqual("*.* syslog.host # foobar", str(r)) -- cgit v1.2.3 From 22cb92421234c31b783ed9df01439c734535ba01 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 27 Jul 2015 16:49:30 -0400 Subject: add rsyslog tests reasonable test of reworked rsyslog module --- cloudinit/config/cc_rsyslog.py | 27 ++--- .../unittests/test_handler/test_handler_rsyslog.py | 113 +++++++++++++++++++++ 2 files changed, 122 insertions(+), 18 deletions(-) create mode 100644 tests/unittests/test_handler/test_handler_rsyslog.py (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rsyslog.py b/cloudinit/config/cc_rsyslog.py index 7d5657bc..a07200c3 100644 --- a/cloudinit/config/cc_rsyslog.py +++ b/cloudinit/config/cc_rsyslog.py @@ -113,17 +113,18 @@ def load_config(cfg): # support converting the old top level format into new format mycfg = cfg.get('rsyslog', {}) - if isinstance(mycfg, list): - mycfg[KEYNAME_CONFIGS] = mycfg + if isinstance(cfg.get('rsyslog'), list): + mycfg = {KEYNAME_CONFIGS: cfg.get('rsyslog')} if KEYNAME_LEGACY_FILENAME in cfg: mycfg[KEYNAME_FILENAME] = cfg[KEYNAME_LEGACY_FILENAME] if KEYNAME_LEGACY_DIR in cfg: - mycfg[KEYNAME_DIR] = cfg[KEYNAME_DIR] + mycfg[KEYNAME_DIR] = cfg[KEYNAME_LEGACY_DIR] fillup = ( - (KEYNAME_DIR, DEF_DIR, six.text_type), - (KEYNAME_FILENAME, DEF_FILENAME, six.text_type) - (KEYNAME_RELOAD, DEF_RELOAD, (six.text_type, list))) + (KEYNAME_CONFIGS, [], list), + (KEYNAME_DIR, DEF_DIR, six.string_types), + (KEYNAME_FILENAME, DEF_FILENAME, six.string_types), + (KEYNAME_RELOAD, DEF_RELOAD, six.string_types + (list,))) for key, default, vtypes in fillup: if key not in mycfg or not isinstance(mycfg[key], vtypes): @@ -156,7 +157,8 @@ def apply_rsyslog_changes(configs, def_fname, cfg_dir): filename = os.path.join(cfg_dir, filename) # Truncate filename first time you see it - omode = "ab" if filename not in files: + omode = "ab" + if filename not in files: omode = "wb" files.append(filename) @@ -170,17 +172,6 @@ def apply_rsyslog_changes(configs, def_fname, cfg_dir): def handle(name, cfg, cloud, log, _args): - # rsyslog: - # configs: - # - "*.* @@192.158.1.1" - # - content: "*.* @@192.0.2.1:10514" - # - filename: 01-examplecom.conf - # content: | - # *.* @@syslogd.example.com - # config_dir: DEF_DIR - # config_filename: DEF_FILENAME - # service_reload: "auto" - if 'rsyslog' not in cfg: log.debug(("Skipping module named %s," " no 'rsyslog' key in configuration"), name) diff --git a/tests/unittests/test_handler/test_handler_rsyslog.py b/tests/unittests/test_handler/test_handler_rsyslog.py new file mode 100644 index 00000000..3501ff95 --- /dev/null +++ b/tests/unittests/test_handler/test_handler_rsyslog.py @@ -0,0 +1,113 @@ +import os +import shutil +import tempfile + +from cloudinit.config.cc_rsyslog import ( + load_config, DEF_FILENAME, DEF_DIR, DEF_RELOAD, apply_rsyslog_changes) +from cloudinit import util + +from .. import helpers as t_help + + +class TestLoadConfig(t_help.TestCase): + def setUp(self): + super(TestLoadConfig, self).setUp() + self.basecfg = { + 'config_filename': DEF_FILENAME, + 'config_dir': DEF_DIR, + 'service_reload_command': DEF_RELOAD, + 'configs': [], + } + + def test_legacy_full(self): + found = load_config({ + 'rsyslog': ['*.* @192.168.1.1'], + 'rsyslog_dir': "mydir", + 'rsyslog_filename': "myfilename"}) + expected = { + 'configs': ['*.* @192.168.1.1'], + 'config_dir': "mydir", + 'config_filename': 'myfilename', + 'service_reload_command': 'auto'} + self.assertEqual(found, expected) + + def test_legacy_defaults(self): + found = load_config({ + 'rsyslog': ['*.* @192.168.1.1']}) + self.basecfg.update({ + 'configs': ['*.* @192.168.1.1']}) + self.assertEqual(found, self.basecfg) + + def test_new_defaults(self): + self.assertEqual(load_config({}), self.basecfg) + + def test_new_configs(self): + cfgs = ['*.* myhost', '*.* my2host'] + self.basecfg.update({'configs': cfgs}) + self.assertEqual( + load_config({'rsyslog': {'configs': cfgs}}), + self.basecfg) + + +class TestApplyChanges(t_help.TestCase): + def setUp(self): + self.tmp = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp) + + def test_simple(self): + cfgline = "*.* foohost" + changed = apply_rsyslog_changes( + configs=[cfgline], def_fname="foo.cfg", cfg_dir=self.tmp) + + fname = os.path.join(self.tmp, "foo.cfg") + self.assertEqual([fname], changed) + self.assertEqual( + util.load_file(fname), cfgline + "\n") + + def test_multiple_files(self): + configs = [ + '*.* foohost', + {'content': 'abc', 'filename': 'my.cfg'}, + {'content': 'filefoo-content', + 'filename': os.path.join(self.tmp, 'mydir/mycfg')}, + ] + + changed = apply_rsyslog_changes( + configs=configs, def_fname="default.cfg", cfg_dir=self.tmp) + + expected = [ + (os.path.join(self.tmp, "default.cfg"), + "*.* foohost\n"), + (os.path.join(self.tmp, "my.cfg"), "abc\n"), + (os.path.join(self.tmp, "mydir/mycfg"), "filefoo-content\n"), + ] + self.assertEqual([f[0] for f in expected], changed) + actual = [] + for fname, _content in expected: + util.load_file(fname) + actual.append((fname, util.load_file(fname),)) + self.assertEqual(expected, actual) + + def test_repeat_def(self): + configs = ['*.* foohost', "*.warn otherhost"] + + changed = apply_rsyslog_changes( + configs=configs, def_fname="default.cfg", cfg_dir=self.tmp) + + fname = os.path.join(self.tmp, "default.cfg") + self.assertEqual([fname], changed) + + expected_content = '\n'.join([c for c in configs]) + '\n' + found_content = util.load_file(fname) + self.assertEqual(expected_content, found_content) + + def test_multiline_content(self): + configs = ['line1', 'line2\nline3\n'] + + changed = apply_rsyslog_changes( + configs=configs, def_fname="default.cfg", cfg_dir=self.tmp) + + fname = os.path.join(self.tmp, "default.cfg") + expected_content = '\n'.join([c for c in configs]) + '\n' + found_content = util.load_file(fname) + self.assertEqual(expected_content, found_content) -- cgit v1.2.3 From 6dd505fd02e0933d8770c8932a927940f6a0e025 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 09:27:26 -0400 Subject: add support for 'remotes' --- cloudinit/config/cc_rsyslog.py | 156 ++++++++++++++++++++- cloudinit/config/cc_syslog.py | 2 +- doc/examples/cloud-config-rsyslog.txt | 37 +++++ doc/examples/cloud-config-syslog.txt | 30 ---- .../unittests/test_handler/test_handler_rsyslog.py | 38 ++++- 5 files changed, 227 insertions(+), 36 deletions(-) create mode 100644 doc/examples/cloud-config-rsyslog.txt delete mode 100644 doc/examples/cloud-config-syslog.txt (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rsyslog.py b/cloudinit/config/cc_rsyslog.py index 9599e925..8c02e826 100644 --- a/cloudinit/config/cc_rsyslog.py +++ b/cloudinit/config/cc_rsyslog.py @@ -37,10 +37,33 @@ Under 'rsyslog' you can define: For simply logging to an existing remote syslog server, via udp: configs: ["*.* @192.168.1.1"] + - remotes: [default={}] + This is a dictionary of name / value pairs. + In comparison to 'config's, it is more focused in that it only supports + remote syslog configuration. It is not rsyslog specific, and could + convert to other syslog implementations. + + Each entry in remotes is a 'name' and a 'value'. + * name: an string identifying the entry. good practice would indicate + using a consistent and identifiable string for the producer. + For example, the MAAS service could use 'maas' as the key. + * value consists of the following parts: + * optional filter for log messages + default if not present: *.* + * optional leading '@' or '@@' (indicates udp or tcp respectively). + default if not present (udp): @ + This is rsyslog format for that. if not present, is '@'. + * ipv4 or ipv6 or hostname + ipv6 addresses must be in [::1] format. (@[fd00::1]:514) + * optional port + port defaults to 514 + - config_filename: [default=20-cloud-config.conf] this is the file name to use if none is provided in a config entry. + - config_dir: [default=/etc/rsyslog.d] this directory is used for filenames that are not absolute paths. + - service_reload_command: [default="auto"] this command is executed if files have been written and thus the syslog daemon needs to be told. @@ -61,9 +84,12 @@ Example config: configs: - "*.* @@192.158.1.1" - content: "*.* @@192.0.2.1:10514" - - filename: 01-examplecom.conf + filename: 01-example.conf - content: | *.* @@syslogd.example.com + remotes: + maas: "192.168.1.1" + juju: "10.0.4.1" config_dir: config_dir config_filename: config_filename service_reload_command: [your, syslog, restart, command] @@ -77,6 +103,7 @@ Example Legacy config: """ import os +import re import six from cloudinit import log as logging @@ -85,6 +112,7 @@ from cloudinit import util DEF_FILENAME = "20-cloud-config.conf" DEF_DIR = "/etc/rsyslog.d" DEF_RELOAD = "auto" +DEF_REMOTES = {} KEYNAME_CONFIGS = 'configs' KEYNAME_FILENAME = 'config_filename' @@ -92,9 +120,15 @@ KEYNAME_DIR = 'config_dir' KEYNAME_RELOAD = 'service_reload_command' KEYNAME_LEGACY_FILENAME = 'rsyslog_filename' KEYNAME_LEGACY_DIR = 'rsyslog_dir' +KEYNAME_REMOTES = 'remotes' LOG = logging.getLogger(__name__) +COMMENT_RE = re.compile(r'[ ]*[#]+[ ]*') +HOST_PORT_RE = re.compile( + r'^(?P[@]{0,2})' + '(([[](?P[^\]]*)[\]])|(?P[^:]*))' + '([:](?P[0-9]+))?$') def reload_syslog(command=DEF_RELOAD, systemd=False): service = 'rsyslog' @@ -124,7 +158,8 @@ def load_config(cfg): (KEYNAME_CONFIGS, [], list), (KEYNAME_DIR, DEF_DIR, six.string_types), (KEYNAME_FILENAME, DEF_FILENAME, six.string_types), - (KEYNAME_RELOAD, DEF_RELOAD, six.string_types + (list,))) + (KEYNAME_RELOAD, DEF_RELOAD, six.string_types + (list,)), + (KEYNAME_REMOTES, DEF_REMOTES, dict)) for key, default, vtypes in fillup: if key not in mycfg or not isinstance(mycfg[key], vtypes): @@ -171,6 +206,113 @@ def apply_rsyslog_changes(configs, def_fname, cfg_dir): return files +def parse_remotes_line(line, name=None): + try: + data, comment = COMMENT_RE.split(line) + comment = comment.strip() + except ValueError: + data, comment = (line, None) + + toks = data.strip().split() + match = None + if len(toks) == 1: + host_port = data + elif len(toks) == 2: + match, host_port = toks + else: + raise ValueError("line had multiple spaces: %s" % data) + + toks = HOST_PORT_RE.match(host_port) + + if not toks: + raise ValueError("Invalid host specification '%s'" % host_port) + + proto = toks.group('proto') + addr = toks.group('addr') or toks.group('bracket_addr') + port = toks.group('port') + + if addr.startswith("[") and not addr.endswith("]"): + raise ValueError("host spec had invalid brackets: %s" % addr) + + if comment and not name: + name = comment + + t = SyslogRemotesLine(name=name, match=match, proto=proto, + addr=addr, port=port) + t.validate() + return t + + +class SyslogRemotesLine(object): + def __init__(self, name=None, match=None, proto=None, addr=None, + port=None): + if not match: + match = "*.*" + self.name = name + self.match = match + if proto == "@": + proto = "udp" + elif proto == "@@": + proto = "tcp" + self.proto = proto + + self.addr = addr + if port: + self.port = int(port) + else: + self.port = None + + def validate(self): + if self.port: + try: + int(self.port) + except ValueError: + raise ValueError("port '%s' is not an integer" % self.port) + + if not self.addr: + raise ValueError("address is required") + + def __repr__(self): + return "[name=%s match=%s proto=%s address=%s port=%s]" % ( + self.name, self.match, self.proto, self.addr, self.port + ) + + def __str__(self): + buf = self.match + " " + if self.proto == "udp": + buf += " @" + elif self.proto == "tcp": + buf += " @@" + + if ":" in self.addr: + buf += "[" + self.addr + "]" + else: + buf += self.addr + + if self.port: + buf += ":%s" % self.port + + if self.name: + buf += " # %s" % self.name + return buf + + +def remotes_to_rsyslog_cfg(remotes, header=None, footer=None): + if not remotes: + return None + lines = [] + if header is not None: + lines.append(header) + for name, line in remotes.items(): + try: + lines.append(parse_remotes_line(line, name=name)) + except ValueError as e: + LOG.warn("failed loading remote %s: %s [%s]", name, line, e) + if footer is not None: + lines.append(footer) + return '\n'.join(str(lines)) + '\n' + + def handle(name, cfg, cloud, log, _args): if 'rsyslog' not in cfg: log.debug(("Skipping module named %s," @@ -178,6 +320,16 @@ def handle(name, cfg, cloud, log, _args): return mycfg = load_config(cfg) + configs = mycfg[KEYNAME_CONFIGS] + + if mycfg[KEYNAME_REMOTES]: + configs.append( + remotes_to_rsyslog_cfg( + mycfg[KEYNAME_REMOTES], + header="# begin remotes", + footer="# end remotes", + )) + if not mycfg['configs']: log.debug("Empty config rsyslog['configs'], nothing to do") return diff --git a/cloudinit/config/cc_syslog.py b/cloudinit/config/cc_syslog.py index 21a8e8a9..27793f8b 100644 --- a/cloudinit/config/cc_syslog.py +++ b/cloudinit/config/cc_syslog.py @@ -168,7 +168,7 @@ def handle(name, cfg, cloud, log, args): LOG.debug("syslog/remotes_file empty, doing nothing") return - remotes = mycfg.get('remotes_dict', {}) + remotes = mycfg.get('remotes', {}) if remotes and not isinstance(remotes, dict): LOG.warn("syslog/remotes: content is not a dictionary") return diff --git a/doc/examples/cloud-config-rsyslog.txt b/doc/examples/cloud-config-rsyslog.txt new file mode 100644 index 00000000..ff60e3a8 --- /dev/null +++ b/doc/examples/cloud-config-rsyslog.txt @@ -0,0 +1,37 @@ +## the rsyslog module allows you to configure the systems syslog. +## configuration of syslog is under the top level cloud-config +## entry 'rsyslog'. +## +## Example: +#cloud-config +rsyslog: + remotes: + # udp to host 'maas.mydomain' port 514 + maashost: maas.mydomain + # udp to ipv4 host on port 514 + maas: "@[10.5.1.56]:514" + # tcp to host ipv6 host on port 555 + maasipv6: "*.* @@[FE80::0202:B3FF:FE1E:8329]:555" + configs: + - "*.* @@192.158.1.1" + - content: "*.* @@192.0.2.1:10514" + filename: 01-example.conf + - content: | + *.* @@syslogd.example.com + config_dir: /etc/rsyslog.d + config_filename: 20-cloud-config.conf + service_reload_command: [your, syslog, reload, command] + +## Additionally the following legacy format is supported +## it is converted into the format above before use. +## rsyslog_filename -> rsyslog/config_filename +## rsyslog_dir -> rsyslog/config_dir +## rsyslog -> rsyslog/configs +# rsyslog: +# - "*.* @@192.158.1.1" +# - content: "*.* @@192.0.2.1:10514" +# filename: 01-example.conf +# - content: | +# *.* @@syslogd.example.com +# rsyslog_filename: 20-cloud-config.conf +# rsyslog_dir: /etc/rsyslog.d diff --git a/doc/examples/cloud-config-syslog.txt b/doc/examples/cloud-config-syslog.txt deleted file mode 100644 index 9ec5e120..00000000 --- a/doc/examples/cloud-config-syslog.txt +++ /dev/null @@ -1,30 +0,0 @@ -## syslog module allows you to configure the systems syslog. -## configuration of syslog is under the top level cloud-config -## entry 'syslog'. -## -## "remotes" -## remotes is a dictionary. items are of 'name: remote_info' -## name is simply a name (example 'maas'). It has no importance other than -## for cloud-init merging configs -## -## remote_info is of the format -## * optional filter for log messages -## default if not present: *.* -## * optional leading '@' or '@@' (indicates udp or tcp). -## default if not present (udp): @ -## This is rsyslog format for that. if not present, is '@' which is udp -## * ipv4 or ipv6 or hostname -## ipv6 addresses must be encoded in [::1] format. example: @[fd00::1]:514 -## * optional port -## port defaults to 514 -## -## Example: -#cloud-config -syslog: - remotes: - # udp to host 'maas.mydomain' port 514 - maashost: maas.mydomain - # udp to ipv4 host on port 514 - maas: "@[10.5.1.56]:514" - # tcp to host ipv6 host on port 555 - maasipv6: "*.* @@[FE80::0202:B3FF:FE1E:8329]:555" diff --git a/tests/unittests/test_handler/test_handler_rsyslog.py b/tests/unittests/test_handler/test_handler_rsyslog.py index 3501ff95..0bace685 100644 --- a/tests/unittests/test_handler/test_handler_rsyslog.py +++ b/tests/unittests/test_handler/test_handler_rsyslog.py @@ -3,7 +3,8 @@ import shutil import tempfile from cloudinit.config.cc_rsyslog import ( - load_config, DEF_FILENAME, DEF_DIR, DEF_RELOAD, apply_rsyslog_changes) + apply_rsyslog_changes, DEF_DIR, DEF_FILENAME, DEF_RELOAD, load_config, + parse_remotes_line) from cloudinit import util from .. import helpers as t_help @@ -17,6 +18,7 @@ class TestLoadConfig(t_help.TestCase): 'config_dir': DEF_DIR, 'service_reload_command': DEF_RELOAD, 'configs': [], + 'remotes': {}, } def test_legacy_full(self): @@ -24,12 +26,14 @@ class TestLoadConfig(t_help.TestCase): 'rsyslog': ['*.* @192.168.1.1'], 'rsyslog_dir': "mydir", 'rsyslog_filename': "myfilename"}) - expected = { + self.basecfg.update({ 'configs': ['*.* @192.168.1.1'], 'config_dir': "mydir", 'config_filename': 'myfilename', 'service_reload_command': 'auto'} - self.assertEqual(found, expected) + ) + + self.assertEqual(found, self.basecfg) def test_legacy_defaults(self): found = load_config({ @@ -111,3 +115,31 @@ class TestApplyChanges(t_help.TestCase): expected_content = '\n'.join([c for c in configs]) + '\n' found_content = util.load_file(fname) self.assertEqual(expected_content, found_content) + + +class TestParseRemotesLine(t_help.TestCase): + def test_valid_port(self): + r = parse_remotes_line("foo:9") + self.assertEqual(9, r.port) + + def test_invalid_port(self): + with self.assertRaises(ValueError): + parse_remotes_line("*.* foo:abc") + + def test_valid_ipv6(self): + r = parse_remotes_line("*.* [::1]") + self.assertEqual("*.* [::1]", str(r)) + + def test_valid_ipv6_with_port(self): + r = parse_remotes_line("*.* [::1]:100") + self.assertEqual(r.port, 100) + self.assertEqual(r.addr, "::1") + self.assertEqual("*.* [::1]:100", str(r)) + + def test_invalid_multiple_colon(self): + with self.assertRaises(ValueError): + parse_remotes_line("*.* ::1:100") + + def test_name_in_string(self): + r = parse_remotes_line("syslog.host", name="foobar") + self.assertEqual("*.* syslog.host # foobar", str(r)) -- cgit v1.2.3 From f61a62434b36ab873b2b82a5ba69eda826755bfc Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 10:12:02 -0400 Subject: fix bug in remotes_to_rsyslog_cfg, add test --- cloudinit/config/cc_rsyslog.py | 4 +-- .../unittests/test_handler/test_handler_rsyslog.py | 32 ++++++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rsyslog.py b/cloudinit/config/cc_rsyslog.py index 8c02e826..915ab420 100644 --- a/cloudinit/config/cc_rsyslog.py +++ b/cloudinit/config/cc_rsyslog.py @@ -305,12 +305,12 @@ def remotes_to_rsyslog_cfg(remotes, header=None, footer=None): lines.append(header) for name, line in remotes.items(): try: - lines.append(parse_remotes_line(line, name=name)) + lines.append(str(parse_remotes_line(line, name=name))) except ValueError as e: LOG.warn("failed loading remote %s: %s [%s]", name, line, e) if footer is not None: lines.append(footer) - return '\n'.join(str(lines)) + '\n' + return '\n'.join(lines) + "\n" def handle(name, cfg, cloud, log, _args): diff --git a/tests/unittests/test_handler/test_handler_rsyslog.py b/tests/unittests/test_handler/test_handler_rsyslog.py index 0bace685..292559c5 100644 --- a/tests/unittests/test_handler/test_handler_rsyslog.py +++ b/tests/unittests/test_handler/test_handler_rsyslog.py @@ -4,7 +4,7 @@ import tempfile from cloudinit.config.cc_rsyslog import ( apply_rsyslog_changes, DEF_DIR, DEF_FILENAME, DEF_RELOAD, load_config, - parse_remotes_line) + parse_remotes_line, remotes_to_rsyslog_cfg) from cloudinit import util from .. import helpers as t_help @@ -80,10 +80,10 @@ class TestApplyChanges(t_help.TestCase): configs=configs, def_fname="default.cfg", cfg_dir=self.tmp) expected = [ - (os.path.join(self.tmp, "default.cfg"), - "*.* foohost\n"), - (os.path.join(self.tmp, "my.cfg"), "abc\n"), - (os.path.join(self.tmp, "mydir/mycfg"), "filefoo-content\n"), + (os.path.join(self.tmp, "default.cfg"), + "*.* foohost\n"), + (os.path.join(self.tmp, "my.cfg"), "abc\n"), + (os.path.join(self.tmp, "mydir/mycfg"), "filefoo-content\n"), ] self.assertEqual([f[0] for f in expected], changed) actual = [] @@ -108,7 +108,7 @@ class TestApplyChanges(t_help.TestCase): def test_multiline_content(self): configs = ['line1', 'line2\nline3\n'] - changed = apply_rsyslog_changes( + apply_rsyslog_changes( configs=configs, def_fname="default.cfg", cfg_dir=self.tmp) fname = os.path.join(self.tmp, "default.cfg") @@ -143,3 +143,23 @@ class TestParseRemotesLine(t_help.TestCase): def test_name_in_string(self): r = parse_remotes_line("syslog.host", name="foobar") self.assertEqual("*.* syslog.host # foobar", str(r)) + + +class TestRemotesToSyslog(t_help.TestCase): + def test_simple(self): + # str rendered line must appear in remotes_to_ryslog_cfg return + mycfg = "*.* myhost" + myline = str(parse_remotes_line(mycfg, name="myname")) + r = remotes_to_rsyslog_cfg({'myname': mycfg}) + lines = r.splitlines() + self.assertEqual(1, len(lines)) + self.assertTrue(myline in r.splitlines()) + + def test_header_footer(self): + header = "#foo head" + footer = "#foo foot" + r = remotes_to_rsyslog_cfg( + {'myname': "*.* myhost"}, header=header, footer=footer) + lines = r.splitlines() + self.assertTrue(header, lines[0]) + self.assertTrue(footer, lines[-1]) -- cgit v1.2.3 From c6e7fb1752a93ed534080adf0588e4c7cdd99071 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 10:34:31 -0400 Subject: add trailing newline only if necessary --- cloudinit/config/cc_rsyslog.py | 9 +++++---- tests/unittests/test_handler/test_handler_rsyslog.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rsyslog.py b/cloudinit/config/cc_rsyslog.py index 915ab420..2bb00728 100644 --- a/cloudinit/config/cc_rsyslog.py +++ b/cloudinit/config/cc_rsyslog.py @@ -188,8 +188,7 @@ def apply_rsyslog_changes(configs, def_fname, cfg_dir): LOG.warn("Entry %s has an empty filename", cur_pos + 1) continue - if not filename.startswith("/"): - filename = os.path.join(cfg_dir, filename) + filename = os.path.join(cfg_dir, filename) # Truncate filename first time you see it omode = "ab" @@ -198,8 +197,10 @@ def apply_rsyslog_changes(configs, def_fname, cfg_dir): files.append(filename) try: - contents = "%s\n" % (content) - util.write_file(filename, contents, omode=omode) + endl = "" + if not content.endswith("\n"): + endl = "\n" + util.write_file(filename, content + endl, omode=omode) except Exception: util.logexc(LOG, "Failed to write to %s", filename) diff --git a/tests/unittests/test_handler/test_handler_rsyslog.py b/tests/unittests/test_handler/test_handler_rsyslog.py index 292559c5..e7666615 100644 --- a/tests/unittests/test_handler/test_handler_rsyslog.py +++ b/tests/unittests/test_handler/test_handler_rsyslog.py @@ -112,7 +112,7 @@ class TestApplyChanges(t_help.TestCase): configs=configs, def_fname="default.cfg", cfg_dir=self.tmp) fname = os.path.join(self.tmp, "default.cfg") - expected_content = '\n'.join([c for c in configs]) + '\n' + expected_content = '\n'.join([c for c in configs]) found_content = util.load_file(fname) self.assertEqual(expected_content, found_content) -- cgit v1.2.3 From 0a581732a40ff814b1fc0dace9f519b7a5c779e6 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 10:48:32 -0400 Subject: must declare proto of '@' --- cloudinit/config/cc_rsyslog.py | 6 ++++-- tests/unittests/test_handler/test_handler_rsyslog.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rsyslog.py b/cloudinit/config/cc_rsyslog.py index 2bb00728..5ecf1629 100644 --- a/cloudinit/config/cc_rsyslog.py +++ b/cloudinit/config/cc_rsyslog.py @@ -251,6 +251,8 @@ class SyslogRemotesLine(object): match = "*.*" self.name = name self.match = match + if not proto: + proto = "udp" if proto == "@": proto = "udp" elif proto == "@@": @@ -281,9 +283,9 @@ class SyslogRemotesLine(object): def __str__(self): buf = self.match + " " if self.proto == "udp": - buf += " @" + buf += "@" elif self.proto == "tcp": - buf += " @@" + buf += "@@" if ":" in self.addr: buf += "[" + self.addr + "]" diff --git a/tests/unittests/test_handler/test_handler_rsyslog.py b/tests/unittests/test_handler/test_handler_rsyslog.py index e7666615..7bfa65a9 100644 --- a/tests/unittests/test_handler/test_handler_rsyslog.py +++ b/tests/unittests/test_handler/test_handler_rsyslog.py @@ -128,13 +128,13 @@ class TestParseRemotesLine(t_help.TestCase): def test_valid_ipv6(self): r = parse_remotes_line("*.* [::1]") - self.assertEqual("*.* [::1]", str(r)) + self.assertEqual("*.* @[::1]", str(r)) def test_valid_ipv6_with_port(self): r = parse_remotes_line("*.* [::1]:100") self.assertEqual(r.port, 100) self.assertEqual(r.addr, "::1") - self.assertEqual("*.* [::1]:100", str(r)) + self.assertEqual("*.* @[::1]:100", str(r)) def test_invalid_multiple_colon(self): with self.assertRaises(ValueError): @@ -142,7 +142,7 @@ class TestParseRemotesLine(t_help.TestCase): def test_name_in_string(self): r = parse_remotes_line("syslog.host", name="foobar") - self.assertEqual("*.* syslog.host # foobar", str(r)) + self.assertEqual("*.* @syslog.host # foobar", str(r)) class TestRemotesToSyslog(t_help.TestCase): -- cgit v1.2.3 From d5f93dbd908c349548554cb69ca3afd05077cf57 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 10:49:48 -0400 Subject: remove 'syslog' module (its been moved to rsyslog) --- cloudinit/config/cc_syslog.py | 183 --------------------- .../unittests/test_handler/test_handler_syslog.py | 32 ---- 2 files changed, 215 deletions(-) delete mode 100644 cloudinit/config/cc_syslog.py delete mode 100644 tests/unittests/test_handler/test_handler_syslog.py (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_syslog.py b/cloudinit/config/cc_syslog.py deleted file mode 100644 index 27793f8b..00000000 --- a/cloudinit/config/cc_syslog.py +++ /dev/null @@ -1,183 +0,0 @@ -# vi: ts=4 expandtab -# -# Copyright (C) 2015 Canonical Ltd. -# -# Author: Scott Moser -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License version 3, as -# published by the Free Software Foundation. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from cloudinit import log as logging -from cloudinit import util -from cloudinit.settings import PER_INSTANCE - -import re - -LOG = logging.getLogger(__name__) - -frequency = PER_INSTANCE - -BUILTIN_CFG = { - 'remotes_file': '/etc/rsyslog.d/20-cloudinit-remotes.conf', - 'remotes': {}, - 'service_name': 'rsyslog', -} - -COMMENT_RE = re.compile(r'[ ]*[#]+[ ]*') -HOST_PORT_RE = re.compile( - r'^(?P[@]{0,2})' - '(([[](?P[^\]]*)[\]])|(?P[^:]*))' - '([:](?P[0-9]+))?$') - - -def parse_remotes_line(line, name=None): - try: - data, comment = COMMENT_RE.split(line) - comment = comment.strip() - except ValueError: - data, comment = (line, None) - - toks = data.strip().split() - match = None - if len(toks) == 1: - host_port = data - elif len(toks) == 2: - match, host_port = toks - else: - raise ValueError("line had multiple spaces: %s" % data) - - toks = HOST_PORT_RE.match(host_port) - - if not toks: - raise ValueError("Invalid host specification '%s'" % host_port) - - proto = toks.group('proto') - addr = toks.group('addr') or toks.group('bracket_addr') - port = toks.group('port') - print("host_port: %s" % addr) - print("port: %s" % port) - - if addr.startswith("[") and not addr.endswith("]"): - raise ValueError("host spec had invalid brackets: %s" % addr) - - if comment and not name: - name = comment - - t = SyslogRemotesLine(name=name, match=match, proto=proto, - addr=addr, port=port) - t.validate() - return t - - -class SyslogRemotesLine(object): - def __init__(self, name=None, match=None, proto=None, addr=None, - port=None): - if not match: - match = "*.*" - self.name = name - self.match = match - if proto == "@": - proto = "udp" - elif proto == "@@": - proto = "tcp" - self.proto = proto - - self.addr = addr - if port: - self.port = int(port) - else: - self.port = None - - def validate(self): - if self.port: - try: - int(self.port) - except ValueError: - raise ValueError("port '%s' is not an integer" % self.port) - - if not self.addr: - raise ValueError("address is required") - - def __repr__(self): - return "[name=%s match=%s proto=%s address=%s port=%s]" % ( - self.name, self.match, self.proto, self.addr, self.port - ) - - def __str__(self): - buf = self.match + " " - if self.proto == "udp": - buf += " @" - elif self.proto == "tcp": - buf += " @@" - - if ":" in self.addr: - buf += "[" + self.addr + "]" - else: - buf += self.addr - - if self.port: - buf += ":%s" % self.port - - if self.name: - buf += " # %s" % self.name - return buf - - -def remotes_to_rsyslog_cfg(remotes, header=None): - if not remotes: - return None - lines = [] - if header is not None: - lines.append(header) - for name, line in remotes.items(): - try: - lines.append(parse_remotes_line(line, name=name)) - except ValueError as e: - LOG.warn("failed loading remote %s: %s [%s]", name, line, e) - return '\n'.join(str(lines)) + '\n' - - -def reload_syslog(systemd, service='rsyslog'): - if systemd: - cmd = ['systemctl', 'reload-or-try-restart', service] - else: - cmd = ['service', service, 'reload'] - try: - util.subp(cmd, capture=True) - except util.ProcessExecutionError as e: - LOG.warn("Failed to reload syslog using '%s': %s", ' '.join(cmd), e) - - -def handle(name, cfg, cloud, log, args): - cfgin = cfg.get('syslog') - if not cfgin: - cfgin = {} - mycfg = util.mergemanydict([cfgin, BUILTIN_CFG]) - - remotes_file = mycfg.get('remotes_file') - if util.is_false(remotes_file): - LOG.debug("syslog/remotes_file empty, doing nothing") - return - - remotes = mycfg.get('remotes', {}) - if remotes and not isinstance(remotes, dict): - LOG.warn("syslog/remotes: content is not a dictionary") - return - - config_data = remotes_to_rsyslog_cfg( - remotes, header="#cloud-init syslog module") - - util.write_file(remotes_file, config_data) - - reload_syslog( - systemd=cloud.distro.uses_systemd(), - service=mycfg.get('service_name')) diff --git a/tests/unittests/test_handler/test_handler_syslog.py b/tests/unittests/test_handler/test_handler_syslog.py deleted file mode 100644 index bbfd521e..00000000 --- a/tests/unittests/test_handler/test_handler_syslog.py +++ /dev/null @@ -1,32 +0,0 @@ -from cloudinit.config.cc_syslog import ( - parse_remotes_line, SyslogRemotesLine, remotes_to_rsyslog_cfg) -from cloudinit import util -from .. import helpers as t_help - - -class TestParseRemotesLine(t_help.TestCase): - def test_valid_port(self): - r = parse_remotes_line("foo:9") - self.assertEqual(9, r.port) - - def test_invalid_port(self): - with self.assertRaises(ValueError): - parse_remotes_line("*.* foo:abc") - - def test_valid_ipv6(self): - r = parse_remotes_line("*.* [::1]") - self.assertEqual("*.* [::1]", str(r)) - - def test_valid_ipv6_with_port(self): - r = parse_remotes_line("*.* [::1]:100") - self.assertEqual(r.port, 100) - self.assertEqual(r.addr, "::1") - self.assertEqual("*.* [::1]:100", str(r)) - - def test_invalid_multiple_colon(self): - with self.assertRaises(ValueError): - parse_remotes_line("*.* ::1:100") - - def test_name_in_string(self): - r = parse_remotes_line("syslog.host", name="foobar") - self.assertEqual("*.* syslog.host # foobar", str(r)) -- cgit v1.2.3 From 55472eb02eaa5b88676a96e006f6838020f8ffe3 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 11:44:32 -0400 Subject: rsyslog: skip empty or None in remotes format This allows user to specify the following to overwrite a previously declared entry without warnings. rsyslog: {'remotes': {'foo': None}} --- cloudinit/config/cc_rsyslog.py | 2 ++ tests/unittests/test_handler/test_handler_rsyslog.py | 9 +++++++++ 2 files changed, 11 insertions(+) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rsyslog.py b/cloudinit/config/cc_rsyslog.py index 5ecf1629..a0132d28 100644 --- a/cloudinit/config/cc_rsyslog.py +++ b/cloudinit/config/cc_rsyslog.py @@ -307,6 +307,8 @@ def remotes_to_rsyslog_cfg(remotes, header=None, footer=None): if header is not None: lines.append(header) for name, line in remotes.items(): + if not line: + continue try: lines.append(str(parse_remotes_line(line, name=name))) except ValueError as e: diff --git a/tests/unittests/test_handler/test_handler_rsyslog.py b/tests/unittests/test_handler/test_handler_rsyslog.py index 7bfa65a9..b932165c 100644 --- a/tests/unittests/test_handler/test_handler_rsyslog.py +++ b/tests/unittests/test_handler/test_handler_rsyslog.py @@ -163,3 +163,12 @@ class TestRemotesToSyslog(t_help.TestCase): lines = r.splitlines() self.assertTrue(header, lines[0]) self.assertTrue(footer, lines[-1]) + + def test_with_empty_or_null(self): + mycfg = "*.* myhost" + myline = str(parse_remotes_line(mycfg, name="myname")) + r = remotes_to_rsyslog_cfg( + {'myname': mycfg, 'removed': None, 'removed2': ""}) + lines = r.splitlines() + self.assertEqual(1, len(lines)) + self.assertTrue(myline in r.splitlines()) -- cgit v1.2.3 From c33b3becebfa7bf3f6e2ee67ea7bc3def6feeb8c Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 28 Jul 2015 16:15:10 -0400 Subject: pull from 2.0 trunk @ a433358bbcf4e8a771b80cae34468409ed5a811d --- cloudinit/registry.py | 23 +++++ cloudinit/reporting.py | 122 ++++++++++++++++++++++++ tests/unittests/test_registry.py | 28 ++++++ tests/unittests/test_reporting.py | 192 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 365 insertions(+) create mode 100644 cloudinit/registry.py create mode 100644 cloudinit/reporting.py create mode 100644 tests/unittests/test_registry.py create mode 100644 tests/unittests/test_reporting.py (limited to 'tests/unittests') diff --git a/cloudinit/registry.py b/cloudinit/registry.py new file mode 100644 index 00000000..46cf0585 --- /dev/null +++ b/cloudinit/registry.py @@ -0,0 +1,23 @@ +import copy + + +class DictRegistry(object): + """A simple registry for a mapping of objects.""" + + def __init__(self): + self._items = {} + + def register_item(self, key, item): + """Add item to the registry.""" + if key in self._items: + raise ValueError( + 'Item already registered with key {0}'.format(key)) + self._items[key] = item + + @property + def registered_items(self): + """All the items that have been registered. + + This cannot be used to modify the contents of the registry. + """ + return copy.copy(self._items) diff --git a/cloudinit/reporting.py b/cloudinit/reporting.py new file mode 100644 index 00000000..d2dd4fec --- /dev/null +++ b/cloudinit/reporting.py @@ -0,0 +1,122 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab +""" +cloud-init reporting framework + +The reporting framework is intended to allow all parts of cloud-init to +report events in a structured manner. +""" + +import abc +import logging + +from cloudinit.registry import DictRegistry + + +FINISH_EVENT_TYPE = 'finish' +START_EVENT_TYPE = 'start' + +DEFAULT_CONFIG = { + 'logging': {'type': 'log'}, +} + + +instantiated_handler_registry = DictRegistry() +available_handlers = DictRegistry() + + +class ReportingEvent(object): + """Encapsulation of event formatting.""" + + def __init__(self, event_type, name, description): + self.event_type = event_type + self.name = name + self.description = description + + def as_string(self): + """The event represented as a string.""" + return '{0}: {1}: {2}'.format( + self.event_type, self.name, self.description) + + +class FinishReportingEvent(ReportingEvent): + + def __init__(self, name, description, successful=None): + super(FinishReportingEvent, self).__init__( + FINISH_EVENT_TYPE, name, description) + self.successful = successful + + def as_string(self): + if self.successful is None: + return super(FinishReportingEvent, self).as_string() + success_string = 'success' if self.successful else 'fail' + return '{0}: {1}: {2}: {3}'.format( + self.event_type, self.name, success_string, self.description) + + +class ReportingHandler(object): + + @abc.abstractmethod + def publish_event(self, event): + raise NotImplementedError + + +class LogHandler(ReportingHandler): + """Publishes events to the cloud-init log at the ``INFO`` log level.""" + + def publish_event(self, event): + """Publish an event to the ``INFO`` log level.""" + logger = logging.getLogger( + '.'.join([__name__, event.event_type, event.name])) + logger.info(event.as_string()) + + +def add_configuration(config): + for handler_name, handler_config in config.items(): + handler_config = handler_config.copy() + cls = available_handlers.registered_items[handler_config.pop('type')] + instance = cls(**handler_config) + instantiated_handler_registry.register_item(handler_name, instance) + + +def report_event(event): + """Report an event to all registered event handlers. + + This should generally be called via one of the other functions in + the reporting module. + + :param event_type: + The type of the event; this should be a constant from the + reporting module. + """ + for _, handler in instantiated_handler_registry.registered_items.items(): + handler.publish_event(event) + + +def report_finish_event(event_name, event_description, successful=None): + """Report a "finish" event. + + See :py:func:`.report_event` for parameter details. + """ + event = FinishReportingEvent(event_name, event_description, successful) + return report_event(event) + + +def report_start_event(event_name, event_description): + """Report a "start" event. + + :param event_name: + The name of the event; this should be a topic which events would + share (e.g. it will be the same for start and finish events). + + :param event_description: + A human-readable description of the event that has occurred. + """ + event = ReportingEvent(START_EVENT_TYPE, event_name, event_description) + return report_event(event) + + +available_handlers.register_item('log', LogHandler) +add_configuration(DEFAULT_CONFIG) diff --git a/tests/unittests/test_registry.py b/tests/unittests/test_registry.py new file mode 100644 index 00000000..bcf01475 --- /dev/null +++ b/tests/unittests/test_registry.py @@ -0,0 +1,28 @@ +from cloudinit.registry import DictRegistry + +from .helpers import (mock, TestCase) + + +class TestDictRegistry(TestCase): + + def test_added_item_included_in_output(self): + registry = DictRegistry() + item_key, item_to_register = 'test_key', mock.Mock() + registry.register_item(item_key, item_to_register) + self.assertEqual({item_key: item_to_register}, + registry.registered_items) + + def test_registry_starts_out_empty(self): + self.assertEqual({}, DictRegistry().registered_items) + + def test_modifying_registered_items_isnt_exposed_to_other_callers(self): + registry = DictRegistry() + registry.registered_items['test_item'] = mock.Mock() + self.assertEqual({}, registry.registered_items) + + def test_keys_cannot_be_replaced(self): + registry = DictRegistry() + item_key = 'test_key' + registry.register_item(item_key, mock.Mock()) + self.assertRaises(ValueError, + registry.register_item, item_key, mock.Mock()) diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py new file mode 100644 index 00000000..f4011a79 --- /dev/null +++ b/tests/unittests/test_reporting.py @@ -0,0 +1,192 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab + +from cloudinit import reporting + +from .helpers import (mock, TestCase) + + +def _fake_registry(): + return mock.Mock(registered_items={'a': mock.MagicMock(), + 'b': mock.MagicMock()}) + + +class TestReportStartEvent(TestCase): + + @mock.patch('cloudinit.reporting.instantiated_handler_registry', + new_callable=_fake_registry) + def test_report_start_event_passes_something_with_as_string_to_handlers( + self, instantiated_handler_registry): + event_name, event_description = 'my_test_event', 'my description' + reporting.report_start_event(event_name, event_description) + expected_string_representation = ': '.join( + ['start', event_name, event_description]) + for _, handler in ( + instantiated_handler_registry.registered_items.items()): + self.assertEqual(1, handler.publish_event.call_count) + event = handler.publish_event.call_args[0][0] + self.assertEqual(expected_string_representation, event.as_string()) + + +class TestReportFinishEvent(TestCase): + + def _report_finish_event(self, successful=None): + event_name, event_description = 'my_test_event', 'my description' + reporting.report_finish_event( + event_name, event_description, successful=successful) + return event_name, event_description + + def assertHandlersPassedObjectWithAsString( + self, handlers, expected_as_string): + for _, handler in handlers.items(): + self.assertEqual(1, handler.publish_event.call_count) + event = handler.publish_event.call_args[0][0] + self.assertEqual(expected_as_string, event.as_string()) + + @mock.patch('cloudinit.reporting.instantiated_handler_registry', + new_callable=_fake_registry) + def test_report_finish_event_passes_something_with_as_string_to_handlers( + self, instantiated_handler_registry): + event_name, event_description = self._report_finish_event() + expected_string_representation = ': '.join( + ['finish', event_name, event_description]) + self.assertHandlersPassedObjectWithAsString( + instantiated_handler_registry.registered_items, + expected_string_representation) + + @mock.patch('cloudinit.reporting.instantiated_handler_registry', + new_callable=_fake_registry) + def test_reporting_successful_finish_has_sensible_string_repr( + self, instantiated_handler_registry): + event_name, event_description = self._report_finish_event( + successful=True) + expected_string_representation = ': '.join( + ['finish', event_name, 'success', event_description]) + self.assertHandlersPassedObjectWithAsString( + instantiated_handler_registry.registered_items, + expected_string_representation) + + @mock.patch('cloudinit.reporting.instantiated_handler_registry', + new_callable=_fake_registry) + def test_reporting_unsuccessful_finish_has_sensible_string_repr( + self, instantiated_handler_registry): + event_name, event_description = self._report_finish_event( + successful=False) + expected_string_representation = ': '.join( + ['finish', event_name, 'fail', event_description]) + self.assertHandlersPassedObjectWithAsString( + instantiated_handler_registry.registered_items, + expected_string_representation) + + +class TestReportingEvent(TestCase): + + def test_as_string(self): + event_type, name, description = 'test_type', 'test_name', 'test_desc' + event = reporting.ReportingEvent(event_type, name, description) + expected_string_representation = ': '.join( + [event_type, name, description]) + self.assertEqual(expected_string_representation, event.as_string()) + + +class TestReportingHandler(TestCase): + + def test_no_default_publish_event_implementation(self): + self.assertRaises(NotImplementedError, + reporting.ReportingHandler().publish_event, None) + + +class TestLogHandler(TestCase): + + @mock.patch.object(reporting.logging, 'getLogger') + def test_appropriate_logger_used(self, getLogger): + event_type, event_name = 'test_type', 'test_name' + event = reporting.ReportingEvent(event_type, event_name, 'description') + reporting.LogHandler().publish_event(event) + self.assertEqual( + [mock.call( + 'cloudinit.reporting.{0}.{1}'.format(event_type, event_name))], + getLogger.call_args_list) + + @mock.patch.object(reporting.logging, 'getLogger') + def test_single_log_message_at_info_published(self, getLogger): + event = reporting.ReportingEvent('type', 'name', 'description') + reporting.LogHandler().publish_event(event) + self.assertEqual(1, getLogger.return_value.info.call_count) + + @mock.patch.object(reporting.logging, 'getLogger') + def test_log_message_uses_event_as_string(self, getLogger): + event = reporting.ReportingEvent('type', 'name', 'description') + reporting.LogHandler().publish_event(event) + self.assertIn(event.as_string(), + getLogger.return_value.info.call_args[0][0]) + + +class TestDefaultRegisteredHandler(TestCase): + + def test_log_handler_registered_by_default(self): + registered_items = ( + reporting.instantiated_handler_registry.registered_items) + for _, item in registered_items.items(): + if isinstance(item, reporting.LogHandler): + break + else: + self.fail('No reporting LogHandler registered by default.') + + +class TestReportingConfiguration(TestCase): + + @mock.patch.object(reporting, 'instantiated_handler_registry') + def test_empty_configuration_doesnt_add_handlers( + self, instantiated_handler_registry): + reporting.add_configuration({}) + self.assertEqual( + 0, instantiated_handler_registry.register_item.call_count) + + @mock.patch.object( + reporting, 'instantiated_handler_registry', reporting.DictRegistry()) + @mock.patch.object(reporting, 'available_handlers') + def test_looks_up_handler_by_type_and_adds_it(self, available_handlers): + handler_type_name = 'test_handler' + handler_cls = mock.Mock() + available_handlers.registered_items = {handler_type_name: handler_cls} + handler_name = 'my_test_handler' + reporting.add_configuration( + {handler_name: {'type': handler_type_name}}) + self.assertEqual( + {handler_name: handler_cls.return_value}, + reporting.instantiated_handler_registry.registered_items) + + @mock.patch.object( + reporting, 'instantiated_handler_registry', reporting.DictRegistry()) + @mock.patch.object(reporting, 'available_handlers') + def test_uses_non_type_parts_of_config_dict_as_kwargs( + self, available_handlers): + handler_type_name = 'test_handler' + handler_cls = mock.Mock() + available_handlers.registered_items = {handler_type_name: handler_cls} + extra_kwargs = {'foo': 'bar', 'bar': 'baz'} + handler_config = extra_kwargs.copy() + handler_config.update({'type': handler_type_name}) + handler_name = 'my_test_handler' + reporting.add_configuration({handler_name: handler_config}) + self.assertEqual( + handler_cls.return_value, + reporting.instantiated_handler_registry.registered_items[ + handler_name]) + self.assertEqual([mock.call(**extra_kwargs)], + handler_cls.call_args_list) + + @mock.patch.object( + reporting, 'instantiated_handler_registry', reporting.DictRegistry()) + @mock.patch.object(reporting, 'available_handlers') + def test_handler_config_not_modified(self, available_handlers): + handler_type_name = 'test_handler' + handler_cls = mock.Mock() + available_handlers.registered_items = {handler_type_name: handler_cls} + handler_config = {'type': handler_type_name, 'foo': 'bar'} + expected_handler_config = handler_config.copy() + reporting.add_configuration({'my_test_handler': handler_config}) + self.assertEqual(expected_handler_config, handler_config) -- cgit v1.2.3 From b5574a9925b29417a1b351e7b38c54bc7d144dba Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 30 Jul 2015 18:06:01 -0400 Subject: tests pass --- bin/cloud-init | 28 ++++++++++-- cloudinit/reporting.py | 91 +++++++++++++++++++++++++++++++++++---- cloudinit/sources/__init__.py | 16 ++++--- cloudinit/stages.py | 10 ++++- tests/unittests/test_reporting.py | 14 +++--- 5 files changed, 134 insertions(+), 25 deletions(-) (limited to 'tests/unittests') diff --git a/bin/cloud-init b/bin/cloud-init index 1d3e7ee3..7f21e49f 100755 --- a/bin/cloud-init +++ b/bin/cloud-init @@ -46,6 +46,7 @@ from cloudinit import sources from cloudinit import stages from cloudinit import templater from cloudinit import util +from cloudinit import reporting from cloudinit import version from cloudinit.settings import (PER_INSTANCE, PER_ALWAYS, PER_ONCE, @@ -313,7 +314,7 @@ def main_modules(action_name, args): # 5. Run the modules for the given stage name # 6. Done! w_msg = welcome_format("%s:%s" % (action_name, name)) - init = stages.Init(ds_deps=[]) + init = stages.Init(ds_deps=[], reporter=args.reporter) # Stage 1 init.read_cfg(extract_fns(args)) # Stage 2 @@ -549,6 +550,8 @@ def main(): ' found (use at your own risk)'), dest='force', default=False) + + parser.set_defaults(reporter=None) subparsers = parser.add_subparsers() # Each action and its sub-options (if any) @@ -595,6 +598,9 @@ def main(): help=("frequency of the module"), required=False, choices=list(FREQ_SHORT_NAMES.keys())) + parser_single.add_argument("--report", action="store_true", + help="enable reporting", + required=False) parser_single.add_argument("module_args", nargs="*", metavar='argument', help=('any additional arguments to' @@ -617,8 +623,24 @@ def main(): if name in ("modules", "init"): functor = status_wrapper - return util.log_time(logfunc=LOG.debug, msg="cloud-init mode '%s'" % name, - get_uptime=True, func=functor, args=(name, args)) + reporting = True + if name == "init": + if args.local: + rname, rdesc = ("init-local", "searching for local datasources") + else: + rname, rdesc = ("init-network", "searching for network datasources") + elif name == "modules": + rname, rdesc = ("modules-%s" % args.mode, "running modules for %s") + elif name == "single": + rname, rdesc = ("single/%s" % args.name, + "running single module %s" % args.name) + reporting = args.report + + reporter = reporting.ReportStack(rname, rdesc, reporting=reporting) + with reporter: + return util.log_time( + logfunc=LOG.debug, msg="cloud-init mode '%s'" % name, + get_uptime=True, func=functor, args=(name, args)) if __name__ == '__main__': diff --git a/cloudinit/reporting.py b/cloudinit/reporting.py index d2dd4fec..c925f661 100644 --- a/cloudinit/reporting.py +++ b/cloudinit/reporting.py @@ -20,9 +20,18 @@ START_EVENT_TYPE = 'start' DEFAULT_CONFIG = { 'logging': {'type': 'log'}, + 'print': {'type': 'print'}, } +class _nameset(set): + def __getattr__(self, name): + if name in self: + return name + raise AttributeError + +status = _nameset(("SUCCESS", "WARN", "FAIL")) + instantiated_handler_registry = DictRegistry() available_handlers = DictRegistry() @@ -43,17 +52,18 @@ class ReportingEvent(object): class FinishReportingEvent(ReportingEvent): - def __init__(self, name, description, successful=None): + def __init__(self, name, description, result=None): super(FinishReportingEvent, self).__init__( FINISH_EVENT_TYPE, name, description) - self.successful = successful + if result is None: + result = status.SUCCESS + self.result = result + if result not in status: + raise ValueError("Invalid result: %s" % result) def as_string(self): - if self.successful is None: - return super(FinishReportingEvent, self).as_string() - success_string = 'success' if self.successful else 'fail' return '{0}: {1}: {2}: {3}'.format( - self.event_type, self.name, success_string, self.description) + self.event_type, self.name, self.result, self.description) class ReportingHandler(object): @@ -73,6 +83,11 @@ class LogHandler(ReportingHandler): logger.info(event.as_string()) +class PrintHandler(ReportingHandler): + def publish_event(self, event): + print(event.as_string()) + + def add_configuration(config): for handler_name, handler_config in config.items(): handler_config = handler_config.copy() @@ -95,12 +110,12 @@ def report_event(event): handler.publish_event(event) -def report_finish_event(event_name, event_description, successful=None): +def report_finish_event(event_name, event_description, result): """Report a "finish" event. See :py:func:`.report_event` for parameter details. """ - event = FinishReportingEvent(event_name, event_description, successful) + event = FinishReportingEvent(event_name, event_description, result) return report_event(event) @@ -118,5 +133,65 @@ def report_start_event(event_name, event_description): return report_event(event) +class ReportStack(object): + def __init__(self, name, description, parent=None, reporting=None, + exc_result=None): + self.parent = parent + self.reporting = reporting + self.name = name + self.description = description + + if exc_result is None: + exc_result = status.FAIL + self.exc_result = exc_result + + if reporting is None: + # if reporting is specified respect it, otherwise use parent's value + if parent: + reporting = parent.reporting + else: + reporting = True + if parent: + self.fullname = '/'.join((name, parent.fullname,)) + else: + self.fullname = self.name + self.children = {} + + def __enter__(self): + self.exception = None + if self.reporting: + report_start_event(self.fullname, self.description) + if self.parent: + self.parent.children[self.name] = (None, None) + return self + + def childrens_finish_info(self, result=None, description=None): + for result in (status.FAIL, status.WARN): + for name, (value, msg) in self.children.items(): + if value == result: + return (result, "[" + name + "]" + msg) + if result is None: + result = status.SUCCESS + if description is None: + description = self.description + return (result, description) + + def finish_info(self, exc): + # return tuple of description, and value + if exc: + # by default, exceptions are fatal + return (self.exc_result, self.description) + return self.childrens_finish_info() + + def __exit__(self, exc_type, exc_value, traceback): + self.exception = exc_value + (result, msg) = self.finish_info(exc_value) + if self.parent: + self.parent.children[self.name] = (result, msg) + if self.reporting: + report_finish_event(self.fullname, msg, result) + + available_handlers.register_item('log', LogHandler) +available_handlers.register_item('print', PrintHandler) add_configuration(DEFAULT_CONFIG) diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index a21c08c2..c4848d5d 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -27,6 +27,7 @@ import six from cloudinit import importer from cloudinit import log as logging +from cloudinit import reporting from cloudinit import type_utils from cloudinit import user_data as ud from cloudinit import util @@ -246,17 +247,22 @@ def normalize_pubkey_data(pubkey_data): return keys -def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list): +def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list, reporter): ds_list = list_sources(cfg_list, ds_deps, pkg_list) ds_names = [type_utils.obj_name(f) for f in ds_list] LOG.debug("Searching for data source in: %s", ds_names) for cls in ds_list: + myreporter = reporting.ReportStack( + "check-%s" % cls, "searching for %s" % cls, + parent=reporter, exc_result=reporting.status.WARN) + try: - LOG.debug("Seeing if we can get any data from %s", cls) - s = cls(sys_cfg, distro, paths) - if s.get_data(): - return (s, type_utils.obj_name(cls)) + with myreporter: + LOG.debug("Seeing if we can get any data from %s", cls) + s = cls(sys_cfg, distro, paths) + if s.get_data(): + return (s, type_utils.obj_name(cls)) except Exception: util.logexc(LOG, "Getting data from %s failed", cls) diff --git a/cloudinit/stages.py b/cloudinit/stages.py index d28e765b..dbcdbece 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -46,6 +46,7 @@ from cloudinit import log as logging from cloudinit import sources from cloudinit import type_utils from cloudinit import util +from cloudinit import reporting LOG = logging.getLogger(__name__) @@ -53,7 +54,7 @@ NULL_DATA_SOURCE = None class Init(object): - def __init__(self, ds_deps=None): + def __init__(self, reporter=None, ds_deps=None): if ds_deps is not None: self.ds_deps = ds_deps else: @@ -65,6 +66,11 @@ class Init(object): # Changed only when a fetch occurs self.datasource = NULL_DATA_SOURCE + if reporter is None: + reporter = reporting.ReportStack( + name="init-reporter", description="init-desc", reporting=False) + self.reporter = reporter + def _reset(self, reset_ds=False): # Recreated on access self._cfg = None @@ -246,7 +252,7 @@ class Init(object): self.paths, copy.deepcopy(self.ds_deps), cfg_list, - pkg_list) + pkg_list, self.reporter) LOG.info("Loaded datasource %s - %s", dsname, ds) self.datasource = ds # Ensure we adjust our path members datasource diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index f4011a79..5700118f 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -32,10 +32,10 @@ class TestReportStartEvent(TestCase): class TestReportFinishEvent(TestCase): - def _report_finish_event(self, successful=None): + def _report_finish_event(self, result=None): event_name, event_description = 'my_test_event', 'my description' reporting.report_finish_event( - event_name, event_description, successful=successful) + event_name, event_description, result=result) return event_name, event_description def assertHandlersPassedObjectWithAsString( @@ -51,7 +51,7 @@ class TestReportFinishEvent(TestCase): self, instantiated_handler_registry): event_name, event_description = self._report_finish_event() expected_string_representation = ': '.join( - ['finish', event_name, event_description]) + ['finish', event_name, reporting.status.SUCCESS, event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) @@ -61,9 +61,9 @@ class TestReportFinishEvent(TestCase): def test_reporting_successful_finish_has_sensible_string_repr( self, instantiated_handler_registry): event_name, event_description = self._report_finish_event( - successful=True) + result=reporting.status.SUCCESS) expected_string_representation = ': '.join( - ['finish', event_name, 'success', event_description]) + ['finish', event_name, reporting.status.SUCCESS, event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) @@ -73,9 +73,9 @@ class TestReportFinishEvent(TestCase): def test_reporting_unsuccessful_finish_has_sensible_string_repr( self, instantiated_handler_registry): event_name, event_description = self._report_finish_event( - successful=False) + result=reporting.status.FAIL) expected_string_representation = ': '.join( - ['finish', event_name, 'fail', event_description]) + ['finish', event_name, reporting.status.FAIL, event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) -- cgit v1.2.3 From 89c564a6fd5ac89869f83541370557e3fa58495c Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Sun, 2 Aug 2015 17:51:40 -0400 Subject: fix tests from sync change ReportStack to ReportEventStack change default ReportEventStack to be status.SUCCESS instead of None --- bin/cloud-init | 3 +- cloudinit/cloud.py | 2 +- cloudinit/reporting/__init__.py | 91 +++++++++++++++++++++++++++++++++++---- cloudinit/reporting/handlers.py | 7 +++ cloudinit/sources/__init__.py | 2 +- cloudinit/stages.py | 10 ++--- tests/unittests/test_reporting.py | 19 ++++---- 7 files changed, 109 insertions(+), 25 deletions(-) (limited to 'tests/unittests') diff --git a/bin/cloud-init b/bin/cloud-init index d369a806..51253c42 100755 --- a/bin/cloud-init +++ b/bin/cloud-init @@ -637,7 +637,8 @@ def main(): "running single module %s" % args.name) report_on = args.report - args.reporter = reporting.ReportStack( + reporting.add_configuration({'print': {'type': 'print'}}) + args.reporter = reporting.ReportEventStack( rname, rdesc, reporting_enabled=report_on) with args.reporter: return util.log_time( diff --git a/cloudinit/cloud.py b/cloudinit/cloud.py index 71eb80eb..a0fb42a3 100644 --- a/cloudinit/cloud.py +++ b/cloudinit/cloud.py @@ -47,7 +47,7 @@ class Cloud(object): self._cfg = cfg self._runners = runners if reporter is None: - reporter = reporting.ReportStack( + reporter = reporting.ReportEventStack( name="unnamed-cloud-reporter", description="unnamed-cloud-reporter", reporting_enabled=False) diff --git a/cloudinit/reporting/__init__.py b/cloudinit/reporting/__init__.py index b0364eec..78dde715 100644 --- a/cloudinit/reporting/__init__.py +++ b/cloudinit/reporting/__init__.py @@ -22,6 +22,15 @@ DEFAULT_CONFIG = { instantiated_handler_registry = DictRegistry() +class _nameset(set): + def __getattr__(self, name): + if name in self: + return name + raise AttributeError("%s not a valid value" % name) + + +status = _nameset(("SUCCESS", "WARN", "FAIL")) + class ReportingEvent(object): """Encapsulation of event formatting.""" @@ -39,17 +48,16 @@ class ReportingEvent(object): class FinishReportingEvent(ReportingEvent): - def __init__(self, name, description, successful=None): + def __init__(self, name, description, result=status.SUCCESS): super(FinishReportingEvent, self).__init__( FINISH_EVENT_TYPE, name, description) - self.successful = successful + self.result = result + if result not in status: + raise ValueError("Invalid result: %s" % result) def as_string(self): - if self.successful is None: - return super(FinishReportingEvent, self).as_string() - success_string = 'success' if self.successful else 'fail' return '{0}: {1}: {2}: {3}'.format( - self.event_type, self.name, success_string, self.description) + self.event_type, self.name, self.result, self.description) def add_configuration(config): @@ -74,12 +82,13 @@ def report_event(event): handler.publish_event(event) -def report_finish_event(event_name, event_description, successful=None): +def report_finish_event(event_name, event_description, + result=status.SUCCESS): """Report a "finish" event. See :py:func:`.report_event` for parameter details. """ - event = FinishReportingEvent(event_name, event_description, successful) + event = FinishReportingEvent(event_name, event_description, result) return report_event(event) @@ -97,4 +106,70 @@ def report_start_event(event_name, event_description): return report_event(event) +class ReportEventStack(object): + def __init__(self, name, description, message=None, parent=None, + reporting_enabled=None, result_on_exception=status.FAIL): + self.parent = parent + self.name = name + self.description = description + self.message = message + self.result_on_exception = result_on_exception + self.result = status.SUCCESS + + # use parents reporting value if not provided + if reporting_enabled is None: + if parent: + reporting_enabled = parent.reporting_enabled + else: + reporting_enabled = True + self.reporting_enabled = reporting_enabled + + if parent: + self.fullname = '/'.join((parent.fullname, name,)) + else: + self.fullname = self.name + self.children = {} + + def __repr__(self): + return ("%s reporting=%s" % (self.fullname, self.reporting_enabled)) + + def __enter__(self): + self.result = status.SUCCESS + if self.reporting_enabled: + report_start_event(self.fullname, self.description) + if self.parent: + self.parent.children[self.name] = (None, None) + return self + + def childrens_finish_info(self): + for cand_result in (status.FAIL, status.WARN): + for name, (value, msg) in self.children.items(): + if value == cand_result: + return (value, "[" + name + "]" + msg) + return (self.result, self.message) + + @property + def message(self): + if self._message is not None: + return self._message + return self.description + + @message.setter + def message(self, value): + self._message = value + + def finish_info(self, exc): + # return tuple of description, and value + if exc: + return (self.result_on_exception, self.message) + return self.childrens_finish_info() + + def __exit__(self, exc_type, exc_value, traceback): + (result, msg) = self.finish_info(exc_value) + if self.parent: + self.parent.children[self.name] = (result, msg) + if self.reporting_enabled: + report_finish_event(self.fullname, msg, result) + + add_configuration(DEFAULT_CONFIG) diff --git a/cloudinit/reporting/handlers.py b/cloudinit/reporting/handlers.py index be323f53..1d5ca524 100644 --- a/cloudinit/reporting/handlers.py +++ b/cloudinit/reporting/handlers.py @@ -21,5 +21,12 @@ class LogHandler(ReportingHandler): logger.info(event.as_string()) +class StderrHandler(ReportingHandler): + def publish_event(self, event): + #sys.stderr.write(event.as_string() + "\n") + print(event.as_string()) + + available_handlers = DictRegistry() available_handlers.register_item('log', LogHandler) +available_handlers.register_item('print', StderrHandler) diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 3b48f173..d07cf1fa 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -254,7 +254,7 @@ def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list, reporter): LOG.debug("Searching for %s data source in: %s", mode, ds_names) for name, cls in zip(ds_names, ds_list): - myrep = reporting.ReportStack( + myrep = reporting.ReportEventStack( name="search-%s-%s" % (mode, name.replace("DataSource", "")), description="searching for %s data from %s" % (mode, name), message = "no %s data found from %s" % (mode, name), diff --git a/cloudinit/stages.py b/cloudinit/stages.py index 8c79ae4e..42989bb4 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -67,7 +67,7 @@ class Init(object): self.datasource = NULL_DATA_SOURCE if reporter is None: - reporter = reporting.ReportStack( + reporter = reporting.ReportEventStack( name="init-reporter", description="init-desc", reporting_enabled=False) self.reporter = reporter @@ -242,7 +242,7 @@ class Init(object): if self.datasource is not NULL_DATA_SOURCE: return self.datasource - with reporting.ReportStack( + with reporting.ReportEventStack( name="check-cache", description="attempting to read from cache", parent=self.reporter) as myrep: ds = self._restore_from_cache() @@ -508,11 +508,11 @@ class Init(object): def consume_data(self, frequency=PER_INSTANCE): # Consume the userdata first, because we need want to let the part # handlers run first (for merging stuff) - with reporting.ReportStack( + with reporting.ReportEventStack( "consume-user-data", "reading and applying user-data", parent=self.reporter): self._consume_userdata(frequency) - with reporting.ReportStack( + with reporting.ReportEventStack( "consume-vendor-data", "reading and applying vendor-data", parent=self.reporter): self._consume_userdata(frequency) @@ -705,7 +705,7 @@ class Modules(object): run_name = "config-%s" % (name) desc="running %s with frequency %s" % (run_name, freq) - myrep = reporting.ReportStack( + myrep = reporting.ReportEventStack( name=run_name, description=desc, parent=self.reporter) with myrep: diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index 5700118f..4f4cf3a4 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -32,7 +32,7 @@ class TestReportStartEvent(TestCase): class TestReportFinishEvent(TestCase): - def _report_finish_event(self, result=None): + def _report_finish_event(self, result=reporting.status.SUCCESS): event_name, event_description = 'my_test_event', 'my description' reporting.report_finish_event( event_name, event_description, result=result) @@ -95,31 +95,32 @@ class TestReportingHandler(TestCase): def test_no_default_publish_event_implementation(self): self.assertRaises(NotImplementedError, - reporting.ReportingHandler().publish_event, None) + reporting.handlers.ReportingHandler().publish_event, + None) class TestLogHandler(TestCase): - @mock.patch.object(reporting.logging, 'getLogger') + @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_appropriate_logger_used(self, getLogger): event_type, event_name = 'test_type', 'test_name' event = reporting.ReportingEvent(event_type, event_name, 'description') - reporting.LogHandler().publish_event(event) + reporting.handlers.LogHandler().publish_event(event) self.assertEqual( [mock.call( 'cloudinit.reporting.{0}.{1}'.format(event_type, event_name))], getLogger.call_args_list) - @mock.patch.object(reporting.logging, 'getLogger') + @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_single_log_message_at_info_published(self, getLogger): event = reporting.ReportingEvent('type', 'name', 'description') - reporting.LogHandler().publish_event(event) + reporting.handlers.LogHandler().publish_event(event) self.assertEqual(1, getLogger.return_value.info.call_count) - @mock.patch.object(reporting.logging, 'getLogger') + @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_log_message_uses_event_as_string(self, getLogger): event = reporting.ReportingEvent('type', 'name', 'description') - reporting.LogHandler().publish_event(event) + reporting.handlers.LogHandler().publish_event(event) self.assertIn(event.as_string(), getLogger.return_value.info.call_args[0][0]) @@ -130,7 +131,7 @@ class TestDefaultRegisteredHandler(TestCase): registered_items = ( reporting.instantiated_handler_registry.registered_items) for _, item in registered_items.items(): - if isinstance(item, reporting.LogHandler): + if isinstance(item, reporting.handlers.LogHandler): break else: self.fail('No reporting LogHandler registered by default.') -- cgit v1.2.3 From 5585b397cfb4ba397e9cfba3d86e3d10af20eb71 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 4 Aug 2015 22:01:27 -0500 Subject: fix pep8 --- bin/cloud-init | 3 ++- cloudinit/reporting/handlers.py | 7 ------- cloudinit/sources/__init__.py | 2 +- cloudinit/stages.py | 10 ++++++---- tests/unittests/test_reporting.py | 6 ++++-- 5 files changed, 13 insertions(+), 15 deletions(-) (limited to 'tests/unittests') diff --git a/bin/cloud-init b/bin/cloud-init index 51253c42..40cdbb06 100755 --- a/bin/cloud-init +++ b/bin/cloud-init @@ -628,7 +628,8 @@ def main(): if args.local: rname, rdesc = ("init-local", "searching for local datasources") else: - rname, rdesc = ("init-network", "searching for network datasources") + rname, rdesc = ("init-network", + "searching for network datasources") elif name == "modules": rname, rdesc = ("modules-%s" % args.mode, "running modules for %s" % args.mode) diff --git a/cloudinit/reporting/handlers.py b/cloudinit/reporting/handlers.py index 1d5ca524..be323f53 100644 --- a/cloudinit/reporting/handlers.py +++ b/cloudinit/reporting/handlers.py @@ -21,12 +21,5 @@ class LogHandler(ReportingHandler): logger.info(event.as_string()) -class StderrHandler(ReportingHandler): - def publish_event(self, event): - #sys.stderr.write(event.as_string() + "\n") - print(event.as_string()) - - available_handlers = DictRegistry() available_handlers.register_item('log', LogHandler) -available_handlers.register_item('print', StderrHandler) diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index cf50c1fb..838cd198 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -257,7 +257,7 @@ def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list, reporter): myrep = reporting.ReportEventStack( name="search-%s" % name.replace("DataSource", ""), description="searching for %s data from %s" % (mode, name), - message = "no %s data found from %s" % (mode, name), + message="no %s data found from %s" % (mode, name), parent=reporter) try: with myrep: diff --git a/cloudinit/stages.py b/cloudinit/stages.py index 7b489b9f..d300709d 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -243,7 +243,8 @@ class Init(object): return self.datasource with reporting.ReportEventStack( - name="check-cache", description="attempting to read from cache", + name="check-cache", + description="attempting to read from cache", parent=self.reporter) as myrep: ds = self._restore_from_cache() if ds: @@ -708,17 +709,18 @@ class Modules(object): # This name will affect the semaphore name created run_name = "config-%s" % (name) - desc="running %s with frequency %s" % (run_name, freq) + desc = "running %s with frequency %s" % (run_name, freq) myrep = reporting.ReportEventStack( name=run_name, description=desc, parent=self.reporter) with myrep: - ran, _r = cc.run(run_name, mod.handle, func_args, freq=freq) + ran, _r = cc.run(run_name, mod.handle, func_args, + freq=freq) if ran: myrep.message = "%s ran successfully" % run_name else: myrep.message = "%s previously ran" % run_name - + except Exception as e: util.logexc(LOG, "Running module %s (%s) failed", name, mod) failures.append((name, e)) diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index 4f4cf3a4..ddfac541 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -51,7 +51,8 @@ class TestReportFinishEvent(TestCase): self, instantiated_handler_registry): event_name, event_description = self._report_finish_event() expected_string_representation = ': '.join( - ['finish', event_name, reporting.status.SUCCESS, event_description]) + ['finish', event_name, reporting.status.SUCCESS, + event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) @@ -63,7 +64,8 @@ class TestReportFinishEvent(TestCase): event_name, event_description = self._report_finish_event( result=reporting.status.SUCCESS) expected_string_representation = ': '.join( - ['finish', event_name, reporting.status.SUCCESS, event_description]) + ['finish', event_name, reporting.status.SUCCESS, + event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) -- cgit v1.2.3 From 963647d5197523fabe319df5f0502ec6dce64bd6 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 4 Aug 2015 22:03:01 -0500 Subject: sync tests back --- tests/unittests/test_reporting.py | 131 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) (limited to 'tests/unittests') diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index ddfac541..ffeb55d2 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -82,6 +82,9 @@ class TestReportFinishEvent(TestCase): instantiated_handler_registry.registered_items, expected_string_representation) + def test_invalid_result_raises_attribute_error(self): + self.assertRaises(ValueError, self._report_finish_event, ("BOGUS",)) + class TestReportingEvent(TestCase): @@ -193,3 +196,131 @@ class TestReportingConfiguration(TestCase): expected_handler_config = handler_config.copy() reporting.add_configuration({'my_test_handler': handler_config}) self.assertEqual(expected_handler_config, handler_config) + + +class TestReportingEventStack(TestCase): + @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.report_start_event') + def test_start_and_finish_success(self, report_start, report_finish): + with reporting.ReportEventStack(name="myname", description="mydesc"): + pass + self.assertEqual( + [mock.call('myname', 'mydesc')], report_start.call_args_list) + self.assertEqual( + [mock.call('myname', 'mydesc', reporting.status.SUCCESS)], + report_finish.call_args_list) + + @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.report_start_event') + def test_finish_exception_defaults_fail(self, report_start, report_finish): + name = "myname" + desc = "mydesc" + try: + with reporting.ReportEventStack(name, description=desc): + raise ValueError("This didnt work") + except ValueError: + pass + self.assertEqual([mock.call(name, desc)], report_start.call_args_list) + self.assertEqual( + [mock.call(name, desc, reporting.status.FAIL)], + report_finish.call_args_list) + + @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.report_start_event') + def test_result_on_exception_used(self, report_start, report_finish): + name = "myname" + desc = "mydesc" + try: + with reporting.ReportEventStack( + name, desc, result_on_exception=reporting.status.WARN): + raise ValueError("This didnt work") + except ValueError: + pass + self.assertEqual([mock.call(name, desc)], report_start.call_args_list) + self.assertEqual( + [mock.call(name, desc, reporting.status.WARN)], + report_finish.call_args_list) + + @mock.patch('cloudinit.reporting.report_start_event') + def test_child_fullname_respects_parent(self, report_start): + parent_name = "topname" + c1_name = "c1name" + c2_name = "c2name" + c2_expected_fullname = '/'.join([parent_name, c1_name, c2_name]) + c1_expected_fullname = '/'.join([parent_name, c1_name]) + + parent = reporting.ReportEventStack(parent_name, "topdesc") + c1 = reporting.ReportEventStack(c1_name, "c1desc", parent=parent) + c2 = reporting.ReportEventStack(c2_name, "c2desc", parent=c1) + with c1: + report_start.assert_called_with(c1_expected_fullname, "c1desc") + with c2: + report_start.assert_called_with(c2_expected_fullname, "c2desc") + + @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.report_start_event') + def test_child_result_bubbles_up(self, report_start, report_finish): + parent = reporting.ReportEventStack("topname", "topdesc") + child = reporting.ReportEventStack("c_name", "c_desc", parent=parent) + with parent: + with child: + child.result = reporting.status.WARN + + report_finish.assert_called_with( + "topname", "topdesc", reporting.status.WARN) + + @mock.patch('cloudinit.reporting.report_finish_event') + def test_message_used_in_finish(self, report_finish): + with reporting.ReportEventStack("myname", "mydesc", + message="mymessage"): + pass + self.assertEqual( + [mock.call("myname", "mymessage", reporting.status.SUCCESS)], + report_finish.call_args_list) + + @mock.patch('cloudinit.reporting.report_finish_event') + def test_message_updatable(self, report_finish): + with reporting.ReportEventStack("myname", "mydesc") as c: + c.message = "all good" + self.assertEqual( + [mock.call("myname", "all good", reporting.status.SUCCESS)], + report_finish.call_args_list) + + @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.report_finish_event') + def test_reporting_disabled_does_not_report_events( + self, report_start, report_finish): + with reporting.ReportEventStack("a", "b", reporting_enabled=False): + pass + self.assertEqual(report_start.call_count, 0) + self.assertEqual(report_finish.call_count, 0) + + @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.report_finish_event') + def test_reporting_child_default_to_parent( + self, report_start, report_finish): + parent = reporting.ReportEventStack( + "pname", "pdesc", reporting_enabled=False) + child = reporting.ReportEventStack("cname", "cdesc", parent=parent) + with parent: + with child: + pass + pass + self.assertEqual(report_start.call_count, 0) + self.assertEqual(report_finish.call_count, 0) + + def test_reporting_event_has_sane_repr(self): + myrep = reporting.ReportEventStack("fooname", "foodesc", + reporting_enabled=True).__repr__() + self.assertIn("fooname", myrep) + self.assertIn("foodesc", myrep) + self.assertIn("True", myrep) + + def test_set_invalid_result_raises_value_error(self): + f = reporting.ReportEventStack("myname", "mydesc") + self.assertRaises(ValueError, setattr, f, "result", "BOGUS") + + +class TestStatusAccess(TestCase): + def test_invalid_status_access_raises_value_error(self): + self.assertRaises(AttributeError, getattr, reporting.status, "BOGUS") -- cgit v1.2.3 From 6fdb23b6cbc8de14ebcffc17e9e49342b7bf193d Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 6 Aug 2015 18:19:46 -0500 Subject: sync with cloudinit 2.0 for registry and reporting --- cloudinit/registry.py | 14 +++++++++++ cloudinit/reporting/__init__.py | 29 +++++++++++++++++++--- cloudinit/reporting/handlers.py | 15 +++++++++++- tests/unittests/test_reporting.py | 51 ++++++++++++++++++++++++++++++++------- 4 files changed, 95 insertions(+), 14 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/registry.py b/cloudinit/registry.py index 46cf0585..04368ddf 100644 --- a/cloudinit/registry.py +++ b/cloudinit/registry.py @@ -1,3 +1,7 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab import copy @@ -5,6 +9,9 @@ class DictRegistry(object): """A simple registry for a mapping of objects.""" def __init__(self): + self.reset() + + def reset(self): self._items = {} def register_item(self, key, item): @@ -14,6 +21,13 @@ class DictRegistry(object): 'Item already registered with key {0}'.format(key)) self._items[key] = item + def unregister_item(self, key, force=True): + """Remove item from the registry.""" + if key in self._items: + del self._items[key] + elif not force: + raise KeyError("%s: key not present to unregister" % key) + @property def registered_items(self): """All the items that have been registered. diff --git a/cloudinit/reporting/__init__.py b/cloudinit/reporting/__init__.py index 2b92ab58..d0bc14e3 100644 --- a/cloudinit/reporting/__init__.py +++ b/cloudinit/reporting/__init__.py @@ -20,8 +20,6 @@ DEFAULT_CONFIG = { 'logging': {'type': 'log'}, } -instantiated_handler_registry = DictRegistry() - class _nameset(set): def __getattr__(self, name): @@ -46,6 +44,11 @@ class ReportingEvent(object): return '{0}: {1}: {2}'.format( self.event_type, self.name, self.description) + def as_dict(self): + """The event represented as a dictionary.""" + return {'name': self.name, 'description': self.description, + 'event_type': self.event_type} + class FinishReportingEvent(ReportingEvent): @@ -60,9 +63,26 @@ class FinishReportingEvent(ReportingEvent): return '{0}: {1}: {2}: {3}'.format( self.event_type, self.name, self.result, self.description) + def as_dict(self): + """The event represented as json friendly.""" + data = super(FinishReportingEvent, self).as_dict() + data['result'] = self.result + return data + -def add_configuration(config): +def update_configuration(config): + """Update the instanciated_handler_registry. + + :param config: + The dictionary containing changes to apply. If a key is given + with a False-ish value, the registered handler matching that name + will be unregistered. + """ for handler_name, handler_config in config.items(): + if not handler_config: + instantiated_handler_registry.unregister_item( + handler_name, force=True) + continue handler_config = handler_config.copy() cls = available_handlers.registered_items[handler_config.pop('type')] instance = cls(**handler_config) @@ -214,4 +234,5 @@ class ReportEventStack(object): report_finish_event(self.fullname, msg, result) -add_configuration(DEFAULT_CONFIG) +instantiated_handler_registry = DictRegistry() +update_configuration(DEFAULT_CONFIG) diff --git a/cloudinit/reporting/handlers.py b/cloudinit/reporting/handlers.py index be323f53..86cbe3c3 100644 --- a/cloudinit/reporting/handlers.py +++ b/cloudinit/reporting/handlers.py @@ -1,14 +1,27 @@ +# vi: ts=4 expandtab + import abc import logging +import oauthlib.oauth1 as oauth1 + +import six from cloudinit.registry import DictRegistry +from cloudinit import url_helper +from cloudinit import util +@six.add_metaclass(abc.ABCMeta) class ReportingHandler(object): + """Base class for report handlers. + + Implement :meth:`~publish_event` for controlling what + the handler does with an event. + """ @abc.abstractmethod def publish_event(self, event): - raise NotImplementedError + """Publish an event to the ``INFO`` log level.""" class LogHandler(ReportingHandler): diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index ffeb55d2..1a4ee8c4 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -4,6 +4,7 @@ # vi: ts=4 expandtab from cloudinit import reporting +from cloudinit.reporting import handlers from .helpers import (mock, TestCase) @@ -95,13 +96,29 @@ class TestReportingEvent(TestCase): [event_type, name, description]) self.assertEqual(expected_string_representation, event.as_string()) + def test_as_dict(self): + event_type, name, desc = 'test_type', 'test_name', 'test_desc' + event = reporting.ReportingEvent(event_type, name, desc) + self.assertEqual( + {'event_type': event_type, 'name': name, 'description': desc}, + event.as_dict()) + + +class TestFinishReportingEvent(TestCase): + def test_as_has_result(self): + result = reporting.status.SUCCESS + name, desc = 'test_name', 'test_desc' + event = reporting.FinishReportingEvent(name, desc, result) + ret = event.as_dict() + self.assertTrue('result' in ret) + self.assertEqual(ret['result'], result) + -class TestReportingHandler(TestCase): +class TestBaseReportingHandler(TestCase): - def test_no_default_publish_event_implementation(self): - self.assertRaises(NotImplementedError, - reporting.handlers.ReportingHandler().publish_event, - None) + def test_base_reporting_handler_is_abstract(self): + regexp = r".*abstract.*publish_event.*" + self.assertRaisesRegexp(TypeError, regexp, handlers.ReportingHandler) class TestLogHandler(TestCase): @@ -147,7 +164,7 @@ class TestReportingConfiguration(TestCase): @mock.patch.object(reporting, 'instantiated_handler_registry') def test_empty_configuration_doesnt_add_handlers( self, instantiated_handler_registry): - reporting.add_configuration({}) + reporting.update_configuration({}) self.assertEqual( 0, instantiated_handler_registry.register_item.call_count) @@ -159,7 +176,7 @@ class TestReportingConfiguration(TestCase): handler_cls = mock.Mock() available_handlers.registered_items = {handler_type_name: handler_cls} handler_name = 'my_test_handler' - reporting.add_configuration( + reporting.update_configuration( {handler_name: {'type': handler_type_name}}) self.assertEqual( {handler_name: handler_cls.return_value}, @@ -177,7 +194,7 @@ class TestReportingConfiguration(TestCase): handler_config = extra_kwargs.copy() handler_config.update({'type': handler_type_name}) handler_name = 'my_test_handler' - reporting.add_configuration({handler_name: handler_config}) + reporting.update_configuration({handler_name: handler_config}) self.assertEqual( handler_cls.return_value, reporting.instantiated_handler_registry.registered_items[ @@ -194,9 +211,25 @@ class TestReportingConfiguration(TestCase): available_handlers.registered_items = {handler_type_name: handler_cls} handler_config = {'type': handler_type_name, 'foo': 'bar'} expected_handler_config = handler_config.copy() - reporting.add_configuration({'my_test_handler': handler_config}) + reporting.update_configuration({'my_test_handler': handler_config}) self.assertEqual(expected_handler_config, handler_config) + @mock.patch.object( + reporting, 'instantiated_handler_registry', reporting.DictRegistry()) + @mock.patch.object(reporting, 'available_handlers') + def test_handlers_removed_if_falseish_specified(self, available_handlers): + handler_type_name = 'test_handler' + handler_cls = mock.Mock() + available_handlers.registered_items = {handler_type_name: handler_cls} + handler_name = 'my_test_handler' + reporting.update_configuration( + {handler_name: {'type': handler_type_name}}) + self.assertEqual( + 1, len(reporting.instantiated_handler_registry.registered_items)) + reporting.update_configuration({handler_name: None}) + self.assertEqual( + 0, len(reporting.instantiated_handler_registry.registered_items)) + class TestReportingEventStack(TestCase): @mock.patch('cloudinit.reporting.report_finish_event') -- cgit v1.2.3 From 48cb8699efb5c6116dfa7b4d76d0a5fb6b3fbbbf Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 7 Aug 2015 00:22:49 -0500 Subject: hopefully fix DataSourceMAAS --- cloudinit/sources/DataSourceMAAS.py | 58 +++++++++++----------------- tests/unittests/test_datasource/test_maas.py | 2 +- 2 files changed, 24 insertions(+), 36 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index 279da238..2f36bbe2 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -164,12 +164,12 @@ def read_maas_seed_dir(seed_d): return check_seed_contents(md, seed_d) -def read_maas_seed_url(seed_url, header_cb=None, timeout=None, +def read_maas_seed_url(seed_url, read_file_or_url=None, timeout=None, version=MD_VERSION, paths=None): """ Read the maas datasource at seed_url. - - header_cb is a method that should return a headers dictionary for - a given url + read_file_or_url is a method that should provide an interface + like util.read_file_or_url Expected format of seed_url is are the following files: * //meta-data/instance-id @@ -190,14 +190,12 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None, 'user-data': "%s/%s" % (base_url, 'user-data'), } + if read_file_or_url is None: + read_file_or_url = util.read_file_or_url + md = {} for name in file_order: url = files.get(name) - if not header_cb: - def _cb(url): - return {} - header_cb = _cb - if name == 'user-data': retries = 0 else: @@ -205,10 +203,8 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None, try: ssl_details = util.fetch_ssl_details(paths) - resp = util.read_file_or_url(url, retries=retries, - headers_cb=header_cb, - timeout=timeout, - ssl_details=ssl_details) + resp = read_file_or_url(url, retries=retries, + timeout=timeout, ssl_details=ssl_details) if resp.ok(): if name in BINARY_FIELDS: md[name] = resp.contents @@ -311,47 +307,39 @@ if __name__ == "__main__": if key in cfg and creds[key] is None: creds[key] = cfg[key] - def geturl(url, headers_cb): - req = Request(url, data=None, headers=headers_cb(url)) - return urlopen(req).read() + oauth_helper = url_helper.OauthUrlHelper(**creds) + + def geturl(url): + return oauth_helper.readurl(url).contents def printurl(url, headers_cb): - print("== %s ==\n%s\n" % (url, geturl(url, headers_cb))) + print("== %s ==\n%s\n" % (url, geturl(url))) - def crawl(url, headers_cb=None): + def crawl(url): if url.endswith("/"): - for line in geturl(url, headers_cb).splitlines(): + for line in geturl(url).splitlines(): if line.endswith("/"): - crawl("%s%s" % (url, line), headers_cb) + crawl("%s%s" % (url, line)) else: - printurl("%s%s" % (url, line), headers_cb) + printurl("%s%s" % (url, line)) else: - printurl(url, headers_cb) - - def my_headers(url): - headers = {} - if creds.get('consumer_key', None) is not None: - headers = oauth_headers(url, **creds) - return headers + printurl(url) if args.subcmd == "check-seed": - if args.url.startswith("http"): - (userdata, metadata) = read_maas_seed_url(args.url, - header_cb=my_headers, - version=args.apiver) - else: - (userdata, metadata) = read_maas_seed_url(args.url) + (userdata, metadata) = read_maas_seed_url( + args.url, read_file_or_url=oauth_helper.read_file_or_url, + version=args.apiver) print("=== userdata ===") print(userdata) print("=== metadata ===") pprint.pprint(metadata) elif args.subcmd == "get": - printurl(args.url, my_headers) + printurl(args.url) elif args.subcmd == "crawl": if not args.url.endswith("/"): args.url = "%s/" % args.url - crawl(args.url, my_headers) + crawl(args.url) main() diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index f109bb04..eb97b692 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -141,7 +141,7 @@ class TestMAASDataSource(TestCase): with mock.patch.object(url_helper, 'readurl', side_effect=side_effect()) as mockobj: userdata, metadata = DataSourceMAAS.read_maas_seed_url( - my_seed, header_cb=my_headers_cb, version=my_ver) + my_seed, version=my_ver) self.assertEqual(b"foodata", userdata) self.assertEqual(metadata['instance-id'], -- cgit v1.2.3 From b39070772aba62d68fea14603b8d657bbb529d5e Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Fri, 7 Aug 2015 16:06:36 -0500 Subject: reporting: fix logging reproter and tests --- cloudinit/reporting/handlers.py | 2 +- tests/unittests/test_reporting.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/reporting/handlers.py b/cloudinit/reporting/handlers.py index 1343311f..172679cc 100644 --- a/cloudinit/reporting/handlers.py +++ b/cloudinit/reporting/handlers.py @@ -35,7 +35,7 @@ class LogHandler(ReportingHandler): else: input_level = level try: - level = gettattr(logging, level.upper()) + level = getattr(logging, level.upper()) except: LOG.warn("invalid level '%s', using WARN", input_level) level = logging.WARN diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index 1a4ee8c4..66d4e87e 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -137,14 +137,14 @@ class TestLogHandler(TestCase): def test_single_log_message_at_info_published(self, getLogger): event = reporting.ReportingEvent('type', 'name', 'description') reporting.handlers.LogHandler().publish_event(event) - self.assertEqual(1, getLogger.return_value.info.call_count) + self.assertEqual(1, getLogger.return_value.log.call_count) @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_log_message_uses_event_as_string(self, getLogger): event = reporting.ReportingEvent('type', 'name', 'description') - reporting.handlers.LogHandler().publish_event(event) + reporting.handlers.LogHandler(level="INFO").publish_event(event) self.assertIn(event.as_string(), - getLogger.return_value.info.call_args[0][0]) + getLogger.return_value.log.call_args[0][1]) class TestDefaultRegisteredHandler(TestCase): -- cgit v1.2.3 From 50bcb0f77d29a76a03946c6da13b15be25257402 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 31 Aug 2015 13:33:30 -0400 Subject: split 'events' portion of reporting into separate file this just separates events from other things that could conceivably be reported. --- cloudinit/cloud.py | 4 +- cloudinit/reporting/__init__.py | 203 +----------------------------------- cloudinit/reporting/events.py | 210 ++++++++++++++++++++++++++++++++++++++ cloudinit/sources/__init__.py | 4 +- cloudinit/stages.py | 14 +-- tests/unittests/test_reporting.py | 121 +++++++++++----------- 6 files changed, 285 insertions(+), 271 deletions(-) create mode 100644 cloudinit/reporting/events.py (limited to 'tests/unittests') diff --git a/cloudinit/cloud.py b/cloudinit/cloud.py index edee3887..3e6be203 100644 --- a/cloudinit/cloud.py +++ b/cloudinit/cloud.py @@ -24,7 +24,7 @@ import copy import os from cloudinit import log as logging -from cloudinit import reporting +from cloudinit.reporting import events LOG = logging.getLogger(__name__) @@ -48,7 +48,7 @@ class Cloud(object): self._cfg = cfg self._runners = runners if reporter is None: - reporter = reporting.ReportEventStack( + reporter = events.ReportEventStack( name="unnamed-cloud-reporter", description="unnamed-cloud-reporter", reporting_enabled=False) diff --git a/cloudinit/reporting/__init__.py b/cloudinit/reporting/__init__.py index 502af95c..6b41ae61 100644 --- a/cloudinit/reporting/__init__.py +++ b/cloudinit/reporting/__init__.py @@ -1,7 +1,6 @@ # Copyright 2015 Canonical Ltd. # This file is part of cloud-init. See LICENCE file for license information. # -# vi: ts=4 expandtab """ cloud-init reporting framework @@ -10,66 +9,13 @@ report events in a structured manner. """ from ..registry import DictRegistry -from ..reporting.handlers import available_handlers - - -FINISH_EVENT_TYPE = 'finish' -START_EVENT_TYPE = 'start' +from .handlers import available_handlers DEFAULT_CONFIG = { 'logging': {'type': 'log'}, } -class _nameset(set): - def __getattr__(self, name): - if name in self: - return name - raise AttributeError("%s not a valid value" % name) - - -status = _nameset(("SUCCESS", "WARN", "FAIL")) - - -class ReportingEvent(object): - """Encapsulation of event formatting.""" - - def __init__(self, event_type, name, description): - self.event_type = event_type - self.name = name - self.description = description - - def as_string(self): - """The event represented as a string.""" - return '{0}: {1}: {2}'.format( - self.event_type, self.name, self.description) - - def as_dict(self): - """The event represented as a dictionary.""" - return {'name': self.name, 'description': self.description, - 'event_type': self.event_type} - - -class FinishReportingEvent(ReportingEvent): - - def __init__(self, name, description, result=status.SUCCESS): - super(FinishReportingEvent, self).__init__( - FINISH_EVENT_TYPE, name, description) - self.result = result - if result not in status: - raise ValueError("Invalid result: %s" % result) - - def as_string(self): - return '{0}: {1}: {2}: {3}'.format( - self.event_type, self.name, self.result, self.description) - - def as_dict(self): - """The event represented as json friendly.""" - data = super(FinishReportingEvent, self).as_dict() - data['result'] = self.result - return data - - def update_configuration(config): """Update the instanciated_handler_registry. @@ -90,150 +36,7 @@ def update_configuration(config): instantiated_handler_registry.register_item(handler_name, instance) -def report_event(event): - """Report an event to all registered event handlers. - - This should generally be called via one of the other functions in - the reporting module. - - :param event_type: - The type of the event; this should be a constant from the - reporting module. - """ - for _, handler in instantiated_handler_registry.registered_items.items(): - handler.publish_event(event) - - -def report_finish_event(event_name, event_description, - result=status.SUCCESS): - """Report a "finish" event. - - See :py:func:`.report_event` for parameter details. - """ - event = FinishReportingEvent(event_name, event_description, result) - return report_event(event) - - -def report_start_event(event_name, event_description): - """Report a "start" event. - - :param event_name: - The name of the event; this should be a topic which events would - share (e.g. it will be the same for start and finish events). - - :param event_description: - A human-readable description of the event that has occurred. - """ - event = ReportingEvent(START_EVENT_TYPE, event_name, event_description) - return report_event(event) - - -class ReportEventStack(object): - """Context Manager for using :py:func:`report_event` - - This enables calling :py:func:`report_start_event` and - :py:func:`report_finish_event` through a context manager. - - :param name: - the name of the event - - :param description: - the event's description, passed on to :py:func:`report_start_event` - - :param message: - the description to use for the finish event. defaults to - :param:description. - - :param parent: - :type parent: :py:class:ReportEventStack or None - The parent of this event. The parent is populated with - results of all its children. The name used in reporting - is / - - :param reporting_enabled: - Indicates if reporting events should be generated. - If not provided, defaults to the parent's value, or True if no parent - is provided. - - :param result_on_exception: - The result value to set if an exception is caught. default - value is FAIL. - """ - def __init__(self, name, description, message=None, parent=None, - reporting_enabled=None, result_on_exception=status.FAIL): - self.parent = parent - self.name = name - self.description = description - self.message = message - self.result_on_exception = result_on_exception - self.result = status.SUCCESS - - # use parents reporting value if not provided - if reporting_enabled is None: - if parent: - reporting_enabled = parent.reporting_enabled - else: - reporting_enabled = True - self.reporting_enabled = reporting_enabled - - if parent: - self.fullname = '/'.join((parent.fullname, name,)) - else: - self.fullname = self.name - self.children = {} - - def __repr__(self): - return ("ReportEventStack(%s, %s, reporting_enabled=%s)" % - (self.name, self.description, self.reporting_enabled)) - - def __enter__(self): - self.result = status.SUCCESS - if self.reporting_enabled: - report_start_event(self.fullname, self.description) - if self.parent: - self.parent.children[self.name] = (None, None) - return self - - def _childrens_finish_info(self): - for cand_result in (status.FAIL, status.WARN): - for name, (value, msg) in self.children.items(): - if value == cand_result: - return (value, self.message) - return (self.result, self.message) - - @property - def result(self): - return self._result - - @result.setter - def result(self, value): - if value not in status: - raise ValueError("'%s' not a valid result" % value) - self._result = value - - @property - def message(self): - if self._message is not None: - return self._message - return self.description - - @message.setter - def message(self, value): - self._message = value - - def _finish_info(self, exc): - # return tuple of description, and value - if exc: - return (self.result_on_exception, self.message) - return self._childrens_finish_info() - - def __exit__(self, exc_type, exc_value, traceback): - (result, msg) = self._finish_info(exc_value) - if self.parent: - self.parent.children[self.name] = (result, msg) - if self.reporting_enabled: - report_finish_event(self.fullname, msg, result) - - instantiated_handler_registry = DictRegistry() update_configuration(DEFAULT_CONFIG) + +# vi: ts=4 expandtab diff --git a/cloudinit/reporting/events.py b/cloudinit/reporting/events.py new file mode 100644 index 00000000..e35e41dd --- /dev/null +++ b/cloudinit/reporting/events.py @@ -0,0 +1,210 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +""" +cloud-init events + +Report events in a structured manner. +The events here are most likely used via reporting. +""" + +from . import instantiated_handler_registry + +FINISH_EVENT_TYPE = 'finish' +START_EVENT_TYPE = 'start' + + +class _nameset(set): + def __getattr__(self, name): + if name in self: + return name + raise AttributeError("%s not a valid value" % name) + + +status = _nameset(("SUCCESS", "WARN", "FAIL")) + + +class ReportingEvent(object): + """Encapsulation of event formatting.""" + + def __init__(self, event_type, name, description): + self.event_type = event_type + self.name = name + self.description = description + + def as_string(self): + """The event represented as a string.""" + return '{0}: {1}: {2}'.format( + self.event_type, self.name, self.description) + + def as_dict(self): + """The event represented as a dictionary.""" + return {'name': self.name, 'description': self.description, + 'event_type': self.event_type} + + +class FinishReportingEvent(ReportingEvent): + + def __init__(self, name, description, result=status.SUCCESS): + super(FinishReportingEvent, self).__init__( + FINISH_EVENT_TYPE, name, description) + self.result = result + if result not in status: + raise ValueError("Invalid result: %s" % result) + + def as_string(self): + return '{0}: {1}: {2}: {3}'.format( + self.event_type, self.name, self.result, self.description) + + def as_dict(self): + """The event represented as json friendly.""" + data = super(FinishReportingEvent, self).as_dict() + data['result'] = self.result + return data + + +def report_event(event): + """Report an event to all registered event handlers. + + This should generally be called via one of the other functions in + the reporting module. + + :param event_type: + The type of the event; this should be a constant from the + reporting module. + """ + for _, handler in instantiated_handler_registry.registered_items.items(): + handler.publish_event(event) + + +def report_finish_event(event_name, event_description, + result=status.SUCCESS): + """Report a "finish" event. + + See :py:func:`.report_event` for parameter details. + """ + event = FinishReportingEvent(event_name, event_description, result) + return report_event(event) + + +def report_start_event(event_name, event_description): + """Report a "start" event. + + :param event_name: + The name of the event; this should be a topic which events would + share (e.g. it will be the same for start and finish events). + + :param event_description: + A human-readable description of the event that has occurred. + """ + event = ReportingEvent(START_EVENT_TYPE, event_name, event_description) + return report_event(event) + + +class ReportEventStack(object): + """Context Manager for using :py:func:`report_event` + + This enables calling :py:func:`report_start_event` and + :py:func:`report_finish_event` through a context manager. + + :param name: + the name of the event + + :param description: + the event's description, passed on to :py:func:`report_start_event` + + :param message: + the description to use for the finish event. defaults to + :param:description. + + :param parent: + :type parent: :py:class:ReportEventStack or None + The parent of this event. The parent is populated with + results of all its children. The name used in reporting + is / + + :param reporting_enabled: + Indicates if reporting events should be generated. + If not provided, defaults to the parent's value, or True if no parent + is provided. + + :param result_on_exception: + The result value to set if an exception is caught. default + value is FAIL. + """ + def __init__(self, name, description, message=None, parent=None, + reporting_enabled=None, result_on_exception=status.FAIL): + self.parent = parent + self.name = name + self.description = description + self.message = message + self.result_on_exception = result_on_exception + self.result = status.SUCCESS + + # use parents reporting value if not provided + if reporting_enabled is None: + if parent: + reporting_enabled = parent.reporting_enabled + else: + reporting_enabled = True + self.reporting_enabled = reporting_enabled + + if parent: + self.fullname = '/'.join((parent.fullname, name,)) + else: + self.fullname = self.name + self.children = {} + + def __repr__(self): + return ("ReportEventStack(%s, %s, reporting_enabled=%s)" % + (self.name, self.description, self.reporting_enabled)) + + def __enter__(self): + self.result = status.SUCCESS + if self.reporting_enabled: + report_start_event(self.fullname, self.description) + if self.parent: + self.parent.children[self.name] = (None, None) + return self + + def _childrens_finish_info(self): + for cand_result in (status.FAIL, status.WARN): + for name, (value, msg) in self.children.items(): + if value == cand_result: + return (value, self.message) + return (self.result, self.message) + + @property + def result(self): + return self._result + + @result.setter + def result(self, value): + if value not in status: + raise ValueError("'%s' not a valid result" % value) + self._result = value + + @property + def message(self): + if self._message is not None: + return self._message + return self.description + + @message.setter + def message(self, value): + self._message = value + + def _finish_info(self, exc): + # return tuple of description, and value + if exc: + return (self.result_on_exception, self.message) + return self._childrens_finish_info() + + def __exit__(self, exc_type, exc_value, traceback): + (result, msg) = self._finish_info(exc_value) + if self.parent: + self.parent.children[self.name] = (result, msg) + if self.reporting_enabled: + report_finish_event(self.fullname, msg, result) + +# vi: ts=4 expandtab diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 838cd198..d3cfa560 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -27,12 +27,12 @@ import six from cloudinit import importer from cloudinit import log as logging -from cloudinit import reporting from cloudinit import type_utils from cloudinit import user_data as ud from cloudinit import util from cloudinit.filters import launch_index +from cloudinit.reporting import events DEP_FILESYSTEM = "FILESYSTEM" DEP_NETWORK = "NETWORK" @@ -254,7 +254,7 @@ def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list, reporter): LOG.debug("Searching for %s data source in: %s", mode, ds_names) for name, cls in zip(ds_names, ds_list): - myrep = reporting.ReportEventStack( + myrep = events.ReportEventStack( name="search-%s" % name.replace("DataSource", ""), description="searching for %s data from %s" % (mode, name), message="no %s data found from %s" % (mode, name), diff --git a/cloudinit/stages.py b/cloudinit/stages.py index d300709d..9f192c8d 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -46,7 +46,7 @@ from cloudinit import log as logging from cloudinit import sources from cloudinit import type_utils from cloudinit import util -from cloudinit import reporting +from cloudinit.reporting import events LOG = logging.getLogger(__name__) @@ -67,7 +67,7 @@ class Init(object): self.datasource = NULL_DATA_SOURCE if reporter is None: - reporter = reporting.ReportEventStack( + reporter = events.ReportEventStack( name="init-reporter", description="init-desc", reporting_enabled=False) self.reporter = reporter @@ -242,7 +242,7 @@ class Init(object): if self.datasource is not NULL_DATA_SOURCE: return self.datasource - with reporting.ReportEventStack( + with events.ReportEventStack( name="check-cache", description="attempting to read from cache", parent=self.reporter) as myrep: @@ -509,11 +509,11 @@ class Init(object): def consume_data(self, frequency=PER_INSTANCE): # Consume the userdata first, because we need want to let the part # handlers run first (for merging stuff) - with reporting.ReportEventStack( + with events.ReportEventStack( "consume-user-data", "reading and applying user-data", parent=self.reporter): self._consume_userdata(frequency) - with reporting.ReportEventStack( + with events.ReportEventStack( "consume-vendor-data", "reading and applying vendor-data", parent=self.reporter): self._consume_vendordata(frequency) @@ -595,7 +595,7 @@ class Modules(object): # Created on first use self._cached_cfg = None if reporter is None: - reporter = reporting.ReportEventStack( + reporter = events.ReportEventStack( name="module-reporter", description="module-desc", reporting_enabled=False) self.reporter = reporter @@ -710,7 +710,7 @@ class Modules(object): run_name = "config-%s" % (name) desc = "running %s with frequency %s" % (run_name, freq) - myrep = reporting.ReportEventStack( + myrep = events.ReportEventStack( name=run_name, description=desc, parent=self.reporter) with myrep: diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index 66d4e87e..bb67ef73 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -5,6 +5,7 @@ from cloudinit import reporting from cloudinit.reporting import handlers +from cloudinit.reporting import events from .helpers import (mock, TestCase) @@ -16,12 +17,12 @@ def _fake_registry(): class TestReportStartEvent(TestCase): - @mock.patch('cloudinit.reporting.instantiated_handler_registry', + @mock.patch('cloudinit.reporting.events.instantiated_handler_registry', new_callable=_fake_registry) def test_report_start_event_passes_something_with_as_string_to_handlers( self, instantiated_handler_registry): event_name, event_description = 'my_test_event', 'my description' - reporting.report_start_event(event_name, event_description) + events.report_start_event(event_name, event_description) expected_string_representation = ': '.join( ['start', event_name, event_description]) for _, handler in ( @@ -33,9 +34,9 @@ class TestReportStartEvent(TestCase): class TestReportFinishEvent(TestCase): - def _report_finish_event(self, result=reporting.status.SUCCESS): + def _report_finish_event(self, result=events.status.SUCCESS): event_name, event_description = 'my_test_event', 'my description' - reporting.report_finish_event( + events.report_finish_event( event_name, event_description, result=result) return event_name, event_description @@ -46,39 +47,39 @@ class TestReportFinishEvent(TestCase): event = handler.publish_event.call_args[0][0] self.assertEqual(expected_as_string, event.as_string()) - @mock.patch('cloudinit.reporting.instantiated_handler_registry', + @mock.patch('cloudinit.reporting.events.instantiated_handler_registry', new_callable=_fake_registry) def test_report_finish_event_passes_something_with_as_string_to_handlers( self, instantiated_handler_registry): event_name, event_description = self._report_finish_event() expected_string_representation = ': '.join( - ['finish', event_name, reporting.status.SUCCESS, + ['finish', event_name, events.status.SUCCESS, event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) - @mock.patch('cloudinit.reporting.instantiated_handler_registry', + @mock.patch('cloudinit.reporting.events.instantiated_handler_registry', new_callable=_fake_registry) def test_reporting_successful_finish_has_sensible_string_repr( self, instantiated_handler_registry): event_name, event_description = self._report_finish_event( - result=reporting.status.SUCCESS) + result=events.status.SUCCESS) expected_string_representation = ': '.join( - ['finish', event_name, reporting.status.SUCCESS, + ['finish', event_name, events.status.SUCCESS, event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) - @mock.patch('cloudinit.reporting.instantiated_handler_registry', + @mock.patch('cloudinit.reporting.events.instantiated_handler_registry', new_callable=_fake_registry) def test_reporting_unsuccessful_finish_has_sensible_string_repr( self, instantiated_handler_registry): event_name, event_description = self._report_finish_event( - result=reporting.status.FAIL) + result=events.status.FAIL) expected_string_representation = ': '.join( - ['finish', event_name, reporting.status.FAIL, event_description]) + ['finish', event_name, events.status.FAIL, event_description]) self.assertHandlersPassedObjectWithAsString( instantiated_handler_registry.registered_items, expected_string_representation) @@ -91,14 +92,14 @@ class TestReportingEvent(TestCase): def test_as_string(self): event_type, name, description = 'test_type', 'test_name', 'test_desc' - event = reporting.ReportingEvent(event_type, name, description) + event = events.ReportingEvent(event_type, name, description) expected_string_representation = ': '.join( [event_type, name, description]) self.assertEqual(expected_string_representation, event.as_string()) def test_as_dict(self): event_type, name, desc = 'test_type', 'test_name', 'test_desc' - event = reporting.ReportingEvent(event_type, name, desc) + event = events.ReportingEvent(event_type, name, desc) self.assertEqual( {'event_type': event_type, 'name': name, 'description': desc}, event.as_dict()) @@ -106,9 +107,9 @@ class TestReportingEvent(TestCase): class TestFinishReportingEvent(TestCase): def test_as_has_result(self): - result = reporting.status.SUCCESS + result = events.status.SUCCESS name, desc = 'test_name', 'test_desc' - event = reporting.FinishReportingEvent(name, desc, result) + event = events.FinishReportingEvent(name, desc, result) ret = event.as_dict() self.assertTrue('result' in ret) self.assertEqual(ret['result'], result) @@ -126,7 +127,7 @@ class TestLogHandler(TestCase): @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_appropriate_logger_used(self, getLogger): event_type, event_name = 'test_type', 'test_name' - event = reporting.ReportingEvent(event_type, event_name, 'description') + event = events.ReportingEvent(event_type, event_name, 'description') reporting.handlers.LogHandler().publish_event(event) self.assertEqual( [mock.call( @@ -135,13 +136,13 @@ class TestLogHandler(TestCase): @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_single_log_message_at_info_published(self, getLogger): - event = reporting.ReportingEvent('type', 'name', 'description') + event = events.ReportingEvent('type', 'name', 'description') reporting.handlers.LogHandler().publish_event(event) self.assertEqual(1, getLogger.return_value.log.call_count) @mock.patch.object(reporting.handlers.logging, 'getLogger') def test_log_message_uses_event_as_string(self, getLogger): - event = reporting.ReportingEvent('type', 'name', 'description') + event = events.ReportingEvent('type', 'name', 'description') reporting.handlers.LogHandler(level="INFO").publish_event(event) self.assertIn(event.as_string(), getLogger.return_value.log.call_args[0][1]) @@ -232,49 +233,49 @@ class TestReportingConfiguration(TestCase): class TestReportingEventStack(TestCase): - @mock.patch('cloudinit.reporting.report_finish_event') - @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_start_event') def test_start_and_finish_success(self, report_start, report_finish): - with reporting.ReportEventStack(name="myname", description="mydesc"): + with events.ReportEventStack(name="myname", description="mydesc"): pass self.assertEqual( [mock.call('myname', 'mydesc')], report_start.call_args_list) self.assertEqual( - [mock.call('myname', 'mydesc', reporting.status.SUCCESS)], + [mock.call('myname', 'mydesc', events.status.SUCCESS)], report_finish.call_args_list) - @mock.patch('cloudinit.reporting.report_finish_event') - @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_start_event') def test_finish_exception_defaults_fail(self, report_start, report_finish): name = "myname" desc = "mydesc" try: - with reporting.ReportEventStack(name, description=desc): + with events.ReportEventStack(name, description=desc): raise ValueError("This didnt work") except ValueError: pass self.assertEqual([mock.call(name, desc)], report_start.call_args_list) self.assertEqual( - [mock.call(name, desc, reporting.status.FAIL)], + [mock.call(name, desc, events.status.FAIL)], report_finish.call_args_list) - @mock.patch('cloudinit.reporting.report_finish_event') - @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_start_event') def test_result_on_exception_used(self, report_start, report_finish): name = "myname" desc = "mydesc" try: - with reporting.ReportEventStack( - name, desc, result_on_exception=reporting.status.WARN): + with events.ReportEventStack( + name, desc, result_on_exception=events.status.WARN): raise ValueError("This didnt work") except ValueError: pass self.assertEqual([mock.call(name, desc)], report_start.call_args_list) self.assertEqual( - [mock.call(name, desc, reporting.status.WARN)], + [mock.call(name, desc, events.status.WARN)], report_finish.call_args_list) - @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.events.report_start_event') def test_child_fullname_respects_parent(self, report_start): parent_name = "topname" c1_name = "c1name" @@ -282,59 +283,59 @@ class TestReportingEventStack(TestCase): c2_expected_fullname = '/'.join([parent_name, c1_name, c2_name]) c1_expected_fullname = '/'.join([parent_name, c1_name]) - parent = reporting.ReportEventStack(parent_name, "topdesc") - c1 = reporting.ReportEventStack(c1_name, "c1desc", parent=parent) - c2 = reporting.ReportEventStack(c2_name, "c2desc", parent=c1) + parent = events.ReportEventStack(parent_name, "topdesc") + c1 = events.ReportEventStack(c1_name, "c1desc", parent=parent) + c2 = events.ReportEventStack(c2_name, "c2desc", parent=c1) with c1: report_start.assert_called_with(c1_expected_fullname, "c1desc") with c2: report_start.assert_called_with(c2_expected_fullname, "c2desc") - @mock.patch('cloudinit.reporting.report_finish_event') - @mock.patch('cloudinit.reporting.report_start_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_start_event') def test_child_result_bubbles_up(self, report_start, report_finish): - parent = reporting.ReportEventStack("topname", "topdesc") - child = reporting.ReportEventStack("c_name", "c_desc", parent=parent) + parent = events.ReportEventStack("topname", "topdesc") + child = events.ReportEventStack("c_name", "c_desc", parent=parent) with parent: with child: - child.result = reporting.status.WARN + child.result = events.status.WARN report_finish.assert_called_with( - "topname", "topdesc", reporting.status.WARN) + "topname", "topdesc", events.status.WARN) - @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') def test_message_used_in_finish(self, report_finish): - with reporting.ReportEventStack("myname", "mydesc", - message="mymessage"): + with events.ReportEventStack("myname", "mydesc", + message="mymessage"): pass self.assertEqual( - [mock.call("myname", "mymessage", reporting.status.SUCCESS)], + [mock.call("myname", "mymessage", events.status.SUCCESS)], report_finish.call_args_list) - @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') def test_message_updatable(self, report_finish): - with reporting.ReportEventStack("myname", "mydesc") as c: + with events.ReportEventStack("myname", "mydesc") as c: c.message = "all good" self.assertEqual( - [mock.call("myname", "all good", reporting.status.SUCCESS)], + [mock.call("myname", "all good", events.status.SUCCESS)], report_finish.call_args_list) - @mock.patch('cloudinit.reporting.report_start_event') - @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_start_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') def test_reporting_disabled_does_not_report_events( self, report_start, report_finish): - with reporting.ReportEventStack("a", "b", reporting_enabled=False): + with events.ReportEventStack("a", "b", reporting_enabled=False): pass self.assertEqual(report_start.call_count, 0) self.assertEqual(report_finish.call_count, 0) - @mock.patch('cloudinit.reporting.report_start_event') - @mock.patch('cloudinit.reporting.report_finish_event') + @mock.patch('cloudinit.reporting.events.report_start_event') + @mock.patch('cloudinit.reporting.events.report_finish_event') def test_reporting_child_default_to_parent( self, report_start, report_finish): - parent = reporting.ReportEventStack( + parent = events.ReportEventStack( "pname", "pdesc", reporting_enabled=False) - child = reporting.ReportEventStack("cname", "cdesc", parent=parent) + child = events.ReportEventStack("cname", "cdesc", parent=parent) with parent: with child: pass @@ -343,17 +344,17 @@ class TestReportingEventStack(TestCase): self.assertEqual(report_finish.call_count, 0) def test_reporting_event_has_sane_repr(self): - myrep = reporting.ReportEventStack("fooname", "foodesc", - reporting_enabled=True).__repr__() + myrep = events.ReportEventStack("fooname", "foodesc", + reporting_enabled=True).__repr__() self.assertIn("fooname", myrep) self.assertIn("foodesc", myrep) self.assertIn("True", myrep) def test_set_invalid_result_raises_value_error(self): - f = reporting.ReportEventStack("myname", "mydesc") + f = events.ReportEventStack("myname", "mydesc") self.assertRaises(ValueError, setattr, f, "result", "BOGUS") class TestStatusAccess(TestCase): def test_invalid_status_access_raises_value_error(self): - self.assertRaises(AttributeError, getattr, reporting.status, "BOGUS") + self.assertRaises(AttributeError, getattr, events.status, "BOGUS") -- cgit v1.2.3 From 7820a43baf94e11ce458476a442edd726a406aba Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 31 Aug 2015 13:57:05 -0400 Subject: events: add timestamp and origin, support file posting This adds 'timestamp' and 'origin' to events. The timestamp is simply that, a floating point timestamp of when the event occurred. The origin indicates the source / reporter of this. It is useful to have a single endpoint with multiple different things reporting to it. For example, MAAS will configure cloud-init and curtin to report to the same endpoint and then it can differenciate who made the post. Admittedly, they could use multiple endpoints, but this this seems sane. Also, add support for posting files at the close of an event. This is utilized in curtin to post a log file when the install is done. files are posted on success or fail of the event. --- cloudinit/reporting/events.py | 58 +++++++++++++++++++++++++++++++-------- tests/unittests/test_reporting.py | 15 ++++++---- 2 files changed, 56 insertions(+), 17 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/reporting/events.py b/cloudinit/reporting/events.py index e35e41dd..2f767f64 100644 --- a/cloudinit/reporting/events.py +++ b/cloudinit/reporting/events.py @@ -2,17 +2,22 @@ # This file is part of cloud-init. See LICENCE file for license information. # """ -cloud-init events +events for reporting. -Report events in a structured manner. -The events here are most likely used via reporting. +The events here are designed to be used with reporting. +They can be published to registered handlers with report_event. """ +import base64 +import os.path +import time from . import instantiated_handler_registry FINISH_EVENT_TYPE = 'finish' START_EVENT_TYPE = 'start' +DEFAULT_EVENT_ORIGIN = 'cloudinit' + class _nameset(set): def __getattr__(self, name): @@ -27,10 +32,13 @@ status = _nameset(("SUCCESS", "WARN", "FAIL")) class ReportingEvent(object): """Encapsulation of event formatting.""" - def __init__(self, event_type, name, description): + def __init__(self, event_type, name, description, + origin=DEFAULT_EVENT_ORIGIN, timestamp=time.time()): self.event_type = event_type self.name = name self.description = description + self.origin = origin + self.timestamp = timestamp def as_string(self): """The event represented as a string.""" @@ -40,15 +48,20 @@ class ReportingEvent(object): def as_dict(self): """The event represented as a dictionary.""" return {'name': self.name, 'description': self.description, - 'event_type': self.event_type} + 'event_type': self.event_type, 'origin': self.origin, + 'timestamp': self.timestamp} class FinishReportingEvent(ReportingEvent): - def __init__(self, name, description, result=status.SUCCESS): + def __init__(self, name, description, result=status.SUCCESS, + post_files=None): super(FinishReportingEvent, self).__init__( FINISH_EVENT_TYPE, name, description) self.result = result + if post_files is None: + post_files = [] + self.post_files = post_files if result not in status: raise ValueError("Invalid result: %s" % result) @@ -60,6 +73,8 @@ class FinishReportingEvent(ReportingEvent): """The event represented as json friendly.""" data = super(FinishReportingEvent, self).as_dict() data['result'] = self.result + if self.post_files: + data['files'] = _collect_file_info(self.post_files) return data @@ -78,12 +93,13 @@ def report_event(event): def report_finish_event(event_name, event_description, - result=status.SUCCESS): + result=status.SUCCESS, post_files=None): """Report a "finish" event. See :py:func:`.report_event` for parameter details. """ - event = FinishReportingEvent(event_name, event_description, result) + event = FinishReportingEvent(event_name, event_description, result, + post_files=post_files) return report_event(event) @@ -133,13 +149,17 @@ class ReportEventStack(object): value is FAIL. """ def __init__(self, name, description, message=None, parent=None, - reporting_enabled=None, result_on_exception=status.FAIL): + reporting_enabled=None, result_on_exception=status.FAIL, + post_files=None): self.parent = parent self.name = name self.description = description self.message = message self.result_on_exception = result_on_exception self.result = status.SUCCESS + if post_files is None: + post_files = [] + self.post_files = post_files # use parents reporting value if not provided if reporting_enabled is None: @@ -205,6 +225,22 @@ class ReportEventStack(object): if self.parent: self.parent.children[self.name] = (result, msg) if self.reporting_enabled: - report_finish_event(self.fullname, msg, result) + report_finish_event(self.fullname, msg, result, + post_files=self.post_files) + + +def _collect_file_info(files): + if not files: + return None + ret = [] + for fname in files: + if not os.path.isfile(fname): + content = None + else: + with open(fname, "rb") as fp: + content = base64.b64encode(fp.read()).decode() + ret.append({'path': fname, 'content': content, + 'encoding': 'base64'}) + return ret -# vi: ts=4 expandtab +# vi: ts=4 expandtab syntax=python diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index bb67ef73..0a441adf 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -241,7 +241,8 @@ class TestReportingEventStack(TestCase): self.assertEqual( [mock.call('myname', 'mydesc')], report_start.call_args_list) self.assertEqual( - [mock.call('myname', 'mydesc', events.status.SUCCESS)], + [mock.call('myname', 'mydesc', events.status.SUCCESS, + post_files=[])], report_finish.call_args_list) @mock.patch('cloudinit.reporting.events.report_finish_event') @@ -256,7 +257,7 @@ class TestReportingEventStack(TestCase): pass self.assertEqual([mock.call(name, desc)], report_start.call_args_list) self.assertEqual( - [mock.call(name, desc, events.status.FAIL)], + [mock.call(name, desc, events.status.FAIL, post_files=[])], report_finish.call_args_list) @mock.patch('cloudinit.reporting.events.report_finish_event') @@ -272,7 +273,7 @@ class TestReportingEventStack(TestCase): pass self.assertEqual([mock.call(name, desc)], report_start.call_args_list) self.assertEqual( - [mock.call(name, desc, events.status.WARN)], + [mock.call(name, desc, events.status.WARN, post_files=[])], report_finish.call_args_list) @mock.patch('cloudinit.reporting.events.report_start_event') @@ -301,7 +302,7 @@ class TestReportingEventStack(TestCase): child.result = events.status.WARN report_finish.assert_called_with( - "topname", "topdesc", events.status.WARN) + "topname", "topdesc", events.status.WARN, post_files=[]) @mock.patch('cloudinit.reporting.events.report_finish_event') def test_message_used_in_finish(self, report_finish): @@ -309,7 +310,8 @@ class TestReportingEventStack(TestCase): message="mymessage"): pass self.assertEqual( - [mock.call("myname", "mymessage", events.status.SUCCESS)], + [mock.call("myname", "mymessage", events.status.SUCCESS, + post_files=[])], report_finish.call_args_list) @mock.patch('cloudinit.reporting.events.report_finish_event') @@ -317,7 +319,8 @@ class TestReportingEventStack(TestCase): with events.ReportEventStack("myname", "mydesc") as c: c.message = "all good" self.assertEqual( - [mock.call("myname", "all good", events.status.SUCCESS)], + [mock.call("myname", "all good", events.status.SUCCESS, + post_files=[])], report_finish.call_args_list) @mock.patch('cloudinit.reporting.events.report_start_event') -- cgit v1.2.3 From cba8282acc1c957698480bae2d0c2032b884d80d Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Wed, 2 Sep 2015 16:44:51 -0400 Subject: fix test_as_dict --- tests/unittests/test_reporting.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py index 0a441adf..32356ef9 100644 --- a/tests/unittests/test_reporting.py +++ b/tests/unittests/test_reporting.py @@ -100,9 +100,15 @@ class TestReportingEvent(TestCase): def test_as_dict(self): event_type, name, desc = 'test_type', 'test_name', 'test_desc' event = events.ReportingEvent(event_type, name, desc) - self.assertEqual( - {'event_type': event_type, 'name': name, 'description': desc}, - event.as_dict()) + expected = {'event_type': event_type, 'name': name, + 'description': desc, 'origin': 'cloudinit'} + + # allow for timestamp to differ, but must be present + as_dict = event.as_dict() + self.assertIn('timestamp', as_dict) + del as_dict['timestamp'] + + self.assertEqual(expected, as_dict) class TestFinishReportingEvent(TestCase): -- cgit v1.2.3 From 3c39c3f7638245e9581a2e1f4faae2dc2680f0c7 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 8 Sep 2015 14:26:30 -0400 Subject: NoCloud: fix consumption of vendor-data the content of vendordata was was being assigned to vendordata, rather than vendordata_raw. The result was that it is not processed for includes or part handlers or other things as it is in other datasources. LP: #1493453 --- ChangeLog | 1 + cloudinit/sources/DataSourceNoCloud.py | 2 +- tests/unittests/test_datasource/test_nocloud.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/ChangeLog b/ChangeLog index 7869ab7e..6fb70696 100644 --- a/ChangeLog +++ b/ChangeLog @@ -60,6 +60,7 @@ - rsyslog: add additional configuration mode (LP: #1478103) - status_wrapper in main: fix use of print_exc when handling exception - reporting: add reporting module for web hook or logging of events. + - NoCloud: fix consumption of vendordata (LP: #1493453) 0.7.6: - open 0.7.6 - Enable vendordata on CloudSigma datasource (LP: #1303986) diff --git a/cloudinit/sources/DataSourceNoCloud.py b/cloudinit/sources/DataSourceNoCloud.py index 6a861af3..4dffe6e6 100644 --- a/cloudinit/sources/DataSourceNoCloud.py +++ b/cloudinit/sources/DataSourceNoCloud.py @@ -190,7 +190,7 @@ class DataSourceNoCloud(sources.DataSource): self.seed = ",".join(found) self.metadata = mydata['meta-data'] self.userdata_raw = mydata['user-data'] - self.vendordata = mydata['vendor-data'] + self.vendordata_raw = mydata['vendor-data'] return True LOG.debug("%s: not claiming datasource, dsmode=%s", self, diff --git a/tests/unittests/test_datasource/test_nocloud.py b/tests/unittests/test_datasource/test_nocloud.py index 85b4c25a..2d5fc37c 100644 --- a/tests/unittests/test_datasource/test_nocloud.py +++ b/tests/unittests/test_datasource/test_nocloud.py @@ -121,7 +121,7 @@ class TestNoCloudDataSource(TestCase): ret = dsrc.get_data() self.assertEqual(dsrc.userdata_raw, ud) self.assertEqual(dsrc.metadata, md) - self.assertEqual(dsrc.vendordata, vd) + self.assertEqual(dsrc.vendordata_raw, vd) self.assertTrue(ret) def test_nocloud_no_vendordata(self): -- cgit v1.2.3 From ba3e59cbb5ae58a2267fcbcd23eecaaa26f2c396 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 8 Sep 2015 16:53:59 -0400 Subject: power_state: support 'condition' argument if 'condition' is provided to config in power_state, then consult it before powering off. This allows the user to shut down only if a condition is met, and leave the system in a debuggable state otherwise. An example is as simple as: power_state: mode: poweroff condition: ['sh', '-c', '[ -f /disable-poweroff ]'] --- ChangeLog | 1 + cloudinit/config/cc_power_state_change.py | 57 +++++++++++++++++++--- doc/examples/cloud-config-power-state.txt | 9 ++++ .../test_handler/test_handler_power_state.py | 48 ++++++++++++++++-- 4 files changed, 105 insertions(+), 10 deletions(-) (limited to 'tests/unittests') diff --git a/ChangeLog b/ChangeLog index 6fb70696..bbb7e990 100644 --- a/ChangeLog +++ b/ChangeLog @@ -61,6 +61,7 @@ - status_wrapper in main: fix use of print_exc when handling exception - reporting: add reporting module for web hook or logging of events. - NoCloud: fix consumption of vendordata (LP: #1493453) + - power_state_change: support 'condition' to disable or enable poweroff 0.7.6: - open 0.7.6 - Enable vendordata on CloudSigma datasource (LP: #1303986) diff --git a/cloudinit/config/cc_power_state_change.py b/cloudinit/config/cc_power_state_change.py index 09d37371..7d9567e3 100644 --- a/cloudinit/config/cc_power_state_change.py +++ b/cloudinit/config/cc_power_state_change.py @@ -22,6 +22,7 @@ from cloudinit import util import errno import os import re +import six import subprocess import time @@ -48,10 +49,40 @@ def givecmdline(pid): return None +def check_condition(cond, log=None): + if isinstance(cond, bool): + if log: + log.debug("Static Condition: %s" % cond) + return cond + + pre = "check_condition command (%s): " % cond + try: + proc = subprocess.Popen(cond, shell=not isinstance(cond, list)) + proc.communicate() + ret = proc.returncode + if ret == 0: + if log: + log.debug(pre + "exited 0. condition met.") + return True + elif ret == 1: + if log: + log.debug(pre + "exited 1. condition not met.") + return False + else: + if log: + log.warn(pre + "unexpected exit %s. " % ret + + "do not apply change.") + return False + except Exception as e: + if log: + log.warn(pre + "Unexpected error: %s" % e) + return False + + def handle(_name, cfg, _cloud, log, _args): try: - (args, timeout) = load_power_state(cfg) + (args, timeout, condition) = load_power_state(cfg) if args is None: log.debug("no power_state provided. doing nothing") return @@ -59,6 +90,10 @@ def handle(_name, cfg, _cloud, log, _args): log.warn("%s Not performing power state change!" % str(e)) return + if condition is False: + log.debug("Condition was false. Will not perform state change.") + return + mypid = os.getpid() cmdline = givecmdline(mypid) @@ -70,8 +105,8 @@ def handle(_name, cfg, _cloud, log, _args): log.debug("After pid %s ends, will execute: %s" % (mypid, ' '.join(args))) - util.fork_cb(run_after_pid_gone, mypid, cmdline, timeout, log, execmd, - [args, devnull_fp]) + util.fork_cb(run_after_pid_gone, mypid, cmdline, timeout, log, + condition, execmd, [args, devnull_fp]) def load_power_state(cfg): @@ -80,7 +115,7 @@ def load_power_state(cfg): pstate = cfg.get('power_state') if pstate is None: - return (None, None) + return (None, None, None) if not isinstance(pstate, dict): raise TypeError("power_state is not a dict.") @@ -115,7 +150,10 @@ def load_power_state(cfg): raise ValueError("failed to convert timeout '%s' to float." % pstate['timeout']) - return (args, timeout) + condition = pstate.get("condition", True) + if not isinstance(condition, six.string_types + (list, bool)): + raise TypeError("condition type %s invalid. must be list, bool, str") + return (args, timeout, condition) def doexit(sysexit): @@ -133,7 +171,7 @@ def execmd(exe_args, output=None, data_in=None): doexit(ret) -def run_after_pid_gone(pid, pidcmdline, timeout, log, func, args): +def run_after_pid_gone(pid, pidcmdline, timeout, log, condition, func, args): # wait until pid, with /proc/pid/cmdline contents of pidcmdline # is no longer alive. After it is gone, or timeout has passed # execute func(args) @@ -175,4 +213,11 @@ def run_after_pid_gone(pid, pidcmdline, timeout, log, func, args): if log: log.debug(msg) + + try: + if not check_condition(condition, log): + return + except Exception as e: + fatal("Unexpected Exception when checking condition: %s" % e) + func(*args) diff --git a/doc/examples/cloud-config-power-state.txt b/doc/examples/cloud-config-power-state.txt index 8df14366..b470153d 100644 --- a/doc/examples/cloud-config-power-state.txt +++ b/doc/examples/cloud-config-power-state.txt @@ -23,9 +23,18 @@ # message: provided as the message argument to 'shutdown'. default is none. # timeout: the amount of time to give the cloud-init process to finish # before executing shutdown. +# condition: apply state change only if condition is met. +# May be boolean True (always met), or False (never met), +# or a command string or list to be executed. +# command's exit code indicates: +# 0: condition met +# 1: condition not met +# other exit codes will result in 'not met', but are reserved +# for future use. # power_state: delay: "+30" mode: poweroff message: Bye Bye timeout: 30 + condition: True diff --git a/tests/unittests/test_handler/test_handler_power_state.py b/tests/unittests/test_handler/test_handler_power_state.py index 2f86b8f8..5687b10d 100644 --- a/tests/unittests/test_handler/test_handler_power_state.py +++ b/tests/unittests/test_handler/test_handler_power_state.py @@ -1,6 +1,9 @@ +import sys + from cloudinit.config import cc_power_state_change as psc from .. import helpers as t_help +from ..helpers import mock class TestLoadPowerState(t_help.TestCase): @@ -9,12 +12,12 @@ class TestLoadPowerState(t_help.TestCase): def test_no_config(self): # completely empty config should mean do nothing - (cmd, _timeout) = psc.load_power_state({}) + (cmd, _timeout, _condition) = psc.load_power_state({}) self.assertEqual(cmd, None) def test_irrelevant_config(self): # no power_state field in config should return None for cmd - (cmd, _timeout) = psc.load_power_state({'foo': 'bar'}) + (cmd, _timeout, _condition) = psc.load_power_state({'foo': 'bar'}) self.assertEqual(cmd, None) def test_invalid_mode(self): @@ -53,23 +56,60 @@ class TestLoadPowerState(t_help.TestCase): def test_no_message(self): # if message is not present, then no argument should be passed for it cfg = {'power_state': {'mode': 'poweroff'}} - (cmd, _timeout) = psc.load_power_state(cfg) + (cmd, _timeout, _condition) = psc.load_power_state(cfg) self.assertNotIn("", cmd) check_lps_ret(psc.load_power_state(cfg)) self.assertTrue(len(cmd) == 3) + def test_condition_null_raises(self): + cfg = {'power_state': {'mode': 'poweroff', 'condition': None}} + self.assertRaises(TypeError, psc.load_power_state, cfg) + + def test_condition_default_is_true(self): + cfg = {'power_state': {'mode': 'poweroff'}} + _cmd, _timeout, cond = psc.load_power_state(cfg) + self.assertEqual(cond, True) + + +class TestCheckCondition(t_help.TestCase): + def cmd_with_exit(self, rc): + return([sys.executable, '-c', 'import sys; sys.exit(%s)' % rc]) + + def test_true_is_true(self): + self.assertEqual(psc.check_condition(True), True) + + def test_false_is_false(self): + self.assertEqual(psc.check_condition(False), False) + + def test_cmd_exit_zero_true(self): + self.assertEqual(psc.check_condition(self.cmd_with_exit(0)), True) + + def test_cmd_exit_one_false(self): + self.assertEqual(psc.check_condition(self.cmd_with_exit(1)), False) + + def test_cmd_exit_nonzero_warns(self): + mocklog = mock.Mock() + self.assertEqual( + psc.check_condition(self.cmd_with_exit(2), mocklog), False) + self.assertEqual(mocklog.warn.call_count, 1) + + def check_lps_ret(psc_return, mode=None): - if len(psc_return) != 2: + if len(psc_return) != 3: raise TypeError("length returned = %d" % len(psc_return)) errs = [] cmd = psc_return[0] timeout = psc_return[1] + condition = psc_return[2] if 'shutdown' not in psc_return[0][0]: errs.append("string 'shutdown' not in cmd") + if 'condition' is None: + errs.append("condition was not returned") + if mode is not None: opt = {'halt': '-H', 'poweroff': '-P', 'reboot': '-r'}[mode] if opt not in psc_return[0]: -- cgit v1.2.3 From 41900b72f31a1bd0eebe2f58a8598bfab25f0003 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 9 Oct 2015 14:01:11 +0100 Subject: Handle escaped quotes in WALinuxAgentShim.find_endpoint. This fixes bug 1488891. --- cloudinit/sources/helpers/azure.py | 2 +- tests/unittests/test_datasource/test_azure_helper.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 281d733e..33003da0 100644 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -233,7 +233,7 @@ class WALinuxAgentShim(object): hex_string += hex_pair value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) else: - value = value.encode('utf-8') + value = value.replace('\\', '').encode('utf-8') endpoint_ip_address = socket.inet_ntoa(value) LOG.debug('Azure endpoint found at %s', endpoint_ip_address) return endpoint_ip_address diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index a5228870..68af31cd 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -97,7 +97,8 @@ class TestFindEndpoint(TestCase): if not use_hex: ip_address_repr = struct.pack( '>L', int(ip_address_repr.replace(':', ''), 16)) - ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) + ip_address_repr = '"{0}"'.format( + ip_address_repr.decode('utf-8').replace('"', '\\"')) return '\n'.join([ 'lease {', ' interface "eth0";', @@ -125,6 +126,13 @@ class TestFindEndpoint(TestCase): self.assertEqual(ip_address, azure_helper.WALinuxAgentShim.find_endpoint()) + def test_packed_string_with_escaped_quote(self): + ip_address = '100.72.34.108' + file_content = self._build_lease_content(ip_address, use_hex=False) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + def test_latest_lease_used(self): ip_addresses = ['4.3.2.1', '98.76.54.32'] file_content = '\n'.join([self._build_lease_content(ip_address) -- cgit v1.2.3 From 20dc4190e27c7778cfa6c2943961f2ad27e14b48 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 9 Oct 2015 14:01:11 +0100 Subject: Handle colons in packed strings in WALinuxAgentShim.find_endpoint. This fixes bug 1488896. --- cloudinit/sources/helpers/azure.py | 12 +++++++----- tests/unittests/test_datasource/test_azure_helper.py | 7 +++++++ 2 files changed, 14 insertions(+), 5 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 33003da0..21b4cd21 100644 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -225,16 +225,18 @@ class WALinuxAgentShim(object): value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') if value is None: raise Exception('No endpoint found in DHCP config.') - if ':' in value: + unescaped_value = value.replace('\\', '') + if len(unescaped_value) > 4: hex_string = '' - for hex_pair in value.split(':'): + for hex_pair in unescaped_value.split(':'): if len(hex_pair) == 1: hex_pair = '0' + hex_pair hex_string += hex_pair - value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) + packed_bytes = struct.pack( + '>L', int(hex_string.replace(':', ''), 16)) else: - value = value.replace('\\', '').encode('utf-8') - endpoint_ip_address = socket.inet_ntoa(value) + packed_bytes = unescaped_value.encode('utf-8') + endpoint_ip_address = socket.inet_ntoa(packed_bytes) LOG.debug('Azure endpoint found at %s', endpoint_ip_address) return endpoint_ip_address diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 68af31cd..5f906837 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -133,6 +133,13 @@ class TestFindEndpoint(TestCase): self.assertEqual(ip_address, azure_helper.WALinuxAgentShim.find_endpoint()) + def test_packed_string_containing_a_colon(self): + ip_address = '100.72.58.108' + file_content = self._build_lease_content(ip_address, use_hex=False) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + def test_latest_lease_used(self): ip_addresses = ['4.3.2.1', '98.76.54.32'] file_content = '\n'.join([self._build_lease_content(ip_address) -- cgit v1.2.3 From a9f5d20bc55abc5ff12a5ce1cec1125173f11c7b Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 9 Oct 2015 14:01:11 +0100 Subject: Convert test helper to staticmethod. --- tests/unittests/test_datasource/test_azure_helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 5f906837..1746f6b8 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -90,7 +90,8 @@ class TestFindEndpoint(TestCase): self.assertRaises(Exception, azure_helper.WALinuxAgentShim.find_endpoint) - def _build_lease_content(self, ip_address, use_hex=True): + @staticmethod + def _build_lease_content(ip_address, use_hex=True): ip_address_repr = ':'.join( [hex(int(part)).replace('0x', '') for part in ip_address.split('.')]) -- cgit v1.2.3 From 2f5a5ed9f9c5ba390a693bc441ab18cbcdff9623 Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 9 Oct 2015 14:01:11 +0100 Subject: Refactor tests to test helper method directly, and remove need for test helper. --- .../unittests/test_datasource/test_azure_helper.py | 84 +++++++++++----------- 1 file changed, 42 insertions(+), 42 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 1746f6b8..fed2a7d8 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -91,63 +91,63 @@ class TestFindEndpoint(TestCase): azure_helper.WALinuxAgentShim.find_endpoint) @staticmethod - def _build_lease_content(ip_address, use_hex=True): - ip_address_repr = ':'.join( - [hex(int(part)).replace('0x', '') - for part in ip_address.split('.')]) - if not use_hex: - ip_address_repr = struct.pack( - '>L', int(ip_address_repr.replace(':', ''), 16)) - ip_address_repr = '"{0}"'.format( - ip_address_repr.decode('utf-8').replace('"', '\\"')) + def _build_lease_content(encoded_address): return '\n'.join([ 'lease {', ' interface "eth0";', - ' option unknown-245 {0};'.format(ip_address_repr), + ' option unknown-245 {0};'.format(encoded_address), '}']) - def test_hex_string(self): - ip_address = '98.76.54.32' - file_content = self._build_lease_content(ip_address) + def test_latest_lease_used(self): + encoded_addresses = ['5:4:3:2', '4:3:2:1'] + file_content = '\n'.join([self._build_lease_content(encoded_address) + for encoded_address in encoded_addresses]) self.load_file.return_value = file_content - self.assertEqual(ip_address, + self.assertEqual(encoded_addresses[-1].replace(':', '.'), azure_helper.WALinuxAgentShim.find_endpoint()) + +class TestExtractIpAddressFromLeaseValue(TestCase): + + def test_hex_string(self): + ip_address, encoded_address = '98.76.54.32', '62:4c:36:20' + self.assertEqual( + ip_address, + azure_helper.WALinuxAgentShim.get_ip_from_lease_value( + encoded_address + )) + def test_hex_string_with_single_character_part(self): - ip_address = '4.3.2.1' - file_content = self._build_lease_content(ip_address) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - azure_helper.WALinuxAgentShim.find_endpoint()) + ip_address, encoded_address = '4.3.2.1', '4:3:2:1' + self.assertEqual( + ip_address, + azure_helper.WALinuxAgentShim.get_ip_from_lease_value( + encoded_address + )) def test_packed_string(self): - ip_address = '98.76.54.32' - file_content = self._build_lease_content(ip_address, use_hex=False) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - azure_helper.WALinuxAgentShim.find_endpoint()) + ip_address, encoded_address = '98.76.54.32', 'bL6 ' + self.assertEqual( + ip_address, + azure_helper.WALinuxAgentShim.get_ip_from_lease_value( + encoded_address + )) def test_packed_string_with_escaped_quote(self): - ip_address = '100.72.34.108' - file_content = self._build_lease_content(ip_address, use_hex=False) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - azure_helper.WALinuxAgentShim.find_endpoint()) + ip_address, encoded_address = '100.72.34.108', 'dH\\"l' + self.assertEqual( + ip_address, + azure_helper.WALinuxAgentShim.get_ip_from_lease_value( + encoded_address + )) def test_packed_string_containing_a_colon(self): - ip_address = '100.72.58.108' - file_content = self._build_lease_content(ip_address, use_hex=False) - self.load_file.return_value = file_content - self.assertEqual(ip_address, - azure_helper.WALinuxAgentShim.find_endpoint()) - - def test_latest_lease_used(self): - ip_addresses = ['4.3.2.1', '98.76.54.32'] - file_content = '\n'.join([self._build_lease_content(ip_address) - for ip_address in ip_addresses]) - self.load_file.return_value = file_content - self.assertEqual(ip_addresses[-1], - azure_helper.WALinuxAgentShim.find_endpoint()) + ip_address, encoded_address = '100.72.58.108', 'dH:l' + self.assertEqual( + ip_address, + azure_helper.WALinuxAgentShim.get_ip_from_lease_value( + encoded_address + )) class TestGoalStateParsing(TestCase): -- cgit v1.2.3 From 92ceca45c5d2983742ce18d2e8b2e671629ef4b0 Mon Sep 17 00:00:00 2001 From: Ben Howard Date: Wed, 14 Oct 2015 16:32:35 -0700 Subject: AZURE: support extracting SSH key values from ovf-env.xml Azure has or will be offering shortly the ability to directly define the SSH key value instead of a fingerprint in the ovf-env.xml file. This patch favors defined SSH keys over the fingerprint method (LP: #1506244). --- cloudinit/sources/DataSourceAzure.py | 16 ++++++--- tests/unittests/test_datasource/test_azure.py | 52 ++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 12 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index ff950deb..eb9fd042 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -148,9 +148,15 @@ class DataSourceAzureNet(sources.DataSource): wait_for = [shcfgxml] fp_files = [] + key_value = None for pk in self.cfg.get('_pubkeys', []): - bname = str(pk['fingerprint'] + ".crt") - fp_files += [os.path.join(ddir, bname)] + if pk.get('value', None): + key_value = pk['value'] + LOG.info("ssh authentication: using value from fabric") + else: + bname = str(pk['fingerprint'] + ".crt") + fp_files += [os.path.join(ddir, bname)] + LOG.info("ssh authentication: using fingerprint from fabirc") missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", func=wait_for_files, @@ -166,7 +172,8 @@ class DataSourceAzureNet(sources.DataSource): metadata['instance-id'] = iid_from_shared_config(shcfgxml) except ValueError as e: LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) - metadata['public-keys'] = pubkeys_from_crt_files(fp_files) + + metadata['public-keys'] = key_value or pubkeys_from_crt_files(fp_files) return metadata def get_data(self): @@ -497,7 +504,8 @@ def load_azure_ovf_pubkeys(sshnode): for pk_node in pubkeys: if not pk_node.hasChildNodes(): continue - cur = {'fingerprint': "", 'path': ""} + + cur = {'fingerprint': "", 'path': "", 'value': ""} for child in pk_node.childNodes: if child.nodeType == text_node or not child.localName: continue diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 8952374f..ec0435f5 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -54,10 +54,13 @@ def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): if pubkeys: content += "\n" - for fp, path in pubkeys: + for fp, path, value in pubkeys: content += " " - content += ("%s%s" % - (fp, path)) + if fp and path: + content += ("%s%s" % + (fp, path)) + if value: + content += "%s" % value content += "\n" content += "" content += """ @@ -297,10 +300,10 @@ class TestAzureDataSource(TestCase): self.assertFalse(ret) self.assertFalse('agent_invoked' in data) - def test_cfg_has_pubkeys(self): + def test_cfg_has_pubkeys_fingerprint(self): odata = {'HostName': "myhost", 'UserName': "myuser"} - mypklist = [{'fingerprint': 'fp1', 'path': 'path1'}] - pubkeys = [(x['fingerprint'], x['path']) for x in mypklist] + mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}] + pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist] data = {'ovfcontent': construct_valid_ovf_env(data=odata, pubkeys=pubkeys)} @@ -309,6 +312,39 @@ class TestAzureDataSource(TestCase): self.assertTrue(ret) for mypk in mypklist: self.assertIn(mypk, dsrc.cfg['_pubkeys']) + self.assertIn('pubkey_from', dsrc.metadata['public-keys'][-1]) + + def test_cfg_has_pubkeys_value(self): + # make sure that provided key is used over fingerprint + odata = {'HostName': "myhost", 'UserName': "myuser"} + mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': 'value1'}] + pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist] + data = {'ovfcontent': construct_valid_ovf_env(data=odata, + pubkeys=pubkeys)} + + dsrc = self._get_ds(data) + ret = dsrc.get_data() + self.assertTrue(ret) + + for mypk in mypklist: + self.assertIn(mypk, dsrc.cfg['_pubkeys']) + self.assertIn(mypk['value'], dsrc.metadata['public-keys']) + + def test_cfg_has_no_fingerprint_has_value(self): + # test value is used when fingerprint not provided + odata = {'HostName': "myhost", 'UserName': "myuser"} + mypklist = [{'fingerprint': None, 'path': 'path1', 'value': 'value1'}] + pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist] + data = {'ovfcontent': construct_valid_ovf_env(data=odata, + pubkeys=pubkeys)} + + dsrc = self._get_ds(data) + ret = dsrc.get_data() + self.assertTrue(ret) + + for mypk in mypklist: + self.assertIn(mypk['value'], dsrc.metadata['public-keys']) + def test_default_ephemeral(self): # make sure the ephemeral device works @@ -642,8 +678,8 @@ class TestReadAzureOvf(TestCase): DataSourceAzure.read_azure_ovf, invalid_xml) def test_load_with_pubkeys(self): - mypklist = [{'fingerprint': 'fp1', 'path': 'path1'}] - pubkeys = [(x['fingerprint'], x['path']) for x in mypklist] + mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}] + pubkeys = [(x['fingerprint'], x['path'], x['value']) for x in mypklist] content = construct_valid_ovf_env(pubkeys=pubkeys) (_md, _ud, cfg) = DataSourceAzure.read_azure_ovf(content) for mypk in mypklist: -- cgit v1.2.3 From 34b208a05361ae6ab4a51a6a999c9ac4ab77f06a Mon Sep 17 00:00:00 2001 From: Daniel Watkins Date: Fri, 30 Oct 2015 16:26:31 +0000 Subject: Use DMI data to find Azure instance IDs. This replaces the use of SharedConfig.xml in both the walinuxagent case, and the case where we communicate with the Azure fabric ourselves. --- cloudinit/sources/DataSourceAzure.py | 38 +---------- cloudinit/sources/helpers/azure.py | 21 ------ tests/unittests/test_datasource/test_azure.py | 77 +++++----------------- .../unittests/test_datasource/test_azure_helper.py | 42 +----------- 4 files changed, 23 insertions(+), 155 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index c6228e6c..bd80a8a6 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -31,8 +31,7 @@ from cloudinit import log as logging from cloudinit.settings import PER_ALWAYS from cloudinit import sources from cloudinit import util -from cloudinit.sources.helpers.azure import ( - get_metadata_from_fabric, iid_from_shared_config_content) +from cloudinit.sources.helpers.azure import get_metadata_from_fabric LOG = logging.getLogger(__name__) @@ -41,7 +40,6 @@ DEFAULT_METADATA = {"instance-id": "iid-AZURE-NODE"} AGENT_START = ['service', 'walinuxagent', 'start'] BOUNCE_COMMAND = ['sh', '-xc', "i=$interface; x=0; ifdown $i || x=$?; ifup $i || x=$?; exit $x"] -DATA_DIR_CLEAN_LIST = ['SharedConfig.xml'] BUILTIN_DS_CONFIG = { 'agent_command': AGENT_START, @@ -144,8 +142,6 @@ class DataSourceAzureNet(sources.DataSource): self.ds_cfg['agent_command']) ddir = self.ds_cfg['data_dir'] - shcfgxml = os.path.join(ddir, "SharedConfig.xml") - wait_for = [shcfgxml] fp_files = [] key_value = None @@ -160,19 +156,11 @@ class DataSourceAzureNet(sources.DataSource): missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", func=wait_for_files, - args=(wait_for + fp_files,)) + args=(fp_files,)) if len(missing): LOG.warn("Did not find files, but going on: %s", missing) metadata = {} - if shcfgxml in missing: - LOG.warn("SharedConfig.xml missing, using static instance-id") - else: - try: - metadata['instance-id'] = iid_from_shared_config(shcfgxml) - except ValueError as e: - LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) - metadata['public-keys'] = key_value or pubkeys_from_crt_files(fp_files) return metadata @@ -229,21 +217,6 @@ class DataSourceAzureNet(sources.DataSource): user_ds_cfg = util.get_cfg_by_path(self.cfg, DS_CFG_PATH, {}) self.ds_cfg = util.mergemanydict([user_ds_cfg, self.ds_cfg]) - if found != ddir: - cached_ovfenv = util.load_file( - os.path.join(ddir, 'ovf-env.xml'), quiet=True, decode=False) - if cached_ovfenv != files['ovf-env.xml']: - # source was not walinux-agent's datadir, so we have to clean - # up so 'wait_for_files' doesn't return early due to stale data - cleaned = [] - for f in [os.path.join(ddir, f) for f in DATA_DIR_CLEAN_LIST]: - if os.path.exists(f): - util.del_file(f) - cleaned.append(f) - if cleaned: - LOG.info("removed stale file(s) in '%s': %s", - ddir, str(cleaned)) - # walinux agent writes files world readable, but expects # the directory to be protected. write_files(ddir, files, dirmode=0o700) @@ -259,6 +232,7 @@ class DataSourceAzureNet(sources.DataSource): " on Azure.", exc_info=True) return False + self.metadata['instance-id'] = util.read_dmi_data('system-uuid') self.metadata.update(fabric_data) found_ephemeral = find_fabric_formatted_ephemeral_disk() @@ -649,12 +623,6 @@ def load_azure_ds_dir(source_dir): return (md, ud, cfg, {'ovf-env.xml': contents}) -def iid_from_shared_config(path): - with open(path, "rb") as fp: - content = fp.read() - return iid_from_shared_config_content(content) - - class BrokenAzureDataSource(Exception): pass diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 281d733e..d90c22fd 100644 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -78,12 +78,6 @@ class GoalState(object): return self._text_from_xpath( './Container/RoleInstanceList/RoleInstance/InstanceId') - @property - def shared_config_xml(self): - url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' - '/Configuration/SharedConfig') - return self.http_client.get(url).contents - @property def certificates_xml(self): if self._certificates_xml is None: @@ -172,19 +166,6 @@ class OpenSSLManager(object): return keys -def iid_from_shared_config_content(content): - """ - find INSTANCE_ID in: - - - - - """ - root = ElementTree.fromstring(content) - depnode = root.find('Deployment') - return depnode.get('name') - - class WALinuxAgentShim(object): REPORT_READY_XML_TEMPLATE = '\n'.join([ @@ -263,8 +244,6 @@ class WALinuxAgentShim(object): public_keys = self.openssl_manager.parse_certificates( goal_state.certificates_xml) data = { - 'instance-id': iid_from_shared_config_content( - goal_state.shared_config_xml), 'public-keys': public_keys, } self._report_ready(goal_state, http_client) diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index ec0435f5..3933794f 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -115,10 +115,6 @@ class TestAzureDataSource(TestCase): data['pubkey_files'] = flist return ["pubkey_from: %s" % f for f in flist] - def _iid_from_shared_config(path): - data['iid_from_shared_cfg'] = path - return 'i-my-azure-id' - if data.get('ovfcontent') is not None: populate_dir(os.path.join(self.paths.seed_dir, "azure"), {'ovf-env.xml': data['ovfcontent']}) @@ -127,20 +123,22 @@ class TestAzureDataSource(TestCase): mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d self.get_metadata_from_fabric = mock.MagicMock(return_value={ - 'instance-id': 'i-my-azure-id', 'public-keys': [], }) + self.instance_id = 'test-instance-id' + self.apply_patches([ (mod, 'list_possible_azure_ds_devs', dsdevs), (mod, 'invoke_agent', _invoke_agent), (mod, 'wait_for_files', _wait_for_files), (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files), - (mod, 'iid_from_shared_config', _iid_from_shared_config), (mod, 'perform_hostname_bounce', mock.MagicMock()), (mod, 'get_hostname', mock.MagicMock()), (mod, 'set_hostname', mock.MagicMock()), (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric), + (mod.util, 'read_dmi_data', mock.MagicMock( + return_value=self.instance_id)), ]) dsrc = mod.DataSourceAzureNet( @@ -193,7 +191,6 @@ class TestAzureDataSource(TestCase): self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName']) self.assertTrue(os.path.isfile( os.path.join(self.waagent_d, 'ovf-env.xml'))) - self.assertEqual(dsrc.metadata['instance-id'], 'i-my-azure-id') def test_waagent_d_has_0700_perms(self): # we expect /var/lib/waagent to be created 0700 @@ -345,7 +342,6 @@ class TestAzureDataSource(TestCase): for mypk in mypklist: self.assertIn(mypk['value'], dsrc.metadata['public-keys']) - def test_default_ephemeral(self): # make sure the ephemeral device works odata = {} @@ -434,54 +430,6 @@ class TestAzureDataSource(TestCase): dsrc = self._get_ds({'ovfcontent': xml}) dsrc.get_data() - def test_existing_ovf_same(self): - # waagent/SharedConfig left alone if found ovf-env.xml same as cached - odata = {'UserData': b64e("SOMEUSERDATA")} - data = {'ovfcontent': construct_valid_ovf_env(data=odata)} - - populate_dir(self.waagent_d, - {'ovf-env.xml': data['ovfcontent'], - 'otherfile': 'otherfile-content', - 'SharedConfig.xml': 'mysharedconfig'}) - - dsrc = self._get_ds(data) - ret = dsrc.get_data() - self.assertTrue(ret) - self.assertTrue(os.path.exists( - os.path.join(self.waagent_d, 'ovf-env.xml'))) - self.assertTrue(os.path.exists( - os.path.join(self.waagent_d, 'otherfile'))) - self.assertTrue(os.path.exists( - os.path.join(self.waagent_d, 'SharedConfig.xml'))) - - def test_existing_ovf_diff(self): - # waagent/SharedConfig must be removed if ovfenv is found elsewhere - - # 'get_data' should remove SharedConfig.xml in /var/lib/waagent - # if ovf-env.xml differs. - cached_ovfenv = construct_valid_ovf_env( - {'userdata': b64e("FOO_USERDATA")}) - new_ovfenv = construct_valid_ovf_env( - {'userdata': b64e("NEW_USERDATA")}) - - populate_dir(self.waagent_d, - {'ovf-env.xml': cached_ovfenv, - 'SharedConfig.xml': "mysharedconfigxml", - 'otherfile': 'otherfilecontent'}) - - dsrc = self._get_ds({'ovfcontent': new_ovfenv}) - ret = dsrc.get_data() - self.assertTrue(ret) - self.assertEqual(dsrc.userdata_raw, b"NEW_USERDATA") - self.assertTrue(os.path.exists( - os.path.join(self.waagent_d, 'otherfile'))) - self.assertFalse(os.path.exists( - os.path.join(self.waagent_d, 'SharedConfig.xml'))) - self.assertTrue(os.path.exists( - os.path.join(self.waagent_d, 'ovf-env.xml'))) - new_xml = load_file(os.path.join(self.waagent_d, 'ovf-env.xml')) - self.xml_equals(new_ovfenv, new_xml) - def test_exception_fetching_fabric_data_doesnt_propagate(self): ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) ds.ds_cfg['agent_command'] = '__builtin__' @@ -496,6 +444,17 @@ class TestAzureDataSource(TestCase): self.assertTrue(ret) self.assertEqual('value', ds.metadata['test']) + def test_instance_id_from_dmidecode_used(self): + ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + ds.get_data() + self.assertEqual(self.instance_id, ds.metadata['instance-id']) + + def test_instance_id_from_dmidecode_used_for_builtin(self): + ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + ds.ds_cfg['agent_command'] = '__builtin__' + ds.get_data() + self.assertEqual(self.instance_id, ds.metadata['instance-id']) + class TestAzureBounce(TestCase): @@ -504,9 +463,6 @@ class TestAzureBounce(TestCase): mock.patch.object(DataSourceAzure, 'invoke_agent')) self.patches.enter_context( mock.patch.object(DataSourceAzure, 'wait_for_files')) - self.patches.enter_context( - mock.patch.object(DataSourceAzure, 'iid_from_shared_config', - mock.MagicMock(return_value='i-my-azure-id'))) self.patches.enter_context( mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs', mock.MagicMock(return_value=[]))) @@ -521,6 +477,9 @@ class TestAzureBounce(TestCase): self.patches.enter_context( mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric', mock.MagicMock(return_value={}))) + self.patches.enter_context( + mock.patch.object(DataSourceAzure.util, 'read_dmi_data', + mock.MagicMock(return_value='test-instance-id'))) def setUp(self): super(TestAzureBounce, self).setUp() diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index a5228870..0638c974 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -40,7 +40,7 @@ GOAL_STATE_TEMPLATE = """\ http://100.86.192.70:80/...hostingEnvironmentConfig... - {shared_config_url} + http://100.86.192.70:80/..SharedConfig.. http://100.86.192.70:80/...extensionsConfig... @@ -55,21 +55,6 @@ GOAL_STATE_TEMPLATE = """\ """ -class TestReadAzureSharedConfig(unittest.TestCase): - - def test_valid_content(self): - xml = """ - - - - - - - """ - ret = azure_helper.iid_from_shared_config_content(xml) - self.assertEqual("MY_INSTANCE_ID", ret) - - class TestFindEndpoint(TestCase): def setUp(self): @@ -140,7 +125,6 @@ class TestGoalStateParsing(TestCase): 'incarnation': 1, 'container_id': 'MyContainerId', 'instance_id': 'MyInstanceId', - 'shared_config_url': 'MySharedConfigUrl', 'certificates_url': 'MyCertificatesUrl', } @@ -174,20 +158,9 @@ class TestGoalStateParsing(TestCase): goal_state = self._get_goal_state(instance_id=instance_id) self.assertEqual(instance_id, goal_state.instance_id) - def test_shared_config_xml_parsed_and_fetched_correctly(self): - http_client = mock.MagicMock() - shared_config_url = 'TestSharedConfigUrl' - goal_state = self._get_goal_state( - http_client=http_client, shared_config_url=shared_config_url) - shared_config_xml = goal_state.shared_config_xml - self.assertEqual(1, http_client.get.call_count) - self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) - self.assertEqual(http_client.get.return_value.contents, - shared_config_xml) - def test_certificates_xml_parsed_and_fetched_correctly(self): http_client = mock.MagicMock() - certificates_url = 'TestSharedConfigUrl' + certificates_url = 'TestCertificatesUrl' goal_state = self._get_goal_state( http_client=http_client, certificates_url=certificates_url) certificates_xml = goal_state.certificates_xml @@ -324,8 +297,6 @@ class TestWALinuxAgentShim(TestCase): azure_helper.WALinuxAgentShim, 'find_endpoint')) self.GoalState = patches.enter_context( mock.patch.object(azure_helper, 'GoalState')) - self.iid_from_shared_config_content = patches.enter_context( - mock.patch.object(azure_helper, 'iid_from_shared_config_content')) self.OpenSSLManager = patches.enter_context( mock.patch.object(azure_helper, 'OpenSSLManager')) patches.enter_context( @@ -367,15 +338,6 @@ class TestWALinuxAgentShim(TestCase): data = shim.register_with_azure_and_fetch_data() self.assertEqual([], data['public-keys']) - def test_instance_id_returned_in_data(self): - shim = azure_helper.WALinuxAgentShim() - data = shim.register_with_azure_and_fetch_data() - self.assertEqual( - [mock.call(self.GoalState.return_value.shared_config_xml)], - self.iid_from_shared_config_content.call_args_list) - self.assertEqual(self.iid_from_shared_config_content.return_value, - data['instance-id']) - def test_correct_url_used_for_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = azure_helper.WALinuxAgentShim() -- cgit v1.2.3 From ee40614b0a34a110265493c176c64db823aa34b3 Mon Sep 17 00:00:00 2001 From: Wesley Wiedenmeier Date: Wed, 3 Feb 2016 22:21:40 -0600 Subject: lxd: add support for setting up lxd using 'lxd init' If lxd key is present in cfg, then run 'lxd init' with values from the 'init' entry in lxd configuration as flags. --- ChangeLog | 1 + cloudinit/config/cc_lxd.py | 50 +++++++++++++++++++ config/cloud.cfg | 1 + tests/unittests/test_handler/test_handler_lxd.py | 62 ++++++++++++++++++++++++ 4 files changed, 114 insertions(+) create mode 100644 cloudinit/config/cc_lxd.py create mode 100644 tests/unittests/test_handler/test_handler_lxd.py (limited to 'tests/unittests') diff --git a/ChangeLog b/ChangeLog index 0ba16492..9fbc920d 100644 --- a/ChangeLog +++ b/ChangeLog @@ -71,6 +71,7 @@ - Azure: get instance id from dmi instead of SharedConfig (LP: #1506187) - systemd/power_state: fix power_state to work even if cloud-final exited non-zero (LP: #1449318) + - lxd: add support for setting up lxd using 'lxd init' 0.7.6: - open 0.7.6 - Enable vendordata on CloudSigma datasource (LP: #1303986) diff --git a/cloudinit/config/cc_lxd.py b/cloudinit/config/cc_lxd.py new file mode 100644 index 00000000..0db8356b --- /dev/null +++ b/cloudinit/config/cc_lxd.py @@ -0,0 +1,50 @@ +# vi: ts=4 expandtab +# +# Copyright (C) 2016 Canonical Ltd. +# +# Author: Wesley Wiedenmeier +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3, as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +""" +This module initializes lxd using 'lxd init' + +Example config: + #cloud-config + lxd: + init: + network_address: + network_port: + storage_backend: + storage_create_device: + storage_create_loop: + storage_pool: + trust_password: +""" + +from cloudinit import util + + +def handle(name, cfg, cloud, log, args): + if not cfg.get('lxd') and cfg['lxd'].get('init'): + log.debug("Skipping module named %s, not present or disabled by cfg") + return + lxd_conf = cfg['lxd']['init'] + keys = ('network_address', 'network_port', 'storage_backend', + 'storage_create_device', 'storage_create_loop', 'storage_pool', + 'trust_password') + cmd = ['lxd', 'init', '--auto'] + for k in keys: + if lxd_conf.get(k): + cmd.extend(["--%s" % k.replace('_', '-'), lxd_conf[k]]) + util.subp(cmd) diff --git a/config/cloud.cfg b/config/cloud.cfg index 74794ab0..795df19f 100644 --- a/config/cloud.cfg +++ b/config/cloud.cfg @@ -56,6 +56,7 @@ cloud_config_modules: - fan - landscape - timezone + - lxd - puppet - chef - salt-minion diff --git a/tests/unittests/test_handler/test_handler_lxd.py b/tests/unittests/test_handler/test_handler_lxd.py new file mode 100644 index 00000000..89863d52 --- /dev/null +++ b/tests/unittests/test_handler/test_handler_lxd.py @@ -0,0 +1,62 @@ +from cloudinit.config import cc_lxd +from cloudinit import (util, distros, helpers, cloud) +from cloudinit.sources import DataSourceNoCloud +from .. import helpers as t_help + +import logging + +LOG = logging.getLogger(__name__) + + +class TestLxd(t_help.TestCase): + def setUp(self): + super(TestLxd, self).setUp() + self.unapply = [] + apply_patches([(util, 'subp', self._mock_subp)]) + self.subp_called = [] + + def tearDown(self): + apply_patches([i for i in reversed(self.unapply)]) + + def _mock_subp(self, *args, **kwargs): + if 'args' not in kwargs: + kwargs['args'] = args[0] + self.subp_called.append(kwargs) + return + + def _get_cloud(self, distro): + cls = distros.fetch(distro) + paths = helpers.Paths({}) + d = cls(distro, {}, paths) + ds = DataSourceNoCloud.DataSourceNoCloud({}, d, paths) + cc = cloud.Cloud(ds, paths, {}, d, None) + return cc + + def test_lxd_init(self): + cfg = { + 'lxd': { + 'init': { + 'network_address': '0.0.0.0', + 'storage_backend': 'zfs', + 'storage_pool': 'poolname', + } + } + } + cc = self._get_cloud('ubuntu') + cc_lxd.handle('cc_lxd', cfg, cc, LOG, []) + + self.assertEqual( + self.subp_called[0].get('args'), + ['lxd', 'init', '--auto', '--network-address', '0.0.0.0', + '--storage-backend', 'zfs', '--storage-pool', 'poolname']) + + +def apply_patches(patches): + ret = [] + for (ref, name, replace) in patches: + if replace is None: + continue + orig = getattr(ref, name) + setattr(ref, name, replace) + ret.append((ref, name, orig)) + return ret -- cgit v1.2.3 From 75ba44d2730b89f13b2069961ea8de63f65ea780 Mon Sep 17 00:00:00 2001 From: Robert Jennings Date: Thu, 4 Feb 2016 15:52:08 -0600 Subject: SmartOS: Add support for Joyent LX-Brand Zones (LP: #1540965) LX-brand zones on Joyent's SmartOS use a different metadata source (socket file) than the KVM-based SmartOS virtualization (serial port). This patch adds support for recognizing the different flavors of virtualization on SmartOS and setting up a metadata source file object. After the file object is created, the rest of the code for the datasource LP: #1540965 --- cloudinit/sources/DataSourceSmartOS.py | 257 ++++++++++++++---------- doc/examples/cloud-config-datasources.txt | 7 + tests/unittests/test_datasource/test_smartos.py | 85 +++++--- 3 files changed, 216 insertions(+), 133 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index c9b497df..7453379a 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -20,10 +20,13 @@ # Datasource for provisioning on SmartOS. This works on Joyent # and public/private Clouds using SmartOS. # -# SmartOS hosts use a serial console (/dev/ttyS1) on Linux Guests. +# SmartOS hosts use a serial console (/dev/ttyS1) on KVM Linux Guests # The meta-data is transmitted via key/value pairs made by # requests on the console. For example, to get the hostname, you # would send "GET hostname" on /dev/ttyS1. +# For Linux Guests running in LX-Brand Zones on SmartOS hosts +# a socket (/native/.zonecontrol/metadata.sock) is used instead +# of a serial console. # # Certain behavior is defined by the DataDictionary # http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html @@ -34,6 +37,8 @@ import contextlib import os import random import re +import socket +import stat import serial @@ -46,6 +51,7 @@ LOG = logging.getLogger(__name__) SMARTOS_ATTRIB_MAP = { # Cloud-init Key : (SmartOS Key, Strip line endings) + 'instance-id': ('sdc:uuid', True), 'local-hostname': ('hostname', True), 'public-keys': ('root_authorized_keys', True), 'user-script': ('user-script', False), @@ -76,6 +82,7 @@ DS_CFG_PATH = ['datasource', DS_NAME] # BUILTIN_DS_CONFIG = { 'serial_device': '/dev/ttyS1', + 'metadata_sockfile': '/native/.zonecontrol/metadata.sock', 'seed_timeout': 60, 'no_base64_decode': ['root_authorized_keys', 'motd_sys_info', @@ -83,6 +90,7 @@ BUILTIN_DS_CONFIG = { 'user-data', 'user-script', 'sdc:datacenter_name', + 'sdc:uuid', ], 'base64_keys': [], 'base64_all': False, @@ -150,17 +158,27 @@ class DataSourceSmartOS(sources.DataSource): def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.is_smartdc = None - self.ds_cfg = util.mergemanydict([ self.ds_cfg, util.get_cfg_by_path(sys_cfg, DS_CFG_PATH, {}), BUILTIN_DS_CONFIG]) self.metadata = {} - self.cfg = BUILTIN_CLOUD_CONFIG - self.seed = self.ds_cfg.get("serial_device") - self.seed_timeout = self.ds_cfg.get("serial_timeout") + # SDC LX-Brand Zones lack dmidecode (no /dev/mem) but + # report 'BrandZ virtual linux' as the kernel version + if os.uname()[3].lower() == 'brandz virtual linux': + LOG.debug("Host is SmartOS, guest in Zone") + self.is_smartdc = True + self.smartos_type = 'lx-brand' + self.cfg = {} + self.seed = self.ds_cfg.get("metadata_sockfile") + else: + self.is_smartdc = True + self.smartos_type = 'kvm' + self.seed = self.ds_cfg.get("serial_device") + self.cfg = BUILTIN_CLOUD_CONFIG + self.seed_timeout = self.ds_cfg.get("serial_timeout") self.smartos_no_base64 = self.ds_cfg.get('no_base64_decode') self.b64_keys = self.ds_cfg.get('base64_keys') self.b64_all = self.ds_cfg.get('base64_all') @@ -170,12 +188,49 @@ class DataSourceSmartOS(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [seed=%s]" % (root, self.seed) + def _get_seed_file_object(self): + if not self.seed: + raise AttributeError("seed device is not set") + + if self.smartos_type == 'lx-brand': + if not stat.S_ISSOCK(os.stat(self.seed).st_mode): + LOG.debug("Seed %s is not a socket", self.seed) + return None + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(self.seed) + return sock.makefile('rwb') + else: + if not stat.S_ISCHR(os.stat(self.seed).st_mode): + LOG.debug("Seed %s is not a character device") + return None + ser = serial.Serial(self.seed, timeout=self.seed_timeout) + if not ser.isOpen(): + raise SystemError("Unable to open %s" % self.seed) + return ser + return None + + def _set_provisioned(self): + '''Mark the instance provisioning state as successful. + + When run in a zone, the host OS will look for /var/svc/provisioning + to be renamed as /var/svc/provision_success. This should be done + after meta-data is successfully retrieved and from this point + the host considers the provision of the zone to be a success and + keeps the zone running. + ''' + + LOG.debug('Instance provisioning state set as successful') + svc_path = '/var/svc' + if os.path.exists('/'.join([svc_path, 'provisioning'])): + os.rename('/'.join([svc_path, 'provisioning']), + '/'.join([svc_path, 'provision_success'])) + def get_data(self): md = {} ud = "" if not device_exists(self.seed): - LOG.debug("No serial device '%s' found for SmartOS datasource", + LOG.debug("No metadata device '%s' found for SmartOS datasource", self.seed) return False @@ -185,29 +240,36 @@ class DataSourceSmartOS(sources.DataSource): LOG.debug("Disabling SmartOS datasource on arm (LP: #1243287)") return False - dmi_info = dmi_data() - if dmi_info is False: - LOG.debug("No dmidata utility found") - return False - - system_uuid, system_type = tuple(dmi_info) - if 'smartdc' not in system_type.lower(): - LOG.debug("Host is not on SmartOS. system_type=%s", system_type) + # SDC KVM instances will provide dmi data, LX-brand does not + if self.smartos_type == 'kvm': + dmi_info = dmi_data() + if dmi_info is False: + LOG.debug("No dmidata utility found") + return False + + system_type = dmi_info + if 'smartdc' not in system_type.lower(): + LOG.debug("Host is not on SmartOS. system_type=%s", + system_type) + return False + LOG.debug("Host is SmartOS, guest in KVM") + + seed_obj = self._get_seed_file_object() + if seed_obj is None: + LOG.debug('Seed file object not found.') return False - self.is_smartdc = True - md['instance-id'] = system_uuid + with contextlib.closing(seed_obj) as seed: + b64_keys = self.query('base64_keys', seed, strip=True, b64=False) + if b64_keys is not None: + self.b64_keys = [k.strip() for k in str(b64_keys).split(',')] - b64_keys = self.query('base64_keys', strip=True, b64=False) - if b64_keys is not None: - self.b64_keys = [k.strip() for k in str(b64_keys).split(',')] + b64_all = self.query('base64_all', seed, strip=True, b64=False) + if b64_all is not None: + self.b64_all = util.is_true(b64_all) - b64_all = self.query('base64_all', strip=True, b64=False) - if b64_all is not None: - self.b64_all = util.is_true(b64_all) - - for ci_noun, attribute in SMARTOS_ATTRIB_MAP.items(): - smartos_noun, strip = attribute - md[ci_noun] = self.query(smartos_noun, strip=strip) + for ci_noun, attribute in SMARTOS_ATTRIB_MAP.items(): + smartos_noun, strip = attribute + md[ci_noun] = self.query(smartos_noun, seed, strip=strip) # @datadictionary: This key may contain a program that is written # to a file in the filesystem of the guest on each boot and then @@ -240,7 +302,7 @@ class DataSourceSmartOS(sources.DataSource): # Handle the cloud-init regular meta if not md['local-hostname']: - md['local-hostname'] = system_uuid + md['local-hostname'] = md['instance-id'] ud = None if md['user-data']: @@ -257,6 +319,8 @@ class DataSourceSmartOS(sources.DataSource): self.metadata = util.mergemanydict([md, self.metadata]) self.userdata_raw = ud self.vendordata_raw = md['vendor-data'] + + self._set_provisioned() return True def device_name_to_device(self, name): @@ -268,40 +332,64 @@ class DataSourceSmartOS(sources.DataSource): def get_instance_id(self): return self.metadata['instance-id'] - def query(self, noun, strip=False, default=None, b64=None): + def query(self, noun, seed_file, strip=False, default=None, b64=None): if b64 is None: if noun in self.smartos_no_base64: b64 = False elif self.b64_all or noun in self.b64_keys: b64 = True - return query_data(noun=noun, strip=strip, seed_device=self.seed, - seed_timeout=self.seed_timeout, default=default, - b64=b64) + return self._query_data(noun, seed_file, strip=strip, + default=default, b64=b64) + def _query_data(self, noun, seed_file, strip=False, + default=None, b64=None): + """Makes a request via "GET " -def device_exists(device): - """Symplistic method to determine if the device exists or not""" - return os.path.exists(device) + In the response, the first line is the status, while subsequent + lines are is the value. A blank line with a "." is used to + indicate end of response. + If the response is expected to be base64 encoded, then set + b64encoded to true. Unfortantely, there is no way to know if + something is 100% encoded, so this method relies on being told + if the data is base64 or not. + """ -def get_serial(seed_device, seed_timeout): - """This is replaced in unit testing, allowing us to replace - serial.Serial with a mocked class. + if not noun: + return False - The timeout value of 60 seconds should never be hit. The value - is taken from SmartOS own provisioning tools. Since we are reading - each line individually up until the single ".", the transfer is - usually very fast (i.e. microseconds) to get the response. - """ - if not seed_device: - raise AttributeError("seed_device value is not set") + response = JoyentMetadataClient(seed_file).get_metadata(noun) + + if response is None: + return default + + if b64 is None: + b64 = self._query_data('b64-%s' % noun, seed_file, b64=False, + default=False, strip=True) + b64 = util.is_true(b64) + + resp = None + if b64 or strip: + resp = "".join(response).rstrip() + else: + resp = "".join(response) - ser = serial.Serial(seed_device, timeout=seed_timeout) - if not ser.isOpen(): - raise SystemError("Unable to open %s" % seed_device) + if b64: + try: + return util.b64d(resp) + # Bogus input produces different errors in Python 2 and 3; + # catch both. + except (TypeError, binascii.Error): + LOG.warn("Failed base64 decoding key '%s'", noun) + return resp - return ser + return resp + + +def device_exists(device): + """Symplistic method to determine if the device exists or not""" + return os.path.exists(device) class JoyentMetadataFetchException(Exception): @@ -320,8 +408,8 @@ class JoyentMetadataClient(object): r' (?P(?P[0-9a-f]+) (?PSUCCESS|NOTFOUND)' r'( (?P.+))?)') - def __init__(self, serial): - self.serial = serial + def __init__(self, metasource): + self.metasource = metasource def _checksum(self, body): return '{0:08x}'.format( @@ -356,67 +444,30 @@ class JoyentMetadataClient(object): util.b64e(metadata_key)) msg = 'V2 {0} {1} {2}\n'.format( len(message_body), self._checksum(message_body), message_body) - LOG.debug('Writing "%s" to serial port.', msg) - self.serial.write(msg.encode('ascii')) - response = self.serial.readline().decode('ascii') - LOG.debug('Read "%s" from serial port.', response) - return self._get_value_from_frame(request_id, response) - - -def query_data(noun, seed_device, seed_timeout, strip=False, default=None, - b64=None): - """Makes a request to via the serial console via "GET " - - In the response, the first line is the status, while subsequent lines - are is the value. A blank line with a "." is used to indicate end of - response. - - If the response is expected to be base64 encoded, then set b64encoded - to true. Unfortantely, there is no way to know if something is 100% - encoded, so this method relies on being told if the data is base64 or - not. - """ - if not noun: - return False - - with contextlib.closing(get_serial(seed_device, seed_timeout)) as ser: - client = JoyentMetadataClient(ser) - response = client.get_metadata(noun) - - if response is None: - return default - - if b64 is None: - b64 = query_data('b64-%s' % noun, seed_device=seed_device, - seed_timeout=seed_timeout, b64=False, - default=False, strip=True) - b64 = util.is_true(b64) - - resp = None - if b64 or strip: - resp = "".join(response).rstrip() - else: - resp = "".join(response) - - if b64: - try: - return util.b64d(resp) - # Bogus input produces different errors in Python 2 and 3; catch both. - except (TypeError, binascii.Error): - LOG.warn("Failed base64 decoding key '%s'", noun) - return resp + LOG.debug('Writing "%s" to metadata transport.', msg) + self.metasource.write(msg.encode('ascii')) + self.metasource.flush() + + response = bytearray() + response.extend(self.metasource.read(1)) + while response[-1:] != b'\n': + response.extend(self.metasource.read(1)) + response = response.rstrip().decode('ascii') + LOG.debug('Read "%s" from metadata transport.', response) + + if 'SUCCESS' not in response: + return None - return resp + return self._get_value_from_frame(request_id, response) def dmi_data(): - sys_uuid = util.read_dmi_data("system-uuid") sys_type = util.read_dmi_data("system-product-name") - if not sys_uuid or not sys_type: + if not sys_type: return None - return (sys_uuid.lower(), sys_type) + return sys_type def write_boot_content(content, content_f, link=None, shebang=False, diff --git a/doc/examples/cloud-config-datasources.txt b/doc/examples/cloud-config-datasources.txt index 3bde4aac..2651c027 100644 --- a/doc/examples/cloud-config-datasources.txt +++ b/doc/examples/cloud-config-datasources.txt @@ -51,12 +51,19 @@ datasource: policy: on # [can be 'on', 'off' or 'force'] SmartOS: + # For KVM guests: # Smart OS datasource works over a serial console interacting with # a server on the other end. By default, the second serial console is the # device. SmartOS also uses a serial timeout of 60 seconds. serial_device: /dev/ttyS1 serial_timeout: 60 + # For LX-Brand Zones guests: + # Smart OS datasource works over a socket interacting with + # the host on the other end. By default, the socket file is in + # the native .zoncontrol directory. + metadata_sockfile: /native/.zonecontrol/metadata.sock + # a list of keys that will not be base64 decoded even if base64_all no_base64_decode: ['root_authorized_keys', 'motd_sys_info', 'iptables_disable'] diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index adee9019..1235436d 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -31,6 +31,7 @@ import shutil import stat import tempfile import uuid +import unittest from binascii import crc32 import serial @@ -56,12 +57,13 @@ MOCK_RETURNS = { 'cloud-init:user-data': '\n'.join(['#!/bin/sh', '/bin/true', '']), 'sdc:datacenter_name': 'somewhere2', 'sdc:operator-script': '\n'.join(['bin/true', '']), + 'sdc:uuid': str(uuid.uuid4()), 'sdc:vendor-data': '\n'.join(['VENDOR_DATA', '']), 'user-data': '\n'.join(['something', '']), 'user-script': '\n'.join(['/bin/true', '']), } -DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc') +DMI_DATA_RETURN = 'smartdc' def get_mock_client(mockdata): @@ -111,7 +113,8 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): ret = apply_patches(patches) self.unapply += ret - def _get_ds(self, sys_cfg=None, ds_cfg=None, mockdata=None, dmi_data=None): + def _get_ds(self, sys_cfg=None, ds_cfg=None, mockdata=None, dmi_data=None, + is_lxbrand=False): mod = DataSourceSmartOS if mockdata is None: @@ -124,9 +127,13 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): return dmi_data def _os_uname(): - # LP: #1243287. tests assume this runs, but running test on - # arm would cause them all to fail. - return ('LINUX', 'NODENAME', 'RELEASE', 'VERSION', 'x86_64') + if not is_lxbrand: + # LP: #1243287. tests assume this runs, but running test on + # arm would cause them all to fail. + return ('LINUX', 'NODENAME', 'RELEASE', 'VERSION', 'x86_64') + else: + return ('LINUX', 'NODENAME', 'RELEASE', 'BRANDZ VIRTUAL LINUX', + 'X86_64') if sys_cfg is None: sys_cfg = {} @@ -136,7 +143,6 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): sys_cfg['datasource']['SmartOS'] = ds_cfg self.apply_patches([(mod, 'LEGACY_USER_D', self.legacy_user_d)]) - self.apply_patches([(mod, 'get_serial', mock.MagicMock())]) self.apply_patches([ (mod, 'JoyentMetadataClient', get_mock_client(mockdata))]) self.apply_patches([(mod, 'dmi_data', _dmi_data)]) @@ -144,6 +150,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): self.apply_patches([(mod, 'device_exists', lambda d: True)]) dsrc = mod.DataSourceSmartOS(sys_cfg, distro=None, paths=self.paths) + self.apply_patches([(dsrc, '_get_seed_file_object', mock.MagicMock())]) return dsrc def test_seed(self): @@ -151,14 +158,29 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): dsrc = self._get_ds() ret = dsrc.get_data() self.assertTrue(ret) + self.assertEquals('kvm', dsrc.smartos_type) self.assertEquals('/dev/ttyS1', dsrc.seed) + def test_seed_lxbrand(self): + # default seed should be /dev/ttyS1 + dsrc = self._get_ds(is_lxbrand=True) + ret = dsrc.get_data() + self.assertTrue(ret) + self.assertEquals('lx-brand', dsrc.smartos_type) + self.assertEquals('/native/.zonecontrol/metadata.sock', dsrc.seed) + def test_issmartdc(self): dsrc = self._get_ds() ret = dsrc.get_data() self.assertTrue(ret) self.assertTrue(dsrc.is_smartdc) + def test_issmartdc_lxbrand(self): + dsrc = self._get_ds(is_lxbrand=True) + ret = dsrc.get_data() + self.assertTrue(ret) + self.assertTrue(dsrc.is_smartdc) + def test_no_base64(self): ds_cfg = {'no_base64_decode': ['test_var1'], 'all_base': True} dsrc = self._get_ds(ds_cfg=ds_cfg) @@ -169,7 +191,8 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): dsrc = self._get_ds(mockdata=MOCK_RETURNS) ret = dsrc.get_data() self.assertTrue(ret) - self.assertEquals(DMI_DATA_RETURN[0], dsrc.metadata['instance-id']) + self.assertEquals(MOCK_RETURNS['sdc:uuid'], + dsrc.metadata['instance-id']) def test_root_keys(self): dsrc = self._get_ds(mockdata=MOCK_RETURNS) @@ -407,18 +430,6 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): self.assertEqual(dsrc.device_name_to_device('FOO'), mydscfg['disk_aliases']['FOO']) - @mock.patch('cloudinit.sources.DataSourceSmartOS.JoyentMetadataClient') - @mock.patch('cloudinit.sources.DataSourceSmartOS.get_serial') - def test_serial_console_closed_on_error(self, get_serial, metadata_client): - class OurException(Exception): - pass - metadata_client.side_effect = OurException - try: - DataSourceSmartOS.query_data('noun', 'device', 0) - except OurException: - pass - self.assertEqual(1, get_serial.return_value.close.call_count) - def apply_patches(patches): ret = [] @@ -447,14 +458,25 @@ class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): } def make_response(): - payload = '' - if self.response_parts['payload']: - payload = ' {0}'.format(self.response_parts['payload']) - del self.response_parts['payload'] - return ( - 'V2 {length} {crc} {request_id} {command}{payload}\n'.format( - payload=payload, **self.response_parts).encode('ascii')) - self.serial.readline.side_effect = make_response + payloadstr = '' + if 'payload' in self.response_parts: + payloadstr = ' {0}'.format(self.response_parts['payload']) + return ('V2 {length} {crc} {request_id} ' + '{command}{payloadstr}\n'.format( + payloadstr=payloadstr, + **self.response_parts).encode('ascii')) + + self.metasource_data = None + + def read_response(length): + if not self.metasource_data: + self.metasource_data = make_response() + self.metasource_data_len = len(self.metasource_data) + resp = self.metasource_data[:length] + self.metasource_data = self.metasource_data[length:] + return resp + + self.serial.read.side_effect = read_response self.patched_funcs.enter_context( mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint', mock.Mock(return_value=self.request_id))) @@ -477,7 +499,9 @@ class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): client.get_metadata('some_key') self.assertEqual(1, self.serial.write.call_count) written_line = self.serial.write.call_args[0][0] - self.assertEndsWith(written_line, b'\n') + print(type(written_line)) + self.assertEndsWith(written_line.decode('ascii'), + b'\n'.decode('ascii')) self.assertEqual(1, written_line.count(b'\n')) def _get_written_line(self, key='some_key'): @@ -489,7 +513,8 @@ class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): self.assertIsInstance(self._get_written_line(), six.binary_type) def test_get_metadata_line_starts_with_v2(self): - self.assertStartsWith(self._get_written_line(), b'V2') + foo = self._get_written_line() + self.assertStartsWith(foo.decode('ascii'), b'V2'.decode('ascii')) def test_get_metadata_uses_get_command(self): parts = self._get_written_line().decode('ascii').strip().split(' ') @@ -526,7 +551,7 @@ class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): def test_get_metadata_reads_a_line(self): client = self._get_client() client.get_metadata('some_key') - self.assertEqual(1, self.serial.readline.call_count) + self.assertEqual(self.metasource_data_len, self.serial.read.call_count) def test_get_metadata_returns_valid_value(self): client = self._get_client() -- cgit v1.2.3 From 1f24e50c655ccc0c49afcd3225ce8f0e0bce05b9 Mon Sep 17 00:00:00 2001 From: Wesley Wiedenmeier Date: Thu, 4 Feb 2016 19:11:02 -0600 Subject: Use mock in test_handler_lxd.py and add test for lxd installation --- tests/unittests/test_handler/test_handler_lxd.py | 76 +++++++++++------------- 1 file changed, 36 insertions(+), 40 deletions(-) (limited to 'tests/unittests') diff --git a/tests/unittests/test_handler/test_handler_lxd.py b/tests/unittests/test_handler/test_handler_lxd.py index 89863d52..4d858b8f 100644 --- a/tests/unittests/test_handler/test_handler_lxd.py +++ b/tests/unittests/test_handler/test_handler_lxd.py @@ -1,28 +1,31 @@ from cloudinit.config import cc_lxd -from cloudinit import (util, distros, helpers, cloud) +from cloudinit import (distros, helpers, cloud) from cloudinit.sources import DataSourceNoCloud from .. import helpers as t_help import logging +try: + from unittest import mock +except ImportError: + import mock + LOG = logging.getLogger(__name__) class TestLxd(t_help.TestCase): + lxd_cfg = { + 'lxd': { + 'init': { + 'network_address': '0.0.0.0', + 'storage_backend': 'zfs', + 'storage_pool': 'poolname', + } + } + } + def setUp(self): super(TestLxd, self).setUp() - self.unapply = [] - apply_patches([(util, 'subp', self._mock_subp)]) - self.subp_called = [] - - def tearDown(self): - apply_patches([i for i in reversed(self.unapply)]) - - def _mock_subp(self, *args, **kwargs): - if 'args' not in kwargs: - kwargs['args'] = args[0] - self.subp_called.append(kwargs) - return def _get_cloud(self, distro): cls = distros.fetch(distro) @@ -32,31 +35,24 @@ class TestLxd(t_help.TestCase): cc = cloud.Cloud(ds, paths, {}, d, None) return cc - def test_lxd_init(self): - cfg = { - 'lxd': { - 'init': { - 'network_address': '0.0.0.0', - 'storage_backend': 'zfs', - 'storage_pool': 'poolname', - } - } - } + @mock.patch("cloudinit.config.cc_lxd.util") + def test_lxd_init(self, mock_util): + cc = self._get_cloud('ubuntu') + mock_util.which.return_value = True + cc_lxd.handle('cc_lxd', self.lxd_cfg, cc, LOG, []) + self.assertTrue(mock_util.which.called) + init_call = mock_util.subp.call_args_list[0][0][0] + self.assertEquals(init_call, + ['lxd', 'init', '--auto', '--network-address', + '0.0.0.0', '--storage-backend', 'zfs', + '--storage-pool', 'poolname']) + + @mock.patch("cloudinit.config.cc_lxd.util") + def test_lxd_install(self, mock_util): cc = self._get_cloud('ubuntu') - cc_lxd.handle('cc_lxd', cfg, cc, LOG, []) - - self.assertEqual( - self.subp_called[0].get('args'), - ['lxd', 'init', '--auto', '--network-address', '0.0.0.0', - '--storage-backend', 'zfs', '--storage-pool', 'poolname']) - - -def apply_patches(patches): - ret = [] - for (ref, name, replace) in patches: - if replace is None: - continue - orig = getattr(ref, name) - setattr(ref, name, replace) - ret.append((ref, name, orig)) - return ret + cc.distro = mock.MagicMock() + mock_util.which.return_value = None + cc_lxd.handle('cc_lxd', self.lxd_cfg, cc, LOG, []) + self.assertTrue(cc.distro.install_packages.called) + install_pkg = cc.distro.install_packages.call_args_list[0][0][0] + self.assertEquals(install_pkg, ('lxd',)) -- cgit v1.2.3 From 9ba841efc4263fcd1ec8eb266a3afa83b525ba49 Mon Sep 17 00:00:00 2001 From: Sankar Tanguturi Date: Mon, 8 Feb 2016 17:06:39 -0800 Subject: - Added new Unittests for VMware support. --- tests/data/vmware/cust-dhcp-2nic.cfg | 34 ++++++++++ tests/data/vmware/cust-static-2nic.cfg | 39 +++++++++++ tests/unittests/test_vmware_config_file.py | 103 +++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 tests/data/vmware/cust-dhcp-2nic.cfg create mode 100644 tests/data/vmware/cust-static-2nic.cfg create mode 100644 tests/unittests/test_vmware_config_file.py (limited to 'tests/unittests') diff --git a/tests/data/vmware/cust-dhcp-2nic.cfg b/tests/data/vmware/cust-dhcp-2nic.cfg new file mode 100644 index 00000000..f687311a --- /dev/null +++ b/tests/data/vmware/cust-dhcp-2nic.cfg @@ -0,0 +1,34 @@ +[NETWORK] +NETWORKING = yes +BOOTPROTO = dhcp +HOSTNAME = myhost1 +DOMAINNAME = eng.vmware.com + +[NIC-CONFIG] +NICS = NIC1,NIC2 + +[NIC1] +MACADDR = 00:50:56:a6:8c:08 +ONBOOT = yes +IPv4_MODE = BACKWARDS_COMPATIBLE +BOOTPROTO = dhcp + +[NIC2] +MACADDR = 00:50:56:a6:5a:de +ONBOOT = yes +IPv4_MODE = BACKWARDS_COMPATIBLE +BOOTPROTO = dhcp + +# some random comment + +[PASSWORD] +# secret +-PASS = c2VjcmV0Cg== + +[DNS] +DNSFROMDHCP=yes +SUFFIX|1 = eng.vmware.com + +[DATETIME] +TIMEZONE = Africa/Abidjan +UTC = yes diff --git a/tests/data/vmware/cust-static-2nic.cfg b/tests/data/vmware/cust-static-2nic.cfg new file mode 100644 index 00000000..0d80c2c4 --- /dev/null +++ b/tests/data/vmware/cust-static-2nic.cfg @@ -0,0 +1,39 @@ +[NETWORK] +NETWORKING = yes +BOOTPROTO = dhcp +HOSTNAME = myhost1 +DOMAINNAME = eng.vmware.com + +[NIC-CONFIG] +NICS = NIC1,NIC2 + +[NIC1] +MACADDR = 00:50:56:a6:8c:08 +ONBOOT = yes +IPv4_MODE = BACKWARDS_COMPATIBLE +BOOTPROTO = static +IPADDR = 10.20.87.154 +NETMASK = 255.255.252.0 +GATEWAY = 10.20.87.253, 10.20.87.105 +IPv6ADDR|1 = fc00:10:20:87::154 +IPv6NETMASK|1 = 64 +IPv6GATEWAY|1 = fc00:10:20:87::253 +[NIC2] +MACADDR = 00:50:56:a6:ef:7d +ONBOOT = yes +IPv4_MODE = BACKWARDS_COMPATIBLE +BOOTPROTO = static +IPADDR = 192.168.6.102 +NETMASK = 255.255.0.0 +GATEWAY = 192.168.0.10 + +[DNS] +DNSFROMDHCP=no +SUFFIX|1 = eng.vmware.com +SUFFIX|2 = proxy.vmware.com +NAMESERVER|1 = 10.20.145.1 +NAMESERVER|2 = 10.20.145.2 + +[DATETIME] +TIMEZONE = Africa/Abidjan +UTC = yes diff --git a/tests/unittests/test_vmware_config_file.py b/tests/unittests/test_vmware_config_file.py new file mode 100644 index 00000000..51166dd7 --- /dev/null +++ b/tests/unittests/test_vmware_config_file.py @@ -0,0 +1,103 @@ +# vi: ts=4 expandtab +# +# Copyright (C) 2015 Canonical Ltd. +# Copyright (C) 2016 VMware INC. +# +# Author: Sankar Tanguturi +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3, as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging +import sys +import unittest + +from cloudinit.sources.helpers.vmware.imc.boot_proto import BootProtoEnum +from cloudinit.sources.helpers.vmware.imc.config import Config +from cloudinit.sources.helpers.vmware.imc.config_file import ConfigFile + +logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +class TestVmwareConfigFile(unittest.TestCase): + + def test_utility_methods(self): + cf = ConfigFile("tests/data/vmware/cust-dhcp-2nic.cfg") + + cf.clear() + + self.assertEqual(0, cf.size(), "clear size") + + cf._insertKey(" PASSWORD|-PASS ", " foo ") + cf._insertKey("BAR", " ") + + self.assertEqual(2, cf.size(), "insert size") + self.assertEqual('foo', cf["PASSWORD|-PASS"], "password") + self.assertTrue("PASSWORD|-PASS" in cf, "hasPassword") + self.assertFalse(cf.should_keep_current_value("PASSWORD|-PASS"), + "keepPassword") + self.assertFalse(cf.should_remove_current_value("PASSWORD|-PASS"), + "removePassword") + self.assertFalse("FOO" in cf, "hasFoo") + self.assertTrue(cf.should_keep_current_value("FOO"), "keepFoo") + self.assertFalse(cf.should_remove_current_value("FOO"), "removeFoo") + self.assertTrue("BAR" in cf, "hasBar") + self.assertFalse(cf.should_keep_current_value("BAR"), "keepBar") + self.assertTrue(cf.should_remove_current_value("BAR"), "removeBar") + + def test_configfile_static_2nics(self): + cf = ConfigFile("tests/data/vmware/cust-static-2nic.cfg") + + conf = Config(cf) + + self.assertEqual('myhost1', conf.host_name, "hostName") + self.assertEqual('Africa/Abidjan', conf.timezone, "tz") + self.assertTrue(conf.utc, "utc") + + self.assertEqual(['10.20.145.1', '10.20.145.2'], + conf.name_servers, + "dns") + self.assertEqual(['eng.vmware.com', 'proxy.vmware.com'], + conf.dns_suffixes, + "suffixes") + + nics = conf.nics + ipv40 = nics[0].staticIpv4 + + self.assertEqual(2, len(nics), "nics") + self.assertEqual('NIC1', nics[0].name, "nic0") + self.assertEqual('00:50:56:a6:8c:08', nics[0].mac, "mac0") + self.assertEqual(BootProtoEnum.STATIC, nics[0].bootProto, "bootproto0") + self.assertEqual('10.20.87.154', ipv40[0].ip, "ipv4Addr0") + self.assertEqual('255.255.252.0', ipv40[0].netmask, "ipv4Mask0") + self.assertEqual(2, len(ipv40[0].gateways), "ipv4Gw0") + self.assertEqual('10.20.87.253', ipv40[0].gateways[0], "ipv4Gw0_0") + self.assertEqual('10.20.87.105', ipv40[0].gateways[1], "ipv4Gw0_1") + + self.assertEqual(1, len(nics[0].staticIpv6), "ipv6Cnt0") + self.assertEqual('fc00:10:20:87::154', + nics[0].staticIpv6[0].ip, + "ipv6Addr0") + + self.assertEqual('NIC2', nics[1].name, "nic1") + self.assertTrue(not nics[1].staticIpv6, "ipv61 dhcp") + + def test_config_file_dhcp_2nics(self): + cf = ConfigFile("tests/data/vmware/cust-dhcp-2nic.cfg") + + conf = Config(cf) + nics = conf.nics + self.assertEqual(2, len(nics), "nics") + self.assertEqual('NIC1', nics[0].name, "nic0") + self.assertEqual('00:50:56:a6:8c:08', nics[0].mac, "mac0") + self.assertEqual(BootProtoEnum.DHCP, nics[0].bootProto, "bootproto0") -- cgit v1.2.3 From 0ce71cb8975e19677eea415101e15da5f4095cd5 Mon Sep 17 00:00:00 2001 From: Sankar Tanguturi Date: Tue, 16 Feb 2016 17:34:24 -0800 Subject: - Used proper 4 space indentations for config_nic.py and nic.py - Implemented the 'search_file' function using 'os.walk()' - Fixed few variable names. - Removed size() function in config_file.py - Updated the test_config_file.py to use len() instead of .size() --- cloudinit/sources/DataSourceOVF.py | 34 +- .../sources/helpers/vmware/imc/config_file.py | 4 - cloudinit/sources/helpers/vmware/imc/config_nic.py | 433 +++++++++++---------- cloudinit/sources/helpers/vmware/imc/nic.py | 20 +- tests/unittests/test_vmware_config_file.py | 4 +- 5 files changed, 238 insertions(+), 257 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index add7d243..6d3bf7bb 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -64,13 +64,12 @@ class DataSourceOVF(sources.DataSource): (seedfile, contents) = get_ovf_env(self.paths.seed_dir) dmi_info = dmi_data() - system_uuid = "" system_type = "" - if dmi_info is False: + if dmi_info is None: LOG.debug("No dmidata utility found") else: - system_uuid, system_type = tuple(dmi_info) + (_, system_type) = dmi_info if 'vmware' in system_type.lower(): LOG.debug("VMware Virtual Platform found") @@ -172,11 +171,11 @@ class DataSourceOVFNet(DataSourceOVF): self.supported_seed_starts = ("http://", "https://", "ftp://") -def wait_for_imc_cfg_file(directoryPath, filename, maxwait=180, naplen=5): +def wait_for_imc_cfg_file(dirpath, filename, maxwait=180, naplen=5): waited = 0 while waited < maxwait: - fileFullPath = search_file(directoryPath, filename) + fileFullPath = search_file(dirpath, filename) if fileFullPath: return fileFullPath time.sleep(naplen) @@ -357,28 +356,13 @@ def dmi_data(): return (sys_uuid.lower(), sys_type) -def search_file(directoryPath, filename): - if not directoryPath or not filename: +def search_file(dirpath, filename): + if not dirpath or not filename: return None - dirs = [] - - if os.path.isdir(directoryPath): - dirs.append(directoryPath) - - while dirs: - dir = dirs.pop() - children = [] - try: - children.extend(os.listdir(dir)) - except: - LOG.debug("Ignoring the error while searching the directory %s" % dir) - for child in children: - childFullPath = os.path.join(dir, child) - if os.path.isdir(childFullPath): - dirs.append(childFullPath) - elif child == filename: - return childFullPath + for root, dirs, files in os.walk(dirpath): + if filename in files: + return os.path.join(root, filename) return None diff --git a/cloudinit/sources/helpers/vmware/imc/config_file.py b/cloudinit/sources/helpers/vmware/imc/config_file.py index 7c47d14c..bb9fb7dc 100644 --- a/cloudinit/sources/helpers/vmware/imc/config_file.py +++ b/cloudinit/sources/helpers/vmware/imc/config_file.py @@ -61,10 +61,6 @@ class ConfigFile(ConfigSource, dict): self[key] = val - def size(self): - """Return the number of properties present.""" - return len(self) - def _loadConfigFile(self, filename): """ Parses properties from the specified config file. diff --git a/cloudinit/sources/helpers/vmware/imc/config_nic.py b/cloudinit/sources/helpers/vmware/imc/config_nic.py index 8e2fc5d3..d79e6936 100644 --- a/cloudinit/sources/helpers/vmware/imc/config_nic.py +++ b/cloudinit/sources/helpers/vmware/imc/config_nic.py @@ -26,221 +26,222 @@ logger = logging.getLogger(__name__) class NicConfigurator: - def __init__(self, nics): - """ - Initialize the Nic Configurator - @param nics (list) an array of nics to configure - """ - self.nics = nics - self.mac2Name = {} - self.ipv4PrimaryGateway = None - self.ipv6PrimaryGateway = None - self.find_devices() - self._primaryNic = self.get_primary_nic() - - def get_primary_nic(self): - """ - Retrieve the primary nic if it exists - @return (NicBase): the primary nic if exists, None otherwise - """ - primaryNic = None - - for nic in self.nics: - if nic.primary: - if primaryNic: - raise Exception('There can only be one primary nic', - primaryNic.mac, nic.mac) + def __init__(self, nics): + """ + Initialize the Nic Configurator + @param nics (list) an array of nics to configure + """ + self.nics = nics + self.mac2Name = {} + self.ipv4PrimaryGateway = None + self.ipv6PrimaryGateway = None + self.find_devices() + self._primaryNic = self.get_primary_nic() + + def get_primary_nic(self): + """ + Retrieve the primary nic if it exists + @return (NicBase): the primary nic if exists, None otherwise + """ + primaryNic = None + + for nic in self.nics: + if nic.primary: + if primaryNic: + raise Exception('There can only be one primary nic', + primaryNic.mac, nic.mac) primaryNic = nic - return primaryNic - - def find_devices(self): - """ - Create the mac2Name dictionary - The mac address(es) are in the lower case - """ - cmd = 'ip addr show' - outText = subprocess.check_output(cmd, shell=True).decode() - sections = re.split(r'\n\d+: ', '\n' + outText)[1:] - - macPat = r'link/ether (([0-9A-Fa-f]{2}[:]){5}([0-9A-Fa-f]{2}))' - for section in sections: - matcher = re.search(macPat, section) - if not matcher: # Only keep info about nics - continue - mac = matcher.group(1).lower() - name = section.split(':', 1)[0] - self.mac2Name[mac] = name - - def gen_one_nic(self, nic): - """ - Return the lines needed to configure a nic - @return (str list): the string list to configure the nic - @param nic (NicBase): the nic to configure - """ - lines = [] - name = self.mac2Name.get(nic.mac.lower()) - if not name: - raise ValueError('No known device has MACADDR: %s' % nic.mac) - - if nic.onboot: - lines.append('auto %s' % name) - - # Customize IPv4 - lines.extend(self.gen_ipv4(name, nic)) - - # Customize IPv6 - lines.extend(self.gen_ipv6(name, nic)) - - lines.append('') - - return lines - - def gen_ipv4(self, name, nic): - """ - Return the lines needed to configure the IPv4 setting of a nic - @return (str list): the string list to configure the gateways - @param name (str): name of the nic - @param nic (NicBase): the nic to configure - """ - lines = [] - - bootproto = nic.bootProto.lower() - if nic.ipv4_mode.lower() == 'disabled': - bootproto = 'manual' - lines.append('iface %s inet %s' % (name, bootproto)) - - if bootproto != 'static': - return lines - - # Static Ipv4 - v4 = nic.staticIpv4 - if v4.ip: - lines.append(' address %s' % v4.ip) - if v4.netmask: - lines.append(' netmask %s' % v4.netmask) - - # Add the primary gateway - if nic.primary and v4.gateways: - self.ipv4PrimaryGateway = v4.gateways[0] - lines.append(' gateway %s metric 0' % self.ipv4PrimaryGateway) - return lines - - # Add routes if there is no primary nic - if not self._primaryNic: - lines.extend(self.gen_ipv4_route(nic, v4.gateways)) - - return lines - - def gen_ipv4_route(self, nic, gateways): - """ - Return the lines needed to configure additional Ipv4 route - @return (str list): the string list to configure the gateways - @param nic (NicBase): the nic to configure - @param gateways (str list): the list of gateways - """ - lines = [] - - for gateway in gateways: - lines.append(' up route add default gw %s metric 10000' % gateway) - - return lines - - def gen_ipv6(self, name, nic): - """ - Return the lines needed to configure the gateways for a nic - @return (str list): the string list to configure the gateways - @param name (str): name of the nic - @param nic (NicBase): the nic to configure - """ - lines = [] - - if not nic.staticIpv6: - return lines - - # Static Ipv6 - addrs = nic.staticIpv6 - lines.append('iface %s inet6 static' % name) - lines.append(' address %s' % addrs[0].ip) - lines.append(' netmask %s' % addrs[0].netmask) - - for addr in addrs[1:]: - lines.append(' up ifconfig %s inet6 add %s/%s' % (name, addr.ip, - addr.netmask)) - # Add the primary gateway - if nic.primary: - for addr in addrs: - if addr.gateway: - self.ipv6PrimaryGateway = addr.gateway - lines.append(' gateway %s' % self.ipv6PrimaryGateway) - return lines - - # Add routes if there is no primary nic - if not self._primaryNic: - lines.extend(self._genIpv6Route(name, nic, addrs)) - - return lines - - def _genIpv6Route(self, name, nic, addrs): - lines = [] - - for addr in addrs: - lines.append(' up route -A inet6 add default gw %s metric 10000' % - addr.gateway) - - return lines - - def generate(self): - """Return the lines that is needed to configure the nics""" - lines = [] - lines.append('iface lo inet loopback') - lines.append('auto lo') - lines.append('') - - for nic in self.nics: - lines.extend(self.gen_one_nic(nic)) - - return lines - - def clear_dhcp(self): - logger.info('Clearing DHCP leases') - - subprocess.call('pkill dhclient', shell=True) - subprocess.check_call('rm -f /var/lib/dhcp/*', shell=True) - - def if_down_up(self): - names = [] - for nic in self.nics: - name = self.mac2Name.get(nic.mac.lower()) - names.append(name) - - for name in names: - logger.info('Bring down interface %s' % name) - subprocess.check_call('ifdown %s' % name, shell=True) - - self.clear_dhcp() - - for name in names: - logger.info('Bring up interface %s' % name) - subprocess.check_call('ifup %s' % name, shell=True) - - def configure(self): - """ - Configure the /etc/network/intefaces - Make a back up of the original - """ - containingDir = '/etc/network' - - interfaceFile = os.path.join(containingDir, 'interfaces') - originalFile = os.path.join(containingDir, - 'interfaces.before_vmware_customization') - - if not os.path.exists(originalFile) and os.path.exists(interfaceFile): - os.rename(interfaceFile, originalFile) - - lines = self.generate() - with open(interfaceFile, 'w') as fp: - for line in lines: - fp.write('%s\n' % line) - - self.if_down_up() + return primaryNic + + def find_devices(self): + """ + Create the mac2Name dictionary + The mac address(es) are in the lower case + """ + cmd = 'ip addr show' + outText = subprocess.check_output(cmd, shell=True).decode() + sections = re.split(r'\n\d+: ', '\n' + outText)[1:] + + macPat = r'link/ether (([0-9A-Fa-f]{2}[:]){5}([0-9A-Fa-f]{2}))' + for section in sections: + matcher = re.search(macPat, section) + if not matcher: # Only keep info about nics + continue + mac = matcher.group(1).lower() + name = section.split(':', 1)[0] + self.mac2Name[mac] = name + + def gen_one_nic(self, nic): + """ + Return the lines needed to configure a nic + @return (str list): the string list to configure the nic + @param nic (NicBase): the nic to configure + """ + lines = [] + name = self.mac2Name.get(nic.mac.lower()) + if not name: + raise ValueError('No known device has MACADDR: %s' % nic.mac) + + if nic.onboot: + lines.append('auto %s' % name) + + # Customize IPv4 + lines.extend(self.gen_ipv4(name, nic)) + + # Customize IPv6 + lines.extend(self.gen_ipv6(name, nic)) + + lines.append('') + + return lines + + def gen_ipv4(self, name, nic): + """ + Return the lines needed to configure the IPv4 setting of a nic + @return (str list): the string list to configure the gateways + @param name (str): name of the nic + @param nic (NicBase): the nic to configure + """ + lines = [] + + bootproto = nic.bootProto.lower() + if nic.ipv4_mode.lower() == 'disabled': + bootproto = 'manual' + lines.append('iface %s inet %s' % (name, bootproto)) + + if bootproto != 'static': + return lines + + # Static Ipv4 + v4 = nic.staticIpv4 + if v4.ip: + lines.append(' address %s' % v4.ip) + if v4.netmask: + lines.append(' netmask %s' % v4.netmask) + + # Add the primary gateway + if nic.primary and v4.gateways: + self.ipv4PrimaryGateway = v4.gateways[0] + lines.append(' gateway %s metric 0' % self.ipv4PrimaryGateway) + return lines + + # Add routes if there is no primary nic + if not self._primaryNic: + lines.extend(self.gen_ipv4_route(nic, v4.gateways)) + + return lines + + def gen_ipv4_route(self, nic, gateways): + """ + Return the lines needed to configure additional Ipv4 route + @return (str list): the string list to configure the gateways + @param nic (NicBase): the nic to configure + @param gateways (str list): the list of gateways + """ + lines = [] + + for gateway in gateways: + lines.append(' up route add default gw %s metric 10000' % + gateway) + + return lines + + def gen_ipv6(self, name, nic): + """ + Return the lines needed to configure the gateways for a nic + @return (str list): the string list to configure the gateways + @param name (str): name of the nic + @param nic (NicBase): the nic to configure + """ + lines = [] + + if not nic.staticIpv6: + return lines + + # Static Ipv6 + addrs = nic.staticIpv6 + lines.append('iface %s inet6 static' % name) + lines.append(' address %s' % addrs[0].ip) + lines.append(' netmask %s' % addrs[0].netmask) + + for addr in addrs[1:]: + lines.append(' up ifconfig %s inet6 add %s/%s' % (name, addr.ip, + addr.netmask)) + # Add the primary gateway + if nic.primary: + for addr in addrs: + if addr.gateway: + self.ipv6PrimaryGateway = addr.gateway + lines.append(' gateway %s' % self.ipv6PrimaryGateway) + return lines + + # Add routes if there is no primary nic + if not self._primaryNic: + lines.extend(self._genIpv6Route(name, nic, addrs)) + + return lines + + def _genIpv6Route(self, name, nic, addrs): + lines = [] + + for addr in addrs: + lines.append(' up route -A inet6 add default gw %s metric 10000' % + addr.gateway) + + return lines + + def generate(self): + """Return the lines that is needed to configure the nics""" + lines = [] + lines.append('iface lo inet loopback') + lines.append('auto lo') + lines.append('') + + for nic in self.nics: + lines.extend(self.gen_one_nic(nic)) + + return lines + + def clear_dhcp(self): + logger.info('Clearing DHCP leases') + + subprocess.call('pkill dhclient', shell=True) + subprocess.check_call('rm -f /var/lib/dhcp/*', shell=True) + + def if_down_up(self): + names = [] + for nic in self.nics: + name = self.mac2Name.get(nic.mac.lower()) + names.append(name) + + for name in names: + logger.info('Bring down interface %s' % name) + subprocess.check_call('ifdown %s' % name, shell=True) + + self.clear_dhcp() + + for name in names: + logger.info('Bring up interface %s' % name) + subprocess.check_call('ifup %s' % name, shell=True) + + def configure(self): + """ + Configure the /etc/network/intefaces + Make a back up of the original + """ + containingDir = '/etc/network' + + interfaceFile = os.path.join(containingDir, 'interfaces') + originalFile = os.path.join(containingDir, + 'interfaces.before_vmware_customization') + + if not os.path.exists(originalFile) and os.path.exists(interfaceFile): + os.rename(interfaceFile, originalFile) + + lines = self.generate() + with open(interfaceFile, 'w') as fp: + for line in lines: + fp.write('%s\n' % line) + + self.if_down_up() diff --git a/cloudinit/sources/helpers/vmware/imc/nic.py b/cloudinit/sources/helpers/vmware/imc/nic.py index 6628a3ec..b5d704ea 100644 --- a/cloudinit/sources/helpers/vmware/imc/nic.py +++ b/cloudinit/sources/helpers/vmware/imc/nic.py @@ -49,35 +49,35 @@ class Nic(NicBase): def primary(self): value = self._get('PRIMARY') if value: - value = value.lower() - return value == 'yes' or value == 'true' + value = value.lower() + return value == 'yes' or value == 'true' else: - return False + return False @property def onboot(self): value = self._get('ONBOOT') if value: - value = value.lower() - return value == 'yes' or value == 'true' + value = value.lower() + return value == 'yes' or value == 'true' else: - return False + return False @property def bootProto(self): value = self._get('BOOTPROTO') if value: - return value.lower() + return value.lower() else: - return "" + return "" @property def ipv4_mode(self): value = self._get('IPv4_MODE') if value: - return value.lower() + return value.lower() else: - return "" + return "" @property def staticIpv4(self): diff --git a/tests/unittests/test_vmware_config_file.py b/tests/unittests/test_vmware_config_file.py index 51166dd7..d5c7367b 100644 --- a/tests/unittests/test_vmware_config_file.py +++ b/tests/unittests/test_vmware_config_file.py @@ -36,12 +36,12 @@ class TestVmwareConfigFile(unittest.TestCase): cf.clear() - self.assertEqual(0, cf.size(), "clear size") + self.assertEqual(0, len(cf), "clear size") cf._insertKey(" PASSWORD|-PASS ", " foo ") cf._insertKey("BAR", " ") - self.assertEqual(2, cf.size(), "insert size") + self.assertEqual(2, len(cf), "insert size") self.assertEqual('foo', cf["PASSWORD|-PASS"], "password") self.assertTrue("PASSWORD|-PASS" in cf, "hasPassword") self.assertFalse(cf.should_keep_current_value("PASSWORD|-PASS"), -- cgit v1.2.3 From 14915526ca67bbf7842028d48170015b85f87469 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 1 Mar 2016 00:19:55 -0500 Subject: lxd: general fix after testing A few changes: a.) change to using '--name=value' rather than '--name' 'value' b.) make sure only strings are passed to command (useful for storage_create_loop: which is likely an integer) c.) document simple working example d.) support installing zfs if not present and storage_backedn has it. --- cloudinit/config/cc_lxd.py | 35 ++++++++++++++++++------ doc/examples/cloud-config-lxd.txt | 7 +++++ tests/unittests/test_handler/test_handler_lxd.py | 9 +++--- 3 files changed, 38 insertions(+), 13 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_lxd.py b/cloudinit/config/cc_lxd.py index aaafb643..84eec7a5 100644 --- a/cloudinit/config/cc_lxd.py +++ b/cloudinit/config/cc_lxd.py @@ -38,16 +38,36 @@ from cloudinit import util def handle(name, cfg, cloud, log, args): # Get config lxd_cfg = cfg.get('lxd') - if not lxd_cfg and isinstance(lxd_cfg, dict): + if not lxd_cfg: log.debug("Skipping module named %s, not present or disabled by cfg") return + if not isinstance(lxd_cfg, dict): + log.warn("lxd config must be a dictionary. found a '%s'", + type(lxd_cfg)) + return + + init_cfg = lxd_cfg.get('init') + if not init_cfg: + init_cfg = {} + + if not isinstance(init_cfg, dict): + log.warn("lxd/init config must be a dictionary. found a '%s'", + type(init_cfg)) + init_cfg = {} + + packages = [] + if (init_cfg.get("storage_backend") == "zfs" and not util.which('zfs')): + packages.append('zfs') # Ensure lxd is installed if not util.which("lxd"): + packages.append('lxd') + + if len(packages): try: - cloud.distro.install_packages(("lxd",)) + cloud.distro.install_packages(packages) except util.ProcessExecutionError as e: - log.warn("no lxd executable and could not install lxd:", e) + log.warn("failed to install packages %s: %s", packages, e) return # Set up lxd if init config is given @@ -55,14 +75,11 @@ def handle(name, cfg, cloud, log, args): 'network_address', 'network_port', 'storage_backend', 'storage_create_device', 'storage_create_loop', 'storage_pool', 'trust_password') - init_cfg = lxd_cfg.get('init') + if init_cfg: - if not isinstance(init_cfg, dict): - log.warn("lxd/init config must be a dictionary. found a '%s'", - type(f)) - return cmd = ['lxd', 'init', '--auto'] for k in init_keys: if init_cfg.get(k): - cmd.extend(["--%s" % k.replace('_', '-'), init_cfg[k]]) + cmd.extend(["--%s=%s" % + (k.replace('_', '-'), str(init_cfg[k]))]) util.subp(cmd) diff --git a/doc/examples/cloud-config-lxd.txt b/doc/examples/cloud-config-lxd.txt index f66da4c3..b9bb4aa5 100644 --- a/doc/examples/cloud-config-lxd.txt +++ b/doc/examples/cloud-config-lxd.txt @@ -19,3 +19,10 @@ lxd: network_port: 8443 storage_backend: zfs storage_pool: datapool + storage_create_loop: 10 + + +# The simplist working configuration is +# lxd: +# init: +# storage_backend: dir diff --git a/tests/unittests/test_handler/test_handler_lxd.py b/tests/unittests/test_handler/test_handler_lxd.py index 4d858b8f..65794a41 100644 --- a/tests/unittests/test_handler/test_handler_lxd.py +++ b/tests/unittests/test_handler/test_handler_lxd.py @@ -43,9 +43,10 @@ class TestLxd(t_help.TestCase): self.assertTrue(mock_util.which.called) init_call = mock_util.subp.call_args_list[0][0][0] self.assertEquals(init_call, - ['lxd', 'init', '--auto', '--network-address', - '0.0.0.0', '--storage-backend', 'zfs', - '--storage-pool', 'poolname']) + ['lxd', 'init', '--auto', + '--network-address=0.0.0.0', + '--storage-backend=zfs', + '--storage-pool=poolname']) @mock.patch("cloudinit.config.cc_lxd.util") def test_lxd_install(self, mock_util): @@ -55,4 +56,4 @@ class TestLxd(t_help.TestCase): cc_lxd.handle('cc_lxd', self.lxd_cfg, cc, LOG, []) self.assertTrue(cc.distro.install_packages.called) install_pkg = cc.distro.install_packages.call_args_list[0][0][0] - self.assertEquals(install_pkg, ('lxd',)) + self.assertEquals(sorted(install_pkg), ['lxd', 'zfs']) -- cgit v1.2.3 From c496b6a11d504ef62371cb5e03ac80b4ceb37540 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 3 Mar 2016 12:20:48 -0500 Subject: run pyflakes in more places, fix fallout this makes 'make' run pyflakes, so failures there will stop a build. also adds it to tox. --- Makefile | 6 ++++-- cloudinit/sources/DataSourceOVF.py | 3 ++- cloudinit/sources/helpers/vmware/imc/config_nic.py | 1 - cloudinit/util.py | 2 +- tests/unittests/test_datasource/test_azure_helper.py | 2 -- tests/unittests/test_datasource/test_smartos.py | 1 - tests/unittests/test_handler/test_handler_power_state.py | 2 +- tox.ini | 6 +++++- 8 files changed, 13 insertions(+), 10 deletions(-) (limited to 'tests/unittests') diff --git a/Makefile b/Makefile index bb0c5253..8987d51c 100644 --- a/Makefile +++ b/Makefile @@ -14,13 +14,15 @@ ifeq ($(distro),) distro = redhat endif -all: test check_version +all: check + +check: test check_version pyflakes pep8: @$(CWD)/tools/run-pep8 $(PY_FILES) pyflakes: - @$(CWD)/tools/tox-venv py34 pyflakes $(PY_FILES) + @pyflakes $(PY_FILES) pip-requirements: @echo "Installing cloud-init dependencies..." diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index 72ba5aba..d12601a4 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -90,7 +90,8 @@ class DataSourceOVF(sources.DataSource): nicConfigurator.configure() vmwarePlatformFound = True except Exception as inst: - LOG.debug("Error while parsing the Customization Config File") + LOG.debug("Error while parsing the Customization " + "Config File: %s", inst) finally: dirPath = os.path.dirname(vmwareImcConfigFilePath) shutil.rmtree(dirPath) diff --git a/cloudinit/sources/helpers/vmware/imc/config_nic.py b/cloudinit/sources/helpers/vmware/imc/config_nic.py index 172a1649..6d721134 100644 --- a/cloudinit/sources/helpers/vmware/imc/config_nic.py +++ b/cloudinit/sources/helpers/vmware/imc/config_nic.py @@ -19,7 +19,6 @@ import logging import os -import subprocess import re from cloudinit import util diff --git a/cloudinit/util.py b/cloudinit/util.py index 45d49e66..0a639bb9 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -2147,7 +2147,7 @@ def _read_dmi_syspath(key): LOG.debug("dmi data %s returned %s", dmi_key_path, key_data) return key_data.strip() - except Exception as e: + except Exception: logexc(LOG, "failed read of %s", dmi_key_path) return None diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 8dbdfb0b..1134199b 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,6 +1,4 @@ import os -import struct -import unittest from cloudinit.sources.helpers import azure as azure_helper from ..helpers import TestCase diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 1235436d..ccb9f080 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -31,7 +31,6 @@ import shutil import stat import tempfile import uuid -import unittest from binascii import crc32 import serial diff --git a/tests/unittests/test_handler/test_handler_power_state.py b/tests/unittests/test_handler/test_handler_power_state.py index 5687b10d..cd376e9c 100644 --- a/tests/unittests/test_handler/test_handler_power_state.py +++ b/tests/unittests/test_handler/test_handler_power_state.py @@ -107,7 +107,7 @@ def check_lps_ret(psc_return, mode=None): if 'shutdown' not in psc_return[0][0]: errs.append("string 'shutdown' not in cmd") - if 'condition' is None: + if condition is None: errs.append("condition was not returned") if mode is not None: diff --git a/tox.ini b/tox.ini index b72df0c9..fd65f6ef 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py27,py3 +envlist = py27,py3,pyflakes recreate = True [testenv] @@ -10,6 +10,10 @@ deps = -r{toxinidir}/test-requirements.txt [testenv:py3] basepython = python3 +[testenv:pyflakes] +basepython = python3 +commands = {envpython} -m pyflakes {posargs:cloudinit/ tests/ tools/} + # https://github.com/gabrielfalcao/HTTPretty/issues/223 setenv = LC_ALL = en_US.utf-8 -- cgit v1.2.3 From 96f1742b36241cee152aa2cb5b4a5e1a267a4770 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 3 Mar 2016 15:17:24 -0500 Subject: fix lxd module to not do anything unless config provided --- cloudinit/config/cc_lxd.py | 30 ++++++++++++------------ tests/unittests/test_handler/test_handler_lxd.py | 16 +++++++++++++ 2 files changed, 31 insertions(+), 15 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_lxd.py b/cloudinit/config/cc_lxd.py index 84eec7a5..80a4d219 100644 --- a/cloudinit/config/cc_lxd.py +++ b/cloudinit/config/cc_lxd.py @@ -47,22 +47,24 @@ def handle(name, cfg, cloud, log, args): return init_cfg = lxd_cfg.get('init') - if not init_cfg: - init_cfg = {} - if not isinstance(init_cfg, dict): log.warn("lxd/init config must be a dictionary. found a '%s'", type(init_cfg)) init_cfg = {} - packages = [] - if (init_cfg.get("storage_backend") == "zfs" and not util.which('zfs')): - packages.append('zfs') + if not init_cfg: + log.debug("no lxd/init config. disabled.") + return + packages = [] # Ensure lxd is installed if not util.which("lxd"): packages.append('lxd') - + + # if using zfs, get the utils + if (init_cfg.get("storage_backend") == "zfs" and not util.which('zfs')): + packages.append('zfs') + if len(packages): try: cloud.distro.install_packages(packages) @@ -75,11 +77,9 @@ def handle(name, cfg, cloud, log, args): 'network_address', 'network_port', 'storage_backend', 'storage_create_device', 'storage_create_loop', 'storage_pool', 'trust_password') - - if init_cfg: - cmd = ['lxd', 'init', '--auto'] - for k in init_keys: - if init_cfg.get(k): - cmd.extend(["--%s=%s" % - (k.replace('_', '-'), str(init_cfg[k]))]) - util.subp(cmd) + cmd = ['lxd', 'init', '--auto'] + for k in init_keys: + if init_cfg.get(k): + cmd.extend(["--%s=%s" % + (k.replace('_', '-'), str(init_cfg[k]))]) + util.subp(cmd) diff --git a/tests/unittests/test_handler/test_handler_lxd.py b/tests/unittests/test_handler/test_handler_lxd.py index 65794a41..7ffa2a53 100644 --- a/tests/unittests/test_handler/test_handler_lxd.py +++ b/tests/unittests/test_handler/test_handler_lxd.py @@ -57,3 +57,19 @@ class TestLxd(t_help.TestCase): self.assertTrue(cc.distro.install_packages.called) install_pkg = cc.distro.install_packages.call_args_list[0][0][0] self.assertEquals(sorted(install_pkg), ['lxd', 'zfs']) + + @mock.patch("cloudinit.config.cc_lxd.util") + def test_no_init_does_nothing(self, mock_util): + cc = self._get_cloud('ubuntu') + cc.distro = mock.MagicMock() + cc_lxd.handle('cc_lxd', {'lxd': {}}, cc, LOG, []) + self.assertFalse(cc.distro.install_packages.called) + self.assertFalse(mock_util.subp.called) + + @mock.patch("cloudinit.config.cc_lxd.util") + def test_no_lxd_does_nothing(self, mock_util): + cc = self._get_cloud('ubuntu') + cc.distro = mock.MagicMock() + cc_lxd.handle('cc_lxd', {'package_update': True}, cc, LOG, []) + self.assertFalse(cc.distro.install_packages.called) + self.assertFalse(mock_util.subp.called) -- cgit v1.2.3 From 8cb7c3f7b5339e686bfbf95996b51afafeaf9c9e Mon Sep 17 00:00:00 2001 From: Ryan Harper Date: Thu, 3 Mar 2016 16:20:10 -0600 Subject: Update pep8 runner and fix pep8 issues --- Makefile | 9 ++-- bin/cloud-init | 43 +++++++++--------- cloudinit/config/cc_apt_configure.py | 6 ++- cloudinit/config/cc_disk_setup.py | 31 +++++++------ cloudinit/config/cc_grub_dpkg.py | 8 ++-- cloudinit/config/cc_keys_to_console.py | 2 +- cloudinit/config/cc_lxd.py | 2 +- cloudinit/config/cc_mounts.py | 12 ++--- cloudinit/config/cc_power_state_change.py | 2 +- cloudinit/config/cc_puppet.py | 6 +-- cloudinit/config/cc_resizefs.py | 2 +- cloudinit/config/cc_rh_subscription.py | 4 +- cloudinit/config/cc_set_hostname.py | 2 +- cloudinit/config/cc_ssh.py | 7 +-- cloudinit/config/cc_update_etc_hosts.py | 6 +-- cloudinit/config/cc_update_hostname.py | 2 +- cloudinit/config/cc_yum_add_repo.py | 2 +- cloudinit/distros/__init__.py | 12 ++--- cloudinit/distros/arch.py | 6 +-- cloudinit/distros/debian.py | 5 ++- cloudinit/distros/freebsd.py | 4 +- cloudinit/distros/gentoo.py | 4 +- cloudinit/distros/parsers/hostname.py | 2 +- cloudinit/distros/parsers/resolv_conf.py | 2 +- cloudinit/distros/parsers/sys_conf.py | 7 ++- cloudinit/filters/launch_index.py | 2 +- cloudinit/helpers.py | 7 +-- cloudinit/sources/DataSourceAzure.py | 21 +++++---- cloudinit/sources/DataSourceConfigDrive.py | 2 +- cloudinit/sources/DataSourceEc2.py | 10 ++--- cloudinit/sources/DataSourceMAAS.py | 15 ++++--- cloudinit/sources/DataSourceOVF.py | 4 +- cloudinit/sources/DataSourceOpenNebula.py | 3 +- cloudinit/sources/DataSourceSmartOS.py | 7 ++- cloudinit/ssh_util.py | 3 +- cloudinit/stages.py | 18 ++++---- cloudinit/url_helper.py | 6 +-- cloudinit/util.py | 15 ++++--- tests/unittests/test_data.py | 5 ++- tests/unittests/test_datasource/test_altcloud.py | 23 +++++----- tests/unittests/test_datasource/test_azure.py | 15 ++++--- .../unittests/test_datasource/test_configdrive.py | 12 ++--- tests/unittests/test_datasource/test_maas.py | 16 +++---- tests/unittests/test_datasource/test_smartos.py | 6 +-- .../test_handler/test_handler_power_state.py | 3 +- .../test_handler/test_handler_seed_random.py | 3 +- .../unittests/test_handler/test_handler_snappy.py | 3 +- tests/unittests/test_sshutil.py | 3 +- tests/unittests/test_templating.py | 3 +- tools/hacking.py | 16 +++---- tools/mock-meta.py | 27 +++++++----- tools/run-pep8 | 51 ++++++++-------------- 52 files changed, 244 insertions(+), 243 deletions(-) (limited to 'tests/unittests') diff --git a/Makefile b/Makefile index 058ac199..fb65b70b 100644 --- a/Makefile +++ b/Makefile @@ -20,13 +20,14 @@ all: test check_version check: pep8 pyflakes pyflakes3 unittest pep8: - @$(CWD)/tools/run-pep8 $(PY_FILES) + @$(CWD)/tools/run-pep8 pyflakes: - @$(CWD)/tools/tox-venv py27 pyflakes $(PY_FILES) + @$(CWD)/tools/run-pyflakes -pyflakes: - @$(CWD)/tools/tox-venv py34 pyflakes $(PY_FILES) +pyflakes3: + @$(CWD)/tools/run-pyflakes3 + unittest: nosetests $(noseopts) tests/unittests diff --git a/bin/cloud-init b/bin/cloud-init index 9b90c45e..7f665e7e 100755 --- a/bin/cloud-init +++ b/bin/cloud-init @@ -194,7 +194,7 @@ def main_init(name, args): if args.debug: # Reset so that all the debug handlers are closed out LOG.debug(("Logging being reset, this logger may no" - " longer be active shortly")) + " longer be active shortly")) logging.resetLogging() logging.setupLogging(init.cfg) apply_reporting_cfg(init.cfg) @@ -276,9 +276,9 @@ def main_init(name, args): # This may run user-data handlers and/or perform # url downloads and such as needed. (ran, _results) = init.cloudify().run('consume_data', - init.consume_data, - args=[PER_INSTANCE], - freq=PER_INSTANCE) + init.consume_data, + args=[PER_INSTANCE], + freq=PER_INSTANCE) if not ran: # Just consume anything that is set to run per-always # if nothing ran in the per-instance code @@ -349,7 +349,7 @@ def main_modules(action_name, args): if args.debug: # Reset so that all the debug handlers are closed out LOG.debug(("Logging being reset, this logger may no" - " longer be active shortly")) + " longer be active shortly")) logging.resetLogging() logging.setupLogging(mods.cfg) apply_reporting_cfg(init.cfg) @@ -534,7 +534,8 @@ def status_wrapper(name, args, data_d=None, link_d=None): errors.extend(v1[m].get('errors', [])) atomic_write_json(result_path, - {'v1': {'datasource': v1['datasource'], 'errors': errors}}) + {'v1': {'datasource': v1['datasource'], + 'errors': errors}}) util.sym_link(os.path.relpath(result_path, link_d), result_link, force=True) @@ -578,13 +579,13 @@ def main(): # These settings are used for the 'config' and 'final' stages parser_mod = subparsers.add_parser('modules', - help=('activates modules ' - 'using a given configuration key')) + help=('activates modules using ' + 'a given configuration key')) parser_mod.add_argument("--mode", '-m', action='store', - help=("module configuration name " - "to use (default: %(default)s)"), - default='config', - choices=('init', 'config', 'final')) + help=("module configuration name " + "to use (default: %(default)s)"), + default='config', + choices=('init', 'config', 'final')) parser_mod.set_defaults(action=('modules', main_modules)) # These settings are used when you want to query information @@ -600,22 +601,22 @@ def main(): # This subcommand allows you to run a single module parser_single = subparsers.add_parser('single', - help=('run a single module ')) + help=('run a single module ')) parser_single.set_defaults(action=('single', main_single)) parser_single.add_argument("--name", '-n', action="store", - help="module name to run", - required=True) + help="module name to run", + required=True) parser_single.add_argument("--frequency", action="store", - help=("frequency of the module"), - required=False, - choices=list(FREQ_SHORT_NAMES.keys())) + help=("frequency of the module"), + required=False, + choices=list(FREQ_SHORT_NAMES.keys())) parser_single.add_argument("--report", action="store_true", help="enable reporting", required=False) parser_single.add_argument("module_args", nargs="*", - metavar='argument', - help=('any additional arguments to' - ' pass to this module')) + metavar='argument', + help=('any additional arguments to' + ' pass to this module')) parser_single.set_defaults(action=('single', main_single)) args = parser.parse_args() diff --git a/cloudinit/config/cc_apt_configure.py b/cloudinit/config/cc_apt_configure.py index 9e9e9e26..702977cb 100644 --- a/cloudinit/config/cc_apt_configure.py +++ b/cloudinit/config/cc_apt_configure.py @@ -91,7 +91,8 @@ def handle(name, cfg, cloud, log, _args): if matchcfg: matcher = re.compile(matchcfg).search else: - matcher = lambda f: False + def matcher(x): + return False errors = add_sources(cfg['apt_sources'], params, aa_repo_match=matcher) @@ -173,7 +174,8 @@ def add_sources(srclist, template_params=None, aa_repo_match=None): template_params = {} if aa_repo_match is None: - aa_repo_match = lambda f: False + def aa_repo_match(x): + return False errorlist = [] for ent in srclist: diff --git a/cloudinit/config/cc_disk_setup.py b/cloudinit/config/cc_disk_setup.py index d5b0d1d7..0ecc2e4c 100644 --- a/cloudinit/config/cc_disk_setup.py +++ b/cloudinit/config/cc_disk_setup.py @@ -167,11 +167,12 @@ def enumerate_disk(device, nodeps=False): parts = [x for x in (info.strip()).splitlines() if len(x.split()) > 0] for part in parts: - d = {'name': None, - 'type': None, - 'fstype': None, - 'label': None, - } + d = { + 'name': None, + 'type': None, + 'fstype': None, + 'label': None, + } for key, value in value_splitter(part): d[key.lower()] = value @@ -701,11 +702,12 @@ def lookup_force_flag(fs): """ A force flag might be -F or -F, this look it up """ - flags = {'ext': '-F', - 'btrfs': '-f', - 'xfs': '-f', - 'reiserfs': '-f', - } + flags = { + 'ext': '-F', + 'btrfs': '-f', + 'xfs': '-f', + 'reiserfs': '-f', + } if 'ext' in fs.lower(): fs = 'ext' @@ -824,10 +826,11 @@ def mkfs(fs_cfg): # Create the commands if fs_cmd: - fs_cmd = fs_cfg['cmd'] % {'label': label, - 'filesystem': fs_type, - 'device': device, - } + fs_cmd = fs_cfg['cmd'] % { + 'label': label, + 'filesystem': fs_type, + 'device': device, + } else: # Find the mkfs command mkfs_cmd = util.which("mkfs.%s" % fs_type) diff --git a/cloudinit/config/cc_grub_dpkg.py b/cloudinit/config/cc_grub_dpkg.py index 456597af..acd3e60a 100644 --- a/cloudinit/config/cc_grub_dpkg.py +++ b/cloudinit/config/cc_grub_dpkg.py @@ -38,11 +38,11 @@ def handle(name, cfg, _cloud, log, _args): idevs = util.get_cfg_option_str(mycfg, "grub-pc/install_devices", None) idevs_empty = util.get_cfg_option_str(mycfg, - "grub-pc/install_devices_empty", None) + "grub-pc/install_devices_empty", + None) if ((os.path.exists("/dev/sda1") and not os.path.exists("/dev/sda")) or - (os.path.exists("/dev/xvda1") - and not os.path.exists("/dev/xvda"))): + (os.path.exists("/dev/xvda1") and not os.path.exists("/dev/xvda"))): if idevs is None: idevs = "" if idevs_empty is None: @@ -66,7 +66,7 @@ def handle(name, cfg, _cloud, log, _args): (idevs, idevs_empty)) log.debug("Setting grub debconf-set-selections with '%s','%s'" % - (idevs, idevs_empty)) + (idevs, idevs_empty)) try: util.subp(['debconf-set-selections'], dconf_sel) diff --git a/cloudinit/config/cc_keys_to_console.py b/cloudinit/config/cc_keys_to_console.py index f1c1adff..aa844ee9 100644 --- a/cloudinit/config/cc_keys_to_console.py +++ b/cloudinit/config/cc_keys_to_console.py @@ -48,7 +48,7 @@ def handle(name, cfg, cloud, log, _args): "ssh_fp_console_blacklist", []) key_blacklist = util.get_cfg_option_list(cfg, "ssh_key_console_blacklist", - ["ssh-dss"]) + ["ssh-dss"]) try: cmd = [helper_path] diff --git a/cloudinit/config/cc_lxd.py b/cloudinit/config/cc_lxd.py index 7d8a0202..e2fdf68e 100644 --- a/cloudinit/config/cc_lxd.py +++ b/cloudinit/config/cc_lxd.py @@ -59,7 +59,7 @@ def handle(name, cfg, cloud, log, args): if init_cfg: if not isinstance(init_cfg, dict): log.warn("lxd/init config must be a dictionary. found a '%s'", - type(init_cfg)) + type(init_cfg)) return cmd = ['lxd', 'init', '--auto'] for k in init_keys: diff --git a/cloudinit/config/cc_mounts.py b/cloudinit/config/cc_mounts.py index 11089d8d..4fe3ee21 100644 --- a/cloudinit/config/cc_mounts.py +++ b/cloudinit/config/cc_mounts.py @@ -204,12 +204,12 @@ def setup_swapfile(fname, size=None, maxsize=None): try: util.ensure_dir(tdir) util.log_time(LOG.debug, msg, func=util.subp, - args=[['sh', '-c', - ('rm -f "$1" && umask 0066 && ' - '{ fallocate -l "${2}M" "$1" || ' - ' dd if=/dev/zero "of=$1" bs=1M "count=$2"; } && ' - 'mkswap "$1" || { r=$?; rm -f "$1"; exit $r; }'), - 'setup_swap', fname, mbsize]]) + args=[['sh', '-c', + ('rm -f "$1" && umask 0066 && ' + '{ fallocate -l "${2}M" "$1" || ' + ' dd if=/dev/zero "of=$1" bs=1M "count=$2"; } && ' + 'mkswap "$1" || { r=$?; rm -f "$1"; exit $r; }'), + 'setup_swap', fname, mbsize]]) except Exception as e: raise IOError("Failed %s: %s" % (msg, e)) diff --git a/cloudinit/config/cc_power_state_change.py b/cloudinit/config/cc_power_state_change.py index 7d9567e3..cc3f7f70 100644 --- a/cloudinit/config/cc_power_state_change.py +++ b/cloudinit/config/cc_power_state_change.py @@ -105,7 +105,7 @@ def handle(_name, cfg, _cloud, log, _args): log.debug("After pid %s ends, will execute: %s" % (mypid, ' '.join(args))) - util.fork_cb(run_after_pid_gone, mypid, cmdline, timeout, log, + util.fork_cb(run_after_pid_gone, mypid, cmdline, timeout, log, condition, execmd, [args, devnull_fp]) diff --git a/cloudinit/config/cc_puppet.py b/cloudinit/config/cc_puppet.py index 4501598e..774d3322 100644 --- a/cloudinit/config/cc_puppet.py +++ b/cloudinit/config/cc_puppet.py @@ -36,8 +36,8 @@ def _autostart_puppet(log): # Set puppet to automatically start if os.path.exists('/etc/default/puppet'): util.subp(['sed', '-i', - '-e', 's/^START=.*/START=yes/', - '/etc/default/puppet'], capture=False) + '-e', 's/^START=.*/START=yes/', + '/etc/default/puppet'], capture=False) elif os.path.exists('/bin/systemctl'): util.subp(['/bin/systemctl', 'enable', 'puppet.service'], capture=False) @@ -65,7 +65,7 @@ def handle(name, cfg, cloud, log, _args): " doing nothing.")) elif install: log.debug(("Attempting to install puppet %s,"), - version if version else 'latest') + version if version else 'latest') cloud.distro.install_packages(('puppet', version)) # ... and then update the puppet configuration diff --git a/cloudinit/config/cc_resizefs.py b/cloudinit/config/cc_resizefs.py index cbc07853..2a2a9f59 100644 --- a/cloudinit/config/cc_resizefs.py +++ b/cloudinit/config/cc_resizefs.py @@ -166,7 +166,7 @@ def handle(name, cfg, _cloud, log, args): func=do_resize, args=(resize_cmd, log)) else: util.log_time(logfunc=log.debug, msg="Resizing", - func=do_resize, args=(resize_cmd, log)) + func=do_resize, args=(resize_cmd, log)) action = 'Resized' if resize_root == NOBLOCK: diff --git a/cloudinit/config/cc_rh_subscription.py b/cloudinit/config/cc_rh_subscription.py index 3b30c47e..6f474aed 100644 --- a/cloudinit/config/cc_rh_subscription.py +++ b/cloudinit/config/cc_rh_subscription.py @@ -127,8 +127,8 @@ class SubscriptionManager(object): return False, not_bool if (self.servicelevel is not None) and \ - ((not self.auto_attach) - or (util.is_false(str(self.auto_attach)))): + ((not self.auto_attach) or + (util.is_false(str(self.auto_attach)))): no_auto = ("The service-level key must be used in conjunction " "with the auto-attach key. Please re-run with " diff --git a/cloudinit/config/cc_set_hostname.py b/cloudinit/config/cc_set_hostname.py index 5d7f4331..f43d8d5a 100644 --- a/cloudinit/config/cc_set_hostname.py +++ b/cloudinit/config/cc_set_hostname.py @@ -24,7 +24,7 @@ from cloudinit import util def handle(name, cfg, cloud, log, _args): if util.get_cfg_option_bool(cfg, "preserve_hostname", False): log.debug(("Configuration option 'preserve_hostname' is set," - " not setting the hostname in module %s"), name) + " not setting the hostname in module %s"), name) return (hostname, fqdn) = util.get_hostname_fqdn(cfg, cloud) diff --git a/cloudinit/config/cc_ssh.py b/cloudinit/config/cc_ssh.py index 5bd2dec6..d24e43c0 100644 --- a/cloudinit/config/cc_ssh.py +++ b/cloudinit/config/cc_ssh.py @@ -30,9 +30,10 @@ from cloudinit import distros as ds from cloudinit import ssh_util from cloudinit import util -DISABLE_ROOT_OPTS = ("no-port-forwarding,no-agent-forwarding," -"no-X11-forwarding,command=\"echo \'Please login as the user \\\"$USER\\\" " -"rather than the user \\\"root\\\".\';echo;sleep 10\"") +DISABLE_ROOT_OPTS = ( + "no-port-forwarding,no-agent-forwarding," + "no-X11-forwarding,command=\"echo \'Please login as the user \\\"$USER\\\"" + " rather than the user \\\"root\\\".\';echo;sleep 10\"") GENERATE_KEY_NAMES = ['rsa', 'dsa', 'ecdsa', 'ed25519'] KEY_FILE_TPL = '/etc/ssh/ssh_host_%s_key' diff --git a/cloudinit/config/cc_update_etc_hosts.py b/cloudinit/config/cc_update_etc_hosts.py index d3dd1f32..15703efe 100644 --- a/cloudinit/config/cc_update_etc_hosts.py +++ b/cloudinit/config/cc_update_etc_hosts.py @@ -41,10 +41,10 @@ def handle(name, cfg, cloud, log, _args): if not tpl_fn_name: raise RuntimeError(("No hosts template could be" " found for distro %s") % - (cloud.distro.osfamily)) + (cloud.distro.osfamily)) templater.render_to_file(tpl_fn_name, '/etc/hosts', - {'hostname': hostname, 'fqdn': fqdn}) + {'hostname': hostname, 'fqdn': fqdn}) elif manage_hosts == "localhost": (hostname, fqdn) = util.get_hostname_fqdn(cfg, cloud) @@ -57,4 +57,4 @@ def handle(name, cfg, cloud, log, _args): cloud.distro.update_etc_hosts(hostname, fqdn) else: log.debug(("Configuration option 'manage_etc_hosts' is not set," - " not managing /etc/hosts in module %s"), name) + " not managing /etc/hosts in module %s"), name) diff --git a/cloudinit/config/cc_update_hostname.py b/cloudinit/config/cc_update_hostname.py index e396ba13..5b78afe1 100644 --- a/cloudinit/config/cc_update_hostname.py +++ b/cloudinit/config/cc_update_hostname.py @@ -29,7 +29,7 @@ frequency = PER_ALWAYS def handle(name, cfg, cloud, log, _args): if util.get_cfg_option_bool(cfg, "preserve_hostname", False): log.debug(("Configuration option 'preserve_hostname' is set," - " not updating the hostname in module %s"), name) + " not updating the hostname in module %s"), name) return (hostname, fqdn) = util.get_hostname_fqdn(cfg, cloud) diff --git a/cloudinit/config/cc_yum_add_repo.py b/cloudinit/config/cc_yum_add_repo.py index 3b821af9..64fba869 100644 --- a/cloudinit/config/cc_yum_add_repo.py +++ b/cloudinit/config/cc_yum_add_repo.py @@ -92,7 +92,7 @@ def handle(name, cfg, _cloud, log, _args): for req_field in ['baseurl']: if req_field not in repo_config: log.warn(("Repository %s does not contain a %s" - " configuration 'required' entry"), + " configuration 'required' entry"), repo_id, req_field) missing_required += 1 if not missing_required: diff --git a/cloudinit/distros/__init__.py b/cloudinit/distros/__init__.py index 71884b32..661a9fd2 100644 --- a/cloudinit/distros/__init__.py +++ b/cloudinit/distros/__init__.py @@ -211,8 +211,8 @@ class Distro(object): # If the system hostname is different than the previous # one or the desired one lets update it as well - if (not sys_hostname) or (sys_hostname == prev_hostname - and sys_hostname != hostname): + if ((not sys_hostname) or (sys_hostname == prev_hostname and + sys_hostname != hostname)): update_files.append(sys_fn) # If something else has changed the hostname after we set it @@ -221,7 +221,7 @@ class Distro(object): if (sys_hostname and prev_hostname and sys_hostname != prev_hostname): LOG.info("%s differs from %s, assuming user maintained hostname.", - prev_hostname_fn, sys_fn) + prev_hostname_fn, sys_fn) return # Remove duplicates (incase the previous config filename) @@ -289,7 +289,7 @@ class Distro(object): def _bring_up_interface(self, device_name): cmd = ['ifup', device_name] LOG.debug("Attempting to run bring up interface %s using command %s", - device_name, cmd) + device_name, cmd) try: (_out, err) = util.subp(cmd) if len(err): @@ -548,7 +548,7 @@ class Distro(object): for member in members: if not util.is_user(member): LOG.warn("Unable to add group member '%s' to group '%s'" - "; user does not exist.", member, name) + "; user does not exist.", member, name) continue util.subp(['usermod', '-a', '-G', name, member]) @@ -886,7 +886,7 @@ def fetch(name): locs, looked_locs = importer.find_module(name, ['', __name__], ['Distro']) if not locs: raise ImportError("No distribution found for distro %s (searched %s)" - % (name, looked_locs)) + % (name, looked_locs)) mod = importer.import_module(locs[0]) cls = getattr(mod, 'Distro') return cls diff --git a/cloudinit/distros/arch.py b/cloudinit/distros/arch.py index 45fcf26f..93a2e008 100644 --- a/cloudinit/distros/arch.py +++ b/cloudinit/distros/arch.py @@ -74,7 +74,7 @@ class Distro(distros.Distro): 'Interface': dev, 'IP': info.get('bootproto'), 'Address': "('%s/%s')" % (info.get('address'), - info.get('netmask')), + info.get('netmask')), 'Gateway': info.get('gateway'), 'DNS': str(tuple(info.get('dns-nameservers'))).replace(',', '') } @@ -86,7 +86,7 @@ class Distro(distros.Distro): if nameservers: util.write_file(self.resolve_conf_fn, - convert_resolv_conf(nameservers)) + convert_resolv_conf(nameservers)) return dev_names @@ -102,7 +102,7 @@ class Distro(distros.Distro): def _bring_up_interface(self, device_name): cmd = ['netctl', 'restart', device_name] LOG.debug("Attempting to run bring up interface %s using command %s", - device_name, cmd) + device_name, cmd) try: (_out, err) = util.subp(cmd) if len(err): diff --git a/cloudinit/distros/debian.py b/cloudinit/distros/debian.py index 6d3a82bf..db5890b1 100644 --- a/cloudinit/distros/debian.py +++ b/cloudinit/distros/debian.py @@ -159,8 +159,9 @@ class Distro(distros.Distro): # Allow the output of this to flow outwards (ie not be captured) util.log_time(logfunc=LOG.debug, - msg="apt-%s [%s]" % (command, ' '.join(cmd)), func=util.subp, - args=(cmd,), kwargs={'env': e, 'capture': False}) + msg="apt-%s [%s]" % (command, ' '.join(cmd)), + func=util.subp, + args=(cmd,), kwargs={'env': e, 'capture': False}) def update_package_sources(self): self._runner.run("update-sources", self.package_command, diff --git a/cloudinit/distros/freebsd.py b/cloudinit/distros/freebsd.py index 4c484639..72012056 100644 --- a/cloudinit/distros/freebsd.py +++ b/cloudinit/distros/freebsd.py @@ -205,8 +205,8 @@ class Distro(distros.Distro): redact_opts = ['passwd'] for key, val in kwargs.items(): - if (key in adduser_opts and val - and isinstance(val, six.string_types)): + if (key in adduser_opts and val and + isinstance(val, six.string_types)): adduser_cmd.extend([adduser_opts[key], val]) # Redact certain fields from the logs diff --git a/cloudinit/distros/gentoo.py b/cloudinit/distros/gentoo.py index 9e80583c..6267dd6e 100644 --- a/cloudinit/distros/gentoo.py +++ b/cloudinit/distros/gentoo.py @@ -66,7 +66,7 @@ class Distro(distros.Distro): def _bring_up_interface(self, device_name): cmd = ['/etc/init.d/net.%s' % device_name, 'restart'] LOG.debug("Attempting to run bring up interface %s using command %s", - device_name, cmd) + device_name, cmd) try: (_out, err) = util.subp(cmd) if len(err): @@ -88,7 +88,7 @@ class Distro(distros.Distro): (_out, err) = util.subp(cmd) if len(err): LOG.warn("Running %s resulted in stderr output: %s", cmd, - err) + err) except util.ProcessExecutionError: util.logexc(LOG, "Running interface command %s failed", cmd) return False diff --git a/cloudinit/distros/parsers/hostname.py b/cloudinit/distros/parsers/hostname.py index 84a1de42..efb185d4 100644 --- a/cloudinit/distros/parsers/hostname.py +++ b/cloudinit/distros/parsers/hostname.py @@ -84,5 +84,5 @@ class HostnameConf(object): hostnames_found.add(head) if len(hostnames_found) > 1: raise IOError("Multiple hostnames (%s) found!" - % (hostnames_found)) + % (hostnames_found)) return entries diff --git a/cloudinit/distros/parsers/resolv_conf.py b/cloudinit/distros/parsers/resolv_conf.py index 8aee03a4..2ed13d9c 100644 --- a/cloudinit/distros/parsers/resolv_conf.py +++ b/cloudinit/distros/parsers/resolv_conf.py @@ -132,7 +132,7 @@ class ResolvConf(object): # Some hard limit on 256 chars total raise ValueError(("Adding %r would go beyond the " "256 maximum search list character limit") - % (search_domain)) + % (search_domain)) self._remove_option('search') self._contents.append(('option', ['search', s_list, ''])) return flat_sds diff --git a/cloudinit/distros/parsers/sys_conf.py b/cloudinit/distros/parsers/sys_conf.py index d795e12f..6157cf32 100644 --- a/cloudinit/distros/parsers/sys_conf.py +++ b/cloudinit/distros/parsers/sys_conf.py @@ -77,8 +77,7 @@ class SysConf(configobj.ConfigObj): quot_func = None if value[0] in ['"', "'"] and value[-1] in ['"', "'"]: if len(value) == 1: - quot_func = (lambda x: - self._get_single_quote(x) % x) + quot_func = (lambda x: self._get_single_quote(x) % x) else: # Quote whitespace if it isn't the start + end of a shell command if value.strip().startswith("$(") and value.strip().endswith(")"): @@ -91,10 +90,10 @@ class SysConf(configobj.ConfigObj): # to use single quotes which won't get expanded... if re.search(r"[\n\"']", value): quot_func = (lambda x: - self._get_triple_quote(x) % x) + self._get_triple_quote(x) % x) else: quot_func = (lambda x: - self._get_single_quote(x) % x) + self._get_single_quote(x) % x) else: quot_func = pipes.quote if not quot_func: diff --git a/cloudinit/filters/launch_index.py b/cloudinit/filters/launch_index.py index 5bebd318..baecdac9 100644 --- a/cloudinit/filters/launch_index.py +++ b/cloudinit/filters/launch_index.py @@ -61,7 +61,7 @@ class Filter(object): discarded += 1 LOG.debug(("Discarding %s multipart messages " "which do not match launch index %s"), - discarded, self.wanted_idx) + discarded, self.wanted_idx) new_message = copy.copy(message) new_message.set_payload(new_msgs) new_message[ud.ATTACHMENT_FIELD] = str(len(new_msgs)) diff --git a/cloudinit/helpers.py b/cloudinit/helpers.py index 5e99d185..a6eb20fe 100644 --- a/cloudinit/helpers.py +++ b/cloudinit/helpers.py @@ -139,9 +139,10 @@ class FileSemaphores(object): # but the item had run before we did canon_sem_name. if cname != name and os.path.exists(self._get_path(name, freq)): LOG.warn("%s has run without canonicalized name [%s].\n" - "likely the migrator has not yet run. It will run next boot.\n" - "run manually with: cloud-init single --name=migrator" - % (name, cname)) + "likely the migrator has not yet run. " + "It will run next boot.\n" + "run manually with: cloud-init single --name=migrator" + % (name, cname)) return True return False diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index bd80a8a6..b03ab895 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -38,7 +38,8 @@ LOG = logging.getLogger(__name__) DS_NAME = 'Azure' DEFAULT_METADATA = {"instance-id": "iid-AZURE-NODE"} AGENT_START = ['service', 'walinuxagent', 'start'] -BOUNCE_COMMAND = ['sh', '-xc', +BOUNCE_COMMAND = [ + 'sh', '-xc', "i=$interface; x=0; ifdown $i || x=$?; ifup $i || x=$?; exit $x"] BUILTIN_DS_CONFIG = { @@ -91,9 +92,9 @@ def temporary_hostname(temp_hostname, cfg, hostname_command='hostname'): """ policy = cfg['hostname_bounce']['policy'] previous_hostname = get_hostname(hostname_command) - if (not util.is_true(cfg.get('set_hostname')) - or util.is_false(policy) - or (previous_hostname == temp_hostname and policy != 'force')): + if (not util.is_true(cfg.get('set_hostname')) or + util.is_false(policy) or + (previous_hostname == temp_hostname and policy != 'force')): yield None return set_hostname(temp_hostname, hostname_command) @@ -123,8 +124,8 @@ class DataSourceAzureNet(sources.DataSource): with temporary_hostname(temp_hostname, self.ds_cfg, hostname_command=hostname_command) \ as previous_hostname: - if (previous_hostname is not None - and util.is_true(self.ds_cfg.get('set_hostname'))): + if (previous_hostname is not None and + util.is_true(self.ds_cfg.get('set_hostname'))): cfg = self.ds_cfg['hostname_bounce'] try: perform_hostname_bounce(hostname=temp_hostname, @@ -152,7 +153,8 @@ class DataSourceAzureNet(sources.DataSource): else: bname = str(pk['fingerprint'] + ".crt") fp_files += [os.path.join(ddir, bname)] - LOG.debug("ssh authentication: using fingerprint from fabirc") + LOG.debug("ssh authentication: " + "using fingerprint from fabirc") missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", func=wait_for_files, @@ -506,7 +508,7 @@ def read_azure_ovf(contents): raise BrokenAzureDataSource("invalid xml: %s" % e) results = find_child(dom.documentElement, - lambda n: n.localName == "ProvisioningSection") + lambda n: n.localName == "ProvisioningSection") if len(results) == 0: raise NonAzureDataSource("No ProvisioningSection") @@ -516,7 +518,8 @@ def read_azure_ovf(contents): provSection = results[0] lpcs_nodes = find_child(provSection, - lambda n: n.localName == "LinuxProvisioningConfigurationSet") + lambda n: + n.localName == "LinuxProvisioningConfigurationSet") if len(results) == 0: raise NonAzureDataSource("No LinuxProvisioningConfigurationSet") diff --git a/cloudinit/sources/DataSourceConfigDrive.py b/cloudinit/sources/DataSourceConfigDrive.py index eb474079..e3916208 100644 --- a/cloudinit/sources/DataSourceConfigDrive.py +++ b/cloudinit/sources/DataSourceConfigDrive.py @@ -39,7 +39,7 @@ FS_TYPES = ('vfat', 'iso9660') LABEL_TYPES = ('config-2',) POSSIBLE_MOUNTS = ('sr', 'cd') OPTICAL_DEVICES = tuple(('/dev/%s%s' % (z, i) for z in POSSIBLE_MOUNTS - for i in range(0, 2))) + for i in range(0, 2))) class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource): diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index 0032d06c..6a897f7d 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -61,12 +61,12 @@ class DataSourceEc2(sources.DataSource): if not self.wait_for_metadata_service(): return False start_time = time.time() - self.userdata_raw = ec2.get_instance_userdata(self.api_ver, - self.metadata_address) + self.userdata_raw = \ + ec2.get_instance_userdata(self.api_ver, self.metadata_address) self.metadata = ec2.get_instance_metadata(self.api_ver, self.metadata_address) LOG.debug("Crawl of metadata service took %s seconds", - int(time.time() - start_time)) + int(time.time() - start_time)) return True except Exception: util.logexc(LOG, "Failed reading from metadata address %s", @@ -132,13 +132,13 @@ class DataSourceEc2(sources.DataSource): start_time = time.time() url = uhelp.wait_for_url(urls=urls, max_wait=max_wait, - timeout=timeout, status_cb=LOG.warn) + timeout=timeout, status_cb=LOG.warn) if url: LOG.debug("Using metadata source: '%s'", url2base[url]) else: LOG.critical("Giving up on md from %s after %s seconds", - urls, int(time.time() - start_time)) + urls, int(time.time() - start_time)) self.metadata_address = url2base.get(url) return bool(url) diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index cfc59ca5..f18c4cee 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -275,17 +275,18 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='Interact with MAAS DS') parser.add_argument("--config", metavar="file", - help="specify DS config file", default=None) + help="specify DS config file", default=None) parser.add_argument("--ckey", metavar="key", - help="the consumer key to auth with", default=None) + help="the consumer key to auth with", default=None) parser.add_argument("--tkey", metavar="key", - help="the token key to auth with", default=None) + help="the token key to auth with", default=None) parser.add_argument("--csec", metavar="secret", - help="the consumer secret (likely '')", default="") + help="the consumer secret (likely '')", default="") parser.add_argument("--tsec", metavar="secret", - help="the token secret to auth with", default=None) + help="the token secret to auth with", default=None) parser.add_argument("--apiver", metavar="version", - help="the apiver to use ("" can be used)", default=MD_VERSION) + help="the apiver to use ("" can be used)", + default=MD_VERSION) subcmds = parser.add_subparsers(title="subcommands", dest="subcmd") subcmds.add_parser('crawl', help="crawl the datasource") @@ -297,7 +298,7 @@ if __name__ == "__main__": args = parser.parse_args() creds = {'consumer_key': args.ckey, 'token_key': args.tkey, - 'token_secret': args.tsec, 'consumer_secret': args.csec} + 'token_secret': args.tsec, 'consumer_secret': args.csec} if args.config: cfg = util.read_conf(args.config) diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index 58a4b2a2..adf9b12e 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -264,14 +264,14 @@ def get_properties(contents): # could also check here that elem.namespaceURI == # "http://schemas.dmtf.org/ovf/environment/1" propSections = find_child(dom.documentElement, - lambda n: n.localName == "PropertySection") + lambda n: n.localName == "PropertySection") if len(propSections) == 0: raise XmlError("No 'PropertySection's") props = {} propElems = find_child(propSections[0], - (lambda n: n.localName == "Property")) + (lambda n: n.localName == "Property")) for elem in propElems: key = elem.attributes.getNamedItemNS(envNsURI, "key").value diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index ac2c3b45..b26940d1 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -404,7 +404,8 @@ def read_context_disk_dir(source_dir, asuser=None): if ssh_key_var: lines = context.get(ssh_key_var).splitlines() results['metadata']['public-keys'] = [l for l in lines - if len(l) and not l.startswith("#")] + if len(l) and not + l.startswith("#")] # custom hostname -- try hostname or leave cloud-init # itself create hostname from IP address later diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 7453379a..139ee52c 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -90,8 +90,7 @@ BUILTIN_DS_CONFIG = { 'user-data', 'user-script', 'sdc:datacenter_name', - 'sdc:uuid', - ], + 'sdc:uuid'], 'base64_keys': [], 'base64_all': False, 'disk_aliases': {'ephemeral0': '/dev/vdb'}, @@ -450,7 +449,7 @@ class JoyentMetadataClient(object): response = bytearray() response.extend(self.metasource.read(1)) - while response[-1:] != b'\n': + while response[-1:] != b'\n': response.extend(self.metasource.read(1)) response = response.rstrip().decode('ascii') LOG.debug('Read "%s" from metadata transport.', response) @@ -513,7 +512,7 @@ def write_boot_content(content, content_f, link=None, shebang=False, except Exception as e: util.logexc(LOG, ("Failed to identify script type for %s" % - content_f, e)) + content_f, e)) if link: try: diff --git a/cloudinit/ssh_util.py b/cloudinit/ssh_util.py index 9b2f5ed5..c74a7ae2 100644 --- a/cloudinit/ssh_util.py +++ b/cloudinit/ssh_util.py @@ -31,7 +31,8 @@ LOG = logging.getLogger(__name__) DEF_SSHD_CFG = "/etc/ssh/sshd_config" # taken from openssh source key.c/key_type_from_name -VALID_KEY_TYPES = ("rsa", "dsa", "ssh-rsa", "ssh-dss", "ecdsa", +VALID_KEY_TYPES = ( + "rsa", "dsa", "ssh-rsa", "ssh-dss", "ecdsa", "ssh-rsa-cert-v00@openssh.com", "ssh-dss-cert-v00@openssh.com", "ssh-rsa-cert-v00@openssh.com", "ssh-dss-cert-v00@openssh.com", "ssh-rsa-cert-v01@openssh.com", "ssh-dss-cert-v01@openssh.com", diff --git a/cloudinit/stages.py b/cloudinit/stages.py index 9f192c8d..dbcf3d55 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -509,13 +509,13 @@ class Init(object): def consume_data(self, frequency=PER_INSTANCE): # Consume the userdata first, because we need want to let the part # handlers run first (for merging stuff) - with events.ReportEventStack( - "consume-user-data", "reading and applying user-data", - parent=self.reporter): + with events.ReportEventStack("consume-user-data", + "reading and applying user-data", + parent=self.reporter): self._consume_userdata(frequency) - with events.ReportEventStack( - "consume-vendor-data", "reading and applying vendor-data", - parent=self.reporter): + with events.ReportEventStack("consume-vendor-data", + "reading and applying vendor-data", + parent=self.reporter): self._consume_vendordata(frequency) # Perform post-consumption adjustments so that @@ -655,7 +655,7 @@ class Modules(object): else: raise TypeError(("Failed to read '%s' item in config," " unknown type %s") % - (item, type_utils.obj_name(item))) + (item, type_utils.obj_name(item))) return module_list def _fixup_modules(self, raw_mods): @@ -762,8 +762,8 @@ class Modules(object): if skipped: LOG.info("Skipping modules %s because they are not verified " - "on distro '%s'. To run anyway, add them to " - "'unverified_modules' in config.", skipped, d_name) + "on distro '%s'. To run anyway, add them to " + "'unverified_modules' in config.", skipped, d_name) if forced: LOG.info("running unverified_modules: %s", forced) diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index f2e1390e..936f7da5 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -252,9 +252,9 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1, # attrs return UrlResponse(r) except exceptions.RequestException as e: - if (isinstance(e, (exceptions.HTTPError)) - and hasattr(e, 'response') # This appeared in v 0.10.8 - and hasattr(e.response, 'status_code')): + if (isinstance(e, (exceptions.HTTPError)) and + hasattr(e, 'response') and # This appeared in v 0.10.8 + hasattr(e.response, 'status_code')): excps.append(UrlError(e, code=e.response.status_code, headers=e.response.headers, url=url)) diff --git a/cloudinit/util.py b/cloudinit/util.py index 45d49e66..de37b0f5 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -612,7 +612,7 @@ def redirect_output(outfmt, errfmt, o_out=None, o_err=None): def make_url(scheme, host, port=None, - path='', params='', query='', fragment=''): + path='', params='', query='', fragment=''): pieces = [] pieces.append(scheme or '') @@ -804,8 +804,8 @@ def load_yaml(blob, default=None, allowed=(dict,)): blob = decode_binary(blob) try: LOG.debug("Attempting to load yaml from string " - "of length %s with allowed root types %s", - len(blob), allowed) + "of length %s with allowed root types %s", + len(blob), allowed) converted = safeyaml.load(blob) if not isinstance(converted, allowed): # Yes this will just be caught, but thats ok for now... @@ -878,7 +878,7 @@ def read_conf_with_confd(cfgfile): if not isinstance(confd, six.string_types): raise TypeError(("Config file %s contains 'conf_d' " "with non-string type %s") % - (cfgfile, type_utils.obj_name(confd))) + (cfgfile, type_utils.obj_name(confd))) else: confd = str(confd).strip() elif os.path.isdir("%s.d" % cfgfile): @@ -1041,7 +1041,8 @@ def is_resolvable(name): for iname in badnames: try: result = socket.getaddrinfo(iname, None, 0, 0, - socket.SOCK_STREAM, socket.AI_CANONNAME) + socket.SOCK_STREAM, + socket.AI_CANONNAME) badresults[iname] = [] for (_fam, _stype, _proto, cname, sockaddr) in result: badresults[iname].append("%s: %s" % (cname, sockaddr[0])) @@ -1109,7 +1110,7 @@ def close_stdin(): def find_devs_with(criteria=None, oformat='device', - tag=None, no_cache=False, path=None): + tag=None, no_cache=False, path=None): """ find devices matching given criteria (via blkid) criteria can be *one* of: @@ -1628,7 +1629,7 @@ def write_file(filename, content, mode=0o644, omode="wb"): content = decode_binary(content) write_type = 'characters' LOG.debug("Writing to %s - %s: [%s] %s %s", - filename, omode, mode, len(content), write_type) + filename, omode, mode, len(content), write_type) with SeLinuxGuard(path=filename): with open(filename, omode) as fh: fh.write(content) diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index c603bfdb..9c1ec1d4 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -27,11 +27,12 @@ from cloudinit import stages from cloudinit import user_data as ud from cloudinit import util -INSTANCE_ID = "i-testing" - from . import helpers +INSTANCE_ID = "i-testing" + + class FakeDataSource(sources.DataSource): def __init__(self, userdata=None, vendordata=None): diff --git a/tests/unittests/test_datasource/test_altcloud.py b/tests/unittests/test_datasource/test_altcloud.py index e9cd2fa5..85759c68 100644 --- a/tests/unittests/test_datasource/test_altcloud.py +++ b/tests/unittests/test_datasource/test_altcloud.py @@ -134,8 +134,7 @@ class TestGetCloudType(TestCase): ''' util.read_dmi_data = _dmi_data('RHEV') dsrc = DataSourceAltCloud({}, None, self.paths) - self.assertEquals('RHEV', \ - dsrc.get_cloud_type()) + self.assertEquals('RHEV', dsrc.get_cloud_type()) def test_vsphere(self): ''' @@ -144,8 +143,7 @@ class TestGetCloudType(TestCase): ''' util.read_dmi_data = _dmi_data('VMware Virtual Platform') dsrc = DataSourceAltCloud({}, None, self.paths) - self.assertEquals('VSPHERE', \ - dsrc.get_cloud_type()) + self.assertEquals('VSPHERE', dsrc.get_cloud_type()) def test_unknown(self): ''' @@ -154,8 +152,7 @@ class TestGetCloudType(TestCase): ''' util.read_dmi_data = _dmi_data('Unrecognized Platform') dsrc = DataSourceAltCloud({}, None, self.paths) - self.assertEquals('UNKNOWN', \ - dsrc.get_cloud_type()) + self.assertEquals('UNKNOWN', dsrc.get_cloud_type()) class TestGetDataCloudInfoFile(TestCase): @@ -412,27 +409,27 @@ class TestReadUserDataCallback(TestCase): '''Test read_user_data_callback() with both files.''' self.assertEquals('test user data', - read_user_data_callback(self.mount_dir)) + read_user_data_callback(self.mount_dir)) def test_callback_dc(self): '''Test read_user_data_callback() with only DC file.''' _remove_user_data_files(self.mount_dir, - dc_file=False, - non_dc_file=True) + dc_file=False, + non_dc_file=True) self.assertEquals('test user data', - read_user_data_callback(self.mount_dir)) + read_user_data_callback(self.mount_dir)) def test_callback_non_dc(self): '''Test read_user_data_callback() with only non-DC file.''' _remove_user_data_files(self.mount_dir, - dc_file=True, - non_dc_file=False) + dc_file=True, + non_dc_file=False) self.assertEquals('test user data', - read_user_data_callback(self.mount_dir)) + read_user_data_callback(self.mount_dir)) def test_callback_none(self): '''Test read_user_data_callback() no files are found.''' diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 3933794f..4c9c7d8b 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -207,7 +207,7 @@ class TestAzureDataSource(TestCase): yaml_cfg = "{agent_command: my_command}\n" cfg = yaml.safe_load(yaml_cfg) odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': yaml_cfg, 'encoding': 'plain'}} + 'dscfg': {'text': yaml_cfg, 'encoding': 'plain'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} dsrc = self._get_ds(data) @@ -219,8 +219,8 @@ class TestAzureDataSource(TestCase): # set dscfg in via base64 encoded yaml cfg = {'agent_command': "my_command"} odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': b64e(yaml.dump(cfg)), - 'encoding': 'base64'}} + 'dscfg': {'text': b64e(yaml.dump(cfg)), + 'encoding': 'base64'}} data = {'ovfcontent': construct_valid_ovf_env(data=odata)} dsrc = self._get_ds(data) @@ -267,7 +267,8 @@ class TestAzureDataSource(TestCase): # should equal that after the '$' pos = defuser['passwd'].rfind("$") + 1 self.assertEqual(defuser['passwd'], - crypt.crypt(odata['UserPassword'], defuser['passwd'][0:pos])) + crypt.crypt(odata['UserPassword'], + defuser['passwd'][0:pos])) def test_userdata_plain(self): mydata = "FOOBAR" @@ -364,8 +365,8 @@ class TestAzureDataSource(TestCase): # Make sure that user can affect disk aliases dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}} odata = {'HostName': "myhost", 'UserName': "myuser", - 'dscfg': {'text': b64e(yaml.dump(dscfg)), - 'encoding': 'base64'}} + 'dscfg': {'text': b64e(yaml.dump(dscfg)), + 'encoding': 'base64'}} usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'}, 'ephemeral0': False}} userdata = '#cloud-config' + yaml.dump(usercfg) + "\n" @@ -634,7 +635,7 @@ class TestReadAzureOvf(TestCase): def test_invalid_xml_raises_non_azure_ds(self): invalid_xml = "" + construct_valid_ovf_env(data={}) self.assertRaises(DataSourceAzure.BrokenAzureDataSource, - DataSourceAzure.read_azure_ovf, invalid_xml) + DataSourceAzure.read_azure_ovf, invalid_xml) def test_load_with_pubkeys(self): mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}] diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index 83aca505..3954ceb3 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -293,9 +293,8 @@ class TestConfigDriveDataSource(TestCase): util.is_partition = my_is_partition devs_with_answers = {"TYPE=vfat": [], - "TYPE=iso9660": ["/dev/vdb"], - "LABEL=config-2": ["/dev/vdb"], - } + "TYPE=iso9660": ["/dev/vdb"], + "LABEL=config-2": ["/dev/vdb"]} self.assertEqual(["/dev/vdb"], ds.find_candidate_devs()) # add a vfat item @@ -306,9 +305,10 @@ class TestConfigDriveDataSource(TestCase): # verify that partitions are considered, that have correct label. devs_with_answers = {"TYPE=vfat": ["/dev/sda1"], - "TYPE=iso9660": [], "LABEL=config-2": ["/dev/vdb3"]} + "TYPE=iso9660": [], + "LABEL=config-2": ["/dev/vdb3"]} self.assertEqual(["/dev/vdb3"], - ds.find_candidate_devs()) + ds.find_candidate_devs()) finally: util.find_devs_with = orig_find_devs_with @@ -319,7 +319,7 @@ class TestConfigDriveDataSource(TestCase): populate_dir(self.tmp, CFG_DRIVE_FILES_V2) myds = cfg_ds_from_dir(self.tmp) self.assertEqual(myds.get_public_ssh_keys(), - [OSTACK_META['public_keys']['mykey']]) + [OSTACK_META['public_keys']['mykey']]) def cfg_ds_from_dir(seed_d): diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index eb97b692..77d15cac 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -25,9 +25,9 @@ class TestMAASDataSource(TestCase): """Verify a valid seeddir is read as such.""" data = {'instance-id': 'i-valid01', - 'local-hostname': 'valid01-hostname', - 'user-data': b'valid01-userdata', - 'public-keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname'} + 'local-hostname': 'valid01-hostname', + 'user-data': b'valid01-userdata', + 'public-keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname'} my_d = os.path.join(self.tmp, "valid") populate_dir(my_d, data) @@ -45,8 +45,8 @@ class TestMAASDataSource(TestCase): """Verify extra files do not affect seed_dir validity.""" data = {'instance-id': 'i-valid-extra', - 'local-hostname': 'valid-extra-hostname', - 'user-data': b'valid-extra-userdata', 'foo': 'bar'} + 'local-hostname': 'valid-extra-hostname', + 'user-data': b'valid-extra-userdata', 'foo': 'bar'} my_d = os.path.join(self.tmp, "valid_extra") populate_dir(my_d, data) @@ -64,7 +64,7 @@ class TestMAASDataSource(TestCase): """Verify that invalid seed_dir raises MAASSeedDirMalformed.""" valid = {'instance-id': 'i-instanceid', - 'local-hostname': 'test-hostname', 'user-data': ''} + 'local-hostname': 'test-hostname', 'user-data': ''} my_based = os.path.join(self.tmp, "valid_extra") @@ -94,8 +94,8 @@ class TestMAASDataSource(TestCase): def test_seed_dir_missing(self): """Verify that missing seed_dir raises MAASSeedDirNone.""" self.assertRaises(DataSourceMAAS.MAASSeedDirNone, - DataSourceMAAS.read_maas_seed_dir, - os.path.join(self.tmp, "nonexistantdirectory")) + DataSourceMAAS.read_maas_seed_dir, + os.path.join(self.tmp, "nonexistantdirectory")) def test_seed_url_valid(self): """Verify that valid seed_url is read as such.""" diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 1235436d..5e617b83 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -463,8 +463,8 @@ class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): payloadstr = ' {0}'.format(self.response_parts['payload']) return ('V2 {length} {crc} {request_id} ' '{command}{payloadstr}\n'.format( - payloadstr=payloadstr, - **self.response_parts).encode('ascii')) + payloadstr=payloadstr, + **self.response_parts).encode('ascii')) self.metasource_data = None @@ -501,7 +501,7 @@ class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): written_line = self.serial.write.call_args[0][0] print(type(written_line)) self.assertEndsWith(written_line.decode('ascii'), - b'\n'.decode('ascii')) + b'\n'.decode('ascii')) self.assertEqual(1, written_line.count(b'\n')) def _get_written_line(self, key='some_key'): diff --git a/tests/unittests/test_handler/test_handler_power_state.py b/tests/unittests/test_handler/test_handler_power_state.py index 5687b10d..f9660ff6 100644 --- a/tests/unittests/test_handler/test_handler_power_state.py +++ b/tests/unittests/test_handler/test_handler_power_state.py @@ -74,7 +74,7 @@ class TestLoadPowerState(t_help.TestCase): class TestCheckCondition(t_help.TestCase): def cmd_with_exit(self, rc): return([sys.executable, '-c', 'import sys; sys.exit(%s)' % rc]) - + def test_true_is_true(self): self.assertEqual(psc.check_condition(True), True) @@ -94,7 +94,6 @@ class TestCheckCondition(t_help.TestCase): self.assertEqual(mocklog.warn.call_count, 1) - def check_lps_ret(psc_return, mode=None): if len(psc_return) != 3: raise TypeError("length returned = %d" % len(psc_return)) diff --git a/tests/unittests/test_handler/test_handler_seed_random.py b/tests/unittests/test_handler/test_handler_seed_random.py index 0bcdcb31..34d11f21 100644 --- a/tests/unittests/test_handler/test_handler_seed_random.py +++ b/tests/unittests/test_handler/test_handler_seed_random.py @@ -190,7 +190,8 @@ class TestRandomSeed(t_help.TestCase): c = self._get_cloud('ubuntu', {}) self.whichdata = {} self.assertRaises(ValueError, cc_seed_random.handle, - 'test', {'random_seed': {'command_required': True}}, c, LOG, []) + 'test', {'random_seed': {'command_required': True}}, + c, LOG, []) def test_seed_command_and_required(self): c = self._get_cloud('ubuntu', {}) diff --git a/tests/unittests/test_handler/test_handler_snappy.py b/tests/unittests/test_handler/test_handler_snappy.py index eceb14d9..8aeff53c 100644 --- a/tests/unittests/test_handler/test_handler_snappy.py +++ b/tests/unittests/test_handler/test_handler_snappy.py @@ -125,8 +125,7 @@ class TestInstallPackages(t_help.TestCase): "pkg1.smoser.config": "pkg1.smoser.config-data", "pkg1.config": "pkg1.config-data", "pkg2.smoser_0.0_amd64.snap": "pkg2-snapdata", - "pkg2.smoser_0.0_amd64.config": "pkg2.config", - }) + "pkg2.smoser_0.0_amd64.config": "pkg2.config"}) ret = get_package_ops( packages=[], configs={}, installed=[], fspath=self.tmp) diff --git a/tests/unittests/test_sshutil.py b/tests/unittests/test_sshutil.py index 3b317121..9aeb1cde 100644 --- a/tests/unittests/test_sshutil.py +++ b/tests/unittests/test_sshutil.py @@ -32,7 +32,8 @@ VALID_CONTENT = { ), } -TEST_OPTIONS = ("no-port-forwarding,no-agent-forwarding,no-X11-forwarding," +TEST_OPTIONS = ( + "no-port-forwarding,no-agent-forwarding,no-X11-forwarding," 'command="echo \'Please login as the user \"ubuntu\" rather than the' 'user \"root\".\';echo;sleep 10"') diff --git a/tests/unittests/test_templating.py b/tests/unittests/test_templating.py index 0c19a2c2..b9863650 100644 --- a/tests/unittests/test_templating.py +++ b/tests/unittests/test_templating.py @@ -114,5 +114,6 @@ $a,$b''' codename) out_data = templater.basic_render(in_data, - {'mirror': mirror, 'codename': codename}) + {'mirror': mirror, + 'codename': codename}) self.assertEqual(ex_data, out_data) diff --git a/tools/hacking.py b/tools/hacking.py index 3175df38..1a0631c2 100755 --- a/tools/hacking.py +++ b/tools/hacking.py @@ -47,10 +47,10 @@ def import_normalize(line): # handle "from x import y as z" to "import x.y as z" split_line = line.split() if (line.startswith("from ") and "," not in line and - split_line[2] == "import" and split_line[3] != "*" and - split_line[1] != "__future__" and - (len(split_line) == 4 or - (len(split_line) == 6 and split_line[4] == "as"))): + split_line[2] == "import" and split_line[3] != "*" and + split_line[1] != "__future__" and + (len(split_line) == 4 or + (len(split_line) == 6 and split_line[4] == "as"))): return "import %s.%s" % (split_line[1], split_line[3]) else: return line @@ -74,7 +74,7 @@ def cloud_import_alphabetical(physical_line, line_number, lines): split_line[0] == "import" and split_previous[0] == "import"): if split_line[1] < split_previous[1]: return (0, "N306: imports not in alphabetical order (%s, %s)" - % (split_previous[1], split_line[1])) + % (split_previous[1], split_line[1])) def cloud_docstring_start_space(physical_line): @@ -87,8 +87,8 @@ def cloud_docstring_start_space(physical_line): pos = max([physical_line.find(i) for i in DOCSTRING_TRIPLE]) # start if (pos != -1 and len(physical_line) > pos + 1): if (physical_line[pos + 3] == ' '): - return (pos, "N401: one line docstring should not start with" - " a space") + return (pos, + "N401: one line docstring should not start with a space") def cloud_todo_format(physical_line): @@ -167,4 +167,4 @@ if __name__ == "__main__": finally: if len(_missingImport) > 0: print >> sys.stderr, ("%i imports missing in this test environment" - % len(_missingImport)) + % len(_missingImport)) diff --git a/tools/mock-meta.py b/tools/mock-meta.py index dfbc2a71..1c746f17 100755 --- a/tools/mock-meta.py +++ b/tools/mock-meta.py @@ -126,11 +126,11 @@ class WebException(Exception): def yamlify(data): formatted = yaml.dump(data, - line_break="\n", - indent=4, - explicit_start=True, - explicit_end=True, - default_flow_style=False) + line_break="\n", + indent=4, + explicit_start=True, + explicit_end=True, + default_flow_style=False) return formatted @@ -282,7 +282,7 @@ class MetaDataHandler(object): else: log.warn(("Did not implement action %s, " "returning empty response: %r"), - action, NOT_IMPL_RESPONSE) + action, NOT_IMPL_RESPONSE) return NOT_IMPL_RESPONSE @@ -404,14 +404,17 @@ def setup_logging(log_level, fmt='%(levelname)s: @%(name)s : %(message)s'): def extract_opts(): parser = OptionParser() parser.add_option("-p", "--port", dest="port", action="store", type=int, - default=80, metavar="PORT", - help="port from which to serve traffic (default: %default)") + default=80, metavar="PORT", + help=("port from which to serve traffic" + " (default: %default)")) parser.add_option("-a", "--addr", dest="address", action="store", type=str, - default='0.0.0.0', metavar="ADDRESS", - help="address from which to serve traffic (default: %default)") + default='0.0.0.0', metavar="ADDRESS", + help=("address from which to serve traffic" + " (default: %default)")) parser.add_option("-f", '--user-data-file', dest='user_data_file', - action='store', metavar='FILE', - help="user data filename to serve back to incoming requests") + action='store', metavar='FILE', + help=("user data filename to serve back to" + "incoming requests")) (options, args) = parser.parse_args() out = dict() out['extra'] = args diff --git a/tools/run-pep8 b/tools/run-pep8 index ccd6be5a..086400fc 100755 --- a/tools/run-pep8 +++ b/tools/run-pep8 @@ -1,39 +1,22 @@ #!/bin/bash -if [ $# -eq 0 ]; then - files=( bin/cloud-init $(find * -name "*.py" -type f) ) +pycheck_dirs=( "cloudinit/" "bin/" "tests/" "tools/" ) +# FIXME: cloud-init modifies sys module path, pep8 does not like +# bin_files=( "bin/cloud-init" ) +CR=" +" +[ "$1" = "-v" ] && { verbose="$1"; shift; } || verbose="" + +set -f +if [ $# -eq 0 ]; then unset IFS + IFS="$CR" + files=( "${bin_files[@]}" "${pycheck_dirs[@]}" ) + unset IFS else - files=( "$@" ); + files=( "$@" ) fi -if [ -f 'hacking.py' ] -then - base=`pwd` -else - base=`pwd`/tools/ -fi - -IGNORE="" - -# King Arthur: Be quiet! ... Be Quiet! I Order You to Be Quiet. -IGNORE="$IGNORE,E121" # Continuation line indentation is not a multiple of four -IGNORE="$IGNORE,E123" # Closing bracket does not match indentation of opening bracket's line -IGNORE="$IGNORE,E124" # Closing bracket missing visual indentation -IGNORE="$IGNORE,E125" # Continuation line does not distinguish itself from next logical line -IGNORE="$IGNORE,E126" # Continuation line over-indented for hanging indent -IGNORE="$IGNORE,E127" # Continuation line over-indented for visual indent -IGNORE="$IGNORE,E128" # Continuation line under-indented for visual indent -IGNORE="$IGNORE,E502" # The backslash is redundant between brackets -IGNORE="${IGNORE#,}" # remove the leading ',' added above - -cmd=( - ${base}/hacking.py - - --ignore="$IGNORE" - - "${files[@]}" -) - -echo -e "\nRunning 'cloudinit' pep8:" -echo "${cmd[@]}" -"${cmd[@]}" +myname=${0##*/} +cmd=( "${myname#run-}" $verbose "${files[@]}" ) +echo "Running: " "${cmd[@]}" 1>&2 +exec "${cmd[@]}" -- cgit v1.2.3 From 3d9153d16b194e7a3139c290e723ef17518e617d Mon Sep 17 00:00:00 2001 From: Ryan Harper Date: Thu, 3 Mar 2016 16:32:32 -0600 Subject: Fix pyflake/pyflake3 errors Now we can run make check to assess pep8, pyflakes for python2 or 3 And execute unittests via nosetests (2 and 3). --- Makefile | 1 - cloudinit/util.py | 2 +- tests/unittests/test_datasource/test_azure_helper.py | 2 -- tests/unittests/test_datasource/test_smartos.py | 1 - .../unittests/test_handler/test_handler_power_state.py | 2 +- tools/run-pyflakes | 18 ++++++++++++++++++ tools/run-pyflakes3 | 2 ++ 7 files changed, 22 insertions(+), 6 deletions(-) create mode 100755 tools/run-pyflakes create mode 100755 tools/run-pyflakes3 (limited to 'tests/unittests') diff --git a/Makefile b/Makefile index fb65b70b..fc91f829 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,6 @@ pyflakes: pyflakes3: @$(CWD)/tools/run-pyflakes3 - unittest: nosetests $(noseopts) tests/unittests nosetests3 $(noseopts) tests/unittests diff --git a/cloudinit/util.py b/cloudinit/util.py index de37b0f5..e7407ea4 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -2148,7 +2148,7 @@ def _read_dmi_syspath(key): LOG.debug("dmi data %s returned %s", dmi_key_path, key_data) return key_data.strip() - except Exception as e: + except Exception: logexc(LOG, "failed read of %s", dmi_key_path) return None diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 8dbdfb0b..1134199b 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,6 +1,4 @@ import os -import struct -import unittest from cloudinit.sources.helpers import azure as azure_helper from ..helpers import TestCase diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 5e617b83..616e9f0e 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -31,7 +31,6 @@ import shutil import stat import tempfile import uuid -import unittest from binascii import crc32 import serial diff --git a/tests/unittests/test_handler/test_handler_power_state.py b/tests/unittests/test_handler/test_handler_power_state.py index f9660ff6..04ce5687 100644 --- a/tests/unittests/test_handler/test_handler_power_state.py +++ b/tests/unittests/test_handler/test_handler_power_state.py @@ -106,7 +106,7 @@ def check_lps_ret(psc_return, mode=None): if 'shutdown' not in psc_return[0][0]: errs.append("string 'shutdown' not in cmd") - if 'condition' is None: + if condition is None: errs.append("condition was not returned") if mode is not None: diff --git a/tools/run-pyflakes b/tools/run-pyflakes new file mode 100755 index 00000000..4bea17f4 --- /dev/null +++ b/tools/run-pyflakes @@ -0,0 +1,18 @@ +#!/bin/bash + +PYTHON_VERSION=${PYTHON_VERSION:-2} +CR=" +" +pycheck_dirs=( "cloudinit/" "bin/" "tests/" "tools/" ) + +set -f +if [ $# -eq 0 ]; then + files=( "${pycheck_dirs[@]}" ) +else + files=( "$@" ) +fi + +cmd=( "python${PYTHON_VERSION}" -m "pyflakes" "${files[@]}" ) + +echo "Running: " "${cmd[@]}" 1>&2 +exec "${cmd[@]}" diff --git a/tools/run-pyflakes3 b/tools/run-pyflakes3 new file mode 100755 index 00000000..e9f0863d --- /dev/null +++ b/tools/run-pyflakes3 @@ -0,0 +1,2 @@ +#!/bin/sh +PYTHON_VERSION=3 exec "${0%/*}/run-pyflakes" "$@" -- cgit v1.2.3 From 70acc910c3368980d7cb8971391a2c9dfaf3fda8 Mon Sep 17 00:00:00 2001 From: Ryan Harper Date: Fri, 4 Mar 2016 09:51:05 -0600 Subject: pep8: update formatting to pass pep8 1.4.6 (trusty) and 1.6.2 (xenial) make check fails in a trusty sbuild due to different rules on older pep8. Fix formatting to pass in older and newer pep8. --- cloudinit/config/cc_rh_subscription.py | 4 +--- tests/unittests/test_datasource/test_azure.py | 2 +- tools/hacking.py | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/config/cc_rh_subscription.py b/cloudinit/config/cc_rh_subscription.py index 6f474aed..6087c45c 100644 --- a/cloudinit/config/cc_rh_subscription.py +++ b/cloudinit/config/cc_rh_subscription.py @@ -126,10 +126,8 @@ class SubscriptionManager(object): "(True/False " return False, not_bool - if (self.servicelevel is not None) and \ - ((not self.auto_attach) or + if (self.servicelevel is not None) and ((not self.auto_attach) or (util.is_false(str(self.auto_attach)))): - no_auto = ("The service-level key must be used in conjunction " "with the auto-attach key. Please re-run with " "auto-attach: True") diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 4c9c7d8b..444e2799 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -268,7 +268,7 @@ class TestAzureDataSource(TestCase): pos = defuser['passwd'].rfind("$") + 1 self.assertEqual(defuser['passwd'], crypt.crypt(odata['UserPassword'], - defuser['passwd'][0:pos])) + defuser['passwd'][0:pos])) def test_userdata_plain(self): mydata = "FOOBAR" diff --git a/tools/hacking.py b/tools/hacking.py index 1a0631c2..716c1154 100755 --- a/tools/hacking.py +++ b/tools/hacking.py @@ -49,8 +49,8 @@ def import_normalize(line): if (line.startswith("from ") and "," not in line and split_line[2] == "import" and split_line[3] != "*" and split_line[1] != "__future__" and - (len(split_line) == 4 or - (len(split_line) == 6 and split_line[4] == "as"))): + (len(split_line) == 4 or (len(split_line) == 6 and + split_line[4] == "as"))): return "import %s.%s" % (split_line[1], split_line[3]) else: return line -- cgit v1.2.3 From 6e31038b9cccbcb4a33693060b96fc4f71d86789 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 7 Mar 2016 21:31:25 -0500 Subject: No longer run pollinate by default in seed_random The user can still choose to run pollinate here to seed their random data. And in an environment with network datasource, that would be expected to work. However, we do not want to run it any more from cloud-init because a.) pollinate's own init system jobs should get it ran before ssh, which is the primary purpose of wanting cloud-init to run it. b.) with a local datasource, there is no network guarantee when init_modules run, so pollinate -q would often cause issues then. c.) cloud-init would run pollinate and log the failure causing many cloud-init specific failures that it could do nothing about. LP: #1554152 --- ChangeLog | 1 + cloudinit/config/cc_seed_random.py | 2 +- tests/unittests/test_handler/test_handler_seed_random.py | 14 ++++++++------ 3 files changed, 10 insertions(+), 7 deletions(-) (limited to 'tests/unittests') diff --git a/ChangeLog b/ChangeLog index a80a5d5f..6da276b5 100644 --- a/ChangeLog +++ b/ChangeLog @@ -85,6 +85,7 @@ unless it is already a file (LP: #1543025). - Enable password changing via a hashed string [Alex Sirbu] - Added BigStep datasource [Alex Sirbu] + - No longer run pollinate in seed_random (LP: #1554152) 0.7.6: - open 0.7.6 diff --git a/cloudinit/config/cc_seed_random.py b/cloudinit/config/cc_seed_random.py index 3288a853..1b011216 100644 --- a/cloudinit/config/cc_seed_random.py +++ b/cloudinit/config/cc_seed_random.py @@ -83,7 +83,7 @@ def handle(name, cfg, cloud, log, _args): len(seed_data), seed_path) util.append_file(seed_path, seed_data) - command = mycfg.get('command', ['pollinate', '-q']) + command = mycfg.get('command', None) req = mycfg.get('command_required', False) try: env = os.environ.copy() diff --git a/tests/unittests/test_handler/test_handler_seed_random.py b/tests/unittests/test_handler/test_handler_seed_random.py index 34d11f21..98bc9b81 100644 --- a/tests/unittests/test_handler/test_handler_seed_random.py +++ b/tests/unittests/test_handler/test_handler_seed_random.py @@ -170,28 +170,30 @@ class TestRandomSeed(t_help.TestCase): contents = util.load_file(self._seed_file) self.assertEquals('tiny-tim-was-here-so-was-josh', contents) - def test_seed_command_not_provided_pollinate_available(self): + def test_seed_command_provided_and_available(self): c = self._get_cloud('ubuntu', {}) self.whichdata = {'pollinate': '/usr/bin/pollinate'} - cc_seed_random.handle('test', {}, c, LOG, []) + cfg = {'random_seed': {'command': ['pollinate', '-q']}} + cc_seed_random.handle('test', cfg, c, LOG, []) subp_args = [f['args'] for f in self.subp_called] self.assertIn(['pollinate', '-q'], subp_args) - def test_seed_command_not_provided_pollinate_not_available(self): + def test_seed_command_not_provided(self): c = self._get_cloud('ubuntu', {}) self.whichdata = {} cc_seed_random.handle('test', {}, c, LOG, []) # subp should not have been called as which would say not available - self.assertEquals(self.subp_called, list()) + self.assertFalse(self.subp_called) def test_unavailable_seed_command_and_required_raises_error(self): c = self._get_cloud('ubuntu', {}) self.whichdata = {} + cfg = {'random_seed': {'command': ['THIS_NO_COMMAND'], + 'command_required': True}} self.assertRaises(ValueError, cc_seed_random.handle, - 'test', {'random_seed': {'command_required': True}}, - c, LOG, []) + 'test', cfg, c, LOG, []) def test_seed_command_and_required(self): c = self._get_cloud('ubuntu', {}) -- cgit v1.2.3 From b839ad32b9bf4541583ecbe68a0bd5dd9f12345a Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 10 Mar 2016 12:32:46 -0500 Subject: dmi data: fix failure of reading dmi data for unset dmi values it is not uncommon to find dmi data in /sys full of 'ff'. utf-8 decoding of those would fail, causing warning and stacktrace. Return '.' instead of \xff. This maps to what dmidecode would return $ dmidecode --string system-product-name ................................. --- ChangeLog | 1 + cloudinit/util.py | 13 ++++++++++--- tests/unittests/test_util.py | 9 +++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) (limited to 'tests/unittests') diff --git a/ChangeLog b/ChangeLog index da1ca9ee..ebaacf6a 100644 --- a/ChangeLog +++ b/ChangeLog @@ -88,6 +88,7 @@ - No longer run pollinate in seed_random (LP: #1554152) - groups: add defalt user to 'lxd' group. Create groups listed for a user if they do not exist. (LP: #1539317) + - dmi data: fix failure of reading dmi data for unset dmi values 0.7.6: - open 0.7.6 diff --git a/cloudinit/util.py b/cloudinit/util.py index e7407ea4..1d50edc9 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -2140,13 +2140,20 @@ def _read_dmi_syspath(key): LOG.debug("did not find %s", dmi_key_path) return None - key_data = load_file(dmi_key_path) + key_data = load_file(dmi_key_path, decode=False) if not key_data: LOG.debug("%s did not return any data", dmi_key_path) return None - LOG.debug("dmi data %s returned %s", dmi_key_path, key_data) - return key_data.strip() + # in the event that this is all \xff and a carriage return + # then return '.' in its place. + if key_data == b'\xff' * (len(key_data) - 1) + b'\n': + key_data = b'.' * (len(key_data) - 1) + b'\n' + + str_data = key_data.decode('utf8').strip() + + LOG.debug("dmi data %s returned %s", dmi_key_path, str_data) + return str_data except Exception: logexc(LOG, "failed read of %s", dmi_key_path) diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 95990165..542e4075 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -385,6 +385,15 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): self.patch_mapping({}) self.assertEqual(None, util.read_dmi_data('expect-fail')) + def test_dots_returned_instead_of_foxfox(self): + my_len = 32 + dmi_value = b'\xff' * my_len + b'\n' + expected = '.' * my_len + dmi_key = 'system-product-name' + sysfs_key = 'product_name' + self._create_sysfs_file(sysfs_key, dmi_value) + self.assertEqual(expected, util.read_dmi_data(dmi_key)) + class TestMultiLog(helpers.FilesystemMockingTestCase): -- cgit v1.2.3 From 03f80fa62eb85270a7a96850c5e689a1c4bc0049 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Mon, 14 Mar 2016 09:21:02 -0400 Subject: change return value for dmi data of all \xff to be "" Previously we returned a string of "." the same length as the dmi field. That seems confusing to the user as "." would seem like a valid response when in fact this value should not be considered valid. So now, in this case, return empty string. --- cloudinit/util.py | 7 +++++-- tests/unittests/test_util.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/util.py b/cloudinit/util.py index 1a517c79..caae17ce 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -2148,7 +2148,7 @@ def _read_dmi_syspath(key): # uninitialized dmi values show as all \xff and /sys appends a '\n'. # in that event, return a string of '.' in the same length. if key_data == b'\xff' * (len(key_data) - 1) + b'\n': - key_data = b'.' * (len(key_data) - 1) + b'\n' + key_data = b"" str_data = key_data.decode('utf8').strip() LOG.debug("dmi data %s returned %s", dmi_key_path, str_data) @@ -2193,7 +2193,10 @@ def read_dmi_data(key): dmidecode_path = which('dmidecode') if dmidecode_path: - return _call_dmidecode(key, dmidecode_path) + ret = _call_dmidecode(key, dmidecode_path) + if ret is not None and ret.replace(".", "") == "": + return "" + return ret LOG.warn("did not find either path %s or dmidecode command", DMI_SYS_PATH) diff --git a/tests/unittests/test_util.py b/tests/unittests/test_util.py index 542e4075..bdee9719 100644 --- a/tests/unittests/test_util.py +++ b/tests/unittests/test_util.py @@ -388,7 +388,7 @@ class TestReadDMIData(helpers.FilesystemMockingTestCase): def test_dots_returned_instead_of_foxfox(self): my_len = 32 dmi_value = b'\xff' * my_len + b'\n' - expected = '.' * my_len + expected = "" dmi_key = 'system-product-name' sysfs_key = 'product_name' self._create_sysfs_file(sysfs_key, dmi_value) -- cgit v1.2.3 From 7f0871dc5b141ff4bf601b6d96021eba8a3bb0ec Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Tue, 22 Mar 2016 05:40:05 -0400 Subject: write to 50-cloud-init.cfg and write systemd.link rules. --- cloudinit/distros/debian.py | 8 +++-- cloudinit/net/__init__.py | 48 ++++++++++++++++++++------ tests/unittests/test_distros/test_netconfig.py | 5 +-- 3 files changed, 46 insertions(+), 15 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/distros/debian.py b/cloudinit/distros/debian.py index 909d6deb..b14fa3e2 100644 --- a/cloudinit/distros/debian.py +++ b/cloudinit/distros/debian.py @@ -46,7 +46,8 @@ APT_GET_WRAPPER = { class Distro(distros.Distro): hostname_conf_fn = "/etc/hostname" locale_conf_fn = "/etc/default/locale" - network_conf_fn = "/etc/network/interfaces" + network_conf_fn = "/etc/network/interfaces.d/50-cloud-init.cfg" + links_prefix = "/etc/systemd/network/50-cloud-init-" def __init__(self, name, cfg, paths): distros.Distro.__init__(self, name, cfg, paths) @@ -79,7 +80,10 @@ class Distro(distros.Distro): def _write_network_config(self, netconfig): ns = net.parse_net_config_data(netconfig) - net.render_network_state(network_state=ns, target="/") + net.render_network_state(target="/", network_state=ns, + eni=self.network_conf_fn, + links_prefix=self.links_prefix) + util.del_file("/etc/network/interfaces.d/eth0.cfg") return [] def _bring_up_interfaces(self, device_names): diff --git a/cloudinit/net/__init__.py b/cloudinit/net/__init__.py index 63fad2fa..ae7b1c04 100644 --- a/cloudinit/net/__init__.py +++ b/cloudinit/net/__init__.py @@ -20,7 +20,6 @@ import errno import glob import os import re -import string from cloudinit import log as logging from cloudinit import util @@ -30,6 +29,7 @@ from . import network_state LOG = logging.getLogger(__name__) SYS_CLASS_NET = "/sys/class/net/" +LINKS_FNAME_PREFIX = "etc/systemd/network/50-cloud-init-" NET_CONFIG_OPTIONS = [ "address", "netmask", "broadcast", "network", "metric", "gateway", @@ -423,19 +423,45 @@ def render_interfaces(network_state): return content -def render_network_state(target, network_state): - eni = 'etc/network/interfaces' - netrules = 'etc/udev/rules.d/70-persistent-net.rules' +def render_network_state(target, network_state, eni="etc/network/interfaces", + links_prefix=LINKS_FNAME_PREFIX, + netrules='etc/udev/rules.d/70-persistent-net.rules'): - eni = os.path.sep.join((target, eni,)) - util.ensure_dir(os.path.dirname(eni)) - with open(eni, 'w+') as f: + fpeni = os.path.sep.join((target, eni,)) + util.ensure_dir(os.path.dirname(fpeni)) + with open(fpeni, 'w+') as f: f.write(render_interfaces(network_state)) - netrules = os.path.sep.join((target, netrules,)) - util.ensure_dir(os.path.dirname(netrules)) - with open(netrules, 'w+') as f: - f.write(render_persistent_net(network_state)) + if netrules: + netrules = os.path.sep.join((target, netrules,)) + util.ensure_dir(os.path.dirname(netrules)) + with open(netrules, 'w+') as f: + f.write(render_persistent_net(network_state)) + + if links_prefix: + render_systemd_links(target, network_state, links_prefix) + + +def render_systemd_links(target, network_state, + links_prefix=LINKS_FNAME_PREFIX): + fp_prefix = os.path.sep.join((target, links_prefix)) + for f in glob.glob(fp_prefix + "*"): + os.unlink(f) + + interfaces = network_state.get('interfaces') + for iface in interfaces.values(): + if (iface['type'] == 'physical' and 'name' in iface and + 'mac_address' in iface): + fname = fp_prefix + iface['name'] + ".link" + with open(fname, "w") as fp: + fp.write("\n".join([ + "[Match]", + "MACAddress=" + iface['mac_address'], + "", + "[Link]", + "Name=" + iface['name'], + "" + ])) def is_disabled_cfg(cfg): diff --git a/tests/unittests/test_distros/test_netconfig.py b/tests/unittests/test_distros/test_netconfig.py index 6d30c5b8..2c2a424d 100644 --- a/tests/unittests/test_distros/test_netconfig.py +++ b/tests/unittests/test_distros/test_netconfig.py @@ -109,8 +109,9 @@ class TestNetCfgDistro(TestCase): ub_distro.apply_network(BASE_NET_CFG, False) self.assertEquals(len(write_bufs), 1) - self.assertIn('/etc/network/interfaces', write_bufs) - write_buf = write_bufs['/etc/network/interfaces'] + eni_name = '/etc/network/interfaces.d/50-cloud-init.cfg' + self.assertIn(eni_name, write_bufs) + write_buf = write_bufs[eni_name] self.assertEquals(str(write_buf).strip(), BASE_NET_CFG.strip()) self.assertEquals(write_buf.mode, 0o644) -- cgit v1.2.3 From 2b85dabb802766e0b3b1949d744c8860c0cb838a Mon Sep 17 00:00:00 2001 From: Ryan Harper Date: Wed, 23 Mar 2016 11:05:22 -0500 Subject: configdata: parse and convert openstack network_data json to network_config --- cloudinit/net/__init__.py | 34 +++-- cloudinit/net/network_state.py | 45 ++++++- cloudinit/sources/DataSourceConfigDrive.py | 137 +++++++++++++++++++++ cloudinit/sources/helpers/openstack.py | 34 ++++- .../unittests/test_datasource/test_configdrive.py | 50 +++++++- 5 files changed, 289 insertions(+), 11 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/net/__init__.py b/cloudinit/net/__init__.py index ae7b1c04..76cd4e8b 100644 --- a/cloudinit/net/__init__.py +++ b/cloudinit/net/__init__.py @@ -336,7 +336,7 @@ def iface_add_attrs(iface): 'index', 'subnets', ] - if iface['type'] not in ['bond', 'bridge']: + if iface['type'] not in ['bond', 'bridge', 'vlan']: ignore_map.append('mac_address') for key, value in iface.items(): @@ -348,19 +348,34 @@ def iface_add_attrs(iface): return content -def render_route(route): - content = "up route add" +def render_route(route, indent=""): + content = "" + up = indent + "post-up route add" + down = indent + "pre-down route del" + eol = " || true\n" mapping = { 'network': '-net', 'netmask': 'netmask', 'gateway': 'gw', 'metric': 'metric', } - for k in ['network', 'netmask', 'gateway', 'metric']: - if k in route: - content += " %s %s" % (mapping[k], route[k]) + if route['network'] == '0.0.0.0' and route['netmask'] == '0.0.0.0': + default_gw = " default gw %s" % route['gateway'] + content += up + default_gw + eol + content += down + default_gw + eol + elif route['network'] == '::' and route['netmask'] == 0: + # ipv6! + default_gw = " -A inet6 default gw %s" % route['gateway'] + content += up + default_gw + eol + content += down + default_gw + eol + else: + route_line = "" + for k in ['network', 'netmask', 'gateway', 'metric']: + if k in route: + route_line += " %s %s" % (mapping[k], route[k]) + content += up + route_line + eol + content += down + route_line + eol - content += '\n' return content @@ -384,6 +399,7 @@ def render_interfaces(network_state): if len(value): content += " dns-{} {}\n".format(dnskey, " ".join(value)) + content += "\n" for iface in sorted(interfaces.values(), key=lambda k: (order[k['type']], k['name'])): content += "auto {name}\n".format(**iface) @@ -409,6 +425,8 @@ def render_interfaces(network_state): content += iface_add_subnet(iface, subnet) content += iface_add_attrs(iface) + for route in subnet.get('routes', []): + content += render_route(route, indent=" ") content += "\n" else: content += "iface {name} {inet} {mode}\n".format(**iface) @@ -419,7 +437,7 @@ def render_interfaces(network_state): content += render_route(route) # global replacements until v2 format - content = content.replace('mac_address', 'hwaddress') + content = content.replace('mac_address', 'hwaddress ether') return content diff --git a/cloudinit/net/network_state.py b/cloudinit/net/network_state.py index df04c526..e32d2cdf 100644 --- a/cloudinit/net/network_state.py +++ b/cloudinit/net/network_state.py @@ -124,6 +124,17 @@ class NetworkState: iface = interfaces.get(command['name'], {}) for param, val in command.get('params', {}).items(): iface.update({param: val}) + + # convert subnet ipv6 netmask to cidr as needed + subnets = command.get('subnets') + if subnets: + for subnet in subnets: + if subnet['type'] == 'static': + if 'netmask' in subnet and ':' in subnet['address']: + subnet['netmask'] = mask2cidr(subnet['netmask']) + for route in subnet.get('routes', []): + if 'netmask' in route: + route['netmask'] = mask2cidr(route['netmask']) iface.update({ 'name': command.get('name'), 'type': command.get('type'), @@ -133,7 +144,7 @@ class NetworkState: 'mtu': command.get('mtu'), 'address': None, 'gateway': None, - 'subnets': command.get('subnets'), + 'subnets': subnets, }) self.network_state['interfaces'].update({command.get('name'): iface}) self.dump_network_state() @@ -144,6 +155,7 @@ class NetworkState: iface eth0.222 inet static address 10.10.10.1 netmask 255.255.255.0 + hwaddress ether BC:76:4E:06:96:B3 vlan-raw-device eth0 ''' required_keys = [ @@ -335,6 +347,37 @@ def cidr2mask(cidr): return ".".join([str(x) for x in mask]) +def ipv4mask2cidr(mask): + if '.' not in mask: + return mask + return sum([bin(int(x)).count('1') for x in mask.split('.')]) + + +def ipv6mask2cidr(mask): + if ':' not in mask: + return mask + + bitCount = [0, 0x8000, 0xc000, 0xe000, 0xf000, 0xf800, 0xfc00, 0xfe00, + 0xff00, 0xff80, 0xffc0, 0xffe0, 0xfff0, 0xfff8, 0xfffc, + 0xfffe, 0xffff] + cidr = 0 + for word in mask.split(':'): + if not word or int(word, 16) == 0: + break + cidr += bitCount.index(int(word, 16)) + + return cidr + + +def mask2cidr(mask): + if ':' in mask: + return ipv6mask2cidr(mask) + elif '.' in mask: + return ipv4mask2cidr(mask) + else: + return mask + + if __name__ == '__main__': import sys import random diff --git a/cloudinit/sources/DataSourceConfigDrive.py b/cloudinit/sources/DataSourceConfigDrive.py index 6fc9e05b..d84fab54 100644 --- a/cloudinit/sources/DataSourceConfigDrive.py +++ b/cloudinit/sources/DataSourceConfigDrive.py @@ -18,6 +18,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import copy import os from cloudinit import log as logging @@ -50,6 +51,8 @@ class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource): self.seed_dir = os.path.join(paths.seed_dir, 'config_drive') self.version = None self.ec2_metadata = None + self._network_config = None + self.network_json = None self.files = {} def __str__(self): @@ -144,12 +147,27 @@ class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource): LOG.warn("Invalid content in vendor-data: %s", e) self.vendordata_raw = None + nd = results.get('networkdata') + self.networkdata_pure = nd + try: + self.network_json = openstack.convert_networkdata_json(nd) + except ValueError as e: + LOG.warn("Invalid content in network-data: %s", e) + self.network_json = None + + if self.network_json: + self._network_config = convert_network_data(self.network_json) + return True def check_instance_id(self): # quickly (local check only) if self.instance_id is still valid return sources.instance_id_matches_system_uuid(self.get_instance_id()) + @property + def network_config(self): + return self._network_config + class DataSourceConfigDriveNet(DataSourceConfigDrive): def __init__(self, sys_cfg, distro, paths): @@ -287,3 +305,122 @@ datasources = [ # Return a list of data sources that match this set of dependencies def get_datasource_list(depends): return sources.list_from_depends(depends, datasources) + + +# Convert OpenStack ConfigDrive NetworkData json to network_config yaml +def convert_network_data(network_json=None): + """Return a dictionary of network_config by parsing provided + OpenStack ConfigDrive NetworkData json format + + OpenStack network_data.json provides a 3 element dictionary + - "links" (links are network devices, physical or virtual) + - "networks" (networks are ip network configurations for one or more + links) + - services (non-ip services, like dns) + + networks and links are combined via network items referencing specific + links via a 'link_id' which maps to a links 'id' field. + + To convert this format to network_config yaml, we first iterate over the + links and then walk the network list to determine if any of the networks + utilize the current link; if so we generate a subnet entry for the device + + We also need to map network_data.json fields to network_config fields. For + example, the network_data links 'id' field is equivalent to network_config + 'name' field for devices. We apply more of this mapping to the various + link types that we encounter. + + There are additional fields that are populated in the network_data.json + from OpenStack that are not relevant to network_config yaml, so we + enumerate a dictionary of valid keys for network_yaml and apply filtering + to drop these superflous keys from the network_config yaml. + """ + if network_json is None: + return None + + # dict of network_config key for filtering network_json + valid_keys = { + 'physical': [ + 'name', + 'type', + 'mac_address', + 'subnets', + 'params', + ], + 'subnet': [ + 'type', + 'address', + 'netmask', + 'broadcast', + 'metric', + 'gateway', + 'pointopoint', + 'mtu', + 'scope', + 'dns_nameservers', + 'dns_search', + 'routes', + ], + } + + links = network_json.get('links', []) + networks = network_json.get('networks', []) + services = network_json.get('services', []) + + config = [] + for link in links: + subnets = [] + cfg = {k: v for k, v in link.items() + if k in valid_keys['physical']} + cfg.update({'name': link['id']}) + for network in [net for net in networks + if net['link'] == link['id']]: + subnet = {k: v for k, v in network.items() + if k in valid_keys['subnet']} + if 'dhcp' in network['type']: + t = 'dhcp6' if network['type'].startswith('ipv6') else 'dhcp4' + subnet.update({ + 'type': t, + }) + else: + subnet.update({ + 'type': 'static', + 'address': network.get('ip_address'), + }) + subnets.append(subnet) + cfg.update({'subnets': subnets}) + if link['type'] in ['ethernet', 'vif', 'ovs']: + cfg.update({ + 'type': 'physical', + 'mac_address': link['ethernet_mac_address']}) + elif link['type'] in ['bond']: + params = {} + for k, v in link.items(): + if k == 'bond_links': + continue + elif k.startswith('bond'): + params.update({k: v}) + cfg.update({ + 'bond_interfaces': copy.deepcopy(link['bond_links']), + 'params': params, + }) + elif link['type'] in ['vlan']: + cfg.update({ + 'name': "%s.%s" % (link['vlan_link'], + link['vlan_id']), + 'vlan_link': link['vlan_link'], + 'vlan_id': link['vlan_id'], + 'mac_address': link['vlan_mac_address'], + }) + else: + raise ValueError( + 'Unknown network_data link type: %s' % link['type']) + + config.append(cfg) + + for service in services: + cfg = service + cfg.update({'type': 'nameserver'}) + config.append(cfg) + + return {'version': 1, 'config': config} diff --git a/cloudinit/sources/helpers/openstack.py b/cloudinit/sources/helpers/openstack.py index bd93d22f..eb50a7be 100644 --- a/cloudinit/sources/helpers/openstack.py +++ b/cloudinit/sources/helpers/openstack.py @@ -51,11 +51,13 @@ OS_LATEST = 'latest' OS_FOLSOM = '2012-08-10' OS_GRIZZLY = '2013-04-04' OS_HAVANA = '2013-10-17' +OS_KILO = '2015-10-15' # keep this in chronological order. new supported versions go at the end. OS_VERSIONS = ( OS_FOLSOM, OS_GRIZZLY, OS_HAVANA, + OS_KILO, ) @@ -229,6 +231,11 @@ class BaseReader(object): False, load_json_anytype, ) + files['networkdata'] = ( + self._path_join("openstack", version, 'network_data.json'), + False, + load_json_anytype, + ) return files results = { @@ -334,7 +341,7 @@ class ConfigDriveReader(BaseReader): path = self._path_join(self.base_path, 'openstack') found = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path))] - self._versions = found + self._versions = sorted(found) return self._versions def _read_ec2_metadata(self): @@ -490,3 +497,28 @@ def convert_vendordata_json(data, recurse=True): recurse=False) raise ValueError("vendordata['cloud-init'] cannot be dict") raise ValueError("Unknown data type for vendordata: %s" % type(data)) + + +def convert_networkdata_json(data, recurse=True): + """ data: a loaded json *object* (strings, arrays, dicts). + return something suitable for cloudinit networkdata_raw. + + if data is: + None: return None + string: return string + list: return data + the list is then processed in UserDataProcessor + dict: return convert_networkdata_json(data.get('cloud-init')) + """ + if not data: + return None + if isinstance(data, six.string_types): + return data + if isinstance(data, list): + return copy.deepcopy(data) + if isinstance(data, dict): + if recurse is True: + return convert_networkdata_json(data.get('cloud-init'), + recurse=False) + raise ValueError("networkdata['cloud-init'] cannot be dict") + raise ValueError("Unknown data type for networkdata: %s" % type(data)) diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index bfd787d1..01f8c5ce 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -59,6 +59,34 @@ OSTACK_META = { CONTENT_0 = b'This is contents of /etc/foo.cfg\n' CONTENT_1 = b'# this is /etc/bar/bar.cfg\n' +NETWORK_DATA = { + 'services': [ + {'type': 'dns', 'address': '199.204.44.24'}, + {'type': 'dns', 'address': '199.204.47.54'} + ], + 'links': [ + {'vif_id': '2ecc7709-b3f7-4448-9580-e1ec32d75bbd', + 'ethernet_mac_address': 'fa:16:3e:69:b0:58', + 'type': 'ovs', 'mtu': None, 'id': 'tap2ecc7709-b3'}, + {'vif_id': '2f88d109-5b57-40e6-af32-2472df09dc33', + 'ethernet_mac_address': 'fa:16:3e:d4:57:ad', + 'type': 'ovs', 'mtu': None, 'id': 'tap2f88d109-5b'}, + {'vif_id': '1a5382f8-04c5-4d75-ab98-d666c1ef52cc', + 'ethernet_mac_address': 'fa:16:3e:05:30:fe', + 'type': 'ovs', 'mtu': None, 'id': 'tap1a5382f8-04'} + ], + 'networks': [ + {'link': 'tap2ecc7709-b3', 'type': 'ipv4_dhcp', + 'network_id': '6d6357ac-0f70-4afa-8bd7-c274cc4ea235', + 'id': 'network0'}, + {'link': 'tap2f88d109-5b', 'type': 'ipv4_dhcp', + 'network_id': 'd227a9b3-6960-4d94-8976-ee5788b44f54', + 'id': 'network1'}, + {'link': 'tap1a5382f8-04', 'type': 'ipv4_dhcp', + 'network_id': 'dab2ba57-cae2-4311-a5ed-010b263891f5', + 'id': 'network2'} + ] +} CFG_DRIVE_FILES_V2 = { 'ec2/2009-04-04/meta-data.json': json.dumps(EC2_META), @@ -70,7 +98,11 @@ CFG_DRIVE_FILES_V2 = { 'openstack/content/0000': CONTENT_0, 'openstack/content/0001': CONTENT_1, 'openstack/latest/meta_data.json': json.dumps(OSTACK_META), - 'openstack/latest/user_data': USER_DATA} + 'openstack/latest/user_data': USER_DATA, + 'openstack/latest/network_data.json': json.dumps(NETWORK_DATA), + 'openstack/2015-10-15/meta_data.json': json.dumps(OSTACK_META), + 'openstack/2015-10-15/user_data': USER_DATA, + 'openstack/2015-10-15/network_data.json': json.dumps(NETWORK_DATA)} class TestConfigDriveDataSource(TestCase): @@ -225,6 +257,7 @@ class TestConfigDriveDataSource(TestCase): self.assertEqual(USER_DATA, found['userdata']) self.assertEqual(expected_md, found['metadata']) + self.assertEqual(NETWORK_DATA, found['networkdata']) self.assertEqual(found['files']['/etc/foo.cfg'], CONTENT_0) self.assertEqual(found['files']['/etc/bar/bar.cfg'], CONTENT_1) @@ -321,6 +354,19 @@ class TestConfigDriveDataSource(TestCase): self.assertEqual(myds.get_public_ssh_keys(), [OSTACK_META['public_keys']['mykey']]) + def test_network_data_is_found(self): + """Verify that network_data is present in ds in config-drive-v2.""" + populate_dir(self.tmp, CFG_DRIVE_FILES_V2) + myds = cfg_ds_from_dir(self.tmp) + self.assertEqual(myds.network_json, NETWORK_DATA) + + def test_network_config_is_converted(self): + """Verify that network_data is converted and present on ds object.""" + populate_dir(self.tmp, CFG_DRIVE_FILES_V2) + myds = cfg_ds_from_dir(self.tmp) + network_config = ds.convert_network_data(NETWORK_DATA) + self.assertEqual(myds.network_config, network_config) + def cfg_ds_from_dir(seed_d): found = ds.read_config_drive(seed_d) @@ -339,6 +385,8 @@ def populate_ds_from_read_config(cfg_ds, source, results): cfg_ds.ec2_metadata = results.get('ec2-metadata') cfg_ds.userdata_raw = results.get('userdata') cfg_ds.version = results.get('version') + cfg_ds.network_json = results.get('networkdata') + cfg_ds._network_config = ds.convert_network_data(cfg_ds.network_json) def populate_dir(seed_dir, files): -- cgit v1.2.3 From d8cfa30be45492140c04fb9c9ed69e83e9f041d9 Mon Sep 17 00:00:00 2001 From: Ryan Harper Date: Wed, 23 Mar 2016 13:32:23 -0500 Subject: unittest: fix bad json test with ConfigDrive Introduced a new path in configdrive, openstack/2015-10-15/, needed to add bogus data in that path as well to ensure config reader didn't find good data when testing for exception thrown. --- tests/unittests/test_datasource/test_configdrive.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests/unittests') diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py index 01f8c5ce..89b15f54 100644 --- a/tests/unittests/test_datasource/test_configdrive.py +++ b/tests/unittests/test_datasource/test_configdrive.py @@ -283,6 +283,7 @@ class TestConfigDriveDataSource(TestCase): data = copy(CFG_DRIVE_FILES_V2) data["openstack/2012-08-10/meta_data.json"] = "non-json garbage {}" + data["openstack/2015-10-15/meta_data.json"] = "non-json garbage {}" data["openstack/latest/meta_data.json"] = "non-json garbage {}" populate_dir(self.tmp, data) -- cgit v1.2.3 From 44f71ff4ccb72c89f6cdebc8a7b4e7a0d7029818 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Wed, 23 Mar 2016 15:32:51 -0400 Subject: add unit test --- tests/unittests/test_net.py | 94 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/unittests/test_net.py (limited to 'tests/unittests') diff --git a/tests/unittests/test_net.py b/tests/unittests/test_net.py new file mode 100644 index 00000000..06a55643 --- /dev/null +++ b/tests/unittests/test_net.py @@ -0,0 +1,94 @@ +from cloudinit import util +from cloudinit import net +from .helpers import TestCase +import copy +import os + +DHCP_CONTENT_1 = """ +DEVICE='eth0' +PROTO='dhcp' +IPV4ADDR='192.168.122.89' +IPV4BROADCAST='192.168.122.255' +IPV4NETMASK='255.255.255.0' +IPV4GATEWAY='192.168.122.1' +IPV4DNS0='192.168.122.1' +IPV4DNS1='0.0.0.0' +HOSTNAME='foohost' +DNSDOMAIN='' +NISDOMAIN='' +ROOTSERVER='192.168.122.1' +ROOTPATH='' +filename='' +UPTIME='21' +DHCPLEASETIME='3600' +DOMAINSEARCH='foo.com' +""" + +DHCP_EXPECTED_1 = { + 'name': 'eth0', + 'type': 'physical', + 'subnets': [{'broadcast': '192.168.122.255', + 'gateway': '192.168.122.1', + 'dns_search': ['foo.com'], + 'type': 'dhcp', + 'netmask': '255.255.255.0', + 'dns_nameservers': ['192.168.122.1']}], +} + + +STATIC_CONTENT_1 = """ +DEVICE='eth1' +PROTO='static' +IPV4ADDR='10.0.0.2' +IPV4BROADCAST='10.0.0.255' +IPV4NETMASK='255.255.255.0' +IPV4GATEWAY='10.0.0.1' +IPV4DNS0='10.0.1.1' +IPV4DNS1='0.0.0.0' +HOSTNAME='foohost' +UPTIME='21' +DHCPLEASETIME='3600' +DOMAINSEARCH='foo.com' +""" + +STATIC_EXPECTED_1 = { + 'name': 'eth1', + 'type': 'physical', + 'subnets': [{'broadcast': '10.0.0.255', 'gateway': '10.0.0.1', + 'dns_search': ['foo.com'], 'type': 'static', + 'netmask': '255.255.255.0', + 'dns_nameservers': ['10.0.1.1']}], +} + +class TestNetConfigParsing(TestCase): + def test_klibc_convert_dhcp(self): + found = net._klibc_to_config_entry(DHCP_CONTENT_1) + self.assertEqual(found, ('eth0', DHCP_EXPECTED_1)) + + def test_klibc_convert_static(self): + found = net._klibc_to_config_entry(STATIC_CONTENT_1) + self.assertEqual(found, ('eth1', STATIC_EXPECTED_1)) + + def test_config_from_klibc_net_cfg(self): + files = [] + pairs = (('net-eth0.cfg', DHCP_CONTENT_1), + ('net-eth1.cfg', STATIC_CONTENT_1)) + + macs = {'eth1': 'b8:ae:ed:75:ff:2b', + 'eth0': 'b8:ae:ed:75:ff:2a'} + + dhcp = copy.deepcopy(DHCP_EXPECTED_1) + dhcp['mac_address'] = macs['eth0'] + + static = copy.deepcopy(STATIC_EXPECTED_1) + static['mac_address'] = macs['eth1'] + + expected = {'version': 1, 'config': [dhcp, static]} + with util.tempdir() as tmpd: + for fname, content in pairs: + fp = os.path.join(tmpd, fname) + files.append(fp) + util.write_file(fp, content) + + found = net.config_from_klibc_net_cfg(files=files, mac_addrs=macs) + self.assertEqual(found, expected) -- cgit v1.2.3 From 51fa67a88dc0dc631c19770142b16c0b56c21384 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Wed, 23 Mar 2016 15:38:32 -0400 Subject: one more tox --- tests/unittests/test_net.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests/unittests') diff --git a/tests/unittests/test_net.py b/tests/unittests/test_net.py index 06a55643..11c0b1eb 100644 --- a/tests/unittests/test_net.py +++ b/tests/unittests/test_net.py @@ -60,6 +60,7 @@ STATIC_EXPECTED_1 = { 'dns_nameservers': ['10.0.1.1']}], } + class TestNetConfigParsing(TestCase): def test_klibc_convert_dhcp(self): found = net._klibc_to_config_entry(DHCP_CONTENT_1) -- cgit v1.2.3 From f9cae62c2f70c582a6b12a98911dc82180881364 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 24 Mar 2016 11:19:41 -0400 Subject: add suport for base64 encoded gzipped text on command line add tests to show this functional. --- cloudinit/net/__init__.py | 34 +++++++++++++++++++++++++++++++++- tests/unittests/test_net.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) (limited to 'tests/unittests') diff --git a/cloudinit/net/__init__.py b/cloudinit/net/__init__.py index 3a208e43..7d7f274d 100644 --- a/cloudinit/net/__init__.py +++ b/cloudinit/net/__init__.py @@ -19,6 +19,8 @@ import base64 import errno import glob +import gzip +import io import os import re import shlex @@ -647,6 +649,36 @@ def generate_fallback_config(): return nconf +def _decomp_gzip(blob, strict=True): + # decompress blob. raise exception if not compressed unless strict=False. + with io.BytesIO(blob) as iobuf: + gzfp = None + try: + gzfp = gzip.GzipFile(mode="rb", fileobj=iobuf) + return gzfp.read() + except IOError: + if strict: + raise + return blob + finally: + if gzfp: + gzfp.close() + + +def _b64dgz(b64str, gzipped="try"): + # decode a base64 string. If gzipped is true, transparently uncompresss + # if gzipped is 'try', then try gunzip, returning the original on fail. + try: + blob = base64.b64decode(b64str) + except TypeError: + raise ValueError("Invalid base64 text: %s" % b64str) + + if not gzipped: + return blob + + return _decomp_gzip(blob, strict=gzipped != "try") + + def read_kernel_cmdline_config(files=None, mac_addrs=None, cmdline=None): if cmdline is None: cmdline = util.get_cmdline() @@ -657,7 +689,7 @@ def read_kernel_cmdline_config(files=None, mac_addrs=None, cmdline=None): if tok.startswith("network-config="): data64 = tok.split("=", 1)[1] if data64: - return util.load_yaml(base64.b64decode(data64)) + return util.load_yaml(_b64dgz(data64)) if 'ip=' not in cmdline: return None diff --git a/tests/unittests/test_net.py b/tests/unittests/test_net.py index 11c0b1eb..16c44588 100644 --- a/tests/unittests/test_net.py +++ b/tests/unittests/test_net.py @@ -1,7 +1,12 @@ from cloudinit import util from cloudinit import net from .helpers import TestCase + +import base64 import copy +import io +import gzip +import json import os DHCP_CONTENT_1 = """ @@ -62,6 +67,11 @@ STATIC_EXPECTED_1 = { class TestNetConfigParsing(TestCase): + simple_cfg = { + 'config': [{"type": "physical", "name": "eth0", + "mac_address": "c0:d6:9f:2c:e8:80", + "subnets": [{"type": "dhcp4"}]}]} + def test_klibc_convert_dhcp(self): found = net._klibc_to_config_entry(DHCP_CONTENT_1) self.assertEqual(found, ('eth0', DHCP_EXPECTED_1)) @@ -93,3 +103,25 @@ class TestNetConfigParsing(TestCase): found = net.config_from_klibc_net_cfg(files=files, mac_addrs=macs) self.assertEqual(found, expected) + + def test_cmdline_with_b64(self): + data = base64.b64encode(json.dumps(self.simple_cfg).encode()) + encoded_text = data.decode() + cmdline = 'ro network-config=' + encoded_text + ' root=foo' + found = net.read_kernel_cmdline_config(cmdline=cmdline) + self.assertEqual(found, self.simple_cfg) + + def test_cmdline_with_b64_gz(self): + data = _gzip_data(json.dumps(self.simple_cfg).encode()) + encoded_text = base64.b64encode(data).decode() + cmdline = 'ro network-config=' + encoded_text + ' root=foo' + found = net.read_kernel_cmdline_config(cmdline=cmdline) + self.assertEqual(found, self.simple_cfg) + + +def _gzip_data(data): + with io.BytesIO() as iobuf: + gzfp = gzip.GzipFile(mode="wb", fileobj=iobuf) + gzfp.write(data) + gzfp.close() + return iobuf.getvalue() -- cgit v1.2.3 From 557650728d3c1b1c1bbd29f9292d43133c00cdd4 Mon Sep 17 00:00:00 2001 From: Scott Moser Date: Thu, 24 Mar 2016 15:46:52 -0400 Subject: add comments and improve error messages --- cloudinit/net/__init__.py | 20 +++++++++++++++++--- tests/unittests/test_net.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) (limited to 'tests/unittests') diff --git a/cloudinit/net/__init__.py b/cloudinit/net/__init__.py index e13ca470..7af9b03a 100644 --- a/cloudinit/net/__init__.py +++ b/cloudinit/net/__init__.py @@ -304,6 +304,19 @@ def _load_shell_content(content, add_empty=False, empty_val=None): def _klibc_to_config_entry(content, mac_addrs=None): + """Convert a klibc writtent shell content file to a 'config' entry + When ip= is seen on the kernel command line in debian initramfs + and networking is brought up, ipconfig will populate + /run/net-.cfg. + + The files are shell style syntax, and examples are in the tests + provided here. There is no good documentation on this unfortunately. + + DEVICE= is expected/required and PROTO should indicate if + this is 'static' or 'dhcp'. + """ + + if mac_addrs is None: mac_addrs = {} @@ -575,11 +588,12 @@ def is_disabled_cfg(cfg): def sys_netdev_info(name, field): if not os.path.exists(os.path.join(SYS_CLASS_NET, name)): - raise OSError("%s: interface does not exist in /sys" % name) + raise OSError("%s: interface does not exist in %s" % + (name, SYS_CLASS_NET)) fname = os.path.join(SYS_CLASS_NET, name, field) if not os.path.exists(fname): - raise OSError("%s: %s does not exist in /sys" % (name, fname)) + raise OSError("%s: could not find sysfs entry: %s" % (name, fname)) data = util.load_file(fname) if data[-1] == '\n': data = data[:-1] @@ -647,7 +661,7 @@ def generate_fallback_config(): nconf['config'].append( {'type': 'physical', 'name': target_name, - 'mac_address': mac, 'subnets': [{'type': 'dhcp4'}]}) + 'mac_address': mac, 'subnets': [{'type': 'dhcp'}]}) return nconf diff --git a/tests/unittests/test_net.py b/tests/unittests/test_net.py index 16c44588..dfb31710 100644 --- a/tests/unittests/test_net.py +++ b/tests/unittests/test_net.py @@ -70,7 +70,7 @@ class TestNetConfigParsing(TestCase): simple_cfg = { 'config': [{"type": "physical", "name": "eth0", "mac_address": "c0:d6:9f:2c:e8:80", - "subnets": [{"type": "dhcp4"}]}]} + "subnets": [{"type": "dhcp"}]}]} def test_klibc_convert_dhcp(self): found = net._klibc_to_config_entry(DHCP_CONTENT_1) -- cgit v1.2.3