diff options
Diffstat (limited to 'cloudinit/url_helper.py')
| -rw-r--r-- | cloudinit/url_helper.py | 140 | 
1 files changed, 135 insertions, 5 deletions
| diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 0e65f431..dca4cc85 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -25,6 +25,10 @@ import time  import six  import requests +import oauthlib.oauth1 as oauth1 +import os +import json +from functools import partial  from requests import exceptions  from six.moves.urllib.parse import ( @@ -147,13 +151,14 @@ class UrlResponse(object):  class UrlError(IOError): -    def __init__(self, cause, code=None, headers=None): +    def __init__(self, cause, code=None, headers=None, url=None):          IOError.__init__(self, str(cause))          self.cause = cause          self.code = code          self.headers = headers          if self.headers is None:              self.headers = {} +        self.url = url  def _get_ssl_args(url, ssl_details): @@ -247,9 +252,10 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1,                      and hasattr(e, 'response')  # This appeared in v 0.10.8                      and hasattr(e.response, 'status_code')):                  excps.append(UrlError(e, code=e.response.status_code, -                                      headers=e.response.headers)) +                                      headers=e.response.headers, +                                      url=url))              else: -                excps.append(UrlError(e)) +                excps.append(UrlError(e, url=url))                  if SSL_ENABLED and isinstance(e, exceptions.SSLError):                      # ssl exceptions are not going to get fixed by waiting a                      # few seconds @@ -333,11 +339,11 @@ def wait_for_url(urls, max_wait=None, timeout=None,                  if not response.contents:                      reason = "empty response [%s]" % (response.code)                      url_exc = UrlError(ValueError(reason), code=response.code, -                                       headers=response.headers) +                                       headers=response.headers, url=url)                  elif not response.ok():                      reason = "bad status code [%s]" % (response.code)                      url_exc = UrlError(ValueError(reason), code=response.code, -                                       headers=response.headers) +                                       headers=response.headers, url=url)                  else:                      return url              except UrlError as e: @@ -368,3 +374,127 @@ def wait_for_url(urls, max_wait=None, timeout=None,          time.sleep(sleep_time)      return False + + +class OauthUrlHelper(object): +    def __init__(self, consumer_key=None, token_key=None, +                 token_secret=None, consumer_secret=None, +                 skew_data_file="/run/oauth_skew.json"): +        self.consumer_key = consumer_key +        self.consumer_secret = consumer_secret or "" +        self.token_key = token_key +        self.token_secret = token_secret +        self.skew_data_file = skew_data_file +        self._do_oauth = True +        self.skew_change_limit = 5 +        required = (self.token_key, self.token_secret, self.consumer_key) +        if not any(required): +            self._do_oauth = False +        elif not all(required): +            raise ValueError("all or none of token_key, token_secret, or " +                             "consumer_key can be set") + +        old = self.read_skew_file() +        self.skew_data = old or {} + +    def read_skew_file(self): +        if self.skew_data_file and os.path.isfile(self.skew_data_file): +            with open(self.skew_data_file, mode="r") as fp: +                return json.load(fp.read()) +        return None + +    def update_skew_file(self, host, value): +        # this is not atomic +        if not self.skew_data_file: +            return +        cur = self.read_skew_file() +        cur[host] = value +        with open(self.skew_data_file, mode="w") as fp: +            fp.write(json.dumps(cur)) + +    def exception_cb(self, msg, exception): +        if not (isinstance(exception, UrlError) and +                (exception.code == 403 or exception.code == 401)): +            return + +        if 'date' not in exception.headers: +            LOG.warn("Missing header 'date' in %s response", exception.code) +            return + +        date = exception.headers['date'] +        try: +            remote_time = time.mktime(parsedate(date)) +        except Exception as e: +            LOG.warn("Failed to convert datetime '%s': %s", date, e) +            return + +        skew = int(remote_time - time.time()) +        host = urlparse(exception.url).netloc +        old_skew = self.skew_data.get(host, 0) +        if abs(old_skew - skew) > self.skew_change_limit: +            self.update_skew_file(host, skew) +            LOG.warn("Setting oauth clockskew for %s to %d", host, skew) +        skew_data[host] = skew + +        return + +    def headers_cb(self, url): +        if not self._do_oauth: +            return {} + +        timestamp = None +        host = urlparse(url).netloc +        if self.skew_data and host in self.skew_data: +            timestamp = int(time.time()) + self.skew_data[host] + +        return oauth_headers( +            url=url, consumer_key=self.consumer_key, +            token_key=self.token_key, token_secret=self.token_secret, +            consumer_secret=self.consumer_secret, timestamp=timestamp) + +    def _wrapped(self, wrapped_func, args, kwargs): +        kwargs['headers_cb'] = partial( +            self._headers_cb, kwargs.get('headers_cb')) +        kwargs['exception_cb'] = partial( +            self._exception_cb, kwargs.get('exception_cb')) +        return wrapped_func(*args, **kwargs) + +    def wait_for_url(self, *args, **kwargs): +        return self._wrapped(wait_for_url, args, kwargs) + +    def readurl(self, *args, **kwargs): +        return self._wrapped(readurl, args, kwargs) + +    def _exception_cb(self, extra_exception_cb, msg, exception): +        ret = None +        try: +            if extra_exception_cb: +                ret = extra_exception_cb(msg, exception) +        finally: +                self.exception_cb(msg, exception) +        return ret + +    def _headers_cb(self, extra_headers_cb, url): +        headers = {} +        if extra_headers_cb: +            headers = extra_headers_cb(url) +        headers.update(self.headers_cb(url)) +        return headers + + +def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, +                  timestamp=None): +    if timestamp: +        timestamp = str(timestamp) +    else: +        timestamp = None + +    client = oauth1.Client( +        consumer_key, +        client_secret=consumer_secret, +        resource_owner_key=token_key, +        resource_owner_secret=token_secret, +        signature_method=oauth1.SIGNATURE_PLAINTEXT, +        timestamp=timestamp) +    uri, signed_headers, body = client.sign(url) +    return signed_headers | 
