diff options
Diffstat (limited to 'cloudinit')
25 files changed, 418 insertions, 56 deletions
diff --git a/cloudinit/analyze/__main__.py b/cloudinit/analyze/__main__.py index 69b9e43e..3ba5903f 100644 --- a/cloudinit/analyze/__main__.py +++ b/cloudinit/analyze/__main__.py @@ -6,6 +6,8 @@ import argparse import re import sys +from cloudinit.util import json_dumps + from . import dump from . import show @@ -112,7 +114,7 @@ def analyze_show(name, args): def analyze_dump(name, args): """Dump cloud-init events in json format""" (infh, outfh) = configure_io(args) - outfh.write(dump.json_dumps(_get_events(infh)) + '\n') + outfh.write(json_dumps(_get_events(infh)) + '\n') def _get_events(infile): diff --git a/cloudinit/analyze/dump.py b/cloudinit/analyze/dump.py index ca4da496..b071aa19 100644 --- a/cloudinit/analyze/dump.py +++ b/cloudinit/analyze/dump.py @@ -2,7 +2,6 @@ import calendar from datetime import datetime -import json import sys from cloudinit import util @@ -132,11 +131,6 @@ def parse_ci_logline(line): return event -def json_dumps(data): - return json.dumps(data, indent=1, sort_keys=True, - separators=(',', ': ')) - - def dump_events(cisource=None, rawdata=None): events = [] event = None @@ -169,7 +163,7 @@ def main(): else: cisource = sys.stdin - return json_dumps(dump_events(cisource)) + return util.json_dumps(dump_events(cisource)) if __name__ == "__main__": diff --git a/cloudinit/sources/DataSourceAliYun.py b/cloudinit/sources/DataSourceAliYun.py index 43a7e42c..7ac8288d 100644 --- a/cloudinit/sources/DataSourceAliYun.py +++ b/cloudinit/sources/DataSourceAliYun.py @@ -11,6 +11,7 @@ ALIYUN_PRODUCT = "Alibaba Cloud ECS" class DataSourceAliYun(EC2.DataSourceEc2): + dsname = 'AliYun' metadata_urls = ['http://100.100.100.200'] # The minimum supported metadata_version from the ec2 metadata apis diff --git a/cloudinit/sources/DataSourceAltCloud.py b/cloudinit/sources/DataSourceAltCloud.py index c78ad9eb..be2d6cf8 100644 --- a/cloudinit/sources/DataSourceAltCloud.py +++ b/cloudinit/sources/DataSourceAltCloud.py @@ -74,6 +74,9 @@ def read_user_data_callback(mount_dir): class DataSourceAltCloud(sources.DataSource): + + dsname = 'AltCloud' + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.seed = None @@ -112,7 +115,7 @@ class DataSourceAltCloud(sources.DataSource): return 'UNKNOWN' - def get_data(self): + def _get_data(self): ''' Description: User Data is passed to the launching instance which diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 14367e9c..6978d4e5 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -246,6 +246,8 @@ def temporary_hostname(temp_hostname, cfg, hostname_command='hostname'): class DataSourceAzure(sources.DataSource): + + dsname = 'Azure' _negotiated = False def __init__(self, sys_cfg, distro, paths): @@ -330,7 +332,7 @@ class DataSourceAzure(sources.DataSource): metadata['public-keys'] = key_value or pubkeys_from_crt_files(fp_files) return metadata - def get_data(self): + def _get_data(self): # azure removes/ejects the cdrom containing the ovf-env.xml # file on reboot. So, in order to successfully reboot we # need to look in the datadir and consider that valid diff --git a/cloudinit/sources/DataSourceBigstep.py b/cloudinit/sources/DataSourceBigstep.py index d7fcd45a..699a85b5 100644 --- a/cloudinit/sources/DataSourceBigstep.py +++ b/cloudinit/sources/DataSourceBigstep.py @@ -16,13 +16,16 @@ LOG = logging.getLogger(__name__) class DataSourceBigstep(sources.DataSource): + + dsname = 'Bigstep' + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.metadata = {} self.vendordata_raw = "" self.userdata_raw = "" - def get_data(self, apply_filter=False): + def _get_data(self, apply_filter=False): url = get_url_from_file() if url is None: return False diff --git a/cloudinit/sources/DataSourceCloudSigma.py b/cloudinit/sources/DataSourceCloudSigma.py index 19df16b1..4eaad475 100644 --- a/cloudinit/sources/DataSourceCloudSigma.py +++ b/cloudinit/sources/DataSourceCloudSigma.py @@ -23,6 +23,9 @@ class DataSourceCloudSigma(sources.DataSource): For more information about CloudSigma's Server Context: http://cloudsigma-docs.readthedocs.org/en/latest/server_context.html """ + + dsname = 'CloudSigma' + def __init__(self, sys_cfg, distro, paths): self.cepko = Cepko() self.ssh_public_key = '' @@ -46,7 +49,7 @@ class DataSourceCloudSigma(sources.DataSource): LOG.warning("failed to query dmi data for system product name") return False - def get_data(self): + def _get_data(self): """ Metadata is the whole server context and /meta/cloud-config is used as userdata. diff --git a/cloudinit/sources/DataSourceCloudStack.py b/cloudinit/sources/DataSourceCloudStack.py index 9dc473fc..0df545fc 100644 --- a/cloudinit/sources/DataSourceCloudStack.py +++ b/cloudinit/sources/DataSourceCloudStack.py @@ -65,6 +65,9 @@ class CloudStackPasswordServerClient(object): class DataSourceCloudStack(sources.DataSource): + + dsname = 'CloudStack' + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.seed_dir = os.path.join(paths.seed_dir, 'cs') @@ -117,7 +120,7 @@ class DataSourceCloudStack(sources.DataSource): def get_config_obj(self): return self.cfg - def get_data(self): + def _get_data(self): seed_ret = {} if util.read_optional_seed(seed_ret, base=(self.seed_dir + "/")): self.userdata_raw = seed_ret['user-data'] diff --git a/cloudinit/sources/DataSourceConfigDrive.py b/cloudinit/sources/DataSourceConfigDrive.py index ef374f3f..870b3688 100644 --- a/cloudinit/sources/DataSourceConfigDrive.py +++ b/cloudinit/sources/DataSourceConfigDrive.py @@ -32,6 +32,9 @@ OPTICAL_DEVICES = tuple(('/dev/%s%s' % (z, i) for z in POSSIBLE_MOUNTS class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource): + + dsname = 'ConfigDrive' + def __init__(self, sys_cfg, distro, paths): super(DataSourceConfigDrive, self).__init__(sys_cfg, distro, paths) self.source = None @@ -50,7 +53,7 @@ class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource): mstr += "[source=%s]" % (self.source) return mstr - def get_data(self): + def _get_data(self): found = None md = {} results = {} diff --git a/cloudinit/sources/DataSourceDigitalOcean.py b/cloudinit/sources/DataSourceDigitalOcean.py index 5e7e66be..e0ef665e 100644 --- a/cloudinit/sources/DataSourceDigitalOcean.py +++ b/cloudinit/sources/DataSourceDigitalOcean.py @@ -27,6 +27,9 @@ MD_USE_IPV4LL = True class DataSourceDigitalOcean(sources.DataSource): + + dsname = 'DigitalOcean' + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.distro = distro @@ -44,7 +47,7 @@ class DataSourceDigitalOcean(sources.DataSource): def _get_sysinfo(self): return do_helper.read_sysinfo() - def get_data(self): + def _get_data(self): (is_do, droplet_id) = self._get_sysinfo() # only proceed if we know we are on DigitalOcean diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index 7bbbfb63..e5c88334 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -31,6 +31,7 @@ _unset = "_unset" class Platforms(object): + # TODO Rename and move to cloudinit.cloud.CloudNames ALIYUN = "AliYun" AWS = "AWS" BRIGHTBOX = "Brightbox" @@ -45,6 +46,7 @@ class Platforms(object): class DataSourceEc2(sources.DataSource): + dsname = 'Ec2' # Default metadata urls that will be used if none are provided # They will be checked for 'resolveability' and some of the # following may be discarded if they do not resolve @@ -68,11 +70,15 @@ class DataSourceEc2(sources.DataSource): _fallback_interface = None def __init__(self, sys_cfg, distro, paths): - sources.DataSource.__init__(self, sys_cfg, distro, paths) + super(DataSourceEc2, self).__init__(sys_cfg, distro, paths) self.metadata_address = None self.seed_dir = os.path.join(paths.seed_dir, "ec2") - def get_data(self): + def _get_cloud_name(self): + """Return the cloud name as identified during _get_data.""" + return self.cloud_platform + + def _get_data(self): seed_ret = {} if util.read_optional_seed(seed_ret, base=(self.seed_dir + "/")): self.userdata_raw = seed_ret['user-data'] @@ -274,7 +280,7 @@ class DataSourceEc2(sources.DataSource): return None @property - def cloud_platform(self): + def cloud_platform(self): # TODO rename cloud_name if self._cloud_platform is None: self._cloud_platform = identify_platform() return self._cloud_platform diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py index ccae4200..ad6dae37 100644 --- a/cloudinit/sources/DataSourceGCE.py +++ b/cloudinit/sources/DataSourceGCE.py @@ -42,6 +42,9 @@ class GoogleMetadataFetcher(object): class DataSourceGCE(sources.DataSource): + + dsname = 'GCE' + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.metadata = dict() @@ -50,7 +53,7 @@ class DataSourceGCE(sources.DataSource): BUILTIN_DS_CONFIG]) self.metadata_address = self.ds_cfg['metadata_url'] - def get_data(self): + def _get_data(self): ret = util.log_time( LOG.debug, 'Crawl of GCE metadata service', read_md, kwargs={'address': self.metadata_address}) diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index 77df5a51..496bd06a 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -39,6 +39,9 @@ class DataSourceMAAS(sources.DataSource): hostname vendor-data """ + + dsname = "MAAS" + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.base_url = None @@ -62,7 +65,7 @@ class DataSourceMAAS(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [%s]" % (root, self.base_url) - def get_data(self): + def _get_data(self): mcfg = self.ds_cfg try: diff --git a/cloudinit/sources/DataSourceNoCloud.py b/cloudinit/sources/DataSourceNoCloud.py index e641244d..5d3a8ddb 100644 --- a/cloudinit/sources/DataSourceNoCloud.py +++ b/cloudinit/sources/DataSourceNoCloud.py @@ -20,6 +20,9 @@ LOG = logging.getLogger(__name__) class DataSourceNoCloud(sources.DataSource): + + dsname = "NoCloud" + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.seed = None @@ -32,7 +35,7 @@ class DataSourceNoCloud(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [seed=%s][dsmode=%s]" % (root, self.seed, self.dsmode) - def get_data(self): + def _get_data(self): defaults = { "instance-id": "nocloud", "dsmode": self.dsmode, diff --git a/cloudinit/sources/DataSourceNone.py b/cloudinit/sources/DataSourceNone.py index 906bb278..e63a7e39 100644 --- a/cloudinit/sources/DataSourceNone.py +++ b/cloudinit/sources/DataSourceNone.py @@ -11,12 +11,15 @@ LOG = logging.getLogger(__name__) class DataSourceNone(sources.DataSource): + + dsname = "None" + def __init__(self, sys_cfg, distro, paths, ud_proc=None): sources.DataSource.__init__(self, sys_cfg, distro, paths, ud_proc) self.metadata = {} self.userdata_raw = '' - def get_data(self): + def _get_data(self): # If the datasource config has any provided 'fallback' # userdata or metadata, use it... if 'userdata_raw' in self.ds_cfg: diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index ccebf11a..6ac621f2 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -43,6 +43,9 @@ LOG = logging.getLogger(__name__) class DataSourceOVF(sources.DataSource): + + dsname = "OVF" + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.seed = None @@ -60,7 +63,7 @@ class DataSourceOVF(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [seed=%s]" % (root, self.seed) - def get_data(self): + def _get_data(self): found = [] md = {} ud = "" diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index 5fdac192..5da11847 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -31,6 +31,9 @@ CONTEXT_DISK_FILES = ["context.sh"] class DataSourceOpenNebula(sources.DataSource): + + dsname = "OpenNebula" + def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.seed = None @@ -40,7 +43,7 @@ class DataSourceOpenNebula(sources.DataSource): root = sources.DataSource.__str__(self) return "%s [seed=%s][dsmode=%s]" % (root, self.seed, self.dsmode) - def get_data(self): + def _get_data(self): defaults = {"instance-id": DEFAULT_IID} results = None seed = None diff --git a/cloudinit/sources/DataSourceOpenStack.py b/cloudinit/sources/DataSourceOpenStack.py index b64a7f24..e55a7638 100644 --- a/cloudinit/sources/DataSourceOpenStack.py +++ b/cloudinit/sources/DataSourceOpenStack.py @@ -24,6 +24,9 @@ DEFAULT_METADATA = { class DataSourceOpenStack(openstack.SourceMixin, sources.DataSource): + + dsname = "OpenStack" + def __init__(self, sys_cfg, distro, paths): super(DataSourceOpenStack, self).__init__(sys_cfg, distro, paths) self.metadata_address = None @@ -96,7 +99,7 @@ class DataSourceOpenStack(openstack.SourceMixin, sources.DataSource): self.metadata_address = url2base.get(avail_url) return bool(avail_url) - def get_data(self): + def _get_data(self): try: if not self.wait_for_metadata_service(): return False diff --git a/cloudinit/sources/DataSourceScaleway.py b/cloudinit/sources/DataSourceScaleway.py index 3a8a8e8f..b0b19c93 100644 --- a/cloudinit/sources/DataSourceScaleway.py +++ b/cloudinit/sources/DataSourceScaleway.py @@ -169,6 +169,8 @@ def query_data_api(api_type, api_address, retries, timeout): class DataSourceScaleway(sources.DataSource): + dsname = "Scaleway" + def __init__(self, sys_cfg, distro, paths): super(DataSourceScaleway, self).__init__(sys_cfg, distro, paths) @@ -184,7 +186,7 @@ class DataSourceScaleway(sources.DataSource): self.retries = int(self.ds_cfg.get('retries', DEF_MD_RETRIES)) self.timeout = int(self.ds_cfg.get('timeout', DEF_MD_TIMEOUT)) - def get_data(self): + def _get_data(self): if not on_scaleway(): return False diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 6c6902fd..86bfa5d8 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -159,6 +159,9 @@ LEGACY_USER_D = "/var/db" class DataSourceSmartOS(sources.DataSource): + + dsname = "Joyent" + _unset = "_unset" smartos_type = _unset md_client = _unset @@ -211,7 +214,7 @@ class DataSourceSmartOS(sources.DataSource): os.rename('/'.join([svc_path, 'provisioning']), '/'.join([svc_path, 'provision_success'])) - def get_data(self): + def _get_data(self): self._init() md = {} diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 9a43fbee..4b819ce6 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -10,9 +10,11 @@ import abc import copy +import json import os import six +from cloudinit.atomic_helper import write_json from cloudinit import importer from cloudinit import log as logging from cloudinit import type_utils @@ -33,6 +35,12 @@ DEP_FILESYSTEM = "FILESYSTEM" DEP_NETWORK = "NETWORK" DS_PREFIX = 'DataSource' +# File in which instance meta-data, user-data and vendor-data is written +INSTANCE_JSON_FILE = 'instance-data.json' + +# Key which can be provide a cloud's official product name to cloud-init +METADATA_CLOUD_NAME_KEY = 'cloud-name' + LOG = logging.getLogger(__name__) @@ -40,12 +48,39 @@ class DataSourceNotFoundException(Exception): pass +def process_base64_metadata(metadata, key_path=''): + """Strip ci-b64 prefix and return metadata with base64-encoded-keys set.""" + md_copy = copy.deepcopy(metadata) + md_copy['base64-encoded-keys'] = [] + for key, val in metadata.items(): + if key_path: + sub_key_path = key_path + '/' + key + else: + sub_key_path = key + if isinstance(val, str) and val.startswith('ci-b64:'): + md_copy['base64-encoded-keys'].append(sub_key_path) + md_copy[key] = val.replace('ci-b64:', '') + if isinstance(val, dict): + return_val = process_base64_metadata(val, sub_key_path) + md_copy['base64-encoded-keys'].extend( + return_val.pop('base64-encoded-keys')) + md_copy[key] = return_val + return md_copy + + @six.add_metaclass(abc.ABCMeta) class DataSource(object): dsmode = DSMODE_NETWORK default_locale = 'en_US.UTF-8' + # Datasource name needs to be set by subclasses to determine which + # cloud-config datasource key is loaded + dsname = '_undef' + + # Cached cloud_name as determined by _get_cloud_name + _cloud_name = None + def __init__(self, sys_cfg, distro, paths, ud_proc=None): self.sys_cfg = sys_cfg self.distro = distro @@ -56,17 +91,8 @@ class DataSource(object): self.vendordata = None self.vendordata_raw = None - # find the datasource config name. - # remove 'DataSource' from classname on front, and remove 'Net' on end. - # Both Foo and FooNet sources expect config in cfg['sources']['Foo'] - name = type_utils.obj_name(self) - if name.startswith(DS_PREFIX): - name = name[len(DS_PREFIX):] - if name.endswith('Net'): - name = name[0:-3] - - self.ds_cfg = util.get_cfg_by_path(self.sys_cfg, - ("datasource", name), {}) + self.ds_cfg = util.get_cfg_by_path( + self.sys_cfg, ("datasource", self.dsname), {}) if not self.ds_cfg: self.ds_cfg = {} @@ -78,6 +104,51 @@ class DataSource(object): def __str__(self): return type_utils.obj_name(self) + def _get_standardized_metadata(self): + """Return a dictionary of standardized metadata keys.""" + return {'v1': { + 'local-hostname': self.get_hostname(), + 'instance-id': self.get_instance_id(), + 'cloud-name': self.cloud_name, + 'region': self.region, + 'availability-zone': self.availability_zone}} + + def get_data(self): + """Datasources implement _get_data to setup metadata and userdata_raw. + + Minimally, the datasource should return a boolean True on success. + """ + return_value = self._get_data() + json_file = os.path.join(self.paths.run_dir, INSTANCE_JSON_FILE) + if not return_value: + return return_value + + instance_data = { + 'ds': { + 'meta-data': self.metadata, + 'user-data': self.get_userdata_raw(), + 'vendor-data': self.get_vendordata_raw()}} + instance_data.update( + self._get_standardized_metadata()) + try: + # Process content base64encoding unserializable values + content = util.json_dumps(instance_data) + # Strip base64: prefix and return base64-encoded-keys + processed_data = process_base64_metadata(json.loads(content)) + except TypeError as e: + LOG.warning('Error persisting instance-data.json: %s', str(e)) + return return_value + except UnicodeDecodeError as e: + LOG.warning('Error persisting instance-data.json: %s', str(e)) + return return_value + write_json(json_file, processed_data, mode=0o600) + return return_value + + def _get_data(self): + raise NotImplementedError( + 'Subclasses of DataSource must implement _get_data which' + ' sets self.metadata, vendordata_raw and userdata_raw.') + def get_userdata(self, apply_filter=False): if self.userdata is None: self.userdata = self.ud_proc.process(self.get_userdata_raw()) @@ -91,6 +162,34 @@ class DataSource(object): return self.vendordata @property + def cloud_name(self): + """Return lowercase cloud name as determined by the datasource. + + Datasource can determine or define its own cloud product name in + metadata. + """ + if self._cloud_name: + return self._cloud_name + if self.metadata and self.metadata.get(METADATA_CLOUD_NAME_KEY): + cloud_name = self.metadata.get(METADATA_CLOUD_NAME_KEY) + if isinstance(cloud_name, six.string_types): + self._cloud_name = cloud_name.lower() + LOG.debug( + 'Ignoring metadata provided key %s: non-string type %s', + METADATA_CLOUD_NAME_KEY, type(cloud_name)) + else: + self._cloud_name = self._get_cloud_name().lower() + return self._cloud_name + + def _get_cloud_name(self): + """Return the datasource name as it frequently matches cloud name. + + Should be overridden in subclasses which can run on multiple + cloud names, such as DatasourceEc2. + """ + return self.dsname + + @property def launch_index(self): if not self.metadata: return None @@ -161,8 +260,11 @@ class DataSource(object): @property def availability_zone(self): - return self.metadata.get('availability-zone', - self.metadata.get('availability_zone')) + top_level_az = self.metadata.get( + 'availability-zone', self.metadata.get('availability_zone')) + if top_level_az: + return top_level_az + return self.metadata.get('placement', {}).get('availability-zone') @property def region(self): @@ -417,4 +519,5 @@ def list_from_depends(depends, ds_list): ret_list.append(cls) return ret_list + # vi: ts=4 expandtab diff --git a/cloudinit/sources/tests/__init__.py b/cloudinit/sources/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/cloudinit/sources/tests/__init__.py diff --git a/cloudinit/sources/tests/test_init.py b/cloudinit/sources/tests/test_init.py new file mode 100644 index 00000000..af151154 --- /dev/null +++ b/cloudinit/sources/tests/test_init.py @@ -0,0 +1,202 @@ +# This file is part of cloud-init. See LICENSE file for license information. + +import os +import six +import stat + +from cloudinit.helpers import Paths +from cloudinit.sources import ( + INSTANCE_JSON_FILE, DataSource) +from cloudinit.tests.helpers import CiTestCase, skipIf +from cloudinit.user_data import UserDataProcessor +from cloudinit import util + + +class DataSourceTestSubclassNet(DataSource): + + dsname = 'MyTestSubclass' + + def __init__(self, sys_cfg, distro, paths, custom_userdata=None): + super(DataSourceTestSubclassNet, self).__init__( + sys_cfg, distro, paths) + self._custom_userdata = custom_userdata + + def _get_cloud_name(self): + return 'SubclassCloudName' + + def _get_data(self): + self.metadata = {'availability_zone': 'myaz', + 'local-hostname': 'test-subclass-hostname', + 'region': 'myregion'} + if self._custom_userdata: + self.userdata_raw = self._custom_userdata + else: + self.userdata_raw = 'userdata_raw' + self.vendordata_raw = 'vendordata_raw' + return True + + +class InvalidDataSourceTestSubclassNet(DataSource): + pass + + +class TestDataSource(CiTestCase): + + with_logs = True + + def setUp(self): + super(TestDataSource, self).setUp() + self.sys_cfg = {'datasource': {'_undef': {'key1': False}}} + self.distro = 'distrotest' # generally should be a Distro object + self.paths = Paths({}) + self.datasource = DataSource(self.sys_cfg, self.distro, self.paths) + + def test_datasource_init(self): + """DataSource initializes metadata attributes, ds_cfg and ud_proc.""" + self.assertEqual(self.paths, self.datasource.paths) + self.assertEqual(self.sys_cfg, self.datasource.sys_cfg) + self.assertEqual(self.distro, self.datasource.distro) + self.assertIsNone(self.datasource.userdata) + self.assertEqual({}, self.datasource.metadata) + self.assertIsNone(self.datasource.userdata_raw) + self.assertIsNone(self.datasource.vendordata) + self.assertIsNone(self.datasource.vendordata_raw) + self.assertEqual({'key1': False}, self.datasource.ds_cfg) + self.assertIsInstance(self.datasource.ud_proc, UserDataProcessor) + + def test_datasource_init_gets_ds_cfg_using_dsname(self): + """Init uses DataSource.dsname for sourcing ds_cfg.""" + sys_cfg = {'datasource': {'MyTestSubclass': {'key2': False}}} + distro = 'distrotest' # generally should be a Distro object + paths = Paths({}) + datasource = DataSourceTestSubclassNet(sys_cfg, distro, paths) + self.assertEqual({'key2': False}, datasource.ds_cfg) + + def test_str_is_classname(self): + """The string representation of the datasource is the classname.""" + self.assertEqual('DataSource', str(self.datasource)) + self.assertEqual( + 'DataSourceTestSubclassNet', + str(DataSourceTestSubclassNet('', '', self.paths))) + + def test__get_data_unimplemented(self): + """Raise an error when _get_data is not implemented.""" + with self.assertRaises(NotImplementedError) as context_manager: + self.datasource.get_data() + self.assertIn( + 'Subclasses of DataSource must implement _get_data', + str(context_manager.exception)) + datasource2 = InvalidDataSourceTestSubclassNet( + self.sys_cfg, self.distro, self.paths) + with self.assertRaises(NotImplementedError) as context_manager: + datasource2.get_data() + self.assertIn( + 'Subclasses of DataSource must implement _get_data', + str(context_manager.exception)) + + def test_get_data_calls_subclass__get_data(self): + """Datasource.get_data uses the subclass' version of _get_data.""" + tmp = self.tmp_dir() + datasource = DataSourceTestSubclassNet( + self.sys_cfg, self.distro, Paths({'run_dir': tmp})) + self.assertTrue(datasource.get_data()) + self.assertEqual( + {'availability_zone': 'myaz', + 'local-hostname': 'test-subclass-hostname', + 'region': 'myregion'}, + datasource.metadata) + self.assertEqual('userdata_raw', datasource.userdata_raw) + self.assertEqual('vendordata_raw', datasource.vendordata_raw) + + def test_get_data_write_json_instance_data(self): + """get_data writes INSTANCE_JSON_FILE to run_dir as readonly root.""" + tmp = self.tmp_dir() + datasource = DataSourceTestSubclassNet( + self.sys_cfg, self.distro, Paths({'run_dir': tmp})) + datasource.get_data() + json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) + content = util.load_file(json_file) + expected = { + 'base64-encoded-keys': [], + 'v1': { + 'availability-zone': 'myaz', + 'cloud-name': 'subclasscloudname', + 'instance-id': 'iid-datasource', + 'local-hostname': 'test-subclass-hostname', + 'region': 'myregion'}, + 'ds': { + 'meta-data': {'availability_zone': 'myaz', + 'local-hostname': 'test-subclass-hostname', + 'region': 'myregion'}, + 'user-data': 'userdata_raw', + 'vendor-data': 'vendordata_raw'}} + self.assertEqual(expected, util.load_json(content)) + file_stat = os.stat(json_file) + self.assertEqual(0o600, stat.S_IMODE(file_stat.st_mode)) + + def test_get_data_handles_redacted_unserializable_content(self): + """get_data warns unserializable content in INSTANCE_JSON_FILE.""" + tmp = self.tmp_dir() + datasource = DataSourceTestSubclassNet( + self.sys_cfg, self.distro, Paths({'run_dir': tmp}), + custom_userdata={'key1': 'val1', 'key2': {'key2.1': self.paths}}) + self.assertTrue(datasource.get_data()) + json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) + content = util.load_file(json_file) + expected_userdata = { + 'key1': 'val1', + 'key2': { + 'key2.1': "Warning: redacted unserializable type <class" + " 'cloudinit.helpers.Paths'>"}} + instance_json = util.load_json(content) + self.assertEqual( + expected_userdata, instance_json['ds']['user-data']) + + @skipIf(not six.PY3, "json serialization on <= py2.7 handles bytes") + def test_get_data_base64encodes_unserializable_bytes(self): + """On py3, get_data base64encodes any unserializable content.""" + tmp = self.tmp_dir() + datasource = DataSourceTestSubclassNet( + self.sys_cfg, self.distro, Paths({'run_dir': tmp}), + custom_userdata={'key1': 'val1', 'key2': {'key2.1': b'\x123'}}) + self.assertTrue(datasource.get_data()) + json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) + content = util.load_file(json_file) + instance_json = util.load_json(content) + self.assertEqual( + ['ds/user-data/key2/key2.1'], + instance_json['base64-encoded-keys']) + self.assertEqual( + {'key1': 'val1', 'key2': {'key2.1': 'EjM='}}, + instance_json['ds']['user-data']) + + @skipIf(not six.PY2, "json serialization on <= py2.7 handles bytes") + def test_get_data_handles_bytes_values(self): + """On py2 get_data handles bytes values without having to b64encode.""" + tmp = self.tmp_dir() + datasource = DataSourceTestSubclassNet( + self.sys_cfg, self.distro, Paths({'run_dir': tmp}), + custom_userdata={'key1': 'val1', 'key2': {'key2.1': b'\x123'}}) + self.assertTrue(datasource.get_data()) + json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) + content = util.load_file(json_file) + instance_json = util.load_json(content) + self.assertEqual([], instance_json['base64-encoded-keys']) + self.assertEqual( + {'key1': 'val1', 'key2': {'key2.1': '\x123'}}, + instance_json['ds']['user-data']) + + @skipIf(not six.PY2, "Only python2 hits UnicodeDecodeErrors on non-utf8") + def test_non_utf8_encoding_logs_warning(self): + """When non-utf-8 values exist in py2 instance-data is not written.""" + tmp = self.tmp_dir() + datasource = DataSourceTestSubclassNet( + self.sys_cfg, self.distro, Paths({'run_dir': tmp}), + custom_userdata={'key1': 'val1', 'key2': {'key2.1': b'ab\xaadef'}}) + self.assertTrue(datasource.get_data()) + json_file = self.tmp_path(INSTANCE_JSON_FILE, tmp) + self.assertFalse(os.path.exists(json_file)) + self.assertIn( + "WARNING: Error persisting instance-data.json: 'utf8' codec can't" + " decode byte 0xaa in position 2: invalid start byte", + self.logs.getvalue()) diff --git a/cloudinit/tests/helpers.py b/cloudinit/tests/helpers.py index 6f88a5b7..feb884ab 100644 --- a/cloudinit/tests/helpers.py +++ b/cloudinit/tests/helpers.py @@ -3,7 +3,6 @@ from __future__ import print_function import functools -import json import logging import os import shutil @@ -337,12 +336,6 @@ def dir2dict(startdir, prefix=None): return flist -def json_dumps(data): - # print data in nicely formatted json. - return json.dumps(data, indent=1, sort_keys=True, - separators=(',', ': ')) - - def wrap_and_call(prefix, mocks, func, *args, **kwargs): """ call func(args, **kwargs) with mocks applied, then unapplies mocks diff --git a/cloudinit/util.py b/cloudinit/util.py index 320d64e0..11e96a77 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -533,15 +533,6 @@ def multi_log(text, console=True, stderr=True, log.log(log_level, text) -def load_json(text, root_types=(dict,)): - decoded = json.loads(decode_binary(text)) - if not isinstance(decoded, tuple(root_types)): - expected_types = ", ".join([str(t) for t in root_types]) - raise TypeError("(%s) root types expected, got %s instead" - % (expected_types, type(decoded))) - return decoded - - def is_ipv4(instr): """determine if input string is a ipv4 address. return boolean.""" toks = instr.split('.') @@ -1480,7 +1471,31 @@ def ensure_dirs(dirlist, mode=0o755): ensure_dir(d, mode) +def load_json(text, root_types=(dict,)): + decoded = json.loads(decode_binary(text)) + if not isinstance(decoded, tuple(root_types)): + expected_types = ", ".join([str(t) for t in root_types]) + raise TypeError("(%s) root types expected, got %s instead" + % (expected_types, type(decoded))) + return decoded + + +def json_serialize_default(_obj): + """Handler for types which aren't json serializable.""" + try: + return 'ci-b64:{0}'.format(b64e(_obj)) + except AttributeError: + return 'Warning: redacted unserializable type {0}'.format(type(_obj)) + + +def json_dumps(data): + """Return data in nicely formatted json.""" + return json.dumps(data, indent=1, sort_keys=True, + separators=(',', ': '), default=json_serialize_default) + + def yaml_dumps(obj, explicit_start=True, explicit_end=True): + """Return data in nicely formatted yaml.""" return yaml.safe_dump(obj, line_break="\n", indent=4, |