diff options
Diffstat (limited to 'cloudinit')
49 files changed, 2095 insertions, 312 deletions
diff --git a/cloudinit/cmd/cloud_id.py b/cloudinit/cmd/cloud_id.py new file mode 100755 index 00000000..97608921 --- /dev/null +++ b/cloudinit/cmd/cloud_id.py @@ -0,0 +1,90 @@ +# This file is part of cloud-init. See LICENSE file for license information. + +"""Commandline utility to list the canonical cloud-id for an instance.""" + +import argparse +import json +import sys + +from cloudinit.sources import ( + INSTANCE_JSON_FILE, METADATA_UNKNOWN, canonical_cloud_id) + +DEFAULT_INSTANCE_JSON = '/run/cloud-init/%s' % INSTANCE_JSON_FILE + +NAME = 'cloud-id' + + +def get_parser(parser=None): + """Build or extend an arg parser for the cloud-id utility. + + @param parser: Optional existing ArgumentParser instance representing the + query subcommand which will be extended to support the args of + this utility. + + @returns: ArgumentParser with proper argument configuration. + """ + if not parser: + parser = argparse.ArgumentParser( + prog=NAME, + description='Report the canonical cloud-id for this instance') + parser.add_argument( + '-j', '--json', action='store_true', default=False, + help='Report all standardized cloud-id information as json.') + parser.add_argument( + '-l', '--long', action='store_true', default=False, + help='Report extended cloud-id information as tab-delimited string.') + parser.add_argument( + '-i', '--instance-data', type=str, default=DEFAULT_INSTANCE_JSON, + help=('Path to instance-data.json file. Default is %s' % + DEFAULT_INSTANCE_JSON)) + return parser + + +def error(msg): + sys.stderr.write('ERROR: %s\n' % msg) + return 1 + + +def handle_args(name, args): + """Handle calls to 'cloud-id' cli. + + Print the canonical cloud-id on which the instance is running. + + @return: 0 on success, 1 otherwise. + """ + try: + instance_data = json.load(open(args.instance_data)) + except IOError: + return error( + "File not found '%s'. Provide a path to instance data json file" + ' using --instance-data' % args.instance_data) + except ValueError as e: + return error( + "File '%s' is not valid json. %s" % (args.instance_data, e)) + v1 = instance_data.get('v1', {}) + cloud_id = canonical_cloud_id( + v1.get('cloud_name', METADATA_UNKNOWN), + v1.get('region', METADATA_UNKNOWN), + v1.get('platform', METADATA_UNKNOWN)) + if args.json: + v1['cloud_id'] = cloud_id + response = json.dumps( # Pretty, sorted json + v1, indent=1, sort_keys=True, separators=(',', ': ')) + elif args.long: + response = '%s\t%s' % (cloud_id, v1.get('region', METADATA_UNKNOWN)) + else: + response = cloud_id + sys.stdout.write('%s\n' % response) + return 0 + + +def main(): + """Tool to query specific instance-data values.""" + parser = get_parser() + sys.exit(handle_args(NAME, parser.parse_args())) + + +if __name__ == '__main__': + main() + +# vi: ts=4 expandtab diff --git a/cloudinit/cmd/devel/logs.py b/cloudinit/cmd/devel/logs.py index df725204..4c086b51 100644 --- a/cloudinit/cmd/devel/logs.py +++ b/cloudinit/cmd/devel/logs.py @@ -5,14 +5,16 @@ """Define 'collect-logs' utility and handler to include in cloud-init cmd.""" import argparse -from cloudinit.util import ( - ProcessExecutionError, chdir, copy, ensure_dir, subp, write_file) -from cloudinit.temp_utils import tempdir from datetime import datetime import os import shutil import sys +from cloudinit.sources import INSTANCE_JSON_SENSITIVE_FILE +from cloudinit.temp_utils import tempdir +from cloudinit.util import ( + ProcessExecutionError, chdir, copy, ensure_dir, subp, write_file) + CLOUDINIT_LOGS = ['/var/log/cloud-init.log', '/var/log/cloud-init-output.log'] CLOUDINIT_RUN_DIR = '/run/cloud-init' @@ -46,6 +48,13 @@ def get_parser(parser=None): return parser +def _copytree_ignore_sensitive_files(curdir, files): + """Return a list of files to ignore if we are non-root""" + if os.getuid() == 0: + return () + return (INSTANCE_JSON_SENSITIVE_FILE,) # Ignore root-permissioned files + + def _write_command_output_to_file(cmd, filename, msg, verbosity): """Helper which runs a command and writes output or error to filename.""" try: @@ -78,6 +87,11 @@ def collect_logs(tarfile, include_userdata, verbosity=0): @param tarfile: The path of the tar-gzipped file to create. @param include_userdata: Boolean, true means include user-data. """ + if include_userdata and os.getuid() != 0: + sys.stderr.write( + "To include userdata, root user is required." + " Try sudo cloud-init collect-logs\n") + return 1 tarfile = os.path.abspath(tarfile) date = datetime.utcnow().date().strftime('%Y-%m-%d') log_dir = 'cloud-init-logs-{0}'.format(date) @@ -110,7 +124,8 @@ def collect_logs(tarfile, include_userdata, verbosity=0): ensure_dir(run_dir) if os.path.exists(CLOUDINIT_RUN_DIR): shutil.copytree(CLOUDINIT_RUN_DIR, - os.path.join(run_dir, 'cloud-init')) + os.path.join(run_dir, 'cloud-init'), + ignore=_copytree_ignore_sensitive_files) _debug("collected dir %s\n" % CLOUDINIT_RUN_DIR, 1, verbosity) else: _debug("directory '%s' did not exist\n" % CLOUDINIT_RUN_DIR, 1, @@ -118,21 +133,21 @@ def collect_logs(tarfile, include_userdata, verbosity=0): with chdir(tmp_dir): subp(['tar', 'czvf', tarfile, log_dir.replace(tmp_dir + '/', '')]) sys.stderr.write("Wrote %s\n" % tarfile) + return 0 def handle_collect_logs_args(name, args): """Handle calls to 'cloud-init collect-logs' as a subcommand.""" - collect_logs(args.tarfile, args.userdata, args.verbosity) + return collect_logs(args.tarfile, args.userdata, args.verbosity) def main(): """Tool to collect and tar all cloud-init related logs.""" parser = get_parser() - handle_collect_logs_args('collect-logs', parser.parse_args()) - return 0 + return handle_collect_logs_args('collect-logs', parser.parse_args()) if __name__ == '__main__': - main() + sys.exit(main()) # vi: ts=4 expandtab diff --git a/cloudinit/cmd/devel/net_convert.py b/cloudinit/cmd/devel/net_convert.py index a0f58a0a..1ad7e0bd 100755 --- a/cloudinit/cmd/devel/net_convert.py +++ b/cloudinit/cmd/devel/net_convert.py @@ -9,6 +9,7 @@ import yaml from cloudinit.sources.helpers import openstack from cloudinit.sources import DataSourceAzure as azure +from cloudinit.sources import DataSourceOVF as ovf from cloudinit import distros from cloudinit.net import eni, netplan, network_state, sysconfig @@ -31,7 +32,7 @@ def get_parser(parser=None): metavar="PATH", required=True) parser.add_argument("-k", "--kind", choices=['eni', 'network_data.json', 'yaml', - 'azure-imds'], + 'azure-imds', 'vmware-imc'], required=True) parser.add_argument("-d", "--directory", metavar="PATH", @@ -76,7 +77,6 @@ def handle_args(name, args): net_data = args.network_data.read() if args.kind == "eni": pre_ns = eni.convert_eni_data(net_data) - ns = network_state.parse_net_config_data(pre_ns) elif args.kind == "yaml": pre_ns = yaml.load(net_data) if 'network' in pre_ns: @@ -85,15 +85,16 @@ def handle_args(name, args): sys.stderr.write('\n'.join( ["Input YAML", yaml.dump(pre_ns, default_flow_style=False, indent=4), ""])) - ns = network_state.parse_net_config_data(pre_ns) elif args.kind == 'network_data.json': pre_ns = openstack.convert_net_json( json.loads(net_data), known_macs=known_macs) - ns = network_state.parse_net_config_data(pre_ns) elif args.kind == 'azure-imds': pre_ns = azure.parse_network_config(json.loads(net_data)) - ns = network_state.parse_net_config_data(pre_ns) + elif args.kind == 'vmware-imc': + config = ovf.Config(ovf.ConfigFile(args.network_data.name)) + pre_ns = ovf.get_network_config_from_conf(config, False) + ns = network_state.parse_net_config_data(pre_ns) if not ns: raise RuntimeError("No valid network_state object created from" "input data") @@ -111,6 +112,10 @@ def handle_args(name, args): elif args.output_kind == "netplan": r_cls = netplan.Renderer config = distro.renderer_configs.get('netplan') + # don't run netplan generate/apply + config['postcmds'] = False + # trim leading slash + config['netplan_path'] = config['netplan_path'][1:] else: r_cls = sysconfig.Renderer config = distro.renderer_configs.get('sysconfig') diff --git a/cloudinit/cmd/devel/render.py b/cloudinit/cmd/devel/render.py index 2ba6b681..1bc22406 100755 --- a/cloudinit/cmd/devel/render.py +++ b/cloudinit/cmd/devel/render.py @@ -8,11 +8,10 @@ import sys from cloudinit.handlers.jinja_template import render_jinja_payload_from_file from cloudinit import log -from cloudinit.sources import INSTANCE_JSON_FILE +from cloudinit.sources import INSTANCE_JSON_FILE, INSTANCE_JSON_SENSITIVE_FILE from . import addLogHandlerCLI, read_cfg_paths NAME = 'render' -DEFAULT_INSTANCE_DATA = '/run/cloud-init/instance-data.json' LOG = log.getLogger(NAME) @@ -47,12 +46,22 @@ def handle_args(name, args): @return 0 on success, 1 on failure. """ addLogHandlerCLI(LOG, log.DEBUG if args.debug else log.WARNING) - if not args.instance_data: - paths = read_cfg_paths() - instance_data_fn = os.path.join( - paths.run_dir, INSTANCE_JSON_FILE) - else: + if args.instance_data: instance_data_fn = args.instance_data + else: + paths = read_cfg_paths() + uid = os.getuid() + redacted_data_fn = os.path.join(paths.run_dir, INSTANCE_JSON_FILE) + if uid == 0: + instance_data_fn = os.path.join( + paths.run_dir, INSTANCE_JSON_SENSITIVE_FILE) + if not os.path.exists(instance_data_fn): + LOG.warning( + 'Missing root-readable %s. Using redacted %s instead.', + instance_data_fn, redacted_data_fn) + instance_data_fn = redacted_data_fn + else: + instance_data_fn = redacted_data_fn if not os.path.exists(instance_data_fn): LOG.error('Missing instance-data.json file: %s', instance_data_fn) return 1 @@ -62,10 +71,14 @@ def handle_args(name, args): except IOError: LOG.error('Missing user-data file: %s', args.user_data) return 1 - rendered_payload = render_jinja_payload_from_file( - payload=user_data, payload_fn=args.user_data, - instance_data_file=instance_data_fn, - debug=True if args.debug else False) + try: + rendered_payload = render_jinja_payload_from_file( + payload=user_data, payload_fn=args.user_data, + instance_data_file=instance_data_fn, + debug=True if args.debug else False) + except RuntimeError as e: + LOG.error('Cannot render from instance data: %s', str(e)) + return 1 if not rendered_payload: LOG.error('Unable to render user-data file: %s', args.user_data) return 1 diff --git a/cloudinit/cmd/devel/tests/test_logs.py b/cloudinit/cmd/devel/tests/test_logs.py index 98b47560..4951797b 100644 --- a/cloudinit/cmd/devel/tests/test_logs.py +++ b/cloudinit/cmd/devel/tests/test_logs.py @@ -1,13 +1,17 @@ # This file is part of cloud-init. See LICENSE file for license information. -from cloudinit.cmd.devel import logs -from cloudinit.util import ensure_dir, load_file, subp, write_file -from cloudinit.tests.helpers import FilesystemMockingTestCase, wrap_and_call from datetime import datetime -import mock import os +from six import StringIO + +from cloudinit.cmd.devel import logs +from cloudinit.sources import INSTANCE_JSON_SENSITIVE_FILE +from cloudinit.tests.helpers import ( + FilesystemMockingTestCase, mock, wrap_and_call) +from cloudinit.util import ensure_dir, load_file, subp, write_file +@mock.patch('cloudinit.cmd.devel.logs.os.getuid') class TestCollectLogs(FilesystemMockingTestCase): def setUp(self): @@ -15,14 +19,29 @@ class TestCollectLogs(FilesystemMockingTestCase): self.new_root = self.tmp_dir() self.run_dir = self.tmp_path('run', self.new_root) - def test_collect_logs_creates_tarfile(self): + def test_collect_logs_with_userdata_requires_root_user(self, m_getuid): + """collect-logs errors when non-root user collects userdata .""" + m_getuid.return_value = 100 # non-root + output_tarfile = self.tmp_path('logs.tgz') + with mock.patch('sys.stderr', new_callable=StringIO) as m_stderr: + self.assertEqual( + 1, logs.collect_logs(output_tarfile, include_userdata=True)) + self.assertEqual( + 'To include userdata, root user is required.' + ' Try sudo cloud-init collect-logs\n', + m_stderr.getvalue()) + + def test_collect_logs_creates_tarfile(self, m_getuid): """collect-logs creates a tarfile with all related cloud-init info.""" + m_getuid.return_value = 100 log1 = self.tmp_path('cloud-init.log', self.new_root) write_file(log1, 'cloud-init-log') log2 = self.tmp_path('cloud-init-output.log', self.new_root) write_file(log2, 'cloud-init-output-log') ensure_dir(self.run_dir) write_file(self.tmp_path('results.json', self.run_dir), 'results') + write_file(self.tmp_path(INSTANCE_JSON_SENSITIVE_FILE, self.run_dir), + 'sensitive') output_tarfile = self.tmp_path('logs.tgz') date = datetime.utcnow().date().strftime('%Y-%m-%d') @@ -59,6 +78,11 @@ class TestCollectLogs(FilesystemMockingTestCase): # unpack the tarfile and check file contents subp(['tar', 'zxvf', output_tarfile, '-C', self.new_root]) out_logdir = self.tmp_path(date_logdir, self.new_root) + self.assertFalse( + os.path.exists( + os.path.join(out_logdir, 'run', 'cloud-init', + INSTANCE_JSON_SENSITIVE_FILE)), + 'Unexpected file found: %s' % INSTANCE_JSON_SENSITIVE_FILE) self.assertEqual( '0.7fake\n', load_file(os.path.join(out_logdir, 'dpkg-version'))) @@ -82,8 +106,9 @@ class TestCollectLogs(FilesystemMockingTestCase): os.path.join(out_logdir, 'run', 'cloud-init', 'results.json'))) fake_stderr.write.assert_any_call('Wrote %s\n' % output_tarfile) - def test_collect_logs_includes_optional_userdata(self): + def test_collect_logs_includes_optional_userdata(self, m_getuid): """collect-logs include userdata when --include-userdata is set.""" + m_getuid.return_value = 0 log1 = self.tmp_path('cloud-init.log', self.new_root) write_file(log1, 'cloud-init-log') log2 = self.tmp_path('cloud-init-output.log', self.new_root) @@ -92,6 +117,8 @@ class TestCollectLogs(FilesystemMockingTestCase): write_file(userdata, 'user-data') ensure_dir(self.run_dir) write_file(self.tmp_path('results.json', self.run_dir), 'results') + write_file(self.tmp_path(INSTANCE_JSON_SENSITIVE_FILE, self.run_dir), + 'sensitive') output_tarfile = self.tmp_path('logs.tgz') date = datetime.utcnow().date().strftime('%Y-%m-%d') @@ -132,4 +159,8 @@ class TestCollectLogs(FilesystemMockingTestCase): self.assertEqual( 'user-data', load_file(os.path.join(out_logdir, 'user-data.txt'))) + self.assertEqual( + 'sensitive', + load_file(os.path.join(out_logdir, 'run', 'cloud-init', + INSTANCE_JSON_SENSITIVE_FILE))) fake_stderr.write.assert_any_call('Wrote %s\n' % output_tarfile) diff --git a/cloudinit/cmd/devel/tests/test_render.py b/cloudinit/cmd/devel/tests/test_render.py index fc5d2c0d..988bba03 100644 --- a/cloudinit/cmd/devel/tests/test_render.py +++ b/cloudinit/cmd/devel/tests/test_render.py @@ -6,7 +6,7 @@ import os from collections import namedtuple from cloudinit.cmd.devel import render from cloudinit.helpers import Paths -from cloudinit.sources import INSTANCE_JSON_FILE +from cloudinit.sources import INSTANCE_JSON_FILE, INSTANCE_JSON_SENSITIVE_FILE from cloudinit.tests.helpers import CiTestCase, mock, skipUnlessJinja from cloudinit.util import ensure_dir, write_file @@ -63,6 +63,49 @@ class TestRender(CiTestCase): 'Missing instance-data.json file: %s' % json_file, self.logs.getvalue()) + def test_handle_args_root_fallback_from_sensitive_instance_data(self): + """When root user defaults to sensitive.json.""" + user_data = self.tmp_path('user-data', dir=self.tmp) + run_dir = self.tmp_path('run_dir', dir=self.tmp) + ensure_dir(run_dir) + paths = Paths({'run_dir': run_dir}) + self.add_patch('cloudinit.cmd.devel.render.read_cfg_paths', 'm_paths') + self.m_paths.return_value = paths + args = self.args( + user_data=user_data, instance_data=None, debug=False) + with mock.patch('sys.stderr', new_callable=StringIO): + with mock.patch('os.getuid') as m_getuid: + m_getuid.return_value = 0 + self.assertEqual(1, render.handle_args('anyname', args)) + json_file = os.path.join(run_dir, INSTANCE_JSON_FILE) + json_sensitive = os.path.join(run_dir, INSTANCE_JSON_SENSITIVE_FILE) + self.assertIn( + 'WARNING: Missing root-readable %s. Using redacted %s' % ( + json_sensitive, json_file), self.logs.getvalue()) + self.assertIn( + 'ERROR: Missing instance-data.json file: %s' % json_file, + self.logs.getvalue()) + + def test_handle_args_root_uses_sensitive_instance_data(self): + """When root user, and no instance-data arg, use sensitive.json.""" + user_data = self.tmp_path('user-data', dir=self.tmp) + write_file(user_data, '##template: jinja\nrendering: {{ my_var }}') + run_dir = self.tmp_path('run_dir', dir=self.tmp) + ensure_dir(run_dir) + json_sensitive = os.path.join(run_dir, INSTANCE_JSON_SENSITIVE_FILE) + write_file(json_sensitive, '{"my-var": "jinja worked"}') + paths = Paths({'run_dir': run_dir}) + self.add_patch('cloudinit.cmd.devel.render.read_cfg_paths', 'm_paths') + self.m_paths.return_value = paths + args = self.args( + user_data=user_data, instance_data=None, debug=False) + with mock.patch('sys.stderr', new_callable=StringIO): + with mock.patch('sys.stdout', new_callable=StringIO) as m_stdout: + with mock.patch('os.getuid') as m_getuid: + m_getuid.return_value = 0 + self.assertEqual(0, render.handle_args('anyname', args)) + self.assertIn('rendering: jinja worked', m_stdout.getvalue()) + @skipUnlessJinja() def test_handle_args_renders_instance_data_vars_in_template(self): """If user_data file is a jinja template render instance-data vars.""" diff --git a/cloudinit/cmd/main.py b/cloudinit/cmd/main.py index 5a437020..933c019a 100644 --- a/cloudinit/cmd/main.py +++ b/cloudinit/cmd/main.py @@ -41,7 +41,7 @@ from cloudinit.settings import (PER_INSTANCE, PER_ALWAYS, PER_ONCE, from cloudinit import atomic_helper from cloudinit.config import cc_set_hostname -from cloudinit.dhclient_hook import LogDhclient +from cloudinit import dhclient_hook # Welcome message template @@ -586,12 +586,6 @@ def main_single(name, args): return 0 -def dhclient_hook(name, args): - record = LogDhclient(args) - record.check_hooks_dir() - record.record() - - def status_wrapper(name, args, data_d=None, link_d=None): if data_d is None: data_d = os.path.normpath("/var/lib/cloud/data") @@ -795,15 +789,9 @@ def main(sysv_args=None): 'query', help='Query standardized instance metadata from the command line.') - parser_dhclient = subparsers.add_parser('dhclient-hook', - help=('run the dhclient hook' - 'to record network info')) - parser_dhclient.add_argument("net_action", - help=('action taken on the interface')) - parser_dhclient.add_argument("net_interface", - help=('the network interface being acted' - ' upon')) - parser_dhclient.set_defaults(action=('dhclient_hook', dhclient_hook)) + parser_dhclient = subparsers.add_parser( + dhclient_hook.NAME, help=dhclient_hook.__doc__) + dhclient_hook.get_parser(parser_dhclient) parser_features = subparsers.add_parser('features', help=('list defined features')) diff --git a/cloudinit/cmd/query.py b/cloudinit/cmd/query.py index 7d2d4fe4..1d888b9d 100644 --- a/cloudinit/cmd/query.py +++ b/cloudinit/cmd/query.py @@ -3,6 +3,7 @@ """Query standardized instance metadata from the command line.""" import argparse +from errno import EACCES import os import six import sys @@ -79,27 +80,38 @@ def handle_args(name, args): uid = os.getuid() if not all([args.instance_data, args.user_data, args.vendor_data]): paths = read_cfg_paths() - if not args.instance_data: + if args.instance_data: + instance_data_fn = args.instance_data + else: + redacted_data_fn = os.path.join(paths.run_dir, INSTANCE_JSON_FILE) if uid == 0: - default_json_fn = INSTANCE_JSON_SENSITIVE_FILE + sensitive_data_fn = os.path.join( + paths.run_dir, INSTANCE_JSON_SENSITIVE_FILE) + if os.path.exists(sensitive_data_fn): + instance_data_fn = sensitive_data_fn + else: + LOG.warning( + 'Missing root-readable %s. Using redacted %s instead.', + sensitive_data_fn, redacted_data_fn) + instance_data_fn = redacted_data_fn else: - default_json_fn = INSTANCE_JSON_FILE # World readable - instance_data_fn = os.path.join(paths.run_dir, default_json_fn) + instance_data_fn = redacted_data_fn + if args.user_data: + user_data_fn = args.user_data else: - instance_data_fn = args.instance_data - if not args.user_data: user_data_fn = os.path.join(paths.instance_link, 'user-data.txt') + if args.vendor_data: + vendor_data_fn = args.vendor_data else: - user_data_fn = args.user_data - if not args.vendor_data: vendor_data_fn = os.path.join(paths.instance_link, 'vendor-data.txt') - else: - vendor_data_fn = args.vendor_data try: instance_json = util.load_file(instance_data_fn) - except IOError: - LOG.error('Missing instance-data.json file: %s', instance_data_fn) + except (IOError, OSError) as e: + if e.errno == EACCES: + LOG.error("No read permission on '%s'. Try sudo", instance_data_fn) + else: + LOG.error('Missing instance-data file: %s', instance_data_fn) return 1 instance_data = util.load_json(instance_json) diff --git a/cloudinit/cmd/tests/test_cloud_id.py b/cloudinit/cmd/tests/test_cloud_id.py new file mode 100644 index 00000000..73738170 --- /dev/null +++ b/cloudinit/cmd/tests/test_cloud_id.py @@ -0,0 +1,127 @@ +# This file is part of cloud-init. See LICENSE file for license information. + +"""Tests for cloud-id command line utility.""" + +from cloudinit import util +from collections import namedtuple +from six import StringIO + +from cloudinit.cmd import cloud_id + +from cloudinit.tests.helpers import CiTestCase, mock + + +class TestCloudId(CiTestCase): + + args = namedtuple('cloudidargs', ('instance_data json long')) + + def setUp(self): + super(TestCloudId, self).setUp() + self.tmp = self.tmp_dir() + self.instance_data = self.tmp_path('instance-data.json', dir=self.tmp) + + def test_cloud_id_arg_parser_defaults(self): + """Validate the argument defaults when not provided by the end-user.""" + cmd = ['cloud-id'] + with mock.patch('sys.argv', cmd): + args = cloud_id.get_parser().parse_args() + self.assertEqual( + '/run/cloud-init/instance-data.json', + args.instance_data) + self.assertEqual(False, args.long) + self.assertEqual(False, args.json) + + def test_cloud_id_arg_parse_overrides(self): + """Override argument defaults by specifying values for each param.""" + util.write_file(self.instance_data, '{}') + cmd = ['cloud-id', '--instance-data', self.instance_data, '--long', + '--json'] + with mock.patch('sys.argv', cmd): + args = cloud_id.get_parser().parse_args() + self.assertEqual(self.instance_data, args.instance_data) + self.assertEqual(True, args.long) + self.assertEqual(True, args.json) + + def test_cloud_id_missing_instance_data_json(self): + """Exit error when the provided instance-data.json does not exist.""" + cmd = ['cloud-id', '--instance-data', self.instance_data] + with mock.patch('sys.argv', cmd): + with mock.patch('sys.stderr', new_callable=StringIO) as m_stderr: + with self.assertRaises(SystemExit) as context_manager: + cloud_id.main() + self.assertEqual(1, context_manager.exception.code) + self.assertIn( + "ERROR: File not found '%s'" % self.instance_data, + m_stderr.getvalue()) + + def test_cloud_id_non_json_instance_data(self): + """Exit error when the provided instance-data.json is not json.""" + cmd = ['cloud-id', '--instance-data', self.instance_data] + util.write_file(self.instance_data, '{') + with mock.patch('sys.argv', cmd): + with mock.patch('sys.stderr', new_callable=StringIO) as m_stderr: + with self.assertRaises(SystemExit) as context_manager: + cloud_id.main() + self.assertEqual(1, context_manager.exception.code) + self.assertIn( + "ERROR: File '%s' is not valid json." % self.instance_data, + m_stderr.getvalue()) + + def test_cloud_id_from_cloud_name_in_instance_data(self): + """Report canonical cloud-id from cloud_name in instance-data.""" + util.write_file( + self.instance_data, + '{"v1": {"cloud_name": "mycloud", "region": "somereg"}}') + cmd = ['cloud-id', '--instance-data', self.instance_data] + with mock.patch('sys.argv', cmd): + with mock.patch('sys.stdout', new_callable=StringIO) as m_stdout: + with self.assertRaises(SystemExit) as context_manager: + cloud_id.main() + self.assertEqual(0, context_manager.exception.code) + self.assertEqual("mycloud\n", m_stdout.getvalue()) + + def test_cloud_id_long_name_from_instance_data(self): + """Report long cloud-id format from cloud_name and region.""" + util.write_file( + self.instance_data, + '{"v1": {"cloud_name": "mycloud", "region": "somereg"}}') + cmd = ['cloud-id', '--instance-data', self.instance_data, '--long'] + with mock.patch('sys.argv', cmd): + with mock.patch('sys.stdout', new_callable=StringIO) as m_stdout: + with self.assertRaises(SystemExit) as context_manager: + cloud_id.main() + self.assertEqual(0, context_manager.exception.code) + self.assertEqual("mycloud\tsomereg\n", m_stdout.getvalue()) + + def test_cloud_id_lookup_from_instance_data_region(self): + """Report discovered canonical cloud_id when region lookup matches.""" + util.write_file( + self.instance_data, + '{"v1": {"cloud_name": "aws", "region": "cn-north-1",' + ' "platform": "ec2"}}') + cmd = ['cloud-id', '--instance-data', self.instance_data, '--long'] + with mock.patch('sys.argv', cmd): + with mock.patch('sys.stdout', new_callable=StringIO) as m_stdout: + with self.assertRaises(SystemExit) as context_manager: + cloud_id.main() + self.assertEqual(0, context_manager.exception.code) + self.assertEqual("aws-china\tcn-north-1\n", m_stdout.getvalue()) + + def test_cloud_id_lookup_json_instance_data_adds_cloud_id_to_json(self): + """Report v1 instance-data content with cloud_id when --json set.""" + util.write_file( + self.instance_data, + '{"v1": {"cloud_name": "unknown", "region": "dfw",' + ' "platform": "openstack", "public_ssh_keys": []}}') + expected = util.json_dumps({ + 'cloud_id': 'openstack', 'cloud_name': 'unknown', + 'platform': 'openstack', 'public_ssh_keys': [], 'region': 'dfw'}) + cmd = ['cloud-id', '--instance-data', self.instance_data, '--json'] + with mock.patch('sys.argv', cmd): + with mock.patch('sys.stdout', new_callable=StringIO) as m_stdout: + with self.assertRaises(SystemExit) as context_manager: + cloud_id.main() + self.assertEqual(0, context_manager.exception.code) + self.assertEqual(expected + '\n', m_stdout.getvalue()) + +# vi: ts=4 expandtab diff --git a/cloudinit/cmd/tests/test_query.py b/cloudinit/cmd/tests/test_query.py index fb87c6ab..28738b1e 100644 --- a/cloudinit/cmd/tests/test_query.py +++ b/cloudinit/cmd/tests/test_query.py @@ -1,5 +1,6 @@ # This file is part of cloud-init. See LICENSE file for license information. +import errno from six import StringIO from textwrap import dedent import os @@ -7,7 +8,8 @@ import os from collections import namedtuple from cloudinit.cmd import query from cloudinit.helpers import Paths -from cloudinit.sources import REDACT_SENSITIVE_VALUE, INSTANCE_JSON_FILE +from cloudinit.sources import ( + REDACT_SENSITIVE_VALUE, INSTANCE_JSON_FILE, INSTANCE_JSON_SENSITIVE_FILE) from cloudinit.tests.helpers import CiTestCase, mock from cloudinit.util import ensure_dir, write_file @@ -50,10 +52,28 @@ class TestQuery(CiTestCase): with mock.patch('sys.stderr', new_callable=StringIO) as m_stderr: self.assertEqual(1, query.handle_args('anyname', args)) self.assertIn( - 'ERROR: Missing instance-data.json file: %s' % absent_fn, + 'ERROR: Missing instance-data file: %s' % absent_fn, self.logs.getvalue()) self.assertIn( - 'ERROR: Missing instance-data.json file: %s' % absent_fn, + 'ERROR: Missing instance-data file: %s' % absent_fn, + m_stderr.getvalue()) + + def test_handle_args_error_when_no_read_permission_instance_data(self): + """When instance_data file is unreadable, log an error.""" + noread_fn = self.tmp_path('unreadable', dir=self.tmp) + write_file(noread_fn, 'thou shall not pass') + args = self.args( + debug=False, dump_all=True, format=None, instance_data=noread_fn, + list_keys=False, user_data='ud', vendor_data='vd', varname=None) + with mock.patch('sys.stderr', new_callable=StringIO) as m_stderr: + with mock.patch('cloudinit.cmd.query.util.load_file') as m_load: + m_load.side_effect = OSError(errno.EACCES, 'Not allowed') + self.assertEqual(1, query.handle_args('anyname', args)) + self.assertIn( + "ERROR: No read permission on '%s'. Try sudo" % noread_fn, + self.logs.getvalue()) + self.assertIn( + "ERROR: No read permission on '%s'. Try sudo" % noread_fn, m_stderr.getvalue()) def test_handle_args_defaults_instance_data(self): @@ -70,12 +90,58 @@ class TestQuery(CiTestCase): self.assertEqual(1, query.handle_args('anyname', args)) json_file = os.path.join(run_dir, INSTANCE_JSON_FILE) self.assertIn( - 'ERROR: Missing instance-data.json file: %s' % json_file, + 'ERROR: Missing instance-data file: %s' % json_file, self.logs.getvalue()) self.assertIn( - 'ERROR: Missing instance-data.json file: %s' % json_file, + 'ERROR: Missing instance-data file: %s' % json_file, m_stderr.getvalue()) + def test_handle_args_root_fallsback_to_instance_data(self): + """When no instance_data argument, root falls back to redacted json.""" + args = self.args( + debug=False, dump_all=True, format=None, instance_data=None, + list_keys=False, user_data=None, vendor_data=None, varname=None) + run_dir = self.tmp_path('run_dir', dir=self.tmp) + ensure_dir(run_dir) + paths = Paths({'run_dir': run_dir}) + self.add_patch('cloudinit.cmd.query.read_cfg_paths', 'm_paths') + self.m_paths.return_value = paths + with mock.patch('sys.stderr', new_callable=StringIO) as m_stderr: + with mock.patch('os.getuid') as m_getuid: + m_getuid.return_value = 0 + self.assertEqual(1, query.handle_args('anyname', args)) + json_file = os.path.join(run_dir, INSTANCE_JSON_FILE) + sensitive_file = os.path.join(run_dir, INSTANCE_JSON_SENSITIVE_FILE) + self.assertIn( + 'WARNING: Missing root-readable %s. Using redacted %s instead.' % ( + sensitive_file, json_file), + m_stderr.getvalue()) + + def test_handle_args_root_uses_instance_sensitive_data(self): + """When no instance_data argument, root uses semsitive json.""" + user_data = self.tmp_path('user-data', dir=self.tmp) + vendor_data = self.tmp_path('vendor-data', dir=self.tmp) + write_file(user_data, 'ud') + write_file(vendor_data, 'vd') + run_dir = self.tmp_path('run_dir', dir=self.tmp) + sensitive_file = os.path.join(run_dir, INSTANCE_JSON_SENSITIVE_FILE) + write_file(sensitive_file, '{"my-var": "it worked"}') + ensure_dir(run_dir) + paths = Paths({'run_dir': run_dir}) + self.add_patch('cloudinit.cmd.query.read_cfg_paths', 'm_paths') + self.m_paths.return_value = paths + args = self.args( + debug=False, dump_all=True, format=None, instance_data=None, + list_keys=False, user_data=vendor_data, vendor_data=vendor_data, + varname=None) + with mock.patch('sys.stdout', new_callable=StringIO) as m_stdout: + with mock.patch('os.getuid') as m_getuid: + m_getuid.return_value = 0 + self.assertEqual(0, query.handle_args('anyname', args)) + self.assertEqual( + '{\n "my_var": "it worked",\n "userdata": "vd",\n ' + '"vendordata": "vd"\n}\n', m_stdout.getvalue()) + def test_handle_args_dumps_all_instance_data(self): """When --all is specified query will dump all instance data vars.""" write_file(self.instance_data, '{"my-var": "it worked"}') diff --git a/cloudinit/config/cc_disk_setup.py b/cloudinit/config/cc_disk_setup.py index 943089e0..29e192e8 100644 --- a/cloudinit/config/cc_disk_setup.py +++ b/cloudinit/config/cc_disk_setup.py @@ -743,7 +743,7 @@ def assert_and_settle_device(device): util.udevadm_settle() if not os.path.exists(device): raise RuntimeError("Device %s did not exist and was not created " - "with a udevamd settle." % device) + "with a udevadm settle." % device) # Whether or not the device existed above, it is possible that udev # events that would populate udev database (for reading by lsdname) have diff --git a/cloudinit/config/cc_resizefs.py b/cloudinit/config/cc_resizefs.py index 2edddd0c..076b9d5a 100644 --- a/cloudinit/config/cc_resizefs.py +++ b/cloudinit/config/cc_resizefs.py @@ -197,6 +197,13 @@ def maybe_get_writable_device_path(devpath, info, log): if devpath.startswith('gpt/'): log.debug('We have a gpt label - just go ahead') return devpath + # Alternatively, our device could simply be a name as returned by gpart, + # such as da0p3 + if not devpath.startswith('/dev/') and not os.path.exists(devpath): + fulldevpath = '/dev/' + devpath.lstrip('/') + log.debug("'%s' doesn't appear to be a valid device path. Trying '%s'", + devpath, fulldevpath) + devpath = fulldevpath try: statret = os.stat(devpath) diff --git a/cloudinit/config/cc_write_files.py b/cloudinit/config/cc_write_files.py index 31d1db61..0b6546e2 100644 --- a/cloudinit/config/cc_write_files.py +++ b/cloudinit/config/cc_write_files.py @@ -49,6 +49,10 @@ binary gzip data can be specified and will be decoded before being written. ... path: /bin/arch permissions: '0555' + - content: | + 15 * * * * root ship_logs + path: /etc/crontab + append: true """ import base64 @@ -113,7 +117,8 @@ def write_files(name, files): contents = extract_contents(f_info.get('content', ''), extractions) (u, g) = util.extract_usergroup(f_info.get('owner', DEFAULT_OWNER)) perms = decode_perms(f_info.get('permissions'), DEFAULT_PERMS) - util.write_file(path, contents, mode=perms) + omode = 'ab' if util.get_cfg_option_bool(f_info, 'append') else 'wb' + util.write_file(path, contents, omode=omode, mode=perms) util.chownbyname(path, u, g) diff --git a/cloudinit/dhclient_hook.py b/cloudinit/dhclient_hook.py index 7f02d7fa..72b51b6a 100644 --- a/cloudinit/dhclient_hook.py +++ b/cloudinit/dhclient_hook.py @@ -1,5 +1,8 @@ # This file is part of cloud-init. See LICENSE file for license information. +"""Run the dhclient hook to record network info.""" + +import argparse import os from cloudinit import atomic_helper @@ -8,44 +11,75 @@ from cloudinit import stages LOG = logging.getLogger(__name__) +NAME = "dhclient-hook" +UP = "up" +DOWN = "down" +EVENTS = (UP, DOWN) + + +def _get_hooks_dir(): + i = stages.Init() + return os.path.join(i.paths.get_runpath(), 'dhclient.hooks') + + +def _filter_env_vals(info): + """Given info (os.environ), return a dictionary with + lower case keys for each entry starting with DHCP4_ or new_.""" + new_info = {} + for k, v in info.items(): + if k.startswith("DHCP4_") or k.startswith("new_"): + key = (k.replace('DHCP4_', '').replace('new_', '')).lower() + new_info[key] = v + return new_info + + +def run_hook(interface, event, data_d=None, env=None): + if event not in EVENTS: + raise ValueError("Unexpected event '%s'. Expected one of: %s" % + (event, EVENTS)) + if data_d is None: + data_d = _get_hooks_dir() + if env is None: + env = os.environ + hook_file = os.path.join(data_d, interface + ".json") + + if event == UP: + if not os.path.exists(data_d): + os.makedirs(data_d) + atomic_helper.write_json(hook_file, _filter_env_vals(env)) + LOG.debug("Wrote dhclient options in %s", hook_file) + elif event == DOWN: + if os.path.exists(hook_file): + os.remove(hook_file) + LOG.debug("Removed dhclient options file %s", hook_file) + + +def get_parser(parser=None): + if parser is None: + parser = argparse.ArgumentParser(prog=NAME, description=__doc__) + parser.add_argument( + "event", help='event taken on the interface', choices=EVENTS) + parser.add_argument( + "interface", help='the network interface being acted upon') + # cloud-init main uses 'action' + parser.set_defaults(action=(NAME, handle_args)) + return parser + + +def handle_args(name, args, data_d=None): + """Handle the Namespace args. + Takes 'name' as passed by cloud-init main. not used here.""" + return run_hook(interface=args.interface, event=args.event, data_d=data_d) + + +if __name__ == '__main__': + import sys + parser = get_parser() + args = parser.parse_args(args=sys.argv[1:]) + return_value = handle_args( + NAME, args, data_d=os.environ.get('_CI_DHCP_HOOK_DATA_D')) + if return_value: + sys.exit(return_value) -class LogDhclient(object): - - def __init__(self, cli_args): - self.hooks_dir = self._get_hooks_dir() - self.net_interface = cli_args.net_interface - self.net_action = cli_args.net_action - self.hook_file = os.path.join(self.hooks_dir, - self.net_interface + ".json") - - @staticmethod - def _get_hooks_dir(): - i = stages.Init() - return os.path.join(i.paths.get_runpath(), 'dhclient.hooks') - - def check_hooks_dir(self): - if not os.path.exists(self.hooks_dir): - os.makedirs(self.hooks_dir) - else: - # If the action is down and the json file exists, we need to - # delete the file - if self.net_action is 'down' and os.path.exists(self.hook_file): - os.remove(self.hook_file) - - @staticmethod - def get_vals(info): - new_info = {} - for k, v in info.items(): - if k.startswith("DHCP4_") or k.startswith("new_"): - key = (k.replace('DHCP4_', '').replace('new_', '')).lower() - new_info[key] = v - return new_info - - def record(self): - envs = os.environ - if self.hook_file is None: - return - atomic_helper.write_json(self.hook_file, self.get_vals(envs)) - LOG.debug("Wrote dhclient options in %s", self.hook_file) # vi: ts=4 expandtab diff --git a/cloudinit/handlers/jinja_template.py b/cloudinit/handlers/jinja_template.py index 3fa4097e..ce3accf6 100644 --- a/cloudinit/handlers/jinja_template.py +++ b/cloudinit/handlers/jinja_template.py @@ -1,5 +1,6 @@ # This file is part of cloud-init. See LICENSE file for license information. +from errno import EACCES import os import re @@ -76,7 +77,14 @@ def render_jinja_payload_from_file( raise RuntimeError( 'Cannot render jinja template vars. Instance data not yet' ' present at %s' % instance_data_file) - instance_data = load_json(load_file(instance_data_file)) + try: + instance_data = load_json(load_file(instance_data_file)) + except (IOError, OSError) as e: + if e.errno == EACCES: + raise RuntimeError( + 'Cannot render jinja template vars. No read permission on' + " '%s'. Try sudo" % instance_data_file) + rendered_payload = render_jinja_payload( payload, payload_fn, instance_data, debug) if not rendered_payload: diff --git a/cloudinit/net/__init__.py b/cloudinit/net/__init__.py index f83d3681..3642fb1f 100644 --- a/cloudinit/net/__init__.py +++ b/cloudinit/net/__init__.py @@ -12,6 +12,7 @@ import re from cloudinit.net.network_state import mask_to_net_prefix from cloudinit import util +from cloudinit.url_helper import UrlError, readurl LOG = logging.getLogger(__name__) SYS_CLASS_NET = "/sys/class/net/" @@ -612,7 +613,8 @@ def get_interfaces(): Bridges and any devices that have a 'stolen' mac are excluded.""" ret = [] devs = get_devicelist() - empty_mac = '00:00:00:00:00:00' + # 16 somewhat arbitrarily chosen. Normally a mac is 6 '00:' tokens. + zero_mac = ':'.join(('00',) * 16) for name in devs: if not interface_has_own_mac(name): continue @@ -624,7 +626,8 @@ def get_interfaces(): # some devices may not have a mac (tun0) if not mac: continue - if mac == empty_mac and name != 'lo': + # skip nics that have no mac (00:00....) + if name != 'lo' and mac == zero_mac[:len(mac)]: continue ret.append((name, mac, device_driver(name), device_devid(name))) return ret @@ -645,16 +648,36 @@ def get_ib_hwaddrs_by_interface(): return ret +def has_url_connectivity(url): + """Return true when the instance has access to the provided URL + + Logs a warning if url is not the expected format. + """ + if not any([url.startswith('http://'), url.startswith('https://')]): + LOG.warning( + "Ignoring connectivity check. Expected URL beginning with http*://" + " received '%s'", url) + return False + try: + readurl(url, timeout=5) + except UrlError: + return False + return True + + class EphemeralIPv4Network(object): """Context manager which sets up temporary static network configuration. - No operations are performed if the provided interface is already connected. + No operations are performed if the provided interface already has the + specified configuration. + This can be verified with the connectivity_url. If unconnected, bring up the interface with valid ip, prefix and broadcast. If router is provided setup a default route for that interface. Upon context exit, clean up the interface leaving no configuration behind. """ - def __init__(self, interface, ip, prefix_or_mask, broadcast, router=None): + def __init__(self, interface, ip, prefix_or_mask, broadcast, router=None, + connectivity_url=None): """Setup context manager and validate call signature. @param interface: Name of the network interface to bring up. @@ -663,6 +686,8 @@ class EphemeralIPv4Network(object): prefix. @param broadcast: Broadcast address for the IPv4 network. @param router: Optionally the default gateway IP. + @param connectivity_url: Optionally, a URL to verify if a usable + connection already exists. """ if not all([interface, ip, prefix_or_mask, broadcast]): raise ValueError( @@ -673,6 +698,8 @@ class EphemeralIPv4Network(object): except ValueError as e: raise ValueError( 'Cannot setup network: {0}'.format(e)) + + self.connectivity_url = connectivity_url self.interface = interface self.ip = ip self.broadcast = broadcast @@ -681,6 +708,13 @@ class EphemeralIPv4Network(object): def __enter__(self): """Perform ephemeral network setup if interface is not connected.""" + if self.connectivity_url: + if has_url_connectivity(self.connectivity_url): + LOG.debug( + 'Skip ephemeral network setup, instance has connectivity' + ' to %s', self.connectivity_url) + return + self._bringup_device() if self.router: self._bringup_router() diff --git a/cloudinit/net/dhcp.py b/cloudinit/net/dhcp.py index 12cf5097..0db991db 100644 --- a/cloudinit/net/dhcp.py +++ b/cloudinit/net/dhcp.py @@ -11,7 +11,8 @@ import re import signal from cloudinit.net import ( - EphemeralIPv4Network, find_fallback_nic, get_devicelist) + EphemeralIPv4Network, find_fallback_nic, get_devicelist, + has_url_connectivity) from cloudinit.net.network_state import mask_and_ipv4_to_bcast_addr as bcip from cloudinit import temp_utils from cloudinit import util @@ -37,37 +38,69 @@ class NoDHCPLeaseError(Exception): class EphemeralDHCPv4(object): - def __init__(self, iface=None): + def __init__(self, iface=None, connectivity_url=None): self.iface = iface self._ephipv4 = None + self.lease = None + self.connectivity_url = connectivity_url def __enter__(self): + """Setup sandboxed dhcp context, unless connectivity_url can already be + reached.""" + if self.connectivity_url: + if has_url_connectivity(self.connectivity_url): + LOG.debug( + 'Skip ephemeral DHCP setup, instance has connectivity' + ' to %s', self.connectivity_url) + return + return self.obtain_lease() + + def __exit__(self, excp_type, excp_value, excp_traceback): + """Teardown sandboxed dhcp context.""" + self.clean_network() + + def clean_network(self): + """Exit _ephipv4 context to teardown of ip configuration performed.""" + if self.lease: + self.lease = None + if not self._ephipv4: + return + self._ephipv4.__exit__(None, None, None) + + def obtain_lease(self): + """Perform dhcp discovery in a sandboxed environment if possible. + + @return: A dict representing dhcp options on the most recent lease + obtained from the dhclient discovery if run, otherwise an error + is raised. + + @raises: NoDHCPLeaseError if no leases could be obtained. + """ + if self.lease: + return self.lease try: leases = maybe_perform_dhcp_discovery(self.iface) except InvalidDHCPLeaseFileError: raise NoDHCPLeaseError() if not leases: raise NoDHCPLeaseError() - lease = leases[-1] + self.lease = leases[-1] LOG.debug("Received dhcp lease on %s for %s/%s", - lease['interface'], lease['fixed-address'], - lease['subnet-mask']) + self.lease['interface'], self.lease['fixed-address'], + self.lease['subnet-mask']) nmap = {'interface': 'interface', 'ip': 'fixed-address', 'prefix_or_mask': 'subnet-mask', 'broadcast': 'broadcast-address', 'router': 'routers'} - kwargs = dict([(k, lease.get(v)) for k, v in nmap.items()]) + kwargs = dict([(k, self.lease.get(v)) for k, v in nmap.items()]) if not kwargs['broadcast']: kwargs['broadcast'] = bcip(kwargs['prefix_or_mask'], kwargs['ip']) + if self.connectivity_url: + kwargs['connectivity_url'] = self.connectivity_url ephipv4 = EphemeralIPv4Network(**kwargs) ephipv4.__enter__() self._ephipv4 = ephipv4 - return lease - - def __exit__(self, excp_type, excp_value, excp_traceback): - if not self._ephipv4: - return - self._ephipv4.__exit__(excp_type, excp_value, excp_traceback) + return self.lease def maybe_perform_dhcp_discovery(nic=None): diff --git a/cloudinit/net/eni.py b/cloudinit/net/eni.py index c6f631a9..64236320 100644 --- a/cloudinit/net/eni.py +++ b/cloudinit/net/eni.py @@ -371,22 +371,23 @@ class Renderer(renderer.Renderer): 'gateway': 'gw', 'metric': 'metric', } + + default_gw = '' if route['network'] == '0.0.0.0' and route['netmask'] == '0.0.0.0': - default_gw = " default gw %s" % route['gateway'] - content.append(up + default_gw + or_true) - content.append(down + default_gw + or_true) + default_gw = ' default' elif route['network'] == '::' and route['prefix'] == 0: - # ipv6! - default_gw = " -A inet6 default gw %s" % route['gateway'] - content.append(up + default_gw + or_true) - content.append(down + default_gw + or_true) - else: - route_line = "" - for k in ['network', 'netmask', 'gateway', 'metric']: - if k in route: - route_line += " %s %s" % (mapping[k], route[k]) - content.append(up + route_line + or_true) - content.append(down + route_line + or_true) + default_gw = ' -A inet6 default' + + route_line = '' + for k in ['network', 'netmask', 'gateway', 'metric']: + if default_gw and k in ['network', 'netmask']: + continue + if k == 'gateway': + route_line += '%s %s %s' % (default_gw, mapping[k], route[k]) + elif k in route: + route_line += ' %s %s' % (mapping[k], route[k]) + content.append(up + route_line + or_true) + content.append(down + route_line + or_true) return content def _render_iface(self, iface, render_hwaddress=False): diff --git a/cloudinit/net/netplan.py b/cloudinit/net/netplan.py index bc1087f9..21517fda 100644 --- a/cloudinit/net/netplan.py +++ b/cloudinit/net/netplan.py @@ -114,13 +114,13 @@ def _extract_addresses(config, entry, ifname): for route in subnet.get('routes', []): to_net = "%s/%s" % (route.get('network'), route.get('prefix')) - route = { + new_route = { 'via': route.get('gateway'), 'to': to_net, } if 'metric' in route: - route.update({'metric': route.get('metric', 100)}) - routes.append(route) + new_route.update({'metric': route.get('metric', 100)}) + routes.append(new_route) addresses.append(addr) diff --git a/cloudinit/net/sysconfig.py b/cloudinit/net/sysconfig.py index 9c16d3a7..17293e1d 100644 --- a/cloudinit/net/sysconfig.py +++ b/cloudinit/net/sysconfig.py @@ -156,13 +156,23 @@ class Route(ConfigMap): _quote_value(gateway_value))) buf.write("%s=%s\n" % ('NETMASK' + str(reindex), _quote_value(netmask_value))) + metric_key = 'METRIC' + index + if metric_key in self._conf: + metric_value = str(self._conf['METRIC' + index]) + buf.write("%s=%s\n" % ('METRIC' + str(reindex), + _quote_value(metric_value))) elif proto == "ipv6" and self.is_ipv6_route(address_value): netmask_value = str(self._conf['NETMASK' + index]) gateway_value = str(self._conf['GATEWAY' + index]) - buf.write("%s/%s via %s dev %s\n" % (address_value, - netmask_value, - gateway_value, - self._route_name)) + metric_value = ( + 'metric ' + str(self._conf['METRIC' + index]) + if 'METRIC' + index in self._conf else '') + buf.write( + "%s/%s via %s %s dev %s\n" % (address_value, + netmask_value, + gateway_value, + metric_value, + self._route_name)) return buf.getvalue() @@ -370,6 +380,9 @@ class Renderer(renderer.Renderer): else: iface_cfg['GATEWAY'] = subnet['gateway'] + if 'metric' in subnet: + iface_cfg['METRIC'] = subnet['metric'] + if 'dns_search' in subnet: iface_cfg['DOMAIN'] = ' '.join(subnet['dns_search']) @@ -414,15 +427,19 @@ class Renderer(renderer.Renderer): else: iface_cfg['GATEWAY'] = route['gateway'] route_cfg.has_set_default_ipv4 = True + if 'metric' in route: + iface_cfg['METRIC'] = route['metric'] else: gw_key = 'GATEWAY%s' % route_cfg.last_idx nm_key = 'NETMASK%s' % route_cfg.last_idx addr_key = 'ADDRESS%s' % route_cfg.last_idx + metric_key = 'METRIC%s' % route_cfg.last_idx route_cfg.last_idx += 1 # add default routes only to ifcfg files, not # to route-* or route6-* for (old_key, new_key) in [('gateway', gw_key), + ('metric', metric_key), ('netmask', nm_key), ('network', addr_key)]: if old_key in route: diff --git a/cloudinit/net/tests/test_dhcp.py b/cloudinit/net/tests/test_dhcp.py index db25b6f2..cd3e7328 100644 --- a/cloudinit/net/tests/test_dhcp.py +++ b/cloudinit/net/tests/test_dhcp.py @@ -1,15 +1,17 @@ # This file is part of cloud-init. See LICENSE file for license information. +import httpretty import os import signal from textwrap import dedent +import cloudinit.net as net from cloudinit.net.dhcp import ( InvalidDHCPLeaseFileError, maybe_perform_dhcp_discovery, parse_dhcp_lease_file, dhcp_discovery, networkd_load_leases) from cloudinit.util import ensure_file, write_file from cloudinit.tests.helpers import ( - CiTestCase, mock, populate_dir, wrap_and_call) + CiTestCase, HttprettyTestCase, mock, populate_dir, wrap_and_call) class TestParseDHCPLeasesFile(CiTestCase): @@ -321,3 +323,35 @@ class TestSystemdParseLeases(CiTestCase): '9': self.lxd_lease}) self.assertEqual({'1': self.azure_parsed, '9': self.lxd_parsed}, networkd_load_leases(self.lease_d)) + + +class TestEphemeralDhcpNoNetworkSetup(HttprettyTestCase): + + @mock.patch('cloudinit.net.dhcp.maybe_perform_dhcp_discovery') + def test_ephemeral_dhcp_no_network_if_url_connectivity(self, m_dhcp): + """No EphemeralDhcp4 network setup when connectivity_url succeeds.""" + url = 'http://example.org/index.html' + + httpretty.register_uri(httpretty.GET, url) + with net.dhcp.EphemeralDHCPv4(connectivity_url=url) as lease: + self.assertIsNone(lease) + # Ensure that no teardown happens: + m_dhcp.assert_not_called() + + @mock.patch('cloudinit.net.dhcp.util.subp') + @mock.patch('cloudinit.net.dhcp.maybe_perform_dhcp_discovery') + def test_ephemeral_dhcp_setup_network_if_url_connectivity( + self, m_dhcp, m_subp): + """No EphemeralDhcp4 network setup when connectivity_url succeeds.""" + url = 'http://example.org/index.html' + fake_lease = { + 'interface': 'eth9', 'fixed-address': '192.168.2.2', + 'subnet-mask': '255.255.0.0'} + m_dhcp.return_value = [fake_lease] + m_subp.return_value = ('', '') + + httpretty.register_uri(httpretty.GET, url, body={}, status=404) + with net.dhcp.EphemeralDHCPv4(connectivity_url=url) as lease: + self.assertEqual(fake_lease, lease) + # Ensure that dhcp discovery occurs + m_dhcp.called_once_with() diff --git a/cloudinit/net/tests/test_init.py b/cloudinit/net/tests/test_init.py index 58e0a591..f55c31e8 100644 --- a/cloudinit/net/tests/test_init.py +++ b/cloudinit/net/tests/test_init.py @@ -2,14 +2,16 @@ import copy import errno +import httpretty import mock import os +import requests import textwrap import yaml import cloudinit.net as net from cloudinit.util import ensure_file, write_file, ProcessExecutionError -from cloudinit.tests.helpers import CiTestCase +from cloudinit.tests.helpers import CiTestCase, HttprettyTestCase class TestSysDevPath(CiTestCase): @@ -458,6 +460,22 @@ class TestEphemeralIPV4Network(CiTestCase): self.assertEqual(expected_setup_calls, m_subp.call_args_list) m_subp.assert_has_calls(expected_teardown_calls) + @mock.patch('cloudinit.net.readurl') + def test_ephemeral_ipv4_no_network_if_url_connectivity( + self, m_readurl, m_subp): + """No network setup is performed if we can successfully connect to + connectivity_url.""" + params = { + 'interface': 'eth0', 'ip': '192.168.2.2', + 'prefix_or_mask': '255.255.255.0', 'broadcast': '192.168.2.255', + 'connectivity_url': 'http://example.org/index.html'} + + with net.EphemeralIPv4Network(**params): + self.assertEqual([mock.call('http://example.org/index.html', + timeout=5)], m_readurl.call_args_list) + # Ensure that no teardown happens: + m_subp.assert_has_calls([]) + def test_ephemeral_ipv4_network_noop_when_configured(self, m_subp): """EphemeralIPv4Network handles exception when address is setup. @@ -619,3 +637,35 @@ class TestApplyNetworkCfgNames(CiTestCase): def test_apply_v2_renames_raises_runtime_error_on_unknown_version(self): with self.assertRaises(RuntimeError): net.apply_network_config_names(yaml.load("version: 3")) + + +class TestHasURLConnectivity(HttprettyTestCase): + + def setUp(self): + super(TestHasURLConnectivity, self).setUp() + self.url = 'http://fake/' + self.kwargs = {'allow_redirects': True, 'timeout': 5.0} + + @mock.patch('cloudinit.net.readurl') + def test_url_timeout_on_connectivity_check(self, m_readurl): + """A timeout of 5 seconds is provided when reading a url.""" + self.assertTrue( + net.has_url_connectivity(self.url), 'Expected True on url connect') + + def test_true_on_url_connectivity_success(self): + httpretty.register_uri(httpretty.GET, self.url) + self.assertTrue( + net.has_url_connectivity(self.url), 'Expected True on url connect') + + @mock.patch('requests.Session.request') + def test_true_on_url_connectivity_timeout(self, m_request): + """A timeout raised accessing the url will return False.""" + m_request.side_effect = requests.Timeout('Fake Connection Timeout') + self.assertFalse( + net.has_url_connectivity(self.url), + 'Expected False on url timeout') + + def test_true_on_url_connectivity_failure(self): + httpretty.register_uri(httpretty.GET, self.url, body={}, status=404) + self.assertFalse( + net.has_url_connectivity(self.url), 'Expected False on url fail') diff --git a/cloudinit/sources/DataSourceAliYun.py b/cloudinit/sources/DataSourceAliYun.py index 858e0827..45cc9f00 100644 --- a/cloudinit/sources/DataSourceAliYun.py +++ b/cloudinit/sources/DataSourceAliYun.py @@ -1,7 +1,5 @@ # This file is part of cloud-init. See LICENSE file for license information. -import os - from cloudinit import sources from cloudinit.sources import DataSourceEc2 as EC2 from cloudinit import util @@ -18,25 +16,17 @@ class DataSourceAliYun(EC2.DataSourceEc2): min_metadata_version = '2016-01-01' extended_metadata_versions = [] - def __init__(self, sys_cfg, distro, paths): - super(DataSourceAliYun, self).__init__(sys_cfg, distro, paths) - self.seed_dir = os.path.join(paths.seed_dir, "AliYun") - def get_hostname(self, fqdn=False, resolve_ip=False, metadata_only=False): return self.metadata.get('hostname', 'localhost.localdomain') def get_public_ssh_keys(self): return parse_public_keys(self.metadata.get('public-keys', {})) - @property - def cloud_platform(self): - if self._cloud_platform is None: - if _is_aliyun(): - self._cloud_platform = EC2.Platforms.ALIYUN - else: - self._cloud_platform = EC2.Platforms.NO_EC2_METADATA - - return self._cloud_platform + def _get_cloud_name(self): + if _is_aliyun(): + return EC2.CloudNames.ALIYUN + else: + return EC2.CloudNames.NO_EC2_METADATA def _is_aliyun(): diff --git a/cloudinit/sources/DataSourceAltCloud.py b/cloudinit/sources/DataSourceAltCloud.py index 8cd312d0..5270fda8 100644 --- a/cloudinit/sources/DataSourceAltCloud.py +++ b/cloudinit/sources/DataSourceAltCloud.py @@ -89,7 +89,9 @@ class DataSourceAltCloud(sources.DataSource): ''' Description: Get the type for the cloud back end this instance is running on - by examining the string returned by reading the dmi data. + by examining the string returned by reading either: + CLOUD_INFO_FILE or + the dmi data. Input: None @@ -99,7 +101,14 @@ class DataSourceAltCloud(sources.DataSource): 'RHEV', 'VSPHERE' or 'UNKNOWN' ''' - + if os.path.exists(CLOUD_INFO_FILE): + try: + cloud_type = util.load_file(CLOUD_INFO_FILE).strip().upper() + except IOError: + util.logexc(LOG, 'Unable to access cloud info file at %s.', + CLOUD_INFO_FILE) + return 'UNKNOWN' + return cloud_type system_name = util.read_dmi_data("system-product-name") if not system_name: return 'UNKNOWN' @@ -134,15 +143,7 @@ class DataSourceAltCloud(sources.DataSource): LOG.debug('Invoked get_data()') - if os.path.exists(CLOUD_INFO_FILE): - try: - cloud_type = util.load_file(CLOUD_INFO_FILE).strip().upper() - except IOError: - util.logexc(LOG, 'Unable to access cloud info file at %s.', - CLOUD_INFO_FILE) - return False - else: - cloud_type = self.get_cloud_type() + cloud_type = self.get_cloud_type() LOG.debug('cloud_type: %s', str(cloud_type)) @@ -161,6 +162,15 @@ class DataSourceAltCloud(sources.DataSource): util.logexc(LOG, 'Failed accessing user data.') return False + def _get_subplatform(self): + """Return the subplatform metadata details.""" + cloud_type = self.get_cloud_type() + if not hasattr(self, 'source'): + self.source = sources.METADATA_UNKNOWN + if cloud_type == 'RHEV': + self.source = '/dev/fd0' + return '%s (%s)' % (cloud_type.lower(), self.source) + def user_data_rhevm(self): ''' RHEVM specific userdata read @@ -232,6 +242,7 @@ class DataSourceAltCloud(sources.DataSource): try: return_str = util.mount_cb(cdrom_dev, read_user_data_callback) if return_str: + self.source = cdrom_dev break except OSError as err: if err.errno != errno.ENOENT: diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 629f006f..a06e6e1f 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -23,7 +23,8 @@ from cloudinit.event import EventType from cloudinit.net.dhcp import EphemeralDHCPv4 from cloudinit import sources from cloudinit.sources.helpers.azure import get_metadata_from_fabric -from cloudinit.url_helper import readurl, UrlError +from cloudinit.sources.helpers import netlink +from cloudinit.url_helper import UrlError, readurl, retry_on_url_exc from cloudinit import util LOG = logging.getLogger(__name__) @@ -58,7 +59,7 @@ IMDS_URL = "http://169.254.169.254/metadata/" # List of static scripts and network config artifacts created by # stock ubuntu suported images. UBUNTU_EXTENDED_NETWORK_SCRIPTS = [ - '/etc/netplan/90-azure-hotplug.yaml', + '/etc/netplan/90-hotplug-azure.yaml', '/usr/local/sbin/ephemeral_eth.sh', '/etc/udev/rules.d/10-net-device-added.rules', '/run/network/interfaces.ephemeral.d', @@ -208,7 +209,9 @@ BUILTIN_DS_CONFIG = { }, 'disk_aliases': {'ephemeral0': RESOURCE_DISK_PATH}, 'dhclient_lease_file': LEASE_FILE, + 'apply_network_config': True, # Use IMDS published network configuration } +# RELEASE_BLOCKER: Xenial and earlier apply_network_config default is False BUILTIN_CLOUD_CONFIG = { 'disk_setup': { @@ -284,6 +287,7 @@ class DataSourceAzure(sources.DataSource): self._network_config = None # Regenerate network config new_instance boot and every boot self.update_events['network'].add(EventType.BOOT) + self._ephemeral_dhcp_ctx = None def __str__(self): root = sources.DataSource.__str__(self) @@ -357,6 +361,14 @@ class DataSourceAzure(sources.DataSource): metadata['public-keys'] = key_value or pubkeys_from_crt_files(fp_files) return metadata + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + if self.seed.startswith('/dev'): + subplatform_type = 'config-disk' + else: + subplatform_type = 'seed-dir' + return '%s (%s)' % (subplatform_type, self.seed) + def crawl_metadata(self): """Walk all instance metadata sources returning a dict on success. @@ -402,7 +414,12 @@ class DataSourceAzure(sources.DataSource): LOG.warning("%s was not mountable", cdev) continue - if reprovision or self._should_reprovision(ret): + perform_reprovision = reprovision or self._should_reprovision(ret) + if perform_reprovision: + if util.is_FreeBSD(): + msg = "Free BSD is not supported for PPS VMs" + LOG.error(msg) + raise sources.InvalidMetaDataException(msg) ret = self._reprovision() imds_md = get_metadata_from_imds( self.fallback_interface, retries=3) @@ -430,6 +447,18 @@ class DataSourceAzure(sources.DataSource): crawled_data['metadata']['random_seed'] = seed crawled_data['metadata']['instance-id'] = util.read_dmi_data( 'system-uuid') + + if perform_reprovision: + LOG.info("Reporting ready to Azure after getting ReprovisionData") + use_cached_ephemeral = (net.is_up(self.fallback_interface) and + getattr(self, '_ephemeral_dhcp_ctx', None)) + if use_cached_ephemeral: + self._report_ready(lease=self._ephemeral_dhcp_ctx.lease) + self._ephemeral_dhcp_ctx.clean_network() # Teardown ephemeral + else: + with EphemeralDHCPv4() as lease: + self._report_ready(lease=lease) + return crawled_data def _is_platform_viable(self): @@ -456,7 +485,8 @@ class DataSourceAzure(sources.DataSource): except sources.InvalidMetaDataException as e: LOG.warning('Could not crawl Azure metadata: %s', e) return False - if self.distro and self.distro.name == 'ubuntu': + if (self.distro and self.distro.name == 'ubuntu' and + self.ds_cfg.get('apply_network_config')): maybe_remove_ubuntu_network_config_scripts() # Process crawled data and augment with various config defaults @@ -504,8 +534,8 @@ class DataSourceAzure(sources.DataSource): response. Then return the returned JSON object.""" url = IMDS_URL + "reprovisiondata?api-version=2017-04-02" headers = {"Metadata": "true"} + nl_sock = None report_ready = bool(not os.path.isfile(REPORTED_READY_MARKER_FILE)) - LOG.debug("Start polling IMDS") def exc_cb(msg, exception): if isinstance(exception, UrlError) and exception.code == 404: @@ -514,25 +544,47 @@ class DataSourceAzure(sources.DataSource): # call DHCP and setup the ephemeral network to acquire the new IP. return False + LOG.debug("Wait for vnetswitch to happen") while True: try: - with EphemeralDHCPv4() as lease: - if report_ready: - path = REPORTED_READY_MARKER_FILE - LOG.info( - "Creating a marker file to report ready: %s", path) - util.write_file(path, "{pid}: {time}\n".format( - pid=os.getpid(), time=time())) - self._report_ready(lease=lease) - report_ready = False + # Save our EphemeralDHCPv4 context so we avoid repeated dhcp + self._ephemeral_dhcp_ctx = EphemeralDHCPv4() + lease = self._ephemeral_dhcp_ctx.obtain_lease() + if report_ready: + try: + nl_sock = netlink.create_bound_netlink_socket() + except netlink.NetlinkCreateSocketError as e: + LOG.warning(e) + self._ephemeral_dhcp_ctx.clean_network() + return + path = REPORTED_READY_MARKER_FILE + LOG.info( + "Creating a marker file to report ready: %s", path) + util.write_file(path, "{pid}: {time}\n".format( + pid=os.getpid(), time=time())) + self._report_ready(lease=lease) + report_ready = False + try: + netlink.wait_for_media_disconnect_connect( + nl_sock, lease['interface']) + except AssertionError as error: + LOG.error(error) + return + self._ephemeral_dhcp_ctx.clean_network() + else: return readurl(url, timeout=1, headers=headers, - exception_cb=exc_cb, infinite=True).contents + exception_cb=exc_cb, infinite=True, + log_req_resp=False).contents except UrlError: + # Teardown our EphemeralDHCPv4 context on failure as we retry + self._ephemeral_dhcp_ctx.clean_network() pass + finally: + if nl_sock: + nl_sock.close() def _report_ready(self, lease): - """Tells the fabric provisioning has completed - before we go into our polling loop.""" + """Tells the fabric provisioning has completed """ try: get_metadata_from_fabric(None, lease['unknown-245']) except Exception: @@ -617,7 +669,11 @@ class DataSourceAzure(sources.DataSource): the blacklisted devices. """ if not self._network_config: - self._network_config = parse_network_config(self._metadata_imds) + if self.ds_cfg.get('apply_network_config'): + nc_src = self._metadata_imds + else: + nc_src = None + self._network_config = parse_network_config(nc_src) return self._network_config @@ -698,7 +754,7 @@ def can_dev_be_reformatted(devpath, preserve_ntfs): file_count = util.mount_cb(cand_path, count_files, mtype="ntfs", update_env_for_mount={'LANG': 'C'}) except util.MountFailedError as e: - if "mount: unknown filesystem type 'ntfs'" in str(e): + if "unknown filesystem type 'ntfs'" in str(e): return True, (bmsg + ' but this system cannot mount NTFS,' ' assuming there are no important files.' ' Formatting allowed.') @@ -926,12 +982,12 @@ def read_azure_ovf(contents): lambda n: n.localName == "LinuxProvisioningConfigurationSet") - if len(results) == 0: + if len(lpcs_nodes) == 0: raise NonAzureDataSource("No LinuxProvisioningConfigurationSet") - if len(results) > 1: + if len(lpcs_nodes) > 1: raise BrokenAzureDataSource("found '%d' %ss" % ("LinuxProvisioningConfigurationSet", - len(results))) + len(lpcs_nodes))) lpcs = lpcs_nodes[0] if not lpcs.hasChildNodes(): @@ -1160,17 +1216,12 @@ def get_metadata_from_imds(fallback_nic, retries): def _get_metadata_from_imds(retries): - def retry_on_url_error(msg, exception): - if isinstance(exception, UrlError) and exception.code == 404: - return True # Continue retries - return False # Stop retries on all other exceptions - url = IMDS_URL + "instance?api-version=2017-12-01" headers = {"Metadata": "true"} try: response = readurl( url, timeout=1, headers=headers, retries=retries, - exception_cb=retry_on_url_error) + exception_cb=retry_on_url_exc) except Exception as e: LOG.debug('Ignoring IMDS instance metadata: %s', e) return {} @@ -1193,7 +1244,7 @@ def maybe_remove_ubuntu_network_config_scripts(paths=None): additional interfaces which get attached by a customer at some point after initial boot. Since the Azure datasource can now regenerate network configuration as metadata reports these new devices, we no longer - want the udev rules or netplan's 90-azure-hotplug.yaml to configure + want the udev rules or netplan's 90-hotplug-azure.yaml to configure networking on eth1 or greater as it might collide with cloud-init's configuration. diff --git a/cloudinit/sources/DataSourceBigstep.py b/cloudinit/sources/DataSourceBigstep.py index 699a85b5..52fff20a 100644 --- a/cloudinit/sources/DataSourceBigstep.py +++ b/cloudinit/sources/DataSourceBigstep.py @@ -36,6 +36,10 @@ class DataSourceBigstep(sources.DataSource): self.userdata_raw = decoded["userdata_raw"] return True + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + return 'metadata (%s)' % get_url_from_file() + def get_url_from_file(): try: diff --git a/cloudinit/sources/DataSourceCloudSigma.py b/cloudinit/sources/DataSourceCloudSigma.py index c816f349..2955d3f0 100644 --- a/cloudinit/sources/DataSourceCloudSigma.py +++ b/cloudinit/sources/DataSourceCloudSigma.py @@ -7,7 +7,7 @@ from base64 import b64decode import re -from cloudinit.cs_utils import Cepko +from cloudinit.cs_utils import Cepko, SERIAL_PORT from cloudinit import log as logging from cloudinit import sources @@ -84,6 +84,10 @@ class DataSourceCloudSigma(sources.DataSource): return True + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + return 'cepko (%s)' % SERIAL_PORT + def get_hostname(self, fqdn=False, resolve_ip=False, metadata_only=False): """ Cleans up and uses the server's name if the latter is set. Otherwise diff --git a/cloudinit/sources/DataSourceConfigDrive.py b/cloudinit/sources/DataSourceConfigDrive.py index 664dc4b7..564e3eb3 100644 --- a/cloudinit/sources/DataSourceConfigDrive.py +++ b/cloudinit/sources/DataSourceConfigDrive.py @@ -160,6 +160,18 @@ class DataSourceConfigDrive(openstack.SourceMixin, sources.DataSource): LOG.debug("no network configuration available") return self._network_config + @property + def platform(self): + return 'openstack' + + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + if self.seed_dir in self.source: + subplatform_type = 'seed-dir' + elif self.source.startswith('/dev'): + subplatform_type = 'config-disk' + return '%s (%s)' % (subplatform_type, self.source) + def read_config_drive(source_dir): reader = openstack.ConfigDriveReader(source_dir) diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index 98ea7bbc..b49a08db 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -30,18 +30,16 @@ STRICT_ID_DEFAULT = "warn" DEFAULT_PRIMARY_NIC = 'eth0' -class Platforms(object): - # TODO Rename and move to cloudinit.cloud.CloudNames - ALIYUN = "AliYun" - AWS = "AWS" - BRIGHTBOX = "Brightbox" - SEEDED = "Seeded" +class CloudNames(object): + ALIYUN = "aliyun" + AWS = "aws" + BRIGHTBOX = "brightbox" # UNKNOWN indicates no positive id. If strict_id is 'warn' or 'false', # then an attempt at the Ec2 Metadata service will be made. - UNKNOWN = "Unknown" + UNKNOWN = "unknown" # NO_EC2_METADATA indicates this platform does not have a Ec2 metadata # service available. No attempt at the Ec2 Metadata service will be made. - NO_EC2_METADATA = "No-EC2-Metadata" + NO_EC2_METADATA = "no-ec2-metadata" class DataSourceEc2(sources.DataSource): @@ -69,8 +67,6 @@ class DataSourceEc2(sources.DataSource): url_max_wait = 120 url_timeout = 50 - _cloud_platform = None - _network_config = sources.UNSET # Used to cache calculated network cfg v1 # Whether we want to get network configuration from the metadata service. @@ -79,30 +75,21 @@ class DataSourceEc2(sources.DataSource): def __init__(self, sys_cfg, distro, paths): super(DataSourceEc2, self).__init__(sys_cfg, distro, paths) self.metadata_address = None - self.seed_dir = os.path.join(paths.seed_dir, "ec2") def _get_cloud_name(self): """Return the cloud name as identified during _get_data.""" - return self.cloud_platform + return identify_platform() def _get_data(self): - seed_ret = {} - if util.read_optional_seed(seed_ret, base=(self.seed_dir + "/")): - self.userdata_raw = seed_ret['user-data'] - self.metadata = seed_ret['meta-data'] - LOG.debug("Using seeded ec2 data from %s", self.seed_dir) - self._cloud_platform = Platforms.SEEDED - return True - strict_mode, _sleep = read_strict_mode( util.get_cfg_by_path(self.sys_cfg, STRICT_ID_PATH, STRICT_ID_DEFAULT), ("warn", None)) - LOG.debug("strict_mode: %s, cloud_platform=%s", - strict_mode, self.cloud_platform) - if strict_mode == "true" and self.cloud_platform == Platforms.UNKNOWN: + LOG.debug("strict_mode: %s, cloud_name=%s cloud_platform=%s", + strict_mode, self.cloud_name, self.platform) + if strict_mode == "true" and self.cloud_name == CloudNames.UNKNOWN: return False - elif self.cloud_platform == Platforms.NO_EC2_METADATA: + elif self.cloud_name == CloudNames.NO_EC2_METADATA: return False if self.perform_dhcp_setup: # Setup networking in init-local stage. @@ -111,13 +98,22 @@ class DataSourceEc2(sources.DataSource): return False try: with EphemeralDHCPv4(self.fallback_interface): - return util.log_time( + self._crawled_metadata = util.log_time( logfunc=LOG.debug, msg='Crawl of metadata service', - func=self._crawl_metadata) + func=self.crawl_metadata) except NoDHCPLeaseError: return False else: - return self._crawl_metadata() + self._crawled_metadata = util.log_time( + logfunc=LOG.debug, msg='Crawl of metadata service', + func=self.crawl_metadata) + if not self._crawled_metadata: + return False + self.metadata = self._crawled_metadata.get('meta-data', None) + self.userdata_raw = self._crawled_metadata.get('user-data', None) + self.identity = self._crawled_metadata.get( + 'dynamic', {}).get('instance-identity', {}).get('document', {}) + return True @property def launch_index(self): @@ -125,6 +121,15 @@ class DataSourceEc2(sources.DataSource): return None return self.metadata.get('ami-launch-index') + @property + def platform(self): + # Handle upgrade path of pickled ds + if not hasattr(self, '_platform_type'): + self._platform_type = DataSourceEc2.dsname.lower() + if not self._platform_type: + self._platform_type = DataSourceEc2.dsname.lower() + return self._platform_type + def get_metadata_api_version(self): """Get the best supported api version from the metadata service. @@ -152,7 +157,7 @@ class DataSourceEc2(sources.DataSource): return self.min_metadata_version def get_instance_id(self): - if self.cloud_platform == Platforms.AWS: + if self.cloud_name == CloudNames.AWS: # Prefer the ID from the instance identity document, but fall back if not getattr(self, 'identity', None): # If re-using cached datasource, it's get_data run didn't @@ -262,7 +267,7 @@ class DataSourceEc2(sources.DataSource): @property def availability_zone(self): try: - if self.cloud_platform == Platforms.AWS: + if self.cloud_name == CloudNames.AWS: return self.identity.get( 'availabilityZone', self.metadata['placement']['availability-zone']) @@ -273,7 +278,7 @@ class DataSourceEc2(sources.DataSource): @property def region(self): - if self.cloud_platform == Platforms.AWS: + if self.cloud_name == CloudNames.AWS: region = self.identity.get('region') # Fallback to trimming the availability zone if region is missing if self.availability_zone and not region: @@ -285,16 +290,10 @@ class DataSourceEc2(sources.DataSource): return az[:-1] return None - @property - def cloud_platform(self): # TODO rename cloud_name - if self._cloud_platform is None: - self._cloud_platform = identify_platform() - return self._cloud_platform - def activate(self, cfg, is_new_instance): if not is_new_instance: return - if self.cloud_platform == Platforms.UNKNOWN: + if self.cloud_name == CloudNames.UNKNOWN: warn_if_necessary( util.get_cfg_by_path(cfg, STRICT_ID_PATH, STRICT_ID_DEFAULT), cfg) @@ -314,13 +313,13 @@ class DataSourceEc2(sources.DataSource): result = None no_network_metadata_on_aws = bool( 'network' not in self.metadata and - self.cloud_platform == Platforms.AWS) + self.cloud_name == CloudNames.AWS) if no_network_metadata_on_aws: LOG.debug("Metadata 'network' not present:" " Refreshing stale metadata from prior to upgrade.") util.log_time( logfunc=LOG.debug, msg='Re-crawl of metadata service', - func=self._crawl_metadata) + func=self.get_data) # Limit network configuration to only the primary/fallback nic iface = self.fallback_interface @@ -348,28 +347,32 @@ class DataSourceEc2(sources.DataSource): return super(DataSourceEc2, self).fallback_interface return self._fallback_interface - def _crawl_metadata(self): + def crawl_metadata(self): """Crawl metadata service when available. - @returns: True on success, False otherwise. + @returns: Dictionary of crawled metadata content containing the keys: + meta-data, user-data and dynamic. """ if not self.wait_for_metadata_service(): - return False + return {} api_version = self.get_metadata_api_version() + crawled_metadata = {} try: - self.userdata_raw = ec2.get_instance_userdata( + crawled_metadata['user-data'] = ec2.get_instance_userdata( api_version, self.metadata_address) - self.metadata = ec2.get_instance_metadata( + crawled_metadata['meta-data'] = ec2.get_instance_metadata( api_version, self.metadata_address) - if self.cloud_platform == Platforms.AWS: - self.identity = ec2.get_instance_identity( - api_version, self.metadata_address).get('document', {}) + if self.cloud_name == CloudNames.AWS: + identity = ec2.get_instance_identity( + api_version, self.metadata_address) + crawled_metadata['dynamic'] = {'instance-identity': identity} except Exception: util.logexc( LOG, "Failed reading from metadata address %s", self.metadata_address) - return False - return True + return {} + crawled_metadata['_metadata_api_version'] = api_version + return crawled_metadata class DataSourceEc2Local(DataSourceEc2): @@ -383,10 +386,10 @@ class DataSourceEc2Local(DataSourceEc2): perform_dhcp_setup = True # Use dhcp before querying metadata def get_data(self): - supported_platforms = (Platforms.AWS,) - if self.cloud_platform not in supported_platforms: + supported_platforms = (CloudNames.AWS,) + if self.cloud_name not in supported_platforms: LOG.debug("Local Ec2 mode only supported on %s, not %s", - supported_platforms, self.cloud_platform) + supported_platforms, self.cloud_name) return False return super(DataSourceEc2Local, self).get_data() @@ -447,20 +450,20 @@ def identify_aws(data): if (data['uuid'].startswith('ec2') and (data['uuid_source'] == 'hypervisor' or data['uuid'] == data['serial'])): - return Platforms.AWS + return CloudNames.AWS return None def identify_brightbox(data): if data['serial'].endswith('brightbox.com'): - return Platforms.BRIGHTBOX + return CloudNames.BRIGHTBOX def identify_platform(): - # identify the platform and return an entry in Platforms. + # identify the platform and return an entry in CloudNames. data = _collect_platform_data() - checks = (identify_aws, identify_brightbox, lambda x: Platforms.UNKNOWN) + checks = (identify_aws, identify_brightbox, lambda x: CloudNames.UNKNOWN) for checker in checks: try: result = checker(data) diff --git a/cloudinit/sources/DataSourceIBMCloud.py b/cloudinit/sources/DataSourceIBMCloud.py index a5358148..21e6ae6b 100644 --- a/cloudinit/sources/DataSourceIBMCloud.py +++ b/cloudinit/sources/DataSourceIBMCloud.py @@ -157,6 +157,10 @@ class DataSourceIBMCloud(sources.DataSource): return True + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + return '%s (%s)' % (self.platform, self.source) + def check_instance_id(self, sys_cfg): """quickly (local check only) if self.instance_id is still valid diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index bcb38544..61aa6d7e 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -109,6 +109,10 @@ class DataSourceMAAS(sources.DataSource): LOG.warning("Invalid content in vendor-data: %s", e) self.vendordata_raw = None + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + return 'seed-dir (%s)' % self.base_url + def wait_for_metadata_service(self, url): mcfg = self.ds_cfg max_wait = 120 diff --git a/cloudinit/sources/DataSourceNoCloud.py b/cloudinit/sources/DataSourceNoCloud.py index 2daea59d..6860f0cc 100644 --- a/cloudinit/sources/DataSourceNoCloud.py +++ b/cloudinit/sources/DataSourceNoCloud.py @@ -186,6 +186,27 @@ class DataSourceNoCloud(sources.DataSource): self._network_eni = mydata['meta-data'].get('network-interfaces') return True + @property + def platform_type(self): + # Handle upgrade path of pickled ds + if not hasattr(self, '_platform_type'): + self._platform_type = None + if not self._platform_type: + self._platform_type = 'lxd' if util.is_lxd() else 'nocloud' + return self._platform_type + + def _get_cloud_name(self): + """Return unknown when 'cloud-name' key is absent from metadata.""" + return sources.METADATA_UNKNOWN + + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + if self.seed.startswith('/dev'): + subplatform_type = 'config-disk' + else: + subplatform_type = 'seed-dir' + return '%s (%s)' % (subplatform_type, self.seed) + def check_instance_id(self, sys_cfg): # quickly (local check only) if self.instance_id is still valid # we check kernel command line or files. @@ -290,6 +311,35 @@ def parse_cmdline_data(ds_id, fill, cmdline=None): return True +def _maybe_remove_top_network(cfg): + """If network-config contains top level 'network' key, then remove it. + + Some providers of network configuration may provide a top level + 'network' key (LP: #1798117) even though it is not necessary. + + Be friendly and remove it if it really seems so. + + Return the original value if no change or the updated value if changed.""" + nullval = object() + network_val = cfg.get('network', nullval) + if network_val is nullval: + return cfg + bmsg = 'Top level network key in network-config %s: %s' + if not isinstance(network_val, dict): + LOG.debug(bmsg, "was not a dict", cfg) + return cfg + if len(list(cfg.keys())) != 1: + LOG.debug(bmsg, "had multiple top level keys", cfg) + return cfg + if network_val.get('config') == "disabled": + LOG.debug(bmsg, "was config/disabled", cfg) + elif not all(('config' in network_val, 'version' in network_val)): + LOG.debug(bmsg, "but missing 'config' or 'version'", cfg) + return cfg + LOG.debug(bmsg, "fixed by removing shifting network.", cfg) + return network_val + + def _merge_new_seed(cur, seeded): ret = cur.copy() @@ -299,7 +349,8 @@ def _merge_new_seed(cur, seeded): ret['meta-data'] = util.mergemanydict([cur['meta-data'], newmd]) if seeded.get('network-config'): - ret['network-config'] = util.load_yaml(seeded['network-config']) + ret['network-config'] = _maybe_remove_top_network( + util.load_yaml(seeded.get('network-config'))) if 'user-data' in seeded: ret['user-data'] = seeded['user-data'] diff --git a/cloudinit/sources/DataSourceNone.py b/cloudinit/sources/DataSourceNone.py index e63a7e39..e6250801 100644 --- a/cloudinit/sources/DataSourceNone.py +++ b/cloudinit/sources/DataSourceNone.py @@ -28,6 +28,10 @@ class DataSourceNone(sources.DataSource): self.metadata = self.ds_cfg['metadata'] return True + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + return 'config' + def get_instance_id(self): return 'iid-datasource-none' diff --git a/cloudinit/sources/DataSourceOVF.py b/cloudinit/sources/DataSourceOVF.py index 178ccb0f..045291e7 100644 --- a/cloudinit/sources/DataSourceOVF.py +++ b/cloudinit/sources/DataSourceOVF.py @@ -275,6 +275,12 @@ class DataSourceOVF(sources.DataSource): self.cfg = cfg return True + def _get_subplatform(self): + system_type = util.read_dmi_data("system-product-name").lower() + if system_type == 'vmware': + return 'vmware (%s)' % self.seed + return 'ovf (%s)' % self.seed + def get_public_ssh_keys(self): if 'public-keys' not in self.metadata: return [] diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index 77ccd128..e62e9729 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -95,6 +95,14 @@ class DataSourceOpenNebula(sources.DataSource): self.userdata_raw = results.get('userdata') return True + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + if self.seed_dir in self.seed: + subplatform_type = 'seed-dir' + else: + subplatform_type = 'config-disk' + return '%s (%s)' % (subplatform_type, self.seed) + @property def network_config(self): if self.network is not None: diff --git a/cloudinit/sources/DataSourceOracle.py b/cloudinit/sources/DataSourceOracle.py index fab39af3..70b9c58a 100644 --- a/cloudinit/sources/DataSourceOracle.py +++ b/cloudinit/sources/DataSourceOracle.py @@ -91,6 +91,10 @@ class DataSourceOracle(sources.DataSource): def crawl_metadata(self): return read_metadata() + def _get_subplatform(self): + """Return the subplatform metadata source details.""" + return 'metadata (%s)' % METADATA_ENDPOINT + def check_instance_id(self, sys_cfg): """quickly check (local only) if self.instance_id is still valid diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 593ac91a..32b57cdd 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -303,6 +303,9 @@ class DataSourceSmartOS(sources.DataSource): self._set_provisioned() return True + def _get_subplatform(self): + return 'serial (%s)' % SERIAL_DEVICE + def device_name_to_device(self, name): return self.ds_cfg['disk_aliases'].get(name) diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index 5ac98826..e6966b31 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -54,9 +54,18 @@ REDACT_SENSITIVE_VALUE = 'redacted for non-root user' METADATA_CLOUD_NAME_KEY = 'cloud-name' UNSET = "_unset" +METADATA_UNKNOWN = 'unknown' LOG = logging.getLogger(__name__) +# CLOUD_ID_REGION_PREFIX_MAP format is: +# <region-match-prefix>: (<new-cloud-id>: <test_allowed_cloud_callable>) +CLOUD_ID_REGION_PREFIX_MAP = { + 'cn-': ('aws-china', lambda c: c == 'aws'), # only change aws regions + 'us-gov-': ('aws-gov', lambda c: c == 'aws'), # only change aws regions + 'china': ('azure-china', lambda c: c == 'azure'), # only change azure +} + class DataSourceNotFoundException(Exception): pass @@ -133,6 +142,14 @@ class DataSource(object): # Cached cloud_name as determined by _get_cloud_name _cloud_name = None + # Cached cloud platform api type: e.g. ec2, openstack, kvm, lxd, azure etc. + _platform_type = None + + # More details about the cloud platform: + # - metadata (http://169.254.169.254/) + # - seed-dir (<dirname>) + _subplatform = None + # Track the discovered fallback nic for use in configuration generation. _fallback_interface = None @@ -192,21 +209,24 @@ class DataSource(object): local_hostname = self.get_hostname() instance_id = self.get_instance_id() availability_zone = self.availability_zone - cloud_name = self.cloud_name - # When adding new standard keys prefer underscore-delimited instead - # of hyphen-delimted to support simple variable references in jinja - # templates. + # In the event of upgrade from existing cloudinit, pickled datasource + # will not contain these new class attributes. So we need to recrawl + # metadata to discover that content. return { 'v1': { + '_beta_keys': ['subplatform'], 'availability-zone': availability_zone, 'availability_zone': availability_zone, - 'cloud-name': cloud_name, - 'cloud_name': cloud_name, + 'cloud-name': self.cloud_name, + 'cloud_name': self.cloud_name, + 'platform': self.platform_type, + 'public_ssh_keys': self.get_public_ssh_keys(), 'instance-id': instance_id, 'instance_id': instance_id, 'local-hostname': local_hostname, 'local_hostname': local_hostname, - 'region': self.region}} + 'region': self.region, + 'subplatform': self.subplatform}} def clear_cached_attrs(self, attr_defaults=()): """Reset any cached metadata attributes to datasource defaults. @@ -247,19 +267,27 @@ class DataSource(object): @return True on successful write, False otherwise. """ - instance_data = { - 'ds': {'_doc': EXPERIMENTAL_TEXT, - 'meta_data': self.metadata}} - if hasattr(self, 'network_json'): - network_json = getattr(self, 'network_json') - if network_json != UNSET: - instance_data['ds']['network_json'] = network_json - if hasattr(self, 'ec2_metadata'): - ec2_metadata = getattr(self, 'ec2_metadata') - if ec2_metadata != UNSET: - instance_data['ds']['ec2_metadata'] = ec2_metadata + if hasattr(self, '_crawled_metadata'): + # Any datasource with _crawled_metadata will best represent + # most recent, 'raw' metadata + crawled_metadata = copy.deepcopy( + getattr(self, '_crawled_metadata')) + crawled_metadata.pop('user-data', None) + crawled_metadata.pop('vendor-data', None) + instance_data = {'ds': crawled_metadata} + else: + instance_data = {'ds': {'meta_data': self.metadata}} + if hasattr(self, 'network_json'): + network_json = getattr(self, 'network_json') + if network_json != UNSET: + instance_data['ds']['network_json'] = network_json + if hasattr(self, 'ec2_metadata'): + ec2_metadata = getattr(self, 'ec2_metadata') + if ec2_metadata != UNSET: + instance_data['ds']['ec2_metadata'] = ec2_metadata instance_data.update( self._get_standardized_metadata()) + instance_data['ds']['_doc'] = EXPERIMENTAL_TEXT try: # Process content base64encoding unserializable values content = util.json_dumps(instance_data) @@ -347,6 +375,40 @@ class DataSource(object): return self._fallback_interface @property + def platform_type(self): + if not hasattr(self, '_platform_type'): + # Handle upgrade path where pickled datasource has no _platform. + self._platform_type = self.dsname.lower() + if not self._platform_type: + self._platform_type = self.dsname.lower() + return self._platform_type + + @property + def subplatform(self): + """Return a string representing subplatform details for the datasource. + + This should be guidance for where the metadata is sourced. + Examples of this on different clouds: + ec2: metadata (http://169.254.169.254) + openstack: configdrive (/dev/path) + openstack: metadata (http://169.254.169.254) + nocloud: seed-dir (/seed/dir/path) + lxd: nocloud (/seed/dir/path) + """ + if not hasattr(self, '_subplatform'): + # Handle upgrade path where pickled datasource has no _platform. + self._subplatform = self._get_subplatform() + if not self._subplatform: + self._subplatform = self._get_subplatform() + return self._subplatform + + def _get_subplatform(self): + """Subclasses should implement to return a "slug (detail)" string.""" + if hasattr(self, 'metadata_address'): + return 'metadata (%s)' % getattr(self, 'metadata_address') + return METADATA_UNKNOWN + + @property def cloud_name(self): """Return lowercase cloud name as determined by the datasource. @@ -359,9 +421,11 @@ class DataSource(object): cloud_name = self.metadata.get(METADATA_CLOUD_NAME_KEY) if isinstance(cloud_name, six.string_types): self._cloud_name = cloud_name.lower() - LOG.debug( - 'Ignoring metadata provided key %s: non-string type %s', - METADATA_CLOUD_NAME_KEY, type(cloud_name)) + else: + self._cloud_name = self._get_cloud_name().lower() + LOG.debug( + 'Ignoring metadata provided key %s: non-string type %s', + METADATA_CLOUD_NAME_KEY, type(cloud_name)) else: self._cloud_name = self._get_cloud_name().lower() return self._cloud_name @@ -714,6 +778,25 @@ def instance_id_matches_system_uuid(instance_id, field='system-uuid'): return instance_id.lower() == dmi_value.lower() +def canonical_cloud_id(cloud_name, region, platform): + """Lookup the canonical cloud-id for a given cloud_name and region.""" + if not cloud_name: + cloud_name = METADATA_UNKNOWN + if not region: + region = METADATA_UNKNOWN + if region == METADATA_UNKNOWN: + if cloud_name != METADATA_UNKNOWN: + return cloud_name + return platform + for prefix, cloud_id_test in CLOUD_ID_REGION_PREFIX_MAP.items(): + (cloud_id, valid_cloud) = cloud_id_test + if region.startswith(prefix) and valid_cloud(cloud_name): + return cloud_id + if cloud_name != METADATA_UNKNOWN: + return cloud_name + return platform + + def convert_vendordata(data, recurse=True): """data: a loaded object (strings, arrays, dicts). return something suitable for cloudinit vendordata_raw. diff --git a/cloudinit/sources/helpers/netlink.py b/cloudinit/sources/helpers/netlink.py new file mode 100644 index 00000000..d377ae3d --- /dev/null +++ b/cloudinit/sources/helpers/netlink.py @@ -0,0 +1,250 @@ +# Author: Tamilmani Manoharan <tamanoha@microsoft.com> +# +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit import log as logging +from cloudinit import util +from collections import namedtuple + +import os +import select +import socket +import struct + +LOG = logging.getLogger(__name__) + +# http://man7.org/linux/man-pages/man7/netlink.7.html +RTMGRP_LINK = 1 +NLMSG_NOOP = 1 +NLMSG_ERROR = 2 +NLMSG_DONE = 3 +RTM_NEWLINK = 16 +RTM_DELLINK = 17 +RTM_GETLINK = 18 +RTM_SETLINK = 19 +MAX_SIZE = 65535 +RTA_DATA_OFFSET = 32 +MSG_TYPE_OFFSET = 16 +SELECT_TIMEOUT = 60 + +NLMSGHDR_FMT = "IHHII" +IFINFOMSG_FMT = "BHiII" +NLMSGHDR_SIZE = struct.calcsize(NLMSGHDR_FMT) +IFINFOMSG_SIZE = struct.calcsize(IFINFOMSG_FMT) +RTATTR_START_OFFSET = NLMSGHDR_SIZE + IFINFOMSG_SIZE +RTA_DATA_START_OFFSET = 4 +PAD_ALIGNMENT = 4 + +IFLA_IFNAME = 3 +IFLA_OPERSTATE = 16 + +# https://www.kernel.org/doc/Documentation/networking/operstates.txt +OPER_UNKNOWN = 0 +OPER_NOTPRESENT = 1 +OPER_DOWN = 2 +OPER_LOWERLAYERDOWN = 3 +OPER_TESTING = 4 +OPER_DORMANT = 5 +OPER_UP = 6 + +RTAAttr = namedtuple('RTAAttr', ['length', 'rta_type', 'data']) +InterfaceOperstate = namedtuple('InterfaceOperstate', ['ifname', 'operstate']) +NetlinkHeader = namedtuple('NetlinkHeader', ['length', 'type', 'flags', 'seq', + 'pid']) + + +class NetlinkCreateSocketError(RuntimeError): + '''Raised if netlink socket fails during create or bind.''' + pass + + +def create_bound_netlink_socket(): + '''Creates netlink socket and bind on netlink group to catch interface + down/up events. The socket will bound only on RTMGRP_LINK (which only + includes RTM_NEWLINK/RTM_DELLINK/RTM_GETLINK events). The socket is set to + non-blocking mode since we're only receiving messages. + + :returns: netlink socket in non-blocking mode + :raises: NetlinkCreateSocketError + ''' + try: + netlink_socket = socket.socket(socket.AF_NETLINK, + socket.SOCK_RAW, + socket.NETLINK_ROUTE) + netlink_socket.bind((os.getpid(), RTMGRP_LINK)) + netlink_socket.setblocking(0) + except socket.error as e: + msg = "Exception during netlink socket create: %s" % e + raise NetlinkCreateSocketError(msg) + LOG.debug("Created netlink socket") + return netlink_socket + + +def get_netlink_msg_header(data): + '''Gets netlink message type and length + + :param: data read from netlink socket + :returns: netlink message type + :raises: AssertionError if data is None or data is not >= NLMSGHDR_SIZE + struct nlmsghdr { + __u32 nlmsg_len; /* Length of message including header */ + __u16 nlmsg_type; /* Type of message content */ + __u16 nlmsg_flags; /* Additional flags */ + __u32 nlmsg_seq; /* Sequence number */ + __u32 nlmsg_pid; /* Sender port ID */ + }; + ''' + assert (data is not None), ("data is none") + assert (len(data) >= NLMSGHDR_SIZE), ( + "data is smaller than netlink message header") + msg_len, msg_type, flags, seq, pid = struct.unpack(NLMSGHDR_FMT, + data[:MSG_TYPE_OFFSET]) + LOG.debug("Got netlink msg of type %d", msg_type) + return NetlinkHeader(msg_len, msg_type, flags, seq, pid) + + +def read_netlink_socket(netlink_socket, timeout=None): + '''Select and read from the netlink socket if ready. + + :param: netlink_socket: specify which socket object to read from + :param: timeout: specify a timeout value (integer) to wait while reading, + if none, it will block indefinitely until socket ready for read + :returns: string of data read (max length = <MAX_SIZE>) from socket, + if no data read, returns None + :raises: AssertionError if netlink_socket is None + ''' + assert (netlink_socket is not None), ("netlink socket is none") + read_set, _, _ = select.select([netlink_socket], [], [], timeout) + # Incase of timeout,read_set doesn't contain netlink socket. + # just return from this function + if netlink_socket not in read_set: + return None + LOG.debug("netlink socket ready for read") + data = netlink_socket.recv(MAX_SIZE) + if data is None: + LOG.error("Reading from Netlink socket returned no data") + return data + + +def unpack_rta_attr(data, offset): + '''Unpack a single rta attribute. + + :param: data: string of data read from netlink socket + :param: offset: starting offset of RTA Attribute + :return: RTAAttr object with length, type and data. On error, return None. + :raises: AssertionError if data is None or offset is not integer. + ''' + assert (data is not None), ("data is none") + assert (type(offset) == int), ("offset is not integer") + assert (offset >= RTATTR_START_OFFSET), ( + "rta offset is less than expected length") + length = rta_type = 0 + attr_data = None + try: + length = struct.unpack_from("H", data, offset=offset)[0] + rta_type = struct.unpack_from("H", data, offset=offset+2)[0] + except struct.error: + return None # Should mean our offset is >= remaining data + + # Unpack just the attribute's data. Offset by 4 to skip length/type header + attr_data = data[offset+RTA_DATA_START_OFFSET:offset+length] + return RTAAttr(length, rta_type, attr_data) + + +def read_rta_oper_state(data): + '''Reads Interface name and operational state from RTA Data. + + :param: data: string of data read from netlink socket + :returns: InterfaceOperstate object containing if_name and oper_state. + None if data does not contain valid IFLA_OPERSTATE and + IFLA_IFNAME messages. + :raises: AssertionError if data is None or length of data is + smaller than RTATTR_START_OFFSET. + ''' + assert (data is not None), ("data is none") + assert (len(data) > RTATTR_START_OFFSET), ( + "length of data is smaller than RTATTR_START_OFFSET") + ifname = operstate = None + offset = RTATTR_START_OFFSET + while offset <= len(data): + attr = unpack_rta_attr(data, offset) + if not attr or attr.length == 0: + break + # Each attribute is 4-byte aligned. Determine pad length. + padlen = (PAD_ALIGNMENT - + (attr.length % PAD_ALIGNMENT)) % PAD_ALIGNMENT + offset += attr.length + padlen + + if attr.rta_type == IFLA_OPERSTATE: + operstate = ord(attr.data) + elif attr.rta_type == IFLA_IFNAME: + interface_name = util.decode_binary(attr.data, 'utf-8') + ifname = interface_name.strip('\0') + if not ifname or operstate is None: + return None + LOG.debug("rta attrs: ifname %s operstate %d", ifname, operstate) + return InterfaceOperstate(ifname, operstate) + + +def wait_for_media_disconnect_connect(netlink_socket, ifname): + '''Block until media disconnect and connect has happened on an interface. + Listens on netlink socket to receive netlink events and when the carrier + changes from 0 to 1, it considers event has happened and + return from this function + + :param: netlink_socket: netlink_socket to receive events + :param: ifname: Interface name to lookout for netlink events + :raises: AssertionError if netlink_socket is None or ifname is None. + ''' + assert (netlink_socket is not None), ("netlink socket is none") + assert (ifname is not None), ("interface name is none") + assert (len(ifname) > 0), ("interface name cannot be empty") + carrier = OPER_UP + prevCarrier = OPER_UP + data = bytes() + LOG.debug("Wait for media disconnect and reconnect to happen") + while True: + recv_data = read_netlink_socket(netlink_socket, SELECT_TIMEOUT) + if recv_data is None: + continue + LOG.debug('read %d bytes from socket', len(recv_data)) + data += recv_data + LOG.debug('Length of data after concat %d', len(data)) + offset = 0 + datalen = len(data) + while offset < datalen: + nl_msg = data[offset:] + if len(nl_msg) < NLMSGHDR_SIZE: + LOG.debug("Data is smaller than netlink header") + break + nlheader = get_netlink_msg_header(nl_msg) + if len(nl_msg) < nlheader.length: + LOG.debug("Partial data. Smaller than netlink message") + break + padlen = (nlheader.length+PAD_ALIGNMENT-1) & ~(PAD_ALIGNMENT-1) + offset = offset + padlen + LOG.debug('offset to next netlink message: %d', offset) + # Ignore any messages not new link or del link + if nlheader.type not in [RTM_NEWLINK, RTM_DELLINK]: + continue + interface_state = read_rta_oper_state(nl_msg) + if interface_state is None: + LOG.debug('Failed to read rta attributes: %s', interface_state) + continue + if interface_state.ifname != ifname: + LOG.debug( + "Ignored netlink event on interface %s. Waiting for %s.", + interface_state.ifname, ifname) + continue + if interface_state.operstate not in [OPER_UP, OPER_DOWN]: + continue + prevCarrier = carrier + carrier = interface_state.operstate + # check for carrier down, up sequence + isVnetSwitch = (prevCarrier == OPER_DOWN) and (carrier == OPER_UP) + if isVnetSwitch: + LOG.debug("Media switch happened on %s.", ifname) + return + data = data[offset:] + +# vi: ts=4 expandtab diff --git a/cloudinit/sources/helpers/tests/test_netlink.py b/cloudinit/sources/helpers/tests/test_netlink.py new file mode 100644 index 00000000..c2898a16 --- /dev/null +++ b/cloudinit/sources/helpers/tests/test_netlink.py @@ -0,0 +1,373 @@ +# Author: Tamilmani Manoharan <tamanoha@microsoft.com> +# +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit.tests.helpers import CiTestCase, mock +import socket +import struct +import codecs +from cloudinit.sources.helpers.netlink import ( + NetlinkCreateSocketError, create_bound_netlink_socket, read_netlink_socket, + read_rta_oper_state, unpack_rta_attr, wait_for_media_disconnect_connect, + OPER_DOWN, OPER_UP, OPER_DORMANT, OPER_LOWERLAYERDOWN, OPER_NOTPRESENT, + OPER_TESTING, OPER_UNKNOWN, RTATTR_START_OFFSET, RTM_NEWLINK, RTM_SETLINK, + RTM_GETLINK, MAX_SIZE) + + +def int_to_bytes(i): + '''convert integer to binary: eg: 1 to \x01''' + hex_value = '{0:x}'.format(i) + hex_value = '0' * (len(hex_value) % 2) + hex_value + return codecs.decode(hex_value, 'hex_codec') + + +class TestCreateBoundNetlinkSocket(CiTestCase): + + @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') + def test_socket_error_on_create(self, m_socket): + '''create_bound_netlink_socket catches socket creation exception''' + + """NetlinkCreateSocketError is raised when socket creation errors.""" + m_socket.side_effect = socket.error("Fake socket failure") + with self.assertRaises(NetlinkCreateSocketError) as ctx_mgr: + create_bound_netlink_socket() + self.assertEqual( + 'Exception during netlink socket create: Fake socket failure', + str(ctx_mgr.exception)) + + +class TestReadNetlinkSocket(CiTestCase): + + @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') + @mock.patch('cloudinit.sources.helpers.netlink.select.select') + def test_read_netlink_socket(self, m_select, m_socket): + '''read_netlink_socket able to receive data''' + data = 'netlinktest' + m_select.return_value = [m_socket], None, None + m_socket.recv.return_value = data + recv_data = read_netlink_socket(m_socket, 2) + m_select.assert_called_with([m_socket], [], [], 2) + m_socket.recv.assert_called_with(MAX_SIZE) + self.assertIsNotNone(recv_data) + self.assertEqual(recv_data, data) + + @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') + @mock.patch('cloudinit.sources.helpers.netlink.select.select') + def test_netlink_read_timeout(self, m_select, m_socket): + '''read_netlink_socket should timeout if nothing to read''' + m_select.return_value = [], None, None + data = read_netlink_socket(m_socket, 1) + m_select.assert_called_with([m_socket], [], [], 1) + self.assertEqual(m_socket.recv.call_count, 0) + self.assertIsNone(data) + + def test_read_invalid_socket(self): + '''read_netlink_socket raises assert error if socket is invalid''' + socket = None + with self.assertRaises(AssertionError) as context: + read_netlink_socket(socket, 1) + self.assertTrue('netlink socket is none' in str(context.exception)) + + +class TestParseNetlinkMessage(CiTestCase): + + def test_read_rta_oper_state(self): + '''read_rta_oper_state could parse netlink message and extract data''' + ifname = "eth0" + bytes = ifname.encode("utf-8") + buf = bytearray(48) + struct.pack_into("HH4sHHc", buf, RTATTR_START_OFFSET, 8, 3, bytes, 5, + 16, int_to_bytes(OPER_DOWN)) + interface_state = read_rta_oper_state(buf) + self.assertEqual(interface_state.ifname, ifname) + self.assertEqual(interface_state.operstate, OPER_DOWN) + + def test_read_none_data(self): + '''read_rta_oper_state raises assert error if data is none''' + data = None + with self.assertRaises(AssertionError) as context: + read_rta_oper_state(data) + self.assertTrue('data is none', str(context.exception)) + + def test_read_invalid_rta_operstate_none(self): + '''read_rta_oper_state returns none if operstate is none''' + ifname = "eth0" + buf = bytearray(40) + bytes = ifname.encode("utf-8") + struct.pack_into("HH4s", buf, RTATTR_START_OFFSET, 8, 3, bytes) + interface_state = read_rta_oper_state(buf) + self.assertIsNone(interface_state) + + def test_read_invalid_rta_ifname_none(self): + '''read_rta_oper_state returns none if ifname is none''' + buf = bytearray(40) + struct.pack_into("HHc", buf, RTATTR_START_OFFSET, 5, 16, + int_to_bytes(OPER_DOWN)) + interface_state = read_rta_oper_state(buf) + self.assertIsNone(interface_state) + + def test_read_invalid_data_len(self): + '''raise assert error if data size is smaller than required size''' + buf = bytearray(32) + with self.assertRaises(AssertionError) as context: + read_rta_oper_state(buf) + self.assertTrue('length of data is smaller than RTATTR_START_OFFSET' in + str(context.exception)) + + def test_unpack_rta_attr_none_data(self): + '''unpack_rta_attr raises assert error if data is none''' + data = None + with self.assertRaises(AssertionError) as context: + unpack_rta_attr(data, RTATTR_START_OFFSET) + self.assertTrue('data is none' in str(context.exception)) + + def test_unpack_rta_attr_invalid_offset(self): + '''unpack_rta_attr raises assert error if offset is invalid''' + data = bytearray(48) + with self.assertRaises(AssertionError) as context: + unpack_rta_attr(data, "offset") + self.assertTrue('offset is not integer' in str(context.exception)) + with self.assertRaises(AssertionError) as context: + unpack_rta_attr(data, 31) + self.assertTrue('rta offset is less than expected length' in + str(context.exception)) + + +@mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +@mock.patch('cloudinit.sources.helpers.netlink.read_netlink_socket') +class TestWaitForMediaDisconnectConnect(CiTestCase): + with_logs = True + + def _media_switch_data(self, ifname, msg_type, operstate): + '''construct netlink data with specified fields''' + if ifname and operstate is not None: + data = bytearray(48) + bytes = ifname.encode("utf-8") + struct.pack_into("HH4sHHc", data, RTATTR_START_OFFSET, 8, 3, + bytes, 5, 16, int_to_bytes(operstate)) + elif ifname: + data = bytearray(40) + bytes = ifname.encode("utf-8") + struct.pack_into("HH4s", data, RTATTR_START_OFFSET, 8, 3, bytes) + elif operstate: + data = bytearray(40) + struct.pack_into("HHc", data, RTATTR_START_OFFSET, 5, 16, + int_to_bytes(operstate)) + struct.pack_into("=LHHLL", data, 0, len(data), msg_type, 0, 0, 0) + return data + + def test_media_down_up_scenario(self, m_read_netlink_socket, + m_socket): + '''Test for media down up sequence for required interface name''' + ifname = "eth0" + # construct data for Oper State down + data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) + # construct data for Oper State up + data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) + m_read_netlink_socket.side_effect = [data_op_down, data_op_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 2) + + def test_wait_for_media_switch_diff_interface(self, m_read_netlink_socket, + m_socket): + '''wait_for_media_disconnect_connect ignores unexpected interfaces. + + The first two messages are for other interfaces and last two are for + expected interface. So the function exit only after receiving last + 2 messages and therefore the call count for m_read_netlink_socket + has to be 4 + ''' + other_ifname = "eth1" + expected_ifname = "eth0" + data_op_down_eth1 = self._media_switch_data( + other_ifname, RTM_NEWLINK, OPER_DOWN) + data_op_up_eth1 = self._media_switch_data( + other_ifname, RTM_NEWLINK, OPER_UP) + data_op_down_eth0 = self._media_switch_data( + expected_ifname, RTM_NEWLINK, OPER_DOWN) + data_op_up_eth0 = self._media_switch_data( + expected_ifname, RTM_NEWLINK, OPER_UP) + m_read_netlink_socket.side_effect = [data_op_down_eth1, + data_op_up_eth1, + data_op_down_eth0, + data_op_up_eth0] + wait_for_media_disconnect_connect(m_socket, expected_ifname) + self.assertIn('Ignored netlink event on interface %s' % other_ifname, + self.logs.getvalue()) + self.assertEqual(m_read_netlink_socket.call_count, 4) + + def test_invalid_msgtype_getlink(self, m_read_netlink_socket, m_socket): + '''wait_for_media_disconnect_connect ignores GETLINK events. + + The first two messages are for oper down and up for RTM_GETLINK type + which netlink module will ignore. The last 2 messages are RTM_NEWLINK + with oper state down and up messages. Therefore the call count for + m_read_netlink_socket has to be 4 ignoring first 2 messages + of RTM_GETLINK + ''' + ifname = "eth0" + data_getlink_down = self._media_switch_data( + ifname, RTM_GETLINK, OPER_DOWN) + data_getlink_up = self._media_switch_data( + ifname, RTM_GETLINK, OPER_UP) + data_newlink_down = self._media_switch_data( + ifname, RTM_NEWLINK, OPER_DOWN) + data_newlink_up = self._media_switch_data( + ifname, RTM_NEWLINK, OPER_UP) + m_read_netlink_socket.side_effect = [data_getlink_down, + data_getlink_up, + data_newlink_down, + data_newlink_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 4) + + def test_invalid_msgtype_setlink(self, m_read_netlink_socket, m_socket): + '''wait_for_media_disconnect_connect ignores SETLINK events. + + The first two messages are for oper down and up for RTM_GETLINK type + which it will ignore. 3rd and 4th messages are RTM_NEWLINK with down + and up messages. This function should exit after 4th messages since it + sees down->up scenario. So the call count for m_read_netlink_socket + has to be 4 ignoring first 2 messages of RTM_GETLINK and + last 2 messages of RTM_NEWLINK + ''' + ifname = "eth0" + data_setlink_down = self._media_switch_data( + ifname, RTM_SETLINK, OPER_DOWN) + data_setlink_up = self._media_switch_data( + ifname, RTM_SETLINK, OPER_UP) + data_newlink_down = self._media_switch_data( + ifname, RTM_NEWLINK, OPER_DOWN) + data_newlink_up = self._media_switch_data( + ifname, RTM_NEWLINK, OPER_UP) + m_read_netlink_socket.side_effect = [data_setlink_down, + data_setlink_up, + data_newlink_down, + data_newlink_up, + data_newlink_down, + data_newlink_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 4) + + def test_netlink_invalid_switch_scenario(self, m_read_netlink_socket, + m_socket): + '''returns only if it receives UP event after a DOWN event''' + ifname = "eth0" + data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) + data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) + data_op_dormant = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_DORMANT) + data_op_notpresent = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_NOTPRESENT) + data_op_lowerdown = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_LOWERLAYERDOWN) + data_op_testing = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_TESTING) + data_op_unknown = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_UNKNOWN) + m_read_netlink_socket.side_effect = [data_op_up, data_op_up, + data_op_dormant, data_op_up, + data_op_notpresent, data_op_up, + data_op_lowerdown, data_op_up, + data_op_testing, data_op_up, + data_op_unknown, data_op_up, + data_op_down, data_op_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 14) + + def test_netlink_valid_inbetween_transitions(self, m_read_netlink_socket, + m_socket): + '''wait_for_media_disconnect_connect handles in between transitions''' + ifname = "eth0" + data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) + data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) + data_op_dormant = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_DORMANT) + data_op_unknown = self._media_switch_data(ifname, RTM_NEWLINK, + OPER_UNKNOWN) + m_read_netlink_socket.side_effect = [data_op_down, data_op_dormant, + data_op_unknown, data_op_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 4) + + def test_netlink_invalid_operstate(self, m_read_netlink_socket, m_socket): + '''wait_for_media_disconnect_connect should handle invalid operstates. + + The function should not fail and return even if it receives invalid + operstates. It always should wait for down up sequence. + ''' + ifname = "eth0" + data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) + data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) + data_op_invalid = self._media_switch_data(ifname, RTM_NEWLINK, 7) + m_read_netlink_socket.side_effect = [data_op_invalid, data_op_up, + data_op_down, data_op_invalid, + data_op_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 5) + + def test_wait_invalid_socket(self, m_read_netlink_socket, m_socket): + '''wait_for_media_disconnect_connect handle none netlink socket.''' + socket = None + ifname = "eth0" + with self.assertRaises(AssertionError) as context: + wait_for_media_disconnect_connect(socket, ifname) + self.assertTrue('netlink socket is none' in str(context.exception)) + + def test_wait_invalid_ifname(self, m_read_netlink_socket, m_socket): + '''wait_for_media_disconnect_connect handle none interface name''' + ifname = None + with self.assertRaises(AssertionError) as context: + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertTrue('interface name is none' in str(context.exception)) + ifname = "" + with self.assertRaises(AssertionError) as context: + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertTrue('interface name cannot be empty' in + str(context.exception)) + + def test_wait_invalid_rta_attr(self, m_read_netlink_socket, m_socket): + ''' wait_for_media_disconnect_connect handles invalid rta data''' + ifname = "eth0" + data_invalid1 = self._media_switch_data(None, RTM_NEWLINK, OPER_DOWN) + data_invalid2 = self._media_switch_data(ifname, RTM_NEWLINK, None) + data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) + data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) + m_read_netlink_socket.side_effect = [data_invalid1, data_invalid2, + data_op_down, data_op_up] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 4) + + def test_read_multiple_netlink_msgs(self, m_read_netlink_socket, m_socket): + '''Read multiple messages in single receive call''' + ifname = "eth0" + bytes = ifname.encode("utf-8") + data = bytearray(96) + struct.pack_into("=LHHLL", data, 0, 48, RTM_NEWLINK, 0, 0, 0) + struct.pack_into("HH4sHHc", data, RTATTR_START_OFFSET, 8, 3, + bytes, 5, 16, int_to_bytes(OPER_DOWN)) + struct.pack_into("=LHHLL", data, 48, 48, RTM_NEWLINK, 0, 0, 0) + struct.pack_into("HH4sHHc", data, 48 + RTATTR_START_OFFSET, 8, + 3, bytes, 5, 16, int_to_bytes(OPER_UP)) + m_read_netlink_socket.return_value = data + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 1) + + def test_read_partial_netlink_msgs(self, m_read_netlink_socket, m_socket): + '''Read partial messages in receive call''' + ifname = "eth0" + bytes = ifname.encode("utf-8") + data1 = bytearray(112) + data2 = bytearray(32) + struct.pack_into("=LHHLL", data1, 0, 48, RTM_NEWLINK, 0, 0, 0) + struct.pack_into("HH4sHHc", data1, RTATTR_START_OFFSET, 8, 3, + bytes, 5, 16, int_to_bytes(OPER_DOWN)) + struct.pack_into("=LHHLL", data1, 48, 48, RTM_NEWLINK, 0, 0, 0) + struct.pack_into("HH4sHHc", data1, 80, 8, 3, bytes, 5, 16, + int_to_bytes(OPER_DOWN)) + struct.pack_into("=LHHLL", data1, 96, 48, RTM_NEWLINK, 0, 0, 0) + struct.pack_into("HH4sHHc", data2, 16, 8, 3, bytes, 5, 16, + int_to_bytes(OPER_UP)) + m_read_netlink_socket.side_effect = [data1, data2] + wait_for_media_disconnect_connect(m_socket, ifname) + self.assertEqual(m_read_netlink_socket.call_count, 2) diff --git a/cloudinit/sources/helpers/vmware/imc/config_nic.py b/cloudinit/sources/helpers/vmware/imc/config_nic.py index e1890e23..77cbf3b6 100644 --- a/cloudinit/sources/helpers/vmware/imc/config_nic.py +++ b/cloudinit/sources/helpers/vmware/imc/config_nic.py @@ -165,9 +165,8 @@ class NicConfigurator(object): # Add routes if there is no primary nic if not self._primaryNic and v4.gateways: - route_list.extend(self.gen_ipv4_route(nic, - v4.gateways, - v4.netmask)) + subnet.update( + {'routes': self.gen_ipv4_route(nic, v4.gateways, v4.netmask)}) return ([subnet], route_list) diff --git a/cloudinit/sources/tests/test_init.py b/cloudinit/sources/tests/test_init.py index 8082019e..6378e98b 100644 --- a/cloudinit/sources/tests/test_init.py +++ b/cloudinit/sources/tests/test_init.py @@ -11,7 +11,8 @@ from cloudinit.helpers import Paths from cloudinit import importer from cloudinit.sources import ( EXPERIMENTAL_TEXT, INSTANCE_JSON_FILE, INSTANCE_JSON_SENSITIVE_FILE, - REDACT_SENSITIVE_VALUE, UNSET, DataSource, redact_sensitive_keys) + METADATA_UNKNOWN, REDACT_SENSITIVE_VALUE, UNSET, DataSource, + canonical_cloud_id, redact_sensitive_keys) from cloudinit.tests.helpers import CiTestCase, skipIf, mock from cloudinit.user_data import UserDataProcessor from cloudinit import util @@ -295,6 +296,7 @@ class TestDataSource(CiTestCase): 'base64_encoded_keys': [], 'sensitive_keys': [], 'v1': { + '_beta_keys': ['subplatform'], 'availability-zone': 'myaz', 'availability_zone': 'myaz', 'cloud-name': 'subclasscloudname', @@ -303,7 +305,10 @@ class TestDataSource(CiTestCase): 'instance_id': 'iid-datasource', 'local-hostname': 'test-subclass-hostname', 'local_hostname': 'test-subclass-hostname', - 'region': 'myregion'}, + 'platform': 'mytestsubclass', + 'public_ssh_keys': [], + 'region': 'myregion', + 'subplatform': 'unknown'}, 'ds': { '_doc': EXPERIMENTAL_TEXT, 'meta_data': {'availability_zone': 'myaz', @@ -339,6 +344,7 @@ class TestDataSource(CiTestCase): 'base64_encoded_keys': [], 'sensitive_keys': ['ds/meta_data/some/security-credentials'], 'v1': { + '_beta_keys': ['subplatform'], 'availability-zone': 'myaz', 'availability_zone': 'myaz', 'cloud-name': 'subclasscloudname', @@ -347,7 +353,10 @@ class TestDataSource(CiTestCase): 'instance_id': 'iid-datasource', 'local-hostname': 'test-subclass-hostname', 'local_hostname': 'test-subclass-hostname', - 'region': 'myregion'}, + 'platform': 'mytestsubclass', + 'public_ssh_keys': [], + 'region': 'myregion', + 'subplatform': 'unknown'}, 'ds': { '_doc': EXPERIMENTAL_TEXT, 'meta_data': { @@ -599,4 +608,75 @@ class TestRedactSensitiveData(CiTestCase): redact_sensitive_keys(md)) +class TestCanonicalCloudID(CiTestCase): + + def test_cloud_id_returns_platform_on_unknowns(self): + """When region and cloud_name are unknown, return platform.""" + self.assertEqual( + 'platform', + canonical_cloud_id(cloud_name=METADATA_UNKNOWN, + region=METADATA_UNKNOWN, + platform='platform')) + + def test_cloud_id_returns_platform_on_none(self): + """When region and cloud_name are unknown, return platform.""" + self.assertEqual( + 'platform', + canonical_cloud_id(cloud_name=None, + region=None, + platform='platform')) + + def test_cloud_id_returns_cloud_name_on_unknown_region(self): + """When region is unknown, return cloud_name.""" + for region in (None, METADATA_UNKNOWN): + self.assertEqual( + 'cloudname', + canonical_cloud_id(cloud_name='cloudname', + region=region, + platform='platform')) + + def test_cloud_id_returns_platform_on_unknown_cloud_name(self): + """When region is set but cloud_name is unknown return cloud_name.""" + self.assertEqual( + 'platform', + canonical_cloud_id(cloud_name=METADATA_UNKNOWN, + region='region', + platform='platform')) + + def test_cloud_id_aws_based_on_region_and_cloud_name(self): + """When cloud_name is aws, return proper cloud-id based on region.""" + self.assertEqual( + 'aws-china', + canonical_cloud_id(cloud_name='aws', + region='cn-north-1', + platform='platform')) + self.assertEqual( + 'aws', + canonical_cloud_id(cloud_name='aws', + region='us-east-1', + platform='platform')) + self.assertEqual( + 'aws-gov', + canonical_cloud_id(cloud_name='aws', + region='us-gov-1', + platform='platform')) + self.assertEqual( # Overrideen non-aws cloud_name is returned + '!aws', + canonical_cloud_id(cloud_name='!aws', + region='us-gov-1', + platform='platform')) + + def test_cloud_id_azure_based_on_region_and_cloud_name(self): + """Report cloud-id when cloud_name is azure and region is in china.""" + self.assertEqual( + 'azure-china', + canonical_cloud_id(cloud_name='azure', + region='chinaeast', + platform='platform')) + self.assertEqual( + 'azure', + canonical_cloud_id(cloud_name='azure', + region='!chinaeast', + platform='platform')) + # vi: ts=4 expandtab diff --git a/cloudinit/sources/tests/test_oracle.py b/cloudinit/sources/tests/test_oracle.py index 7599126c..97d62947 100644 --- a/cloudinit/sources/tests/test_oracle.py +++ b/cloudinit/sources/tests/test_oracle.py @@ -71,6 +71,14 @@ class TestDataSourceOracle(test_helpers.CiTestCase): self.assertFalse(ds._get_data()) mocks._is_platform_viable.assert_called_once_with() + def test_platform_info(self): + """Return platform-related information for Oracle Datasource.""" + ds, _mocks = self._get_ds() + self.assertEqual('oracle', ds.cloud_name) + self.assertEqual('oracle', ds.platform_type) + self.assertEqual( + 'metadata (http://169.254.169.254/openstack/)', ds.subplatform) + @mock.patch(DS_PATH + "._is_iscsi_root", return_value=True) def test_without_userdata(self, m_is_iscsi_root): """If no user-data is provided, it should not be in return dict.""" diff --git a/cloudinit/tests/test_dhclient_hook.py b/cloudinit/tests/test_dhclient_hook.py new file mode 100644 index 00000000..7aab8dd5 --- /dev/null +++ b/cloudinit/tests/test_dhclient_hook.py @@ -0,0 +1,105 @@ +# This file is part of cloud-init. See LICENSE file for license information. + +"""Tests for cloudinit.dhclient_hook.""" + +from cloudinit import dhclient_hook as dhc +from cloudinit.tests.helpers import CiTestCase, dir2dict, populate_dir + +import argparse +import json +import mock +import os + + +class TestDhclientHook(CiTestCase): + + ex_env = { + 'interface': 'eth0', + 'new_dhcp_lease_time': '3600', + 'new_host_name': 'x1', + 'new_ip_address': '10.145.210.163', + 'new_subnet_mask': '255.255.255.0', + 'old_host_name': 'x1', + 'PATH': '/usr/sbin:/usr/bin:/sbin:/bin', + 'pid': '614', + 'reason': 'BOUND', + } + + # some older versions of dhclient put the same content, + # but in upper case with DHCP4_ instead of new_ + ex_env_dhcp4 = { + 'REASON': 'BOUND', + 'DHCP4_dhcp_lease_time': '3600', + 'DHCP4_host_name': 'x1', + 'DHCP4_ip_address': '10.145.210.163', + 'DHCP4_subnet_mask': '255.255.255.0', + 'INTERFACE': 'eth0', + 'PATH': '/usr/sbin:/usr/bin:/sbin:/bin', + 'pid': '614', + } + + expected = { + 'dhcp_lease_time': '3600', + 'host_name': 'x1', + 'ip_address': '10.145.210.163', + 'subnet_mask': '255.255.255.0'} + + def setUp(self): + super(TestDhclientHook, self).setUp() + self.tmp = self.tmp_dir() + + def test_handle_args(self): + """quick test of call to handle_args.""" + nic = 'eth0' + args = argparse.Namespace(event=dhc.UP, interface=nic) + with mock.patch.dict("os.environ", clear=True, values=self.ex_env): + dhc.handle_args(dhc.NAME, args, data_d=self.tmp) + found = dir2dict(self.tmp + os.path.sep) + self.assertEqual([nic + ".json"], list(found.keys())) + self.assertEqual(self.expected, json.loads(found[nic + ".json"])) + + def test_run_hook_up_creates_dir(self): + """If dir does not exist, run_hook should create it.""" + subd = self.tmp_path("subdir", self.tmp) + nic = 'eth1' + dhc.run_hook(nic, 'up', data_d=subd, env=self.ex_env) + self.assertEqual( + set([nic + ".json"]), set(dir2dict(subd + os.path.sep))) + + def test_run_hook_up(self): + """Test expected use of run_hook_up.""" + nic = 'eth0' + dhc.run_hook(nic, 'up', data_d=self.tmp, env=self.ex_env) + found = dir2dict(self.tmp + os.path.sep) + self.assertEqual([nic + ".json"], list(found.keys())) + self.assertEqual(self.expected, json.loads(found[nic + ".json"])) + + def test_run_hook_up_dhcp4_prefix(self): + """Test run_hook filters correctly with older DHCP4_ data.""" + nic = 'eth0' + dhc.run_hook(nic, 'up', data_d=self.tmp, env=self.ex_env_dhcp4) + found = dir2dict(self.tmp + os.path.sep) + self.assertEqual([nic + ".json"], list(found.keys())) + self.assertEqual(self.expected, json.loads(found[nic + ".json"])) + + def test_run_hook_down_deletes(self): + """down should delete the created json file.""" + nic = 'eth1' + populate_dir( + self.tmp, {nic + ".json": "{'abcd'}", 'myfile.txt': 'text'}) + dhc.run_hook(nic, 'down', data_d=self.tmp, env={'old_host_name': 'x1'}) + self.assertEqual( + set(['myfile.txt']), + set(dir2dict(self.tmp + os.path.sep))) + + def test_get_parser(self): + """Smoke test creation of get_parser.""" + # cloud-init main uses 'action'. + event, interface = (dhc.UP, 'mynic0') + self.assertEqual( + argparse.Namespace(event=event, interface=interface, + action=(dhc.NAME, dhc.handle_args)), + dhc.get_parser().parse_args([event, interface])) + + +# vi: ts=4 expandtab diff --git a/cloudinit/tests/test_url_helper.py b/cloudinit/tests/test_url_helper.py index 113249d9..aa9f3ec1 100644 --- a/cloudinit/tests/test_url_helper.py +++ b/cloudinit/tests/test_url_helper.py @@ -1,10 +1,12 @@ # This file is part of cloud-init. See LICENSE file for license information. -from cloudinit.url_helper import oauth_headers, read_file_or_url +from cloudinit.url_helper import ( + NOT_FOUND, UrlError, oauth_headers, read_file_or_url, retry_on_url_exc) from cloudinit.tests.helpers import CiTestCase, mock, skipIf from cloudinit import util import httpretty +import requests try: @@ -64,3 +66,24 @@ class TestReadFileOrUrl(CiTestCase): result = read_file_or_url(url) self.assertEqual(result.contents, data) self.assertEqual(str(result), data.decode('utf-8')) + + +class TestRetryOnUrlExc(CiTestCase): + + def test_do_not_retry_non_urlerror(self): + """When exception is not UrlError return False.""" + myerror = IOError('something unexcpected') + self.assertFalse(retry_on_url_exc(msg='', exc=myerror)) + + def test_perform_retries_on_not_found(self): + """When exception is UrlError with a 404 status code return True.""" + myerror = UrlError(cause=RuntimeError( + 'something was not found'), code=NOT_FOUND) + self.assertTrue(retry_on_url_exc(msg='', exc=myerror)) + + def test_perform_retries_on_timeout(self): + """When exception is a requests.Timout return True.""" + myerror = UrlError(cause=requests.Timeout('something timed out')) + self.assertTrue(retry_on_url_exc(msg='', exc=myerror)) + +# vi: ts=4 expandtab diff --git a/cloudinit/tests/test_util.py b/cloudinit/tests/test_util.py index edb0c18f..e3d2dbaa 100644 --- a/cloudinit/tests/test_util.py +++ b/cloudinit/tests/test_util.py @@ -18,25 +18,51 @@ MOUNT_INFO = [ ] OS_RELEASE_SLES = dedent("""\ - NAME="SLES"\n - VERSION="12-SP3"\n - VERSION_ID="12.3"\n - PRETTY_NAME="SUSE Linux Enterprise Server 12 SP3"\n - ID="sles"\nANSI_COLOR="0;32"\n - CPE_NAME="cpe:/o:suse:sles:12:sp3"\n + NAME="SLES" + VERSION="12-SP3" + VERSION_ID="12.3" + PRETTY_NAME="SUSE Linux Enterprise Server 12 SP3" + ID="sles" + ANSI_COLOR="0;32" + CPE_NAME="cpe:/o:suse:sles:12:sp3" """) OS_RELEASE_OPENSUSE = dedent("""\ -NAME="openSUSE Leap" -VERSION="42.3" -ID=opensuse -ID_LIKE="suse" -VERSION_ID="42.3" -PRETTY_NAME="openSUSE Leap 42.3" -ANSI_COLOR="0;32" -CPE_NAME="cpe:/o:opensuse:leap:42.3" -BUG_REPORT_URL="https://bugs.opensuse.org" -HOME_URL="https://www.opensuse.org/" + NAME="openSUSE Leap" + VERSION="42.3" + ID=opensuse + ID_LIKE="suse" + VERSION_ID="42.3" + PRETTY_NAME="openSUSE Leap 42.3" + ANSI_COLOR="0;32" + CPE_NAME="cpe:/o:opensuse:leap:42.3" + BUG_REPORT_URL="https://bugs.opensuse.org" + HOME_URL="https://www.opensuse.org/" +""") + +OS_RELEASE_OPENSUSE_L15 = dedent("""\ + NAME="openSUSE Leap" + VERSION="15.0" + ID="opensuse-leap" + ID_LIKE="suse opensuse" + VERSION_ID="15.0" + PRETTY_NAME="openSUSE Leap 15.0" + ANSI_COLOR="0;32" + CPE_NAME="cpe:/o:opensuse:leap:15.0" + BUG_REPORT_URL="https://bugs.opensuse.org" + HOME_URL="https://www.opensuse.org/" +""") + +OS_RELEASE_OPENSUSE_TW = dedent("""\ + NAME="openSUSE Tumbleweed" + ID="opensuse-tumbleweed" + ID_LIKE="opensuse suse" + VERSION_ID="20180920" + PRETTY_NAME="openSUSE Tumbleweed" + ANSI_COLOR="0;32" + CPE_NAME="cpe:/o:opensuse:tumbleweed:20180920" + BUG_REPORT_URL="https://bugs.opensuse.org" + HOME_URL="https://www.opensuse.org/" """) OS_RELEASE_CENTOS = dedent("""\ @@ -447,12 +473,35 @@ class TestGetLinuxDistro(CiTestCase): @mock.patch('cloudinit.util.load_file') def test_get_linux_opensuse(self, m_os_release, m_path_exists): - """Verify we get the correct name and machine arch on OpenSUSE.""" + """Verify we get the correct name and machine arch on openSUSE + prior to openSUSE Leap 15. + """ m_os_release.return_value = OS_RELEASE_OPENSUSE m_path_exists.side_effect = TestGetLinuxDistro.os_release_exists dist = util.get_linux_distro() self.assertEqual(('opensuse', '42.3', platform.machine()), dist) + @mock.patch('cloudinit.util.load_file') + def test_get_linux_opensuse_l15(self, m_os_release, m_path_exists): + """Verify we get the correct name and machine arch on openSUSE + for openSUSE Leap 15.0 and later. + """ + m_os_release.return_value = OS_RELEASE_OPENSUSE_L15 + m_path_exists.side_effect = TestGetLinuxDistro.os_release_exists + dist = util.get_linux_distro() + self.assertEqual(('opensuse-leap', '15.0', platform.machine()), dist) + + @mock.patch('cloudinit.util.load_file') + def test_get_linux_opensuse_tw(self, m_os_release, m_path_exists): + """Verify we get the correct name and machine arch on openSUSE + for openSUSE Tumbleweed + """ + m_os_release.return_value = OS_RELEASE_OPENSUSE_TW + m_path_exists.side_effect = TestGetLinuxDistro.os_release_exists + dist = util.get_linux_distro() + self.assertEqual( + ('opensuse-tumbleweed', '20180920', platform.machine()), dist) + @mock.patch('platform.dist') def test_get_linux_distro_no_data(self, m_platform_dist, m_path_exists): """Verify we get no information if os-release does not exist""" @@ -478,4 +527,20 @@ class TestGetLinuxDistro(CiTestCase): dist = util.get_linux_distro() self.assertEqual(('foo', '1.1', 'aarch64'), dist) + +@mock.patch('os.path.exists') +class TestIsLXD(CiTestCase): + + def test_is_lxd_true_on_sock_device(self, m_exists): + """When lxd's /dev/lxd/sock exists, is_lxd returns true.""" + m_exists.return_value = True + self.assertTrue(util.is_lxd()) + m_exists.assert_called_once_with('/dev/lxd/sock') + + def test_is_lxd_false_when_sock_device_absent(self, m_exists): + """When lxd's /dev/lxd/sock is absent, is_lxd returns false.""" + m_exists.return_value = False + self.assertFalse(util.is_lxd()) + m_exists.assert_called_once_with('/dev/lxd/sock') + # vi: ts=4 expandtab diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 8067979e..396d69ae 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -199,7 +199,7 @@ def _get_ssl_args(url, ssl_details): def readurl(url, data=None, timeout=None, retries=0, sec_between=1, headers=None, headers_cb=None, ssl_details=None, check_status=True, allow_redirects=True, exception_cb=None, - session=None, infinite=False): + session=None, infinite=False, log_req_resp=True): url = _cleanurl(url) req_args = { 'url': url, @@ -256,9 +256,11 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1, continue filtered_req_args[k] = v try: - LOG.debug("[%s/%s] open '%s' with %s configuration", i, - "infinite" if infinite else manual_tries, url, - filtered_req_args) + + if log_req_resp: + LOG.debug("[%s/%s] open '%s' with %s configuration", i, + "infinite" if infinite else manual_tries, url, + filtered_req_args) if session is None: session = requests.Session() @@ -294,8 +296,11 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1, break if (infinite and sec_between > 0) or \ (i + 1 < manual_tries and sec_between > 0): - LOG.debug("Please wait %s seconds while we wait to try again", - sec_between) + + if log_req_resp: + LOG.debug( + "Please wait %s seconds while we wait to try again", + sec_between) time.sleep(sec_between) if excps: raise excps[-1] @@ -549,4 +554,18 @@ def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, _uri, signed_headers, _body = client.sign(url) return signed_headers + +def retry_on_url_exc(msg, exc): + """readurl exception_cb that will retry on NOT_FOUND and Timeout. + + Returns False to raise the exception from readurl, True to retry. + """ + if not isinstance(exc, UrlError): + return False + if exc.code == NOT_FOUND: + return True + if exc.cause and isinstance(exc.cause, requests.Timeout): + return True + return False + # vi: ts=4 expandtab diff --git a/cloudinit/util.py b/cloudinit/util.py index 50680960..7800f7bc 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -615,8 +615,8 @@ def get_linux_distro(): distro_name = os_release.get('ID', '') distro_version = os_release.get('VERSION_ID', '') if 'sles' in distro_name or 'suse' in distro_name: - # RELEASE_BLOCKER: We will drop this sles ivergent behavior in - # before 18.4 so that get_linux_distro returns a named tuple + # RELEASE_BLOCKER: We will drop this sles divergent behavior in + # the future so that get_linux_distro returns a named tuple # which will include both version codename and architecture # on all distributions. flavor = platform.machine() @@ -668,7 +668,8 @@ def system_info(): var = 'ubuntu' elif linux_dist == 'redhat': var = 'rhel' - elif linux_dist in ('opensuse', 'sles'): + elif linux_dist in ( + 'opensuse', 'opensuse-tumbleweed', 'opensuse-leap', 'sles'): var = 'suse' else: var = 'linux' @@ -2171,6 +2172,11 @@ def is_container(): return False +def is_lxd(): + """Check to see if we are running in a lxd container.""" + return os.path.exists('/dev/lxd/sock') + + def get_proc_env(pid, encoding='utf-8', errors='replace'): """ Return the environment in a dict that a given process id was started with. diff --git a/cloudinit/version.py b/cloudinit/version.py index 844a02e0..a2c5d43a 100644 --- a/cloudinit/version.py +++ b/cloudinit/version.py @@ -4,7 +4,7 @@ # # This file is part of cloud-init. See LICENSE file for license information. -__VERSION__ = "18.4" +__VERSION__ = "18.5" _PACKAGED_VERSION = '@@PACKAGED_VERSION@@' FEATURES = [ |