diff options
Diffstat (limited to 'tests/unittests/test_datasource/test_smartos.py')
-rw-r--r-- | tests/unittests/test_datasource/test_smartos.py | 336 |
1 files changed, 305 insertions, 31 deletions
diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 88bae5f9..46d67b94 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -1,4 +1,5 @@ # Copyright (C) 2013 Canonical Ltd. +# Copyright (c) 2018, Joyent, Inc. # # Author: Ben Howard <ben.howard@canonical.com> # @@ -15,26 +16,40 @@ from __future__ import print_function from binascii import crc32 import json +import multiprocessing import os import os.path import re -import shutil +import signal import stat -import tempfile +import unittest2 import uuid from cloudinit import serial from cloudinit.sources import DataSourceSmartOS from cloudinit.sources.DataSourceSmartOS import ( - convert_smartos_network_data as convert_net) + convert_smartos_network_data as convert_net, + SMARTOS_ENV_KVM, SERIAL_DEVICE, get_smartos_environ, + identify_file) import six from cloudinit import helpers as c_helpers -from cloudinit.util import b64e +from cloudinit.util import ( + b64e, subp, ProcessExecutionError, which, write_file) -from cloudinit.tests.helpers import mock, FilesystemMockingTestCase, TestCase +from cloudinit.tests.helpers import ( + CiTestCase, mock, FilesystemMockingTestCase, skipIf) + +try: + import serial as _pyserial + assert _pyserial # avoid pyflakes error F401: import unused + HAS_PYSERIAL = True +except ImportError: + HAS_PYSERIAL = False + +DSMOS = 'cloudinit.sources.DataSourceSmartOS' SDC_NICS = json.loads(""" [ { @@ -318,12 +333,19 @@ MOCK_RETURNS = { DMI_DATA_RETURN = 'smartdc' +# Useful for calculating the length of a frame body. A SUCCESS body will be +# followed by more characters or be one character less if SUCCESS with no +# payload. See Section 4.3 of https://eng.joyent.com/mdata/protocol.html. +SUCCESS_LEN = len('0123abcd SUCCESS ') +NOTFOUND_LEN = len('0123abcd NOTFOUND') + class PsuedoJoyentClient(object): def __init__(self, data=None): if data is None: data = MOCK_RETURNS.copy() self.data = data + self._is_open = False return def get(self, key, default=None, strip=False): @@ -344,39 +366,43 @@ class PsuedoJoyentClient(object): def exists(self): return True + def open_transport(self): + assert(not self._is_open) + self._is_open = True + + def close_transport(self): + assert(self._is_open) + self._is_open = False + class TestSmartOSDataSource(FilesystemMockingTestCase): + jmc_cfact = None + get_smartos_environ = None + def setUp(self): super(TestSmartOSDataSource, self).setUp() - dsmos = 'cloudinit.sources.DataSourceSmartOS' - patcher = mock.patch(dsmos + ".jmc_client_factory") - self.jmc_cfact = patcher.start() - self.addCleanup(patcher.stop) - patcher = mock.patch(dsmos + ".get_smartos_environ") - self.get_smartos_environ = patcher.start() - self.addCleanup(patcher.stop) - - self.tmp = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, self.tmp) - self.paths = c_helpers.Paths( - {'cloud_dir': self.tmp, 'run_dir': self.tmp}) - - self.legacy_user_d = os.path.join(self.tmp, 'legacy_user_tmp') + self.add_patch(DSMOS + ".get_smartos_environ", "get_smartos_environ") + self.add_patch(DSMOS + ".jmc_client_factory", "jmc_cfact") + self.legacy_user_d = self.tmp_path('legacy_user_tmp') os.mkdir(self.legacy_user_d) - - self.orig_lud = DataSourceSmartOS.LEGACY_USER_D - DataSourceSmartOS.LEGACY_USER_D = self.legacy_user_d - - def tearDown(self): - DataSourceSmartOS.LEGACY_USER_D = self.orig_lud - super(TestSmartOSDataSource, self).tearDown() + self.add_patch(DSMOS + ".LEGACY_USER_D", "m_legacy_user_d", + autospec=False, new=self.legacy_user_d) + self.add_patch(DSMOS + ".identify_file", "m_identify_file", + return_value="text/plain") def _get_ds(self, mockdata=None, mode=DataSourceSmartOS.SMARTOS_ENV_KVM, sys_cfg=None, ds_cfg=None): self.jmc_cfact.return_value = PsuedoJoyentClient(mockdata) self.get_smartos_environ.return_value = mode + tmpd = self.tmp_dir() + dirs = {'cloud_dir': self.tmp_path('cloud_dir', tmpd), + 'run_dir': self.tmp_path('run_dir')} + for d in dirs.values(): + os.mkdir(d) + paths = c_helpers.Paths(dirs) + if sys_cfg is None: sys_cfg = {} @@ -385,7 +411,7 @@ class TestSmartOSDataSource(FilesystemMockingTestCase): sys_cfg['datasource']['SmartOS'] = ds_cfg return DataSourceSmartOS.DataSourceSmartOS( - sys_cfg, distro=None, paths=self.paths) + sys_cfg, distro=None, paths=paths) def test_no_base64(self): ds_cfg = {'no_base64_decode': ['test_var1'], 'all_base': True} @@ -421,6 +447,34 @@ class TestSmartOSDataSource(FilesystemMockingTestCase): self.assertEqual(MOCK_RETURNS['hostname'], dsrc.metadata['local-hostname']) + def test_hostname_if_no_sdc_hostname(self): + my_returns = MOCK_RETURNS.copy() + my_returns['sdc:hostname'] = 'sdc-' + my_returns['hostname'] + dsrc = self._get_ds(mockdata=my_returns) + ret = dsrc.get_data() + self.assertTrue(ret) + self.assertEqual(my_returns['hostname'], + dsrc.metadata['local-hostname']) + + def test_sdc_hostname_if_no_hostname(self): + my_returns = MOCK_RETURNS.copy() + my_returns['sdc:hostname'] = 'sdc-' + my_returns['hostname'] + del my_returns['hostname'] + dsrc = self._get_ds(mockdata=my_returns) + ret = dsrc.get_data() + self.assertTrue(ret) + self.assertEqual(my_returns['sdc:hostname'], + dsrc.metadata['local-hostname']) + + def test_sdc_uuid_if_no_hostname_or_sdc_hostname(self): + my_returns = MOCK_RETURNS.copy() + del my_returns['hostname'] + dsrc = self._get_ds(mockdata=my_returns) + ret = dsrc.get_data() + self.assertTrue(ret) + self.assertEqual(my_returns['sdc:uuid'], + dsrc.metadata['local-hostname']) + def test_userdata(self): dsrc = self._get_ds(mockdata=MOCK_RETURNS) ret = dsrc.get_data() @@ -445,6 +499,7 @@ class TestSmartOSDataSource(FilesystemMockingTestCase): dsrc.metadata['user-script']) legacy_script_f = "%s/user-script" % self.legacy_user_d + print("legacy_script_f=%s" % legacy_script_f) self.assertTrue(os.path.exists(legacy_script_f)) self.assertTrue(os.path.islink(legacy_script_f)) user_script_perm = oct(os.stat(legacy_script_f)[stat.ST_MODE])[-3:] @@ -592,8 +647,68 @@ class TestSmartOSDataSource(FilesystemMockingTestCase): mydscfg['disk_aliases']['FOO']) +class TestIdentifyFile(CiTestCase): + """Test the 'identify_file' utility.""" + @skipIf(not which("file"), "command 'file' not available.") + def test_file_happy_path(self): + """Test file is available and functional on plain text.""" + fname = self.tmp_path("myfile") + write_file(fname, "plain text content here\n") + with self.allow_subp(["file"]): + self.assertEqual("text/plain", identify_file(fname)) + + @mock.patch(DSMOS + ".util.subp") + def test_returns_none_on_error(self, m_subp): + """On 'file' execution error, None should be returned.""" + m_subp.side_effect = ProcessExecutionError("FILE_FAILED", exit_code=99) + fname = self.tmp_path("myfile") + write_file(fname, "plain text content here\n") + self.assertEqual(None, identify_file(fname)) + self.assertEqual( + [mock.call(["file", "--brief", "--mime-type", fname])], + m_subp.call_args_list) + + +class ShortReader(object): + """Implements a 'read' interface for bytes provided. + much like io.BytesIO but the 'endbyte' acts as if EOF. + When it is reached a short will be returned.""" + def __init__(self, initial_bytes, endbyte=b'\0'): + self.data = initial_bytes + self.index = 0 + self.len = len(self.data) + self.endbyte = endbyte + + @property + def emptied(self): + return self.index >= self.len + + def read(self, size=-1): + """Read size bytes but not past a null.""" + if size == 0 or self.index >= self.len: + return b'' + + rsize = size + if size < 0 or size + self.index > self.len: + rsize = self.len - self.index + + next_null = self.data.find(self.endbyte, self.index, rsize) + if next_null >= 0: + rsize = next_null - self.index + 1 + i = self.index + self.index += rsize + ret = self.data[i:i + rsize] + if len(ret) and ret[-1:] == self.endbyte: + ret = ret[:-1] + return ret + + class TestJoyentMetadataClient(FilesystemMockingTestCase): + invalid = b'invalid command\n' + failure = b'FAILURE\n' + v2_ok = b'V2_OK\n' + def setUp(self): super(TestJoyentMetadataClient, self).setUp() @@ -603,7 +718,7 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): self.response_parts = { 'command': 'SUCCESS', 'crc': 'b5a9ff00', - 'length': 17 + len(b64e(self.metadata_value)), + 'length': SUCCESS_LEN + len(b64e(self.metadata_value)), 'payload': b64e(self.metadata_value), 'request_id': '{0:08x}'.format(self.request_id), } @@ -636,6 +751,11 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): return DataSourceSmartOS.JoyentMetadataClient( fp=self.serial, smartos_type=DataSourceSmartOS.SMARTOS_ENV_KVM) + def _get_serial_client(self): + self.serial.timeout = 1 + return DataSourceSmartOS.JoyentMetadataSerialClient(None, + fp=self.serial) + def assertEndsWith(self, haystack, prefix): self.assertTrue(haystack.endswith(prefix), "{0} does not end with '{1}'".format( @@ -646,12 +766,14 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): "{0} does not start with '{1}'".format( repr(haystack), prefix)) + def assertNoMoreSideEffects(self, obj): + self.assertRaises(StopIteration, obj) + def test_get_metadata_writes_a_single_line(self): client = self._get_client() client.get('some_key') self.assertEqual(1, self.serial.write.call_count) written_line = self.serial.write.call_args[0][0] - print(type(written_line)) self.assertEndsWith(written_line.decode('ascii'), b'\n'.decode('ascii')) self.assertEqual(1, written_line.count(b'\n')) @@ -732,13 +854,75 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): def test_get_metadata_returns_None_if_value_not_found(self): self.response_parts['payload'] = '' self.response_parts['command'] = 'NOTFOUND' - self.response_parts['length'] = 17 + self.response_parts['length'] = NOTFOUND_LEN client = self._get_client() client._checksum = lambda _: self.response_parts['crc'] self.assertIsNone(client.get('some_key')) + def test_negotiate(self): + client = self._get_client() + reader = ShortReader(self.v2_ok) + client.fp.read.side_effect = reader.read + client._negotiate() + self.assertTrue(reader.emptied) + + def test_negotiate_short_response(self): + client = self._get_client() + # chopped '\n' from v2_ok. + reader = ShortReader(self.v2_ok[:-1] + b'\0') + client.fp.read.side_effect = reader.read + self.assertRaises(DataSourceSmartOS.JoyentMetadataTimeoutException, + client._negotiate) + self.assertTrue(reader.emptied) + + def test_negotiate_bad_response(self): + client = self._get_client() + reader = ShortReader(b'garbage\n' + self.v2_ok) + client.fp.read.side_effect = reader.read + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client._negotiate) + self.assertEqual(self.v2_ok, client.fp.read()) + + def test_serial_open_transport(self): + client = self._get_serial_client() + reader = ShortReader(b'garbage\0' + self.invalid + self.v2_ok) + client.fp.read.side_effect = reader.read + client.open_transport() + self.assertTrue(reader.emptied) + + def test_flush_failure(self): + client = self._get_serial_client() + reader = ShortReader(b'garbage' + b'\0' + self.failure + + self.invalid + self.v2_ok) + client.fp.read.side_effect = reader.read + client.open_transport() + self.assertTrue(reader.emptied) + + def test_flush_many_timeouts(self): + client = self._get_serial_client() + reader = ShortReader(b'\0' * 100 + self.invalid + self.v2_ok) + client.fp.read.side_effect = reader.read + client.open_transport() + self.assertTrue(reader.emptied) + + def test_list_metadata_returns_list(self): + parts = ['foo', 'bar'] + value = b64e('\n'.join(parts)) + self.response_parts['payload'] = value + self.response_parts['crc'] = '40873553' + self.response_parts['length'] = SUCCESS_LEN + len(value) + client = self._get_client() + self.assertEqual(client.list(), parts) + + def test_list_metadata_returns_empty_list_if_no_customer_metadata(self): + del self.response_parts['payload'] + self.response_parts['length'] = SUCCESS_LEN - 1 + self.response_parts['crc'] = '14e563ba' + client = self._get_client() + self.assertEqual(client.list(), []) + -class TestNetworkConversion(TestCase): +class TestNetworkConversion(CiTestCase): def test_convert_simple(self): expected = { 'version': 1, @@ -872,4 +1056,94 @@ class TestNetworkConversion(TestCase): found = convert_net(SDC_NICS_SINGLE_GATEWAY) self.assertEqual(expected, found) + def test_routes_on_all_nics(self): + routes = [ + {'linklocal': False, 'dst': '3.0.0.0/8', 'gateway': '8.12.42.3'}, + {'linklocal': False, 'dst': '4.0.0.0/8', 'gateway': '10.210.1.4'}] + expected = { + 'version': 1, + 'config': [ + {'mac_address': '90:b8:d0:d8:82:b4', 'mtu': 1500, + 'name': 'net0', 'type': 'physical', + 'subnets': [{'address': '8.12.42.26/24', + 'gateway': '8.12.42.1', 'type': 'static', + 'routes': [{'network': '3.0.0.0/8', + 'gateway': '8.12.42.3'}, + {'network': '4.0.0.0/8', + 'gateway': '10.210.1.4'}]}]}, + {'mac_address': '90:b8:d0:0a:51:31', 'mtu': 1500, + 'name': 'net1', 'type': 'physical', + 'subnets': [{'address': '10.210.1.27/24', 'type': 'static', + 'routes': [{'network': '3.0.0.0/8', + 'gateway': '8.12.42.3'}, + {'network': '4.0.0.0/8', + 'gateway': '10.210.1.4'}]}]}]} + found = convert_net(SDC_NICS_SINGLE_GATEWAY, routes=routes) + self.maxDiff = None + self.assertEqual(expected, found) + + +@unittest2.skipUnless(get_smartos_environ() == SMARTOS_ENV_KVM, + "Only supported on KVM and bhyve guests under SmartOS") +@unittest2.skipUnless(os.access(SERIAL_DEVICE, os.W_OK), + "Requires write access to " + SERIAL_DEVICE) +@unittest2.skipUnless(HAS_PYSERIAL is True, "pyserial not available") +class TestSerialConcurrency(CiTestCase): + """ + This class tests locking on an actual serial port, and as such can only + be run in a kvm or bhyve guest running on a SmartOS host. A test run on + a metadata socket will not be valid because a metadata socket ensures + there is only one session over a connection. In contrast, in the + absence of proper locking multiple processes opening the same serial + port can corrupt each others' exchanges with the metadata server. + + This takes on the order of 2 to 3 minutes to run. + """ + allowed_subp = ['mdata-get'] + + def setUp(self): + self.mdata_proc = multiprocessing.Process(target=self.start_mdata_loop) + self.mdata_proc.start() + super(TestSerialConcurrency, self).setUp() + + def tearDown(self): + # os.kill() rather than mdata_proc.terminate() to avoid console spam. + os.kill(self.mdata_proc.pid, signal.SIGKILL) + self.mdata_proc.join() + super(TestSerialConcurrency, self).tearDown() + + def start_mdata_loop(self): + """ + The mdata-get command is repeatedly run in a separate process so + that it may try to race with metadata operations performed in the + main test process. Use of mdata-get is better than two processes + using the protocol implementation in DataSourceSmartOS because we + are testing to be sure that cloud-init and mdata-get respect each + others locks. + """ + rcs = list(range(0, 256)) + while True: + subp(['mdata-get', 'sdc:routes'], rcs=rcs) + + def test_all_keys(self): + self.assertIsNotNone(self.mdata_proc.pid) + ds = DataSourceSmartOS + keys = [tup[0] for tup in ds.SMARTOS_ATTRIB_MAP.values()] + keys.extend(ds.SMARTOS_ATTRIB_JSON.values()) + + client = ds.jmc_client_factory(smartos_type=SMARTOS_ENV_KVM) + self.assertIsNotNone(client) + + # The behavior that we are testing for was observed mdata-get running + # 10 times at roughly the same time as cloud-init fetched each key + # once. cloud-init would regularly see failures before making it + # through all keys once. + for _ in range(0, 3): + for key in keys: + # We don't care about the return value, just that it doesn't + # thrown any exceptions. + client.get(key) + + self.assertIsNone(self.mdata_proc.exitcode) + # vi: ts=4 expandtab |