summaryrefslogtreecommitdiff
path: root/tests/unittests/test_datasource/test_smartos.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittests/test_datasource/test_smartos.py')
-rw-r--r--tests/unittests/test_datasource/test_smartos.py216
1 files changed, 184 insertions, 32 deletions
diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py
index cdd83bf8..c79cf3aa 100644
--- a/tests/unittests/test_datasource/test_smartos.py
+++ b/tests/unittests/test_datasource/test_smartos.py
@@ -31,15 +31,24 @@ import shutil
import stat
import tempfile
import uuid
+from binascii import crc32
+
+import serial
+import six
import six
from cloudinit import helpers as c_helpers
from cloudinit.sources import DataSourceSmartOS
-from cloudinit.util import b64e
+from cloudinit.util import b64d, b64e
from .. import helpers
+try:
+ from unittest import mock
+except ImportError:
+ import mock
+
MOCK_RETURNS = {
'hostname': 'test-host',
'root_authorized_keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname',
@@ -57,6 +66,37 @@ MOCK_RETURNS = {
DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc')
+def _checksum(body):
+ return '{0:08x}'.format(crc32(body.encode('utf-8')) & 0xffffffff)
+
+
+def _generate_v2_frame(request_id, command, body=None):
+ body_parts = [request_id, command]
+ if body:
+ body_parts.append(b64e(body))
+ message_body = ' '.join(body_parts)
+ return 'V2 {0} {1} {2}\n'.format(
+ len(message_body), _checksum(message_body), message_body).encode(
+ 'ascii')
+
+
+def _parse_v2_frame(line):
+ line = line.decode('ascii')
+ if not line.endswith('\n'):
+ raise Exception('Frames must end with a newline.')
+ version, length, checksum, body = line.strip().split(' ', 3)
+ if version != 'V2':
+ raise Exception('Frames must begin with V2.')
+ if int(length) != len(body):
+ raise Exception('Incorrect frame length given ({0} != {1}).'.format(
+ length, len(body)))
+ expected_checksum = _checksum(body)
+ if checksum != expected_checksum:
+ raise Exception('Invalid checksum.')
+ request_id, command, payload = body.split()
+ return request_id, command, b64d(payload)
+
+
class MockSerial(object):
"""Fake a serial terminal for testing the code that
interfaces with the serial"""
@@ -81,39 +121,21 @@ class MockSerial(object):
return True
def write(self, line):
- if not isinstance(line, six.binary_type):
- raise TypeError("Should be writing binary lines.")
- line = line.decode('ascii').replace('GET ', '')
- self.last = line.rstrip()
+ self.last = line
def readline(self):
- if self.new:
- self.new = False
- if self.last in self.mockdata:
- line = 'SUCCESS\n'
- else:
- line = 'NOTFOUND %s\n' % self.last
-
- elif self.last in self.mockdata:
- if not self.mocked_out:
- self.mocked_out = [x for x in self._format_out()]
-
- if len(self.mocked_out) > self.count:
- self.count += 1
- line = self.mocked_out[self.count - 1]
- return line.encode('ascii')
-
- def _format_out(self):
- if self.last in self.mockdata:
- _mret = self.mockdata[self.last]
- try:
- for l in _mret.splitlines():
- yield "%s\n" % l.rstrip()
- except:
- yield "%s\n" % _mret.rstrip()
-
- yield '.'
- yield '\n'
+ if self.last == '\n':
+ return 'invalid command\n'
+ elif self.last == 'NEGOTIATE V2\n':
+ return 'V2_OK\n'
+ request_id, command, request_body = _parse_v2_frame(self.last)
+ if command != 'GET':
+ raise Exception('MockSerial only supports GET requests.')
+ metadata_key = request_body.strip()
+ if metadata_key in self.mockdata:
+ return _generate_v2_frame(
+ request_id, 'SUCCESS', self.mockdata[metadata_key])
+ return _generate_v2_frame(request_id, 'NOTFOUND')
class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
@@ -459,3 +481,133 @@ def apply_patches(patches):
setattr(ref, name, replace)
ret.append((ref, name, orig))
return ret
+
+
+class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase):
+
+ def setUp(self):
+ super(TestJoyentMetadataClient, self).setUp()
+ self.serial = mock.MagicMock(spec=serial.Serial)
+ self.request_id = 0xabcdef12
+ self.metadata_value = 'value'
+ self.response_parts = {
+ 'command': 'SUCCESS',
+ 'crc': 'b5a9ff00',
+ 'length': 17 + len(b64e(self.metadata_value)),
+ 'payload': b64e(self.metadata_value),
+ 'request_id': '{0:08x}'.format(self.request_id),
+ }
+
+ def make_response():
+ payload = ''
+ if self.response_parts['payload']:
+ payload = ' {0}'.format(self.response_parts['payload'])
+ del self.response_parts['payload']
+ return (
+ 'V2 {length} {crc} {request_id} {command}{payload}\n'.format(
+ payload=payload, **self.response_parts).encode('ascii'))
+ self.serial.readline.side_effect = make_response
+ self.patched_funcs.enter_context(
+ mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint',
+ mock.Mock(return_value=self.request_id)))
+
+ def _get_client(self):
+ return DataSourceSmartOS.JoyentMetadataClient(self.serial)
+
+ def assertEndsWith(self, haystack, prefix):
+ self.assertTrue(haystack.endswith(prefix),
+ "{0} does not end with '{1}'".format(
+ repr(haystack), prefix))
+
+ def assertStartsWith(self, haystack, prefix):
+ self.assertTrue(haystack.startswith(prefix),
+ "{0} does not start with '{1}'".format(
+ repr(haystack), prefix))
+
+ def test_get_metadata_writes_a_single_line(self):
+ client = self._get_client()
+ client.get_metadata('some_key')
+ self.assertEqual(1, self.serial.write.call_count)
+ written_line = self.serial.write.call_args[0][0]
+ self.assertEndsWith(written_line, b'\n')
+ self.assertEqual(1, written_line.count(b'\n'))
+
+ def _get_written_line(self, key='some_key'):
+ client = self._get_client()
+ client.get_metadata(key)
+ return self.serial.write.call_args[0][0]
+
+ def test_get_metadata_writes_bytes(self):
+ self.assertIsInstance(self._get_written_line(), six.binary_type)
+
+ def test_get_metadata_line_starts_with_v2(self):
+ self.assertStartsWith(self._get_written_line(), b'V2')
+
+ def test_get_metadata_uses_get_command(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ self.assertEqual('GET', parts[4])
+
+ def test_get_metadata_base64_encodes_argument(self):
+ key = 'my_key'
+ parts = self._get_written_line(key).decode('ascii').strip().split(' ')
+ self.assertEqual(b64e(key), parts[5])
+
+ def test_get_metadata_calculates_length_correctly(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ expected_length = len(' '.join(parts[3:]))
+ self.assertEqual(expected_length, int(parts[1]))
+
+ def test_get_metadata_uses_appropriate_request_id(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ request_id = parts[3]
+ self.assertEqual(8, len(request_id))
+ self.assertEqual(request_id, request_id.lower())
+
+ def test_get_metadata_uses_random_number_for_request_id(self):
+ line = self._get_written_line()
+ request_id = line.decode('ascii').strip().split(' ')[3]
+ self.assertEqual('{0:08x}'.format(self.request_id), request_id)
+
+ def test_get_metadata_checksums_correctly(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ expected_checksum = '{0:08x}'.format(
+ crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff)
+ checksum = parts[2]
+ self.assertEqual(expected_checksum, checksum)
+
+ def test_get_metadata_reads_a_line(self):
+ client = self._get_client()
+ client.get_metadata('some_key')
+ self.assertEqual(1, self.serial.readline.call_count)
+
+ def test_get_metadata_returns_valid_value(self):
+ client = self._get_client()
+ value = client.get_metadata('some_key')
+ self.assertEqual(self.metadata_value, value)
+
+ def test_get_metadata_throws_exception_for_incorrect_length(self):
+ self.response_parts['length'] = 0
+ client = self._get_client()
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_throws_exception_for_incorrect_crc(self):
+ self.response_parts['crc'] = 'deadbeef'
+ client = self._get_client()
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_throws_exception_for_request_id_mismatch(self):
+ self.response_parts['request_id'] = 'deadbeef'
+ client = self._get_client()
+ client._checksum = lambda _: self.response_parts['crc']
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_returns_None_if_value_not_found(self):
+ self.response_parts['payload'] = ''
+ self.response_parts['command'] = 'NOTFOUND'
+ self.response_parts['length'] = 17
+ client = self._get_client()
+ client._checksum = lambda _: self.response_parts['crc']
+ self.assertIsNone(client.get_metadata('some_key'))