import os
import struct
import unittest
from cloudinit.sources.helpers import azure as azure_helper
from ..helpers import TestCase
try:
from unittest import mock
except ImportError:
import mock
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
GOAL_STATE_TEMPLATE = """\
2012-11-30
{incarnation}
Started
300000
16001
FALSE
{container_id}
{instance_id}
Started
http://100.86.192.70:80/...hostingEnvironmentConfig...
{shared_config_url}
http://100.86.192.70:80/...extensionsConfig...
http://100.86.192.70:80/...fullConfig...
{certificates_url}
68ce47.0.68ce47.0.utl-trusty--292258.1.xml
"""
class TestReadAzureSharedConfig(unittest.TestCase):
def test_valid_content(self):
xml = """
"""
ret = azure_helper.iid_from_shared_config_content(xml)
self.assertEqual("MY_INSTANCE_ID", ret)
class TestFindEndpoint(TestCase):
def setUp(self):
super(TestFindEndpoint, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.load_file = patches.enter_context(
mock.patch.object(azure_helper.util, 'load_file'))
def test_missing_file(self):
self.load_file.side_effect = IOError
self.assertRaises(IOError,
azure_helper.WALinuxAgentShim.find_endpoint)
def test_missing_special_azure_line(self):
self.load_file.return_value = ''
self.assertRaises(Exception,
azure_helper.WALinuxAgentShim.find_endpoint)
def _build_lease_content(self, ip_address, use_hex=True):
ip_address_repr = ':'.join(
[hex(int(part)).replace('0x', '')
for part in ip_address.split('.')])
if not use_hex:
ip_address_repr = struct.pack(
'>L', int(ip_address_repr.replace(':', ''), 16))
ip_address_repr = '"{0}"'.format(ip_address_repr.decode('utf-8'))
return '\n'.join([
'lease {',
' interface "eth0";',
' option unknown-245 {0};'.format(ip_address_repr),
'}'])
def test_hex_string(self):
ip_address = '98.76.54.32'
file_content = self._build_lease_content(ip_address)
self.load_file.return_value = file_content
self.assertEqual(ip_address,
azure_helper.WALinuxAgentShim.find_endpoint())
def test_hex_string_with_single_character_part(self):
ip_address = '4.3.2.1'
file_content = self._build_lease_content(ip_address)
self.load_file.return_value = file_content
self.assertEqual(ip_address,
azure_helper.WALinuxAgentShim.find_endpoint())
def test_packed_string(self):
ip_address = '98.76.54.32'
file_content = self._build_lease_content(ip_address, use_hex=False)
self.load_file.return_value = file_content
self.assertEqual(ip_address,
azure_helper.WALinuxAgentShim.find_endpoint())
def test_latest_lease_used(self):
ip_addresses = ['4.3.2.1', '98.76.54.32']
file_content = '\n'.join([self._build_lease_content(ip_address)
for ip_address in ip_addresses])
self.load_file.return_value = file_content
self.assertEqual(ip_addresses[-1],
azure_helper.WALinuxAgentShim.find_endpoint())
class TestGoalStateParsing(TestCase):
default_parameters = {
'incarnation': 1,
'container_id': 'MyContainerId',
'instance_id': 'MyInstanceId',
'shared_config_url': 'MySharedConfigUrl',
'certificates_url': 'MyCertificatesUrl',
}
def _get_goal_state(self, http_client=None, **kwargs):
if http_client is None:
http_client = mock.MagicMock()
parameters = self.default_parameters.copy()
parameters.update(kwargs)
xml = GOAL_STATE_TEMPLATE.format(**parameters)
if parameters['certificates_url'] is None:
new_xml_lines = []
for line in xml.splitlines():
if 'Certificates' in line:
continue
new_xml_lines.append(line)
xml = '\n'.join(new_xml_lines)
return azure_helper.GoalState(xml, http_client)
def test_incarnation_parsed_correctly(self):
incarnation = '123'
goal_state = self._get_goal_state(incarnation=incarnation)
self.assertEqual(incarnation, goal_state.incarnation)
def test_container_id_parsed_correctly(self):
container_id = 'TestContainerId'
goal_state = self._get_goal_state(container_id=container_id)
self.assertEqual(container_id, goal_state.container_id)
def test_instance_id_parsed_correctly(self):
instance_id = 'TestInstanceId'
goal_state = self._get_goal_state(instance_id=instance_id)
self.assertEqual(instance_id, goal_state.instance_id)
def test_shared_config_xml_parsed_and_fetched_correctly(self):
http_client = mock.MagicMock()
shared_config_url = 'TestSharedConfigUrl'
goal_state = self._get_goal_state(
http_client=http_client, shared_config_url=shared_config_url)
shared_config_xml = goal_state.shared_config_xml
self.assertEqual(1, http_client.get.call_count)
self.assertEqual(shared_config_url, http_client.get.call_args[0][0])
self.assertEqual(http_client.get.return_value.contents,
shared_config_xml)
def test_certificates_xml_parsed_and_fetched_correctly(self):
http_client = mock.MagicMock()
certificates_url = 'TestSharedConfigUrl'
goal_state = self._get_goal_state(
http_client=http_client, certificates_url=certificates_url)
certificates_xml = goal_state.certificates_xml
self.assertEqual(1, http_client.get.call_count)
self.assertEqual(certificates_url, http_client.get.call_args[0][0])
self.assertTrue(http_client.get.call_args[1].get('secure', False))
self.assertEqual(http_client.get.return_value.contents,
certificates_xml)
def test_missing_certificates_skips_http_get(self):
http_client = mock.MagicMock()
goal_state = self._get_goal_state(
http_client=http_client, certificates_url=None)
certificates_xml = goal_state.certificates_xml
self.assertEqual(0, http_client.get.call_count)
self.assertIsNone(certificates_xml)
class TestAzureEndpointHttpClient(TestCase):
regular_headers = {
'x-ms-agent-name': 'WALinuxAgent',
'x-ms-version': '2012-11-30',
}
def setUp(self):
super(TestAzureEndpointHttpClient, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.read_file_or_url = patches.enter_context(
mock.patch.object(azure_helper.util, 'read_file_or_url'))
def test_non_secure_get(self):
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
url = 'MyTestUrl'
response = client.get(url, secure=False)
self.assertEqual(1, self.read_file_or_url.call_count)
self.assertEqual(self.read_file_or_url.return_value, response)
self.assertEqual(mock.call(url, headers=self.regular_headers),
self.read_file_or_url.call_args)
def test_secure_get(self):
url = 'MyTestUrl'
certificate = mock.MagicMock()
expected_headers = self.regular_headers.copy()
expected_headers.update({
"x-ms-cipher-name": "DES_EDE3_CBC",
"x-ms-guest-agent-public-x509-cert": certificate,
})
client = azure_helper.AzureEndpointHttpClient(certificate)
response = client.get(url, secure=True)
self.assertEqual(1, self.read_file_or_url.call_count)
self.assertEqual(self.read_file_or_url.return_value, response)
self.assertEqual(mock.call(url, headers=expected_headers),
self.read_file_or_url.call_args)
def test_post(self):
data = mock.MagicMock()
url = 'MyTestUrl'
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
response = client.post(url, data=data)
self.assertEqual(1, self.read_file_or_url.call_count)
self.assertEqual(self.read_file_or_url.return_value, response)
self.assertEqual(
mock.call(url, data=data, headers=self.regular_headers),
self.read_file_or_url.call_args)
def test_post_with_extra_headers(self):
url = 'MyTestUrl'
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
extra_headers = {'test': 'header'}
client.post(url, extra_headers=extra_headers)
self.assertEqual(1, self.read_file_or_url.call_count)
expected_headers = self.regular_headers.copy()
expected_headers.update(extra_headers)
self.assertEqual(
mock.call(mock.ANY, data=mock.ANY, headers=expected_headers),
self.read_file_or_url.call_args)
class TestOpenSSLManager(TestCase):
def setUp(self):
super(TestOpenSSLManager, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.subp = patches.enter_context(
mock.patch.object(azure_helper.util, 'subp'))
try:
self.open = patches.enter_context(
mock.patch('__builtin__.open'))
except ImportError:
self.open = patches.enter_context(
mock.patch('builtins.open'))
@mock.patch.object(azure_helper, 'cd', mock.MagicMock())
@mock.patch.object(azure_helper.tempfile, 'mkdtemp')
def test_openssl_manager_creates_a_tmpdir(self, mkdtemp):
manager = azure_helper.OpenSSLManager()
self.assertEqual(mkdtemp.return_value, manager.tmpdir)
def test_generate_certificate_uses_tmpdir(self):
subp_directory = {}
def capture_directory(*args, **kwargs):
subp_directory['path'] = os.getcwd()
self.subp.side_effect = capture_directory
manager = azure_helper.OpenSSLManager()
self.assertEqual(manager.tmpdir, subp_directory['path'])
@mock.patch.object(azure_helper, 'cd', mock.MagicMock())
@mock.patch.object(azure_helper.tempfile, 'mkdtemp', mock.MagicMock())
@mock.patch.object(azure_helper.util, 'del_dir')
def test_clean_up(self, del_dir):
manager = azure_helper.OpenSSLManager()
manager.clean_up()
self.assertEqual([mock.call(manager.tmpdir)], del_dir.call_args_list)
class TestWALinuxAgentShim(TestCase):
def setUp(self):
super(TestWALinuxAgentShim, self).setUp()
patches = ExitStack()
self.addCleanup(patches.close)
self.AzureEndpointHttpClient = patches.enter_context(
mock.patch.object(azure_helper, 'AzureEndpointHttpClient'))
self.find_endpoint = patches.enter_context(
mock.patch.object(
azure_helper.WALinuxAgentShim, 'find_endpoint'))
self.GoalState = patches.enter_context(
mock.patch.object(azure_helper, 'GoalState'))
self.iid_from_shared_config_content = patches.enter_context(
mock.patch.object(azure_helper, 'iid_from_shared_config_content'))
self.OpenSSLManager = patches.enter_context(
mock.patch.object(azure_helper, 'OpenSSLManager'))
patches.enter_context(
mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
def test_http_client_uses_certificate(self):
shim = azure_helper.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.OpenSSLManager.return_value.certificate)],
self.AzureEndpointHttpClient.call_args_list)
def test_correct_url_used_for_goalstate(self):
self.find_endpoint.return_value = 'test_endpoint'
shim = azure_helper.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
get = self.AzureEndpointHttpClient.return_value.get
self.assertEqual(
[mock.call('http://test_endpoint/machine/?comp=goalstate')],
get.call_args_list)
self.assertEqual(
[mock.call(get.return_value.contents,
self.AzureEndpointHttpClient.return_value)],
self.GoalState.call_args_list)
def test_certificates_used_to_determine_public_keys(self):
shim = azure_helper.WALinuxAgentShim()
data = shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.GoalState.return_value.certificates_xml)],
self.OpenSSLManager.return_value.parse_certificates.call_args_list)
self.assertEqual(
self.OpenSSLManager.return_value.parse_certificates.return_value,
data['public-keys'])
def test_absent_certificates_produces_empty_public_keys(self):
self.GoalState.return_value.certificates_xml = None
shim = azure_helper.WALinuxAgentShim()
data = shim.register_with_azure_and_fetch_data()
self.assertEqual([], data['public-keys'])
def test_instance_id_returned_in_data(self):
shim = azure_helper.WALinuxAgentShim()
data = shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.GoalState.return_value.shared_config_xml)],
self.iid_from_shared_config_content.call_args_list)
self.assertEqual(self.iid_from_shared_config_content.return_value,
data['instance-id'])
def test_correct_url_used_for_report_ready(self):
self.find_endpoint.return_value = 'test_endpoint'
shim = azure_helper.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
expected_url = 'http://test_endpoint/machine?comp=health'
self.assertEqual(
[mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
self.AzureEndpointHttpClient.return_value.post.call_args_list)
def test_goal_state_values_used_for_report_ready(self):
self.GoalState.return_value.incarnation = 'TestIncarnation'
self.GoalState.return_value.container_id = 'TestContainerId'
self.GoalState.return_value.instance_id = 'TestInstanceId'
shim = azure_helper.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
posted_document = (
self.AzureEndpointHttpClient.return_value.post.call_args[1]['data']
)
self.assertIn('TestIncarnation', posted_document)
self.assertIn('TestContainerId', posted_document)
self.assertIn('TestInstanceId', posted_document)
def test_clean_up_can_be_called_at_any_time(self):
shim = azure_helper.WALinuxAgentShim()
shim.clean_up()
def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self):
shim = azure_helper.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
shim.clean_up()
self.assertEqual(
1, self.OpenSSLManager.return_value.clean_up.call_count)
def test_failure_to_fetch_goalstate_bubbles_up(self):
class SentinelException(Exception):
pass
self.AzureEndpointHttpClient.return_value.get.side_effect = (
SentinelException)
shim = azure_helper.WALinuxAgentShim()
self.assertRaises(SentinelException,
shim.register_with_azure_and_fetch_data)
class TestGetMetadataFromFabric(TestCase):
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_data_from_shim_returned(self, shim):
ret = azure_helper.get_metadata_from_fabric()
self.assertEqual(
shim.return_value.register_with_azure_and_fetch_data.return_value,
ret)
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_success_calls_clean_up(self, shim):
azure_helper.get_metadata_from_fabric()
self.assertEqual(1, shim.return_value.clean_up.call_count)
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_failure_in_registration_calls_clean_up(self, shim):
class SentinelException(Exception):
pass
shim.return_value.register_with_azure_and_fetch_data.side_effect = (
SentinelException)
self.assertRaises(SentinelException,
azure_helper.get_metadata_from_fabric)
self.assertEqual(1, shim.return_value.clean_up.call_count)