diff options
| -rwxr-xr-x | cloudinit/sources/DataSourceAzure.py | 115 | ||||
| -rwxr-xr-x | cloudinit/sources/helpers/azure.py | 5 | ||||
| -rw-r--r-- | tests/unittests/sources/test_azure.py | 10 | ||||
| -rw-r--r-- | tests/unittests/sources/test_azure_helper.py | 8 | 
4 files changed, 62 insertions, 76 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 44efd358..f4be4cda 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -11,11 +11,9 @@ import os  import os.path  import re  import xml.etree.ElementTree as ET -from collections import namedtuple  from enum import Enum -from functools import partial  from time import sleep, time -from typing import Optional +from typing import List, Optional  from xml.dom import minidom  import requests @@ -71,10 +69,6 @@ IMDS_VER_MIN = "2019-06-01"  IMDS_VER_WANT = "2021-08-01"  IMDS_EXTENDED_VER_MIN = "2021-03-01" -# This holds SSH key data including if the source was -# from IMDS, as well as the SSH key data itself. -SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys")) -  class MetadataType(Enum):      ALL = "{}/instance".format(IMDS_URL) @@ -740,63 +734,59 @@ class DataSourceAzure(sources.DataSource):          return self.ds_cfg["disk_aliases"].get(name)      @azure_ds_telemetry_reporter -    def get_public_ssh_keys(self): +    def get_public_ssh_keys(self) -> List[str]:          """          Retrieve public SSH keys.          """ +        try: +            return self._get_public_keys_from_imds(self.metadata["imds"]) +        except (KeyError, ValueError): +            pass -        return self._get_public_ssh_keys_and_source().ssh_keys +        return self._get_public_keys_from_ovf() -    def _get_public_ssh_keys_and_source(self): -        """ -        Try to get the ssh keys from IMDS first, and if that fails -        (i.e. IMDS is unavailable) then fallback to getting the ssh -        keys from OVF. +    def _get_public_keys_from_imds(self, imds_md: dict) -> List[str]: +        """Get SSH keys from IMDS metadata. -        The benefit to getting keys from IMDS is a large performance -        advantage, so this is a strong preference. But we must keep -        OVF as a second option for environments that don't have IMDS. -        """ +        :raises KeyError: if IMDS metadata is malformed/missing. +        :raises ValueError: if key format is not supported. -        LOG.debug("Retrieving public SSH keys") -        ssh_keys = [] -        keys_from_imds = True -        LOG.debug("Attempting to get SSH keys from IMDS") +        :returns: List of keys. +        """          try:              ssh_keys = [                  public_key["keyData"] -                for public_key in self.metadata["imds"]["compute"][ -                    "publicKeys" -                ] +                for public_key in imds_md["compute"]["publicKeys"]              ] -            for key in ssh_keys: -                if not _key_is_openssh_formatted(key=key): -                    keys_from_imds = False -                    break - -            if not keys_from_imds: -                log_msg = "Keys not in OpenSSH format, using OVF" -            else: -                log_msg = "Retrieved {} keys from IMDS".format( -                    len(ssh_keys) if ssh_keys is not None else 0 -                )          except KeyError: -            log_msg = "Unable to get keys from IMDS, falling back to OVF" -            keys_from_imds = False -        finally: +            log_msg = "No SSH keys found in IMDS metadata"              report_diagnostic_event(log_msg, logger_func=LOG.debug) +            raise -        if not keys_from_imds: -            LOG.debug("Attempting to get SSH keys from OVF") -            try: -                ssh_keys = self.metadata["public-keys"] -                log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys)) -            except KeyError: -                log_msg = "No keys available from OVF" -            finally: -                report_diagnostic_event(log_msg, logger_func=LOG.debug) +        if any(not _key_is_openssh_formatted(key=key) for key in ssh_keys): +            log_msg = "Key(s) not in OpenSSH format" +            report_diagnostic_event(log_msg, logger_func=LOG.debug) +            raise ValueError(log_msg) + +        log_msg = "Retrieved {} keys from IMDS".format(len(ssh_keys)) +        report_diagnostic_event(log_msg, logger_func=LOG.debug) +        return ssh_keys -        return SSHKeys(keys_from_imds=keys_from_imds, ssh_keys=ssh_keys) +    def _get_public_keys_from_ovf(self) -> List[str]: +        """Get SSH keys that were fetched from wireserver. + +        :returns: List of keys. +        """ +        ssh_keys = [] +        try: +            ssh_keys = self.metadata["public-keys"] +            log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys)) +            report_diagnostic_event(log_msg, logger_func=LOG.debug) +        except KeyError: +            log_msg = "No keys available from OVF" +            report_diagnostic_event(log_msg, logger_func=LOG.debug) + +        return ssh_keys      def get_config_obj(self):          return self.cfg @@ -832,10 +822,10 @@ class DataSourceAzure(sources.DataSource):                  self.get_instance_id(),                  is_new_instance,              ) -            fabric_data = self._negotiate() -            LOG.debug("negotiating returned %s", fabric_data) -            if fabric_data: -                self.metadata.update(fabric_data) +            ssh_keys = self._negotiate() +            LOG.debug("negotiating returned %s", ssh_keys) +            if ssh_keys: +                self.metadata["public-keys"] = ssh_keys              self._negotiated = True          else:              LOG.debug( @@ -1462,24 +1452,21 @@ class DataSourceAzure(sources.DataSource):          On failure, returns False.          """          pubkey_info = None -        ssh_keys_and_source = self._get_public_ssh_keys_and_source() - -        if not ssh_keys_and_source.keys_from_imds: +        try: +            self._get_public_keys_from_imds(self.metadata["imds"]) +        except (KeyError, ValueError):              pubkey_info = self.cfg.get("_pubkeys", None)              log_msg = "Retrieved {} fingerprints from OVF".format(                  len(pubkey_info) if pubkey_info is not None else 0              )              report_diagnostic_event(log_msg, logger_func=LOG.debug) -        metadata_func = partial( -            get_metadata_from_fabric, -            fallback_lease_file=self.dhclient_lease_file, -            pubkey_info=pubkey_info, -        ) -          LOG.debug("negotiating with fabric")          try: -            fabric_data = metadata_func() +            ssh_keys = get_metadata_from_fabric( +                fallback_lease_file=self.dhclient_lease_file, +                pubkey_info=pubkey_info, +            )          except Exception as e:              report_diagnostic_event(                  "Error communicating with Azure fabric; You may experience " @@ -1491,7 +1478,7 @@ class DataSourceAzure(sources.DataSource):          util.del_file(REPORTED_READY_MARKER_FILE)          util.del_file(REPROVISION_MARKER_FILE)          util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE) -        return fabric_data +        return ssh_keys      @azure_ds_telemetry_reporter      def activate(self, cfg, is_new_instance): diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 8e8f5ce5..ec6ab80c 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -12,6 +12,7 @@ import zlib  from contextlib import contextmanager  from datetime import datetime  from errno import ENOENT +from typing import List, Optional  from xml.etree import ElementTree  from xml.sax.saxutils import escape @@ -1004,7 +1005,7 @@ class WALinuxAgentShim:      @azure_ds_telemetry_reporter      def register_with_azure_and_fetch_data(          self, pubkey_info=None, iso_dev=None -    ) -> dict: +    ) -> Optional[List[str]]:          """Gets the VM's GoalState from Azure, uses the GoalState information          to report ready/send the ready signal/provisioning complete signal to          Azure, and then uses pubkey_info to filter and obtain the user's @@ -1037,7 +1038,7 @@ class WALinuxAgentShim:              self.eject_iso(iso_dev)          health_reporter.send_ready_signal() -        return {"public-keys": ssh_keys} +        return ssh_keys      @azure_ds_telemetry_reporter      def register_with_azure_and_report_failure(self, description: str) -> None: diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py index a6c43ea7..a47ed611 100644 --- a/tests/unittests/sources/test_azure.py +++ b/tests/unittests/sources/test_azure.py @@ -762,9 +762,7 @@ scbus-1 on xpt0 bus 0          dsaz.BUILTIN_DS_CONFIG["data_dir"] = self.waagent_d          self.m_is_platform_viable = mock.MagicMock(autospec=True) -        self.m_get_metadata_from_fabric = mock.MagicMock( -            return_value={"public-keys": []} -        ) +        self.m_get_metadata_from_fabric = mock.MagicMock(return_value=[])          self.m_report_failure_to_fabric = mock.MagicMock(autospec=True)          self.m_list_possible_azure_ds = mock.MagicMock(              side_effect=_load_possible_azure_ds @@ -1725,10 +1723,10 @@ scbus-1 on xpt0 bus 0      def test_fabric_data_included_in_metadata(self):          dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()}) -        self.m_get_metadata_from_fabric.return_value = {"test": "value"} +        self.m_get_metadata_from_fabric.return_value = ["ssh-key-value"]          ret = self._get_and_setup(dsrc)          self.assertTrue(ret) -        self.assertEqual("value", dsrc.metadata["test"]) +        self.assertEqual(["ssh-key-value"], dsrc.metadata["public-keys"])      def test_instance_id_case_insensitive(self):          """Return the previous iid when current is a case-insensitive match.""" @@ -2008,7 +2006,7 @@ scbus-1 on xpt0 bus 0              "sys_cfg": sys_cfg,          }          dsrc = self._get_ds(data) -        dsaz.get_metadata_from_fabric.return_value = {"public-keys": ["key2"]} +        dsaz.get_metadata_from_fabric.return_value = ["key2"]          dsrc.get_data()          dsrc.setup(True)          ssh_keys = dsrc.get_public_ssh_keys() diff --git a/tests/unittests/sources/test_azure_helper.py b/tests/unittests/sources/test_azure_helper.py index 6f7f2890..98143bc3 100644 --- a/tests/unittests/sources/test_azure_helper.py +++ b/tests/unittests/sources/test_azure_helper.py @@ -1204,16 +1204,16 @@ class TestWALinuxAgentShim(CiTestCase):              [mock.call(self.GoalState.return_value.certificates_xml)],              sslmgr.parse_certificates.call_args_list,          ) -        self.assertIn("expected-key", data["public-keys"]) -        self.assertIn("expected-no-value-key", data["public-keys"]) -        self.assertNotIn("should-not-be-found", data["public-keys"]) +        self.assertIn("expected-key", data) +        self.assertIn("expected-no-value-key", data) +        self.assertNotIn("should-not-be-found", data)      def test_absent_certificates_produces_empty_public_keys(self):          mypk = [{"fingerprint": "fp1", "path": "path1"}]          self.GoalState.return_value.certificates_xml = None          shim = wa_shim()          data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk) -        self.assertEqual([], data["public-keys"]) +        self.assertEqual([], data)      def test_correct_url_used_for_report_ready(self):          self.find_endpoint.return_value = "test_endpoint"  | 
