diff options
author | Mike Gerdts <mike.gerdts@joyent.com> | 2018-04-18 13:55:17 -0400 |
---|---|---|
committer | Scott Moser <smoser@brickies.net> | 2018-04-18 13:55:17 -0400 |
commit | 4c573d0e0173d2b1e99a383c54a0a6c957aa1cbb (patch) | |
tree | 299a60fe312a0f0752f70a968da45269a7c8ab5a | |
parent | 025ddc0329d9314f131cea35075734916797b439 (diff) | |
download | vyos-cloud-init-4c573d0e0173d2b1e99a383c54a0a6c957aa1cbb.tar.gz vyos-cloud-init-4c573d0e0173d2b1e99a383c54a0a6c957aa1cbb.zip |
DataSourceSmartOS: fix hang when metadata service is down
If the metadata service in the host is down while a guest that uses
DataSourceSmartOS is booting, the request from the guest falls into the
bit bucket. When the metadata service is eventually started, the guest
has no awareness of this and does not resend the request. This results in
cloud-init hanging forever with a guest reboot as the only recovery
option.
This fix updates the metadata protocol to implement the initialization
phase, just as is implemented by mdata-get and related utilities. The
initialization phase includes draining all pending data from the serial
port, writing an empty command and getting an expected error message in
reply. If the initialization phase times out, it is retried every five
seconds. Each timeout results in a warning message: "Timeout while
initializing metadata client. Is the host metadata service running?" By
default, warning messages are logged to the console, thus the reason for a
hung boot is readily apparent.
LP: #1667735
-rw-r--r-- | cloudinit/sources/DataSourceSmartOS.py | 117 | ||||
-rw-r--r-- | tests/unittests/test_datasource/test_smartos.py | 103 |
2 files changed, 204 insertions, 16 deletions
diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 5dd8a125..c8998b40 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -1,4 +1,5 @@ # Copyright (C) 2013 Canonical Ltd. +# Copyright (c) 2018, Joyent, Inc. # # Author: Ben Howard <ben.howard@canonical.com> # @@ -21,6 +22,7 @@ import base64 import binascii +import errno import json import os import random @@ -229,6 +231,9 @@ class DataSourceSmartOS(sources.DataSource): self.md_client) return False + # Open once for many requests, rather than once for each request + self.md_client.open_transport() + for ci_noun, attribute in SMARTOS_ATTRIB_MAP.items(): smartos_noun, strip = attribute md[ci_noun] = self.md_client.get(smartos_noun, strip=strip) @@ -236,6 +241,8 @@ class DataSourceSmartOS(sources.DataSource): for ci_noun, smartos_noun in SMARTOS_ATTRIB_JSON.items(): md[ci_noun] = self.md_client.get_json(smartos_noun) + self.md_client.close_transport() + # @datadictionary: This key may contain a program that is written # to a file in the filesystem of the guest on each boot and then # executed. It may be of any format that would be considered @@ -316,6 +323,10 @@ class JoyentMetadataFetchException(Exception): pass +class JoyentMetadataTimeoutException(JoyentMetadataFetchException): + pass + + class JoyentMetadataClient(object): """ A client implementing v2 of the Joyent Metadata Protocol Specification. @@ -360,6 +371,47 @@ class JoyentMetadataClient(object): LOG.debug('Value "%s" found.', value) return value + def _readline(self): + """ + Reads a line a byte at a time until \n is encountered. Returns an + ascii string with the trailing newline removed. + + If a timeout (per-byte) is set and it expires, a + JoyentMetadataFetchException will be thrown. + """ + response = [] + + def as_ascii(): + return b''.join(response).decode('ascii') + + msg = "Partial response: '%s'" + while True: + try: + byte = self.fp.read(1) + if len(byte) == 0: + raise JoyentMetadataTimeoutException(msg % as_ascii()) + if byte == b'\n': + return as_ascii() + response.append(byte) + except OSError as exc: + if exc.errno == errno.EAGAIN: + raise JoyentMetadataTimeoutException(msg % as_ascii()) + raise + + def _write(self, msg): + self.fp.write(msg.encode('ascii')) + self.fp.flush() + + def _negotiate(self): + LOG.debug('Negotiating protocol V2') + self._write('NEGOTIATE V2\n') + response = self._readline() + LOG.debug('read "%s"', response) + if response != 'V2_OK': + raise JoyentMetadataFetchException( + 'Invalid response "%s" to "NEGOTIATE V2"' % response) + LOG.debug('Negotiation complete') + def request(self, rtype, param=None): request_id = '{0:08x}'.format(random.randint(0, 0xffffffff)) message_body = ' '.join((request_id, rtype,)) @@ -374,18 +426,11 @@ class JoyentMetadataClient(object): self.open_transport() need_close = True - self.fp.write(msg.encode('ascii')) - self.fp.flush() - - response = bytearray() - response.extend(self.fp.read(1)) - while response[-1:] != b'\n': - response.extend(self.fp.read(1)) - + self._write(msg) + response = self._readline() if need_close: self.close_transport() - response = response.rstrip().decode('ascii') LOG.debug('Read "%s" from metadata transport.', response) if 'SUCCESS' not in response: @@ -450,6 +495,7 @@ class JoyentMetadataSocketClient(JoyentMetadataClient): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.connect(self.socketpath) self.fp = sock.makefile('rwb') + self._negotiate() def exists(self): return os.path.exists(self.socketpath) @@ -459,8 +505,9 @@ class JoyentMetadataSocketClient(JoyentMetadataClient): class JoyentMetadataSerialClient(JoyentMetadataClient): - def __init__(self, device, timeout=10, smartos_type=SMARTOS_ENV_KVM): - super(JoyentMetadataSerialClient, self).__init__(smartos_type) + def __init__(self, device, timeout=10, smartos_type=SMARTOS_ENV_KVM, + fp=None): + super(JoyentMetadataSerialClient, self).__init__(smartos_type, fp) self.device = device self.timeout = timeout @@ -468,10 +515,50 @@ class JoyentMetadataSerialClient(JoyentMetadataClient): return os.path.exists(self.device) def open_transport(self): - ser = serial.Serial(self.device, timeout=self.timeout) - if not ser.isOpen(): - raise SystemError("Unable to open %s" % self.device) - self.fp = ser + if self.fp is None: + ser = serial.Serial(self.device, timeout=self.timeout) + if not ser.isOpen(): + raise SystemError("Unable to open %s" % self.device) + self.fp = ser + self._flush() + self._negotiate() + + def _flush(self): + LOG.debug('Flushing input') + # Read any pending data + timeout = self.fp.timeout + self.fp.timeout = 0.1 + while True: + try: + self._readline() + except JoyentMetadataTimeoutException: + break + LOG.debug('Input empty') + + # Send a newline and expect "invalid command". Keep trying until + # successful. Retry rather frequently so that the "Is the host + # metadata service running" appears on the console soon after someone + # attaches in an effort to debug. + if timeout > 5: + self.fp.timeout = 5 + else: + self.fp.timeout = timeout + while True: + LOG.debug('Writing newline, expecting "invalid command"') + self._write('\n') + try: + response = self._readline() + if response == 'invalid command': + break + if response == 'FAILURE': + LOG.debug('Got "FAILURE". Retrying.') + continue + LOG.warning('Unexpected response "%s" during flush', response) + except JoyentMetadataTimeoutException: + LOG.warning('Timeout while initializing metadata client. ' + + 'Is the host metadata service running?') + LOG.debug('Got "invalid command". Flush complete.') + self.fp.timeout = timeout def __repr__(self): return "%s(device=%s, timeout=%s)" % ( diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py index 88bae5f9..2bea7a17 100644 --- a/tests/unittests/test_datasource/test_smartos.py +++ b/tests/unittests/test_datasource/test_smartos.py @@ -1,4 +1,5 @@ # Copyright (C) 2013 Canonical Ltd. +# Copyright (c) 2018, Joyent, Inc. # # Author: Ben Howard <ben.howard@canonical.com> # @@ -324,6 +325,7 @@ class PsuedoJoyentClient(object): if data is None: data = MOCK_RETURNS.copy() self.data = data + self._is_open = False return def get(self, key, default=None, strip=False): @@ -344,6 +346,14 @@ class PsuedoJoyentClient(object): def exists(self): return True + def open_transport(self): + assert(not self._is_open) + self._is_open = True + + def close_transport(self): + assert(self._is_open) + self._is_open = False + class TestSmartOSDataSource(FilesystemMockingTestCase): def setUp(self): @@ -592,8 +602,46 @@ class TestSmartOSDataSource(FilesystemMockingTestCase): mydscfg['disk_aliases']['FOO']) +class ShortReader(object): + """Implements a 'read' interface for bytes provided. + much like io.BytesIO but the 'endbyte' acts as if EOF. + When it is reached a short will be returned.""" + def __init__(self, initial_bytes, endbyte=b'\0'): + self.data = initial_bytes + self.index = 0 + self.len = len(self.data) + self.endbyte = endbyte + + @property + def emptied(self): + return self.index >= self.len + + def read(self, size=-1): + """Read size bytes but not past a null.""" + if size == 0 or self.index >= self.len: + return b'' + + rsize = size + if size < 0 or size + self.index > self.len: + rsize = self.len - self.index + + next_null = self.data.find(self.endbyte, self.index, rsize) + if next_null >= 0: + rsize = next_null - self.index + 1 + i = self.index + self.index += rsize + ret = self.data[i:i + rsize] + if len(ret) and ret[-1:] == self.endbyte: + ret = ret[:-1] + return ret + + class TestJoyentMetadataClient(FilesystemMockingTestCase): + invalid = b'invalid command\n' + failure = b'FAILURE\n' + v2_ok = b'V2_OK\n' + def setUp(self): super(TestJoyentMetadataClient, self).setUp() @@ -636,6 +684,11 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): return DataSourceSmartOS.JoyentMetadataClient( fp=self.serial, smartos_type=DataSourceSmartOS.SMARTOS_ENV_KVM) + def _get_serial_client(self): + self.serial.timeout = 1 + return DataSourceSmartOS.JoyentMetadataSerialClient(None, + fp=self.serial) + def assertEndsWith(self, haystack, prefix): self.assertTrue(haystack.endswith(prefix), "{0} does not end with '{1}'".format( @@ -646,12 +699,14 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): "{0} does not start with '{1}'".format( repr(haystack), prefix)) + def assertNoMoreSideEffects(self, obj): + self.assertRaises(StopIteration, obj) + def test_get_metadata_writes_a_single_line(self): client = self._get_client() client.get('some_key') self.assertEqual(1, self.serial.write.call_count) written_line = self.serial.write.call_args[0][0] - print(type(written_line)) self.assertEndsWith(written_line.decode('ascii'), b'\n'.decode('ascii')) self.assertEqual(1, written_line.count(b'\n')) @@ -737,6 +792,52 @@ class TestJoyentMetadataClient(FilesystemMockingTestCase): client._checksum = lambda _: self.response_parts['crc'] self.assertIsNone(client.get('some_key')) + def test_negotiate(self): + client = self._get_client() + reader = ShortReader(self.v2_ok) + client.fp.read.side_effect = reader.read + client._negotiate() + self.assertTrue(reader.emptied) + + def test_negotiate_short_response(self): + client = self._get_client() + # chopped '\n' from v2_ok. + reader = ShortReader(self.v2_ok[:-1] + b'\0') + client.fp.read.side_effect = reader.read + self.assertRaises(DataSourceSmartOS.JoyentMetadataTimeoutException, + client._negotiate) + self.assertTrue(reader.emptied) + + def test_negotiate_bad_response(self): + client = self._get_client() + reader = ShortReader(b'garbage\n' + self.v2_ok) + client.fp.read.side_effect = reader.read + self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException, + client._negotiate) + self.assertEqual(self.v2_ok, client.fp.read()) + + def test_serial_open_transport(self): + client = self._get_serial_client() + reader = ShortReader(b'garbage\0' + self.invalid + self.v2_ok) + client.fp.read.side_effect = reader.read + client.open_transport() + self.assertTrue(reader.emptied) + + def test_flush_failure(self): + client = self._get_serial_client() + reader = ShortReader(b'garbage' + b'\0' + self.failure + + self.invalid + self.v2_ok) + client.fp.read.side_effect = reader.read + client.open_transport() + self.assertTrue(reader.emptied) + + def test_flush_many_timeouts(self): + client = self._get_serial_client() + reader = ShortReader(b'\0' * 100 + self.invalid + self.v2_ok) + client.fp.read.side_effect = reader.read + client.open_transport() + self.assertTrue(reader.emptied) + class TestNetworkConversion(TestCase): def test_convert_simple(self): |