diff options
-rw-r--r-- | cloudinit/sources/DataSourceAzure.py | 46 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 29 |
2 files changed, 49 insertions, 26 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index bd3c742b..b93357d5 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -260,7 +260,6 @@ class WALinuxAgentShim(object): def __init__(self): self.endpoint = find_endpoint() - self.goal_state = None self.openssl_manager = OpenSSLManager() self.http_client = AzureEndpointHttpClient( self.openssl_manager.certificate) @@ -276,18 +275,24 @@ class WALinuxAgentShim(object): time.sleep(i + 1) else: break - self.goal_state = GoalState(response.contents, self.http_client) - self.public_keys = [] - if self.goal_state.certificates_xml is not None: - self.public_keys = self.openssl_manager.parse_certificates( - self.goal_state.certificates_xml) - self._report_ready() - - def _report_ready(self): + goal_state = GoalState(response.contents, self.http_client) + public_keys = [] + if goal_state.certificates_xml is not None: + public_keys = self.openssl_manager.parse_certificates( + goal_state.certificates_xml) + data = { + 'instance-id': iid_from_shared_config_content( + goal_state.shared_config_xml), + 'public-keys': public_keys, + } + self._report_ready(goal_state) + return data + + def _report_ready(self, goal_state): document = REPORT_READY_XML_TEMPLATE.format( - incarnation=self.goal_state.incarnation, - container_id=self.goal_state.container_id, - instance_id=self.goal_state.instance_id, + incarnation=goal_state.incarnation, + container_id=goal_state.container_id, + instance_id=goal_state.instance_id, ) self.http_client.post( "http://{}/machine?comp=health".format(self.endpoint), @@ -414,17 +419,16 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) - shim = WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - try: - self.metadata['instance-id'] = iid_from_shared_config_content( - shim.goal_state.shared_config_xml) - except ValueError as e: - LOG.warn( - "failed to get instance id in %s: %s", shim.shared_config, e) + shim = WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + except Exception as exc: + LOG.info("Error communicating with Azure fabric; assume we aren't" + " on Azure.", exc_info=True) + return False - self.metadata['public-keys'] = shim.public_keys + self.metadata['instance-id'] = data['instance-id'] + self.metadata['public-keys'] = data['public-keys'] found_ephemeral = find_ephemeral_disk() if found_ephemeral: diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index dc7f2663..fd5b24f8 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -160,6 +160,12 @@ class TestAzureDataSource(TestCase): mod = DataSourceAzure mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d + fake_shim = mock.MagicMock() + fake_shim().register_with_azure_and_fetch_data.return_value = { + 'instance-id': 'i-my-azure-id', + 'public-keys': [], + } + self.apply_patches([ (mod, 'list_possible_azure_ds_devs', dsdevs), (mod, 'invoke_agent', _invoke_agent), @@ -169,7 +175,8 @@ class TestAzureDataSource(TestCase): (mod, 'perform_hostname_bounce', mock.MagicMock()), (mod, 'get_hostname', mock.MagicMock()), (mod, 'set_hostname', mock.MagicMock()), - ]) + (mod, 'WALinuxAgentShim', fake_shim), + ]) dsrc = mod.DataSourceAzureNet( data.get('sys_cfg', {}), distro=None, paths=self.paths) @@ -852,6 +859,9 @@ class TestWALinuxAgentShim(TestCase): mock.patch.object(DataSourceAzure, 'find_endpoint')) self.GoalState = patches.enter_context( mock.patch.object(DataSourceAzure, 'GoalState')) + self.iid_from_shared_config_content = patches.enter_context( + mock.patch.object(DataSourceAzure, + 'iid_from_shared_config_content')) self.OpenSSLManager = patches.enter_context( mock.patch.object(DataSourceAzure, 'OpenSSLManager')) @@ -877,19 +887,28 @@ class TestWALinuxAgentShim(TestCase): def test_certificates_used_to_determine_public_keys(self): shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() + data = shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(self.GoalState.return_value.certificates_xml)], self.OpenSSLManager.return_value.parse_certificates.call_args_list) self.assertEqual( self.OpenSSLManager.return_value.parse_certificates.return_value, - shim.public_keys) + data['public-keys']) def test_absent_certificates_produces_empty_public_keys(self): self.GoalState.return_value.certificates_xml = None shim = DataSourceAzure.WALinuxAgentShim() - shim.register_with_azure_and_fetch_data() - self.assertEqual([], shim.public_keys) + data = shim.register_with_azure_and_fetch_data() + self.assertEqual([], data['public-keys']) + + def test_instance_id_returned_in_data(self): + shim = DataSourceAzure.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.shared_config_xml)], + self.iid_from_shared_config_content.call_args_list) + self.assertEqual(self.iid_from_shared_config_content.return_value, + data['instance-id']) def test_correct_url_used_for_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' |