summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcloudinit/sources/DataSourceAzure.py53
-rwxr-xr-xcloudinit/sources/helpers/azure.py50
-rw-r--r--doc/rtd/topics/datasources/azure.rst6
-rw-r--r--tests/unittests/test_datasource/test_azure.py64
-rw-r--r--tests/unittests/test_datasource/test_azure_helper.py13
5 files changed, 156 insertions, 30 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py
index f3c6452b..e98fd497 100755
--- a/cloudinit/sources/DataSourceAzure.py
+++ b/cloudinit/sources/DataSourceAzure.py
@@ -561,6 +561,40 @@ class DataSourceAzure(sources.DataSource):
def device_name_to_device(self, name):
return self.ds_cfg['disk_aliases'].get(name)
+ @azure_ds_telemetry_reporter
+ def get_public_ssh_keys(self):
+ """
+ Try to get the ssh keys from IMDS first, and if that fails
+ (i.e. IMDS is unavailable) then fallback to getting the ssh
+ keys from OVF.
+
+ The benefit to getting keys from IMDS is a large performance
+ advantage, so this is a strong preference. But we must keep
+ OVF as a second option for environments that don't have IMDS.
+ """
+ LOG.debug('Retrieving public SSH keys')
+ ssh_keys = []
+ try:
+ ssh_keys = [
+ public_key['keyData']
+ for public_key
+ in self.metadata['imds']['compute']['publicKeys']
+ ]
+ LOG.debug('Retrieved SSH keys from IMDS')
+ except KeyError:
+ log_msg = 'Unable to get keys from IMDS, falling back to OVF'
+ LOG.debug(log_msg)
+ report_diagnostic_event(log_msg)
+ try:
+ ssh_keys = self.metadata['public-keys']
+ LOG.debug('Retrieved keys from OVF')
+ except KeyError:
+ log_msg = 'No keys available from OVF'
+ LOG.debug(log_msg)
+ report_diagnostic_event(log_msg)
+
+ return ssh_keys
+
def get_config_obj(self):
return self.cfg
@@ -764,7 +798,22 @@ class DataSourceAzure(sources.DataSource):
if self.ds_cfg['agent_command'] == AGENT_START_BUILTIN:
self.bounce_network_with_azure_hostname()
- pubkey_info = self.cfg.get('_pubkeys', None)
+ pubkey_info = None
+ try:
+ public_keys = self.metadata['imds']['compute']['publicKeys']
+ LOG.debug(
+ 'Successfully retrieved %s key(s) from IMDS',
+ len(public_keys)
+ if public_keys is not None
+ else 0
+ )
+ except KeyError:
+ LOG.debug(
+ 'Unable to retrieve SSH keys from IMDS during '
+ 'negotiation, falling back to OVF'
+ )
+ pubkey_info = self.cfg.get('_pubkeys', None)
+
metadata_func = partial(get_metadata_from_fabric,
fallback_lease_file=self.
dhclient_lease_file,
@@ -1443,7 +1492,7 @@ def get_metadata_from_imds(fallback_nic, retries):
@azure_ds_telemetry_reporter
def _get_metadata_from_imds(retries):
- url = IMDS_URL + "instance?api-version=2017-12-01"
+ url = IMDS_URL + "instance?api-version=2019-06-01"
headers = {"Metadata": "true"}
try:
response = readurl(
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index 507f6ac8..79445a81 100755
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -288,12 +288,16 @@ class InvalidGoalStateXMLException(Exception):
class GoalState:
- def __init__(self, unparsed_xml: str,
- azure_endpoint_client: AzureEndpointHttpClient) -> None:
+ def __init__(
+ self,
+ unparsed_xml: str,
+ azure_endpoint_client: AzureEndpointHttpClient,
+ need_certificate: bool = True) -> None:
"""Parses a GoalState XML string and returns a GoalState object.
@param unparsed_xml: string representing a GoalState XML.
- @param azure_endpoint_client: instance of AzureEndpointHttpClient
+ @param azure_endpoint_client: instance of AzureEndpointHttpClient.
+ @param need_certificate: switch to know if certificates is needed.
@return: GoalState object representing the GoalState XML string.
"""
self.azure_endpoint_client = azure_endpoint_client
@@ -322,7 +326,7 @@ class GoalState:
url = self._text_from_xpath(
'./Container/RoleInstanceList/RoleInstance'
'/Configuration/Certificates')
- if url is not None:
+ if url is not None and need_certificate:
with events.ReportEventStack(
name="get-certificates-xml",
description="get certificates xml",
@@ -741,27 +745,38 @@ class WALinuxAgentShim:
GoalState.
@return: The list of user's authorized pubkey values.
"""
- if self.openssl_manager is None:
+ http_client_certificate = None
+ if self.openssl_manager is None and pubkey_info is not None:
self.openssl_manager = OpenSSLManager()
+ http_client_certificate = self.openssl_manager.certificate
if self.azure_endpoint_client is None:
self.azure_endpoint_client = AzureEndpointHttpClient(
- self.openssl_manager.certificate)
- goal_state = self._fetch_goal_state_from_azure()
- ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info)
+ http_client_certificate)
+ goal_state = self._fetch_goal_state_from_azure(
+ need_certificate=http_client_certificate is not None
+ )
+ ssh_keys = None
+ if pubkey_info is not None:
+ ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info)
health_reporter = GoalStateHealthReporter(
goal_state, self.azure_endpoint_client, self.endpoint)
health_reporter.send_ready_signal()
return {'public-keys': ssh_keys}
@azure_ds_telemetry_reporter
- def _fetch_goal_state_from_azure(self) -> GoalState:
+ def _fetch_goal_state_from_azure(
+ self,
+ need_certificate: bool) -> GoalState:
"""Fetches the GoalState XML from the Azure endpoint, parses the XML,
and returns a GoalState object.
@return: GoalState object representing the GoalState XML
"""
unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure()
- return self._parse_raw_goal_state_xml(unparsed_goal_state_xml)
+ return self._parse_raw_goal_state_xml(
+ unparsed_goal_state_xml,
+ need_certificate
+ )
@azure_ds_telemetry_reporter
def _get_raw_goal_state_xml_from_azure(self) -> str:
@@ -774,7 +789,11 @@ class WALinuxAgentShim:
LOG.info('Registering with Azure...')
url = 'http://{}/machine/?comp=goalstate'.format(self.endpoint)
try:
- response = self.azure_endpoint_client.get(url)
+ with events.ReportEventStack(
+ name="goalstate-retrieval",
+ description="retrieve goalstate",
+ parent=azure_ds_reporter):
+ response = self.azure_endpoint_client.get(url)
except Exception as e:
msg = 'failed to register with Azure: %s' % e
LOG.warning(msg)
@@ -785,7 +804,9 @@ class WALinuxAgentShim:
@azure_ds_telemetry_reporter
def _parse_raw_goal_state_xml(
- self, unparsed_goal_state_xml: str) -> GoalState:
+ self,
+ unparsed_goal_state_xml: str,
+ need_certificate: bool) -> GoalState:
"""Parses a GoalState XML string and returns a GoalState object.
@param unparsed_goal_state_xml: GoalState XML string
@@ -793,7 +814,10 @@ class WALinuxAgentShim:
"""
try:
goal_state = GoalState(
- unparsed_goal_state_xml, self.azure_endpoint_client)
+ unparsed_goal_state_xml,
+ self.azure_endpoint_client,
+ need_certificate
+ )
except Exception as e:
msg = 'Error processing GoalState XML: %s' % e
LOG.warning(msg)
diff --git a/doc/rtd/topics/datasources/azure.rst b/doc/rtd/topics/datasources/azure.rst
index fdb919a5..e04c3a33 100644
--- a/doc/rtd/topics/datasources/azure.rst
+++ b/doc/rtd/topics/datasources/azure.rst
@@ -68,6 +68,12 @@ configuration information to the instance. Cloud-init uses the IMDS for:
- network configuration for the instance which is applied per boot
- a preprovisioing gate which blocks instance configuration until Azure fabric
is ready to provision
+- retrieving SSH public keys. Cloud-init will first try to utilize SSH keys
+ returned from IMDS, and if they are not provided from IMDS then it will
+ fallback to using the OVF file provided from the CD-ROM. There is a large
+ performance benefit to using IMDS for SSH key retrieval, but in order to
+ support environments where IMDS is not available then we must continue to
+ all for keys from OVF
Configuration
diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py
index 47e03bd1..2dda9925 100644
--- a/tests/unittests/test_datasource/test_azure.py
+++ b/tests/unittests/test_datasource/test_azure.py
@@ -102,7 +102,13 @@ NETWORK_METADATA = {
"vmId": "ff702a6b-cb6a-4fcd-ad68-b4ce38227642",
"vmScaleSetName": "",
"vmSize": "Standard_DS1_v2",
- "zone": ""
+ "zone": "",
+ "publicKeys": [
+ {
+ "keyData": "key1",
+ "path": "path1"
+ }
+ ]
},
"network": {
"interface": [
@@ -302,7 +308,7 @@ class TestGetMetadataFromIMDS(HttprettyTestCase):
def setUp(self):
super(TestGetMetadataFromIMDS, self).setUp()
- self.network_md_url = dsaz.IMDS_URL + "instance?api-version=2017-12-01"
+ self.network_md_url = dsaz.IMDS_URL + "instance?api-version=2019-06-01"
@mock.patch(MOCKPATH + 'readurl')
@mock.patch(MOCKPATH + 'EphemeralDHCPv4')
@@ -1304,6 +1310,40 @@ scbus-1 on xpt0 bus 0
dsaz.get_hostname(hostname_command=("hostname",))
m_subp.assert_called_once_with(("hostname",), capture=True)
+ @mock.patch(
+ 'cloudinit.sources.helpers.azure.OpenSSLManager.parse_certificates')
+ def test_get_public_ssh_keys_with_imds(self, m_parse_certificates):
+ sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}}
+ odata = {'HostName': "myhost", 'UserName': "myuser"}
+ data = {
+ 'ovfcontent': construct_valid_ovf_env(data=odata),
+ 'sys_cfg': sys_cfg
+ }
+ dsrc = self._get_ds(data)
+ dsrc.get_data()
+ dsrc.setup(True)
+ ssh_keys = dsrc.get_public_ssh_keys()
+ self.assertEqual(ssh_keys, ['key1'])
+ self.assertEqual(m_parse_certificates.call_count, 0)
+
+ @mock.patch(MOCKPATH + 'get_metadata_from_imds')
+ def test_get_public_ssh_keys_without_imds(
+ self,
+ m_get_metadata_from_imds):
+ m_get_metadata_from_imds.return_value = dict()
+ sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}}
+ odata = {'HostName': "myhost", 'UserName': "myuser"}
+ data = {
+ 'ovfcontent': construct_valid_ovf_env(data=odata),
+ 'sys_cfg': sys_cfg
+ }
+ dsrc = self._get_ds(data)
+ dsaz.get_metadata_from_fabric.return_value = {'public-keys': ['key2']}
+ dsrc.get_data()
+ dsrc.setup(True)
+ ssh_keys = dsrc.get_public_ssh_keys()
+ self.assertEqual(ssh_keys, ['key2'])
+
class TestAzureBounce(CiTestCase):
@@ -2094,14 +2134,18 @@ class TestAzureDataSourcePreprovisioning(CiTestCase):
md, _ud, cfg, _d = dsa._reprovision()
self.assertEqual(md['local-hostname'], hostname)
self.assertEqual(cfg['system_info']['default_user']['name'], username)
- self.assertEqual(fake_resp.call_args_list,
- [mock.call(allow_redirects=True,
- headers={'Metadata': 'true',
- 'User-Agent':
- 'Cloud-Init/%s' % vs()},
- method='GET',
- timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS,
- url=full_url)])
+ self.assertIn(
+ mock.call(
+ allow_redirects=True,
+ headers={
+ 'Metadata': 'true',
+ 'User-Agent': 'Cloud-Init/%s' % vs()
+ },
+ method='GET',
+ timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS,
+ url=full_url
+ ),
+ fake_resp.call_args_list)
self.assertEqual(m_dhcp.call_count, 2)
m_net.assert_any_call(
broadcast='192.168.2.255', interface='eth9', ip='192.168.2.9',
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py
index 5e6d3d2d..5c31b8be 100644
--- a/tests/unittests/test_datasource/test_azure_helper.py
+++ b/tests/unittests/test_datasource/test_azure_helper.py
@@ -609,11 +609,11 @@ class TestWALinuxAgentShim(CiTestCase):
self.GoalState.return_value.container_id = self.test_container_id
self.GoalState.return_value.instance_id = self.test_instance_id
- def test_azure_endpoint_client_uses_certificate_during_report_ready(self):
+ def test_http_client_does_not_use_certificate(self):
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
self.assertEqual(
- [mock.call(self.OpenSSLManager.return_value.certificate)],
+ [mock.call(None)],
self.AzureEndpointHttpClient.call_args_list)
def test_correct_url_used_for_goalstate_during_report_ready(self):
@@ -625,8 +625,11 @@ class TestWALinuxAgentShim(CiTestCase):
[mock.call('http://test_endpoint/machine/?comp=goalstate')],
get.call_args_list)
self.assertEqual(
- [mock.call(get.return_value.contents,
- self.AzureEndpointHttpClient.return_value)],
+ [mock.call(
+ get.return_value.contents,
+ self.AzureEndpointHttpClient.return_value,
+ False
+ )],
self.GoalState.call_args_list)
def test_certificates_used_to_determine_public_keys(self):
@@ -701,7 +704,7 @@ class TestWALinuxAgentShim(CiTestCase):
shim.register_with_azure_and_fetch_data()
shim.clean_up()
self.assertEqual(
- 1, self.OpenSSLManager.return_value.clean_up.call_count)
+ 0, self.OpenSSLManager.return_value.clean_up.call_count)
def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self):
self.AzureEndpointHttpClient.return_value.get \