summaryrefslogtreecommitdiff
path: root/tests/unittests/test_datasource
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittests/test_datasource')
-rw-r--r--tests/unittests/test_datasource/test_azure.py217
-rw-r--r--tests/unittests/test_datasource/test_configdrive.py15
-rw-r--r--tests/unittests/test_datasource/test_digitalocean.py3
-rw-r--r--tests/unittests/test_datasource/test_gce.py54
-rw-r--r--tests/unittests/test_datasource/test_maas.py8
-rw-r--r--tests/unittests/test_datasource/test_nocloud.py14
-rw-r--r--tests/unittests/test_datasource/test_openstack.py8
-rw-r--r--tests/unittests/test_datasource/test_smartos.py228
8 files changed, 414 insertions, 133 deletions
diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py
index 8112c69b..7e789853 100644
--- a/tests/unittests/test_datasource/test_azure.py
+++ b/tests/unittests/test_datasource/test_azure.py
@@ -116,9 +116,6 @@ class TestAzureDataSource(TestCase):
data['iid_from_shared_cfg'] = path
return 'i-my-azure-id'
- def _apply_hostname_bounce(**kwargs):
- data['apply_hostname_bounce'] = kwargs
-
if data.get('ovfcontent') is not None:
populate_dir(os.path.join(self.paths.seed_dir, "azure"),
{'ovf-env.xml': data['ovfcontent']})
@@ -132,7 +129,9 @@ class TestAzureDataSource(TestCase):
(mod, 'wait_for_files', _wait_for_files),
(mod, 'pubkeys_from_crt_files', _pubkeys_from_crt_files),
(mod, 'iid_from_shared_config', _iid_from_shared_config),
- (mod, 'apply_hostname_bounce', _apply_hostname_bounce),
+ (mod, 'perform_hostname_bounce', mock.MagicMock()),
+ (mod, 'get_hostname', mock.MagicMock()),
+ (mod, 'set_hostname', mock.MagicMock()),
])
dsrc = mod.DataSourceAzureNet(
@@ -272,47 +271,6 @@ class TestAzureDataSource(TestCase):
for mypk in mypklist:
self.assertIn(mypk, dsrc.cfg['_pubkeys'])
- def test_disabled_bounce(self):
- pass
-
- def test_apply_bounce_call_1(self):
- # hostname needs to get through to apply_hostname_bounce
- odata = {'HostName': 'my-random-hostname'}
- data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
-
- self._get_ds(data).get_data()
- self.assertIn('hostname', data['apply_hostname_bounce'])
- self.assertEqual(data['apply_hostname_bounce']['hostname'],
- odata['HostName'])
-
- def test_apply_bounce_call_configurable(self):
- # hostname_bounce should be configurable in datasource cfg
- cfg = {'hostname_bounce': {'interface': 'eth1', 'policy': 'off',
- 'command': 'my-bounce-command',
- 'hostname_command': 'my-hostname-command'}}
- odata = {'HostName': "xhost",
- 'dscfg': {'text': b64e(yaml.dump(cfg)),
- 'encoding': 'base64'}}
- data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
- self._get_ds(data).get_data()
-
- for k in cfg['hostname_bounce']:
- self.assertIn(k, data['apply_hostname_bounce'])
-
- for k, v in cfg['hostname_bounce'].items():
- self.assertEqual(data['apply_hostname_bounce'][k], v)
-
- def test_set_hostname_disabled(self):
- # config specifying set_hostname off should not bounce
- cfg = {'set_hostname': False}
- odata = {'HostName': "xhost",
- 'dscfg': {'text': b64e(yaml.dump(cfg)),
- 'encoding': 'base64'}}
- data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
- self._get_ds(data).get_data()
-
- self.assertEqual(data.get('apply_hostname_bounce', "N/A"), "N/A")
-
def test_default_ephemeral(self):
# make sure the ephemeral device works
odata = {}
@@ -425,6 +383,175 @@ class TestAzureDataSource(TestCase):
load_file(os.path.join(self.waagent_d, 'ovf-env.xml')))
+class TestAzureBounce(TestCase):
+
+ def mock_out_azure_moving_parts(self):
+ self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'invoke_agent'))
+ self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'wait_for_files'))
+ self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'iid_from_shared_config',
+ mock.MagicMock(return_value='i-my-azure-id')))
+ self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'list_possible_azure_ds_devs',
+ mock.MagicMock(return_value=[])))
+ self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'find_ephemeral_disk',
+ mock.MagicMock(return_value=None)))
+ self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'find_ephemeral_part',
+ mock.MagicMock(return_value=None)))
+
+ def setUp(self):
+ super(TestAzureBounce, self).setUp()
+ self.tmp = tempfile.mkdtemp()
+ self.waagent_d = os.path.join(self.tmp, 'var', 'lib', 'waagent')
+ self.paths = helpers.Paths({'cloud_dir': self.tmp})
+ self.addCleanup(shutil.rmtree, self.tmp)
+ DataSourceAzure.BUILTIN_DS_CONFIG['data_dir'] = self.waagent_d
+ self.patches = ExitStack()
+ self.mock_out_azure_moving_parts()
+ self.get_hostname = self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'get_hostname'))
+ self.set_hostname = self.patches.enter_context(
+ mock.patch.object(DataSourceAzure, 'set_hostname'))
+ self.subp = self.patches.enter_context(
+ mock.patch('cloudinit.sources.DataSourceAzure.util.subp'))
+
+ def tearDown(self):
+ self.patches.close()
+
+ def _get_ds(self, ovfcontent=None):
+ if ovfcontent is not None:
+ populate_dir(os.path.join(self.paths.seed_dir, "azure"),
+ {'ovf-env.xml': ovfcontent})
+ return DataSourceAzure.DataSourceAzureNet(
+ {}, distro=None, paths=self.paths)
+
+ def get_ovf_env_with_dscfg(self, hostname, cfg):
+ odata = {
+ 'HostName': hostname,
+ 'dscfg': {
+ 'text': b64e(yaml.dump(cfg)),
+ 'encoding': 'base64'
+ }
+ }
+ return construct_valid_ovf_env(data=odata)
+
+ def test_disabled_bounce_does_not_change_hostname(self):
+ cfg = {'hostname_bounce': {'policy': 'off'}}
+ self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
+ self.assertEqual(0, self.set_hostname.call_count)
+
+ @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
+ def test_disabled_bounce_does_not_perform_bounce(
+ self, perform_hostname_bounce):
+ cfg = {'hostname_bounce': {'policy': 'off'}}
+ self._get_ds(self.get_ovf_env_with_dscfg('test-host', cfg)).get_data()
+ self.assertEqual(0, perform_hostname_bounce.call_count)
+
+ def test_same_hostname_does_not_change_hostname(self):
+ host_name = 'unchanged-host-name'
+ self.get_hostname.return_value = host_name
+ cfg = {'hostname_bounce': {'policy': 'yes'}}
+ self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
+ self.assertEqual(0, self.set_hostname.call_count)
+
+ @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
+ def test_unchanged_hostname_does_not_perform_bounce(
+ self, perform_hostname_bounce):
+ host_name = 'unchanged-host-name'
+ self.get_hostname.return_value = host_name
+ cfg = {'hostname_bounce': {'policy': 'yes'}}
+ self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
+ self.assertEqual(0, perform_hostname_bounce.call_count)
+
+ @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
+ def test_force_performs_bounce_regardless(self, perform_hostname_bounce):
+ host_name = 'unchanged-host-name'
+ self.get_hostname.return_value = host_name
+ cfg = {'hostname_bounce': {'policy': 'force'}}
+ self._get_ds(self.get_ovf_env_with_dscfg(host_name, cfg)).get_data()
+ self.assertEqual(1, perform_hostname_bounce.call_count)
+
+ def test_different_hostnames_sets_hostname(self):
+ expected_hostname = 'azure-expected-host-name'
+ self.get_hostname.return_value = 'default-host-name'
+ self._get_ds(
+ self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
+ self.assertEqual(expected_hostname,
+ self.set_hostname.call_args_list[0][0][0])
+
+ @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
+ def test_different_hostnames_performs_bounce(
+ self, perform_hostname_bounce):
+ expected_hostname = 'azure-expected-host-name'
+ self.get_hostname.return_value = 'default-host-name'
+ self._get_ds(
+ self.get_ovf_env_with_dscfg(expected_hostname, {})).get_data()
+ self.assertEqual(1, perform_hostname_bounce.call_count)
+
+ def test_different_hostnames_sets_hostname_back(self):
+ initial_host_name = 'default-host-name'
+ self.get_hostname.return_value = initial_host_name
+ self._get_ds(
+ self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
+ self.assertEqual(initial_host_name,
+ self.set_hostname.call_args_list[-1][0][0])
+
+ @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
+ def test_failure_in_bounce_still_resets_host_name(
+ self, perform_hostname_bounce):
+ perform_hostname_bounce.side_effect = Exception
+ initial_host_name = 'default-host-name'
+ self.get_hostname.return_value = initial_host_name
+ self._get_ds(
+ self.get_ovf_env_with_dscfg('some-host-name', {})).get_data()
+ self.assertEqual(initial_host_name,
+ self.set_hostname.call_args_list[-1][0][0])
+
+ def test_environment_correct_for_bounce_command(self):
+ interface = 'int0'
+ hostname = 'my-new-host'
+ old_hostname = 'my-old-host'
+ self.get_hostname.return_value = old_hostname
+ cfg = {'hostname_bounce': {'interface': interface, 'policy': 'force'}}
+ data = self.get_ovf_env_with_dscfg(hostname, cfg)
+ self._get_ds(data).get_data()
+ self.assertEqual(1, self.subp.call_count)
+ bounce_env = self.subp.call_args[1]['env']
+ self.assertEqual(interface, bounce_env['interface'])
+ self.assertEqual(hostname, bounce_env['hostname'])
+ self.assertEqual(old_hostname, bounce_env['old_hostname'])
+
+ def test_default_bounce_command_used_by_default(self):
+ cmd = 'default-bounce-command'
+ DataSourceAzure.BUILTIN_DS_CONFIG['hostname_bounce']['command'] = cmd
+ cfg = {'hostname_bounce': {'policy': 'force'}}
+ data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
+ self._get_ds(data).get_data()
+ self.assertEqual(1, self.subp.call_count)
+ bounce_args = self.subp.call_args[1]['args']
+ self.assertEqual(cmd, bounce_args)
+
+ @mock.patch('cloudinit.sources.DataSourceAzure.perform_hostname_bounce')
+ def test_set_hostname_option_can_disable_bounce(
+ self, perform_hostname_bounce):
+ cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
+ data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
+ self._get_ds(data).get_data()
+
+ self.assertEqual(0, perform_hostname_bounce.call_count)
+
+ def test_set_hostname_option_can_disable_hostname_set(self):
+ cfg = {'set_hostname': False, 'hostname_bounce': {'policy': 'force'}}
+ data = self.get_ovf_env_with_dscfg('some-hostname', cfg)
+ self._get_ds(data).get_data()
+
+ self.assertEqual(0, self.set_hostname.call_count)
+
+
class TestReadAzureOvf(TestCase):
def test_invalid_xml_raises_non_azure_ds(self):
invalid_xml = "<foo>" + construct_valid_ovf_env(data={})
diff --git a/tests/unittests/test_datasource/test_configdrive.py b/tests/unittests/test_datasource/test_configdrive.py
index e28bdd84..83aca505 100644
--- a/tests/unittests/test_datasource/test_configdrive.py
+++ b/tests/unittests/test_datasource/test_configdrive.py
@@ -2,6 +2,7 @@ from copy import copy
import json
import os
import shutil
+import six
import tempfile
try:
@@ -45,7 +46,7 @@ EC2_META = {
'reservation-id': 'r-iru5qm4m',
'security-groups': ['default']
}
-USER_DATA = '#!/bin/sh\necho This is user data\n'
+USER_DATA = b'#!/bin/sh\necho This is user data\n'
OSTACK_META = {
'availability_zone': 'nova',
'files': [{'content_path': '/content/0000', 'path': '/etc/foo.cfg'},
@@ -56,8 +57,8 @@ OSTACK_META = {
'public_keys': {'mykey': PUBKEY},
'uuid': 'b0fa911b-69d4-4476-bbe2-1c92bff6535c'}
-CONTENT_0 = 'This is contents of /etc/foo.cfg\n'
-CONTENT_1 = '# this is /etc/bar/bar.cfg\n'
+CONTENT_0 = b'This is contents of /etc/foo.cfg\n'
+CONTENT_1 = b'# this is /etc/bar/bar.cfg\n'
CFG_DRIVE_FILES_V2 = {
'ec2/2009-04-04/meta-data.json': json.dumps(EC2_META),
@@ -346,8 +347,12 @@ def populate_dir(seed_dir, files):
dirname = os.path.dirname(path)
if not os.path.isdir(dirname):
os.makedirs(dirname)
- with open(path, "w") as fp:
+ if isinstance(content, six.text_type):
+ mode = "w"
+ else:
+ mode = "wb"
+
+ with open(path, mode) as fp:
fp.write(content)
- fp.close()
# vi: ts=4 expandtab
diff --git a/tests/unittests/test_datasource/test_digitalocean.py b/tests/unittests/test_datasource/test_digitalocean.py
index 98f9cfac..679d1b82 100644
--- a/tests/unittests/test_datasource/test_digitalocean.py
+++ b/tests/unittests/test_datasource/test_digitalocean.py
@@ -15,7 +15,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-import httpretty
import re
from six.moves.urllib_parse import urlparse
@@ -26,6 +25,8 @@ from cloudinit.sources import DataSourceDigitalOcean
from .. import helpers as test_helpers
+httpretty = test_helpers.import_httpretty()
+
# Abbreviated for the test
DO_INDEX = """id
hostname
diff --git a/tests/unittests/test_datasource/test_gce.py b/tests/unittests/test_datasource/test_gce.py
index 6dd4b5ed..1fb100f7 100644
--- a/tests/unittests/test_datasource/test_gce.py
+++ b/tests/unittests/test_datasource/test_gce.py
@@ -15,7 +15,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-import httpretty
import re
from base64 import b64encode, b64decode
@@ -27,12 +26,14 @@ from cloudinit.sources import DataSourceGCE
from .. import helpers as test_helpers
+httpretty = test_helpers.import_httpretty()
+
GCE_META = {
'instance/id': '123',
'instance/zone': 'foo/bar',
'project/attributes/sshKeys': 'user:ssh-rsa AA2..+aRD0fyVw== root@server',
'instance/hostname': 'server.project-foo.local',
- 'instance/attributes/user-data': '/bin/echo foo\n',
+ 'instance/attributes/user-data': b'/bin/echo foo\n',
}
GCE_META_PARTIAL = {
@@ -112,10 +113,6 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
self.assertEqual(GCE_META.get('instance/attributes/user-data'),
self.ds.get_userdata_raw())
- # we expect a list of public ssh keys with user names stripped
- self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'],
- self.ds.get_public_ssh_keys())
-
# test partial metadata (missing user-data in particular)
@httpretty.activate
def test_metadata_partial(self):
@@ -140,3 +137,48 @@ class TestDataSourceGCE(test_helpers.HttprettyTestCase):
decoded = b64decode(
GCE_META_ENCODING.get('instance/attributes/user-data'))
self.assertEqual(decoded, self.ds.get_userdata_raw())
+
+ @httpretty.activate
+ def test_missing_required_keys_return_false(self):
+ for required_key in ['instance/id', 'instance/zone',
+ 'instance/hostname']:
+ meta = GCE_META_PARTIAL.copy()
+ del meta[required_key]
+ httpretty.register_uri(httpretty.GET, MD_URL_RE,
+ body=_new_request_callback(meta))
+ self.assertEqual(False, self.ds.get_data())
+ httpretty.reset()
+
+ @httpretty.activate
+ def test_project_level_ssh_keys_are_used(self):
+ httpretty.register_uri(httpretty.GET, MD_URL_RE,
+ body=_new_request_callback())
+ self.ds.get_data()
+
+ # we expect a list of public ssh keys with user names stripped
+ self.assertEqual(['ssh-rsa AA2..+aRD0fyVw== root@server'],
+ self.ds.get_public_ssh_keys())
+
+ @httpretty.activate
+ def test_instance_level_ssh_keys_are_used(self):
+ key_content = 'ssh-rsa JustAUser root@server'
+ meta = GCE_META.copy()
+ meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content)
+
+ httpretty.register_uri(httpretty.GET, MD_URL_RE,
+ body=_new_request_callback(meta))
+ self.ds.get_data()
+
+ self.assertIn(key_content, self.ds.get_public_ssh_keys())
+
+ @httpretty.activate
+ def test_instance_level_keys_replace_project_level_keys(self):
+ key_content = 'ssh-rsa JustAUser root@server'
+ meta = GCE_META.copy()
+ meta['instance/attributes/sshKeys'] = 'user:{0}'.format(key_content)
+
+ httpretty.register_uri(httpretty.GET, MD_URL_RE,
+ body=_new_request_callback(meta))
+ self.ds.get_data()
+
+ self.assertEqual([key_content], self.ds.get_public_ssh_keys())
diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py
index d25e1adc..f109bb04 100644
--- a/tests/unittests/test_datasource/test_maas.py
+++ b/tests/unittests/test_datasource/test_maas.py
@@ -26,7 +26,7 @@ class TestMAASDataSource(TestCase):
data = {'instance-id': 'i-valid01',
'local-hostname': 'valid01-hostname',
- 'user-data': 'valid01-userdata',
+ 'user-data': b'valid01-userdata',
'public-keys': 'ssh-rsa AAAAB3Nz...aC1yc2E= keyname'}
my_d = os.path.join(self.tmp, "valid")
@@ -46,7 +46,7 @@ class TestMAASDataSource(TestCase):
data = {'instance-id': 'i-valid-extra',
'local-hostname': 'valid-extra-hostname',
- 'user-data': 'valid-extra-userdata', 'foo': 'bar'}
+ 'user-data': b'valid-extra-userdata', 'foo': 'bar'}
my_d = os.path.join(self.tmp, "valid_extra")
populate_dir(my_d, data)
@@ -103,7 +103,7 @@ class TestMAASDataSource(TestCase):
'meta-data/instance-id': 'i-instanceid',
'meta-data/local-hostname': 'test-hostname',
'meta-data/public-keys': 'test-hostname',
- 'user-data': 'foodata',
+ 'user-data': b'foodata',
}
valid_order = [
'meta-data/local-hostname',
@@ -143,7 +143,7 @@ class TestMAASDataSource(TestCase):
userdata, metadata = DataSourceMAAS.read_maas_seed_url(
my_seed, header_cb=my_headers_cb, version=my_ver)
- self.assertEqual("foodata", userdata)
+ self.assertEqual(b"foodata", userdata)
self.assertEqual(metadata['instance-id'],
valid['meta-data/instance-id'])
self.assertEqual(metadata['local-hostname'],
diff --git a/tests/unittests/test_datasource/test_nocloud.py b/tests/unittests/test_datasource/test_nocloud.py
index 4f967f58..85b4c25a 100644
--- a/tests/unittests/test_datasource/test_nocloud.py
+++ b/tests/unittests/test_datasource/test_nocloud.py
@@ -37,7 +37,7 @@ class TestNoCloudDataSource(TestCase):
def test_nocloud_seed_dir(self):
md = {'instance-id': 'IID', 'dsmode': 'local'}
- ud = "USER_DATA_HERE"
+ ud = b"USER_DATA_HERE"
populate_dir(os.path.join(self.paths.seed_dir, "nocloud"),
{'user-data': ud, 'meta-data': yaml.safe_dump(md)})
@@ -92,20 +92,20 @@ class TestNoCloudDataSource(TestCase):
data = {
'fs_label': None,
'meta-data': yaml.safe_dump({'instance-id': 'IID'}),
- 'user-data': "USER_DATA_RAW",
+ 'user-data': b"USER_DATA_RAW",
}
sys_cfg = {'datasource': {'NoCloud': data}}
dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
ret = dsrc.get_data()
- self.assertEqual(dsrc.userdata_raw, "USER_DATA_RAW")
+ self.assertEqual(dsrc.userdata_raw, b"USER_DATA_RAW")
self.assertEqual(dsrc.metadata.get('instance-id'), 'IID')
self.assertTrue(ret)
def test_nocloud_seed_with_vendordata(self):
md = {'instance-id': 'IID', 'dsmode': 'local'}
- ud = "USER_DATA_HERE"
- vd = "THIS IS MY VENDOR_DATA"
+ ud = b"USER_DATA_HERE"
+ vd = b"THIS IS MY VENDOR_DATA"
populate_dir(os.path.join(self.paths.seed_dir, "nocloud"),
{'user-data': ud, 'meta-data': yaml.safe_dump(md),
@@ -126,7 +126,7 @@ class TestNoCloudDataSource(TestCase):
def test_nocloud_no_vendordata(self):
populate_dir(os.path.join(self.paths.seed_dir, "nocloud"),
- {'user-data': "ud", 'meta-data': "instance-id: IID\n"})
+ {'user-data': b"ud", 'meta-data': "instance-id: IID\n"})
sys_cfg = {'datasource': {'NoCloud': {'fs_label': None}}}
@@ -134,7 +134,7 @@ class TestNoCloudDataSource(TestCase):
dsrc = ds(sys_cfg=sys_cfg, distro=None, paths=self.paths)
ret = dsrc.get_data()
- self.assertEqual(dsrc.userdata_raw, "ud")
+ self.assertEqual(dsrc.userdata_raw, b"ud")
self.assertFalse(dsrc.vendordata)
self.assertTrue(ret)
diff --git a/tests/unittests/test_datasource/test_openstack.py b/tests/unittests/test_datasource/test_openstack.py
index 81ef1546..0aa1ba84 100644
--- a/tests/unittests/test_datasource/test_openstack.py
+++ b/tests/unittests/test_datasource/test_openstack.py
@@ -31,7 +31,7 @@ from cloudinit.sources import DataSourceOpenStack as ds
from cloudinit.sources.helpers import openstack
from cloudinit import util
-import httpretty as hp
+hp = test_helpers.import_httpretty()
BASE_URL = "http://169.254.169.254"
PUBKEY = u'ssh-rsa AAAAB3NzaC1....sIkJhq8wdX+4I3A4cYbYP ubuntu@server-460\n'
@@ -49,7 +49,7 @@ EC2_META = {
'public-ipv4': '0.0.0.1',
'reservation-id': 'r-iru5qm4m',
}
-USER_DATA = '#!/bin/sh\necho This is user data\n'
+USER_DATA = b'#!/bin/sh\necho This is user data\n'
VENDOR_DATA = {
'magic': '',
}
@@ -63,8 +63,8 @@ OSTACK_META = {
'public_keys': {'mykey': PUBKEY},
'uuid': 'b0fa911b-69d4-4476-bbe2-1c92bff6535c',
}
-CONTENT_0 = 'This is contents of /etc/foo.cfg\n'
-CONTENT_1 = '# this is /etc/bar/bar.cfg\n'
+CONTENT_0 = b'This is contents of /etc/foo.cfg\n'
+CONTENT_1 = b'# this is /etc/bar/bar.cfg\n'
OS_FILES = {
'openstack/latest/meta_data.json': json.dumps(OSTACK_META),
'openstack/latest/user_data': USER_DATA,
diff --git a/tests/unittests/test_datasource/test_smartos.py b/tests/unittests/test_datasource/test_smartos.py
index 8b62b1b1..adee9019 100644
--- a/tests/unittests/test_datasource/test_smartos.py
+++ b/tests/unittests/test_datasource/test_smartos.py
@@ -24,18 +24,28 @@
from __future__ import print_function
-from cloudinit import helpers as c_helpers
-from cloudinit.sources import DataSourceSmartOS
-from cloudinit.util import b64e
-from .. import helpers
import os
import os.path
import re
import shutil
-import tempfile
import stat
+import tempfile
import uuid
+from binascii import crc32
+
+import serial
+import six
+from cloudinit import helpers as c_helpers
+from cloudinit.sources import DataSourceSmartOS
+from cloudinit.util import b64e
+
+from .. import helpers
+
+try:
+ from unittest import mock
+except ImportError:
+ import mock
MOCK_RETURNS = {
'hostname': 'test-host',
@@ -54,60 +64,15 @@ MOCK_RETURNS = {
DMI_DATA_RETURN = (str(uuid.uuid4()), 'smartdc')
-class MockSerial(object):
- """Fake a serial terminal for testing the code that
- interfaces with the serial"""
-
- port = None
+def get_mock_client(mockdata):
+ class MockMetadataClient(object):
- def __init__(self, mockdata):
- self.last = None
- self.last = None
- self.new = True
- self.count = 0
- self.mocked_out = []
- self.mockdata = mockdata
+ def __init__(self, serial):
+ pass
- def open(self):
- return True
-
- def close(self):
- return True
-
- def isOpen(self):
- return True
-
- def write(self, line):
- line = line.replace('GET ', '')
- self.last = line.rstrip()
-
- def readline(self):
- if self.new:
- self.new = False
- if self.last in self.mockdata:
- return 'SUCCESS\n'
- else:
- return 'NOTFOUND %s\n' % self.last
-
- if self.last in self.mockdata:
- if not self.mocked_out:
- self.mocked_out = [x for x in self._format_out()]
-
- if len(self.mocked_out) > self.count:
- self.count += 1
- return self.mocked_out[self.count - 1]
-
- def _format_out(self):
- if self.last in self.mockdata:
- _mret = self.mockdata[self.last]
- try:
- for l in _mret.splitlines():
- yield "%s\n" % l.rstrip()
- except:
- yield "%s\n" % _mret.rstrip()
-
- yield '.'
- yield '\n'
+ def get_metadata(self, metadata_key):
+ return mockdata.get(metadata_key)
+ return MockMetadataClient
class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
@@ -155,9 +120,6 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
if dmi_data is None:
dmi_data = DMI_DATA_RETURN
- def _get_serial(*_):
- return MockSerial(mockdata)
-
def _dmi_data():
return dmi_data
@@ -174,7 +136,9 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
sys_cfg['datasource']['SmartOS'] = ds_cfg
self.apply_patches([(mod, 'LEGACY_USER_D', self.legacy_user_d)])
- self.apply_patches([(mod, 'get_serial', _get_serial)])
+ self.apply_patches([(mod, 'get_serial', mock.MagicMock())])
+ self.apply_patches([
+ (mod, 'JoyentMetadataClient', get_mock_client(mockdata))])
self.apply_patches([(mod, 'dmi_data', _dmi_data)])
self.apply_patches([(os, 'uname', _os_uname)])
self.apply_patches([(mod, 'device_exists', lambda d: True)])
@@ -443,6 +407,18 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
self.assertEqual(dsrc.device_name_to_device('FOO'),
mydscfg['disk_aliases']['FOO'])
+ @mock.patch('cloudinit.sources.DataSourceSmartOS.JoyentMetadataClient')
+ @mock.patch('cloudinit.sources.DataSourceSmartOS.get_serial')
+ def test_serial_console_closed_on_error(self, get_serial, metadata_client):
+ class OurException(Exception):
+ pass
+ metadata_client.side_effect = OurException
+ try:
+ DataSourceSmartOS.query_data('noun', 'device', 0)
+ except OurException:
+ pass
+ self.assertEqual(1, get_serial.return_value.close.call_count)
+
def apply_patches(patches):
ret = []
@@ -453,3 +429,133 @@ def apply_patches(patches):
setattr(ref, name, replace)
ret.append((ref, name, orig))
return ret
+
+
+class TestJoyentMetadataClient(helpers.FilesystemMockingTestCase):
+
+ def setUp(self):
+ super(TestJoyentMetadataClient, self).setUp()
+ self.serial = mock.MagicMock(spec=serial.Serial)
+ self.request_id = 0xabcdef12
+ self.metadata_value = 'value'
+ self.response_parts = {
+ 'command': 'SUCCESS',
+ 'crc': 'b5a9ff00',
+ 'length': 17 + len(b64e(self.metadata_value)),
+ 'payload': b64e(self.metadata_value),
+ 'request_id': '{0:08x}'.format(self.request_id),
+ }
+
+ def make_response():
+ payload = ''
+ if self.response_parts['payload']:
+ payload = ' {0}'.format(self.response_parts['payload'])
+ del self.response_parts['payload']
+ return (
+ 'V2 {length} {crc} {request_id} {command}{payload}\n'.format(
+ payload=payload, **self.response_parts).encode('ascii'))
+ self.serial.readline.side_effect = make_response
+ self.patched_funcs.enter_context(
+ mock.patch('cloudinit.sources.DataSourceSmartOS.random.randint',
+ mock.Mock(return_value=self.request_id)))
+
+ def _get_client(self):
+ return DataSourceSmartOS.JoyentMetadataClient(self.serial)
+
+ def assertEndsWith(self, haystack, prefix):
+ self.assertTrue(haystack.endswith(prefix),
+ "{0} does not end with '{1}'".format(
+ repr(haystack), prefix))
+
+ def assertStartsWith(self, haystack, prefix):
+ self.assertTrue(haystack.startswith(prefix),
+ "{0} does not start with '{1}'".format(
+ repr(haystack), prefix))
+
+ def test_get_metadata_writes_a_single_line(self):
+ client = self._get_client()
+ client.get_metadata('some_key')
+ self.assertEqual(1, self.serial.write.call_count)
+ written_line = self.serial.write.call_args[0][0]
+ self.assertEndsWith(written_line, b'\n')
+ self.assertEqual(1, written_line.count(b'\n'))
+
+ def _get_written_line(self, key='some_key'):
+ client = self._get_client()
+ client.get_metadata(key)
+ return self.serial.write.call_args[0][0]
+
+ def test_get_metadata_writes_bytes(self):
+ self.assertIsInstance(self._get_written_line(), six.binary_type)
+
+ def test_get_metadata_line_starts_with_v2(self):
+ self.assertStartsWith(self._get_written_line(), b'V2')
+
+ def test_get_metadata_uses_get_command(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ self.assertEqual('GET', parts[4])
+
+ def test_get_metadata_base64_encodes_argument(self):
+ key = 'my_key'
+ parts = self._get_written_line(key).decode('ascii').strip().split(' ')
+ self.assertEqual(b64e(key), parts[5])
+
+ def test_get_metadata_calculates_length_correctly(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ expected_length = len(' '.join(parts[3:]))
+ self.assertEqual(expected_length, int(parts[1]))
+
+ def test_get_metadata_uses_appropriate_request_id(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ request_id = parts[3]
+ self.assertEqual(8, len(request_id))
+ self.assertEqual(request_id, request_id.lower())
+
+ def test_get_metadata_uses_random_number_for_request_id(self):
+ line = self._get_written_line()
+ request_id = line.decode('ascii').strip().split(' ')[3]
+ self.assertEqual('{0:08x}'.format(self.request_id), request_id)
+
+ def test_get_metadata_checksums_correctly(self):
+ parts = self._get_written_line().decode('ascii').strip().split(' ')
+ expected_checksum = '{0:08x}'.format(
+ crc32(' '.join(parts[3:]).encode('utf-8')) & 0xffffffff)
+ checksum = parts[2]
+ self.assertEqual(expected_checksum, checksum)
+
+ def test_get_metadata_reads_a_line(self):
+ client = self._get_client()
+ client.get_metadata('some_key')
+ self.assertEqual(1, self.serial.readline.call_count)
+
+ def test_get_metadata_returns_valid_value(self):
+ client = self._get_client()
+ value = client.get_metadata('some_key')
+ self.assertEqual(self.metadata_value, value)
+
+ def test_get_metadata_throws_exception_for_incorrect_length(self):
+ self.response_parts['length'] = 0
+ client = self._get_client()
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_throws_exception_for_incorrect_crc(self):
+ self.response_parts['crc'] = 'deadbeef'
+ client = self._get_client()
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_throws_exception_for_request_id_mismatch(self):
+ self.response_parts['request_id'] = 'deadbeef'
+ client = self._get_client()
+ client._checksum = lambda _: self.response_parts['crc']
+ self.assertRaises(DataSourceSmartOS.JoyentMetadataFetchException,
+ client.get_metadata, 'some_key')
+
+ def test_get_metadata_returns_None_if_value_not_found(self):
+ self.response_parts['payload'] = ''
+ self.response_parts['command'] = 'NOTFOUND'
+ self.response_parts['length'] = 17
+ client = self._get_client()
+ client._checksum = lambda _: self.response_parts['crc']
+ self.assertIsNone(client.get_metadata('some_key'))