diff options
-rw-r--r-- | cloudinit/sources/DataSourceSmartOS.py | 75 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_smartos.py | 216 |
2 files changed, 239 insertions, 52 deletions
diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 896fde3f..694a011a 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -29,9 +29,10 @@ # http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html # Comments with "@datadictionary" are snippets of the definition -import base64 import binascii import os +import random +import re import serial from cloudinit import log as logging @@ -301,6 +302,53 @@ def get_serial(seed_device, seed_timeout): return ser +class JoyentMetadataFetchException(Exception): + pass + + +class JoyentMetadataClient(object): + + def __init__(self, serial): + self.serial = serial + + def _checksum(self, body): + return '{0:08x}'.format( + binascii.crc32(body.encode('utf-8')) & 0xffffffff) + + def _get_value_from_frame(self, expected_request_id, frame): + regex = ( + r'V2 (?P<length>\d+) (?P<checksum>[0-9a-f]+)' + r' (?P<body>(?P<request_id>[0-9a-f]+) (?P<status>SUCCESS|NOTFOUND)' + r'( (?P<payload>.+))?)') + frame_data = re.match(regex, frame).groupdict() + if int(frame_data['length']) != len(frame_data['body']): + raise JoyentMetadataFetchException( + 'Incorrect frame length given ({0} != {1}).'.format( + frame_data['length'], len(frame_data['body']))) + expected_checksum = self._checksum(frame_data['body']) + if frame_data['checksum'] != expected_checksum: + raise JoyentMetadataFetchException( + 'Invalid checksum (expected: {0}; got {1}).'.format( + expected_checksum, frame_data['checksum'])) + if frame_data['request_id'] != expected_request_id: + raise JoyentMetadataFetchException( + 'Request ID mismatch (expected: {0}; got {1}).'.format( + expected_request_id, frame_data['request_id'])) + if not frame_data.get('payload', None): + return None + return util.b64d(frame_data['payload']) + + def get_metadata(self, metadata_key): + request_id = '{0:08x}'.format(random.randint(0, 0xffffffff)) + message_body = '{0} GET {1}'.format(request_id, + util.b64e(metadata_key)) + msg = 'V2 {0} {1} {2}\n'.format( + len(message_body), self._checksum(message_body), message_body) + self.serial.write(msg.encode('ascii')) + response = self.serial.readline().decode('ascii') + return self._get_value_from_frame(request_id, response) + + def query_data(noun, seed_device, seed_timeout, strip=False, default=None, b64=None): """Makes a request to via the serial console via "GET <NOUN>" @@ -314,34 +362,21 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, encoded, so this method relies on being told if the data is base64 or not. """ - if not noun: return False ser = get_serial(seed_device, seed_timeout) - request_line = "GET %s\n" % noun.rstrip() - ser.write(request_line.encode('ascii')) - status = str(ser.readline()).rstrip() - response = [] - eom_found = False - - if 'SUCCESS' not in status: - ser.close() - return default - - while not eom_found: - m = ser.readline().decode('ascii') - if m.rstrip() == ".": - eom_found = True - else: - response.append(m) + client = JoyentMetadataClient(ser) + response = client.get_metadata(noun) ser.close() + if response is None: + return default if b64 is None: b64 = query_data('b64-%s' % noun, seed_device=seed_device, - seed_timeout=seed_timeout, b64=False, - default=False, strip=True) + seed_timeout=seed_timeout, b64=False, + default=False, strip=True) b64 = util.is_true(b64) resp = None diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index cdd83bf8..c79cf3aa 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -31,15 +31,24 @@ import shutil import stat import tempfile import uuid +from binascii import crc32 + +import serial +import six import six from cloudinit import helpers as c_helpers from cloudinit.sources import DataSourceSmartOS -from cloudinit.util import b64e +from cloudinit.util import b64d, b64e from .. import helpers +try: + from unittest import mock +except ImportError: + import mock + MOCK_RETURNS = { 'hostname': 'test-host', 'root_authorized_keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname', @@ -57,6 +66,37 @@ MOCK_RETURNS = { DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc') +def _checksum(body): + return '{0:08x}'.format(crc32(body.encode('utf-8')) & 0xffffffff) + + +def _generate_v2_frame(request_id, command, body=None): + body_parts = [request_id, command] + if body: + body_parts.append(b64e(body)) + message_body = ' '.join(body_parts) + return 'V2 {0} {1} {2}\n'.format( + len(message_body), _checksum(message_body), message_body).encode( + 'ascii') + + +def _parse_v2_frame(line): + line = line.decode('ascii') + if not line.endswith('\n'): + raise Exception('Frames must end with a newline.') + version, length, checksum, body = line.strip().split(' ', 3) + if version != 'V2': + raise Exception('Frames must begin with V2.') + if int(length) != len(body): + raise Exception('Incorrect frame length given ({0} != {1}).'.format( + length, len(body))) + expected_checksum = _checksum(body) + if checksum != expected_checksum: + raise Exception('Invalid checksum.') + request_id, command, payload = body.split() + return request_id, command, b64d(payload) + + class MockSerial(object): """Fake a serial terminal for testing the code that interfaces with the serial""" @@ -81,39 +121,21 @@ class MockSerial(object): return True def write(self, line): - if not isinstance(line, six.binary_type): - raise TypeError("Should be writing binary lines.") - line = line.decode('ascii').replace('GET ', '') - self.last = line.rstrip() + self.last = line def readline(self): - if self.new: - self.new = False - if self.last in self.mockdata: - line = 'SUCCESS\n' - else: - line = 'NOTFOUND %s\n' % self.last - - elif self.last in self.mockdata: - if not self.mocked_out: - self.mocked_out = [x for x in self._format_out()] - - if len(self.mocked_out) > self.count: - self.count += 1 - line = self.mocked_out[self.count - 1] - return line.encode('ascii') - - def _format_out(self): - if self.last in self.mockdata: - _mret = self.mockdata[self.last] - try: - for l in _mret.splitlines(): - yield "%s\n" % l.rstrip() - except: - yield "%s\n" % _mret.rstrip() - - yield '.' - yield '\n' + if self.last == '\n': + return 'invalid command\n' + elif self.last == 'NEGOTIATE V2\n': + return 'V2_OK\n' + request_id, command, request_body = _parse_v2_frame(self.last) + if command != 'GET': + raise Exception('MockSerial only supports GET requests.') + metadata_key = request_body.strip() + if metadata_key in self.mockdata: + return _generate_v2_frame( + request_id, 'SUCCESS', self.mockdata[metadata_key]) + return _generate_v2_frame(request_id, 'NOTFOUND') class TestSmartOSDataSource(helpers.FilesystemMockingTestCase): @@ -459,3 +481,133 @@ def apply_patches(patches): setattr(ref, name, replace) ret.append((ref, name, orig)) return ret + + +class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase): + + def setUp(self): + super(TestJoyentMetadataClient, self).setUp() + self.serial = mock.MagicMock(spec=serial.Serial) + self.request_id = 0xabcdef12 + self.metadata_value = 'value' + self.response_parts = { + 'command': 'SUCCESS', + 'crc': 'b5a9ff00', + 'length': 17 + len(b64e(self.metadata_value)), + 'payload': b64e(self.metadata_value), + 'request_id': '{0:08x}'.format(self.request_id), + } + + def make_response(): + payload = '' + if self.response_parts['payload']: + payload = ' {0}'.format(self.response_parts['payload']) + del self.response_parts['payload'] + return ( + 'V2 {length} {crc} {request_id} {command}{payload}\n'.format( + payload=payload, **self.response_parts).encode('ascii')) + self.serial.readline.side_effect = make_response + self.patched_funcs.enter_context( + mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint', + mock.Mock(return_value=self.request_id))) + + def _get_client(self): + return DataSourceSmartOS.JoyentMetadataClient(self.serial) + + def assertEndsWith(self, haystack, prefix): + self.assertTrue(haystack.endswith(prefix), + "{0} does not end with '{1}'".format( + repr(haystack), prefix)) + + def assertStartsWith(self, haystack, prefix): + self.assertTrue(haystack.startswith(prefix), + "{0} does not start with '{1}'".format( + repr(haystack), prefix)) + + def test_get_metadata_writes_a_single_line(self): + client = self._get_client() + client.get_metadata('some_key') + self.assertEqual(1, self.serial.write.call_count) + written_line = self.serial.write.call_args[0][0] + self.assertEndsWith(written_line, b'\n') + self.assertEqual(1, written_line.count(b'\n')) + + def _get_written_line(self, key='some_key'): + client = self._get_client() + client.get_metadata(key) + return self.serial.write.call_args[0][0] + + def test_get_metadata_writes_bytes(self): + self.assertIsInstance(self._get_written_line(), six.binary_type) + + def test_get_metadata_line_starts_with_v2(self): + self.assertStartsWith(self._get_written_line(), b'V2') + + def test_get_metadata_uses_get_command(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + self.assertEqual('GET', parts[4]) + + def test_get_metadata_base64_encodes_argument(self): + key = 'my_key' + parts = self._get_written_line(key).decode('ascii').strip().split(' ') + self.assertEqual(b64e(key), parts[5]) + + def test_get_metadata_calculates_length_correctly(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + expected_length = len(' '.join(parts[3:])) + self.assertEqual(expected_length, int(parts[1])) + + def test_get_metadata_uses_appropriate_request_id(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + request_id = parts[3] + self.assertEqual(8, len(request_id)) + self.assertEqual(request_id, request_id.lower()) + + def test_get_metadata_uses_random_number_for_request_id(self): + line = self._get_written_line() + request_id = line.decode('ascii').strip().split(' ')[3] + self.assertEqual('{0:08x}'.format(self.request_id), request_id) + + def test_get_metadata_checksums_correctly(self): + parts = self._get_written_line().decode('ascii').strip().split(' ') + expected_checksum = '{0:08x}'.format( + crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff) + checksum = parts[2] + self.assertEqual(expected_checksum, checksum) + + def test_get_metadata_reads_a_line(self): + client = self._get_client() + client.get_metadata('some_key') + self.assertEqual(1, self.serial.readline.call_count) + + def test_get_metadata_returns_valid_value(self): + client = self._get_client() + value = client.get_metadata('some_key') + self.assertEqual(self.metadata_value, value) + + def test_get_metadata_throws_exception_for_incorrect_length(self): + self.response_parts['length'] = 0 + client = self._get_client() + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client.get_metadata, 'some_key') + + def test_get_metadata_throws_exception_for_incorrect_crc(self): + self.response_parts['crc'] = 'deadbeef' + client = self._get_client() + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client.get_metadata, 'some_key') + + def test_get_metadata_throws_exception_for_request_id_mismatch(self): + self.response_parts['request_id'] = 'deadbeef' + client = self._get_client() + client._checksum = lambda _: self.response_parts['crc'] + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client.get_metadata, 'some_key') + + def test_get_metadata_returns_None_if_value_not_found(self): + self.response_parts['payload'] = '' + self.response_parts['command'] = 'NOTFOUND' + self.response_parts['length'] = 17 + client = self._get_client() + client._checksum = lambda _: self.response_parts['crc'] + self.assertIsNone(client.get_metadata('some_key')) |