summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/sources/DataSourceGCE.py134
-rw-r--r--tests/unittests/test_datasource/test_gce.py193
2 files changed, 267 insertions, 60 deletions
diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py
index ad6dae37..2da34a99 100644
--- a/cloudinit/sources/DataSourceGCE.py
+++ b/cloudinit/sources/DataSourceGCE.py
@@ -2,8 +2,12 @@
#
# This file is part of cloud-init. See LICENSE file for license information.
+import datetime
+import json
+
from base64 import b64decode
+from cloudinit.distros import ug_util
from cloudinit import log as logging
from cloudinit import sources
from cloudinit import url_helper
@@ -17,16 +21,18 @@ REQUIRED_FIELDS = ('instance-id', 'availability-zone', 'local-hostname')
class GoogleMetadataFetcher(object):
- headers = {'X-Google-Metadata-Request': 'True'}
+ headers = {'Metadata-Flavor': 'Google'}
def __init__(self, metadata_address):
self.metadata_address = metadata_address
- def get_value(self, path, is_text):
+ def get_value(self, path, is_text, is_recursive=False):
value = None
try:
- resp = url_helper.readurl(url=self.metadata_address + path,
- headers=self.headers)
+ url = self.metadata_address + path
+ if is_recursive:
+ url += '/?recursive=True'
+ resp = url_helper.readurl(url=url, headers=self.headers)
except url_helper.UrlError as exc:
msg = "url %s raised exception %s"
LOG.debug(msg, path, exc)
@@ -35,7 +41,7 @@ class GoogleMetadataFetcher(object):
if is_text:
value = util.decode_binary(resp.contents)
else:
- value = resp.contents
+ value = resp.contents.decode('utf-8')
else:
LOG.debug("url %s returned code %s", path, resp.code)
return value
@@ -47,6 +53,10 @@ class DataSourceGCE(sources.DataSource):
def __init__(self, sys_cfg, distro, paths):
sources.DataSource.__init__(self, sys_cfg, distro, paths)
+ self.default_user = None
+ if distro:
+ (users, _groups) = ug_util.normalize_users_groups(sys_cfg, distro)
+ (self.default_user, _user_config) = ug_util.extract_default(users)
self.metadata = dict()
self.ds_cfg = util.mergemanydict([
util.get_cfg_by_path(sys_cfg, ["datasource", "GCE"], {}),
@@ -70,17 +80,18 @@ class DataSourceGCE(sources.DataSource):
@property
def launch_index(self):
- # GCE does not provide lauch_index property
+ # GCE does not provide lauch_index property.
return None
def get_instance_id(self):
return self.metadata['instance-id']
def get_public_ssh_keys(self):
- return self.metadata['public-keys']
+ public_keys_data = self.metadata['public-keys-data']
+ return _parse_public_keys(public_keys_data, self.default_user)
def get_hostname(self, fqdn=False, resolve_ip=False):
- # GCE has long FDQN's and has asked for short hostnames
+ # GCE has long FDQN's and has asked for short hostnames.
return self.metadata['local-hostname'].split('.')[0]
@property
@@ -92,15 +103,58 @@ class DataSourceGCE(sources.DataSource):
return self.availability_zone.rsplit('-', 1)[0]
-def _trim_key(public_key):
- # GCE takes sshKeys attribute in the format of '<user>:<public_key>'
- # so we have to trim each key to remove the username part
+def _has_expired(public_key):
+ # Check whether an SSH key is expired. Public key input is a single SSH
+ # public key in the GCE specific key format documented here:
+ # https://cloud.google.com/compute/docs/instances/adding-removing-ssh-keys#sshkeyformat
+ try:
+ # Check for the Google-specific schema identifier.
+ schema, json_str = public_key.split(None, 3)[2:]
+ except (ValueError, AttributeError):
+ return False
+
+ # Do not expire keys if they do not have the expected schema identifier.
+ if schema != 'google-ssh':
+ return False
+
+ try:
+ json_obj = json.loads(json_str)
+ except ValueError:
+ return False
+
+ # Do not expire keys if there is no expriation timestamp.
+ if 'expireOn' not in json_obj:
+ return False
+
+ expire_str = json_obj['expireOn']
+ format_str = '%Y-%m-%dT%H:%M:%S+0000'
try:
- index = public_key.index(':')
- if index > 0:
- return public_key[(index + 1):]
- except Exception:
- return public_key
+ expire_time = datetime.datetime.strptime(expire_str, format_str)
+ except ValueError:
+ return False
+
+ # Expire the key if and only if we have exceeded the expiration timestamp.
+ return datetime.datetime.utcnow() > expire_time
+
+
+def _parse_public_keys(public_keys_data, default_user=None):
+ # Parse the SSH key data for the default user account. Public keys input is
+ # a list containing SSH public keys in the GCE specific key format
+ # documented here:
+ # https://cloud.google.com/compute/docs/instances/adding-removing-ssh-keys#sshkeyformat
+ public_keys = []
+ if not public_keys_data:
+ return public_keys
+ for public_key in public_keys_data:
+ if not public_key or not all(ord(c) < 128 for c in public_key):
+ continue
+ split_public_key = public_key.split(':', 1)
+ if len(split_public_key) != 2:
+ continue
+ user, key = split_public_key
+ if user in ('cloudinit', default_user) and not _has_expired(key):
+ public_keys.append(key)
+ return public_keys
def read_md(address=None, platform_check=True):
@@ -116,31 +170,28 @@ def read_md(address=None, platform_check=True):
ret['reason'] = "Not running on GCE."
return ret
- # if we cannot resolve the metadata server, then no point in trying
+ # If we cannot resolve the metadata server, then no point in trying.
if not util.is_resolvable_url(address):
LOG.debug("%s is not resolvable", address)
ret['reason'] = 'address "%s" is not resolvable' % address
return ret
- # url_map: (our-key, path, required, is_text)
+ # url_map: (our-key, path, required, is_text, is_recursive)
url_map = [
- ('instance-id', ('instance/id',), True, True),
- ('availability-zone', ('instance/zone',), True, True),
- ('local-hostname', ('instance/hostname',), True, True),
- ('public-keys', ('project/attributes/sshKeys',
- 'instance/attributes/ssh-keys'), False, True),
- ('user-data', ('instance/attributes/user-data',), False, False),
- ('user-data-encoding', ('instance/attributes/user-data-encoding',),
- False, True),
+ ('instance-id', ('instance/id',), True, True, False),
+ ('availability-zone', ('instance/zone',), True, True, False),
+ ('local-hostname', ('instance/hostname',), True, True, False),
+ ('instance-data', ('instance/attributes',), False, False, True),
+ ('project-data', ('project/attributes',), False, False, True),
]
metadata_fetcher = GoogleMetadataFetcher(address)
md = {}
- # iterate over url_map keys to get metadata items
- for (mkey, paths, required, is_text) in url_map:
+ # Iterate over url_map keys to get metadata items.
+ for (mkey, paths, required, is_text, is_recursive) in url_map:
value = None
for path in paths:
- new_value = metadata_fetcher.get_value(path, is_text)
+ new_value = metadata_fetcher.get_value(path, is_text, is_recursive)
if new_value is not None:
value = new_value
if required and value is None:
@@ -149,17 +200,23 @@ def read_md(address=None, platform_check=True):
return ret
md[mkey] = value
- if md['public-keys']:
- lines = md['public-keys'].splitlines()
- md['public-keys'] = [_trim_key(k) for k in lines]
+ instance_data = json.loads(md['instance-data'] or '{}')
+ project_data = json.loads(md['project-data'] or '{}')
+ valid_keys = [instance_data.get('sshKeys'), instance_data.get('ssh-keys')]
+ block_project = instance_data.get('block-project-ssh-keys', '').lower()
+ if block_project != 'true' and not instance_data.get('sshKeys'):
+ valid_keys.append(project_data.get('ssh-keys'))
+ valid_keys.append(project_data.get('sshKeys'))
+ public_keys_data = '\n'.join([key for key in valid_keys if key])
+ md['public-keys-data'] = public_keys_data.splitlines()
if md['availability-zone']:
md['availability-zone'] = md['availability-zone'].split('/')[-1]
- encoding = md.get('user-data-encoding')
+ encoding = instance_data.get('user-data-encoding')
if encoding:
if encoding == 'base64':
- md['user-data'] = b64decode(md['user-data'])
+ md['user-data'] = b64decode(instance_data.get('user-data'))
else:
LOG.warning('unknown user-data-encoding: %s, ignoring', encoding)
@@ -188,20 +245,19 @@ def platform_reports_gce():
return False
-# Used to match classes to dependencies
+# Used to match classes to dependencies.
datasources = [
(DataSourceGCE, (sources.DEP_FILESYSTEM, sources.DEP_NETWORK)),
]
-# Return a list of data sources that match this set of dependencies
+# Return a list of data sources that match this set of dependencies.
def get_datasource_list(depends):
return sources.list_from_depends(depends, datasources)
if __name__ == "__main__":
import argparse
- import json
import sys
from base64 import b64encode
@@ -217,7 +273,7 @@ if __name__ == "__main__":
data = read_md(address=args.endpoint, platform_check=args.platform_check)
if 'user-data' in data:
# user-data is bytes not string like other things. Handle it specially.
- # if it can be represented as utf-8 then do so. Otherwise print base64
+ # If it can be represented as utf-8 then do so. Otherwise print base64
# encoded value in the key user-data-b64.
try:
data['user-data'] = data['user-data'].decode()
@@ -225,7 +281,7 @@ if __name__ == "__main__":
sys.stderr.write("User-data cannot be decoded. "
"Writing as base64\n")
del data['user-data']
- # b64encode returns a bytes value. decode to get the string.
+ # b64encode returns a bytes value. Decode to get the string.
data['user-data-b64'] = b64encode(data['user-data']).decode()
print(json.dumps(data, indent=1, sort_keys=True, separators=(',', ': ')))
diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py
index 82c788dc..12d68009 100644
--- a/tests/unittests/test_datasource/test_gce.py
+++ b/tests/unittests/test_datasource/test_gce.py
@@ -4,13 +4,16 @@
#
# This file is part of cloud-init. See LICENSE file for license information.
+import datetime
import httpretty
+import json
import mock
import re
from base64 import b64encode, b64decode
from six.moves.urllib_parse import urlparse
+from cloudinit import distros
from cloudinit import helpers
from cloudinit import settings
from cloudinit.sources import DataSourceGCE
@@ -21,10 +24,7 @@ from cloudinit.tests import helpers as test_helpers
GCE_META = {
'instance/id': '123',
'instance/zone': 'foo/bar',
- 'project/attributes/sshKeys': 'user:ssh-rsa AA2..+aRD0fyVw== root@server',
'instance/hostname': 'server.project-foo.local',
- # UnicodeDecodeError below if set to ds.userdata instead of userdata_raw
- 'instance/attributes/user-data': b'/bin/echo \xff\n',
}
GCE_META_PARTIAL = {
@@ -37,11 +37,13 @@ GCE_META_ENCODING = {
'instance/id': '12345',
'instance/hostname': 'server.project-baz.local',
'instance/zone': 'baz/bang',
- 'instance/attributes/user-data': b64encode(b'/bin/echo baz\n'),
- 'instance/attributes/user-data-encoding': 'base64',
+ 'instance/attributes': {
+ 'user-data': b64encode(b'/bin/echo baz\n').decode('utf-8'),
+ 'user-data-encoding': 'base64',
+ }
}
-HEADERS = {'X-Google-Metadata-Request': 'True'}
+HEADERS = {'Metadata-Flavor': 'Google'}
MD_URL_RE = re.compile(
r'http://metadata.google.internal/computeMetadata/v1/.*')
@@ -54,10 +56,15 @@ def _set_mock_metadata(gce_meta=None):
url_path = urlparse(uri).path
if url_path.startswith('/computeMetadata/v1/'):
path = url_path.split('/computeMetadata/v1/')[1:][0]
+ recursive = path.endswith('/')
+ path = path.rstrip('/')
else:
path = None
if path in gce_meta:
- return (200, headers, gce_meta.get(path))
+ response = gce_meta.get(path)
+ if recursive:
+ response = json.dumps(response)
+ return (200, headers, response)
else:
return (404, headers, '')
@@ -69,6 +76,16 @@ def _set_mock_metadata(gce_meta=None):
@httpretty.activate
class TestDataSourceGCE(test_helpers.HttprettyTestCase):
+ def _make_distro(self, dtype, def_user=None):
+ cfg = dict(settings.CFG_BUILTIN)
+ cfg['system_info']['distro'] = dtype
+ paths = helpers.Paths(cfg['system_info']['paths'])
+ distro_cls = distros.fetch(dtype)
+ if def_user:
+ cfg['system_info']['default_user'] = def_user.copy()
+ distro = distro_cls(dtype, cfg['system_info'], paths)
+ return distro
+
def setUp(self):
tmp = self.tmp_dir()
self.ds = DataSourceGCE.DataSourceGCE(
@@ -90,6 +107,10 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
self.assertDictContainsSubset(HEADERS, req_header)
def test_metadata(self):
+ # UnicodeDecodeError if set to ds.userdata instead of userdata_raw
+ meta = GCE_META.copy()
+ meta['instance/attributes/user-data'] = b'/bin/echo \xff\n'
+
_set_mock_metadata()
self.ds.get_data()
@@ -118,8 +139,8 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
_set_mock_metadata(GCE_META_ENCODING)
self.ds.get_data()
- decoded = b64decode(
- GCE_META_ENCODING.get('instance/attributes/user-data'))
+ instance_data = GCE_META_ENCODING.get('instance/attributes')
+ decoded = b64decode(instance_data.get('user-data'))
self.assertEqual(decoded, self.ds.get_userdata_raw())
def test_missing_required_keys_return_false(self):
@@ -131,33 +152,124 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
self.assertEqual(False, self.ds.get_data())
httpretty.reset()
- def test_project_level_ssh_keys_are_used(self):
+ def test_no_ssh_keys_metadata(self):
_set_mock_metadata()
self.ds.get_data()
+ self.assertEqual([], self.ds.get_public_ssh_keys())
+
+ def test_cloudinit_ssh_keys(self):
+ valid_key = 'ssh-rsa VALID {0}'
+ invalid_key = 'ssh-rsa INVALID {0}'
+ project_attributes = {
+ 'sshKeys': '\n'.join([
+ 'cloudinit:{0}'.format(valid_key.format(0)),
+ 'user:{0}'.format(invalid_key.format(0)),
+ ]),
+ 'ssh-keys': '\n'.join([
+ 'cloudinit:{0}'.format(valid_key.format(1)),
+ 'user:{0}'.format(invalid_key.format(1)),
+ ]),
+ }
+ instance_attributes = {
+ 'ssh-keys': '\n'.join([
+ 'cloudinit:{0}'.format(valid_key.format(2)),
+ 'user:{0}'.format(invalid_key.format(2)),
+ ]),
+ 'block-project-ssh-keys': 'False',
+ }
+
+ meta = GCE_META.copy()
+ meta['project/attributes'] = project_attributes
+ meta['instance/attributes'] = instance_attributes
+
+ _set_mock_metadata(meta)
+ self.ds.get_data()
+
+ expected = [valid_key.format(key) for key in range(3)]
+ self.assertEqual(set(expected), set(self.ds.get_public_ssh_keys()))
+
+ @mock.patch("cloudinit.sources.DataSourceGCE.ug_util")
+ def test_default_user_ssh_keys(self, mock_ug_util):
+ mock_ug_util.normalize_users_groups.return_value = None, None
+ mock_ug_util.extract_default.return_value = 'ubuntu', None
+ ubuntu_ds = DataSourceGCE.DataSourceGCE(
+ settings.CFG_BUILTIN, self._make_distro('ubuntu'),
+ helpers.Paths({}))
+
+ valid_key = 'ssh-rsa VALID {0}'
+ invalid_key = 'ssh-rsa INVALID {0}'
+ project_attributes = {
+ 'sshKeys': '\n'.join([
+ 'ubuntu:{0}'.format(valid_key.format(0)),
+ 'user:{0}'.format(invalid_key.format(0)),
+ ]),
+ 'ssh-keys': '\n'.join([
+ 'ubuntu:{0}'.format(valid_key.format(1)),
+ 'user:{0}'.format(invalid_key.format(1)),
+ ]),
+ }
+ instance_attributes = {
+ 'ssh-keys': '\n'.join([
+ 'ubuntu:{0}'.format(valid_key.format(2)),
+ 'user:{0}'.format(invalid_key.format(2)),
+ ]),
+ 'block-project-ssh-keys': 'False',
+ }
- # we expect a list of public ssh keys with user names stripped
- self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'],
- self.ds.get_public_ssh_keys())
+ meta = GCE_META.copy()
+ meta['project/attributes'] = project_attributes
+ meta['instance/attributes'] = instance_attributes
+
+ _set_mock_metadata(meta)
+ ubuntu_ds.get_data()
+
+ expected = [valid_key.format(key) for key in range(3)]
+ self.assertEqual(set(expected), set(ubuntu_ds.get_public_ssh_keys()))
+
+ def test_instance_ssh_keys_override(self):
+ valid_key = 'ssh-rsa VALID {0}'
+ invalid_key = 'ssh-rsa INVALID {0}'
+ project_attributes = {
+ 'sshKeys': 'cloudinit:{0}'.format(invalid_key.format(0)),
+ 'ssh-keys': 'cloudinit:{0}'.format(invalid_key.format(1)),
+ }
+ instance_attributes = {
+ 'sshKeys': 'cloudinit:{0}'.format(valid_key.format(0)),
+ 'ssh-keys': 'cloudinit:{0}'.format(valid_key.format(1)),
+ 'block-project-ssh-keys': 'False',
+ }
- def test_instance_level_ssh_keys_are_used(self):
- key_content = 'ssh-rsa JustAUser root@server'
meta = GCE_META.copy()
- meta['instance/attributes/ssh-keys'] = 'user:{0}'.format(key_content)
+ meta['project/attributes'] = project_attributes
+ meta['instance/attributes'] = instance_attributes
_set_mock_metadata(meta)
self.ds.get_data()
- self.assertIn(key_content, self.ds.get_public_ssh_keys())
+ expected = [valid_key.format(key) for key in range(2)]
+ self.assertEqual(set(expected), set(self.ds.get_public_ssh_keys()))
+
+ def test_block_project_ssh_keys_override(self):
+ valid_key = 'ssh-rsa VALID {0}'
+ invalid_key = 'ssh-rsa INVALID {0}'
+ project_attributes = {
+ 'sshKeys': 'cloudinit:{0}'.format(invalid_key.format(0)),
+ 'ssh-keys': 'cloudinit:{0}'.format(invalid_key.format(1)),
+ }
+ instance_attributes = {
+ 'ssh-keys': 'cloudinit:{0}'.format(valid_key.format(0)),
+ 'block-project-ssh-keys': 'True',
+ }
- def test_instance_level_keys_replace_project_level_keys(self):
- key_content = 'ssh-rsa JustAUser root@server'
meta = GCE_META.copy()
- meta['instance/attributes/ssh-keys'] = 'user:{0}'.format(key_content)
+ meta['project/attributes'] = project_attributes
+ meta['instance/attributes'] = instance_attributes
_set_mock_metadata(meta)
self.ds.get_data()
- self.assertEqual([key_content], self.ds.get_public_ssh_keys())
+ expected = [valid_key.format(0)]
+ self.assertEqual(set(expected), set(self.ds.get_public_ssh_keys()))
def test_only_last_part_of_zone_used_for_availability_zone(self):
_set_mock_metadata()
@@ -172,5 +284,44 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
self.assertEqual(False, ret)
m_fetcher.assert_not_called()
+ def test_has_expired(self):
+
+ def _get_timestamp(days):
+ format_str = '%Y-%m-%dT%H:%M:%S+0000'
+ today = datetime.datetime.now()
+ timestamp = today + datetime.timedelta(days=days)
+ return timestamp.strftime(format_str)
+
+ past = _get_timestamp(-1)
+ future = _get_timestamp(1)
+ ssh_keys = {
+ None: False,
+ '': False,
+ 'Invalid': False,
+ 'user:ssh-rsa key user@domain.com': False,
+ 'user:ssh-rsa key google {"expireOn":"%s"}' % past: False,
+ 'user:ssh-rsa key google-ssh': False,
+ 'user:ssh-rsa key google-ssh {invalid:json}': False,
+ 'user:ssh-rsa key google-ssh {"userName":"user"}': False,
+ 'user:ssh-rsa key google-ssh {"expireOn":"invalid"}': False,
+ 'user:xyz key google-ssh {"expireOn":"%s"}' % future: False,
+ 'user:xyz key google-ssh {"expireOn":"%s"}' % past: True,
+ }
+
+ for key, expired in ssh_keys.items():
+ self.assertEqual(DataSourceGCE._has_expired(key), expired)
+
+ def test_parse_public_keys_non_ascii(self):
+ public_key_data = [
+ 'cloudinit:rsa ssh-ke%s invalid' % chr(165),
+ 'use%sname:rsa ssh-key' % chr(174),
+ 'cloudinit:test 1',
+ 'default:test 2',
+ 'user:test 3',
+ ]
+ expected = ['test 1', 'test 2']
+ found = DataSourceGCE._parse_public_keys(
+ public_key_data, default_user='default')
+ self.assertEqual(sorted(found), sorted(expected))
# vi: ts=4 expandtab