diff options
Diffstat (limited to 'cloudinit/sources/tests')
| -rw-r--r-- | cloudinit/sources/tests/test_init.py | 362 | ||||
| -rw-r--r-- | cloudinit/sources/tests/test_oracle.py | 331 | 
2 files changed, 664 insertions, 29 deletions
| diff --git a/cloudinit/sources/tests/test_init.py b/cloudinit/sources/tests/test_init.py index e7fda22a..8082019e 100644 --- a/cloudinit/sources/tests/test_init.py +++ b/cloudinit/sources/tests/test_init.py @@ -1,14 +1,17 @@  # This file is part of cloud-init. See LICENSE file for license information. +import copy  import inspect  import os  import six  import stat +from cloudinit.event import EventType  from cloudinit.helpers import Paths  from cloudinit import importer  from cloudinit.sources import ( -    INSTANCE_JSON_FILE, DataSource) +    EXPERIMENTAL_TEXT, INSTANCE_JSON_FILE, INSTANCE_JSON_SENSITIVE_FILE, +    REDACT_SENSITIVE_VALUE, UNSET, DataSource, redact_sensitive_keys)  from cloudinit.tests.helpers import CiTestCase, skipIf, mock  from cloudinit.user_data import UserDataProcessor  from cloudinit import util @@ -17,25 +20,32 @@ from cloudinit import util  class DataSourceTestSubclassNet(DataSource):      dsname = 'MyTestSubclass' +    url_max_wait = 55 -    def __init__(self, sys_cfg, distro, paths, custom_userdata=None): +    def __init__(self, sys_cfg, distro, paths, custom_metadata=None, +                 custom_userdata=None, get_data_retval=True):          super(DataSourceTestSubclassNet, self).__init__(              sys_cfg, distro, paths)          self._custom_userdata = custom_userdata +        self._custom_metadata = custom_metadata +        self._get_data_retval = get_data_retval      def _get_cloud_name(self):          return 'SubclassCloudName'      def _get_data(self): -        self.metadata = {'availability_zone': 'myaz', -                         'local-hostname': 'test-subclass-hostname', -                         'region': 'myregion'} +        if self._custom_metadata: +            self.metadata = self._custom_metadata +        else: +            self.metadata = {'availability_zone': 'myaz', +                             'local-hostname': 'test-subclass-hostname', +                             'region': 'myregion'}          if self._custom_userdata:              self.userdata_raw = self._custom_userdata          else:              self.userdata_raw = 'userdata_raw'          self.vendordata_raw = 'vendordata_raw' -        return True +        return self._get_data_retval  class InvalidDataSourceTestSubclassNet(DataSource): @@ -70,8 +80,7 @@ class TestDataSource(CiTestCase):          """Init uses DataSource.dsname for sourcing ds_cfg."""          sys_cfg = {'datasource': {'MyTestSubclass': {'key2': False}}}          distro = 'distrotest'  # generally should be a Distro object -        paths = Paths({}) -        datasource = DataSourceTestSubclassNet(sys_cfg, distro, paths) +        datasource = DataSourceTestSubclassNet(sys_cfg, distro, self.paths)          self.assertEqual({'key2': False}, datasource.ds_cfg)      def test_str_is_classname(self): @@ -81,6 +90,91 @@ class TestDataSource(CiTestCase):              'DataSourceTestSubclassNet',              str(DataSourceTestSubclassNet('', '', self.paths))) +    def test_datasource_get_url_params_defaults(self): +        """get_url_params default url config settings for the datasource.""" +        params = self.datasource.get_url_params() +        self.assertEqual(params.max_wait_seconds, self.datasource.url_max_wait) +        self.assertEqual(params.timeout_seconds, self.datasource.url_timeout) +        self.assertEqual(params.num_retries, self.datasource.url_retries) + +    def test_datasource_get_url_params_subclassed(self): +        """Subclasses can override get_url_params defaults.""" +        sys_cfg = {'datasource': {'MyTestSubclass': {'key2': False}}} +        distro = 'distrotest'  # generally should be a Distro object +        datasource = DataSourceTestSubclassNet(sys_cfg, distro, self.paths) +        expected = (datasource.url_max_wait, datasource.url_timeout, +                    datasource.url_retries) +        url_params = datasource.get_url_params() +        self.assertNotEqual(self.datasource.get_url_params(), url_params) +        self.assertEqual(expected, url_params) + +    def test_datasource_get_url_params_ds_config_override(self): +        """Datasource configuration options can override url param defaults.""" +        sys_cfg = { +            'datasource': { +                'MyTestSubclass': { +                    'max_wait': '1', 'timeout': '2', 'retries': '3'}}} +        datasource = DataSourceTestSubclassNet( +            sys_cfg, self.distro, self.paths) +        expected = (1, 2, 3) +        url_params = datasource.get_url_params() +        self.assertNotEqual( +            (datasource.url_max_wait, datasource.url_timeout, +             datasource.url_retries), +            url_params) +        self.assertEqual(expected, url_params) + +    def test_datasource_get_url_params_is_zero_or_greater(self): +        """get_url_params ignores timeouts with a value below 0.""" +        # Set an override that is below 0 which gets ignored. +        sys_cfg = {'datasource': {'_undef': {'timeout': '-1'}}} +        datasource = DataSource(sys_cfg, self.distro, self.paths) +        (_max_wait, timeout, _retries) = datasource.get_url_params() +        self.assertEqual(0, timeout) + +    def test_datasource_get_url_uses_defaults_on_errors(self): +        """On invalid system config values for url_params defaults are used.""" +        # All invalid values should be logged +        sys_cfg = {'datasource': { +            '_undef': { +                'max_wait': 'nope', 'timeout': 'bug', 'retries': 'nonint'}}} +        datasource = DataSource(sys_cfg, self.distro, self.paths) +        url_params = datasource.get_url_params() +        expected = (datasource.url_max_wait, datasource.url_timeout, +                    datasource.url_retries) +        self.assertEqual(expected, url_params) +        logs = self.logs.getvalue() +        expected_logs = [ +            "Config max_wait 'nope' is not an int, using default '-1'", +            "Config timeout 'bug' is not an int, using default '10'", +            "Config retries 'nonint' is not an int, using default '5'", +        ] +        for log in expected_logs: +            self.assertIn(log, logs) + +    @mock.patch('cloudinit.sources.net.find_fallback_nic') +    def test_fallback_interface_is_discovered(self, m_get_fallback_nic): +        """The fallback_interface is discovered via find_fallback_nic.""" +        m_get_fallback_nic.return_value = 'nic9' +        self.assertEqual('nic9', self.datasource.fallback_interface) + +    @mock.patch('cloudinit.sources.net.find_fallback_nic') +    def test_fallback_interface_logs_undiscovered(self, m_get_fallback_nic): +        """Log a warning when fallback_interface can not discover the nic.""" +        self.datasource._cloud_name = 'MySupahCloud' +        m_get_fallback_nic.return_value = None  # Couldn't discover nic +        self.assertIsNone(self.datasource.fallback_interface) +        self.assertEqual( +            'WARNING: Did not find a fallback interface on MySupahCloud.\n', +            self.logs.getvalue()) + +    @mock.patch('cloudinit.sources.net.find_fallback_nic') +    def test_wb_fallback_interface_is_cached(self, m_get_fallback_nic): +        """The fallback_interface is cached and won't be rediscovered.""" +        self.datasource._fallback_interface = 'nic10' +        self.assertEqual('nic10', self.datasource.fallback_interface) +        m_get_fallback_nic.assert_not_called() +      def test__get_data_unimplemented(self):          """Raise an error when _get_data is not implemented."""          with self.assertRaises(NotImplementedError) as context_manager: @@ -178,8 +272,19 @@ class TestDataSource(CiTestCase):                  self.assertEqual('fqdnhostname.domain.com',                                   datasource.get_hostname(fqdn=True)) -    def test_get_data_write_json_instance_data(self): -        """get_data writes INSTANCE_JSON_FILE to run_dir as readonly root.""" +    def test_get_data_does_not_write_instance_data_on_failure(self): +        """get_data does not write INSTANCE_JSON_FILE on get_data False.""" +        tmp = self.tmp_dir() +        datasource = DataSourceTestSubclassNet( +            self.sys_cfg, self.distro, Paths({'run_dir': tmp}), +            get_data_retval=False) +        self.assertFalse(datasource.get_data()) +        json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) +        self.assertFalse( +            os.path.exists(json_file), 'Found unexpected file %s' % json_file) + +    def test_get_data_writes_json_instance_data_on_success(self): +        """get_data writes INSTANCE_JSON_FILE to run_dir as world readable."""          tmp = self.tmp_dir()          datasource = DataSourceTestSubclassNet(              self.sys_cfg, self.distro, Paths({'run_dir': tmp})) @@ -187,40 +292,126 @@ class TestDataSource(CiTestCase):          json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp)          content = util.load_file(json_file)          expected = { -            'base64-encoded-keys': [], +            'base64_encoded_keys': [], +            'sensitive_keys': [],              'v1': {                  'availability-zone': 'myaz', +                'availability_zone': 'myaz',                  'cloud-name': 'subclasscloudname', +                'cloud_name': 'subclasscloudname',                  'instance-id': 'iid-datasource', +                'instance_id': 'iid-datasource',                  'local-hostname': 'test-subclass-hostname', +                'local_hostname': 'test-subclass-hostname',                  'region': 'myregion'},              'ds': { -                'meta-data': {'availability_zone': 'myaz', +                '_doc': EXPERIMENTAL_TEXT, +                'meta_data': {'availability_zone': 'myaz',                                'local-hostname': 'test-subclass-hostname', -                              'region': 'myregion'}, -                'user-data': 'userdata_raw', -                'vendor-data': 'vendordata_raw'}} +                              'region': 'myregion'}}}          self.assertEqual(expected, util.load_json(content))          file_stat = os.stat(json_file) +        self.assertEqual(0o644, stat.S_IMODE(file_stat.st_mode)) +        self.assertEqual(expected, util.load_json(content)) + +    def test_get_data_writes_json_instance_data_sensitive(self): +        """get_data writes INSTANCE_JSON_SENSITIVE_FILE as readonly root.""" +        tmp = self.tmp_dir() +        datasource = DataSourceTestSubclassNet( +            self.sys_cfg, self.distro, Paths({'run_dir': tmp}), +            custom_metadata={ +                'availability_zone': 'myaz', +                'local-hostname': 'test-subclass-hostname', +                'region': 'myregion', +                'some': {'security-credentials': { +                    'cred1': 'sekret', 'cred2': 'othersekret'}}}) +        self.assertEqual( +            ('security-credentials',), datasource.sensitive_metadata_keys) +        datasource.get_data() +        json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) +        sensitive_json_file = self.tmp_path(INSTANCE_JSON_SENSITIVE_FILE, tmp) +        redacted = util.load_json(util.load_file(json_file)) +        self.assertEqual( +            {'cred1': 'sekret', 'cred2': 'othersekret'}, +            redacted['ds']['meta_data']['some']['security-credentials']) +        content = util.load_file(sensitive_json_file) +        expected = { +            'base64_encoded_keys': [], +            'sensitive_keys': ['ds/meta_data/some/security-credentials'], +            'v1': { +                'availability-zone': 'myaz', +                'availability_zone': 'myaz', +                'cloud-name': 'subclasscloudname', +                'cloud_name': 'subclasscloudname', +                'instance-id': 'iid-datasource', +                'instance_id': 'iid-datasource', +                'local-hostname': 'test-subclass-hostname', +                'local_hostname': 'test-subclass-hostname', +                'region': 'myregion'}, +            'ds': { +                '_doc': EXPERIMENTAL_TEXT, +                'meta_data': { +                    'availability_zone': 'myaz', +                    'local-hostname': 'test-subclass-hostname', +                    'region': 'myregion', +                    'some': {'security-credentials': REDACT_SENSITIVE_VALUE}}} +        } +        self.maxDiff = None +        self.assertEqual(expected, util.load_json(content)) +        file_stat = os.stat(sensitive_json_file)          self.assertEqual(0o600, stat.S_IMODE(file_stat.st_mode)) +        self.assertEqual(expected, util.load_json(content))      def test_get_data_handles_redacted_unserializable_content(self):          """get_data warns unserializable content in INSTANCE_JSON_FILE."""          tmp = self.tmp_dir()          datasource = DataSourceTestSubclassNet(              self.sys_cfg, self.distro, Paths({'run_dir': tmp}), -            custom_userdata={'key1': 'val1', 'key2': {'key2.1': self.paths}}) -        self.assertTrue(datasource.get_data()) +            custom_metadata={'key1': 'val1', 'key2': {'key2.1': self.paths}}) +        datasource.get_data()          json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp)          content = util.load_file(json_file) -        expected_userdata = { +        expected_metadata = {              'key1': 'val1',              'key2': {                  'key2.1': "Warning: redacted unserializable type <class"                            " 'cloudinit.helpers.Paths'>"}}          instance_json = util.load_json(content)          self.assertEqual( -            expected_userdata, instance_json['ds']['user-data']) +            expected_metadata, instance_json['ds']['meta_data']) + +    def test_persist_instance_data_writes_ec2_metadata_when_set(self): +        """When ec2_metadata class attribute is set, persist to json.""" +        tmp = self.tmp_dir() +        datasource = DataSourceTestSubclassNet( +            self.sys_cfg, self.distro, Paths({'run_dir': tmp})) +        datasource.ec2_metadata = UNSET +        datasource.get_data() +        json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) +        instance_data = util.load_json(util.load_file(json_file)) +        self.assertNotIn('ec2_metadata', instance_data['ds']) +        datasource.ec2_metadata = {'ec2stuff': 'is good'} +        datasource.persist_instance_data() +        instance_data = util.load_json(util.load_file(json_file)) +        self.assertEqual( +            {'ec2stuff': 'is good'}, +            instance_data['ds']['ec2_metadata']) + +    def test_persist_instance_data_writes_network_json_when_set(self): +        """When network_data.json class attribute is set, persist to json.""" +        tmp = self.tmp_dir() +        datasource = DataSourceTestSubclassNet( +            self.sys_cfg, self.distro, Paths({'run_dir': tmp})) +        datasource.get_data() +        json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) +        instance_data = util.load_json(util.load_file(json_file)) +        self.assertNotIn('network_json', instance_data['ds']) +        datasource.network_json = {'network_json': 'is good'} +        datasource.persist_instance_data() +        instance_data = util.load_json(util.load_file(json_file)) +        self.assertEqual( +            {'network_json': 'is good'}, +            instance_data['ds']['network_json'])      @skipIf(not six.PY3, "json serialization on <= py2.7 handles bytes")      def test_get_data_base64encodes_unserializable_bytes(self): @@ -228,17 +419,17 @@ class TestDataSource(CiTestCase):          tmp = self.tmp_dir()          datasource = DataSourceTestSubclassNet(              self.sys_cfg, self.distro, Paths({'run_dir': tmp}), -            custom_userdata={'key1': 'val1', 'key2': {'key2.1': b'\x123'}}) +            custom_metadata={'key1': 'val1', 'key2': {'key2.1': b'\x123'}})          self.assertTrue(datasource.get_data())          json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp)          content = util.load_file(json_file)          instance_json = util.load_json(content) -        self.assertEqual( -            ['ds/user-data/key2/key2.1'], -            instance_json['base64-encoded-keys']) +        self.assertItemsEqual( +            ['ds/meta_data/key2/key2.1'], +            instance_json['base64_encoded_keys'])          self.assertEqual(              {'key1': 'val1', 'key2': {'key2.1': 'EjM='}}, -            instance_json['ds']['user-data']) +            instance_json['ds']['meta_data'])      @skipIf(not six.PY2, "json serialization on <= py2.7 handles bytes")      def test_get_data_handles_bytes_values(self): @@ -246,15 +437,15 @@ class TestDataSource(CiTestCase):          tmp = self.tmp_dir()          datasource = DataSourceTestSubclassNet(              self.sys_cfg, self.distro, Paths({'run_dir': tmp}), -            custom_userdata={'key1': 'val1', 'key2': {'key2.1': b'\x123'}}) +            custom_metadata={'key1': 'val1', 'key2': {'key2.1': b'\x123'}})          self.assertTrue(datasource.get_data())          json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp)          content = util.load_file(json_file)          instance_json = util.load_json(content) -        self.assertEqual([], instance_json['base64-encoded-keys']) +        self.assertEqual([], instance_json['base64_encoded_keys'])          self.assertEqual(              {'key1': 'val1', 'key2': {'key2.1': '\x123'}}, -            instance_json['ds']['user-data']) +            instance_json['ds']['meta_data'])      @skipIf(not six.PY2, "Only python2 hits UnicodeDecodeErrors on non-utf8")      def test_non_utf8_encoding_logs_warning(self): @@ -262,7 +453,7 @@ class TestDataSource(CiTestCase):          tmp = self.tmp_dir()          datasource = DataSourceTestSubclassNet(              self.sys_cfg, self.distro, Paths({'run_dir': tmp}), -            custom_userdata={'key1': 'val1', 'key2': {'key2.1': b'ab\xaadef'}}) +            custom_metadata={'key1': 'val1', 'key2': {'key2.1': b'ab\xaadef'}})          self.assertTrue(datasource.get_data())          json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp)          self.assertFalse(os.path.exists(json_file)) @@ -278,7 +469,7 @@ class TestDataSource(CiTestCase):          base_args = get_args(DataSource.get_hostname)  # pylint: disable=W1505          # Import all DataSource subclasses so we can inspect them.          modules = util.find_modules(os.path.dirname(os.path.dirname(__file__))) -        for loc, name in modules.items(): +        for _loc, name in modules.items():              mod_locs, _ = importer.find_module(name, ['cloudinit.sources'], [])              if mod_locs:                  importer.import_module(mod_locs[0]) @@ -296,3 +487,116 @@ class TestDataSource(CiTestCase):                      get_args(grandchild.get_hostname),  # pylint: disable=W1505                      '%s does not implement DataSource.get_hostname params'                      % grandchild) + +    def test_clear_cached_attrs_resets_cached_attr_class_attributes(self): +        """Class attributes listed in cached_attr_defaults are reset.""" +        count = 0 +        # Setup values for all cached class attributes +        for attr, value in self.datasource.cached_attr_defaults: +            setattr(self.datasource, attr, count) +            count += 1 +        self.datasource._dirty_cache = True +        self.datasource.clear_cached_attrs() +        for attr, value in self.datasource.cached_attr_defaults: +            self.assertEqual(value, getattr(self.datasource, attr)) + +    def test_clear_cached_attrs_noops_on_clean_cache(self): +        """Class attributes listed in cached_attr_defaults are reset.""" +        count = 0 +        # Setup values for all cached class attributes +        for attr, _ in self.datasource.cached_attr_defaults: +            setattr(self.datasource, attr, count) +            count += 1 +        self.datasource._dirty_cache = False   # Fake clean cache +        self.datasource.clear_cached_attrs() +        count = 0 +        for attr, _ in self.datasource.cached_attr_defaults: +            self.assertEqual(count, getattr(self.datasource, attr)) +            count += 1 + +    def test_clear_cached_attrs_skips_non_attr_class_attributes(self): +        """Skip any cached_attr_defaults which aren't class attributes.""" +        self.datasource._dirty_cache = True +        self.datasource.clear_cached_attrs() +        for attr in ('ec2_metadata', 'network_json'): +            self.assertFalse(hasattr(self.datasource, attr)) + +    def test_clear_cached_attrs_of_custom_attrs(self): +        """Custom attr_values can be passed to clear_cached_attrs.""" +        self.datasource._dirty_cache = True +        cached_attr_name = self.datasource.cached_attr_defaults[0][0] +        setattr(self.datasource, cached_attr_name, 'himom') +        self.datasource.myattr = 'orig' +        self.datasource.clear_cached_attrs( +            attr_defaults=(('myattr', 'updated'),)) +        self.assertEqual('himom', getattr(self.datasource, cached_attr_name)) +        self.assertEqual('updated', self.datasource.myattr) + +    def test_update_metadata_only_acts_on_supported_update_events(self): +        """update_metadata won't get_data on unsupported update events.""" +        self.datasource.update_events['network'].discard(EventType.BOOT) +        self.assertEqual( +            {'network': set([EventType.BOOT_NEW_INSTANCE])}, +            self.datasource.update_events) + +        def fake_get_data(): +            raise Exception('get_data should not be called') + +        self.datasource.get_data = fake_get_data +        self.assertFalse( +            self.datasource.update_metadata( +                source_event_types=[EventType.BOOT])) + +    def test_update_metadata_returns_true_on_supported_update_event(self): +        """update_metadata returns get_data response on supported events.""" + +        def fake_get_data(): +            return True + +        self.datasource.get_data = fake_get_data +        self.datasource._network_config = 'something' +        self.datasource._dirty_cache = True +        self.assertTrue( +            self.datasource.update_metadata( +                source_event_types=[ +                    EventType.BOOT, EventType.BOOT_NEW_INSTANCE])) +        self.assertEqual(UNSET, self.datasource._network_config) +        self.assertIn( +            "DEBUG: Update datasource metadata and network config due to" +            " events: New instance first boot", +            self.logs.getvalue()) + + +class TestRedactSensitiveData(CiTestCase): + +    def test_redact_sensitive_data_noop_when_no_sensitive_keys_present(self): +        """When sensitive_keys is absent or empty from metadata do nothing.""" +        md = {'my': 'data'} +        self.assertEqual( +            md, redact_sensitive_keys(md, redact_value='redacted')) +        md['sensitive_keys'] = [] +        self.assertEqual( +            md, redact_sensitive_keys(md, redact_value='redacted')) + +    def test_redact_sensitive_data_redacts_exact_match_name(self): +        """Only exact matched sensitive_keys are redacted from metadata.""" +        md = {'sensitive_keys': ['md/secure'], +              'md': {'secure': 's3kr1t', 'insecure': 'publik'}} +        secure_md = copy.deepcopy(md) +        secure_md['md']['secure'] = 'redacted' +        self.assertEqual( +            secure_md, +            redact_sensitive_keys(md, redact_value='redacted')) + +    def test_redact_sensitive_data_does_redacts_with_default_string(self): +        """When redact_value is absent, REDACT_SENSITIVE_VALUE is used.""" +        md = {'sensitive_keys': ['md/secure'], +              'md': {'secure': 's3kr1t', 'insecure': 'publik'}} +        secure_md = copy.deepcopy(md) +        secure_md['md']['secure'] = 'redacted for non-root user' +        self.assertEqual( +            secure_md, +            redact_sensitive_keys(md)) + + +# vi: ts=4 expandtab diff --git a/cloudinit/sources/tests/test_oracle.py b/cloudinit/sources/tests/test_oracle.py new file mode 100644 index 00000000..7599126c --- /dev/null +++ b/cloudinit/sources/tests/test_oracle.py @@ -0,0 +1,331 @@ +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit.sources import DataSourceOracle as oracle +from cloudinit.sources import BrokenMetadata +from cloudinit import helpers + +from cloudinit.tests import helpers as test_helpers + +from textwrap import dedent +import argparse +import httpretty +import json +import mock +import os +import six +import uuid + +DS_PATH = "cloudinit.sources.DataSourceOracle" +MD_VER = "2013-10-17" + + +class TestDataSourceOracle(test_helpers.CiTestCase): +    """Test datasource DataSourceOracle.""" + +    ds_class = oracle.DataSourceOracle + +    my_uuid = str(uuid.uuid4()) +    my_md = {"uuid": "ocid1.instance.oc1.phx.abyhqlj", +             "name": "ci-vm1", "availability_zone": "phx-ad-3", +             "hostname": "ci-vm1hostname", +             "launch_index": 0, "files": [], +             "public_keys": {"0": "ssh-rsa AAAAB3N...== user@host"}, +             "meta": {}} + +    def _patch_instance(self, inst, patches): +        """Patch an instance of a class 'inst'. +        for each name, kwargs in patches: +             inst.name = mock.Mock(**kwargs) +        returns a namespace object that has +             namespace.name = mock.Mock(**kwargs) +        Do not bother with cleanup as instance is assumed transient.""" +        mocks = argparse.Namespace() +        for name, kwargs in patches.items(): +            imock = mock.Mock(name=name, spec=getattr(inst, name), **kwargs) +            setattr(mocks, name, imock) +            setattr(inst, name, imock) +        return mocks + +    def _get_ds(self, sys_cfg=None, distro=None, paths=None, ud_proc=None, +                patches=None): +        if sys_cfg is None: +            sys_cfg = {} +        if patches is None: +            patches = {} +        if paths is None: +            tmpd = self.tmp_dir() +            dirs = {'cloud_dir': self.tmp_path('cloud_dir', tmpd), +                    'run_dir': self.tmp_path('run_dir')} +            for d in dirs.values(): +                os.mkdir(d) +            paths = helpers.Paths(dirs) + +        ds = self.ds_class(sys_cfg=sys_cfg, distro=distro, +                           paths=paths, ud_proc=ud_proc) + +        return ds, self._patch_instance(ds, patches) + +    def test_platform_not_viable_returns_false(self): +        ds, mocks = self._get_ds( +            patches={'_is_platform_viable': {'return_value': False}}) +        self.assertFalse(ds._get_data()) +        mocks._is_platform_viable.assert_called_once_with() + +    @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True) +    def test_without_userdata(self, m_is_iscsi_root): +        """If no user-data is provided, it should not be in return dict.""" +        ds, mocks = self._get_ds(patches={ +            '_is_platform_viable': {'return_value': True}, +            'crawl_metadata': { +                'return_value': { +                    MD_VER: {'system_uuid': self.my_uuid, +                             'meta_data': self.my_md}}}}) +        self.assertTrue(ds._get_data()) +        mocks._is_platform_viable.assert_called_once_with() +        mocks.crawl_metadata.assert_called_once_with() +        self.assertEqual(self.my_uuid, ds.system_uuid) +        self.assertEqual(self.my_md['availability_zone'], ds.availability_zone) +        self.assertIn(self.my_md["public_keys"]["0"], ds.get_public_ssh_keys()) +        self.assertEqual(self.my_md['uuid'], ds.get_instance_id()) +        self.assertIsNone(ds.userdata_raw) + +    @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True) +    def test_with_vendordata(self, m_is_iscsi_root): +        """Test with vendor data.""" +        vd = {'cloud-init': '#cloud-config\nkey: value'} +        ds, mocks = self._get_ds(patches={ +            '_is_platform_viable': {'return_value': True}, +            'crawl_metadata': { +                'return_value': { +                    MD_VER: {'system_uuid': self.my_uuid, +                             'meta_data': self.my_md, +                             'vendor_data': vd}}}}) +        self.assertTrue(ds._get_data()) +        mocks._is_platform_viable.assert_called_once_with() +        mocks.crawl_metadata.assert_called_once_with() +        self.assertEqual(vd, ds.vendordata_pure) +        self.assertEqual(vd['cloud-init'], ds.vendordata_raw) + +    @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True) +    def test_with_userdata(self, m_is_iscsi_root): +        """Ensure user-data is populated if present and is binary.""" +        my_userdata = b'abcdefg' +        ds, mocks = self._get_ds(patches={ +            '_is_platform_viable': {'return_value': True}, +            'crawl_metadata': { +                'return_value': { +                    MD_VER: {'system_uuid': self.my_uuid, +                             'meta_data': self.my_md, +                             'user_data': my_userdata}}}}) +        self.assertTrue(ds._get_data()) +        mocks._is_platform_viable.assert_called_once_with() +        mocks.crawl_metadata.assert_called_once_with() +        self.assertEqual(self.my_uuid, ds.system_uuid) +        self.assertIn(self.my_md["public_keys"]["0"], ds.get_public_ssh_keys()) +        self.assertEqual(self.my_md['uuid'], ds.get_instance_id()) +        self.assertEqual(my_userdata, ds.userdata_raw) + +    @mock.patch(DS_PATH + ".cmdline.read_kernel_cmdline_config") +    @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True) +    def test_network_cmdline(self, m_is_iscsi_root, m_cmdline_config): +        """network_config should read kernel cmdline.""" +        distro = mock.MagicMock() +        ds, _ = self._get_ds(distro=distro, patches={ +            '_is_platform_viable': {'return_value': True}, +            'crawl_metadata': { +                'return_value': { +                    MD_VER: {'system_uuid': self.my_uuid, +                             'meta_data': self.my_md}}}}) +        ncfg = {'version': 1, 'config': [{'a': 'b'}]} +        m_cmdline_config.return_value = ncfg +        self.assertTrue(ds._get_data()) +        self.assertEqual(ncfg, ds.network_config) +        m_cmdline_config.assert_called_once_with() +        self.assertFalse(distro.generate_fallback_config.called) + +    @mock.patch(DS_PATH + ".cmdline.read_kernel_cmdline_config") +    @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True) +    def test_network_fallback(self, m_is_iscsi_root, m_cmdline_config): +        """test that fallback network is generated if no kernel cmdline.""" +        distro = mock.MagicMock() +        ds, _ = self._get_ds(distro=distro, patches={ +            '_is_platform_viable': {'return_value': True}, +            'crawl_metadata': { +                'return_value': { +                    MD_VER: {'system_uuid': self.my_uuid, +                             'meta_data': self.my_md}}}}) +        ncfg = {'version': 1, 'config': [{'a': 'b'}]} +        m_cmdline_config.return_value = None +        self.assertTrue(ds._get_data()) +        ncfg = {'version': 1, 'config': [{'distro1': 'value'}]} +        distro.generate_fallback_config.return_value = ncfg +        self.assertEqual(ncfg, ds.network_config) +        m_cmdline_config.assert_called_once_with() +        distro.generate_fallback_config.assert_called_once_with() +        self.assertEqual(1, m_cmdline_config.call_count) + +        # test that the result got cached, and the methods not re-called. +        self.assertEqual(ncfg, ds.network_config) +        self.assertEqual(1, m_cmdline_config.call_count) + + +@mock.patch(DS_PATH + "._read_system_uuid", return_value=str(uuid.uuid4())) +class TestReadMetaData(test_helpers.HttprettyTestCase): +    """Test the read_metadata which interacts with http metadata service.""" + +    mdurl = oracle.METADATA_ENDPOINT +    my_md = {"uuid": "ocid1.instance.oc1.phx.abyhqlj", +             "name": "ci-vm1", "availability_zone": "phx-ad-3", +             "hostname": "ci-vm1hostname", +             "launch_index": 0, "files": [], +             "public_keys": {"0": "ssh-rsa AAAAB3N...== user@host"}, +             "meta": {}} + +    def populate_md(self, data): +        """call httppretty.register_url for each item dict 'data', +           including valid indexes. Text values converted to bytes.""" +        httpretty.register_uri( +            httpretty.GET, self.mdurl + MD_VER + "/", +            '\n'.join(data.keys()).encode('utf-8')) +        for k, v in data.items(): +            httpretty.register_uri( +                httpretty.GET, self.mdurl + MD_VER + "/" + k, +                v if not isinstance(v, six.text_type) else v.encode('utf-8')) + +    def test_broken_no_sys_uuid(self, m_read_system_uuid): +        """Datasource requires ability to read system_uuid and true return.""" +        m_read_system_uuid.return_value = None +        self.assertRaises(BrokenMetadata, oracle.read_metadata) + +    def test_broken_no_metadata_json(self, m_read_system_uuid): +        """Datasource requires meta_data.json.""" +        httpretty.register_uri( +            httpretty.GET, self.mdurl + MD_VER + "/", +            '\n'.join(['user_data']).encode('utf-8')) +        with self.assertRaises(BrokenMetadata) as cm: +            oracle.read_metadata() +        self.assertIn("Required field 'meta_data.json' missing", +                      str(cm.exception)) + +    def test_with_userdata(self, m_read_system_uuid): +        data = {'user_data': b'#!/bin/sh\necho hi world\n', +                'meta_data.json': json.dumps(self.my_md)} +        self.populate_md(data) +        result = oracle.read_metadata()[MD_VER] +        self.assertEqual(data['user_data'], result['user_data']) +        self.assertEqual(self.my_md, result['meta_data']) + +    def test_without_userdata(self, m_read_system_uuid): +        data = {'meta_data.json': json.dumps(self.my_md)} +        self.populate_md(data) +        result = oracle.read_metadata()[MD_VER] +        self.assertNotIn('user_data', result) +        self.assertEqual(self.my_md, result['meta_data']) + +    def test_unknown_fields_included(self, m_read_system_uuid): +        """Unknown fields listed in index should be included. +        And those ending in .json should be decoded.""" +        some_data = {'key1': 'data1', 'subk1': {'subd1': 'subv'}} +        some_vendor_data = {'cloud-init': 'foo'} +        data = {'meta_data.json': json.dumps(self.my_md), +                'some_data.json': json.dumps(some_data), +                'vendor_data.json': json.dumps(some_vendor_data), +                'other_blob': b'this is blob'} +        self.populate_md(data) +        result = oracle.read_metadata()[MD_VER] +        self.assertNotIn('user_data', result) +        self.assertEqual(self.my_md, result['meta_data']) +        self.assertEqual(some_data, result['some_data']) +        self.assertEqual(some_vendor_data, result['vendor_data']) +        self.assertEqual(data['other_blob'], result['other_blob']) + + +class TestIsPlatformViable(test_helpers.CiTestCase): +    @mock.patch(DS_PATH + ".util.read_dmi_data", +                return_value=oracle.CHASSIS_ASSET_TAG) +    def test_expected_viable(self, m_read_dmi_data): +        """System with known chassis tag is viable.""" +        self.assertTrue(oracle._is_platform_viable()) +        m_read_dmi_data.assert_has_calls([mock.call('chassis-asset-tag')]) + +    @mock.patch(DS_PATH + ".util.read_dmi_data", return_value=None) +    def test_expected_not_viable_dmi_data_none(self, m_read_dmi_data): +        """System without known chassis tag is not viable.""" +        self.assertFalse(oracle._is_platform_viable()) +        m_read_dmi_data.assert_has_calls([mock.call('chassis-asset-tag')]) + +    @mock.patch(DS_PATH + ".util.read_dmi_data", return_value="LetsGoCubs") +    def test_expected_not_viable_other(self, m_read_dmi_data): +        """System with unnown chassis tag is not viable.""" +        self.assertFalse(oracle._is_platform_viable()) +        m_read_dmi_data.assert_has_calls([mock.call('chassis-asset-tag')]) + + +class TestLoadIndex(test_helpers.CiTestCase): +    """_load_index handles parsing of an index into a proper list. +    The tests here guarantee correct parsing of html version or +    a fixed version.  See the function docstring for more doc.""" + +    _known_html_api_versions = dedent("""\ +        <html> +        <head><title>Index of /openstack/</title></head> +        <body bgcolor="white"> +        <h1>Index of /openstack/</h1><hr><pre><a href="../">../</a> +        <a href="2013-10-17/">2013-10-17/</a>   27-Jun-2018 12:22  - +        <a href="latest/">latest/</a>           27-Jun-2018 12:22  - +        </pre><hr></body> +        </html>""") + +    _known_html_contents = dedent("""\ +        <html> +        <head><title>Index of /openstack/2013-10-17/</title></head> +        <body bgcolor="white"> +        <h1>Index of /openstack/2013-10-17/</h1><hr><pre><a href="../">../</a> +        <a href="meta_data.json">meta_data.json</a>  27-Jun-2018 12:22  679 +        <a href="user_data">user_data</a>            27-Jun-2018 12:22  146 +        </pre><hr></body> +        </html>""") + +    def test_parse_html(self): +        """Test parsing of lower case html.""" +        self.assertEqual( +            ['2013-10-17/', 'latest/'], +            oracle._load_index(self._known_html_api_versions)) +        self.assertEqual( +            ['meta_data.json', 'user_data'], +            oracle._load_index(self._known_html_contents)) + +    def test_parse_html_upper(self): +        """Test parsing of upper case html, although known content is lower.""" +        def _toupper(data): +            return data.replace("<a", "<A").replace("html>", "HTML>") + +        self.assertEqual( +            ['2013-10-17/', 'latest/'], +            oracle._load_index(_toupper(self._known_html_api_versions))) +        self.assertEqual( +            ['meta_data.json', 'user_data'], +            oracle._load_index(_toupper(self._known_html_contents))) + +    def test_parse_newline_list_with_endl(self): +        """Test parsing of newline separated list with ending newline.""" +        self.assertEqual( +            ['2013-10-17/', 'latest/'], +            oracle._load_index("\n".join(["2013-10-17/", "latest/", ""]))) +        self.assertEqual( +            ['meta_data.json', 'user_data'], +            oracle._load_index("\n".join(["meta_data.json", "user_data", ""]))) + +    def test_parse_newline_list_without_endl(self): +        """Test parsing of newline separated list with no ending newline. + +        Actual openstack implementation does not include trailing newline.""" +        self.assertEqual( +            ['2013-10-17/', 'latest/'], +            oracle._load_index("\n".join(["2013-10-17/", "latest/"]))) +        self.assertEqual( +            ['meta_data.json', 'user_data'], +            oracle._load_index("\n".join(["meta_data.json", "user_data"]))) + + +# vi: ts=4 expandtab | 
