diff options
author | Scott Moser <smoser@ubuntu.com> | 2015-05-15 16:16:14 -0400 |
---|---|---|
committer | Scott Moser <smoser@ubuntu.com> | 2015-05-15 16:16:14 -0400 |
commit | b811fddc03511d7325775a68f076829dcb94349a (patch) | |
tree | c2c27003b48b89499ab2941c7641a5f08cbb1e40 | |
parent | 02da8491a11caf783e952a964da5bc90a618d46f (diff) | |
parent | dad01d2cf14a7e0bdca455040fb5a173775cefdc (diff) | |
download | vyos-cloud-init-b811fddc03511d7325775a68f076829dcb94349a.tar.gz vyos-cloud-init-b811fddc03511d7325775a68f076829dcb94349a.zip |
Azure: remove dependency on walinux-agent
This takes away our dependency on walinux-agent, by providing a builtin
path for doing cloud-init had delegated to it.
Currently the default is to still use the old path, but adding this code
in will allow us to move to the new code path with more confidence.
-rw-r--r-- | ChangeLog | 1 | ||||
-rw-r--r-- | cloudinit/sources/DataSourceAzure.py | 141 | ||||
-rw-r--r-- | cloudinit/sources/helpers/azure.py | 293 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 39 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 439 |
5 files changed, 822 insertions, 91 deletions
@@ -40,6 +40,7 @@ - Azure: do not re-set hostname if user has changed it (LP: #1375252) - Fix exception when running with no arguments on Python 3. [Daniel Watkins] - Centos: detect/expect use of systemd on centos 7. [Brian Rak] + - Azure: remove dependency on walinux-agent [Daniel Watkins] 0.7.6: - open 0.7.6 - Enable vendordata on CloudSigma datasource (LP: #1303986) diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index a19d9ca2..f2388c63 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -29,6 +29,8 @@ from cloudinit import log as logging from cloudinit.settings import PER_ALWAYS from cloudinit import sources from cloudinit import util +from cloudinit.sources.helpers.azure import ( + get_metadata_from_fabric, iid_from_shared_config_content) LOG = logging.getLogger(__name__) @@ -111,6 +113,56 @@ class DataSourceAzureNet(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [seed=%s]" % (root, self.seed) + def get_metadata_from_agent(self): + temp_hostname = self.metadata.get('local-hostname') + hostname_command = self.ds_cfg['hostname_bounce']['hostname_command'] + with temporary_hostname(temp_hostname, self.ds_cfg, + hostname_command=hostname_command) \ + as previous_hostname: + if (previous_hostname is not None + and util.is_true(self.ds_cfg.get('set_hostname'))): + cfg = self.ds_cfg['hostname_bounce'] + try: + perform_hostname_bounce(hostname=temp_hostname, + cfg=cfg, + prev_hostname=previous_hostname) + except Exception as e: + LOG.warn("Failed publishing hostname: %s", e) + util.logexc(LOG, "handling set_hostname failed") + + try: + invoke_agent(self.ds_cfg['agent_command']) + except util.ProcessExecutionError: + # claim the datasource even if the command failed + util.logexc(LOG, "agent command '%s' failed.", + self.ds_cfg['agent_command']) + + ddir = self.ds_cfg['data_dir'] + shcfgxml = os.path.join(ddir, "SharedConfig.xml") + wait_for = [shcfgxml] + + fp_files = [] + for pk in self.cfg.get('_pubkeys', []): + bname = str(pk['fingerprint'] + ".crt") + fp_files += [os.path.join(ddir, bname)] + + missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", + func=wait_for_files, + args=(wait_for + fp_files,)) + if len(missing): + LOG.warn("Did not find files, but going on: %s", missing) + + metadata = {} + if shcfgxml in missing: + LOG.warn("SharedConfig.xml missing, using static instance-id") + else: + try: + metadata['instance-id'] = iid_from_shared_config(shcfgxml) + except ValueError as e: + LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) + metadata['public-keys'] = pubkeys_from_crt_files(fp_files) + return metadata + def get_data(self): # azure removes/ejects the cdrom containing the ovf-env.xml # file on reboot. So, in order to successfully reboot we @@ -163,8 +215,6 @@ class DataSourceAzureNet(sources.DataSource): # now update ds_cfg to reflect contents pass in config user_ds_cfg = util.get_cfg_by_path(self.cfg, DS_CFG_PATH, {}) self.ds_cfg = util.mergemanydict([user_ds_cfg, self.ds_cfg]) - mycfg = self.ds_cfg - ddir = mycfg['data_dir'] if found != ddir: cached_ovfenv = util.load_file( @@ -185,53 +235,18 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) - temp_hostname = self.metadata.get('local-hostname') - hostname_command = mycfg['hostname_bounce']['hostname_command'] - with temporary_hostname(temp_hostname, mycfg, - hostname_command=hostname_command) \ - as previous_hostname: - if (previous_hostname is not None - and util.is_true(mycfg.get('set_hostname'))): - cfg = mycfg['hostname_bounce'] - try: - perform_hostname_bounce(hostname=temp_hostname, - cfg=cfg, - prev_hostname=previous_hostname) - except Exception as e: - LOG.warn("Failed publishing hostname: %s", e) - util.logexc(LOG, "handling set_hostname failed") - - try: - invoke_agent(mycfg['agent_command']) - except util.ProcessExecutionError: - # claim the datasource even if the command failed - util.logexc(LOG, "agent command '%s' failed.", - mycfg['agent_command']) - - shcfgxml = os.path.join(ddir, "SharedConfig.xml") - wait_for = [shcfgxml] - - fp_files = [] - for pk in self.cfg.get('_pubkeys', []): - bname = str(pk['fingerprint'] + ".crt") - fp_files += [os.path.join(ddir, bname)] - - missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", - func=wait_for_files, - args=(wait_for + fp_files,)) - if len(missing): - LOG.warn("Did not find files, but going on: %s", missing) - - if shcfgxml in missing: - LOG.warn("SharedConfig.xml missing, using static instance-id") + if self.ds_cfg['agent_command'] == '__builtin__': + metadata_func = get_metadata_from_fabric else: - try: - self.metadata['instance-id'] = iid_from_shared_config(shcfgxml) - except ValueError as e: - LOG.warn("failed to get instance id in %s: %s", shcfgxml, e) + metadata_func = self.get_metadata_from_agent + try: + fabric_data = metadata_func() + except Exception as exc: + LOG.info("Error communicating with Azure fabric; assume we aren't" + " on Azure.", exc_info=True) + return False - pubkeys = pubkeys_from_crt_files(fp_files) - self.metadata['public-keys'] = pubkeys + self.metadata.update(fabric_data) found_ephemeral = find_ephemeral_disk() if found_ephemeral: @@ -363,10 +378,11 @@ def perform_hostname_bounce(hostname, cfg, prev_hostname): 'env': env}) -def crtfile_to_pubkey(fname): +def crtfile_to_pubkey(fname, data=None): pipeline = ('openssl x509 -noout -pubkey < "$0" |' 'ssh-keygen -i -m PKCS8 -f /dev/stdin') - (out, _err) = util.subp(['sh', '-c', pipeline, fname], capture=True) + (out, _err) = util.subp(['sh', '-c', pipeline, fname], + capture=True, data=data) return out.rstrip() @@ -476,20 +492,6 @@ def load_azure_ovf_pubkeys(sshnode): return found -def single_node_at_path(node, pathlist): - curnode = node - for tok in pathlist: - results = find_child(curnode, lambda n: n.localName == tok) - if len(results) == 0: - raise ValueError("missing %s token in %s" % (tok, str(pathlist))) - if len(results) > 1: - raise ValueError("found %s nodes of type %s looking for %s" % - (len(results), tok, str(pathlist))) - curnode = results[0] - - return curnode - - def read_azure_ovf(contents): try: dom = minidom.parseString(contents) @@ -620,19 +622,6 @@ def iid_from_shared_config(path): return iid_from_shared_config_content(content) -def iid_from_shared_config_content(content): - """ - find INSTANCE_ID in: - <?xml version="1.0" encoding="utf-8"?> - <SharedConfig version="1.0.0.0" goalStateIncarnation="1"> - <Deployment name="INSTANCE_ID" guid="{...}" incarnation="0"> - <Service name="..." guid="{00000000-0000-0000-0000-000000000000}" /> - """ - dom = minidom.parseString(content) - depnode = single_node_at_path(dom, ["SharedConfig", "Deployment"]) - return depnode.attributes.get('name').value - - class BrokenAzureDataSource(Exception): pass diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py new file mode 100644 index 00000000..281d733e --- /dev/null +++ b/cloudinit/sources/helpers/azure.py @@ -0,0 +1,293 @@ +import logging +import os +import re +import socket +import struct +import tempfile +import time +from contextlib import contextmanager +from xml.etree import ElementTree + +from cloudinit import util + + +LOG = logging.getLogger(__name__) + + +@contextmanager +def cd(newdir): + prevdir = os.getcwd() + os.chdir(os.path.expanduser(newdir)) + try: + yield + finally: + os.chdir(prevdir) + + +class AzureEndpointHttpClient(object): + + headers = { + 'x-ms-agent-name': 'WALinuxAgent', + 'x-ms-version': '2012-11-30', + } + + def __init__(self, certificate): + self.extra_secure_headers = { + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": certificate, + } + + def get(self, url, secure=False): + headers = self.headers + if secure: + headers = self.headers.copy() + headers.update(self.extra_secure_headers) + return util.read_file_or_url(url, headers=headers) + + def post(self, url, data=None, extra_headers=None): + headers = self.headers + if extra_headers is not None: + headers = self.headers.copy() + headers.update(extra_headers) + return util.read_file_or_url(url, data=data, headers=headers) + + +class GoalState(object): + + def __init__(self, xml, http_client): + self.http_client = http_client + self.root = ElementTree.fromstring(xml) + self._certificates_xml = None + + def _text_from_xpath(self, xpath): + element = self.root.find(xpath) + if element is not None: + return element.text + return None + + @property + def container_id(self): + return self._text_from_xpath('./Container/ContainerId') + + @property + def incarnation(self): + return self._text_from_xpath('./Incarnation') + + @property + def instance_id(self): + return self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance/InstanceId') + + @property + def shared_config_xml(self): + url = self._text_from_xpath('./Container/RoleInstanceList/RoleInstance' + '/Configuration/SharedConfig') + return self.http_client.get(url).contents + + @property + def certificates_xml(self): + if self._certificates_xml is None: + url = self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance' + '/Configuration/Certificates') + if url is not None: + self._certificates_xml = self.http_client.get( + url, secure=True).contents + return self._certificates_xml + + +class OpenSSLManager(object): + + certificate_names = { + 'private_key': 'TransportPrivate.pem', + 'certificate': 'TransportCert.pem', + } + + def __init__(self): + self.tmpdir = tempfile.mkdtemp() + self.certificate = None + self.generate_certificate() + + def clean_up(self): + util.del_dir(self.tmpdir) + + def generate_certificate(self): + LOG.debug('Generating certificate for communication with fabric...') + if self.certificate is not None: + LOG.debug('Certificate already generated.') + return + with cd(self.tmpdir): + util.subp([ + 'openssl', 'req', '-x509', '-nodes', '-subj', + '/CN=LinuxTransport', '-days', '32768', '-newkey', 'rsa:2048', + '-keyout', self.certificate_names['private_key'], + '-out', self.certificate_names['certificate'], + ]) + certificate = '' + for line in open(self.certificate_names['certificate']): + if "CERTIFICATE" not in line: + certificate += line.rstrip() + self.certificate = certificate + LOG.debug('New certificate generated.') + + def parse_certificates(self, certificates_xml): + tag = ElementTree.fromstring(certificates_xml).find( + './/Data') + certificates_content = tag.text + lines = [ + b'MIME-Version: 1.0', + b'Content-Disposition: attachment; filename="Certificates.p7m"', + b'Content-Type: application/x-pkcs7-mime; name="Certificates.p7m"', + b'Content-Transfer-Encoding: base64', + b'', + certificates_content.encode('utf-8'), + ] + with cd(self.tmpdir): + with open('Certificates.p7m', 'wb') as f: + f.write(b'\n'.join(lines)) + out, _ = util.subp( + 'openssl cms -decrypt -in Certificates.p7m -inkey' + ' {private_key} -recip {certificate} | openssl pkcs12 -nodes' + ' -password pass:'.format(**self.certificate_names), + shell=True) + private_keys, certificates = [], [] + current = [] + for line in out.splitlines(): + current.append(line) + if re.match(r'[-]+END .*?KEY[-]+$', line): + private_keys.append('\n'.join(current)) + current = [] + elif re.match(r'[-]+END .*?CERTIFICATE[-]+$', line): + certificates.append('\n'.join(current)) + current = [] + keys = [] + for certificate in certificates: + with cd(self.tmpdir): + public_key, _ = util.subp( + 'openssl x509 -noout -pubkey |' + 'ssh-keygen -i -m PKCS8 -f /dev/stdin', + data=certificate, + shell=True) + keys.append(public_key) + return keys + + +def iid_from_shared_config_content(content): + """ + find INSTANCE_ID in: + <?xml version="1.0" encoding="utf-8"?> + <SharedConfig version="1.0.0.0" goalStateIncarnation="1"> + <Deployment name="INSTANCE_ID" guid="{...}" incarnation="0"> + <Service name="..." guid="{00000000-0000-0000-0000-000000000000}"/> + """ + root = ElementTree.fromstring(content) + depnode = root.find('Deployment') + return depnode.get('name') + + +class WALinuxAgentShim(object): + + REPORT_READY_XML_TEMPLATE = '\n'.join([ + '<?xml version="1.0" encoding="utf-8"?>', + '<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"' + ' xmlns:xsd="http://www.w3.org/2001/XMLSchema">', + ' <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>', + ' <Container>', + ' <ContainerId>{container_id}</ContainerId>', + ' <RoleInstanceList>', + ' <Role>', + ' <InstanceId>{instance_id}</InstanceId>', + ' <Health>', + ' <State>Ready</State>', + ' </Health>', + ' </Role>', + ' </RoleInstanceList>', + ' </Container>', + '</Health>']) + + def __init__(self): + LOG.debug('WALinuxAgentShim instantiated...') + self.endpoint = self.find_endpoint() + self.openssl_manager = None + self.values = {} + + def clean_up(self): + if self.openssl_manager is not None: + self.openssl_manager.clean_up() + + @staticmethod + def find_endpoint(): + LOG.debug('Finding Azure endpoint...') + content = util.load_file('/var/lib/dhcp/dhclient.eth0.leases') + value = None + for line in content.splitlines(): + if 'unknown-245' in line: + value = line.strip(' ').split(' ', 2)[-1].strip(';\n"') + if value is None: + raise Exception('No endpoint found in DHCP config.') + if ':' in value: + hex_string = '' + for hex_pair in value.split(':'): + if len(hex_pair) == 1: + hex_pair = '0' + hex_pair + hex_string += hex_pair + value = struct.pack('>L', int(hex_string.replace(':', ''), 16)) + else: + value = value.encode('utf-8') + endpoint_ip_address = socket.inet_ntoa(value) + LOG.debug('Azure endpoint found at %s', endpoint_ip_address) + return endpoint_ip_address + + def register_with_azure_and_fetch_data(self): + self.openssl_manager = OpenSSLManager() + http_client = AzureEndpointHttpClient(self.openssl_manager.certificate) + LOG.info('Registering with Azure...') + attempts = 0 + while True: + try: + response = http_client.get( + 'http://{0}/machine/?comp=goalstate'.format(self.endpoint)) + except Exception: + if attempts < 10: + time.sleep(attempts + 1) + else: + raise + else: + break + attempts += 1 + LOG.debug('Successfully fetched GoalState XML.') + goal_state = GoalState(response.contents, http_client) + public_keys = [] + if goal_state.certificates_xml is not None: + LOG.debug('Certificate XML found; parsing out public keys.') + public_keys = self.openssl_manager.parse_certificates( + goal_state.certificates_xml) + data = { + 'instance-id': iid_from_shared_config_content( + goal_state.shared_config_xml), + 'public-keys': public_keys, + } + self._report_ready(goal_state, http_client) + return data + + def _report_ready(self, goal_state, http_client): + LOG.debug('Reporting ready to Azure fabric.') + document = self.REPORT_READY_XML_TEMPLATE.format( + incarnation=goal_state.incarnation, + container_id=goal_state.container_id, + instance_id=goal_state.instance_id, + ) + http_client.post( + "http://{0}/machine?comp=health".format(self.endpoint), + data=document, + extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, + ) + LOG.info('Reported ready to Azure fabric.') + + +def get_metadata_from_fabric(): + shim = WALinuxAgentShim() + try: + return shim.register_with_azure_and_fetch_data() + finally: + shim.clean_up() diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 7e789853..c72dc801 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -18,7 +18,6 @@ import stat import yaml import shutil import tempfile -import unittest def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): @@ -123,6 +122,11 @@ class TestAzureDataSource(TestCase): mod = DataSourceAzure mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d + self.get_metadata_from_fabric = mock.MagicMock(return_value={ + 'instance-id': 'i-my-azure-id', + 'public-keys': [], + }) + self.apply_patches([ (mod, 'list_possible_azure_ds_devs', dsdevs), (mod, 'invoke_agent', _invoke_agent), @@ -132,7 +136,8 @@ class TestAzureDataSource(TestCase): (mod, 'perform_hostname_bounce', mock.MagicMock()), (mod, 'get_hostname', mock.MagicMock()), (mod, 'set_hostname', mock.MagicMock()), - ]) + (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric), + ]) dsrc = mod.DataSourceAzureNet( data.get('sys_cfg', {}), distro=None, paths=self.paths) @@ -382,6 +387,20 @@ class TestAzureDataSource(TestCase): self.assertEqual(new_ovfenv, load_file(os.path.join(self.waagent_d, 'ovf-env.xml'))) + def test_exception_fetching_fabric_data_doesnt_propagate(self): + ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + ds.ds_cfg['agent_command'] = '__builtin__' + self.get_metadata_from_fabric.side_effect = Exception + self.assertFalse(ds.get_data()) + + def test_fabric_data_included_in_metadata(self): + ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()}) + ds.ds_cfg['agent_command'] = '__builtin__' + self.get_metadata_from_fabric.return_value = {'test': 'value'} + ret = ds.get_data() + self.assertTrue(ret) + self.assertEqual('value', ds.metadata['test']) + class TestAzureBounce(TestCase): @@ -402,6 +421,9 @@ class TestAzureBounce(TestCase): self.patches.enter_context( mock.patch.object(DataSourceAzure, 'find_ephemeral_part', mock.MagicMock(return_value=None))) + self.patches.enter_context( + mock.patch.object(DataSourceAzure, 'get_metadata_from_fabric', + mock.MagicMock(return_value={}))) def setUp(self): super(TestAzureBounce, self).setUp() @@ -566,16 +588,3 @@ class TestReadAzureOvf(TestCase): for mypk in mypklist: self.assertIn(mypk, cfg['_pubkeys']) - -class TestReadAzureSharedConfig(unittest.TestCase): - def test_valid_content(self): - xml = """<?xml version="1.0" encoding="utf-8"?> - <SharedConfig> - <Deployment name="MY_INSTANCE_ID"> - <Service name="myservice"/> - <ServiceInstance name="INSTANCE_ID.0" guid="{abcd-uuid}" /> - </Deployment> - <Incarnation number="1"/> - </SharedConfig>""" - ret = DataSourceAzure.iid_from_shared_config_content(xml) - self.assertEqual("MY_INSTANCE_ID", ret) diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py new file mode 100644 index 00000000..23bc997c --- /dev/null +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -0,0 +1,439 @@ +import os +import struct +import unittest + +from cloudinit.sources.helpers import azure as azure_helper +from ..helpers import TestCase + +try: + from unittest import mock +except ImportError: + import mock + +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + + +GOAL_STATE_TEMPLATE = """\ +<?xml version="1.0" encoding="utf-8"?> +<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="goalstate10.xsd"> + <Version>2012-11-30</Version> + <Incarnation>{incarnation}</Incarnation> + <Machine> + <ExpectedState>Started</ExpectedState> + <StopRolesDeadlineHint>300000</StopRolesDeadlineHint> + <LBProbePorts> + <Port>16001</Port> + </LBProbePorts> + <ExpectHealthReport>FALSE</ExpectHealthReport> + </Machine> + <Container> + <ContainerId>{container_id}</ContainerId> + <RoleInstanceList> + <RoleInstance> + <InstanceId>{instance_id}</InstanceId> + <State>Started</State> + <Configuration> + <HostingEnvironmentConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1</HostingEnvironmentConfig> + <SharedConfig>{shared_config_url}</SharedConfig> + <ExtensionsConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1</ExtensionsConfig> + <FullConfig>http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1</FullConfig> + <Certificates>{certificates_url}</Certificates> + <ConfigName>68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml</ConfigName> + </Configuration> + </RoleInstance> + </RoleInstanceList> + </Container> +</GoalState> +""" + + +class TestReadAzureSharedConfig(unittest.TestCase): + + def test_valid_content(self): + xml = """<?xml version="1.0" encoding="utf-8"?> + <SharedConfig> + <Deployment name="MY_INSTANCE_ID"> + <Service name="myservice"/> + <ServiceInstance name="INSTANCE_ID.0" guid="{abcd-uuid}" /> + </Deployment> + <Incarnation number="1"/> + </SharedConfig>""" + ret = azure_helper.iid_from_shared_config_content(xml) + self.assertEqual("MY_INSTANCE_ID", ret) + + +class TestFindEndpoint(TestCase): + + def setUp(self): + super(TestFindEndpoint, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.load_file = patches.enter_context( + mock.patch.object(azure_helper.util, 'load_file')) + + def test_missing_file(self): + self.load_file.side_effect = IOError + self.assertRaises(IOError, + azure_helper.WALinuxAgentShim.find_endpoint) + + def test_missing_special_azure_line(self): + self.load_file.return_value = '' + self.assertRaises(Exception, + azure_helper.WALinuxAgentShim.find_endpoint) + + def _build_lease_content(self, ip_address, use_hex=True): + ip_address_repr = ':'.join( + [hex(int(part)).replace('0x', '') + for part in ip_address.split('.')]) + if not use_hex: + ip_address_repr = struct.pack( + '>L', int(ip_address_repr.replace(':', ''), 16)) + ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8')) + return '\n'.join([ + 'lease {', + ' interface "eth0";', + ' option unknown-245 {0};'.format(ip_address_repr), + '}']) + + def test_hex_string(self): + ip_address = '98.76.54.32' + file_content = self._build_lease_content(ip_address) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + + def test_hex_string_with_single_character_part(self): + ip_address = '4.3.2.1' + file_content = self._build_lease_content(ip_address) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + + def test_packed_string(self): + ip_address = '98.76.54.32' + file_content = self._build_lease_content(ip_address, use_hex=False) + self.load_file.return_value = file_content + self.assertEqual(ip_address, + azure_helper.WALinuxAgentShim.find_endpoint()) + + def test_latest_lease_used(self): + ip_addresses = ['4.3.2.1', '98.76.54.32'] + file_content = '\n'.join([self._build_lease_content(ip_address) + for ip_address in ip_addresses]) + self.load_file.return_value = file_content + self.assertEqual(ip_addresses[-1], + azure_helper.WALinuxAgentShim.find_endpoint()) + + +class TestGoalStateParsing(TestCase): + + default_parameters = { + 'incarnation': 1, + 'container_id': 'MyContainerId', + 'instance_id': 'MyInstanceId', + 'shared_config_url': 'MySharedConfigUrl', + 'certificates_url': 'MyCertificatesUrl', + } + + def _get_goal_state(self, http_client=None, **kwargs): + if http_client is None: + http_client = mock.MagicMock() + parameters = self.default_parameters.copy() + parameters.update(kwargs) + xml = GOAL_STATE_TEMPLATE.format(**parameters) + if parameters['certificates_url'] is None: + new_xml_lines = [] + for line in xml.splitlines(): + if 'Certificates' in line: + continue + new_xml_lines.append(line) + xml = '\n'.join(new_xml_lines) + return azure_helper.GoalState(xml, http_client) + + def test_incarnation_parsed_correctly(self): + incarnation = '123' + goal_state = self._get_goal_state(incarnation=incarnation) + self.assertEqual(incarnation, goal_state.incarnation) + + def test_container_id_parsed_correctly(self): + container_id = 'TestContainerId' + goal_state = self._get_goal_state(container_id=container_id) + self.assertEqual(container_id, goal_state.container_id) + + def test_instance_id_parsed_correctly(self): + instance_id = 'TestInstanceId' + goal_state = self._get_goal_state(instance_id=instance_id) + self.assertEqual(instance_id, goal_state.instance_id) + + def test_shared_config_xml_parsed_and_fetched_correctly(self): + http_client = mock.MagicMock() + shared_config_url = 'TestSharedConfigUrl' + goal_state = self._get_goal_state( + http_client=http_client, shared_config_url=shared_config_url) + shared_config_xml = goal_state.shared_config_xml + self.assertEqual(1, http_client.get.call_count) + self.assertEqual(shared_config_url, http_client.get.call_args[0][0]) + self.assertEqual(http_client.get.return_value.contents, + shared_config_xml) + + def test_certificates_xml_parsed_and_fetched_correctly(self): + http_client = mock.MagicMock() + certificates_url = 'TestSharedConfigUrl' + goal_state = self._get_goal_state( + http_client=http_client, certificates_url=certificates_url) + certificates_xml = goal_state.certificates_xml + self.assertEqual(1, http_client.get.call_count) + self.assertEqual(certificates_url, http_client.get.call_args[0][0]) + self.assertTrue(http_client.get.call_args[1].get('secure', False)) + self.assertEqual(http_client.get.return_value.contents, + certificates_xml) + + def test_missing_certificates_skips_http_get(self): + http_client = mock.MagicMock() + goal_state = self._get_goal_state( + http_client=http_client, certificates_url=None) + certificates_xml = goal_state.certificates_xml + self.assertEqual(0, http_client.get.call_count) + self.assertIsNone(certificates_xml) + + +class TestAzureEndpointHttpClient(TestCase): + + regular_headers = { + 'x-ms-agent-name': 'WALinuxAgent', + 'x-ms-version': '2012-11-30', + } + + def setUp(self): + super(TestAzureEndpointHttpClient, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.read_file_or_url = patches.enter_context( + mock.patch.object(azure_helper.util, 'read_file_or_url')) + + def test_non_secure_get(self): + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + url = 'MyTestUrl' + response = client.get(url, secure=False) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(mock.call(url, headers=self.regular_headers), + self.read_file_or_url.call_args) + + def test_secure_get(self): + url = 'MyTestUrl' + certificate = mock.MagicMock() + expected_headers = self.regular_headers.copy() + expected_headers.update({ + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": certificate, + }) + client = azure_helper.AzureEndpointHttpClient(certificate) + response = client.get(url, secure=True) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(mock.call(url, headers=expected_headers), + self.read_file_or_url.call_args) + + def test_post(self): + data = mock.MagicMock() + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + response = client.post(url, data=data) + self.assertEqual(1, self.read_file_or_url.call_count) + self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual( + mock.call(url, data=data, headers=self.regular_headers), + self.read_file_or_url.call_args) + + def test_post_with_extra_headers(self): + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + extra_headers = {'test': 'header'} + client.post(url, extra_headers=extra_headers) + self.assertEqual(1, self.read_file_or_url.call_count) + expected_headers = self.regular_headers.copy() + expected_headers.update(extra_headers) + self.assertEqual( + mock.call(mock.ANY, data=mock.ANY, headers=expected_headers), + self.read_file_or_url.call_args) + + +class TestOpenSSLManager(TestCase): + + def setUp(self): + super(TestOpenSSLManager, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.subp = patches.enter_context( + mock.patch.object(azure_helper.util, 'subp')) + try: + self.open = patches.enter_context( + mock.patch('__builtin__.open')) + except ImportError: + self.open = patches.enter_context( + mock.patch('builtins.open')) + + @mock.patch.object(azure_helper, 'cd', mock.MagicMock()) + @mock.patch.object(azure_helper.tempfile, 'mkdtemp') + def test_openssl_manager_creates_a_tmpdir(self, mkdtemp): + manager = azure_helper.OpenSSLManager() + self.assertEqual(mkdtemp.return_value, manager.tmpdir) + + def test_generate_certificate_uses_tmpdir(self): + subp_directory = {} + + def capture_directory(*args, **kwargs): + subp_directory['path'] = os.getcwd() + + self.subp.side_effect = capture_directory + manager = azure_helper.OpenSSLManager() + self.assertEqual(manager.tmpdir, subp_directory['path']) + + @mock.patch.object(azure_helper, 'cd', mock.MagicMock()) + @mock.patch.object(azure_helper.tempfile, 'mkdtemp', mock.MagicMock()) + @mock.patch.object(azure_helper.util, 'del_dir') + def test_clean_up(self, del_dir): + manager = azure_helper.OpenSSLManager() + manager.clean_up() + self.assertEqual([mock.call(manager.tmpdir)], del_dir.call_args_list) + + +class TestWALinuxAgentShim(TestCase): + + def setUp(self): + super(TestWALinuxAgentShim, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + self.AzureEndpointHttpClient = patches.enter_context( + mock.patch.object(azure_helper, 'AzureEndpointHttpClient')) + self.find_endpoint = patches.enter_context( + mock.patch.object( + azure_helper.WALinuxAgentShim, 'find_endpoint')) + self.GoalState = patches.enter_context( + mock.patch.object(azure_helper, 'GoalState')) + self.iid_from_shared_config_content = patches.enter_context( + mock.patch.object(azure_helper, 'iid_from_shared_config_content')) + self.OpenSSLManager = patches.enter_context( + mock.patch.object(azure_helper, 'OpenSSLManager')) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + + def test_http_client_uses_certificate(self): + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.OpenSSLManager.return_value.certificate)], + self.AzureEndpointHttpClient.call_args_list) + + def test_correct_url_used_for_goalstate(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + get = self.AzureEndpointHttpClient.return_value.get + self.assertEqual( + [mock.call('http://test_endpoint/machine/?comp=goalstate')], + get.call_args_list) + self.assertEqual( + [mock.call(get.return_value.contents, + self.AzureEndpointHttpClient.return_value)], + self.GoalState.call_args_list) + + def test_certificates_used_to_determine_public_keys(self): + shim = azure_helper.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.certificates_xml)], + self.OpenSSLManager.return_value.parse_certificates.call_args_list) + self.assertEqual( + self.OpenSSLManager.return_value.parse_certificates.return_value, + data['public-keys']) + + def test_absent_certificates_produces_empty_public_keys(self): + self.GoalState.return_value.certificates_xml = None + shim = azure_helper.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual([], data['public-keys']) + + def test_instance_id_returned_in_data(self): + shim = azure_helper.WALinuxAgentShim() + data = shim.register_with_azure_and_fetch_data() + self.assertEqual( + [mock.call(self.GoalState.return_value.shared_config_xml)], + self.iid_from_shared_config_content.call_args_list) + self.assertEqual(self.iid_from_shared_config_content.return_value, + data['instance-id']) + + def test_correct_url_used_for_report_ready(self): + self.find_endpoint.return_value = 'test_endpoint' + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + expected_url = 'http://test_endpoint/machine?comp=health' + self.assertEqual( + [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], + self.AzureEndpointHttpClient.return_value.post.call_args_list) + + def test_goal_state_values_used_for_report_ready(self): + self.GoalState.return_value.incarnation = 'TestIncarnation' + self.GoalState.return_value.container_id = 'TestContainerId' + self.GoalState.return_value.instance_id = 'TestInstanceId' + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + posted_document = ( + self.AzureEndpointHttpClient.return_value.post.call_args[1]['data'] + ) + self.assertIn('TestIncarnation', posted_document) + self.assertIn('TestContainerId', posted_document) + self.assertIn('TestInstanceId', posted_document) + + def test_clean_up_can_be_called_at_any_time(self): + shim = azure_helper.WALinuxAgentShim() + shim.clean_up() + + def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self): + shim = azure_helper.WALinuxAgentShim() + shim.register_with_azure_and_fetch_data() + shim.clean_up() + self.assertEqual( + 1, self.OpenSSLManager.return_value.clean_up.call_count) + + def test_failure_to_fetch_goalstate_bubbles_up(self): + class SentinelException(Exception): + pass + self.AzureEndpointHttpClient.return_value.get.side_effect = ( + SentinelException) + shim = azure_helper.WALinuxAgentShim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) + + +class TestGetMetadataFromFabric(TestCase): + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_data_from_shim_returned(self, shim): + ret = azure_helper.get_metadata_from_fabric() + self.assertEqual( + shim.return_value.register_with_azure_and_fetch_data.return_value, + ret) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_success_calls_clean_up(self, shim): + azure_helper.get_metadata_from_fabric() + self.assertEqual(1, shim.return_value.clean_up.call_count) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_failure_in_registration_calls_clean_up(self, shim): + class SentinelException(Exception): + pass + shim.return_value.register_with_azure_and_fetch_data.side_effect = ( + SentinelException) + self.assertRaises(SentinelException, + azure_helper.get_metadata_from_fabric) + self.assertEqual(1, shim.return_value.clean_up.call_count) |