diff options
| -rw-r--r-- | cloudinit/sources/DataSourceAzure.py | 46 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 29 | 
2 files changed, 49 insertions, 26 deletions
| diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index bd3c742b..b93357d5 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -260,7 +260,6 @@ class WALinuxAgentShim(object):      def __init__(self):          self.endpoint = find_endpoint() -        self.goal_state = None          self.openssl_manager = OpenSSLManager()          self.http_client = AzureEndpointHttpClient(              self.openssl_manager.certificate) @@ -276,18 +275,24 @@ class WALinuxAgentShim(object):                  time.sleep(i + 1)              else:                  break -        self.goal_state = GoalState(response.contents, self.http_client) -        self.public_keys = [] -        if self.goal_state.certificates_xml is not None: -            self.public_keys = self.openssl_manager.parse_certificates( -                self.goal_state.certificates_xml) -        self._report_ready() - -    def _report_ready(self): +        goal_state = GoalState(response.contents, self.http_client) +        public_keys = [] +        if goal_state.certificates_xml is not None: +            public_keys = self.openssl_manager.parse_certificates( +                goal_state.certificates_xml) +        data = { +            'instance-id': iid_from_shared_config_content( +                goal_state.shared_config_xml), +            'public-keys': public_keys, +        } +        self._report_ready(goal_state) +        return data + +    def _report_ready(self, goal_state):          document = REPORT_READY_XML_TEMPLATE.format( -            incarnation=self.goal_state.incarnation, -            container_id=self.goal_state.container_id, -            instance_id=self.goal_state.instance_id, +            incarnation=goal_state.incarnation, +            container_id=goal_state.container_id, +            instance_id=goal_state.instance_id,          )          self.http_client.post(              "http://{}/machine?comp=health".format(self.endpoint), @@ -414,17 +419,16 @@ class DataSourceAzureNet(sources.DataSource):          # the directory to be protected.          write_files(ddir, files, dirmode=0o700) -        shim = WALinuxAgentShim() -        shim.register_with_azure_and_fetch_data() -          try: -            self.metadata['instance-id'] = iid_from_shared_config_content( -                shim.goal_state.shared_config_xml) -        except ValueError as e: -            LOG.warn( -                "failed to get instance id in %s: %s", shim.shared_config, e) +            shim = WALinuxAgentShim() +            data = shim.register_with_azure_and_fetch_data() +        except Exception as exc: +            LOG.info("Error communicating with Azure fabric; assume we aren't" +                     " on Azure.", exc_info=True) +            return False -        self.metadata['public-keys'] = shim.public_keys +        self.metadata['instance-id'] = data['instance-id'] +        self.metadata['public-keys'] = data['public-keys']          found_ephemeral = find_ephemeral_disk()          if found_ephemeral: diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index dc7f2663..fd5b24f8 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -160,6 +160,12 @@ class TestAzureDataSource(TestCase):          mod = DataSourceAzure          mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d +        fake_shim = mock.MagicMock() +        fake_shim().register_with_azure_and_fetch_data.return_value = { +            'instance-id': 'i-my-azure-id', +            'public-keys': [], +        } +          self.apply_patches([              (mod, 'list_possible_azure_ds_devs', dsdevs),              (mod, 'invoke_agent', _invoke_agent), @@ -169,7 +175,8 @@ class TestAzureDataSource(TestCase):              (mod, 'perform_hostname_bounce', mock.MagicMock()),              (mod, 'get_hostname', mock.MagicMock()),              (mod, 'set_hostname', mock.MagicMock()), -            ]) +            (mod, 'WALinuxAgentShim', fake_shim), +        ])          dsrc = mod.DataSourceAzureNet(              data.get('sys_cfg', {}), distro=None, paths=self.paths) @@ -852,6 +859,9 @@ class TestWALinuxAgentShim(TestCase):              mock.patch.object(DataSourceAzure, 'find_endpoint'))          self.GoalState = patches.enter_context(              mock.patch.object(DataSourceAzure, 'GoalState')) +        self.iid_from_shared_config_content = patches.enter_context( +            mock.patch.object(DataSourceAzure, +                              'iid_from_shared_config_content'))          self.OpenSSLManager = patches.enter_context(              mock.patch.object(DataSourceAzure, 'OpenSSLManager')) @@ -877,19 +887,28 @@ class TestWALinuxAgentShim(TestCase):      def test_certificates_used_to_determine_public_keys(self):          shim = DataSourceAzure.WALinuxAgentShim() -        shim.register_with_azure_and_fetch_data() +        data = shim.register_with_azure_and_fetch_data()          self.assertEqual(              [mock.call(self.GoalState.return_value.certificates_xml)],              self.OpenSSLManager.return_value.parse_certificates.call_args_list)          self.assertEqual(              self.OpenSSLManager.return_value.parse_certificates.return_value, -            shim.public_keys) +            data['public-keys'])      def test_absent_certificates_produces_empty_public_keys(self):          self.GoalState.return_value.certificates_xml = None          shim = DataSourceAzure.WALinuxAgentShim() -        shim.register_with_azure_and_fetch_data() -        self.assertEqual([], shim.public_keys) +        data = shim.register_with_azure_and_fetch_data() +        self.assertEqual([], data['public-keys']) + +    def test_instance_id_returned_in_data(self): +        shim = DataSourceAzure.WALinuxAgentShim() +        data = shim.register_with_azure_and_fetch_data() +        self.assertEqual( +            [mock.call(self.GoalState.return_value.shared_config_xml)], +            self.iid_from_shared_config_content.call_args_list) +        self.assertEqual(self.iid_from_shared_config_content.return_value, +                         data['instance-id'])      def test_correct_url_used_for_report_ready(self):          self.find_endpoint.return_value = 'test_endpoint' | 
