summaryrefslogtreecommitdiff
path: root/cloudinit/sources/DataSourceSmartOS.py
diff options
context:
space:
mode:
Diffstat (limited to 'cloudinit/sources/DataSourceSmartOS.py')
-rw-r--r--cloudinit/sources/DataSourceSmartOS.py80
1 files changed, 72 insertions, 8 deletions
diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py
index 1ce20c10..1cf9e4f0 100644
--- a/cloudinit/sources/DataSourceSmartOS.py
+++ b/cloudinit/sources/DataSourceSmartOS.py
@@ -27,6 +27,7 @@
#
+import base64
from cloudinit import log as logging
from cloudinit import sources
from cloudinit import util
@@ -34,7 +35,7 @@ import os
import os.path
import serial
-
+DS_NAME = 'SmartOS'
DEF_TTY_LOC = '/dev/ttyS1'
DEF_TTY_TIMEOUT = 60
LOG = logging.getLogger(__name__)
@@ -49,15 +50,24 @@ SMARTOS_ATTRIB_MAP = {
'motd_sys_info': ('motd_sys_info', True),
}
+# These are values which will never be base64 encoded.
+SMARTOS_NO_BASE64 = ['root_authorized_keys', 'motd_sys_info',
+ 'iptables_disable']
+
class DataSourceSmartOS(sources.DataSource):
def __init__(self, sys_cfg, distro, paths):
sources.DataSource.__init__(self, sys_cfg, distro, paths)
self.seed_dir = os.path.join(paths.seed_dir, 'sdc')
self.is_smartdc = None
+ self.base_64_encoded = []
self.seed = self.sys_cfg.get("serial_device", DEF_TTY_LOC)
+ self.all_base64 = self.sys_cfg.get("decode_base64", False)
self.seed_timeout = self.sys_cfg.get("serial_timeout",
DEF_TTY_TIMEOUT)
+ self.smartos_no_base64 = SMARTOS_NO_BASE64
+ if 'no_base64_decode' in self.ds_cfg:
+ self.smartos_no_base64 = self.ds_cfg['no_base64_decode']
def __str__(self):
root = sources.DataSource.__str__(self)
@@ -84,17 +94,41 @@ class DataSourceSmartOS(sources.DataSource):
self.is_smartdc = True
md['instance-id'] = system_uuid
+ self.base_64_encoded = query_data('base_64_enocded',
+ self.seed,
+ self.seed_timeout,
+ strip=True)
+ if self.base_64_encoded:
+ self.base_64_encoded = str(self.base_64_encoded).split(',')
+ else:
+ self.base_64_encoded = []
+
+ if not self.all_base64:
+ self.all_base64 = util.is_true(query_data('meta_encoded_base64',
+ self.seed,
+ self.seed_timeout,
+ strip=True))
+
for ci_noun, attribute in SMARTOS_ATTRIB_MAP.iteritems():
smartos_noun, strip = attribute
+
+ b64encoded = False
+ if self.all_base64 and \
+ (smartos_noun not in self.smartos_no_base64 and \
+ ci_noun not in self.smartos_no_base64):
+ b64encoded = True
+
md[ci_noun] = query_data(smartos_noun, self.seed,
- self.seed_timeout, strip=strip)
+ self.seed_timeout, strip=strip,
+ b64encoded=b64encoded)
if not md['local-hostname']:
md['local-hostname'] = system_uuid
+ ud = None
if md['user-data']:
ud = md['user-data']
- else:
+ elif md['user-script']:
ud = md['user-script']
self.metadata = md
@@ -104,10 +138,25 @@ class DataSourceSmartOS(sources.DataSource):
def get_instance_id(self):
return self.metadata['instance-id']
+ def not_b64_var(self, var):
+ """Return true if value is read as b64."""
+ if var in self.smartos_no_base64 or \
+ not self.all_base64:
+ return True
+ return False
+
+ def is_b64_var(self, var):
+ """Return true if value is read as b64."""
+ if self.all_base64 or (
+ var not in self.smartos_no_base64 and
+ var in self.base_64_encoded):
+ return True
+ return False
+
def get_serial(seed_device, seed_timeout):
"""This is replaced in unit testing, allowing us to replace
- serial.Serial with a mocked class
+ serial.Serial with a mocked class.
The timeout value of 60 seconds should never be hit. The value
is taken from SmartOS own provisioning tools. Since we are reading
@@ -124,12 +173,17 @@ def get_serial(seed_device, seed_timeout):
return ser
-def query_data(noun, seed_device, seed_timeout, strip=False):
+def query_data(noun, seed_device, seed_timeout, strip=False, b64encoded=False):
"""Makes a request to via the serial console via "GET <NOUN>"
In the response, the first line is the status, while subsequent lines
are is the value. A blank line with a "." is used to indicate end of
response.
+
+ If the response is expected to be base64 encoded, then set b64encoded
+ to true. Unfortantely, there is no way to know if something is 100%
+ encoded, so this method relies on being told if the data is base64 or
+ not.
"""
if not noun:
@@ -153,12 +207,22 @@ def query_data(noun, seed_device, seed_timeout, strip=False):
response.append(m)
ser.close()
+
+ resp = None
if not strip:
- return "".join(response)
+ resp = "".join(response)
+ elif b64encoded:
+ resp = "".join(response).rstrip()
else:
- return "".join(response).rstrip()
+ resp = "".join(response).rstrip()
+
+ if b64encoded:
+ try:
+ return base64.b64decode(resp)
+ except TypeError:
+ return resp
- return None
+ return resp
def dmi_data():