diff options
-rwxr-xr-x | cloudinit/sources/helpers/azure.py | 401 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 391 | ||||
-rw-r--r-- | tools/.github-cla-signers | 1 |
3 files changed, 631 insertions, 162 deletions
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 7afa7ed8..6df28ccf 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -195,7 +195,7 @@ def _get_dhcp_endpoint_option_name(): return azure_endpoint -class AzureEndpointHttpClient(object): +class AzureEndpointHttpClient: headers = { 'x-ms-agent-name': 'WALinuxAgent', @@ -213,57 +213,77 @@ class AzureEndpointHttpClient(object): if secure: headers = self.headers.copy() headers.update(self.extra_secure_headers) - return url_helper.read_file_or_url(url, headers=headers, timeout=5, - retries=10) + return url_helper.readurl(url, headers=headers, + timeout=5, retries=10, sec_between=5) def post(self, url, data=None, extra_headers=None): headers = self.headers if extra_headers is not None: headers = self.headers.copy() headers.update(extra_headers) - return url_helper.read_file_or_url(url, data=data, headers=headers, - timeout=5, retries=10) + return url_helper.readurl(url, data=data, headers=headers, + timeout=5, retries=10, sec_between=5) -class GoalState(object): +class InvalidGoalStateXMLException(Exception): + """Raised when GoalState XML is invalid or has missing data.""" - def __init__(self, xml, http_client): - self.http_client = http_client - self.root = ElementTree.fromstring(xml) - self._certificates_xml = None - def _text_from_xpath(self, xpath): - element = self.root.find(xpath) - if element is not None: - return element.text - return None +class GoalState: - @property - def container_id(self): - return self._text_from_xpath('./Container/ContainerId') + def __init__(self, unparsed_xml, azure_endpoint_client): + """Parses a GoalState XML string and returns a GoalState object. - @property - def incarnation(self): - return self._text_from_xpath('./Incarnation') + @param unparsed_xml: string representing a GoalState XML. + @param azure_endpoint_client: instance of AzureEndpointHttpClient + @return: GoalState object representing the GoalState XML string. + """ + self.azure_endpoint_client = azure_endpoint_client - @property - def instance_id(self): - return self._text_from_xpath( + try: + self.root = ElementTree.fromstring(unparsed_xml) + except ElementTree.ParseError as e: + msg = 'Failed to parse GoalState XML: %s' + LOG.warning(msg, e) + report_diagnostic_event(msg % (e,)) + raise + + self.container_id = self._text_from_xpath('./Container/ContainerId') + self.instance_id = self._text_from_xpath( './Container/RoleInstanceList/RoleInstance/InstanceId') + self.incarnation = self._text_from_xpath('./Incarnation') + + for attr in ("container_id", "instance_id", "incarnation"): + if getattr(self, attr) is None: + msg = 'Missing %s in GoalState XML' + LOG.warning(msg, attr) + report_diagnostic_event(msg % (attr,)) + raise InvalidGoalStateXMLException(msg) + + self.certificates_xml = None + url = self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance' + '/Configuration/Certificates') + if url is not None: + with events.ReportEventStack( + name="get-certificates-xml", + description="get certificates xml", + parent=azure_ds_reporter): + self.certificates_xml = \ + self.azure_endpoint_client.get( + url, secure=True).contents + if self.certificates_xml is None: + raise InvalidGoalStateXMLException( + 'Azure endpoint returned empty certificates xml.') - @property - def certificates_xml(self): - if self._certificates_xml is None: - url = self._text_from_xpath( - './Container/RoleInstanceList/RoleInstance' - '/Configuration/Certificates') - if url is not None: - self._certificates_xml = self.http_client.get( - url, secure=True).contents - return self._certificates_xml + def _text_from_xpath(self, xpath): + element = self.root.find(xpath) + if element is not None: + return element.text + return None -class OpenSSLManager(object): +class OpenSSLManager: certificate_names = { 'private_key': 'TransportPrivate.pem', @@ -370,25 +390,120 @@ class OpenSSLManager(object): return keys -class WALinuxAgentShim(object): - - REPORT_READY_XML_TEMPLATE = '\n'.join([ - '<?xml version="1.0" encoding="utf-8"?>', - '<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"' - ' xmlns:xsd="http://www.w3.org/2001/XMLSchema">', - ' <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>', - ' <Container>', - ' <ContainerId>{container_id}</ContainerId>', - ' <RoleInstanceList>', - ' <Role>', - ' <InstanceId>{instance_id}</InstanceId>', - ' <Health>', - ' <State>Ready</State>', - ' </Health>', - ' </Role>', - ' </RoleInstanceList>', - ' </Container>', - '</Health>']) +class GoalStateHealthReporter: + + HEALTH_REPORT_XML_TEMPLATE = textwrap.dedent('''\ + <?xml version="1.0" encoding="utf-8"?> + <Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xmlns:xsd="http://www.w3.org/2001/XMLSchema"> + <GoalStateIncarnation>{incarnation}</GoalStateIncarnation> + <Container> + <ContainerId>{container_id}</ContainerId> + <RoleInstanceList> + <Role> + <InstanceId>{instance_id}</InstanceId> + <Health> + <State>{health_status}</State> + {health_detail_subsection} + </Health> + </Role> + </RoleInstanceList> + </Container> + </Health> + ''') + + HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = textwrap.dedent('''\ + <Details> + <SubStatus>{health_substatus}</SubStatus> + <Description>{health_description}</Description> + </Details> + ''') + + PROVISIONING_SUCCESS_STATUS = 'Ready' + + def __init__(self, goal_state, azure_endpoint_client, endpoint): + """Creates instance that will report provisioning status to an endpoint + + @param goal_state: An instance of class GoalState that contains + goal state info such as incarnation, container id, and instance id. + These 3 values are needed when reporting the provisioning status + to Azure + @param azure_endpoint_client: Instance of class AzureEndpointHttpClient + @param endpoint: Endpoint (string) where the provisioning status report + will be sent to + @return: Instance of class GoalStateHealthReporter + """ + self._goal_state = goal_state + self._azure_endpoint_client = azure_endpoint_client + self._endpoint = endpoint + + @azure_ds_telemetry_reporter + def send_ready_signal(self): + document = self.build_report( + incarnation=self._goal_state.incarnation, + container_id=self._goal_state.container_id, + instance_id=self._goal_state.instance_id, + status=self.PROVISIONING_SUCCESS_STATUS) + LOG.debug('Reporting ready to Azure fabric.') + try: + self._post_health_report(document=document) + except Exception as e: + msg = "exception while reporting ready: %s" % e + LOG.error(msg) + report_diagnostic_event(msg) + raise + + LOG.info('Reported ready to Azure fabric.') + + def build_report( + self, incarnation, container_id, instance_id, + status, substatus=None, description=None): + health_detail = '' + if substatus is not None: + health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( + health_substatus=substatus, health_description=description) + + health_report = self.HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=incarnation, + container_id=container_id, + instance_id=instance_id, + health_status=status, + health_detail_subsection=health_detail) + + return health_report + + @azure_ds_telemetry_reporter + def _post_health_report(self, document): + # Whenever report_diagnostic_event(diagnostic_msg) is invoked in code, + # the diagnostic messages are written to special files + # (/var/opt/hyperv/.kvp_pool_*) as Hyper-V KVP messages. + # Hyper-V KVP message communication is done through these files, + # and KVP functionality is used to communicate and share diagnostic + # info with the Azure Host. + # The Azure Host will collect the VM's Hyper-V KVP diagnostic messages + # when cloud-init reports to fabric. + # When the Azure Host receives the health report signal, it will only + # collect and process whatever KVP diagnostic messages have been + # written to the KVP files. + # KVP messages that are published after the Azure Host receives the + # signal are ignored and unprocessed, so yield this thread to the + # Hyper-V KVP Reporting thread so that they are written. + # time.sleep(0) is a low-cost and proven method to yield the scheduler + # and ensure that events are flushed. + # See HyperVKvpReportingHandler class, which is a multi-threaded + # reporting handler that writes to the special KVP files. + time.sleep(0) + + LOG.debug('Sending health report to Azure fabric.') + url = "http://{}/machine?comp=health".format(self._endpoint) + self._azure_endpoint_client.post( + url, + data=document, + extra_headers={'Content-Type': 'text/xml; charset=utf-8'}) + LOG.debug('Successfully sent health report to Azure fabric') + + +class WALinuxAgentShim: def __init__(self, fallback_lease_file=None, dhcp_options=None): LOG.debug('WALinuxAgentShim instantiated, fallback_lease_file=%s', @@ -396,6 +511,7 @@ class WALinuxAgentShim(object): self.dhcpoptions = dhcp_options self._endpoint = None self.openssl_manager = None + self.azure_endpoint_client = None self.lease_file = fallback_lease_file def clean_up(self): @@ -494,7 +610,22 @@ class WALinuxAgentShim(object): @staticmethod @azure_ds_telemetry_reporter def find_endpoint(fallback_lease_file=None, dhcp245=None): + """Finds and returns the Azure endpoint using various methods. + + The Azure endpoint is searched in the following order: + 1. Endpoint from dhcp options (dhcp option 245). + 2. Endpoint from networkd. + 3. Endpoint from dhclient hook json. + 4. Endpoint from fallback lease file. + 5. The default Azure endpoint. + + @param fallback_lease_file: Fallback lease file that will be used + during endpoint search. + @param dhcp245: dhcp options that will be used during endpoint search. + @return: Azure endpoint IP address. + """ value = None + if dhcp245 is not None: value = dhcp245 LOG.debug("Using Azure Endpoint from dhcp options") @@ -536,42 +667,128 @@ class WALinuxAgentShim(object): @azure_ds_telemetry_reporter def register_with_azure_and_fetch_data(self, pubkey_info=None): + """Gets the VM's GoalState from Azure, uses the GoalState information + to report ready/send the ready signal/provisioning complete signal to + Azure, and then uses pubkey_info to filter and obtain the user's + pubkeys from the GoalState. + + @param pubkey_info: List of pubkey values and fingerprints which are + used to filter and obtain the user's pubkey values from the + GoalState. + @return: The list of user's authorized pubkey values. + """ if self.openssl_manager is None: self.openssl_manager = OpenSSLManager() - http_client = AzureEndpointHttpClient(self.openssl_manager.certificate) + if self.azure_endpoint_client is None: + self.azure_endpoint_client = AzureEndpointHttpClient( + self.openssl_manager.certificate) + goal_state = self._fetch_goal_state_from_azure() + ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) + health_reporter = GoalStateHealthReporter( + goal_state, self.azure_endpoint_client, self.endpoint) + health_reporter.send_ready_signal() + return {'public-keys': ssh_keys} + + @azure_ds_telemetry_reporter + def _fetch_goal_state_from_azure(self): + """Fetches the GoalState XML from the Azure endpoint, parses the XML, + and returns a GoalState object. + + @return: GoalState object representing the GoalState XML + """ + unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() + return self._parse_raw_goal_state_xml(unparsed_goal_state_xml) + + @azure_ds_telemetry_reporter + def _get_raw_goal_state_xml_from_azure(self): + """Fetches the GoalState XML from the Azure endpoint and returns + the XML as a string. + + @return: GoalState XML string + """ + LOG.info('Registering with Azure...') - attempts = 0 - while True: - try: - response = http_client.get( - 'http://{0}/machine/?comp=goalstate'.format(self.endpoint)) - except Exception as e: - if attempts < 10: - time.sleep(attempts + 1) - else: - report_diagnostic_event( - "failed to register with Azure: %s" % e) - raise - else: - break - attempts += 1 + url = 'http://{}/machine/?comp=goalstate'.format(self.endpoint) + try: + response = self.azure_endpoint_client.get(url) + except Exception as e: + msg = 'failed to register with Azure: %s' % e + LOG.warning(msg) + report_diagnostic_event(msg) + raise LOG.debug('Successfully fetched GoalState XML.') - goal_state = GoalState(response.contents, http_client) - report_diagnostic_event("container_id %s" % goal_state.container_id) + return response.contents + + @azure_ds_telemetry_reporter + def _parse_raw_goal_state_xml(self, unparsed_goal_state_xml): + """Parses a GoalState XML string and returns a GoalState object. + + @param unparsed_goal_state_xml: GoalState XML string + @return: GoalState object representing the GoalState XML + """ + try: + goal_state = GoalState( + unparsed_goal_state_xml, self.azure_endpoint_client) + except Exception as e: + msg = 'Error processing GoalState XML: %s' % e + LOG.warning(msg) + report_diagnostic_event(msg) + raise + msg = ', '.join([ + 'GoalState XML container id: %s' % goal_state.container_id, + 'GoalState XML instance id: %s' % goal_state.instance_id, + 'GoalState XML incarnation: %s' % goal_state.incarnation]) + LOG.debug(msg) + report_diagnostic_event(msg) + return goal_state + + @azure_ds_telemetry_reporter + def _get_user_pubkeys(self, goal_state, pubkey_info): + """Gets and filters the VM admin user's authorized pubkeys. + + The admin user in this case is the username specified as "admin" + when deploying VMs on Azure. + See https://docs.microsoft.com/en-us/cli/azure/vm#az-vm-create. + cloud-init expects a straightforward array of keys to be dropped + into the admin user's authorized_keys file. Azure control plane exposes + multiple public keys to the VM via wireserver. Select just the + admin user's key(s) and return them, ignoring any other certs. + + @param goal_state: GoalState object. The GoalState object contains + a certificate XML, which contains both the VM user's authorized + pubkeys and other non-user pubkeys, which are used for + MSI and protected extension handling. + @param pubkey_info: List of VM user pubkey dicts that were previously + obtained from provisioning data. + Each pubkey dict in this list can either have the format + pubkey['value'] or pubkey['fingerprint']. + Each pubkey['fingerprint'] in the list is used to filter + and obtain the actual pubkey value from the GoalState + certificates XML. + Each pubkey['value'] requires no further processing and is + immediately added to the return list. + @return: A list of the VM user's authorized pubkey values. + """ ssh_keys = [] if goal_state.certificates_xml is not None and pubkey_info is not None: LOG.debug('Certificate XML found; parsing out public keys.') keys_by_fingerprint = self.openssl_manager.parse_certificates( goal_state.certificates_xml) ssh_keys = self._filter_pubkeys(keys_by_fingerprint, pubkey_info) - self._report_ready(goal_state, http_client) - return {'public-keys': ssh_keys} + return ssh_keys - def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info): - """cloud-init expects a straightforward array of keys to be dropped - into the user's authorized_keys file. Azure control plane exposes - multiple public keys to the VM via wireserver. Select just the - user's key(s) and return them, ignoring any other certs. + @staticmethod + def _filter_pubkeys(keys_by_fingerprint, pubkey_info): + """ Filter and return only the user's actual pubkeys. + + @param keys_by_fingerprint: pubkey fingerprint -> pubkey value dict + that was obtained from GoalState Certificates XML. May contain + non-user pubkeys. + @param pubkey_info: List of VM user pubkeys. Pubkey values are added + to the return list without further processing. Pubkey fingerprints + are used to filter and obtain the actual pubkey values from + keys_by_fingerprint. + @return: A list of the VM user's authorized pubkey values. """ keys = [] for pubkey in pubkey_info: @@ -590,30 +807,6 @@ class WALinuxAgentShim(object): return keys - @azure_ds_telemetry_reporter - def _report_ready(self, goal_state, http_client): - LOG.debug('Reporting ready to Azure fabric.') - document = self.REPORT_READY_XML_TEMPLATE.format( - incarnation=goal_state.incarnation, - container_id=goal_state.container_id, - instance_id=goal_state.instance_id, - ) - # Host will collect kvps when cloud-init reports ready. - # some kvps might still be in the queue. We yield the scheduler - # to make sure we process all kvps up till this point. - time.sleep(0) - try: - http_client.post( - "http://{0}/machine?comp=health".format(self.endpoint), - data=document, - extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, - ) - except Exception as e: - report_diagnostic_event("exception while reporting ready: %s" % e) - raise - - LOG.info('Reported ready to Azure fabric.') - @azure_ds_telemetry_reporter def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, @@ -631,7 +824,7 @@ def dhcp_log_cb(out, err): report_diagnostic_event("dhclient error stream: %s" % err) -class EphemeralDHCPv4WithReporting(object): +class EphemeralDHCPv4WithReporting: def __init__(self, reporter, nic=None): self.reporter = reporter self.ephemeralDHCPv4 = EphemeralDHCPv4( diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index f314cd4a..5e6d3d2d 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,8 +1,10 @@ # This file is part of cloud-init. See LICENSE file for license information. import os +import re import unittest from textwrap import dedent +from xml.etree import ElementTree from cloudinit.sources.helpers import azure as azure_helper from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -48,6 +50,30 @@ GOAL_STATE_TEMPLATE = """\ </GoalState> """ +HEALTH_REPORT_XML_TEMPLATE = '''\ +<?xml version="1.0" encoding="utf-8"?> +<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xmlns:xsd="http://www.w3.org/2001/XMLSchema"> + <GoalStateIncarnation>{incarnation}</GoalStateIncarnation> + <Container> + <ContainerId>{container_id}</ContainerId> + <RoleInstanceList> + <Role> + <InstanceId>{instance_id}</InstanceId> + <Health> + <State>{health_status}</State> + {health_detail_subsection} + </Health> + </Role> + </RoleInstanceList> + </Container> +</Health> +''' + + +class SentinelException(Exception): + pass + class TestFindEndpoint(CiTestCase): @@ -140,9 +166,7 @@ class TestGoalStateParsing(CiTestCase): 'certificates_url': 'MyCertificatesUrl', } - def _get_goal_state(self, http_client=None, **kwargs): - if http_client is None: - http_client = mock.MagicMock() + def _get_formatted_goal_state_xml_string(self, **kwargs): parameters = self.default_parameters.copy() parameters.update(kwargs) xml = GOAL_STATE_TEMPLATE.format(**parameters) @@ -153,7 +177,13 @@ class TestGoalStateParsing(CiTestCase): continue new_xml_lines.append(line) xml = '\n'.join(new_xml_lines) - return azure_helper.GoalState(xml, http_client) + return xml + + def _get_goal_state(self, m_azure_endpoint_client=None, **kwargs): + if m_azure_endpoint_client is None: + m_azure_endpoint_client = mock.MagicMock() + xml = self._get_formatted_goal_state_xml_string(**kwargs) + return azure_helper.GoalState(xml, m_azure_endpoint_client) def test_incarnation_parsed_correctly(self): incarnation = '123' @@ -190,25 +220,55 @@ class TestGoalStateParsing(CiTestCase): azure_helper.is_byte_swapped(previous_iid, current_iid)) def test_certificates_xml_parsed_and_fetched_correctly(self): - http_client = mock.MagicMock() + m_azure_endpoint_client = mock.MagicMock() certificates_url = 'TestCertificatesUrl' goal_state = self._get_goal_state( - http_client=http_client, certificates_url=certificates_url) + m_azure_endpoint_client=m_azure_endpoint_client, + certificates_url=certificates_url) certificates_xml = goal_state.certificates_xml - self.assertEqual(1, http_client.get.call_count) - self.assertEqual(certificates_url, http_client.get.call_args[0][0]) - self.assertTrue(http_client.get.call_args[1].get('secure', False)) - self.assertEqual(http_client.get.return_value.contents, - certificates_xml) + self.assertEqual(1, m_azure_endpoint_client.get.call_count) + self.assertEqual( + certificates_url, + m_azure_endpoint_client.get.call_args[0][0]) + self.assertTrue( + m_azure_endpoint_client.get.call_args[1].get( + 'secure', False)) + self.assertEqual( + m_azure_endpoint_client.get.return_value.contents, + certificates_xml) def test_missing_certificates_skips_http_get(self): - http_client = mock.MagicMock() + m_azure_endpoint_client = mock.MagicMock() goal_state = self._get_goal_state( - http_client=http_client, certificates_url=None) + m_azure_endpoint_client=m_azure_endpoint_client, + certificates_url=None) certificates_xml = goal_state.certificates_xml - self.assertEqual(0, http_client.get.call_count) + self.assertEqual(0, m_azure_endpoint_client.get.call_count) self.assertIsNone(certificates_xml) + def test_invalid_goal_state_xml_raises_parse_error(self): + xml = 'random non-xml data' + with self.assertRaises(ElementTree.ParseError): + azure_helper.GoalState(xml, mock.MagicMock()) + + def test_missing_container_id_in_goal_state_xml_raises_exc(self): + xml = self._get_formatted_goal_state_xml_string() + xml = re.sub('<ContainerId>.*</ContainerId>', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, mock.MagicMock()) + + def test_missing_instance_id_in_goal_state_xml_raises_exc(self): + xml = self._get_formatted_goal_state_xml_string() + xml = re.sub('<InstanceId>.*</InstanceId>', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, mock.MagicMock()) + + def test_missing_incarnation_in_goal_state_xml_raises_exc(self): + xml = self._get_formatted_goal_state_xml_string() + xml = re.sub('<Incarnation>.*</Incarnation>', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, mock.MagicMock()) + class TestAzureEndpointHttpClient(CiTestCase): @@ -222,61 +282,95 @@ class TestAzureEndpointHttpClient(CiTestCase): patches = ExitStack() self.addCleanup(patches.close) - self.read_file_or_url = patches.enter_context( - mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + self.readurl = patches.enter_context( + mock.patch.object(azure_helper.url_helper, 'readurl')) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) def test_non_secure_get(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) url = 'MyTestUrl' response = client.get(url, secure=False) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, headers=self.regular_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, headers=self.regular_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_non_secure_get_raises_exception(self): + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + url = 'MyTestUrl' + with self.assertRaises(SentinelException): + client.get(url, secure=False) def test_secure_get(self): url = 'MyTestUrl' - certificate = mock.MagicMock() + m_certificate = mock.MagicMock() expected_headers = self.regular_headers.copy() expected_headers.update({ "x-ms-cipher-name": "DES_EDE3_CBC", - "x-ms-guest-agent-public-x509-cert": certificate, + "x-ms-guest-agent-public-x509-cert": m_certificate, }) - client = azure_helper.AzureEndpointHttpClient(certificate) + client = azure_helper.AzureEndpointHttpClient(m_certificate) response = client.get(url, secure=True) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, headers=expected_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, headers=expected_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_secure_get_raises_exception(self): + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.get(url, secure=True) def test_post(self): - data = mock.MagicMock() + m_data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) - response = client.post(url, data=data) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + response = client.post(url, data=m_data) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, data=data, headers=self.regular_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, data=m_data, headers=self.regular_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_post_raises_exception(self): + m_data = mock.MagicMock() + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.post(url, data=m_data) def test_post_with_extra_headers(self): url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) extra_headers = {'test': 'header'} client.post(url, extra_headers=extra_headers) - self.assertEqual(1, self.read_file_or_url.call_count) expected_headers = self.regular_headers.copy() expected_headers.update(extra_headers) + self.assertEqual(1, self.readurl.call_count) self.assertEqual( mock.call(mock.ANY, data=mock.ANY, headers=expected_headers, - retries=10, timeout=5), - self.read_file_or_url.call_args) + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_post_with_sleep_with_extra_headers_raises_exception(self): + m_data = mock.MagicMock() + url = 'MyTestUrl' + extra_headers = {'test': 'header'} + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.post( + url, data=m_data, extra_headers=extra_headers) class TestOpenSSLManager(CiTestCase): @@ -365,6 +459,131 @@ class TestOpenSSLManagerActions(CiTestCase): self.assertIn(fp, keys_by_fp) +class TestGoalStateHealthReporter(CiTestCase): + + default_parameters = { + 'incarnation': 1634, + 'container_id': 'MyContainerId', + 'instance_id': 'MyInstanceId' + } + + test_endpoint = 'TestEndpoint' + test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) + test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} + + provisioning_success_status = 'Ready' + + def setUp(self): + super(TestGoalStateHealthReporter, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + self.read_file_or_url = patches.enter_context( + mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + + self.post = patches.enter_context( + mock.patch.object(azure_helper.AzureEndpointHttpClient, + 'post')) + + self.GoalState = patches.enter_context( + mock.patch.object(azure_helper, 'GoalState')) + self.GoalState.return_value.container_id = \ + self.default_parameters['container_id'] + self.GoalState.return_value.instance_id = \ + self.default_parameters['instance_id'] + self.GoalState.return_value.incarnation = \ + self.default_parameters['incarnation'] + + def _get_formatted_health_report_xml_string(self, **kwargs): + return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + + def _get_report_ready_health_document(self): + return self._get_formatted_health_report_xml_string( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + health_status=self.provisioning_success_status, + health_detail_subsection='') + + def test_send_ready_signal_sends_post_request(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, + 'build_report') as m_build_report: + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + client, self.test_endpoint) + reporter.send_ready_signal() + + self.assertEqual(1, self.post.call_count) + self.assertEqual( + mock.call( + self.test_url, + data=m_build_report.return_value, + extra_headers=self.test_default_headers), + self.post.call_args) + + def test_build_report_for_health_document(self): + health_document = self._get_report_ready_health_document() + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_endpoint) + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_success_status) + self.assertEqual(health_document, generated_health_document) + self.assertIn( + '<GoalStateIncarnation>{}</GoalStateIncarnation>'.format( + str(self.default_parameters['incarnation'])), + generated_health_document) + self.assertIn( + ''.join([ + '<ContainerId>', + self.default_parameters['container_id'], + '</ContainerId>']), + generated_health_document) + self.assertIn( + ''.join([ + '<InstanceId>', + self.default_parameters['instance_id'], + '</InstanceId>']), + generated_health_document) + self.assertIn( + ''.join([ + '<State>', + self.provisioning_success_status, + '</State>']), + generated_health_document + ) + self.assertNotIn('<Details>', generated_health_document) + self.assertNotIn('<SubStatus>', generated_health_document) + self.assertNotIn('<Description>', generated_health_document) + + def test_send_ready_signal_calls_build_report(self): + with mock.patch.object( + azure_helper.GoalStateHealthReporter, 'build_report' + ) as m_build_report: + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_endpoint) + reporter.send_ready_signal() + + self.assertEqual(1, m_build_report.call_count) + self.assertEqual( + mock.call( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_success_status), + m_build_report.call_args) + + class TestWALinuxAgentShim(CiTestCase): def setUp(self): @@ -383,14 +602,21 @@ class TestWALinuxAgentShim(CiTestCase): patches.enter_context( mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) - def test_http_client_uses_certificate(self): + self.test_incarnation = 'TestIncarnation' + self.test_container_id = 'TestContainerId' + self.test_instance_id = 'TestInstanceId' + self.GoalState.return_value.incarnation = self.test_incarnation + self.GoalState.return_value.container_id = self.test_container_id + self.GoalState.return_value.instance_id = self.test_instance_id + + def test_azure_endpoint_client_uses_certificate_during_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(self.OpenSSLManager.return_value.certificate)], self.AzureEndpointHttpClient.call_args_list) - def test_correct_url_used_for_goalstate(self): + def test_correct_url_used_for_goalstate_during_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = wa_shim() shim.register_with_azure_and_fetch_data() @@ -438,43 +664,67 @@ class TestWALinuxAgentShim(CiTestCase): expected_url = 'http://test_endpoint/machine?comp=health' self.assertEqual( [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], - self.AzureEndpointHttpClient.return_value.post.call_args_list) + self.AzureEndpointHttpClient.return_value.post + .call_args_list) def test_goal_state_values_used_for_report_ready(self): - self.GoalState.return_value.incarnation = 'TestIncarnation' - self.GoalState.return_value.container_id = 'TestContainerId' - self.GoalState.return_value.instance_id = 'TestInstanceId' shim = wa_shim() shim.register_with_azure_and_fetch_data() posted_document = ( - self.AzureEndpointHttpClient.return_value.post.call_args[1]['data'] + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data'] ) - self.assertIn('TestIncarnation', posted_document) - self.assertIn('TestContainerId', posted_document) - self.assertIn('TestInstanceId', posted_document) + self.assertIn(self.test_incarnation, posted_document) + self.assertIn(self.test_container_id, posted_document) + self.assertIn(self.test_instance_id, posted_document) + + def test_xml_elems_in_report_ready(self): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + health_document = HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=self.test_incarnation, + container_id=self.test_container_id, + instance_id=self.test_instance_id, + health_status='Ready', + health_detail_subsection='') + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data']) + self.assertEqual(health_document, posted_document) def test_clean_up_can_be_called_at_any_time(self): shim = wa_shim() shim.clean_up() - def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self): + def test_clean_up_after_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() shim.clean_up() self.assertEqual( 1, self.OpenSSLManager.return_value.clean_up.call_count) - def test_failure_to_fetch_goalstate_bubbles_up(self): - class SentinelException(Exception): - pass - self.AzureEndpointHttpClient.return_value.get.side_effect = ( - SentinelException) + def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): + self.AzureEndpointHttpClient.return_value.get \ + .side_effect = (SentinelException) shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): + self.GoalState.side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) -class TestGetMetadataFromFabric(CiTestCase): + def test_failure_to_send_report_ready_health_doc_bubbles_up(self): + self.AzureEndpointHttpClient.return_value.post \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) + + +class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_data_from_shim_returned(self, shim): @@ -490,14 +740,39 @@ class TestGetMetadataFromFabric(CiTestCase): @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_failure_in_registration_calls_clean_up(self, shim): - class SentinelException(Exception): - pass shim.return_value.register_with_azure_and_fetch_data.side_effect = ( SentinelException) self.assertRaises(SentinelException, azure_helper.get_metadata_from_fabric) self.assertEqual(1, shim.return_value.clean_up.call_count) + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_calls_shim_register_with_azure_and_fetch_data(self, shim): + m_pubkey_info = mock.MagicMock() + azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info) + self.assertEqual( + 1, + shim.return_value + .register_with_azure_and_fetch_data.call_count) + self.assertEqual( + mock.call(pubkey_info=m_pubkey_info), + shim.return_value + .register_with_azure_and_fetch_data.call_args) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_instantiates_shim_with_kwargs(self, shim): + m_fallback_lease_file = mock.MagicMock() + m_dhcp_options = mock.MagicMock() + azure_helper.get_metadata_from_fabric( + fallback_lease_file=m_fallback_lease_file, + dhcp_opts=m_dhcp_options) + self.assertEqual(1, shim.call_count) + self.assertEqual( + mock.call( + fallback_lease_file=m_fallback_lease_file, + dhcp_options=m_dhcp_options), + shim.call_args) + class TestExtractIpAddressFromNetworkd(CiTestCase): diff --git a/tools/.github-cla-signers b/tools/.github-cla-signers index 50473bf7..0c4d728f 100644 --- a/tools/.github-cla-signers +++ b/tools/.github-cla-signers @@ -7,6 +7,7 @@ dermotbradley dhensby eandersson izzyleung +johnsonshi landon912 lucasmoura marlluslustosa |