summaryrefslogtreecommitdiff
path: root/cloudinit
diff options
context:
space:
mode:
Diffstat (limited to 'cloudinit')
-rw-r--r--cloudinit/sources/DataSourceGCE.py134
1 files changed, 95 insertions, 39 deletions
diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py
index ad6dae37..2da34a99 100644
--- a/cloudinit/sources/DataSourceGCE.py
+++ b/cloudinit/sources/DataSourceGCE.py
@@ -2,8 +2,12 @@
#
# This file is part of cloud-init. See LICENSE file for license information.
+import datetime
+import json
+
from base64 import b64decode
+from cloudinit.distros import ug_util
from cloudinit import log as logging
from cloudinit import sources
from cloudinit import url_helper
@@ -17,16 +21,18 @@ REQUIRED_FIELDS = ('instance-id', 'availability-zone', 'local-hostname')
class GoogleMetadataFetcher(object):
- headers = {'X-Google-Metadata-Request': 'True'}
+ headers = {'Metadata-Flavor': 'Google'}
def __init__(self, metadata_address):
self.metadata_address = metadata_address
- def get_value(self, path, is_text):
+ def get_value(self, path, is_text, is_recursive=False):
value = None
try:
- resp = url_helper.readurl(url=self.metadata_address + path,
- headers=self.headers)
+ url = self.metadata_address + path
+ if is_recursive:
+ url += '/?recursive=True'
+ resp = url_helper.readurl(url=url, headers=self.headers)
except url_helper.UrlError as exc:
msg = "url %s raised exception %s"
LOG.debug(msg, path, exc)
@@ -35,7 +41,7 @@ class GoogleMetadataFetcher(object):
if is_text:
value = util.decode_binary(resp.contents)
else:
- value = resp.contents
+ value = resp.contents.decode('utf-8')
else:
LOG.debug("url %s returned code %s", path, resp.code)
return value
@@ -47,6 +53,10 @@ class DataSourceGCE(sources.DataSource):
def __init__(self, sys_cfg, distro, paths):
sources.DataSource.__init__(self, sys_cfg, distro, paths)
+ self.default_user = None
+ if distro:
+ (users, _groups) = ug_util.normalize_users_groups(sys_cfg, distro)
+ (self.default_user, _user_config) = ug_util.extract_default(users)
self.metadata = dict()
self.ds_cfg = util.mergemanydict([
util.get_cfg_by_path(sys_cfg, ["datasource", "GCE"], {}),
@@ -70,17 +80,18 @@ class DataSourceGCE(sources.DataSource):
@property
def launch_index(self):
- # GCE does not provide lauch_index property
+ # GCE does not provide lauch_index property.
return None
def get_instance_id(self):
return self.metadata['instance-id']
def get_public_ssh_keys(self):
- return self.metadata['public-keys']
+ public_keys_data = self.metadata['public-keys-data']
+ return _parse_public_keys(public_keys_data, self.default_user)
def get_hostname(self, fqdn=False, resolve_ip=False):
- # GCE has long FDQN's and has asked for short hostnames
+ # GCE has long FDQN's and has asked for short hostnames.
return self.metadata['local-hostname'].split('.')[0]
@property
@@ -92,15 +103,58 @@ class DataSourceGCE(sources.DataSource):
return self.availability_zone.rsplit('-', 1)[0]
-def _trim_key(public_key):
- # GCE takes sshKeys attribute in the format of '<user>:<public_key>'
- # so we have to trim each key to remove the username part
+def _has_expired(public_key):
+ # Check whether an SSH key is expired. Public key input is a single SSH
+ # public key in the GCE specific key format documented here:
+ # https://cloud.google.com/compute/docs/instances/adding-removing-ssh-keys#sshkeyformat
+ try:
+ # Check for the Google-specific schema identifier.
+ schema, json_str = public_key.split(None, 3)[2:]
+ except (ValueError, AttributeError):
+ return False
+
+ # Do not expire keys if they do not have the expected schema identifier.
+ if schema != 'google-ssh':
+ return False
+
+ try:
+ json_obj = json.loads(json_str)
+ except ValueError:
+ return False
+
+ # Do not expire keys if there is no expriation timestamp.
+ if 'expireOn' not in json_obj:
+ return False
+
+ expire_str = json_obj['expireOn']
+ format_str = '%Y-%m-%dT%H:%M:%S+0000'
try:
- index = public_key.index(':')
- if index > 0:
- return public_key[(index + 1):]
- except Exception:
- return public_key
+ expire_time = datetime.datetime.strptime(expire_str, format_str)
+ except ValueError:
+ return False
+
+ # Expire the key if and only if we have exceeded the expiration timestamp.
+ return datetime.datetime.utcnow() > expire_time
+
+
+def _parse_public_keys(public_keys_data, default_user=None):
+ # Parse the SSH key data for the default user account. Public keys input is
+ # a list containing SSH public keys in the GCE specific key format
+ # documented here:
+ # https://cloud.google.com/compute/docs/instances/adding-removing-ssh-keys#sshkeyformat
+ public_keys = []
+ if not public_keys_data:
+ return public_keys
+ for public_key in public_keys_data:
+ if not public_key or not all(ord(c) < 128 for c in public_key):
+ continue
+ split_public_key = public_key.split(':', 1)
+ if len(split_public_key) != 2:
+ continue
+ user, key = split_public_key
+ if user in ('cloudinit', default_user) and not _has_expired(key):
+ public_keys.append(key)
+ return public_keys
def read_md(address=None, platform_check=True):
@@ -116,31 +170,28 @@ def read_md(address=None, platform_check=True):
ret['reason'] = "Not running on GCE."
return ret
- # if we cannot resolve the metadata server, then no point in trying
+ # If we cannot resolve the metadata server, then no point in trying.
if not util.is_resolvable_url(address):
LOG.debug("%s is not resolvable", address)
ret['reason'] = 'address "%s" is not resolvable' % address
return ret
- # url_map: (our-key, path, required, is_text)
+ # url_map: (our-key, path, required, is_text, is_recursive)
url_map = [
- ('instance-id', ('instance/id',), True, True),
- ('availability-zone', ('instance/zone',), True, True),
- ('local-hostname', ('instance/hostname',), True, True),
- ('public-keys', ('project/attributes/sshKeys',
- 'instance/attributes/ssh-keys'), False, True),
- ('user-data', ('instance/attributes/user-data',), False, False),
- ('user-data-encoding', ('instance/attributes/user-data-encoding',),
- False, True),
+ ('instance-id', ('instance/id',), True, True, False),
+ ('availability-zone', ('instance/zone',), True, True, False),
+ ('local-hostname', ('instance/hostname',), True, True, False),
+ ('instance-data', ('instance/attributes',), False, False, True),
+ ('project-data', ('project/attributes',), False, False, True),
]
metadata_fetcher = GoogleMetadataFetcher(address)
md = {}
- # iterate over url_map keys to get metadata items
- for (mkey, paths, required, is_text) in url_map:
+ # Iterate over url_map keys to get metadata items.
+ for (mkey, paths, required, is_text, is_recursive) in url_map:
value = None
for path in paths:
- new_value = metadata_fetcher.get_value(path, is_text)
+ new_value = metadata_fetcher.get_value(path, is_text, is_recursive)
if new_value is not None:
value = new_value
if required and value is None:
@@ -149,17 +200,23 @@ def read_md(address=None, platform_check=True):
return ret
md[mkey] = value
- if md['public-keys']:
- lines = md['public-keys'].splitlines()
- md['public-keys'] = [_trim_key(k) for k in lines]
+ instance_data = json.loads(md['instance-data'] or '{}')
+ project_data = json.loads(md['project-data'] or '{}')
+ valid_keys = [instance_data.get('sshKeys'), instance_data.get('ssh-keys')]
+ block_project = instance_data.get('block-project-ssh-keys', '').lower()
+ if block_project != 'true' and not instance_data.get('sshKeys'):
+ valid_keys.append(project_data.get('ssh-keys'))
+ valid_keys.append(project_data.get('sshKeys'))
+ public_keys_data = '\n'.join([key for key in valid_keys if key])
+ md['public-keys-data'] = public_keys_data.splitlines()
if md['availability-zone']:
md['availability-zone'] = md['availability-zone'].split('/')[-1]
- encoding = md.get('user-data-encoding')
+ encoding = instance_data.get('user-data-encoding')
if encoding:
if encoding == 'base64':
- md['user-data'] = b64decode(md['user-data'])
+ md['user-data'] = b64decode(instance_data.get('user-data'))
else:
LOG.warning('unknown user-data-encoding: %s, ignoring', encoding)
@@ -188,20 +245,19 @@ def platform_reports_gce():
return False
-# Used to match classes to dependencies
+# Used to match classes to dependencies.
datasources = [
(DataSourceGCE, (sources.DEP_FILESYSTEM, sources.DEP_NETWORK)),
]
-# Return a list of data sources that match this set of dependencies
+# Return a list of data sources that match this set of dependencies.
def get_datasource_list(depends):
return sources.list_from_depends(depends, datasources)
if __name__ == "__main__":
import argparse
- import json
import sys
from base64 import b64encode
@@ -217,7 +273,7 @@ if __name__ == "__main__":
data = read_md(address=args.endpoint, platform_check=args.platform_check)
if 'user-data' in data:
# user-data is bytes not string like other things. Handle it specially.
- # if it can be represented as utf-8 then do so. Otherwise print base64
+ # If it can be represented as utf-8 then do so. Otherwise print base64
# encoded value in the key user-data-b64.
try:
data['user-data'] = data['user-data'].decode()
@@ -225,7 +281,7 @@ if __name__ == "__main__":
sys.stderr.write("User-data cannot be decoded. "
"Writing as base64\n")
del data['user-data']
- # b64encode returns a bytes value. decode to get the string.
+ # b64encode returns a bytes value. Decode to get the string.
data['user-data-b64'] = b64encode(data['user-data']).decode()
print(json.dumps(data, indent=1, sort_keys=True, separators=(',', ': ')))