diff options
-rwxr-xr-x | cloudinit/sources/DataSourceAzure.py | 149 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_azure.py | 87 |
2 files changed, 205 insertions, 31 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index eac6405a..f4fc91cd 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -5,6 +5,7 @@ # This file is part of cloud-init. See LICENSE file for license information. import base64 +from collections import namedtuple import contextlib import crypt from functools import partial @@ -25,6 +26,7 @@ from cloudinit.net import device_driver from cloudinit.net.dhcp import EphemeralDHCPv4 from cloudinit import sources from cloudinit.sources.helpers import netlink +from cloudinit import ssh_util from cloudinit import subp from cloudinit.url_helper import UrlError, readurl, retry_on_url_exc from cloudinit import util @@ -80,7 +82,12 @@ AGENT_SEED_DIR = '/var/lib/waagent' IMDS_TIMEOUT_IN_SECONDS = 2 IMDS_URL = "http://169.254.169.254/metadata" IMDS_VER_MIN = "2019-06-01" -IMDS_VER_WANT = "2020-09-01" +IMDS_VER_WANT = "2020-10-01" + + +# This holds SSH key data including if the source was +# from IMDS, as well as the SSH key data itself. +SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys")) class metadata_type(Enum): @@ -391,6 +398,8 @@ class DataSourceAzure(sources.DataSource): """Return the subplatform metadata source details.""" if self.seed.startswith('/dev'): subplatform_type = 'config-disk' + elif self.seed.lower() == 'imds': + subplatform_type = 'imds' else: subplatform_type = 'seed-dir' return '%s (%s)' % (subplatform_type, self.seed) @@ -433,9 +442,11 @@ class DataSourceAzure(sources.DataSource): found = None reprovision = False + ovf_is_accessible = True reprovision_after_nic_attach = False for cdev in candidates: try: + LOG.debug("cdev: %s", cdev) if cdev == "IMDS": ret = None reprovision = True @@ -462,8 +473,18 @@ class DataSourceAzure(sources.DataSource): raise sources.InvalidMetaDataException(msg) except util.MountFailedError: report_diagnostic_event( - '%s was not mountable' % cdev, logger_func=LOG.warning) - continue + '%s was not mountable' % cdev, logger_func=LOG.debug) + cdev = 'IMDS' + ovf_is_accessible = False + empty_md = {'local-hostname': ''} + empty_cfg = dict( + system_info=dict( + default_user=dict( + name='' + ) + ) + ) + ret = (empty_md, '', empty_cfg, {}) report_diagnostic_event("Found provisioning metadata in %s" % cdev, logger_func=LOG.debug) @@ -490,6 +511,10 @@ class DataSourceAzure(sources.DataSource): self.fallback_interface, retries=10 ) + if not imds_md and not ovf_is_accessible: + msg = 'No OVF or IMDS available' + report_diagnostic_event(msg) + raise sources.InvalidMetaDataException(msg) (md, userdata_raw, cfg, files) = ret self.seed = cdev crawled_data.update({ @@ -498,6 +523,21 @@ class DataSourceAzure(sources.DataSource): 'metadata': util.mergemanydict( [md, {'imds': imds_md}]), 'userdata_raw': userdata_raw}) + imds_username = _username_from_imds(imds_md) + imds_hostname = _hostname_from_imds(imds_md) + imds_disable_password = _disable_password_from_imds(imds_md) + if imds_username: + LOG.debug('Username retrieved from IMDS: %s', imds_username) + cfg['system_info']['default_user']['name'] = imds_username + if imds_hostname: + LOG.debug('Hostname retrieved from IMDS: %s', imds_hostname) + crawled_data['metadata']['local-hostname'] = imds_hostname + if imds_disable_password: + LOG.debug( + 'Disable password retrieved from IMDS: %s', + imds_disable_password + ) + crawled_data['metadata']['disable_password'] = imds_disable_password # noqa: E501 found = cdev report_diagnostic_event( @@ -677,6 +717,13 @@ class DataSourceAzure(sources.DataSource): @azure_ds_telemetry_reporter def get_public_ssh_keys(self): """ + Retrieve public SSH keys. + """ + + return self._get_public_ssh_keys_and_source().ssh_keys + + def _get_public_ssh_keys_and_source(self): + """ Try to get the ssh keys from IMDS first, and if that fails (i.e. IMDS is unavailable) then fallback to getting the ssh keys from OVF. @@ -685,30 +732,50 @@ class DataSourceAzure(sources.DataSource): advantage, so this is a strong preference. But we must keep OVF as a second option for environments that don't have IMDS. """ + LOG.debug('Retrieving public SSH keys') ssh_keys = [] + keys_from_imds = True + LOG.debug('Attempting to get SSH keys from IMDS') try: - raise KeyError( - "Not using public SSH keys from IMDS" - ) - # pylint:disable=unreachable ssh_keys = [ public_key['keyData'] for public_key in self.metadata['imds']['compute']['publicKeys'] ] - LOG.debug('Retrieved SSH keys from IMDS') + for key in ssh_keys: + if not _key_is_openssh_formatted(key=key): + keys_from_imds = False + break + + if not keys_from_imds: + log_msg = 'Keys not in OpenSSH format, using OVF' + else: + log_msg = 'Retrieved {} keys from IMDS'.format( + len(ssh_keys) + if ssh_keys is not None + else 0 + ) except KeyError: log_msg = 'Unable to get keys from IMDS, falling back to OVF' + keys_from_imds = False + finally: report_diagnostic_event(log_msg, logger_func=LOG.debug) + + if not keys_from_imds: + LOG.debug('Attempting to get SSH keys from OVF') try: ssh_keys = self.metadata['public-keys'] - LOG.debug('Retrieved keys from OVF') + log_msg = 'Retrieved {} keys from OVF'.format(len(ssh_keys)) except KeyError: log_msg = 'No keys available from OVF' + finally: report_diagnostic_event(log_msg, logger_func=LOG.debug) - return ssh_keys + return SSHKeys( + keys_from_imds=keys_from_imds, + ssh_keys=ssh_keys + ) def get_config_obj(self): return self.cfg @@ -1325,30 +1392,21 @@ class DataSourceAzure(sources.DataSource): self.bounce_network_with_azure_hostname() pubkey_info = None - try: - raise KeyError( - "Not using public SSH keys from IMDS" - ) - # pylint:disable=unreachable - public_keys = self.metadata['imds']['compute']['publicKeys'] - LOG.debug( - 'Successfully retrieved %s key(s) from IMDS', - len(public_keys) - if public_keys is not None + ssh_keys_and_source = self._get_public_ssh_keys_and_source() + + if not ssh_keys_and_source.keys_from_imds: + pubkey_info = self.cfg.get('_pubkeys', None) + log_msg = 'Retrieved {} fingerprints from OVF'.format( + len(pubkey_info) + if pubkey_info is not None else 0 ) - except KeyError: - LOG.debug( - 'Unable to retrieve SSH keys from IMDS during ' - 'negotiation, falling back to OVF' - ) - pubkey_info = self.cfg.get('_pubkeys', None) + report_diagnostic_event(log_msg, logger_func=LOG.debug) metadata_func = partial(get_metadata_from_fabric, fallback_lease_file=self. dhclient_lease_file, - pubkey_info=pubkey_info, - iso_dev=self.iso_dev) + pubkey_info=pubkey_info) LOG.debug("negotiating with fabric via agent command %s", self.ds_cfg['agent_command']) @@ -1404,6 +1462,41 @@ class DataSourceAzure(sources.DataSource): return self.metadata.get('imds', {}).get('compute', {}).get('location') +def _username_from_imds(imds_data): + try: + return imds_data['compute']['osProfile']['adminUsername'] + except KeyError: + return None + + +def _hostname_from_imds(imds_data): + try: + return imds_data['compute']['osProfile']['computerName'] + except KeyError: + return None + + +def _disable_password_from_imds(imds_data): + try: + return imds_data['compute']['osProfile']['disablePasswordAuthentication'] == 'true' # noqa: E501 + except KeyError: + return None + + +def _key_is_openssh_formatted(key): + """ + Validate whether or not the key is OpenSSH-formatted. + """ + + parser = ssh_util.AuthKeyLineParser() + try: + akl = parser.parse(key) + except TypeError: + return False + + return akl.keytype is not None + + def _partitions_on_device(devpath, maxnum=16): # return a list of tuples (ptnum, path) for each part on devpath for suff in ("-part", "p", ""): diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py index 320fa857..d9817d84 100644 --- a/tests/unittests/test_datasource/test_azure.py +++ b/tests/unittests/test_datasource/test_azure.py @@ -108,7 +108,7 @@ NETWORK_METADATA = { "zone": "", "publicKeys": [ { - "keyData": "key1", + "keyData": "ssh-rsa key1", "path": "path1" } ] @@ -1761,8 +1761,29 @@ scbus-1 on xpt0 bus 0 dsrc.get_data() dsrc.setup(True) ssh_keys = dsrc.get_public_ssh_keys() - # Temporarily alter this test so that SSH public keys - # from IMDS are *not* going to be in use to fix a regression. + self.assertEqual(ssh_keys, ["ssh-rsa key1"]) + self.assertEqual(m_parse_certificates.call_count, 0) + + @mock.patch( + 'cloudinit.sources.helpers.azure.OpenSSLManager.parse_certificates') + @mock.patch(MOCKPATH + 'get_metadata_from_imds') + def test_get_public_ssh_keys_with_no_openssh_format( + self, + m_get_metadata_from_imds, + m_parse_certificates): + imds_data = copy.deepcopy(NETWORK_METADATA) + imds_data['compute']['publicKeys'][0]['keyData'] = 'no-openssh-format' + m_get_metadata_from_imds.return_value = imds_data + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + dsrc = self._get_ds(data) + dsrc.get_data() + dsrc.setup(True) + ssh_keys = dsrc.get_public_ssh_keys() self.assertEqual(ssh_keys, []) self.assertEqual(m_parse_certificates.call_count, 0) @@ -1818,6 +1839,66 @@ scbus-1 on xpt0 bus 0 self.assertIsNotNone(dsrc.metadata) self.assertFalse(dsrc.failed_desired_api_version) + @mock.patch(MOCKPATH + 'get_metadata_from_imds') + def test_hostname_from_imds(self, m_get_metadata_from_imds): + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + imds_data_with_os_profile = copy.deepcopy(NETWORK_METADATA) + imds_data_with_os_profile["compute"]["osProfile"] = dict( + adminUsername="username1", + computerName="hostname1", + disablePasswordAuthentication="true" + ) + m_get_metadata_from_imds.return_value = imds_data_with_os_profile + dsrc = self._get_ds(data) + dsrc.get_data() + self.assertEqual(dsrc.metadata["local-hostname"], "hostname1") + + @mock.patch(MOCKPATH + 'get_metadata_from_imds') + def test_username_from_imds(self, m_get_metadata_from_imds): + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + imds_data_with_os_profile = copy.deepcopy(NETWORK_METADATA) + imds_data_with_os_profile["compute"]["osProfile"] = dict( + adminUsername="username1", + computerName="hostname1", + disablePasswordAuthentication="true" + ) + m_get_metadata_from_imds.return_value = imds_data_with_os_profile + dsrc = self._get_ds(data) + dsrc.get_data() + self.assertEqual( + dsrc.cfg["system_info"]["default_user"]["name"], + "username1" + ) + + @mock.patch(MOCKPATH + 'get_metadata_from_imds') + def test_disable_password_from_imds(self, m_get_metadata_from_imds): + sys_cfg = {'datasource': {'Azure': {'apply_network_config': True}}} + odata = {'HostName': "myhost", 'UserName': "myuser"} + data = { + 'ovfcontent': construct_valid_ovf_env(data=odata), + 'sys_cfg': sys_cfg + } + imds_data_with_os_profile = copy.deepcopy(NETWORK_METADATA) + imds_data_with_os_profile["compute"]["osProfile"] = dict( + adminUsername="username1", + computerName="hostname1", + disablePasswordAuthentication="true" + ) + m_get_metadata_from_imds.return_value = imds_data_with_os_profile + dsrc = self._get_ds(data) + dsrc.get_data() + self.assertTrue(dsrc.metadata["disable_password"]) + class TestAzureBounce(CiTestCase): |