diff options
Diffstat (limited to 'cloudinit/url_helper.py')
-rw-r--r-- | cloudinit/url_helper.py | 196 |
1 files changed, 172 insertions, 24 deletions
diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 3074dd08..936f7da5 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -20,21 +20,33 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. -import httplib +import json +import os +import requests +import six import time -import urllib -import requests +from email.utils import parsedate +from functools import partial from requests import exceptions +import oauthlib.oauth1 as oauth1 -from urlparse import (urlparse, urlunparse) +from six.moves.urllib.parse import ( + urlparse, urlunparse, + quote as urlquote) from cloudinit import log as logging from cloudinit import version LOG = logging.getLogger(__name__) -NOT_FOUND = httplib.NOT_FOUND +if six.PY2: + import httplib + NOT_FOUND = httplib.NOT_FOUND +else: + import http.client + NOT_FOUND = http.client.NOT_FOUND + # Check if requests has ssl support (added in requests >= 0.8.8) SSL_ENABLED = False @@ -70,7 +82,7 @@ def combine_url(base, *add_ons): path = url_parsed[2] if path and not path.endswith("/"): path += "/" - path += urllib.quote(str(add_on), safe="/:") + path += urlquote(str(add_on), safe="/:") url_parsed[2] = path return urlunparse(url_parsed) @@ -135,17 +147,18 @@ class UrlResponse(object): return self._response.status_code def __str__(self): - return self.contents + return self._response.text class UrlError(IOError): - def __init__(self, cause, code=None, headers=None): + def __init__(self, cause, code=None, headers=None, url=None): IOError.__init__(self, str(cause)) self.cause = cause self.code = code self.headers = headers if self.headers is None: self.headers = {} + self.url = url def _get_ssl_args(url, ssl_details): @@ -198,10 +211,14 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1, manual_tries = 1 if retries: manual_tries = max(int(retries) + 1, 1) - if not headers: - headers = { - 'User-Agent': 'Cloud-Init/%s' % (version.version_string()), - } + + def_headers = { + 'User-Agent': 'Cloud-Init/%s' % (version.version_string()), + } + if headers: + def_headers.update(headers) + headers = def_headers + if not headers_cb: def _cb(url): return headers @@ -235,18 +252,21 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1, # attrs return UrlResponse(r) except exceptions.RequestException as e: - if (isinstance(e, (exceptions.HTTPError)) - and hasattr(e, 'response') # This appeared in v 0.10.8 - and hasattr(e.response, 'status_code')): + if (isinstance(e, (exceptions.HTTPError)) and + hasattr(e, 'response') and # This appeared in v 0.10.8 + hasattr(e.response, 'status_code')): excps.append(UrlError(e, code=e.response.status_code, - headers=e.response.headers)) + headers=e.response.headers, + url=url)) else: - excps.append(UrlError(e)) + excps.append(UrlError(e, url=url)) if SSL_ENABLED and isinstance(e, exceptions.SSLError): # ssl exceptions are not going to get fixed by waiting a # few seconds break - if exception_cb and not exception_cb(req_args.copy(), excps[-1]): + if exception_cb and exception_cb(req_args.copy(), excps[-1]): + # if an exception callback was given it should return None + # a true-ish value means to break and re-raise the exception break if i + 1 < manual_tries and sec_between > 0: LOG.debug("Please wait %s seconds while we wait to try again", @@ -313,7 +333,7 @@ def wait_for_url(urls, max_wait=None, timeout=None, timeout = int((start_time + max_wait) - now) reason = "" - e = None + url_exc = None try: if headers_cb is not None: headers = headers_cb(url) @@ -324,18 +344,20 @@ def wait_for_url(urls, max_wait=None, timeout=None, check_status=False) if not response.contents: reason = "empty response [%s]" % (response.code) - e = UrlError(ValueError(reason), - code=response.code, headers=response.headers) + url_exc = UrlError(ValueError(reason), code=response.code, + headers=response.headers, url=url) elif not response.ok(): reason = "bad status code [%s]" % (response.code) - e = UrlError(ValueError(reason), - code=response.code, headers=response.headers) + url_exc = UrlError(ValueError(reason), code=response.code, + headers=response.headers, url=url) else: return url except UrlError as e: reason = "request error [%s]" % e + url_exc = e except Exception as e: reason = "unexpected error [%s]" % e + url_exc = e time_taken = int(time.time() - start_time) status_msg = "Calling '%s' failed [%s/%ss]: %s" % (url, @@ -347,7 +369,7 @@ def wait_for_url(urls, max_wait=None, timeout=None, # This can be used to alter the headers that will be sent # in the future, for example this is what the MAAS datasource # does. - exception_cb(msg=status_msg, exception=e) + exception_cb(msg=status_msg, exception=url_exc) if timeup(max_wait, start_time): break @@ -358,3 +380,129 @@ def wait_for_url(urls, max_wait=None, timeout=None, time.sleep(sleep_time) return False + + +class OauthUrlHelper(object): + def __init__(self, consumer_key=None, token_key=None, + token_secret=None, consumer_secret=None, + skew_data_file="/run/oauth_skew.json"): + self.consumer_key = consumer_key + self.consumer_secret = consumer_secret or "" + self.token_key = token_key + self.token_secret = token_secret + self.skew_data_file = skew_data_file + self._do_oauth = True + self.skew_change_limit = 5 + required = (self.token_key, self.token_secret, self.consumer_key) + if not any(required): + self._do_oauth = False + elif not all(required): + raise ValueError("all or none of token_key, token_secret, or " + "consumer_key can be set") + + old = self.read_skew_file() + self.skew_data = old or {} + + def read_skew_file(self): + if self.skew_data_file and os.path.isfile(self.skew_data_file): + with open(self.skew_data_file, mode="r") as fp: + return json.load(fp) + return None + + def update_skew_file(self, host, value): + # this is not atomic + if not self.skew_data_file: + return + cur = self.read_skew_file() + if cur is None: + cur = {} + cur[host] = value + with open(self.skew_data_file, mode="w") as fp: + fp.write(json.dumps(cur)) + + def exception_cb(self, msg, exception): + if not (isinstance(exception, UrlError) and + (exception.code == 403 or exception.code == 401)): + return + + if 'date' not in exception.headers: + LOG.warn("Missing header 'date' in %s response", exception.code) + return + + date = exception.headers['date'] + try: + remote_time = time.mktime(parsedate(date)) + except Exception as e: + LOG.warn("Failed to convert datetime '%s': %s", date, e) + return + + skew = int(remote_time - time.time()) + host = urlparse(exception.url).netloc + old_skew = self.skew_data.get(host, 0) + if abs(old_skew - skew) > self.skew_change_limit: + self.update_skew_file(host, skew) + LOG.warn("Setting oauth clockskew for %s to %d", host, skew) + self.skew_data[host] = skew + + return + + def headers_cb(self, url): + if not self._do_oauth: + return {} + + timestamp = None + host = urlparse(url).netloc + if self.skew_data and host in self.skew_data: + timestamp = int(time.time()) + self.skew_data[host] + + return oauth_headers( + url=url, consumer_key=self.consumer_key, + token_key=self.token_key, token_secret=self.token_secret, + consumer_secret=self.consumer_secret, timestamp=timestamp) + + def _wrapped(self, wrapped_func, args, kwargs): + kwargs['headers_cb'] = partial( + self._headers_cb, kwargs.get('headers_cb')) + kwargs['exception_cb'] = partial( + self._exception_cb, kwargs.get('exception_cb')) + return wrapped_func(*args, **kwargs) + + def wait_for_url(self, *args, **kwargs): + return self._wrapped(wait_for_url, args, kwargs) + + def readurl(self, *args, **kwargs): + return self._wrapped(readurl, args, kwargs) + + def _exception_cb(self, extra_exception_cb, msg, exception): + ret = None + try: + if extra_exception_cb: + ret = extra_exception_cb(msg, exception) + finally: + self.exception_cb(msg, exception) + return ret + + def _headers_cb(self, extra_headers_cb, url): + headers = {} + if extra_headers_cb: + headers = extra_headers_cb(url) + headers.update(self.headers_cb(url)) + return headers + + +def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret, + timestamp=None): + if timestamp: + timestamp = str(timestamp) + else: + timestamp = None + + client = oauth1.Client( + consumer_key, + client_secret=consumer_secret, + resource_owner_key=token_key, + resource_owner_secret=token_secret, + signature_method=oauth1.SIGNATURE_PLAINTEXT, + timestamp=timestamp) + uri, signed_headers, body = client.sign(url) + return signed_headers |