From e56b55452549cb037da0a4165154ffa494e9678a Mon Sep 17 00:00:00 2001 From: Thomas Stringer Date: Thu, 10 Sep 2020 14:29:54 -0400 Subject: Retrieve SSH keys from IMDS first with OVF as a fallback (#509) * pull ssh keys from imds first and fall back to ovf if unavailable * refactor log and diagnostic messages * refactor the OpenSSLManager instantiation and certificate usage * fix unit test where exception was being silenced for generate cert * fix tests now that certificate is not always generated * add documentation for ssh key retrieval * add ability to check if http client has security enabled * refactor certificate logic to GoalState --- tests/unittests/test_datasource/test_azure_helper.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) (limited to 'tests/unittests/test_datasource/test_azure_helper.py') diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 5e6d3d2d..5c31b8be 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -609,11 +609,11 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.return_value.container_id = self.test_container_id self.GoalState.return_value.instance_id = self.test_instance_id - def test_azure_endpoint_client_uses_certificate_during_report_ready(self): + def test_http_client_does_not_use_certificate(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( - [mock.call(self.OpenSSLManager.return_value.certificate)], + [mock.call(None)], self.AzureEndpointHttpClient.call_args_list) def test_correct_url_used_for_goalstate_during_report_ready(self): @@ -625,8 +625,11 @@ class TestWALinuxAgentShim(CiTestCase): [mock.call('http://test_endpoint/machine/?comp=goalstate')], get.call_args_list) self.assertEqual( - [mock.call(get.return_value.contents, - self.AzureEndpointHttpClient.return_value)], + [mock.call( + get.return_value.contents, + self.AzureEndpointHttpClient.return_value, + False + )], self.GoalState.call_args_list) def test_certificates_used_to_determine_public_keys(self): @@ -701,7 +704,7 @@ class TestWALinuxAgentShim(CiTestCase): shim.register_with_azure_and_fetch_data() shim.clean_up() self.assertEqual( - 1, self.OpenSSLManager.return_value.clean_up.call_count) + 0, self.OpenSSLManager.return_value.clean_up.call_count) def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): self.AzureEndpointHttpClient.return_value.get \ -- cgit v1.2.3 From 8766784f4b1d1f9f6a9094e1268e4accb811ea7f Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Fri, 16 Oct 2020 08:54:38 -0700 Subject: DataSourceAzure: write marker file after report ready in preprovisioning (#590) DataSourceAzure previously writes the preprovisioning reported ready marker file before it goes through the report ready workflow. On certain VM instances, the marker file is successfully written but then reporting ready fails. Upon rare VM reboots by the platform, cloud-init sees that the report ready marker file already exists. The existence of this marker file tells cloud-init not to report ready again (because it mistakenly assumes that it already reported ready in preprovisioning). In this scenario, cloud-init instead erroneously takes the reprovisioning workflow instead of reporting ready again. --- cloudinit/sources/DataSourceAzure.py | 23 +++++- tests/unittests/test_datasource/test_azure.py | 86 +++++++++++++++++----- .../unittests/test_datasource/test_azure_helper.py | 3 +- 3 files changed, 90 insertions(+), 22 deletions(-) (limited to 'tests/unittests/test_datasource/test_azure_helper.py') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 8858fbd5..70e32f46 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -720,12 +720,23 @@ class DataSourceAzure(sources.DataSource): self._ephemeral_dhcp_ctx.clean_network() break + report_ready_succeeded = self._report_ready(lease=lease) + if not report_ready_succeeded: + msg = ('Failed reporting ready while in ' + 'the preprovisioning pool.') + report_diagnostic_event(msg, logger_func=LOG.error) + self._ephemeral_dhcp_ctx.clean_network() + raise sources.InvalidMetaDataException(msg) + path = REPORTED_READY_MARKER_FILE LOG.info( "Creating a marker file to report ready: %s", path) util.write_file(path, "{pid}: {time}\n".format( pid=os.getpid(), time=time())) - self._report_ready(lease=lease) + report_diagnostic_event( + 'Successfully created reported ready marker file ' + 'while in the preprovisioning pool.', + logger_func=LOG.debug) report_ready = False with events.ReportEventStack( @@ -773,14 +784,20 @@ class DataSourceAzure(sources.DataSource): return return_val @azure_ds_telemetry_reporter - def _report_ready(self, lease): - """Tells the fabric provisioning has completed """ + def _report_ready(self, lease: dict) -> bool: + """Tells the fabric provisioning has completed. + + @param lease: dhcp lease to use for sending the ready signal. + @return: The success status of sending the ready signal. + """ try: get_metadata_from_fabric(None, lease['unknown-245']) + return True except Exception as e: report_diagnostic_event( "Error communicating with Azure fabric; You may experience " "connectivity issues: %s" % e, logger_func=LOG.warning) + return False def _should_reprovision(self, ret): """Whether or not we should poll IMDS for reprovisioning data. diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 3b9456f7..56c1cf18 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -1161,6 +1161,19 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds({'ovfcontent': xml}) dsrc.get_data() + def test_dsaz_report_ready_returns_true_when_report_succeeds( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + self.assertTrue(dsrc._report_ready(lease=mock.MagicMock())) + + def test_dsaz_report_ready_returns_false_and_does_not_propagate_exc( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + self.get_metadata_from_fabric.side_effect = Exception + self.assertFalse(dsrc._report_ready(lease=mock.MagicMock())) + def test_exception_fetching_fabric_data_doesnt_propagate(self): """Errors communicating with fabric should warn, but return True.""" dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) @@ -2041,7 +2054,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): @mock.patch('time.sleep', mock.MagicMock()) @mock.patch(MOCKPATH + 'EphemeralDHCPv4') def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, report_ready_func, - fake_resp, m_media_switch, m_dhcp, + m_request, m_media_switch, m_dhcp, m_net): """The poll_imds will retry DHCP on IMDS timeout.""" report_file = self.tmp_path('report_marker', self.tmp) @@ -2070,7 +2083,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): # Third try should succeed and stop retries or redhcp return mock.MagicMock(status_code=200, text="good", content="good") - fake_resp.side_effect = fake_timeout_once + m_request.side_effect = fake_timeout_once dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): @@ -2080,11 +2093,10 @@ class TestPreprovisioningPollIMDS(CiTestCase): self.assertEqual(3, m_dhcpv4.call_count, 'Expected 3 DHCP calls') self.assertEqual(4, self.tries, 'Expected 4 total reads from IMDS') - def test_poll_imds_report_ready_false(self, - report_ready_func, fake_resp, - m_media_switch, m_dhcp, m_net): - """The poll_imds should not call reporting ready - when flag is false""" + def test_does_not_poll_imds_report_ready_when_marker_file_exists( + self, m_report_ready, m_request, m_media_switch, m_dhcp, m_net): + """poll_imds should not call report ready when the reported ready + marker file exists""" report_file = self.tmp_path('report_marker', self.tmp) write_file(report_file, content='dont run report_ready :)') m_dhcp.return_value = [{ @@ -2095,11 +2107,49 @@ class TestPreprovisioningPollIMDS(CiTestCase): dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): dsa._poll_imds() - self.assertEqual(report_ready_func.call_count, 0) + self.assertEqual(m_report_ready.call_count, 0) + + def test_poll_imds_report_ready_success_writes_marker_file( + self, m_report_ready, m_request, m_media_switch, m_dhcp, m_net): + """poll_imds should write the report_ready marker file if + reporting ready succeeds""" + report_file = self.tmp_path('report_marker', self.tmp) + m_dhcp.return_value = [{ + 'interface': 'eth9', 'fixed-address': '192.168.2.9', + 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', + 'unknown-245': '624c3620'}] + m_media_switch.return_value = None + dsa = dsaz.DataSourceAzure({}, distro=None, paths=self.paths) + self.assertFalse(os.path.exists(report_file)) + with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): + dsa._poll_imds() + self.assertEqual(m_report_ready.call_count, 1) + self.assertTrue(os.path.exists(report_file)) + + def test_poll_imds_report_ready_failure_raises_exc_and_doesnt_write_marker( + self, m_report_ready, m_request, m_media_switch, m_dhcp, m_net): + """poll_imds should write the report_ready marker file if + reporting ready succeeds""" + report_file = self.tmp_path('report_marker', self.tmp) + m_dhcp.return_value = [{ + 'interface': 'eth9', 'fixed-address': '192.168.2.9', + 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', + 'unknown-245': '624c3620'}] + m_media_switch.return_value = None + m_report_ready.return_value = False + dsa = dsaz.DataSourceAzure({}, distro=None, paths=self.paths) + self.assertFalse(os.path.exists(report_file)) + with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): + self.assertRaises( + InvalidMetaDataException, + dsa._poll_imds) + self.assertEqual(m_report_ready.call_count, 1) + self.assertFalse(os.path.exists(report_file)) -@mock.patch(MOCKPATH + 'subp.subp') -@mock.patch(MOCKPATH + 'util.write_file') +@mock.patch(MOCKPATH + 'DataSourceAzure._report_ready', mock.MagicMock()) +@mock.patch(MOCKPATH + 'subp.subp', mock.MagicMock()) +@mock.patch(MOCKPATH + 'util.write_file', mock.MagicMock()) @mock.patch(MOCKPATH + 'util.is_FreeBSD') @mock.patch('cloudinit.sources.helpers.netlink.' 'wait_for_media_disconnect_connect') @@ -2115,10 +2165,10 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): self.paths = helpers.Paths({'cloud_dir': tmp}) dsaz.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d - def test_poll_imds_returns_ovf_env(self, fake_resp, + def test_poll_imds_returns_ovf_env(self, m_request, m_dhcp, m_net, m_media_switch, - m_is_bsd, write_f, subp): + m_is_bsd): """The _poll_imds method should return the ovf_env.xml.""" m_is_bsd.return_value = False m_media_switch.return_value = None @@ -2128,11 +2178,11 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): url = 'http://{0}/metadata/reprovisiondata?api-version=2017-04-02' host = "169.254.169.254" full_url = url.format(host) - fake_resp.return_value = mock.MagicMock(status_code=200, text="ovf", + m_request.return_value = mock.MagicMock(status_code=200, text="ovf", content="ovf") dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) self.assertTrue(len(dsa._poll_imds()) > 0) - self.assertEqual(fake_resp.call_args_list, + self.assertEqual(m_request.call_args_list, [mock.call(allow_redirects=True, headers={'Metadata': 'true', 'User-Agent': @@ -2147,10 +2197,10 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): static_routes=None) self.assertEqual(m_net.call_count, 2) - def test__reprovision_calls__poll_imds(self, fake_resp, + def test__reprovision_calls__poll_imds(self, m_request, m_dhcp, m_net, m_media_switch, - m_is_bsd, write_f, subp): + m_is_bsd): """The _reprovision method should call poll IMDS.""" m_is_bsd.return_value = False m_media_switch.return_value = None @@ -2165,7 +2215,7 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): username = "myuser" odata = {'HostName': hostname, 'UserName': username} content = construct_valid_ovf_env(data=odata) - fake_resp.return_value = mock.MagicMock(status_code=200, text=content, + m_request.return_value = mock.MagicMock(status_code=200, text=content, content=content) dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) md, _ud, cfg, _d = dsa._reprovision() @@ -2182,7 +2232,7 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, url=full_url ), - fake_resp.call_args_list) + m_request.call_args_list) self.assertEqual(m_dhcp.call_count, 2) m_net.assert_any_call( broadcast='192.168.2.255', interface='eth9', ip='192.168.2.9', diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 5c31b8be..6e004e34 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -742,7 +742,8 @@ class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): self.assertEqual(1, shim.return_value.clean_up.call_count) @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_failure_in_registration_calls_clean_up(self, shim): + def test_failure_in_registration_propagates_exc_and_calls_clean_up( + self, shim): shim.return_value.register_with_azure_and_fetch_data.side_effect = ( SentinelException) self.assertRaises(SentinelException, -- cgit v1.2.3 From d807df288f8cef29ca74f0b00c326b084e825782 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Wed, 18 Nov 2020 09:34:04 -0800 Subject: DataSourceAzure: send failure signal on Azure datasource failure (#594) On systems where the Azure datasource is a viable platform for crawling metadata, cloud-init occasionally encounters fatal irrecoverable errors during the crawling of the Azure datasource. When this happens, cloud-init crashes, and Azure VM provisioning would fail. However, instead of failing immediately, the user will continue seeing provisioning for a long time until it times out with "OS Provisioning Timed Out" message. In these situations, cloud-init should report failure to the Azure datasource endpoint indicating provisioning failure. The user will immediately see provisioning terminate, giving them a much better failure experience instead of pointlessly waiting for OS provisioning timeout. --- cloudinit/sources/DataSourceAzure.py | 73 ++- cloudinit/sources/helpers/azure.py | 80 ++- tests/unittests/test_datasource/test_azure.py | 322 ++++++++++-- .../unittests/test_datasource/test_azure_helper.py | 569 ++++++++++++++++++--- 4 files changed, 921 insertions(+), 123 deletions(-) (limited to 'tests/unittests/test_datasource/test_azure_helper.py') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index fa3e0a2b..ab139b8d 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -29,6 +29,7 @@ from cloudinit import util from cloudinit.reporting import events from cloudinit.sources.helpers.azure import ( + DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE, azure_ds_reporter, azure_ds_telemetry_reporter, get_metadata_from_fabric, @@ -38,7 +39,8 @@ from cloudinit.sources.helpers.azure import ( EphemeralDHCPv4WithReporting, is_byte_swapped, dhcp_log_cb, - push_log_to_kvp) + push_log_to_kvp, + report_failure_to_fabric) LOG = logging.getLogger(__name__) @@ -508,8 +510,9 @@ class DataSourceAzure(sources.DataSource): if perform_reprovision: LOG.info("Reporting ready to Azure after getting ReprovisionData") - use_cached_ephemeral = (net.is_up(self.fallback_interface) and - getattr(self, '_ephemeral_dhcp_ctx', None)) + use_cached_ephemeral = ( + self.distro.networking.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None)) if use_cached_ephemeral: self._report_ready(lease=self._ephemeral_dhcp_ctx.lease) self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral @@ -560,9 +563,14 @@ class DataSourceAzure(sources.DataSource): logfunc=LOG.debug, msg='Crawl of metadata service', func=self.crawl_metadata ) - except sources.InvalidMetaDataException as e: - LOG.warning('Could not crawl Azure metadata: %s', e) + except Exception as e: + report_diagnostic_event( + 'Could not crawl Azure metadata: %s' % e, + logger_func=LOG.error) + self._report_failure( + description=DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) return False + if (self.distro and self.distro.name == 'ubuntu' and self.ds_cfg.get('apply_network_config')): maybe_remove_ubuntu_network_config_scripts() @@ -785,6 +793,61 @@ class DataSourceAzure(sources.DataSource): return return_val @azure_ds_telemetry_reporter + def _report_failure(self, description=None) -> bool: + """Tells the Azure fabric that provisioning has failed. + + @param description: A description of the error encountered. + @return: The success status of sending the failure signal. + """ + unknown_245_key = 'unknown-245' + + try: + if (self.distro.networking.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None) and + getattr(self._ephemeral_dhcp_ctx, 'lease', None) and + unknown_245_key in self._ephemeral_dhcp_ctx.lease): + report_diagnostic_event( + 'Using cached ephemeral dhcp context ' + 'to report failure to Azure', logger_func=LOG.debug) + report_failure_to_fabric( + dhcp_opts=self._ephemeral_dhcp_ctx.lease[unknown_245_key], + description=description) + self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using ' + 'cached ephemeral dhcp context: %s' % e, + logger_func=LOG.error) + + try: + report_diagnostic_event( + 'Using new ephemeral dhcp to report failure to Azure', + logger_func=LOG.debug) + with EphemeralDHCPv4WithReporting(azure_ds_reporter) as lease: + report_failure_to_fabric( + dhcp_opts=lease[unknown_245_key], + description=description) + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using new ephemeral dhcp: %s' % e, + logger_func=LOG.debug) + + try: + report_diagnostic_event( + 'Using fallback lease to report failure to Azure') + report_failure_to_fabric( + fallback_lease_file=self.dhclient_lease_file, + description=description) + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using fallback lease: %s' % e, + logger_func=LOG.debug) + + return False + def _report_ready(self, lease: dict) -> bool: """Tells the fabric provisioning has completed. diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 4071a50e..951c7a10 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -17,6 +17,7 @@ from cloudinit import stages from cloudinit import temp_utils from contextlib import contextmanager from xml.etree import ElementTree +from xml.sax.saxutils import escape from cloudinit import subp from cloudinit import url_helper @@ -50,6 +51,11 @@ azure_ds_reporter = events.ReportEventStack( description="initialize reporter for azure ds", reporting_enabled=True) +DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE = ( + 'The VM encountered an error during deployment. ' + 'Please visit https://aka.ms/linuxprovisioningerror ' + 'for more information on remediation.') + def azure_ds_telemetry_reporter(func): def impl(*args, **kwargs): @@ -379,12 +385,20 @@ class OpenSSLManager: def __init__(self): self.tmpdir = temp_utils.mkdtemp() - self.certificate = None + self._certificate = None self.generate_certificate() def clean_up(self): util.del_dir(self.tmpdir) + @property + def certificate(self): + return self._certificate + + @certificate.setter + def certificate(self, value): + self._certificate = value + @azure_ds_telemetry_reporter def generate_certificate(self): LOG.debug('Generating certificate for communication with fabric...') @@ -507,6 +521,10 @@ class GoalStateHealthReporter: ''') PROVISIONING_SUCCESS_STATUS = 'Ready' + PROVISIONING_NOT_READY_STATUS = 'NotReady' + PROVISIONING_FAILURE_SUBSTATUS = 'ProvisioningFailed' + + HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512 def __init__( self, goal_state: GoalState, @@ -545,19 +563,39 @@ class GoalStateHealthReporter: LOG.info('Reported ready to Azure fabric.') + @azure_ds_telemetry_reporter + def send_failure_signal(self, description: str) -> None: + document = self.build_report( + incarnation=self._goal_state.incarnation, + container_id=self._goal_state.container_id, + instance_id=self._goal_state.instance_id, + status=self.PROVISIONING_NOT_READY_STATUS, + substatus=self.PROVISIONING_FAILURE_SUBSTATUS, + description=description) + try: + self._post_health_report(document=document) + except Exception as e: + msg = "exception while reporting failure: %s" % e + report_diagnostic_event(msg, logger_func=LOG.error) + raise + + LOG.warning('Reported failure to Azure fabric.') + def build_report( self, incarnation: str, container_id: str, instance_id: str, status: str, substatus=None, description=None) -> str: health_detail = '' if substatus is not None: health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( - health_substatus=substatus, health_description=description) + health_substatus=escape(substatus), + health_description=escape( + description[:self.HEALTH_REPORT_DESCRIPTION_TRIM_LEN])) health_report = self.HEALTH_REPORT_XML_TEMPLATE.format( - incarnation=incarnation, - container_id=container_id, - instance_id=instance_id, - health_status=status, + incarnation=escape(str(incarnation)), + container_id=escape(container_id), + instance_id=escape(instance_id), + health_status=escape(status), health_detail_subsection=health_detail) return health_report @@ -797,6 +835,20 @@ class WALinuxAgentShim: health_reporter.send_ready_signal() return {'public-keys': ssh_keys} + @azure_ds_telemetry_reporter + def register_with_azure_and_report_failure(self, description: str) -> None: + """Gets the VM's GoalState from Azure, uses the GoalState information + to report failure/send provisioning failure signal to Azure. + + @param: user visible error description of provisioning failure. + """ + if self.azure_endpoint_client is None: + self.azure_endpoint_client = AzureEndpointHttpClient(None) + goal_state = self._fetch_goal_state_from_azure(need_certificate=False) + health_reporter = GoalStateHealthReporter( + goal_state, self.azure_endpoint_client, self.endpoint) + health_reporter.send_failure_signal(description=description) + @azure_ds_telemetry_reporter def _fetch_goal_state_from_azure( self, @@ -804,6 +856,7 @@ class WALinuxAgentShim: """Fetches the GoalState XML from the Azure endpoint, parses the XML, and returns a GoalState object. + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML """ unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() @@ -844,6 +897,7 @@ class WALinuxAgentShim: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_goal_state_xml: GoalState XML string + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML """ try: @@ -942,6 +996,20 @@ def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, shim.clean_up() +@azure_ds_telemetry_reporter +def report_failure_to_fabric(fallback_lease_file=None, dhcp_opts=None, + description=None): + shim = WALinuxAgentShim(fallback_lease_file=fallback_lease_file, + dhcp_options=dhcp_opts) + if not description: + description = DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE + try: + shim.register_with_azure_and_report_failure( + description=description) + finally: + shim.clean_up() + + def dhcp_log_cb(out, err): report_diagnostic_event( "dhclient output stream: %s" % out, logger_func=LOG.debug) diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 433fbc66..d9752ab7 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -461,6 +461,8 @@ class TestGetMetadataFromIMDS(HttprettyTestCase): class TestAzureDataSource(CiTestCase): + with_logs = True + def setUp(self): super(TestAzureDataSource, self).setUp() self.tmp = self.tmp_dir() @@ -549,9 +551,12 @@ scbus-1 on xpt0 bus 0 dsaz.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d - self.get_metadata_from_fabric = mock.MagicMock(return_value={ - 'public-keys': [], - }) + self.m_is_platform_viable = mock.MagicMock(autospec=True) + self.m_get_metadata_from_fabric = mock.MagicMock( + return_value={'public-keys': []}) + self.m_report_failure_to_fabric = mock.MagicMock(autospec=True) + self.m_ephemeral_dhcpv4 = mock.MagicMock() + self.m_ephemeral_dhcpv4_with_reporting = mock.MagicMock() self.instance_id = 'D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8' @@ -568,7 +573,17 @@ scbus-1 on xpt0 bus 0 (dsaz, 'perform_hostname_bounce', mock.MagicMock()), (dsaz, 'get_hostname', mock.MagicMock()), (dsaz, 'set_hostname', mock.MagicMock()), - (dsaz, 'get_metadata_from_fabric', self.get_metadata_from_fabric), + (dsaz, '_is_platform_viable', + self.m_is_platform_viable), + (dsaz, 'get_metadata_from_fabric', + self.m_get_metadata_from_fabric), + (dsaz, 'report_failure_to_fabric', + self.m_report_failure_to_fabric), + (dsaz, 'EphemeralDHCPv4', self.m_ephemeral_dhcpv4), + (dsaz, 'EphemeralDHCPv4WithReporting', + self.m_ephemeral_dhcpv4_with_reporting), + (dsaz, 'get_boot_telemetry', mock.MagicMock()), + (dsaz, 'get_system_info', mock.MagicMock()), (dsaz.subp, 'which', lambda x: True), (dsaz.dmi, 'read_dmi_data', mock.MagicMock( side_effect=_dmi_mocks)), @@ -632,15 +647,87 @@ scbus-1 on xpt0 bus 0 dev = ds.get_resource_disk_on_freebsd(1) self.assertEqual("da1", dev) - @mock.patch(MOCKPATH + '_is_platform_viable') - def test_call_is_platform_viable_seed(self, m_is_platform_viable): + def test_not_is_platform_viable_seed_should_return_no_datasource(self): """Check seed_dir using _is_platform_viable and return False.""" # Return a non-matching asset tag value - m_is_platform_viable.return_value = False - dsrc = dsaz.DataSourceAzure( - {}, distro=mock.Mock(), paths=self.paths) - self.assertFalse(dsrc.get_data()) - m_is_platform_viable.assert_called_with(dsrc.seed_dir) + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = False + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_report_failure') as m_report_failure: + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + # Assert that for non viable platforms, + # there is no communication with the Azure datasource. + self.assertEqual( + 0, + m_crawl_metadata.call_count) + self.assertEqual( + 0, + m_report_failure.call_count) + + def test_platform_viable_but_no_devs_should_return_no_datasource(self): + """For platforms where the Azure platform is viable + (which is indicated by the matching asset tag), + the absence of any devs at all (devs == candidate sources + for crawling Azure datasource) is NOT expected. + Report failure to Azure as this is an unexpected fatal error. + """ + data = {} + dsrc = self._get_ds(data) + with mock.patch.object(dsrc, '_report_failure') as m_report_failure: + self.m_is_platform_viable.return_value = True + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + self.assertEqual( + 1, + m_report_failure.call_count) + + def test_crawl_metadata_exception_returns_no_datasource(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertEqual( + 1, + m_crawl_metadata.call_count) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + + def test_crawl_metadata_exception_should_report_failure_with_msg(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_report_failure') as m_report_failure: + m_crawl_metadata.side_effect = Exception + dsrc.get_data() + self.assertEqual( + 1, + m_crawl_metadata.call_count) + m_report_failure.assert_called_once_with( + description=dsaz.DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_crawl_metadata_exc_should_log_could_not_crawl_msg(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + dsrc.get_data() + self.assertEqual( + 1, + m_crawl_metadata.call_count) + self.assertIn( + "Could not crawl Azure metadata", + self.logs.getvalue()) def test_basic_seed_dir(self): odata = {'HostName': "myhost", 'UserName': "myuser"} @@ -761,7 +848,7 @@ scbus-1 on xpt0 bus 0 'cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready') @mock.patch('cloudinit.sources.DataSourceAzure.DataSourceAzure._poll_imds') def test_crawl_metadata_on_reprovision_reports_ready( - self, poll_imds_func, report_ready_func, m_write, m_dhcp + self, poll_imds_func, m_report_ready, m_write, m_dhcp ): """If reprovisioning, report ready at the end""" ovfenv = construct_valid_ovf_env( @@ -775,18 +862,16 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) poll_imds_func.return_value = ovfenv dsrc.crawl_metadata() - self.assertEqual(1, report_ready_func.call_count) + self.assertEqual(1, m_report_ready.call_count) @mock.patch('cloudinit.sources.DataSourceAzure.util.write_file') @mock.patch('cloudinit.sources.helpers.netlink.' 'wait_for_media_disconnect_connect') @mock.patch( 'cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready') - @mock.patch('cloudinit.net.dhcp.EphemeralIPv4Network') - @mock.patch('cloudinit.net.dhcp.maybe_perform_dhcp_discovery') @mock.patch('cloudinit.sources.DataSourceAzure.readurl') def test_crawl_metadata_on_reprovision_reports_ready_using_lease( - self, m_readurl, m_dhcp, m_net, report_ready_func, + self, m_readurl, m_report_ready, m_media_switch, m_write ): """If reprovisioning, report ready using the obtained lease""" @@ -800,20 +885,30 @@ scbus-1 on xpt0 bus 0 } dsrc = self._get_ds(data) - lease = { - 'interface': 'eth9', 'fixed-address': '192.168.2.9', - 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', - 'unknown-245': '624c3620'} - m_dhcp.return_value = [lease] - m_media_switch.return_value = None + with mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: - reprovision_ovfenv = construct_valid_ovf_env() - m_readurl.return_value = url_helper.StringResponse( - reprovision_ovfenv.encode('utf-8')) + # For this mock, net should not be up, + # so that cached ephemeral won't be used. + # This is so that a NEW ephemeral dhcp lease will be discovered + # and used instead. + m_dsrc_distro_networking_is_up.return_value = False - dsrc.crawl_metadata() - self.assertEqual(2, report_ready_func.call_count) - report_ready_func.assert_called_with(lease=lease) + lease = { + 'interface': 'eth9', 'fixed-address': '192.168.2.9', + 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', + 'unknown-245': '624c3620'} + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.return_value = lease + m_media_switch.return_value = None + + reprovision_ovfenv = construct_valid_ovf_env() + m_readurl.return_value = url_helper.StringResponse( + reprovision_ovfenv.encode('utf-8')) + + dsrc.crawl_metadata() + self.assertEqual(2, m_report_ready.call_count) + m_report_ready.assert_called_with(lease=lease) def test_waagent_d_has_0700_perms(self): # we expect /var/lib/waagent to be created 0700 @@ -971,7 +1066,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertTrue('default_user' in dsrc.cfg['system_info']) + self.assertIn('default_user', dsrc.cfg['system_info']) defuser = dsrc.cfg['system_info']['default_user'] # default user should be updated username and should not be locked. @@ -993,7 +1088,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertTrue('default_user' in dsrc.cfg['system_info']) + self.assertIn('default_user', dsrc.cfg['system_info']) defuser = dsrc.cfg['system_info']['default_user'] # default user should be updated username and should not be locked. @@ -1021,14 +1116,6 @@ scbus-1 on xpt0 bus 0 self.assertTrue(ret) self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8')) - def test_no_datasource_expected(self): - # no source should be found if no seed_dir and no devs - data = {} - dsrc = self._get_ds({}) - ret = dsrc.get_data() - self.assertFalse(ret) - self.assertFalse('agent_invoked' in data) - def test_cfg_has_pubkeys_fingerprint(self): odata = {'HostName': "myhost", 'UserName': "myuser"} mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}] @@ -1171,21 +1258,168 @@ scbus-1 on xpt0 bus 0 self): dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.side_effect = Exception + self.m_get_metadata_from_fabric.side_effect = Exception self.assertFalse(dsrc._report_ready(lease=mock.MagicMock())) + def test_dsaz_report_failure_returns_true_when_report_succeeds(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) + self.assertEqual( + 1, + self.m_report_failure_to_fabric.call_count) + + def test_dsaz_report_failure_returns_false_and_does_not_propagate_exc( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_ephemeral_dhcp_ctx') \ + as m_ephemeral_dhcp_ctx, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # setup mocks to allow using cached ephemeral dhcp lease + m_dsrc_distro_networking_is_up.return_value = True + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + m_ephemeral_dhcp_ctx.lease = test_lease + + # We expect 3 calls to report_failure_to_fabric, + # because we try 3 different methods of calling report failure. + # The different methods are attempted in the following order: + # 1. Using cached ephemeral dhcp context to report failure to Azure + # 2. Using new ephemeral dhcp to report failure to Azure + # 3. Using fallback lease to report failure to Azure + self.m_report_failure_to_fabric.side_effect = Exception + self.assertFalse(dsrc._report_failure()) + self.assertEqual( + 3, + self.m_report_failure_to_fabric.call_count) + + def test_dsaz_report_failure_description_msg(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + test_msg = 'Test report failure description message' + self.assertTrue(dsrc._report_failure(description=test_msg)) + self.m_report_failure_to_fabric.assert_called_once_with( + dhcp_opts=mock.ANY, description=test_msg) + + def test_dsaz_report_failure_no_description_msg(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) # no description msg + self.m_report_failure_to_fabric.assert_called_once_with( + dhcp_opts=mock.ANY, description=None) + + def test_dsaz_report_failure_uses_cached_ephemeral_dhcp_ctx_lease(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_ephemeral_dhcp_ctx') \ + as m_ephemeral_dhcp_ctx, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # setup mocks to allow using cached ephemeral dhcp lease + m_dsrc_distro_networking_is_up.return_value = True + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + m_ephemeral_dhcp_ctx.lease = test_lease + + self.assertTrue(dsrc._report_failure()) + + # ensure called with cached ephemeral dhcp lease option 245 + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, dhcp_opts=test_lease_dhcp_option_245) + + # ensure cached ephemeral is cleaned + self.assertEqual( + 1, + m_ephemeral_dhcp_ctx.clean_network.call_count) + + def test_dsaz_report_failure_no_net_uses_new_ephemeral_dhcp_lease(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # net is not up and cannot use cached ephemeral dhcp + m_dsrc_distro_networking_is_up.return_value = False + # setup ephemeral dhcp lease discovery mock + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.return_value = test_lease + + self.assertTrue(dsrc._report_failure()) + + # ensure called with the newly discovered + # ephemeral dhcp lease option 245 + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, dhcp_opts=test_lease_dhcp_option_245) + + def test_dsaz_report_failure_no_net_and_no_dhcp_uses_fallback_lease( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # net is not up and cannot use cached ephemeral dhcp + m_dsrc_distro_networking_is_up.return_value = False + # ephemeral dhcp discovery failure, + # so cannot use a new ephemeral dhcp + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) + + # ensure called with fallback lease + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, + fallback_lease_file=dsrc.dhclient_lease_file) + def test_exception_fetching_fabric_data_doesnt_propagate(self): """Errors communicating with fabric should warn, but return True.""" dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.side_effect = Exception + self.m_get_metadata_from_fabric.side_effect = Exception ret = self._get_and_setup(dsrc) self.assertTrue(ret) def test_fabric_data_included_in_metadata(self): dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.return_value = {'test': 'value'} + self.m_get_metadata_from_fabric.return_value = {'test': 'value'} ret = self._get_and_setup(dsrc) self.assertTrue(ret) self.assertEqual('value', dsrc.metadata['test']) @@ -2053,7 +2287,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): @mock.patch('time.sleep', mock.MagicMock()) @mock.patch(MOCKPATH + 'EphemeralDHCPv4') - def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, report_ready_func, + def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, m_report_ready, m_request, m_media_switch, m_dhcp, m_net): """The poll_imds will retry DHCP on IMDS timeout.""" @@ -2088,8 +2322,8 @@ class TestPreprovisioningPollIMDS(CiTestCase): dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): dsa._poll_imds() - self.assertEqual(report_ready_func.call_count, 1) - report_ready_func.assert_called_with(lease=lease) + self.assertEqual(m_report_ready.call_count, 1) + m_report_ready.assert_called_with(lease=lease) self.assertEqual(3, m_dhcpv4.call_count, 'Expected 3 DHCP calls') self.assertEqual(4, self.tries, 'Expected 4 total reads from IMDS') diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 6e004e34..adf68857 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -5,6 +5,7 @@ import re import unittest from textwrap import dedent from xml.etree import ElementTree +from xml.sax.saxutils import escape, unescape from cloudinit.sources.helpers import azure as azure_helper from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -70,6 +71,15 @@ HEALTH_REPORT_XML_TEMPLATE = '''\ ''' +HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = dedent('''\ +
+ {health_substatus} + {health_description} +
+ ''') + +HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512 + class SentinelException(Exception): pass @@ -461,17 +471,24 @@ class TestOpenSSLManagerActions(CiTestCase): class TestGoalStateHealthReporter(CiTestCase): + maxDiff = None + default_parameters = { 'incarnation': 1634, 'container_id': 'MyContainerId', 'instance_id': 'MyInstanceId' } - test_endpoint = 'TestEndpoint' - test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) + test_azure_endpoint = 'TestEndpoint' + test_health_report_url = 'http://{0}/machine?comp=health'.format( + test_azure_endpoint) test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} provisioning_success_status = 'Ready' + provisioning_not_ready_status = 'NotReady' + provisioning_failure_substatus = 'ProvisioningFailed' + provisioning_failure_err_description = ( + 'Test error message containing provisioning failure details') def setUp(self): super(TestGoalStateHealthReporter, self).setUp() @@ -496,17 +513,40 @@ class TestGoalStateHealthReporter(CiTestCase): self.GoalState.return_value.incarnation = \ self.default_parameters['incarnation'] + def _text_from_xpath_in_xroot(self, xroot, xpath): + element = xroot.find(xpath) + if element is not None: + return element.text + return None + def _get_formatted_health_report_xml_string(self, **kwargs): return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + def _get_formatted_health_detail_subsection_xml_string(self, **kwargs): + return HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format(**kwargs) + def _get_report_ready_health_document(self): return self._get_formatted_health_report_xml_string( - incarnation=self.default_parameters['incarnation'], - container_id=self.default_parameters['container_id'], - instance_id=self.default_parameters['instance_id'], - health_status=self.provisioning_success_status, + incarnation=escape(str(self.default_parameters['incarnation'])), + container_id=escape(self.default_parameters['container_id']), + instance_id=escape(self.default_parameters['instance_id']), + health_status=escape(self.provisioning_success_status), health_detail_subsection='') + def _get_report_failure_health_document(self): + health_detail_subsection = \ + self._get_formatted_health_detail_subsection_xml_string( + health_substatus=escape(self.provisioning_failure_substatus), + health_description=escape( + self.provisioning_failure_err_description)) + + return self._get_formatted_health_report_xml_string( + incarnation=escape(str(self.default_parameters['incarnation'])), + container_id=escape(self.default_parameters['container_id']), + instance_id=escape(self.default_parameters['instance_id']), + health_status=escape(self.provisioning_not_ready_status), + health_detail_subsection=health_detail_subsection) + def test_send_ready_signal_sends_post_request(self): with mock.patch.object( azure_helper.GoalStateHealthReporter, @@ -514,55 +554,130 @@ class TestGoalStateHealthReporter(CiTestCase): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), - client, self.test_endpoint) + client, self.test_azure_endpoint) reporter.send_ready_signal() self.assertEqual(1, self.post.call_count) self.assertEqual( mock.call( - self.test_url, + self.test_health_report_url, + data=m_build_report.return_value, + extra_headers=self.test_default_headers), + self.post.call_args) + + def test_send_failure_signal_sends_post_request(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, + 'build_report') as m_build_report: + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + client, self.test_azure_endpoint) + reporter.send_failure_signal( + description=self.provisioning_failure_err_description) + + self.assertEqual(1, self.post.call_count) + self.assertEqual( + mock.call( + self.test_health_report_url, data=m_build_report.return_value, extra_headers=self.test_default_headers), self.post.call_args) - def test_build_report_for_health_document(self): + def test_build_report_for_ready_signal_health_document(self): health_document = self._get_report_ready_health_document() reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), azure_helper.AzureEndpointHttpClient(mock.MagicMock()), - self.test_endpoint) + self.test_azure_endpoint) generated_health_document = reporter.build_report( incarnation=self.default_parameters['incarnation'], container_id=self.default_parameters['container_id'], instance_id=self.default_parameters['instance_id'], status=self.provisioning_success_status) + self.assertEqual(health_document, generated_health_document) - self.assertIn( - '{}'.format( - str(self.default_parameters['incarnation'])), - generated_health_document) - self.assertIn( - ''.join([ - '', - self.default_parameters['container_id'], - '']), - generated_health_document) - self.assertIn( - ''.join([ - '', - self.default_parameters['instance_id'], - '']), - generated_health_document) - self.assertIn( - ''.join([ - '', - self.provisioning_success_status, - '']), - generated_health_document + + generated_xroot = ElementTree.fromstring(generated_health_document) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './GoalStateIncarnation'), + str(self.default_parameters['incarnation'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './Container/ContainerId'), + str(self.default_parameters['container_id'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/InstanceId'), + str(self.default_parameters['instance_id'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/State'), + escape(self.provisioning_success_status)) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details')) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/SubStatus')) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') ) - self.assertNotIn('
', generated_health_document) - self.assertNotIn('', generated_health_document) - self.assertNotIn('', generated_health_document) + + def test_build_report_for_failure_signal_health_document(self): + health_document = self._get_report_failure_health_document() + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=self.provisioning_failure_err_description) + + self.assertEqual(health_document, generated_health_document) + + generated_xroot = ElementTree.fromstring(generated_health_document) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './GoalStateIncarnation'), + str(self.default_parameters['incarnation'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './Container/ContainerId'), + self.default_parameters['container_id']) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/InstanceId'), + self.default_parameters['instance_id']) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/State'), + escape(self.provisioning_not_ready_status)) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/' + 'SubStatus'), + escape(self.provisioning_failure_substatus)) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/' + 'Description'), + escape(self.provisioning_failure_err_description)) def test_send_ready_signal_calls_build_report(self): with mock.patch.object( @@ -571,7 +686,7 @@ class TestGoalStateHealthReporter(CiTestCase): reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), azure_helper.AzureEndpointHttpClient(mock.MagicMock()), - self.test_endpoint) + self.test_azure_endpoint) reporter.send_ready_signal() self.assertEqual(1, m_build_report.call_count) @@ -583,6 +698,131 @@ class TestGoalStateHealthReporter(CiTestCase): status=self.provisioning_success_status), m_build_report.call_args) + def test_send_failure_signal_calls_build_report(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, 'build_report' + ) as m_build_report: + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + reporter.send_failure_signal( + description=self.provisioning_failure_err_description) + + self.assertEqual(1, m_build_report.call_count) + self.assertEqual( + mock.call( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=self.provisioning_failure_err_description), + m_build_report.call_args) + + def test_build_report_escapes_chars(self): + incarnation = 'jd8\'9*&^<\'A>' + instance_id = 'Opo>>>jas\'&d;[p&fp\"a<&aa\'sd!@&!)((*<&>' + health_substatus = '&as\"d<d<\'^@!5&6<7' + health_description = '&&&>!#$\"&&><>&\"sd<67<]>>' + + health_detail_subsection = \ + self._get_formatted_health_detail_subsection_xml_string( + health_substatus=escape(health_substatus), + health_description=escape(health_description)) + health_document = self._get_formatted_health_report_xml_string( + incarnation=escape(incarnation), + container_id=escape(container_id), + instance_id=escape(instance_id), + health_status=escape(health_status), + health_detail_subsection=health_detail_subsection) + + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + generated_health_document = reporter.build_report( + incarnation=incarnation, + container_id=container_id, + instance_id=instance_id, + status=health_status, + substatus=health_substatus, + description=health_description) + + self.assertEqual(health_document, generated_health_document) + + def test_build_report_conforms_to_length_limits(self): + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + long_err_msg = 'a9&ea8>>>e as1< d\"q2*&(^%\'a=5<' * 100 + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=long_err_msg) + + generated_xroot = ElementTree.fromstring(generated_health_document) + generated_health_report_description = self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') + self.assertEqual( + len(unescape(generated_health_report_description)), + HEALTH_REPORT_DESCRIPTION_TRIM_LEN) + + def test_trim_description_then_escape_conforms_to_len_limits_worst_case( + self): + """When unescaped characters are XML-escaped, the length increases. + Char Escape String + < < + > > + " " + ' ' + & & + + We (step 1) trim the health report XML's description field, + and then (step 2) XML-escape the health report XML's description field. + + The health report XML's description field limit within cloud-init + is HEALTH_REPORT_DESCRIPTION_TRIM_LEN. + + The Azure platform's limit on the health report XML's description field + is 4096 chars. + + For worst-case chars, there is a 5x blowup in length + when the chars are XML-escaped. + ' and " when XML-escaped have a 5x blowup. + + Ensure that (1) trimming and then (2) XML-escaping does not blow past + the Azure platform's limit for health report XML's description field + (4096 chars). + """ + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + long_err_msg = '\'\"' * 10000 + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=long_err_msg) + + generated_xroot = ElementTree.fromstring(generated_health_document) + generated_health_report_description = self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') + # The escaped description string should be less than + # the Azure platform limit for the escaped description string. + self.assertLessEqual(len(generated_health_report_description), 4096) + class TestWALinuxAgentShim(CiTestCase): @@ -598,7 +838,7 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState = patches.enter_context( mock.patch.object(azure_helper, 'GoalState')) self.OpenSSLManager = patches.enter_context( - mock.patch.object(azure_helper, 'OpenSSLManager')) + mock.patch.object(azure_helper, 'OpenSSLManager', autospec=True)) patches.enter_context( mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) @@ -609,24 +849,47 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.return_value.container_id = self.test_container_id self.GoalState.return_value.instance_id = self.test_instance_id - def test_http_client_does_not_use_certificate(self): + def test_http_client_does_not_use_certificate_for_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(None)], self.AzureEndpointHttpClient.call_args_list) + def test_http_client_does_not_use_certificate_for_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + self.assertEqual( + [mock.call(None)], + self.AzureEndpointHttpClient.call_args_list) + def test_correct_url_used_for_goalstate_during_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = wa_shim() shim.register_with_azure_and_fetch_data() - get = self.AzureEndpointHttpClient.return_value.get + m_get = self.AzureEndpointHttpClient.return_value.get + self.assertEqual( + [mock.call('http://test_endpoint/machine/?comp=goalstate')], + m_get.call_args_list) + self.assertEqual( + [mock.call( + m_get.return_value.contents, + self.AzureEndpointHttpClient.return_value, + False + )], + self.GoalState.call_args_list) + + def test_correct_url_used_for_goalstate_during_report_failure(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + m_get = self.AzureEndpointHttpClient.return_value.get self.assertEqual( [mock.call('http://test_endpoint/machine/?comp=goalstate')], - get.call_args_list) + m_get.call_args_list) self.assertEqual( [mock.call( - get.return_value.contents, + m_get.return_value.contents, self.AzureEndpointHttpClient.return_value, False )], @@ -670,6 +933,16 @@ class TestWALinuxAgentShim(CiTestCase): self.AzureEndpointHttpClient.return_value.post .call_args_list) + def test_correct_url_used_for_report_failure(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + expected_url = 'http://test_endpoint/machine?comp=health' + self.assertEqual( + [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], + self.AzureEndpointHttpClient.return_value.post + .call_args_list) + def test_goal_state_values_used_for_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() @@ -681,44 +954,128 @@ class TestWALinuxAgentShim(CiTestCase): self.assertIn(self.test_container_id, posted_document) self.assertIn(self.test_instance_id, posted_document) - def test_xml_elems_in_report_ready(self): + def test_goal_state_values_used_for_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data'] + ) + self.assertIn(self.test_incarnation, posted_document) + self.assertIn(self.test_container_id, posted_document) + self.assertIn(self.test_instance_id, posted_document) + + def test_xml_elems_in_report_ready_post(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() health_document = HEALTH_REPORT_XML_TEMPLATE.format( - incarnation=self.test_incarnation, - container_id=self.test_container_id, - instance_id=self.test_instance_id, - health_status='Ready', + incarnation=escape(self.test_incarnation), + container_id=escape(self.test_container_id), + instance_id=escape(self.test_instance_id), + health_status=escape('Ready'), health_detail_subsection='') posted_document = ( self.AzureEndpointHttpClient.return_value.post .call_args[1]['data']) self.assertEqual(health_document, posted_document) + def test_xml_elems_in_report_failure_post(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + health_document = HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=escape(self.test_incarnation), + container_id=escape(self.test_container_id), + instance_id=escape(self.test_instance_id), + health_status=escape('NotReady'), + health_detail_subsection=HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE + .format( + health_substatus=escape('ProvisioningFailed'), + health_description=escape('TestDesc'))) + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data']) + self.assertEqual(health_document, posted_document) + + @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True) + def test_register_with_azure_and_fetch_data_calls_send_ready_signal( + self, m_goal_state_health_reporter): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + self.assertEqual( + 1, + m_goal_state_health_reporter.return_value.send_ready_signal + .call_count) + + @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True) + def test_register_with_azure_and_report_failure_calls_send_failure_signal( + self, m_goal_state_health_reporter): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + m_goal_state_health_reporter.return_value.send_failure_signal \ + .assert_called_once_with(description='TestDesc') + + def test_register_with_azure_and_report_failure_does_not_need_certificates( + self): + shim = wa_shim() + with mock.patch.object( + shim, '_fetch_goal_state_from_azure', autospec=True + ) as m_fetch_goal_state_from_azure: + shim.register_with_azure_and_report_failure(description='TestDesc') + m_fetch_goal_state_from_azure.assert_called_once_with( + need_certificate=False) + def test_clean_up_can_be_called_at_any_time(self): shim = wa_shim() shim.clean_up() + def test_openssl_manager_not_instantiated_by_shim_report_status(self): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + shim.register_with_azure_and_report_failure(description='TestDesc') + shim.clean_up() + self.OpenSSLManager.assert_not_called() + def test_clean_up_after_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() shim.clean_up() - self.assertEqual( - 0, self.OpenSSLManager.return_value.clean_up.call_count) + self.OpenSSLManager.return_value.clean_up.assert_not_called() + + def test_clean_up_after_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + shim.clean_up() + self.OpenSSLManager.return_value.clean_up.assert_not_called() def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): self.AzureEndpointHttpClient.return_value.get \ - .side_effect = (SentinelException) + .side_effect = SentinelException shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_failure_raises_exc_on_get_exc(self): + self.AzureEndpointHttpClient.return_value.get \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): self.GoalState.side_effect = SentinelException shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_failure_raises_exc_on_parse_exc( + self): + self.GoalState.side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + def test_failure_to_send_report_ready_health_doc_bubbles_up(self): self.AzureEndpointHttpClient.return_value.post \ .side_effect = SentinelException @@ -726,56 +1083,132 @@ class TestWALinuxAgentShim(CiTestCase): self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_failure_to_send_report_failure_health_doc_bubbles_up(self): + self.AzureEndpointHttpClient.return_value.post \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_data_from_shim_returned(self, shim): + def setUp(self): + super(TestGetMetadataGoalStateXMLAndReportReadyToFabric, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_shim = patches.enter_context( + mock.patch.object(azure_helper, 'WALinuxAgentShim')) + + def test_data_from_shim_returned(self): ret = azure_helper.get_metadata_from_fabric() self.assertEqual( - shim.return_value.register_with_azure_and_fetch_data.return_value, + self.m_shim.return_value.register_with_azure_and_fetch_data + .return_value, ret) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_success_calls_clean_up(self, shim): + def test_success_calls_clean_up(self): azure_helper.get_metadata_from_fabric() - self.assertEqual(1, shim.return_value.clean_up.call_count) + self.assertEqual(1, self.m_shim.return_value.clean_up.call_count) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_failure_in_registration_propagates_exc_and_calls_clean_up( - self, shim): - shim.return_value.register_with_azure_and_fetch_data.side_effect = ( - SentinelException) + self): + self.m_shim.return_value.register_with_azure_and_fetch_data \ + .side_effect = SentinelException self.assertRaises(SentinelException, azure_helper.get_metadata_from_fabric) - self.assertEqual(1, shim.return_value.clean_up.call_count) + self.assertEqual(1, self.m_shim.return_value.clean_up.call_count) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_calls_shim_register_with_azure_and_fetch_data(self, shim): + def test_calls_shim_register_with_azure_and_fetch_data(self): m_pubkey_info = mock.MagicMock() azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info) self.assertEqual( 1, - shim.return_value + self.m_shim.return_value .register_with_azure_and_fetch_data.call_count) self.assertEqual( mock.call(pubkey_info=m_pubkey_info), - shim.return_value + self.m_shim.return_value .register_with_azure_and_fetch_data.call_args) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_instantiates_shim_with_kwargs(self, shim): + def test_instantiates_shim_with_kwargs(self): m_fallback_lease_file = mock.MagicMock() m_dhcp_options = mock.MagicMock() azure_helper.get_metadata_from_fabric( fallback_lease_file=m_fallback_lease_file, dhcp_opts=m_dhcp_options) - self.assertEqual(1, shim.call_count) + self.assertEqual(1, self.m_shim.call_count) self.assertEqual( mock.call( fallback_lease_file=m_fallback_lease_file, dhcp_options=m_dhcp_options), - shim.call_args) + self.m_shim.call_args) + + +class TestGetMetadataGoalStateXMLAndReportFailureToFabric(CiTestCase): + + def setUp(self): + super( + TestGetMetadataGoalStateXMLAndReportFailureToFabric, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_shim = patches.enter_context( + mock.patch.object(azure_helper, 'WALinuxAgentShim')) + + def test_success_calls_clean_up(self): + azure_helper.report_failure_to_fabric() + self.assertEqual( + 1, + self.m_shim.return_value.clean_up.call_count) + + def test_failure_in_shim_report_failure_propagates_exc_and_calls_clean_up( + self): + self.m_shim.return_value.register_with_azure_and_report_failure \ + .side_effect = SentinelException + self.assertRaises(SentinelException, + azure_helper.report_failure_to_fabric) + self.assertEqual( + 1, + self.m_shim.return_value.clean_up.call_count) + + def test_report_failure_to_fabric_with_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric(description='TestDesc') + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with(description='TestDesc') + + def test_report_failure_to_fabric_with_no_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric() + # default err message description should be shown to the user + # if no description is passed in + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with( + description=azure_helper + .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_report_failure_to_fabric_empty_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric(description='') + # default err message description should be shown to the user + # if an empty description is passed in + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with( + description=azure_helper + .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_instantiates_shim_with_kwargs(self): + m_fallback_lease_file = mock.MagicMock() + m_dhcp_options = mock.MagicMock() + azure_helper.report_failure_to_fabric( + fallback_lease_file=m_fallback_lease_file, + dhcp_opts=m_dhcp_options) + self.m_shim.assert_called_once_with( + fallback_lease_file=m_fallback_lease_file, + dhcp_options=m_dhcp_options) class TestExtractIpAddressFromNetworkd(CiTestCase): -- cgit v1.2.3 From 6df0230b1201d6bed8661b19d8f3758797635377 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Wed, 18 Nov 2020 10:02:56 -0800 Subject: Azure helper: Increase Azure Endpoint HTTP retries (#619) Increase Azure Endpoint HTTP retries to handle occasional platform network blips. Introduce a common method http_with_retries in the azure.py helper, which will serve as the common HTTP request handler for all HTTP requests with the Azure endpoint. This method has builtin retries and reporting diagnostics logic. --- cloudinit/sources/helpers/azure.py | 55 ++++- .../unittests/test_datasource/test_azure_helper.py | 227 +++++++++++++++++---- 2 files changed, 241 insertions(+), 41 deletions(-) (limited to 'tests/unittests/test_datasource/test_azure_helper.py') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 951c7a10..2b3303c7 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -284,6 +284,54 @@ def _get_dhcp_endpoint_option_name(): return azure_endpoint +@azure_ds_telemetry_reporter +def http_with_retries(url, **kwargs) -> str: + """Wrapper around url_helper.readurl() with custom telemetry logging + that url_helper.readurl() does not provide. + """ + exc = None + + max_readurl_attempts = 240 + default_readurl_timeout = 5 + periodic_logging_attempts = 12 + + if 'timeout' not in kwargs: + kwargs['timeout'] = default_readurl_timeout + + # remove kwargs that cause url_helper.readurl to retry, + # since we are already implementing our own retry logic. + if kwargs.pop('retries', None): + LOG.warning( + 'Ignoring retries kwarg passed in for ' + 'communication with Azure endpoint.') + if kwargs.pop('infinite', None): + LOG.warning( + 'Ignoring infinite kwarg passed in for communication ' + 'with Azure endpoint.') + + for attempt in range(1, max_readurl_attempts + 1): + try: + ret = url_helper.readurl(url, **kwargs) + + report_diagnostic_event( + 'Successful HTTP request with Azure endpoint %s after ' + '%d attempts' % (url, attempt), + logger_func=LOG.debug) + + return ret + + except Exception as e: + exc = e + if attempt % periodic_logging_attempts == 0: + report_diagnostic_event( + 'Failed HTTP request with Azure endpoint %s during ' + 'attempt %d with exception: %s' % + (url, attempt, e), + logger_func=LOG.debug) + + raise exc + + class AzureEndpointHttpClient: headers = { @@ -302,16 +350,15 @@ class AzureEndpointHttpClient: if secure: headers = self.headers.copy() headers.update(self.extra_secure_headers) - return url_helper.readurl(url, headers=headers, - timeout=5, retries=10, sec_between=5) + return http_with_retries(url, headers=headers) def post(self, url, data=None, extra_headers=None): headers = self.headers if extra_headers is not None: headers = self.headers.copy() headers.update(extra_headers) - return url_helper.readurl(url, data=data, headers=headers, - timeout=5, retries=10, sec_between=5) + return http_with_retries( + url, data=data, headers=headers) class InvalidGoalStateXMLException(Exception): diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index adf68857..b8899807 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,5 +1,6 @@ # This file is part of cloud-init. See LICENSE file for license information. +import copy import os import re import unittest @@ -291,29 +292,25 @@ class TestAzureEndpointHttpClient(CiTestCase): super(TestAzureEndpointHttpClient, self).setUp() patches = ExitStack() self.addCleanup(patches.close) - - self.readurl = patches.enter_context( - mock.patch.object(azure_helper.url_helper, 'readurl')) - patches.enter_context( - mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + self.m_http_with_retries = patches.enter_context( + mock.patch.object(azure_helper, 'http_with_retries')) def test_non_secure_get(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) url = 'MyTestUrl' response = client.get(url, secure=False) - self.assertEqual(1, self.readurl.call_count) - self.assertEqual(self.readurl.return_value, response) + self.assertEqual(1, self.m_http_with_retries.call_count) + self.assertEqual(self.m_http_with_retries.return_value, response) self.assertEqual( - mock.call(url, headers=self.regular_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, headers=self.regular_headers), + self.m_http_with_retries.call_args) def test_non_secure_get_raises_exception(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException url = 'MyTestUrl' - with self.assertRaises(SentinelException): - client.get(url, secure=False) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises(SentinelException, client.get, url, secure=False) + self.assertEqual(1, self.m_http_with_retries.call_count) def test_secure_get(self): url = 'MyTestUrl' @@ -325,39 +322,37 @@ class TestAzureEndpointHttpClient(CiTestCase): }) client = azure_helper.AzureEndpointHttpClient(m_certificate) response = client.get(url, secure=True) - self.assertEqual(1, self.readurl.call_count) - self.assertEqual(self.readurl.return_value, response) + self.assertEqual(1, self.m_http_with_retries.call_count) + self.assertEqual(self.m_http_with_retries.return_value, response) self.assertEqual( - mock.call(url, headers=expected_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, headers=expected_headers), + self.m_http_with_retries.call_args) def test_secure_get_raises_exception(self): url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException - with self.assertRaises(SentinelException): - client.get(url, secure=True) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises(SentinelException, client.get, url, secure=True) + self.assertEqual(1, self.m_http_with_retries.call_count) def test_post(self): m_data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) response = client.post(url, data=m_data) - self.assertEqual(1, self.readurl.call_count) - self.assertEqual(self.readurl.return_value, response) + self.assertEqual(1, self.m_http_with_retries.call_count) + self.assertEqual(self.m_http_with_retries.return_value, response) self.assertEqual( - mock.call(url, data=m_data, headers=self.regular_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, data=m_data, headers=self.regular_headers), + self.m_http_with_retries.call_args) def test_post_raises_exception(self): m_data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException - with self.assertRaises(SentinelException): - client.post(url, data=m_data) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises(SentinelException, client.post, url, data=m_data) + self.assertEqual(1, self.m_http_with_retries.call_count) def test_post_with_extra_headers(self): url = 'MyTestUrl' @@ -366,21 +361,179 @@ class TestAzureEndpointHttpClient(CiTestCase): client.post(url, extra_headers=extra_headers) expected_headers = self.regular_headers.copy() expected_headers.update(extra_headers) - self.assertEqual(1, self.readurl.call_count) + self.assertEqual(1, self.m_http_with_retries.call_count) self.assertEqual( - mock.call(mock.ANY, data=mock.ANY, headers=expected_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, data=mock.ANY, headers=expected_headers), + self.m_http_with_retries.call_args) def test_post_with_sleep_with_extra_headers_raises_exception(self): m_data = mock.MagicMock() url = 'MyTestUrl' extra_headers = {'test': 'header'} client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException - with self.assertRaises(SentinelException): - client.post( - url, data=m_data, extra_headers=extra_headers) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises( + SentinelException, client.post, + url, data=m_data, extra_headers=extra_headers) + self.assertEqual(1, self.m_http_with_retries.call_count) + + +class TestAzureHelperHttpWithRetries(CiTestCase): + + with_logs = True + + max_readurl_attempts = 240 + default_readurl_timeout = 5 + periodic_logging_attempts = 12 + + def setUp(self): + super(TestAzureHelperHttpWithRetries, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_readurl = patches.enter_context( + mock.patch.object( + azure_helper.url_helper, 'readurl', mock.MagicMock())) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + + def test_http_with_retries(self): + self.m_readurl.return_value = 'TestResp' + self.assertEqual( + azure_helper.http_with_retries('testurl'), + self.m_readurl.return_value) + self.assertEqual(self.m_readurl.call_count, 1) + + def test_http_with_retries_propagates_readurl_exc_and_logs_exc( + self): + self.m_readurl.side_effect = SentinelException + + self.assertRaises( + SentinelException, azure_helper.http_with_retries, 'testurl') + self.assertEqual(self.m_readurl.call_count, self.max_readurl_attempts) + + self.assertIsNotNone( + re.search( + r'Failed HTTP request with Azure endpoint \S* during ' + r'attempt \d+ with exception: \S*', + self.logs.getvalue())) + self.assertIsNone( + re.search( + r'Successful HTTP request with Azure endpoint \S* after ' + r'\d+ attempts', + self.logs.getvalue())) + + def test_http_with_retries_delayed_success_due_to_temporary_readurl_exc( + self): + self.m_readurl.side_effect = \ + [SentinelException] * self.periodic_logging_attempts + \ + ['TestResp'] + self.m_readurl.return_value = 'TestResp' + + response = azure_helper.http_with_retries('testurl') + self.assertEqual( + response, + self.m_readurl.return_value) + self.assertEqual( + self.m_readurl.call_count, + self.periodic_logging_attempts + 1) + + def test_http_with_retries_long_delay_logs_periodic_failure_msg(self): + self.m_readurl.side_effect = \ + [SentinelException] * self.periodic_logging_attempts + \ + ['TestResp'] + self.m_readurl.return_value = 'TestResp' + + azure_helper.http_with_retries('testurl') + + self.assertEqual( + self.m_readurl.call_count, + self.periodic_logging_attempts + 1) + self.assertIsNotNone( + re.search( + r'Failed HTTP request with Azure endpoint \S* during ' + r'attempt \d+ with exception: \S*', + self.logs.getvalue())) + self.assertIsNotNone( + re.search( + r'Successful HTTP request with Azure endpoint \S* after ' + r'\d+ attempts', + self.logs.getvalue())) + + def test_http_with_retries_short_delay_does_not_log_periodic_failure_msg( + self): + self.m_readurl.side_effect = \ + [SentinelException] * \ + (self.periodic_logging_attempts - 1) + \ + ['TestResp'] + self.m_readurl.return_value = 'TestResp' + + azure_helper.http_with_retries('testurl') + self.assertEqual( + self.m_readurl.call_count, + self.periodic_logging_attempts) + + self.assertIsNone( + re.search( + r'Failed HTTP request with Azure endpoint \S* during ' + r'attempt \d+ with exception: \S*', + self.logs.getvalue())) + self.assertIsNotNone( + re.search( + r'Successful HTTP request with Azure endpoint \S* after ' + r'\d+ attempts', + self.logs.getvalue())) + + def test_http_with_retries_calls_url_helper_readurl_with_args_kwargs(self): + testurl = mock.MagicMock() + kwargs = { + 'headers': mock.MagicMock(), + 'data': mock.MagicMock(), + # timeout kwarg should not be modified or deleted if present + 'timeout': mock.MagicMock() + } + azure_helper.http_with_retries(testurl, **kwargs) + self.m_readurl.assert_called_once_with(testurl, **kwargs) + + def test_http_with_retries_adds_timeout_kwarg_if_not_present(self): + testurl = mock.MagicMock() + kwargs = { + 'headers': mock.MagicMock(), + 'data': mock.MagicMock() + } + expected_kwargs = copy.deepcopy(kwargs) + expected_kwargs['timeout'] = self.default_readurl_timeout + + azure_helper.http_with_retries(testurl, **kwargs) + self.m_readurl.assert_called_once_with(testurl, **expected_kwargs) + + def test_http_with_retries_deletes_retries_kwargs_passed_in( + self): + """http_with_retries already implements retry logic, + so url_helper.readurl should not have retries. + http_with_retries should delete kwargs that + cause url_helper.readurl to retry. + """ + testurl = mock.MagicMock() + kwargs = { + 'headers': mock.MagicMock(), + 'data': mock.MagicMock(), + 'timeout': mock.MagicMock(), + 'retries': mock.MagicMock(), + 'infinite': mock.MagicMock() + } + expected_kwargs = copy.deepcopy(kwargs) + expected_kwargs.pop('retries', None) + expected_kwargs.pop('infinite', None) + + azure_helper.http_with_retries(testurl, **kwargs) + self.m_readurl.assert_called_once_with(testurl, **expected_kwargs) + self.assertIn( + 'retries kwarg passed in for communication with Azure endpoint.', + self.logs.getvalue()) + self.assertIn( + 'infinite kwarg passed in for communication with Azure endpoint.', + self.logs.getvalue()) class TestOpenSSLManager(CiTestCase): -- cgit v1.2.3