diff options
Diffstat (limited to 'tests/unittests/test_datasource')
| -rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 217 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_gce.py | 49 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_smartos.py | 229 | 
3 files changed, 382 insertions, 113 deletions
| diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 8112c69b..7e789853 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -116,9 +116,6 @@ class TestAzureDataSource(TestCase):              data['iid_from_shared_cfg'] = path              return 'i-my-azure-id' -        def _apply_hostname_bounce(**kwargs): -            data['apply_hostname_bounce'] = kwargs -          if data.get('ovfcontent') is not None:              populate_dir(os.path.join(self.paths.seed_dir, "azure"),                           {'ovf-env.xml': data['ovfcontent']}) @@ -132,7 +129,9 @@ class TestAzureDataSource(TestCase):              (mod, 'wait_for_files', _wait_for_files),              (mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),              (mod, 'iid_from_shared_config', _iid_from_shared_config), -            (mod, 'apply_hostname_bounce', _apply_hostname_bounce), +            (mod, 'perform_hostname_bounce', mock.MagicMock()), +            (mod, 'get_hostname', mock.MagicMock()), +            (mod, 'set_hostname', mock.MagicMock()),              ])          dsrc = mod.DataSourceAzureNet( @@ -272,47 +271,6 @@ class TestAzureDataSource(TestCase):          for mypk in mypklist:              self.assertIn(mypk, dsrc.cfg['_pubkeys']) -    def test_disabled_bounce(self): -        pass - -    def test_apply_bounce_call_1(self): -        # hostname needs to get through to apply_hostname_bounce -        odata = {'HostName': 'my-random-hostname'} -        data = {'ovfcontent': construct_valid_ovf_env(data=odata)} - -        self._get_ds(data).get_data() -        self.assertIn('hostname', data['apply_hostname_bounce']) -        self.assertEqual(data['apply_hostname_bounce']['hostname'], -                         odata['HostName']) - -    def test_apply_bounce_call_configurable(self): -        # hostname_bounce should be configurable in datasource cfg -        cfg = {'hostname_bounce': {'interface': 'eth1', 'policy': 'off', -                                   'command': 'my-bounce-command', -                                   'hostname_command': 'my-hostname-command'}} -        odata = {'HostName': "xhost", -                'dscfg': {'text': b64e(yaml.dump(cfg)), -                          'encoding': 'base64'}} -        data = {'ovfcontent': construct_valid_ovf_env(data=odata)} -        self._get_ds(data).get_data() - -        for k in cfg['hostname_bounce']: -            self.assertIn(k, data['apply_hostname_bounce']) - -        for k, v in cfg['hostname_bounce'].items(): -            self.assertEqual(data['apply_hostname_bounce'][k], v) - -    def test_set_hostname_disabled(self): -        # config specifying set_hostname off should not bounce -        cfg = {'set_hostname': False} -        odata = {'HostName': "xhost", -                'dscfg': {'text': b64e(yaml.dump(cfg)), -                          'encoding': 'base64'}} -        data = {'ovfcontent': construct_valid_ovf_env(data=odata)} -        self._get_ds(data).get_data() - -        self.assertEqual(data.get('apply_hostname_bounce', "N/A"), "N/A") -      def test_default_ephemeral(self):          # make sure the ephemeral device works          odata = {} @@ -425,6 +383,175 @@ class TestAzureDataSource(TestCase):              load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) +class TestAzureBounce(TestCase): + +    def mock_out_azure_moving_parts(self): +        self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'invoke_agent')) +        self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'wait_for_files')) +        self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'iid_from_shared_config', +                              mock.MagicMock(return_value='i-my-azure-id'))) +        self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs', +                              mock.MagicMock(return_value=[]))) +        self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'find_ephemeral_disk', +                              mock.MagicMock(return_value=None))) +        self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'find_ephemeral_part', +                              mock.MagicMock(return_value=None))) + +    def setUp(self): +        super(TestAzureBounce, self).setUp() +        self.tmp = tempfile.mkdtemp() +        self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent') +        self.paths = helpers.Paths({'cloud_dir': self.tmp}) +        self.addCleanup(shutil.rmtree, self.tmp) +        DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d +        self.patches = ExitStack() +        self.mock_out_azure_moving_parts() +        self.get_hostname = self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'get_hostname')) +        self.set_hostname = self.patches.enter_context( +            mock.patch.object(DataSourceAzure, 'set_hostname')) +        self.subp = self.patches.enter_context( +            mock.patch('cloudinit.sources.DataSourceAzure.util.subp')) + +    def tearDown(self): +        self.patches.close() + +    def _get_ds(self, ovfcontent=None): +        if ovfcontent is not None: +            populate_dir(os.path.join(self.paths.seed_dir, "azure"), +                         {'ovf-env.xml': ovfcontent}) +        return DataSourceAzure.DataSourceAzureNet( +            {}, distro=None, paths=self.paths) + +    def get_ovf_env_with_dscfg(self, hostname, cfg): +        odata = { +            'HostName': hostname, +            'dscfg': { +                'text': b64e(yaml.dump(cfg)), +                'encoding': 'base64' +            } +        } +        return construct_valid_ovf_env(data=odata) + +    def test_disabled_bounce_does_not_change_hostname(self): +        cfg = {'hostname_bounce': {'policy': 'off'}} +        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data() +        self.assertEqual(0, self.set_hostname.call_count) + +    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') +    def test_disabled_bounce_does_not_perform_bounce( +            self, perform_hostname_bounce): +        cfg = {'hostname_bounce': {'policy': 'off'}} +        self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data() +        self.assertEqual(0, perform_hostname_bounce.call_count) + +    def test_same_hostname_does_not_change_hostname(self): +        host_name = 'unchanged-host-name' +        self.get_hostname.return_value = host_name +        cfg = {'hostname_bounce': {'policy': 'yes'}} +        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data() +        self.assertEqual(0, self.set_hostname.call_count) + +    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') +    def test_unchanged_hostname_does_not_perform_bounce( +            self, perform_hostname_bounce): +        host_name = 'unchanged-host-name' +        self.get_hostname.return_value = host_name +        cfg = {'hostname_bounce': {'policy': 'yes'}} +        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data() +        self.assertEqual(0, perform_hostname_bounce.call_count) + +    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') +    def test_force_performs_bounce_regardless(self, perform_hostname_bounce): +        host_name = 'unchanged-host-name' +        self.get_hostname.return_value = host_name +        cfg = {'hostname_bounce': {'policy': 'force'}} +        self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data() +        self.assertEqual(1, perform_hostname_bounce.call_count) + +    def test_different_hostnames_sets_hostname(self): +        expected_hostname = 'azure-expected-host-name' +        self.get_hostname.return_value = 'default-host-name' +        self._get_ds( +            self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data() +        self.assertEqual(expected_hostname, +                         self.set_hostname.call_args_list[0][0][0]) + +    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') +    def test_different_hostnames_performs_bounce( +            self, perform_hostname_bounce): +        expected_hostname = 'azure-expected-host-name' +        self.get_hostname.return_value = 'default-host-name' +        self._get_ds( +            self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data() +        self.assertEqual(1, perform_hostname_bounce.call_count) + +    def test_different_hostnames_sets_hostname_back(self): +        initial_host_name = 'default-host-name' +        self.get_hostname.return_value = initial_host_name +        self._get_ds( +            self.get_ovf_env_with_dscfg('some-host-name', {})).get_data() +        self.assertEqual(initial_host_name, +                         self.set_hostname.call_args_list[-1][0][0]) + +    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') +    def test_failure_in_bounce_still_resets_host_name( +            self, perform_hostname_bounce): +        perform_hostname_bounce.side_effect = Exception +        initial_host_name = 'default-host-name' +        self.get_hostname.return_value = initial_host_name +        self._get_ds( +            self.get_ovf_env_with_dscfg('some-host-name', {})).get_data() +        self.assertEqual(initial_host_name, +                         self.set_hostname.call_args_list[-1][0][0]) + +    def test_environment_correct_for_bounce_command(self): +        interface = 'int0' +        hostname = 'my-new-host' +        old_hostname = 'my-old-host' +        self.get_hostname.return_value = old_hostname +        cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}} +        data = self.get_ovf_env_with_dscfg(hostname, cfg) +        self._get_ds(data).get_data() +        self.assertEqual(1, self.subp.call_count) +        bounce_env = self.subp.call_args[1]['env'] +        self.assertEqual(interface, bounce_env['interface']) +        self.assertEqual(hostname, bounce_env['hostname']) +        self.assertEqual(old_hostname, bounce_env['old_hostname']) + +    def test_default_bounce_command_used_by_default(self): +        cmd = 'default-bounce-command' +        DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd +        cfg = {'hostname_bounce': {'policy': 'force'}} +        data = self.get_ovf_env_with_dscfg('some-hostname', cfg) +        self._get_ds(data).get_data() +        self.assertEqual(1, self.subp.call_count) +        bounce_args = self.subp.call_args[1]['args'] +        self.assertEqual(cmd, bounce_args) + +    @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce') +    def test_set_hostname_option_can_disable_bounce( +            self, perform_hostname_bounce): +        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}} +        data = self.get_ovf_env_with_dscfg('some-hostname', cfg) +        self._get_ds(data).get_data() + +        self.assertEqual(0, perform_hostname_bounce.call_count) + +    def test_set_hostname_option_can_disable_hostname_set(self): +        cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}} +        data = self.get_ovf_env_with_dscfg('some-hostname', cfg) +        self._get_ds(data).get_data() + +        self.assertEqual(0, self.set_hostname.call_count) + +  class TestReadAzureOvf(TestCase):      def test_invalid_xml_raises_non_azure_ds(self):          invalid_xml = "<foo>" + construct_valid_ovf_env(data={}) diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 4280abc4..1fb100f7 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -113,10 +113,6 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):          self.assertEqual(GCE_META.get('instance/attributes/user-data'),                           self.ds.get_userdata_raw()) -        # we expect a list of public ssh keys with user names stripped -        self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'], -                         self.ds.get_public_ssh_keys()) -      # test partial metadata (missing user-data in particular)      @httpretty.activate      def test_metadata_partial(self): @@ -141,3 +137,48 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):          decoded = b64decode(              GCE_META_ENCODING.get('instance/attributes/user-data'))          self.assertEqual(decoded, self.ds.get_userdata_raw()) + +    @httpretty.activate +    def test_missing_required_keys_return_false(self): +        for required_key in ['instance/id', 'instance/zone', +                             'instance/hostname']: +            meta = GCE_META_PARTIAL.copy() +            del meta[required_key] +            httpretty.register_uri(httpretty.GET, MD_URL_RE, +                                   body=_new_request_callback(meta)) +            self.assertEqual(False, self.ds.get_data()) +            httpretty.reset() + +    @httpretty.activate +    def test_project_level_ssh_keys_are_used(self): +        httpretty.register_uri(httpretty.GET, MD_URL_RE, +                               body=_new_request_callback()) +        self.ds.get_data() + +        # we expect a list of public ssh keys with user names stripped +        self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'], +                         self.ds.get_public_ssh_keys()) + +    @httpretty.activate +    def test_instance_level_ssh_keys_are_used(self): +        key_content = 'ssh-rsa JustAUser root@server' +        meta = GCE_META.copy() +        meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content) + +        httpretty.register_uri(httpretty.GET, MD_URL_RE, +                               body=_new_request_callback(meta)) +        self.ds.get_data() + +        self.assertIn(key_content, self.ds.get_public_ssh_keys()) + +    @httpretty.activate +    def test_instance_level_keys_replace_project_level_keys(self): +        key_content = 'ssh-rsa JustAUser root@server' +        meta = GCE_META.copy() +        meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content) + +        httpretty.register_uri(httpretty.GET, MD_URL_RE, +                               body=_new_request_callback(meta)) +        self.ds.get_data() + +        self.assertEqual([key_content], self.ds.get_public_ssh_keys()) diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index cb0ab984..adee9019 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -24,20 +24,28 @@  from __future__ import print_function -from cloudinit import helpers as c_helpers -from cloudinit.sources import DataSourceSmartOS -from cloudinit.util import b64e -from .. import helpers  import os  import os.path  import re  import shutil -import tempfile  import stat +import tempfile  import uuid +from binascii import crc32 +import serial  import six +from cloudinit import helpers as c_helpers +from cloudinit.sources import DataSourceSmartOS +from cloudinit.util import b64e + +from .. import helpers + +try: +    from unittest import mock +except ImportError: +    import mock  MOCK_RETURNS = {      'hostname': 'test-host', @@ -56,63 +64,15 @@ MOCK_RETURNS = {  DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc') -class MockSerial(object): -    """Fake a serial terminal for testing the code that -        interfaces with the serial""" - -    port = None - -    def __init__(self, mockdata): -        self.last = None -        self.last = None -        self.new = True -        self.count = 0 -        self.mocked_out = [] -        self.mockdata = mockdata +def get_mock_client(mockdata): +    class MockMetadataClient(object): -    def open(self): -        return True +        def __init__(self, serial): +            pass -    def close(self): -        return True - -    def isOpen(self): -        return True - -    def write(self, line): -        if not isinstance(line, six.binary_type): -            raise TypeError("Should be writing binary lines.") -        line = line.decode('ascii').replace('GET ', '') -        self.last = line.rstrip() - -    def readline(self): -        if self.new: -            self.new = False -            if self.last in self.mockdata: -                line = 'SUCCESS\n' -            else: -                line = 'NOTFOUND %s\n' % self.last - -        elif self.last in self.mockdata: -            if not self.mocked_out: -                self.mocked_out = [x for x in self._format_out()] - -            if len(self.mocked_out) > self.count: -                self.count += 1 -                line = self.mocked_out[self.count - 1] -        return line.encode('ascii') - -    def _format_out(self): -        if self.last in self.mockdata: -            _mret = self.mockdata[self.last] -            try: -                for l in _mret.splitlines(): -                    yield "%s\n" % l.rstrip() -            except: -                yield "%s\n" % _mret.rstrip() - -            yield '.' -            yield '\n' +        def get_metadata(self, metadata_key): +            return mockdata.get(metadata_key) +    return MockMetadataClient  class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): @@ -160,9 +120,6 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):          if dmi_data is None:              dmi_data = DMI_DATA_RETURN -        def _get_serial(*_): -            return MockSerial(mockdata) -          def _dmi_data():              return dmi_data @@ -179,7 +136,9 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):              sys_cfg['datasource']['SmartOS'] = ds_cfg          self.apply_patches([(mod, 'LEGACY_USER_D', self.legacy_user_d)]) -        self.apply_patches([(mod, 'get_serial', _get_serial)]) +        self.apply_patches([(mod, 'get_serial', mock.MagicMock())]) +        self.apply_patches([ +            (mod, 'JoyentMetadataClient', get_mock_client(mockdata))])          self.apply_patches([(mod, 'dmi_data', _dmi_data)])          self.apply_patches([(os, 'uname', _os_uname)])          self.apply_patches([(mod, 'device_exists', lambda d: True)]) @@ -448,6 +407,18 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):          self.assertEqual(dsrc.device_name_to_device('FOO'),                           mydscfg['disk_aliases']['FOO']) +    @mock.patch('cloudinit.sources.DataSourceSmartOS.JoyentMetadataClient') +    @mock.patch('cloudinit.sources.DataSourceSmartOS.get_serial') +    def test_serial_console_closed_on_error(self, get_serial, metadata_client): +        class OurException(Exception): +            pass +        metadata_client.side_effect = OurException +        try: +            DataSourceSmartOS.query_data('noun', 'device', 0) +        except OurException: +            pass +        self.assertEqual(1, get_serial.return_value.close.call_count) +  def apply_patches(patches):      ret = [] @@ -458,3 +429,133 @@ def apply_patches(patches):          setattr(ref, name, replace)          ret.append((ref, name, orig))      return ret + + +class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): + +    def setUp(self): +        super(TestJoyentMetadataClient, self).setUp() +        self.serial = mock.MagicMock(spec=serial.Serial) +        self.request_id = 0xabcdef12 +        self.metadata_value = 'value' +        self.response_parts = { +            'command': 'SUCCESS', +            'crc': 'b5a9ff00', +            'length': 17 + len(b64e(self.metadata_value)), +            'payload': b64e(self.metadata_value), +            'request_id': '{0:08x}'.format(self.request_id), +        } + +        def make_response(): +            payload = '' +            if self.response_parts['payload']: +                payload = ' {0}'.format(self.response_parts['payload']) +            del self.response_parts['payload'] +            return ( +                'V2 {length} {crc} {request_id} {command}{payload}\n'.format( +                    payload=payload, **self.response_parts).encode('ascii')) +        self.serial.readline.side_effect = make_response +        self.patched_funcs.enter_context( +            mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint', +                       mock.Mock(return_value=self.request_id))) + +    def _get_client(self): +        return DataSourceSmartOS.JoyentMetadataClient(self.serial) + +    def assertEndsWith(self, haystack, prefix): +        self.assertTrue(haystack.endswith(prefix), +                        "{0} does not end with '{1}'".format( +                            repr(haystack), prefix)) + +    def assertStartsWith(self, haystack, prefix): +        self.assertTrue(haystack.startswith(prefix), +                        "{0} does not start with '{1}'".format( +                            repr(haystack), prefix)) + +    def test_get_metadata_writes_a_single_line(self): +        client = self._get_client() +        client.get_metadata('some_key') +        self.assertEqual(1, self.serial.write.call_count) +        written_line = self.serial.write.call_args[0][0] +        self.assertEndsWith(written_line, b'\n') +        self.assertEqual(1, written_line.count(b'\n')) + +    def _get_written_line(self, key='some_key'): +        client = self._get_client() +        client.get_metadata(key) +        return self.serial.write.call_args[0][0] + +    def test_get_metadata_writes_bytes(self): +        self.assertIsInstance(self._get_written_line(), six.binary_type) + +    def test_get_metadata_line_starts_with_v2(self): +        self.assertStartsWith(self._get_written_line(), b'V2') + +    def test_get_metadata_uses_get_command(self): +        parts = self._get_written_line().decode('ascii').strip().split(' ') +        self.assertEqual('GET', parts[4]) + +    def test_get_metadata_base64_encodes_argument(self): +        key = 'my_key' +        parts = self._get_written_line(key).decode('ascii').strip().split(' ') +        self.assertEqual(b64e(key), parts[5]) + +    def test_get_metadata_calculates_length_correctly(self): +        parts = self._get_written_line().decode('ascii').strip().split(' ') +        expected_length = len(' '.join(parts[3:])) +        self.assertEqual(expected_length, int(parts[1])) + +    def test_get_metadata_uses_appropriate_request_id(self): +        parts = self._get_written_line().decode('ascii').strip().split(' ') +        request_id = parts[3] +        self.assertEqual(8, len(request_id)) +        self.assertEqual(request_id, request_id.lower()) + +    def test_get_metadata_uses_random_number_for_request_id(self): +        line = self._get_written_line() +        request_id = line.decode('ascii').strip().split(' ')[3] +        self.assertEqual('{0:08x}'.format(self.request_id), request_id) + +    def test_get_metadata_checksums_correctly(self): +        parts = self._get_written_line().decode('ascii').strip().split(' ') +        expected_checksum = '{0:08x}'.format( +            crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff) +        checksum = parts[2] +        self.assertEqual(expected_checksum, checksum) + +    def test_get_metadata_reads_a_line(self): +        client = self._get_client() +        client.get_metadata('some_key') +        self.assertEqual(1, self.serial.readline.call_count) + +    def test_get_metadata_returns_valid_value(self): +        client = self._get_client() +        value = client.get_metadata('some_key') +        self.assertEqual(self.metadata_value, value) + +    def test_get_metadata_throws_exception_for_incorrect_length(self): +        self.response_parts['length'] = 0 +        client = self._get_client() +        self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, +                          client.get_metadata, 'some_key') + +    def test_get_metadata_throws_exception_for_incorrect_crc(self): +        self.response_parts['crc'] = 'deadbeef' +        client = self._get_client() +        self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, +                          client.get_metadata, 'some_key') + +    def test_get_metadata_throws_exception_for_request_id_mismatch(self): +        self.response_parts['request_id'] = 'deadbeef' +        client = self._get_client() +        client._checksum = lambda _: self.response_parts['crc'] +        self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, +                          client.get_metadata, 'some_key') + +    def test_get_metadata_returns_None_if_value_not_found(self): +        self.response_parts['payload'] = '' +        self.response_parts['command'] = 'NOTFOUND' +        self.response_parts['length'] = 17 +        client = self._get_client() +        client._checksum = lambda _: self.response_parts['crc'] +        self.assertIsNone(client.get_metadata('some_key')) | 
