diff options
| -rw-r--r-- | cloudinit/sources/DataSourceAzure.py | 292 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 331 | 
2 files changed, 574 insertions, 49 deletions
| diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index a19d9ca2..bd3c742b 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -22,8 +22,14 @@ import crypt  import fnmatch  import os  import os.path +import re +import socket +import struct +import tempfile  import time +from contextlib import contextmanager  from xml.dom import minidom +from xml.etree import ElementTree  from cloudinit import log as logging  from cloudinit.settings import PER_ALWAYS @@ -34,13 +40,11 @@ LOG = logging.getLogger(__name__)  DS_NAME = 'Azure'  DEFAULT_METADATA = {"instance-id": "iid-AZURE-NODE"} -AGENT_START = ['service', 'walinuxagent', 'start']  BOUNCE_COMMAND = ['sh', '-xc',      "i=$interface; x=0; ifdown $i || x=$?; ifup $i || x=$?; exit $x"]  DATA_DIR_CLEAN_LIST = ['SharedConfig.xml']  BUILTIN_DS_CONFIG = { -    'agent_command': AGENT_START,      'data_dir': "/var/lib/waagent",      'set_hostname': True,      'hostname_bounce': { @@ -66,6 +70,231 @@ BUILTIN_CLOUD_CONFIG = {  DS_CFG_PATH = ['datasource', DS_NAME]  DEF_EPHEMERAL_LABEL = 'Temporary Storage' +REPORT_READY_XML_TEMPLATE = """\ +<?xml version=\"1.0\" encoding=\"utf-8\"?> +<Health xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\"> +  <GoalStateIncarnation>{incarnation}</GoalStateIncarnation> +  <Container> +    <ContainerId>{container_id}</ContainerId> +    <RoleInstanceList> +      <Role> +        <InstanceId>{instance_id}</InstanceId> +        <Health> +          <State>Ready</State> +        </Health> +      </Role> +    </RoleInstanceList> +  </Container> +</Health>""" + + +@contextmanager +def cd(newdir): +    prevdir = os.getcwd() +    os.chdir(os.path.expanduser(newdir)) +    try: +        yield +    finally: +        os.chdir(prevdir) + + +class AzureEndpointHttpClient(object): + +    headers = { +        'x-ms-agent-name': 'WALinuxAgent', +        'x-ms-version': '2012-11-30', +    } + +    def __init__(self, certificate): +        self.extra_secure_headers = { +            "x-ms-cipher-name": "DES_EDE3_CBC", +            "x-ms-guest-agent-public-x509-cert": certificate, +        } + +    def get(self, url, secure=False): +        headers = self.headers +        if secure: +            headers = self.headers.copy() +            headers.update(self.extra_secure_headers) +        return util.read_file_or_url(url, headers=headers) + +    def post(self, url, data=None, extra_headers=None): +        headers = self.headers +        if extra_headers is not None: +            headers = self.headers.copy() +            headers.update(extra_headers) +        return util.read_file_or_url(url, data=data, headers=headers) + + +def find_endpoint(): +    content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') +    value = None +    for line in content.splitlines(): +        if 'unknown-245' in line: +            value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') +    if value is None: +        raise Exception('No endpoint found in DHCP config.') +    if ':' in value: +        hex_string = '' +        for hex_pair in value.split(':'): +            if len(hex_pair) == 1: +                hex_pair = '0' + hex_pair +            hex_string += hex_pair +        value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) +    else: +        value = value.encode('utf-8') +    return socket.inet_ntoa(value) + + +class GoalState(object): + +    def __init__(self, xml, http_client): +        self.http_client = http_client +        self.root = ElementTree.fromstring(xml) + +    def _text_from_xpath(self, xpath): +        element = self.root.find(xpath) +        if element is not None: +            return element.text +        return None + +    @property +    def container_id(self): +        return self._text_from_xpath('./Container/ContainerId') + +    @property +    def incarnation(self): +        return self._text_from_xpath('./Incarnation') + +    @property +    def instance_id(self): +        return self._text_from_xpath( +            './Container/RoleInstanceList/RoleInstance/InstanceId') + +    @property +    def shared_config_xml(self): +        url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' +                                    '/Configuration/SharedConfig') +        return self.http_client.get(url).contents + +    @property +    def certificates_xml(self): +        url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' +                                    '/Configuration/Certificates') +        if url is not None: +            return self.http_client.get(url, secure=True).contents +        return None + + +class OpenSSLManager(object): + +    certificate_names = { +        'private_key': 'TransportPrivate.pem', +        'certificate': 'TransportCert.pem', +    } + +    def __init__(self): +        self.tmpdir = tempfile.TemporaryDirectory() +        self.certificate = None +        self.generate_certificate() + +    def generate_certificate(self): +        if self.certificate is not None: +            return +        with cd(self.tmpdir.name): +            util.subp([ +                'openssl', 'req', '-x509', '-nodes', '-subj', +                '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', +                '-keyout', self.certificate_names['private_key'], +                '-out', self.certificate_names['certificate'], +            ]) +            certificate = '' +            for line in open(self.certificate_names['certificate']): +                if "CERTIFICATE" not in line: +                    certificate += line.rstrip() +            self.certificate = certificate + +    def parse_certificates(self, certificates_xml): +        tag = ElementTree.fromstring(certificates_xml).find( +            './/Data') +        certificates_content = tag.text +        lines = [ +            b'MIME-Version: 1.0', +            b'Content-Disposition: attachment; filename="Certificates.p7m"', +            b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', +            b'Content-Transfer-Encoding: base64', +            b'', +            certificates_content.encode('utf-8'), +        ] +        with cd(self.tmpdir.name): +            with open('Certificates.p7m', 'wb') as f: +                f.write(b'\n'.join(lines)) +            out, _ = util.subp( +                'openssl cms -decrypt -in Certificates.p7m -inkey' +                ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' +                ' -password pass:'.format(**self.certificate_names), +                shell=True) +        private_keys, certificates = [], [] +        current = [] +        for line in out.splitlines(): +            current.append(line) +            if re.match(r'[-]+END .*?KEY[-]+$', line): +                private_keys.append('\n'.join(current)) +                current = [] +            elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): +                certificates.append('\n'.join(current)) +                current = [] +        keys = [] +        for certificate in certificates: +            with cd(self.tmpdir.name): +                public_key, _ = util.subp( +                    'openssl x509 -noout -pubkey |' +                    'ssh-keygen -i -m PKCS8 -f /dev/stdin', +                    data=certificate, +                    shell=True) +            keys.append(public_key) +        return keys + + +class WALinuxAgentShim(object): + +    def __init__(self): +        self.endpoint = find_endpoint() +        self.goal_state = None +        self.openssl_manager = OpenSSLManager() +        self.http_client = AzureEndpointHttpClient( +            self.openssl_manager.certificate) +        self.values = {} + +    def register_with_azure_and_fetch_data(self): +        LOG.info('Registering with Azure...') +        for i in range(10): +            try: +                response = self.http_client.get( +                    'http://{}/machine/?comp=goalstate'.format(self.endpoint)) +            except Exception: +                time.sleep(i + 1) +            else: +                break +        self.goal_state = GoalState(response.contents, self.http_client) +        self.public_keys = [] +        if self.goal_state.certificates_xml is not None: +            self.public_keys = self.openssl_manager.parse_certificates( +                self.goal_state.certificates_xml) +        self._report_ready() + +    def _report_ready(self): +        document = REPORT_READY_XML_TEMPLATE.format( +            incarnation=self.goal_state.incarnation, +            container_id=self.goal_state.container_id, +            instance_id=self.goal_state.instance_id, +        ) +        self.http_client.post( +            "http://{}/machine?comp=health".format(self.endpoint), +            data=document, +            extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, +        ) +  def get_hostname(hostname_command='hostname'):      return util.subp(hostname_command, capture=True)[0].strip() @@ -185,53 +414,17 @@ class DataSourceAzureNet(sources.DataSource):          # the directory to be protected.          write_files(ddir, files, dirmode=0o700) -        temp_hostname = self.metadata.get('local-hostname') -        hostname_command = mycfg['hostname_bounce']['hostname_command'] -        with temporary_hostname(temp_hostname, mycfg, -                                hostname_command=hostname_command) \ -                as previous_hostname: -            if (previous_hostname is not None -                    and util.is_true(mycfg.get('set_hostname'))): -                cfg = mycfg['hostname_bounce'] -                try: -                    perform_hostname_bounce(hostname=temp_hostname, -                                            cfg=cfg, -                                            prev_hostname=previous_hostname) -                except Exception as e: -                    LOG.warn("Failed publishing hostname: %s", e) -                    util.logexc(LOG, "handling set_hostname failed") +        shim = WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() -            try: -                invoke_agent(mycfg['agent_command']) -            except util.ProcessExecutionError: -                # claim the datasource even if the command failed -                util.logexc(LOG, "agent command '%s' failed.", -                            mycfg['agent_command']) - -            shcfgxml = os.path.join(ddir, "SharedConfig.xml") -            wait_for = [shcfgxml] - -            fp_files = [] -            for pk in self.cfg.get('_pubkeys', []): -                bname = str(pk['fingerprint'] + ".crt") -                fp_files += [os.path.join(ddir, bname)] - -            missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", -                                    func=wait_for_files, -                                    args=(wait_for + fp_files,)) -        if len(missing): -            LOG.warn("Did not find files, but going on: %s", missing) - -        if shcfgxml in missing: -            LOG.warn("SharedConfig.xml missing, using static instance-id") -        else: -            try: -                self.metadata['instance-id'] = iid_from_shared_config(shcfgxml) -            except ValueError as e: -                LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) +        try: +            self.metadata['instance-id'] = iid_from_shared_config_content( +                shim.goal_state.shared_config_xml) +        except ValueError as e: +            LOG.warn( +                "failed to get instance id in %s: %s", shim.shared_config, e) -        pubkeys = pubkeys_from_crt_files(fp_files) -        self.metadata['public-keys'] = pubkeys +        self.metadata['public-keys'] = shim.public_keys          found_ephemeral = find_ephemeral_disk()          if found_ephemeral: @@ -363,10 +556,11 @@ def perform_hostname_bounce(hostname, cfg, prev_hostname):                            'env': env}) -def crtfile_to_pubkey(fname): +def crtfile_to_pubkey(fname, data=None):      pipeline = ('openssl x509 -noout -pubkey < "$0" |'                  'ssh-keygen -i -m PKCS8 -f /dev/stdin') -    (out, _err) = util.subp(['sh', '-c', pipeline, fname], capture=True) +    (out, _err) = util.subp(['sh', '-c', pipeline, fname], +                            capture=True, data=data)      return out.rstrip() diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 7e789853..dc7f2663 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -15,11 +15,48 @@ except ImportError:  import crypt  import os  import stat +import struct  import yaml  import shutil  import tempfile  import unittest +from cloudinit import url_helper + + +GOAL_STATE_TEMPLATE = """\ +<?xml version="1.0" encoding="utf-8"?> +<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="goalstate10.xsd"> +  <Version>2012-11-30</Version> +  <Incarnation>{incarnation}</Incarnation> +  <Machine> +    <ExpectedState>Started</ExpectedState> +    <StopRolesDeadlineHint>300000</StopRolesDeadlineHint> +    <LBProbePorts> +      <Port>16001</Port> +    </LBProbePorts> +    <ExpectHealthReport>FALSE</ExpectHealthReport> +  </Machine> +  <Container> +    <ContainerId>{container_id}</ContainerId> +    <RoleInstanceList> +      <RoleInstance> +        <InstanceId>{instance_id}</InstanceId> +        <State>Started</State> +        <Configuration> +          <HostingEnvironmentConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1</HostingEnvironmentConfig> +          <SharedConfig>{shared_config_url}</SharedConfig> +          <ExtensionsConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1</ExtensionsConfig> +          <FullConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1</FullConfig> +          <Certificates>{certificates_url}</Certificates> +          <ConfigName>68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml</ConfigName> +        </Configuration> +      </RoleInstance> +    </RoleInstanceList> +  </Container> +</GoalState> +""" +  def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None):      if data is None: @@ -579,3 +616,297 @@ class TestReadAzureSharedConfig(unittest.TestCase):              </SharedConfig>"""          ret = DataSourceAzure.iid_from_shared_config_content(xml)          self.assertEqual("MY_INSTANCE_ID", ret) + + +class TestFindEndpoint(TestCase): + +    def setUp(self): +        super(TestFindEndpoint, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.load_file = patches.enter_context( +            mock.patch.object(DataSourceAzure.util, 'load_file')) + +    def test_missing_file(self): +        self.load_file.side_effect = IOError +        self.assertRaises(IOError, DataSourceAzure.find_endpoint) + +    def test_missing_special_azure_line(self): +        self.load_file.return_value = '' +        self.assertRaises(Exception, DataSourceAzure.find_endpoint) + +    def _build_lease_content(self, ip_address, use_hex=True): +        ip_address_repr = ':'.join( +            [hex(int(part)).replace('0x', '') +             for part in ip_address.split('.')]) +        if not use_hex: +            ip_address_repr = struct.pack( +                '>L', int(ip_address_repr.replace(':', ''), 16)) +            ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) +        return '\n'.join([ +            'lease {', +            ' interface "eth0";', +            ' option unknown-245 {0};'.format(ip_address_repr), +            '}']) + +    def test_hex_string(self): +        ip_address = '98.76.54.32' +        file_content = self._build_lease_content(ip_address) +        self.load_file.return_value = file_content +        self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + +    def test_hex_string_with_single_character_part(self): +        ip_address = '4.3.2.1' +        file_content = self._build_lease_content(ip_address) +        self.load_file.return_value = file_content +        self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + +    def test_packed_string(self): +        ip_address = '98.76.54.32' +        file_content = self._build_lease_content(ip_address, use_hex=False) +        self.load_file.return_value = file_content +        self.assertEqual(ip_address, DataSourceAzure.find_endpoint()) + +    def test_latest_lease_used(self): +        ip_addresses = ['4.3.2.1', '98.76.54.32'] +        file_content = '\n'.join([self._build_lease_content(ip_address) +                                  for ip_address in ip_addresses]) +        self.load_file.return_value = file_content +        self.assertEqual(ip_addresses[-1], DataSourceAzure.find_endpoint()) + + +class TestGoalStateParsing(TestCase): + +    default_parameters = { +        'incarnation': 1, +        'container_id': 'MyContainerId', +        'instance_id': 'MyInstanceId', +        'shared_config_url': 'MySharedConfigUrl', +        'certificates_url': 'MyCertificatesUrl', +    } + +    def _get_goal_state(self, http_client=None, **kwargs): +        if http_client is None: +            http_client = mock.MagicMock() +        parameters = self.default_parameters.copy() +        parameters.update(kwargs) +        xml = GOAL_STATE_TEMPLATE.format(**parameters) +        if parameters['certificates_url'] is None: +            new_xml_lines = [] +            for line in xml.splitlines(): +                if 'Certificates' in line: +                    continue +                new_xml_lines.append(line) +            xml = '\n'.join(new_xml_lines) +        return DataSourceAzure.GoalState(xml, http_client) + +    def test_incarnation_parsed_correctly(self): +        incarnation = '123' +        goal_state = self._get_goal_state(incarnation=incarnation) +        self.assertEqual(incarnation, goal_state.incarnation) + +    def test_container_id_parsed_correctly(self): +        container_id = 'TestContainerId' +        goal_state = self._get_goal_state(container_id=container_id) +        self.assertEqual(container_id, goal_state.container_id) + +    def test_instance_id_parsed_correctly(self): +        instance_id = 'TestInstanceId' +        goal_state = self._get_goal_state(instance_id=instance_id) +        self.assertEqual(instance_id, goal_state.instance_id) + +    def test_shared_config_xml_parsed_and_fetched_correctly(self): +        http_client = mock.MagicMock() +        shared_config_url = 'TestSharedConfigUrl' +        goal_state = self._get_goal_state( +            http_client=http_client, shared_config_url=shared_config_url) +        shared_config_xml = goal_state.shared_config_xml +        self.assertEqual(1, http_client.get.call_count) +        self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) +        self.assertEqual(http_client.get.return_value.contents, +                         shared_config_xml) + +    def test_certificates_xml_parsed_and_fetched_correctly(self): +        http_client = mock.MagicMock() +        certificates_url = 'TestSharedConfigUrl' +        goal_state = self._get_goal_state( +            http_client=http_client, certificates_url=certificates_url) +        certificates_xml = goal_state.certificates_xml +        self.assertEqual(1, http_client.get.call_count) +        self.assertEqual(certificates_url, http_client.get.call_args[0][0]) +        self.assertTrue(http_client.get.call_args[1].get('secure', False)) +        self.assertEqual(http_client.get.return_value.contents, +                         certificates_xml) + +    def test_missing_certificates_skips_http_get(self): +        http_client = mock.MagicMock() +        goal_state = self._get_goal_state( +            http_client=http_client, certificates_url=None) +        certificates_xml = goal_state.certificates_xml +        self.assertEqual(0, http_client.get.call_count) +        self.assertIsNone(certificates_xml) + + +class TestAzureEndpointHttpClient(TestCase): + +    regular_headers = { +        'x-ms-agent-name': 'WALinuxAgent', +        'x-ms-version': '2012-11-30', +    } + +    def setUp(self): +        super(TestAzureEndpointHttpClient, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.read_file_or_url = patches.enter_context( +            mock.patch.object(DataSourceAzure.util, 'read_file_or_url')) + +    def test_non_secure_get(self): +        client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) +        url = 'MyTestUrl' +        response = client.get(url, secure=False) +        self.assertEqual(1, self.read_file_or_url.call_count) +        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual(mock.call(url, headers=self.regular_headers), +                         self.read_file_or_url.call_args) + +    def test_secure_get(self): +        url = 'MyTestUrl' +        certificate = mock.MagicMock() +        expected_headers = self.regular_headers.copy() +        expected_headers.update({ +            "x-ms-cipher-name": "DES_EDE3_CBC", +            "x-ms-guest-agent-public-x509-cert": certificate, +        }) +        client = DataSourceAzure.AzureEndpointHttpClient(certificate) +        response = client.get(url, secure=True) +        self.assertEqual(1, self.read_file_or_url.call_count) +        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual(mock.call(url, headers=expected_headers), +                         self.read_file_or_url.call_args) + +    def test_post(self): +        data = mock.MagicMock() +        url = 'MyTestUrl' +        client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) +        response = client.post(url, data=data) +        self.assertEqual(1, self.read_file_or_url.call_count) +        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual( +            mock.call(url, data=data, headers=self.regular_headers), +            self.read_file_or_url.call_args) + +    def test_post_with_extra_headers(self): +        url = 'MyTestUrl' +        client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) +        extra_headers = {'test': 'header'} +        client.post(url, extra_headers=extra_headers) +        self.assertEqual(1, self.read_file_or_url.call_count) +        expected_headers = self.regular_headers.copy() +        expected_headers.update(extra_headers) +        self.assertEqual( +            mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), +            self.read_file_or_url.call_args) + + +class TestOpenSSLManager(TestCase): + +    def setUp(self): +        super(TestOpenSSLManager, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.subp = patches.enter_context( +            mock.patch.object(DataSourceAzure.util, 'subp')) + +    @mock.patch.object(DataSourceAzure, 'cd', mock.MagicMock()) +    @mock.patch.object(DataSourceAzure.tempfile, 'TemporaryDirectory') +    def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): +        manager = DataSourceAzure.OpenSSLManager() +        self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) + +    @mock.patch('builtins.open') +    def test_generate_certificate_uses_tmpdir(self, open): +        subp_directory = {} + +        def capture_directory(*args, **kwargs): +            subp_directory['path'] = os.getcwd() + +        self.subp.side_effect = capture_directory +        manager = DataSourceAzure.OpenSSLManager() +        self.assertEqual(manager.tmpdir.name, subp_directory['path']) + + +class TestWALinuxAgentShim(TestCase): + +    def setUp(self): +        super(TestWALinuxAgentShim, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.AzureEndpointHttpClient = patches.enter_context( +            mock.patch.object(DataSourceAzure, 'AzureEndpointHttpClient')) +        self.find_endpoint = patches.enter_context( +            mock.patch.object(DataSourceAzure, 'find_endpoint')) +        self.GoalState = patches.enter_context( +            mock.patch.object(DataSourceAzure, 'GoalState')) +        self.OpenSSLManager = patches.enter_context( +            mock.patch.object(DataSourceAzure, 'OpenSSLManager')) + +    def test_http_client_uses_certificate(self): +        shim = DataSourceAzure.WALinuxAgentShim() +        self.assertEqual( +            [mock.call(self.OpenSSLManager.return_value.certificate)], +            self.AzureEndpointHttpClient.call_args_list) +        self.assertEqual(self.AzureEndpointHttpClient.return_value, +                         shim.http_client) + +    def test_correct_url_used_for_goalstate(self): +        self.find_endpoint.return_value = 'test_endpoint' +        shim = DataSourceAzure.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        get = self.AzureEndpointHttpClient.return_value.get +        self.assertEqual( +            [mock.call('http://test_endpoint/machine/?comp=goalstate')], +            get.call_args_list) +        self.assertEqual( +            [mock.call(get.return_value.contents, shim.http_client)], +            self.GoalState.call_args_list) + +    def test_certificates_used_to_determine_public_keys(self): +        shim = DataSourceAzure.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        self.assertEqual( +            [mock.call(self.GoalState.return_value.certificates_xml)], +            self.OpenSSLManager.return_value.parse_certificates.call_args_list) +        self.assertEqual( +            self.OpenSSLManager.return_value.parse_certificates.return_value, +            shim.public_keys) + +    def test_absent_certificates_produces_empty_public_keys(self): +        self.GoalState.return_value.certificates_xml = None +        shim = DataSourceAzure.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        self.assertEqual([], shim.public_keys) + +    def test_correct_url_used_for_report_ready(self): +        self.find_endpoint.return_value = 'test_endpoint' +        shim = DataSourceAzure.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        expected_url = 'http://test_endpoint/machine?comp=health' +        self.assertEqual( +            [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], +            shim.http_client.post.call_args_list) + +    def test_goal_state_values_used_for_report_ready(self): +        self.GoalState.return_value.incarnation = 'TestIncarnation' +        self.GoalState.return_value.container_id = 'TestContainerId' +        self.GoalState.return_value.instance_id = 'TestInstanceId' +        shim = DataSourceAzure.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        posted_document = shim.http_client.post.call_args[1]['data'] +        self.assertIn('TestIncarnation', posted_document) +        self.assertIn('TestContainerId', posted_document) +        self.assertIn('TestInstanceId', posted_document) | 
