summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Patterson <cpatterson@microsoft.com>2022-02-10 15:03:10 -0500
committerGitHub <noreply@github.com>2022-02-10 14:03:10 -0600
commit50de985bc4e47dff1a8fc52abb7679032bb40cae (patch)
treeeec039c570ff733a202c6c7ef341fda3873b1c4f
parentc3482971f0f155475f367d6dec00bae25b79cfff (diff)
downloadvyos-cloud-init-50de985bc4e47dff1a8fc52abb7679032bb40cae.tar.gz
vyos-cloud-init-50de985bc4e47dff1a8fc52abb7679032bb40cae.zip
sources/azure: refactor ssh key handling (#1248)
Split _get_public_ssh_keys_and_source() into _get_public_keys_from_imds() and _get_public_keys_from_ovf(). Set _get_public_keys_from_imds() to take a parameter of the IMDS metadata rather than assuming it is already set in self.metadata. This will allow us to move negotation into local phase where self.metadata may not be set yet. Update this method to raise KeyError if IMDS metadata is missing/malformed, and ValueError if SSH key format is not supported. Update get_public_ssh_keys() to catch these errors and fall back to the OVF/Wireserver keys as needed. To improve clarity, update register_with_azure_and_fetch_data() to return the list of SSH keys, rather than bundling them into a dictionary for updating against the metadata dictionary. There should be no change in behavior with this refactor. Signed-off-by: Chris Patterson <cpatterson@microsoft.com>
-rwxr-xr-xcloudinit/sources/DataSourceAzure.py115
-rwxr-xr-xcloudinit/sources/helpers/azure.py5
-rw-r--r--tests/unittests/sources/test_azure.py10
-rw-r--r--tests/unittests/sources/test_azure_helper.py8
4 files changed, 62 insertions, 76 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py
index 44efd358..f4be4cda 100755
--- a/cloudinit/sources/DataSourceAzure.py
+++ b/cloudinit/sources/DataSourceAzure.py
@@ -11,11 +11,9 @@ import os
import os.path
import re
import xml.etree.ElementTree as ET
-from collections import namedtuple
from enum import Enum
-from functools import partial
from time import sleep, time
-from typing import Optional
+from typing import List, Optional
from xml.dom import minidom
import requests
@@ -71,10 +69,6 @@ IMDS_VER_MIN = "2019-06-01"
IMDS_VER_WANT = "2021-08-01"
IMDS_EXTENDED_VER_MIN = "2021-03-01"
-# This holds SSH key data including if the source was
-# from IMDS, as well as the SSH key data itself.
-SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys"))
-
class MetadataType(Enum):
ALL = "{}/instance".format(IMDS_URL)
@@ -740,63 +734,59 @@ class DataSourceAzure(sources.DataSource):
return self.ds_cfg["disk_aliases"].get(name)
@azure_ds_telemetry_reporter
- def get_public_ssh_keys(self):
+ def get_public_ssh_keys(self) -> List[str]:
"""
Retrieve public SSH keys.
"""
+ try:
+ return self._get_public_keys_from_imds(self.metadata["imds"])
+ except (KeyError, ValueError):
+ pass
- return self._get_public_ssh_keys_and_source().ssh_keys
+ return self._get_public_keys_from_ovf()
- def _get_public_ssh_keys_and_source(self):
- """
- Try to get the ssh keys from IMDS first, and if that fails
- (i.e. IMDS is unavailable) then fallback to getting the ssh
- keys from OVF.
+ def _get_public_keys_from_imds(self, imds_md: dict) -> List[str]:
+ """Get SSH keys from IMDS metadata.
- The benefit to getting keys from IMDS is a large performance
- advantage, so this is a strong preference. But we must keep
- OVF as a second option for environments that don't have IMDS.
- """
+ :raises KeyError: if IMDS metadata is malformed/missing.
+ :raises ValueError: if key format is not supported.
- LOG.debug("Retrieving public SSH keys")
- ssh_keys = []
- keys_from_imds = True
- LOG.debug("Attempting to get SSH keys from IMDS")
+ :returns: List of keys.
+ """
try:
ssh_keys = [
public_key["keyData"]
- for public_key in self.metadata["imds"]["compute"][
- "publicKeys"
- ]
+ for public_key in imds_md["compute"]["publicKeys"]
]
- for key in ssh_keys:
- if not _key_is_openssh_formatted(key=key):
- keys_from_imds = False
- break
-
- if not keys_from_imds:
- log_msg = "Keys not in OpenSSH format, using OVF"
- else:
- log_msg = "Retrieved {} keys from IMDS".format(
- len(ssh_keys) if ssh_keys is not None else 0
- )
except KeyError:
- log_msg = "Unable to get keys from IMDS, falling back to OVF"
- keys_from_imds = False
- finally:
+ log_msg = "No SSH keys found in IMDS metadata"
report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ raise
- if not keys_from_imds:
- LOG.debug("Attempting to get SSH keys from OVF")
- try:
- ssh_keys = self.metadata["public-keys"]
- log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
- except KeyError:
- log_msg = "No keys available from OVF"
- finally:
- report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ if any(not _key_is_openssh_formatted(key=key) for key in ssh_keys):
+ log_msg = "Key(s) not in OpenSSH format"
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ raise ValueError(log_msg)
+
+ log_msg = "Retrieved {} keys from IMDS".format(len(ssh_keys))
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ return ssh_keys
- return SSHKeys(keys_from_imds=keys_from_imds, ssh_keys=ssh_keys)
+ def _get_public_keys_from_ovf(self) -> List[str]:
+ """Get SSH keys that were fetched from wireserver.
+
+ :returns: List of keys.
+ """
+ ssh_keys = []
+ try:
+ ssh_keys = self.metadata["public-keys"]
+ log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+ except KeyError:
+ log_msg = "No keys available from OVF"
+ report_diagnostic_event(log_msg, logger_func=LOG.debug)
+
+ return ssh_keys
def get_config_obj(self):
return self.cfg
@@ -832,10 +822,10 @@ class DataSourceAzure(sources.DataSource):
self.get_instance_id(),
is_new_instance,
)
- fabric_data = self._negotiate()
- LOG.debug("negotiating returned %s", fabric_data)
- if fabric_data:
- self.metadata.update(fabric_data)
+ ssh_keys = self._negotiate()
+ LOG.debug("negotiating returned %s", ssh_keys)
+ if ssh_keys:
+ self.metadata["public-keys"] = ssh_keys
self._negotiated = True
else:
LOG.debug(
@@ -1462,24 +1452,21 @@ class DataSourceAzure(sources.DataSource):
On failure, returns False.
"""
pubkey_info = None
- ssh_keys_and_source = self._get_public_ssh_keys_and_source()
-
- if not ssh_keys_and_source.keys_from_imds:
+ try:
+ self._get_public_keys_from_imds(self.metadata["imds"])
+ except (KeyError, ValueError):
pubkey_info = self.cfg.get("_pubkeys", None)
log_msg = "Retrieved {} fingerprints from OVF".format(
len(pubkey_info) if pubkey_info is not None else 0
)
report_diagnostic_event(log_msg, logger_func=LOG.debug)
- metadata_func = partial(
- get_metadata_from_fabric,
- fallback_lease_file=self.dhclient_lease_file,
- pubkey_info=pubkey_info,
- )
-
LOG.debug("negotiating with fabric")
try:
- fabric_data = metadata_func()
+ ssh_keys = get_metadata_from_fabric(
+ fallback_lease_file=self.dhclient_lease_file,
+ pubkey_info=pubkey_info,
+ )
except Exception as e:
report_diagnostic_event(
"Error communicating with Azure fabric; You may experience "
@@ -1491,7 +1478,7 @@ class DataSourceAzure(sources.DataSource):
util.del_file(REPORTED_READY_MARKER_FILE)
util.del_file(REPROVISION_MARKER_FILE)
util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE)
- return fabric_data
+ return ssh_keys
@azure_ds_telemetry_reporter
def activate(self, cfg, is_new_instance):
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index 8e8f5ce5..ec6ab80c 100755
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -12,6 +12,7 @@ import zlib
from contextlib import contextmanager
from datetime import datetime
from errno import ENOENT
+from typing import List, Optional
from xml.etree import ElementTree
from xml.sax.saxutils import escape
@@ -1004,7 +1005,7 @@ class WALinuxAgentShim:
@azure_ds_telemetry_reporter
def register_with_azure_and_fetch_data(
self, pubkey_info=None, iso_dev=None
- ) -> dict:
+ ) -> Optional[List[str]]:
"""Gets the VM's GoalState from Azure, uses the GoalState information
to report ready/send the ready signal/provisioning complete signal to
Azure, and then uses pubkey_info to filter and obtain the user's
@@ -1037,7 +1038,7 @@ class WALinuxAgentShim:
self.eject_iso(iso_dev)
health_reporter.send_ready_signal()
- return {"public-keys": ssh_keys}
+ return ssh_keys
@azure_ds_telemetry_reporter
def register_with_azure_and_report_failure(self, description: str) -> None:
diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py
index a6c43ea7..a47ed611 100644
--- a/tests/unittests/sources/test_azure.py
+++ b/tests/unittests/sources/test_azure.py
@@ -762,9 +762,7 @@ scbus-1 on xpt0 bus 0
dsaz.BUILTIN_DS_CONFIG["data_dir"] = self.waagent_d
self.m_is_platform_viable = mock.MagicMock(autospec=True)
- self.m_get_metadata_from_fabric = mock.MagicMock(
- return_value={"public-keys": []}
- )
+ self.m_get_metadata_from_fabric = mock.MagicMock(return_value=[])
self.m_report_failure_to_fabric = mock.MagicMock(autospec=True)
self.m_list_possible_azure_ds = mock.MagicMock(
side_effect=_load_possible_azure_ds
@@ -1725,10 +1723,10 @@ scbus-1 on xpt0 bus 0
def test_fabric_data_included_in_metadata(self):
dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()})
- self.m_get_metadata_from_fabric.return_value = {"test": "value"}
+ self.m_get_metadata_from_fabric.return_value = ["ssh-key-value"]
ret = self._get_and_setup(dsrc)
self.assertTrue(ret)
- self.assertEqual("value", dsrc.metadata["test"])
+ self.assertEqual(["ssh-key-value"], dsrc.metadata["public-keys"])
def test_instance_id_case_insensitive(self):
"""Return the previous iid when current is a case-insensitive match."""
@@ -2008,7 +2006,7 @@ scbus-1 on xpt0 bus 0
"sys_cfg": sys_cfg,
}
dsrc = self._get_ds(data)
- dsaz.get_metadata_from_fabric.return_value = {"public-keys": ["key2"]}
+ dsaz.get_metadata_from_fabric.return_value = ["key2"]
dsrc.get_data()
dsrc.setup(True)
ssh_keys = dsrc.get_public_ssh_keys()
diff --git a/tests/unittests/sources/test_azure_helper.py b/tests/unittests/sources/test_azure_helper.py
index 6f7f2890..98143bc3 100644
--- a/tests/unittests/sources/test_azure_helper.py
+++ b/tests/unittests/sources/test_azure_helper.py
@@ -1204,16 +1204,16 @@ class TestWALinuxAgentShim(CiTestCase):
[mock.call(self.GoalState.return_value.certificates_xml)],
sslmgr.parse_certificates.call_args_list,
)
- self.assertIn("expected-key", data["public-keys"])
- self.assertIn("expected-no-value-key", data["public-keys"])
- self.assertNotIn("should-not-be-found", data["public-keys"])
+ self.assertIn("expected-key", data)
+ self.assertIn("expected-no-value-key", data)
+ self.assertNotIn("should-not-be-found", data)
def test_absent_certificates_produces_empty_public_keys(self):
mypk = [{"fingerprint": "fp1", "path": "path1"}]
self.GoalState.return_value.certificates_xml = None
shim = wa_shim()
data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk)
- self.assertEqual([], data["public-keys"])
+ self.assertEqual([], data)
def test_correct_url_used_for_report_ready(self):
self.find_endpoint.return_value = "test_endpoint"