From 13362f536e9d8a092ec20dcb5abe7a0b86407f45 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Fri, 28 Aug 2020 08:23:59 -0700 Subject: Add method type hints for Azure helper (#540) This reverts commit 8d25d5e6fac39ab3319ec5d37d23196429fb0c95. --- cloudinit/sources/helpers/azure.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index b968a96f..507f6ac8 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -288,7 +288,8 @@ class InvalidGoalStateXMLException(Exception): class GoalState: - def __init__(self, unparsed_xml, azure_endpoint_client): + def __init__(self, unparsed_xml: str, + azure_endpoint_client: AzureEndpointHttpClient) -> None: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_xml: string representing a GoalState XML. @@ -478,7 +479,10 @@ class GoalStateHealthReporter: PROVISIONING_SUCCESS_STATUS = 'Ready' - def __init__(self, goal_state, azure_endpoint_client, endpoint): + def __init__( + self, goal_state: GoalState, + azure_endpoint_client: AzureEndpointHttpClient, + endpoint: str) -> None: """Creates instance that will report provisioning status to an endpoint @param goal_state: An instance of class GoalState that contains @@ -495,7 +499,7 @@ class GoalStateHealthReporter: self._endpoint = endpoint @azure_ds_telemetry_reporter - def send_ready_signal(self): + def send_ready_signal(self) -> None: document = self.build_report( incarnation=self._goal_state.incarnation, container_id=self._goal_state.container_id, @@ -513,8 +517,8 @@ class GoalStateHealthReporter: LOG.info('Reported ready to Azure fabric.') def build_report( - self, incarnation, container_id, instance_id, - status, substatus=None, description=None): + self, incarnation: str, container_id: str, instance_id: str, + status: str, substatus=None, description=None) -> str: health_detail = '' if substatus is not None: health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( @@ -530,7 +534,7 @@ class GoalStateHealthReporter: return health_report @azure_ds_telemetry_reporter - def _post_health_report(self, document): + def _post_health_report(self, document: str) -> None: push_log_to_kvp() # Whenever report_diagnostic_event(diagnostic_msg) is invoked in code, @@ -726,7 +730,7 @@ class WALinuxAgentShim: return endpoint_ip_address @azure_ds_telemetry_reporter - def register_with_azure_and_fetch_data(self, pubkey_info=None): + def register_with_azure_and_fetch_data(self, pubkey_info=None) -> dict: """Gets the VM's GoalState from Azure, uses the GoalState information to report ready/send the ready signal/provisioning complete signal to Azure, and then uses pubkey_info to filter and obtain the user's @@ -750,7 +754,7 @@ class WALinuxAgentShim: return {'public-keys': ssh_keys} @azure_ds_telemetry_reporter - def _fetch_goal_state_from_azure(self): + def _fetch_goal_state_from_azure(self) -> GoalState: """Fetches the GoalState XML from the Azure endpoint, parses the XML, and returns a GoalState object. @@ -760,7 +764,7 @@ class WALinuxAgentShim: return self._parse_raw_goal_state_xml(unparsed_goal_state_xml) @azure_ds_telemetry_reporter - def _get_raw_goal_state_xml_from_azure(self): + def _get_raw_goal_state_xml_from_azure(self) -> str: """Fetches the GoalState XML from the Azure endpoint and returns the XML as a string. @@ -780,7 +784,8 @@ class WALinuxAgentShim: return response.contents @azure_ds_telemetry_reporter - def _parse_raw_goal_state_xml(self, unparsed_goal_state_xml): + def _parse_raw_goal_state_xml( + self, unparsed_goal_state_xml: str) -> GoalState: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_goal_state_xml: GoalState XML string @@ -803,7 +808,8 @@ class WALinuxAgentShim: return goal_state @azure_ds_telemetry_reporter - def _get_user_pubkeys(self, goal_state, pubkey_info): + def _get_user_pubkeys( + self, goal_state: GoalState, pubkey_info: list) -> list: """Gets and filters the VM admin user's authorized pubkeys. The admin user in this case is the username specified as "admin" @@ -838,7 +844,7 @@ class WALinuxAgentShim: return ssh_keys @staticmethod - def _filter_pubkeys(keys_by_fingerprint, pubkey_info): + def _filter_pubkeys(keys_by_fingerprint: dict, pubkey_info: list) -> list: """ Filter and return only the user's actual pubkeys. @param keys_by_fingerprint: pubkey fingerprint -> pubkey value dict -- cgit v1.2.3 From e56b55452549cb037da0a4165154ffa494e9678a Mon Sep 17 00:00:00 2001 From: Thomas Stringer Date: Thu, 10 Sep 2020 14:29:54 -0400 Subject: Retrieve SSH keys from IMDS first with OVF as a fallback (#509) * pull ssh keys from imds first and fall back to ovf if unavailable * refactor log and diagnostic messages * refactor the OpenSSLManager instantiation and certificate usage * fix unit test where exception was being silenced for generate cert * fix tests now that certificate is not always generated * add documentation for ssh key retrieval * add ability to check if http client has security enabled * refactor certificate logic to GoalState --- cloudinit/sources/DataSourceAzure.py | 53 +++++++++++++++++- cloudinit/sources/helpers/azure.py | 50 ++++++++++++----- doc/rtd/topics/datasources/azure.rst | 6 ++ tests/unittests/test_datasource/test_azure.py | 64 ++++++++++++++++++---- .../unittests/test_datasource/test_azure_helper.py | 13 +++-- 5 files changed, 156 insertions(+), 30 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index f3c6452b..e98fd497 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -561,6 +561,40 @@ class DataSourceAzure(sources.DataSource): def device_name_to_device(self, name): return self.ds_cfg['disk_aliases'].get(name) + @azure_ds_telemetry_reporter + def get_public_ssh_keys(self): + """ + Try to get the ssh keys from IMDS first, and if that fails + (i.e. IMDS is unavailable) then fallback to getting the ssh + keys from OVF. + + The benefit to getting keys from IMDS is a large performance + advantage, so this is a strong preference. But we must keep + OVF as a second option for environments that don't have IMDS. + """ + LOG.debug('Retrieving public SSH keys') + ssh_keys = [] + try: + ssh_keys = [ + public_key['keyData'] + for public_key + in self.metadata['imds']['compute']['publicKeys'] + ] + LOG.debug('Retrieved SSH keys from IMDS') + except KeyError: + log_msg = 'Unable to get keys from IMDS, falling back to OVF' + LOG.debug(log_msg) + report_diagnostic_event(log_msg) + try: + ssh_keys = self.metadata['public-keys'] + LOG.debug('Retrieved keys from OVF') + except KeyError: + log_msg = 'No keys available from OVF' + LOG.debug(log_msg) + report_diagnostic_event(log_msg) + + return ssh_keys + def get_config_obj(self): return self.cfg @@ -764,7 +798,22 @@ class DataSourceAzure(sources.DataSource): if self.ds_cfg['agent_command'] == AGENT_START_BUILTIN: self.bounce_network_with_azure_hostname() - pubkey_info = self.cfg.get('_pubkeys', None) + pubkey_info = None + try: + public_keys = self.metadata['imds']['compute']['publicKeys'] + LOG.debug( + 'Successfully retrieved %s key(s) from IMDS', + len(public_keys) + if public_keys is not None + else 0 + ) + except KeyError: + LOG.debug( + 'Unable to retrieve SSH keys from IMDS during ' + 'negotiation, falling back to OVF' + ) + pubkey_info = self.cfg.get('_pubkeys', None) + metadata_func = partial(get_metadata_from_fabric, fallback_lease_file=self. dhclient_lease_file, @@ -1443,7 +1492,7 @@ def get_metadata_from_imds(fallback_nic, retries): @azure_ds_telemetry_reporter def _get_metadata_from_imds(retries): - url = IMDS_URL + "instance?api-version=2017-12-01" + url = IMDS_URL + "instance?api-version=2019-06-01" headers = {"Metadata": "true"} try: response = readurl( diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 507f6ac8..79445a81 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -288,12 +288,16 @@ class InvalidGoalStateXMLException(Exception): class GoalState: - def __init__(self, unparsed_xml: str, - azure_endpoint_client: AzureEndpointHttpClient) -> None: + def __init__( + self, + unparsed_xml: str, + azure_endpoint_client: AzureEndpointHttpClient, + need_certificate: bool = True) -> None: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_xml: string representing a GoalState XML. - @param azure_endpoint_client: instance of AzureEndpointHttpClient + @param azure_endpoint_client: instance of AzureEndpointHttpClient. + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML string. """ self.azure_endpoint_client = azure_endpoint_client @@ -322,7 +326,7 @@ class GoalState: url = self._text_from_xpath( './Container/RoleInstanceList/RoleInstance' '/Configuration/Certificates') - if url is not None: + if url is not None and need_certificate: with events.ReportEventStack( name="get-certificates-xml", description="get certificates xml", @@ -741,27 +745,38 @@ class WALinuxAgentShim: GoalState. @return: The list of user's authorized pubkey values. """ - if self.openssl_manager is None: + http_client_certificate = None + if self.openssl_manager is None and pubkey_info is not None: self.openssl_manager = OpenSSLManager() + http_client_certificate = self.openssl_manager.certificate if self.azure_endpoint_client is None: self.azure_endpoint_client = AzureEndpointHttpClient( - self.openssl_manager.certificate) - goal_state = self._fetch_goal_state_from_azure() - ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) + http_client_certificate) + goal_state = self._fetch_goal_state_from_azure( + need_certificate=http_client_certificate is not None + ) + ssh_keys = None + if pubkey_info is not None: + ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) health_reporter = GoalStateHealthReporter( goal_state, self.azure_endpoint_client, self.endpoint) health_reporter.send_ready_signal() return {'public-keys': ssh_keys} @azure_ds_telemetry_reporter - def _fetch_goal_state_from_azure(self) -> GoalState: + def _fetch_goal_state_from_azure( + self, + need_certificate: bool) -> GoalState: """Fetches the GoalState XML from the Azure endpoint, parses the XML, and returns a GoalState object. @return: GoalState object representing the GoalState XML """ unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() - return self._parse_raw_goal_state_xml(unparsed_goal_state_xml) + return self._parse_raw_goal_state_xml( + unparsed_goal_state_xml, + need_certificate + ) @azure_ds_telemetry_reporter def _get_raw_goal_state_xml_from_azure(self) -> str: @@ -774,7 +789,11 @@ class WALinuxAgentShim: LOG.info('Registering with Azure...') url = 'http://{}/machine/?comp=goalstate'.format(self.endpoint) try: - response = self.azure_endpoint_client.get(url) + with events.ReportEventStack( + name="goalstate-retrieval", + description="retrieve goalstate", + parent=azure_ds_reporter): + response = self.azure_endpoint_client.get(url) except Exception as e: msg = 'failed to register with Azure: %s' % e LOG.warning(msg) @@ -785,7 +804,9 @@ class WALinuxAgentShim: @azure_ds_telemetry_reporter def _parse_raw_goal_state_xml( - self, unparsed_goal_state_xml: str) -> GoalState: + self, + unparsed_goal_state_xml: str, + need_certificate: bool) -> GoalState: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_goal_state_xml: GoalState XML string @@ -793,7 +814,10 @@ class WALinuxAgentShim: """ try: goal_state = GoalState( - unparsed_goal_state_xml, self.azure_endpoint_client) + unparsed_goal_state_xml, + self.azure_endpoint_client, + need_certificate + ) except Exception as e: msg = 'Error processing GoalState XML: %s' % e LOG.warning(msg) diff --git a/doc/rtd/topics/datasources/azure.rst b/doc/rtd/topics/datasources/azure.rst index fdb919a5..e04c3a33 100644 --- a/doc/rtd/topics/datasources/azure.rst +++ b/doc/rtd/topics/datasources/azure.rst @@ -68,6 +68,12 @@ configuration information to the instance. Cloud-init uses the IMDS for: - network configuration for the instance which is applied per boot - a preprovisioing gate which blocks instance configuration until Azure fabric is ready to provision +- retrieving SSH public keys. Cloud-init will first try to utilize SSH keys + returned from IMDS, and if they are not provided from IMDS then it will + fallback to using the OVF file provided from the CD-ROM. There is a large + performance benefit to using IMDS for SSH key retrieval, but in order to + support environments where IMDS is not available then we must continue to + all for keys from OVF Configuration diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 47e03bd1..2dda9925 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -102,7 +102,13 @@ NETWORK_METADATA = { "vmId": "ff702a6b-cb6a-4fcd-ad68-b4ce38227642", "vmScaleSetName": "", "vmSize": "Standard_DS1_v2", - "zone": "" + "zone": "", + "publicKeys": [ + { + "keyData": "key1", + "path": "path1" + } + ] }, "network": { "interface": [ @@ -302,7 +308,7 @@ class TestGetMetadataFromIMDS(HttprettyTestCase): def setUp(self): super(TestGetMetadataFromIMDS, self).setUp() - self.network_md_url = dsaz.IMDS_URL + "instance?api-version=2017-12-01" + self.network_md_url = dsaz.IMDS_URL + "instance?api-version=2019-06-01" @mock.patch(MOCKPATH + 'readurl') @mock.patch(MOCKPATH + 'EphemeralDHCPv4') @@ -1304,6 +1310,40 @@ scbus-1 on xpt0 bus 0 dsaz.get_hostname(hostname_command=("hostname",)) m_subp.assert_called_once_with(("hostname",), capture=True) + @mock.patch( + 'cloudinit.sources.helpers.azure.OpenSSLManager.parse_certificates') + def test_get_public_ssh_keys_with_imds(self, m_parse_certificates): + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + dsrc = self._get_ds(data) + dsrc.get_data() + dsrc.setup(True) + ssh_keys = dsrc.get_public_ssh_keys() + self.assertEqual(ssh_keys, ['key1']) + self.assertEqual(m_parse_certificates.call_count, 0) + + @mock.patch(MOCKPATH + 'get_metadata_from_imds') + def test_get_public_ssh_keys_without_imds( + self, + m_get_metadata_from_imds): + m_get_metadata_from_imds.return_value = dict() + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + dsrc = self._get_ds(data) + dsaz.get_metadata_from_fabric.return_value = {'public-keys': ['key2']} + dsrc.get_data() + dsrc.setup(True) + ssh_keys = dsrc.get_public_ssh_keys() + self.assertEqual(ssh_keys, ['key2']) + class TestAzureBounce(CiTestCase): @@ -2094,14 +2134,18 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): md, _ud, cfg, _d = dsa._reprovision() self.assertEqual(md['local-hostname'], hostname) self.assertEqual(cfg['system_info']['default_user']['name'], username) - self.assertEqual(fake_resp.call_args_list, - [mock.call(allow_redirects=True, - headers={'Metadata': 'true', - 'User-Agent': - 'Cloud-Init/%s' % vs()}, - method='GET', - timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, - url=full_url)]) + self.assertIn( + mock.call( + allow_redirects=True, + headers={ + 'Metadata': 'true', + 'User-Agent': 'Cloud-Init/%s' % vs() + }, + method='GET', + timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, + url=full_url + ), + fake_resp.call_args_list) self.assertEqual(m_dhcp.call_count, 2) m_net.assert_any_call( broadcast='192.168.2.255', interface='eth9', ip='192.168.2.9', diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 5e6d3d2d..5c31b8be 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -609,11 +609,11 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.return_value.container_id = self.test_container_id self.GoalState.return_value.instance_id = self.test_instance_id - def test_azure_endpoint_client_uses_certificate_during_report_ready(self): + def test_http_client_does_not_use_certificate(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( - [mock.call(self.OpenSSLManager.return_value.certificate)], + [mock.call(None)], self.AzureEndpointHttpClient.call_args_list) def test_correct_url_used_for_goalstate_during_report_ready(self): @@ -625,8 +625,11 @@ class TestWALinuxAgentShim(CiTestCase): [mock.call('http://test_endpoint/machine/?comp=goalstate')], get.call_args_list) self.assertEqual( - [mock.call(get.return_value.contents, - self.AzureEndpointHttpClient.return_value)], + [mock.call( + get.return_value.contents, + self.AzureEndpointHttpClient.return_value, + False + )], self.GoalState.call_args_list) def test_certificates_used_to_determine_public_keys(self): @@ -701,7 +704,7 @@ class TestWALinuxAgentShim(CiTestCase): shim.register_with_azure_and_fetch_data() shim.clean_up() self.assertEqual( - 1, self.OpenSSLManager.return_value.clean_up.call_count) + 0, self.OpenSSLManager.return_value.clean_up.call_count) def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): self.AzureEndpointHttpClient.return_value.get \ -- cgit v1.2.3 From 3b05b1a6c58dfc7533a16f795405bda0e53aa9d8 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Thu, 15 Oct 2020 07:19:57 -0700 Subject: azure: clean up and refactor report_diagnostic_event (#563) This moves logging into `report_diagnostic_event`, to clean up its callsites. --- cloudinit/sources/DataSourceAzure.py | 120 +++++++++++++++++-------------- cloudinit/sources/helpers/azure.py | 99 ++++++++++++++----------- tests/unittests/test_reporting_hyperv.py | 22 +++++- 3 files changed, 142 insertions(+), 99 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index fc32f8b1..8858fbd5 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -299,9 +299,9 @@ def temporary_hostname(temp_hostname, cfg, hostname_command='hostname'): try: set_hostname(temp_hostname, hostname_command) except Exception as e: - msg = 'Failed setting temporary hostname: %s' % e - report_diagnostic_event(msg) - LOG.warning(msg) + report_diagnostic_event( + 'Failed setting temporary hostname: %s' % e, + logger_func=LOG.warning) yield None return try: @@ -356,7 +356,9 @@ class DataSourceAzure(sources.DataSource): cfg=cfg, prev_hostname=previous_hn) except Exception as e: - LOG.warning("Failed publishing hostname: %s", e) + report_diagnostic_event( + "Failed publishing hostname: %s" % e, + logger_func=LOG.warning) util.logexc(LOG, "handling set_hostname failed") return False @@ -454,24 +456,23 @@ class DataSourceAzure(sources.DataSource): except NonAzureDataSource: report_diagnostic_event( - "Did not find Azure data source in %s" % cdev) + "Did not find Azure data source in %s" % cdev, + logger_func=LOG.debug) continue except BrokenAzureDataSource as exc: msg = 'BrokenAzureDataSource: %s' % exc - report_diagnostic_event(msg) + report_diagnostic_event(msg, logger_func=LOG.error) raise sources.InvalidMetaDataException(msg) except util.MountFailedError: - msg = '%s was not mountable' % cdev - report_diagnostic_event(msg) - LOG.warning(msg) + report_diagnostic_event( + '%s was not mountable' % cdev, logger_func=LOG.warning) continue perform_reprovision = reprovision or self._should_reprovision(ret) if perform_reprovision: if util.is_FreeBSD(): msg = "Free BSD is not supported for PPS VMs" - LOG.error(msg) - report_diagnostic_event(msg) + report_diagnostic_event(msg, logger_func=LOG.error) raise sources.InvalidMetaDataException(msg) ret = self._reprovision() imds_md = get_metadata_from_imds( @@ -486,16 +487,18 @@ class DataSourceAzure(sources.DataSource): 'userdata_raw': userdata_raw}) found = cdev - LOG.debug("found datasource in %s", cdev) + report_diagnostic_event( + 'found datasource in %s' % cdev, logger_func=LOG.debug) break if not found: msg = 'No Azure metadata found' - report_diagnostic_event(msg) + report_diagnostic_event(msg, logger_func=LOG.error) raise sources.InvalidMetaDataException(msg) if found == ddir: - LOG.debug("using files cached in %s", ddir) + report_diagnostic_event( + "using files cached in %s" % ddir, logger_func=LOG.debug) seed = _get_random_seed() if seed: @@ -516,7 +519,8 @@ class DataSourceAzure(sources.DataSource): self._report_ready(lease=lease) except Exception as e: report_diagnostic_event( - "exception while reporting ready: %s" % e) + "exception while reporting ready: %s" % e, + logger_func=LOG.error) raise return crawled_data @@ -605,14 +609,14 @@ class DataSourceAzure(sources.DataSource): except KeyError: log_msg = 'Unable to get keys from IMDS, falling back to OVF' LOG.debug(log_msg) - report_diagnostic_event(log_msg) + report_diagnostic_event(log_msg, logger_func=LOG.debug) try: ssh_keys = self.metadata['public-keys'] LOG.debug('Retrieved keys from OVF') except KeyError: log_msg = 'No keys available from OVF' LOG.debug(log_msg) - report_diagnostic_event(log_msg) + report_diagnostic_event(log_msg, logger_func=LOG.debug) return ssh_keys @@ -666,16 +670,14 @@ class DataSourceAzure(sources.DataSource): if self.imds_poll_counter == self.imds_logging_threshold: # Reducing the logging frequency as we are polling IMDS self.imds_logging_threshold *= 2 - LOG.debug("Call to IMDS with arguments %s failed " - "with status code %s after %s retries", - msg, exception.code, self.imds_poll_counter) LOG.debug("Backing off logging threshold for the same " "exception to %d", self.imds_logging_threshold) report_diagnostic_event("poll IMDS with %s failed. " "Exception: %s and code: %s" % (msg, exception.cause, - exception.code)) + exception.code), + logger_func=LOG.debug) self.imds_poll_counter += 1 return True else: @@ -684,12 +686,15 @@ class DataSourceAzure(sources.DataSource): report_diagnostic_event("poll IMDS with %s failed. " "Exception: %s and code: %s" % (msg, exception.cause, - exception.code)) + exception.code), + logger_func=LOG.warning) return False - LOG.debug("poll IMDS failed with an unexpected exception: %s", - exception) - return False + report_diagnostic_event( + "poll IMDS failed with an " + "unexpected exception: %s" % exception, + logger_func=LOG.warning) + return False LOG.debug("Wait for vnetswitch to happen") while True: @@ -709,8 +714,9 @@ class DataSourceAzure(sources.DataSource): try: nl_sock = netlink.create_bound_netlink_socket() except netlink.NetlinkCreateSocketError as e: - report_diagnostic_event(e) - LOG.warning(e) + report_diagnostic_event( + 'Failed to create bound netlink socket: %s' % e, + logger_func=LOG.warning) self._ephemeral_dhcp_ctx.clean_network() break @@ -729,9 +735,10 @@ class DataSourceAzure(sources.DataSource): try: netlink.wait_for_media_disconnect_connect( nl_sock, lease['interface']) - except AssertionError as error: - report_diagnostic_event(error) - LOG.error(error) + except AssertionError as e: + report_diagnostic_event( + 'Error while waiting for vnet switch: %s' % e, + logger_func=LOG.error) break vnet_switched = True @@ -757,9 +764,11 @@ class DataSourceAzure(sources.DataSource): if vnet_switched: report_diagnostic_event("attempted dhcp %d times after reuse" % - dhcp_attempts) + dhcp_attempts, + logger_func=LOG.debug) report_diagnostic_event("polled imds %d times after reuse" % - self.imds_poll_counter) + self.imds_poll_counter, + logger_func=LOG.debug) return return_val @@ -768,10 +777,10 @@ class DataSourceAzure(sources.DataSource): """Tells the fabric provisioning has completed """ try: get_metadata_from_fabric(None, lease['unknown-245']) - except Exception: - LOG.warning( - "Error communicating with Azure fabric; You may experience." - "connectivity issues.", exc_info=True) + except Exception as e: + report_diagnostic_event( + "Error communicating with Azure fabric; You may experience " + "connectivity issues: %s" % e, logger_func=LOG.warning) def _should_reprovision(self, ret): """Whether or not we should poll IMDS for reprovisioning data. @@ -849,10 +858,7 @@ class DataSourceAzure(sources.DataSource): except Exception as e: report_diagnostic_event( "Error communicating with Azure fabric; You may experience " - "connectivity issues: %s" % e) - LOG.warning( - "Error communicating with Azure fabric; You may experience " - "connectivity issues.", exc_info=True) + "connectivity issues: %s" % e, logger_func=LOG.warning) return False util.del_file(REPORTED_READY_MARKER_FILE) @@ -1017,9 +1023,10 @@ def address_ephemeral_resize(devpath=RESOURCE_DISK_PATH, maxwait=120, log_pre="Azure ephemeral disk: ") if missing: - LOG.warning("ephemeral device '%s' did" - " not appear after %d seconds.", - devpath, maxwait) + report_diagnostic_event( + "ephemeral device '%s' did not appear after %d seconds." % + (devpath, maxwait), + logger_func=LOG.warning) return result = False @@ -1104,7 +1111,9 @@ def pubkeys_from_crt_files(flist): errors.append(fname) if errors: - LOG.warning("failed to convert the crt files to pubkey: %s", errors) + report_diagnostic_event( + "failed to convert the crt files to pubkey: %s" % errors, + logger_func=LOG.warning) return pubkeys @@ -1216,7 +1225,7 @@ def read_azure_ovf(contents): dom = minidom.parseString(contents) except Exception as e: error_str = "Invalid ovf-env.xml: %s" % e - report_diagnostic_event(error_str) + report_diagnostic_event(error_str, logger_func=LOG.warning) raise BrokenAzureDataSource(error_str) from e results = find_child(dom.documentElement, @@ -1523,7 +1532,9 @@ def get_metadata_from_imds(fallback_nic, retries): azure_ds_reporter, fallback_nic): return util.log_time(**kwargs) except Exception as e: - report_diagnostic_event("exception while getting metadata: %s" % e) + report_diagnostic_event( + "exception while getting metadata: %s" % e, + logger_func=LOG.warning) raise @@ -1537,9 +1548,10 @@ def _get_metadata_from_imds(retries): url, timeout=IMDS_TIMEOUT_IN_SECONDS, headers=headers, retries=retries, exception_cb=retry_on_url_exc) except Exception as e: - msg = 'Ignoring IMDS instance metadata: %s' % e - report_diagnostic_event(msg) - LOG.debug(msg) + report_diagnostic_event( + 'Ignoring IMDS instance metadata. ' + 'Get metadata from IMDS failed: %s' % e, + logger_func=LOG.warning) return {} try: from json.decoder import JSONDecodeError @@ -1550,9 +1562,10 @@ def _get_metadata_from_imds(retries): try: return util.load_json(str(response)) except json_decode_error as e: - report_diagnostic_event('non-json imds response' % e) - LOG.warning( - 'Ignoring non-json IMDS instance metadata: %s', str(response)) + report_diagnostic_event( + 'Ignoring non-json IMDS instance metadata response: %s. ' + 'Loading non-json IMDS response failed: %s' % (str(response), e), + logger_func=LOG.warning) return {} @@ -1604,9 +1617,8 @@ def _is_platform_viable(seed_dir): if asset_tag == AZURE_CHASSIS_ASSET_TAG: return True msg = "Non-Azure DMI asset tag '%s' discovered." % asset_tag - LOG.debug(msg) evt.description = msg - report_diagnostic_event(msg) + report_diagnostic_event(msg, logger_func=LOG.debug) if os.path.exists(os.path.join(seed_dir, 'ovf-env.xml')): return True return False diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 79445a81..560cadba 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -180,12 +180,15 @@ def get_system_info(): return evt -def report_diagnostic_event(str): +def report_diagnostic_event( + msg: str, *, logger_func=None) -> events.ReportingEvent: """Report a diagnostic event""" + if callable(logger_func): + logger_func(msg) evt = events.ReportingEvent( DIAGNOSTIC_EVENT_TYPE, 'diagnostic message', - str, events.DEFAULT_EVENT_ORIGIN) - events.report_event(evt) + msg, events.DEFAULT_EVENT_ORIGIN) + events.report_event(evt, excluded_handler_types={"log"}) # return the event for unit testing purpose return evt @@ -215,7 +218,8 @@ def push_log_to_kvp(file_name=CFG_BUILTIN['def_log_file']): log_pushed_to_kvp = bool(os.path.isfile(LOG_PUSHED_TO_KVP_MARKER_FILE)) if log_pushed_to_kvp: - report_diagnostic_event("cloud-init.log is already pushed to KVP") + report_diagnostic_event( + "cloud-init.log is already pushed to KVP", logger_func=LOG.debug) return LOG.debug("Dumping cloud-init.log file to KVP") @@ -225,13 +229,15 @@ def push_log_to_kvp(file_name=CFG_BUILTIN['def_log_file']): seek_index = max(f.tell() - MAX_LOG_TO_KVP_LENGTH, 0) report_diagnostic_event( "Dumping last {} bytes of cloud-init.log file to KVP".format( - f.tell() - seek_index)) + f.tell() - seek_index), + logger_func=LOG.debug) f.seek(seek_index, os.SEEK_SET) report_compressed_event("cloud-init.log", f.read()) util.write_file(LOG_PUSHED_TO_KVP_MARKER_FILE, '') except Exception as ex: - report_diagnostic_event("Exception when dumping log file: %s" % - repr(ex)) + report_diagnostic_event( + "Exception when dumping log file: %s" % repr(ex), + logger_func=LOG.warning) @contextmanager @@ -305,9 +311,9 @@ class GoalState: try: self.root = ElementTree.fromstring(unparsed_xml) except ElementTree.ParseError as e: - msg = 'Failed to parse GoalState XML: %s' - LOG.warning(msg, e) - report_diagnostic_event(msg % (e,)) + report_diagnostic_event( + 'Failed to parse GoalState XML: %s' % e, + logger_func=LOG.warning) raise self.container_id = self._text_from_xpath('./Container/ContainerId') @@ -317,9 +323,8 @@ class GoalState: for attr in ("container_id", "instance_id", "incarnation"): if getattr(self, attr) is None: - msg = 'Missing %s in GoalState XML' - LOG.warning(msg, attr) - report_diagnostic_event(msg % (attr,)) + msg = 'Missing %s in GoalState XML' % attr + report_diagnostic_event(msg, logger_func=LOG.warning) raise InvalidGoalStateXMLException(msg) self.certificates_xml = None @@ -513,9 +518,9 @@ class GoalStateHealthReporter: try: self._post_health_report(document=document) except Exception as e: - msg = "exception while reporting ready: %s" % e - LOG.error(msg) - report_diagnostic_event(msg) + report_diagnostic_event( + "exception while reporting ready: %s" % e, + logger_func=LOG.error) raise LOG.info('Reported ready to Azure fabric.') @@ -698,39 +703,48 @@ class WALinuxAgentShim: value = dhcp245 LOG.debug("Using Azure Endpoint from dhcp options") if value is None: - report_diagnostic_event("No Azure endpoint from dhcp options") - LOG.debug('Finding Azure endpoint from networkd...') + report_diagnostic_event( + 'No Azure endpoint from dhcp options. ' + 'Finding Azure endpoint from networkd...', + logger_func=LOG.debug) value = WALinuxAgentShim._networkd_get_value_from_leases() if value is None: # Option-245 stored in /run/cloud-init/dhclient.hooks/.json # a dhclient exit hook that calls cloud-init-dhclient-hook - report_diagnostic_event("No Azure endpoint from networkd") - LOG.debug('Finding Azure endpoint from hook json...') + report_diagnostic_event( + 'No Azure endpoint from networkd. ' + 'Finding Azure endpoint from hook json...', + logger_func=LOG.debug) dhcp_options = WALinuxAgentShim._load_dhclient_json() value = WALinuxAgentShim._get_value_from_dhcpoptions(dhcp_options) if value is None: # Fallback and check the leases file if unsuccessful - report_diagnostic_event("No Azure endpoint from dhclient logs") - LOG.debug("Unable to find endpoint in dhclient logs. " - " Falling back to check lease files") + report_diagnostic_event( + 'No Azure endpoint from dhclient logs. ' + 'Unable to find endpoint in dhclient logs. ' + 'Falling back to check lease files', + logger_func=LOG.debug) if fallback_lease_file is None: - LOG.warning("No fallback lease file was specified.") + report_diagnostic_event( + 'No fallback lease file was specified.', + logger_func=LOG.warning) value = None else: - LOG.debug("Looking for endpoint in lease file %s", - fallback_lease_file) + report_diagnostic_event( + 'Looking for endpoint in lease file %s' + % fallback_lease_file, logger_func=LOG.debug) value = WALinuxAgentShim._get_value_from_leases_file( fallback_lease_file) if value is None: - msg = "No lease found; using default endpoint" - report_diagnostic_event(msg) - LOG.warning(msg) value = DEFAULT_WIRESERVER_ENDPOINT + report_diagnostic_event( + 'No lease found; using default endpoint: %s' % value, + logger_func=LOG.warning) endpoint_ip_address = WALinuxAgentShim.get_ip_from_lease_value(value) - msg = 'Azure endpoint found at %s' % endpoint_ip_address - report_diagnostic_event(msg) - LOG.debug(msg) + report_diagnostic_event( + 'Azure endpoint found at %s' % endpoint_ip_address, + logger_func=LOG.debug) return endpoint_ip_address @azure_ds_telemetry_reporter @@ -795,9 +809,9 @@ class WALinuxAgentShim: parent=azure_ds_reporter): response = self.azure_endpoint_client.get(url) except Exception as e: - msg = 'failed to register with Azure: %s' % e - LOG.warning(msg) - report_diagnostic_event(msg) + report_diagnostic_event( + 'failed to register with Azure and fetch GoalState XML: %s' + % e, logger_func=LOG.warning) raise LOG.debug('Successfully fetched GoalState XML.') return response.contents @@ -819,16 +833,15 @@ class WALinuxAgentShim: need_certificate ) except Exception as e: - msg = 'Error processing GoalState XML: %s' % e - LOG.warning(msg) - report_diagnostic_event(msg) + report_diagnostic_event( + 'Error processing GoalState XML: %s' % e, + logger_func=LOG.warning) raise msg = ', '.join([ 'GoalState XML container id: %s' % goal_state.container_id, 'GoalState XML instance id: %s' % goal_state.instance_id, 'GoalState XML incarnation: %s' % goal_state.incarnation]) - LOG.debug(msg) - report_diagnostic_event(msg) + report_diagnostic_event(msg, logger_func=LOG.debug) return goal_state @azure_ds_telemetry_reporter @@ -910,8 +923,10 @@ def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, def dhcp_log_cb(out, err): - report_diagnostic_event("dhclient output stream: %s" % out) - report_diagnostic_event("dhclient error stream: %s" % err) + report_diagnostic_event( + "dhclient output stream: %s" % out, logger_func=LOG.debug) + report_diagnostic_event( + "dhclient error stream: %s" % err, logger_func=LOG.debug) class EphemeralDHCPv4WithReporting: diff --git a/tests/unittests/test_reporting_hyperv.py b/tests/unittests/test_reporting_hyperv.py index 47ede670..3f63a60e 100644 --- a/tests/unittests/test_reporting_hyperv.py +++ b/tests/unittests/test_reporting_hyperv.py @@ -188,18 +188,34 @@ class TextKvpReporter(CiTestCase): if not re.search("variant=" + pattern, evt_msg): raise AssertionError("missing distro variant string") - def test_report_diagnostic_event(self): + def test_report_diagnostic_event_without_logger_func(self): reporter = HyperVKvpReportingHandler(kvp_file_path=self.tmp_file_path) + diagnostic_msg = "test_diagnostic" + reporter.publish_event( + azure.report_diagnostic_event(diagnostic_msg)) + reporter.q.join() + kvps = list(reporter._iterate_kvps(0)) + self.assertEqual(1, len(kvps)) + evt_msg = kvps[0]['value'] + + if diagnostic_msg not in evt_msg: + raise AssertionError("missing expected diagnostic message") + def test_report_diagnostic_event_with_logger_func(self): + reporter = HyperVKvpReportingHandler(kvp_file_path=self.tmp_file_path) + logger_func = mock.MagicMock() + diagnostic_msg = "test_diagnostic" reporter.publish_event( - azure.report_diagnostic_event("test_diagnostic")) + azure.report_diagnostic_event(diagnostic_msg, + logger_func=logger_func)) reporter.q.join() kvps = list(reporter._iterate_kvps(0)) self.assertEqual(1, len(kvps)) evt_msg = kvps[0]['value'] - if "test_diagnostic" not in evt_msg: + if diagnostic_msg not in evt_msg: raise AssertionError("missing expected diagnostic message") + logger_func.assert_called_once_with(diagnostic_msg) def test_report_compressed_event(self): reporter = HyperVKvpReportingHandler(kvp_file_path=self.tmp_file_path) -- cgit v1.2.3 From c86283f0d9fe8a2634dc3c47727e6218fdaf25e2 Mon Sep 17 00:00:00 2001 From: Moustafa Moustafa Date: Wed, 4 Nov 2020 11:51:16 -0800 Subject: azure: enable pushing the log to KVP from the last pushed byte (#614) This allows the cloud-init log to be pushed multiple times during boot, with the latest lines being pushed each time. --- cloudinit/sources/helpers/azure.py | 44 +++++++++++++++++++++++--------- tests/unittests/test_reporting_hyperv.py | 7 ++--- 2 files changed, 36 insertions(+), 15 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 560cadba..4071a50e 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -9,6 +9,7 @@ import struct import time import textwrap import zlib +from errno import ENOENT from cloudinit.settings import CFG_BUILTIN from cloudinit.net import dhcp @@ -41,8 +42,9 @@ COMPRESSED_EVENT_TYPE = 'compressed' # cloud-init.log files where the P95 of the file sizes was 537KB and the time # consumed to dump 500KB file was (P95:76, P99:233, P99.9:1170) in ms MAX_LOG_TO_KVP_LENGTH = 512000 -# Marker file to indicate whether cloud-init.log is pushed to KVP -LOG_PUSHED_TO_KVP_MARKER_FILE = '/var/lib/cloud/data/log_pushed_to_kvp' +# File to store the last byte of cloud-init.log that was pushed to KVP. This +# file will be deleted with every VM reboot. +LOG_PUSHED_TO_KVP_INDEX_FILE = '/run/cloud-init/log_pushed_to_kvp_index' azure_ds_reporter = events.ReportEventStack( name="azure-ds", description="initialize reporter for azure ds", @@ -214,32 +216,50 @@ def report_compressed_event(event_name, event_content): def push_log_to_kvp(file_name=CFG_BUILTIN['def_log_file']): """Push a portion of cloud-init.log file or the whole file to KVP based on the file size. - If called more than once, it skips pushing the log file to KVP again.""" + The first time this function is called after VM boot, It will push the last + n bytes of the log file such that n < MAX_LOG_TO_KVP_LENGTH + If called again on the same boot, it continues from where it left off.""" - log_pushed_to_kvp = bool(os.path.isfile(LOG_PUSHED_TO_KVP_MARKER_FILE)) - if log_pushed_to_kvp: - report_diagnostic_event( - "cloud-init.log is already pushed to KVP", logger_func=LOG.debug) - return + start_index = get_last_log_byte_pushed_to_kvp_index() LOG.debug("Dumping cloud-init.log file to KVP") try: with open(file_name, "rb") as f: f.seek(0, os.SEEK_END) - seek_index = max(f.tell() - MAX_LOG_TO_KVP_LENGTH, 0) + seek_index = max(f.tell() - MAX_LOG_TO_KVP_LENGTH, start_index) report_diagnostic_event( - "Dumping last {} bytes of cloud-init.log file to KVP".format( - f.tell() - seek_index), + "Dumping last {0} bytes of cloud-init.log file to KVP starting" + " from index: {1}".format(f.tell() - seek_index, seek_index), logger_func=LOG.debug) f.seek(seek_index, os.SEEK_SET) report_compressed_event("cloud-init.log", f.read()) - util.write_file(LOG_PUSHED_TO_KVP_MARKER_FILE, '') + util.write_file(LOG_PUSHED_TO_KVP_INDEX_FILE, str(f.tell())) except Exception as ex: report_diagnostic_event( "Exception when dumping log file: %s" % repr(ex), logger_func=LOG.warning) +@azure_ds_telemetry_reporter +def get_last_log_byte_pushed_to_kvp_index(): + try: + with open(LOG_PUSHED_TO_KVP_INDEX_FILE, "r") as f: + return int(f.read()) + except IOError as e: + if e.errno != ENOENT: + report_diagnostic_event("Reading LOG_PUSHED_TO_KVP_INDEX_FILE" + " failed: %s." % repr(e), + logger_func=LOG.warning) + except ValueError as e: + report_diagnostic_event("Invalid value in LOG_PUSHED_TO_KVP_INDEX_FILE" + ": %s." % repr(e), + logger_func=LOG.warning) + except Exception as e: + report_diagnostic_event("Failed to get the last log byte pushed to KVP" + ": %s." % repr(e), logger_func=LOG.warning) + return 0 + + @contextmanager def cd(newdir): prevdir = os.getcwd() diff --git a/tests/unittests/test_reporting_hyperv.py b/tests/unittests/test_reporting_hyperv.py index 3f63a60e..8f7b3694 100644 --- a/tests/unittests/test_reporting_hyperv.py +++ b/tests/unittests/test_reporting_hyperv.py @@ -237,7 +237,7 @@ class TextKvpReporter(CiTestCase): instantiated_handler_registry.register_item("telemetry", reporter) log_file = self.tmp_path("cloud-init.log") azure.MAX_LOG_TO_KVP_LENGTH = 100 - azure.LOG_PUSHED_TO_KVP_MARKER_FILE = self.tmp_path( + azure.LOG_PUSHED_TO_KVP_INDEX_FILE = self.tmp_path( 'log_pushed_to_kvp') with open(log_file, "w") as f: log_content = "A" * 50 + "B" * 100 @@ -254,8 +254,9 @@ class TextKvpReporter(CiTestCase): self.assertNotEqual( event.event_type, azure.COMPRESSED_EVENT_TYPE) self.validate_compressed_kvps( - reporter, 1, - [log_content[-azure.MAX_LOG_TO_KVP_LENGTH:].encode()]) + reporter, 2, + [log_content[-azure.MAX_LOG_TO_KVP_LENGTH:].encode(), + extra_content.encode()]) finally: instantiated_handler_registry.unregister_item("telemetry", force=False) -- cgit v1.2.3 From d807df288f8cef29ca74f0b00c326b084e825782 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Wed, 18 Nov 2020 09:34:04 -0800 Subject: DataSourceAzure: send failure signal on Azure datasource failure (#594) On systems where the Azure datasource is a viable platform for crawling metadata, cloud-init occasionally encounters fatal irrecoverable errors during the crawling of the Azure datasource. When this happens, cloud-init crashes, and Azure VM provisioning would fail. However, instead of failing immediately, the user will continue seeing provisioning for a long time until it times out with "OS Provisioning Timed Out" message. In these situations, cloud-init should report failure to the Azure datasource endpoint indicating provisioning failure. The user will immediately see provisioning terminate, giving them a much better failure experience instead of pointlessly waiting for OS provisioning timeout. --- cloudinit/sources/DataSourceAzure.py | 73 ++- cloudinit/sources/helpers/azure.py | 80 ++- tests/unittests/test_datasource/test_azure.py | 322 ++++++++++-- .../unittests/test_datasource/test_azure_helper.py | 569 ++++++++++++++++++--- 4 files changed, 921 insertions(+), 123 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index fa3e0a2b..ab139b8d 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -29,6 +29,7 @@ from cloudinit import util from cloudinit.reporting import events from cloudinit.sources.helpers.azure import ( + DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE, azure_ds_reporter, azure_ds_telemetry_reporter, get_metadata_from_fabric, @@ -38,7 +39,8 @@ from cloudinit.sources.helpers.azure import ( EphemeralDHCPv4WithReporting, is_byte_swapped, dhcp_log_cb, - push_log_to_kvp) + push_log_to_kvp, + report_failure_to_fabric) LOG = logging.getLogger(__name__) @@ -508,8 +510,9 @@ class DataSourceAzure(sources.DataSource): if perform_reprovision: LOG.info("Reporting ready to Azure after getting ReprovisionData") - use_cached_ephemeral = (net.is_up(self.fallback_interface) and - getattr(self, '_ephemeral_dhcp_ctx', None)) + use_cached_ephemeral = ( + self.distro.networking.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None)) if use_cached_ephemeral: self._report_ready(lease=self._ephemeral_dhcp_ctx.lease) self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral @@ -560,9 +563,14 @@ class DataSourceAzure(sources.DataSource): logfunc=LOG.debug, msg='Crawl of metadata service', func=self.crawl_metadata ) - except sources.InvalidMetaDataException as e: - LOG.warning('Could not crawl Azure metadata: %s', e) + except Exception as e: + report_diagnostic_event( + 'Could not crawl Azure metadata: %s' % e, + logger_func=LOG.error) + self._report_failure( + description=DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) return False + if (self.distro and self.distro.name == 'ubuntu' and self.ds_cfg.get('apply_network_config')): maybe_remove_ubuntu_network_config_scripts() @@ -785,6 +793,61 @@ class DataSourceAzure(sources.DataSource): return return_val @azure_ds_telemetry_reporter + def _report_failure(self, description=None) -> bool: + """Tells the Azure fabric that provisioning has failed. + + @param description: A description of the error encountered. + @return: The success status of sending the failure signal. + """ + unknown_245_key = 'unknown-245' + + try: + if (self.distro.networking.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None) and + getattr(self._ephemeral_dhcp_ctx, 'lease', None) and + unknown_245_key in self._ephemeral_dhcp_ctx.lease): + report_diagnostic_event( + 'Using cached ephemeral dhcp context ' + 'to report failure to Azure', logger_func=LOG.debug) + report_failure_to_fabric( + dhcp_opts=self._ephemeral_dhcp_ctx.lease[unknown_245_key], + description=description) + self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using ' + 'cached ephemeral dhcp context: %s' % e, + logger_func=LOG.error) + + try: + report_diagnostic_event( + 'Using new ephemeral dhcp to report failure to Azure', + logger_func=LOG.debug) + with EphemeralDHCPv4WithReporting(azure_ds_reporter) as lease: + report_failure_to_fabric( + dhcp_opts=lease[unknown_245_key], + description=description) + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using new ephemeral dhcp: %s' % e, + logger_func=LOG.debug) + + try: + report_diagnostic_event( + 'Using fallback lease to report failure to Azure') + report_failure_to_fabric( + fallback_lease_file=self.dhclient_lease_file, + description=description) + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using fallback lease: %s' % e, + logger_func=LOG.debug) + + return False + def _report_ready(self, lease: dict) -> bool: """Tells the fabric provisioning has completed. diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 4071a50e..951c7a10 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -17,6 +17,7 @@ from cloudinit import stages from cloudinit import temp_utils from contextlib import contextmanager from xml.etree import ElementTree +from xml.sax.saxutils import escape from cloudinit import subp from cloudinit import url_helper @@ -50,6 +51,11 @@ azure_ds_reporter = events.ReportEventStack( description="initialize reporter for azure ds", reporting_enabled=True) +DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE = ( + 'The VM encountered an error during deployment. ' + 'Please visit https://aka.ms/linuxprovisioningerror ' + 'for more information on remediation.') + def azure_ds_telemetry_reporter(func): def impl(*args, **kwargs): @@ -379,12 +385,20 @@ class OpenSSLManager: def __init__(self): self.tmpdir = temp_utils.mkdtemp() - self.certificate = None + self._certificate = None self.generate_certificate() def clean_up(self): util.del_dir(self.tmpdir) + @property + def certificate(self): + return self._certificate + + @certificate.setter + def certificate(self, value): + self._certificate = value + @azure_ds_telemetry_reporter def generate_certificate(self): LOG.debug('Generating certificate for communication with fabric...') @@ -507,6 +521,10 @@ class GoalStateHealthReporter: ''') PROVISIONING_SUCCESS_STATUS = 'Ready' + PROVISIONING_NOT_READY_STATUS = 'NotReady' + PROVISIONING_FAILURE_SUBSTATUS = 'ProvisioningFailed' + + HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512 def __init__( self, goal_state: GoalState, @@ -545,19 +563,39 @@ class GoalStateHealthReporter: LOG.info('Reported ready to Azure fabric.') + @azure_ds_telemetry_reporter + def send_failure_signal(self, description: str) -> None: + document = self.build_report( + incarnation=self._goal_state.incarnation, + container_id=self._goal_state.container_id, + instance_id=self._goal_state.instance_id, + status=self.PROVISIONING_NOT_READY_STATUS, + substatus=self.PROVISIONING_FAILURE_SUBSTATUS, + description=description) + try: + self._post_health_report(document=document) + except Exception as e: + msg = "exception while reporting failure: %s" % e + report_diagnostic_event(msg, logger_func=LOG.error) + raise + + LOG.warning('Reported failure to Azure fabric.') + def build_report( self, incarnation: str, container_id: str, instance_id: str, status: str, substatus=None, description=None) -> str: health_detail = '' if substatus is not None: health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( - health_substatus=substatus, health_description=description) + health_substatus=escape(substatus), + health_description=escape( + description[:self.HEALTH_REPORT_DESCRIPTION_TRIM_LEN])) health_report = self.HEALTH_REPORT_XML_TEMPLATE.format( - incarnation=incarnation, - container_id=container_id, - instance_id=instance_id, - health_status=status, + incarnation=escape(str(incarnation)), + container_id=escape(container_id), + instance_id=escape(instance_id), + health_status=escape(status), health_detail_subsection=health_detail) return health_report @@ -797,6 +835,20 @@ class WALinuxAgentShim: health_reporter.send_ready_signal() return {'public-keys': ssh_keys} + @azure_ds_telemetry_reporter + def register_with_azure_and_report_failure(self, description: str) -> None: + """Gets the VM's GoalState from Azure, uses the GoalState information + to report failure/send provisioning failure signal to Azure. + + @param: user visible error description of provisioning failure. + """ + if self.azure_endpoint_client is None: + self.azure_endpoint_client = AzureEndpointHttpClient(None) + goal_state = self._fetch_goal_state_from_azure(need_certificate=False) + health_reporter = GoalStateHealthReporter( + goal_state, self.azure_endpoint_client, self.endpoint) + health_reporter.send_failure_signal(description=description) + @azure_ds_telemetry_reporter def _fetch_goal_state_from_azure( self, @@ -804,6 +856,7 @@ class WALinuxAgentShim: """Fetches the GoalState XML from the Azure endpoint, parses the XML, and returns a GoalState object. + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML """ unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() @@ -844,6 +897,7 @@ class WALinuxAgentShim: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_goal_state_xml: GoalState XML string + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML """ try: @@ -942,6 +996,20 @@ def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, shim.clean_up() +@azure_ds_telemetry_reporter +def report_failure_to_fabric(fallback_lease_file=None, dhcp_opts=None, + description=None): + shim = WALinuxAgentShim(fallback_lease_file=fallback_lease_file, + dhcp_options=dhcp_opts) + if not description: + description = DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE + try: + shim.register_with_azure_and_report_failure( + description=description) + finally: + shim.clean_up() + + def dhcp_log_cb(out, err): report_diagnostic_event( "dhclient output stream: %s" % out, logger_func=LOG.debug) diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 433fbc66..d9752ab7 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -461,6 +461,8 @@ class TestGetMetadataFromIMDS(HttprettyTestCase): class TestAzureDataSource(CiTestCase): + with_logs = True + def setUp(self): super(TestAzureDataSource, self).setUp() self.tmp = self.tmp_dir() @@ -549,9 +551,12 @@ scbus-1 on xpt0 bus 0 dsaz.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d - self.get_metadata_from_fabric = mock.MagicMock(return_value={ - 'public-keys': [], - }) + self.m_is_platform_viable = mock.MagicMock(autospec=True) + self.m_get_metadata_from_fabric = mock.MagicMock( + return_value={'public-keys': []}) + self.m_report_failure_to_fabric = mock.MagicMock(autospec=True) + self.m_ephemeral_dhcpv4 = mock.MagicMock() + self.m_ephemeral_dhcpv4_with_reporting = mock.MagicMock() self.instance_id = 'D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8' @@ -568,7 +573,17 @@ scbus-1 on xpt0 bus 0 (dsaz, 'perform_hostname_bounce', mock.MagicMock()), (dsaz, 'get_hostname', mock.MagicMock()), (dsaz, 'set_hostname', mock.MagicMock()), - (dsaz, 'get_metadata_from_fabric', self.get_metadata_from_fabric), + (dsaz, '_is_platform_viable', + self.m_is_platform_viable), + (dsaz, 'get_metadata_from_fabric', + self.m_get_metadata_from_fabric), + (dsaz, 'report_failure_to_fabric', + self.m_report_failure_to_fabric), + (dsaz, 'EphemeralDHCPv4', self.m_ephemeral_dhcpv4), + (dsaz, 'EphemeralDHCPv4WithReporting', + self.m_ephemeral_dhcpv4_with_reporting), + (dsaz, 'get_boot_telemetry', mock.MagicMock()), + (dsaz, 'get_system_info', mock.MagicMock()), (dsaz.subp, 'which', lambda x: True), (dsaz.dmi, 'read_dmi_data', mock.MagicMock( side_effect=_dmi_mocks)), @@ -632,15 +647,87 @@ scbus-1 on xpt0 bus 0 dev = ds.get_resource_disk_on_freebsd(1) self.assertEqual("da1", dev) - @mock.patch(MOCKPATH + '_is_platform_viable') - def test_call_is_platform_viable_seed(self, m_is_platform_viable): + def test_not_is_platform_viable_seed_should_return_no_datasource(self): """Check seed_dir using _is_platform_viable and return False.""" # Return a non-matching asset tag value - m_is_platform_viable.return_value = False - dsrc = dsaz.DataSourceAzure( - {}, distro=mock.Mock(), paths=self.paths) - self.assertFalse(dsrc.get_data()) - m_is_platform_viable.assert_called_with(dsrc.seed_dir) + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = False + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_report_failure') as m_report_failure: + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + # Assert that for non viable platforms, + # there is no communication with the Azure datasource. + self.assertEqual( + 0, + m_crawl_metadata.call_count) + self.assertEqual( + 0, + m_report_failure.call_count) + + def test_platform_viable_but_no_devs_should_return_no_datasource(self): + """For platforms where the Azure platform is viable + (which is indicated by the matching asset tag), + the absence of any devs at all (devs == candidate sources + for crawling Azure datasource) is NOT expected. + Report failure to Azure as this is an unexpected fatal error. + """ + data = {} + dsrc = self._get_ds(data) + with mock.patch.object(dsrc, '_report_failure') as m_report_failure: + self.m_is_platform_viable.return_value = True + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + self.assertEqual( + 1, + m_report_failure.call_count) + + def test_crawl_metadata_exception_returns_no_datasource(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertEqual( + 1, + m_crawl_metadata.call_count) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + + def test_crawl_metadata_exception_should_report_failure_with_msg(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_report_failure') as m_report_failure: + m_crawl_metadata.side_effect = Exception + dsrc.get_data() + self.assertEqual( + 1, + m_crawl_metadata.call_count) + m_report_failure.assert_called_once_with( + description=dsaz.DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_crawl_metadata_exc_should_log_could_not_crawl_msg(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + dsrc.get_data() + self.assertEqual( + 1, + m_crawl_metadata.call_count) + self.assertIn( + "Could not crawl Azure metadata", + self.logs.getvalue()) def test_basic_seed_dir(self): odata = {'HostName': "myhost", 'UserName': "myuser"} @@ -761,7 +848,7 @@ scbus-1 on xpt0 bus 0 'cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready') @mock.patch('cloudinit.sources.DataSourceAzure.DataSourceAzure._poll_imds') def test_crawl_metadata_on_reprovision_reports_ready( - self, poll_imds_func, report_ready_func, m_write, m_dhcp + self, poll_imds_func, m_report_ready, m_write, m_dhcp ): """If reprovisioning, report ready at the end""" ovfenv = construct_valid_ovf_env( @@ -775,18 +862,16 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) poll_imds_func.return_value = ovfenv dsrc.crawl_metadata() - self.assertEqual(1, report_ready_func.call_count) + self.assertEqual(1, m_report_ready.call_count) @mock.patch('cloudinit.sources.DataSourceAzure.util.write_file') @mock.patch('cloudinit.sources.helpers.netlink.' 'wait_for_media_disconnect_connect') @mock.patch( 'cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready') - @mock.patch('cloudinit.net.dhcp.EphemeralIPv4Network') - @mock.patch('cloudinit.net.dhcp.maybe_perform_dhcp_discovery') @mock.patch('cloudinit.sources.DataSourceAzure.readurl') def test_crawl_metadata_on_reprovision_reports_ready_using_lease( - self, m_readurl, m_dhcp, m_net, report_ready_func, + self, m_readurl, m_report_ready, m_media_switch, m_write ): """If reprovisioning, report ready using the obtained lease""" @@ -800,20 +885,30 @@ scbus-1 on xpt0 bus 0 } dsrc = self._get_ds(data) - lease = { - 'interface': 'eth9', 'fixed-address': '192.168.2.9', - 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', - 'unknown-245': '624c3620'} - m_dhcp.return_value = [lease] - m_media_switch.return_value = None + with mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: - reprovision_ovfenv = construct_valid_ovf_env() - m_readurl.return_value = url_helper.StringResponse( - reprovision_ovfenv.encode('utf-8')) + # For this mock, net should not be up, + # so that cached ephemeral won't be used. + # This is so that a NEW ephemeral dhcp lease will be discovered + # and used instead. + m_dsrc_distro_networking_is_up.return_value = False - dsrc.crawl_metadata() - self.assertEqual(2, report_ready_func.call_count) - report_ready_func.assert_called_with(lease=lease) + lease = { + 'interface': 'eth9', 'fixed-address': '192.168.2.9', + 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', + 'unknown-245': '624c3620'} + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.return_value = lease + m_media_switch.return_value = None + + reprovision_ovfenv = construct_valid_ovf_env() + m_readurl.return_value = url_helper.StringResponse( + reprovision_ovfenv.encode('utf-8')) + + dsrc.crawl_metadata() + self.assertEqual(2, m_report_ready.call_count) + m_report_ready.assert_called_with(lease=lease) def test_waagent_d_has_0700_perms(self): # we expect /var/lib/waagent to be created 0700 @@ -971,7 +1066,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertTrue('default_user' in dsrc.cfg['system_info']) + self.assertIn('default_user', dsrc.cfg['system_info']) defuser = dsrc.cfg['system_info']['default_user'] # default user should be updated username and should not be locked. @@ -993,7 +1088,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertTrue('default_user' in dsrc.cfg['system_info']) + self.assertIn('default_user', dsrc.cfg['system_info']) defuser = dsrc.cfg['system_info']['default_user'] # default user should be updated username and should not be locked. @@ -1021,14 +1116,6 @@ scbus-1 on xpt0 bus 0 self.assertTrue(ret) self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8')) - def test_no_datasource_expected(self): - # no source should be found if no seed_dir and no devs - data = {} - dsrc = self._get_ds({}) - ret = dsrc.get_data() - self.assertFalse(ret) - self.assertFalse('agent_invoked' in data) - def test_cfg_has_pubkeys_fingerprint(self): odata = {'HostName': "myhost", 'UserName': "myuser"} mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}] @@ -1171,21 +1258,168 @@ scbus-1 on xpt0 bus 0 self): dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.side_effect = Exception + self.m_get_metadata_from_fabric.side_effect = Exception self.assertFalse(dsrc._report_ready(lease=mock.MagicMock())) + def test_dsaz_report_failure_returns_true_when_report_succeeds(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) + self.assertEqual( + 1, + self.m_report_failure_to_fabric.call_count) + + def test_dsaz_report_failure_returns_false_and_does_not_propagate_exc( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_ephemeral_dhcp_ctx') \ + as m_ephemeral_dhcp_ctx, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # setup mocks to allow using cached ephemeral dhcp lease + m_dsrc_distro_networking_is_up.return_value = True + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + m_ephemeral_dhcp_ctx.lease = test_lease + + # We expect 3 calls to report_failure_to_fabric, + # because we try 3 different methods of calling report failure. + # The different methods are attempted in the following order: + # 1. Using cached ephemeral dhcp context to report failure to Azure + # 2. Using new ephemeral dhcp to report failure to Azure + # 3. Using fallback lease to report failure to Azure + self.m_report_failure_to_fabric.side_effect = Exception + self.assertFalse(dsrc._report_failure()) + self.assertEqual( + 3, + self.m_report_failure_to_fabric.call_count) + + def test_dsaz_report_failure_description_msg(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + test_msg = 'Test report failure description message' + self.assertTrue(dsrc._report_failure(description=test_msg)) + self.m_report_failure_to_fabric.assert_called_once_with( + dhcp_opts=mock.ANY, description=test_msg) + + def test_dsaz_report_failure_no_description_msg(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) # no description msg + self.m_report_failure_to_fabric.assert_called_once_with( + dhcp_opts=mock.ANY, description=None) + + def test_dsaz_report_failure_uses_cached_ephemeral_dhcp_ctx_lease(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_ephemeral_dhcp_ctx') \ + as m_ephemeral_dhcp_ctx, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # setup mocks to allow using cached ephemeral dhcp lease + m_dsrc_distro_networking_is_up.return_value = True + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + m_ephemeral_dhcp_ctx.lease = test_lease + + self.assertTrue(dsrc._report_failure()) + + # ensure called with cached ephemeral dhcp lease option 245 + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, dhcp_opts=test_lease_dhcp_option_245) + + # ensure cached ephemeral is cleaned + self.assertEqual( + 1, + m_ephemeral_dhcp_ctx.clean_network.call_count) + + def test_dsaz_report_failure_no_net_uses_new_ephemeral_dhcp_lease(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # net is not up and cannot use cached ephemeral dhcp + m_dsrc_distro_networking_is_up.return_value = False + # setup ephemeral dhcp lease discovery mock + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.return_value = test_lease + + self.assertTrue(dsrc._report_failure()) + + # ensure called with the newly discovered + # ephemeral dhcp lease option 245 + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, dhcp_opts=test_lease_dhcp_option_245) + + def test_dsaz_report_failure_no_net_and_no_dhcp_uses_fallback_lease( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # net is not up and cannot use cached ephemeral dhcp + m_dsrc_distro_networking_is_up.return_value = False + # ephemeral dhcp discovery failure, + # so cannot use a new ephemeral dhcp + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) + + # ensure called with fallback lease + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, + fallback_lease_file=dsrc.dhclient_lease_file) + def test_exception_fetching_fabric_data_doesnt_propagate(self): """Errors communicating with fabric should warn, but return True.""" dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.side_effect = Exception + self.m_get_metadata_from_fabric.side_effect = Exception ret = self._get_and_setup(dsrc) self.assertTrue(ret) def test_fabric_data_included_in_metadata(self): dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.return_value = {'test': 'value'} + self.m_get_metadata_from_fabric.return_value = {'test': 'value'} ret = self._get_and_setup(dsrc) self.assertTrue(ret) self.assertEqual('value', dsrc.metadata['test']) @@ -2053,7 +2287,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): @mock.patch('time.sleep', mock.MagicMock()) @mock.patch(MOCKPATH + 'EphemeralDHCPv4') - def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, report_ready_func, + def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, m_report_ready, m_request, m_media_switch, m_dhcp, m_net): """The poll_imds will retry DHCP on IMDS timeout.""" @@ -2088,8 +2322,8 @@ class TestPreprovisioningPollIMDS(CiTestCase): dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): dsa._poll_imds() - self.assertEqual(report_ready_func.call_count, 1) - report_ready_func.assert_called_with(lease=lease) + self.assertEqual(m_report_ready.call_count, 1) + m_report_ready.assert_called_with(lease=lease) self.assertEqual(3, m_dhcpv4.call_count, 'Expected 3 DHCP calls') self.assertEqual(4, self.tries, 'Expected 4 total reads from IMDS') diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 6e004e34..adf68857 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -5,6 +5,7 @@ import re import unittest from textwrap import dedent from xml.etree import ElementTree +from xml.sax.saxutils import escape, unescape from cloudinit.sources.helpers import azure as azure_helper from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -70,6 +71,15 @@ HEALTH_REPORT_XML_TEMPLATE = '''\ ''' +HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = dedent('''\ +
+ {health_substatus} + {health_description} +
+ ''') + +HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512 + class SentinelException(Exception): pass @@ -461,17 +471,24 @@ class TestOpenSSLManagerActions(CiTestCase): class TestGoalStateHealthReporter(CiTestCase): + maxDiff = None + default_parameters = { 'incarnation': 1634, 'container_id': 'MyContainerId', 'instance_id': 'MyInstanceId' } - test_endpoint = 'TestEndpoint' - test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) + test_azure_endpoint = 'TestEndpoint' + test_health_report_url = 'http://{0}/machine?comp=health'.format( + test_azure_endpoint) test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} provisioning_success_status = 'Ready' + provisioning_not_ready_status = 'NotReady' + provisioning_failure_substatus = 'ProvisioningFailed' + provisioning_failure_err_description = ( + 'Test error message containing provisioning failure details') def setUp(self): super(TestGoalStateHealthReporter, self).setUp() @@ -496,17 +513,40 @@ class TestGoalStateHealthReporter(CiTestCase): self.GoalState.return_value.incarnation = \ self.default_parameters['incarnation'] + def _text_from_xpath_in_xroot(self, xroot, xpath): + element = xroot.find(xpath) + if element is not None: + return element.text + return None + def _get_formatted_health_report_xml_string(self, **kwargs): return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + def _get_formatted_health_detail_subsection_xml_string(self, **kwargs): + return HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format(**kwargs) + def _get_report_ready_health_document(self): return self._get_formatted_health_report_xml_string( - incarnation=self.default_parameters['incarnation'], - container_id=self.default_parameters['container_id'], - instance_id=self.default_parameters['instance_id'], - health_status=self.provisioning_success_status, + incarnation=escape(str(self.default_parameters['incarnation'])), + container_id=escape(self.default_parameters['container_id']), + instance_id=escape(self.default_parameters['instance_id']), + health_status=escape(self.provisioning_success_status), health_detail_subsection='') + def _get_report_failure_health_document(self): + health_detail_subsection = \ + self._get_formatted_health_detail_subsection_xml_string( + health_substatus=escape(self.provisioning_failure_substatus), + health_description=escape( + self.provisioning_failure_err_description)) + + return self._get_formatted_health_report_xml_string( + incarnation=escape(str(self.default_parameters['incarnation'])), + container_id=escape(self.default_parameters['container_id']), + instance_id=escape(self.default_parameters['instance_id']), + health_status=escape(self.provisioning_not_ready_status), + health_detail_subsection=health_detail_subsection) + def test_send_ready_signal_sends_post_request(self): with mock.patch.object( azure_helper.GoalStateHealthReporter, @@ -514,55 +554,130 @@ class TestGoalStateHealthReporter(CiTestCase): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), - client, self.test_endpoint) + client, self.test_azure_endpoint) reporter.send_ready_signal() self.assertEqual(1, self.post.call_count) self.assertEqual( mock.call( - self.test_url, + self.test_health_report_url, + data=m_build_report.return_value, + extra_headers=self.test_default_headers), + self.post.call_args) + + def test_send_failure_signal_sends_post_request(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, + 'build_report') as m_build_report: + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + client, self.test_azure_endpoint) + reporter.send_failure_signal( + description=self.provisioning_failure_err_description) + + self.assertEqual(1, self.post.call_count) + self.assertEqual( + mock.call( + self.test_health_report_url, data=m_build_report.return_value, extra_headers=self.test_default_headers), self.post.call_args) - def test_build_report_for_health_document(self): + def test_build_report_for_ready_signal_health_document(self): health_document = self._get_report_ready_health_document() reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), azure_helper.AzureEndpointHttpClient(mock.MagicMock()), - self.test_endpoint) + self.test_azure_endpoint) generated_health_document = reporter.build_report( incarnation=self.default_parameters['incarnation'], container_id=self.default_parameters['container_id'], instance_id=self.default_parameters['instance_id'], status=self.provisioning_success_status) + self.assertEqual(health_document, generated_health_document) - self.assertIn( - '{}'.format( - str(self.default_parameters['incarnation'])), - generated_health_document) - self.assertIn( - ''.join([ - '', - self.default_parameters['container_id'], - '']), - generated_health_document) - self.assertIn( - ''.join([ - '', - self.default_parameters['instance_id'], - '']), - generated_health_document) - self.assertIn( - ''.join([ - '', - self.provisioning_success_status, - '']), - generated_health_document + + generated_xroot = ElementTree.fromstring(generated_health_document) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './GoalStateIncarnation'), + str(self.default_parameters['incarnation'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './Container/ContainerId'), + str(self.default_parameters['container_id'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/InstanceId'), + str(self.default_parameters['instance_id'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/State'), + escape(self.provisioning_success_status)) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details')) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/SubStatus')) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') ) - self.assertNotIn('
', generated_health_document) - self.assertNotIn('', generated_health_document) - self.assertNotIn('', generated_health_document) + + def test_build_report_for_failure_signal_health_document(self): + health_document = self._get_report_failure_health_document() + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=self.provisioning_failure_err_description) + + self.assertEqual(health_document, generated_health_document) + + generated_xroot = ElementTree.fromstring(generated_health_document) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './GoalStateIncarnation'), + str(self.default_parameters['incarnation'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './Container/ContainerId'), + self.default_parameters['container_id']) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/InstanceId'), + self.default_parameters['instance_id']) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/State'), + escape(self.provisioning_not_ready_status)) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/' + 'SubStatus'), + escape(self.provisioning_failure_substatus)) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/' + 'Description'), + escape(self.provisioning_failure_err_description)) def test_send_ready_signal_calls_build_report(self): with mock.patch.object( @@ -571,7 +686,7 @@ class TestGoalStateHealthReporter(CiTestCase): reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), azure_helper.AzureEndpointHttpClient(mock.MagicMock()), - self.test_endpoint) + self.test_azure_endpoint) reporter.send_ready_signal() self.assertEqual(1, m_build_report.call_count) @@ -583,6 +698,131 @@ class TestGoalStateHealthReporter(CiTestCase): status=self.provisioning_success_status), m_build_report.call_args) + def test_send_failure_signal_calls_build_report(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, 'build_report' + ) as m_build_report: + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + reporter.send_failure_signal( + description=self.provisioning_failure_err_description) + + self.assertEqual(1, m_build_report.call_count) + self.assertEqual( + mock.call( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=self.provisioning_failure_err_description), + m_build_report.call_args) + + def test_build_report_escapes_chars(self): + incarnation = 'jd8\'9*&^<\'A>' + instance_id = 'Opo>>>jas\'&d;[p&fp\"a<&aa\'sd!@&!)((*<&>' + health_substatus = '&as\"d<d<\'^@!5&6<7' + health_description = '&&&>!#$\"&&><>&\"sd<67<]>>' + + health_detail_subsection = \ + self._get_formatted_health_detail_subsection_xml_string( + health_substatus=escape(health_substatus), + health_description=escape(health_description)) + health_document = self._get_formatted_health_report_xml_string( + incarnation=escape(incarnation), + container_id=escape(container_id), + instance_id=escape(instance_id), + health_status=escape(health_status), + health_detail_subsection=health_detail_subsection) + + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + generated_health_document = reporter.build_report( + incarnation=incarnation, + container_id=container_id, + instance_id=instance_id, + status=health_status, + substatus=health_substatus, + description=health_description) + + self.assertEqual(health_document, generated_health_document) + + def test_build_report_conforms_to_length_limits(self): + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + long_err_msg = 'a9&ea8>>>e as1< d\"q2*&(^%\'a=5<' * 100 + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=long_err_msg) + + generated_xroot = ElementTree.fromstring(generated_health_document) + generated_health_report_description = self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') + self.assertEqual( + len(unescape(generated_health_report_description)), + HEALTH_REPORT_DESCRIPTION_TRIM_LEN) + + def test_trim_description_then_escape_conforms_to_len_limits_worst_case( + self): + """When unescaped characters are XML-escaped, the length increases. + Char Escape String + < < + > > + " " + ' ' + & & + + We (step 1) trim the health report XML's description field, + and then (step 2) XML-escape the health report XML's description field. + + The health report XML's description field limit within cloud-init + is HEALTH_REPORT_DESCRIPTION_TRIM_LEN. + + The Azure platform's limit on the health report XML's description field + is 4096 chars. + + For worst-case chars, there is a 5x blowup in length + when the chars are XML-escaped. + ' and " when XML-escaped have a 5x blowup. + + Ensure that (1) trimming and then (2) XML-escaping does not blow past + the Azure platform's limit for health report XML's description field + (4096 chars). + """ + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + long_err_msg = '\'\"' * 10000 + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=long_err_msg) + + generated_xroot = ElementTree.fromstring(generated_health_document) + generated_health_report_description = self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') + # The escaped description string should be less than + # the Azure platform limit for the escaped description string. + self.assertLessEqual(len(generated_health_report_description), 4096) + class TestWALinuxAgentShim(CiTestCase): @@ -598,7 +838,7 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState = patches.enter_context( mock.patch.object(azure_helper, 'GoalState')) self.OpenSSLManager = patches.enter_context( - mock.patch.object(azure_helper, 'OpenSSLManager')) + mock.patch.object(azure_helper, 'OpenSSLManager', autospec=True)) patches.enter_context( mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) @@ -609,24 +849,47 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.return_value.container_id = self.test_container_id self.GoalState.return_value.instance_id = self.test_instance_id - def test_http_client_does_not_use_certificate(self): + def test_http_client_does_not_use_certificate_for_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(None)], self.AzureEndpointHttpClient.call_args_list) + def test_http_client_does_not_use_certificate_for_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + self.assertEqual( + [mock.call(None)], + self.AzureEndpointHttpClient.call_args_list) + def test_correct_url_used_for_goalstate_during_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = wa_shim() shim.register_with_azure_and_fetch_data() - get = self.AzureEndpointHttpClient.return_value.get + m_get = self.AzureEndpointHttpClient.return_value.get + self.assertEqual( + [mock.call('http://test_endpoint/machine/?comp=goalstate')], + m_get.call_args_list) + self.assertEqual( + [mock.call( + m_get.return_value.contents, + self.AzureEndpointHttpClient.return_value, + False + )], + self.GoalState.call_args_list) + + def test_correct_url_used_for_goalstate_during_report_failure(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + m_get = self.AzureEndpointHttpClient.return_value.get self.assertEqual( [mock.call('http://test_endpoint/machine/?comp=goalstate')], - get.call_args_list) + m_get.call_args_list) self.assertEqual( [mock.call( - get.return_value.contents, + m_get.return_value.contents, self.AzureEndpointHttpClient.return_value, False )], @@ -670,6 +933,16 @@ class TestWALinuxAgentShim(CiTestCase): self.AzureEndpointHttpClient.return_value.post .call_args_list) + def test_correct_url_used_for_report_failure(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + expected_url = 'http://test_endpoint/machine?comp=health' + self.assertEqual( + [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], + self.AzureEndpointHttpClient.return_value.post + .call_args_list) + def test_goal_state_values_used_for_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() @@ -681,44 +954,128 @@ class TestWALinuxAgentShim(CiTestCase): self.assertIn(self.test_container_id, posted_document) self.assertIn(self.test_instance_id, posted_document) - def test_xml_elems_in_report_ready(self): + def test_goal_state_values_used_for_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data'] + ) + self.assertIn(self.test_incarnation, posted_document) + self.assertIn(self.test_container_id, posted_document) + self.assertIn(self.test_instance_id, posted_document) + + def test_xml_elems_in_report_ready_post(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() health_document = HEALTH_REPORT_XML_TEMPLATE.format( - incarnation=self.test_incarnation, - container_id=self.test_container_id, - instance_id=self.test_instance_id, - health_status='Ready', + incarnation=escape(self.test_incarnation), + container_id=escape(self.test_container_id), + instance_id=escape(self.test_instance_id), + health_status=escape('Ready'), health_detail_subsection='') posted_document = ( self.AzureEndpointHttpClient.return_value.post .call_args[1]['data']) self.assertEqual(health_document, posted_document) + def test_xml_elems_in_report_failure_post(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + health_document = HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=escape(self.test_incarnation), + container_id=escape(self.test_container_id), + instance_id=escape(self.test_instance_id), + health_status=escape('NotReady'), + health_detail_subsection=HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE + .format( + health_substatus=escape('ProvisioningFailed'), + health_description=escape('TestDesc'))) + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data']) + self.assertEqual(health_document, posted_document) + + @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True) + def test_register_with_azure_and_fetch_data_calls_send_ready_signal( + self, m_goal_state_health_reporter): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + self.assertEqual( + 1, + m_goal_state_health_reporter.return_value.send_ready_signal + .call_count) + + @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True) + def test_register_with_azure_and_report_failure_calls_send_failure_signal( + self, m_goal_state_health_reporter): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + m_goal_state_health_reporter.return_value.send_failure_signal \ + .assert_called_once_with(description='TestDesc') + + def test_register_with_azure_and_report_failure_does_not_need_certificates( + self): + shim = wa_shim() + with mock.patch.object( + shim, '_fetch_goal_state_from_azure', autospec=True + ) as m_fetch_goal_state_from_azure: + shim.register_with_azure_and_report_failure(description='TestDesc') + m_fetch_goal_state_from_azure.assert_called_once_with( + need_certificate=False) + def test_clean_up_can_be_called_at_any_time(self): shim = wa_shim() shim.clean_up() + def test_openssl_manager_not_instantiated_by_shim_report_status(self): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + shim.register_with_azure_and_report_failure(description='TestDesc') + shim.clean_up() + self.OpenSSLManager.assert_not_called() + def test_clean_up_after_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() shim.clean_up() - self.assertEqual( - 0, self.OpenSSLManager.return_value.clean_up.call_count) + self.OpenSSLManager.return_value.clean_up.assert_not_called() + + def test_clean_up_after_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + shim.clean_up() + self.OpenSSLManager.return_value.clean_up.assert_not_called() def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): self.AzureEndpointHttpClient.return_value.get \ - .side_effect = (SentinelException) + .side_effect = SentinelException shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_failure_raises_exc_on_get_exc(self): + self.AzureEndpointHttpClient.return_value.get \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): self.GoalState.side_effect = SentinelException shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_failure_raises_exc_on_parse_exc( + self): + self.GoalState.side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + def test_failure_to_send_report_ready_health_doc_bubbles_up(self): self.AzureEndpointHttpClient.return_value.post \ .side_effect = SentinelException @@ -726,56 +1083,132 @@ class TestWALinuxAgentShim(CiTestCase): self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_failure_to_send_report_failure_health_doc_bubbles_up(self): + self.AzureEndpointHttpClient.return_value.post \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_data_from_shim_returned(self, shim): + def setUp(self): + super(TestGetMetadataGoalStateXMLAndReportReadyToFabric, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_shim = patches.enter_context( + mock.patch.object(azure_helper, 'WALinuxAgentShim')) + + def test_data_from_shim_returned(self): ret = azure_helper.get_metadata_from_fabric() self.assertEqual( - shim.return_value.register_with_azure_and_fetch_data.return_value, + self.m_shim.return_value.register_with_azure_and_fetch_data + .return_value, ret) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_success_calls_clean_up(self, shim): + def test_success_calls_clean_up(self): azure_helper.get_metadata_from_fabric() - self.assertEqual(1, shim.return_value.clean_up.call_count) + self.assertEqual(1, self.m_shim.return_value.clean_up.call_count) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_failure_in_registration_propagates_exc_and_calls_clean_up( - self, shim): - shim.return_value.register_with_azure_and_fetch_data.side_effect = ( - SentinelException) + self): + self.m_shim.return_value.register_with_azure_and_fetch_data \ + .side_effect = SentinelException self.assertRaises(SentinelException, azure_helper.get_metadata_from_fabric) - self.assertEqual(1, shim.return_value.clean_up.call_count) + self.assertEqual(1, self.m_shim.return_value.clean_up.call_count) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_calls_shim_register_with_azure_and_fetch_data(self, shim): + def test_calls_shim_register_with_azure_and_fetch_data(self): m_pubkey_info = mock.MagicMock() azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info) self.assertEqual( 1, - shim.return_value + self.m_shim.return_value .register_with_azure_and_fetch_data.call_count) self.assertEqual( mock.call(pubkey_info=m_pubkey_info), - shim.return_value + self.m_shim.return_value .register_with_azure_and_fetch_data.call_args) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_instantiates_shim_with_kwargs(self, shim): + def test_instantiates_shim_with_kwargs(self): m_fallback_lease_file = mock.MagicMock() m_dhcp_options = mock.MagicMock() azure_helper.get_metadata_from_fabric( fallback_lease_file=m_fallback_lease_file, dhcp_opts=m_dhcp_options) - self.assertEqual(1, shim.call_count) + self.assertEqual(1, self.m_shim.call_count) self.assertEqual( mock.call( fallback_lease_file=m_fallback_lease_file, dhcp_options=m_dhcp_options), - shim.call_args) + self.m_shim.call_args) + + +class TestGetMetadataGoalStateXMLAndReportFailureToFabric(CiTestCase): + + def setUp(self): + super( + TestGetMetadataGoalStateXMLAndReportFailureToFabric, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_shim = patches.enter_context( + mock.patch.object(azure_helper, 'WALinuxAgentShim')) + + def test_success_calls_clean_up(self): + azure_helper.report_failure_to_fabric() + self.assertEqual( + 1, + self.m_shim.return_value.clean_up.call_count) + + def test_failure_in_shim_report_failure_propagates_exc_and_calls_clean_up( + self): + self.m_shim.return_value.register_with_azure_and_report_failure \ + .side_effect = SentinelException + self.assertRaises(SentinelException, + azure_helper.report_failure_to_fabric) + self.assertEqual( + 1, + self.m_shim.return_value.clean_up.call_count) + + def test_report_failure_to_fabric_with_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric(description='TestDesc') + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with(description='TestDesc') + + def test_report_failure_to_fabric_with_no_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric() + # default err message description should be shown to the user + # if no description is passed in + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with( + description=azure_helper + .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_report_failure_to_fabric_empty_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric(description='') + # default err message description should be shown to the user + # if an empty description is passed in + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with( + description=azure_helper + .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_instantiates_shim_with_kwargs(self): + m_fallback_lease_file = mock.MagicMock() + m_dhcp_options = mock.MagicMock() + azure_helper.report_failure_to_fabric( + fallback_lease_file=m_fallback_lease_file, + dhcp_opts=m_dhcp_options) + self.m_shim.assert_called_once_with( + fallback_lease_file=m_fallback_lease_file, + dhcp_options=m_dhcp_options) class TestExtractIpAddressFromNetworkd(CiTestCase): -- cgit v1.2.3 From 6df0230b1201d6bed8661b19d8f3758797635377 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Wed, 18 Nov 2020 10:02:56 -0800 Subject: Azure helper: Increase Azure Endpoint HTTP retries (#619) Increase Azure Endpoint HTTP retries to handle occasional platform network blips. Introduce a common method http_with_retries in the azure.py helper, which will serve as the common HTTP request handler for all HTTP requests with the Azure endpoint. This method has builtin retries and reporting diagnostics logic. --- cloudinit/sources/helpers/azure.py | 55 ++++- .../unittests/test_datasource/test_azure_helper.py | 227 +++++++++++++++++---- 2 files changed, 241 insertions(+), 41 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 951c7a10..2b3303c7 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -284,6 +284,54 @@ def _get_dhcp_endpoint_option_name(): return azure_endpoint +@azure_ds_telemetry_reporter +def http_with_retries(url, **kwargs) -> str: + """Wrapper around url_helper.readurl() with custom telemetry logging + that url_helper.readurl() does not provide. + """ + exc = None + + max_readurl_attempts = 240 + default_readurl_timeout = 5 + periodic_logging_attempts = 12 + + if 'timeout' not in kwargs: + kwargs['timeout'] = default_readurl_timeout + + # remove kwargs that cause url_helper.readurl to retry, + # since we are already implementing our own retry logic. + if kwargs.pop('retries', None): + LOG.warning( + 'Ignoring retries kwarg passed in for ' + 'communication with Azure endpoint.') + if kwargs.pop('infinite', None): + LOG.warning( + 'Ignoring infinite kwarg passed in for communication ' + 'with Azure endpoint.') + + for attempt in range(1, max_readurl_attempts + 1): + try: + ret = url_helper.readurl(url, **kwargs) + + report_diagnostic_event( + 'Successful HTTP request with Azure endpoint %s after ' + '%d attempts' % (url, attempt), + logger_func=LOG.debug) + + return ret + + except Exception as e: + exc = e + if attempt % periodic_logging_attempts == 0: + report_diagnostic_event( + 'Failed HTTP request with Azure endpoint %s during ' + 'attempt %d with exception: %s' % + (url, attempt, e), + logger_func=LOG.debug) + + raise exc + + class AzureEndpointHttpClient: headers = { @@ -302,16 +350,15 @@ class AzureEndpointHttpClient: if secure: headers = self.headers.copy() headers.update(self.extra_secure_headers) - return url_helper.readurl(url, headers=headers, - timeout=5, retries=10, sec_between=5) + return http_with_retries(url, headers=headers) def post(self, url, data=None, extra_headers=None): headers = self.headers if extra_headers is not None: headers = self.headers.copy() headers.update(extra_headers) - return url_helper.readurl(url, data=data, headers=headers, - timeout=5, retries=10, sec_between=5) + return http_with_retries( + url, data=data, headers=headers) class InvalidGoalStateXMLException(Exception): diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index adf68857..b8899807 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,5 +1,6 @@ # This file is part of cloud-init. See LICENSE file for license information. +import copy import os import re import unittest @@ -291,29 +292,25 @@ class TestAzureEndpointHttpClient(CiTestCase): super(TestAzureEndpointHttpClient, self).setUp() patches = ExitStack() self.addCleanup(patches.close) - - self.readurl = patches.enter_context( - mock.patch.object(azure_helper.url_helper, 'readurl')) - patches.enter_context( - mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + self.m_http_with_retries = patches.enter_context( + mock.patch.object(azure_helper, 'http_with_retries')) def test_non_secure_get(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) url = 'MyTestUrl' response = client.get(url, secure=False) - self.assertEqual(1, self.readurl.call_count) - self.assertEqual(self.readurl.return_value, response) + self.assertEqual(1, self.m_http_with_retries.call_count) + self.assertEqual(self.m_http_with_retries.return_value, response) self.assertEqual( - mock.call(url, headers=self.regular_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, headers=self.regular_headers), + self.m_http_with_retries.call_args) def test_non_secure_get_raises_exception(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException url = 'MyTestUrl' - with self.assertRaises(SentinelException): - client.get(url, secure=False) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises(SentinelException, client.get, url, secure=False) + self.assertEqual(1, self.m_http_with_retries.call_count) def test_secure_get(self): url = 'MyTestUrl' @@ -325,39 +322,37 @@ class TestAzureEndpointHttpClient(CiTestCase): }) client = azure_helper.AzureEndpointHttpClient(m_certificate) response = client.get(url, secure=True) - self.assertEqual(1, self.readurl.call_count) - self.assertEqual(self.readurl.return_value, response) + self.assertEqual(1, self.m_http_with_retries.call_count) + self.assertEqual(self.m_http_with_retries.return_value, response) self.assertEqual( - mock.call(url, headers=expected_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, headers=expected_headers), + self.m_http_with_retries.call_args) def test_secure_get_raises_exception(self): url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException - with self.assertRaises(SentinelException): - client.get(url, secure=True) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises(SentinelException, client.get, url, secure=True) + self.assertEqual(1, self.m_http_with_retries.call_count) def test_post(self): m_data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) response = client.post(url, data=m_data) - self.assertEqual(1, self.readurl.call_count) - self.assertEqual(self.readurl.return_value, response) + self.assertEqual(1, self.m_http_with_retries.call_count) + self.assertEqual(self.m_http_with_retries.return_value, response) self.assertEqual( - mock.call(url, data=m_data, headers=self.regular_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, data=m_data, headers=self.regular_headers), + self.m_http_with_retries.call_args) def test_post_raises_exception(self): m_data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException - with self.assertRaises(SentinelException): - client.post(url, data=m_data) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises(SentinelException, client.post, url, data=m_data) + self.assertEqual(1, self.m_http_with_retries.call_count) def test_post_with_extra_headers(self): url = 'MyTestUrl' @@ -366,21 +361,179 @@ class TestAzureEndpointHttpClient(CiTestCase): client.post(url, extra_headers=extra_headers) expected_headers = self.regular_headers.copy() expected_headers.update(extra_headers) - self.assertEqual(1, self.readurl.call_count) + self.assertEqual(1, self.m_http_with_retries.call_count) self.assertEqual( - mock.call(mock.ANY, data=mock.ANY, headers=expected_headers, - timeout=5, retries=10, sec_between=5), - self.readurl.call_args) + mock.call(url, data=mock.ANY, headers=expected_headers), + self.m_http_with_retries.call_args) def test_post_with_sleep_with_extra_headers_raises_exception(self): m_data = mock.MagicMock() url = 'MyTestUrl' extra_headers = {'test': 'header'} client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - self.readurl.side_effect = SentinelException - with self.assertRaises(SentinelException): - client.post( - url, data=m_data, extra_headers=extra_headers) + self.m_http_with_retries.side_effect = SentinelException + self.assertRaises( + SentinelException, client.post, + url, data=m_data, extra_headers=extra_headers) + self.assertEqual(1, self.m_http_with_retries.call_count) + + +class TestAzureHelperHttpWithRetries(CiTestCase): + + with_logs = True + + max_readurl_attempts = 240 + default_readurl_timeout = 5 + periodic_logging_attempts = 12 + + def setUp(self): + super(TestAzureHelperHttpWithRetries, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_readurl = patches.enter_context( + mock.patch.object( + azure_helper.url_helper, 'readurl', mock.MagicMock())) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + + def test_http_with_retries(self): + self.m_readurl.return_value = 'TestResp' + self.assertEqual( + azure_helper.http_with_retries('testurl'), + self.m_readurl.return_value) + self.assertEqual(self.m_readurl.call_count, 1) + + def test_http_with_retries_propagates_readurl_exc_and_logs_exc( + self): + self.m_readurl.side_effect = SentinelException + + self.assertRaises( + SentinelException, azure_helper.http_with_retries, 'testurl') + self.assertEqual(self.m_readurl.call_count, self.max_readurl_attempts) + + self.assertIsNotNone( + re.search( + r'Failed HTTP request with Azure endpoint \S* during ' + r'attempt \d+ with exception: \S*', + self.logs.getvalue())) + self.assertIsNone( + re.search( + r'Successful HTTP request with Azure endpoint \S* after ' + r'\d+ attempts', + self.logs.getvalue())) + + def test_http_with_retries_delayed_success_due_to_temporary_readurl_exc( + self): + self.m_readurl.side_effect = \ + [SentinelException] * self.periodic_logging_attempts + \ + ['TestResp'] + self.m_readurl.return_value = 'TestResp' + + response = azure_helper.http_with_retries('testurl') + self.assertEqual( + response, + self.m_readurl.return_value) + self.assertEqual( + self.m_readurl.call_count, + self.periodic_logging_attempts + 1) + + def test_http_with_retries_long_delay_logs_periodic_failure_msg(self): + self.m_readurl.side_effect = \ + [SentinelException] * self.periodic_logging_attempts + \ + ['TestResp'] + self.m_readurl.return_value = 'TestResp' + + azure_helper.http_with_retries('testurl') + + self.assertEqual( + self.m_readurl.call_count, + self.periodic_logging_attempts + 1) + self.assertIsNotNone( + re.search( + r'Failed HTTP request with Azure endpoint \S* during ' + r'attempt \d+ with exception: \S*', + self.logs.getvalue())) + self.assertIsNotNone( + re.search( + r'Successful HTTP request with Azure endpoint \S* after ' + r'\d+ attempts', + self.logs.getvalue())) + + def test_http_with_retries_short_delay_does_not_log_periodic_failure_msg( + self): + self.m_readurl.side_effect = \ + [SentinelException] * \ + (self.periodic_logging_attempts - 1) + \ + ['TestResp'] + self.m_readurl.return_value = 'TestResp' + + azure_helper.http_with_retries('testurl') + self.assertEqual( + self.m_readurl.call_count, + self.periodic_logging_attempts) + + self.assertIsNone( + re.search( + r'Failed HTTP request with Azure endpoint \S* during ' + r'attempt \d+ with exception: \S*', + self.logs.getvalue())) + self.assertIsNotNone( + re.search( + r'Successful HTTP request with Azure endpoint \S* after ' + r'\d+ attempts', + self.logs.getvalue())) + + def test_http_with_retries_calls_url_helper_readurl_with_args_kwargs(self): + testurl = mock.MagicMock() + kwargs = { + 'headers': mock.MagicMock(), + 'data': mock.MagicMock(), + # timeout kwarg should not be modified or deleted if present + 'timeout': mock.MagicMock() + } + azure_helper.http_with_retries(testurl, **kwargs) + self.m_readurl.assert_called_once_with(testurl, **kwargs) + + def test_http_with_retries_adds_timeout_kwarg_if_not_present(self): + testurl = mock.MagicMock() + kwargs = { + 'headers': mock.MagicMock(), + 'data': mock.MagicMock() + } + expected_kwargs = copy.deepcopy(kwargs) + expected_kwargs['timeout'] = self.default_readurl_timeout + + azure_helper.http_with_retries(testurl, **kwargs) + self.m_readurl.assert_called_once_with(testurl, **expected_kwargs) + + def test_http_with_retries_deletes_retries_kwargs_passed_in( + self): + """http_with_retries already implements retry logic, + so url_helper.readurl should not have retries. + http_with_retries should delete kwargs that + cause url_helper.readurl to retry. + """ + testurl = mock.MagicMock() + kwargs = { + 'headers': mock.MagicMock(), + 'data': mock.MagicMock(), + 'timeout': mock.MagicMock(), + 'retries': mock.MagicMock(), + 'infinite': mock.MagicMock() + } + expected_kwargs = copy.deepcopy(kwargs) + expected_kwargs.pop('retries', None) + expected_kwargs.pop('infinite', None) + + azure_helper.http_with_retries(testurl, **kwargs) + self.m_readurl.assert_called_once_with(testurl, **expected_kwargs) + self.assertIn( + 'retries kwarg passed in for communication with Azure endpoint.', + self.logs.getvalue()) + self.assertIn( + 'infinite kwarg passed in for communication with Azure endpoint.', + self.logs.getvalue()) class TestOpenSSLManager(CiTestCase): -- cgit v1.2.3 From 73e704e3690611625e3cda060a7a6a81492af9d2 Mon Sep 17 00:00:00 2001 From: Anh Vo Date: Thu, 19 Nov 2020 13:38:27 -0500 Subject: DataSourceAzure: push dmesg log to KVP (#670) Pushing dmesg log to KVP to help troubleshoot VM boot issues --- cloudinit/sources/helpers/azure.py | 12 +++++++++++- tests/unittests/test_reporting_hyperv.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) (limited to 'cloudinit/sources/helpers/azure.py') diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 2b3303c7..d3055d08 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -224,7 +224,8 @@ def push_log_to_kvp(file_name=CFG_BUILTIN['def_log_file']): based on the file size. The first time this function is called after VM boot, It will push the last n bytes of the log file such that n < MAX_LOG_TO_KVP_LENGTH - If called again on the same boot, it continues from where it left off.""" + If called again on the same boot, it continues from where it left off. + In addition to cloud-init.log, dmesg log will also be collected.""" start_index = get_last_log_byte_pushed_to_kvp_index() @@ -245,6 +246,15 @@ def push_log_to_kvp(file_name=CFG_BUILTIN['def_log_file']): "Exception when dumping log file: %s" % repr(ex), logger_func=LOG.warning) + LOG.debug("Dumping dmesg log to KVP") + try: + out, _ = subp.subp(['dmesg'], decode=False, capture=True) + report_compressed_event("dmesg", out) + except Exception as ex: + report_diagnostic_event( + "Exception when dumping dmesg log: %s" % repr(ex), + logger_func=LOG.warning) + @azure_ds_telemetry_reporter def get_last_log_byte_pushed_to_kvp_index(): diff --git a/tests/unittests/test_reporting_hyperv.py b/tests/unittests/test_reporting_hyperv.py index 8f7b3694..9324b78d 100644 --- a/tests/unittests/test_reporting_hyperv.py +++ b/tests/unittests/test_reporting_hyperv.py @@ -230,8 +230,33 @@ class TextKvpReporter(CiTestCase): instantiated_handler_registry.unregister_item("telemetry", force=False) + @mock.patch('cloudinit.sources.helpers.azure.report_compressed_event') + @mock.patch('cloudinit.sources.helpers.azure.report_diagnostic_event') + @mock.patch('cloudinit.subp.subp') + def test_push_log_to_kvp_exception_handling(self, m_subp, m_diag, m_com): + reporter = HyperVKvpReportingHandler(kvp_file_path=self.tmp_file_path) + try: + instantiated_handler_registry.register_item("telemetry", reporter) + log_file = self.tmp_path("cloud-init.log") + azure.MAX_LOG_TO_KVP_LENGTH = 100 + azure.LOG_PUSHED_TO_KVP_INDEX_FILE = self.tmp_path( + 'log_pushed_to_kvp') + with open(log_file, "w") as f: + log_content = "A" * 50 + "B" * 100 + f.write(log_content) + + m_com.side_effect = Exception("Mock Exception") + azure.push_log_to_kvp(log_file) + + # exceptions will trigger diagnostic reporting calls + self.assertEqual(m_diag.call_count, 3) + finally: + instantiated_handler_registry.unregister_item("telemetry", + force=False) + + @mock.patch('cloudinit.subp.subp') @mock.patch.object(LogHandler, 'publish_event') - def test_push_log_to_kvp(self, publish_event): + def test_push_log_to_kvp(self, publish_event, m_subp): reporter = HyperVKvpReportingHandler(kvp_file_path=self.tmp_file_path) try: instantiated_handler_registry.register_item("telemetry", reporter) @@ -249,6 +274,10 @@ class TextKvpReporter(CiTestCase): f.write(extra_content) azure.push_log_to_kvp(log_file) + # make sure dmesg is called every time + m_subp.assert_called_with( + ['dmesg'], capture=True, decode=False) + for call_arg in publish_event.call_args_list: event = call_arg[0][0] self.assertNotEqual( -- cgit v1.2.3