summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/sources/DataSourceAzure.py8
-rw-r--r--cloudinit/sources/helpers/azure.py31
-rw-r--r--tests/unittests/test_datasource/test_azure.py19
-rw-r--r--tests/unittests/test_datasource/test_azure_helper.py56
4 files changed, 92 insertions, 22 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py
index 5e147950..4053cfa6 100644
--- a/cloudinit/sources/DataSourceAzure.py
+++ b/cloudinit/sources/DataSourceAzure.py
@@ -29,7 +29,7 @@ from cloudinit.settings import PER_ALWAYS
from cloudinit import sources
from cloudinit import util
from cloudinit.sources.helpers.azure import (
- iid_from_shared_config_content, WALinuxAgentShim)
+ get_metadata_from_fabric, iid_from_shared_config_content)
LOG = logging.getLogger(__name__)
@@ -185,15 +185,13 @@ class DataSourceAzureNet(sources.DataSource):
write_files(ddir, files, dirmode=0o700)
try:
- shim = WALinuxAgentShim()
- data = shim.register_with_azure_and_fetch_data()
+ fabric_data = get_metadata_from_fabric()
except Exception as exc:
LOG.info("Error communicating with Azure fabric; assume we aren't"
" on Azure.", exc_info=True)
return False
- self.metadata['instance-id'] = data['instance-id']
- self.metadata['public-keys'] = data['public-keys']
+ self.metadata.update(fabric_data)
found_ephemeral = find_ephemeral_disk()
if found_ephemeral:
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index cb13187f..dfdfa7c2 100644
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -108,6 +108,9 @@ class OpenSSLManager(object):
self.certificate = None
self.generate_certificate()
+ def clean_up(self):
+ util.del_dir(self.tmpdir)
+
def generate_certificate(self):
LOG.debug('Generating certificate for communication with fabric...')
if self.certificate is not None:
@@ -205,11 +208,13 @@ class WALinuxAgentShim(object):
def __init__(self):
LOG.debug('WALinuxAgentShim instantiated...')
self.endpoint = self.find_endpoint()
- self.openssl_manager = OpenSSLManager()
- self.http_client = AzureEndpointHttpClient(
- self.openssl_manager.certificate)
+ self.openssl_manager = None
self.values = {}
+ def clean_up(self):
+ if self.openssl_manager is not None:
+ self.openssl_manager.clean_up()
+
@staticmethod
def find_endpoint():
LOG.debug('Finding Azure endpoint...')
@@ -234,17 +239,19 @@ class WALinuxAgentShim(object):
return endpoint_ip_address
def register_with_azure_and_fetch_data(self):
+ self.openssl_manager = OpenSSLManager()
+ http_client = AzureEndpointHttpClient(self.openssl_manager.certificate)
LOG.info('Registering with Azure...')
for i in range(10):
try:
- response = self.http_client.get(
+ response = http_client.get(
'http://{}/machine/?comp=goalstate'.format(self.endpoint))
except Exception:
time.sleep(i + 1)
else:
break
LOG.debug('Successfully fetched GoalState XML.')
- goal_state = GoalState(response.contents, self.http_client)
+ goal_state = GoalState(response.contents, http_client)
public_keys = []
if goal_state.certificates_xml is not None:
LOG.debug('Certificate XML found; parsing out public keys.')
@@ -255,19 +262,27 @@ class WALinuxAgentShim(object):
goal_state.shared_config_xml),
'public-keys': public_keys,
}
- self._report_ready(goal_state)
+ self._report_ready(goal_state, http_client)
return data
- def _report_ready(self, goal_state):
+ def _report_ready(self, goal_state, http_client):
LOG.debug('Reporting ready to Azure fabric.')
document = self.REPORT_READY_XML_TEMPLATE.format(
incarnation=goal_state.incarnation,
container_id=goal_state.container_id,
instance_id=goal_state.instance_id,
)
- self.http_client.post(
+ http_client.post(
"http://{}/machine?comp=health".format(self.endpoint),
data=document,
extra_headers={'Content-Type': 'text/xml; charset=utf-8'},
)
LOG.info('Reported ready to Azure fabric.')
+
+
+def get_metadata_from_fabric():
+ shim = WALinuxAgentShim()
+ try:
+ return shim.register_with_azure_and_fetch_data()
+ finally:
+ shim.clean_up()
diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py
index ee7109e1..983be4cd 100644
--- a/tests/unittests/test_datasource/test_azure.py
+++ b/tests/unittests/test_datasource/test_azure.py
@@ -122,11 +122,10 @@ class TestAzureDataSource(TestCase):
mod = DataSourceAzure
mod.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
- fake_shim = mock.MagicMock()
- fake_shim().register_with_azure_and_fetch_data.return_value = {
+ self.get_metadata_from_fabric = mock.MagicMock(return_value={
'instance-id': 'i-my-azure-id',
'public-keys': [],
- }
+ })
self.apply_patches([
(mod, 'list_possible_azure_ds_devs', dsdevs),
@@ -137,7 +136,7 @@ class TestAzureDataSource(TestCase):
(mod, 'perform_hostname_bounce', mock.MagicMock()),
(mod, 'get_hostname', mock.MagicMock()),
(mod, 'set_hostname', mock.MagicMock()),
- (mod, 'WALinuxAgentShim', fake_shim),
+ (mod, 'get_metadata_from_fabric', self.get_metadata_from_fabric),
])
dsrc = mod.DataSourceAzureNet(
@@ -388,6 +387,18 @@ class TestAzureDataSource(TestCase):
self.assertEqual(new_ovfenv,
load_file(os.path.join(self.waagent_d, 'ovf-env.xml')))
+ def test_exception_fetching_fabric_data_doesnt_propagate(self):
+ ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
+ self.get_metadata_from_fabric.side_effect = Exception
+ self.assertFalse(ds.get_data())
+
+ def test_fabric_data_included_in_metadata(self):
+ ds = self._get_ds({'ovfcontent': construct_valid_ovf_env()})
+ self.get_metadata_from_fabric.return_value = {'test': 'value'}
+ ret = ds.get_data()
+ self.assertTrue(ret)
+ self.assertEqual('value', ds.metadata['test'])
+
class TestAzureBounce(TestCase):
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py
index 398a9007..5fac2ade 100644
--- a/tests/unittests/test_datasource/test_azure_helper.py
+++ b/tests/unittests/test_datasource/test_azure_helper.py
@@ -296,6 +296,14 @@ class TestOpenSSLManager(TestCase):
manager = azure_helper.OpenSSLManager()
self.assertEqual(manager.tmpdir, subp_directory['path'])
+ @mock.patch.object(azure_helper, 'cd', mock.MagicMock())
+ @mock.patch.object(azure_helper.tempfile, 'mkdtemp', mock.MagicMock())
+ @mock.patch.object(azure_helper.util, 'del_dir')
+ def test_clean_up(self, del_dir):
+ manager = azure_helper.OpenSSLManager()
+ manager.clean_up()
+ self.assertEqual([mock.call(manager.tmpdir)], del_dir.call_args_list)
+
class TestWALinuxAgentShim(TestCase):
@@ -318,11 +326,10 @@ class TestWALinuxAgentShim(TestCase):
def test_http_client_uses_certificate(self):
shim = azure_helper.WALinuxAgentShim()
+ shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.OpenSSLManager.return_value.certificate)],
self.AzureEndpointHttpClient.call_args_list)
- self.assertEqual(self.AzureEndpointHttpClient.return_value,
- shim.http_client)
def test_correct_url_used_for_goalstate(self):
self.find_endpoint.return_value = 'test_endpoint'
@@ -333,7 +340,8 @@ class TestWALinuxAgentShim(TestCase):
[mock.call('http://test_endpoint/machine/?comp=goalstate')],
get.call_args_list)
self.assertEqual(
- [mock.call(get.return_value.contents, shim.http_client)],
+ [mock.call(get.return_value.contents,
+ self.AzureEndpointHttpClient.return_value)],
self.GoalState.call_args_list)
def test_certificates_used_to_determine_public_keys(self):
@@ -368,7 +376,7 @@ class TestWALinuxAgentShim(TestCase):
expected_url = 'http://test_endpoint/machine?comp=health'
self.assertEqual(
[mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
- shim.http_client.post.call_args_list)
+ self.AzureEndpointHttpClient.return_value.post.call_args_list)
def test_goal_state_values_used_for_report_ready(self):
self.GoalState.return_value.incarnation = 'TestIncarnation'
@@ -376,7 +384,45 @@ class TestWALinuxAgentShim(TestCase):
self.GoalState.return_value.instance_id = 'TestInstanceId'
shim = azure_helper.WALinuxAgentShim()
shim.register_with_azure_and_fetch_data()
- posted_document = shim.http_client.post.call_args[1]['data']
+ posted_document = (
+ self.AzureEndpointHttpClient.return_value.post.call_args[1]['data']
+ )
self.assertIn('TestIncarnation', posted_document)
self.assertIn('TestContainerId', posted_document)
self.assertIn('TestInstanceId', posted_document)
+
+ def test_clean_up_can_be_called_at_any_time(self):
+ shim = azure_helper.WALinuxAgentShim()
+ shim.clean_up()
+
+ def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self):
+ shim = azure_helper.WALinuxAgentShim()
+ shim.register_with_azure_and_fetch_data()
+ shim.clean_up()
+ self.assertEqual(
+ 1, self.OpenSSLManager.return_value.clean_up.call_count)
+
+
+class TestGetMetadataFromFabric(TestCase):
+
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_data_from_shim_returned(self, shim):
+ ret = azure_helper.get_metadata_from_fabric()
+ self.assertEqual(
+ shim.return_value.register_with_azure_and_fetch_data.return_value,
+ ret)
+
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_success_calls_clean_up(self, shim):
+ azure_helper.get_metadata_from_fabric()
+ self.assertEqual(1, shim.return_value.clean_up.call_count)
+
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_failure_in_registration_calls_clean_up(self, shim):
+ class SentinelException(Exception):
+ pass
+ shim.return_value.register_with_azure_and_fetch_data.side_effect = (
+ SentinelException)
+ self.assertRaises(SentinelException,
+ azure_helper.get_metadata_from_fabric)
+ self.assertEqual(1, shim.return_value.clean_up.call_count)