diff options
| -rw-r--r-- | ChangeLog | 1 | ||||
| -rwxr-xr-x | bin/cloud-init | 51 | ||||
| -rw-r--r-- | cloudinit/cloud.py | 9 | ||||
| -rw-r--r-- | cloudinit/registry.py | 37 | ||||
| -rw-r--r-- | cloudinit/reporting/__init__.py | 240 | ||||
| -rw-r--r-- | cloudinit/reporting/handlers.py | 90 | ||||
| -rw-r--r-- | cloudinit/sources/DataSourceMAAS.py | 146 | ||||
| -rw-r--r-- | cloudinit/sources/__init__.py | 25 | ||||
| -rw-r--r-- | cloudinit/stages.py | 59 | ||||
| -rw-r--r-- | cloudinit/url_helper.py | 140 | ||||
| -rw-r--r-- | cloudinit/util.py | 3 | ||||
| -rw-r--r-- | doc/examples/cloud-config-reporting.txt | 17 | ||||
| -rw-r--r-- | tests/unittests/test_datasource/test_maas.py | 2 | ||||
| -rw-r--r-- | tests/unittests/test_registry.py | 28 | ||||
| -rw-r--r-- | tests/unittests/test_reporting.py | 359 | 
15 files changed, 1069 insertions, 138 deletions
| @@ -59,6 +59,7 @@   - _read_dmi_syspath: fix bad log message causing unintended exception   - rsyslog: add additional configuration mode (LP: #1478103)   - status_wrapper in main: fix use of print_exc when handling exception + - reporting: add reporting module for web hook or logging of events.  0.7.6:   - open 0.7.6   - Enable vendordata on CloudSigma datasource (LP: #1303986) diff --git a/bin/cloud-init b/bin/cloud-init index 63c13b09..1f64461e 100755 --- a/bin/cloud-init +++ b/bin/cloud-init @@ -46,6 +46,7 @@ from cloudinit import sources  from cloudinit import stages  from cloudinit import templater  from cloudinit import util +from cloudinit import reporting  from cloudinit import version  from cloudinit.settings import (PER_INSTANCE, PER_ALWAYS, PER_ONCE, @@ -136,6 +137,11 @@ def run_module_section(mods, action_name, section):          return failures +def apply_reporting_cfg(cfg): +    if cfg.get('reporting'): +        reporting.update_configuration(cfg.get('reporting')) + +  def main_init(name, args):      deps = [sources.DEP_FILESYSTEM, sources.DEP_NETWORK]      if args.local: @@ -171,7 +177,7 @@ def main_init(name, args):          w_msg = welcome_format(name)      else:          w_msg = welcome_format("%s-local" % (name)) -    init = stages.Init(deps) +    init = stages.Init(ds_deps=deps, reporter=args.reporter)      # Stage 1      init.read_cfg(extract_fns(args))      # Stage 2 @@ -190,6 +196,7 @@ def main_init(name, args):                      " longer be active shortly"))          logging.resetLogging()      logging.setupLogging(init.cfg) +    apply_reporting_cfg(init.cfg)      # Any log usage prior to setupLogging above did not have local user log      # config applied.  We send the welcome message now, as stderr/out have @@ -282,8 +289,10 @@ def main_init(name, args):          util.logexc(LOG, "Consuming user data failed!")          return (init.datasource, ["Consuming user data failed!"]) +    apply_reporting_cfg(init.cfg) +      # Stage 8 - re-read and apply relevant cloud-config to include user-data -    mods = stages.Modules(init, extract_fns(args)) +    mods = stages.Modules(init, extract_fns(args), reporter=args.reporter)      # Stage 9      try:          outfmt_orig = outfmt @@ -313,7 +322,7 @@ def main_modules(action_name, args):      # 5. Run the modules for the given stage name      # 6. Done!      w_msg = welcome_format("%s:%s" % (action_name, name)) -    init = stages.Init(ds_deps=[]) +    init = stages.Init(ds_deps=[], reporter=args.reporter)      # Stage 1      init.read_cfg(extract_fns(args))      # Stage 2 @@ -328,7 +337,7 @@ def main_modules(action_name, args):          if not args.force:              return [(msg)]      # Stage 3 -    mods = stages.Modules(init, extract_fns(args)) +    mods = stages.Modules(init, extract_fns(args), reporter=args.reporter)      # Stage 4      try:          LOG.debug("Closing stdin") @@ -342,6 +351,7 @@ def main_modules(action_name, args):                      " longer be active shortly"))          logging.resetLogging()      logging.setupLogging(mods.cfg) +    apply_reporting_cfg(init.cfg)      # now that logging is setup and stdout redirected, send welcome      welcome(name, msg=w_msg) @@ -366,7 +376,7 @@ def main_single(name, args):      # 6. Done!      mod_name = args.name      w_msg = welcome_format(name) -    init = stages.Init(ds_deps=[]) +    init = stages.Init(ds_deps=[], reporter=args.reporter)      # Stage 1      init.read_cfg(extract_fns(args))      # Stage 2 @@ -383,7 +393,7 @@ def main_single(name, args):          if not args.force:              return 1      # Stage 3 -    mods = stages.Modules(init, extract_fns(args)) +    mods = stages.Modules(init, extract_fns(args), reporter=args.reporter)      mod_args = args.module_args      if mod_args:          LOG.debug("Using passed in arguments %s", mod_args) @@ -404,6 +414,7 @@ def main_single(name, args):                     " longer be active shortly"))          logging.resetLogging()      logging.setupLogging(mods.cfg) +    apply_reporting_cfg(init.cfg)      # now that logging is setup and stdout redirected, send welcome      welcome(name, msg=w_msg) @@ -549,6 +560,8 @@ def main():                                ' found (use at your own risk)'),                          dest='force',                          default=False) + +    parser.set_defaults(reporter=None)      subparsers = parser.add_subparsers()      # Each action and its sub-options (if any) @@ -595,6 +608,9 @@ def main():                                help=("frequency of the module"),                                required=False,                                choices=list(FREQ_SHORT_NAMES.keys())) +    parser_single.add_argument("--report", action="store_true", +                               help="enable reporting", +                               required=False)      parser_single.add_argument("module_args", nargs="*",                                metavar='argument',                                help=('any additional arguments to' @@ -617,8 +633,27 @@ def main():      if name in ("modules", "init"):          functor = status_wrapper -    return util.log_time(logfunc=LOG.debug, msg="cloud-init mode '%s'" % name, -                         get_uptime=True, func=functor, args=(name, args)) +    report_on = True +    if name == "init": +        if args.local: +            rname, rdesc = ("init-local", "searching for local datasources") +        else: +            rname, rdesc = ("init-network", +                            "searching for network datasources") +    elif name == "modules": +        rname, rdesc = ("modules-%s" % args.mode, +                        "running modules for %s" % args.mode) +    elif name == "single": +        rname, rdesc = ("single/%s" % args.name, +                        "running single module %s" % args.name) +        report_on = args.report + +    args.reporter = reporting.ReportEventStack( +        rname, rdesc, reporting_enabled=report_on) +    with args.reporter: +        return util.log_time( +            logfunc=LOG.debug, msg="cloud-init mode '%s'" % name, +            get_uptime=True, func=functor, args=(name, args))  if __name__ == '__main__': diff --git a/cloudinit/cloud.py b/cloudinit/cloud.py index 95e0cfb2..edee3887 100644 --- a/cloudinit/cloud.py +++ b/cloudinit/cloud.py @@ -24,6 +24,7 @@ import copy  import os  from cloudinit import log as logging +from cloudinit import reporting  LOG = logging.getLogger(__name__) @@ -40,12 +41,18 @@ LOG = logging.getLogger(__name__)  class Cloud(object): -    def __init__(self, datasource, paths, cfg, distro, runners): +    def __init__(self, datasource, paths, cfg, distro, runners, reporter=None):          self.datasource = datasource          self.paths = paths          self.distro = distro          self._cfg = cfg          self._runners = runners +        if reporter is None: +            reporter = reporting.ReportEventStack( +                name="unnamed-cloud-reporter", +                description="unnamed-cloud-reporter", +                reporting_enabled=False) +        self.reporter = reporter      # If a 'user' manipulates logging or logging services      # it is typically useful to cause the logging to be diff --git a/cloudinit/registry.py b/cloudinit/registry.py new file mode 100644 index 00000000..04368ddf --- /dev/null +++ b/cloudinit/registry.py @@ -0,0 +1,37 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init.  See LICENCE file for license information. +# +# vi: ts=4 expandtab +import copy + + +class DictRegistry(object): +    """A simple registry for a mapping of objects.""" + +    def __init__(self): +        self.reset() + +    def reset(self): +        self._items = {} + +    def register_item(self, key, item): +        """Add item to the registry.""" +        if key in self._items: +            raise ValueError( +                'Item already registered with key {0}'.format(key)) +        self._items[key] = item + +    def unregister_item(self, key, force=True): +        """Remove item from the registry.""" +        if key in self._items: +            del self._items[key] +        elif not force: +            raise KeyError("%s: key not present to unregister" % key) + +    @property +    def registered_items(self): +        """All the items that have been registered. + +        This cannot be used to modify the contents of the registry. +        """ +        return copy.copy(self._items) diff --git a/cloudinit/reporting/__init__.py b/cloudinit/reporting/__init__.py new file mode 100644 index 00000000..e23fab32 --- /dev/null +++ b/cloudinit/reporting/__init__.py @@ -0,0 +1,240 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init.  See LICENCE file for license information. +# +# vi: ts=4 expandtab +""" +cloud-init reporting framework + +The reporting framework is intended to allow all parts of cloud-init to +report events in a structured manner. +""" + +from ..registry import DictRegistry +from ..reporting.handlers import available_handlers + + +FINISH_EVENT_TYPE = 'finish' +START_EVENT_TYPE = 'start' + +DEFAULT_CONFIG = { +    'logging': {'type': 'log'}, +} + + +class _nameset(set): +    def __getattr__(self, name): +        if name in self: +            return name +        raise AttributeError("%s not a valid value" % name) + + +status = _nameset(("SUCCESS", "WARN", "FAIL")) + + +class ReportingEvent(object): +    """Encapsulation of event formatting.""" + +    def __init__(self, event_type, name, description): +        self.event_type = event_type +        self.name = name +        self.description = description + +    def as_string(self): +        """The event represented as a string.""" +        return '{0}: {1}: {2}'.format( +            self.event_type, self.name, self.description) + +    def as_dict(self): +        """The event represented as a dictionary.""" +        return {'name': self.name, 'description': self.description, +                'event_type': self.event_type} + + +class FinishReportingEvent(ReportingEvent): + +    def __init__(self, name, description, result=status.SUCCESS): +        super(FinishReportingEvent, self).__init__( +            FINISH_EVENT_TYPE, name, description) +        self.result = result +        if result not in status: +            raise ValueError("Invalid result: %s" % result) + +    def as_string(self): +        return '{0}: {1}: {2}: {3}'.format( +            self.event_type, self.name, self.result, self.description) + +    def as_dict(self): +        """The event represented as json friendly.""" +        data = super(FinishReportingEvent, self).as_dict() +        data['result'] = self.result +        return data + + +def update_configuration(config): +    """Update the instanciated_handler_registry. + +    :param config: +        The dictionary containing changes to apply.  If a key is given +        with a False-ish value, the registered handler matching that name +        will be unregistered. +    """ +    for handler_name, handler_config in config.items(): +        if not handler_config: +            instantiated_handler_registry.unregister_item( +                handler_name, force=True) +            continue +        registered = instantiated_handler_registry.registered_items +        handler_config = handler_config.copy() +        cls = available_handlers.registered_items[handler_config.pop('type')] +        instantiated_handler_registry.unregister_item(handler_name) +        instance = cls(**handler_config) +        instantiated_handler_registry.register_item(handler_name, instance) + + +def report_event(event): +    """Report an event to all registered event handlers. + +    This should generally be called via one of the other functions in +    the reporting module. + +    :param event_type: +        The type of the event; this should be a constant from the +        reporting module. +    """ +    for _, handler in instantiated_handler_registry.registered_items.items(): +        handler.publish_event(event) + + +def report_finish_event(event_name, event_description, +                        result=status.SUCCESS): +    """Report a "finish" event. + +    See :py:func:`.report_event` for parameter details. +    """ +    event = FinishReportingEvent(event_name, event_description, result) +    return report_event(event) + + +def report_start_event(event_name, event_description): +    """Report a "start" event. + +    :param event_name: +        The name of the event; this should be a topic which events would +        share (e.g. it will be the same for start and finish events). + +    :param event_description: +        A human-readable description of the event that has occurred. +    """ +    event = ReportingEvent(START_EVENT_TYPE, event_name, event_description) +    return report_event(event) + + +class ReportEventStack(object): +    """Context Manager for using :py:func:`report_event` + +    This enables calling :py:func:`report_start_event` and +    :py:func:`report_finish_event` through a context manager. + +    :param name: +        the name of the event + +    :param description: +        the event's description, passed on to :py:func:`report_start_event` + +    :param message: +        the description to use for the finish event. defaults to +        :param:description. + +    :param parent: +    :type parent: :py:class:ReportEventStack or None +        The parent of this event.  The parent is populated with +        results of all its children.  The name used in reporting +        is <parent.name>/<name> + +    :param reporting_enabled: +        Indicates if reporting events should be generated. +        If not provided, defaults to the parent's value, or True if no parent +        is provided. + +    :param result_on_exception: +        The result value to set if an exception is caught. default +        value is FAIL. +    """ +    def __init__(self, name, description, message=None, parent=None, +                 reporting_enabled=None, result_on_exception=status.FAIL): +        self.parent = parent +        self.name = name +        self.description = description +        self.message = message +        self.result_on_exception = result_on_exception +        self.result = status.SUCCESS + +        # use parents reporting value if not provided +        if reporting_enabled is None: +            if parent: +                reporting_enabled = parent.reporting_enabled +            else: +                reporting_enabled = True +        self.reporting_enabled = reporting_enabled + +        if parent: +            self.fullname = '/'.join((parent.fullname, name,)) +        else: +            self.fullname = self.name +        self.children = {} + +    def __repr__(self): +        return ("ReportEventStack(%s, %s, reporting_enabled=%s)" % +                (self.name, self.description, self.reporting_enabled)) + +    def __enter__(self): +        self.result = status.SUCCESS +        if self.reporting_enabled: +            report_start_event(self.fullname, self.description) +        if self.parent: +            self.parent.children[self.name] = (None, None) +        return self + +    def _childrens_finish_info(self): +        for cand_result in (status.FAIL, status.WARN): +            for name, (value, msg) in self.children.items(): +                if value == cand_result: +                    return (value, self.message) +        return (self.result, self.message) + +    @property +    def result(self): +        return self._result + +    @result.setter +    def result(self, value): +        if value not in status: +            raise ValueError("'%s' not a valid result" % value) +        self._result = value + +    @property +    def message(self): +        if self._message is not None: +            return self._message +        return self.description + +    @message.setter +    def message(self, value): +        self._message = value + +    def _finish_info(self, exc): +        # return tuple of description, and value +        if exc: +            return (self.result_on_exception, self.message) +        return self._childrens_finish_info() + +    def __exit__(self, exc_type, exc_value, traceback): +        (result, msg) = self._finish_info(exc_value) +        if self.parent: +            self.parent.children[self.name] = (result, msg) +        if self.reporting_enabled: +            report_finish_event(self.fullname, msg, result) + + +instantiated_handler_registry = DictRegistry() +update_configuration(DEFAULT_CONFIG) diff --git a/cloudinit/reporting/handlers.py b/cloudinit/reporting/handlers.py new file mode 100644 index 00000000..1343311f --- /dev/null +++ b/cloudinit/reporting/handlers.py @@ -0,0 +1,90 @@ +# vi: ts=4 expandtab + +import abc +import oauthlib.oauth1 as oauth1 +import six + +from ..registry import DictRegistry +from .. import (url_helper, util) +from .. import log as logging + + +LOG = logging.getLogger(__name__) + + +@six.add_metaclass(abc.ABCMeta) +class ReportingHandler(object): +    """Base class for report handlers. + +    Implement :meth:`~publish_event` for controlling what +    the handler does with an event. +    """ + +    @abc.abstractmethod +    def publish_event(self, event): +        """Publish an event to the ``INFO`` log level.""" + + +class LogHandler(ReportingHandler): +    """Publishes events to the cloud-init log at the ``INFO`` log level.""" + +    def __init__(self, level="DEBUG"): +        super(LogHandler, self).__init__() +        if isinstance(level, int): +            pass +        else: +            input_level = level +            try: +                level = gettattr(logging, level.upper()) +            except: +                LOG.warn("invalid level '%s', using WARN", input_level) +                level = logging.WARN +        self.level = level + +    def publish_event(self, event): +        """Publish an event to the ``INFO`` log level.""" +        logger = logging.getLogger( +            '.'.join(['cloudinit', 'reporting', event.event_type, event.name])) +        logger.log(self.level, event.as_string()) + + +class PrintHandler(ReportingHandler): +    def publish_event(self, event): +        """Publish an event to the ``INFO`` log level.""" + + +class WebHookHandler(ReportingHandler): +    def __init__(self, endpoint, consumer_key=None, token_key=None, +                 token_secret=None, consumer_secret=None, timeout=None, +                 retries=None): +        super(WebHookHandler, self).__init__() + +        if any([consumer_key, token_key, token_secret, consumer_secret]): +            self.oauth_helper = url_helper.OauthUrlHelper( +                consumer_key=consumer_key, token_key=token_key, +                token_secret=token_secret, consumer_secret=consumer_secret) +        else: +            self.oauth_helper = None +        self.endpoint = endpoint +        self.timeout = timeout +        self.retries = retries +        self.ssl_details = util.fetch_ssl_details() + +    def publish_event(self, event): +        if self.oauth_helper: +            readurl = self.oauth_helper.readurl +        else: +            readurl = url_helper.readurl +        try: +            return readurl( +                self.endpoint, data=event.as_dict(), +                timeout=self.timeout, +                retries=self.retries, ssl_details=self.ssl_details) +        except: +            LOG.warn("failed posting event: %s" % event.as_string()) + + +available_handlers = DictRegistry() +available_handlers.register_item('log', LogHandler) +available_handlers.register_item('print', PrintHandler) +available_handlers.register_item('webhook', WebHookHandler) diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py index c1a0eb61..2f36bbe2 100644 --- a/cloudinit/sources/DataSourceMAAS.py +++ b/cloudinit/sources/DataSourceMAAS.py @@ -52,7 +52,20 @@ class DataSourceMAAS(sources.DataSource):          sources.DataSource.__init__(self, sys_cfg, distro, paths)          self.base_url = None          self.seed_dir = os.path.join(paths.seed_dir, 'maas') -        self.oauth_clockskew = None +        self.oauth_helper = self._get_helper() + +    def _get_helper(self): +        mcfg = self.ds_cfg +        # If we are missing token_key, token_secret or consumer_key +        # then just do non-authed requests +        for required in ('token_key', 'token_secret', 'consumer_key'): +            if required not in mcfg: +                return url_helper.OauthUrlHelper() + +        return url_helper.OauthHelper( +            consumer_key=mcfg['consumer_key'], token_key=mcfg['token_key'], +            token_secret=mcfg['token_secret'], +            consumer_secret=mcfg.get('consumer_secret'))      def __str__(self):          root = sources.DataSource.__str__(self) @@ -84,9 +97,9 @@ class DataSourceMAAS(sources.DataSource):              self.base_url = url -            (userdata, metadata) = read_maas_seed_url(self.base_url, -                                                      self._md_headers, -                                                      paths=self.paths) +            (userdata, metadata) = read_maas_seed_url( +                self.base_url, self.oauth_helper.md_headers, +                paths=self.paths)              self.userdata_raw = userdata              self.metadata = metadata              return True @@ -94,31 +107,8 @@ class DataSourceMAAS(sources.DataSource):              util.logexc(LOG, "Failed fetching metadata from url %s", url)              return False -    def _md_headers(self, url): -        mcfg = self.ds_cfg - -        # If we are missing token_key, token_secret or consumer_key -        # then just do non-authed requests -        for required in ('token_key', 'token_secret', 'consumer_key'): -            if required not in mcfg: -                return {} - -        consumer_secret = mcfg.get('consumer_secret', "") - -        timestamp = None -        if self.oauth_clockskew: -            timestamp = int(time.time()) + self.oauth_clockskew - -        return oauth_headers(url=url, -                             consumer_key=mcfg['consumer_key'], -                             token_key=mcfg['token_key'], -                             token_secret=mcfg['token_secret'], -                             consumer_secret=consumer_secret, -                             timestamp=timestamp) -      def wait_for_metadata_service(self, url):          mcfg = self.ds_cfg -          max_wait = 120          try:              max_wait = int(mcfg.get("max_wait", max_wait)) @@ -138,10 +128,8 @@ class DataSourceMAAS(sources.DataSource):          starttime = time.time()          check_url = "%s/%s/meta-data/instance-id" % (url, MD_VERSION)          urls = [check_url] -        url = url_helper.wait_for_url(urls=urls, max_wait=max_wait, -                                      timeout=timeout, -                                      exception_cb=self._except_cb, -                                      headers_cb=self._md_headers) +        url = self.oauth_helper.wait_for_url( +            urls=urls, max_wait=max_wait, timeout=timeout)          if url:              LOG.debug("Using metadata source: '%s'", url) @@ -151,26 +139,6 @@ class DataSourceMAAS(sources.DataSource):          return bool(url) -    def _except_cb(self, msg, exception): -        if not (isinstance(exception, url_helper.UrlError) and -                (exception.code == 403 or exception.code == 401)): -            return - -        if 'date' not in exception.headers: -            LOG.warn("Missing header 'date' in %s response", exception.code) -            return - -        date = exception.headers['date'] -        try: -            ret_time = time.mktime(parsedate(date)) -        except Exception as e: -            LOG.warn("Failed to convert datetime '%s': %s", date, e) -            return - -        self.oauth_clockskew = int(ret_time - time.time()) -        LOG.warn("Setting oauth clockskew to %d", self.oauth_clockskew) -        return -  def read_maas_seed_dir(seed_d):      """ @@ -196,12 +164,12 @@ def read_maas_seed_dir(seed_d):      return check_seed_contents(md, seed_d) -def read_maas_seed_url(seed_url, header_cb=None, timeout=None, +def read_maas_seed_url(seed_url, read_file_or_url=None, timeout=None,                         version=MD_VERSION, paths=None):      """      Read the maas datasource at seed_url. -      - header_cb is a method that should return a headers dictionary for -        a given url +      read_file_or_url is a method that should provide an interface +      like util.read_file_or_url      Expected format of seed_url is are the following files:        * <seed_url>/<version>/meta-data/instance-id @@ -222,14 +190,12 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None,          'user-data': "%s/%s" % (base_url, 'user-data'),      } +    if read_file_or_url is None: +        read_file_or_url = util.read_file_or_url +      md = {}      for name in file_order:          url = files.get(name) -        if not header_cb: -            def _cb(url): -                return {} -            header_cb = _cb -          if name == 'user-data':              retries = 0          else: @@ -237,10 +203,8 @@ def read_maas_seed_url(seed_url, header_cb=None, timeout=None,          try:              ssl_details = util.fetch_ssl_details(paths) -            resp = util.read_file_or_url(url, retries=retries, -                                         headers_cb=header_cb, -                                         timeout=timeout, -                                         ssl_details=ssl_details) +            resp = read_file_or_url(url, retries=retries, +                                    timeout=timeout, ssl_details=ssl_details)              if resp.ok():                  if name in BINARY_FIELDS:                      md[name] = resp.contents @@ -280,24 +244,6 @@ def check_seed_contents(content, seed):      return (userdata, md) -def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, -                  timestamp=None): -    if timestamp: -        timestamp = str(timestamp) -    else: -        timestamp = None - -    client = oauth1.Client( -        consumer_key, -        client_secret=consumer_secret, -        resource_owner_key=token_key, -        resource_owner_secret=token_secret, -        signature_method=oauth1.SIGNATURE_PLAINTEXT, -        timestamp=timestamp) -    uri, signed_headers, body = client.sign(url) -    return signed_headers - -  class MAASSeedDirNone(Exception):      pass @@ -361,47 +307,39 @@ if __name__ == "__main__":                  if key in cfg and creds[key] is None:                      creds[key] = cfg[key] -        def geturl(url, headers_cb): -            req = Request(url, data=None, headers=headers_cb(url)) -            return urlopen(req).read() +        oauth_helper = url_helper.OauthUrlHelper(**creds) + +        def geturl(url): +            return oauth_helper.readurl(url).contents          def printurl(url, headers_cb): -            print("== %s ==\n%s\n" % (url, geturl(url, headers_cb))) +            print("== %s ==\n%s\n" % (url, geturl(url))) -        def crawl(url, headers_cb=None): +        def crawl(url):              if url.endswith("/"): -                for line in geturl(url, headers_cb).splitlines(): +                for line in geturl(url).splitlines():                      if line.endswith("/"): -                        crawl("%s%s" % (url, line), headers_cb) +                        crawl("%s%s" % (url, line))                      else: -                        printurl("%s%s" % (url, line), headers_cb) +                        printurl("%s%s" % (url, line))              else: -                printurl(url, headers_cb) - -        def my_headers(url): -            headers = {} -            if creds.get('consumer_key', None) is not None: -                headers = oauth_headers(url, **creds) -            return headers +                printurl(url)          if args.subcmd == "check-seed": -            if args.url.startswith("http"): -                (userdata, metadata) = read_maas_seed_url(args.url, -                                                          header_cb=my_headers, -                                                          version=args.apiver) -            else: -                (userdata, metadata) = read_maas_seed_url(args.url) +            (userdata, metadata) = read_maas_seed_url( +                args.url, read_file_or_url=oauth_helper.read_file_or_url, +                version=args.apiver)              print("=== userdata ===")              print(userdata)              print("=== metadata ===")              pprint.pprint(metadata)          elif args.subcmd == "get": -            printurl(args.url, my_headers) +            printurl(args.url)          elif args.subcmd == "crawl":              if not args.url.endswith("/"):                  args.url = "%s/" % args.url -            crawl(args.url, my_headers) +            crawl(args.url)      main() diff --git a/cloudinit/sources/__init__.py b/cloudinit/sources/__init__.py index a21c08c2..838cd198 100644 --- a/cloudinit/sources/__init__.py +++ b/cloudinit/sources/__init__.py @@ -27,6 +27,7 @@ import six  from cloudinit import importer  from cloudinit import log as logging +from cloudinit import reporting  from cloudinit import type_utils  from cloudinit import user_data as ud  from cloudinit import util @@ -246,17 +247,25 @@ def normalize_pubkey_data(pubkey_data):      return keys -def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list): +def find_source(sys_cfg, distro, paths, ds_deps, cfg_list, pkg_list, reporter):      ds_list = list_sources(cfg_list, ds_deps, pkg_list)      ds_names = [type_utils.obj_name(f) for f in ds_list] -    LOG.debug("Searching for data source in: %s", ds_names) - -    for cls in ds_list: +    mode = "network" if DEP_NETWORK in ds_deps else "local" +    LOG.debug("Searching for %s data source in: %s", mode, ds_names) + +    for name, cls in zip(ds_names, ds_list): +        myrep = reporting.ReportEventStack( +            name="search-%s" % name.replace("DataSource", ""), +            description="searching for %s data from %s" % (mode, name), +            message="no %s data found from %s" % (mode, name), +            parent=reporter)          try: -            LOG.debug("Seeing if we can get any data from %s", cls) -            s = cls(sys_cfg, distro, paths) -            if s.get_data(): -                return (s, type_utils.obj_name(cls)) +            with myrep: +                LOG.debug("Seeing if we can get any data from %s", cls) +                s = cls(sys_cfg, distro, paths) +                if s.get_data(): +                    myrep.message = "found %s data from %s" % (mode, name) +                    return (s, type_utils.obj_name(cls))          except Exception:              util.logexc(LOG, "Getting data from %s failed", cls) diff --git a/cloudinit/stages.py b/cloudinit/stages.py index d28e765b..d300709d 100644 --- a/cloudinit/stages.py +++ b/cloudinit/stages.py @@ -46,6 +46,7 @@ from cloudinit import log as logging  from cloudinit import sources  from cloudinit import type_utils  from cloudinit import util +from cloudinit import reporting  LOG = logging.getLogger(__name__) @@ -53,7 +54,7 @@ NULL_DATA_SOURCE = None  class Init(object): -    def __init__(self, ds_deps=None): +    def __init__(self, ds_deps=None, reporter=None):          if ds_deps is not None:              self.ds_deps = ds_deps          else: @@ -65,6 +66,12 @@ class Init(object):          # Changed only when a fetch occurs          self.datasource = NULL_DATA_SOURCE +        if reporter is None: +            reporter = reporting.ReportEventStack( +                name="init-reporter", description="init-desc", +                reporting_enabled=False) +        self.reporter = reporter +      def _reset(self, reset_ds=False):          # Recreated on access          self._cfg = None @@ -234,9 +241,17 @@ class Init(object):      def _get_data_source(self):          if self.datasource is not NULL_DATA_SOURCE:              return self.datasource -        ds = self._restore_from_cache() -        if ds: -            LOG.debug("Restored from cache, datasource: %s", ds) + +        with reporting.ReportEventStack( +                name="check-cache", +                description="attempting to read from cache", +                parent=self.reporter) as myrep: +            ds = self._restore_from_cache() +            if ds: +                LOG.debug("Restored from cache, datasource: %s", ds) +                myrep.description = "restored from cache" +            else: +                myrep.description = "no cache found"          if not ds:              (cfg_list, pkg_list) = self._get_datasources()              # Deep copy so that user-data handlers can not modify @@ -246,7 +261,7 @@ class Init(object):                                                 self.paths,                                                 copy.deepcopy(self.ds_deps),                                                 cfg_list, -                                               pkg_list) +                                               pkg_list, self.reporter)              LOG.info("Loaded datasource %s - %s", dsname, ds)          self.datasource = ds          # Ensure we adjust our path members datasource @@ -327,7 +342,8 @@ class Init(object):          # Form the needed options to cloudify our members          return cloud.Cloud(self.datasource,                             self.paths, self.cfg, -                           self.distro, helpers.Runners(self.paths)) +                           self.distro, helpers.Runners(self.paths), +                           reporter=self.reporter)      def update(self):          if not self._write_to_cache(): @@ -493,8 +509,14 @@ class Init(object):      def consume_data(self, frequency=PER_INSTANCE):          # Consume the userdata first, because we need want to let the part          # handlers run first (for merging stuff) -        self._consume_userdata(frequency) -        self._consume_vendordata(frequency) +        with reporting.ReportEventStack( +            "consume-user-data", "reading and applying user-data", +            parent=self.reporter): +                self._consume_userdata(frequency) +        with reporting.ReportEventStack( +            "consume-vendor-data", "reading and applying vendor-data", +            parent=self.reporter): +                self._consume_vendordata(frequency)          # Perform post-consumption adjustments so that          # modules that run during the init stage reflect @@ -567,11 +589,16 @@ class Init(object):  class Modules(object): -    def __init__(self, init, cfg_files=None): +    def __init__(self, init, cfg_files=None, reporter=None):          self.init = init          self.cfg_files = cfg_files          # Created on first use          self._cached_cfg = None +        if reporter is None: +            reporter = reporting.ReportEventStack( +                name="module-reporter", description="module-desc", +                reporting_enabled=False) +        self.reporter = reporter      @property      def cfg(self): @@ -681,7 +708,19 @@ class Modules(object):                  which_ran.append(name)                  # This name will affect the semaphore name created                  run_name = "config-%s" % (name) -                cc.run(run_name, mod.handle, func_args, freq=freq) + +                desc = "running %s with frequency %s" % (run_name, freq) +                myrep = reporting.ReportEventStack( +                    name=run_name, description=desc, parent=self.reporter) + +                with myrep: +                    ran, _r = cc.run(run_name, mod.handle, func_args, +                                     freq=freq) +                    if ran: +                        myrep.message = "%s ran successfully" % run_name +                    else: +                        myrep.message = "%s previously ran" % run_name +              except Exception as e:                  util.logexc(LOG, "Running module %s (%s) failed", name, mod)                  failures.append((name, e)) diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 0e65f431..dca4cc85 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -25,6 +25,10 @@ import time  import six  import requests +import oauthlib.oauth1 as oauth1 +import os +import json +from functools import partial  from requests import exceptions  from six.moves.urllib.parse import ( @@ -147,13 +151,14 @@ class UrlResponse(object):  class UrlError(IOError): -    def __init__(self, cause, code=None, headers=None): +    def __init__(self, cause, code=None, headers=None, url=None):          IOError.__init__(self, str(cause))          self.cause = cause          self.code = code          self.headers = headers          if self.headers is None:              self.headers = {} +        self.url = url  def _get_ssl_args(url, ssl_details): @@ -247,9 +252,10 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1,                      and hasattr(e, 'response')  # This appeared in v 0.10.8                      and hasattr(e.response, 'status_code')):                  excps.append(UrlError(e, code=e.response.status_code, -                                      headers=e.response.headers)) +                                      headers=e.response.headers, +                                      url=url))              else: -                excps.append(UrlError(e)) +                excps.append(UrlError(e, url=url))                  if SSL_ENABLED and isinstance(e, exceptions.SSLError):                      # ssl exceptions are not going to get fixed by waiting a                      # few seconds @@ -333,11 +339,11 @@ def wait_for_url(urls, max_wait=None, timeout=None,                  if not response.contents:                      reason = "empty response [%s]" % (response.code)                      url_exc = UrlError(ValueError(reason), code=response.code, -                                       headers=response.headers) +                                       headers=response.headers, url=url)                  elif not response.ok():                      reason = "bad status code [%s]" % (response.code)                      url_exc = UrlError(ValueError(reason), code=response.code, -                                       headers=response.headers) +                                       headers=response.headers, url=url)                  else:                      return url              except UrlError as e: @@ -368,3 +374,127 @@ def wait_for_url(urls, max_wait=None, timeout=None,          time.sleep(sleep_time)      return False + + +class OauthUrlHelper(object): +    def __init__(self, consumer_key=None, token_key=None, +                 token_secret=None, consumer_secret=None, +                 skew_data_file="/run/oauth_skew.json"): +        self.consumer_key = consumer_key +        self.consumer_secret = consumer_secret or "" +        self.token_key = token_key +        self.token_secret = token_secret +        self.skew_data_file = skew_data_file +        self._do_oauth = True +        self.skew_change_limit = 5 +        required = (self.token_key, self.token_secret, self.consumer_key) +        if not any(required): +            self._do_oauth = False +        elif not all(required): +            raise ValueError("all or none of token_key, token_secret, or " +                             "consumer_key can be set") + +        old = self.read_skew_file() +        self.skew_data = old or {} + +    def read_skew_file(self): +        if self.skew_data_file and os.path.isfile(self.skew_data_file): +            with open(self.skew_data_file, mode="r") as fp: +                return json.load(fp.read()) +        return None + +    def update_skew_file(self, host, value): +        # this is not atomic +        if not self.skew_data_file: +            return +        cur = self.read_skew_file() +        cur[host] = value +        with open(self.skew_data_file, mode="w") as fp: +            fp.write(json.dumps(cur)) + +    def exception_cb(self, msg, exception): +        if not (isinstance(exception, UrlError) and +                (exception.code == 403 or exception.code == 401)): +            return + +        if 'date' not in exception.headers: +            LOG.warn("Missing header 'date' in %s response", exception.code) +            return + +        date = exception.headers['date'] +        try: +            remote_time = time.mktime(parsedate(date)) +        except Exception as e: +            LOG.warn("Failed to convert datetime '%s': %s", date, e) +            return + +        skew = int(remote_time - time.time()) +        host = urlparse(exception.url).netloc +        old_skew = self.skew_data.get(host, 0) +        if abs(old_skew - skew) > self.skew_change_limit: +            self.update_skew_file(host, skew) +            LOG.warn("Setting oauth clockskew for %s to %d", host, skew) +        skew_data[host] = skew + +        return + +    def headers_cb(self, url): +        if not self._do_oauth: +            return {} + +        timestamp = None +        host = urlparse(url).netloc +        if self.skew_data and host in self.skew_data: +            timestamp = int(time.time()) + self.skew_data[host] + +        return oauth_headers( +            url=url, consumer_key=self.consumer_key, +            token_key=self.token_key, token_secret=self.token_secret, +            consumer_secret=self.consumer_secret, timestamp=timestamp) + +    def _wrapped(self, wrapped_func, args, kwargs): +        kwargs['headers_cb'] = partial( +            self._headers_cb, kwargs.get('headers_cb')) +        kwargs['exception_cb'] = partial( +            self._exception_cb, kwargs.get('exception_cb')) +        return wrapped_func(*args, **kwargs) + +    def wait_for_url(self, *args, **kwargs): +        return self._wrapped(wait_for_url, args, kwargs) + +    def readurl(self, *args, **kwargs): +        return self._wrapped(readurl, args, kwargs) + +    def _exception_cb(self, extra_exception_cb, msg, exception): +        ret = None +        try: +            if extra_exception_cb: +                ret = extra_exception_cb(msg, exception) +        finally: +                self.exception_cb(msg, exception) +        return ret + +    def _headers_cb(self, extra_headers_cb, url): +        headers = {} +        if extra_headers_cb: +            headers = extra_headers_cb(url) +        headers.update(self.headers_cb(url)) +        return headers + + +def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, +                  timestamp=None): +    if timestamp: +        timestamp = str(timestamp) +    else: +        timestamp = None + +    client = oauth1.Client( +        consumer_key, +        client_secret=consumer_secret, +        resource_owner_key=token_key, +        resource_owner_secret=token_secret, +        signature_method=oauth1.SIGNATURE_PLAINTEXT, +        timestamp=timestamp) +    uri, signed_headers, body = client.sign(url) +    return signed_headers diff --git a/cloudinit/util.py b/cloudinit/util.py index 02ba654a..09e583f5 100644 --- a/cloudinit/util.py +++ b/cloudinit/util.py @@ -782,7 +782,8 @@ def read_file_or_url(url, timeout=5, retries=10,              code = e.errno              if e.errno == errno.ENOENT:                  code = url_helper.NOT_FOUND -            raise url_helper.UrlError(cause=e, code=code, headers=None) +            raise url_helper.UrlError(cause=e, code=code, headers=None, +                                      url=url)          return url_helper.FileResponse(file_path, contents=contents)      else:          return url_helper.readurl(url, diff --git a/doc/examples/cloud-config-reporting.txt b/doc/examples/cloud-config-reporting.txt new file mode 100644 index 00000000..ee00078f --- /dev/null +++ b/doc/examples/cloud-config-reporting.txt @@ -0,0 +1,17 @@ +#cloud-config +## +## The following sets up 2 reporting end points. +## A 'webhook' and a 'log' type. +## It also disables the built in default 'log' +reporting: +   smtest: +     type: webhook +     endpoint: "http://myhost:8000/" +     consumer_key: "ckey_foo" +     consumer_secret: "csecret_foo" +     token_key: "tkey_foo" +     token_secret: "tkey_foo" +   smlogger: +     type: log +     level: WARN +   log: null diff --git a/tests/unittests/test_datasource/test_maas.py b/tests/unittests/test_datasource/test_maas.py index f109bb04..eb97b692 100644 --- a/tests/unittests/test_datasource/test_maas.py +++ b/tests/unittests/test_datasource/test_maas.py @@ -141,7 +141,7 @@ class TestMAASDataSource(TestCase):          with mock.patch.object(url_helper, 'readurl',                                 side_effect=side_effect()) as mockobj:              userdata, metadata = DataSourceMAAS.read_maas_seed_url( -                my_seed, header_cb=my_headers_cb, version=my_ver) +                my_seed, version=my_ver)              self.assertEqual(b"foodata", userdata)              self.assertEqual(metadata['instance-id'], diff --git a/tests/unittests/test_registry.py b/tests/unittests/test_registry.py new file mode 100644 index 00000000..bcf01475 --- /dev/null +++ b/tests/unittests/test_registry.py @@ -0,0 +1,28 @@ +from cloudinit.registry import DictRegistry + +from .helpers import (mock, TestCase) + + +class TestDictRegistry(TestCase): + +    def test_added_item_included_in_output(self): +        registry = DictRegistry() +        item_key, item_to_register = 'test_key', mock.Mock() +        registry.register_item(item_key, item_to_register) +        self.assertEqual({item_key: item_to_register}, +                         registry.registered_items) + +    def test_registry_starts_out_empty(self): +        self.assertEqual({}, DictRegistry().registered_items) + +    def test_modifying_registered_items_isnt_exposed_to_other_callers(self): +        registry = DictRegistry() +        registry.registered_items['test_item'] = mock.Mock() +        self.assertEqual({}, registry.registered_items) + +    def test_keys_cannot_be_replaced(self): +        registry = DictRegistry() +        item_key = 'test_key' +        registry.register_item(item_key, mock.Mock()) +        self.assertRaises(ValueError, +                          registry.register_item, item_key, mock.Mock()) diff --git a/tests/unittests/test_reporting.py b/tests/unittests/test_reporting.py new file mode 100644 index 00000000..1a4ee8c4 --- /dev/null +++ b/tests/unittests/test_reporting.py @@ -0,0 +1,359 @@ +# Copyright 2015 Canonical Ltd. +# This file is part of cloud-init.  See LICENCE file for license information. +# +# vi: ts=4 expandtab + +from cloudinit import reporting +from cloudinit.reporting import handlers + +from .helpers import (mock, TestCase) + + +def _fake_registry(): +    return mock.Mock(registered_items={'a': mock.MagicMock(), +                                       'b': mock.MagicMock()}) + + +class TestReportStartEvent(TestCase): + +    @mock.patch('cloudinit.reporting.instantiated_handler_registry', +                new_callable=_fake_registry) +    def test_report_start_event_passes_something_with_as_string_to_handlers( +            self, instantiated_handler_registry): +        event_name, event_description = 'my_test_event', 'my description' +        reporting.report_start_event(event_name, event_description) +        expected_string_representation = ': '.join( +            ['start', event_name, event_description]) +        for _, handler in ( +                instantiated_handler_registry.registered_items.items()): +            self.assertEqual(1, handler.publish_event.call_count) +            event = handler.publish_event.call_args[0][0] +            self.assertEqual(expected_string_representation, event.as_string()) + + +class TestReportFinishEvent(TestCase): + +    def _report_finish_event(self, result=reporting.status.SUCCESS): +        event_name, event_description = 'my_test_event', 'my description' +        reporting.report_finish_event( +            event_name, event_description, result=result) +        return event_name, event_description + +    def assertHandlersPassedObjectWithAsString( +            self, handlers, expected_as_string): +        for _, handler in handlers.items(): +            self.assertEqual(1, handler.publish_event.call_count) +            event = handler.publish_event.call_args[0][0] +            self.assertEqual(expected_as_string, event.as_string()) + +    @mock.patch('cloudinit.reporting.instantiated_handler_registry', +                new_callable=_fake_registry) +    def test_report_finish_event_passes_something_with_as_string_to_handlers( +            self, instantiated_handler_registry): +        event_name, event_description = self._report_finish_event() +        expected_string_representation = ': '.join( +            ['finish', event_name, reporting.status.SUCCESS, +             event_description]) +        self.assertHandlersPassedObjectWithAsString( +            instantiated_handler_registry.registered_items, +            expected_string_representation) + +    @mock.patch('cloudinit.reporting.instantiated_handler_registry', +                new_callable=_fake_registry) +    def test_reporting_successful_finish_has_sensible_string_repr( +            self, instantiated_handler_registry): +        event_name, event_description = self._report_finish_event( +            result=reporting.status.SUCCESS) +        expected_string_representation = ': '.join( +            ['finish', event_name, reporting.status.SUCCESS, +             event_description]) +        self.assertHandlersPassedObjectWithAsString( +            instantiated_handler_registry.registered_items, +            expected_string_representation) + +    @mock.patch('cloudinit.reporting.instantiated_handler_registry', +                new_callable=_fake_registry) +    def test_reporting_unsuccessful_finish_has_sensible_string_repr( +            self, instantiated_handler_registry): +        event_name, event_description = self._report_finish_event( +            result=reporting.status.FAIL) +        expected_string_representation = ': '.join( +            ['finish', event_name, reporting.status.FAIL, event_description]) +        self.assertHandlersPassedObjectWithAsString( +            instantiated_handler_registry.registered_items, +            expected_string_representation) + +    def test_invalid_result_raises_attribute_error(self): +        self.assertRaises(ValueError, self._report_finish_event, ("BOGUS",)) + + +class TestReportingEvent(TestCase): + +    def test_as_string(self): +        event_type, name, description = 'test_type', 'test_name', 'test_desc' +        event = reporting.ReportingEvent(event_type, name, description) +        expected_string_representation = ': '.join( +            [event_type, name, description]) +        self.assertEqual(expected_string_representation, event.as_string()) + +    def test_as_dict(self): +        event_type, name, desc = 'test_type', 'test_name', 'test_desc' +        event = reporting.ReportingEvent(event_type, name, desc) +        self.assertEqual( +            {'event_type': event_type, 'name': name, 'description': desc}, +            event.as_dict()) + + +class TestFinishReportingEvent(TestCase): +    def test_as_has_result(self): +        result = reporting.status.SUCCESS +        name, desc = 'test_name', 'test_desc' +        event = reporting.FinishReportingEvent(name, desc, result) +        ret = event.as_dict() +        self.assertTrue('result' in ret) +        self.assertEqual(ret['result'], result) + + +class TestBaseReportingHandler(TestCase): + +    def test_base_reporting_handler_is_abstract(self): +        regexp = r".*abstract.*publish_event.*" +        self.assertRaisesRegexp(TypeError, regexp, handlers.ReportingHandler) + + +class TestLogHandler(TestCase): + +    @mock.patch.object(reporting.handlers.logging, 'getLogger') +    def test_appropriate_logger_used(self, getLogger): +        event_type, event_name = 'test_type', 'test_name' +        event = reporting.ReportingEvent(event_type, event_name, 'description') +        reporting.handlers.LogHandler().publish_event(event) +        self.assertEqual( +            [mock.call( +                'cloudinit.reporting.{0}.{1}'.format(event_type, event_name))], +            getLogger.call_args_list) + +    @mock.patch.object(reporting.handlers.logging, 'getLogger') +    def test_single_log_message_at_info_published(self, getLogger): +        event = reporting.ReportingEvent('type', 'name', 'description') +        reporting.handlers.LogHandler().publish_event(event) +        self.assertEqual(1, getLogger.return_value.info.call_count) + +    @mock.patch.object(reporting.handlers.logging, 'getLogger') +    def test_log_message_uses_event_as_string(self, getLogger): +        event = reporting.ReportingEvent('type', 'name', 'description') +        reporting.handlers.LogHandler().publish_event(event) +        self.assertIn(event.as_string(), +                      getLogger.return_value.info.call_args[0][0]) + + +class TestDefaultRegisteredHandler(TestCase): + +    def test_log_handler_registered_by_default(self): +        registered_items = ( +            reporting.instantiated_handler_registry.registered_items) +        for _, item in registered_items.items(): +            if isinstance(item, reporting.handlers.LogHandler): +                break +        else: +            self.fail('No reporting LogHandler registered by default.') + + +class TestReportingConfiguration(TestCase): + +    @mock.patch.object(reporting, 'instantiated_handler_registry') +    def test_empty_configuration_doesnt_add_handlers( +            self, instantiated_handler_registry): +        reporting.update_configuration({}) +        self.assertEqual( +            0, instantiated_handler_registry.register_item.call_count) + +    @mock.patch.object( +        reporting, 'instantiated_handler_registry', reporting.DictRegistry()) +    @mock.patch.object(reporting, 'available_handlers') +    def test_looks_up_handler_by_type_and_adds_it(self, available_handlers): +        handler_type_name = 'test_handler' +        handler_cls = mock.Mock() +        available_handlers.registered_items = {handler_type_name: handler_cls} +        handler_name = 'my_test_handler' +        reporting.update_configuration( +            {handler_name: {'type': handler_type_name}}) +        self.assertEqual( +            {handler_name: handler_cls.return_value}, +            reporting.instantiated_handler_registry.registered_items) + +    @mock.patch.object( +        reporting, 'instantiated_handler_registry', reporting.DictRegistry()) +    @mock.patch.object(reporting, 'available_handlers') +    def test_uses_non_type_parts_of_config_dict_as_kwargs( +            self, available_handlers): +        handler_type_name = 'test_handler' +        handler_cls = mock.Mock() +        available_handlers.registered_items = {handler_type_name: handler_cls} +        extra_kwargs = {'foo': 'bar', 'bar': 'baz'} +        handler_config = extra_kwargs.copy() +        handler_config.update({'type': handler_type_name}) +        handler_name = 'my_test_handler' +        reporting.update_configuration({handler_name: handler_config}) +        self.assertEqual( +            handler_cls.return_value, +            reporting.instantiated_handler_registry.registered_items[ +                handler_name]) +        self.assertEqual([mock.call(**extra_kwargs)], +                         handler_cls.call_args_list) + +    @mock.patch.object( +        reporting, 'instantiated_handler_registry', reporting.DictRegistry()) +    @mock.patch.object(reporting, 'available_handlers') +    def test_handler_config_not_modified(self, available_handlers): +        handler_type_name = 'test_handler' +        handler_cls = mock.Mock() +        available_handlers.registered_items = {handler_type_name: handler_cls} +        handler_config = {'type': handler_type_name, 'foo': 'bar'} +        expected_handler_config = handler_config.copy() +        reporting.update_configuration({'my_test_handler': handler_config}) +        self.assertEqual(expected_handler_config, handler_config) + +    @mock.patch.object( +        reporting, 'instantiated_handler_registry', reporting.DictRegistry()) +    @mock.patch.object(reporting, 'available_handlers') +    def test_handlers_removed_if_falseish_specified(self, available_handlers): +        handler_type_name = 'test_handler' +        handler_cls = mock.Mock() +        available_handlers.registered_items = {handler_type_name: handler_cls} +        handler_name = 'my_test_handler' +        reporting.update_configuration( +            {handler_name: {'type': handler_type_name}}) +        self.assertEqual( +            1, len(reporting.instantiated_handler_registry.registered_items)) +        reporting.update_configuration({handler_name: None}) +        self.assertEqual( +            0, len(reporting.instantiated_handler_registry.registered_items)) + + +class TestReportingEventStack(TestCase): +    @mock.patch('cloudinit.reporting.report_finish_event') +    @mock.patch('cloudinit.reporting.report_start_event') +    def test_start_and_finish_success(self, report_start, report_finish): +        with reporting.ReportEventStack(name="myname", description="mydesc"): +            pass +        self.assertEqual( +            [mock.call('myname', 'mydesc')], report_start.call_args_list) +        self.assertEqual( +            [mock.call('myname', 'mydesc', reporting.status.SUCCESS)], +            report_finish.call_args_list) + +    @mock.patch('cloudinit.reporting.report_finish_event') +    @mock.patch('cloudinit.reporting.report_start_event') +    def test_finish_exception_defaults_fail(self, report_start, report_finish): +        name = "myname" +        desc = "mydesc" +        try: +            with reporting.ReportEventStack(name, description=desc): +                raise ValueError("This didnt work") +        except ValueError: +            pass +        self.assertEqual([mock.call(name, desc)], report_start.call_args_list) +        self.assertEqual( +            [mock.call(name, desc, reporting.status.FAIL)], +            report_finish.call_args_list) + +    @mock.patch('cloudinit.reporting.report_finish_event') +    @mock.patch('cloudinit.reporting.report_start_event') +    def test_result_on_exception_used(self, report_start, report_finish): +        name = "myname" +        desc = "mydesc" +        try: +            with reporting.ReportEventStack( +                    name, desc, result_on_exception=reporting.status.WARN): +                raise ValueError("This didnt work") +        except ValueError: +            pass +        self.assertEqual([mock.call(name, desc)], report_start.call_args_list) +        self.assertEqual( +            [mock.call(name, desc, reporting.status.WARN)], +            report_finish.call_args_list) + +    @mock.patch('cloudinit.reporting.report_start_event') +    def test_child_fullname_respects_parent(self, report_start): +        parent_name = "topname" +        c1_name = "c1name" +        c2_name = "c2name" +        c2_expected_fullname = '/'.join([parent_name, c1_name, c2_name]) +        c1_expected_fullname = '/'.join([parent_name, c1_name]) + +        parent = reporting.ReportEventStack(parent_name, "topdesc") +        c1 = reporting.ReportEventStack(c1_name, "c1desc", parent=parent) +        c2 = reporting.ReportEventStack(c2_name, "c2desc", parent=c1) +        with c1: +            report_start.assert_called_with(c1_expected_fullname, "c1desc") +            with c2: +                report_start.assert_called_with(c2_expected_fullname, "c2desc") + +    @mock.patch('cloudinit.reporting.report_finish_event') +    @mock.patch('cloudinit.reporting.report_start_event') +    def test_child_result_bubbles_up(self, report_start, report_finish): +        parent = reporting.ReportEventStack("topname", "topdesc") +        child = reporting.ReportEventStack("c_name", "c_desc", parent=parent) +        with parent: +            with child: +                child.result = reporting.status.WARN + +        report_finish.assert_called_with( +            "topname", "topdesc", reporting.status.WARN) + +    @mock.patch('cloudinit.reporting.report_finish_event') +    def test_message_used_in_finish(self, report_finish): +        with reporting.ReportEventStack("myname", "mydesc", +                                        message="mymessage"): +            pass +        self.assertEqual( +            [mock.call("myname", "mymessage", reporting.status.SUCCESS)], +            report_finish.call_args_list) + +    @mock.patch('cloudinit.reporting.report_finish_event') +    def test_message_updatable(self, report_finish): +        with reporting.ReportEventStack("myname", "mydesc") as c: +            c.message = "all good" +        self.assertEqual( +            [mock.call("myname", "all good", reporting.status.SUCCESS)], +            report_finish.call_args_list) + +    @mock.patch('cloudinit.reporting.report_start_event') +    @mock.patch('cloudinit.reporting.report_finish_event') +    def test_reporting_disabled_does_not_report_events( +            self, report_start, report_finish): +        with reporting.ReportEventStack("a", "b", reporting_enabled=False): +            pass +        self.assertEqual(report_start.call_count, 0) +        self.assertEqual(report_finish.call_count, 0) + +    @mock.patch('cloudinit.reporting.report_start_event') +    @mock.patch('cloudinit.reporting.report_finish_event') +    def test_reporting_child_default_to_parent( +            self, report_start, report_finish): +        parent = reporting.ReportEventStack( +            "pname", "pdesc", reporting_enabled=False) +        child = reporting.ReportEventStack("cname", "cdesc", parent=parent) +        with parent: +            with child: +                pass +            pass +        self.assertEqual(report_start.call_count, 0) +        self.assertEqual(report_finish.call_count, 0) + +    def test_reporting_event_has_sane_repr(self): +        myrep = reporting.ReportEventStack("fooname", "foodesc", +                                           reporting_enabled=True).__repr__() +        self.assertIn("fooname", myrep) +        self.assertIn("foodesc", myrep) +        self.assertIn("True", myrep) + +    def test_set_invalid_result_raises_value_error(self): +        f = reporting.ReportEventStack("myname", "mydesc") +        self.assertRaises(ValueError, setattr, f, "result", "BOGUS") + + +class TestStatusAccess(TestCase): +    def test_invalid_status_access_raises_value_error(self): +        self.assertRaises(AttributeError, getattr, reporting.status, "BOGUS") | 
