From 6062595b83e08e0f12e1fe6d8e367d8db9d91ef8 Mon Sep 17 00:00:00 2001
From: Chad Smith <chad.smith@canonical.com>
Date: Tue, 13 Nov 2018 03:14:58 +0000
Subject: azure: retry imds polling on requests.Timeout

There is an infrequent race when the booting instance can hit the IMDS
service before it is fully available. This results in a
requests.ConnectTimeout being raised.
Azure's retry_callback logic now retries on either 404s or Timeouts.

LP:1800223
---
 cloudinit/sources/DataSourceAzure.py          | 18 +++-------------
 cloudinit/tests/test_url_helper.py            | 25 +++++++++++++++++++++-
 cloudinit/url_helper.py                       | 14 +++++++++++++
 tests/unittests/test_datasource/test_azure.py | 30 +++++++++++++++++++++++++++
 4 files changed, 71 insertions(+), 16 deletions(-)

diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py
index 6e1797ea..9e8a1a8b 100644
--- a/cloudinit/sources/DataSourceAzure.py
+++ b/cloudinit/sources/DataSourceAzure.py
@@ -22,7 +22,7 @@ from cloudinit.event import EventType
 from cloudinit.net.dhcp import EphemeralDHCPv4
 from cloudinit import sources
 from cloudinit.sources.helpers.azure import get_metadata_from_fabric
-from cloudinit.url_helper import readurl, UrlError
+from cloudinit.url_helper import UrlError, readurl, retry_on_url_exc
 from cloudinit import util
 
 LOG = logging.getLogger(__name__)
@@ -526,13 +526,6 @@ class DataSourceAzure(sources.DataSource):
         report_ready = bool(not os.path.isfile(REPORTED_READY_MARKER_FILE))
         LOG.debug("Start polling IMDS")
 
-        def exc_cb(msg, exception):
-            if isinstance(exception, UrlError) and exception.code == 404:
-                return True
-            # If we get an exception while trying to call IMDS, we
-            # call DHCP and setup the ephemeral network to acquire the new IP.
-            return False
-
         while True:
             try:
                 # Save our EphemeralDHCPv4 context so we avoid repeated dhcp
@@ -547,7 +540,7 @@ class DataSourceAzure(sources.DataSource):
                     self._report_ready(lease=lease)
                     report_ready = False
                 return readurl(url, timeout=1, headers=headers,
-                               exception_cb=exc_cb, infinite=True,
+                               exception_cb=retry_on_url_exc, infinite=True,
                                log_req_resp=False).contents
             except UrlError:
                 # Teardown our EphemeralDHCPv4 context on failure as we retry
@@ -1187,17 +1180,12 @@ def get_metadata_from_imds(fallback_nic, retries):
 
 def _get_metadata_from_imds(retries):
 
-    def retry_on_url_error(msg, exception):
-        if isinstance(exception, UrlError) and exception.code == 404:
-            return True  # Continue retries
-        return False  # Stop retries on all other exceptions
-
     url = IMDS_URL + "instance?api-version=2017-12-01"
     headers = {"Metadata": "true"}
     try:
         response = readurl(
             url, timeout=1, headers=headers, retries=retries,
-            exception_cb=retry_on_url_error)
+            exception_cb=retry_on_url_exc)
     except Exception as e:
         LOG.debug('Ignoring IMDS instance metadata: %s', e)
         return {}
diff --git a/cloudinit/tests/test_url_helper.py b/cloudinit/tests/test_url_helper.py
index 113249d9..aa9f3ec1 100644
--- a/cloudinit/tests/test_url_helper.py
+++ b/cloudinit/tests/test_url_helper.py
@@ -1,10 +1,12 @@
 # This file is part of cloud-init. See LICENSE file for license information.
 
-from cloudinit.url_helper import oauth_headers, read_file_or_url
+from cloudinit.url_helper import (
+    NOT_FOUND, UrlError, oauth_headers, read_file_or_url, retry_on_url_exc)
 from cloudinit.tests.helpers import CiTestCase, mock, skipIf
 from cloudinit import util
 
 import httpretty
