diff options
| -rw-r--r-- | cloudinit/sources/DataSourceAzure.py | 271 | ||||
| -rw-r--r-- | cloudinit/sources/helpers/azure.py | 273 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 364 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 377 | 
4 files changed, 653 insertions, 632 deletions
| diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index c2dc6b4c..5e147950 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -17,23 +17,19 @@  #    along with this program.  If not, see <http://www.gnu.org/licenses/>.  import base64 +import contextlib  import crypt  import fnmatch  import os  import os.path -import re -import socket -import struct -import tempfile -import time -from contextlib import contextmanager  from xml.dom import minidom -from xml.etree import ElementTree  from cloudinit import log as logging  from cloudinit.settings import PER_ALWAYS  from cloudinit import sources  from cloudinit import util +from cloudinit.sources.helpers.azure import ( +    iid_from_shared_config_content, WALinuxAgentShim)  LOG = logging.getLogger(__name__) @@ -70,253 +66,6 @@ DS_CFG_PATH = ['datasource', DS_NAME]  DEF_EPHEMERAL_LABEL = 'Temporary Storage' - -@contextmanager -def cd(newdir): -    prevdir = os.getcwd() -    os.chdir(os.path.expanduser(newdir)) -    try: -        yield -    finally: -        os.chdir(prevdir) - - -class AzureEndpointHttpClient(object): - -    headers = { -        'x-ms-agent-name': 'WALinuxAgent', -        'x-ms-version': '2012-11-30', -    } - -    def __init__(self, certificate): -        self.extra_secure_headers = { -            "x-ms-cipher-name": "DES_EDE3_CBC", -            "x-ms-guest-agent-public-x509-cert": certificate, -        } - -    def get(self, url, secure=False): -        headers = self.headers -        if secure: -            headers = self.headers.copy() -            headers.update(self.extra_secure_headers) -        return util.read_file_or_url(url, headers=headers) - -    def post(self, url, data=None, extra_headers=None): -        headers = self.headers -        if extra_headers is not None: -            headers = self.headers.copy() -            headers.update(extra_headers) -        return util.read_file_or_url(url, data=data, headers=headers) - - -class GoalState(object): - -    def __init__(self, xml, http_client): -        self.http_client = http_client -        self.root = ElementTree.fromstring(xml) -        self._certificates_xml = None - -    def _text_from_xpath(self, xpath): -        element = self.root.find(xpath) -        if element is not None: -            return element.text -        return None - -    @property -    def container_id(self): -        return self._text_from_xpath('./Container/ContainerId') - -    @property -    def incarnation(self): -        return self._text_from_xpath('./Incarnation') - -    @property -    def instance_id(self): -        return self._text_from_xpath( -            './Container/RoleInstanceList/RoleInstance/InstanceId') - -    @property -    def shared_config_xml(self): -        url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' -                                    '/Configuration/SharedConfig') -        return self.http_client.get(url).contents - -    @property -    def certificates_xml(self): -        if self._certificates_xml is None: -            url = self._text_from_xpath( -                './Container/RoleInstanceList/RoleInstance' -                '/Configuration/Certificates') -            if url is not None: -                self._certificates_xml = self.http_client.get( -                    url, secure=True).contents -        return self._certificates_xml - - -class OpenSSLManager(object): - -    certificate_names = { -        'private_key': 'TransportPrivate.pem', -        'certificate': 'TransportCert.pem', -    } - -    def __init__(self): -        self.tmpdir = tempfile.TemporaryDirectory() -        self.certificate = None -        self.generate_certificate() - -    def generate_certificate(self): -        LOG.debug('Generating certificate for communication with fabric...') -        if self.certificate is not None: -            LOG.debug('Certificate already generated.') -            return -        with cd(self.tmpdir.name): -            util.subp([ -                'openssl', 'req', '-x509', '-nodes', '-subj', -                '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', -                '-keyout', self.certificate_names['private_key'], -                '-out', self.certificate_names['certificate'], -            ]) -            certificate = '' -            for line in open(self.certificate_names['certificate']): -                if "CERTIFICATE" not in line: -                    certificate += line.rstrip() -            self.certificate = certificate -        LOG.debug('New certificate generated.') - -    def parse_certificates(self, certificates_xml): -        tag = ElementTree.fromstring(certificates_xml).find( -            './/Data') -        certificates_content = tag.text -        lines = [ -            b'MIME-Version: 1.0', -            b'Content-Disposition: attachment; filename="Certificates.p7m"', -            b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', -            b'Content-Transfer-Encoding: base64', -            b'', -            certificates_content.encode('utf-8'), -        ] -        with cd(self.tmpdir.name): -            with open('Certificates.p7m', 'wb') as f: -                f.write(b'\n'.join(lines)) -            out, _ = util.subp( -                'openssl cms -decrypt -in Certificates.p7m -inkey' -                ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' -                ' -password pass:'.format(**self.certificate_names), -                shell=True) -        private_keys, certificates = [], [] -        current = [] -        for line in out.splitlines(): -            current.append(line) -            if re.match(r'[-]+END .*?KEY[-]+$', line): -                private_keys.append('\n'.join(current)) -                current = [] -            elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): -                certificates.append('\n'.join(current)) -                current = [] -        keys = [] -        for certificate in certificates: -            with cd(self.tmpdir.name): -                public_key, _ = util.subp( -                    'openssl x509 -noout -pubkey |' -                    'ssh-keygen -i -m PKCS8 -f /dev/stdin', -                    data=certificate, -                    shell=True) -            keys.append(public_key) -        return keys - - -class WALinuxAgentShim(object): - -    REPORT_READY_XML_TEMPLATE = '\n'.join([ -        '<?xml version="1.0" encoding="utf-8"?>', -        '<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"' -        ' xmlns:xsd="http://www.w3.org/2001/XMLSchema">', -        '  <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>', -        '  <Container>', -        '    <ContainerId>{container_id}</ContainerId>', -        '    <RoleInstanceList>', -        '      <Role>', -        '        <InstanceId>{instance_id}</InstanceId>', -        '        <Health>', -        '          <State>Ready</State>', -        '        </Health>', -        '      </Role>', -        '    </RoleInstanceList>', -        '  </Container>', -        '</Health>']) - -    def __init__(self): -        LOG.debug('WALinuxAgentShim instantiated...') -        self.endpoint = self.find_endpoint() -        self.openssl_manager = OpenSSLManager() -        self.http_client = AzureEndpointHttpClient( -            self.openssl_manager.certificate) -        self.values = {} - -    @staticmethod -    def find_endpoint(): -        LOG.debug('Finding Azure endpoint...') -        content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') -        value = None -        for line in content.splitlines(): -            if 'unknown-245' in line: -                value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') -        if value is None: -            raise Exception('No endpoint found in DHCP config.') -        if ':' in value: -            hex_string = '' -            for hex_pair in value.split(':'): -                if len(hex_pair) == 1: -                    hex_pair = '0' + hex_pair -                hex_string += hex_pair -            value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) -        else: -            value = value.encode('utf-8') -        endpoint_ip_address = socket.inet_ntoa(value) -        LOG.debug('Azure endpoint found at %s', endpoint_ip_address) -        return endpoint_ip_address - -    def register_with_azure_and_fetch_data(self): -        LOG.info('Registering with Azure...') -        for i in range(10): -            try: -                response = self.http_client.get( -                    'http://{}/machine/?comp=goalstate'.format(self.endpoint)) -            except Exception: -                time.sleep(i + 1) -            else: -                break -        LOG.debug('Successfully fetched GoalState XML.') -        goal_state = GoalState(response.contents, self.http_client) -        public_keys = [] -        if goal_state.certificates_xml is not None: -            LOG.debug('Certificate XML found; parsing out public keys.') -            public_keys = self.openssl_manager.parse_certificates( -                goal_state.certificates_xml) -        data = { -            'instance-id': iid_from_shared_config_content( -                goal_state.shared_config_xml), -            'public-keys': public_keys, -        } -        self._report_ready(goal_state) -        return data - -    def _report_ready(self, goal_state): -        LOG.debug('Reporting ready to Azure fabric.') -        document = self.REPORT_READY_XML_TEMPLATE.format( -            incarnation=goal_state.incarnation, -            container_id=goal_state.container_id, -            instance_id=goal_state.instance_id, -        ) -        self.http_client.post( -            "http://{}/machine?comp=health".format(self.endpoint), -            data=document, -            extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, -        ) -        LOG.info('Reported ready to Azure fabric.') - -  def get_hostname(hostname_command='hostname'):      return util.subp(hostname_command, capture=True)[0].strip() @@ -690,20 +439,6 @@ def load_azure_ovf_pubkeys(sshnode):      return found -def single_node_at_path(node, pathlist): -    curnode = node -    for tok in pathlist: -        results = find_child(curnode, lambda n: n.localName == tok) -        if len(results) == 0: -            raise ValueError("missing %s token in %s" % (tok, str(pathlist))) -        if len(results) > 1: -            raise ValueError("found %s nodes of type %s looking for %s" % -                             (len(results), tok, str(pathlist))) -        curnode = results[0] - -    return curnode - -  def read_azure_ovf(contents):      try:          dom = minidom.parseString(contents) diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py new file mode 100644 index 00000000..60f116e0 --- /dev/null +++ b/cloudinit/sources/helpers/azure.py @@ -0,0 +1,273 @@ +import logging +import os +import re +import socket +import struct +import tempfile +import time +from contextlib import contextmanager +from xml.etree import ElementTree + +from cloudinit import util + + +LOG = logging.getLogger(__name__) + + +@contextmanager +def cd(newdir): +    prevdir = os.getcwd() +    os.chdir(os.path.expanduser(newdir)) +    try: +        yield +    finally: +        os.chdir(prevdir) + + +class AzureEndpointHttpClient(object): + +    headers = { +        'x-ms-agent-name': 'WALinuxAgent', +        'x-ms-version': '2012-11-30', +    } + +    def __init__(self, certificate): +        self.extra_secure_headers = { +            "x-ms-cipher-name": "DES_EDE3_CBC", +            "x-ms-guest-agent-public-x509-cert": certificate, +        } + +    def get(self, url, secure=False): +        headers = self.headers +        if secure: +            headers = self.headers.copy() +            headers.update(self.extra_secure_headers) +        return util.read_file_or_url(url, headers=headers) + +    def post(self, url, data=None, extra_headers=None): +        headers = self.headers +        if extra_headers is not None: +            headers = self.headers.copy() +            headers.update(extra_headers) +        return util.read_file_or_url(url, data=data, headers=headers) + + +class GoalState(object): + +    def __init__(self, xml, http_client): +        self.http_client = http_client +        self.root = ElementTree.fromstring(xml) +        self._certificates_xml = None + +    def _text_from_xpath(self, xpath): +        element = self.root.find(xpath) +        if element is not None: +            return element.text +        return None + +    @property +    def container_id(self): +        return self._text_from_xpath('./Container/ContainerId') + +    @property +    def incarnation(self): +        return self._text_from_xpath('./Incarnation') + +    @property +    def instance_id(self): +        return self._text_from_xpath( +            './Container/RoleInstanceList/RoleInstance/InstanceId') + +    @property +    def shared_config_xml(self): +        url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' +                                    '/Configuration/SharedConfig') +        return self.http_client.get(url).contents + +    @property +    def certificates_xml(self): +        if self._certificates_xml is None: +            url = self._text_from_xpath( +                './Container/RoleInstanceList/RoleInstance' +                '/Configuration/Certificates') +            if url is not None: +                self._certificates_xml = self.http_client.get( +                    url, secure=True).contents +        return self._certificates_xml + + +class OpenSSLManager(object): + +    certificate_names = { +        'private_key': 'TransportPrivate.pem', +        'certificate': 'TransportCert.pem', +    } + +    def __init__(self): +        self.tmpdir = tempfile.TemporaryDirectory() +        self.certificate = None +        self.generate_certificate() + +    def generate_certificate(self): +        LOG.debug('Generating certificate for communication with fabric...') +        if self.certificate is not None: +            LOG.debug('Certificate already generated.') +            return +        with cd(self.tmpdir.name): +            util.subp([ +                'openssl', 'req', '-x509', '-nodes', '-subj', +                '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', +                '-keyout', self.certificate_names['private_key'], +                '-out', self.certificate_names['certificate'], +            ]) +            certificate = '' +            for line in open(self.certificate_names['certificate']): +                if "CERTIFICATE" not in line: +                    certificate += line.rstrip() +            self.certificate = certificate +        LOG.debug('New certificate generated.') + +    def parse_certificates(self, certificates_xml): +        tag = ElementTree.fromstring(certificates_xml).find( +            './/Data') +        certificates_content = tag.text +        lines = [ +            b'MIME-Version: 1.0', +            b'Content-Disposition: attachment; filename="Certificates.p7m"', +            b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', +            b'Content-Transfer-Encoding: base64', +            b'', +            certificates_content.encode('utf-8'), +        ] +        with cd(self.tmpdir.name): +            with open('Certificates.p7m', 'wb') as f: +                f.write(b'\n'.join(lines)) +            out, _ = util.subp( +                'openssl cms -decrypt -in Certificates.p7m -inkey' +                ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' +                ' -password pass:'.format(**self.certificate_names), +                shell=True) +        private_keys, certificates = [], [] +        current = [] +        for line in out.splitlines(): +            current.append(line) +            if re.match(r'[-]+END .*?KEY[-]+$', line): +                private_keys.append('\n'.join(current)) +                current = [] +            elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): +                certificates.append('\n'.join(current)) +                current = [] +        keys = [] +        for certificate in certificates: +            with cd(self.tmpdir.name): +                public_key, _ = util.subp( +                    'openssl x509 -noout -pubkey |' +                    'ssh-keygen -i -m PKCS8 -f /dev/stdin', +                    data=certificate, +                    shell=True) +            keys.append(public_key) +        return keys + + +def iid_from_shared_config_content(content): +    """ +    find INSTANCE_ID in: +    <?xml version="1.0" encoding="utf-8"?> +    <SharedConfig version="1.0.0.0" goalStateIncarnation="1"> +    <Deployment name="INSTANCE_ID" guid="{...}" incarnation="0"> +        <Service name="..." guid="{00000000-0000-0000-0000-000000000000}"/> +    """ +    root = ElementTree.fromstring(content) +    depnode = root.find('Deployment') +    return depnode.get('name') + + +class WALinuxAgentShim(object): + +    REPORT_READY_XML_TEMPLATE = '\n'.join([ +        '<?xml version="1.0" encoding="utf-8"?>', +        '<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"' +        ' xmlns:xsd="http://www.w3.org/2001/XMLSchema">', +        '  <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>', +        '  <Container>', +        '    <ContainerId>{container_id}</ContainerId>', +        '    <RoleInstanceList>', +        '      <Role>', +        '        <InstanceId>{instance_id}</InstanceId>', +        '        <Health>', +        '          <State>Ready</State>', +        '        </Health>', +        '      </Role>', +        '    </RoleInstanceList>', +        '  </Container>', +        '</Health>']) + +    def __init__(self): +        LOG.debug('WALinuxAgentShim instantiated...') +        self.endpoint = self.find_endpoint() +        self.openssl_manager = OpenSSLManager() +        self.http_client = AzureEndpointHttpClient( +            self.openssl_manager.certificate) +        self.values = {} + +    @staticmethod +    def find_endpoint(): +        LOG.debug('Finding Azure endpoint...') +        content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') +        value = None +        for line in content.splitlines(): +            if 'unknown-245' in line: +                value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') +        if value is None: +            raise Exception('No endpoint found in DHCP config.') +        if ':' in value: +            hex_string = '' +            for hex_pair in value.split(':'): +                if len(hex_pair) == 1: +                    hex_pair = '0' + hex_pair +                hex_string += hex_pair +            value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) +        else: +            value = value.encode('utf-8') +        endpoint_ip_address = socket.inet_ntoa(value) +        LOG.debug('Azure endpoint found at %s', endpoint_ip_address) +        return endpoint_ip_address + +    def register_with_azure_and_fetch_data(self): +        LOG.info('Registering with Azure...') +        for i in range(10): +            try: +                response = self.http_client.get( +                    'http://{}/machine/?comp=goalstate'.format(self.endpoint)) +            except Exception: +                time.sleep(i + 1) +            else: +                break +        LOG.debug('Successfully fetched GoalState XML.') +        goal_state = GoalState(response.contents, self.http_client) +        public_keys = [] +        if goal_state.certificates_xml is not None: +            LOG.debug('Certificate XML found; parsing out public keys.') +            public_keys = self.openssl_manager.parse_certificates( +                goal_state.certificates_xml) +        data = { +            'instance-id': iid_from_shared_config_content( +                goal_state.shared_config_xml), +            'public-keys': public_keys, +        } +        self._report_ready(goal_state) +        return data + +    def _report_ready(self, goal_state): +        LOG.debug('Reporting ready to Azure fabric.') +        document = self.REPORT_READY_XML_TEMPLATE.format( +            incarnation=goal_state.incarnation, +            container_id=goal_state.container_id, +            instance_id=goal_state.instance_id, +        ) +        self.http_client.post( +            "http://{}/machine?comp=health".format(self.endpoint), +            data=document, +            extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, +        ) +        LOG.info('Reported ready to Azure fabric.') diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 28703029..ee7109e1 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -15,47 +15,9 @@ except ImportError:  import crypt  import os  import stat -import struct  import yaml  import shutil  import tempfile -import unittest - -from cloudinit import url_helper - - -GOAL_STATE_TEMPLATE = """\ -<?xml version="1.0" encoding="utf-8"?> -<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="goalstate10.xsd"> -  <Version>2012-11-30</Version> -  <Incarnation>{incarnation}</Incarnation> -  <Machine> -    <ExpectedState>Started</ExpectedState> -    <StopRolesDeadlineHint>300000</StopRolesDeadlineHint> -    <LBProbePorts> -      <Port>16001</Port> -    </LBProbePorts> -    <ExpectHealthReport>FALSE</ExpectHealthReport> -  </Machine> -  <Container> -    <ContainerId>{container_id}</ContainerId> -    <RoleInstanceList> -      <RoleInstance> -        <InstanceId>{instance_id}</InstanceId> -        <State>Started</State> -        <Configuration> -          <HostingEnvironmentConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1</HostingEnvironmentConfig> -          <SharedConfig>{shared_config_url}</SharedConfig> -          <ExtensionsConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1</ExtensionsConfig> -          <FullConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1</FullConfig> -          <Certificates>{certificates_url}</Certificates> -          <ConfigName>68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml</ConfigName> -        </Configuration> -      </RoleInstance> -    </RoleInstanceList> -  </Container> -</GoalState> -"""  def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): @@ -610,329 +572,3 @@ class TestReadAzureOvf(TestCase):          for mypk in mypklist:              self.assertIn(mypk, cfg['_pubkeys']) - -class TestReadAzureSharedConfig(unittest.TestCase): -    def test_valid_content(self): -        xml = """<?xml version="1.0" encoding="utf-8"?> -            <SharedConfig> -             <Deployment name="MY_INSTANCE_ID"> -              <Service name="myservice"/> -              <ServiceInstance name="INSTANCE_ID.0" guid="{abcd-uuid}" /> -             </Deployment> -            <Incarnation number="1"/> -            </SharedConfig>""" -        ret = DataSourceAzure.iid_from_shared_config_content(xml) -        self.assertEqual("MY_INSTANCE_ID", ret) - - -class TestFindEndpoint(TestCase): - -    def setUp(self): -        super(TestFindEndpoint, self).setUp() -        patches = ExitStack() -        self.addCleanup(patches.close) - -        self.load_file = patches.enter_context( -            mock.patch.object(DataSourceAzure.util, 'load_file')) - -    def test_missing_file(self): -        self.load_file.side_effect = IOError -        self.assertRaises(IOError, -                          DataSourceAzure.WALinuxAgentShim.find_endpoint) - -    def test_missing_special_azure_line(self): -        self.load_file.return_value = '' -        self.assertRaises(Exception, -                          DataSourceAzure.WALinuxAgentShim.find_endpoint) - -    def _build_lease_content(self, ip_address, use_hex=True): -        ip_address_repr = ':'.join( -            [hex(int(part)).replace('0x', '') -             for part in ip_address.split('.')]) -        if not use_hex: -            ip_address_repr = struct.pack( -                '>L', int(ip_address_repr.replace(':', ''), 16)) -            ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) -        return '\n'.join([ -            'lease {', -            ' interface "eth0";', -            ' option unknown-245 {0};'.format(ip_address_repr), -            '}']) - -    def test_hex_string(self): -        ip_address = '98.76.54.32' -        file_content = self._build_lease_content(ip_address) -        self.load_file.return_value = file_content -        self.assertEqual(ip_address, -                         DataSourceAzure.WALinuxAgentShim.find_endpoint()) - -    def test_hex_string_with_single_character_part(self): -        ip_address = '4.3.2.1' -        file_content = self._build_lease_content(ip_address) -        self.load_file.return_value = file_content -        self.assertEqual(ip_address, -                         DataSourceAzure.WALinuxAgentShim.find_endpoint()) - -    def test_packed_string(self): -        ip_address = '98.76.54.32' -        file_content = self._build_lease_content(ip_address, use_hex=False) -        self.load_file.return_value = file_content -        self.assertEqual(ip_address, -                         DataSourceAzure.WALinuxAgentShim.find_endpoint()) - -    def test_latest_lease_used(self): -        ip_addresses = ['4.3.2.1', '98.76.54.32'] -        file_content = '\n'.join([self._build_lease_content(ip_address) -                                  for ip_address in ip_addresses]) -        self.load_file.return_value = file_content -        self.assertEqual(ip_addresses[-1], -                         DataSourceAzure.WALinuxAgentShim.find_endpoint()) - - -class TestGoalStateParsing(TestCase): - -    default_parameters = { -        'incarnation': 1, -        'container_id': 'MyContainerId', -        'instance_id': 'MyInstanceId', -        'shared_config_url': 'MySharedConfigUrl', -        'certificates_url': 'MyCertificatesUrl', -    } - -    def _get_goal_state(self, http_client=None, **kwargs): -        if http_client is None: -            http_client = mock.MagicMock() -        parameters = self.default_parameters.copy() -        parameters.update(kwargs) -        xml = GOAL_STATE_TEMPLATE.format(**parameters) -        if parameters['certificates_url'] is None: -            new_xml_lines = [] -            for line in xml.splitlines(): -                if 'Certificates' in line: -                    continue -                new_xml_lines.append(line) -            xml = '\n'.join(new_xml_lines) -        return DataSourceAzure.GoalState(xml, http_client) - -    def test_incarnation_parsed_correctly(self): -        incarnation = '123' -        goal_state = self._get_goal_state(incarnation=incarnation) -        self.assertEqual(incarnation, goal_state.incarnation) - -    def test_container_id_parsed_correctly(self): -        container_id = 'TestContainerId' -        goal_state = self._get_goal_state(container_id=container_id) -        self.assertEqual(container_id, goal_state.container_id) - -    def test_instance_id_parsed_correctly(self): -        instance_id = 'TestInstanceId' -        goal_state = self._get_goal_state(instance_id=instance_id) -        self.assertEqual(instance_id, goal_state.instance_id) - -    def test_shared_config_xml_parsed_and_fetched_correctly(self): -        http_client = mock.MagicMock() -        shared_config_url = 'TestSharedConfigUrl' -        goal_state = self._get_goal_state( -            http_client=http_client, shared_config_url=shared_config_url) -        shared_config_xml = goal_state.shared_config_xml -        self.assertEqual(1, http_client.get.call_count) -        self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) -        self.assertEqual(http_client.get.return_value.contents, -                         shared_config_xml) - -    def test_certificates_xml_parsed_and_fetched_correctly(self): -        http_client = mock.MagicMock() -        certificates_url = 'TestSharedConfigUrl' -        goal_state = self._get_goal_state( -            http_client=http_client, certificates_url=certificates_url) -        certificates_xml = goal_state.certificates_xml -        self.assertEqual(1, http_client.get.call_count) -        self.assertEqual(certificates_url, http_client.get.call_args[0][0]) -        self.assertTrue(http_client.get.call_args[1].get('secure', False)) -        self.assertEqual(http_client.get.return_value.contents, -                         certificates_xml) - -    def test_missing_certificates_skips_http_get(self): -        http_client = mock.MagicMock() -        goal_state = self._get_goal_state( -            http_client=http_client, certificates_url=None) -        certificates_xml = goal_state.certificates_xml -        self.assertEqual(0, http_client.get.call_count) -        self.assertIsNone(certificates_xml) - - -class TestAzureEndpointHttpClient(TestCase): - -    regular_headers = { -        'x-ms-agent-name': 'WALinuxAgent', -        'x-ms-version': '2012-11-30', -    } - -    def setUp(self): -        super(TestAzureEndpointHttpClient, self).setUp() -        patches = ExitStack() -        self.addCleanup(patches.close) - -        self.read_file_or_url = patches.enter_context( -            mock.patch.object(DataSourceAzure.util, 'read_file_or_url')) - -    def test_non_secure_get(self): -        client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) -        url = 'MyTestUrl' -        response = client.get(url, secure=False) -        self.assertEqual(1, self.read_file_or_url.call_count) -        self.assertEqual(self.read_file_or_url.return_value, response) -        self.assertEqual(mock.call(url, headers=self.regular_headers), -                         self.read_file_or_url.call_args) - -    def test_secure_get(self): -        url = 'MyTestUrl' -        certificate = mock.MagicMock() -        expected_headers = self.regular_headers.copy() -        expected_headers.update({ -            "x-ms-cipher-name": "DES_EDE3_CBC", -            "x-ms-guest-agent-public-x509-cert": certificate, -        }) -        client = DataSourceAzure.AzureEndpointHttpClient(certificate) -        response = client.get(url, secure=True) -        self.assertEqual(1, self.read_file_or_url.call_count) -        self.assertEqual(self.read_file_or_url.return_value, response) -        self.assertEqual(mock.call(url, headers=expected_headers), -                         self.read_file_or_url.call_args) - -    def test_post(self): -        data = mock.MagicMock() -        url = 'MyTestUrl' -        client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) -        response = client.post(url, data=data) -        self.assertEqual(1, self.read_file_or_url.call_count) -        self.assertEqual(self.read_file_or_url.return_value, response) -        self.assertEqual( -            mock.call(url, data=data, headers=self.regular_headers), -            self.read_file_or_url.call_args) - -    def test_post_with_extra_headers(self): -        url = 'MyTestUrl' -        client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock()) -        extra_headers = {'test': 'header'} -        client.post(url, extra_headers=extra_headers) -        self.assertEqual(1, self.read_file_or_url.call_count) -        expected_headers = self.regular_headers.copy() -        expected_headers.update(extra_headers) -        self.assertEqual( -            mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), -            self.read_file_or_url.call_args) - - -class TestOpenSSLManager(TestCase): - -    def setUp(self): -        super(TestOpenSSLManager, self).setUp() -        patches = ExitStack() -        self.addCleanup(patches.close) - -        self.subp = patches.enter_context( -            mock.patch.object(DataSourceAzure.util, 'subp')) - -    @mock.patch.object(DataSourceAzure, 'cd', mock.MagicMock()) -    @mock.patch.object(DataSourceAzure.tempfile, 'TemporaryDirectory') -    def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): -        manager = DataSourceAzure.OpenSSLManager() -        self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) - -    @mock.patch('builtins.open') -    def test_generate_certificate_uses_tmpdir(self, open): -        subp_directory = {} - -        def capture_directory(*args, **kwargs): -            subp_directory['path'] = os.getcwd() - -        self.subp.side_effect = capture_directory -        manager = DataSourceAzure.OpenSSLManager() -        self.assertEqual(manager.tmpdir.name, subp_directory['path']) - - -class TestWALinuxAgentShim(TestCase): - -    def setUp(self): -        super(TestWALinuxAgentShim, self).setUp() -        patches = ExitStack() -        self.addCleanup(patches.close) - -        self.AzureEndpointHttpClient = patches.enter_context( -            mock.patch.object(DataSourceAzure, 'AzureEndpointHttpClient')) -        self.find_endpoint = patches.enter_context( -            mock.patch.object( -                DataSourceAzure.WALinuxAgentShim, 'find_endpoint')) -        self.GoalState = patches.enter_context( -            mock.patch.object(DataSourceAzure, 'GoalState')) -        self.iid_from_shared_config_content = patches.enter_context( -            mock.patch.object(DataSourceAzure, -                              'iid_from_shared_config_content')) -        self.OpenSSLManager = patches.enter_context( -            mock.patch.object(DataSourceAzure, 'OpenSSLManager')) - -    def test_http_client_uses_certificate(self): -        shim = DataSourceAzure.WALinuxAgentShim() -        self.assertEqual( -            [mock.call(self.OpenSSLManager.return_value.certificate)], -            self.AzureEndpointHttpClient.call_args_list) -        self.assertEqual(self.AzureEndpointHttpClient.return_value, -                         shim.http_client) - -    def test_correct_url_used_for_goalstate(self): -        self.find_endpoint.return_value = 'test_endpoint' -        shim = DataSourceAzure.WALinuxAgentShim() -        shim.register_with_azure_and_fetch_data() -        get = self.AzureEndpointHttpClient.return_value.get -        self.assertEqual( -            [mock.call('http://test_endpoint/machine/?comp=goalstate')], -            get.call_args_list) -        self.assertEqual( -            [mock.call(get.return_value.contents, shim.http_client)], -            self.GoalState.call_args_list) - -    def test_certificates_used_to_determine_public_keys(self): -        shim = DataSourceAzure.WALinuxAgentShim() -        data = shim.register_with_azure_and_fetch_data() -        self.assertEqual( -            [mock.call(self.GoalState.return_value.certificates_xml)], -            self.OpenSSLManager.return_value.parse_certificates.call_args_list) -        self.assertEqual( -            self.OpenSSLManager.return_value.parse_certificates.return_value, -            data['public-keys']) - -    def test_absent_certificates_produces_empty_public_keys(self): -        self.GoalState.return_value.certificates_xml = None -        shim = DataSourceAzure.WALinuxAgentShim() -        data = shim.register_with_azure_and_fetch_data() -        self.assertEqual([], data['public-keys']) - -    def test_instance_id_returned_in_data(self): -        shim = DataSourceAzure.WALinuxAgentShim() -        data = shim.register_with_azure_and_fetch_data() -        self.assertEqual( -            [mock.call(self.GoalState.return_value.shared_config_xml)], -            self.iid_from_shared_config_content.call_args_list) -        self.assertEqual(self.iid_from_shared_config_content.return_value, -                         data['instance-id']) - -    def test_correct_url_used_for_report_ready(self): -        self.find_endpoint.return_value = 'test_endpoint' -        shim = DataSourceAzure.WALinuxAgentShim() -        shim.register_with_azure_and_fetch_data() -        expected_url = 'http://test_endpoint/machine?comp=health' -        self.assertEqual( -            [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], -            shim.http_client.post.call_args_list) - -    def test_goal_state_values_used_for_report_ready(self): -        self.GoalState.return_value.incarnation = 'TestIncarnation' -        self.GoalState.return_value.container_id = 'TestContainerId' -        self.GoalState.return_value.instance_id = 'TestInstanceId' -        shim = DataSourceAzure.WALinuxAgentShim() -        shim.register_with_azure_and_fetch_data() -        posted_document = shim.http_client.post.call_args[1]['data'] -        self.assertIn('TestIncarnation', posted_document) -        self.assertIn('TestContainerId', posted_document) -        self.assertIn('TestInstanceId', posted_document) diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py new file mode 100644 index 00000000..47b77840 --- /dev/null +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -0,0 +1,377 @@ +import os +import struct +import unittest + +from cloudinit.sources.helpers import azure as azure_helper +from ..helpers import TestCase + +try: +    from unittest import mock +except ImportError: +    import mock + +try: +    from contextlib import ExitStack +except ImportError: +    from contextlib2 import ExitStack + + +GOAL_STATE_TEMPLATE = """\ +<?xml version="1.0" encoding="utf-8"?> +<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="goalstate10.xsd"> +  <Version>2012-11-30</Version> +  <Incarnation>{incarnation}</Incarnation> +  <Machine> +    <ExpectedState>Started</ExpectedState> +    <StopRolesDeadlineHint>300000</StopRolesDeadlineHint> +    <LBProbePorts> +      <Port>16001</Port> +    </LBProbePorts> +    <ExpectHealthReport>FALSE</ExpectHealthReport> +  </Machine> +  <Container> +    <ContainerId>{container_id}</ContainerId> +    <RoleInstanceList> +      <RoleInstance> +        <InstanceId>{instance_id}</InstanceId> +        <State>Started</State> +        <Configuration> +          <HostingEnvironmentConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1</HostingEnvironmentConfig> +          <SharedConfig>{shared_config_url}</SharedConfig> +          <ExtensionsConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1</ExtensionsConfig> +          <FullConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1</FullConfig> +          <Certificates>{certificates_url}</Certificates> +          <ConfigName>68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml</ConfigName> +        </Configuration> +      </RoleInstance> +    </RoleInstanceList> +  </Container> +</GoalState> +""" + + +class TestReadAzureSharedConfig(unittest.TestCase): + +    def test_valid_content(self): +        xml = """<?xml version="1.0" encoding="utf-8"?> +            <SharedConfig> +             <Deployment name="MY_INSTANCE_ID"> +              <Service name="myservice"/> +              <ServiceInstance name="INSTANCE_ID.0" guid="{abcd-uuid}" /> +             </Deployment> +            <Incarnation number="1"/> +            </SharedConfig>""" +        ret = azure_helper.iid_from_shared_config_content(xml) +        self.assertEqual("MY_INSTANCE_ID", ret) + + +class TestFindEndpoint(TestCase): + +    def setUp(self): +        super(TestFindEndpoint, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.load_file = patches.enter_context( +            mock.patch.object(azure_helper.util, 'load_file')) + +    def test_missing_file(self): +        self.load_file.side_effect = IOError +        self.assertRaises(IOError, +                          azure_helper.WALinuxAgentShim.find_endpoint) + +    def test_missing_special_azure_line(self): +        self.load_file.return_value = '' +        self.assertRaises(Exception, +                          azure_helper.WALinuxAgentShim.find_endpoint) + +    def _build_lease_content(self, ip_address, use_hex=True): +        ip_address_repr = ':'.join( +            [hex(int(part)).replace('0x', '') +             for part in ip_address.split('.')]) +        if not use_hex: +            ip_address_repr = struct.pack( +                '>L', int(ip_address_repr.replace(':', ''), 16)) +            ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) +        return '\n'.join([ +            'lease {', +            ' interface "eth0";', +            ' option unknown-245 {0};'.format(ip_address_repr), +            '}']) + +    def test_hex_string(self): +        ip_address = '98.76.54.32' +        file_content = self._build_lease_content(ip_address) +        self.load_file.return_value = file_content +        self.assertEqual(ip_address, +                         azure_helper.WALinuxAgentShim.find_endpoint()) + +    def test_hex_string_with_single_character_part(self): +        ip_address = '4.3.2.1' +        file_content = self._build_lease_content(ip_address) +        self.load_file.return_value = file_content +        self.assertEqual(ip_address, +                         azure_helper.WALinuxAgentShim.find_endpoint()) + +    def test_packed_string(self): +        ip_address = '98.76.54.32' +        file_content = self._build_lease_content(ip_address, use_hex=False) +        self.load_file.return_value = file_content +        self.assertEqual(ip_address, +                         azure_helper.WALinuxAgentShim.find_endpoint()) + +    def test_latest_lease_used(self): +        ip_addresses = ['4.3.2.1', '98.76.54.32'] +        file_content = '\n'.join([self._build_lease_content(ip_address) +                                  for ip_address in ip_addresses]) +        self.load_file.return_value = file_content +        self.assertEqual(ip_addresses[-1], +                         azure_helper.WALinuxAgentShim.find_endpoint()) + + +class TestGoalStateParsing(TestCase): + +    default_parameters = { +        'incarnation': 1, +        'container_id': 'MyContainerId', +        'instance_id': 'MyInstanceId', +        'shared_config_url': 'MySharedConfigUrl', +        'certificates_url': 'MyCertificatesUrl', +    } + +    def _get_goal_state(self, http_client=None, **kwargs): +        if http_client is None: +            http_client = mock.MagicMock() +        parameters = self.default_parameters.copy() +        parameters.update(kwargs) +        xml = GOAL_STATE_TEMPLATE.format(**parameters) +        if parameters['certificates_url'] is None: +            new_xml_lines = [] +            for line in xml.splitlines(): +                if 'Certificates' in line: +                    continue +                new_xml_lines.append(line) +            xml = '\n'.join(new_xml_lines) +        return azure_helper.GoalState(xml, http_client) + +    def test_incarnation_parsed_correctly(self): +        incarnation = '123' +        goal_state = self._get_goal_state(incarnation=incarnation) +        self.assertEqual(incarnation, goal_state.incarnation) + +    def test_container_id_parsed_correctly(self): +        container_id = 'TestContainerId' +        goal_state = self._get_goal_state(container_id=container_id) +        self.assertEqual(container_id, goal_state.container_id) + +    def test_instance_id_parsed_correctly(self): +        instance_id = 'TestInstanceId' +        goal_state = self._get_goal_state(instance_id=instance_id) +        self.assertEqual(instance_id, goal_state.instance_id) + +    def test_shared_config_xml_parsed_and_fetched_correctly(self): +        http_client = mock.MagicMock() +        shared_config_url = 'TestSharedConfigUrl' +        goal_state = self._get_goal_state( +            http_client=http_client, shared_config_url=shared_config_url) +        shared_config_xml = goal_state.shared_config_xml +        self.assertEqual(1, http_client.get.call_count) +        self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) +        self.assertEqual(http_client.get.return_value.contents, +                         shared_config_xml) + +    def test_certificates_xml_parsed_and_fetched_correctly(self): +        http_client = mock.MagicMock() +        certificates_url = 'TestSharedConfigUrl' +        goal_state = self._get_goal_state( +            http_client=http_client, certificates_url=certificates_url) +        certificates_xml = goal_state.certificates_xml +        self.assertEqual(1, http_client.get.call_count) +        self.assertEqual(certificates_url, http_client.get.call_args[0][0]) +        self.assertTrue(http_client.get.call_args[1].get('secure', False)) +        self.assertEqual(http_client.get.return_value.contents, +                         certificates_xml) + +    def test_missing_certificates_skips_http_get(self): +        http_client = mock.MagicMock() +        goal_state = self._get_goal_state( +            http_client=http_client, certificates_url=None) +        certificates_xml = goal_state.certificates_xml +        self.assertEqual(0, http_client.get.call_count) +        self.assertIsNone(certificates_xml) + + +class TestAzureEndpointHttpClient(TestCase): + +    regular_headers = { +        'x-ms-agent-name': 'WALinuxAgent', +        'x-ms-version': '2012-11-30', +    } + +    def setUp(self): +        super(TestAzureEndpointHttpClient, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.read_file_or_url = patches.enter_context( +            mock.patch.object(azure_helper.util, 'read_file_or_url')) + +    def test_non_secure_get(self): +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        url = 'MyTestUrl' +        response = client.get(url, secure=False) +        self.assertEqual(1, self.read_file_or_url.call_count) +        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual(mock.call(url, headers=self.regular_headers), +                         self.read_file_or_url.call_args) + +    def test_secure_get(self): +        url = 'MyTestUrl' +        certificate = mock.MagicMock() +        expected_headers = self.regular_headers.copy() +        expected_headers.update({ +            "x-ms-cipher-name": "DES_EDE3_CBC", +            "x-ms-guest-agent-public-x509-cert": certificate, +        }) +        client = azure_helper.AzureEndpointHttpClient(certificate) +        response = client.get(url, secure=True) +        self.assertEqual(1, self.read_file_or_url.call_count) +        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual(mock.call(url, headers=expected_headers), +                         self.read_file_or_url.call_args) + +    def test_post(self): +        data = mock.MagicMock() +        url = 'MyTestUrl' +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        response = client.post(url, data=data) +        self.assertEqual(1, self.read_file_or_url.call_count) +        self.assertEqual(self.read_file_or_url.return_value, response) +        self.assertEqual( +            mock.call(url, data=data, headers=self.regular_headers), +            self.read_file_or_url.call_args) + +    def test_post_with_extra_headers(self): +        url = 'MyTestUrl' +        client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) +        extra_headers = {'test': 'header'} +        client.post(url, extra_headers=extra_headers) +        self.assertEqual(1, self.read_file_or_url.call_count) +        expected_headers = self.regular_headers.copy() +        expected_headers.update(extra_headers) +        self.assertEqual( +            mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), +            self.read_file_or_url.call_args) + + +class TestOpenSSLManager(TestCase): + +    def setUp(self): +        super(TestOpenSSLManager, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.subp = patches.enter_context( +            mock.patch.object(azure_helper.util, 'subp')) + +    @mock.patch.object(azure_helper, 'cd', mock.MagicMock()) +    @mock.patch.object(azure_helper.tempfile, 'TemporaryDirectory') +    def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory): +        manager = azure_helper.OpenSSLManager() +        self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir) + +    @mock.patch('builtins.open') +    def test_generate_certificate_uses_tmpdir(self, open): +        subp_directory = {} + +        def capture_directory(*args, **kwargs): +            subp_directory['path'] = os.getcwd() + +        self.subp.side_effect = capture_directory +        manager = azure_helper.OpenSSLManager() +        self.assertEqual(manager.tmpdir.name, subp_directory['path']) + + +class TestWALinuxAgentShim(TestCase): + +    def setUp(self): +        super(TestWALinuxAgentShim, self).setUp() +        patches = ExitStack() +        self.addCleanup(patches.close) + +        self.AzureEndpointHttpClient = patches.enter_context( +            mock.patch.object(azure_helper, 'AzureEndpointHttpClient')) +        self.find_endpoint = patches.enter_context( +            mock.patch.object( +                azure_helper.WALinuxAgentShim, 'find_endpoint')) +        self.GoalState = patches.enter_context( +            mock.patch.object(azure_helper, 'GoalState')) +        self.iid_from_shared_config_content = patches.enter_context( +            mock.patch.object(azure_helper, 'iid_from_shared_config_content')) +        self.OpenSSLManager = patches.enter_context( +            mock.patch.object(azure_helper, 'OpenSSLManager')) + +    def test_http_client_uses_certificate(self): +        shim = azure_helper.WALinuxAgentShim() +        self.assertEqual( +            [mock.call(self.OpenSSLManager.return_value.certificate)], +            self.AzureEndpointHttpClient.call_args_list) +        self.assertEqual(self.AzureEndpointHttpClient.return_value, +                         shim.http_client) + +    def test_correct_url_used_for_goalstate(self): +        self.find_endpoint.return_value = 'test_endpoint' +        shim = azure_helper.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        get = self.AzureEndpointHttpClient.return_value.get +        self.assertEqual( +            [mock.call('http://test_endpoint/machine/?comp=goalstate')], +            get.call_args_list) +        self.assertEqual( +            [mock.call(get.return_value.contents, shim.http_client)], +            self.GoalState.call_args_list) + +    def test_certificates_used_to_determine_public_keys(self): +        shim = azure_helper.WALinuxAgentShim() +        data = shim.register_with_azure_and_fetch_data() +        self.assertEqual( +            [mock.call(self.GoalState.return_value.certificates_xml)], +            self.OpenSSLManager.return_value.parse_certificates.call_args_list) +        self.assertEqual( +            self.OpenSSLManager.return_value.parse_certificates.return_value, +            data['public-keys']) + +    def test_absent_certificates_produces_empty_public_keys(self): +        self.GoalState.return_value.certificates_xml = None +        shim = azure_helper.WALinuxAgentShim() +        data = shim.register_with_azure_and_fetch_data() +        self.assertEqual([], data['public-keys']) + +    def test_instance_id_returned_in_data(self): +        shim = azure_helper.WALinuxAgentShim() +        data = shim.register_with_azure_and_fetch_data() +        self.assertEqual( +            [mock.call(self.GoalState.return_value.shared_config_xml)], +            self.iid_from_shared_config_content.call_args_list) +        self.assertEqual(self.iid_from_shared_config_content.return_value, +                         data['instance-id']) + +    def test_correct_url_used_for_report_ready(self): +        self.find_endpoint.return_value = 'test_endpoint' +        shim = azure_helper.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        expected_url = 'http://test_endpoint/machine?comp=health' +        self.assertEqual( +            [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], +            shim.http_client.post.call_args_list) + +    def test_goal_state_values_used_for_report_ready(self): +        self.GoalState.return_value.incarnation = 'TestIncarnation' +        self.GoalState.return_value.container_id = 'TestContainerId' +        self.GoalState.return_value.instance_id = 'TestInstanceId' +        shim = azure_helper.WALinuxAgentShim() +        shim.register_with_azure_and_fetch_data() +        posted_document = shim.http_client.post.call_args[1]['data'] +        self.assertIn('TestIncarnation', posted_document) +        self.assertIn('TestContainerId', posted_document) +        self.assertIn('TestInstanceId', posted_document) | 
