diff options
Diffstat (limited to 'cloudinit/sources')
-rw-r--r-- | cloudinit/sources/DataSourceAzure.py | 132 | ||||
-rw-r--r-- | cloudinit/sources/DataSourceGCE.py | 92 | ||||
-rw-r--r-- | cloudinit/sources/DataSourceNoCloud.py | 5 | ||||
-rw-r--r-- | cloudinit/sources/DataSourceOpenNebula.py | 1 | ||||
-rw-r--r-- | cloudinit/sources/DataSourceSmartOS.py | 90 |
5 files changed, 195 insertions, 125 deletions
diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 6e030217..a19d9ca2 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -17,6 +17,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. import base64 +import contextlib import crypt import fnmatch import os @@ -66,6 +67,36 @@ DS_CFG_PATH = ['datasource', DS_NAME] DEF_EPHEMERAL_LABEL = 'Temporary Storage' +def get_hostname(hostname_command='hostname'): + return util.subp(hostname_command, capture=True)[0].strip() + + +def set_hostname(hostname, hostname_command='hostname'): + util.subp([hostname_command, hostname]) + + +@contextlib.contextmanager +def temporary_hostname(temp_hostname, cfg, hostname_command='hostname'): + """ + Set a temporary hostname, restoring the previous hostname on exit. + + Will have the value of the previous hostname when used as a context + manager, or None if the hostname was not changed. + """ + policy = cfg['hostname_bounce']['policy'] + previous_hostname = get_hostname(hostname_command) + if (not util.is_true(cfg.get('set_hostname')) + or util.is_false(policy) + or (previous_hostname == temp_hostname and policy != 'force')): + yield None + return + set_hostname(temp_hostname, hostname_command) + try: + yield previous_hostname + finally: + set_hostname(previous_hostname, hostname_command) + + class DataSourceAzureNet(sources.DataSource): def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) @@ -154,33 +185,40 @@ class DataSourceAzureNet(sources.DataSource): # the directory to be protected. write_files(ddir, files, dirmode=0o700) - # handle the hostname 'publishing' - try: - handle_set_hostname(mycfg.get('set_hostname'), - self.metadata.get('local-hostname'), - mycfg['hostname_bounce']) - except Exception as e: - LOG.warn("Failed publishing hostname: %s", e) - util.logexc(LOG, "handling set_hostname failed") - - try: - invoke_agent(mycfg['agent_command']) - except util.ProcessExecutionError: - # claim the datasource even if the command failed - util.logexc(LOG, "agent command '%s' failed.", - mycfg['agent_command']) - - shcfgxml = os.path.join(ddir, "SharedConfig.xml") - wait_for = [shcfgxml] - - fp_files = [] - for pk in self.cfg.get('_pubkeys', []): - bname = str(pk['fingerprint'] + ".crt") - fp_files += [os.path.join(ddir, bname)] + temp_hostname = self.metadata.get('local-hostname') + hostname_command = mycfg['hostname_bounce']['hostname_command'] + with temporary_hostname(temp_hostname, mycfg, + hostname_command=hostname_command) \ + as previous_hostname: + if (previous_hostname is not None + and util.is_true(mycfg.get('set_hostname'))): + cfg = mycfg['hostname_bounce'] + try: + perform_hostname_bounce(hostname=temp_hostname, + cfg=cfg, + prev_hostname=previous_hostname) + except Exception as e: + LOG.warn("Failed publishing hostname: %s", e) + util.logexc(LOG, "handling set_hostname failed") - missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", - func=wait_for_files, - args=(wait_for + fp_files,)) + try: + invoke_agent(mycfg['agent_command']) + except util.ProcessExecutionError: + # claim the datasource even if the command failed + util.logexc(LOG, "agent command '%s' failed.", + mycfg['agent_command']) + + shcfgxml = os.path.join(ddir, "SharedConfig.xml") + wait_for = [shcfgxml] + + fp_files = [] + for pk in self.cfg.get('_pubkeys', []): + bname = str(pk['fingerprint'] + ".crt") + fp_files += [os.path.join(ddir, bname)] + + missing = util.log_time(logfunc=LOG.debug, msg="waiting for files", + func=wait_for_files, + args=(wait_for + fp_files,)) if len(missing): LOG.warn("Did not find files, but going on: %s", missing) @@ -299,39 +337,15 @@ def support_new_ephemeral(cfg): return mod_list -def handle_set_hostname(enabled, hostname, cfg): - if not util.is_true(enabled): - return - - if not hostname: - LOG.warn("set_hostname was true but no local-hostname") - return - - apply_hostname_bounce(hostname=hostname, policy=cfg['policy'], - interface=cfg['interface'], - command=cfg['command'], - hostname_command=cfg['hostname_command']) - - -def apply_hostname_bounce(hostname, policy, interface, command, - hostname_command="hostname"): +def perform_hostname_bounce(hostname, cfg, prev_hostname): # set the hostname to 'hostname' if it is not already set to that. # then, if policy is not off, bounce the interface using command - prev_hostname = util.subp(hostname_command, capture=True)[0].strip() - - util.subp([hostname_command, hostname]) - - msg = ("phostname=%s hostname=%s policy=%s interface=%s" % - (prev_hostname, hostname, policy, interface)) - - if util.is_false(policy): - LOG.debug("pubhname: policy false, skipping [%s]", msg) - return - - if prev_hostname == hostname and policy != "force": - LOG.debug("pubhname: no change, policy != force. skipping. [%s]", msg) - return + command = cfg['command'] + interface = cfg['interface'] + policy = cfg['policy'] + msg = ("hostname=%s policy=%s interface=%s" % + (hostname, policy, interface)) env = os.environ.copy() env['interface'] = interface env['hostname'] = hostname @@ -344,9 +358,9 @@ def apply_hostname_bounce(hostname, policy, interface, command, shell = not isinstance(command, (list, tuple)) # capture=False, see comments in bug 1202758 and bug 1206164. util.log_time(logfunc=LOG.debug, msg="publishing hostname", - get_uptime=True, func=util.subp, - kwargs={'args': command, 'shell': shell, 'capture': False, - 'env': env}) + get_uptime=True, func=util.subp, + kwargs={'args': command, 'shell': shell, 'capture': False, + 'env': env}) def crtfile_to_pubkey(fname): diff --git a/cloudinit/sources/DataSourceGCE.py b/cloudinit/sources/DataSourceGCE.py index 608c07f1..f4ed915d 100644 --- a/cloudinit/sources/DataSourceGCE.py +++ b/cloudinit/sources/DataSourceGCE.py @@ -30,6 +30,31 @@ BUILTIN_DS_CONFIG = { REQUIRED_FIELDS = ('instance-id', 'availability-zone', 'local-hostname') +class GoogleMetadataFetcher(object): + headers = {'X-Google-Metadata-Request': True} + + def __init__(self, metadata_address): + self.metadata_address = metadata_address + + def get_value(self, path, is_text): + value = None + try: + resp = url_helper.readurl(url=self.metadata_address + path, + headers=self.headers) + except url_helper.UrlError as exc: + msg = "url %s raised exception %s" + LOG.debug(msg, path, exc) + else: + if resp.code == 200: + if is_text: + value = util.decode_binary(resp.contents) + else: + value = resp.contents + else: + LOG.debug("url %s returned code %s", path, resp.code) + return value + + class DataSourceGCE(sources.DataSource): def __init__(self, sys_cfg, distro, paths): sources.DataSource.__init__(self, sys_cfg, distro, paths) @@ -50,17 +75,15 @@ class DataSourceGCE(sources.DataSource): return public_key def get_data(self): - # GCE metadata server requires a custom header since v1 - headers = {'X-Google-Metadata-Request': True} - # url_map: (our-key, path, required, is_text) url_map = [ - ('instance-id', 'instance/id', True, True), - ('availability-zone', 'instance/zone', True, True), - ('local-hostname', 'instance/hostname', True, True), - ('public-keys', 'project/attributes/sshKeys', False, True), - ('user-data', 'instance/attributes/user-data', False, False), - ('user-data-encoding', 'instance/attributes/user-data-encoding', + ('instance-id', ('instance/id',), True, True), + ('availability-zone', ('instance/zone',), True, True), + ('local-hostname', ('instance/hostname',), True, True), + ('public-keys', ('project/attributes/sshKeys', + 'instance/attributes/sshKeys'), False, True), + ('user-data', ('instance/attributes/user-data',), False, False), + ('user-data-encoding', ('instance/attributes/user-data-encoding',), False, True), ] @@ -69,40 +92,25 @@ class DataSourceGCE(sources.DataSource): LOG.debug("%s is not resolvable", self.metadata_address) return False + metadata_fetcher = GoogleMetadataFetcher(self.metadata_address) # iterate over url_map keys to get metadata items - found = False - for (mkey, path, required, is_text) in url_map: - try: - resp = url_helper.readurl(url=self.metadata_address + path, - headers=headers) - if resp.code == 200: - found = True - if is_text: - self.metadata[mkey] = util.decode_binary(resp.contents) - else: - self.metadata[mkey] = resp.contents + running_on_gce = False + for (mkey, paths, required, is_text) in url_map: + value = None + for path in paths: + new_value = metadata_fetcher.get_value(path, is_text) + if new_value is not None: + value = new_value + if value: + running_on_gce = True + if required and value is None: + msg = "required key %s returned nothing. not GCE" + if not running_on_gce: + LOG.debug(msg, mkey) else: - if required: - msg = "required url %s returned code %s. not GCE" - if not found: - LOG.debug(msg, path, resp.code) - else: - LOG.warn(msg, path, resp.code) - return False - else: - self.metadata[mkey] = None - except url_helper.UrlError as e: - if required: - msg = "required url %s raised exception %s. not GCE" - if not found: - LOG.debug(msg, path, e) - else: - LOG.warn(msg, path, e) - return False - msg = "Failed to get %s metadata item: %s." - LOG.debug(msg, path, e) - - self.metadata[mkey] = None + LOG.warn(msg, mkey) + return False + self.metadata[mkey] = value if self.metadata['public-keys']: lines = self.metadata['public-keys'].splitlines() @@ -116,7 +124,7 @@ class DataSourceGCE(sources.DataSource): else: LOG.warn('unknown user-data-encoding: %s, ignoring', encoding) - return found + return running_on_gce @property def launch_index(self): diff --git a/cloudinit/sources/DataSourceNoCloud.py b/cloudinit/sources/DataSourceNoCloud.py index c26a645c..6a861af3 100644 --- a/cloudinit/sources/DataSourceNoCloud.py +++ b/cloudinit/sources/DataSourceNoCloud.py @@ -124,7 +124,7 @@ class DataSourceNoCloud(sources.DataSource): # that is more likely to be what is desired. If they want # dsmode of local, then they must specify that. if 'dsmode' not in mydata['meta-data']: - mydata['dsmode'] = "net" + mydata['meta-data']['dsmode'] = "net" LOG.debug("Using data from %s", dev) found.append(dev) @@ -193,7 +193,8 @@ class DataSourceNoCloud(sources.DataSource): self.vendordata = mydata['vendor-data'] return True - LOG.debug("%s: not claiming datasource, dsmode=%s", self, md['dsmode']) + LOG.debug("%s: not claiming datasource, dsmode=%s", self, + mydata['meta-data']['dsmode']) return False diff --git a/cloudinit/sources/DataSourceOpenNebula.py b/cloudinit/sources/DataSourceOpenNebula.py index 61709c1b..ac2c3b45 100644 --- a/cloudinit/sources/DataSourceOpenNebula.py +++ b/cloudinit/sources/DataSourceOpenNebula.py @@ -24,7 +24,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. -import base64 import os import pwd import re diff --git a/cloudinit/sources/DataSourceSmartOS.py b/cloudinit/sources/DataSourceSmartOS.py index 896fde3f..c9b497df 100644 --- a/cloudinit/sources/DataSourceSmartOS.py +++ b/cloudinit/sources/DataSourceSmartOS.py @@ -29,9 +29,12 @@ # http://us-east.manta.joyent.com/jmc/public/mdata/datadict.html # Comments with "@datadictionary" are snippets of the definition -import base64 import binascii +import contextlib import os +import random +import re + import serial from cloudinit import log as logging @@ -301,6 +304,65 @@ def get_serial(seed_device, seed_timeout): return ser +class JoyentMetadataFetchException(Exception): + pass + + +class JoyentMetadataClient(object): + """ + A client implementing v2 of the Joyent Metadata Protocol Specification. + + The full specification can be found at + http://eng.joyent.com/mdata/protocol.html + """ + line_regex = re.compile( + r'V2 (?P<length>\d+) (?P<checksum>[0-9a-f]+)' + r' (?P<body>(?P<request_id>[0-9a-f]+) (?P<status>SUCCESS|NOTFOUND)' + r'( (?P<payload>.+))?)') + + def __init__(self, serial): + self.serial = serial + + def _checksum(self, body): + return '{0:08x}'.format( + binascii.crc32(body.encode('utf-8')) & 0xffffffff) + + def _get_value_from_frame(self, expected_request_id, frame): + frame_data = self.line_regex.match(frame).groupdict() + if int(frame_data['length']) != len(frame_data['body']): + raise JoyentMetadataFetchException( + 'Incorrect frame length given ({0} != {1}).'.format( + frame_data['length'], len(frame_data['body']))) + expected_checksum = self._checksum(frame_data['body']) + if frame_data['checksum'] != expected_checksum: + raise JoyentMetadataFetchException( + 'Invalid checksum (expected: {0}; got {1}).'.format( + expected_checksum, frame_data['checksum'])) + if frame_data['request_id'] != expected_request_id: + raise JoyentMetadataFetchException( + 'Request ID mismatch (expected: {0}; got {1}).'.format( + expected_request_id, frame_data['request_id'])) + if not frame_data.get('payload', None): + LOG.debug('No value found.') + return None + value = util.b64d(frame_data['payload']) + LOG.debug('Value "%s" found.', value) + return value + + def get_metadata(self, metadata_key): + LOG.debug('Fetching metadata key "%s"...', metadata_key) + request_id = '{0:08x}'.format(random.randint(0, 0xffffffff)) + message_body = '{0} GET {1}'.format(request_id, + util.b64e(metadata_key)) + msg = 'V2 {0} {1} {2}\n'.format( + len(message_body), self._checksum(message_body), message_body) + LOG.debug('Writing "%s" to serial port.', msg) + self.serial.write(msg.encode('ascii')) + response = self.serial.readline().decode('ascii') + LOG.debug('Read "%s" from serial port.', response) + return self._get_value_from_frame(request_id, response) + + def query_data(noun, seed_device, seed_timeout, strip=False, default=None, b64=None): """Makes a request to via the serial console via "GET <NOUN>" @@ -314,34 +376,20 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None, encoded, so this method relies on being told if the data is base64 or not. """ - if not noun: return False - ser = get_serial(seed_device, seed_timeout) - request_line = "GET %s\n" % noun.rstrip() - ser.write(request_line.encode('ascii')) - status = str(ser.readline()).rstrip() - response = [] - eom_found = False + with contextlib.closing(get_serial(seed_device, seed_timeout)) as ser: + client = JoyentMetadataClient(ser) + response = client.get_metadata(noun) - if 'SUCCESS' not in status: - ser.close() + if response is None: return default - while not eom_found: - m = ser.readline().decode('ascii') - if m.rstrip() == ".": - eom_found = True - else: - response.append(m) - - ser.close() - if b64 is None: b64 = query_data('b64-%s' % noun, seed_device=seed_device, - seed_timeout=seed_timeout, b64=False, - default=False, strip=True) + seed_timeout=seed_timeout, b64=False, + default=False, strip=True) b64 = util.is_true(b64) resp = None |