diff options
Diffstat (limited to 'cloudinit/sources')
-rwxr-xr-x | cloudinit/sources/DataSourceAzure.py | 115 | ||||
-rwxr-xr-x | cloudinit/sources/helpers/azure.py | 5 |
2 files changed, 54 insertions, 66 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 44efd358..f4be4cda 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -11,11 +11,9 @@ import os import os.path import re import xml.etree.ElementTree as ET -from collections import namedtuple from enum import Enum -from functools import partial from time import sleep, time -from typing import Optional +from typing import List, Optional from xml.dom import minidom import requests @@ -71,10 +69,6 @@ IMDS_VER_MIN = "2019-06-01" IMDS_VER_WANT = "2021-08-01" IMDS_EXTENDED_VER_MIN = "2021-03-01" -# This holds SSH key data including if the source was -# from IMDS, as well as the SSH key data itself. -SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys")) - class MetadataType(Enum): ALL = "{}/instance".format(IMDS_URL) @@ -740,63 +734,59 @@ class DataSourceAzure(sources.DataSource): return self.ds_cfg["disk_aliases"].get(name) @azure_ds_telemetry_reporter - def get_public_ssh_keys(self): + def get_public_ssh_keys(self) -> List[str]: """ Retrieve public SSH keys. """ + try: + return self._get_public_keys_from_imds(self.metadata["imds"]) + except (KeyError, ValueError): + pass - return self._get_public_ssh_keys_and_source().ssh_keys + return self._get_public_keys_from_ovf() - def _get_public_ssh_keys_and_source(self): - """ - Try to get the ssh keys from IMDS first, and if that fails - (i.e. IMDS is unavailable) then fallback to getting the ssh - keys from OVF. + def _get_public_keys_from_imds(self, imds_md: dict) -> List[str]: + """Get SSH keys from IMDS metadata. - The benefit to getting keys from IMDS is a large performance - advantage, so this is a strong preference. But we must keep - OVF as a second option for environments that don't have IMDS. - """ + :raises KeyError: if IMDS metadata is malformed/missing. + :raises ValueError: if key format is not supported. - LOG.debug("Retrieving public SSH keys") - ssh_keys = [] - keys_from_imds = True - LOG.debug("Attempting to get SSH keys from IMDS") + :returns: List of keys. + """ try: ssh_keys = [ public_key["keyData"] - for public_key in self.metadata["imds"]["compute"][ - "publicKeys" - ] + for public_key in imds_md["compute"]["publicKeys"] ] - for key in ssh_keys: - if not _key_is_openssh_formatted(key=key): - keys_from_imds = False - break - - if not keys_from_imds: - log_msg = "Keys not in OpenSSH format, using OVF" - else: - log_msg = "Retrieved {} keys from IMDS".format( - len(ssh_keys) if ssh_keys is not None else 0 - ) except KeyError: - log_msg = "Unable to get keys from IMDS, falling back to OVF" - keys_from_imds = False - finally: + log_msg = "No SSH keys found in IMDS metadata" report_diagnostic_event(log_msg, logger_func=LOG.debug) + raise - if not keys_from_imds: - LOG.debug("Attempting to get SSH keys from OVF") - try: - ssh_keys = self.metadata["public-keys"] - log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys)) - except KeyError: - log_msg = "No keys available from OVF" - finally: - report_diagnostic_event(log_msg, logger_func=LOG.debug) + if any(not _key_is_openssh_formatted(key=key) for key in ssh_keys): + log_msg = "Key(s) not in OpenSSH format" + report_diagnostic_event(log_msg, logger_func=LOG.debug) + raise ValueError(log_msg) + + log_msg = "Retrieved {} keys from IMDS".format(len(ssh_keys)) + report_diagnostic_event(log_msg, logger_func=LOG.debug) + return ssh_keys - return SSHKeys(keys_from_imds=keys_from_imds, ssh_keys=ssh_keys) + def _get_public_keys_from_ovf(self) -> List[str]: + """Get SSH keys that were fetched from wireserver. + + :returns: List of keys. + """ + ssh_keys = [] + try: + ssh_keys = self.metadata["public-keys"] + log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys)) + report_diagnostic_event(log_msg, logger_func=LOG.debug) + except KeyError: + log_msg = "No keys available from OVF" + report_diagnostic_event(log_msg, logger_func=LOG.debug) + + return ssh_keys def get_config_obj(self): return self.cfg @@ -832,10 +822,10 @@ class DataSourceAzure(sources.DataSource): self.get_instance_id(), is_new_instance, ) - fabric_data = self._negotiate() - LOG.debug("negotiating returned %s", fabric_data) - if fabric_data: - self.metadata.update(fabric_data) + ssh_keys = self._negotiate() + LOG.debug("negotiating returned %s", ssh_keys) + if ssh_keys: + self.metadata["public-keys"] = ssh_keys self._negotiated = True else: LOG.debug( @@ -1462,24 +1452,21 @@ class DataSourceAzure(sources.DataSource): On failure, returns False. """ pubkey_info = None - ssh_keys_and_source = self._get_public_ssh_keys_and_source() - - if not ssh_keys_and_source.keys_from_imds: + try: + self._get_public_keys_from_imds(self.metadata["imds"]) + except (KeyError, ValueError): pubkey_info = self.cfg.get("_pubkeys", None) log_msg = "Retrieved {} fingerprints from OVF".format( len(pubkey_info) if pubkey_info is not None else 0 ) report_diagnostic_event(log_msg, logger_func=LOG.debug) - metadata_func = partial( - get_metadata_from_fabric, - fallback_lease_file=self.dhclient_lease_file, - pubkey_info=pubkey_info, - ) - LOG.debug("negotiating with fabric") try: - fabric_data = metadata_func() + ssh_keys = get_metadata_from_fabric( + fallback_lease_file=self.dhclient_lease_file, + pubkey_info=pubkey_info, + ) except Exception as e: report_diagnostic_event( "Error communicating with Azure fabric; You may experience " @@ -1491,7 +1478,7 @@ class DataSourceAzure(sources.DataSource): util.del_file(REPORTED_READY_MARKER_FILE) util.del_file(REPROVISION_MARKER_FILE) util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE) - return fabric_data + return ssh_keys @azure_ds_telemetry_reporter def activate(self, cfg, is_new_instance): diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 8e8f5ce5..ec6ab80c 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -12,6 +12,7 @@ import zlib from contextlib import contextmanager from datetime import datetime from errno import ENOENT +from typing import List, Optional from xml.etree import ElementTree from xml.sax.saxutils import escape @@ -1004,7 +1005,7 @@ class WALinuxAgentShim: @azure_ds_telemetry_reporter def register_with_azure_and_fetch_data( self, pubkey_info=None, iso_dev=None - ) -> dict: + ) -> Optional[List[str]]: """Gets the VM's GoalState from Azure, uses the GoalState information to report ready/send the ready signal/provisioning complete signal to Azure, and then uses pubkey_info to filter and obtain the user's @@ -1037,7 +1038,7 @@ class WALinuxAgentShim: self.eject_iso(iso_dev) health_reporter.send_ready_signal() - return {"public-keys": ssh_keys} + return ssh_keys @azure_ds_telemetry_reporter def register_with_azure_and_report_failure(self, description: str) -> None: |