diff options
-rwxr-xr-x | cloudinit/sources/DataSourceAzure.py | 42 | ||||
-rw-r--r-- | tests/unittests/sources/test_azure.py | 30 |
2 files changed, 45 insertions, 27 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index a8b403e8..f5630840 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -297,13 +297,10 @@ class DataSourceAzure(sources.DataSource): self.dhclient_lease_file = self.ds_cfg.get("dhclient_lease_file") self._network_config = None self._ephemeral_dhcp_ctx = None - self.failed_desired_api_version = False self.iso_dev = None def _unpickle(self, ci_pkl_version: int) -> None: super()._unpickle(ci_pkl_version) - if not hasattr(self, "failed_desired_api_version"): - self.failed_desired_api_version = False if not hasattr(self, "iso_dev"): self.iso_dev = None @@ -647,29 +644,24 @@ class DataSourceAzure(sources.DataSource): this fault tolerant and fall back to a good known minimum api version. """ - - if not self.failed_desired_api_version: - for _ in range(retries): - try: - LOG.info("Attempting IMDS api-version: %s", IMDS_VER_WANT) - return get_metadata_from_imds( - fallback_nic=fallback_nic, - retries=0, - md_type=md_type, - api_version=IMDS_VER_WANT, - exc_cb=exc_cb, - ) - except UrlError as err: - LOG.info( - "UrlError with IMDS api-version: %s", IMDS_VER_WANT + for _ in range(retries): + try: + LOG.info("Attempting IMDS api-version: %s", IMDS_VER_WANT) + return get_metadata_from_imds( + fallback_nic=fallback_nic, + retries=0, + md_type=md_type, + api_version=IMDS_VER_WANT, + exc_cb=exc_cb, + ) + except UrlError as err: + LOG.info("UrlError with IMDS api-version: %s", IMDS_VER_WANT) + if err.code == 400: + log_msg = "Fall back to IMDS api-version: {}".format( + IMDS_VER_MIN ) - if err.code == 400: - log_msg = "Fall back to IMDS api-version: {}".format( - IMDS_VER_MIN - ) - report_diagnostic_event(log_msg, logger_func=LOG.info) - self.failed_desired_api_version = True - break + report_diagnostic_event(log_msg, logger_func=LOG.info) + break LOG.info("Using IMDS api-version: %s", IMDS_VER_MIN) return get_metadata_from_imds( diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py index 8b0762b7..44c0a545 100644 --- a/tests/unittests/sources/test_azure.py +++ b/tests/unittests/sources/test_azure.py @@ -2149,7 +2149,24 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) dsrc.get_data() self.assertIsNotNone(dsrc.metadata) - self.assertTrue(dsrc.failed_desired_api_version) + + assert m_get_metadata_from_imds.mock_calls == [ + mock.call( + fallback_nic="eth9", + retries=0, + md_type=dsaz.metadata_type.all, + api_version="2021-08-01", + exc_cb=mock.ANY, + ), + mock.call( + fallback_nic="eth9", + retries=10, + md_type=dsaz.metadata_type.all, + api_version="2019-06-01", + exc_cb=mock.ANY, + infinite=False, + ), + ] @mock.patch( MOCKPATH + "get_metadata_from_imds", return_value=NETWORK_METADATA @@ -2164,7 +2181,16 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) dsrc.get_data() self.assertIsNotNone(dsrc.metadata) - self.assertFalse(dsrc.failed_desired_api_version) + + assert m_get_metadata_from_imds.mock_calls == [ + mock.call( + fallback_nic="eth9", + retries=0, + md_type=dsaz.metadata_type.all, + api_version="2021-08-01", + exc_cb=mock.ANY, + ) + ] @mock.patch(MOCKPATH + "get_metadata_from_imds") def test_hostname_from_imds(self, m_get_metadata_from_imds): |