diff options
| -rwxr-xr-x | cloudinit/sources/helpers/azure.py | 401 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 391 | ||||
| -rw-r--r-- | tools/.github-cla-signers | 1 | 
3 files changed, 631 insertions, 162 deletions
| diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 7afa7ed8..6df28ccf 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -195,7 +195,7 @@ def _get_dhcp_endpoint_option_name():      return azure_endpoint -class AzureEndpointHttpClient(object): +class AzureEndpointHttpClient:      headers = {          'x-ms-agent-name': 'WALinuxAgent', @@ -213,57 +213,77 @@ class AzureEndpointHttpClient(object):          if secure:              headers = self.headers.copy()              headers.update(self.extra_secure_headers) -        return url_helper.read_file_or_url(url, headers=headers, timeout=5, -                                           retries=10) +        return url_helper.readurl(url, headers=headers, +                                  timeout=5, retries=10, sec_between=5)      def post(self, url, data=None, extra_headers=None):          headers = self.headers          if extra_headers is not None:              headers = self.headers.copy()              headers.update(extra_headers) -        return url_helper.read_file_or_url(url, data=data, headers=headers, -                                           timeout=5, retries=10) +        return url_helper.readurl(url, data=data, headers=headers, +                                  timeout=5, retries=10, sec_between=5) -class GoalState(object): +class InvalidGoalStateXMLException(Exception): +    """Raised when GoalState XML is invalid or has missing data.""" -    def __init__(self, xml, http_client): -        self.http_client = http_client -        self.root = ElementTree.fromstring(xml) -        self._certificates_xml = None -    def _text_from_xpath(self, xpath): -        element = self.root.find(xpath) -        if element is not None: -            return element.text -        return None +class GoalState: -    @property -    def container_id(self): -        return self._text_from_xpath('./Container/ContainerId') +    def __init__(self, unparsed_xml, azure_endpoint_client): +        """Parses a GoalState XML string and returns a GoalState object. -    @property -    def incarnation(self): -        return self._text_from_xpath('./Incarnation') +        @param unparsed_xml: string representing a GoalState XML. +        @param azure_endpoint_client: instance of AzureEndpointHttpClient +        @return: GoalState object representing the GoalState XML string. +        """ +        self.azure_endpoint_client = azure_endpoint_client -    @property -    def instance_id(self): -        return self._text_from_xpath( +        try: +            self.root = ElementTree.fromstring(unparsed_xml) +        except ElementTree.ParseError as e: +            msg = 'Failed to parse GoalState XML: %s' +            LOG.warning(msg, e) +            report_diagnostic_event(msg % (e,)) +            raise + +        self.container_id = self._text_from_xpath('./Container/ContainerId') +        self.instance_id = self._text_from_xpath(              './Container/RoleInstanceList/RoleInstance/InstanceId') +        self.incarnation = self._text_from_xpath('./Incarnation') + +        for attr in ("container_id", "instance_id", "incarnation"): +            if getattr(self, attr) is None: +                msg = 'Missing %s in GoalState XML' +                LOG.warning(msg, attr) +                report_diagnostic_event(msg % (attr,)) +                raise InvalidGoalStateXMLException(msg) + +        self.certificates_xml = None +        url = self._text_from_xpath( +            './Container/RoleInstanceList/RoleInstance' +            '/Configuration/Certificates') +        if url is not None: +            with events.ReportEventStack( +                    name="get-certificates-xml", +                    description="get certificates xml", +                    parent=azure_ds_reporter): +                self.certificates_xml = \ +                    self.azure_endpoint_client.get( +                        url, secure=True).contents +                if self.certificates_xml is None: +                    raise InvalidGoalStateXMLException( +                        'Azure endpoint returned empty certificates xml.') -    @property -    def certificates_xml(self): -        if self._certificates_xml is None: -            url = self._text_from_xpath( -                './Container/RoleInstanceList/RoleInstance' -                '/Configuration/Certificates') -            if url is not None: -                self._certificates_xml = self.http_client.get( -                    url, secure=True).contents -        return self._certificates_xml +    def _text_from_xpath(self, xpath): +        element = self.root.find(xpath) +        if element is not None: +            return element.text +        return None -class OpenSSLManager(object): +class OpenSSLManager:      certificate_names = {          'private_key': 'TransportPrivate.pem', @@ -370,25 +390,120 @@ class OpenSSLManager(object):          return keys -class WALinuxAgentShim(object): - -    REPORT_READY_XML_TEMPLATE = '\n'.join([ -        '<?xml version="1.0" encoding="utf-8"?>', -        '<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"' -        ' xmlns:xsd="http://www.w3.org/2001/XMLSchema">', -        '  <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>', -        '  <Container>', -        '    <ContainerId>{container_id}</ContainerId>', -        '    <RoleInstanceList>', -        '      <Role>', -        '        <InstanceId>{instance_id}</InstanceId>', -        '        <Health>', -        '          <State>Ready</State>', -        '        </Health>', -        '      </Role>', -        '    </RoleInstanceList>', -        '  </Container>', -        '</Health>']) +class GoalStateHealthReporter: + +    HEALTH_REPORT_XML_TEMPLATE = textwrap.dedent('''\ +        <?xml version="1.0" encoding="utf-8"?> +        <Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" +         xmlns:xsd="http://www.w3.org/2001/XMLSchema"> +          <GoalStateIncarnation>{incarnation}</GoalStateIncarnation> +          <Container> +            <ContainerId>{container_id}</ContainerId> +            <RoleInstanceList> +              <Role> +                <InstanceId>{instance_id}</InstanceId> +                <Health> +                  <State>{health_status}</State> +                  {health_detail_subsection} +                </Health> +              </Role> +            </RoleInstanceList> +          </Container> +        </Health> +        ''') + +    HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = textwrap.dedent('''\ +        <Details> +          <SubStatus>{health_substatus}</SubStatus> +          <Description>{health_description}</Description> +        </Details> +        ''') + +    PROVISIONING_SUCCESS_STATUS = 'Ready' + +    def __init__(self, goal_state, azure_endpoint_client, endpoint): +        """Creates instance that will report provisioning status to an endpoint + +        @param goal_state: An instance of class GoalState that contains +            goal state info such as incarnation, container id, and instance id. +            These 3 values are needed when reporting the provisioning status +            to Azure +        @param azure_endpoint_client: Instance of class AzureEndpointHttpClient +        @param endpoint: Endpoint (string) where the provisioning status report +            will be sent to +        @return: Instance of class GoalStateHealthReporter +        """ +        self._goal_state = goal_state +        self._azure_endpoint_client = azure_endpoint_client +        self._endpoint = endpoint + +    @azure_ds_telemetry_reporter +    def send_ready_signal(self): +        document = self.build_report( +            incarnation=self._goal_state.incarnation, +            container_id=self._goal_state.container_id, +            instance_id=self._goal_state.instance_id, +            status=self.PROVISIONING_SUCCESS_STATUS) +        LOG.debug('Reporting ready to Azure fabric.') +        try: +            self._post_health_report(document=document) +        except Exception as e: +            msg = "exception while reporting ready: %s" % e +            LOG.error(msg) +            report_diagnostic_event(msg) +            raise + +        LOG.info('Reported ready to Azure fabric.') + +    def build_report( +            self, incarnation, container_id, instance_id, +            status, substatus=None, description=None): +        health_detail = '' +        if substatus is not None: +            health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( +                health_substatus=substatus, health_description=description) + +        health_report = self.HEALTH_REPORT_XML_TEMPLATE.format( +            incarnation=incarnation, +            container_id=container_id, +            instance_id=instance_id, +            health_status=status, +            health_detail_subsection=health_detail) + +        return health_report + +    @azure_ds_telemetry_reporter +    def _post_health_report(self, document): +        # Whenever report_diagnostic_event(diagnostic_msg) is invoked in code, +        # the diagnostic messages are written to special files +        # (/var/opt/hyperv/.kvp_pool_*) as Hyper-V KVP messages. +        # Hyper-V KVP message communication is done through these files, +        # and KVP functionality is used to communicate and share diagnostic +        # info with the Azure Host. +        # The Azure Host will collect the VM's Hyper-V KVP diagnostic messages +        # when cloud-init reports to fabric. +        # When the Azure Host receives the health report signal, it will only +        # collect and process whatever KVP diagnostic messages have been +        # written to the KVP files. +        # KVP messages that are published after the Azure Host receives the +        # signal are ignored and unprocessed, so yield this thread to the +        # Hyper-V KVP Reporting thread so that they are written. +        # time.sleep(0) is a low-cost and proven method to yield the scheduler +        # and ensure that events are flushed. +        # See HyperVKvpReportingHandler class, which is a multi-threaded +        # reporting handler that writes to the special KVP files. +        time.sleep(0) + +        LOG.debug('Sending health report to Azure fabric.') +        url = "http://{}/machine?comp=health".format(self._endpoint) +        self._azure_endpoint_client.post( +            url, +            data=document, +            extra_headers={'Content-Type': 'text/xml; charset=utf-8'}) +        LOG.debug('Successfully sent health report to Azure fabric') + + +class WALinuxAgentShim:      def __init__(self, fallback_lease_file=None, dhcp_options=None):          LOG.debug('WALinuxAgentShim instantiated, fallback_lease_file=%s', @@ -396,6 +511,7 @@ class WALinuxAgentShim(object):          self.dhcpoptions = dhcp_options          self._endpoint = None          self.openssl_manager = None +        self.azure_endpoint_client = None          self.lease_file = fallback_lease_file      def clean_up(self): @@ -494,7 +610,22 @@ class WALinuxAgentShim(object):      @staticmethod      @azure_ds_telemetry_reporter      def find_endpoint(fallback_lease_file=None, dhcp245=None): +        """Finds and returns the Azure endpoint using various methods. + +        The Azure endpoint is searched in the following order: +        1. Endpoint from dhcp options (dhcp option 245). +        2. Endpoint from networkd. +        3. Endpoint from dhclient hook json. +        4. Endpoint from fallback lease file. +        5. The default Azure endpoint. + +        @param fallback_lease_file: Fallback lease file that will be used +            during endpoint search. +        @param dhcp245: dhcp options that will be used during endpoint search. +        @return: Azure endpoint IP address. +        """          value = None +          if dhcp245 is not None:              value = dhcp245              LOG.debug("Using Azure Endpoint from dhcp options") @@ -536,42 +667,128 @@ class WALinuxAgentShim(object):      @azure_ds_telemetry_reporter      def register_with_azure_and_fetch_data(self, pubkey_info=None): +        """Gets the VM's GoalState from Azure, uses the GoalState information +        to report ready/send the ready signal/provisioning complete signal to +        Azure, and then uses pubkey_info to filter and obtain the user's +        pubkeys from the GoalState. + +        @param pubkey_info: List of pubkey values and fingerprints which are +            used to filter and obtain the user's pubkey values from the +            GoalState. +        @return: The list of user's authorized pubkey values. +        """          if self.openssl_manager is None:              self.openssl_manager = OpenSSLManager() -        http_client = AzureEndpointHttpClient(self.openssl_manager.certificate) +        if self.azure_endpoint_client is None: +            self.azure_endpoint_client = AzureEndpointHttpClient( +                self.openssl_manager.certificate) +        goal_state = self._fetch_goal_state_from_azure() +        ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) +        health_reporter = GoalStateHealthReporter( +            goal_state, self.azure_endpoint_client, self.endpoint) +        health_reporter.send_ready_signal() +        return {'public-keys': ssh_keys} + +    @azure_ds_telemetry_reporter +    def _fetch_goal_state_from_azure(self): +        """Fetches the GoalState XML from the Azure endpoint, parses the XML, +        and returns a GoalState object. + +        @return: GoalState object representing the GoalState XML +        """ +        unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() +        return self._parse_raw_goal_state_xml(unparsed_goal_state_xml) + +    @azure_ds_telemetry_reporter +    def _get_raw_goal_state_xml_from_azure(self): +        """Fetches the GoalState XML from the Azure endpoint and returns +        the XML as a string. + +        @return: GoalState XML string +        """ +          LOG.info('Registering with Azure...') -        attempts = 0 -        while True: -            try: -                response = http_client.get( -                    'http://{0}/machine/?comp=goalstate'.format(self.endpoint)) -            except Exception as e: -                if attempts < 10: -                    time.sleep(attempts + 1) -                else: -                    report_diagnostic_event( -                        "failed to register with Azure: %s" % e) -                    raise -            else: -                break -            attempts += 1 +        url = 'http://{}/machine/?comp=goalstate'.format(self.endpoint) +        try: +            response = self.azure_endpoint_client.get(url) +        except Exception as e: +            msg = 'failed to register with Azure: %s' % e +            LOG.warning(msg) +            report_diagnostic_event(msg) +            raise          LOG.debug('Successfully fetched GoalState XML.') -        goal_state = GoalState(response.contents, http_client) -        report_diagnostic_event("container_id %s" % goal_state.container_id) +        return response.contents + +    @azure_ds_telemetry_reporter +    def _parse_raw_goal_state_xml(self, unparsed_goal_state_xml): +        """Parses a GoalState XML string and returns a GoalState object. + +        @param unparsed_goal_state_xml: GoalState XML string +        @return: GoalState object representing the GoalState XML +        """ +        try: +            goal_state = GoalState( +                unparsed_goal_state_xml, self.azure_endpoint_client) +        except Exception as e: +            msg = 'Error processing GoalState XML: %s' % e +            LOG.warning(msg) +            report_diagnostic_event(msg) +            raise +        msg = ', '.join([ +            'GoalState XML container id: %s' % goal_state.container_id, +            'GoalState XML instance id: %s' % goal_state.instance_id, +            'GoalState XML incarnation: %s' % goal_state.incarnation]) +        LOG.debug(msg) +        report_diagnostic_event(msg) +        return goal_state + +    @azure_ds_telemetry_reporter +    def _get_user_pubkeys(self, goal_state, pubkey_info): +        """Gets and filters the VM admin user's authorized pubkeys. + +        The admin user in this case is the username specified as "admin" +        when deploying VMs on Azure. +        See https://docs.microsoft.com/en-us/cli/azure/vm#az-vm-create. +        cloud-init expects a straightforward array of keys to be dropped +        into the admin user's authorized_keys file. Azure control plane exposes +        multiple public keys to the VM via wireserver. Select just the +        admin user's key(s) and return them, ignoring any other certs. + +        @param goal_state: GoalState object. The GoalState object contains +            a certificate XML, which contains both the VM user's authorized +            pubkeys and other non-user pubkeys, which are used for +            MSI and protected extension handling. +        @param pubkey_info: List of VM user pubkey dicts that were previously +            obtained from provisioning data. +            Each pubkey dict in this list can either have the format +            pubkey['value'] or pubkey['fingerprint']. +            Each pubkey['fingerprint'] in the list is used to filter +            and obtain the actual pubkey value from the GoalState +            certificates XML. +            Each pubkey['value'] requires no further processing and is +            immediately added to the return list. +        @return: A list of the VM user's authorized pubkey values. +        """          ssh_keys = []          if goal_state.certificates_xml is not None and pubkey_info is not None:              LOG.debug('Certificate XML found; parsing out public keys.')              keys_by_fingerprint = self.openssl_manager.parse_certificates(                  goal_state.certificates_xml)              ssh_keys = self._filter_pubkeys(keys_by_fingerprint, pubkey_info) -        self._report_ready(goal_state, http_client) -        return {'public-keys': ssh_keys} +        return ssh_keys -    def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info): -        """cloud-init expects a straightforward array of keys to be dropped -           into the user's authorized_keys file. Azure control plane exposes -           multiple public keys to the VM via wireserver. Select just the -           user's key(s) and return them, ignoring any other certs. +    @staticmethod +    def _filter_pubkeys(keys_by_fingerprint, pubkey_info): +        """ Filter and return only the user's actual pubkeys. + +        @param keys_by_fingerprint: pubkey fingerprint -> pubkey value dict +            that was obtained from GoalState Certificates XML. May contain +            non-user pubkeys. +        @param pubkey_info: List of VM user pubkeys. Pubkey values are added +            to the return list without further processing. Pubkey fingerprints +            are used to filter and obtain the actual pubkey values from +            keys_by_fingerprint. +        @return: A list of the VM user's authorized pubkey values.          """          keys = []          for pubkey in pubkey_info: @@ -590,30 +807,6 @@ class WALinuxAgentShim(object):          return keys -    @azure_ds_telemetry_reporter -    def _report_ready(self, goal_state, http_client): -        LOG.debug('Reporting ready to Azure fabric.') -        document = self.REPORT_READY_XML_TEMPLATE.format( -            incarnation=goal_state.incarnation, -            container_id=goal_state.container_id, -            instance_id=goal_state.instance_id, -        ) -        # Host will collect kvps when cloud-init reports ready. -        # some kvps might still be in the queue. We yield the scheduler -        # to make sure we process all kvps up till this point. -        time.sleep(0) -        try: -            http_client.post( -                "http://{0}/machine?comp=health".format(self.endpoint), -                data=document, -                extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, -            ) -        except Exception as e: -            report_diagnostic_event("exception while reporting ready: %s" % e) -            raise - -        LOG.info('Reported ready to Azure fabric.') -  @azure_ds_telemetry_reporter  def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, @@ -631,7 +824,7 @@ def dhcp_log_cb(out, err):      report_diagnostic_event("dhclient error stream: %s" % err) -class EphemeralDHCPv4WithReporting(object): +class EphemeralDHCPv4WithReporting:      def __init__(self, reporter, nic=None):          self.reporter = reporter          self.ephemeralDHCPv4 = EphemeralDHCPv4( diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index f314cd4a..5e6d3d2d 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,8 +1,10 @@  # This file is part of cloud-init. See LICENSE file for license information.  import os +import re  import unittest  from textwrap import dedent +from xml.etree import ElementTree  from cloudinit.sources.helpers import azure as azure_helper  from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -48,6 +50,30 @@ GOAL_STATE_TEMPLATE = """\  </GoalState>  """ +HEALTH_REPORT_XML_TEMPLATE = '''\ +<?xml version="1.0" encoding="utf-8"?> +<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xmlns:xsd="http://www.w3.org/2001/XMLSchema"> +  <GoalStateIncarnation>{incarnation}</GoalStateIncarnation> +  <Container> +    <ContainerId>{container_id}</ContainerId> +    <RoleInstanceList> +      <Role> +        <InstanceId>{instance_id}</InstanceId> +        <Health> +          <State>{health_status}</State> +          {health_detail_subsection} +        </Health> +      </Role> +    </RoleInstanceList> +  </Container> +</Health> +''' + + +class SentinelException(Exception): +    pass +  class TestFindEndpoint(CiTestCase): @@ -140,9 +166,7 @@ class TestGoalStateParsing(CiTestCase):          'certificates_url': 'MyCertificatesUrl',      } -    def _get_goal_state(self, http_client=None, **kwargs): -        if http_client is None: -            http_client = mock.MagicMock() +    def _get_formatted_goal_state_xml_string(self, **kwargs):          parameters = self.default_parameters.copy()          parameters.update(kwargs)          xml = GOAL_STATE_TEMPLATE.format(**parameters) @@ -153,7 +177,13 @@ class TestGoalStateParsing(CiTestCase):                      continue                  new_xml_lines.append(line)              xml = '\n'.join(new_xml_lines) -        return azure_helper.GoalState(xml, http_client) +        return xml + +    def _get_goal_state(self, m_azure_endpoint_client=None, **kwargs): +        if m_azure_endpoint_client is None: +            m_azure_endpoint_client = mock.MagicMock() +        xml = self._get_formatted_goal_state_xml_string(**kwargs) +        return azure_helper.GoalState(xml, m_azure_endpoint_client)      def test_incarnation_parsed_correctly(self):          incarnation = '123' @@ -190,25 +220,55 @@ class TestGoalStateParsing(CiTestCase):              azure_helper.is_byte_swapped(previous_iid, current_iid))      def test_certificates_xml_parsed_and_fetched_correctly(self): -        http_client = mock.MagicMock() +        m_azure_endpoint_client = mock.MagicMock()          certificates_url = 'TestCertificatesUrl'          goal_state = self._get_goal_state( -            http_client=http_client, certificates_url=certificates_url) +            m_azure_endpoint_client=m_azure_endpoint_client, +            certificates_url=certificates_url)          certificates_xml = goal_state.certificates_xml -        self.assertEqual(1, http_client.get.call_count) -        self.assertEqual(certificates_url, http_client.get.call_args[0][0]) -        self.assertTrue(http_client.get.call_args[1].get('secure', False)) -        self.assertEqual(http_client.get.return_value.contents, -                         certificates_xml) +        self.assertEqual(1, m_azure_endpoint_client.get.call_count) +        self.assertEqual( +            certificates_url, +            m_azure_endpoint_client.get.call_args[0][0]) +        self.assertTrue( +            m_azure_endpoint_client.get.call_args[1].get( +                'secure', False)) +        self.assertEqual( +            m_azure_endpoint_client.get.return_value.contents, +            certificates_xml)      def test_missing_certificates_skips_http_get(self): -        http_client = mock.MagicMock() +        m_azure_endpoint_client = mock.MagicMock()          goal_state = self._get_goal_state( -            http_client=http_client, certificates_url=None) +            m_azure_endpoint_client=m_azure_endpoint_client, +            certificates_url=None)          certificates_xml = goal_state.certificates_xml -        self.assertEqual(0, http_client.get.call_count) +        self.assertEqual(0, m_azure_endpoint_client.get.call_count)          self.assertIsNone(certificates_xml) +    def test_invalid_goal_state_xml_raises_parse_error(self): +        xml = 'random non-xml data' +        with self.assertRaises(ElementTree.ParseError): +            azure_helper.GoalState(xml, mock.MagicMock()) + +    def test_missing_container_id_in_goal_state_xml_raises_exc(self): +        xml = self._get_formatted_goal_state_xml_string() +        xml = re.sub('<ContainerId>.*</ContainerId>', '', xml) +        with self.assertRaises(azure_helper.InvalidGoalStateXMLException): +            azure_helper.GoalState(xml, mock.MagicMock()) + +    def test_missing_instance_id_in_goal_state_xml_raises_exc(self): +        xml = self._get_formatted_goal_state_xml_string() +        xml = re.sub('<InstanceId>.*</InstanceId>', '', xml) +        with self.assertRaises(azure_helper.InvalidGoalStateXMLException): +            azure_helper.GoalState(xml, mock.MagicMock()) + +    def test_missing_incarnation_in_goal_state_xml_raises_exc(self): +        xml = self._get_formatted_goal_state_xml_string() +        xml = re.sub('<Incarnation>.*</Incarnation>', '', xml) +        with self.assertRaises(azure_helper.InvalidGoalStateXMLException): +            azure_helper.GoalState(xml, mock.MagicMock()) +  class TestAzureEndpointHttpClient(CiTestCase): @@ -222,61 +282,95 @@ class TestAzureEndpointHttpClient(CiTestCase):          patches = ExitStack()          self.addCleanup(patches.close) -        self.read_file_or_url = patches.enter_context( -            mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) +        self.readurl = patches.enter_context( +            mock.patch.object(azure_helper.url_helper, 'readurl')) +        patches.enter_context( +            mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))      def test_non_secure_get(self):          client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())          url = 'MyTestUrl'          response = client.get(url, secure=False) -        self.assertEqual(1, self.read_file_or_url.call_count) -        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual(1, self.readurl.call_count) +        self.assertEqual(self.readurl.return_value, response)          self.assertEqual( -            mock.call(url, headers=self.regular_headers, retries=10, -                      timeout=5), -            self.read_file_or_url.call_args) +            mock.call(url, headers=self.regular_headers, +                      timeout=5, retries=10, sec_between=5), +            self.readurl.call_args) + +    def test_non_secure_get_raises_exception(self): +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        self.readurl.side_effect = SentinelException +        url = 'MyTestUrl' +        with self.assertRaises(SentinelException): +            client.get(url, secure=False)      def test_secure_get(self):          url = 'MyTestUrl' -        certificate = mock.MagicMock() +        m_certificate = mock.MagicMock()          expected_headers = self.regular_headers.copy()          expected_headers.update({              "x-ms-cipher-name": "DES_EDE3_CBC", -            "x-ms-guest-agent-public-x509-cert": certificate, +            "x-ms-guest-agent-public-x509-cert": m_certificate,          }) -        client = azure_helper.AzureEndpointHttpClient(certificate) +        client = azure_helper.AzureEndpointHttpClient(m_certificate)          response = client.get(url, secure=True) -        self.assertEqual(1, self.read_file_or_url.call_count) -        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual(1, self.readurl.call_count) +        self.assertEqual(self.readurl.return_value, response)          self.assertEqual( -            mock.call(url, headers=expected_headers, retries=10, -                      timeout=5), -            self.read_file_or_url.call_args) +            mock.call(url, headers=expected_headers, +                      timeout=5, retries=10, sec_between=5), +            self.readurl.call_args) + +    def test_secure_get_raises_exception(self): +        url = 'MyTestUrl' +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        self.readurl.side_effect = SentinelException +        with self.assertRaises(SentinelException): +            client.get(url, secure=True)      def test_post(self): -        data = mock.MagicMock() +        m_data = mock.MagicMock()          url = 'MyTestUrl'          client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) -        response = client.post(url, data=data) -        self.assertEqual(1, self.read_file_or_url.call_count) -        self.assertEqual(self.read_file_or_url.return_value, response) +        response = client.post(url, data=m_data) +        self.assertEqual(1, self.readurl.call_count) +        self.assertEqual(self.readurl.return_value, response)          self.assertEqual( -            mock.call(url, data=data, headers=self.regular_headers, retries=10, -                      timeout=5), -            self.read_file_or_url.call_args) +            mock.call(url, data=m_data, headers=self.regular_headers, +                      timeout=5, retries=10, sec_between=5), +            self.readurl.call_args) + +    def test_post_raises_exception(self): +        m_data = mock.MagicMock() +        url = 'MyTestUrl' +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        self.readurl.side_effect = SentinelException +        with self.assertRaises(SentinelException): +            client.post(url, data=m_data)      def test_post_with_extra_headers(self):          url = 'MyTestUrl'          client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())          extra_headers = {'test': 'header'}          client.post(url, extra_headers=extra_headers) -        self.assertEqual(1, self.read_file_or_url.call_count)          expected_headers = self.regular_headers.copy()          expected_headers.update(extra_headers) +        self.assertEqual(1, self.readurl.call_count)          self.assertEqual(              mock.call(mock.ANY, data=mock.ANY, headers=expected_headers, -                      retries=10, timeout=5), -            self.read_file_or_url.call_args) +                      timeout=5, retries=10, sec_between=5), +            self.readurl.call_args) + +    def test_post_with_sleep_with_extra_headers_raises_exception(self): +        m_data = mock.MagicMock() +        url = 'MyTestUrl' +        extra_headers = {'test': 'header'} +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        self.readurl.side_effect = SentinelException +        with self.assertRaises(SentinelException): +            client.post( +                url, data=m_data, extra_headers=extra_headers)  class TestOpenSSLManager(CiTestCase): @@ -365,6 +459,131 @@ class TestOpenSSLManagerActions(CiTestCase):              self.assertIn(fp, keys_by_fp) +class TestGoalStateHealthReporter(CiTestCase): + +    default_parameters = { +        'incarnation': 1634, +        'container_id': 'MyContainerId', +        'instance_id': 'MyInstanceId' +    } + +    test_endpoint = 'TestEndpoint' +    test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) +    test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} + +    provisioning_success_status = 'Ready' + +    def setUp(self): +        super(TestGoalStateHealthReporter, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        patches.enter_context( +            mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) +        self.read_file_or_url = patches.enter_context( +            mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + +        self.post = patches.enter_context( +            mock.patch.object(azure_helper.AzureEndpointHttpClient, +                              'post')) + +        self.GoalState = patches.enter_context( +            mock.patch.object(azure_helper, 'GoalState')) +        self.GoalState.return_value.container_id = \ +            self.default_parameters['container_id'] +        self.GoalState.return_value.instance_id = \ +            self.default_parameters['instance_id'] +        self.GoalState.return_value.incarnation = \ +            self.default_parameters['incarnation'] + +    def _get_formatted_health_report_xml_string(self, **kwargs): +        return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + +    def _get_report_ready_health_document(self): +        return self._get_formatted_health_report_xml_string( +            incarnation=self.default_parameters['incarnation'], +            container_id=self.default_parameters['container_id'], +            instance_id=self.default_parameters['instance_id'], +            health_status=self.provisioning_success_status, +            health_detail_subsection='') + +    def test_send_ready_signal_sends_post_request(self): +        with mock.patch.object( +                azure_helper.GoalStateHealthReporter, +                'build_report') as m_build_report: +            client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +            reporter = azure_helper.GoalStateHealthReporter( +                azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), +                client, self.test_endpoint) +            reporter.send_ready_signal() + +            self.assertEqual(1, self.post.call_count) +            self.assertEqual( +                mock.call( +                    self.test_url, +                    data=m_build_report.return_value, +                    extra_headers=self.test_default_headers), +                self.post.call_args) + +    def test_build_report_for_health_document(self): +        health_document = self._get_report_ready_health_document() +        reporter = azure_helper.GoalStateHealthReporter( +            azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), +            azure_helper.AzureEndpointHttpClient(mock.MagicMock()), +            self.test_endpoint) +        generated_health_document = reporter.build_report( +            incarnation=self.default_parameters['incarnation'], +            container_id=self.default_parameters['container_id'], +            instance_id=self.default_parameters['instance_id'], +            status=self.provisioning_success_status) +        self.assertEqual(health_document, generated_health_document) +        self.assertIn( +            '<GoalStateIncarnation>{}</GoalStateIncarnation>'.format( +                str(self.default_parameters['incarnation'])), +            generated_health_document) +        self.assertIn( +            ''.join([ +                '<ContainerId>', +                self.default_parameters['container_id'], +                '</ContainerId>']), +            generated_health_document) +        self.assertIn( +            ''.join([ +                '<InstanceId>', +                self.default_parameters['instance_id'], +                '</InstanceId>']), +            generated_health_document) +        self.assertIn( +            ''.join([ +                '<State>', +                self.provisioning_success_status, +                '</State>']), +            generated_health_document +        ) +        self.assertNotIn('<Details>', generated_health_document) +        self.assertNotIn('<SubStatus>', generated_health_document) +        self.assertNotIn('<Description>', generated_health_document) + +    def test_send_ready_signal_calls_build_report(self): +        with mock.patch.object( +            azure_helper.GoalStateHealthReporter, 'build_report' +        ) as m_build_report: +            reporter = azure_helper.GoalStateHealthReporter( +                azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), +                azure_helper.AzureEndpointHttpClient(mock.MagicMock()), +                self.test_endpoint) +            reporter.send_ready_signal() + +            self.assertEqual(1, m_build_report.call_count) +            self.assertEqual( +                mock.call( +                    incarnation=self.default_parameters['incarnation'], +                    container_id=self.default_parameters['container_id'], +                    instance_id=self.default_parameters['instance_id'], +                    status=self.provisioning_success_status), +                m_build_report.call_args) + +  class TestWALinuxAgentShim(CiTestCase):      def setUp(self): @@ -383,14 +602,21 @@ class TestWALinuxAgentShim(CiTestCase):          patches.enter_context(              mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) -    def test_http_client_uses_certificate(self): +        self.test_incarnation = 'TestIncarnation' +        self.test_container_id = 'TestContainerId' +        self.test_instance_id = 'TestInstanceId' +        self.GoalState.return_value.incarnation = self.test_incarnation +        self.GoalState.return_value.container_id = self.test_container_id +        self.GoalState.return_value.instance_id = self.test_instance_id + +    def test_azure_endpoint_client_uses_certificate_during_report_ready(self):          shim = wa_shim()          shim.register_with_azure_and_fetch_data()          self.assertEqual(              [mock.call(self.OpenSSLManager.return_value.certificate)],              self.AzureEndpointHttpClient.call_args_list) -    def test_correct_url_used_for_goalstate(self): +    def test_correct_url_used_for_goalstate_during_report_ready(self):          self.find_endpoint.return_value = 'test_endpoint'          shim = wa_shim()          shim.register_with_azure_and_fetch_data() @@ -438,43 +664,67 @@ class TestWALinuxAgentShim(CiTestCase):          expected_url = 'http://test_endpoint/machine?comp=health'          self.assertEqual(              [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], -            self.AzureEndpointHttpClient.return_value.post.call_args_list) +            self.AzureEndpointHttpClient.return_value.post +                .call_args_list)      def test_goal_state_values_used_for_report_ready(self): -        self.GoalState.return_value.incarnation = 'TestIncarnation' -        self.GoalState.return_value.container_id = 'TestContainerId' -        self.GoalState.return_value.instance_id = 'TestInstanceId'          shim = wa_shim()          shim.register_with_azure_and_fetch_data()          posted_document = ( -            self.AzureEndpointHttpClient.return_value.post.call_args[1]['data'] +            self.AzureEndpointHttpClient.return_value.post +                .call_args[1]['data']          ) -        self.assertIn('TestIncarnation', posted_document) -        self.assertIn('TestContainerId', posted_document) -        self.assertIn('TestInstanceId', posted_document) +        self.assertIn(self.test_incarnation, posted_document) +        self.assertIn(self.test_container_id, posted_document) +        self.assertIn(self.test_instance_id, posted_document) + +    def test_xml_elems_in_report_ready(self): +        shim = wa_shim() +        shim.register_with_azure_and_fetch_data() +        health_document = HEALTH_REPORT_XML_TEMPLATE.format( +            incarnation=self.test_incarnation, +            container_id=self.test_container_id, +            instance_id=self.test_instance_id, +            health_status='Ready', +            health_detail_subsection='') +        posted_document = ( +            self.AzureEndpointHttpClient.return_value.post +                .call_args[1]['data']) +        self.assertEqual(health_document, posted_document)      def test_clean_up_can_be_called_at_any_time(self):          shim = wa_shim()          shim.clean_up() -    def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self): +    def test_clean_up_after_report_ready(self):          shim = wa_shim()          shim.register_with_azure_and_fetch_data()          shim.clean_up()          self.assertEqual(              1, self.OpenSSLManager.return_value.clean_up.call_count) -    def test_failure_to_fetch_goalstate_bubbles_up(self): -        class SentinelException(Exception): -            pass -        self.AzureEndpointHttpClient.return_value.get.side_effect = ( -            SentinelException) +    def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): +        self.AzureEndpointHttpClient.return_value.get \ +            .side_effect = (SentinelException)          shim = wa_shim()          self.assertRaises(SentinelException,                            shim.register_with_azure_and_fetch_data) +    def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): +        self.GoalState.side_effect = SentinelException +        shim = wa_shim() +        self.assertRaises(SentinelException, +                          shim.register_with_azure_and_fetch_data) -class TestGetMetadataFromFabric(CiTestCase): +    def test_failure_to_send_report_ready_health_doc_bubbles_up(self): +        self.AzureEndpointHttpClient.return_value.post \ +            .side_effect = SentinelException +        shim = wa_shim() +        self.assertRaises(SentinelException, +                          shim.register_with_azure_and_fetch_data) + + +class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase):      @mock.patch.object(azure_helper, 'WALinuxAgentShim')      def test_data_from_shim_returned(self, shim): @@ -490,14 +740,39 @@ class TestGetMetadataFromFabric(CiTestCase):      @mock.patch.object(azure_helper, 'WALinuxAgentShim')      def test_failure_in_registration_calls_clean_up(self, shim): -        class SentinelException(Exception): -            pass          shim.return_value.register_with_azure_and_fetch_data.side_effect = (              SentinelException)          self.assertRaises(SentinelException,                            azure_helper.get_metadata_from_fabric)          self.assertEqual(1, shim.return_value.clean_up.call_count) +    @mock.patch.object(azure_helper, 'WALinuxAgentShim') +    def test_calls_shim_register_with_azure_and_fetch_data(self, shim): +        m_pubkey_info = mock.MagicMock() +        azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info) +        self.assertEqual( +            1, +            shim.return_value +                .register_with_azure_and_fetch_data.call_count) +        self.assertEqual( +            mock.call(pubkey_info=m_pubkey_info), +            shim.return_value +                .register_with_azure_and_fetch_data.call_args) + +    @mock.patch.object(azure_helper, 'WALinuxAgentShim') +    def test_instantiates_shim_with_kwargs(self, shim): +        m_fallback_lease_file = mock.MagicMock() +        m_dhcp_options = mock.MagicMock() +        azure_helper.get_metadata_from_fabric( +            fallback_lease_file=m_fallback_lease_file, +            dhcp_opts=m_dhcp_options) +        self.assertEqual(1, shim.call_count) +        self.assertEqual( +            mock.call( +                fallback_lease_file=m_fallback_lease_file, +                dhcp_options=m_dhcp_options), +            shim.call_args) +  class TestExtractIpAddressFromNetworkd(CiTestCase): diff --git a/tools/.github-cla-signers b/tools/.github-cla-signers index 50473bf7..0c4d728f 100644 --- a/tools/.github-cla-signers +++ b/tools/.github-cla-signers @@ -7,6 +7,7 @@ dermotbradley  dhensby  eandersson  izzyleung +johnsonshi  landon912  lucasmoura  marlluslustosa | 
