diff options
Diffstat (limited to 'cloudinit/sources')
21 files changed, 1092 insertions, 140 deletions
| diff --git a/cloudinit/sources/DataSourceAliYun.py b/cloudinit/sources/DataSourceAliYun.py index 858e0827..45cc9f00 100644 --- a/cloudinit/sources/DataSourceAliYun.py +++ b/cloudinit/sources/DataSourceAliYun.py @@ -1,7 +1,5 @@  # This file is part of cloud-init. See LICENSE file for license information. -import os -  from cloudinit import sources  from cloudinit.sources import DataSourceEc2 as EC2  from cloudinit import util @@ -18,25 +16,17 @@ class DataSourceAliYun(EC2.DataSourceEc2):      min_metadata_version = '2016-01-01'      extended_metadata_versions = [] -    def __init__(self, sys_cfg, distro, paths): -        super(DataSourceAliYun, self).__init__(sys_cfg, distro, paths) -        self.seed_dir = os.path.join(paths.seed_dir, "AliYun") -      def get_hostname(self, fqdn=False, resolve_ip=False, metadata_only=False):          return self.metadata.get('hostname', 'localhost.localdomain')      def get_public_ssh_keys(self):          return parse_public_keys(self.metadata.get('public-keys', {})) -    @property -    def cloud_platform(self): -        if self._cloud_platform is None: -            if _is_aliyun(): -                self._cloud_platform = EC2.Platforms.ALIYUN -            else: -                self._cloud_platform = EC2.Platforms.NO_EC2_METADATA - -        return self._cloud_platform +    def _get_cloud_name(self): +        if _is_aliyun(): +            return EC2.CloudNames.ALIYUN +        else: +            return EC2.CloudNames.NO_EC2_METADATA  def _is_aliyun(): diff --git a/cloudinit/sources/DataSourceAltCloud.py b/cloudinit/sources/DataSourceAltCloud.py index 8cd312d0..5270fda8 100644 --- a/cloudinit/sources/DataSourceAltCloud.py +++ b/cloudinit/sources/DataSourceAltCloud.py @@ -89,7 +89,9 @@ class DataSourceAltCloud(sources.DataSource):          '''          Description:              Get the type for the cloud back end this instance is running on -            by examining the string returned by reading the dmi data. +            by examining the string returned by reading either: +                CLOUD_INFO_FILE or +                the dmi data.          Input:              None @@ -99,7 +101,14 @@ class DataSourceAltCloud(sources.DataSource):              'RHEV', 'VSPHERE' or 'UNKNOWN'          ''' - +        if os.path.exists(CLOUD_INFO_FILE): +            try: +                cloud_type = util.load_file(CLOUD_INFO_FILE).strip().upper() +            except IOError: +                util.logexc(LOG, 'Unable to access cloud info file at %s.', +                            CLOUD_INFO_FILE) +                return 'UNKNOWN' +            return cloud_type          system_name = util.read_dmi_data("system-product-name")          if not system_name:              return 'UNKNOWN' @@ -134,15 +143,7 @@ class DataSourceAltCloud(sources.DataSource):          LOG.debug('Invoked get_data()') -        if os.path.exists(CLOUD_INFO_FILE): -            try: -                cloud_type = util.load_file(CLOUD_INFO_FILE).strip().upper() -            except IOError: -                util.logexc(LOG, 'Unable to access cloud info file at %s.', -                            CLOUD_INFO_FILE) -                return False -        else: -            cloud_type = self.get_cloud_type() +        cloud_type = self.get_cloud_type()          LOG.debug('cloud_type: %s', str(cloud_type)) @@ -161,6 +162,15 @@ class DataSourceAltCloud(sources.DataSource):          util.logexc(LOG, 'Failed accessing user data.')          return False +    def _get_subplatform(self): +        """Return the subplatform metadata details.""" +        cloud_type = self.get_cloud_type() +        if not hasattr(self, 'source'): +            self.source = sources.METADATA_UNKNOWN +        if cloud_type == 'RHEV': +            self.source = '/dev/fd0' +        return '%s (%s)' % (cloud_type.lower(), self.source) +      def user_data_rhevm(self):          '''          RHEVM specific userdata read @@ -232,6 +242,7 @@ class DataSourceAltCloud(sources.DataSource):              try:                  return_str = util.mount_cb(cdrom_dev, read_user_data_callback)                  if return_str: +                    self.source = cdrom_dev                      break              except OSError as err:                  if err.errno != errno.ENOENT: diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 629f006f..a06e6e1f 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -23,7 +23,8 @@ from cloudinit.event import EventType  from cloudinit.net.dhcp import EphemeralDHCPv4  from cloudinit import sources  from cloudinit.sources.helpers.azure import get_metadata_from_fabric -from cloudinit.url_helper import readurl, UrlError +from cloudinit.sources.helpers import netlink +from cloudinit.url_helper import UrlError, readurl, retry_on_url_exc  from cloudinit import util  LOG = logging.getLogger(__name__) @@ -58,7 +59,7 @@ IMDS_URL = "http://169.254.169.254/metadata/"  # List of static scripts and network config artifacts created by  # stock ubuntu suported images.  UBUNTU_EXTENDED_NETWORK_SCRIPTS = [ -    '/etc/netplan/90-azure-hotplug.yaml', +    '/etc/netplan/90-hotplug-azure.yaml',      '/usr/local/sbin/ephemeral_eth.sh',      '/etc/udev/rules.d/10-net-device-added.rules',      '/run/network/interfaces.ephemeral.d', @@ -208,7 +209,9 @@ BUILTIN_DS_CONFIG = {      },      'disk_aliases': {'ephemeral0': RESOURCE_DISK_PATH},      'dhclient_lease_file': LEASE_FILE, +    'apply_network_config': True,  # Use IMDS published network configuration  } +# RELEASE_BLOCKER: Xenial and earlier apply_network_config default is False  BUILTIN_CLOUD_CONFIG = {      'disk_setup': { @@ -284,6 +287,7 @@ class DataSourceAzure(sources.DataSource):          self._network_config = None          # Regenerate network config new_instance boot and every boot          self.update_events['network'].add(EventType.BOOT) +        self._ephemeral_dhcp_ctx = None      def __str__(self):          root = sources.DataSource.__str__(self) @@ -357,6 +361,14 @@ class DataSourceAzure(sources.DataSource):          metadata['public-keys'] = key_value or pubkeys_from_crt_files(fp_files)          return metadata +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        if self.seed.startswith('/dev'): +            subplatform_type = 'config-disk' +        else: +            subplatform_type = 'seed-dir' +        return '%s (%s)' % (subplatform_type, self.seed) +      def crawl_metadata(self):          """Walk all instance metadata sources returning a dict on success. @@ -402,7 +414,12 @@ class DataSourceAzure(sources.DataSource):                  LOG.warning("%s was not mountable", cdev)                  continue -            if reprovision or self._should_reprovision(ret): +            perform_reprovision = reprovision or self._should_reprovision(ret) +            if perform_reprovision: +                if util.is_FreeBSD(): +                    msg = "Free BSD is not supported for PPS VMs" +                    LOG.error(msg) +                    raise sources.InvalidMetaDataException(msg)                  ret = self._reprovision()              imds_md = get_metadata_from_imds(                  self.fallback_interface, retries=3) @@ -430,6 +447,18 @@ class DataSourceAzure(sources.DataSource):              crawled_data['metadata']['random_seed'] = seed          crawled_data['metadata']['instance-id'] = util.read_dmi_data(              'system-uuid') + +        if perform_reprovision: +            LOG.info("Reporting ready to Azure after getting ReprovisionData") +            use_cached_ephemeral = (net.is_up(self.fallback_interface) and +                                    getattr(self, '_ephemeral_dhcp_ctx', None)) +            if use_cached_ephemeral: +                self._report_ready(lease=self._ephemeral_dhcp_ctx.lease) +                self._ephemeral_dhcp_ctx.clean_network()  # Teardown ephemeral +            else: +                with EphemeralDHCPv4() as lease: +                    self._report_ready(lease=lease) +          return crawled_data      def _is_platform_viable(self): @@ -456,7 +485,8 @@ class DataSourceAzure(sources.DataSource):          except sources.InvalidMetaDataException as e:              LOG.warning('Could not crawl Azure metadata: %s', e)              return False -        if self.distro and self.distro.name == 'ubuntu': +        if (self.distro and self.distro.name == 'ubuntu' and +                self.ds_cfg.get('apply_network_config')):              maybe_remove_ubuntu_network_config_scripts()          # Process crawled data and augment with various config defaults @@ -504,8 +534,8 @@ class DataSourceAzure(sources.DataSource):          response. Then return the returned JSON object."""          url = IMDS_URL + "reprovisiondata?api-version=2017-04-02"          headers = {"Metadata": "true"} +        nl_sock = None          report_ready = bool(not os.path.isfile(REPORTED_READY_MARKER_FILE)) -        LOG.debug("Start polling IMDS")          def exc_cb(msg, exception):              if isinstance(exception, UrlError) and exception.code == 404: @@ -514,25 +544,47 @@ class DataSourceAzure(sources.DataSource):              # call DHCP and setup the ephemeral network to acquire the new IP.              return False +        LOG.debug("Wait for vnetswitch to happen")          while True:              try: -                with EphemeralDHCPv4() as lease: -                    if report_ready: -                        path = REPORTED_READY_MARKER_FILE -                        LOG.info( -                            "Creating a marker file to report ready: %s", path) -                        util.write_file(path, "{pid}: {time}\n".format( -                            pid=os.getpid(), time=time())) -                        self._report_ready(lease=lease) -                        report_ready = False +                # Save our EphemeralDHCPv4 context so we avoid repeated dhcp +                self._ephemeral_dhcp_ctx = EphemeralDHCPv4() +                lease = self._ephemeral_dhcp_ctx.obtain_lease() +                if report_ready: +                    try: +                        nl_sock = netlink.create_bound_netlink_socket() +                    except netlink.NetlinkCreateSocketError as e: +                        LOG.warning(e) +                        self._ephemeral_dhcp_ctx.clean_network() +                        return +                    path = REPORTED_READY_MARKER_FILE +                    LOG.info( +                        "Creating a marker file to report ready: %s", path) +                    util.write_file(path, "{pid}: {time}\n".format( +                        pid=os.getpid(), time=time())) +                    self._report_ready(lease=lease) +                    report_ready = False +                    try: +                        netlink.wait_for_media_disconnect_connect( +                            nl_sock, lease['interface']) +                    except AssertionError as error: +                        LOG.error(error) +                        return +                    self._ephemeral_dhcp_ctx.clean_network() +                else:                      return readurl(url, timeout=1, headers=headers, -                                   exception_cb=exc_cb, infinite=True).contents +                                   exception_cb=exc_cb, infinite=True, +                                   log_req_resp=False).contents              except UrlError: +                # Teardown our EphemeralDHCPv4 context on failure as we retry +                self._ephemeral_dhcp_ctx.clean_network()                  pass +            finally: +                if nl_sock: +                    nl_sock.close()      def _report_ready(self, lease): -        """Tells the fabric provisioning has completed -           before we go into our polling loop.""" +        """Tells the fabric provisioning has completed """          try:              get_metadata_from_fabric(None, lease['unknown-245'])          except Exception: @@ -617,7 +669,11 @@ class DataSourceAzure(sources.DataSource):                the blacklisted devices.          """          if not self._network_config: -            self._network_config = parse_network_config(self._metadata_imds) +            if self.ds_cfg.get('apply_network_config'): +                nc_src = self._metadata_imds +            else: +                nc_src = None +            self._network_config = parse_network_config(nc_src)          return self._network_config @@ -698,7 +754,7 @@ def can_dev_be_reformatted(devpath, preserve_ntfs):          file_count = util.mount_cb(cand_path, count_files, mtype="ntfs",                                     update_env_for_mount={'LANG': 'C'})      except util.MountFailedError as e: -        if "mount: unknown filesystem type 'ntfs'" in str(e): +        if "unknown filesystem type 'ntfs'" in str(e):              return True, (bmsg + ' but this system cannot mount NTFS,'                            ' assuming there are no important files.'                            ' Formatting allowed.') @@ -926,12 +982,12 @@ def read_azure_ovf(contents):                              lambda n:                              n.localName == "LinuxProvisioningConfigurationSet") -    if len(results) == 0: +    if len(lpcs_nodes) == 0:          raise NonAzureDataSource("No LinuxProvisioningConfigurationSet") -    if len(results) > 1: +    if len(lpcs_nodes) > 1:          raise BrokenAzureDataSource("found '%d' %ss" %                                      ("LinuxProvisioningConfigurationSet", -                                     len(results))) +                                     len(lpcs_nodes)))      lpcs = lpcs_nodes[0]      if not lpcs.hasChildNodes(): @@ -1160,17 +1216,12 @@ def get_metadata_from_imds(fallback_nic, retries):  def _get_metadata_from_imds(retries): -    def retry_on_url_error(msg, exception): -        if isinstance(exception, UrlError) and exception.code == 404: -            return True  # Continue retries -        return False  # Stop retries on all other exceptions -      url = IMDS_URL + "instance?api-version=2017-12-01"      headers = {"Metadata": "true"}      try:          response = readurl(              url, timeout=1, headers=headers, retries=retries, -            exception_cb=retry_on_url_error) +            exception_cb=retry_on_url_exc)      except Exception as e:          LOG.debug('Ignoring IMDS instance metadata: %s', e)          return {} @@ -1193,7 +1244,7 @@ def maybe_remove_ubuntu_network_config_scripts(paths=None):      additional interfaces which get attached by a customer at some point      after initial boot. Since the Azure datasource can now regenerate      network configuration as metadata reports these new devices, we no longer -    want the udev rules or netplan's 90-azure-hotplug.yaml to configure +    want the udev rules or netplan's 90-hotplug-azure.yaml to configure      networking on eth1 or greater as it might collide with cloud-init's      configuration. diff --git a/cloudinit/sources/DataSourceBigstep.py b/cloudinit/sources/DataSourceBigstep.py index 699a85b5..52fff20a 100644 --- a/cloudinit/sources/DataSourceBigstep.py +++ b/cloudinit/sources/DataSourceBigstep.py @@ -36,6 +36,10 @@ class DataSourceBigstep(sources.DataSource):          self.userdata_raw = decoded["userdata_raw"]          return True +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        return 'metadata (%s)' % get_url_from_file() +  def get_url_from_file():      try: diff --git a/cloudinit/sources/DataSourceCloudSigma.py b/cloudinit/sources/DataSourceCloudSigma.py index c816f349..2955d3f0 100644 --- a/cloudinit/sources/DataSourceCloudSigma.py +++ b/cloudinit/sources/DataSourceCloudSigma.py @@ -7,7 +7,7 @@  from base64 import b64decode  import re -from cloudinit.cs_utils import Cepko +from cloudinit.cs_utils import Cepko, SERIAL_PORT  from cloudinit import log as logging  from cloudinit import sources @@ -84,6 +84,10 @@ class DataSourceCloudSigma(sources.DataSource):          return True +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        return 'cepko (%s)' % SERIAL_PORT +      def get_hostname(self, fqdn=False, resolve_ip=False, metadata_only=False):          """          Cleans up and uses the server's name if the latter is set. Otherwise diff --git a/cloudinit/sources/DataSourceConfigDrive.py b/cloudinit/sources/DataSourceConfigDrive.py index 664dc4b7..564e3eb3 100644 --- a/cloudinit/sources/DataSourceConfigDrive.py +++ b/cloudinit/sources/DataSourceConfigDrive.py @@ -160,6 +160,18 @@ class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource):                  LOG.debug("no network configuration available")          return self._network_config +    @property +    def platform(self): +        return 'openstack' + +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        if self.seed_dir in self.source: +            subplatform_type = 'seed-dir' +        elif self.source.startswith('/dev'): +            subplatform_type = 'config-disk' +        return '%s (%s)' % (subplatform_type, self.source) +  def read_config_drive(source_dir):      reader = openstack.ConfigDriveReader(source_dir) diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index 98ea7bbc..b49a08db 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -30,18 +30,16 @@ STRICT_ID_DEFAULT = "warn"  DEFAULT_PRIMARY_NIC = 'eth0' -class Platforms(object): -    # TODO Rename and move to cloudinit.cloud.CloudNames -    ALIYUN = "AliYun" -    AWS = "AWS" -    BRIGHTBOX = "Brightbox" -    SEEDED = "Seeded" +class CloudNames(object): +    ALIYUN = "aliyun" +    AWS = "aws" +    BRIGHTBOX = "brightbox"      # UNKNOWN indicates no positive id.  If strict_id is 'warn' or 'false',      # then an attempt at the Ec2 Metadata service will be made. -    UNKNOWN = "Unknown" +    UNKNOWN = "unknown"      # NO_EC2_METADATA indicates this platform does not have a Ec2 metadata      # service available. No attempt at the Ec2 Metadata service will be made. -    NO_EC2_METADATA = "No-EC2-Metadata" +    NO_EC2_METADATA = "no-ec2-metadata"  class DataSourceEc2(sources.DataSource): @@ -69,8 +67,6 @@ class DataSourceEc2(sources.DataSource):      url_max_wait = 120      url_timeout = 50 -    _cloud_platform = None -      _network_config = sources.UNSET  # Used to cache calculated network cfg v1      # Whether we want to get network configuration from the metadata service. @@ -79,30 +75,21 @@ class DataSourceEc2(sources.DataSource):      def __init__(self, sys_cfg, distro, paths):          super(DataSourceEc2, self).__init__(sys_cfg, distro, paths)          self.metadata_address = None -        self.seed_dir = os.path.join(paths.seed_dir, "ec2")      def _get_cloud_name(self):          """Return the cloud name as identified during _get_data.""" -        return self.cloud_platform +        return identify_platform()      def _get_data(self): -        seed_ret = {} -        if util.read_optional_seed(seed_ret, base=(self.seed_dir + "/")): -            self.userdata_raw = seed_ret['user-data'] -            self.metadata = seed_ret['meta-data'] -            LOG.debug("Using seeded ec2 data from %s", self.seed_dir) -            self._cloud_platform = Platforms.SEEDED -            return True -          strict_mode, _sleep = read_strict_mode(              util.get_cfg_by_path(self.sys_cfg, STRICT_ID_PATH,                                   STRICT_ID_DEFAULT), ("warn", None)) -        LOG.debug("strict_mode: %s, cloud_platform=%s", -                  strict_mode, self.cloud_platform) -        if strict_mode == "true" and self.cloud_platform == Platforms.UNKNOWN: +        LOG.debug("strict_mode: %s, cloud_name=%s cloud_platform=%s", +                  strict_mode, self.cloud_name, self.platform) +        if strict_mode == "true" and self.cloud_name == CloudNames.UNKNOWN:              return False -        elif self.cloud_platform == Platforms.NO_EC2_METADATA: +        elif self.cloud_name == CloudNames.NO_EC2_METADATA:              return False          if self.perform_dhcp_setup:  # Setup networking in init-local stage. @@ -111,13 +98,22 @@ class DataSourceEc2(sources.DataSource):                  return False              try:                  with EphemeralDHCPv4(self.fallback_interface): -                    return util.log_time( +                    self._crawled_metadata = util.log_time(                          logfunc=LOG.debug, msg='Crawl of metadata service', -                        func=self._crawl_metadata) +                        func=self.crawl_metadata)              except NoDHCPLeaseError:                  return False          else: -            return self._crawl_metadata() +            self._crawled_metadata = util.log_time( +                logfunc=LOG.debug, msg='Crawl of metadata service', +                func=self.crawl_metadata) +        if not self._crawled_metadata: +            return False +        self.metadata = self._crawled_metadata.get('meta-data', None) +        self.userdata_raw = self._crawled_metadata.get('user-data', None) +        self.identity = self._crawled_metadata.get( +            'dynamic', {}).get('instance-identity', {}).get('document', {}) +        return True      @property      def launch_index(self): @@ -125,6 +121,15 @@ class DataSourceEc2(sources.DataSource):              return None          return self.metadata.get('ami-launch-index') +    @property +    def platform(self): +        # Handle upgrade path of pickled ds +        if not hasattr(self, '_platform_type'): +            self._platform_type = DataSourceEc2.dsname.lower() +        if not self._platform_type: +            self._platform_type = DataSourceEc2.dsname.lower() +        return self._platform_type +      def get_metadata_api_version(self):          """Get the best supported api version from the metadata service. @@ -152,7 +157,7 @@ class DataSourceEc2(sources.DataSource):          return self.min_metadata_version      def get_instance_id(self): -        if self.cloud_platform == Platforms.AWS: +        if self.cloud_name == CloudNames.AWS:              # Prefer the ID from the instance identity document, but fall back              if not getattr(self, 'identity', None):                  # If re-using cached datasource, it's get_data run didn't @@ -262,7 +267,7 @@ class DataSourceEc2(sources.DataSource):      @property      def availability_zone(self):          try: -            if self.cloud_platform == Platforms.AWS: +            if self.cloud_name == CloudNames.AWS:                  return self.identity.get(                      'availabilityZone',                      self.metadata['placement']['availability-zone']) @@ -273,7 +278,7 @@ class DataSourceEc2(sources.DataSource):      @property      def region(self): -        if self.cloud_platform == Platforms.AWS: +        if self.cloud_name == CloudNames.AWS:              region = self.identity.get('region')              # Fallback to trimming the availability zone if region is missing              if self.availability_zone and not region: @@ -285,16 +290,10 @@ class DataSourceEc2(sources.DataSource):                  return az[:-1]          return None -    @property -    def cloud_platform(self):  # TODO rename cloud_name -        if self._cloud_platform is None: -            self._cloud_platform = identify_platform() -        return self._cloud_platform -      def activate(self, cfg, is_new_instance):          if not is_new_instance:              return -        if self.cloud_platform == Platforms.UNKNOWN: +        if self.cloud_name == CloudNames.UNKNOWN:              warn_if_necessary(                  util.get_cfg_by_path(cfg, STRICT_ID_PATH, STRICT_ID_DEFAULT),                  cfg) @@ -314,13 +313,13 @@ class DataSourceEc2(sources.DataSource):          result = None          no_network_metadata_on_aws = bool(              'network' not in self.metadata and -            self.cloud_platform == Platforms.AWS) +            self.cloud_name == CloudNames.AWS)          if no_network_metadata_on_aws:              LOG.debug("Metadata 'network' not present:"                        " Refreshing stale metadata from prior to upgrade.")              util.log_time(                  logfunc=LOG.debug, msg='Re-crawl of metadata service', -                func=self._crawl_metadata) +                func=self.get_data)          # Limit network configuration to only the primary/fallback nic          iface = self.fallback_interface @@ -348,28 +347,32 @@ class DataSourceEc2(sources.DataSource):                  return super(DataSourceEc2, self).fallback_interface          return self._fallback_interface -    def _crawl_metadata(self): +    def crawl_metadata(self):          """Crawl metadata service when available. -        @returns: True on success, False otherwise. +        @returns: Dictionary of crawled metadata content containing the keys: +          meta-data, user-data and dynamic.          """          if not self.wait_for_metadata_service(): -            return False +            return {}          api_version = self.get_metadata_api_version() +        crawled_metadata = {}          try: -            self.userdata_raw = ec2.get_instance_userdata( +            crawled_metadata['user-data'] = ec2.get_instance_userdata(                  api_version, self.metadata_address) -            self.metadata = ec2.get_instance_metadata( +            crawled_metadata['meta-data'] = ec2.get_instance_metadata(                  api_version, self.metadata_address) -            if self.cloud_platform == Platforms.AWS: -                self.identity = ec2.get_instance_identity( -                    api_version, self.metadata_address).get('document', {}) +            if self.cloud_name == CloudNames.AWS: +                identity = ec2.get_instance_identity( +                    api_version, self.metadata_address) +                crawled_metadata['dynamic'] = {'instance-identity': identity}          except Exception:              util.logexc(                  LOG, "Failed reading from metadata address %s",                  self.metadata_address) -            return False -        return True +            return {} +        crawled_metadata['_metadata_api_version'] = api_version +        return crawled_metadata  class DataSourceEc2Local(DataSourceEc2): @@ -383,10 +386,10 @@ class DataSourceEc2Local(DataSourceEc2):      perform_dhcp_setup = True  # Use dhcp before querying metadata      def get_data(self): -        supported_platforms = (Platforms.AWS,) -        if self.cloud_platform not in supported_platforms: +        supported_platforms = (CloudNames.AWS,) +        if self.cloud_name not in supported_platforms:              LOG.debug("Local Ec2 mode only supported on %s, not %s", -                      supported_platforms, self.cloud_platform) +                      supported_platforms, self.cloud_name)              return False          return super(DataSourceEc2Local, self).get_data() @@ -447,20 +450,20 @@ def identify_aws(data):      if (data['uuid'].startswith('ec2') and              (data['uuid_source'] == 'hypervisor' or               data['uuid'] == data['serial'])): -            return Platforms.AWS +            return CloudNames.AWS      return None  def identify_brightbox(data):      if data['serial'].endswith('brightbox.com'): -        return Platforms.BRIGHTBOX +        return CloudNames.BRIGHTBOX  def identify_platform(): -    # identify the platform and return an entry in Platforms. +    # identify the platform and return an entry in CloudNames.      data = _collect_platform_data() -    checks = (identify_aws, identify_brightbox, lambda x: Platforms.UNKNOWN) +    checks = (identify_aws, identify_brightbox, lambda x: CloudNames.UNKNOWN)      for checker in checks:          try:              result = checker(data) diff --git a/cloudinit/sources/DataSourceIBMCloud.py b/cloudinit/sources/DataSourceIBMCloud.py index a5358148..21e6ae6b 100644 --- a/cloudinit/sources/DataSourceIBMCloud.py +++ b/cloudinit/sources/DataSourceIBMCloud.py @@ -157,6 +157,10 @@ class DataSourceIBMCloud(sources.DataSource):          return True +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        return '%s (%s)' % (self.platform, self.source) +      def check_instance_id(self, sys_cfg):          """quickly (local check only) if self.instance_id is still valid diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index bcb38544..61aa6d7e 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -109,6 +109,10 @@ class DataSourceMAAS(sources.DataSource):                  LOG.warning("Invalid content in vendor-data: %s", e)                  self.vendordata_raw = None +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        return 'seed-dir (%s)' % self.base_url +      def wait_for_metadata_service(self, url):          mcfg = self.ds_cfg          max_wait = 120 diff --git a/cloudinit/sources/DataSourceNoCloud.py b/cloudinit/sources/DataSourceNoCloud.py index 2daea59d..6860f0cc 100644 --- a/cloudinit/sources/DataSourceNoCloud.py +++ b/cloudinit/sources/DataSourceNoCloud.py @@ -186,6 +186,27 @@ class DataSourceNoCloud(sources.DataSource):          self._network_eni = mydata['meta-data'].get('network-interfaces')          return True +    @property +    def platform_type(self): +        # Handle upgrade path of pickled ds +        if not hasattr(self, '_platform_type'): +            self._platform_type = None +        if not self._platform_type: +            self._platform_type = 'lxd' if util.is_lxd() else 'nocloud' +        return self._platform_type + +    def _get_cloud_name(self): +        """Return unknown when 'cloud-name' key is absent from metadata.""" +        return sources.METADATA_UNKNOWN + +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        if self.seed.startswith('/dev'): +            subplatform_type = 'config-disk' +        else: +            subplatform_type = 'seed-dir' +        return '%s (%s)' % (subplatform_type, self.seed) +      def check_instance_id(self, sys_cfg):          # quickly (local check only) if self.instance_id is still valid          # we check kernel command line or files. @@ -290,6 +311,35 @@ def parse_cmdline_data(ds_id, fill, cmdline=None):      return True +def _maybe_remove_top_network(cfg): +    """If network-config contains top level 'network' key, then remove it. + +    Some providers of network configuration may provide a top level +    'network' key (LP: #1798117) even though it is not necessary. + +    Be friendly and remove it if it really seems so. + +    Return the original value if no change or the updated value if changed.""" +    nullval = object() +    network_val = cfg.get('network', nullval) +    if network_val is nullval: +        return cfg +    bmsg = 'Top level network key in network-config %s: %s' +    if not isinstance(network_val, dict): +        LOG.debug(bmsg, "was not a dict", cfg) +        return cfg +    if len(list(cfg.keys())) != 1: +        LOG.debug(bmsg, "had multiple top level keys", cfg) +        return cfg +    if network_val.get('config') == "disabled": +        LOG.debug(bmsg, "was config/disabled", cfg) +    elif not all(('config' in network_val, 'version' in network_val)): +        LOG.debug(bmsg, "but missing 'config' or 'version'", cfg) +        return cfg +    LOG.debug(bmsg, "fixed by removing shifting network.", cfg) +    return network_val + +  def _merge_new_seed(cur, seeded):      ret = cur.copy() @@ -299,7 +349,8 @@ def _merge_new_seed(cur, seeded):      ret['meta-data'] = util.mergemanydict([cur['meta-data'], newmd])      if seeded.get('network-config'): -        ret['network-config'] = util.load_yaml(seeded['network-config']) +        ret['network-config'] = _maybe_remove_top_network( +            util.load_yaml(seeded.get('network-config')))      if 'user-data' in seeded:          ret['user-data'] = seeded['user-data'] diff --git a/cloudinit/sources/DataSourceNone.py b/cloudinit/sources/DataSourceNone.py index e63a7e39..e6250801 100644 --- a/cloudinit/sources/DataSourceNone.py +++ b/cloudinit/sources/DataSourceNone.py @@ -28,6 +28,10 @@ class DataSourceNone(sources.DataSource):              self.metadata = self.ds_cfg['metadata']          return True +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        return 'config' +      def get_instance_id(self):          return 'iid-datasource-none' diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index 178ccb0f..045291e7 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -275,6 +275,12 @@ class DataSourceOVF(sources.DataSource):          self.cfg = cfg          return True +    def _get_subplatform(self): +        system_type = util.read_dmi_data("system-product-name").lower() +        if system_type == 'vmware': +            return 'vmware (%s)' % self.seed +        return 'ovf (%s)' % self.seed +      def get_public_ssh_keys(self):          if 'public-keys' not in self.metadata:              return [] diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index 77ccd128..e62e9729 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -95,6 +95,14 @@ class DataSourceOpenNebula(sources.DataSource):          self.userdata_raw = results.get('userdata')          return True +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        if self.seed_dir in self.seed: +            subplatform_type = 'seed-dir' +        else: +            subplatform_type = 'config-disk' +        return '%s (%s)' % (subplatform_type, self.seed) +      @property      def network_config(self):          if self.network is not None: diff --git a/cloudinit/sources/DataSourceOracle.py b/cloudinit/sources/DataSourceOracle.py index fab39af3..70b9c58a 100644 --- a/cloudinit/sources/DataSourceOracle.py +++ b/cloudinit/sources/DataSourceOracle.py @@ -91,6 +91,10 @@ class DataSourceOracle(sources.DataSource):      def crawl_metadata(self):          return read_metadata() +    def _get_subplatform(self): +        """Return the subplatform metadata source details.""" +        return 'metadata (%s)' % METADATA_ENDPOINT +      def check_instance_id(self, sys_cfg):          """quickly check (local only) if self.instance_id is still valid diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 593ac91a..32b57cdd 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -303,6 +303,9 @@ class DataSourceSmartOS(sources.DataSource):          self._set_provisioned()          return True +    def _get_subplatform(self): +        return 'serial (%s)' % SERIAL_DEVICE +      def device_name_to_device(self, name):          return self.ds_cfg['disk_aliases'].get(name) diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 5ac98826..e6966b31 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -54,9 +54,18 @@ REDACT_SENSITIVE_VALUE = 'redacted for non-root user'  METADATA_CLOUD_NAME_KEY = 'cloud-name'  UNSET = "_unset" +METADATA_UNKNOWN = 'unknown'  LOG = logging.getLogger(__name__) +# CLOUD_ID_REGION_PREFIX_MAP format is: +#  <region-match-prefix>: (<new-cloud-id>: <test_allowed_cloud_callable>) +CLOUD_ID_REGION_PREFIX_MAP = { +    'cn-': ('aws-china', lambda c: c == 'aws'),    # only change aws regions +    'us-gov-': ('aws-gov', lambda c: c == 'aws'),  # only change aws regions +    'china': ('azure-china', lambda c: c == 'azure'),  # only change azure +} +  class DataSourceNotFoundException(Exception):      pass @@ -133,6 +142,14 @@ class DataSource(object):      # Cached cloud_name as determined by _get_cloud_name      _cloud_name = None +    # Cached cloud platform api type: e.g. ec2, openstack, kvm, lxd, azure etc. +    _platform_type = None + +    # More details about the cloud platform: +    #  - metadata (http://169.254.169.254/) +    #  - seed-dir (<dirname>) +    _subplatform = None +      # Track the discovered fallback nic for use in configuration generation.      _fallback_interface = None @@ -192,21 +209,24 @@ class DataSource(object):          local_hostname = self.get_hostname()          instance_id = self.get_instance_id()          availability_zone = self.availability_zone -        cloud_name = self.cloud_name -        # When adding new standard keys prefer underscore-delimited instead -        # of hyphen-delimted to support simple variable references in jinja -        # templates. +        # In the event of upgrade from existing cloudinit, pickled datasource +        # will not contain these new class attributes. So we need to recrawl +        # metadata to discover that content.          return {              'v1': { +                '_beta_keys': ['subplatform'],                  'availability-zone': availability_zone,                  'availability_zone': availability_zone, -                'cloud-name': cloud_name, -                'cloud_name': cloud_name, +                'cloud-name': self.cloud_name, +                'cloud_name': self.cloud_name, +                'platform': self.platform_type, +                'public_ssh_keys': self.get_public_ssh_keys(),                  'instance-id': instance_id,                  'instance_id': instance_id,                  'local-hostname': local_hostname,                  'local_hostname': local_hostname, -                'region': self.region}} +                'region': self.region, +                'subplatform': self.subplatform}}      def clear_cached_attrs(self, attr_defaults=()):          """Reset any cached metadata attributes to datasource defaults. @@ -247,19 +267,27 @@ class DataSource(object):          @return True on successful write, False otherwise.          """ -        instance_data = { -            'ds': {'_doc': EXPERIMENTAL_TEXT, -                   'meta_data': self.metadata}} -        if hasattr(self, 'network_json'): -            network_json = getattr(self, 'network_json') -            if network_json != UNSET: -                instance_data['ds']['network_json'] = network_json -        if hasattr(self, 'ec2_metadata'): -            ec2_metadata = getattr(self, 'ec2_metadata') -            if ec2_metadata != UNSET: -                instance_data['ds']['ec2_metadata'] = ec2_metadata +        if hasattr(self, '_crawled_metadata'): +            # Any datasource with _crawled_metadata will best represent +            # most recent, 'raw' metadata +            crawled_metadata = copy.deepcopy( +                getattr(self, '_crawled_metadata')) +            crawled_metadata.pop('user-data', None) +            crawled_metadata.pop('vendor-data', None) +            instance_data = {'ds': crawled_metadata} +        else: +            instance_data = {'ds': {'meta_data': self.metadata}} +            if hasattr(self, 'network_json'): +                network_json = getattr(self, 'network_json') +                if network_json != UNSET: +                    instance_data['ds']['network_json'] = network_json +            if hasattr(self, 'ec2_metadata'): +                ec2_metadata = getattr(self, 'ec2_metadata') +                if ec2_metadata != UNSET: +                    instance_data['ds']['ec2_metadata'] = ec2_metadata          instance_data.update(              self._get_standardized_metadata()) +        instance_data['ds']['_doc'] = EXPERIMENTAL_TEXT          try:              # Process content base64encoding unserializable values              content = util.json_dumps(instance_data) @@ -347,6 +375,40 @@ class DataSource(object):          return self._fallback_interface      @property +    def platform_type(self): +        if not hasattr(self, '_platform_type'): +            # Handle upgrade path where pickled datasource has no _platform. +            self._platform_type = self.dsname.lower() +        if not self._platform_type: +            self._platform_type = self.dsname.lower() +        return self._platform_type + +    @property +    def subplatform(self): +        """Return a string representing subplatform details for the datasource. + +        This should be guidance for where the metadata is sourced. +        Examples of this on different clouds: +            ec2:       metadata (http://169.254.169.254) +            openstack: configdrive (/dev/path) +            openstack: metadata (http://169.254.169.254) +            nocloud:   seed-dir (/seed/dir/path) +            lxd:   nocloud (/seed/dir/path) +        """ +        if not hasattr(self, '_subplatform'): +            # Handle upgrade path where pickled datasource has no _platform. +            self._subplatform = self._get_subplatform() +        if not self._subplatform: +            self._subplatform = self._get_subplatform() +        return self._subplatform + +    def _get_subplatform(self): +        """Subclasses should implement to return a "slug (detail)" string.""" +        if hasattr(self, 'metadata_address'): +            return 'metadata (%s)' % getattr(self, 'metadata_address') +        return METADATA_UNKNOWN + +    @property      def cloud_name(self):          """Return lowercase cloud name as determined by the datasource. @@ -359,9 +421,11 @@ class DataSource(object):              cloud_name = self.metadata.get(METADATA_CLOUD_NAME_KEY)              if isinstance(cloud_name, six.string_types):                  self._cloud_name = cloud_name.lower() -            LOG.debug( -                'Ignoring metadata provided key %s: non-string type %s', -                METADATA_CLOUD_NAME_KEY, type(cloud_name)) +            else: +                self._cloud_name = self._get_cloud_name().lower() +                LOG.debug( +                    'Ignoring metadata provided key %s: non-string type %s', +                    METADATA_CLOUD_NAME_KEY, type(cloud_name))          else:              self._cloud_name = self._get_cloud_name().lower()          return self._cloud_name @@ -714,6 +778,25 @@ def instance_id_matches_system_uuid(instance_id, field='system-uuid'):      return instance_id.lower() == dmi_value.lower() +def canonical_cloud_id(cloud_name, region, platform): +    """Lookup the canonical cloud-id for a given cloud_name and region.""" +    if not cloud_name: +        cloud_name = METADATA_UNKNOWN +    if not region: +        region = METADATA_UNKNOWN +    if region == METADATA_UNKNOWN: +        if cloud_name != METADATA_UNKNOWN: +            return cloud_name +        return platform +    for prefix, cloud_id_test in CLOUD_ID_REGION_PREFIX_MAP.items(): +        (cloud_id, valid_cloud) = cloud_id_test +        if region.startswith(prefix) and valid_cloud(cloud_name): +            return cloud_id +    if cloud_name != METADATA_UNKNOWN: +        return cloud_name +    return platform + +  def convert_vendordata(data, recurse=True):      """data: a loaded object (strings, arrays, dicts).      return something suitable for cloudinit vendordata_raw. diff --git a/cloudinit/sources/helpers/netlink.py b/cloudinit/sources/helpers/netlink.py new file mode 100644 index 00000000..d377ae3d --- /dev/null +++ b/cloudinit/sources/helpers/netlink.py @@ -0,0 +1,250 @@ +# Author: Tamilmani Manoharan <tamanoha@microsoft.com> +# +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit import log as logging +from cloudinit import util +from collections import namedtuple + +import os +import select +import socket +import struct + +LOG = logging.getLogger(__name__) + +# http://man7.org/linux/man-pages/man7/netlink.7.html +RTMGRP_LINK = 1 +NLMSG_NOOP = 1 +NLMSG_ERROR = 2 +NLMSG_DONE = 3 +RTM_NEWLINK = 16 +RTM_DELLINK = 17 +RTM_GETLINK = 18 +RTM_SETLINK = 19 +MAX_SIZE = 65535 +RTA_DATA_OFFSET = 32 +MSG_TYPE_OFFSET = 16 +SELECT_TIMEOUT = 60 + +NLMSGHDR_FMT = "IHHII" +IFINFOMSG_FMT = "BHiII" +NLMSGHDR_SIZE = struct.calcsize(NLMSGHDR_FMT) +IFINFOMSG_SIZE = struct.calcsize(IFINFOMSG_FMT) +RTATTR_START_OFFSET = NLMSGHDR_SIZE + IFINFOMSG_SIZE +RTA_DATA_START_OFFSET = 4 +PAD_ALIGNMENT = 4 + +IFLA_IFNAME = 3 +IFLA_OPERSTATE = 16 + +# https://www.kernel.org/doc/Documentation/networking/operstates.txt +OPER_UNKNOWN = 0 +OPER_NOTPRESENT = 1 +OPER_DOWN = 2 +OPER_LOWERLAYERDOWN = 3 +OPER_TESTING = 4 +OPER_DORMANT = 5 +OPER_UP = 6 + +RTAAttr = namedtuple('RTAAttr', ['length', 'rta_type', 'data']) +InterfaceOperstate = namedtuple('InterfaceOperstate', ['ifname', 'operstate']) +NetlinkHeader = namedtuple('NetlinkHeader', ['length', 'type', 'flags', 'seq', +                                             'pid']) + + +class NetlinkCreateSocketError(RuntimeError): +    '''Raised if netlink socket fails during create or bind.''' +    pass + + +def create_bound_netlink_socket(): +    '''Creates netlink socket and bind on netlink group to catch interface +    down/up events. The socket will bound only on RTMGRP_LINK (which only +    includes RTM_NEWLINK/RTM_DELLINK/RTM_GETLINK events). The socket is set to +    non-blocking mode since we're only receiving messages. + +    :returns: netlink socket in non-blocking mode +    :raises: NetlinkCreateSocketError +    ''' +    try: +        netlink_socket = socket.socket(socket.AF_NETLINK, +                                       socket.SOCK_RAW, +                                       socket.NETLINK_ROUTE) +        netlink_socket.bind((os.getpid(), RTMGRP_LINK)) +        netlink_socket.setblocking(0) +    except socket.error as e: +        msg = "Exception during netlink socket create: %s" % e +        raise NetlinkCreateSocketError(msg) +    LOG.debug("Created netlink socket") +    return netlink_socket + + +def get_netlink_msg_header(data): +    '''Gets netlink message type and length + +    :param: data read from netlink socket +    :returns: netlink message type +    :raises: AssertionError if data is None or data is not >= NLMSGHDR_SIZE +    struct nlmsghdr { +               __u32 nlmsg_len;    /* Length of message including header */ +               __u16 nlmsg_type;   /* Type of message content */ +               __u16 nlmsg_flags;  /* Additional flags */ +               __u32 nlmsg_seq;    /* Sequence number */ +               __u32 nlmsg_pid;    /* Sender port ID */ +    }; +    ''' +    assert (data is not None), ("data is none") +    assert (len(data) >= NLMSGHDR_SIZE), ( +        "data is smaller than netlink message header") +    msg_len, msg_type, flags, seq, pid = struct.unpack(NLMSGHDR_FMT, +                                                       data[:MSG_TYPE_OFFSET]) +    LOG.debug("Got netlink msg of type %d", msg_type) +    return NetlinkHeader(msg_len, msg_type, flags, seq, pid) + + +def read_netlink_socket(netlink_socket, timeout=None): +    '''Select and read from the netlink socket if ready. + +    :param: netlink_socket: specify which socket object to read from +    :param: timeout: specify a timeout value (integer) to wait while reading, +            if none, it will block indefinitely until socket ready for read +    :returns: string of data read (max length = <MAX_SIZE>) from socket, +              if no data read, returns None +    :raises: AssertionError if netlink_socket is None +    ''' +    assert (netlink_socket is not None), ("netlink socket is none") +    read_set, _, _ = select.select([netlink_socket], [], [], timeout) +    # Incase of timeout,read_set doesn't contain netlink socket. +    # just return from this function +    if netlink_socket not in read_set: +        return None +    LOG.debug("netlink socket ready for read") +    data = netlink_socket.recv(MAX_SIZE) +    if data is None: +        LOG.error("Reading from Netlink socket returned no data") +    return data + + +def unpack_rta_attr(data, offset): +    '''Unpack a single rta attribute. + +    :param: data: string of data read from netlink socket +    :param: offset: starting offset of RTA Attribute +    :return: RTAAttr object with length, type and data. On error, return None. +    :raises: AssertionError if data is None or offset is not integer. +    ''' +    assert (data is not None), ("data is none") +    assert (type(offset) == int), ("offset is not integer") +    assert (offset >= RTATTR_START_OFFSET), ( +        "rta offset is less than expected length") +    length = rta_type = 0 +    attr_data = None +    try: +        length = struct.unpack_from("H", data, offset=offset)[0] +        rta_type = struct.unpack_from("H", data, offset=offset+2)[0] +    except struct.error: +        return None  # Should mean our offset is >= remaining data + +    # Unpack just the attribute's data. Offset by 4 to skip length/type header +    attr_data = data[offset+RTA_DATA_START_OFFSET:offset+length] +    return RTAAttr(length, rta_type, attr_data) + + +def read_rta_oper_state(data): +    '''Reads Interface name and operational state from RTA Data. + +    :param: data: string of data read from netlink socket +    :returns: InterfaceOperstate object containing if_name and oper_state. +              None if data does not contain valid IFLA_OPERSTATE and +              IFLA_IFNAME messages. +    :raises: AssertionError if data is None or length of data is +             smaller than RTATTR_START_OFFSET. +    ''' +    assert (data is not None), ("data is none") +    assert (len(data) > RTATTR_START_OFFSET), ( +        "length of data is smaller than RTATTR_START_OFFSET") +    ifname = operstate = None +    offset = RTATTR_START_OFFSET +    while offset <= len(data): +        attr = unpack_rta_attr(data, offset) +        if not attr or attr.length == 0: +            break +        # Each attribute is 4-byte aligned. Determine pad length. +        padlen = (PAD_ALIGNMENT - +                  (attr.length % PAD_ALIGNMENT)) % PAD_ALIGNMENT +        offset += attr.length + padlen + +        if attr.rta_type == IFLA_OPERSTATE: +            operstate = ord(attr.data) +        elif attr.rta_type == IFLA_IFNAME: +            interface_name = util.decode_binary(attr.data, 'utf-8') +            ifname = interface_name.strip('\0') +    if not ifname or operstate is None: +        return None +    LOG.debug("rta attrs: ifname %s operstate %d", ifname, operstate) +    return InterfaceOperstate(ifname, operstate) + + +def wait_for_media_disconnect_connect(netlink_socket, ifname): +    '''Block until media disconnect and connect has happened on an interface. +    Listens on netlink socket to receive netlink events and when the carrier +    changes from 0 to 1, it considers event has happened and +    return from this function + +    :param: netlink_socket: netlink_socket to receive events +    :param: ifname: Interface name to lookout for netlink events +    :raises: AssertionError if netlink_socket is None or ifname is None. +    ''' +    assert (netlink_socket is not None), ("netlink socket is none") +    assert (ifname is not None), ("interface name is none") +    assert (len(ifname) > 0), ("interface name cannot be empty") +    carrier = OPER_UP +    prevCarrier = OPER_UP +    data = bytes() +    LOG.debug("Wait for media disconnect and reconnect to happen") +    while True: +        recv_data = read_netlink_socket(netlink_socket, SELECT_TIMEOUT) +        if recv_data is None: +            continue +        LOG.debug('read %d bytes from socket', len(recv_data)) +        data += recv_data +        LOG.debug('Length of data after concat %d', len(data)) +        offset = 0 +        datalen = len(data) +        while offset < datalen: +            nl_msg = data[offset:] +            if len(nl_msg) < NLMSGHDR_SIZE: +                LOG.debug("Data is smaller than netlink header") +                break +            nlheader = get_netlink_msg_header(nl_msg) +            if len(nl_msg) < nlheader.length: +                LOG.debug("Partial data. Smaller than netlink message") +                break +            padlen = (nlheader.length+PAD_ALIGNMENT-1) & ~(PAD_ALIGNMENT-1) +            offset = offset + padlen +            LOG.debug('offset to next netlink message: %d', offset) +            # Ignore any messages not new link or del link +            if nlheader.type not in [RTM_NEWLINK, RTM_DELLINK]: +                continue +            interface_state = read_rta_oper_state(nl_msg) +            if interface_state is None: +                LOG.debug('Failed to read rta attributes: %s', interface_state) +                continue +            if interface_state.ifname != ifname: +                LOG.debug( +                    "Ignored netlink event on interface %s. Waiting for %s.", +                    interface_state.ifname, ifname) +                continue +            if interface_state.operstate not in [OPER_UP, OPER_DOWN]: +                continue +            prevCarrier = carrier +            carrier = interface_state.operstate +            # check for carrier down, up sequence +            isVnetSwitch = (prevCarrier == OPER_DOWN) and (carrier == OPER_UP) +            if isVnetSwitch: +                LOG.debug("Media switch happened on %s.", ifname) +                return +        data = data[offset:] + +# vi: ts=4 expandtab diff --git a/cloudinit/sources/helpers/tests/test_netlink.py b/cloudinit/sources/helpers/tests/test_netlink.py new file mode 100644 index 00000000..c2898a16 --- /dev/null +++ b/cloudinit/sources/helpers/tests/test_netlink.py @@ -0,0 +1,373 @@ +# Author: Tamilmani Manoharan <tamanoha@microsoft.com> +# +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit.tests.helpers import CiTestCase, mock +import socket +import struct +import codecs +from cloudinit.sources.helpers.netlink import ( +    NetlinkCreateSocketError, create_bound_netlink_socket, read_netlink_socket, +    read_rta_oper_state, unpack_rta_attr, wait_for_media_disconnect_connect, +    OPER_DOWN, OPER_UP, OPER_DORMANT, OPER_LOWERLAYERDOWN, OPER_NOTPRESENT, +    OPER_TESTING, OPER_UNKNOWN, RTATTR_START_OFFSET, RTM_NEWLINK, RTM_SETLINK, +    RTM_GETLINK, MAX_SIZE) + + +def int_to_bytes(i): +    '''convert integer to binary: eg: 1 to \x01''' +    hex_value = '{0:x}'.format(i) +    hex_value = '0' * (len(hex_value) % 2) + hex_value +    return codecs.decode(hex_value, 'hex_codec') + + +class TestCreateBoundNetlinkSocket(CiTestCase): + +    @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +    def test_socket_error_on_create(self, m_socket): +        '''create_bound_netlink_socket catches socket creation exception''' + +        """NetlinkCreateSocketError is raised when socket creation errors.""" +        m_socket.side_effect = socket.error("Fake socket failure") +        with self.assertRaises(NetlinkCreateSocketError) as ctx_mgr: +            create_bound_netlink_socket() +        self.assertEqual( +            'Exception during netlink socket create: Fake socket failure', +            str(ctx_mgr.exception)) + + +class TestReadNetlinkSocket(CiTestCase): + +    @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +    @mock.patch('cloudinit.sources.helpers.netlink.select.select') +    def test_read_netlink_socket(self, m_select, m_socket): +        '''read_netlink_socket able to receive data''' +        data = 'netlinktest' +        m_select.return_value = [m_socket], None, None +        m_socket.recv.return_value = data +        recv_data = read_netlink_socket(m_socket, 2) +        m_select.assert_called_with([m_socket], [], [], 2) +        m_socket.recv.assert_called_with(MAX_SIZE) +        self.assertIsNotNone(recv_data) +        self.assertEqual(recv_data, data) + +    @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +    @mock.patch('cloudinit.sources.helpers.netlink.select.select') +    def test_netlink_read_timeout(self, m_select, m_socket): +        '''read_netlink_socket should timeout if nothing to read''' +        m_select.return_value = [], None, None +        data = read_netlink_socket(m_socket, 1) +        m_select.assert_called_with([m_socket], [], [], 1) +        self.assertEqual(m_socket.recv.call_count, 0) +        self.assertIsNone(data) + +    def test_read_invalid_socket(self): +        '''read_netlink_socket raises assert error if socket is invalid''' +        socket = None +        with self.assertRaises(AssertionError) as context: +            read_netlink_socket(socket, 1) +        self.assertTrue('netlink socket is none' in str(context.exception)) + + +class TestParseNetlinkMessage(CiTestCase): + +    def test_read_rta_oper_state(self): +        '''read_rta_oper_state could parse netlink message and extract data''' +        ifname = "eth0" +        bytes = ifname.encode("utf-8") +        buf = bytearray(48) +        struct.pack_into("HH4sHHc", buf, RTATTR_START_OFFSET, 8, 3, bytes, 5, +                         16, int_to_bytes(OPER_DOWN)) +        interface_state = read_rta_oper_state(buf) +        self.assertEqual(interface_state.ifname, ifname) +        self.assertEqual(interface_state.operstate, OPER_DOWN) + +    def test_read_none_data(self): +        '''read_rta_oper_state raises assert error if data is none''' +        data = None +        with self.assertRaises(AssertionError) as context: +            read_rta_oper_state(data) +        self.assertTrue('data is none', str(context.exception)) + +    def test_read_invalid_rta_operstate_none(self): +        '''read_rta_oper_state returns none if operstate is none''' +        ifname = "eth0" +        buf = bytearray(40) +        bytes = ifname.encode("utf-8") +        struct.pack_into("HH4s", buf, RTATTR_START_OFFSET, 8, 3, bytes) +        interface_state = read_rta_oper_state(buf) +        self.assertIsNone(interface_state) + +    def test_read_invalid_rta_ifname_none(self): +        '''read_rta_oper_state returns none if ifname is none''' +        buf = bytearray(40) +        struct.pack_into("HHc", buf, RTATTR_START_OFFSET, 5, 16, +                         int_to_bytes(OPER_DOWN)) +        interface_state = read_rta_oper_state(buf) +        self.assertIsNone(interface_state) + +    def test_read_invalid_data_len(self): +        '''raise assert error if data size is smaller than required size''' +        buf = bytearray(32) +        with self.assertRaises(AssertionError) as context: +            read_rta_oper_state(buf) +        self.assertTrue('length of data is smaller than RTATTR_START_OFFSET' in +                        str(context.exception)) + +    def test_unpack_rta_attr_none_data(self): +        '''unpack_rta_attr raises assert error if data is none''' +        data = None +        with self.assertRaises(AssertionError) as context: +            unpack_rta_attr(data, RTATTR_START_OFFSET) +        self.assertTrue('data is none' in str(context.exception)) + +    def test_unpack_rta_attr_invalid_offset(self): +        '''unpack_rta_attr raises assert error if offset is invalid''' +        data = bytearray(48) +        with self.assertRaises(AssertionError) as context: +            unpack_rta_attr(data, "offset") +        self.assertTrue('offset is not integer' in str(context.exception)) +        with self.assertRaises(AssertionError) as context: +            unpack_rta_attr(data, 31) +        self.assertTrue('rta offset is less than expected length' in +                        str(context.exception)) + + +@mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +@mock.patch('cloudinit.sources.helpers.netlink.read_netlink_socket') +class TestWaitForMediaDisconnectConnect(CiTestCase): +    with_logs = True + +    def _media_switch_data(self, ifname, msg_type, operstate): +        '''construct netlink data with specified fields''' +        if ifname and operstate is not None: +            data = bytearray(48) +            bytes = ifname.encode("utf-8") +            struct.pack_into("HH4sHHc", data, RTATTR_START_OFFSET, 8, 3, +                             bytes, 5, 16, int_to_bytes(operstate)) +        elif ifname: +            data = bytearray(40) +            bytes = ifname.encode("utf-8") +            struct.pack_into("HH4s", data, RTATTR_START_OFFSET, 8, 3, bytes) +        elif operstate: +            data = bytearray(40) +            struct.pack_into("HHc", data, RTATTR_START_OFFSET, 5, 16, +                             int_to_bytes(operstate)) +        struct.pack_into("=LHHLL", data, 0, len(data), msg_type, 0, 0, 0) +        return data + +    def test_media_down_up_scenario(self, m_read_netlink_socket, +                                    m_socket): +        '''Test for media down up sequence for required interface name''' +        ifname = "eth0" +        # construct data for Oper State down +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        # construct data for Oper State up +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_op_down, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 2) + +    def test_wait_for_media_switch_diff_interface(self, m_read_netlink_socket, +                                                  m_socket): +        '''wait_for_media_disconnect_connect ignores unexpected interfaces. + +        The first two messages are for other interfaces and last two are for +        expected interface. So the function exit only after receiving last +        2 messages and therefore the call count for m_read_netlink_socket +        has to be 4 +        ''' +        other_ifname = "eth1" +        expected_ifname = "eth0" +        data_op_down_eth1 = self._media_switch_data( +                                other_ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up_eth1 = self._media_switch_data( +                                other_ifname, RTM_NEWLINK, OPER_UP) +        data_op_down_eth0 = self._media_switch_data( +                                expected_ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up_eth0 = self._media_switch_data( +                                expected_ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_op_down_eth1, +                                             data_op_up_eth1, +                                             data_op_down_eth0, +                                             data_op_up_eth0] +        wait_for_media_disconnect_connect(m_socket, expected_ifname) +        self.assertIn('Ignored netlink event on interface %s' % other_ifname, +                      self.logs.getvalue()) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_invalid_msgtype_getlink(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect ignores GETLINK events. + +        The first two messages are for oper down and up for RTM_GETLINK type +        which netlink module will ignore. The last 2 messages are RTM_NEWLINK +        with oper state down and up messages. Therefore the call count for +        m_read_netlink_socket has to be 4 ignoring first 2 messages +        of RTM_GETLINK +        ''' +        ifname = "eth0" +        data_getlink_down = self._media_switch_data( +                                    ifname, RTM_GETLINK, OPER_DOWN) +        data_getlink_up = self._media_switch_data( +                                    ifname, RTM_GETLINK, OPER_UP) +        data_newlink_down = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_DOWN) +        data_newlink_up = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_getlink_down, +                                             data_getlink_up, +                                             data_newlink_down, +                                             data_newlink_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_invalid_msgtype_setlink(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect ignores SETLINK events. + +        The first two messages are for oper down and up for RTM_GETLINK type +        which it will ignore. 3rd and 4th messages are RTM_NEWLINK with down +        and up messages. This function should exit after 4th messages since it +        sees down->up scenario. So the call count for m_read_netlink_socket +        has to be 4 ignoring first 2 messages of RTM_GETLINK and +        last 2 messages of RTM_NEWLINK +        ''' +        ifname = "eth0" +        data_setlink_down = self._media_switch_data( +                                    ifname, RTM_SETLINK, OPER_DOWN) +        data_setlink_up = self._media_switch_data( +                                    ifname, RTM_SETLINK, OPER_UP) +        data_newlink_down = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_DOWN) +        data_newlink_up = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_setlink_down, +                                             data_setlink_up, +                                             data_newlink_down, +                                             data_newlink_up, +                                             data_newlink_down, +                                             data_newlink_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_netlink_invalid_switch_scenario(self, m_read_netlink_socket, +                                             m_socket): +        '''returns only if it receives UP event after a DOWN event''' +        ifname = "eth0" +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        data_op_dormant = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_DORMANT) +        data_op_notpresent = self._media_switch_data(ifname, RTM_NEWLINK, +                                                     OPER_NOTPRESENT) +        data_op_lowerdown = self._media_switch_data(ifname, RTM_NEWLINK, +                                                    OPER_LOWERLAYERDOWN) +        data_op_testing = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_TESTING) +        data_op_unknown = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_UNKNOWN) +        m_read_netlink_socket.side_effect = [data_op_up, data_op_up, +                                             data_op_dormant, data_op_up, +                                             data_op_notpresent, data_op_up, +                                             data_op_lowerdown, data_op_up, +                                             data_op_testing, data_op_up, +                                             data_op_unknown, data_op_up, +                                             data_op_down, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 14) + +    def test_netlink_valid_inbetween_transitions(self, m_read_netlink_socket, +                                                 m_socket): +        '''wait_for_media_disconnect_connect handles in between transitions''' +        ifname = "eth0" +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        data_op_dormant = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_DORMANT) +        data_op_unknown = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_UNKNOWN) +        m_read_netlink_socket.side_effect = [data_op_down, data_op_dormant, +                                             data_op_unknown, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_netlink_invalid_operstate(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect should handle invalid operstates. + +        The function should not fail and return even if it receives invalid +        operstates. It always should wait for down up sequence. +        ''' +        ifname = "eth0" +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        data_op_invalid = self._media_switch_data(ifname, RTM_NEWLINK, 7) +        m_read_netlink_socket.side_effect = [data_op_invalid, data_op_up, +                                             data_op_down, data_op_invalid, +                                             data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 5) + +    def test_wait_invalid_socket(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect handle none netlink socket.''' +        socket = None +        ifname = "eth0" +        with self.assertRaises(AssertionError) as context: +            wait_for_media_disconnect_connect(socket, ifname) +        self.assertTrue('netlink socket is none' in str(context.exception)) + +    def test_wait_invalid_ifname(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect handle none interface name''' +        ifname = None +        with self.assertRaises(AssertionError) as context: +            wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertTrue('interface name is none' in str(context.exception)) +        ifname = "" +        with self.assertRaises(AssertionError) as context: +            wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertTrue('interface name cannot be empty' in +                        str(context.exception)) + +    def test_wait_invalid_rta_attr(self, m_read_netlink_socket, m_socket): +        ''' wait_for_media_disconnect_connect handles invalid rta data''' +        ifname = "eth0" +        data_invalid1 = self._media_switch_data(None, RTM_NEWLINK, OPER_DOWN) +        data_invalid2 = self._media_switch_data(ifname, RTM_NEWLINK, None) +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_invalid1, data_invalid2, +                                             data_op_down, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_read_multiple_netlink_msgs(self, m_read_netlink_socket, m_socket): +        '''Read multiple messages in single receive call''' +        ifname = "eth0" +        bytes = ifname.encode("utf-8") +        data = bytearray(96) +        struct.pack_into("=LHHLL", data, 0, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data, RTATTR_START_OFFSET, 8, 3, +                         bytes, 5, 16, int_to_bytes(OPER_DOWN)) +        struct.pack_into("=LHHLL", data, 48, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data, 48 + RTATTR_START_OFFSET, 8, +                         3, bytes, 5, 16, int_to_bytes(OPER_UP)) +        m_read_netlink_socket.return_value = data +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 1) + +    def test_read_partial_netlink_msgs(self, m_read_netlink_socket, m_socket): +        '''Read partial messages in receive call''' +        ifname = "eth0" +        bytes = ifname.encode("utf-8") +        data1 = bytearray(112) +        data2 = bytearray(32) +        struct.pack_into("=LHHLL", data1, 0, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data1, RTATTR_START_OFFSET, 8, 3, +                         bytes, 5, 16, int_to_bytes(OPER_DOWN)) +        struct.pack_into("=LHHLL", data1, 48, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data1, 80, 8, 3, bytes, 5, 16, +                         int_to_bytes(OPER_DOWN)) +        struct.pack_into("=LHHLL", data1, 96, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data2, 16, 8, 3, bytes, 5, 16, +                         int_to_bytes(OPER_UP)) +        m_read_netlink_socket.side_effect = [data1, data2] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 2) diff --git a/cloudinit/sources/helpers/vmware/imc/config_nic.py b/cloudinit/sources/helpers/vmware/imc/config_nic.py index e1890e23..77cbf3b6 100644 --- a/cloudinit/sources/helpers/vmware/imc/config_nic.py +++ b/cloudinit/sources/helpers/vmware/imc/config_nic.py @@ -165,9 +165,8 @@ class NicConfigurator(object):          # Add routes if there is no primary nic          if not self._primaryNic and v4.gateways: -            route_list.extend(self.gen_ipv4_route(nic, -                                                  v4.gateways, -                                                  v4.netmask)) +            subnet.update( +                {'routes': self.gen_ipv4_route(nic, v4.gateways, v4.netmask)})          return ([subnet], route_list) diff --git a/cloudinit/sources/tests/test_init.py b/cloudinit/sources/tests/test_init.py index 8082019e..6378e98b 100644 --- a/cloudinit/sources/tests/test_init.py +++ b/cloudinit/sources/tests/test_init.py @@ -11,7 +11,8 @@ from cloudinit.helpers import Paths  from cloudinit import importer  from cloudinit.sources import (      EXPERIMENTAL_TEXT, INSTANCE_JSON_FILE, INSTANCE_JSON_SENSITIVE_FILE, -    REDACT_SENSITIVE_VALUE, UNSET, DataSource, redact_sensitive_keys) +    METADATA_UNKNOWN, REDACT_SENSITIVE_VALUE, UNSET, DataSource, +    canonical_cloud_id, redact_sensitive_keys)  from cloudinit.tests.helpers import CiTestCase, skipIf, mock  from cloudinit.user_data import UserDataProcessor  from cloudinit import util @@ -295,6 +296,7 @@ class TestDataSource(CiTestCase):              'base64_encoded_keys': [],              'sensitive_keys': [],              'v1': { +                '_beta_keys': ['subplatform'],                  'availability-zone': 'myaz',                  'availability_zone': 'myaz',                  'cloud-name': 'subclasscloudname', @@ -303,7 +305,10 @@ class TestDataSource(CiTestCase):                  'instance_id': 'iid-datasource',                  'local-hostname': 'test-subclass-hostname',                  'local_hostname': 'test-subclass-hostname', -                'region': 'myregion'}, +                'platform': 'mytestsubclass', +                'public_ssh_keys': [], +                'region': 'myregion', +                'subplatform': 'unknown'},              'ds': {                  '_doc': EXPERIMENTAL_TEXT,                  'meta_data': {'availability_zone': 'myaz', @@ -339,6 +344,7 @@ class TestDataSource(CiTestCase):              'base64_encoded_keys': [],              'sensitive_keys': ['ds/meta_data/some/security-credentials'],              'v1': { +                '_beta_keys': ['subplatform'],                  'availability-zone': 'myaz',                  'availability_zone': 'myaz',                  'cloud-name': 'subclasscloudname', @@ -347,7 +353,10 @@ class TestDataSource(CiTestCase):                  'instance_id': 'iid-datasource',                  'local-hostname': 'test-subclass-hostname',                  'local_hostname': 'test-subclass-hostname', -                'region': 'myregion'}, +                'platform': 'mytestsubclass', +                'public_ssh_keys': [], +                'region': 'myregion', +                'subplatform': 'unknown'},              'ds': {                  '_doc': EXPERIMENTAL_TEXT,                  'meta_data': { @@ -599,4 +608,75 @@ class TestRedactSensitiveData(CiTestCase):              redact_sensitive_keys(md)) +class TestCanonicalCloudID(CiTestCase): + +    def test_cloud_id_returns_platform_on_unknowns(self): +        """When region and cloud_name are unknown, return platform.""" +        self.assertEqual( +            'platform', +            canonical_cloud_id(cloud_name=METADATA_UNKNOWN, +                               region=METADATA_UNKNOWN, +                               platform='platform')) + +    def test_cloud_id_returns_platform_on_none(self): +        """When region and cloud_name are unknown, return platform.""" +        self.assertEqual( +            'platform', +            canonical_cloud_id(cloud_name=None, +                               region=None, +                               platform='platform')) + +    def test_cloud_id_returns_cloud_name_on_unknown_region(self): +        """When region is unknown, return cloud_name.""" +        for region in (None, METADATA_UNKNOWN): +            self.assertEqual( +                'cloudname', +                canonical_cloud_id(cloud_name='cloudname', +                                   region=region, +                                   platform='platform')) + +    def test_cloud_id_returns_platform_on_unknown_cloud_name(self): +        """When region is set but cloud_name is unknown return cloud_name.""" +        self.assertEqual( +            'platform', +            canonical_cloud_id(cloud_name=METADATA_UNKNOWN, +                               region='region', +                               platform='platform')) + +    def test_cloud_id_aws_based_on_region_and_cloud_name(self): +        """When cloud_name is aws, return proper cloud-id based on region.""" +        self.assertEqual( +            'aws-china', +            canonical_cloud_id(cloud_name='aws', +                               region='cn-north-1', +                               platform='platform')) +        self.assertEqual( +            'aws', +            canonical_cloud_id(cloud_name='aws', +                               region='us-east-1', +                               platform='platform')) +        self.assertEqual( +            'aws-gov', +            canonical_cloud_id(cloud_name='aws', +                               region='us-gov-1', +                               platform='platform')) +        self.assertEqual(  # Overrideen non-aws cloud_name is returned +            '!aws', +            canonical_cloud_id(cloud_name='!aws', +                               region='us-gov-1', +                               platform='platform')) + +    def test_cloud_id_azure_based_on_region_and_cloud_name(self): +        """Report cloud-id when cloud_name is azure and region is in china.""" +        self.assertEqual( +            'azure-china', +            canonical_cloud_id(cloud_name='azure', +                               region='chinaeast', +                               platform='platform')) +        self.assertEqual( +            'azure', +            canonical_cloud_id(cloud_name='azure', +                               region='!chinaeast', +                               platform='platform')) +  # vi: ts=4 expandtab diff --git a/cloudinit/sources/tests/test_oracle.py b/cloudinit/sources/tests/test_oracle.py index 7599126c..97d62947 100644 --- a/cloudinit/sources/tests/test_oracle.py +++ b/cloudinit/sources/tests/test_oracle.py @@ -71,6 +71,14 @@ class TestDataSourceOracle(test_helpers.CiTestCase):          self.assertFalse(ds._get_data())          mocks._is_platform_viable.assert_called_once_with() +    def test_platform_info(self): +        """Return platform-related information for Oracle Datasource.""" +        ds, _mocks = self._get_ds() +        self.assertEqual('oracle', ds.cloud_name) +        self.assertEqual('oracle', ds.platform_type) +        self.assertEqual( +            'metadata (http://169.254.169.254/openstack/)', ds.subplatform) +      @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True)      def test_without_userdata(self, m_is_iscsi_root):          """If no user-data is provided, it should not be in return dict.""" | 