+import requests
 
 
 try:
@@ -64,3 +66,24 @@ class TestReadFileOrUrl(CiTestCase):
         result = read_file_or_url(url)
         self.assertEqual(result.contents, data)
         self.assertEqual(str(result), data.decode('utf-8'))
+
+
+class TestRetryOnUrlExc(CiTestCase):
+
+    def test_do_not_retry_non_urlerror(self):
+        """When exception is not UrlError return False."""
+        myerror = IOError('something unexcpected')
+        self.assertFalse(retry_on_url_exc(msg='', exc=myerror))
+
+    def test_perform_retries_on_not_found(self):
+        """When exception is UrlError with a 404 status code return True."""
+        myerror = UrlError(cause=RuntimeError(
+            'something was not found'), code=NOT_FOUND)
+        self.assertTrue(retry_on_url_exc(msg='', exc=myerror))
+
+    def test_perform_retries_on_timeout(self):
+        """When exception is a requests.Timout return True."""
+        myerror = UrlError(cause=requests.Timeout('something timed out'))
+        self.assertTrue(retry_on_url_exc(msg='', exc=myerror))
+
+# vi: ts=4 expandtab
diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py
index cf57dbd5..396d69ae 100644
--- a/cloudinit/url_helper.py
+++ b/cloudinit/url_helper.py
@@ -554,4 +554,18 @@ def oauth_headers(url, consumer_key, token_key, token_secret, consumer_secret,
     _uri, signed_headers, _body = client.sign(url)
     return signed_headers
 
+
+def retry_on_url_exc(msg, exc):
+    """readurl exception_cb that will retry on NOT_FOUND and Timeout.
+
+    Returns False to raise the exception from readurl, True to retry.
+    """
+    if not isinstance(exc, UrlError):
+        return False
+    if exc.code == NOT_FOUND:
+        return True
+    if exc.cause and isinstance(exc.cause, requests.Timeout):
+        return True
+    return False
+
 # vi: ts=4 expandtab
diff --git a/tests/unittests/test_datasource/test_azure.py b/tests/unittests/test_datasource/test_azure.py
index 8ad4368c..56484b27 100644
--- a/tests/unittests/test_datasource/test_azure.py
+++ b/tests/unittests/test_datasource/test_azure.py
@@ -17,6 +17,7 @@ import crypt
 import httpretty
 import json
 import os
+import requests
 import stat
 import xml.etree.ElementTree as ET
 import yaml
@@ -184,6 +185,35 @@ class TestGetMetadataFromIMDS(HttprettyTestCase):
             "Crawl of Azure Instance Metadata Service (IMDS) took",  # log_time
             self.logs.getvalue())
 
+    @mock.patch('requests.Session.request')
+    @mock.patch('cloudinit.url_helper.time.sleep')
+    @mock.patch(MOCKPATH + 'net.is_up')
+    def test_get_metadata_from_imds_retries_on_timeout(
+            self, m_net_is_up, m_sleep, m_request):
+        """Retry IMDS network metadata on timeout errors."""
+
+        self.attempt = 0
+        m_request.side_effect = requests.Timeout('Fake Connection Timeout')
+
+        def retry_callback(request, uri, headers):
+            self.attempt += 1
+            raise requests.Timeout('Fake connection timeout')
+
+        httpretty.register_uri(
+            httpretty.GET,
+            dsaz.IMDS_URL + 'instance?api-version=2017-12-01',
+            body=retry_callback)
+
+        m_net_is_up.return_value = True  # skips dhcp
+
+        self.assertEqual({}, dsaz.get_metadata_from_imds('eth9', retries=3))
+
+        m_net_is_up.assert_called_with('eth9')
+        self.assertEqual([mock.call(1)]*3, m_sleep.call_args_list)
+        self.assertIn(
+            "Crawl of Azure Instance Metadata Service (IMDS) took",  # log_time
+            self.logs.getvalue())
+
 
 class TestAzureDataSource(CiTestCase):
 
-- 
cgit v1.2.3