summaryrefslogtreecommitdiff
path: root/cloudinit/sources/DataSourceGCE.py
diff options
context:
space:
mode:
Diffstat (limited to 'cloudinit/sources/DataSourceGCE.py')
-rw-r--r--cloudinit/sources/DataSourceGCE.py251
1 files changed, 148 insertions, 103 deletions
diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py
index 746caddb..c470bea8 100644
--- a/cloudinit/sources/DataSourceGCE.py
+++ b/cloudinit/sources/DataSourceGCE.py
@@ -4,39 +4,46 @@
import datetime
import json
-
from base64 import b64decode
+from contextlib import suppress as noop
from cloudinit import dmi
-from cloudinit.distros import ug_util
from cloudinit import log as logging
-from cloudinit import sources
-from cloudinit import url_helper
-from cloudinit import util
+from cloudinit import sources, url_helper, util
+from cloudinit.distros import ug_util
+from cloudinit.net.dhcp import EphemeralDHCPv4
LOG = logging.getLogger(__name__)
-MD_V1_URL = 'http://metadata.google.internal/computeMetadata/v1/'
-BUILTIN_DS_CONFIG = {'metadata_url': MD_V1_URL}
-REQUIRED_FIELDS = ('instance-id', 'availability-zone', 'local-hostname')
-GUEST_ATTRIBUTES_URL = ('http://metadata.google.internal/computeMetadata/'
- 'v1/instance/guest-attributes')
-HOSTKEY_NAMESPACE = 'hostkeys'
-HEADERS = {'Metadata-Flavor': 'Google'}
+MD_V1_URL = "http://metadata.google.internal/computeMetadata/v1/"
+BUILTIN_DS_CONFIG = {"metadata_url": MD_V1_URL}
+REQUIRED_FIELDS = ("instance-id", "availability-zone", "local-hostname")
+GUEST_ATTRIBUTES_URL = (
+ "http://metadata.google.internal/computeMetadata/"
+ "v1/instance/guest-attributes"
+)
+HOSTKEY_NAMESPACE = "hostkeys"
+HEADERS = {"Metadata-Flavor": "Google"}
class GoogleMetadataFetcher(object):
-
- def __init__(self, metadata_address):
+ def __init__(self, metadata_address, num_retries, sec_between_retries):
self.metadata_address = metadata_address
+ self.num_retries = num_retries
+ self.sec_between_retries = sec_between_retries
def get_value(self, path, is_text, is_recursive=False):
value = None
try:
url = self.metadata_address + path
if is_recursive:
- url += '/?recursive=True'
- resp = url_helper.readurl(url=url, headers=HEADERS)
+ url += "/?recursive=True"
+ resp = url_helper.readurl(
+ url=url,
+ headers=HEADERS,
+ retries=self.num_retries,
+ sec_between=self.sec_between_retries,
+ )
except url_helper.UrlError as exc:
msg = "url %s raised exception %s"
LOG.debug(msg, path, exc)
@@ -45,7 +52,7 @@ class GoogleMetadataFetcher(object):
if is_text:
value = util.decode_binary(resp.contents)
else:
- value = resp.contents.decode('utf-8')
+ value = resp.contents.decode("utf-8")
else:
LOG.debug("url %s returned code %s", path, resp.code)
return value
@@ -53,7 +60,8 @@ class GoogleMetadataFetcher(object):
class DataSourceGCE(sources.DataSource):
- dsname = 'GCE'
+ dsname = "GCE"
+ perform_dhcp_setup = False
def __init__(self, sys_cfg, distro, paths):
sources.DataSource.__init__(self, sys_cfg, distro, paths)
@@ -62,24 +70,38 @@ class DataSourceGCE(sources.DataSource):
(users, _groups) = ug_util.normalize_users_groups(sys_cfg, distro)
(self.default_user, _user_config) = ug_util.extract_default(users)
self.metadata = dict()
- self.ds_cfg = util.mergemanydict([
- util.get_cfg_by_path(sys_cfg, ["datasource", "GCE"], {}),
- BUILTIN_DS_CONFIG])
- self.metadata_address = self.ds_cfg['metadata_url']
+ self.ds_cfg = util.mergemanydict(
+ [
+ util.get_cfg_by_path(sys_cfg, ["datasource", "GCE"], {}),
+ BUILTIN_DS_CONFIG,
+ ]
+ )
+ self.metadata_address = self.ds_cfg["metadata_url"]
def _get_data(self):
- ret = util.log_time(
- LOG.debug, 'Crawl of GCE metadata service',
- read_md, kwargs={'address': self.metadata_address})
-
- if not ret['success']:
- if ret['platform_reports_gce']:
- LOG.warning(ret['reason'])
+ url_params = self.get_url_params()
+ network_context = noop()
+ if self.perform_dhcp_setup:
+ network_context = EphemeralDHCPv4(self.fallback_interface)
+ with network_context:
+ ret = util.log_time(
+ LOG.debug,
+ "Crawl of GCE metadata service",
+ read_md,
+ kwargs={
+ "address": self.metadata_address,
+ "url_params": url_params,
+ },
+ )
+
+ if not ret["success"]:
+ if ret["platform_reports_gce"]:
+ LOG.warning(ret["reason"])
else:
- LOG.debug(ret['reason'])
+ LOG.debug(ret["reason"])
return False
- self.metadata = ret['meta-data']
- self.userdata_raw = ret['user-data']
+ self.metadata = ret["meta-data"]
+ self.userdata_raw = ret["user-data"]
return True
@property
@@ -88,10 +110,10 @@ class DataSourceGCE(sources.DataSource):
return None
def get_instance_id(self):
- return self.metadata['instance-id']
+ return self.metadata["instance-id"]
def get_public_ssh_keys(self):
- public_keys_data = self.metadata['public-keys-data']
+ public_keys_data = self.metadata["public-keys-data"]
return _parse_public_keys(public_keys_data, self.default_user)
def publish_host_keys(self, hostkeys):
@@ -100,26 +122,35 @@ class DataSourceGCE(sources.DataSource):
def get_hostname(self, fqdn=False, resolve_ip=False, metadata_only=False):
# GCE has long FDQN's and has asked for short hostnames.
- return self.metadata['local-hostname'].split('.')[0]
+ return self.metadata["local-hostname"].split(".")[0]
@property
def availability_zone(self):
- return self.metadata['availability-zone']
+ return self.metadata["availability-zone"]
@property
def region(self):
- return self.availability_zone.rsplit('-', 1)[0]
+ return self.availability_zone.rsplit("-", 1)[0]
+
+
+class DataSourceGCELocal(DataSourceGCE):
+ perform_dhcp_setup = True
def _write_host_key_to_guest_attributes(key_type, key_value):
- url = '%s/%s/%s' % (GUEST_ATTRIBUTES_URL, HOSTKEY_NAMESPACE, key_type)
- key_value = key_value.encode('utf-8')
- resp = url_helper.readurl(url=url, data=key_value, headers=HEADERS,
- request_method='PUT', check_status=False)
+ url = "%s/%s/%s" % (GUEST_ATTRIBUTES_URL, HOSTKEY_NAMESPACE, key_type)
+ key_value = key_value.encode("utf-8")
+ resp = url_helper.readurl(
+ url=url,
+ data=key_value,
+ headers=HEADERS,
+ request_method="PUT",
+ check_status=False,
+ )
if resp.ok():
- LOG.debug('Wrote %s host key to guest attributes.', key_type)
+ LOG.debug("Wrote %s host key to guest attributes.", key_type)
else:
- LOG.debug('Unable to write %s host key to guest attributes.', key_type)
+ LOG.debug("Unable to write %s host key to guest attributes.", key_type)
def _has_expired(public_key):
@@ -133,7 +164,7 @@ def _has_expired(public_key):
return False
# Do not expire keys if they do not have the expected schema identifier.
- if schema != 'google-ssh':
+ if schema != "google-ssh":
return False
try:
@@ -142,11 +173,11 @@ def _has_expired(public_key):
return False
# Do not expire keys if there is no expriation timestamp.
- if 'expireOn' not in json_obj:
+ if "expireOn" not in json_obj:
return False
- expire_str = json_obj['expireOn']
- format_str = '%Y-%m-%dT%H:%M:%S+0000'
+ expire_str = json_obj["expireOn"]
+ format_str = "%Y-%m-%dT%H:%M:%S+0000"
try:
expire_time = datetime.datetime.strptime(expire_str, format_str)
except ValueError:
@@ -167,44 +198,49 @@ def _parse_public_keys(public_keys_data, default_user=None):
for public_key in public_keys_data:
if not public_key or not all(ord(c) < 128 for c in public_key):
continue
- split_public_key = public_key.split(':', 1)
+ split_public_key = public_key.split(":", 1)
if len(split_public_key) != 2:
continue
user, key = split_public_key
- if user in ('cloudinit', default_user) and not _has_expired(key):
+ if user in ("cloudinit", default_user) and not _has_expired(key):
public_keys.append(key)
return public_keys
-def read_md(address=None, platform_check=True):
+def read_md(address=None, url_params=None, platform_check=True):
if address is None:
address = MD_V1_URL
- ret = {'meta-data': None, 'user-data': None,
- 'success': False, 'reason': None}
- ret['platform_reports_gce'] = platform_reports_gce()
+ ret = {
+ "meta-data": None,
+ "user-data": None,
+ "success": False,
+ "reason": None,
+ }
+ ret["platform_reports_gce"] = platform_reports_gce()
- if platform_check and not ret['platform_reports_gce']:
- ret['reason'] = "Not running on GCE."
+ if platform_check and not ret["platform_reports_gce"]:
+ ret["reason"] = "Not running on GCE."
return ret
# If we cannot resolve the metadata server, then no point in trying.
if not util.is_resolvable_url(address):
LOG.debug("%s is not resolvable", address)
- ret['reason'] = 'address "%s" is not resolvable' % address
+ ret["reason"] = 'address "%s" is not resolvable' % address
return ret
# url_map: (our-key, path, required, is_text, is_recursive)
url_map = [
- ('instance-id', ('instance/id',), True, True, False),
- ('availability-zone', ('instance/zone',), True, True, False),
- ('local-hostname', ('instance/hostname',), True, True, False),
- ('instance-data', ('instance/attributes',), False, False, True),
- ('project-data', ('project/attributes',), False, False, True),
+ ("instance-id", ("instance/id",), True, True, False),
+ ("availability-zone", ("instance/zone",), True, True, False),
+ ("local-hostname", ("instance/hostname",), True, True, False),
+ ("instance-data", ("instance/attributes",), False, False, True),
+ ("project-data", ("project/attributes",), False, False, True),
]
-
- metadata_fetcher = GoogleMetadataFetcher(address)
+ metadata_fetcher = GoogleMetadataFetcher(
+ address, url_params.num_retries, url_params.sec_between_retries
+ )
md = {}
# Iterate over url_map keys to get metadata items.
for (mkey, paths, required, is_text, is_recursive) in url_map:
@@ -215,56 +251,58 @@ def read_md(address=None, platform_check=True):
value = new_value
if required and value is None:
msg = "required key %s returned nothing. not GCE"
- ret['reason'] = msg % mkey
+ ret["reason"] = msg % mkey
return ret
md[mkey] = value
- instance_data = json.loads(md['instance-data'] or '{}')
- project_data = json.loads(md['project-data'] or '{}')
- valid_keys = [instance_data.get('sshKeys'), instance_data.get('ssh-keys')]
- block_project = instance_data.get('block-project-ssh-keys', '').lower()
- if block_project != 'true' and not instance_data.get('sshKeys'):
- valid_keys.append(project_data.get('ssh-keys'))
- valid_keys.append(project_data.get('sshKeys'))
- public_keys_data = '\n'.join([key for key in valid_keys if key])
- md['public-keys-data'] = public_keys_data.splitlines()
+ instance_data = json.loads(md["instance-data"] or "{}")
+ project_data = json.loads(md["project-data"] or "{}")
+ valid_keys = [instance_data.get("sshKeys"), instance_data.get("ssh-keys")]
+ block_project = instance_data.get("block-project-ssh-keys", "").lower()
+ if block_project != "true" and not instance_data.get("sshKeys"):
+ valid_keys.append(project_data.get("ssh-keys"))
+ valid_keys.append(project_data.get("sshKeys"))
+ public_keys_data = "\n".join([key for key in valid_keys if key])
+ md["public-keys-data"] = public_keys_data.splitlines()
- if md['availability-zone']:
- md['availability-zone'] = md['availability-zone'].split('/')[-1]
+ if md["availability-zone"]:
+ md["availability-zone"] = md["availability-zone"].split("/")[-1]
- if 'user-data' in instance_data:
+ if "user-data" in instance_data:
# instance_data was json, so values are all utf-8 strings.
- ud = instance_data['user-data'].encode("utf-8")
- encoding = instance_data.get('user-data-encoding')
- if encoding == 'base64':
+ ud = instance_data["user-data"].encode("utf-8")
+ encoding = instance_data.get("user-data-encoding")
+ if encoding == "base64":
ud = b64decode(ud)
elif encoding:
- LOG.warning('unknown user-data-encoding: %s, ignoring', encoding)
- ret['user-data'] = ud
+ LOG.warning("unknown user-data-encoding: %s, ignoring", encoding)
+ ret["user-data"] = ud
- ret['meta-data'] = md
- ret['success'] = True
+ ret["meta-data"] = md
+ ret["success"] = True
return ret
def platform_reports_gce():
- pname = dmi.read_dmi_data('system-product-name') or "N/A"
- if pname == "Google Compute Engine":
+ pname = dmi.read_dmi_data("system-product-name") or "N/A"
+ if pname == "Google Compute Engine" or pname == "Google":
return True
# system-product-name is not always guaranteed (LP: #1674861)
- serial = dmi.read_dmi_data('system-serial-number') or "N/A"
+ serial = dmi.read_dmi_data("system-serial-number") or "N/A"
if serial.startswith("GoogleCloud-"):
return True
- LOG.debug("Not running on google cloud. product-name=%s serial=%s",
- pname, serial)
+ LOG.debug(
+ "Not running on google cloud. product-name=%s serial=%s", pname, serial
+ )
return False
# Used to match classes to dependencies.
datasources = [
+ (DataSourceGCELocal, (sources.DEP_FILESYSTEM,)),
(DataSourceGCE, (sources.DEP_FILESYSTEM, sources.DEP_NETWORK)),
]
@@ -277,31 +315,38 @@ def get_datasource_list(depends):
if __name__ == "__main__":
import argparse
import sys
-
from base64 import b64encode
- parser = argparse.ArgumentParser(description='Query GCE Metadata Service')
- parser.add_argument("--endpoint", metavar="URL",
- help="The url of the metadata service.",
- default=MD_V1_URL)
- parser.add_argument("--no-platform-check", dest="platform_check",
- help="Ignore smbios platform check",
- action='store_false', default=True)
+ parser = argparse.ArgumentParser(description="Query GCE Metadata Service")
+ parser.add_argument(
+ "--endpoint",
+ metavar="URL",
+ help="The url of the metadata service.",
+ default=MD_V1_URL,
+ )
+ parser.add_argument(
+ "--no-platform-check",
+ dest="platform_check",
+ help="Ignore smbios platform check",
+ action="store_false",
+ default=True,
+ )
args = parser.parse_args()
data = read_md(address=args.endpoint, platform_check=args.platform_check)
- if 'user-data' in data:
+ if "user-data" in data:
# user-data is bytes not string like other things. Handle it specially.
# If it can be represented as utf-8 then do so. Otherwise print base64
# encoded value in the key user-data-b64.
try:
- data['user-data'] = data['user-data'].decode()
+ data["user-data"] = data["user-data"].decode()
except UnicodeDecodeError:
- sys.stderr.write("User-data cannot be decoded. "
- "Writing as base64\n")
- del data['user-data']
+ sys.stderr.write(
+ "User-data cannot be decoded. Writing as base64\n"
+ )
+ del data["user-data"]
# b64encode returns a bytes value. Decode to get the string.
- data['user-data-b64'] = b64encode(data['user-data']).decode()
+ data["user-data-b64"] = b64encode(data["user-data"]).decode()
- print(json.dumps(data, indent=1, sort_keys=True, separators=(',', ': ')))
+ print(json.dumps(data, indent=1, sort_keys=True, separators=(",", ": ")))
# vi: ts=4 expandtab