diff options
Diffstat (limited to 'tests/unittests/test_datasource/test_azure_helper.py')
-rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 406 |
1 files changed, 340 insertions, 66 deletions
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 007df09f..5e6d3d2d 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,8 +1,10 @@ # This file is part of cloud-init. See LICENSE file for license information. import os -import unittest2 +import re +import unittest from textwrap import dedent +from xml.etree import ElementTree from cloudinit.sources.helpers import azure as azure_helper from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -48,6 +50,30 @@ GOAL_STATE_TEMPLATE = """\ </GoalState> """ +HEALTH_REPORT_XML_TEMPLATE = '''\ +<?xml version="1.0" encoding="utf-8"?> +<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xmlns:xsd="http://www.w3.org/2001/XMLSchema"> + <GoalStateIncarnation>{incarnation}</GoalStateIncarnation> + <Container> + <ContainerId>{container_id}</ContainerId> + <RoleInstanceList> + <Role> + <InstanceId>{instance_id}</InstanceId> + <Health> + <State>{health_status}</State> + {health_detail_subsection} + </Health> + </Role> + </RoleInstanceList> + </Container> +</Health> +''' + + +class SentinelException(Exception): + pass + class TestFindEndpoint(CiTestCase): @@ -140,9 +166,7 @@ class TestGoalStateParsing(CiTestCase): 'certificates_url': 'MyCertificatesUrl', } - def _get_goal_state(self, http_client=None, **kwargs): - if http_client is None: - http_client = mock.MagicMock() + def _get_formatted_goal_state_xml_string(self, **kwargs): parameters = self.default_parameters.copy() parameters.update(kwargs) xml = GOAL_STATE_TEMPLATE.format(**parameters) @@ -153,7 +177,13 @@ class TestGoalStateParsing(CiTestCase): continue new_xml_lines.append(line) xml = '\n'.join(new_xml_lines) - return azure_helper.GoalState(xml, http_client) + return xml + + def _get_goal_state(self, m_azure_endpoint_client=None, **kwargs): + if m_azure_endpoint_client is None: + m_azure_endpoint_client = mock.MagicMock() + xml = self._get_formatted_goal_state_xml_string(**kwargs) + return azure_helper.GoalState(xml, m_azure_endpoint_client) def test_incarnation_parsed_correctly(self): incarnation = '123' @@ -190,25 +220,55 @@ class TestGoalStateParsing(CiTestCase): azure_helper.is_byte_swapped(previous_iid, current_iid)) def test_certificates_xml_parsed_and_fetched_correctly(self): - http_client = mock.MagicMock() + m_azure_endpoint_client = mock.MagicMock() certificates_url = 'TestCertificatesUrl' goal_state = self._get_goal_state( - http_client=http_client, certificates_url=certificates_url) + m_azure_endpoint_client=m_azure_endpoint_client, + certificates_url=certificates_url) certificates_xml = goal_state.certificates_xml - self.assertEqual(1, http_client.get.call_count) - self.assertEqual(certificates_url, http_client.get.call_args[0][0]) - self.assertTrue(http_client.get.call_args[1].get('secure', False)) - self.assertEqual(http_client.get.return_value.contents, - certificates_xml) + self.assertEqual(1, m_azure_endpoint_client.get.call_count) + self.assertEqual( + certificates_url, + m_azure_endpoint_client.get.call_args[0][0]) + self.assertTrue( + m_azure_endpoint_client.get.call_args[1].get( + 'secure', False)) + self.assertEqual( + m_azure_endpoint_client.get.return_value.contents, + certificates_xml) def test_missing_certificates_skips_http_get(self): - http_client = mock.MagicMock() + m_azure_endpoint_client = mock.MagicMock() goal_state = self._get_goal_state( - http_client=http_client, certificates_url=None) + m_azure_endpoint_client=m_azure_endpoint_client, + certificates_url=None) certificates_xml = goal_state.certificates_xml - self.assertEqual(0, http_client.get.call_count) + self.assertEqual(0, m_azure_endpoint_client.get.call_count) self.assertIsNone(certificates_xml) + def test_invalid_goal_state_xml_raises_parse_error(self): + xml = 'random non-xml data' + with self.assertRaises(ElementTree.ParseError): + azure_helper.GoalState(xml, mock.MagicMock()) + + def test_missing_container_id_in_goal_state_xml_raises_exc(self): + xml = self._get_formatted_goal_state_xml_string() + xml = re.sub('<ContainerId>.*</ContainerId>', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, mock.MagicMock()) + + def test_missing_instance_id_in_goal_state_xml_raises_exc(self): + xml = self._get_formatted_goal_state_xml_string() + xml = re.sub('<InstanceId>.*</InstanceId>', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, mock.MagicMock()) + + def test_missing_incarnation_in_goal_state_xml_raises_exc(self): + xml = self._get_formatted_goal_state_xml_string() + xml = re.sub('<Incarnation>.*</Incarnation>', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, mock.MagicMock()) + class TestAzureEndpointHttpClient(CiTestCase): @@ -222,61 +282,95 @@ class TestAzureEndpointHttpClient(CiTestCase): patches = ExitStack() self.addCleanup(patches.close) - self.read_file_or_url = patches.enter_context( - mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + self.readurl = patches.enter_context( + mock.patch.object(azure_helper.url_helper, 'readurl')) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) def test_non_secure_get(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) url = 'MyTestUrl' response = client.get(url, secure=False) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, headers=self.regular_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, headers=self.regular_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_non_secure_get_raises_exception(self): + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + url = 'MyTestUrl' + with self.assertRaises(SentinelException): + client.get(url, secure=False) def test_secure_get(self): url = 'MyTestUrl' - certificate = mock.MagicMock() + m_certificate = mock.MagicMock() expected_headers = self.regular_headers.copy() expected_headers.update({ "x-ms-cipher-name": "DES_EDE3_CBC", - "x-ms-guest-agent-public-x509-cert": certificate, + "x-ms-guest-agent-public-x509-cert": m_certificate, }) - client = azure_helper.AzureEndpointHttpClient(certificate) + client = azure_helper.AzureEndpointHttpClient(m_certificate) response = client.get(url, secure=True) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, headers=expected_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, headers=expected_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_secure_get_raises_exception(self): + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.get(url, secure=True) def test_post(self): - data = mock.MagicMock() + m_data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - response = client.post(url, data=data) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + response = client.post(url, data=m_data) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, data=data, headers=self.regular_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, data=m_data, headers=self.regular_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_post_raises_exception(self): + m_data = mock.MagicMock() + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.post(url, data=m_data) def test_post_with_extra_headers(self): url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) extra_headers = {'test': 'header'} client.post(url, extra_headers=extra_headers) - self.assertEqual(1, self.read_file_or_url.call_count) expected_headers = self.regular_headers.copy() expected_headers.update(extra_headers) + self.assertEqual(1, self.readurl.call_count) self.assertEqual( mock.call(mock.ANY, data=mock.ANY, headers=expected_headers, - retries=10, timeout=5), - self.read_file_or_url.call_args) + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_post_with_sleep_with_extra_headers_raises_exception(self): + m_data = mock.MagicMock() + url = 'MyTestUrl' + extra_headers = {'test': 'header'} + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.post( + url, data=m_data, extra_headers=extra_headers) class TestOpenSSLManager(CiTestCase): @@ -287,7 +381,7 @@ class TestOpenSSLManager(CiTestCase): self.addCleanup(patches.close) self.subp = patches.enter_context( - mock.patch.object(azure_helper.util, 'subp')) + mock.patch.object(azure_helper.subp, 'subp')) try: self.open = patches.enter_context( mock.patch('__builtin__.open')) @@ -332,7 +426,7 @@ class TestOpenSSLManagerActions(CiTestCase): path = 'tests/data/azure' return os.path.join(path, name) - @unittest2.skip("todo move to cloud_test") + @unittest.skip("todo move to cloud_test") def test_pubkey_extract(self): cert = load_file(self._data_file('pubkey_extract_cert')) good_key = load_file(self._data_file('pubkey_extract_ssh_key')) @@ -344,7 +438,7 @@ class TestOpenSSLManagerActions(CiTestCase): fingerprint = sslmgr._get_fingerprint_from_cert(cert) self.assertEqual(good_fingerprint, fingerprint) - @unittest2.skip("todo move to cloud_test") + @unittest.skip("todo move to cloud_test") @mock.patch.object(azure_helper.OpenSSLManager, '_decrypt_certs_from_xml') def test_parse_certificates(self, mock_decrypt_certs): """Azure control plane puts private keys as well as certificates @@ -365,6 +459,131 @@ class TestOpenSSLManagerActions(CiTestCase): self.assertIn(fp, keys_by_fp) +class TestGoalStateHealthReporter(CiTestCase): + + default_parameters = { + 'incarnation': 1634, + 'container_id': 'MyContainerId', + 'instance_id': 'MyInstanceId' + } + + test_endpoint = 'TestEndpoint' + test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) + test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} + + provisioning_success_status = 'Ready' + + def setUp(self): + super(TestGoalStateHealthReporter, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + self.read_file_or_url = patches.enter_context( + mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + + self.post = patches.enter_context( + mock.patch.object(azure_helper.AzureEndpointHttpClient, + 'post')) + + self.GoalState = patches.enter_context( + mock.patch.object(azure_helper, 'GoalState')) + self.GoalState.return_value.container_id = \ + self.default_parameters['container_id'] + self.GoalState.return_value.instance_id = \ + self.default_parameters['instance_id'] + self.GoalState.return_value.incarnation = \ + self.default_parameters['incarnation'] + + def _get_formatted_health_report_xml_string(self, **kwargs): + return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + + def _get_report_ready_health_document(self): + return self._get_formatted_health_report_xml_string( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + health_status=self.provisioning_success_status, + health_detail_subsection='') + + def test_send_ready_signal_sends_post_request(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, + 'build_report') as m_build_report: + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + client, self.test_endpoint) + reporter.send_ready_signal() + + self.assertEqual(1, self.post.call_count) + self.assertEqual( + mock.call( + self.test_url, + data=m_build_report.return_value, + extra_headers=self.test_default_headers), + self.post.call_args) + + def test_build_report_for_health_document(self): + health_document = self._get_report_ready_health_document() + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_endpoint) + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_success_status) + self.assertEqual(health_document, generated_health_document) + self.assertIn( + '<GoalStateIncarnation>{}</GoalStateIncarnation>'.format( + str(self.default_parameters['incarnation'])), + generated_health_document) + self.assertIn( + ''.join([ + '<ContainerId>', + self.default_parameters['container_id'], + '</ContainerId>']), + generated_health_document) + self.assertIn( + ''.join([ + '<InstanceId>', + self.default_parameters['instance_id'], + '</InstanceId>']), + generated_health_document) + self.assertIn( + ''.join([ + '<State>', + self.provisioning_success_status, + '</State>']), + generated_health_document + ) + self.assertNotIn('<Details>', generated_health_document) + self.assertNotIn('<SubStatus>', generated_health_document) + self.assertNotIn('<Description>', generated_health_document) + + def test_send_ready_signal_calls_build_report(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, 'build_report' + ) as m_build_report: + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_endpoint) + reporter.send_ready_signal() + + self.assertEqual(1, m_build_report.call_count) + self.assertEqual( + mock.call( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_success_status), + m_build_report.call_args) + + class TestWALinuxAgentShim(CiTestCase): def setUp(self): @@ -383,14 +602,21 @@ class TestWALinuxAgentShim(CiTestCase): patches.enter_context( mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) - def test_http_client_uses_certificate(self): + self.test_incarnation = 'TestIncarnation' + self.test_container_id = 'TestContainerId' + self.test_instance_id = 'TestInstanceId' + self.GoalState.return_value.incarnation = self.test_incarnation + self.GoalState.return_value.container_id = self.test_container_id + self.GoalState.return_value.instance_id = self.test_instance_id + + def test_azure_endpoint_client_uses_certificate_during_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(self.OpenSSLManager.return_value.certificate)], self.AzureEndpointHttpClient.call_args_list) - def test_correct_url_used_for_goalstate(self): + def test_correct_url_used_for_goalstate_during_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = wa_shim() shim.register_with_azure_and_fetch_data() @@ -404,11 +630,10 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.call_args_list) def test_certificates_used_to_determine_public_keys(self): + # if register_with_azure_and_fetch_data() isn't passed some info about + # the user's public keys, there's no point in even trying to parse the + # certificates shim = wa_shim() - """if register_with_azure_and_fetch_data() isn't passed some info about - the user's public keys, there's no point in even trying to parse - the certificates - """ mypk = [{'fingerprint': 'fp1', 'path': 'path1'}, {'fingerprint': 'fp3', 'path': 'path3', 'value': ''}] certs = {'fp1': 'expected-key', @@ -439,43 +664,67 @@ class TestWALinuxAgentShim(CiTestCase): expected_url = 'http://test_endpoint/machine?comp=health' self.assertEqual( [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], - self.AzureEndpointHttpClient.return_value.post.call_args_list) + self.AzureEndpointHttpClient.return_value.post + .call_args_list) def test_goal_state_values_used_for_report_ready(self): - self.GoalState.return_value.incarnation = 'TestIncarnation' - self.GoalState.return_value.container_id = 'TestContainerId' - self.GoalState.return_value.instance_id = 'TestInstanceId' shim = wa_shim() shim.register_with_azure_and_fetch_data() posted_document = ( - self.AzureEndpointHttpClient.return_value.post.call_args[1]['data'] + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data'] ) - self.assertIn('TestIncarnation', posted_document) - self.assertIn('TestContainerId', posted_document) - self.assertIn('TestInstanceId', posted_document) + self.assertIn(self.test_incarnation, posted_document) + self.assertIn(self.test_container_id, posted_document) + self.assertIn(self.test_instance_id, posted_document) + + def test_xml_elems_in_report_ready(self): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + health_document = HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=self.test_incarnation, + container_id=self.test_container_id, + instance_id=self.test_instance_id, + health_status='Ready', + health_detail_subsection='') + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data']) + self.assertEqual(health_document, posted_document) def test_clean_up_can_be_called_at_any_time(self): shim = wa_shim() shim.clean_up() - def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self): + def test_clean_up_after_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() shim.clean_up() self.assertEqual( 1, self.OpenSSLManager.return_value.clean_up.call_count) - def test_failure_to_fetch_goalstate_bubbles_up(self): - class SentinelException(Exception): - pass - self.AzureEndpointHttpClient.return_value.get.side_effect = ( - SentinelException) + def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): + self.AzureEndpointHttpClient.return_value.get \ + .side_effect = (SentinelException) shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): + self.GoalState.side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) -class TestGetMetadataFromFabric(CiTestCase): + def test_failure_to_send_report_ready_health_doc_bubbles_up(self): + self.AzureEndpointHttpClient.return_value.post \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) + + +class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_data_from_shim_returned(self, shim): @@ -491,14 +740,39 @@ class TestGetMetadataFromFabric(CiTestCase): @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_failure_in_registration_calls_clean_up(self, shim): - class SentinelException(Exception): - pass shim.return_value.register_with_azure_and_fetch_data.side_effect = ( SentinelException) self.assertRaises(SentinelException, azure_helper.get_metadata_from_fabric) self.assertEqual(1, shim.return_value.clean_up.call_count) + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_calls_shim_register_with_azure_and_fetch_data(self, shim): + m_pubkey_info = mock.MagicMock() + azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info) + self.assertEqual( + 1, + shim.return_value + .register_with_azure_and_fetch_data.call_count) + self.assertEqual( + mock.call(pubkey_info=m_pubkey_info), + shim.return_value + .register_with_azure_and_fetch_data.call_args) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_instantiates_shim_with_kwargs(self, shim): + m_fallback_lease_file = mock.MagicMock() + m_dhcp_options = mock.MagicMock() + azure_helper.get_metadata_from_fabric( + fallback_lease_file=m_fallback_lease_file, + dhcp_opts=m_dhcp_options) + self.assertEqual(1, shim.call_count) + self.assertEqual( + mock.call( + fallback_lease_file=m_fallback_lease_file, + dhcp_options=m_dhcp_options), + shim.call_args) + class TestExtractIpAddressFromNetworkd(CiTestCase): |