summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcloudinit/sources/helpers/azure.py401
-rw-r--r--tests/unittests/test_datasource/test_azure_helper.py391
-rw-r--r--tools/.github-cla-signers1
3 files changed, 631 insertions, 162 deletions
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index 7afa7ed8..6df28ccf 100755
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -195,7 +195,7 @@ def _get_dhcp_endpoint_option_name():
return azure_endpoint
-class AzureEndpointHttpClient(object):
+class AzureEndpointHttpClient:
headers = {
'x-ms-agent-name': 'WALinuxAgent',
@@ -213,57 +213,77 @@ class AzureEndpointHttpClient(object):
if secure:
headers = self.headers.copy()
headers.update(self.extra_secure_headers)
- return url_helper.read_file_or_url(url, headers=headers, timeout=5,
- retries=10)
+ return url_helper.readurl(url, headers=headers,
+ timeout=5, retries=10, sec_between=5)
def post(self, url, data=None, extra_headers=None):
headers = self.headers
if extra_headers is not None:
headers = self.headers.copy()
headers.update(extra_headers)
- return url_helper.read_file_or_url(url, data=data, headers=headers,
- timeout=5, retries=10)
+ return url_helper.readurl(url, data=data, headers=headers,
+ timeout=5, retries=10, sec_between=5)
-class GoalState(object):
+class InvalidGoalStateXMLException(Exception):
+ """Raised when GoalState XML is invalid or has missing data."""
- def __init__(self, xml, http_client):
- self.http_client = http_client
- self.root = ElementTree.fromstring(xml)
- self._certificates_xml = None
- def _text_from_xpath(self, xpath):
- element = self.root.find(xpath)
- if element is not None:
- return element.text
- return None
+class GoalState:
- @property
- def container_id(self):
- return self._text_from_xpath('./Container/ContainerId')
+ def __init__(self, unparsed_xml, azure_endpoint_client):
+ """Parses a GoalState XML string and returns a GoalState object.
- @property
- def incarnation(self):
- return self._text_from_xpath('./Incarnation')
+ @param unparsed_xml: string representing a GoalState XML.
+ @param azure_endpoint_client: instance of AzureEndpointHttpClient
+ @return: GoalState object representing the GoalState XML string.
+ """
+ self.azure_endpoint_client = azure_endpoint_client
- @property
- def instance_id(self):
- return self._text_from_xpath(
+ try:
+ self.root = ElementTree.fromstring(unparsed_xml)
+ except ElementTree.ParseError as e:
+ msg = 'Failed to parse GoalState XML: %s'
+ LOG.warning(msg, e)
+ report_diagnostic_event(msg % (e,))
+ raise
+
+ self.container_id = self._text_from_xpath('./Container/ContainerId')
+ self.instance_id = self._text_from_xpath(
'./Container/RoleInstanceList/RoleInstance/InstanceId')
+ self.incarnation = self._text_from_xpath('./Incarnation')
+
+ for attr in ("container_id", "instance_id", "incarnation"):
+ if getattr(self, attr) is None:
+ msg = 'Missing %s in GoalState XML'
+ LOG.warning(msg, attr)
+ report_diagnostic_event(msg % (attr,))
+ raise InvalidGoalStateXMLException(msg)
+
+ self.certificates_xml = None
+ url = self._text_from_xpath(
+ './Container/RoleInstanceList/RoleInstance'
+ '/Configuration/Certificates')
+ if url is not None:
+ with events.ReportEventStack(
+ name="get-certificates-xml",
+ description="get certificates xml",
+ parent=azure_ds_reporter):
+ self.certificates_xml = \
+ self.azure_endpoint_client.get(
+ url, secure=True).contents
+ if self.certificates_xml is None:
+ raise InvalidGoalStateXMLException(
+ 'Azure endpoint returned empty certificates xml.')
- @property
- def certificates_xml(self):
- if self._certificates_xml is None:
- url = self._text_from_xpath(
- './Container/RoleInstanceList/RoleInstance'
- '/Configuration/Certificates')
- if url is not None:
- self._certificates_xml = self.http_client.get(
- url, secure=True).contents
- return self._certificates_xml
+ def _text_from_xpath(self, xpath):
+ element = self.root.find(xpath)
+ if element is not None:
+ return element.text
+ return None
-class OpenSSLManager(object):
+class OpenSSLManager:
certificate_names = {
'private_key': 'TransportPrivate.pem',
@@ -370,25 +390,120 @@ class OpenSSLManager(object):
return keys
-class WALinuxAgentShim(object):
-
- REPORT_READY_XML_TEMPLATE = '\n'.join([
- '<?xml version="1.0" encoding="utf-8"?>',
- '<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"'
- ' xmlns:xsd="http://www.w3.org/2001/XMLSchema">',
- ' <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>',
- ' <Container>',
- ' <ContainerId>{container_id}</ContainerId>',
- ' <RoleInstanceList>',
- ' <Role>',
- ' <InstanceId>{instance_id}</InstanceId>',
- ' <Health>',
- ' <State>Ready</State>',
- ' </Health>',
- ' </Role>',
- ' </RoleInstanceList>',
- ' </Container>',
- '</Health>'])
+class GoalStateHealthReporter:
+
+ HEALTH_REPORT_XML_TEMPLATE = textwrap.dedent('''\
+ <?xml version="1.0" encoding="utf-8"?>
+ <Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:xsd="http://www.w3.org/2001/XMLSchema">
+ <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>
+ <Container>
+ <ContainerId>{container_id}</ContainerId>
+ <RoleInstanceList>
+ <Role>
+ <InstanceId>{instance_id}</InstanceId>
+ <Health>
+ <State>{health_status}</State>
+ {health_detail_subsection}
+ </Health>
+ </Role>
+ </RoleInstanceList>
+ </Container>
+ </Health>
+ ''')
+
+ HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = textwrap.dedent('''\
+ <Details>
+ <SubStatus>{health_substatus}</SubStatus>
+ <Description>{health_description}</Description>
+ </Details>
+ ''')
+
+ PROVISIONING_SUCCESS_STATUS = 'Ready'
+
+ def __init__(self, goal_state, azure_endpoint_client, endpoint):
+ """Creates instance that will report provisioning status to an endpoint
+
+ @param goal_state: An instance of class GoalState that contains
+ goal state info such as incarnation, container id, and instance id.
+ These 3 values are needed when reporting the provisioning status
+ to Azure
+ @param azure_endpoint_client: Instance of class AzureEndpointHttpClient
+ @param endpoint: Endpoint (string) where the provisioning status report
+ will be sent to
+ @return: Instance of class GoalStateHealthReporter
+ """
+ self._goal_state = goal_state
+ self._azure_endpoint_client = azure_endpoint_client
+ self._endpoint = endpoint
+
+ @azure_ds_telemetry_reporter
+ def send_ready_signal(self):
+ document = self.build_report(
+ incarnation=self._goal_state.incarnation,
+ container_id=self._goal_state.container_id,
+ instance_id=self._goal_state.instance_id,
+ status=self.PROVISIONING_SUCCESS_STATUS)
+ LOG.debug('Reporting ready to Azure fabric.')
+ try:
+ self._post_health_report(document=document)
+ except Exception as e:
+ msg = "exception while reporting ready: %s" % e
+ LOG.error(msg)
+ report_diagnostic_event(msg)
+ raise
+
+ LOG.info('Reported ready to Azure fabric.')
+
+ def build_report(
+ self, incarnation, container_id, instance_id,
+ status, substatus=None, description=None):
+ health_detail = ''
+ if substatus is not None:
+ health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format(
+ health_substatus=substatus, health_description=description)
+
+ health_report = self.HEALTH_REPORT_XML_TEMPLATE.format(
+ incarnation=incarnation,
+ container_id=container_id,
+ instance_id=instance_id,
+ health_status=status,
+ health_detail_subsection=health_detail)
+
+ return health_report
+
+ @azure_ds_telemetry_reporter
+ def _post_health_report(self, document):
+ # Whenever report_diagnostic_event(diagnostic_msg) is invoked in code,
+ # the diagnostic messages are written to special files
+ # (/var/opt/hyperv/.kvp_pool_*) as Hyper-V KVP messages.
+ # Hyper-V KVP message communication is done through these files,
+ # and KVP functionality is used to communicate and share diagnostic
+ # info with the Azure Host.
+ # The Azure Host will collect the VM's Hyper-V KVP diagnostic messages
+ # when cloud-init reports to fabric.
+ # When the Azure Host receives the health report signal, it will only
+ # collect and process whatever KVP diagnostic messages have been
+ # written to the KVP files.
+ # KVP messages that are published after the Azure Host receives the
+ # signal are ignored and unprocessed, so yield this thread to the
+ # Hyper-V KVP Reporting thread so that they are written.
+ # time.sleep(0) is a low-cost and proven method to yield the scheduler
+ # and ensure that events are flushed.
+ # See HyperVKvpReportingHandler class, which is a multi-threaded
+ # reporting handler that writes to the special KVP files.
+ time.sleep(0)
+
+ LOG.debug('Sending health report to Azure fabric.')
+ url = "http://{}/machine?comp=health".format(self._endpoint)
+ self._azure_endpoint_client.post(
+ url,
+ data=document,
+ extra_headers={'Content-Type': 'text/xml; charset=utf-8'})
+ LOG.debug('Successfully sent health report to Azure fabric')
+
+
+class WALinuxAgentShim:
def __init__(self, fallback_lease_file=None, dhcp_options=None):
LOG.debug('WALinuxAgentShim instantiated, fallback_lease_file=%s',
@@ -396,6 +511,7 @@ class WALinuxAgentShim(object):
self.dhcpoptions = dhcp_options
self._endpoint = None
self.openssl_manager = None
+ self.azure_endpoint_client = None
self.lease_file = fallback_lease_file
def clean_up(self):
@@ -494,7 +610,22 @@ class WALinuxAgentShim(object):
@staticmethod
@azure_ds_telemetry_reporter
def find_endpoint(fallback_lease_file=None, dhcp245=None):
+ """Finds and returns the Azure endpoint using various methods.
+
+ The Azure endpoint is searched in the following order:
+ 1. Endpoint from dhcp options (dhcp option 245).
+ 2. Endpoint from networkd.
+ 3. Endpoint from dhclient hook json.
+ 4. Endpoint from fallback lease file.
+ 5. The default Azure endpoint.
+
+ @param fallback_lease_file: Fallback lease file that will be used
+ during endpoint search.
+ @param dhcp245: dhcp options that will be used during endpoint search.
+ @return: Azure endpoint IP address.
+ """
value = None
+
if dhcp245 is not None:
value = dhcp245
LOG.debug("Using Azure Endpoint from dhcp options")
@@ -536,42 +667,128 @@ class WALinuxAgentShim(object):
@azure_ds_telemetry_reporter
def register_with_azure_and_fetch_data(self, pubkey_info=None):
+ """Gets the VM's GoalState from Azure, uses the GoalState information
+ to report ready/send the ready signal/provisioning complete signal to
+ Azure, and then uses pubkey_info to filter and obtain the user's
+ pubkeys from the GoalState.
+
+ @param pubkey_info: List of pubkey values and fingerprints which are
+ used to filter and obtain the user's pubkey values from the
+ GoalState.
+ @return: The list of user's authorized pubkey values.
+ """
if self.openssl_manager is None:
self.openssl_manager = OpenSSLManager()
- http_client = AzureEndpointHttpClient(self.openssl_manager.certificate)
+ if self.azure_endpoint_client is None:
+ self.azure_endpoint_client = AzureEndpointHttpClient(
+ self.openssl_manager.certificate)
+ goal_state = self._fetch_goal_state_from_azure()
+ ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info)
+ health_reporter = GoalStateHealthReporter(
+ goal_state, self.azure_endpoint_client, self.endpoint)
+ health_reporter.send_ready_signal()
+ return {'public-keys': ssh_keys}
+
+ @azure_ds_telemetry_reporter
+ def _fetch_goal_state_from_azure(self):
+ """Fetches the GoalState XML from the Azure endpoint, parses the XML,
+ and returns a GoalState object.
+
+ @return: GoalState object representing the GoalState XML
+ """
+ unparsed_goal_state_xml = self._get_raw_goal_state_xml_from_azure()
+ return self._parse_raw_goal_state_xml(unparsed_goal_state_xml)
+
+ @azure_ds_telemetry_reporter
+ def _get_raw_goal_state_xml_from_azure(self):
+ """Fetches the GoalState XML from the Azure endpoint and returns
+ the XML as a string.
+
+ @return: GoalState XML string
+ """
+
LOG.info('Registering with Azure...')
- attempts = 0
- while True:
- try:
- response = http_client.get(
- 'http://{0}/machine/?comp=goalstate'.format(self.endpoint))
- except Exception as e:
- if attempts < 10:
- time.sleep(attempts + 1)
- else:
- report_diagnostic_event(
- "failed to register with Azure: %s" % e)
- raise
- else:
- break
- attempts += 1
+ url = 'http://{}/machine/?comp=goalstate'.format(self.endpoint)
+ try:
+ response = self.azure_endpoint_client.get(url)
+ except Exception as e:
+ msg = 'failed to register with Azure: %s' % e
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
+ raise
LOG.debug('Successfully fetched GoalState XML.')
- goal_state = GoalState(response.contents, http_client)
- report_diagnostic_event("container_id %s" % goal_state.container_id)
+ return response.contents
+
+ @azure_ds_telemetry_reporter
+ def _parse_raw_goal_state_xml(self, unparsed_goal_state_xml):
+ """Parses a GoalState XML string and returns a GoalState object.
+
+ @param unparsed_goal_state_xml: GoalState XML string
+ @return: GoalState object representing the GoalState XML
+ """
+ try:
+ goal_state = GoalState(
+ unparsed_goal_state_xml, self.azure_endpoint_client)
+ except Exception as e:
+ msg = 'Error processing GoalState XML: %s' % e
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
+ raise
+ msg = ', '.join([
+ 'GoalState XML container id: %s' % goal_state.container_id,
+ 'GoalState XML instance id: %s' % goal_state.instance_id,
+ 'GoalState XML incarnation: %s' % goal_state.incarnation])
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
+ return goal_state
+
+ @azure_ds_telemetry_reporter
+ def _get_user_pubkeys(self, goal_state, pubkey_info):
+ """Gets and filters the VM admin user's authorized pubkeys.
+
+ The admin user in this case is the username specified as "admin"
+ when deploying VMs on Azure.
+ See https://docs.microsoft.com/en-us/cli/azure/vm#az-vm-create.
+ cloud-init expects a straightforward array of keys to be dropped
+ into the admin user's authorized_keys file. Azure control plane exposes
+ multiple public keys to the VM via wireserver. Select just the
+ admin user's key(s) and return them, ignoring any other certs.
+
+ @param goal_state: GoalState object. The GoalState object contains
+ a certificate XML, which contains both the VM user's authorized
+ pubkeys and other non-user pubkeys, which are used for
+ MSI and protected extension handling.
+ @param pubkey_info: List of VM user pubkey dicts that were previously
+ obtained from provisioning data.
+ Each pubkey dict in this list can either have the format
+ pubkey['value'] or pubkey['fingerprint'].
+ Each pubkey['fingerprint'] in the list is used to filter
+ and obtain the actual pubkey value from the GoalState
+ certificates XML.
+ Each pubkey['value'] requires no further processing and is
+ immediately added to the return list.
+ @return: A list of the VM user's authorized pubkey values.
+ """
ssh_keys = []
if goal_state.certificates_xml is not None and pubkey_info is not None:
LOG.debug('Certificate XML found; parsing out public keys.')
keys_by_fingerprint = self.openssl_manager.parse_certificates(
goal_state.certificates_xml)
ssh_keys = self._filter_pubkeys(keys_by_fingerprint, pubkey_info)
- self._report_ready(goal_state, http_client)
- return {'public-keys': ssh_keys}
+ return ssh_keys
- def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info):
- """cloud-init expects a straightforward array of keys to be dropped
- into the user's authorized_keys file. Azure control plane exposes
- multiple public keys to the VM via wireserver. Select just the
- user's key(s) and return them, ignoring any other certs.
+ @staticmethod
+ def _filter_pubkeys(keys_by_fingerprint, pubkey_info):
+ """ Filter and return only the user's actual pubkeys.
+
+ @param keys_by_fingerprint: pubkey fingerprint -> pubkey value dict
+ that was obtained from GoalState Certificates XML. May contain
+ non-user pubkeys.
+ @param pubkey_info: List of VM user pubkeys. Pubkey values are added
+ to the return list without further processing. Pubkey fingerprints
+ are used to filter and obtain the actual pubkey values from
+ keys_by_fingerprint.
+ @return: A list of the VM user's authorized pubkey values.
"""
keys = []
for pubkey in pubkey_info:
@@ -590,30 +807,6 @@ class WALinuxAgentShim(object):
return keys
- @azure_ds_telemetry_reporter
- def _report_ready(self, goal_state, http_client):
- LOG.debug('Reporting ready to Azure fabric.')
- document = self.REPORT_READY_XML_TEMPLATE.format(
- incarnation=goal_state.incarnation,
- container_id=goal_state.container_id,
- instance_id=goal_state.instance_id,
- )
- # Host will collect kvps when cloud-init reports ready.
- # some kvps might still be in the queue. We yield the scheduler
- # to make sure we process all kvps up till this point.
- time.sleep(0)
- try:
- http_client.post(
- "http://{0}/machine?comp=health".format(self.endpoint),
- data=document,
- extra_headers={'Content-Type': 'text/xml; charset=utf-8'},
- )
- except Exception as e:
- report_diagnostic_event("exception while reporting ready: %s" % e)
- raise
-
- LOG.info('Reported ready to Azure fabric.')
-
@azure_ds_telemetry_reporter
def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None,
@@ -631,7 +824,7 @@ def dhcp_log_cb(out, err):
report_diagnostic_event("dhclient error stream: %s" % err)
-class EphemeralDHCPv4WithReporting(object):
+class EphemeralDHCPv4WithReporting:
def __init__(self, reporter, nic=None):
self.reporter = reporter
self.ephemeralDHCPv4 = EphemeralDHCPv4(
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py
index f314cd4a..5e6d3d2d 100644
--- a/tests/unittests/test_datasource/test_azure_helper.py
+++ b/tests/unittests/test_datasource/test_azure_helper.py
@@ -1,8 +1,10 @@
# This file is part of cloud-init. See LICENSE file for license information.
import os
+import re
import unittest
from textwrap import dedent
+from xml.etree import ElementTree
from cloudinit.sources.helpers import azure as azure_helper
from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir
@@ -48,6 +50,30 @@ GOAL_STATE_TEMPLATE = """\
</GoalState>
"""
+HEALTH_REPORT_XML_TEMPLATE = '''\
+<?xml version="1.0" encoding="utf-8"?>
+<Health xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:xsd="http://www.w3.org/2001/XMLSchema">
+ <GoalStateIncarnation>{incarnation}</GoalStateIncarnation>
+ <Container>
+ <ContainerId>{container_id}</ContainerId>
+ <RoleInstanceList>
+ <Role>
+ <InstanceId>{instance_id}</InstanceId>
+ <Health>
+ <State>{health_status}</State>
+ {health_detail_subsection}
+ </Health>
+ </Role>
+ </RoleInstanceList>
+ </Container>
+</Health>
+'''
+
+
+class SentinelException(Exception):
+ pass
+
class TestFindEndpoint(CiTestCase):
@@ -140,9 +166,7 @@ class TestGoalStateParsing(CiTestCase):
'certificates_url': 'MyCertificatesUrl',
}
- def _get_goal_state(self, http_client=None, **kwargs):
- if http_client is None:
- http_client = mock.MagicMock()
+ def _get_formatted_goal_state_xml_string(self, **kwargs):
parameters = self.default_parameters.copy()
parameters.update(kwargs)
xml = GOAL_STATE_TEMPLATE.format(**parameters)
@@ -153,7 +177,13 @@ class TestGoalStateParsing(CiTestCase):
continue
new_xml_lines.append(line)
xml = '\n'.join(new_xml_lines)
- return azure_helper.GoalState(xml, http_client)
+ return xml
+
+ def _get_goal_state(self, m_azure_endpoint_client=None, **kwargs):
+ if m_azure_endpoint_client is None:
+ m_azure_endpoint_client = mock.MagicMock()
+ xml = self._get_formatted_goal_state_xml_string(**kwargs)
+ return azure_helper.GoalState(xml, m_azure_endpoint_client)
def test_incarnation_parsed_correctly(self):
incarnation = '123'
@@ -190,25 +220,55 @@ class TestGoalStateParsing(CiTestCase):
azure_helper.is_byte_swapped(previous_iid, current_iid))
def test_certificates_xml_parsed_and_fetched_correctly(self):
- http_client = mock.MagicMock()
+ m_azure_endpoint_client = mock.MagicMock()
certificates_url = 'TestCertificatesUrl'
goal_state = self._get_goal_state(
- http_client=http_client, certificates_url=certificates_url)
+ m_azure_endpoint_client=m_azure_endpoint_client,
+ certificates_url=certificates_url)
certificates_xml = goal_state.certificates_xml
- self.assertEqual(1, http_client.get.call_count)
- self.assertEqual(certificates_url, http_client.get.call_args[0][0])
- self.assertTrue(http_client.get.call_args[1].get('secure', False))
- self.assertEqual(http_client.get.return_value.contents,
- certificates_xml)
+ self.assertEqual(1, m_azure_endpoint_client.get.call_count)
+ self.assertEqual(
+ certificates_url,
+ m_azure_endpoint_client.get.call_args[0][0])
+ self.assertTrue(
+ m_azure_endpoint_client.get.call_args[1].get(
+ 'secure', False))
+ self.assertEqual(
+ m_azure_endpoint_client.get.return_value.contents,
+ certificates_xml)
def test_missing_certificates_skips_http_get(self):
- http_client = mock.MagicMock()
+ m_azure_endpoint_client = mock.MagicMock()
goal_state = self._get_goal_state(
- http_client=http_client, certificates_url=None)
+ m_azure_endpoint_client=m_azure_endpoint_client,
+ certificates_url=None)
certificates_xml = goal_state.certificates_xml
- self.assertEqual(0, http_client.get.call_count)
+ self.assertEqual(0, m_azure_endpoint_client.get.call_count)
self.assertIsNone(certificates_xml)
+ def test_invalid_goal_state_xml_raises_parse_error(self):
+ xml = 'random non-xml data'
+ with self.assertRaises(ElementTree.ParseError):
+ azure_helper.GoalState(xml, mock.MagicMock())
+
+ def test_missing_container_id_in_goal_state_xml_raises_exc(self):
+ xml = self._get_formatted_goal_state_xml_string()
+ xml = re.sub('<ContainerId>.*</ContainerId>', '', xml)
+ with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+ azure_helper.GoalState(xml, mock.MagicMock())
+
+ def test_missing_instance_id_in_goal_state_xml_raises_exc(self):
+ xml = self._get_formatted_goal_state_xml_string()
+ xml = re.sub('<InstanceId>.*</InstanceId>', '', xml)
+ with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+ azure_helper.GoalState(xml, mock.MagicMock())
+
+ def test_missing_incarnation_in_goal_state_xml_raises_exc(self):
+ xml = self._get_formatted_goal_state_xml_string()
+ xml = re.sub('<Incarnation>.*</Incarnation>', '', xml)
+ with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+ azure_helper.GoalState(xml, mock.MagicMock())
+
class TestAzureEndpointHttpClient(CiTestCase):
@@ -222,61 +282,95 @@ class TestAzureEndpointHttpClient(CiTestCase):
patches = ExitStack()
self.addCleanup(patches.close)
- self.read_file_or_url = patches.enter_context(
- mock.patch.object(azure_helper.url_helper, 'read_file_or_url'))
+ self.readurl = patches.enter_context(
+ mock.patch.object(azure_helper.url_helper, 'readurl'))
+ patches.enter_context(
+ mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
def test_non_secure_get(self):
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
url = 'MyTestUrl'
response = client.get(url, secure=False)
- self.assertEqual(1, self.read_file_or_url.call_count)
- self.assertEqual(self.read_file_or_url.return_value, response)
+ self.assertEqual(1, self.readurl.call_count)
+ self.assertEqual(self.readurl.return_value, response)
self.assertEqual(
- mock.call(url, headers=self.regular_headers, retries=10,
- timeout=5),
- self.read_file_or_url.call_args)
+ mock.call(url, headers=self.regular_headers,
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_non_secure_get_raises_exception(self):
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ url = 'MyTestUrl'
+ with self.assertRaises(SentinelException):
+ client.get(url, secure=False)
def test_secure_get(self):
url = 'MyTestUrl'
- certificate = mock.MagicMock()
+ m_certificate = mock.MagicMock()
expected_headers = self.regular_headers.copy()
expected_headers.update({
"x-ms-cipher-name": "DES_EDE3_CBC",
- "x-ms-guest-agent-public-x509-cert": certificate,
+ "x-ms-guest-agent-public-x509-cert": m_certificate,
})
- client = azure_helper.AzureEndpointHttpClient(certificate)
+ client = azure_helper.AzureEndpointHttpClient(m_certificate)
response = client.get(url, secure=True)
- self.assertEqual(1, self.read_file_or_url.call_count)
- self.assertEqual(self.read_file_or_url.return_value, response)
+ self.assertEqual(1, self.readurl.call_count)
+ self.assertEqual(self.readurl.return_value, response)
self.assertEqual(
- mock.call(url, headers=expected_headers, retries=10,
- timeout=5),
- self.read_file_or_url.call_args)
+ mock.call(url, headers=expected_headers,
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_secure_get_raises_exception(self):
+ url = 'MyTestUrl'
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ with self.assertRaises(SentinelException):
+ client.get(url, secure=True)
def test_post(self):
- data = mock.MagicMock()
+ m_data = mock.MagicMock()
url = 'MyTestUrl'
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
- response = client.post(url, data=data)
- self.assertEqual(1, self.read_file_or_url.call_count)
- self.assertEqual(self.read_file_or_url.return_value, response)
+ response = client.post(url, data=m_data)
+ self.assertEqual(1, self.readurl.call_count)
+ self.assertEqual(self.readurl.return_value, response)
self.assertEqual(
- mock.call(url, data=data, headers=self.regular_headers, retries=10,
- timeout=5),
- self.read_file_or_url.call_args)
+ mock.call(url, data=m_data, headers=self.regular_headers,
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_post_raises_exception(self):
+ m_data = mock.MagicMock()
+ url = 'MyTestUrl'
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ with self.assertRaises(SentinelException):
+ client.post(url, data=m_data)
def test_post_with_extra_headers(self):
url = 'MyTestUrl'
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
extra_headers = {'test': 'header'}
client.post(url, extra_headers=extra_headers)
- self.assertEqual(1, self.read_file_or_url.call_count)
expected_headers = self.regular_headers.copy()
expected_headers.update(extra_headers)
+ self.assertEqual(1, self.readurl.call_count)
self.assertEqual(
mock.call(mock.ANY, data=mock.ANY, headers=expected_headers,
- retries=10, timeout=5),
- self.read_file_or_url.call_args)
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_post_with_sleep_with_extra_headers_raises_exception(self):
+ m_data = mock.MagicMock()
+ url = 'MyTestUrl'
+ extra_headers = {'test': 'header'}
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ with self.assertRaises(SentinelException):
+ client.post(
+ url, data=m_data, extra_headers=extra_headers)
class TestOpenSSLManager(CiTestCase):
@@ -365,6 +459,131 @@ class TestOpenSSLManagerActions(CiTestCase):
self.assertIn(fp, keys_by_fp)
+class TestGoalStateHealthReporter(CiTestCase):
+
+ default_parameters = {
+ 'incarnation': 1634,
+ 'container_id': 'MyContainerId',
+ 'instance_id': 'MyInstanceId'
+ }
+
+ test_endpoint = 'TestEndpoint'
+ test_url = 'http://{0}/machine?comp=health'.format(test_endpoint)
+ test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'}
+
+ provisioning_success_status = 'Ready'
+
+ def setUp(self):
+ super(TestGoalStateHealthReporter, self).setUp()
+ patches = ExitStack()
+ self.addCleanup(patches.close)
+
+ patches.enter_context(
+ mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
+ self.read_file_or_url = patches.enter_context(
+ mock.patch.object(azure_helper.url_helper, 'read_file_or_url'))
+
+ self.post = patches.enter_context(
+ mock.patch.object(azure_helper.AzureEndpointHttpClient,
+ 'post'))
+
+ self.GoalState = patches.enter_context(
+ mock.patch.object(azure_helper, 'GoalState'))
+ self.GoalState.return_value.container_id = \
+ self.default_parameters['container_id']
+ self.GoalState.return_value.instance_id = \
+ self.default_parameters['instance_id']
+ self.GoalState.return_value.incarnation = \
+ self.default_parameters['incarnation']
+
+ def _get_formatted_health_report_xml_string(self, **kwargs):
+ return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs)
+
+ def _get_report_ready_health_document(self):
+ return self._get_formatted_health_report_xml_string(
+ incarnation=self.default_parameters['incarnation'],
+ container_id=self.default_parameters['container_id'],
+ instance_id=self.default_parameters['instance_id'],
+ health_status=self.provisioning_success_status,
+ health_detail_subsection='')
+
+ def test_send_ready_signal_sends_post_request(self):
+ with mock.patch.object(
+ azure_helper.GoalStateHealthReporter,
+ 'build_report') as m_build_report:
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ reporter = azure_helper.GoalStateHealthReporter(
+ azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+ client, self.test_endpoint)
+ reporter.send_ready_signal()
+
+ self.assertEqual(1, self.post.call_count)
+ self.assertEqual(
+ mock.call(
+ self.test_url,
+ data=m_build_report.return_value,
+ extra_headers=self.test_default_headers),
+ self.post.call_args)
+
+ def test_build_report_for_health_document(self):
+ health_document = self._get_report_ready_health_document()
+ reporter = azure_helper.GoalStateHealthReporter(
+ azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+ azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+ self.test_endpoint)
+ generated_health_document = reporter.build_report(
+ incarnation=self.default_parameters['incarnation'],
+ container_id=self.default_parameters['container_id'],
+ instance_id=self.default_parameters['instance_id'],
+ status=self.provisioning_success_status)
+ self.assertEqual(health_document, generated_health_document)
+ self.assertIn(
+ '<GoalStateIncarnation>{}</GoalStateIncarnation>'.format(
+ str(self.default_parameters['incarnation'])),
+ generated_health_document)
+ self.assertIn(
+ ''.join([
+ '<ContainerId>',
+ self.default_parameters['container_id'],
+ '</ContainerId>']),
+ generated_health_document)
+ self.assertIn(
+ ''.join([
+ '<InstanceId>',
+ self.default_parameters['instance_id'],
+ '</InstanceId>']),
+ generated_health_document)
+ self.assertIn(
+ ''.join([
+ '<State>',
+ self.provisioning_success_status,
+ '</State>']),
+ generated_health_document
+ )
+ self.assertNotIn('<Details>', generated_health_document)
+ self.assertNotIn('<SubStatus>', generated_health_document)
+ self.assertNotIn('<Description>', generated_health_document)
+
+ def test_send_ready_signal_calls_build_report(self):
+ with mock.patch.object(
+ azure_helper.GoalStateHealthReporter, 'build_report'
+ ) as m_build_report:
+ reporter = azure_helper.GoalStateHealthReporter(
+ azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+ azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+ self.test_endpoint)
+ reporter.send_ready_signal()
+
+ self.assertEqual(1, m_build_report.call_count)
+ self.assertEqual(
+ mock.call(
+ incarnation=self.default_parameters['incarnation'],
+ container_id=self.default_parameters['container_id'],
+ instance_id=self.default_parameters['instance_id'],
+ status=self.provisioning_success_status),
+ m_build_report.call_args)
+
+
class TestWALinuxAgentShim(CiTestCase):
def setUp(self):
@@ -383,14 +602,21 @@ class TestWALinuxAgentShim(CiTestCase):
patches.enter_context(
mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
- def test_http_client_uses_certificate(self):
+ self.test_incarnation = 'TestIncarnation'
+ self.test_container_id = 'TestContainerId'
+ self.test_instance_id = 'TestInstanceId'
+ self.GoalState.return_value.incarnation = self.test_incarnation
+ self.GoalState.return_value.container_id = self.test_container_id
+ self.GoalState.return_value.instance_id = self.test_instance_id
+
+ def test_azure_endpoint_client_uses_certificate_during_report_ready(self):
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.OpenSSLManager.return_value.certificate)],
self.AzureEndpointHttpClient.call_args_list)
- def test_correct_url_used_for_goalstate(self):
+ def test_correct_url_used_for_goalstate_during_report_ready(self):
self.find_endpoint.return_value = 'test_endpoint'
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
@@ -438,43 +664,67 @@ class TestWALinuxAgentShim(CiTestCase):
expected_url = 'http://test_endpoint/machine?comp=health'
self.assertEqual(
[mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
- self.AzureEndpointHttpClient.return_value.post.call_args_list)
+ self.AzureEndpointHttpClient.return_value.post
+ .call_args_list)
def test_goal_state_values_used_for_report_ready(self):
- self.GoalState.return_value.incarnation = 'TestIncarnation'
- self.GoalState.return_value.container_id = 'TestContainerId'
- self.GoalState.return_value.instance_id = 'TestInstanceId'
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
posted_document = (
- self.AzureEndpointHttpClient.return_value.post.call_args[1]['data']
+ self.AzureEndpointHttpClient.return_value.post
+ .call_args[1]['data']
)
- self.assertIn('TestIncarnation', posted_document)
- self.assertIn('TestContainerId', posted_document)
- self.assertIn('TestInstanceId', posted_document)
+ self.assertIn(self.test_incarnation, posted_document)
+ self.assertIn(self.test_container_id, posted_document)
+ self.assertIn(self.test_instance_id, posted_document)
+
+ def test_xml_elems_in_report_ready(self):
+ shim = wa_shim()
+ shim.register_with_azure_and_fetch_data()
+ health_document = HEALTH_REPORT_XML_TEMPLATE.format(
+ incarnation=self.test_incarnation,
+ container_id=self.test_container_id,
+ instance_id=self.test_instance_id,
+ health_status='Ready',
+ health_detail_subsection='')
+ posted_document = (
+ self.AzureEndpointHttpClient.return_value.post
+ .call_args[1]['data'])
+ self.assertEqual(health_document, posted_document)
def test_clean_up_can_be_called_at_any_time(self):
shim = wa_shim()
shim.clean_up()
- def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self):
+ def test_clean_up_after_report_ready(self):
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
shim.clean_up()
self.assertEqual(
1, self.OpenSSLManager.return_value.clean_up.call_count)
- def test_failure_to_fetch_goalstate_bubbles_up(self):
- class SentinelException(Exception):
- pass
- self.AzureEndpointHttpClient.return_value.get.side_effect = (
- SentinelException)
+ def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self):
+ self.AzureEndpointHttpClient.return_value.get \
+ .side_effect = (SentinelException)
shim = wa_shim()
self.assertRaises(SentinelException,
shim.register_with_azure_and_fetch_data)
+ def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self):
+ self.GoalState.side_effect = SentinelException
+ shim = wa_shim()
+ self.assertRaises(SentinelException,
+ shim.register_with_azure_and_fetch_data)
-class TestGetMetadataFromFabric(CiTestCase):
+ def test_failure_to_send_report_ready_health_doc_bubbles_up(self):
+ self.AzureEndpointHttpClient.return_value.post \
+ .side_effect = SentinelException
+ shim = wa_shim()
+ self.assertRaises(SentinelException,
+ shim.register_with_azure_and_fetch_data)
+
+
+class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase):
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_data_from_shim_returned(self, shim):
@@ -490,14 +740,39 @@ class TestGetMetadataFromFabric(CiTestCase):
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_failure_in_registration_calls_clean_up(self, shim):
- class SentinelException(Exception):
- pass
shim.return_value.register_with_azure_and_fetch_data.side_effect = (
SentinelException)
self.assertRaises(SentinelException,
azure_helper.get_metadata_from_fabric)
self.assertEqual(1, shim.return_value.clean_up.call_count)
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_calls_shim_register_with_azure_and_fetch_data(self, shim):
+ m_pubkey_info = mock.MagicMock()
+ azure_helper.get_metadata_from_fabric(pubkey_info=m_pubkey_info)
+ self.assertEqual(
+ 1,
+ shim.return_value
+ .register_with_azure_and_fetch_data.call_count)
+ self.assertEqual(
+ mock.call(pubkey_info=m_pubkey_info),
+ shim.return_value
+ .register_with_azure_and_fetch_data.call_args)
+
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_instantiates_shim_with_kwargs(self, shim):
+ m_fallback_lease_file = mock.MagicMock()
+ m_dhcp_options = mock.MagicMock()
+ azure_helper.get_metadata_from_fabric(
+ fallback_lease_file=m_fallback_lease_file,
+ dhcp_opts=m_dhcp_options)
+ self.assertEqual(1, shim.call_count)
+ self.assertEqual(
+ mock.call(
+ fallback_lease_file=m_fallback_lease_file,
+ dhcp_options=m_dhcp_options),
+ shim.call_args)
+
class TestExtractIpAddressFromNetworkd(CiTestCase):
diff --git a/tools/.github-cla-signers b/tools/.github-cla-signers
index 50473bf7..0c4d728f 100644
--- a/tools/.github-cla-signers
+++ b/tools/.github-cla-signers
@@ -7,6 +7,7 @@ dermotbradley
dhensby
eandersson
izzyleung
+johnsonshi
landon912
lucasmoura
marlluslustosa