diff options
author | Johnson Shi <Johnson.Shi@microsoft.com> | 2020-11-18 09:34:04 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-11-18 12:34:04 -0500 |
commit | d807df288f8cef29ca74f0b00c326b084e825782 (patch) | |
tree | 4102e9fda39c03da57cd875692015c41b73f5055 | |
parent | 96d21dfbee308cd8fe00809184f78da9231ece4a (diff) | |
download | vyos-cloud-init-d807df288f8cef29ca74f0b00c326b084e825782.tar.gz vyos-cloud-init-d807df288f8cef29ca74f0b00c326b084e825782.zip |
DataSourceAzure: send failure signal on Azure datasource failure (#594)
On systems where the Azure datasource
is a viable platform for crawling metadata,
cloud-init occasionally encounters fatal
irrecoverable errors during the crawling
of the Azure datasource.
When this happens, cloud-init crashes,
and Azure VM provisioning would fail.
However, instead of failing immediately,
the user will continue seeing provisioning
for a long time until it times out with
"OS Provisioning Timed Out" message.
In these situations, cloud-init should
report failure to the Azure datasource
endpoint indicating provisioning failure.
The user will immediately see provisioning
terminate, giving them a much better
failure experience instead of pointlessly
waiting for OS provisioning timeout.
-rwxr-xr-x | cloudinit/sources/DataSourceAzure.py | 73 | ||||
-rwxr-xr-x | cloudinit/sources/helpers/azure.py | 80 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 322 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 569 |
4 files changed, 921 insertions, 123 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index fa3e0a2b..ab139b8d 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -29,6 +29,7 @@ from cloudinit import util from cloudinit.reporting import events from cloudinit.sources.helpers.azure import ( + DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE, azure_ds_reporter, azure_ds_telemetry_reporter, get_metadata_from_fabric, @@ -38,7 +39,8 @@ from cloudinit.sources.helpers.azure import ( EphemeralDHCPv4WithReporting, is_byte_swapped, dhcp_log_cb, - push_log_to_kvp) + push_log_to_kvp, + report_failure_to_fabric) LOG = logging.getLogger(__name__) @@ -508,8 +510,9 @@ class DataSourceAzure(sources.DataSource): if perform_reprovision: LOG.info("Reporting ready to Azure after getting ReprovisionData") - use_cached_ephemeral = (net.is_up(self.fallback_interface) and - getattr(self, '_ephemeral_dhcp_ctx', None)) + use_cached_ephemeral = ( + self.distro.networking.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None)) if use_cached_ephemeral: self._report_ready(lease=self._ephemeral_dhcp_ctx.lease) self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral @@ -560,9 +563,14 @@ class DataSourceAzure(sources.DataSource): logfunc=LOG.debug, msg='Crawl of metadata service', func=self.crawl_metadata ) - except sources.InvalidMetaDataException as e: - LOG.warning('Could not crawl Azure metadata: %s', e) + except Exception as e: + report_diagnostic_event( + 'Could not crawl Azure metadata: %s' % e, + logger_func=LOG.error) + self._report_failure( + description=DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) return False + if (self.distro and self.distro.name == 'ubuntu' and self.ds_cfg.get('apply_network_config')): maybe_remove_ubuntu_network_config_scripts() @@ -785,6 +793,61 @@ class DataSourceAzure(sources.DataSource): return return_val @azure_ds_telemetry_reporter + def _report_failure(self, description=None) -> bool: + """Tells the Azure fabric that provisioning has failed. + + @param description: A description of the error encountered. + @return: The success status of sending the failure signal. + """ + unknown_245_key = 'unknown-245' + + try: + if (self.distro.networking.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None) and + getattr(self._ephemeral_dhcp_ctx, 'lease', None) and + unknown_245_key in self._ephemeral_dhcp_ctx.lease): + report_diagnostic_event( + 'Using cached ephemeral dhcp context ' + 'to report failure to Azure', logger_func=LOG.debug) + report_failure_to_fabric( + dhcp_opts=self._ephemeral_dhcp_ctx.lease[unknown_245_key], + description=description) + self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using ' + 'cached ephemeral dhcp context: %s' % e, + logger_func=LOG.error) + + try: + report_diagnostic_event( + 'Using new ephemeral dhcp to report failure to Azure', + logger_func=LOG.debug) + with EphemeralDHCPv4WithReporting(azure_ds_reporter) as lease: + report_failure_to_fabric( + dhcp_opts=lease[unknown_245_key], + description=description) + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using new ephemeral dhcp: %s' % e, + logger_func=LOG.debug) + + try: + report_diagnostic_event( + 'Using fallback lease to report failure to Azure') + report_failure_to_fabric( + fallback_lease_file=self.dhclient_lease_file, + description=description) + return True + except Exception as e: + report_diagnostic_event( + 'Failed to report failure using fallback lease: %s' % e, + logger_func=LOG.debug) + + return False + def _report_ready(self, lease: dict) -> bool: """Tells the fabric provisioning has completed. diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 4071a50e..951c7a10 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -17,6 +17,7 @@ from cloudinit import stages from cloudinit import temp_utils from contextlib import contextmanager from xml.etree import ElementTree +from xml.sax.saxutils import escape from cloudinit import subp from cloudinit import url_helper @@ -50,6 +51,11 @@ azure_ds_reporter = events.ReportEventStack( description="initialize reporter for azure ds", reporting_enabled=True) +DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE = ( + 'The VM encountered an error during deployment. ' + 'Please visit https://aka.ms/linuxprovisioningerror ' + 'for more information on remediation.') + def azure_ds_telemetry_reporter(func): def impl(*args, **kwargs): @@ -379,12 +385,20 @@ class OpenSSLManager: def __init__(self): self.tmpdir = temp_utils.mkdtemp() - self.certificate = None + self._certificate = None self.generate_certificate() def clean_up(self): util.del_dir(self.tmpdir) + @property + def certificate(self): + return self._certificate + + @certificate.setter + def certificate(self, value): + self._certificate = value + @azure_ds_telemetry_reporter def generate_certificate(self): LOG.debug('Generating certificate for communication with fabric...') @@ -507,6 +521,10 @@ class GoalStateHealthReporter: ''') PROVISIONING_SUCCESS_STATUS = 'Ready' + PROVISIONING_NOT_READY_STATUS = 'NotReady' + PROVISIONING_FAILURE_SUBSTATUS = 'ProvisioningFailed' + + HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512 def __init__( self, goal_state: GoalState, @@ -545,19 +563,39 @@ class GoalStateHealthReporter: LOG.info('Reported ready to Azure fabric.') + @azure_ds_telemetry_reporter + def send_failure_signal(self, description: str) -> None: + document = self.build_report( + incarnation=self._goal_state.incarnation, + container_id=self._goal_state.container_id, + instance_id=self._goal_state.instance_id, + status=self.PROVISIONING_NOT_READY_STATUS, + substatus=self.PROVISIONING_FAILURE_SUBSTATUS, + description=description) + try: + self._post_health_report(document=document) + except Exception as e: + msg = "exception while reporting failure: %s" % e + report_diagnostic_event(msg, logger_func=LOG.error) + raise + + LOG.warning('Reported failure to Azure fabric.') + def build_report( self, incarnation: str, container_id: str, instance_id: str, status: str, substatus=None, description=None) -> str: health_detail = '' if substatus is not None: health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( - health_substatus=substatus, health_description=description) + health_substatus=escape(substatus), + health_description=escape( + description[:self.HEALTH_REPORT_DESCRIPTION_TRIM_LEN])) health_report = self.HEALTH_REPORT_XML_TEMPLATE.format( - incarnation=incarnation, - container_id=container_id, - instance_id=instance_id, - health_status=status, + incarnation=escape(str(incarnation)), + container_id=escape(container_id), + instance_id=escape(instance_id), + health_status=escape(status), health_detail_subsection=health_detail) return health_report @@ -798,12 +836,27 @@ class WALinuxAgentShim: return {'public-keys': ssh_keys} @azure_ds_telemetry_reporter + def register_with_azure_and_report_failure(self, description: str) -> None: + """Gets the VM's GoalState from Azure, uses the GoalState information + to report failure/send provisioning failure signal to Azure. + + @param: user visible error description of provisioning failure. + """ + if self.azure_endpoint_client is None: + self.azure_endpoint_client = AzureEndpointHttpClient(None) + goal_state = self._fetch_goal_state_from_azure(need_certificate=False) + health_reporter = GoalStateHealthReporter( + goal_state, self.azure_endpoint_client, self.endpoint) + health_reporter.send_failure_signal(description=description) + + @azure_ds_telemetry_reporter def _fetch_goal_state_from_azure( self, need_certificate: bool) -> GoalState: """Fetches the GoalState XML from the Azure endpoint, parses the XML, and returns a GoalState object. + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML """ unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() @@ -844,6 +897,7 @@ class WALinuxAgentShim: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_goal_state_xml: GoalState XML string + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML """ try: @@ -942,6 +996,20 @@ def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, shim.clean_up() +@azure_ds_telemetry_reporter +def report_failure_to_fabric(fallback_lease_file=None, dhcp_opts=None, + description=None): + shim = WALinuxAgentShim(fallback_lease_file=fallback_lease_file, + dhcp_options=dhcp_opts) + if not description: + description = DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE + try: + shim.register_with_azure_and_report_failure( + description=description) + finally: + shim.clean_up() + + def dhcp_log_cb(out, err): report_diagnostic_event( "dhclient output stream: %s" % out, logger_func=LOG.debug) diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 433fbc66..d9752ab7 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -461,6 +461,8 @@ class TestGetMetadataFromIMDS(HttprettyTestCase): class TestAzureDataSource(CiTestCase): + with_logs = True + def setUp(self): super(TestAzureDataSource, self).setUp() self.tmp = self.tmp_dir() @@ -549,9 +551,12 @@ scbus-1 on xpt0 bus 0 dsaz.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d - self.get_metadata_from_fabric = mock.MagicMock(return_value={ - 'public-keys': [], - }) + self.m_is_platform_viable = mock.MagicMock(autospec=True) + self.m_get_metadata_from_fabric = mock.MagicMock( + return_value={'public-keys': []}) + self.m_report_failure_to_fabric = mock.MagicMock(autospec=True) + self.m_ephemeral_dhcpv4 = mock.MagicMock() + self.m_ephemeral_dhcpv4_with_reporting = mock.MagicMock() self.instance_id = 'D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8' @@ -568,7 +573,17 @@ scbus-1 on xpt0 bus 0 (dsaz, 'perform_hostname_bounce', mock.MagicMock()), (dsaz, 'get_hostname', mock.MagicMock()), (dsaz, 'set_hostname', mock.MagicMock()), - (dsaz, 'get_metadata_from_fabric', self.get_metadata_from_fabric), + (dsaz, '_is_platform_viable', + self.m_is_platform_viable), + (dsaz, 'get_metadata_from_fabric', + self.m_get_metadata_from_fabric), + (dsaz, 'report_failure_to_fabric', + self.m_report_failure_to_fabric), + (dsaz, 'EphemeralDHCPv4', self.m_ephemeral_dhcpv4), + (dsaz, 'EphemeralDHCPv4WithReporting', + self.m_ephemeral_dhcpv4_with_reporting), + (dsaz, 'get_boot_telemetry', mock.MagicMock()), + (dsaz, 'get_system_info', mock.MagicMock()), (dsaz.subp, 'which', lambda x: True), (dsaz.dmi, 'read_dmi_data', mock.MagicMock( side_effect=_dmi_mocks)), @@ -632,15 +647,87 @@ scbus-1 on xpt0 bus 0 dev = ds.get_resource_disk_on_freebsd(1) self.assertEqual("da1", dev) - @mock.patch(MOCKPATH + '_is_platform_viable') - def test_call_is_platform_viable_seed(self, m_is_platform_viable): + def test_not_is_platform_viable_seed_should_return_no_datasource(self): """Check seed_dir using _is_platform_viable and return False.""" # Return a non-matching asset tag value - m_is_platform_viable.return_value = False - dsrc = dsaz.DataSourceAzure( - {}, distro=mock.Mock(), paths=self.paths) - self.assertFalse(dsrc.get_data()) - m_is_platform_viable.assert_called_with(dsrc.seed_dir) + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = False + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_report_failure') as m_report_failure: + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + # Assert that for non viable platforms, + # there is no communication with the Azure datasource. + self.assertEqual( + 0, + m_crawl_metadata.call_count) + self.assertEqual( + 0, + m_report_failure.call_count) + + def test_platform_viable_but_no_devs_should_return_no_datasource(self): + """For platforms where the Azure platform is viable + (which is indicated by the matching asset tag), + the absence of any devs at all (devs == candidate sources + for crawling Azure datasource) is NOT expected. + Report failure to Azure as this is an unexpected fatal error. + """ + data = {} + dsrc = self._get_ds(data) + with mock.patch.object(dsrc, '_report_failure') as m_report_failure: + self.m_is_platform_viable.return_value = True + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + self.assertEqual( + 1, + m_report_failure.call_count) + + def test_crawl_metadata_exception_returns_no_datasource(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + ret = dsrc.get_data() + self.m_is_platform_viable.assert_called_with(dsrc.seed_dir) + self.assertEqual( + 1, + m_crawl_metadata.call_count) + self.assertFalse(ret) + self.assertNotIn('agent_invoked', data) + + def test_crawl_metadata_exception_should_report_failure_with_msg(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_report_failure') as m_report_failure: + m_crawl_metadata.side_effect = Exception + dsrc.get_data() + self.assertEqual( + 1, + m_crawl_metadata.call_count) + m_report_failure.assert_called_once_with( + description=dsaz.DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_crawl_metadata_exc_should_log_could_not_crawl_msg(self): + data = {} + dsrc = self._get_ds(data) + self.m_is_platform_viable.return_value = True + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + dsrc.get_data() + self.assertEqual( + 1, + m_crawl_metadata.call_count) + self.assertIn( + "Could not crawl Azure metadata", + self.logs.getvalue()) def test_basic_seed_dir(self): odata = {'HostName': "myhost", 'UserName': "myuser"} @@ -761,7 +848,7 @@ scbus-1 on xpt0 bus 0 'cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready') @mock.patch('cloudinit.sources.DataSourceAzure.DataSourceAzure._poll_imds') def test_crawl_metadata_on_reprovision_reports_ready( - self, poll_imds_func, report_ready_func, m_write, m_dhcp + self, poll_imds_func, m_report_ready, m_write, m_dhcp ): """If reprovisioning, report ready at the end""" ovfenv = construct_valid_ovf_env( @@ -775,18 +862,16 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) poll_imds_func.return_value = ovfenv dsrc.crawl_metadata() - self.assertEqual(1, report_ready_func.call_count) + self.assertEqual(1, m_report_ready.call_count) @mock.patch('cloudinit.sources.DataSourceAzure.util.write_file') @mock.patch('cloudinit.sources.helpers.netlink.' 'wait_for_media_disconnect_connect') @mock.patch( 'cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready') - @mock.patch('cloudinit.net.dhcp.EphemeralIPv4Network') - @mock.patch('cloudinit.net.dhcp.maybe_perform_dhcp_discovery') @mock.patch('cloudinit.sources.DataSourceAzure.readurl') def test_crawl_metadata_on_reprovision_reports_ready_using_lease( - self, m_readurl, m_dhcp, m_net, report_ready_func, + self, m_readurl, m_report_ready, m_media_switch, m_write ): """If reprovisioning, report ready using the obtained lease""" @@ -800,20 +885,30 @@ scbus-1 on xpt0 bus 0 } dsrc = self._get_ds(data) - lease = { - 'interface': 'eth9', 'fixed-address': '192.168.2.9', - 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', - 'unknown-245': '624c3620'} - m_dhcp.return_value = [lease] - m_media_switch.return_value = None + with mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: - reprovision_ovfenv = construct_valid_ovf_env() - m_readurl.return_value = url_helper.StringResponse( - reprovision_ovfenv.encode('utf-8')) + # For this mock, net should not be up, + # so that cached ephemeral won't be used. + # This is so that a NEW ephemeral dhcp lease will be discovered + # and used instead. + m_dsrc_distro_networking_is_up.return_value = False - dsrc.crawl_metadata() - self.assertEqual(2, report_ready_func.call_count) - report_ready_func.assert_called_with(lease=lease) + lease = { + 'interface': 'eth9', 'fixed-address': '192.168.2.9', + 'routers': '192.168.2.1', 'subnet-mask': '255.255.255.0', + 'unknown-245': '624c3620'} + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.return_value = lease + m_media_switch.return_value = None + + reprovision_ovfenv = construct_valid_ovf_env() + m_readurl.return_value = url_helper.StringResponse( + reprovision_ovfenv.encode('utf-8')) + + dsrc.crawl_metadata() + self.assertEqual(2, m_report_ready.call_count) + m_report_ready.assert_called_with(lease=lease) def test_waagent_d_has_0700_perms(self): # we expect /var/lib/waagent to be created 0700 @@ -971,7 +1066,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertTrue('default_user' in dsrc.cfg['system_info']) + self.assertIn('default_user', dsrc.cfg['system_info']) defuser = dsrc.cfg['system_info']['default_user'] # default user should be updated username and should not be locked. @@ -993,7 +1088,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) - self.assertTrue('default_user' in dsrc.cfg['system_info']) + self.assertIn('default_user', dsrc.cfg['system_info']) defuser = dsrc.cfg['system_info']['default_user'] # default user should be updated username and should not be locked. @@ -1021,14 +1116,6 @@ scbus-1 on xpt0 bus 0 self.assertTrue(ret) self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8')) - def test_no_datasource_expected(self): - # no source should be found if no seed_dir and no devs - data = {} - dsrc = self._get_ds({}) - ret = dsrc.get_data() - self.assertFalse(ret) - self.assertFalse('agent_invoked' in data) - def test_cfg_has_pubkeys_fingerprint(self): odata = {'HostName': "myhost", 'UserName': "myuser"} mypklist = [{'fingerprint': 'fp1', 'path': 'path1', 'value': ''}] @@ -1171,21 +1258,168 @@ scbus-1 on xpt0 bus 0 self): dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.side_effect = Exception + self.m_get_metadata_from_fabric.side_effect = Exception self.assertFalse(dsrc._report_ready(lease=mock.MagicMock())) + def test_dsaz_report_failure_returns_true_when_report_succeeds(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) + self.assertEqual( + 1, + self.m_report_failure_to_fabric.call_count) + + def test_dsaz_report_failure_returns_false_and_does_not_propagate_exc( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_ephemeral_dhcp_ctx') \ + as m_ephemeral_dhcp_ctx, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # setup mocks to allow using cached ephemeral dhcp lease + m_dsrc_distro_networking_is_up.return_value = True + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + m_ephemeral_dhcp_ctx.lease = test_lease + + # We expect 3 calls to report_failure_to_fabric, + # because we try 3 different methods of calling report failure. + # The different methods are attempted in the following order: + # 1. Using cached ephemeral dhcp context to report failure to Azure + # 2. Using new ephemeral dhcp to report failure to Azure + # 3. Using fallback lease to report failure to Azure + self.m_report_failure_to_fabric.side_effect = Exception + self.assertFalse(dsrc._report_failure()) + self.assertEqual( + 3, + self.m_report_failure_to_fabric.call_count) + + def test_dsaz_report_failure_description_msg(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + test_msg = 'Test report failure description message' + self.assertTrue(dsrc._report_failure(description=test_msg)) + self.m_report_failure_to_fabric.assert_called_once_with( + dhcp_opts=mock.ANY, description=test_msg) + + def test_dsaz_report_failure_no_description_msg(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata: + m_crawl_metadata.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) # no description msg + self.m_report_failure_to_fabric.assert_called_once_with( + dhcp_opts=mock.ANY, description=None) + + def test_dsaz_report_failure_uses_cached_ephemeral_dhcp_ctx_lease(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc, '_ephemeral_dhcp_ctx') \ + as m_ephemeral_dhcp_ctx, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # setup mocks to allow using cached ephemeral dhcp lease + m_dsrc_distro_networking_is_up.return_value = True + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + m_ephemeral_dhcp_ctx.lease = test_lease + + self.assertTrue(dsrc._report_failure()) + + # ensure called with cached ephemeral dhcp lease option 245 + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, dhcp_opts=test_lease_dhcp_option_245) + + # ensure cached ephemeral is cleaned + self.assertEqual( + 1, + m_ephemeral_dhcp_ctx.clean_network.call_count) + + def test_dsaz_report_failure_no_net_uses_new_ephemeral_dhcp_lease(self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # net is not up and cannot use cached ephemeral dhcp + m_dsrc_distro_networking_is_up.return_value = False + # setup ephemeral dhcp lease discovery mock + test_lease_dhcp_option_245 = 'test_lease_dhcp_option_245' + test_lease = {'unknown-245': test_lease_dhcp_option_245} + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.return_value = test_lease + + self.assertTrue(dsrc._report_failure()) + + # ensure called with the newly discovered + # ephemeral dhcp lease option 245 + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, dhcp_opts=test_lease_dhcp_option_245) + + def test_dsaz_report_failure_no_net_and_no_dhcp_uses_fallback_lease( + self): + dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + dsrc.ds_cfg['agent_command'] = '__builtin__' + + with mock.patch.object(dsrc, 'crawl_metadata') as m_crawl_metadata, \ + mock.patch.object(dsrc.distro.networking, 'is_up') \ + as m_dsrc_distro_networking_is_up: + # mock crawl metadata failure to cause report failure + m_crawl_metadata.side_effect = Exception + + # net is not up and cannot use cached ephemeral dhcp + m_dsrc_distro_networking_is_up.return_value = False + # ephemeral dhcp discovery failure, + # so cannot use a new ephemeral dhcp + self.m_ephemeral_dhcpv4_with_reporting.return_value \ + .__enter__.side_effect = Exception + + self.assertTrue(dsrc._report_failure()) + + # ensure called with fallback lease + self.m_report_failure_to_fabric.assert_called_once_with( + description=mock.ANY, + fallback_lease_file=dsrc.dhclient_lease_file) + def test_exception_fetching_fabric_data_doesnt_propagate(self): """Errors communicating with fabric should warn, but return True.""" dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.side_effect = Exception + self.m_get_metadata_from_fabric.side_effect = Exception ret = self._get_and_setup(dsrc) self.assertTrue(ret) def test_fabric_data_included_in_metadata(self): dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) dsrc.ds_cfg['agent_command'] = '__builtin__' - self.get_metadata_from_fabric.return_value = {'test': 'value'} + self.m_get_metadata_from_fabric.return_value = {'test': 'value'} ret = self._get_and_setup(dsrc) self.assertTrue(ret) self.assertEqual('value', dsrc.metadata['test']) @@ -2053,7 +2287,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): @mock.patch('time.sleep', mock.MagicMock()) @mock.patch(MOCKPATH + 'EphemeralDHCPv4') - def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, report_ready_func, + def test_poll_imds_re_dhcp_on_timeout(self, m_dhcpv4, m_report_ready, m_request, m_media_switch, m_dhcp, m_net): """The poll_imds will retry DHCP on IMDS timeout.""" @@ -2088,8 +2322,8 @@ class TestPreprovisioningPollIMDS(CiTestCase): dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) with mock.patch(MOCKPATH + 'REPORTED_READY_MARKER_FILE', report_file): dsa._poll_imds() - self.assertEqual(report_ready_func.call_count, 1) - report_ready_func.assert_called_with(lease=lease) + self.assertEqual(m_report_ready.call_count, 1) + m_report_ready.assert_called_with(lease=lease) self.assertEqual(3, m_dhcpv4.call_count, 'Expected 3 DHCP calls') self.assertEqual(4, self.tries, 'Expected 4 total reads from IMDS') diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 6e004e34..adf68857 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -5,6 +5,7 @@ import re import unittest from textwrap import dedent from xml.etree import ElementTree +from xml.sax.saxutils import escape, unescape from cloudinit.sources.helpers import azure as azure_helper from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -70,6 +71,15 @@ HEALTH_REPORT_XML_TEMPLATE = '''\ </Health> ''' +HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = dedent('''\ + <Details> + <SubStatus>{health_substatus}</SubStatus> + <Description>{health_description}</Description> + </Details> + ''') + +HEALTH_REPORT_DESCRIPTION_TRIM_LEN = 512 + class SentinelException(Exception): pass @@ -461,17 +471,24 @@ class TestOpenSSLManagerActions(CiTestCase): class TestGoalStateHealthReporter(CiTestCase): + maxDiff = None + default_parameters = { 'incarnation': 1634, 'container_id': 'MyContainerId', 'instance_id': 'MyInstanceId' } - test_endpoint = 'TestEndpoint' - test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) + test_azure_endpoint = 'TestEndpoint' + test_health_report_url = 'http://{0}/machine?comp=health'.format( + test_azure_endpoint) test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} provisioning_success_status = 'Ready' + provisioning_not_ready_status = 'NotReady' + provisioning_failure_substatus = 'ProvisioningFailed' + provisioning_failure_err_description = ( + 'Test error message containing provisioning failure details') def setUp(self): super(TestGoalStateHealthReporter, self).setUp() @@ -496,17 +513,40 @@ class TestGoalStateHealthReporter(CiTestCase): self.GoalState.return_value.incarnation = \ self.default_parameters['incarnation'] + def _text_from_xpath_in_xroot(self, xroot, xpath): + element = xroot.find(xpath) + if element is not None: + return element.text + return None + def _get_formatted_health_report_xml_string(self, **kwargs): return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + def _get_formatted_health_detail_subsection_xml_string(self, **kwargs): + return HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format(**kwargs) + def _get_report_ready_health_document(self): return self._get_formatted_health_report_xml_string( - incarnation=self.default_parameters['incarnation'], - container_id=self.default_parameters['container_id'], - instance_id=self.default_parameters['instance_id'], - health_status=self.provisioning_success_status, + incarnation=escape(str(self.default_parameters['incarnation'])), + container_id=escape(self.default_parameters['container_id']), + instance_id=escape(self.default_parameters['instance_id']), + health_status=escape(self.provisioning_success_status), health_detail_subsection='') + def _get_report_failure_health_document(self): + health_detail_subsection = \ + self._get_formatted_health_detail_subsection_xml_string( + health_substatus=escape(self.provisioning_failure_substatus), + health_description=escape( + self.provisioning_failure_err_description)) + + return self._get_formatted_health_report_xml_string( + incarnation=escape(str(self.default_parameters['incarnation'])), + container_id=escape(self.default_parameters['container_id']), + instance_id=escape(self.default_parameters['instance_id']), + health_status=escape(self.provisioning_not_ready_status), + health_detail_subsection=health_detail_subsection) + def test_send_ready_signal_sends_post_request(self): with mock.patch.object( azure_helper.GoalStateHealthReporter, @@ -514,55 +554,130 @@ class TestGoalStateHealthReporter(CiTestCase): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), - client, self.test_endpoint) + client, self.test_azure_endpoint) reporter.send_ready_signal() self.assertEqual(1, self.post.call_count) self.assertEqual( mock.call( - self.test_url, + self.test_health_report_url, + data=m_build_report.return_value, + extra_headers=self.test_default_headers), + self.post.call_args) + + def test_send_failure_signal_sends_post_request(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, + 'build_report') as m_build_report: + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + client, self.test_azure_endpoint) + reporter.send_failure_signal( + description=self.provisioning_failure_err_description) + + self.assertEqual(1, self.post.call_count) + self.assertEqual( + mock.call( + self.test_health_report_url, data=m_build_report.return_value, extra_headers=self.test_default_headers), self.post.call_args) - def test_build_report_for_health_document(self): + def test_build_report_for_ready_signal_health_document(self): health_document = self._get_report_ready_health_document() reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), azure_helper.AzureEndpointHttpClient(mock.MagicMock()), - self.test_endpoint) + self.test_azure_endpoint) generated_health_document = reporter.build_report( incarnation=self.default_parameters['incarnation'], container_id=self.default_parameters['container_id'], instance_id=self.default_parameters['instance_id'], status=self.provisioning_success_status) + self.assertEqual(health_document, generated_health_document) - self.assertIn( - '<GoalStateIncarnation>{}</GoalStateIncarnation>'.format( - str(self.default_parameters['incarnation'])), - generated_health_document) - self.assertIn( - ''.join([ - '<ContainerId>', - self.default_parameters['container_id'], - '</ContainerId>']), - generated_health_document) - self.assertIn( - ''.join([ - '<InstanceId>', - self.default_parameters['instance_id'], - '</InstanceId>']), - generated_health_document) - self.assertIn( - ''.join([ - '<State>', - self.provisioning_success_status, - '</State>']), - generated_health_document + + generated_xroot = ElementTree.fromstring(generated_health_document) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './GoalStateIncarnation'), + str(self.default_parameters['incarnation'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './Container/ContainerId'), + str(self.default_parameters['container_id'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/InstanceId'), + str(self.default_parameters['instance_id'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/State'), + escape(self.provisioning_success_status)) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details')) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/SubStatus')) + self.assertIsNone( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') ) - self.assertNotIn('<Details>', generated_health_document) - self.assertNotIn('<SubStatus>', generated_health_document) - self.assertNotIn('<Description>', generated_health_document) + + def test_build_report_for_failure_signal_health_document(self): + health_document = self._get_report_failure_health_document() + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=self.provisioning_failure_err_description) + + self.assertEqual(health_document, generated_health_document) + + generated_xroot = ElementTree.fromstring(generated_health_document) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './GoalStateIncarnation'), + str(self.default_parameters['incarnation'])) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, './Container/ContainerId'), + self.default_parameters['container_id']) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/InstanceId'), + self.default_parameters['instance_id']) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/State'), + escape(self.provisioning_not_ready_status)) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/' + 'SubStatus'), + escape(self.provisioning_failure_substatus)) + self.assertEqual( + self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/' + 'Description'), + escape(self.provisioning_failure_err_description)) def test_send_ready_signal_calls_build_report(self): with mock.patch.object( @@ -571,7 +686,7 @@ class TestGoalStateHealthReporter(CiTestCase): reporter = azure_helper.GoalStateHealthReporter( azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), azure_helper.AzureEndpointHttpClient(mock.MagicMock()), - self.test_endpoint) + self.test_azure_endpoint) reporter.send_ready_signal() self.assertEqual(1, m_build_report.call_count) @@ -583,6 +698,131 @@ class TestGoalStateHealthReporter(CiTestCase): status=self.provisioning_success_status), m_build_report.call_args) + def test_send_failure_signal_calls_build_report(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, 'build_report' + ) as m_build_report: + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + reporter.send_failure_signal( + description=self.provisioning_failure_err_description) + + self.assertEqual(1, m_build_report.call_count) + self.assertEqual( + mock.call( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=self.provisioning_failure_err_description), + m_build_report.call_args) + + def test_build_report_escapes_chars(self): + incarnation = 'jd8\'9*&^<\'A><A[p&o+\"SD()*&&&LKAJSD23' + container_id = '&&<\"><><ds8\'9+7&d9a86!@($09asdl;<>' + instance_id = 'Opo>>>jas\'&d;[p&fp\"a<<!!@&&' + health_status = '&<897\"6&>&aa\'sd!@&!)((*<&>' + health_substatus = '&as\"d<<a&s>d<\'^@!5&6<7' + health_description = '&&&>!#$\"&&<as\'1!@$d&>><>&\"sd<67<]>>' + + health_detail_subsection = \ + self._get_formatted_health_detail_subsection_xml_string( + health_substatus=escape(health_substatus), + health_description=escape(health_description)) + health_document = self._get_formatted_health_report_xml_string( + incarnation=escape(incarnation), + container_id=escape(container_id), + instance_id=escape(instance_id), + health_status=escape(health_status), + health_detail_subsection=health_detail_subsection) + + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + generated_health_document = reporter.build_report( + incarnation=incarnation, + container_id=container_id, + instance_id=instance_id, + status=health_status, + substatus=health_substatus, + description=health_description) + + self.assertEqual(health_document, generated_health_document) + + def test_build_report_conforms_to_length_limits(self): + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + long_err_msg = 'a9&ea8>>>e as1< d\"q2*&(^%\'a=5<' * 100 + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=long_err_msg) + + generated_xroot = ElementTree.fromstring(generated_health_document) + generated_health_report_description = self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') + self.assertEqual( + len(unescape(generated_health_report_description)), + HEALTH_REPORT_DESCRIPTION_TRIM_LEN) + + def test_trim_description_then_escape_conforms_to_len_limits_worst_case( + self): + """When unescaped characters are XML-escaped, the length increases. + Char Escape String + < < + > > + " " + ' ' + & & + + We (step 1) trim the health report XML's description field, + and then (step 2) XML-escape the health report XML's description field. + + The health report XML's description field limit within cloud-init + is HEALTH_REPORT_DESCRIPTION_TRIM_LEN. + + The Azure platform's limit on the health report XML's description field + is 4096 chars. + + For worst-case chars, there is a 5x blowup in length + when the chars are XML-escaped. + ' and " when XML-escaped have a 5x blowup. + + Ensure that (1) trimming and then (2) XML-escaping does not blow past + the Azure platform's limit for health report XML's description field + (4096 chars). + """ + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_azure_endpoint) + long_err_msg = '\'\"' * 10000 + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_not_ready_status, + substatus=self.provisioning_failure_substatus, + description=long_err_msg) + + generated_xroot = ElementTree.fromstring(generated_health_document) + generated_health_report_description = self._text_from_xpath_in_xroot( + generated_xroot, + './Container/RoleInstanceList/Role/Health/Details/Description') + # The escaped description string should be less than + # the Azure platform limit for the escaped description string. + self.assertLessEqual(len(generated_health_report_description), 4096) + class TestWALinuxAgentShim(CiTestCase): @@ -598,7 +838,7 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState = patches.enter_context( mock.patch.object(azure_helper, 'GoalState')) self.OpenSSLManager = patches.enter_context( - mock.patch.object(azure_helper, 'OpenSSLManager')) + mock.patch.object(azure_helper, 'OpenSSLManager', autospec=True)) patches.enter_context( mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) @@ -609,24 +849,47 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.return_value.container_id = self.test_container_id self.GoalState.return_value.instance_id = self.test_instance_id - def test_http_client_does_not_use_certificate(self): + def test_http_client_does_not_use_certificate_for_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(None)], self.AzureEndpointHttpClient.call_args_list) + def test_http_client_does_not_use_certificate_for_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + self.assertEqual( + [mock.call(None)], + self.AzureEndpointHttpClient.call_args_list) + def test_correct_url_used_for_goalstate_during_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = wa_shim() shim.register_with_azure_and_fetch_data() - get = self.AzureEndpointHttpClient.return_value.get + m_get = self.AzureEndpointHttpClient.return_value.get + self.assertEqual( + [mock.call('http://test_endpoint/machine/?comp=goalstate')], + m_get.call_args_list) + self.assertEqual( + [mock.call( + m_get.return_value.contents, + self.AzureEndpointHttpClient.return_value, + False + )], + self.GoalState.call_args_list) + + def test_correct_url_used_for_goalstate_during_report_failure(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + m_get = self.AzureEndpointHttpClient.return_value.get self.assertEqual( [mock.call('http://test_endpoint/machine/?comp=goalstate')], - get.call_args_list) + m_get.call_args_list) self.assertEqual( [mock.call( - get.return_value.contents, + m_get.return_value.contents, self.AzureEndpointHttpClient.return_value, False )], @@ -670,6 +933,16 @@ class TestWALinuxAgentShim(CiTestCase): self.AzureEndpointHttpClient.return_value.post .call_args_list) + def test_correct_url_used_for_report_failure(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + expected_url = 'http://test_endpoint/machine?comp=health' + self.assertEqual( + [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], + self.AzureEndpointHttpClient.return_value.post + .call_args_list) + def test_goal_state_values_used_for_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() @@ -681,44 +954,128 @@ class TestWALinuxAgentShim(CiTestCase): self.assertIn(self.test_container_id, posted_document) self.assertIn(self.test_instance_id, posted_document) - def test_xml_elems_in_report_ready(self): + def test_goal_state_values_used_for_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data'] + ) + self.assertIn(self.test_incarnation, posted_document) + self.assertIn(self.test_container_id, posted_document) + self.assertIn(self.test_instance_id, posted_document) + + def test_xml_elems_in_report_ready_post(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() health_document = HEALTH_REPORT_XML_TEMPLATE.format( - incarnation=self.test_incarnation, - container_id=self.test_container_id, - instance_id=self.test_instance_id, - health_status='Ready', + incarnation=escape(self.test_incarnation), + container_id=escape(self.test_container_id), + instance_id=escape(self.test_instance_id), + health_status=escape('Ready'), health_detail_subsection='') posted_document = ( self.AzureEndpointHttpClient.return_value.post .call_args[1]['data']) self.assertEqual(health_document, posted_document) + def test_xml_elems_in_report_failure_post(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + health_document = HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=escape(self.test_incarnation), + container_id=escape(self.test_container_id), + instance_id=escape(self.test_instance_id), + health_status=escape('NotReady'), + health_detail_subsection=HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE + .format( + health_substatus=escape('ProvisioningFailed'), + health_description=escape('TestDesc'))) + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data']) + self.assertEqual(health_document, posted_document) + + @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True) + def test_register_with_azure_and_fetch_data_calls_send_ready_signal( + self, m_goal_state_health_reporter): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + self.assertEqual( + 1, + m_goal_state_health_reporter.return_value.send_ready_signal + .call_count) + + @mock.patch.object(azure_helper, 'GoalStateHealthReporter', autospec=True) + def test_register_with_azure_and_report_failure_calls_send_failure_signal( + self, m_goal_state_health_reporter): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + m_goal_state_health_reporter.return_value.send_failure_signal \ + .assert_called_once_with(description='TestDesc') + + def test_register_with_azure_and_report_failure_does_not_need_certificates( + self): + shim = wa_shim() + with mock.patch.object( + shim, '_fetch_goal_state_from_azure', autospec=True + ) as m_fetch_goal_state_from_azure: + shim.register_with_azure_and_report_failure(description='TestDesc') + m_fetch_goal_state_from_azure.assert_called_once_with( + need_certificate=False) + def test_clean_up_can_be_called_at_any_time(self): shim = wa_shim() shim.clean_up() + def test_openssl_manager_not_instantiated_by_shim_report_status(self): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + shim.register_with_azure_and_report_failure(description='TestDesc') + shim.clean_up() + self.OpenSSLManager.assert_not_called() + def test_clean_up_after_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() shim.clean_up() - self.assertEqual( - 0, self.OpenSSLManager.return_value.clean_up.call_count) + self.OpenSSLManager.return_value.clean_up.assert_not_called() + + def test_clean_up_after_report_failure(self): + shim = wa_shim() + shim.register_with_azure_and_report_failure(description='TestDesc') + shim.clean_up() + self.OpenSSLManager.return_value.clean_up.assert_not_called() def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): self.AzureEndpointHttpClient.return_value.get \ - .side_effect = (SentinelException) + .side_effect = SentinelException shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_failure_raises_exc_on_get_exc(self): + self.AzureEndpointHttpClient.return_value.get \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): self.GoalState.side_effect = SentinelException shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_failure_raises_exc_on_parse_exc( + self): + self.GoalState.side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + def test_failure_to_send_report_ready_health_doc_bubbles_up(self): self.AzureEndpointHttpClient.return_value.post \ .side_effect = SentinelException @@ -726,56 +1083,132 @@ class TestWALinuxAgentShim(CiTestCase): self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_failure_to_send_report_failure_health_doc_bubbles_up(self): + self.AzureEndpointHttpClient.return_value.post \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_report_failure, + description='TestDesc') + class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_data_from_shim_returned(self, shim): + def setUp(self): + super(TestGetMetadataGoalStateXMLAndReportReadyToFabric, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_shim = patches.enter_context( + mock.patch.object(azure_helper, 'WALinuxAgentShim')) + + def test_data_from_shim_returned(self): ret = azure_helper.get_metadata_from_fabric() self.assertEqual( - shim.return_value.register_with_azure_and_fetch_data.return_value, + self.m_shim.return_value.register_with_azure_and_fetch_data + .return_value, ret) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_success_calls_clean_up(self, shim): + def test_success_calls_clean_up(self): azure_helper.get_metadata_from_fabric() - self.assertEqual(1, shim.return_value.clean_up.call_count) + self.assertEqual(1, self.m_shim.return_value.clean_up.call_count) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_failure_in_registration_propagates_exc_and_calls_clean_up( - self, shim): - shim.return_value.register_with_azure_and_fetch_data.side_effect = ( - SentinelException) + self): + self.m_shim.return_value.register_with_azure_and_fetch_data \ + .side_effect = SentinelException self.assertRaises(SentinelException, azure_helper.get_metadata_from_fabric) - self.assertEqual(1, shim.return_value.clean_up.call_count) + self.assertEqual(1, self.m_shim.return_value.clean_up.call_count) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_calls_shim_register_with_azure_and_fetch_data(self, shim): + def test_calls_shim_register_with_azure_and_fetch_data(self): m_pubkey_info = mock.MagicMock() azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info) self.assertEqual( 1, - shim.return_value + self.m_shim.return_value .register_with_azure_and_fetch_data.call_count) self.assertEqual( mock.call(pubkey_info=m_pubkey_info), - shim.return_value + self.m_shim.return_value .register_with_azure_and_fetch_data.call_args) - @mock.patch.object(azure_helper, 'WALinuxAgentShim') - def test_instantiates_shim_with_kwargs(self, shim): + def test_instantiates_shim_with_kwargs(self): m_fallback_lease_file = mock.MagicMock() m_dhcp_options = mock.MagicMock() azure_helper.get_metadata_from_fabric( fallback_lease_file=m_fallback_lease_file, dhcp_opts=m_dhcp_options) - self.assertEqual(1, shim.call_count) + self.assertEqual(1, self.m_shim.call_count) self.assertEqual( mock.call( fallback_lease_file=m_fallback_lease_file, dhcp_options=m_dhcp_options), - shim.call_args) + self.m_shim.call_args) + + +class TestGetMetadataGoalStateXMLAndReportFailureToFabric(CiTestCase): + + def setUp(self): + super( + TestGetMetadataGoalStateXMLAndReportFailureToFabric, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.m_shim = patches.enter_context( + mock.patch.object(azure_helper, 'WALinuxAgentShim')) + + def test_success_calls_clean_up(self): + azure_helper.report_failure_to_fabric() + self.assertEqual( + 1, + self.m_shim.return_value.clean_up.call_count) + + def test_failure_in_shim_report_failure_propagates_exc_and_calls_clean_up( + self): + self.m_shim.return_value.register_with_azure_and_report_failure \ + .side_effect = SentinelException + self.assertRaises(SentinelException, + azure_helper.report_failure_to_fabric) + self.assertEqual( + 1, + self.m_shim.return_value.clean_up.call_count) + + def test_report_failure_to_fabric_with_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric(description='TestDesc') + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with(description='TestDesc') + + def test_report_failure_to_fabric_with_no_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric() + # default err message description should be shown to the user + # if no description is passed in + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with( + description=azure_helper + .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_report_failure_to_fabric_empty_desc_calls_shim_report_failure( + self): + azure_helper.report_failure_to_fabric(description='') + # default err message description should be shown to the user + # if an empty description is passed in + self.m_shim.return_value.register_with_azure_and_report_failure \ + .assert_called_once_with( + description=azure_helper + .DEFAULT_REPORT_FAILURE_USER_VISIBLE_MESSAGE) + + def test_instantiates_shim_with_kwargs(self): + m_fallback_lease_file = mock.MagicMock() + m_dhcp_options = mock.MagicMock() + azure_helper.report_failure_to_fabric( + fallback_lease_file=m_fallback_lease_file, + dhcp_opts=m_dhcp_options) + self.m_shim.assert_called_once_with( + fallback_lease_file=m_fallback_lease_file, + dhcp_options=m_dhcp_options) class TestExtractIpAddressFromNetworkd(CiTestCase): |