diff options
-rw-r--r-- | cloudinit/sources/DataSourceGCE.py | 25 | ||||
-rw-r--r-- | tests/integration_tests/modules/test_combined.py | 41 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_common.py | 1 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_gce.py | 24 |
4 files changed, 87 insertions, 4 deletions
diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py index 9f838bd4..b82fa410 100644 --- a/cloudinit/sources/DataSourceGCE.py +++ b/cloudinit/sources/DataSourceGCE.py @@ -4,6 +4,7 @@ import datetime import json +from contextlib import suppress as noop from base64 import b64decode @@ -13,6 +14,7 @@ from cloudinit import log as logging from cloudinit import sources from cloudinit import url_helper from cloudinit import util +from cloudinit.net.dhcp import EphemeralDHCPv4 LOG = logging.getLogger(__name__) @@ -58,6 +60,7 @@ class GoogleMetadataFetcher(object): class DataSourceGCE(sources.DataSource): dsname = 'GCE' + perform_dhcp_setup = False def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) @@ -73,10 +76,19 @@ class DataSourceGCE(sources.DataSource): def _get_data(self): url_params = self.get_url_params() - ret = util.log_time( - LOG.debug, 'Crawl of GCE metadata service', - read_md, kwargs={'address': self.metadata_address, - 'url_params': url_params}) + network_context = noop() + if self.perform_dhcp_setup: + network_context = EphemeralDHCPv4(self.fallback_interface) + with network_context: + ret = util.log_time( + LOG.debug, + "Crawl of GCE metadata service", + read_md, + kwargs={ + "address": self.metadata_address, + "url_params": url_params, + }, + ) if not ret['success']: if ret['platform_reports_gce']: @@ -117,6 +129,10 @@ class DataSourceGCE(sources.DataSource): return self.availability_zone.rsplit('-', 1)[0] +class DataSourceGCELocal(DataSourceGCE): + perform_dhcp_setup = True + + def _write_host_key_to_guest_attributes(key_type, key_value): url = '%s/%s/%s' % (GUEST_ATTRIBUTES_URL, HOSTKEY_NAMESPACE, key_type) key_value = key_value.encode('utf-8') @@ -272,6 +288,7 @@ def platform_reports_gce(): # Used to match classes to dependencies. datasources = [ + (DataSourceGCELocal, (sources.DEP_FILESYSTEM,)), (DataSourceGCE, (sources.DEP_FILESYSTEM, sources.DEP_NETWORK)), ] diff --git a/tests/integration_tests/modules/test_combined.py b/tests/integration_tests/modules/test_combined.py index bc19c2a2..758c96fa 100644 --- a/tests/integration_tests/modules/test_combined.py +++ b/tests/integration_tests/modules/test_combined.py @@ -209,6 +209,31 @@ class TestCombined: log = client.read_from_file('/var/log/cloud-init.log') verify_clean_log(log) + def test_correct_datasource_detected( + self, class_client: IntegrationInstance + ): + """Test datasource is detected at the proper boot stage.""" + client = class_client + status_file = client.read_from_file("/run/cloud-init/status.json") + + platform_datasources = { + "azure": "DataSourceAzure [seed=/dev/sr0]", + "ec2": "DataSourceEc2Local", + "gce": "DataSourceGCELocal", + "oci": "DataSourceOracle", + "openstack": "DataSourceOpenStackLocal [net,ver=2]", + "lxd_container": ( + "DataSourceNoCloud " + "[seed=/var/lib/cloud/seed/nocloud-net][dsmode=net]" + ), + "lxd_vm": "DataSourceNoCloud [seed=/dev/sr0][dsmode=net]", + } + + assert ( + platform_datasources[client.settings.PLATFORM] + == json.loads(status_file)["v1"]["datasource"] + ) + def _check_common_metadata(self, data): assert data['base64_encoded_keys'] == [] assert data['merged_cfg'] == 'redacted for non-root user' @@ -277,3 +302,19 @@ class TestCombined: assert v1_data['instance_id'] == client.instance.name assert v1_data['local_hostname'].startswith('ip-') assert v1_data['region'] == client.cloud.cloud_instance.region + + @pytest.mark.gce + def test_instance_json_gce(self, class_client: IntegrationInstance): + client = class_client + instance_json_file = client.read_from_file( + "/run/cloud-init/instance-data.json" + ) + data = json.loads(instance_json_file) + self._check_common_metadata(data) + v1_data = data["v1"] + assert v1_data["cloud_name"] == "gce" + assert v1_data["platform"] == "gce" + assert v1_data["subplatform"].startswith("metadata") + assert v1_data["availability_zone"] == client.instance.zone + assert v1_data["instance_id"] == client.instance.instance_id + assert v1_data["local_hostname"] == client.instance.name diff --git a/tests/unittests/test_datasource/test_common.py b/tests/unittests/test_datasource/test_common.py index 17d53160..9089e5de 100644 --- a/tests/unittests/test_datasource/test_common.py +++ b/tests/unittests/test_datasource/test_common.py @@ -41,6 +41,7 @@ DEFAULT_LOCAL = [ CloudSigma.DataSourceCloudSigma, ConfigDrive.DataSourceConfigDrive, DigitalOcean.DataSourceDigitalOcean, + GCE.DataSourceGCELocal, Hetzner.DataSourceHetzner, IBMCloud.DataSourceIBMCloud, LXD.DataSourceLXD, diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py index 80b38f9e..1d91b301 100644 --- a/tests/unittests/test_datasource/test_gce.py +++ b/tests/unittests/test_datasource/test_gce.py @@ -360,5 +360,29 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase): self.ds.publish_host_keys(hostkeys) m_readurl.assert_has_calls(readurl_expected_calls, any_order=True) + @mock.patch( + "cloudinit.sources.DataSourceGCE.EphemeralDHCPv4", + autospec=True, + ) + @mock.patch( + "cloudinit.sources.DataSourceGCE.DataSourceGCELocal.fallback_interface" + ) + def test_local_datasource_uses_ephemeral_dhcp(self, _m_fallback, m_dhcp): + _set_mock_metadata() + ds = DataSourceGCE.DataSourceGCELocal( + sys_cfg={}, distro=None, paths=None + ) + ds._get_data() + assert m_dhcp.call_count == 1 + + @mock.patch( + "cloudinit.sources.DataSourceGCE.EphemeralDHCPv4", + autospec=True, + ) + def test_datasource_doesnt_use_ephemeral_dhcp(self, m_dhcp): + _set_mock_metadata() + ds = DataSourceGCE.DataSourceGCE(sys_cfg={}, distro=None, paths=None) + ds._get_data() + assert m_dhcp.call_count == 0 # vi: ts=4 expandtab |