from cloudinit import helpers
from cloudinit.util import b64e, decode_binary, load_file
from cloudinit.sources import DataSourceAzure
from ..helpers import TestCase, populate_dir
try:
from unittest import mock
except ImportError:
import mock
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
import crypt
import os
import stat
import struct
import yaml
import shutil
import tempfile
import unittest
from cloudinit import url_helper
GOAL_STATE_TEMPLATE = """\
2012-11-30
{incarnation}
Started
300000
16001
FALSE
{container_id}
{instance_id}
Started
http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=hostingEnvironmentConfig&incarnation=1
{shared_config_url}
http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=extensionsConfig&incarnation=1
http://100.86.192.70:80/machine/46504ebc-f968-4f23-b9aa-cd2b3e4d470c/68ce47b32ea94952be7b20951c383628.utl%2Dtrusty%2D%2D292258?comp=config&type=fullConfig&incarnation=1
{certificates_url}
68ce47b32ea94952be7b20951c383628.0.68ce47b32ea94952be7b20951c383628.0.utl-trusty--292258.1.xml
"""
def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None):
if data is None:
data = {'HostName': 'FOOHOST'}
if pubkeys is None:
pubkeys = {}
content = """
1.0
LinuxProvisioningConfiguration
"""
for key, dval in data.items():
if isinstance(dval, dict):
val = dval.get('text')
attrs = ' ' + ' '.join(["%s='%s'" % (k, v) for k, v in dval.items()
if k != 'text'])
else:
val = dval
attrs = ""
content += "<%s%s>%s%s>\n" % (key, attrs, val, key)
if userdata:
content += "%s\n" % (b64e(userdata))
if pubkeys:
content += "\n"
for fp, path in pubkeys:
content += " "
content += ("%s%s" %
(fp, path))
content += "\n"
content += ""
content += """
1.0
kms.core.windows.net
false
"""
return content
class TestAzureDataSource(TestCase):
def setUp(self):
super(TestAzureDataSource, self).setUp()
self.tmp = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp)
# patch cloud_dir, so our 'seed_dir' is guaranteed empty
self.paths = helpers.Paths({'cloud_dir': self.tmp})
self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
self.patches = ExitStack()
self.addCleanup(self.patches.close)
super(TestAzureDataSource, self).setUp()
def apply_patches(self, patches):
for module, name, new in patches:
self.patches.enter_context(mock.patch.object(module, name, new))
def _get_ds(self, data):
def dsdevs():
return data.get('dsdevs', [])
def _invoke_agent(cmd):
data['agent_invoked'] = cmd
def _wait_for_files(flist, _maxwait=None, _naplen=None):
data['waited'] = flist
return []
def _pubkeys_from_crt_files(flist):
data['pubkey_files'] = flist
return ["pubkey_from: %s" % f for f in flist]
def _iid_from_shared_config(path):
data['iid_from_shared_cfg'] = path
return 'i-my-azure-id'
if data.get('ovfcontent') is not None:
populate_dir(os.path.join(self.paths.seed_dir, "azure"),
{'ovf-env.xml': data['ovfcontent']})
mod = DataSourceAzure
mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
self.apply_patches([
(mod, 'list_possible_azure_ds_devs', dsdevs),
(mod, 'invoke_agent', _invoke_agent),
(mod, 'wait_for_files', _wait_for_files),
(mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),
(mod, 'iid_from_shared_config', _iid_from_shared_config),
(mod, 'perform_hostname_bounce', mock.MagicMock()),
(mod, 'get_hostname', mock.MagicMock()),
(mod, 'set_hostname', mock.MagicMock()),
])
dsrc = mod.DataSourceAzureNet(
data.get('sys_cfg', {}), distro=None, paths=self.paths)
return dsrc
def test_basic_seed_dir(self):
odata = {'HostName': "myhost", 'UserName': "myuser"}
data = {'ovfcontent': construct_valid_ovf_env(data=odata),
'sys_cfg': {}}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(dsrc.userdata_raw, "")
self.assertEqual(dsrc.metadata['local-hostname'], odata['HostName'])
self.assertTrue(os.path.isfile(
os.path.join(self.waagent_d, 'ovf-env.xml')))
self.assertEqual(dsrc.metadata['instance-id'], 'i-my-azure-id')
def test_waagent_d_has_0700_perms(self):
# we expect /var/lib/waagent to be created 0700
dsrc = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertTrue(os.path.isdir(self.waagent_d))
self.assertEqual(stat.S_IMODE(os.stat(self.waagent_d).st_mode), 0o700)
def test_user_cfg_set_agent_command_plain(self):
# set dscfg in via plaintext
# we must have friendly-to-xml formatted plaintext in yaml_cfg
# not all plaintext is expected to work.
yaml_cfg = "{agent_command: my_command}\n"
cfg = yaml.safe_load(yaml_cfg)
odata = {'HostName': "myhost", 'UserName': "myuser",
'dscfg': {'text': yaml_cfg, 'encoding': 'plain'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(data['agent_invoked'], cfg['agent_command'])
def test_user_cfg_set_agent_command(self):
# set dscfg in via base64 encoded yaml
cfg = {'agent_command': "my_command"}
odata = {'HostName': "myhost", 'UserName': "myuser",
'dscfg': {'text': b64e(yaml.dump(cfg)),
'encoding': 'base64'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(data['agent_invoked'], cfg['agent_command'])
def test_sys_cfg_set_agent_command(self):
sys_cfg = {'datasource': {'Azure': {'agent_command': '_COMMAND'}}}
data = {'ovfcontent': construct_valid_ovf_env(data={}),
'sys_cfg': sys_cfg}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(data['agent_invoked'], '_COMMAND')
def test_username_used(self):
odata = {'HostName': "myhost", 'UserName': "myuser"}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(dsrc.cfg['system_info']['default_user']['name'],
"myuser")
def test_password_given(self):
odata = {'HostName': "myhost", 'UserName': "myuser",
'UserPassword': "mypass"}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertTrue('default_user' in dsrc.cfg['system_info'])
defuser = dsrc.cfg['system_info']['default_user']
# default user should be updated username and should not be locked.
self.assertEqual(defuser['name'], odata['UserName'])
self.assertFalse(defuser['lock_passwd'])
# passwd is crypt formated string $id$salt$encrypted
# encrypting plaintext with salt value of everything up to final '$'
# should equal that after the '$'
pos = defuser['passwd'].rfind("$") + 1
self.assertEqual(defuser['passwd'],
crypt.crypt(odata['UserPassword'], defuser['passwd'][0:pos]))
def test_userdata_plain(self):
mydata = "FOOBAR"
odata = {'UserData': {'text': mydata, 'encoding': 'plain'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(decode_binary(dsrc.userdata_raw), mydata)
def test_userdata_found(self):
mydata = "FOOBAR"
odata = {'UserData': {'text': b64e(mydata), 'encoding': 'base64'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(dsrc.userdata_raw, mydata.encode('utf-8'))
def test_no_datasource_expected(self):
# no source should be found if no seed_dir and no devs
data = {}
dsrc = self._get_ds({})
ret = dsrc.get_data()
self.assertFalse(ret)
self.assertFalse('agent_invoked' in data)
def test_cfg_has_pubkeys(self):
odata = {'HostName': "myhost", 'UserName': "myuser"}
mypklist = [{'fingerprint': 'fp1', 'path': 'path1'}]
pubkeys = [(x['fingerprint'], x['path']) for x in mypklist]
data = {'ovfcontent': construct_valid_ovf_env(data=odata,
pubkeys=pubkeys)}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
for mypk in mypklist:
self.assertIn(mypk, dsrc.cfg['_pubkeys'])
def test_default_ephemeral(self):
# make sure the ephemeral device works
odata = {}
data = {'ovfcontent': construct_valid_ovf_env(data=odata),
'sys_cfg': {}}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
cfg = dsrc.get_config_obj()
self.assertEquals(dsrc.device_name_to_device("ephemeral0"),
"/dev/sdb")
assert 'disk_setup' in cfg
assert 'fs_setup' in cfg
self.assertIsInstance(cfg['disk_setup'], dict)
self.assertIsInstance(cfg['fs_setup'], list)
def test_provide_disk_aliases(self):
# Make sure that user can affect disk aliases
dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}}
odata = {'HostName': "myhost", 'UserName': "myuser",
'dscfg': {'text': b64e(yaml.dump(dscfg)),
'encoding': 'base64'}}
usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'},
'ephemeral0': False}}
userdata = '#cloud-config' + yaml.dump(usercfg) + "\n"
ovfcontent = construct_valid_ovf_env(data=odata, userdata=userdata)
data = {'ovfcontent': ovfcontent, 'sys_cfg': {}}
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
cfg = dsrc.get_config_obj()
self.assertTrue(cfg)
def test_userdata_arrives(self):
userdata = "This is my user-data"
xml = construct_valid_ovf_env(data={}, userdata=userdata)
data = {'ovfcontent': xml}
dsrc = self._get_ds(data)
dsrc.get_data()
self.assertEqual(userdata.encode('us-ascii'), dsrc.userdata_raw)
def test_ovf_env_arrives_in_waagent_dir(self):
xml = construct_valid_ovf_env(data={}, userdata="FOODATA")
dsrc = self._get_ds({'ovfcontent': xml})
dsrc.get_data()
# 'data_dir' is '/var/lib/waagent' (walinux-agent's state dir)
# we expect that the ovf-env.xml file is copied there.
ovf_env_path = os.path.join(self.waagent_d, 'ovf-env.xml')
self.assertTrue(os.path.exists(ovf_env_path))
self.assertEqual(xml, load_file(ovf_env_path))
def test_ovf_can_include_unicode(self):
xml = construct_valid_ovf_env(data={})
xml = u'\ufeff{0}'.format(xml)
dsrc = self._get_ds({'ovfcontent': xml})
dsrc.get_data()
def test_existing_ovf_same(self):
# waagent/SharedConfig left alone if found ovf-env.xml same as cached
odata = {'UserData': b64e("SOMEUSERDATA")}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
populate_dir(self.waagent_d,
{'ovf-env.xml': data['ovfcontent'],
'otherfile': 'otherfile-content',
'SharedConfig.xml': 'mysharedconfig'})
dsrc = self._get_ds(data)
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertTrue(os.path.exists(
os.path.join(self.waagent_d, 'ovf-env.xml')))
self.assertTrue(os.path.exists(
os.path.join(self.waagent_d, 'otherfile')))
self.assertTrue(os.path.exists(
os.path.join(self.waagent_d, 'SharedConfig.xml')))
def test_existing_ovf_diff(self):
# waagent/SharedConfig must be removed if ovfenv is found elsewhere
# 'get_data' should remove SharedConfig.xml in /var/lib/waagent
# if ovf-env.xml differs.
cached_ovfenv = construct_valid_ovf_env(
{'userdata': b64e("FOO_USERDATA")})
new_ovfenv = construct_valid_ovf_env(
{'userdata': b64e("NEW_USERDATA")})
populate_dir(self.waagent_d,
{'ovf-env.xml': cached_ovfenv,
'SharedConfig.xml': "mysharedconfigxml",
'otherfile': 'otherfilecontent'})
dsrc = self._get_ds({'ovfcontent': new_ovfenv})
ret = dsrc.get_data()
self.assertTrue(ret)
self.assertEqual(dsrc.userdata_raw, b"NEW_USERDATA")
self.assertTrue(os.path.exists(
os.path.join(self.waagent_d, 'otherfile')))
self.assertFalse(
os.path.exists(os.path.join(self.waagent_d, 'SharedConfig.xml')))
self.assertTrue(
os.path.exists(os.path.join(self.waagent_d, 'ovf-env.xml')))
self.assertEqual(new_ovfenv,
load_file(os.path.join(self.waagent_d, 'ovf-env.xml')))
class TestAzureBounce(TestCase):
def mock_out_azure_moving_parts(self):
self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'invoke_agent'))
self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'wait_for_files'))
self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'iid_from_shared_config',
mock.MagicMock(return_value='i-my-azure-id')))
self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs',
mock.MagicMock(return_value=[])))
self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'find_ephemeral_disk',
mock.MagicMock(return_value=None)))
self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'find_ephemeral_part',
mock.MagicMock(return_value=None)))
def setUp(self):
super(TestAzureBounce, self).setUp()
self.tmp = tempfile.mkdtemp()
self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
self.paths = helpers.Paths({'cloud_dir': self.tmp})
self.addCleanup(shutil.rmtree, self.tmp)
DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
self.patches = ExitStack()
self.mock_out_azure_moving_parts()
self.get_hostname = self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'get_hostname'))
self.set_hostname = self.patches.enter_context(
mock.patch.object(DataSourceAzure, 'set_hostname'))
self.subp = self.patches.enter_context(
mock.patch('cloudinit.sources.DataSourceAzure.util.subp'))
def tearDown(self):
self.patches.close()
def _get_ds(self, ovfcontent=None):
if ovfcontent is not None:
populate_dir(os.path.join(self.paths.seed_dir, "azure"),
{'ovf-env.xml': ovfcontent})
return DataSourceAzure.DataSourceAzureNet(
{}, distro=None, paths=self.paths)
def get_ovf_env_with_dscfg(self, hostname, cfg):
odata = {
'HostName': hostname,
'dscfg': {
'text': b64e(yaml.dump(cfg)),
'encoding': 'base64'
}
}
return construct_valid_ovf_env(data=odata)
def test_disabled_bounce_does_not_change_hostname(self):
cfg = {'hostname_bounce': {'policy': 'off'}}
self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
self.assertEqual(0, self.set_hostname.call_count)
@mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
def test_disabled_bounce_does_not_perform_bounce(
self, perform_hostname_bounce):
cfg = {'hostname_bounce': {'policy': 'off'}}
self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
self.assertEqual(0, perform_hostname_bounce.call_count)
def test_same_hostname_does_not_change_hostname(self):
host_name = 'unchanged-host-name'
self.get_hostname.return_value = host_name
cfg = {'hostname_bounce': {'policy': 'yes'}}
self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
self.assertEqual(0, self.set_hostname.call_count)
@mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
def test_unchanged_hostname_does_not_perform_bounce(
self, perform_hostname_bounce):
host_name = 'unchanged-host-name'
self.get_hostname.return_value = host_name
cfg = {'hostname_bounce': {'policy': 'yes'}}
self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
self.assertEqual(0, perform_hostname_bounce.call_count)
@mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
def test_force_performs_bounce_regardless(self, perform_hostname_bounce):
host_name = 'unchanged-host-name'
self.get_hostname.return_value = host_name
cfg = {'hostname_bounce': {'policy': 'force'}}
self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
self.assertEqual(1, perform_hostname_bounce.call_count)
def test_different_hostnames_sets_hostname(self):
expected_hostname = 'azure-expected-host-name'
self.get_hostname.return_value = 'default-host-name'
self._get_ds(
self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
self.assertEqual(expected_hostname,
self.set_hostname.call_args_list[0][0][0])
@mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
def test_different_hostnames_performs_bounce(
self, perform_hostname_bounce):
expected_hostname = 'azure-expected-host-name'
self.get_hostname.return_value = 'default-host-name'
self._get_ds(
self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
self.assertEqual(1, perform_hostname_bounce.call_count)
def test_different_hostnames_sets_hostname_back(self):
initial_host_name = 'default-host-name'
self.get_hostname.return_value = initial_host_name
self._get_ds(
self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
self.assertEqual(initial_host_name,
self.set_hostname.call_args_list[-1][0][0])
@mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
def test_failure_in_bounce_still_resets_host_name(
self, perform_hostname_bounce):
perform_hostname_bounce.side_effect = Exception
initial_host_name = 'default-host-name'
self.get_hostname.return_value = initial_host_name
self._get_ds(
self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
self.assertEqual(initial_host_name,
self.set_hostname.call_args_list[-1][0][0])
def test_environment_correct_for_bounce_command(self):
interface = 'int0'
hostname = 'my-new-host'
old_hostname = 'my-old-host'
self.get_hostname.return_value = old_hostname
cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}}
data = self.get_ovf_env_with_dscfg(hostname, cfg)
self._get_ds(data).get_data()
self.assertEqual(1, self.subp.call_count)
bounce_env = self.subp.call_args[1]['env']
self.assertEqual(interface, bounce_env['interface'])
self.assertEqual(hostname, bounce_env['hostname'])
self.assertEqual(old_hostname, bounce_env['old_hostname'])
def test_default_bounce_command_used_by_default(self):
cmd = 'default-bounce-command'
DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd
cfg = {'hostname_bounce': {'policy': 'force'}}
data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
self._get_ds(data).get_data()
self.assertEqual(1, self.subp.call_count)
bounce_args = self.subp.call_args[1]['args']
self.assertEqual(cmd, bounce_args)
@mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
def test_set_hostname_option_can_disable_bounce(
self, perform_hostname_bounce):
cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
self._get_ds(data).get_data()
self.assertEqual(0, perform_hostname_bounce.call_count)
def test_set_hostname_option_can_disable_hostname_set(self):
cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
self._get_ds(data).get_data()
self.assertEqual(0, self.set_hostname.call_count)
class TestReadAzureOvf(TestCase):
def test_invalid_xml_raises_non_azure_ds(self):
invalid_xml = "" + construct_valid_ovf_env(data={})
self.assertRaises(DataSourceAzure.BrokenAzureDataSource,
DataSourceAzure.read_azure_ovf, invalid_xml)
def test_load_with_pubkeys(self):
mypklist = [{'fingerprint': 'fp1', 'path': 'path1'}]
pubkeys = [(x['fingerprint'], x['path']) for x in mypklist]
content = construct_valid_ovf_env(pubkeys=pubkeys)
(_md, _ud, cfg) = DataSourceAzure.read_azure_ovf(content)
for mypk in mypklist:
self.assertIn(mypk, cfg['_pubkeys'])
class TestReadAzureSharedConfig(unittest.TestCase):
def test_valid_content(self):
xml = """
"""
ret = DataSourceAzure.iid_from_shared_config_content(xml)
self.assertEqual("MY_INSTANCE_ID", ret)
class TestFindEndpoint(TestCase):
def setUp(self):
super(TestFindEndpoint, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.load_file = patches.enter_context(
mock.patch.object(DataSourceAzure.util, 'load_file'))
def test_missing_file(self):
self.load_file.side_effect = IOError
self.assertRaises(IOError, DataSourceAzure.find_endpoint)
def test_missing_special_azure_line(self):
self.load_file.return_value = ''
self.assertRaises(Exception, DataSourceAzure.find_endpoint)
def _build_lease_content(self, ip_address, use_hex=True):
ip_address_repr = ':'.join(
[hex(int(part)).replace('0x', '')
for part in ip_address.split('.')])
if not use_hex:
ip_address_repr = struct.pack(
'>L', int(ip_address_repr.replace(':', ''), 16))
ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8'))
return '\n'.join([
'lease {',
' interface "eth0";',
' option unknown-245 {0};'.format(ip_address_repr),
'}'])
def test_hex_string(self):
ip_address = '98.76.54.32'
file_content = self._build_lease_content(ip_address)
self.load_file.return_value = file_content
self.assertEqual(ip_address, DataSourceAzure.find_endpoint())
def test_hex_string_with_single_character_part(self):
ip_address = '4.3.2.1'
file_content = self._build_lease_content(ip_address)
self.load_file.return_value = file_content
self.assertEqual(ip_address, DataSourceAzure.find_endpoint())
def test_packed_string(self):
ip_address = '98.76.54.32'
file_content = self._build_lease_content(ip_address, use_hex=False)
self.load_file.return_value = file_content
self.assertEqual(ip_address, DataSourceAzure.find_endpoint())
def test_latest_lease_used(self):
ip_addresses = ['4.3.2.1', '98.76.54.32']
file_content = '\n'.join([self._build_lease_content(ip_address)
for ip_address in ip_addresses])
self.load_file.return_value = file_content
self.assertEqual(ip_addresses[-1], DataSourceAzure.find_endpoint())
class TestGoalStateParsing(TestCase):
default_parameters = {
'incarnation': 1,
'container_id': 'MyContainerId',
'instance_id': 'MyInstanceId',
'shared_config_url': 'MySharedConfigUrl',
'certificates_url': 'MyCertificatesUrl',
}
def _get_goal_state(self, http_client=None, **kwargs):
if http_client is None:
http_client = mock.MagicMock()
parameters = self.default_parameters.copy()
parameters.update(kwargs)
xml = GOAL_STATE_TEMPLATE.format(**parameters)
if parameters['certificates_url'] is None:
new_xml_lines = []
for line in xml.splitlines():
if 'Certificates' in line:
continue
new_xml_lines.append(line)
xml = '\n'.join(new_xml_lines)
return DataSourceAzure.GoalState(xml, http_client)
def test_incarnation_parsed_correctly(self):
incarnation = '123'
goal_state = self._get_goal_state(incarnation=incarnation)
self.assertEqual(incarnation, goal_state.incarnation)
def test_container_id_parsed_correctly(self):
container_id = 'TestContainerId'
goal_state = self._get_goal_state(container_id=container_id)
self.assertEqual(container_id, goal_state.container_id)
def test_instance_id_parsed_correctly(self):
instance_id = 'TestInstanceId'
goal_state = self._get_goal_state(instance_id=instance_id)
self.assertEqual(instance_id, goal_state.instance_id)
def test_shared_config_xml_parsed_and_fetched_correctly(self):
http_client = mock.MagicMock()
shared_config_url = 'TestSharedConfigUrl'
goal_state = self._get_goal_state(
http_client=http_client, shared_config_url=shared_config_url)
shared_config_xml = goal_state.shared_config_xml
self.assertEqual(1, http_client.get.call_count)
self.assertEqual(shared_config_url, http_client.get.call_args[0][0])
self.assertEqual(http_client.get.return_value.contents,
shared_config_xml)
def test_certificates_xml_parsed_and_fetched_correctly(self):
http_client = mock.MagicMock()
certificates_url = 'TestSharedConfigUrl'
goal_state = self._get_goal_state(
http_client=http_client, certificates_url=certificates_url)
certificates_xml = goal_state.certificates_xml
self.assertEqual(1, http_client.get.call_count)
self.assertEqual(certificates_url, http_client.get.call_args[0][0])
self.assertTrue(http_client.get.call_args[1].get('secure', False))
self.assertEqual(http_client.get.return_value.contents,
certificates_xml)
def test_missing_certificates_skips_http_get(self):
http_client = mock.MagicMock()
goal_state = self._get_goal_state(
http_client=http_client, certificates_url=None)
certificates_xml = goal_state.certificates_xml
self.assertEqual(0, http_client.get.call_count)
self.assertIsNone(certificates_xml)
class TestAzureEndpointHttpClient(TestCase):
regular_headers = {
'x-ms-agent-name': 'WALinuxAgent',
'x-ms-version': '2012-11-30',
}
def setUp(self):
super(TestAzureEndpointHttpClient, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.read_file_or_url = patches.enter_context(
mock.patch.object(DataSourceAzure.util, 'read_file_or_url'))
def test_non_secure_get(self):
client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock())
url = 'MyTestUrl'
response = client.get(url, secure=False)
self.assertEqual(1, self.read_file_or_url.call_count)
self.assertEqual(self.read_file_or_url.return_value, response)
self.assertEqual(mock.call(url, headers=self.regular_headers),
self.read_file_or_url.call_args)
def test_secure_get(self):
url = 'MyTestUrl'
certificate = mock.MagicMock()
expected_headers = self.regular_headers.copy()
expected_headers.update({
"x-ms-cipher-name": "DES_EDE3_CBC",
"x-ms-guest-agent-public-x509-cert": certificate,
})
client = DataSourceAzure.AzureEndpointHttpClient(certificate)
response = client.get(url, secure=True)
self.assertEqual(1, self.read_file_or_url.call_count)
self.assertEqual(self.read_file_or_url.return_value, response)
self.assertEqual(mock.call(url, headers=expected_headers),
self.read_file_or_url.call_args)
def test_post(self):
data = mock.MagicMock()
url = 'MyTestUrl'
client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock())
response = client.post(url, data=data)
self.assertEqual(1, self.read_file_or_url.call_count)
self.assertEqual(self.read_file_or_url.return_value, response)
self.assertEqual(
mock.call(url, data=data, headers=self.regular_headers),
self.read_file_or_url.call_args)
def test_post_with_extra_headers(self):
url = 'MyTestUrl'
client = DataSourceAzure.AzureEndpointHttpClient(mock.MagicMock())
extra_headers = {'test': 'header'}
client.post(url, extra_headers=extra_headers)
self.assertEqual(1, self.read_file_or_url.call_count)
expected_headers = self.regular_headers.copy()
expected_headers.update(extra_headers)
self.assertEqual(
mock.call(mock.ANY, data=mock.ANY, headers=expected_headers),
self.read_file_or_url.call_args)
class TestOpenSSLManager(TestCase):
def setUp(self):
super(TestOpenSSLManager, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.subp = patches.enter_context(
mock.patch.object(DataSourceAzure.util, 'subp'))
@mock.patch.object(DataSourceAzure, 'cd', mock.MagicMock())
@mock.patch.object(DataSourceAzure.tempfile, 'TemporaryDirectory')
def test_openssl_manager_creates_a_tmpdir(self, TemporaryDirectory):
manager = DataSourceAzure.OpenSSLManager()
self.assertEqual(TemporaryDirectory.return_value, manager.tmpdir)
@mock.patch('builtins.open')
def test_generate_certificate_uses_tmpdir(self, open):
subp_directory = {}
def capture_directory(*args, **kwargs):
subp_directory['path'] = os.getcwd()
self.subp.side_effect = capture_directory
manager = DataSourceAzure.OpenSSLManager()
self.assertEqual(manager.tmpdir.name, subp_directory['path'])
class TestWALinuxAgentShim(TestCase):
def setUp(self):
super(TestWALinuxAgentShim, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.AzureEndpointHttpClient = patches.enter_context(
mock.patch.object(DataSourceAzure, 'AzureEndpointHttpClient'))
self.find_endpoint = patches.enter_context(
mock.patch.object(DataSourceAzure, 'find_endpoint'))
self.GoalState = patches.enter_context(
mock.patch.object(DataSourceAzure, 'GoalState'))
self.OpenSSLManager = patches.enter_context(
mock.patch.object(DataSourceAzure, 'OpenSSLManager'))
def test_http_client_uses_certificate(self):
shim = DataSourceAzure.WALinuxAgentShim()
self.assertEqual(
[mock.call(self.OpenSSLManager.return_value.certificate)],
self.AzureEndpointHttpClient.call_args_list)
self.assertEqual(self.AzureEndpointHttpClient.return_value,
shim.http_client)
def test_correct_url_used_for_goalstate(self):
self.find_endpoint.return_value = 'test_endpoint'
shim = DataSourceAzure.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
get = self.AzureEndpointHttpClient.return_value.get
self.assertEqual(
[mock.call('http://test_endpoint/machine/?comp=goalstate')],
get.call_args_list)
self.assertEqual(
[mock.call(get.return_value.contents, shim.http_client)],
self.GoalState.call_args_list)
def test_certificates_used_to_determine_public_keys(self):
shim = DataSourceAzure.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.GoalState.return_value.certificates_xml)],
self.OpenSSLManager.return_value.parse_certificates.call_args_list)
self.assertEqual(
self.OpenSSLManager.return_value.parse_certificates.return_value,
shim.public_keys)
def test_absent_certificates_produces_empty_public_keys(self):
self.GoalState.return_value.certificates_xml = None
shim = DataSourceAzure.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
self.assertEqual([], shim.public_keys)
def test_correct_url_used_for_report_ready(self):
self.find_endpoint.return_value = 'test_endpoint'
shim = DataSourceAzure.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
expected_url = 'http://test_endpoint/machine?comp=health'
self.assertEqual(
[mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
shim.http_client.post.call_args_list)
def test_goal_state_values_used_for_report_ready(self):
self.GoalState.return_value.incarnation = 'TestIncarnation'
self.GoalState.return_value.container_id = 'TestContainerId'
self.GoalState.return_value.instance_id = 'TestInstanceId'
shim = DataSourceAzure.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
posted_document = shim.http_client.post.call_args[1]['data']
self.assertIn('TestIncarnation', posted_document)
self.assertIn('TestContainerId', posted_document)
self.assertIn('TestInstanceId', posted_document)