diff options
-rwxr-xr-x | cloudinit/sources/DataSourceAzure.py | 53 | ||||
-rwxr-xr-x | cloudinit/sources/helpers/azure.py | 50 | ||||
-rw-r--r-- | doc/rtd/topics/datasources/azure.rst | 6 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 64 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure_helper.py | 13 |
5 files changed, 156 insertions, 30 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index f3c6452b..e98fd497 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -561,6 +561,40 @@ class DataSourceAzure(sources.DataSource): def device_name_to_device(self, name): return self.ds_cfg['disk_aliases'].get(name) + @azure_ds_telemetry_reporter + def get_public_ssh_keys(self): + """ + Try to get the ssh keys from IMDS first, and if that fails + (i.e. IMDS is unavailable) then fallback to getting the ssh + keys from OVF. + + The benefit to getting keys from IMDS is a large performance + advantage, so this is a strong preference. But we must keep + OVF as a second option for environments that don't have IMDS. + """ + LOG.debug('Retrieving public SSH keys') + ssh_keys = [] + try: + ssh_keys = [ + public_key['keyData'] + for public_key + in self.metadata['imds']['compute']['publicKeys'] + ] + LOG.debug('Retrieved SSH keys from IMDS') + except KeyError: + log_msg = 'Unable to get keys from IMDS, falling back to OVF' + LOG.debug(log_msg) + report_diagnostic_event(log_msg) + try: + ssh_keys = self.metadata['public-keys'] + LOG.debug('Retrieved keys from OVF') + except KeyError: + log_msg = 'No keys available from OVF' + LOG.debug(log_msg) + report_diagnostic_event(log_msg) + + return ssh_keys + def get_config_obj(self): return self.cfg @@ -764,7 +798,22 @@ class DataSourceAzure(sources.DataSource): if self.ds_cfg['agent_command'] == AGENT_START_BUILTIN: self.bounce_network_with_azure_hostname() - pubkey_info = self.cfg.get('_pubkeys', None) + pubkey_info = None + try: + public_keys = self.metadata['imds']['compute']['publicKeys'] + LOG.debug( + 'Successfully retrieved %s key(s) from IMDS', + len(public_keys) + if public_keys is not None + else 0 + ) + except KeyError: + LOG.debug( + 'Unable to retrieve SSH keys from IMDS during ' + 'negotiation, falling back to OVF' + ) + pubkey_info = self.cfg.get('_pubkeys', None) + metadata_func = partial(get_metadata_from_fabric, fallback_lease_file=self. dhclient_lease_file, @@ -1443,7 +1492,7 @@ def get_metadata_from_imds(fallback_nic, retries): @azure_ds_telemetry_reporter def _get_metadata_from_imds(retries): - url = IMDS_URL + "instance?api-version=2017-12-01" + url = IMDS_URL + "instance?api-version=2019-06-01" headers = {"Metadata": "true"} try: response = readurl( diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 507f6ac8..79445a81 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -288,12 +288,16 @@ class InvalidGoalStateXMLException(Exception): class GoalState: - def __init__(self, unparsed_xml: str, - azure_endpoint_client: AzureEndpointHttpClient) -> None: + def __init__( + self, + unparsed_xml: str, + azure_endpoint_client: AzureEndpointHttpClient, + need_certificate: bool = True) -> None: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_xml: string representing a GoalState XML. - @param azure_endpoint_client: instance of AzureEndpointHttpClient + @param azure_endpoint_client: instance of AzureEndpointHttpClient. + @param need_certificate: switch to know if certificates is needed. @return: GoalState object representing the GoalState XML string. """ self.azure_endpoint_client = azure_endpoint_client @@ -322,7 +326,7 @@ class GoalState: url = self._text_from_xpath( './Container/RoleInstanceList/RoleInstance' '/Configuration/Certificates') - if url is not None: + if url is not None and need_certificate: with events.ReportEventStack( name="get-certificates-xml", description="get certificates xml", @@ -741,27 +745,38 @@ class WALinuxAgentShim: GoalState. @return: The list of user's authorized pubkey values. """ - if self.openssl_manager is None: + http_client_certificate = None + if self.openssl_manager is None and pubkey_info is not None: self.openssl_manager = OpenSSLManager() + http_client_certificate = self.openssl_manager.certificate if self.azure_endpoint_client is None: self.azure_endpoint_client = AzureEndpointHttpClient( - self.openssl_manager.certificate) - goal_state = self._fetch_goal_state_from_azure() - ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) + http_client_certificate) + goal_state = self._fetch_goal_state_from_azure( + need_certificate=http_client_certificate is not None + ) + ssh_keys = None + if pubkey_info is not None: + ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) health_reporter = GoalStateHealthReporter( goal_state, self.azure_endpoint_client, self.endpoint) health_reporter.send_ready_signal() return {'public-keys': ssh_keys} @azure_ds_telemetry_reporter - def _fetch_goal_state_from_azure(self) -> GoalState: + def _fetch_goal_state_from_azure( + self, + need_certificate: bool) -> GoalState: """Fetches the GoalState XML from the Azure endpoint, parses the XML, and returns a GoalState object. @return: GoalState object representing the GoalState XML """ unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure() - return self._parse_raw_goal_state_xml(unparsed_goal_state_xml) + return self._parse_raw_goal_state_xml( + unparsed_goal_state_xml, + need_certificate + ) @azure_ds_telemetry_reporter def _get_raw_goal_state_xml_from_azure(self) -> str: @@ -774,7 +789,11 @@ class WALinuxAgentShim: LOG.info('Registering with Azure...') url = 'http://{}/machine/?comp=goalstate'.format(self.endpoint) try: - response = self.azure_endpoint_client.get(url) + with events.ReportEventStack( + name="goalstate-retrieval", + description="retrieve goalstate", + parent=azure_ds_reporter): + response = self.azure_endpoint_client.get(url) except Exception as e: msg = 'failed to register with Azure: %s' % e LOG.warning(msg) @@ -785,7 +804,9 @@ class WALinuxAgentShim: @azure_ds_telemetry_reporter def _parse_raw_goal_state_xml( - self, unparsed_goal_state_xml: str) -> GoalState: + self, + unparsed_goal_state_xml: str, + need_certificate: bool) -> GoalState: """Parses a GoalState XML string and returns a GoalState object. @param unparsed_goal_state_xml: GoalState XML string @@ -793,7 +814,10 @@ class WALinuxAgentShim: """ try: goal_state = GoalState( - unparsed_goal_state_xml, self.azure_endpoint_client) + unparsed_goal_state_xml, + self.azure_endpoint_client, + need_certificate + ) except Exception as e: msg = 'Error processing GoalState XML: %s' % e LOG.warning(msg) diff --git a/doc/rtd/topics/datasources/azure.rst b/doc/rtd/topics/datasources/azure.rst index fdb919a5..e04c3a33 100644 --- a/doc/rtd/topics/datasources/azure.rst +++ b/doc/rtd/topics/datasources/azure.rst @@ -68,6 +68,12 @@ configuration information to the instance. Cloud-init uses the IMDS for: - network configuration for the instance which is applied per boot - a preprovisioing gate which blocks instance configuration until Azure fabric is ready to provision +- retrieving SSH public keys. Cloud-init will first try to utilize SSH keys + returned from IMDS, and if they are not provided from IMDS then it will + fallback to using the OVF file provided from the CD-ROM. There is a large + performance benefit to using IMDS for SSH key retrieval, but in order to + support environments where IMDS is not available then we must continue to + all for keys from OVF Configuration diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 47e03bd1..2dda9925 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -102,7 +102,13 @@ NETWORK_METADATA = { "vmId": "ff702a6b-cb6a-4fcd-ad68-b4ce38227642", "vmScaleSetName": "", "vmSize": "Standard_DS1_v2", - "zone": "" + "zone": "", + "publicKeys": [ + { + "keyData": "key1", + "path": "path1" + } + ] }, "network": { "interface": [ @@ -302,7 +308,7 @@ class TestGetMetadataFromIMDS(HttprettyTestCase): def setUp(self): super(TestGetMetadataFromIMDS, self).setUp() - self.network_md_url = dsaz.IMDS_URL + "instance?api-version=2017-12-01" + self.network_md_url = dsaz.IMDS_URL + "instance?api-version=2019-06-01" @mock.patch(MOCKPATH + 'readurl') @mock.patch(MOCKPATH + 'EphemeralDHCPv4') @@ -1304,6 +1310,40 @@ scbus-1 on xpt0 bus 0 dsaz.get_hostname(hostname_command=("hostname",)) m_subp.assert_called_once_with(("hostname",), capture=True) + @mock.patch( + 'cloudinit.sources.helpers.azure.OpenSSLManager.parse_certificates') + def test_get_public_ssh_keys_with_imds(self, m_parse_certificates): + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + dsrc = self._get_ds(data) + dsrc.get_data() + dsrc.setup(True) + ssh_keys = dsrc.get_public_ssh_keys() + self.assertEqual(ssh_keys, ['key1']) + self.assertEqual(m_parse_certificates.call_count, 0) + + @mock.patch(MOCKPATH + 'get_metadata_from_imds') + def test_get_public_ssh_keys_without_imds( + self, + m_get_metadata_from_imds): + m_get_metadata_from_imds.return_value = dict() + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + dsrc = self._get_ds(data) + dsaz.get_metadata_from_fabric.return_value = {'public-keys': ['key2']} + dsrc.get_data() + dsrc.setup(True) + ssh_keys = dsrc.get_public_ssh_keys() + self.assertEqual(ssh_keys, ['key2']) + class TestAzureBounce(CiTestCase): @@ -2094,14 +2134,18 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): md, _ud, cfg, _d = dsa._reprovision() self.assertEqual(md['local-hostname'], hostname) self.assertEqual(cfg['system_info']['default_user']['name'], username) - self.assertEqual(fake_resp.call_args_list, - [mock.call(allow_redirects=True, - headers={'Metadata': 'true', - 'User-Agent': - 'Cloud-Init/%s' % vs()}, - method='GET', - timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, - url=full_url)]) + self.assertIn( + mock.call( + allow_redirects=True, + headers={ + 'Metadata': 'true', + 'User-Agent': 'Cloud-Init/%s' % vs() + }, + method='GET', + timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, + url=full_url + ), + fake_resp.call_args_list) self.assertEqual(m_dhcp.call_count, 2) m_net.assert_any_call( broadcast='192.168.2.255', interface='eth9', ip='192.168.2.9', diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 5e6d3d2d..5c31b8be 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -609,11 +609,11 @@ class TestWALinuxAgentShim(CiTestCase): self.GoalState.return_value.container_id = self.test_container_id self.GoalState.return_value.instance_id = self.test_instance_id - def test_azure_endpoint_client_uses_certificate_during_report_ready(self): + def test_http_client_does_not_use_certificate(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( - [mock.call(self.OpenSSLManager.return_value.certificate)], + [mock.call(None)], self.AzureEndpointHttpClient.call_args_list) def test_correct_url_used_for_goalstate_during_report_ready(self): @@ -625,8 +625,11 @@ class TestWALinuxAgentShim(CiTestCase): [mock.call('http://test_endpoint/machine/?comp=goalstate')], get.call_args_list) self.assertEqual( - [mock.call(get.return_value.contents, - self.AzureEndpointHttpClient.return_value)], + [mock.call( + get.return_value.contents, + self.AzureEndpointHttpClient.return_value, + False + )], self.GoalState.call_args_list) def test_certificates_used_to_determine_public_keys(self): @@ -701,7 +704,7 @@ class TestWALinuxAgentShim(CiTestCase): shim.register_with_azure_and_fetch_data() shim.clean_up() self.assertEqual( - 1, self.OpenSSLManager.return_value.clean_up.call_count) + 0, self.OpenSSLManager.return_value.clean_up.call_count) def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): self.AzureEndpointHttpClient.return_value.get \ |