summaryrefslogtreecommitdiff
path: root/cloudinit/url_helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'cloudinit/url_helper.py')
-rw-r--r--cloudinit/url_helper.py196
1 files changed, 172 insertions, 24 deletions
diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py
index 3074dd08..936f7da5 100644
--- a/cloudinit/url_helper.py
+++ b/cloudinit/url_helper.py
@@ -20,21 +20,33 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-import httplib
+import json
+import os
+import requests
+import six
import time
-import urllib
-import requests
+from email.utils import parsedate
+from functools import partial
from requests import exceptions
+import oauthlib.oauth1 as oauth1
-from urlparse import (urlparse, urlunparse)
+from six.moves.urllib.parse import (
+ urlparse, urlunparse,
+ quote as urlquote)
from cloudinit import log as logging
from cloudinit import version
LOG = logging.getLogger(__name__)
-NOT_FOUND = httplib.NOT_FOUND
+if six.PY2:
+ import httplib
+ NOT_FOUND = httplib.NOT_FOUND
+else:
+ import http.client
+ NOT_FOUND = http.client.NOT_FOUND
+
# Check if requests has ssl support (added in requests >= 0.8.8)
SSL_ENABLED = False
@@ -70,7 +82,7 @@ def combine_url(base, *add_ons):
path = url_parsed[2]
if path and not path.endswith("/"):
path += "/"
- path += urllib.quote(str(add_on), safe="/:")
+ path += urlquote(str(add_on), safe="/:")
url_parsed[2] = path
return urlunparse(url_parsed)
@@ -135,17 +147,18 @@ class UrlResponse(object):
return self._response.status_code
def __str__(self):
- return self.contents
+ return self._response.text
class UrlError(IOError):
- def __init__(self, cause, code=None, headers=None):
+ def __init__(self, cause, code=None, headers=None, url=None):
IOError.__init__(self, str(cause))
self.cause = cause
self.code = code
self.headers = headers
if self.headers is None:
self.headers = {}
+ self.url = url
def _get_ssl_args(url, ssl_details):
@@ -198,10 +211,14 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1,
manual_tries = 1
if retries:
manual_tries = max(int(retries) + 1, 1)
- if not headers:
- headers = {
- 'User-Agent': 'Cloud-Init/%s' % (version.version_string()),
- }
+
+ def_headers = {
+ 'User-Agent': 'Cloud-Init/%s' % (version.version_string()),
+ }
+ if headers:
+ def_headers.update(headers)
+ headers = def_headers
+
if not headers_cb:
def _cb(url):
return headers
@@ -235,18 +252,21 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1,
# attrs
return UrlResponse(r)
except exceptions.RequestException as e:
- if (isinstance(e, (exceptions.HTTPError))
- and hasattr(e, 'response') # This appeared in v 0.10.8
- and hasattr(e.response, 'status_code')):
+ if (isinstance(e, (exceptions.HTTPError)) and
+ hasattr(e, 'response') and # This appeared in v 0.10.8
+ hasattr(e.response, 'status_code')):
excps.append(UrlError(e, code=e.response.status_code,
- headers=e.response.headers))
+ headers=e.response.headers,
+ url=url))
else:
- excps.append(UrlError(e))
+ excps.append(UrlError(e, url=url))
if SSL_ENABLED and isinstance(e, exceptions.SSLError):
# ssl exceptions are not going to get fixed by waiting a
# few seconds
break
- if exception_cb and not exception_cb(req_args.copy(), excps[-1]):
+ if exception_cb and exception_cb(req_args.copy(), excps[-1]):
+ # if an exception callback was given it should return None
+ # a true-ish value means to break and re-raise the exception
break
if i + 1 < manual_tries and sec_between > 0:
LOG.debug("Please wait %s seconds while we wait to try again",
@@ -313,7 +333,7 @@ def wait_for_url(urls, max_wait=None, timeout=None,
timeout = int((start_time + max_wait) - now)
reason = ""
- e = None
+ url_exc = None
try:
if headers_cb is not None:
headers = headers_cb(url)
@@ -324,18 +344,20 @@ def wait_for_url(urls, max_wait=None, timeout=None,
check_status=False)
if not response.contents:
reason = "empty response [%s]" % (response.code)
- e = UrlError(ValueError(reason),
- code=response.code, headers=response.headers)
+ url_exc = UrlError(ValueError(reason), code=response.code,
+ headers=response.headers, url=url)
elif not response.ok():
reason = "bad status code [%s]" % (response.code)
- e = UrlError(ValueError(reason),
- code=response.code, headers=response.headers)
+ url_exc = UrlError(ValueError(reason), code=response.code,
+ headers=response.headers, url=url)
else:
return url
except UrlError as e:
reason = "request error [%s]" % e
+ url_exc = e
except Exception as e:
reason = "unexpected error [%s]" % e
+ url_exc = e
time_taken = int(time.time() - start_time)
status_msg = "Calling '%s' failed [%s/%ss]: %s" % (url,
@@ -347,7 +369,7 @@ def wait_for_url(urls, max_wait=None, timeout=None,
# This can be used to alter the headers that will be sent
# in the future, for example this is what the MAAS datasource
# does.
- exception_cb(msg=status_msg, exception=e)
+ exception_cb(msg=status_msg, exception=url_exc)
if timeup(max_wait, start_time):
break
@@ -358,3 +380,129 @@ def wait_for_url(urls, max_wait=None, timeout=None,
time.sleep(sleep_time)
return False
+
+
+class OauthUrlHelper(object):
+ def __init__(self, consumer_key=None, token_key=None,
+ token_secret=None, consumer_secret=None,
+ skew_data_file="/run/oauth_skew.json"):
+ self.consumer_key = consumer_key
+ self.consumer_secret = consumer_secret or ""
+ self.token_key = token_key
+ self.token_secret = token_secret
+ self.skew_data_file = skew_data_file
+ self._do_oauth = True
+ self.skew_change_limit = 5
+ required = (self.token_key, self.token_secret, self.consumer_key)
+ if not any(required):
+ self._do_oauth = False
+ elif not all(required):
+ raise ValueError("all or none of token_key, token_secret, or "
+ "consumer_key can be set")
+
+ old = self.read_skew_file()
+ self.skew_data = old or {}
+
+ def read_skew_file(self):
+ if self.skew_data_file and os.path.isfile(self.skew_data_file):
+ with open(self.skew_data_file, mode="r") as fp:
+ return json.load(fp)
+ return None
+
+ def update_skew_file(self, host, value):
+ # this is not atomic
+ if not self.skew_data_file:
+ return
+ cur = self.read_skew_file()
+ if cur is None:
+ cur = {}
+ cur[host] = value
+ with open(self.skew_data_file, mode="w") as fp:
+ fp.write(json.dumps(cur))
+
+ def exception_cb(self, msg, exception):
+ if not (isinstance(exception, UrlError) and
+ (exception.code == 403 or exception.code == 401)):
+ return
+
+ if 'date' not in exception.headers:
+ LOG.warn("Missing header 'date' in %s response", exception.code)
+ return
+
+ date = exception.headers['date']
+ try:
+ remote_time = time.mktime(parsedate(date))
+ except Exception as e:
+ LOG.warn("Failed to convert datetime '%s': %s", date, e)
+ return
+
+ skew = int(remote_time - time.time())
+ host = urlparse(exception.url).netloc
+ old_skew = self.skew_data.get(host, 0)
+ if abs(old_skew - skew) > self.skew_change_limit:
+ self.update_skew_file(host, skew)
+ LOG.warn("Setting oauth clockskew for %s to %d", host, skew)
+ self.skew_data[host] = skew
+
+ return
+
+ def headers_cb(self, url):
+ if not self._do_oauth:
+ return {}
+
+ timestamp = None
+ host = urlparse(url).netloc
+ if self.skew_data and host in self.skew_data:
+ timestamp = int(time.time()) + self.skew_data[host]
+
+ return oauth_headers(
+ url=url, consumer_key=self.consumer_key,
+ token_key=self.token_key, token_secret=self.token_secret,
+ consumer_secret=self.consumer_secret, timestamp=timestamp)
+
+ def _wrapped(self, wrapped_func, args, kwargs):
+ kwargs['headers_cb'] = partial(
+ self._headers_cb, kwargs.get('headers_cb'))
+ kwargs['exception_cb'] = partial(
+ self._exception_cb, kwargs.get('exception_cb'))
+ return wrapped_func(*args, **kwargs)
+
+ def wait_for_url(self, *args, **kwargs):
+ return self._wrapped(wait_for_url, args, kwargs)
+
+ def readurl(self, *args, **kwargs):
+ return self._wrapped(readurl, args, kwargs)
+
+ def _exception_cb(self, extra_exception_cb, msg, exception):
+ ret = None
+ try:
+ if extra_exception_cb:
+ ret = extra_exception_cb(msg, exception)
+ finally:
+ self.exception_cb(msg, exception)
+ return ret
+
+ def _headers_cb(self, extra_headers_cb, url):
+ headers = {}
+ if extra_headers_cb:
+ headers = extra_headers_cb(url)
+ headers.update(self.headers_cb(url))
+ return headers
+
+
+def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret,
+ timestamp=None):
+ if timestamp:
+ timestamp = str(timestamp)
+ else:
+ timestamp = None
+
+ client = oauth1.Client(
+ consumer_key,
+ client_secret=consumer_secret,
+ resource_owner_key=token_key,
+ resource_owner_secret=token_secret,
+ signature_method=oauth1.SIGNATURE_PLAINTEXT,
+ timestamp=timestamp)
+ uri, signed_headers, body = client.sign(url)
+ return signed_headers