summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cloudinit/sources/DataSourceSmartOS.py75
-rw-r--r--tests/unittests/test_datasource/test_smartos.py216
2 files changed, 239 insertions, 52 deletions
diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py
index 896fde3f..694a011a 100644
--- a/cloudinit/sources/DataSourceSmartOS.py
+++ b/cloudinit/sources/DataSourceSmartOS.py
@@ -29,9 +29,10 @@
# http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html
# Comments with "@datadictionary" are snippets of the definition
-import base64
import binascii
import os
+import random
+import re
import serial
from cloudinit import log as logging
@@ -301,6 +302,53 @@ def get_serial(seed_device, seed_timeout):
return ser
+class JoyentMetadataFetchException(Exception):
+ pass
+
+
+class JoyentMetadataClient(object):
+
+ def __init__(self, serial):
+ self.serial = serial
+
+ def _checksum(self, body):
+ return '{0:08x}'.format(
+ binascii.crc32(body.encode('utf-8')) & 0xffffffff)
+
+ def _get_value_from_frame(self, expected_request_id, frame):
+ regex = (
+ r'V2 (?P<length>\d+) (?P<checksum>[0-9a-f]+)'
+ r' (?P<body>(?P<request_id>[0-9a-f]+) (?P<status>SUCCESS|NOTFOUND)'
+ r'( (?P<payload>.+))?)')
+ frame_data = re.match(regex, frame).groupdict()
+ if int(frame_data['length']) != len(frame_data['body']):
+ raise JoyentMetadataFetchException(
+ 'Incorrect frame length given ({0} != {1}).'.format(
+ frame_data['length'], len(frame_data['body'])))
+ expected_checksum = self._checksum(frame_data['body'])
+ if frame_data['checksum'] != expected_checksum:
+ raise JoyentMetadataFetchException(
+ 'Invalid checksum (expected: {0}; got {1}).'.format(
+ expected_checksum, frame_data['checksum']))
+ if frame_data['request_id'] != expected_request_id:
+ raise JoyentMetadataFetchException(
+ 'Request ID mismatch (expected: {0}; got {1}).'.format(
+ expected_request_id, frame_data['request_id']))
+ if not frame_data.get('payload', None):
+ return None
+ return util.b64d(frame_data['payload'])
+
+ def get_metadata(self, metadata_key):
+ request_id = '{0:08x}'.format(random.randint(0, 0xffffffff))
+ message_body = '{0} GET {1}'.format(request_id,
+ util.b64e(metadata_key))
+ msg = 'V2 {0} {1} {2}\n'.format(
+ len(message_body), self._checksum(message_body), message_body)
+ self.serial.write(msg.encode('ascii'))
+ response = self.serial.readline().decode('ascii')
+ return self._get_value_from_frame(request_id, response)
+
+
def query_data(noun, seed_device, seed_timeout, strip=False, default=None,
b64=None):
"""Makes a request to via the serial console via "GET <NOUN>"
@@ -314,34 +362,21 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None,
encoded, so this method relies on being told if the data is base64 or
not.
"""
-
if not noun:
return False
ser = get_serial(seed_device, seed_timeout)
- request_line = "GET %s\n" % noun.rstrip()
- ser.write(request_line.encode('ascii'))
- status = str(ser.readline()).rstrip()
- response = []
- eom_found = False
-
- if 'SUCCESS' not in status:
- ser.close()
- return default
-
- while not eom_found:
- m = ser.readline().decode('ascii')
- if m.rstrip() == ".":
- eom_found = True
- else:
- response.append(m)
+ client = JoyentMetadataClient(ser)
+ response = client.get_metadata(noun)
ser.close()
+ if response is None:
+ return default
if b64 is None:
b64 = query_data('b64-%s' % noun, seed_device=seed_device,
- seed_timeout=seed_timeout, b64=False,
- default=False, strip=True)
+ seed_timeout=seed_timeout, b64=False,
+ default=False, strip=True)
b64 = util.is_true(b64)
resp = None
diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py
index cdd83bf8..c79cf3aa 100644
--- a/tests/unittests/test_datasource/test_smartos.py
+++ b/tests/unittests/test_datasource/test_smartos.py
@@ -31,15 +31,24 @@ import shutil
import stat
import tempfile
import uuid
+from binascii import crc32
+
+import serial
+import six
import six
from cloudinit import helpers as c_helpers
from cloudinit.sources import DataSourceSmartOS
-from cloudinit.util import b64e
+from cloudinit.util import b64d, b64e
from .. import helpers
+try:
+ from unittest import mock
+except ImportError:
+ import mock
+
MOCK_RETURNS = {
'hostname': 'test-host',
'root_authorized_keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname',
@@ -57,6 +66,37 @@ MOCK_RETURNS = {
DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc')
+def _checksum(body):
+ return '{0:08x}'.format(crc32(body.encode('utf-8')) & 0xffffffff)
+
+
+def _generate_v2_frame(request_id, command, body=None):
+ body_parts = [request_id, command]
+ if body:
+ body_parts.append(b64e(body))
+ message_body = ' '.join(body_parts)
+ return 'V2 {0} {1} {2}\n'.format(
+ len(message_body), _checksum(message_body), message_body).encode(
+ 'ascii')
+
+
+def _parse_v2_frame(line):
+ line = line.decode('ascii')
+ if not line.endswith('\n'):
+ raise Exception('Frames must end with a newline.')
+ version, length, checksum, body = line.strip().split(' ', 3)
+ if version != 'V2':
+ raise Exception('Frames must begin with V2.')
+ if int(length) != len(body):
+ raise Exception('Incorrect frame length given ({0} != {1}).'.format(
+ length, len(body)))
+ expected_checksum = _checksum(body)
+ if checksum != expected_checksum:
+ raise Exception('Invalid checksum.')
+ request_id, command, payload = body.split()
+ return request_id, command, b64d(payload)
+
+
class MockSerial(object):
"""Fake a serial terminal for testing the code that
interfaces with the serial"""
@@ -81,39 +121,21 @@ class MockSerial(object):
return True
def write(self, line):
- if not isinstance(line, six.binary_type):
- raise TypeError("Should be writing binary lines.")
- line = line.decode('ascii').replace('GET ', '')
- self.last = line.rstrip()
+ self.last = line
def readline(self):
- if self.new:
- self.new = False
- if self.last in self.mockdata:
- line = 'SUCCESS\n'
- else:
- line = 'NOTFOUND %s\n' % self.last
-
- elif self.last in self.mockdata:
- if not self.mocked_out:
- self.mocked_out = [x for x in self._format_out()]
-
- if len(self.mocked_out) > self.count:
- self.count += 1
- line = self.mocked_out[self.count - 1]
- return line.encode('ascii')
-
- def _format_out(self):
- if self.last in self.mockdata:
- _mret = self.mockdata[self.last]
- try:
- for l in _mret.splitlines():
- yield "%s\n" % l.rstrip()
- except:
- yield "%s\n" % _mret.rstrip()
-
- yield '.'
- yield '\n'
+ if self.last == '\n':
+ return 'invalid command\n'
+ elif self.last == 'NEGOTIATE V2\n':
+ return 'V2_OK\n'
+ request_id, command, request_body = _parse_v2_frame(self.last)
+ if command != 'GET':
+ raise Exception('MockSerial only supports GET requests.')
+ metadata_key = request_body.strip()
+ if metadata_key in self.mockdata:
+ return _generate_v2_frame(
+ request_id, 'SUCCESS', self.mockdata[metadata_key])
+ return _generate_v2_frame(request_id, 'NOTFOUND')
class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
@@ -459,3 +481,133 @@ def apply_patches(patches):
setattr(ref, name, replace)
ret.append((ref, name, orig))
return ret
+
+
+class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase):
+
+ def setUp(self):
+ super(TestJoyentMetadataClient, self).setUp()
+ self.serial = mock.MagicMock(spec=serial.Serial)
+ self.request_id = 0xabcdef12
+ self.metadata_value = 'value'
+ self.response_parts = {
+ 'command': 'SUCCESS',
+ 'crc': 'b5a9ff00',
+ 'length': 17 + len(b64e(self.metadata_value)),
+ 'payload': b64e(self.metadata_value),
+ 'request_id': '{0:08x}'.format(self.request_id),
+ }
+
+ def make_response():
+ payload = ''
+ if self.response_parts['payload']:
+ payload = ' {0}'.format(self.response_parts['payload'])
+ del self.response_parts['payload']
+ return (
+ 'V2 {length} {crc} {request_id} {command}{payload}\n'.format(
+ payload=payload, **self.response_parts).encode('ascii'))
+ self.serial.readline.side_effect = make_response
+ self.patched_funcs.enter_context(
+ mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint',
+ mock.Mock(return_value=self.request_id)))
+
+ def _get_client(self):
+ return DataSourceSmartOS.JoyentMetadataClient(self.serial)
+
+ def assertEndsWith(self, haystack, prefix):
+ self.assertTrue(haystack.endswith(prefix),
+ "{0} does not end with '{1}'".format(
+ repr(haystack), prefix))
+
+ def assertStartsWith(self, haystack, prefix):
+ self.assertTrue(haystack.startswith(prefix),
+ "{0} does not start with '{1}'".format(
+ repr(haystack), prefix))
+
+ def test_get_metadata_writes_a_single_line(self):
+ client = self._get_client()
+ client.get_metadata('some_key')
+ self.assertEqual(1, self.serial.write.call_count)
+ written_line = self.serial.write.call_args[0][0]
+ self.assertEndsWith(written_line, b'\n')
+ self.assertEqual(1, written_line.count(b'\n'))
+
+ def _get_written_line(self, key='some_key'):
+ client = self._get_client()
+ client.get_metadata(key)
+ return self.serial.write.call_args[0][0]
+
+ def test_get_metadata_writes_bytes(self):
+ self.assertIsInstance(self._get_written_line(), six.binary_type)
+
+ def test_get_metadata_line_starts_with_v2(self):
+ self.assertStartsWith(self._get_written_line(), b'V2')
+
+ def test_get_metadata_uses_get_command(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ self.assertEqual('GET', parts[4])
+
+ def test_get_metadata_base64_encodes_argument(self):
+ key = 'my_key'
+ parts = self._get_written_line(key).decode('ascii').strip().split(' ')
+ self.assertEqual(b64e(key), parts[5])
+
+ def test_get_metadata_calculates_length_correctly(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ expected_length = len(' '.join(parts[3:]))
+ self.assertEqual(expected_length, int(parts[1]))
+
+ def test_get_metadata_uses_appropriate_request_id(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ request_id = parts[3]
+ self.assertEqual(8, len(request_id))
+ self.assertEqual(request_id, request_id.lower())
+
+ def test_get_metadata_uses_random_number_for_request_id(self):
+ line = self._get_written_line()
+ request_id = line.decode('ascii').strip().split(' ')[3]
+ self.assertEqual('{0:08x}'.format(self.request_id), request_id)
+
+ def test_get_metadata_checksums_correctly(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ expected_checksum = '{0:08x}'.format(
+ crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff)
+ checksum = parts[2]
+ self.assertEqual(expected_checksum, checksum)
+
+ def test_get_metadata_reads_a_line(self):
+ client = self._get_client()
+ client.get_metadata('some_key')
+ self.assertEqual(1, self.serial.readline.call_count)
+
+ def test_get_metadata_returns_valid_value(self):
+ client = self._get_client()
+ value = client.get_metadata('some_key')
+ self.assertEqual(self.metadata_value, value)
+
+ def test_get_metadata_throws_exception_for_incorrect_length(self):
+ self.response_parts['length'] = 0
+ client = self._get_client()
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_throws_exception_for_incorrect_crc(self):
+ self.response_parts['crc'] = 'deadbeef'
+ client = self._get_client()
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_throws_exception_for_request_id_mismatch(self):
+ self.response_parts['request_id'] = 'deadbeef'
+ client = self._get_client()
+ client._checksum = lambda _: self.response_parts['crc']
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_returns_None_if_value_not_found(self):
+ self.response_parts['payload'] = ''
+ self.response_parts['command'] = 'NOTFOUND'
+ self.response_parts['length'] = 17
+ client = self._get_client()
+ client._checksum = lambda _: self.response_parts['crc']
+ self.assertIsNone(client.get_metadata('some_key'))