diff options
Diffstat (limited to 'cloudinit')
-rw-r--r-- | cloudinit/cloud.py | 9 | ||||
-rw-r--r-- | cloudinit/registry.py | 37 | ||||
-rw-r--r-- | cloudinit/reporting/__init__.py | 240 | ||||
-rw-r--r-- | cloudinit/reporting/handlers.py | 90 | ||||
-rw-r--r-- | cloudinit/sources/DataSourceMAAS.py | 146 | ||||
-rw-r--r-- | cloudinit/sources/__init__.py | 25 | ||||
-rw-r--r-- | cloudinit/stages.py | 59 | ||||
-rw-r--r-- | cloudinit/url_helper.py | 140 | ||||
-rw-r--r-- | cloudinit/util.py | 3 |
9 files changed, 620 insertions, 129 deletions
diff --git a/cloudinit/cloud.py b/cloudinit/cloud.py index 95e0cfb2..edee3887 100644 --- a/cloudinit/cloud.py +++ b/cloudinit/cloud.py @@ -24,6 +24,7 @@ import copy import os from cloudinit import log as logging +from cloudinit import reporting LOG = logging.getLogger(__name__) @@ -40,12 +41,18 @@ LOG = logging.getLogger(__name__) class Cloud(object): - def __init__(self, datasource, paths, cfg, distro, runners): + def __init__(self, datasource, paths, cfg, distro, runners, reporter=None): self.datasource = datasource self.paths = paths self.distro = distro self._cfg = cfg self._runners = runners + if reporter is None: + reporter = reporting.ReportEventStack( + name="unnamed-cloud-reporter", + description="unnamed-cloud-reporter", + reporting_enabled=False) + self.reporter = reporter # If a 'user' manipulates logging or logging services # it is typically useful to cause the logging to be diff --git a/cloudinit/registry.py b/cloudinit/registry.py new file mode 100644 index 00000000..04368ddf --- /dev/null +++ b/cloudinit/registry.py @@ -0,0 +1,37 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab +import copy + + +class DictRegistry(object): + """A simple registry for a mapping of objects.""" + + def __init__(self): + self.reset() + + def reset(self): + self._items = {} + + def register_item(self, key, item): + """Add item to the registry.""" + if key in self._items: + raise ValueError( + 'Item already registered with key {0}'.format(key)) + self._items[key] = item + + def unregister_item(self, key, force=True): + """Remove item from the registry.""" + if key in self._items: + del self._items[key] + elif not force: + raise KeyError("%s: key not present to unregister" % key) + + @property + def registered_items(self): + """All the items that have been registered. + + This cannot be used to modify the contents of the registry. + """ + return copy.copy(self._items) diff --git a/cloudinit/reporting/__init__.py b/cloudinit/reporting/__init__.py new file mode 100644 index 00000000..e23fab32 --- /dev/null +++ b/cloudinit/reporting/__init__.py @@ -0,0 +1,240 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init. See LICENCE file for license information. +# +# vi: ts=4 expandtab +""" +cloud-init reporting framework + +The reporting framework is intended to allow all parts of cloud-init to +report events in a structured manner. +""" + +from ..registry import DictRegistry +from ..reporting.handlers import available_handlers + + +FINISH_EVENT_TYPE = 'finish' +START_EVENT_TYPE = 'start' + +DEFAULT_CONFIG = { + 'logging': {'type': 'log'}, +} + + +class _nameset(set): + def __getattr__(self, name): + if name in self: + return name + raise AttributeError("%s not a valid value" % name) + + +status = _nameset(("SUCCESS", "WARN", "FAIL")) + + +class ReportingEvent(object): + """Encapsulation of event formatting.""" + + def __init__(self, event_type, name, description): + self.event_type = event_type + self.name = name + self.description = description + + def as_string(self): + """The event represented as a string.""" + return '{0}: {1}: {2}'.format( + self.event_type, self.name, self.description) + + def as_dict(self): + """The event represented as a dictionary.""" + return {'name': self.name, 'description': self.description, + 'event_type': self.event_type} + + +class FinishReportingEvent(ReportingEvent): + + def __init__(self, name, description, result=status.SUCCESS): + super(FinishReportingEvent, self).__init__( + FINISH_EVENT_TYPE, name, description) + self.result = result + if result not in status: + raise ValueError("Invalid result: %s" % result) + + def as_string(self): + return '{0}: {1}: {2}: {3}'.format( + self.event_type, self.name, self.result, self.description) + + def as_dict(self): + """The event represented as json friendly.""" + data = super(FinishReportingEvent, self).as_dict() + data['result'] = self.result + return data + + +def update_configuration(config): + """Update the instanciated_handler_registry. + + :param config: + The dictionary containing changes to apply. If a key is given + with a False-ish value, the registered handler matching that name + will be unregistered. + """ + for handler_name, handler_config in config.items(): + if not handler_config: + instantiated_handler_registry.unregister_item( + handler_name, force=True) + continue + registered = instantiated_handler_registry.registered_items + handler_config = handler_config.copy() + cls = available_handlers.registered_items[handler_config.pop('type')] + instantiated_handler_registry.unregister_item(handler_name) + instance = cls(**handler_config) + instantiated_handler_registry.register_item(handler_name, instance) + + +def report_event(event): + """Report an event to all registered event handlers. + + This should generally be called via one of the other functions in + the reporting module. + + :param event_type: + The type of the event; this should be a constant from the + reporting module. + """ + for _, handler in instantiated_handler_registry.registered_items.items(): + handler.publish_event(event) + + +def report_finish_event(event_name, event_description, + result=status.SUCCESS): + """Report a "finish" event. + + See :py:func:`.report_event` for parameter details. + """ + event = FinishReportingEvent(event_name, event_description, result) + return report_event(event) + + +def report_start_event(event_name, event_description): + """Report a "start" event. + + :param event_name: + The name of the event; this should be a topic which events would + share (e.g. it will be the same for start and finish events). + + :param event_description: + A human-readable description of the event that has occurred. + """ + event = ReportingEvent(START_EVENT_TYPE, event_name, event_description) + return report_event(event) + + +class ReportEventStack(object): + """Context Manager for using :py:func:`report_event` + + This enables calling :py:func:`report_start_event` and + :py:func:`report_finish_event` through a context manager. + + :param name: + the name of the event + + :param description: + the event's description, passed on to :py:func:`report_start_event` + + :param message: + the description to use for the finish event. defaults to + :param:description. + + :param parent: + :type parent: :py:class:ReportEventStack or None + The parent of this event. The parent is populated with + results of all its children. The name used in reporting + is <parent.name>/<name> + + :param reporting_enabled: + Indicates if reporting events should be generated. + If not provided, defaults to the parent's value, or True if no parent + is provided. + + :param result_on_exception: + The result value to set if an exception is caught. default + value is FAIL. + """ + def __init__(self, name, description, message=None, parent=None, + reporting_enabled=None, result_on_exception=status.FAIL): + self.parent = parent + self.name = name + self.description = description + self.message = message + self.result_on_exception = result_on_exception + self.result = status.SUCCESS + + # use parents reporting value if not provided + if reporting_enabled is None: + if parent: + reporting_enabled = parent.reporting_enabled + else: + reporting_enabled = True + self.reporting_enabled = reporting_enabled + + if parent: + self.fullname = '/'.join((parent.fullname, name,)) + else: + self.fullname = self.name + self.children = {} + + def __repr__(self): + return ("ReportEventStack(%s, %s, reporting_enabled=%s)" % + (self.name, self.description, self.reporting_enabled)) + + def __enter__(self): + self.result = status.SUCCESS + if self.reporting_enabled: + report_start_event(self.fullname, self.description) + if self.parent: + self.parent.children[self.name] = (None, None) + return self + + def _childrens_finish_info(self): + for cand_result in (status.FAIL, status.WARN): + for name, (value, msg) in self.children.items(): + if value == cand_result: + return (value, self.message) + return (self.result, self.message) + + @property + def result(self): + return self._result + + @result.setter + def result(self, value): + if value not in status: + raise ValueError("'%s' not a valid result" % value) + self._result = value + + @property + def message(self): + if self._message is not None: + return self._message + return self.description + + @message.setter + def message(self, value): + self._message = value + + def _finish_info(self, exc): + # return tuple of description, and value + if exc: + return (self.result_on_exception, self.message) + return self._childrens_finish_info() + + def __exit__(self, exc_type, exc_value, traceback): + (result, msg) = self._finish_info(exc_value) + if self.parent: + self.parent.children[self.name] = (result, msg) + if self.reporting_enabled: + report_finish_event(self.fullname, msg, result) + + +instantiated_handler_registry = DictRegistry() +update_configuration(DEFAULT_CONFIG) diff --git a/cloudinit/reporting/handlers.py b/cloudinit/reporting/handlers.py new file mode 100644 index 00000000..1343311f --- /dev/null +++ b/cloudinit/reporting/handlers.py @@ -0,0 +1,90 @@ +# vi: ts=4 expandtab + +import abc +import oauthlib.oauth1 as oauth1 +import six + +from ..registry import DictRegistry +from .. import (url_helper, util) +from .. import log as logging + + +LOG = logging.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class ReportingHandler(object): + """Base class for report handlers. + + Implement :meth:`~publish_event` for controlling what + the handler does with an event. + """ + + @abc.abstractmethod + def publish_event(self, event): + """Publish an event to the ``INFO`` log level.""" + + +class LogHandler(ReportingHandler): + """Publishes events to the cloud-init log at the ``INFO`` log level.""" + + def __init__(self, level="DEBUG"): + super(LogHandler, self).__init__() + if isinstance(level, int): + pass + else: + input_level = level + try: + level = gettattr(logging, level.upper()) + except: + LOG.warn("invalid level '%s', using WARN", input_level) + level = logging.WARN + self.level = level + + def publish_event(self, event): + """Publish an event to the ``INFO`` log level.""" + logger = logging.getLogger( + '.'.join(['cloudinit', 'reporting', event.event_type, event.name])) + logger.log(self.level, event.as_string()) + + +class PrintHandler(ReportingHandler): + def publish_event(self, event): + """Publish an event to the ``INFO`` log level.""" + + +class WebHookHandler(ReportingHandler): + def __init__(self, endpoint, consumer_key=None, token_key=None, + token_secret=None, consumer_secret=None, timeout=None, + retries=None): + super(WebHookHandler, self).__init__() + + if any([consumer_key, token_key, token_secret, consumer_secret]): + self.oauth_helper = url_helper.OauthUrlHelper( + consumer_key=consumer_key, token_key=token_key, + token_secret=token_secret, consumer_secret=consumer_secret) + else: + self.oauth_helper = None + self.endpoint = endpoint + self.timeout = timeout + self.retries = retries + self.ssl_details = util.fetch_ssl_details() + + def publish_event(self, event): + if self.oauth_helper: + readurl = self.oauth_helper.readurl + else: + readurl = url_helper.readurl + try: + return readurl( + self.endpoint, data=event.as_dict(), + timeout=self.timeout, + retries=self.retries, ssl_details=self.ssl_details) + except: + LOG.warn("failed posting event: %s" % event.as_string()) + + +available_handlers = DictRegistry() +available_handlers.register_item('log', LogHandler) +available_handlers.register_item('print', PrintHandler) +available_handlers.register_item('webhook', WebHookHandler) diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index c1a0eb61..2f36bbe2 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -52,7 +52,20 @@ class DataSourceMAAS(sources.DataSource): sources.DataSource.__init__(self, sys_cfg, distro, paths) self.base_url = None self.seed_dir = os.path.join(paths.seed_dir, 'maas') - self.oauth_clockskew = None + self.oauth_helper = self._get_helper() + + def _get_helper(self): + mcfg = self.ds_cfg + # If we are missing token_key, token_secret or consumer_key + # then just do non-authed requests + for required in ('token_key', 'token_secret', 'consumer_key'): + if required not in mcfg: + return url_helper.OauthUrlHelper() + + return url_helper.OauthHelper( + consumer_key=mcfg['consumer_key'], token_key=mcfg['token_key'], + token_secret=mcfg['token_secret'], + consumer_secret=mcfg.get('consumer_secret')) def __str__(self): root = sources.DataSource.__str__(self) @@ -84,9 +97,9 @@ class DataSourceMAAS(sources.DataSource): self.base_url = url - (userdata, metadata) = read_maas_seed_url(self.base_url, - self._md_headers, - paths=self.paths) + (userdata, metadata) = read_maas_seed_url( + self.base_url, self.oauth_helper.md_headers, + paths=self.paths) self.userdata_raw = userdata self.metadata = metadata return True @@ -94,31 +107,8 @@ class DataSourceMAAS(sources.DataSource): util.logexc(LOG, "Failed fetching metadata from url %s", url) return False - def _md_headers(self, url): - mcfg = self.ds_cfg - - # If we are missing token_key, token_secret or consumer_key - # then just do non-authed requests - for required in ('token_key', 'token_secret', 'consumer_key'): - if required not in mcfg: - return {} - - consumer_secret = mcfg.get('consumer_secret', "") - - timestamp = None - if self.oauth_clockskew: - timestamp = int(time.time()) + self.oauth_clockskew - - return oauth_headers(url=url, - consumer_key=mcfg['consumer_key'], - token_key=mcfg['token_key'], - token_secret=mcfg['token_secret'], - consumer_secret=consumer_secret, - timestamp=timestamp) - def wait_for_metadata_service(self, url): mcfg = self.ds_cfg - max_wait = 120 try: max_wait = int(mcfg.get("max_wait", max_wait)) @@ -138,10 +128,8 @@ class DataSourceMAAS(sources.DataSource): starttime = time.time() check_url = "%s/%s/meta-data/instance-id" % (url, MD_VERSION) urls = [check_url] - url = url_helper.wait_for_url(urls=urls, max_wait=max_wait, - timeout=timeout, - exception_cb=self._except_cb, - headers_cb=self._md_headers) + url = self.oauth_helper.wait_for_url( + urls=urls, max_wait=max_wait, timeout=timeout) if url: LOG.debug("Using metadata source: '%s'", url) @@ -151,26 +139,6 @@ class DataSourceMAAS(sources.DataSource): return bool(url) - def _except_cb(self, msg, exception): - if not (isinstance(exception, url_helper.UrlError) and - (exception.code == 403 or exception.code == 401)): - return - - if 'date' not in exception.headers: - LOG.warn("Missing header 'date' in %s response", exception.code) - return - - date = exception.headers['date'] - try: - ret_time = time.mktime(parsedate(date)) - except Exception as e: - LOG.warn("Failed to convert datetime '%s': %s", date, e) - return - - self.oauth_clockskew = int(ret_time - time.time()) - LOG.warn("Setting oauth clockskew to %d", self.oauth_clockskew) - return - def read_maas_seed_dir(seed_d): """ @@ -196,12 +164,12 @@ def read_maas_seed_dir(seed_d): return check_seed_contents(md, seed_d) -def read_maas_seed_url(seed_url, header_cb=None, timeout=None, +def read_maas_seed_url(seed_url, read_file_or_url=None, timeout=None, version=MD_VERSION, paths=None): """ Read the maas datasource at seed_url. - - header_cb is a method that should return a headers dictionary for - a given url + read_file_or_url is a method that should provide an interface + like util.read_file_or_url Expected format of seed_url is are the following files: * <seed_url>/<version>/meta-data/instance-id @@ -222,14 +190,12 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None, 'user-data': "%s/%s" % (base_url, 'user-data'), } + if read_file_or_url is None: + read_file_or_url = util.read_file_or_url + md = {} for name in file_order: url = files.get(name) - if not header_cb: - def _cb(url): - return {} - header_cb = _cb - if name == 'user-data': retries = 0 else: @@ -237,10 +203,8 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None, try: ssl_details = util.fetch_ssl_details(paths) - resp = util.read_file_or_url(url, retries=retries, - headers_cb=header_cb, - timeout=timeout, - ssl_details=ssl_details) + resp = read_file_or_url(url, retries=retries, + timeout=timeout, ssl_details=ssl_details) if resp.ok(): if name in BINARY_FIELDS: md[name] = resp.contents @@ -280,24 +244,6 @@ def check_seed_contents(content, seed): return (userdata, md) -def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, - timestamp=None): - if timestamp: - timestamp = str(timestamp) - else: - timestamp = None - - client = oauth1.Client( - consumer_key, - client_secret=consumer_secret, - resource_owner_key=token_key, - resource_owner_secret=token_secret, - signature_method=oauth1.SIGNATURE_PLAINTEXT, - timestamp=timestamp) - uri, signed_headers, body = client.sign(url) - return signed_headers - - class MAASSeedDirNone(Exception): pass @@ -361,47 +307,39 @@ if __name__ == "__main__": if key in cfg and creds[key] is None: creds[key] = cfg[key] - def geturl(url, headers_cb): - req = Request(url, data=None, headers=headers_cb(url)) - return urlopen(req).read() + oauth_helper = url_helper.OauthUrlHelper(**creds) + + def geturl(url): + return oauth_helper.readurl(url).contents def printurl(url, headers_cb): - print("== %s ==\n%s\n" % (url, geturl(url, headers_cb))) + print("== %s ==\n%s\n" % (url, geturl(url))) - def crawl(url, headers_cb=None): + def crawl(url): if url.endswith("/"): - for line in geturl(url, headers_cb).splitlines(): + for line in geturl(url).splitlines(): if line.endswith("/"): - crawl("%s%s" % (url, line), headers_cb) + crawl("%s%s" % (url, line)) else: - printurl("%s%s" % (url, line), headers_cb) + printurl("%s%s" % (url, line)) else: - printurl(url, headers_cb) - - def my_headers(url): - headers = {} - if creds.get('consumer_key', None) is not None: - headers = oauth_headers(url, **creds) - return headers + printurl(url) if args.subcmd == "check-seed": - if args.url.startswith("http"): - (userdata, metadata) = read_maas_seed_url(args.url, - header_cb=my_headers, - version=args.apiver) - else: - (userdata, metadata) = read_maas_seed_url(args.url) + (userdata, metadata) = read_maas_seed_url( + args.url, read_file_or_url=oauth_helper.read_file_or_url, + version=args.apiver) print("=== userdata ===") print(userdata) print("=== metadata ===") pprint.pprint(metadata) elif args.subcmd == "get": - printurl(args.url, my_headers) + printurl(args.url) elif args.subcmd == "crawl": if not args.url.endswith("/"): args.url = "%s/" % args.url - crawl(args.url, my_headers) + crawl(args.url) main() diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index a21c08c2..838cd198 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -27,6 +27,7 @@ import six from cloudinit import importer from cloudinit import log as logging +from cloudinit import reporting from cloudinit import type_utils from cloudinit import user_data as ud from cloudinit import util @@ -246,17 +247,25 @@ def normalize_pubkey_data(pubkey_data): return keys -def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list): +def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list, reporter): ds_list = list_sources(cfg_list, ds_deps, pkg_list) ds_names = [type_utils.obj_name(f) for f in ds_list] - LOG.debug("Searching for data source in: %s", ds_names) - - for cls in ds_list: + mode = "network" if DEP_NETWORK in ds_deps else "local" + LOG.debug("Searching for %s data source in: %s", mode, ds_names) + + for name, cls in zip(ds_names, ds_list): + myrep = reporting.ReportEventStack( + name="search-%s" % name.replace("DataSource", ""), + description="searching for %s data from %s" % (mode, name), + message="no %s data found from %s" % (mode, name), + parent=reporter) try: - LOG.debug("Seeing if we can get any data from %s", cls) - s = cls(sys_cfg, distro, paths) - if s.get_data(): - return (s, type_utils.obj_name(cls)) + with myrep: + LOG.debug("Seeing if we can get any data from %s", cls) + s = cls(sys_cfg, distro, paths) + if s.get_data(): + myrep.message = "found %s data from %s" % (mode, name) + return (s, type_utils.obj_name(cls)) except Exception: util.logexc(LOG, "Getting data from %s failed", cls) diff --git a/cloudinit/stages.py b/cloudinit/stages.py index d28e765b..d300709d 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -46,6 +46,7 @@ from cloudinit import log as logging from cloudinit import sources from cloudinit import type_utils from cloudinit import util +from cloudinit import reporting LOG = logging.getLogger(__name__) @@ -53,7 +54,7 @@ NULL_DATA_SOURCE = None class Init(object): - def __init__(self, ds_deps=None): + def __init__(self, ds_deps=None, reporter=None): if ds_deps is not None: self.ds_deps = ds_deps else: @@ -65,6 +66,12 @@ class Init(object): # Changed only when a fetch occurs self.datasource = NULL_DATA_SOURCE + if reporter is None: + reporter = reporting.ReportEventStack( + name="init-reporter", description="init-desc", + reporting_enabled=False) + self.reporter = reporter + def _reset(self, reset_ds=False): # Recreated on access self._cfg = None @@ -234,9 +241,17 @@ class Init(object): def _get_data_source(self): if self.datasource is not NULL_DATA_SOURCE: return self.datasource - ds = self._restore_from_cache() - if ds: - LOG.debug("Restored from cache, datasource: %s", ds) + + with reporting.ReportEventStack( + name="check-cache", + description="attempting to read from cache", + parent=self.reporter) as myrep: + ds = self._restore_from_cache() + if ds: + LOG.debug("Restored from cache, datasource: %s", ds) + myrep.description = "restored from cache" + else: + myrep.description = "no cache found" if not ds: (cfg_list, pkg_list) = self._get_datasources() # Deep copy so that user-data handlers can not modify @@ -246,7 +261,7 @@ class Init(object): self.paths, copy.deepcopy(self.ds_deps), cfg_list, - pkg_list) + pkg_list, self.reporter) LOG.info("Loaded datasource %s - %s", dsname, ds) self.datasource = ds # Ensure we adjust our path members datasource @@ -327,7 +342,8 @@ class Init(object): # Form the needed options to cloudify our members return cloud.Cloud(self.datasource, self.paths, self.cfg, - self.distro, helpers.Runners(self.paths)) + self.distro, helpers.Runners(self.paths), + reporter=self.reporter) def update(self): if not self._write_to_cache(): @@ -493,8 +509,14 @@ class Init(object): def consume_data(self, frequency=PER_INSTANCE): # Consume the userdata first, because we need want to let the part # handlers run first (for merging stuff) - self._consume_userdata(frequency) - self._consume_vendordata(frequency) + with reporting.ReportEventStack( + "consume-user-data", "reading and applying user-data", + parent=self.reporter): + self._consume_userdata(frequency) + with reporting.ReportEventStack( + "consume-vendor-data", "reading and applying vendor-data", + parent=self.reporter): + self._consume_vendordata(frequency) # Perform post-consumption adjustments so that # modules that run during the init stage reflect @@ -567,11 +589,16 @@ class Init(object): class Modules(object): - def __init__(self, init, cfg_files=None): + def __init__(self, init, cfg_files=None, reporter=None): self.init = init self.cfg_files = cfg_files # Created on first use self._cached_cfg = None + if reporter is None: + reporter = reporting.ReportEventStack( + name="module-reporter", description="module-desc", + reporting_enabled=False) + self.reporter = reporter @property def cfg(self): @@ -681,7 +708,19 @@ class Modules(object): which_ran.append(name) # This name will affect the semaphore name created run_name = "config-%s" % (name) - cc.run(run_name, mod.handle, func_args, freq=freq) + + desc = "running %s with frequency %s" % (run_name, freq) + myrep = reporting.ReportEventStack( + name=run_name, description=desc, parent=self.reporter) + + with myrep: + ran, _r = cc.run(run_name, mod.handle, func_args, + freq=freq) + if ran: + myrep.message = "%s ran successfully" % run_name + else: + myrep.message = "%s previously ran" % run_name + except Exception as e: util.logexc(LOG, "Running module %s (%s) failed", name, mod) failures.append((name, e)) diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 0e65f431..dca4cc85 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -25,6 +25,10 @@ import time import six import requests +import oauthlib.oauth1 as oauth1 +import os +import json +from functools import partial from requests import exceptions from six.moves.urllib.parse import ( @@ -147,13 +151,14 @@ class UrlResponse(object): class UrlError(IOError): - def __init__(self, cause, code=None, headers=None): + def __init__(self, cause, code=None, headers=None, url=None): IOError.__init__(self, str(cause)) self.cause = cause self.code = code self.headers = headers if self.headers is None: self.headers = {} + self.url = url def _get_ssl_args(url, ssl_details): @@ -247,9 +252,10 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1, and hasattr(e, 'response') # This appeared in v 0.10.8 and hasattr(e.response, 'status_code')): excps.append(UrlError(e, code=e.response.status_code, - headers=e.response.headers)) + headers=e.response.headers, + url=url)) else: - excps.append(UrlError(e)) + excps.append(UrlError(e, url=url)) if SSL_ENABLED and isinstance(e, exceptions.SSLError): # ssl exceptions are not going to get fixed by waiting a # few seconds @@ -333,11 +339,11 @@ def wait_for_url(urls, max_wait=None, timeout=None, if not response.contents: reason = "empty response [%s]" % (response.code) url_exc = UrlError(ValueError(reason), code=response.code, - headers=response.headers) + headers=response.headers, url=url) elif not response.ok(): reason = "bad status code [%s]" % (response.code) url_exc = UrlError(ValueError(reason), code=response.code, - headers=response.headers) + headers=response.headers, url=url) else: return url except UrlError as e: @@ -368,3 +374,127 @@ def wait_for_url(urls, max_wait=None, timeout=None, time.sleep(sleep_time) return False + + +class OauthUrlHelper(object): + def __init__(self, consumer_key=None, token_key=None, + token_secret=None, consumer_secret=None, + skew_data_file="/run/oauth_skew.json"): + self.consumer_key = consumer_key + self.consumer_secret = consumer_secret or "" + self.token_key = token_key + self.token_secret = token_secret + self.skew_data_file = skew_data_file + self._do_oauth = True + self.skew_change_limit = 5 + required = (self.token_key, self.token_secret, self.consumer_key) + if not any(required): + self._do_oauth = False + elif not all(required): + raise ValueError("all or none of token_key, token_secret, or " + "consumer_key can be set") + + old = self.read_skew_file() + self.skew_data = old or {} + + def read_skew_file(self): + if self.skew_data_file and os.path.isfile(self.skew_data_file): + with open(self.skew_data_file, mode="r") as fp: + return json.load(fp.read()) + return None + + def update_skew_file(self, host, value): + # this is not atomic + if not self.skew_data_file: + return + cur = self.read_skew_file() + cur[host] = value + with open(self.skew_data_file, mode="w") as fp: + fp.write(json.dumps(cur)) + + def exception_cb(self, msg, exception): + if not (isinstance(exception, UrlError) and + (exception.code == 403 or exception.code == 401)): + return + + if 'date' not in exception.headers: + LOG.warn("Missing header 'date' in %s response", exception.code) + return + + date = exception.headers['date'] + try: + remote_time = time.mktime(parsedate(date)) + except Exception as e: + LOG.warn("Failed to convert datetime '%s': %s", date, e) + return + + skew = int(remote_time - time.time()) + host = urlparse(exception.url).netloc + old_skew = self.skew_data.get(host, 0) + if abs(old_skew - skew) > self.skew_change_limit: + self.update_skew_file(host, skew) + LOG.warn("Setting oauth clockskew for %s to %d", host, skew) + skew_data[host] = skew + + return + + def headers_cb(self, url): + if not self._do_oauth: + return {} + + timestamp = None + host = urlparse(url).netloc + if self.skew_data and host in self.skew_data: + timestamp = int(time.time()) + self.skew_data[host] + + return oauth_headers( + url=url, consumer_key=self.consumer_key, + token_key=self.token_key, token_secret=self.token_secret, + consumer_secret=self.consumer_secret, timestamp=timestamp) + + def _wrapped(self, wrapped_func, args, kwargs): + kwargs['headers_cb'] = partial( + self._headers_cb, kwargs.get('headers_cb')) + kwargs['exception_cb'] = partial( + self._exception_cb, kwargs.get('exception_cb')) + return wrapped_func(*args, **kwargs) + + def wait_for_url(self, *args, **kwargs): + return self._wrapped(wait_for_url, args, kwargs) + + def readurl(self, *args, **kwargs): + return self._wrapped(readurl, args, kwargs) + + def _exception_cb(self, extra_exception_cb, msg, exception): + ret = None + try: + if extra_exception_cb: + ret = extra_exception_cb(msg, exception) + finally: + self.exception_cb(msg, exception) + return ret + + def _headers_cb(self, extra_headers_cb, url): + headers = {} + if extra_headers_cb: + headers = extra_headers_cb(url) + headers.update(self.headers_cb(url)) + return headers + + +def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, + timestamp=None): + if timestamp: + timestamp = str(timestamp) + else: + timestamp = None + + client = oauth1.Client( + consumer_key, + client_secret=consumer_secret, + resource_owner_key=token_key, + resource_owner_secret=token_secret, + signature_method=oauth1.SIGNATURE_PLAINTEXT, + timestamp=timestamp) + uri, signed_headers, body = client.sign(url) + return signed_headers diff --git a/cloudinit/util.py b/cloudinit/util.py index 02ba654a..09e583f5 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -782,7 +782,8 @@ def read_file_or_url(url, timeout=5, retries=10, code = e.errno if e.errno == errno.ENOENT: code = url_helper.NOT_FOUND - raise url_helper.UrlError(cause=e, code=code, headers=None) + raise url_helper.UrlError(cause=e, code=code, headers=None, + url=url) return url_helper.FileResponse(file_path, contents=contents) else: return url_helper.readurl(url, |