summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorƁukasz 'sil2100' Zemczak <lukasz.zemczak@ubuntu.com>2017-09-04 10:27:07 +0200
committerusd-importer <ubuntu-server@lists.ubuntu.com>2017-09-04 09:38:24 +0000
commite919bdd14e48919244da9e499070fb64377993e5 (patch)
tree33c260c7c99410ac94d5f265fc506cc0b40bb6e4 /tests
parent70c0ea1ac879b2e1cba0a8edb1f3fbe82652413b (diff)
parent3a1d96a77ccaf023256d16183428e3d895f8a051 (diff)
downloadvyos-walinuxagent-e919bdd14e48919244da9e499070fb64377993e5.tar.gz
vyos-walinuxagent-e919bdd14e48919244da9e499070fb64377993e5.zip
Import patches-applied version 2.2.16-0ubuntu1 to applied/ubuntu/artful-proposed
Imported using git-ubuntu import. Changelog parent: 70c0ea1ac879b2e1cba0a8edb1f3fbe82652413b Unapplied parent: 3a1d96a77ccaf023256d16183428e3d895f8a051 New changelog entries: * New upstream release (LP: #1714299).
Diffstat (limited to 'tests')
-rw-r--r--tests/common/osutil/test_default.py229
-rw-r--r--tests/common/test_conf.py51
-rw-r--r--tests/common/test_event.py117
-rw-r--r--tests/data/ga/WALinuxAgent-2.2.11.zipbin450878 -> 0 bytes
-rw-r--r--tests/data/ga/WALinuxAgent-2.2.14.zipbin0 -> 500633 bytes
-rw-r--r--tests/data/test_waagent.conf16
-rw-r--r--tests/ga/test_update.py243
-rw-r--r--tests/pa/test_provision.py28
-rw-r--r--tests/protocol/mockwiredata.py81
-rw-r--r--tests/protocol/test_hostplugin.py61
-rw-r--r--tests/protocol/test_metadata.py20
-rw-r--r--tests/protocol/test_wire.py77
-rw-r--r--tests/test_agent.py77
-rw-r--r--tests/tools.py9
-rw-r--r--tests/utils/test_file_util.py115
-rw-r--r--tests/utils/test_rest_util.py356
-rw-r--r--tests/utils/test_text_util.py44
17 files changed, 1243 insertions, 281 deletions
diff --git a/tests/common/osutil/test_default.py b/tests/common/osutil/test_default.py
index 87acc60..ec4408b 100644
--- a/tests/common/osutil/test_default.py
+++ b/tests/common/osutil/test_default.py
@@ -25,6 +25,7 @@ from azurelinuxagent.common.exception import OSUtilError
from azurelinuxagent.common.future import ustr
from azurelinuxagent.common.osutil import get_osutil
from azurelinuxagent.common.utils import fileutil
+from azurelinuxagent.common.utils.flexible_version import FlexibleVersion
from tests.tools import *
@@ -112,6 +113,21 @@ class TestOSUtil(AgentTestCase):
self.assertFalse(osutil.DefaultOSUtil().is_primary_interface('lo'))
self.assertTrue(osutil.DefaultOSUtil().is_primary_interface('eth0'))
+ def test_sriov(self):
+ routing_table = "\
+ Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT \n" \
+ "bond0 00000000 0100000A 0003 0 0 0 00000000 0 0 0 \n" \
+ "bond0 0000000A 00000000 0001 0 0 0 00000000 0 0 0 \n" \
+ "eth0 0000000A 00000000 0001 0 0 0 00000000 0 0 0 \n" \
+ "bond0 10813FA8 0100000A 0007 0 0 0 00000000 0 0 0 \n" \
+ "bond0 FEA9FEA9 0100000A 0007 0 0 0 00000000 0 0 0 \n"
+
+ mo = mock.mock_open(read_data=routing_table)
+ with patch(open_patch(), mo):
+ self.assertFalse(osutil.DefaultOSUtil().is_primary_interface('eth0'))
+ self.assertTrue(osutil.DefaultOSUtil().is_primary_interface('bond0'))
+
+
def test_multiple_default_routes(self):
routing_table = "\
Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT \n\
@@ -362,23 +378,50 @@ Match host 192.168.1.2\n\
conf.get_sshd_conf_file_path(),
expected_output)
+ def test_correct_instance_id(self):
+ util = osutil.DefaultOSUtil()
+ self.assertEqual(
+ "12345678-1234-1234-1234-123456789012",
+ util._correct_instance_id("78563412-3412-3412-1234-123456789012"))
+ self.assertEqual(
+ "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8",
+ util._correct_instance_id("544CDFD0-CB4E-4B4A-9954-5BDF3ED5C3B8"))
+
@patch('os.path.isfile', return_value=True)
@patch('azurelinuxagent.common.utils.fileutil.read_file',
- return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502")
+ return_value="33C2F3B9-1399-429F-8EB3-BA656DF32502")
def test_get_instance_id_from_file(self, mock_read, mock_isfile):
util = osutil.DefaultOSUtil()
self.assertEqual(
- "B9F3C233-9913-9F42-8EB3-BA656DF32502",
+ util.get_instance_id(),
+ "B9F3C233-9913-9F42-8EB3-BA656DF32502")
+
+ @patch('os.path.isfile', return_value=True)
+ @patch('azurelinuxagent.common.utils.fileutil.read_file',
+ return_value="")
+ def test_get_instance_id_empty_from_file(self, mock_read, mock_isfile):
+ util = osutil.DefaultOSUtil()
+ self.assertEqual(
+ "",
+ util.get_instance_id())
+
+ @patch('os.path.isfile', return_value=True)
+ @patch('azurelinuxagent.common.utils.fileutil.read_file',
+ return_value="Value")
+ def test_get_instance_id_malformed_from_file(self, mock_read, mock_isfile):
+ util = osutil.DefaultOSUtil()
+ self.assertEqual(
+ "Value",
util.get_instance_id())
@patch('os.path.isfile', return_value=False)
@patch('azurelinuxagent.common.utils.shellutil.run_get_output',
- return_value=[0, 'B9F3C233-9913-9F42-8EB3-BA656DF32502'])
+ return_value=[0, '33C2F3B9-1399-429F-8EB3-BA656DF32502'])
def test_get_instance_id_from_dmidecode(self, mock_shell, mock_isfile):
util = osutil.DefaultOSUtil()
self.assertEqual(
- "B9F3C233-9913-9F42-8EB3-BA656DF32502",
- util.get_instance_id())
+ util.get_instance_id(),
+ "B9F3C233-9913-9F42-8EB3-BA656DF32502")
@patch('os.path.isfile', return_value=False)
@patch('azurelinuxagent.common.utils.shellutil.run_get_output',
@@ -394,5 +437,181 @@ Match host 192.168.1.2\n\
util = osutil.DefaultOSUtil()
self.assertEqual("", util.get_instance_id())
+ @patch('os.path.isfile', return_value=True)
+ @patch('azurelinuxagent.common.utils.fileutil.read_file')
+ def test_is_current_instance_id_from_file(self, mock_read, mock_isfile):
+ util = osutil.DefaultOSUtil()
+
+ mock_read.return_value = "B9F3C233-9913-9F42-8EB3-BA656DF32502"
+ self.assertTrue(util.is_current_instance_id(
+ "B9F3C233-9913-9F42-8EB3-BA656DF32502"))
+
+ mock_read.return_value = "33C2F3B9-1399-429F-8EB3-BA656DF32502"
+ self.assertTrue(util.is_current_instance_id(
+ "B9F3C233-9913-9F42-8EB3-BA656DF32502"))
+
+ @patch('os.path.isfile', return_value=False)
+ @patch('azurelinuxagent.common.utils.shellutil.run_get_output')
+ def test_is_current_instance_id_from_dmidecode(self, mock_shell, mock_isfile):
+ util = osutil.DefaultOSUtil()
+
+ mock_shell.return_value = [0, 'B9F3C233-9913-9F42-8EB3-BA656DF32502']
+ self.assertTrue(util.is_current_instance_id(
+ "B9F3C233-9913-9F42-8EB3-BA656DF32502"))
+
+ mock_shell.return_value = [0, '33C2F3B9-1399-429F-8EB3-BA656DF32502']
+ self.assertTrue(util.is_current_instance_id(
+ "B9F3C233-9913-9F42-8EB3-BA656DF32502"))
+
+ @patch('azurelinuxagent.common.conf.get_sudoers_dir')
+ def test_conf_sudoer(self, mock_dir):
+ tmp_dir = tempfile.mkdtemp()
+ mock_dir.return_value = tmp_dir
+
+ util = osutil.DefaultOSUtil()
+
+ # Assert the sudoer line is added if missing
+ util.conf_sudoer("FooBar")
+ waagent_sudoers = os.path.join(tmp_dir, 'waagent')
+ self.assertTrue(os.path.isfile(waagent_sudoers))
+
+ count = -1
+ with open(waagent_sudoers, 'r') as f:
+ count = len(f.readlines())
+ self.assertEqual(1, count)
+
+ # Assert the line does not get added a second time
+ util.conf_sudoer("FooBar")
+
+ count = -1
+ with open(waagent_sudoers, 'r') as f:
+ count = len(f.readlines())
+ print("WRITING TO {0}".format(waagent_sudoers))
+ self.assertEqual(1, count)
+
+ @patch('os.getuid', return_value=42)
+ @patch('azurelinuxagent.common.utils.shellutil.run_get_output')
+ @patch('azurelinuxagent.common.utils.shellutil.run')
+ def test_enable_firewall(self, mock_run, mock_output, mock_uid):
+ osutil._enable_firewall = True
+ util = osutil.DefaultOSUtil()
+
+ dst = '1.2.3.4'
+ uid = 42
+ version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION)
+ wait = "-w"
+
+ mock_run.side_effect = [1, 0, 0]
+ mock_output.side_effect = [(0, version), (0, "Output")]
+ self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid))
+
+ mock_run.assert_has_calls([
+ call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False),
+ call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)),
+ call(osutil.FIREWALL_DROP.format(wait, "A", dst))
+ ])
+ mock_output.assert_has_calls([
+ call(osutil.IPTABLES_VERSION),
+ call(osutil.FIREWALL_LIST.format(wait))
+ ])
+ self.assertTrue(osutil._enable_firewall)
+
+ @patch('os.getuid', return_value=42)
+ @patch('azurelinuxagent.common.utils.shellutil.run_get_output')
+ @patch('azurelinuxagent.common.utils.shellutil.run')
+ def test_enable_firewall_no_wait(self, mock_run, mock_output, mock_uid):
+ osutil._enable_firewall = True
+ util = osutil.DefaultOSUtil()
+
+ dst = '1.2.3.4'
+ uid = 42
+ version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION-1)
+ wait = ""
+
+ mock_run.side_effect = [1, 0, 0]
+ mock_output.side_effect = [(0, version), (0, "Output")]
+ self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid))
+
+ mock_run.assert_has_calls([
+ call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False),
+ call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)),
+ call(osutil.FIREWALL_DROP.format(wait, "A", dst))
+ ])
+ mock_output.assert_has_calls([
+ call(osutil.IPTABLES_VERSION),
+ call(osutil.FIREWALL_LIST.format(wait))
+ ])
+ self.assertTrue(osutil._enable_firewall)
+
+ @patch('os.getuid', return_value=42)
+ @patch('azurelinuxagent.common.utils.shellutil.run_get_output')
+ @patch('azurelinuxagent.common.utils.shellutil.run')
+ def test_enable_firewall_skips_if_drop_exists(self, mock_run, mock_output, mock_uid):
+ osutil._enable_firewall = True
+ util = osutil.DefaultOSUtil()
+
+ dst = '1.2.3.4'
+ uid = 42
+ version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION)
+ wait = "-w"
+
+ mock_run.side_effect = [0, 0, 0]
+ mock_output.return_value = (0, version)
+ self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid))
+
+ mock_run.assert_has_calls([
+ call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False),
+ ])
+ mock_output.assert_has_calls([
+ call(osutil.IPTABLES_VERSION)
+ ])
+ self.assertTrue(osutil._enable_firewall)
+
+ @patch('os.getuid', return_value=42)
+ @patch('azurelinuxagent.common.utils.shellutil.run_get_output')
+ @patch('azurelinuxagent.common.utils.shellutil.run')
+ def test_enable_firewall_ignores_exceptions(self, mock_run, mock_output, mock_uid):
+ osutil._enable_firewall = True
+ util = osutil.DefaultOSUtil()
+
+ dst = '1.2.3.4'
+ uid = 42
+ version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION)
+ wait = "-w"
+
+ mock_run.side_effect = [1, Exception]
+ mock_output.return_value = (0, version)
+ self.assertFalse(util.enable_firewall(dst_ip=dst, uid=uid))
+
+ mock_run.assert_has_calls([
+ call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False),
+ call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid))
+ ])
+ mock_output.assert_has_calls([
+ call(osutil.IPTABLES_VERSION)
+ ])
+ self.assertFalse(osutil._enable_firewall)
+
+ @patch('os.getuid', return_value=42)
+ @patch('azurelinuxagent.common.utils.shellutil.run_get_output')
+ @patch('azurelinuxagent.common.utils.shellutil.run')
+ def test_enable_firewall_skips_if_disabled(self, mock_run, mock_output, mock_uid):
+ osutil._enable_firewall = False
+ util = osutil.DefaultOSUtil()
+
+ dst = '1.2.3.4'
+ uid = 42
+ version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION)
+ wait = "-w"
+
+ mock_run.side_effect = [1, 0, 0]
+ mock_output.side_effect = [(0, version), (0, "Output")]
+ self.assertFalse(util.enable_firewall(dst_ip=dst, uid=uid))
+
+ mock_run.assert_not_called()
+ mock_output.assert_not_called()
+ mock_uid.assert_not_called()
+ self.assertFalse(osutil._enable_firewall)
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/common/test_conf.py b/tests/common/test_conf.py
index 1287b0d..93759de 100644
--- a/tests/common/test_conf.py
+++ b/tests/common/test_conf.py
@@ -24,6 +24,49 @@ from tests.tools import *
class TestConf(AgentTestCase):
+ # Note:
+ # -- These values *MUST* match those from data/test_waagent.conf
+ EXPECTED_CONFIGURATION = {
+ "Provisioning.Enabled" : True,
+ "Provisioning.UseCloudInit" : True,
+ "Provisioning.DeleteRootPassword" : True,
+ "Provisioning.RegenerateSshHostKeyPair" : True,
+ "Provisioning.SshHostKeyPairType" : "rsa",
+ "Provisioning.MonitorHostName" : True,
+ "Provisioning.DecodeCustomData" : False,
+ "Provisioning.ExecuteCustomData" : False,
+ "Provisioning.PasswordCryptId" : '6',
+ "Provisioning.PasswordCryptSaltLength" : 10,
+ "Provisioning.AllowResetSysUser" : False,
+ "ResourceDisk.Format" : True,
+ "ResourceDisk.Filesystem" : "ext4",
+ "ResourceDisk.MountPoint" : "/mnt/resource",
+ "ResourceDisk.EnableSwap" : False,
+ "ResourceDisk.SwapSizeMB" : 0,
+ "ResourceDisk.MountOptions" : None,
+ "Logs.Verbose" : False,
+ "OS.EnableFIPS" : True,
+ "OS.RootDeviceScsiTimeout" : '300',
+ "OS.OpensslPath" : '/usr/bin/openssl',
+ "OS.SshDir" : "/notareal/path",
+ "HttpProxy.Host" : None,
+ "HttpProxy.Port" : None,
+ "DetectScvmmEnv" : False,
+ "Lib.Dir" : "/var/lib/waagent",
+ "DVD.MountPoint" : "/mnt/cdrom/secure",
+ "Pid.File" : "/var/run/waagent.pid",
+ "Extension.LogDir" : "/var/log/azure",
+ "OS.HomeDir" : "/home",
+ "OS.EnableRDMA" : False,
+ "OS.UpdateRdmaDriver" : False,
+ "OS.CheckRdmaDriver" : False,
+ "AutoUpdate.Enabled" : True,
+ "AutoUpdate.GAFamily" : "Prod",
+ "EnableOverProvisioning" : False,
+ "OS.AllowHTTP" : False,
+ "OS.EnableFirewall" : True
+ }
+
def setUp(self):
AgentTestCase.setUp(self)
self.conf = ConfigurationProvider()
@@ -59,3 +102,11 @@ class TestConf(AgentTestCase):
def test_get_provision_cloudinit(self):
self.assertTrue(get_provision_cloudinit(self.conf))
+
+ def test_get_configuration(self):
+ configuration = conf.get_configuration(self.conf)
+ self.assertTrue(len(configuration.keys()) > 0)
+ for k in TestConf.EXPECTED_CONFIGURATION.keys():
+ self.assertEqual(
+ TestConf.EXPECTED_CONFIGURATION[k],
+ configuration[k])
diff --git a/tests/common/test_event.py b/tests/common/test_event.py
index a485edf..55a99c4 100644
--- a/tests/common/test_event.py
+++ b/tests/common/test_event.py
@@ -22,7 +22,8 @@ from datetime import datetime
import azurelinuxagent.common.event as event
import azurelinuxagent.common.logger as logger
-from azurelinuxagent.common.event import init_event_logger, add_event
+from azurelinuxagent.common.event import add_event, \
+ mark_event_status, should_emit_event
from azurelinuxagent.common.future import ustr
from azurelinuxagent.common.version import CURRENT_VERSION
@@ -30,10 +31,84 @@ from tests.tools import *
class TestEvent(AgentTestCase):
+ def test_event_status_event_marked(self):
+ es = event.__event_status__
+
+ self.assertFalse(es.event_marked("Foo", "1.2", "FauxOperation"))
+ es.mark_event_status("Foo", "1.2", "FauxOperation", True)
+ self.assertTrue(es.event_marked("Foo", "1.2", "FauxOperation"))
+
+ event.__event_status__ = event.EventStatus()
+ event.init_event_status(self.tmp_dir)
+ es = event.__event_status__
+ self.assertTrue(es.event_marked("Foo", "1.2", "FauxOperation"))
+
+ def test_event_status_defaults_to_success(self):
+ es = event.__event_status__
+ self.assertTrue(es.event_succeeded("Foo", "1.2", "FauxOperation"))
+
+ def test_event_status_records_status(self):
+ d = tempfile.mkdtemp()
+ es = event.EventStatus(tempfile.mkdtemp())
+
+ es.mark_event_status("Foo", "1.2", "FauxOperation", True)
+ self.assertTrue(es.event_succeeded("Foo", "1.2", "FauxOperation"))
+
+ es.mark_event_status("Foo", "1.2", "FauxOperation", False)
+ self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation"))
+
+ def test_event_status_preserves_state(self):
+ es = event.__event_status__
+
+ es.mark_event_status("Foo", "1.2", "FauxOperation", False)
+ self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation"))
+
+ event.__event_status__ = event.EventStatus()
+ event.init_event_status(self.tmp_dir)
+ es = event.__event_status__
+ self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation"))
+
+ def test_should_emit_event_ignores_unknown_operations(self):
+ event.__event_status__ = event.EventStatus(tempfile.mkdtemp())
+
+ self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", True))
+ self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", False))
+
+ # Marking the event has no effect
+ event.mark_event_status("Foo", "1.2", "FauxOperation", True)
+
+ self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", True))
+ self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", False))
+
+
+ def test_should_emit_event_handles_known_operations(self):
+ event.__event_status__ = event.EventStatus(tempfile.mkdtemp())
+
+ # Known operations always initially "fire"
+ for op in event.__event_status_operations__:
+ self.assertTrue(event.should_emit_event("Foo", "1.2", op, True))
+ self.assertTrue(event.should_emit_event("Foo", "1.2", op, False))
+
+ # Note a success event...
+ for op in event.__event_status_operations__:
+ event.mark_event_status("Foo", "1.2", op, True)
+
+ # Subsequent success events should not fire, but failures will
+ for op in event.__event_status_operations__:
+ self.assertFalse(event.should_emit_event("Foo", "1.2", op, True))
+ self.assertTrue(event.should_emit_event("Foo", "1.2", op, False))
+
+ # Note a failure event...
+ for op in event.__event_status_operations__:
+ event.mark_event_status("Foo", "1.2", op, False)
+
+ # Subsequent success events fire and failure do not
+ for op in event.__event_status_operations__:
+ self.assertTrue(event.should_emit_event("Foo", "1.2", op, True))
+ self.assertFalse(event.should_emit_event("Foo", "1.2", op, False))
@patch('azurelinuxagent.common.event.EventLogger.add_event')
def test_periodic_emits_if_not_previously_sent(self, mock_event):
- init_event_logger(tempfile.mkdtemp())
event.__event_logger__.reset_periodic()
event.add_periodic(logger.EVERY_DAY, "FauxEvent")
@@ -41,7 +116,6 @@ class TestEvent(AgentTestCase):
@patch('azurelinuxagent.common.event.EventLogger.add_event')
def test_periodic_does_not_emit_if_previously_sent(self, mock_event):
- init_event_logger(tempfile.mkdtemp())
event.__event_logger__.reset_periodic()
event.add_periodic(logger.EVERY_DAY, "FauxEvent")
@@ -52,7 +126,6 @@ class TestEvent(AgentTestCase):
@patch('azurelinuxagent.common.event.EventLogger.add_event')
def test_periodic_emits_if_forced(self, mock_event):
- init_event_logger(tempfile.mkdtemp())
event.__event_logger__.reset_periodic()
event.add_periodic(logger.EVERY_DAY, "FauxEvent")
@@ -63,7 +136,6 @@ class TestEvent(AgentTestCase):
@patch('azurelinuxagent.common.event.EventLogger.add_event')
def test_periodic_emits_after_elapsed_delta(self, mock_event):
- init_event_logger(tempfile.mkdtemp())
event.__event_logger__.reset_periodic()
event.add_periodic(logger.EVERY_DAY, "FauxEvent")
@@ -73,14 +145,13 @@ class TestEvent(AgentTestCase):
self.assertEqual(1, mock_event.call_count)
h = hash("FauxEvent"+""+ustr(True)+"")
- event.__event_logger__.periodic_messages[h] = \
+ event.__event_logger__.periodic_events[h] = \
datetime.now() - logger.EVERY_DAY - logger.EVERY_HOUR
event.add_periodic(logger.EVERY_DAY, "FauxEvent")
self.assertEqual(2, mock_event.call_count)
@patch('azurelinuxagent.common.event.EventLogger.add_event')
def test_periodic_forwards_args(self, mock_event):
- init_event_logger(tempfile.mkdtemp())
event.__event_logger__.reset_periodic()
event.add_periodic(logger.EVERY_DAY, "FauxEvent")
@@ -90,68 +161,58 @@ class TestEvent(AgentTestCase):
log_event=True, message='', op='', version=str(CURRENT_VERSION))
def test_save_event(self):
- tmp_evt = tempfile.mkdtemp()
- init_event_logger(tmp_evt)
add_event('test', message='test event')
- self.assertTrue(len(os.listdir(tmp_evt)) == 1)
- shutil.rmtree(tmp_evt)
+ self.assertTrue(len(os.listdir(self.tmp_dir)) == 1)
def test_save_event_rollover(self):
- tmp_evt = tempfile.mkdtemp()
- init_event_logger(tmp_evt)
add_event('test', message='first event')
for i in range(0, 999):
add_event('test', message='test event {0}'.format(i))
- events = os.listdir(tmp_evt)
+ events = os.listdir(self.tmp_dir)
events.sort()
self.assertTrue(len(events) == 1000)
- first_event = os.path.join(tmp_evt, events[0])
+ first_event = os.path.join(self.tmp_dir, events[0])
with open(first_event) as first_fh:
first_event_text = first_fh.read()
self.assertTrue('first event' in first_event_text)
add_event('test', message='last event')
- events = os.listdir(tmp_evt)
+ events = os.listdir(self.tmp_dir)
events.sort()
self.assertTrue(len(events) == 1000, "{0} events found, 1000 expected".format(len(events)))
- first_event = os.path.join(tmp_evt, events[0])
+ first_event = os.path.join(self.tmp_dir, events[0])
with open(first_event) as first_fh:
first_event_text = first_fh.read()
self.assertFalse('first event' in first_event_text)
self.assertTrue('test event 0' in first_event_text)
- last_event = os.path.join(tmp_evt, events[-1])
+ last_event = os.path.join(self.tmp_dir, events[-1])
with open(last_event) as last_fh:
last_event_text = last_fh.read()
self.assertTrue('last event' in last_event_text)
- shutil.rmtree(tmp_evt)
-
def test_save_event_cleanup(self):
- tmp_evt = tempfile.mkdtemp()
- init_event_logger(tmp_evt)
-
for i in range(0, 2000):
- evt = os.path.join(tmp_evt, '{0}.tld'.format(ustr(1491004920536531 + i)))
+ evt = os.path.join(self.tmp_dir, '{0}.tld'.format(ustr(1491004920536531 + i)))
with open(evt, 'w') as fh:
fh.write('test event {0}'.format(i))
- events = os.listdir(tmp_evt)
+ events = os.listdir(self.tmp_dir)
self.assertTrue(len(events) == 2000, "{0} events found, 2000 expected".format(len(events)))
add_event('test', message='last event')
- events = os.listdir(tmp_evt)
+ events = os.listdir(self.tmp_dir)
events.sort()
self.assertTrue(len(events) == 1000, "{0} events found, 1000 expected".format(len(events)))
- first_event = os.path.join(tmp_evt, events[0])
+ first_event = os.path.join(self.tmp_dir, events[0])
with open(first_event) as first_fh:
first_event_text = first_fh.read()
self.assertTrue('test event 1001' in first_event_text)
- last_event = os.path.join(tmp_evt, events[-1])
+ last_event = os.path.join(self.tmp_dir, events[-1])
with open(last_event) as last_fh:
last_event_text = last_fh.read()
self.assertTrue('last event' in last_event_text)
diff --git a/tests/data/ga/WALinuxAgent-2.2.11.zip b/tests/data/ga/WALinuxAgent-2.2.11.zip
deleted file mode 100644
index f018116..0000000
--- a/tests/data/ga/WALinuxAgent-2.2.11.zip
+++ /dev/null
Binary files differ
diff --git a/tests/data/ga/WALinuxAgent-2.2.14.zip b/tests/data/ga/WALinuxAgent-2.2.14.zip
new file mode 100644
index 0000000..a978207
--- /dev/null
+++ b/tests/data/ga/WALinuxAgent-2.2.14.zip
Binary files differ
diff --git a/tests/data/test_waagent.conf b/tests/data/test_waagent.conf
index 6368c39..edc3676 100644
--- a/tests/data/test_waagent.conf
+++ b/tests/data/test_waagent.conf
@@ -94,10 +94,13 @@ OS.SshDir=/notareal/path
# Extension.LogDir=/var/log/azure
#
-# Home.Dir=/home
+# OS.HomeDir=/home
# Enable RDMA management and set up, should only be used in HPC images
-# OS.EnableRDMA=y
+# OS.EnableRDMA=n
+# OS.UpdateRdmaDriver=n
+# OS.CheckRdmaDriver=n
+
# Enable or disable goal state processing auto-update, default is enabled
# AutoUpdate.Enabled=y
@@ -109,3 +112,12 @@ OS.SshDir=/notareal/path
# handling until inVMArtifactsProfile.OnHold is false.
# Default is disabled
# EnableOverProvisioning=n
+
+# Allow fallback to HTTP if HTTPS is unavailable
+# Note: Allowing HTTP (vs. HTTPS) may cause security risks
+# OS.AllowHTTP=n
+
+# Add firewall rules to protect access to Azure host node services
+# Note:
+# - The default is false to protect the state of exising VMs
+OS.EnableFirewall=y
diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py
index 0c8642c..59251cb 100644
--- a/tests/ga/test_update.py
+++ b/tests/ga/test_update.py
@@ -22,6 +22,7 @@ from datetime import datetime
import json
import shutil
+from azurelinuxagent.common.event import *
from azurelinuxagent.common.protocol.hostplugin import *
from azurelinuxagent.common.protocol.metadata import *
from azurelinuxagent.common.protocol.wire import *
@@ -148,7 +149,9 @@ class UpdateTestCase(AgentTestCase):
def create_error(self, error_data=NO_ERROR):
with self.get_error_file(error_data) as path:
- return GuestAgentError(path.name)
+ err = GuestAgentError(path.name)
+ err.load()
+ return err
def copy_agents(self, *agents):
if len(agents) <= 0:
@@ -157,11 +160,11 @@ class UpdateTestCase(AgentTestCase):
fileutil.copy_file(agent, to_dir=self.tmp_dir)
return
- def expand_agents(self, mark_test=False):
+ def expand_agents(self, mark_optional=False):
for agent in self.agent_pkgs():
path = os.path.join(self.tmp_dir, fileutil.trim_ext(agent, "zip"))
zipfile.ZipFile(agent).extractall(path)
- if mark_test:
+ if mark_optional:
src = os.path.join(data_dir, 'ga', 'supported.json')
dst = os.path.join(path, 'supported.json')
shutil.copy(src, dst)
@@ -170,12 +173,12 @@ class UpdateTestCase(AgentTestCase):
fileutil.write_file(dst, json.dumps(SENTINEL_ERROR))
return
- def prepare_agent(self, version, mark_test=False):
+ def prepare_agent(self, version, mark_optional=False):
"""
Create a download for the current agent version, copied from test data
"""
self.copy_agents(get_agent_pkgs()[0])
- self.expand_agents(mark_test=mark_test)
+ self.expand_agents(mark_optional=mark_optional)
versions = self.agent_versions()
src_v = FlexibleVersion(str(versions[0]))
@@ -246,7 +249,6 @@ class TestSupportedDistribution(UpdateTestCase):
self.sd = SupportedDistribution({
'slice':10,
'versions': ['^Ubuntu,16.10,yakkety$']})
-
def test_creation(self):
self.assertRaises(TypeError, SupportedDistribution)
@@ -276,6 +278,7 @@ class TestSupported(UpdateTestCase):
def setUp(self):
UpdateTestCase.setUp(self)
self.sp = Supported(os.path.join(data_dir, 'ga', 'supported.json'))
+ self.sp.load()
def test_creation(self):
self.assertRaises(TypeError, Supported)
@@ -305,6 +308,7 @@ class TestGuestAgentError(UpdateTestCase):
with self.get_error_file(error_data=WITH_ERROR) as path:
err = GuestAgentError(path.name)
+ err.load()
self.assertEqual(path.name, err.path)
self.assertNotEqual(None, err)
@@ -316,6 +320,7 @@ class TestGuestAgentError(UpdateTestCase):
def test_clear(self):
with self.get_error_file(error_data=WITH_ERROR) as path:
err = GuestAgentError(path.name)
+ err.load()
self.assertEqual(path.name, err.path)
self.assertNotEqual(None, err)
@@ -328,27 +333,16 @@ class TestGuestAgentError(UpdateTestCase):
def test_is_sentinel(self):
with self.get_error_file(error_data=SENTINEL_ERROR) as path:
err = GuestAgentError(path.name)
+ err.load()
self.assertTrue(err.is_blacklisted)
self.assertTrue(err.is_sentinel)
with self.get_error_file(error_data=FATAL_ERROR) as path:
err = GuestAgentError(path.name)
+ err.load()
self.assertTrue(err.is_blacklisted)
self.assertFalse(err.is_sentinel)
- def test_load_preserves_error_state(self):
- with self.get_error_file(error_data=WITH_ERROR) as path:
- err = GuestAgentError(path.name)
- self.assertEqual(path.name, err.path)
- self.assertNotEqual(None, err)
-
- with self.get_error_file(error_data=NO_ERROR):
- err.load()
- self.assertEqual(WITH_ERROR["last_failure"], err.last_failure)
- self.assertEqual(WITH_ERROR["failure_count"], err.failure_count)
- self.assertEqual(WITH_ERROR["was_fatal"], err.was_fatal)
- return
-
def test_save(self):
err1 = self.create_error()
err1.mark_failure()
@@ -406,22 +400,20 @@ class TestGuestAgent(UpdateTestCase):
self.agent_path = os.path.join(self.tmp_dir, get_agent_name())
return
- def tearDown(self):
- self.remove_agents()
- return
-
def test_creation(self):
self.assertRaises(UpdateError, GuestAgent, "A very bad file name")
n = "{0}-a.bad.version".format(AGENT_NAME)
self.assertRaises(UpdateError, GuestAgent, n)
+ self.expand_agents()
+
agent = GuestAgent(path=self.agent_path)
self.assertNotEqual(None, agent)
self.assertEqual(get_agent_name(), agent.name)
self.assertEqual(get_agent_version(), agent.version)
- self.assertFalse(agent.is_test)
- self.assertFalse(agent.in_slice)
+ self.assertFalse(agent._is_optional)
+ self.assertFalse(agent._in_slice)
self.assertEqual(self.agent_path, agent.get_agent_dir())
@@ -436,13 +428,14 @@ class TestGuestAgent(UpdateTestCase):
self.assertEqual(path, agent.get_agent_pkg_path())
self.assertTrue(agent.is_downloaded)
- # Note: Agent will get blacklisted since the package for this test is invalid
- self.assertTrue(agent.is_blacklisted)
- self.assertFalse(agent.is_available)
+ self.assertFalse(agent.is_blacklisted)
+ self.assertTrue(agent.is_available)
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_clear_error(self, mock_ensure):
+ def test_clear_error(self, mock_downloaded):
+ self.expand_agents()
+
agent = GuestAgent(path=self.agent_path)
agent.mark_failure(is_fatal=True)
@@ -459,7 +452,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_is_available(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_is_available(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(agent.is_available)
@@ -471,7 +465,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_is_blacklisted(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_is_blacklisted(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(agent.is_blacklisted)
@@ -485,7 +480,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_is_downloaded(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_is_downloaded(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(agent.is_downloaded)
agent._unpack()
@@ -493,50 +489,47 @@ class TestGuestAgent(UpdateTestCase):
return
@patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety'])
- def test_is_test(self, mock_dist):
- self.expand_agents(mark_test=True)
+ @patch('azurelinuxagent.ga.update.GuestAgent._enable')
+ def test_is_optional(self, mock_enable, mock_dist):
+ self.expand_agents(mark_optional=True)
agent = GuestAgent(path=self.agent_path)
self.assertTrue(agent.is_blacklisted)
- self.assertTrue(agent.is_test)
+ self.assertTrue(agent._is_optional)
@patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety'])
@patch('azurelinuxagent.ga.update.datetime')
def test_in_slice(self, mock_dt, mock_dist):
- self.expand_agents(mark_test=True)
+ self.expand_agents(mark_optional=True)
agent = GuestAgent(path=self.agent_path)
mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5))
- self.assertTrue(agent.in_slice)
+ self.assertTrue(agent._in_slice)
mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 42))
- self.assertFalse(agent.in_slice)
+ self.assertFalse(agent._in_slice)
@patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety'])
@patch('azurelinuxagent.ga.update.datetime')
def test_enable(self, mock_dt, mock_dist):
mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5))
- self.expand_agents(mark_test=True)
+ self.expand_agents(mark_optional=True)
agent = GuestAgent(path=self.agent_path)
- self.assertTrue(agent.is_blacklisted)
- self.assertTrue(agent.is_test)
- self.assertTrue(agent.in_slice)
-
- agent.enable()
-
self.assertFalse(agent.is_blacklisted)
- self.assertFalse(agent.is_test)
+ self.assertFalse(agent._is_optional)
# Ensure the new state is preserved to disk
agent = GuestAgent(path=self.agent_path)
self.assertFalse(agent.is_blacklisted)
- self.assertFalse(agent.is_test)
+ self.assertFalse(agent._is_optional)
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_mark_failure(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_mark_failure(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
+
agent.mark_failure()
self.assertEqual(1, agent.error.failure_count)
@@ -546,7 +539,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_unpack(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_unpack(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(os.path.isdir(agent.get_agent_dir()))
agent._unpack()
@@ -555,7 +549,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_unpack_fail(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_unpack_fail(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(os.path.isdir(agent.get_agent_dir()))
os.remove(agent.get_agent_pkg_path())
@@ -563,7 +558,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_load_manifest(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_load_manifest(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
agent._unpack()
agent._load_manifest()
@@ -572,7 +568,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_load_manifest_missing(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_load_manifest_missing(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(os.path.isdir(agent.get_agent_dir()))
agent._unpack()
@@ -581,7 +578,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_load_manifest_is_empty(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_load_manifest_is_empty(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(os.path.isdir(agent.get_agent_dir()))
agent._unpack()
@@ -593,7 +591,8 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
- def test_load_manifest_is_malformed(self, mock_ensure):
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
+ def test_load_manifest_is_malformed(self, mock_loaded, mock_downloaded):
agent = GuestAgent(path=self.agent_path)
self.assertFalse(os.path.isdir(agent.get_agent_dir()))
agent._unpack()
@@ -613,8 +612,9 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
@patch("azurelinuxagent.ga.update.restutil.http_get")
- def test_download(self, mock_http_get, mock_ensure):
+ def test_download(self, mock_http_get, mock_loaded, mock_downloaded):
self.remove_agents()
self.assertFalse(os.path.isdir(self.agent_path))
@@ -630,8 +630,9 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
@patch("azurelinuxagent.ga.update.restutil.http_get")
- def test_download_fail(self, mock_http_get, mock_ensure):
+ def test_download_fail(self, mock_http_get, mock_loaded, mock_downloaded):
self.remove_agents()
self.assertFalse(os.path.isdir(self.agent_path))
@@ -647,8 +648,9 @@ class TestGuestAgent(UpdateTestCase):
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded")
@patch("azurelinuxagent.ga.update.restutil.http_get")
- def test_download_fallback(self, mock_http_get, mock_ensure):
+ def test_download_fallback(self, mock_http_get, mock_loaded, mock_downloaded):
self.remove_agents()
self.assertFalse(os.path.isdir(self.agent_path))
@@ -681,8 +683,12 @@ class TestGuestAgent(UpdateTestCase):
return_value=True):
self.assertRaises(UpdateError, agent._download)
self.assertEqual(mock_http_get.call_count, 4)
+
self.assertEqual(mock_http_get.call_args_list[2][0][0], ext_uri)
+
self.assertEqual(mock_http_get.call_args_list[3][0][0], art_uri)
+ a, k = mock_http_get.call_args_list[3]
+ self.assertEqual(False, k['use_proxy'])
# ensure fallback works as expected
with patch.object(HostPluginProtocol,
@@ -690,8 +696,16 @@ class TestGuestAgent(UpdateTestCase):
return_value=[art_uri, {}]):
self.assertRaises(UpdateError, agent._download)
self.assertEqual(mock_http_get.call_count, 6)
+
+ a, k = mock_http_get.call_args_list[3]
+ self.assertEqual(False, k['use_proxy'])
+
self.assertEqual(mock_http_get.call_args_list[4][0][0], ext_uri)
+ a, k = mock_http_get.call_args_list[4]
+
self.assertEqual(mock_http_get.call_args_list[5][0][0], art_uri)
+ a, k = mock_http_get.call_args_list[5]
+ self.assertEqual(False, k['use_proxy'])
@patch("azurelinuxagent.ga.update.restutil.http_get")
def test_ensure_downloaded(self, mock_http_get):
@@ -725,7 +739,7 @@ class TestGuestAgent(UpdateTestCase):
@patch("azurelinuxagent.ga.update.GuestAgent._download")
@patch("azurelinuxagent.ga.update.GuestAgent._unpack", side_effect=UpdateError)
- def test_ensure_downloaded_unpack_fails(self, mock_download, mock_unpack):
+ def test_ensure_downloaded_unpack_fails(self, mock_unpack, mock_download):
self.assertFalse(os.path.isdir(self.agent_path))
pkg = ExtHandlerPackage(version=str(get_agent_version()))
@@ -740,7 +754,7 @@ class TestGuestAgent(UpdateTestCase):
@patch("azurelinuxagent.ga.update.GuestAgent._download")
@patch("azurelinuxagent.ga.update.GuestAgent._unpack")
@patch("azurelinuxagent.ga.update.GuestAgent._load_manifest", side_effect=UpdateError)
- def test_ensure_downloaded_load_manifest_fails(self, mock_download, mock_unpack, mock_manifest):
+ def test_ensure_downloaded_load_manifest_fails(self, mock_manifest, mock_unpack, mock_download):
self.assertFalse(os.path.isdir(self.agent_path))
pkg = ExtHandlerPackage(version=str(get_agent_version()))
@@ -755,10 +769,13 @@ class TestGuestAgent(UpdateTestCase):
@patch("azurelinuxagent.ga.update.GuestAgent._download")
@patch("azurelinuxagent.ga.update.GuestAgent._unpack")
@patch("azurelinuxagent.ga.update.GuestAgent._load_manifest")
- def test_ensure_download_skips_blacklisted(self, mock_download, mock_unpack, mock_manifest):
+ def test_ensure_download_skips_blacklisted(self, mock_manifest, mock_unpack, mock_download):
agent = GuestAgent(path=self.agent_path)
+ self.assertEqual(0, mock_download.call_count)
+
agent.clear_error()
agent.mark_failure(is_fatal=True)
+ self.assertTrue(agent.is_blacklisted)
pkg = ExtHandlerPackage(version=str(get_agent_version()))
pkg.uris.append(ExtHandlerPackageUri())
@@ -769,7 +786,6 @@ class TestGuestAgent(UpdateTestCase):
self.assertTrue(agent.is_blacklisted)
self.assertEqual(0, mock_download.call_count)
self.assertEqual(0, mock_unpack.call_count)
- self.assertEqual(0, mock_manifest.call_count)
return
@@ -812,6 +828,12 @@ class TestUpdate(UpdateTestCase):
self.event_patch.stop()
return
+ def _create_protocol(self, count=5, versions=None):
+ latest_version = self.prepare_agents(count=count)
+ if versions is None or len(versions) <= 0:
+ versions = [latest_version]
+ return ProtocolMock(versions=versions)
+
def _test_upgrade_available(
self,
base_version=FlexibleVersion(AGENT_VERSION),
@@ -819,12 +841,9 @@ class TestUpdate(UpdateTestCase):
versions=None,
count=5):
- latest_version = self.prepare_agents(count=count)
- if versions is None or len(versions) <= 0:
- versions = [latest_version]
-
if protocol is None:
- protocol = ProtocolMock(versions=versions)
+ protocol = self._create_protocol(count=count, versions=versions)
+
self.update_handler.protocol_util = protocol
conf.get_autoupdate_gafamily = Mock(return_value=protocol.family)
@@ -834,6 +853,16 @@ class TestUpdate(UpdateTestCase):
self.assertTrue(self._test_upgrade_available())
return
+ def test_upgrade_available_will_refresh_goal_state(self):
+ protocol = self._create_protocol()
+ protocol.emulate_stale_goal_state()
+ self.assertTrue(self._test_upgrade_available(protocol=protocol))
+ self.assertEqual(2, protocol.call_counts["get_vmagent_manifests"])
+ self.assertEqual(1, protocol.call_counts["get_vmagent_pkgs"])
+ self.assertEqual(1, protocol.call_counts["update_goal_state"])
+ self.assertTrue(protocol.goal_state_forced)
+ return
+
def test_get_latest_agent_excluded(self):
self.prepare_agent(AGENT_VERSION)
self.assertFalse(self._test_upgrade_available(
@@ -909,7 +938,7 @@ class TestUpdate(UpdateTestCase):
v = a.version
return
- def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL):
+ def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL, pid_count=0):
with patch.object(self.update_handler, 'osutil') as mock_util:
# Note:
# - Python only allows mutations of objects to which a function has
@@ -924,15 +953,20 @@ class TestUpdate(UpdateTestCase):
mock_util.check_pid_alive = Mock(side_effect=iterator)
+ pid_files = self.update_handler._get_pid_files()
+ self.assertEqual(pid_count, len(pid_files))
+
with patch('os.getpid', return_value=42):
with patch('time.sleep', return_value=None) as mock_sleep:
self.update_handler._ensure_no_orphans(orphan_wait_interval=interval)
+ for pid_file in pid_files:
+ self.assertFalse(os.path.exists(pid_file))
return mock_util.check_pid_alive.call_count, mock_sleep.call_count
return
def test_ensure_no_orphans(self):
fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41))
- calls, sleeps = self._test_ensure_no_orphans(invocations=3)
+ calls, sleeps = self._test_ensure_no_orphans(invocations=3, pid_count=1)
self.assertEqual(3, calls)
self.assertEqual(2, sleeps)
return
@@ -955,7 +989,8 @@ class TestUpdate(UpdateTestCase):
with patch('os.kill') as mock_kill:
calls, sleeps = self._test_ensure_no_orphans(
invocations=4,
- interval=3*GOAL_STATE_INTERVAL)
+ interval=3*GOAL_STATE_INTERVAL,
+ pid_count=1)
self.assertEqual(3, calls)
self.assertEqual(2, sleeps)
self.assertEqual(1, mock_kill.call_count)
@@ -1066,26 +1101,20 @@ class TestUpdate(UpdateTestCase):
return
def test_get_pid_files(self):
- previous_pid_file, pid_file, = self.update_handler._get_pid_files()
- self.assertEqual(None, previous_pid_file)
- self.assertEqual("0_waagent.pid", os.path.basename(pid_file))
+ pid_files = self.update_handler._get_pid_files()
+ self.assertEqual(0, len(pid_files))
return
def test_get_pid_files_returns_previous(self):
for n in range(1250):
fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1))
- previous_pid_file, pid_file, = self.update_handler._get_pid_files()
- self.assertEqual("1249_waagent.pid", os.path.basename(previous_pid_file))
- self.assertEqual("1250_waagent.pid", os.path.basename(pid_file))
- return
-
- @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety'])
- @patch('azurelinuxagent.ga.update.datetime')
- def test_get_test_agent(self, mock_dt, mock_dist):
- mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5))
- self.prepare_agent(AGENT_VERSION, mark_test=True)
+ pid_files = self.update_handler._get_pid_files()
+ self.assertEqual(1250, len(pid_files))
- self.assertNotEqual(None, self.update_handler.get_test_agent())
+ pid_dir, pid_name, pid_re = self.update_handler._get_pid_parts()
+ for p in pid_files:
+ self.assertTrue(pid_re.match(os.path.basename(p)))
+ return
def test_is_clean_start_returns_true_when_no_sentinal(self):
self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
@@ -1421,23 +1450,6 @@ class TestUpdate(UpdateTestCase):
self.update_handler._upgrade_available = Mock(return_value=True)
self._test_run(invocations=0, calls=[], enable_updates=True)
return
-
- @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety'])
- @patch('azurelinuxagent.ga.update.datetime')
- def test_run_stops_if_test_agent_available(self, mock_dt, mock_dist):
- mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5))
- self.prepare_agent(AGENT_VERSION, mark_test=True)
-
- agent = GuestAgent(path=self.agent_dir(AGENT_VERSION))
- agent.enable = Mock()
- self.assertTrue(agent.is_test)
- self.assertTrue(agent.in_slice)
-
- with patch('azurelinuxagent.ga.update.UpdateHandler.get_test_agent',
- return_value=agent) as mock_test:
- self._test_run(invocations=0)
- self.assertEqual(mock_test.call_count, 1)
- self.assertEqual(agent.enable.call_count, 1)
def test_run_stops_if_orphaned(self):
with patch('os.getppid', return_value=1):
@@ -1521,8 +1533,9 @@ class TestUpdate(UpdateTestCase):
for n in range(1112):
fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1))
with patch('os.getpid', return_value=1112):
- previous_pid_file, pid_file = self.update_handler._write_pid_file()
- self.assertEqual("1111_waagent.pid", os.path.basename(previous_pid_file))
+ pid_files, pid_file = self.update_handler._write_pid_file()
+ self.assertEqual(1112, len(pid_files))
+ self.assertEqual("1111_waagent.pid", os.path.basename(pid_files[-1]))
self.assertEqual("1112_waagent.pid", os.path.basename(pid_file))
self.assertEqual(fileutil.read_file(pid_file), ustr(1112))
return
@@ -1530,8 +1543,8 @@ class TestUpdate(UpdateTestCase):
def test_write_pid_file_ignores_exceptions(self):
with patch('azurelinuxagent.common.utils.fileutil.write_file', side_effect=Exception):
with patch('os.getpid', return_value=42):
- previous_pid_file, pid_file = self.update_handler._write_pid_file()
- self.assertEqual(None, previous_pid_file)
+ pid_files, pid_file = self.update_handler._write_pid_file()
+ self.assertEqual(0, len(pid_files))
self.assertEqual(None, pid_file)
return
@@ -1549,12 +1562,22 @@ class ProtocolMock(object):
def __init__(self, family="TestAgent", etag=42, versions=None, client=None):
self.family = family
self.client = client
+ self.call_counts = {
+ "get_vmagent_manifests" : 0,
+ "get_vmagent_pkgs" : 0,
+ "update_goal_state" : 0
+ }
+ self.goal_state_is_stale = False
+ self.goal_state_forced = False
self.etag = etag
self.versions = versions if versions is not None else []
self.create_manifests()
self.create_packages()
return
+ def emulate_stale_goal_state(self):
+ self.goal_state_is_stale = True
+
def create_manifests(self):
self.agent_manifests = VMAgentManifestList()
if len(self.versions) <= 0:
@@ -1585,11 +1608,23 @@ class ProtocolMock(object):
return self
def get_vmagent_manifests(self):
+ self.call_counts["get_vmagent_manifests"] += 1
+ if self.goal_state_is_stale:
+ self.goal_state_is_stale = False
+ raise ResourceGoneError()
return self.agent_manifests, self.etag
def get_vmagent_pkgs(self, manifest):
+ self.call_counts["get_vmagent_pkgs"] += 1
+ if self.goal_state_is_stale:
+ self.goal_state_is_stale = False
+ raise ResourceGoneError()
return self.agent_packages
+ def update_goal_state(self, forced=False, max_retry=3):
+ self.call_counts["update_goal_state"] += 1
+ self.goal_state_forced = self.goal_state_forced or forced
+ return
class ResponseMock(Mock):
def __init__(self, status=restutil.httpclient.OK, response=None, reason=None):
diff --git a/tests/pa/test_provision.py b/tests/pa/test_provision.py
index 0446442..7045fcc 100644
--- a/tests/pa/test_provision.py
+++ b/tests/pa/test_provision.py
@@ -53,6 +53,24 @@ class TestProvision(AgentTestCase):
data = DefaultOSUtil().decode_customdata(base64data)
fileutil.write_file(tempfile.mktemp(), data)
+ @patch('azurelinuxagent.common.conf.get_provision_enabled',
+ return_value=False)
+ def test_provisioning_is_skipped_when_not_enabled(self, mock_conf):
+ ph = ProvisionHandler()
+ ph.osutil = DefaultOSUtil()
+ ph.osutil.get_instance_id = Mock(
+ return_value='B9F3C233-9913-9F42-8EB3-BA656DF32502')
+
+ ph.is_provisioned = Mock()
+ ph.report_ready = Mock()
+ ph.write_provisioned = Mock()
+
+ ph.run()
+
+ ph.is_provisioned.assert_not_called()
+ ph.report_ready.assert_called_once()
+ ph.write_provisioned.assert_called_once()
+
@patch('os.path.isfile', return_value=False)
def test_is_provisioned_not_provisioned(self, mock_isfile):
ph = ProvisionHandler()
@@ -64,33 +82,37 @@ class TestProvision(AgentTestCase):
@patch('azurelinuxagent.pa.deprovision.get_deprovision_handler')
def test_is_provisioned_is_provisioned(self,
mock_deprovision, mock_read, mock_isfile):
+
ph = ProvisionHandler()
ph.osutil = Mock()
- ph.osutil.get_instance_id = \
- Mock(return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502")
+ ph.osutil.is_current_instance_id = Mock(return_value=True)
ph.write_provisioned = Mock()
deprovision_handler = Mock()
mock_deprovision.return_value = deprovision_handler
self.assertTrue(ph.is_provisioned())
+ ph.osutil.is_current_instance_id.assert_called_once()
deprovision_handler.run_changed_unique_id.assert_not_called()
@patch('os.path.isfile', return_value=True)
@patch('azurelinuxagent.common.utils.fileutil.read_file',
- side_effect=["Value"])
+ return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502")
@patch('azurelinuxagent.pa.deprovision.get_deprovision_handler')
def test_is_provisioned_not_deprovisioned(self,
mock_deprovision, mock_read, mock_isfile):
ph = ProvisionHandler()
ph.osutil = Mock()
+ ph.osutil.is_current_instance_id = Mock(return_value=False)
+ ph.report_ready = Mock()
ph.write_provisioned = Mock()
deprovision_handler = Mock()
mock_deprovision.return_value = deprovision_handler
self.assertTrue(ph.is_provisioned())
+ ph.osutil.is_current_instance_id.assert_called_once()
deprovision_handler.run_changed_unique_id.assert_called_once()
if __name__ == '__main__':
diff --git a/tests/protocol/mockwiredata.py b/tests/protocol/mockwiredata.py
index 4e45623..5924719 100644
--- a/tests/protocol/mockwiredata.py
+++ b/tests/protocol/mockwiredata.py
@@ -16,6 +16,7 @@
#
from tests.tools import *
+from azurelinuxagent.common.exception import HttpError, ResourceGoneError
from azurelinuxagent.common.future import httpclient
from azurelinuxagent.common.utils.cryptutil import CryptUtil
@@ -53,6 +54,20 @@ DATA_FILE_EXT_AUTOUPGRADE_INTERNALVERSION["ext_conf"] = "wire/ext_conf_autoupgra
class WireProtocolData(object):
def __init__(self, data_files=DATA_FILE):
+ self.emulate_stale_goal_state = False
+ self.call_counts = {
+ "comp=versions" : 0,
+ "/versions" : 0,
+ "goalstate" : 0,
+ "hostingenvuri" : 0,
+ "sharedconfiguri" : 0,
+ "certificatesuri" : 0,
+ "extensionsconfiguri" : 0,
+ "extensionArtifact" : 0,
+ "manifest.xml" : 0,
+ "manifest_of_ga.xml" : 0,
+ "ExampleHandlerLinux" : 0
+ }
self.version_info = load_data(data_files.get("version_info"))
self.goal_state = load_data(data_files.get("goal_state"))
self.hosting_env = load_data(data_files.get("hosting_env"))
@@ -67,32 +82,70 @@ class WireProtocolData(object):
def mock_http_get(self, url, *args, **kwargs):
content = None
- if "versions" in url:
+
+ resp = MagicMock()
+ resp.status = httpclient.OK
+
+ # wire server versions
+ if "comp=versions" in url:
content = self.version_info
+ self.call_counts["comp=versions"] += 1
+
+ # HostPlugin versions
+ elif "/versions" in url:
+ content = '["2015-09-01"]'
+ self.call_counts["/versions"] += 1
elif "goalstate" in url:
content = self.goal_state
+ self.call_counts["goalstate"] += 1
elif "hostingenvuri" in url:
content = self.hosting_env
+ self.call_counts["hostingenvuri"] += 1
elif "sharedconfiguri" in url:
content = self.shared_config
+ self.call_counts["sharedconfiguri"] += 1
elif "certificatesuri" in url:
content = self.certs
+ self.call_counts["certificatesuri"] += 1
elif "extensionsconfiguri" in url:
content = self.ext_conf
- elif "manifest.xml" in url:
- content = self.manifest
- elif "manifest_of_ga.xml" in url:
- content = self.ga_manifest
- elif "ExampleHandlerLinux" in url:
- content = self.ext
- resp = MagicMock()
- resp.status = httpclient.OK
- resp.read = Mock(return_value=content)
- return resp
+ self.call_counts["extensionsconfiguri"] += 1
+
else:
- raise Exception("Bad url {0}".format(url))
- resp = MagicMock()
- resp.status = httpclient.OK
+ # A stale GoalState results in a 400 from the HostPlugin
+ # for which the HTTP handler in restutil raises ResourceGoneError
+ if self.emulate_stale_goal_state:
+ if "extensionArtifact" in url:
+ self.emulate_stale_goal_state = False
+ self.call_counts["extensionArtifact"] += 1
+ raise ResourceGoneError()
+ else:
+ raise HttpError()
+
+ # For HostPlugin requests, replace the URL with that passed
+ # via the x-ms-artifact-location header
+ if "extensionArtifact" in url:
+ self.call_counts["extensionArtifact"] += 1
+ if "headers" not in kwargs or \
+ "x-ms-artifact-location" not in kwargs["headers"]:
+ raise Exception("Bad HEADERS passed to HostPlugin: {0}",
+ kwargs)
+ url = kwargs["headers"]["x-ms-artifact-location"]
+
+ if "manifest.xml" in url:
+ content = self.manifest
+ self.call_counts["manifest.xml"] += 1
+ elif "manifest_of_ga.xml" in url:
+ content = self.ga_manifest
+ self.call_counts["manifest_of_ga.xml"] += 1
+ elif "ExampleHandlerLinux" in url:
+ content = self.ext
+ self.call_counts["ExampleHandlerLinux"] += 1
+ resp.read = Mock(return_value=content)
+ return resp
+ else:
+ raise Exception("Bad url {0}".format(url))
+
resp.read = Mock(return_value=content.encode("utf-8"))
return resp
diff --git a/tests/protocol/test_hostplugin.py b/tests/protocol/test_hostplugin.py
index b18b691..74f7f24 100644
--- a/tests/protocol/test_hostplugin.py
+++ b/tests/protocol/test_hostplugin.py
@@ -146,6 +146,7 @@ class TestHostPlugin(AgentTestCase):
test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state)
status = restapi.VMStatus(status="Ready",
message="Guest Agent is running")
+ wire.HostPluginProtocol.set_default_channel(False)
with patch.object(wire.HostPluginProtocol,
"ensure_initialized",
return_value=True):
@@ -173,6 +174,7 @@ class TestHostPlugin(AgentTestCase):
test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state)
status = restapi.VMStatus(status="Ready",
message="Guest Agent is running")
+ wire.HostPluginProtocol.set_default_channel(False)
with patch.object(wire.StatusBlob,
"upload",
return_value=False):
@@ -211,6 +213,8 @@ class TestHostPlugin(AgentTestCase):
bytearray(faux_status, encoding='utf-8'))
with patch.object(restutil, "http_request") as patch_http:
+ patch_http.return_value = Mock(status=httpclient.OK)
+
wire_protocol_client.get_goal_state = Mock(return_value=test_goal_state)
plugin = wire_protocol_client.get_host_plugin()
@@ -224,61 +228,6 @@ class TestHostPlugin(AgentTestCase):
test_goal_state,
exp_method, exp_url, exp_data)
- def test_read_response_error(self):
- """
- Validate the read_response_error method handles encoding correctly
- """
- responses = ['message', b'message', '\x80message\x80']
- response = MagicMock()
- response.status = 'status'
- response.reason = 'reason'
- with patch.object(response, 'read') as patch_response:
- for s in responses:
- patch_response.return_value = s
- result = hostplugin.HostPluginProtocol.read_response_error(response)
- self.assertTrue('[status: reason]' in result)
- self.assertTrue('message' in result)
-
- def test_read_response_bytes(self):
- response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \
- '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \
- '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \
- '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \
- '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \
- 'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \
- '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \
- '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \
- '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \
- '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \
- '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \
- '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \
- '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \
- '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \
- '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \
- '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \
- '6c:73:22:3a:20:22:22:0a:7d'.split(':')
- expected_response = '[status: reason] {\n "errorCode": "The blob ' \
- 'type is invalid for this operation.",\n ' \
- '"message": "<?xml version="1.0" ' \
- 'encoding="utf-8"?>' \
- '<Error><Code>InvalidBlobType</Code><Message>The ' \
- 'blob type is invalid for this operation.\n' \
- 'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \
- '\n "details": ""\n}'
-
- response_string = ''.join(chr(int(b, 16)) for b in response_bytes)
- response = MagicMock()
- response.status = 'status'
- response.reason = 'reason'
- with patch.object(response, 'read') as patch_response:
- patch_response.return_value = response_string
- result = hostplugin.HostPluginProtocol.read_response_error(response)
- self.assertEqual(result, expected_response)
- try:
- raise HttpError("{0}".format(result))
- except HttpError as e:
- self.assertTrue(result in ustr(e))
-
def test_no_fallback(self):
"""
Validate fallback to upload status using HostGAPlugin is not happening
@@ -318,6 +267,8 @@ class TestHostPlugin(AgentTestCase):
bytearray(faux_status, encoding='utf-8'))
with patch.object(restutil, "http_request") as patch_http:
+ patch_http.return_value = Mock(status=httpclient.OK)
+
with patch.object(wire.HostPluginProtocol,
"get_api_versions") as patch_get:
patch_get.return_value = api_versions
diff --git a/tests/protocol/test_metadata.py b/tests/protocol/test_metadata.py
index ee4ba3e..5047b86 100644
--- a/tests/protocol/test_metadata.py
+++ b/tests/protocol/test_metadata.py
@@ -31,17 +31,15 @@ class TestMetadataProtocolGetters(AgentTestCase):
return json.loads(ustr(load_data(path)), encoding="utf-8")
@patch("time.sleep")
- @patch("azurelinuxagent.common.protocol.metadata.restutil")
- def _test_getters(self, test_data, mock_restutil ,_):
- mock_restutil.http_get.side_effect = test_data.mock_http_get
-
- protocol = MetadataProtocol()
- protocol.detect()
- protocol.get_vminfo()
- protocol.get_certs()
- ext_handlers, etag = protocol.get_ext_handlers()
- for ext_handler in ext_handlers.extHandlers:
- protocol.get_ext_handler_pkgs(ext_handler)
+ def _test_getters(self, test_data ,_):
+ with patch.object(restutil, 'http_get', test_data.mock_http_get):
+ protocol = MetadataProtocol()
+ protocol.detect()
+ protocol.get_vminfo()
+ protocol.get_certs()
+ ext_handlers, etag = protocol.get_ext_handlers()
+ for ext_handler in ext_handlers.extHandlers:
+ protocol.get_ext_handler_pkgs(ext_handler)
def test_getters(self, *args):
test_data = MetadataProtocolData(DATA_FILE)
diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py
index 02976ca..d19bab1 100644
--- a/tests/protocol/test_wire.py
+++ b/tests/protocol/test_wire.py
@@ -25,30 +25,34 @@ wireserver_url = '168.63.129.16'
@patch("time.sleep")
@patch("azurelinuxagent.common.protocol.wire.CryptUtil")
-@patch("azurelinuxagent.common.protocol.wire.restutil")
class TestWireProtocolGetters(AgentTestCase):
- def _test_getters(self, test_data, mock_restutil, MockCryptUtil, _):
- mock_restutil.http_get.side_effect = test_data.mock_http_get
+
+ def setUp(self):
+ super(TestWireProtocolGetters, self).setUp()
+ HostPluginProtocol.set_default_channel(False)
+
+ def _test_getters(self, test_data, MockCryptUtil, _):
MockCryptUtil.side_effect = test_data.mock_crypt_util
- protocol = WireProtocol(wireserver_url)
- protocol.detect()
- protocol.get_vminfo()
- protocol.get_certs()
- ext_handlers, etag = protocol.get_ext_handlers()
- for ext_handler in ext_handlers.extHandlers:
- protocol.get_ext_handler_pkgs(ext_handler)
-
- crt1 = os.path.join(self.tmp_dir,
- '33B0ABCE4673538650971C10F7D7397E71561F35.crt')
- crt2 = os.path.join(self.tmp_dir,
- '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt')
- prv2 = os.path.join(self.tmp_dir,
- '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv')
-
- self.assertTrue(os.path.isfile(crt1))
- self.assertTrue(os.path.isfile(crt2))
- self.assertTrue(os.path.isfile(prv2))
+ with patch.object(restutil, 'http_get', test_data.mock_http_get):
+ protocol = WireProtocol(wireserver_url)
+ protocol.detect()
+ protocol.get_vminfo()
+ protocol.get_certs()
+ ext_handlers, etag = protocol.get_ext_handlers()
+ for ext_handler in ext_handlers.extHandlers:
+ protocol.get_ext_handler_pkgs(ext_handler)
+
+ crt1 = os.path.join(self.tmp_dir,
+ '33B0ABCE4673538650971C10F7D7397E71561F35.crt')
+ crt2 = os.path.join(self.tmp_dir,
+ '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt')
+ prv2 = os.path.join(self.tmp_dir,
+ '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv')
+
+ self.assertTrue(os.path.isfile(crt1))
+ self.assertTrue(os.path.isfile(crt2))
+ self.assertTrue(os.path.isfile(prv2))
def test_getters(self, *args):
"""Normal case"""
@@ -70,8 +74,21 @@ class TestWireProtocolGetters(AgentTestCase):
test_data = WireProtocolData(DATA_FILE_EXT_NO_PUBLIC)
self._test_getters(test_data, *args)
+ def test_getters_with_stale_goal_state(self, *args):
+ test_data = WireProtocolData(DATA_FILE)
+ test_data.emulate_stale_goal_state = True
+
+ self._test_getters(test_data, *args)
+ # Ensure HostPlugin was invoked
+ self.assertEqual(1, test_data.call_counts["/versions"])
+ self.assertEqual(2, test_data.call_counts["extensionArtifact"])
+ # Ensure the expected number of HTTP calls were made
+ # -- Tracking calls to retrieve GoalState is problematic since it is
+ # fetched often; however, the dependent documents, such as the
+ # HostingEnvironmentConfig, will be retrieved the expected number
+ self.assertEqual(2, test_data.call_counts["hostingenvuri"])
+
def test_call_storage_kwargs(self,
- mock_restutil,
mock_cryptutil,
mock_sleep):
from azurelinuxagent.common.utils import restutil
@@ -83,32 +100,32 @@ class TestWireProtocolGetters(AgentTestCase):
# no kwargs -- Default to True
WireClient.call_storage_service(http_req)
- # kwargs, no chk_proxy -- Default to True
+ # kwargs, no use_proxy -- Default to True
WireClient.call_storage_service(http_req,
url,
headers)
- # kwargs, chk_proxy None -- Default to True
+ # kwargs, use_proxy None -- Default to True
WireClient.call_storage_service(http_req,
url,
headers,
- chk_proxy=None)
+ use_proxy=None)
- # kwargs, chk_proxy False -- Keep False
+ # kwargs, use_proxy False -- Keep False
WireClient.call_storage_service(http_req,
url,
headers,
- chk_proxy=False)
+ use_proxy=False)
- # kwargs, chk_proxy True -- Keep True
+ # kwargs, use_proxy True -- Keep True
WireClient.call_storage_service(http_req,
url,
headers,
- chk_proxy=True)
+ use_proxy=True)
# assert
self.assertTrue(http_patch.call_count == 5)
for i in range(0,5):
- c = http_patch.call_args_list[i][-1]['chk_proxy']
+ c = http_patch.call_args_list[i][-1]['use_proxy']
self.assertTrue(c == (True if i != 3 else False))
def test_status_blob_parsing(self, *args):
diff --git a/tests/test_agent.py b/tests/test_agent.py
index 1b35933..77be07a 100644
--- a/tests/test_agent.py
+++ b/tests/test_agent.py
@@ -17,12 +17,56 @@
import mock
import os.path
+import sys
from azurelinuxagent.agent import *
from azurelinuxagent.common.conf import *
from tests.tools import *
+EXPECTED_CONFIGURATION = \
+"""AutoUpdate.Enabled = True
+AutoUpdate.GAFamily = Prod
+Autoupdate.Frequency = 3600
+DVD.MountPoint = /mnt/cdrom/secure
+DetectScvmmEnv = False
+EnableOverProvisioning = False
+Extension.LogDir = /var/log/azure
+HttpProxy.Host = None
+HttpProxy.Port = None
+Lib.Dir = /var/lib/waagent
+Logs.Verbose = False
+OS.AllowHTTP = False
+OS.CheckRdmaDriver = False
+OS.EnableFIPS = True
+OS.EnableFirewall = True
+OS.EnableRDMA = False
+OS.HomeDir = /home
+OS.OpensslPath = /usr/bin/openssl
+OS.PasswordPath = /etc/shadow
+OS.RootDeviceScsiTimeout = 300
+OS.SshDir = /notareal/path
+OS.SudoersDir = /etc/sudoers.d
+OS.UpdateRdmaDriver = False
+Pid.File = /var/run/waagent.pid
+Provisioning.AllowResetSysUser = False
+Provisioning.DecodeCustomData = False
+Provisioning.DeleteRootPassword = True
+Provisioning.Enabled = True
+Provisioning.ExecuteCustomData = False
+Provisioning.MonitorHostName = True
+Provisioning.PasswordCryptId = 6
+Provisioning.PasswordCryptSaltLength = 10
+Provisioning.RegenerateSshHostKeyPair = True
+Provisioning.SshHostKeyPairType = rsa
+Provisioning.UseCloudInit = True
+ResourceDisk.EnableSwap = False
+ResourceDisk.Filesystem = ext4
+ResourceDisk.Format = True
+ResourceDisk.MountOptions = None
+ResourceDisk.MountPoint = /mnt/resource
+ResourceDisk.SwapSizeMB = 0
+""".split('\n')
class TestAgent(AgentTestCase):
@@ -90,3 +134,36 @@ class TestAgent(AgentTestCase):
mock_daemon.run.assert_called_once_with(child_args="-configuration-path:/foo/bar.conf")
mock_load.assert_called_once()
+
+ @patch("azurelinuxagent.common.conf.get_ext_log_dir")
+ def test_agent_ensures_extension_log_directory(self, mock_dir):
+ ext_log_dir = os.path.join(self.tmp_dir, "FauxLogDir")
+ mock_dir.return_value = ext_log_dir
+
+ self.assertFalse(os.path.isdir(ext_log_dir))
+ agent = Agent(False,
+ conf_file_path=os.path.join(data_dir, "test_waagent.conf"))
+ self.assertTrue(os.path.isdir(ext_log_dir))
+
+ @patch("azurelinuxagent.common.logger.error")
+ @patch("azurelinuxagent.common.conf.get_ext_log_dir")
+ def test_agent_logs_if_extension_log_directory_is_a_file(self, mock_dir, mock_log):
+ ext_log_dir = os.path.join(self.tmp_dir, "FauxLogDir")
+ mock_dir.return_value = ext_log_dir
+ fileutil.write_file(ext_log_dir, "Foo")
+
+ self.assertTrue(os.path.isfile(ext_log_dir))
+ self.assertFalse(os.path.isdir(ext_log_dir))
+ agent = Agent(False,
+ conf_file_path=os.path.join(data_dir, "test_waagent.conf"))
+ self.assertTrue(os.path.isfile(ext_log_dir))
+ self.assertFalse(os.path.isdir(ext_log_dir))
+ mock_log.assert_called_once()
+
+ def test_agent_show_configuration(self):
+ if not hasattr(sys.stdout, 'getvalue'):
+ self.fail('Test requires at least Python 2.7 with buffered output')
+ agent = Agent(False,
+ conf_file_path=os.path.join(data_dir, "test_waagent.conf"))
+ agent.show_configuration()
+ self.assertEqual(EXPECTED_CONFIGURATION, sys.stdout.getvalue().split('\n'))
diff --git a/tests/tools.py b/tests/tools.py
index a505700..94fab7f 100644
--- a/tests/tools.py
+++ b/tests/tools.py
@@ -26,8 +26,10 @@ import tempfile
import unittest
from functools import wraps
+import azurelinuxagent.common.event as event
import azurelinuxagent.common.conf as conf
import azurelinuxagent.common.logger as logger
+
from azurelinuxagent.common.version import PY_VERSION_MAJOR
#Import mock module for Python2 and Python3
@@ -51,14 +53,21 @@ if debug:
class AgentTestCase(unittest.TestCase):
def setUp(self):
prefix = "{0}_".format(self.__class__.__name__)
+
self.tmp_dir = tempfile.mkdtemp(prefix=prefix)
self.test_file = 'test_file'
+
conf.get_autoupdate_enabled = Mock(return_value=True)
conf.get_lib_dir = Mock(return_value=self.tmp_dir)
+
ext_log_dir = os.path.join(self.tmp_dir, "azure")
conf.get_ext_log_dir = Mock(return_value=ext_log_dir)
+
conf.get_agent_pid_file_path = Mock(return_value=os.path.join(self.tmp_dir, "waagent.pid"))
+ event.init_event_status(self.tmp_dir)
+ event.init_event_logger(self.tmp_dir)
+
def tearDown(self):
if not debug and self.tmp_dir is not None:
shutil.rmtree(self.tmp_dir)
diff --git a/tests/utils/test_file_util.py b/tests/utils/test_file_util.py
index 0b92513..87bce8c 100644
--- a/tests/utils/test_file_util.py
+++ b/tests/utils/test_file_util.py
@@ -15,6 +15,7 @@
# Requires Python 2.4+ and Openssl 1.0+
#
+import errno as errno
import glob
import random
import string
@@ -64,6 +65,50 @@ class TestFileOperations(AgentTestCase):
os.remove(test_file)
+ def test_findre_in_file(self):
+ fp = tempfile.mktemp()
+ with open(fp, 'w') as f:
+ f.write(
+'''
+First line
+Second line
+Third line with more words
+'''
+ )
+
+ self.assertNotEquals(
+ None,
+ fileutil.findre_in_file(fp, ".*rst line$"))
+ self.assertNotEquals(
+ None,
+ fileutil.findre_in_file(fp, ".*ond line$"))
+ self.assertNotEquals(
+ None,
+ fileutil.findre_in_file(fp, ".*with more.*"))
+ self.assertNotEquals(
+ None,
+ fileutil.findre_in_file(fp, "^Third.*"))
+ self.assertEquals(
+ None,
+ fileutil.findre_in_file(fp, "^Do not match.*"))
+
+ def test_findstr_in_file(self):
+ fp = tempfile.mktemp()
+ with open(fp, 'w') as f:
+ f.write(
+'''
+First line
+Second line
+Third line with more words
+'''
+ )
+
+ self.assertTrue(fileutil.findstr_in_file(fp, "First line"))
+ self.assertTrue(fileutil.findstr_in_file(fp, "Second line"))
+ self.assertTrue(
+ fileutil.findstr_in_file(fp, "Third line with more words"))
+ self.assertFalse(fileutil.findstr_in_file(fp, "Not a line"))
+
def test_get_last_path_element(self):
filepath = '/tmp/abc.def'
filename = fileutil.base_name(filepath)
@@ -197,5 +242,75 @@ DHCP_HOSTNAME=test\n"
fileutil.update_conf_file(path, 'DHCP_HOSTNAME', 'DHCP_HOSTNAME=test')
patch_write.assert_called_once_with(path, updated_file)
+ def test_clean_ioerror_ignores_missing(self):
+ e = IOError()
+ e.errno = errno.ENOSPC
+
+ # Send no paths
+ fileutil.clean_ioerror(e)
+
+ # Send missing file(s) / directories
+ fileutil.clean_ioerror(e, paths=['/foo/not/here', None, '/bar/not/there'])
+
+ def test_clean_ioerror_ignores_unless_ioerror(self):
+ try:
+ d = tempfile.mkdtemp()
+ fd, f = tempfile.mkstemp()
+ os.close(fd)
+ fileutil.write_file(f, 'Not empty')
+
+ # Send non-IOError exception
+ e = Exception()
+ fileutil.clean_ioerror(e, paths=[d, f])
+ self.assertTrue(os.path.isdir(d))
+ self.assertTrue(os.path.isfile(f))
+
+ # Send unrecognized IOError
+ e = IOError()
+ e.errno = errno.EFAULT
+ self.assertFalse(e.errno in fileutil.KNOWN_IOERRORS)
+ fileutil.clean_ioerror(e, paths=[d, f])
+ self.assertTrue(os.path.isdir(d))
+ self.assertTrue(os.path.isfile(f))
+
+ finally:
+ shutil.rmtree(d)
+ os.remove(f)
+
+ def test_clean_ioerror_removes_files(self):
+ fd, f = tempfile.mkstemp()
+ os.close(fd)
+ fileutil.write_file(f, 'Not empty')
+
+ e = IOError()
+ e.errno = errno.ENOSPC
+ fileutil.clean_ioerror(e, paths=[f])
+ self.assertFalse(os.path.isdir(f))
+ self.assertFalse(os.path.isfile(f))
+
+ def test_clean_ioerror_removes_directories(self):
+ d1 = tempfile.mkdtemp()
+ d2 = tempfile.mkdtemp()
+ for n in ['foo', 'bar']:
+ fileutil.write_file(os.path.join(d2, n), 'Not empty')
+
+ e = IOError()
+ e.errno = errno.ENOSPC
+ fileutil.clean_ioerror(e, paths=[d1, d2])
+ self.assertFalse(os.path.isdir(d1))
+ self.assertFalse(os.path.isfile(d1))
+ self.assertFalse(os.path.isdir(d2))
+ self.assertFalse(os.path.isfile(d2))
+
+ def test_clean_ioerror_handles_a_range_of_errors(self):
+ for err in fileutil.KNOWN_IOERRORS:
+ e = IOError()
+ e.errno = err
+
+ d = tempfile.mkdtemp()
+ fileutil.clean_ioerror(e, paths=[d])
+ self.assertFalse(os.path.isdir(d))
+ self.assertFalse(os.path.isfile(d))
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/utils/test_rest_util.py b/tests/utils/test_rest_util.py
index 5f084a6..52674da 100644
--- a/tests/utils/test_rest_util.py
+++ b/tests/utils/test_rest_util.py
@@ -15,10 +15,16 @@
# Requires Python 2.4+ and Openssl 1.0+
#
+import os
import unittest
+
+from azurelinuxagent.common.exception import HttpError, \
+ ProtocolError, \
+ ResourceGoneError
import azurelinuxagent.common.utils.restutil as restutil
-from azurelinuxagent.common.future import httpclient
-from tests.tools import AgentTestCase, patch, Mock, MagicMock
+
+from azurelinuxagent.common.future import httpclient, ustr
+from tests.tools import *
class TestHttpOperations(AgentTestCase):
@@ -50,45 +56,163 @@ class TestHttpOperations(AgentTestCase):
self.assertEquals(None, host)
self.assertEquals(rel_uri, "None")
+ @patch('azurelinuxagent.common.conf.get_httpproxy_port')
+ @patch('azurelinuxagent.common.conf.get_httpproxy_host')
+ def test_get_http_proxy_none_is_default(self, mock_host, mock_port):
+ mock_host.return_value = None
+ mock_port.return_value = None
+ h, p = restutil._get_http_proxy()
+ self.assertEqual(None, h)
+ self.assertEqual(None, p)
+
+ @patch('azurelinuxagent.common.conf.get_httpproxy_port')
+ @patch('azurelinuxagent.common.conf.get_httpproxy_host')
+ def test_get_http_proxy_configuration_overrides_env(self, mock_host, mock_port):
+ mock_host.return_value = "host"
+ mock_port.return_value = None
+ h, p = restutil._get_http_proxy()
+ self.assertEqual("host", h)
+ self.assertEqual(None, p)
+ mock_host.assert_called_once()
+ mock_port.assert_called_once()
+
+ @patch('azurelinuxagent.common.conf.get_httpproxy_port')
+ @patch('azurelinuxagent.common.conf.get_httpproxy_host')
+ def test_get_http_proxy_configuration_requires_host(self, mock_host, mock_port):
+ mock_host.return_value = None
+ mock_port.return_value = None
+ h, p = restutil._get_http_proxy()
+ self.assertEqual(None, h)
+ self.assertEqual(None, p)
+ mock_host.assert_called_once()
+ mock_port.assert_not_called()
+
+ @patch('azurelinuxagent.common.conf.get_httpproxy_host')
+ def test_get_http_proxy_http_uses_httpproxy(self, mock_host):
+ mock_host.return_value = None
+ with patch.dict(os.environ, {
+ 'http_proxy' : 'http://foo.com:80',
+ 'https_proxy' : 'https://bar.com:443'
+ }):
+ h, p = restutil._get_http_proxy()
+ self.assertEqual("foo.com", h)
+ self.assertEqual(80, p)
+
+ @patch('azurelinuxagent.common.conf.get_httpproxy_host')
+ def test_get_http_proxy_https_uses_httpsproxy(self, mock_host):
+ mock_host.return_value = None
+ with patch.dict(os.environ, {
+ 'http_proxy' : 'http://foo.com:80',
+ 'https_proxy' : 'https://bar.com:443'
+ }):
+ h, p = restutil._get_http_proxy(secure=True)
+ self.assertEqual("bar.com", h)
+ self.assertEqual(443, p)
+
+ @patch('azurelinuxagent.common.conf.get_httpproxy_host')
+ def test_get_http_proxy_ignores_user_in_httpproxy(self, mock_host):
+ mock_host.return_value = None
+ with patch.dict(os.environ, {
+ 'http_proxy' : 'http://user:pw@foo.com:80'
+ }):
+ h, p = restutil._get_http_proxy()
+ self.assertEqual("foo.com", h)
+ self.assertEqual(80, p)
+
@patch("azurelinuxagent.common.future.httpclient.HTTPSConnection")
@patch("azurelinuxagent.common.future.httpclient.HTTPConnection")
- def test_http_request(self, HTTPConnection, HTTPSConnection):
- mock_http_conn = MagicMock()
- mock_http_resp = MagicMock()
- mock_http_conn.getresponse = Mock(return_value=mock_http_resp)
- HTTPConnection.return_value = mock_http_conn
- HTTPSConnection.return_value = mock_http_conn
+ def test_http_request_direct(self, HTTPConnection, HTTPSConnection):
+ mock_conn = \
+ MagicMock(getresponse=\
+ Mock(return_value=\
+ Mock(read=Mock(return_value="TheResults"))))
- mock_http_resp.read = Mock(return_value="_(:3| <)_")
+ HTTPConnection.return_value = mock_conn
- # Test http get
- resp = restutil._http_request("GET", "foo", "bar")
- self.assertNotEquals(None, resp)
- self.assertEquals("_(:3| <)_", resp.read())
+ resp = restutil._http_request("GET", "foo", "/bar")
- # Test https get
- resp = restutil._http_request("GET", "foo", "bar", secure=True)
+ HTTPConnection.assert_has_calls([
+ call("foo", 80, timeout=10)
+ ])
+ HTTPSConnection.assert_not_called()
+ mock_conn.request.assert_has_calls([
+ call(method="GET", url="/bar", body=None, headers={})
+ ])
+ mock_conn.getresponse.assert_called_once()
self.assertNotEquals(None, resp)
- self.assertEquals("_(:3| <)_", resp.read())
+ self.assertEquals("TheResults", resp.read())
+
+ @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection")
+ @patch("azurelinuxagent.common.future.httpclient.HTTPConnection")
+ def test_http_request_direct_secure(self, HTTPConnection, HTTPSConnection):
+ mock_conn = \
+ MagicMock(getresponse=\
+ Mock(return_value=\
+ Mock(read=Mock(return_value="TheResults"))))
+
+ HTTPSConnection.return_value = mock_conn
+
+ resp = restutil._http_request("GET", "foo", "/bar", secure=True)
- # Test http get with proxy
- mock_http_resp.read = Mock(return_value="_(:3| <)_")
- resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar",
- proxy_port=23333)
+ HTTPConnection.assert_not_called()
+ HTTPSConnection.assert_has_calls([
+ call("foo", 443, timeout=10)
+ ])
+ mock_conn.request.assert_has_calls([
+ call(method="GET", url="/bar", body=None, headers={})
+ ])
+ mock_conn.getresponse.assert_called_once()
self.assertNotEquals(None, resp)
- self.assertEquals("_(:3| <)_", resp.read())
+ self.assertEquals("TheResults", resp.read())
- # Test https get
- resp = restutil._http_request("GET", "foo", "bar", secure=True)
+ @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection")
+ @patch("azurelinuxagent.common.future.httpclient.HTTPConnection")
+ def test_http_request_proxy(self, HTTPConnection, HTTPSConnection):
+ mock_conn = \
+ MagicMock(getresponse=\
+ Mock(return_value=\
+ Mock(read=Mock(return_value="TheResults"))))
+
+ HTTPConnection.return_value = mock_conn
+
+ resp = restutil._http_request("GET", "foo", "/bar",
+ proxy_host="foo.bar", proxy_port=23333)
+
+ HTTPConnection.assert_has_calls([
+ call("foo.bar", 23333, timeout=10)
+ ])
+ HTTPSConnection.assert_not_called()
+ mock_conn.request.assert_has_calls([
+ call(method="GET", url="http://foo:80/bar", body=None, headers={})
+ ])
+ mock_conn.getresponse.assert_called_once()
self.assertNotEquals(None, resp)
- self.assertEquals("_(:3| <)_", resp.read())
+ self.assertEquals("TheResults", resp.read())
+
+ @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection")
+ @patch("azurelinuxagent.common.future.httpclient.HTTPConnection")
+ def test_http_request_proxy_secure(self, HTTPConnection, HTTPSConnection):
+ mock_conn = \
+ MagicMock(getresponse=\
+ Mock(return_value=\
+ Mock(read=Mock(return_value="TheResults"))))
+
+ HTTPSConnection.return_value = mock_conn
- # Test https get with proxy
- mock_http_resp.read = Mock(return_value="_(:3| <)_")
- resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar",
- proxy_port=23333, secure=True)
+ resp = restutil._http_request("GET", "foo", "/bar",
+ proxy_host="foo.bar", proxy_port=23333,
+ secure=True)
+
+ HTTPConnection.assert_not_called()
+ HTTPSConnection.assert_has_calls([
+ call("foo.bar", 23333, timeout=10)
+ ])
+ mock_conn.request.assert_has_calls([
+ call(method="GET", url="https://foo:443/bar", body=None, headers={})
+ ])
+ mock_conn.getresponse.assert_called_once()
self.assertNotEquals(None, resp)
- self.assertEquals("_(:3| <)_", resp.read())
+ self.assertEquals("TheResults", resp.read())
@patch("time.sleep")
@patch("azurelinuxagent.common.utils.restutil._http_request")
@@ -115,6 +239,180 @@ class TestHttpOperations(AgentTestCase):
self.assertRaises(restutil.HttpError, restutil.http_get,
"http://foo.bar")
+ @patch("time.sleep")
+ @patch("azurelinuxagent.common.utils.restutil._http_request")
+ def test_http_request_retries_status_codes(self, _http_request, _sleep):
+ _http_request.side_effect = [
+ Mock(status=httpclient.SERVICE_UNAVAILABLE),
+ Mock(status=httpclient.OK)
+ ]
+
+ restutil.http_get("https://foo.bar")
+ self.assertEqual(2, _http_request.call_count)
+ self.assertEqual(1, _sleep.call_count)
+
+ @patch("time.sleep")
+ @patch("azurelinuxagent.common.utils.restutil._http_request")
+ def test_http_request_retries_passed_status_codes(self, _http_request, _sleep):
+ # Ensure the code is not part of the standard set
+ self.assertFalse(httpclient.UNAUTHORIZED in restutil.RETRY_CODES)
+
+ _http_request.side_effect = [
+ Mock(status=httpclient.UNAUTHORIZED),
+ Mock(status=httpclient.OK)
+ ]
+
+ restutil.http_get("https://foo.bar", retry_codes=[httpclient.UNAUTHORIZED])
+ self.assertEqual(2, _http_request.call_count)
+ self.assertEqual(1, _sleep.call_count)
+
+ @patch("time.sleep")
+ @patch("azurelinuxagent.common.utils.restutil._http_request")
+ def test_http_request_raises_for_bad_request(self, _http_request, _sleep):
+ _http_request.side_effect = [
+ Mock(status=httpclient.BAD_REQUEST)
+ ]
+
+ self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar")
+ self.assertEqual(1, _http_request.call_count)
+
+ @patch("time.sleep")
+ @patch("azurelinuxagent.common.utils.restutil._http_request")
+ def test_http_request_raises_for_resource_gone(self, _http_request, _sleep):
+ _http_request.side_effect = [
+ Mock(status=httpclient.GONE)
+ ]
+
+ self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar")
+ self.assertEqual(1, _http_request.call_count)
+
+ @patch("time.sleep")
+ @patch("azurelinuxagent.common.utils.restutil._http_request")
+ def test_http_request_retries_exceptions(self, _http_request, _sleep):
+ # Testing each exception is difficult because they have varying
+ # signatures; for now, test one and ensure the set is unchanged
+ recognized_exceptions = [
+ httpclient.NotConnected,
+ httpclient.IncompleteRead,
+ httpclient.ImproperConnectionState,
+ httpclient.BadStatusLine
+ ]
+ self.assertEqual(recognized_exceptions, restutil.RETRY_EXCEPTIONS)
+
+ _http_request.side_effect = [
+ httpclient.IncompleteRead(''),
+ Mock(status=httpclient.OK)
+ ]
+
+ restutil.http_get("https://foo.bar")
+ self.assertEqual(2, _http_request.call_count)
+ self.assertEqual(1, _sleep.call_count)
+
+ @patch("time.sleep")
+ @patch("azurelinuxagent.common.utils.restutil._http_request")
+ def test_http_request_retries_ioerrors(self, _http_request, _sleep):
+ ioerror = IOError()
+ ioerror.errno = 42
+
+ _http_request.side_effect = [
+ ioerror,
+ Mock(status=httpclient.OK)
+ ]
+
+ restutil.http_get("https://foo.bar")
+ self.assertEqual(2, _http_request.call_count)
+ self.assertEqual(1, _sleep.call_count)
+
+ def test_request_failed(self):
+ self.assertTrue(restutil.request_failed(None))
+
+ resp = Mock()
+ for status in restutil.OK_CODES:
+ resp.status = status
+ self.assertFalse(restutil.request_failed(resp))
+
+ self.assertFalse(httpclient.BAD_REQUEST in restutil.OK_CODES)
+ resp.status = httpclient.BAD_REQUEST
+ self.assertTrue(restutil.request_failed(resp))
+
+ self.assertFalse(
+ restutil.request_failed(
+ resp, ok_codes=[httpclient.BAD_REQUEST]))
+
+ def test_request_succeeded(self):
+ self.assertFalse(restutil.request_succeeded(None))
+
+ resp = Mock()
+ for status in restutil.OK_CODES:
+ resp.status = status
+ self.assertTrue(restutil.request_succeeded(resp))
+
+ self.assertFalse(httpclient.BAD_REQUEST in restutil.OK_CODES)
+ resp.status = httpclient.BAD_REQUEST
+ self.assertFalse(restutil.request_succeeded(resp))
+
+ self.assertTrue(
+ restutil.request_succeeded(
+ resp, ok_codes=[httpclient.BAD_REQUEST]))
+
+ def test_read_response_error(self):
+ """
+ Validate the read_response_error method handles encoding correctly
+ """
+ responses = ['message', b'message', '\x80message\x80']
+ response = MagicMock()
+ response.status = 'status'
+ response.reason = 'reason'
+ with patch.object(response, 'read') as patch_response:
+ for s in responses:
+ patch_response.return_value = s
+ result = restutil.read_response_error(response)
+ print("RESPONSE: {0}".format(s))
+ print("RESULT: {0}".format(result))
+ print("PRESENT: {0}".format('[status: reason]' in result))
+ self.assertTrue('[status: reason]' in result)
+ self.assertTrue('message' in result)
+
+ def test_read_response_bytes(self):
+ response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \
+ '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \
+ '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \
+ '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \
+ '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \
+ 'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \
+ '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \
+ '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \
+ '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \
+ '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \
+ '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \
+ '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \
+ '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \
+ '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \
+ '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \
+ '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \
+ '6c:73:22:3a:20:22:22:0a:7d'.split(':')
+ expected_response = '[HTTP Failed] [status: reason] {\n "errorCode": "The blob ' \
+ 'type is invalid for this operation.",\n ' \
+ '"message": "<?xml version="1.0" ' \
+ 'encoding="utf-8"?>' \
+ '<Error><Code>InvalidBlobType</Code><Message>The ' \
+ 'blob type is invalid for this operation.\n' \
+ 'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \
+ '\n "details": ""\n}'
+
+ response_string = ''.join(chr(int(b, 16)) for b in response_bytes)
+ response = MagicMock()
+ response.status = 'status'
+ response.reason = 'reason'
+ with patch.object(response, 'read') as patch_response:
+ patch_response.return_value = response_string
+ result = restutil.read_response_error(response)
+ self.assertEqual(result, expected_response)
+ try:
+ raise HttpError("{0}".format(result))
+ except HttpError as e:
+ self.assertTrue(result in ustr(e))
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/utils/test_text_util.py b/tests/utils/test_text_util.py
index 6f204c7..d182a67 100644
--- a/tests/utils/test_text_util.py
+++ b/tests/utils/test_text_util.py
@@ -34,6 +34,19 @@ class TestTextUtil(AgentTestCase):
password_hash = textutil.gen_password_hash(data, 6, 10)
self.assertNotEquals(None, password_hash)
+ def test_replace_non_ascii(self):
+ data = ustr(b'\xef\xbb\xbfhehe', encoding='utf-8')
+ self.assertEqual('hehe', textutil.replace_non_ascii(data))
+
+ data = "abcd\xa0e\xf0fghijk\xbblm"
+ self.assertEqual("abcdefghijklm", textutil.replace_non_ascii(data))
+
+ data = "abcd\xa0e\xf0fghijk\xbblm"
+ self.assertEqual("abcdXeXfghijkXlm",
+ textutil.replace_non_ascii(data, replace_char='X'))
+
+ self.assertEqual('', textutil.replace_non_ascii(None))
+
def test_remove_bom(self):
#Test bom could be removed
data = ustr(b'\xef\xbb\xbfhehe', encoding='utf-8')
@@ -94,6 +107,37 @@ class TestTextUtil(AgentTestCase):
"-----END PRIVATE Key-----\n")
base64_bytes = textutil.get_bytes_from_pem(content)
self.assertEquals("private key", base64_bytes)
+
+ def test_swap_hexstring(self):
+ data = [
+ ['12', 1, '21'],
+ ['12', 2, '12'],
+ ['12', 3, '012'],
+ ['12', 4, '0012'],
+
+ ['123', 1, '321'],
+ ['123', 2, '2301'],
+ ['123', 3, '123'],
+ ['123', 4, '0123'],
+
+ ['1234', 1, '4321'],
+ ['1234', 2, '3412'],
+ ['1234', 3, '234001'],
+ ['1234', 4, '1234'],
+
+ ['abcdef12', 1, '21fedcba'],
+ ['abcdef12', 2, '12efcdab'],
+ ['abcdef12', 3, 'f12cde0ab'],
+ ['abcdef12', 4, 'ef12abcd'],
+
+ ['aBcdEf12', 1, '21fEdcBa'],
+ ['aBcdEf12', 2, '12EfcdaB'],
+ ['aBcdEf12', 3, 'f12cdE0aB'],
+ ['aBcdEf12', 4, 'Ef12aBcd']
+ ]
+
+ for t in data:
+ self.assertEqual(t[2], textutil.swap_hexstring(t[0], width=t[1]))
if __name__ == '__main__':
unittest.main()