summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcloudinit/sources/DataSourceAzure.py115
-rwxr-xr-xcloudinit/sources/helpers/azure.py5
-rw-r--r--tests/unittests/sources/test_azure.py10
-rw-r--r--tests/unittests/sources/test_azure_helper.py8
4 files changed, 62 insertions, 76 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py
index 44efd358..f4be4cda 100755
--- a/cloudinit/sources/DataSourceAzure.py
+++ b/cloudinit/sources/DataSourceAzure.py
@@ -11,11 +11,9 @@ import os
import os.path
import re
import xml.etree.ElementTree as ET
-from collections import namedtuple
from enum import Enum
-from functools import partial
from time import sleep, time
-from typing import Optional
+from typing import List, Optional
from xml.dom import minidom
import requests
@@ -71,10 +69,6 @@ IMDS_VER_MIN = "2019-06-01"
IMDS_VER_WANT = "2021-08-01"
IMDS_EXTENDED_VER_MIN = "2021-03-01"
-# This holds SSH key data including if the source was
-# from IMDS, as well as the SSH key data itself.
-SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys"))
-
class MetadataType(Enum):
ALL = "{}/instance".format(IMDS_URL)
@@ -740,63 +734,59 @@ class DataSourceAzure(sources.DataSource):
return self.ds_cfg["disk_aliases"].get(name)
@azure_ds_telemetry_reporter
- def get_public_ssh_keys(self):
+ def get_public_ssh_keys(self) -> List[str]:
"""
Retrieve public SSH keys.
"""
+ try:
+ return self._get_public_keys_from_imds(self.metadata["imds"])
+ except (KeyError, ValueError):
+ pass
- return self._get_public_ssh_keys_and_source().ssh_keys
+ return self._get_public_keys_from_ovf()
- def _get_public_ssh_keys_and_source(self):
- """
- Try to get the ssh keys from IMDS first, and if that fails
- (i.e. IMDS is unavailable) then fallback to getting the ssh
- keys from OVF.
+ def _get_public_keys_from_imds(self, imds_md: dict) -> List[str]:
+ """Get SSH keys from IMDS metadata.
- The benefit to getting keys from IMDS is a large performance
- advantage, so this is a strong preference. But we must keep
- OVF as a second option for environments that don't have IMDS.
- """
+ :raises KeyError: if IMDS metadata is malformed/missing.
+ :raises ValueError: if key format is not supported.
- LOG.debug("Retrieving public SSH keys")
- ssh_keys = []
- keys_from_imds = True
- LOG.debug("Attempting to get SSH keys from IMDS")
+ :returns: List of keys.
+ """
try:
ssh_keys = [
public_key["keyData"]
- for public_key in self.metadata["imds"]["compute"][
- "publicKeys"
- ]
+ for public_key in imds_md["compute"]["publicKeys"]
]
- for key in ssh_keys:
- if not _key_is_openssh_formatted(key=key):
- keys_from_imds = False
- break
-
- if not keys_from_imds:
- log_msg = "Keys not in OpenSSH format, using OVF"
- else:
- log_msg = "Retrieved {} keys from IMDS".format(
- len(ssh_keys) if ssh_keys is not None else 0
- )
except KeyError:
- log_msg = "Unable to get keys from IMDS, falling back to OVF"
- keys_from_imds = False
- finally:
+ log_msg = "No SSH keys found in IMDS metadata"
report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ raise
- if not keys_from_imds:
- LOG.debug("Attempting to get SSH keys from OVF")
- try:
- ssh_keys = self.metadata["public-keys"]
- log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
- except KeyError:
- log_msg = "No keys available from OVF"
- finally:
- report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ if any(not _key_is_openssh_formatted(key=key) for key in ssh_keys):
+ log_msg = "Key(s) not in OpenSSH format"
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ raise ValueError(log_msg)
+
+ log_msg = "Retrieved {} keys from IMDS".format(len(ssh_keys))
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ return ssh_keys
- return SSHKeys(keys_from_imds=keys_from_imds, ssh_keys=ssh_keys)
+ def _get_public_keys_from_ovf(self) -> List[str]:
+ """Get SSH keys that were fetched from wireserver.
+
+ :returns: List of keys.
+ """
+ ssh_keys = []
+ try:
+ ssh_keys = self.metadata["public-keys"]
+ log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ except KeyError:
+ log_msg = "No keys available from OVF"
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+
+ return ssh_keys
def get_config_obj(self):
return self.cfg
@@ -832,10 +822,10 @@ class DataSourceAzure(sources.DataSource):
self.get_instance_id(),
is_new_instance,
)
- fabric_data = self._negotiate()
- LOG.debug("negotiating returned %s", fabric_data)
- if fabric_data:
- self.metadata.update(fabric_data)
+ ssh_keys = self._negotiate()
+ LOG.debug("negotiating returned %s", ssh_keys)
+ if ssh_keys:
+ self.metadata["public-keys"] = ssh_keys
self._negotiated = True
else:
LOG.debug(
@@ -1462,24 +1452,21 @@ class DataSourceAzure(sources.DataSource):
On failure, returns False.
"""
pubkey_info = None
- ssh_keys_and_source = self._get_public_ssh_keys_and_source()
-
- if not ssh_keys_and_source.keys_from_imds:
+ try:
+ self._get_public_keys_from_imds(self.metadata["imds"])
+ except (KeyError, ValueError):
pubkey_info = self.cfg.get("_pubkeys", None)
log_msg = "Retrieved {} fingerprints from OVF".format(
len(pubkey_info) if pubkey_info is not None else 0
)
report_diagnostic_event(log_msg, logger_func=LOG.debug)
- metadata_func = partial(
- get_metadata_from_fabric,
- fallback_lease_file=self.dhclient_lease_file,
- pubkey_info=pubkey_info,
- )
-
LOG.debug("negotiating with fabric")
try:
- fabric_data = metadata_func()
+ ssh_keys = get_metadata_from_fabric(
+ fallback_lease_file=self.dhclient_lease_file,
+ pubkey_info=pubkey_info,
+ )
except Exception as e:
report_diagnostic_event(
"Error communicating with Azure fabric; You may experience "
@@ -1491,7 +1478,7 @@ class DataSourceAzure(sources.DataSource):
util.del_file(REPORTED_READY_MARKER_FILE)
util.del_file(REPROVISION_MARKER_FILE)
util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE)
- return fabric_data
+ return ssh_keys
@azure_ds_telemetry_reporter
def activate(self, cfg, is_new_instance):
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index 8e8f5ce5..ec6ab80c 100755
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -12,6 +12,7 @@ import zlib
from contextlib import contextmanager
from datetime import datetime
from errno import ENOENT
+from typing import List, Optional
from xml.etree import ElementTree
from xml.sax.saxutils import escape
@@ -1004,7 +1005,7 @@ class WALinuxAgentShim:
@azure_ds_telemetry_reporter
def register_with_azure_and_fetch_data(
self, pubkey_info=None, iso_dev=None
- ) -> dict:
+ ) -> Optional[List[str]]:
"""Gets the VM's GoalState from Azure, uses the GoalState information
to report ready/send the ready signal/provisioning complete signal to
Azure, and then uses pubkey_info to filter and obtain the user's
@@ -1037,7 +1038,7 @@ class WALinuxAgentShim:
self.eject_iso(iso_dev)
health_reporter.send_ready_signal()
- return {"public-keys": ssh_keys}
+ return ssh_keys
@azure_ds_telemetry_reporter
def register_with_azure_and_report_failure(self, description: str) -> None:
diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py
index a6c43ea7..a47ed611 100644
--- a/tests/unittests/sources/test_azure.py
+++ b/tests/unittests/sources/test_azure.py
@@ -762,9 +762,7 @@ scbus-1 on xpt0 bus 0
dsaz.BUILTIN_DS_CONFIG["data_dir"] = self.waagent_d
self.m_is_platform_viable = mock.MagicMock(autospec=True)
- self.m_get_metadata_from_fabric = mock.MagicMock(
- return_value={"public-keys": []}
- )
+ self.m_get_metadata_from_fabric = mock.MagicMock(return_value=[])
self.m_report_failure_to_fabric = mock.MagicMock(autospec=True)
self.m_list_possible_azure_ds = mock.MagicMock(
side_effect=_load_possible_azure_ds
@@ -1725,10 +1723,10 @@ scbus-1 on xpt0 bus 0
def test_fabric_data_included_in_metadata(self):
dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()})
- self.m_get_metadata_from_fabric.return_value = {"test": "value"}
+ self.m_get_metadata_from_fabric.return_value = ["ssh-key-value"]
ret = self._get_and_setup(dsrc)
self.assertTrue(ret)
- self.assertEqual("value", dsrc.metadata["test"])
+ self.assertEqual(["ssh-key-value"], dsrc.metadata["public-keys"])
def test_instance_id_case_insensitive(self):
"""Return the previous iid when current is a case-insensitive match."""
@@ -2008,7 +2006,7 @@ scbus-1 on xpt0 bus 0
"sys_cfg": sys_cfg,
}
dsrc = self._get_ds(data)
- dsaz.get_metadata_from_fabric.return_value = {"public-keys": ["key2"]}
+ dsaz.get_metadata_from_fabric.return_value = ["key2"]
dsrc.get_data()
dsrc.setup(True)
ssh_keys = dsrc.get_public_ssh_keys()
diff --git a/tests/unittests/sources/test_azure_helper.py b/tests/unittests/sources/test_azure_helper.py
index 6f7f2890..98143bc3 100644
--- a/tests/unittests/sources/test_azure_helper.py
+++ b/tests/unittests/sources/test_azure_helper.py
@@ -1204,16 +1204,16 @@ class TestWALinuxAgentShim(CiTestCase):
[mock.call(self.GoalState.return_value.certificates_xml)],
sslmgr.parse_certificates.call_args_list,
)
- self.assertIn("expected-key", data["public-keys"])
- self.assertIn("expected-no-value-key", data["public-keys"])
- self.assertNotIn("should-not-be-found", data["public-keys"])
+ self.assertIn("expected-key", data)
+ self.assertIn("expected-no-value-key", data)
+ self.assertNotIn("should-not-be-found", data)
def test_absent_certificates_produces_empty_public_keys(self):
mypk = [{"fingerprint": "fp1", "path": "path1"}]
self.GoalState.return_value.certificates_xml = None
shim = wa_shim()
data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk)
- self.assertEqual([], data["public-keys"])
+ self.assertEqual([], data)
def test_correct_url_used_for_report_ready(self):
self.find_endpoint.return_value = "test_endpoint"