diff options
author | Ćukasz 'sil2100' Zemczak <lukasz.zemczak@ubuntu.com> | 2017-09-04 10:27:07 +0200 |
---|---|---|
committer | usd-importer <ubuntu-server@lists.ubuntu.com> | 2017-09-04 09:38:24 +0000 |
commit | e919bdd14e48919244da9e499070fb64377993e5 (patch) | |
tree | 33c260c7c99410ac94d5f265fc506cc0b40bb6e4 /tests | |
parent | 70c0ea1ac879b2e1cba0a8edb1f3fbe82652413b (diff) | |
parent | 3a1d96a77ccaf023256d16183428e3d895f8a051 (diff) | |
download | vyos-walinuxagent-e919bdd14e48919244da9e499070fb64377993e5.tar.gz vyos-walinuxagent-e919bdd14e48919244da9e499070fb64377993e5.zip |
Import patches-applied version 2.2.16-0ubuntu1 to applied/ubuntu/artful-proposed
Imported using git-ubuntu import.
Changelog parent: 70c0ea1ac879b2e1cba0a8edb1f3fbe82652413b
Unapplied parent: 3a1d96a77ccaf023256d16183428e3d895f8a051
New changelog entries:
* New upstream release (LP: #1714299).
Diffstat (limited to 'tests')
-rw-r--r-- | tests/common/osutil/test_default.py | 229 | ||||
-rw-r--r-- | tests/common/test_conf.py | 51 | ||||
-rw-r--r-- | tests/common/test_event.py | 117 | ||||
-rw-r--r-- | tests/data/ga/WALinuxAgent-2.2.11.zip | bin | 450878 -> 0 bytes | |||
-rw-r--r-- | tests/data/ga/WALinuxAgent-2.2.14.zip | bin | 0 -> 500633 bytes | |||
-rw-r--r-- | tests/data/test_waagent.conf | 16 | ||||
-rw-r--r-- | tests/ga/test_update.py | 243 | ||||
-rw-r--r-- | tests/pa/test_provision.py | 28 | ||||
-rw-r--r-- | tests/protocol/mockwiredata.py | 81 | ||||
-rw-r--r-- | tests/protocol/test_hostplugin.py | 61 | ||||
-rw-r--r-- | tests/protocol/test_metadata.py | 20 | ||||
-rw-r--r-- | tests/protocol/test_wire.py | 77 | ||||
-rw-r--r-- | tests/test_agent.py | 77 | ||||
-rw-r--r-- | tests/tools.py | 9 | ||||
-rw-r--r-- | tests/utils/test_file_util.py | 115 | ||||
-rw-r--r-- | tests/utils/test_rest_util.py | 356 | ||||
-rw-r--r-- | tests/utils/test_text_util.py | 44 |
17 files changed, 1243 insertions, 281 deletions
diff --git a/tests/common/osutil/test_default.py b/tests/common/osutil/test_default.py index 87acc60..ec4408b 100644 --- a/tests/common/osutil/test_default.py +++ b/tests/common/osutil/test_default.py @@ -25,6 +25,7 @@ from azurelinuxagent.common.exception import OSUtilError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.osutil import get_osutil from azurelinuxagent.common.utils import fileutil +from azurelinuxagent.common.utils.flexible_version import FlexibleVersion from tests.tools import * @@ -112,6 +113,21 @@ class TestOSUtil(AgentTestCase): self.assertFalse(osutil.DefaultOSUtil().is_primary_interface('lo')) self.assertTrue(osutil.DefaultOSUtil().is_primary_interface('eth0')) + def test_sriov(self): + routing_table = "\ + Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT \n" \ + "bond0 00000000 0100000A 0003 0 0 0 00000000 0 0 0 \n" \ + "bond0 0000000A 00000000 0001 0 0 0 00000000 0 0 0 \n" \ + "eth0 0000000A 00000000 0001 0 0 0 00000000 0 0 0 \n" \ + "bond0 10813FA8 0100000A 0007 0 0 0 00000000 0 0 0 \n" \ + "bond0 FEA9FEA9 0100000A 0007 0 0 0 00000000 0 0 0 \n" + + mo = mock.mock_open(read_data=routing_table) + with patch(open_patch(), mo): + self.assertFalse(osutil.DefaultOSUtil().is_primary_interface('eth0')) + self.assertTrue(osutil.DefaultOSUtil().is_primary_interface('bond0')) + + def test_multiple_default_routes(self): routing_table = "\ Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT \n\ @@ -362,23 +378,50 @@ Match host 192.168.1.2\n\ conf.get_sshd_conf_file_path(), expected_output) + def test_correct_instance_id(self): + util = osutil.DefaultOSUtil() + self.assertEqual( + "12345678-1234-1234-1234-123456789012", + util._correct_instance_id("78563412-3412-3412-1234-123456789012")) + self.assertEqual( + "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8", + util._correct_instance_id("544CDFD0-CB4E-4B4A-9954-5BDF3ED5C3B8")) + @patch('os.path.isfile', return_value=True) @patch('azurelinuxagent.common.utils.fileutil.read_file', - return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502") + return_value="33C2F3B9-1399-429F-8EB3-BA656DF32502") def test_get_instance_id_from_file(self, mock_read, mock_isfile): util = osutil.DefaultOSUtil() self.assertEqual( - "B9F3C233-9913-9F42-8EB3-BA656DF32502", + util.get_instance_id(), + "B9F3C233-9913-9F42-8EB3-BA656DF32502") + + @patch('os.path.isfile', return_value=True) + @patch('azurelinuxagent.common.utils.fileutil.read_file', + return_value="") + def test_get_instance_id_empty_from_file(self, mock_read, mock_isfile): + util = osutil.DefaultOSUtil() + self.assertEqual( + "", + util.get_instance_id()) + + @patch('os.path.isfile', return_value=True) + @patch('azurelinuxagent.common.utils.fileutil.read_file', + return_value="Value") + def test_get_instance_id_malformed_from_file(self, mock_read, mock_isfile): + util = osutil.DefaultOSUtil() + self.assertEqual( + "Value", util.get_instance_id()) @patch('os.path.isfile', return_value=False) @patch('azurelinuxagent.common.utils.shellutil.run_get_output', - return_value=[0, 'B9F3C233-9913-9F42-8EB3-BA656DF32502']) + return_value=[0, '33C2F3B9-1399-429F-8EB3-BA656DF32502']) def test_get_instance_id_from_dmidecode(self, mock_shell, mock_isfile): util = osutil.DefaultOSUtil() self.assertEqual( - "B9F3C233-9913-9F42-8EB3-BA656DF32502", - util.get_instance_id()) + util.get_instance_id(), + "B9F3C233-9913-9F42-8EB3-BA656DF32502") @patch('os.path.isfile', return_value=False) @patch('azurelinuxagent.common.utils.shellutil.run_get_output', @@ -394,5 +437,181 @@ Match host 192.168.1.2\n\ util = osutil.DefaultOSUtil() self.assertEqual("", util.get_instance_id()) + @patch('os.path.isfile', return_value=True) + @patch('azurelinuxagent.common.utils.fileutil.read_file') + def test_is_current_instance_id_from_file(self, mock_read, mock_isfile): + util = osutil.DefaultOSUtil() + + mock_read.return_value = "B9F3C233-9913-9F42-8EB3-BA656DF32502" + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + mock_read.return_value = "33C2F3B9-1399-429F-8EB3-BA656DF32502" + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + @patch('os.path.isfile', return_value=False) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + def test_is_current_instance_id_from_dmidecode(self, mock_shell, mock_isfile): + util = osutil.DefaultOSUtil() + + mock_shell.return_value = [0, 'B9F3C233-9913-9F42-8EB3-BA656DF32502'] + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + mock_shell.return_value = [0, '33C2F3B9-1399-429F-8EB3-BA656DF32502'] + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + @patch('azurelinuxagent.common.conf.get_sudoers_dir') + def test_conf_sudoer(self, mock_dir): + tmp_dir = tempfile.mkdtemp() + mock_dir.return_value = tmp_dir + + util = osutil.DefaultOSUtil() + + # Assert the sudoer line is added if missing + util.conf_sudoer("FooBar") + waagent_sudoers = os.path.join(tmp_dir, 'waagent') + self.assertTrue(os.path.isfile(waagent_sudoers)) + + count = -1 + with open(waagent_sudoers, 'r') as f: + count = len(f.readlines()) + self.assertEqual(1, count) + + # Assert the line does not get added a second time + util.conf_sudoer("FooBar") + + count = -1 + with open(waagent_sudoers, 'r') as f: + count = len(f.readlines()) + print("WRITING TO {0}".format(waagent_sudoers)) + self.assertEqual(1, count) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [1, 0, 0] + mock_output.side_effect = [(0, version), (0, "Output")] + self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)), + call(osutil.FIREWALL_DROP.format(wait, "A", dst)) + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION), + call(osutil.FIREWALL_LIST.format(wait)) + ]) + self.assertTrue(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_no_wait(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION-1) + wait = "" + + mock_run.side_effect = [1, 0, 0] + mock_output.side_effect = [(0, version), (0, "Output")] + self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)), + call(osutil.FIREWALL_DROP.format(wait, "A", dst)) + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION), + call(osutil.FIREWALL_LIST.format(wait)) + ]) + self.assertTrue(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_skips_if_drop_exists(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [0, 0, 0] + mock_output.return_value = (0, version) + self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION) + ]) + self.assertTrue(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_ignores_exceptions(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [1, Exception] + mock_output.return_value = (0, version) + self.assertFalse(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)) + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION) + ]) + self.assertFalse(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_skips_if_disabled(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = False + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [1, 0, 0] + mock_output.side_effect = [(0, version), (0, "Output")] + self.assertFalse(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_not_called() + mock_output.assert_not_called() + mock_uid.assert_not_called() + self.assertFalse(osutil._enable_firewall) + if __name__ == '__main__': unittest.main() diff --git a/tests/common/test_conf.py b/tests/common/test_conf.py index 1287b0d..93759de 100644 --- a/tests/common/test_conf.py +++ b/tests/common/test_conf.py @@ -24,6 +24,49 @@ from tests.tools import * class TestConf(AgentTestCase): + # Note: + # -- These values *MUST* match those from data/test_waagent.conf + EXPECTED_CONFIGURATION = { + "Provisioning.Enabled" : True, + "Provisioning.UseCloudInit" : True, + "Provisioning.DeleteRootPassword" : True, + "Provisioning.RegenerateSshHostKeyPair" : True, + "Provisioning.SshHostKeyPairType" : "rsa", + "Provisioning.MonitorHostName" : True, + "Provisioning.DecodeCustomData" : False, + "Provisioning.ExecuteCustomData" : False, + "Provisioning.PasswordCryptId" : '6', + "Provisioning.PasswordCryptSaltLength" : 10, + "Provisioning.AllowResetSysUser" : False, + "ResourceDisk.Format" : True, + "ResourceDisk.Filesystem" : "ext4", + "ResourceDisk.MountPoint" : "/mnt/resource", + "ResourceDisk.EnableSwap" : False, + "ResourceDisk.SwapSizeMB" : 0, + "ResourceDisk.MountOptions" : None, + "Logs.Verbose" : False, + "OS.EnableFIPS" : True, + "OS.RootDeviceScsiTimeout" : '300', + "OS.OpensslPath" : '/usr/bin/openssl', + "OS.SshDir" : "/notareal/path", + "HttpProxy.Host" : None, + "HttpProxy.Port" : None, + "DetectScvmmEnv" : False, + "Lib.Dir" : "/var/lib/waagent", + "DVD.MountPoint" : "/mnt/cdrom/secure", + "Pid.File" : "/var/run/waagent.pid", + "Extension.LogDir" : "/var/log/azure", + "OS.HomeDir" : "/home", + "OS.EnableRDMA" : False, + "OS.UpdateRdmaDriver" : False, + "OS.CheckRdmaDriver" : False, + "AutoUpdate.Enabled" : True, + "AutoUpdate.GAFamily" : "Prod", + "EnableOverProvisioning" : False, + "OS.AllowHTTP" : False, + "OS.EnableFirewall" : True + } + def setUp(self): AgentTestCase.setUp(self) self.conf = ConfigurationProvider() @@ -59,3 +102,11 @@ class TestConf(AgentTestCase): def test_get_provision_cloudinit(self): self.assertTrue(get_provision_cloudinit(self.conf)) + + def test_get_configuration(self): + configuration = conf.get_configuration(self.conf) + self.assertTrue(len(configuration.keys()) > 0) + for k in TestConf.EXPECTED_CONFIGURATION.keys(): + self.assertEqual( + TestConf.EXPECTED_CONFIGURATION[k], + configuration[k]) diff --git a/tests/common/test_event.py b/tests/common/test_event.py index a485edf..55a99c4 100644 --- a/tests/common/test_event.py +++ b/tests/common/test_event.py @@ -22,7 +22,8 @@ from datetime import datetime import azurelinuxagent.common.event as event import azurelinuxagent.common.logger as logger -from azurelinuxagent.common.event import init_event_logger, add_event +from azurelinuxagent.common.event import add_event, \ + mark_event_status, should_emit_event from azurelinuxagent.common.future import ustr from azurelinuxagent.common.version import CURRENT_VERSION @@ -30,10 +31,84 @@ from tests.tools import * class TestEvent(AgentTestCase): + def test_event_status_event_marked(self): + es = event.__event_status__ + + self.assertFalse(es.event_marked("Foo", "1.2", "FauxOperation")) + es.mark_event_status("Foo", "1.2", "FauxOperation", True) + self.assertTrue(es.event_marked("Foo", "1.2", "FauxOperation")) + + event.__event_status__ = event.EventStatus() + event.init_event_status(self.tmp_dir) + es = event.__event_status__ + self.assertTrue(es.event_marked("Foo", "1.2", "FauxOperation")) + + def test_event_status_defaults_to_success(self): + es = event.__event_status__ + self.assertTrue(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + def test_event_status_records_status(self): + d = tempfile.mkdtemp() + es = event.EventStatus(tempfile.mkdtemp()) + + es.mark_event_status("Foo", "1.2", "FauxOperation", True) + self.assertTrue(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + es.mark_event_status("Foo", "1.2", "FauxOperation", False) + self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + def test_event_status_preserves_state(self): + es = event.__event_status__ + + es.mark_event_status("Foo", "1.2", "FauxOperation", False) + self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + event.__event_status__ = event.EventStatus() + event.init_event_status(self.tmp_dir) + es = event.__event_status__ + self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + def test_should_emit_event_ignores_unknown_operations(self): + event.__event_status__ = event.EventStatus(tempfile.mkdtemp()) + + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", False)) + + # Marking the event has no effect + event.mark_event_status("Foo", "1.2", "FauxOperation", True) + + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", False)) + + + def test_should_emit_event_handles_known_operations(self): + event.__event_status__ = event.EventStatus(tempfile.mkdtemp()) + + # Known operations always initially "fire" + for op in event.__event_status_operations__: + self.assertTrue(event.should_emit_event("Foo", "1.2", op, True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", op, False)) + + # Note a success event... + for op in event.__event_status_operations__: + event.mark_event_status("Foo", "1.2", op, True) + + # Subsequent success events should not fire, but failures will + for op in event.__event_status_operations__: + self.assertFalse(event.should_emit_event("Foo", "1.2", op, True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", op, False)) + + # Note a failure event... + for op in event.__event_status_operations__: + event.mark_event_status("Foo", "1.2", op, False) + + # Subsequent success events fire and failure do not + for op in event.__event_status_operations__: + self.assertTrue(event.should_emit_event("Foo", "1.2", op, True)) + self.assertFalse(event.should_emit_event("Foo", "1.2", op, False)) @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_emits_if_not_previously_sent(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -41,7 +116,6 @@ class TestEvent(AgentTestCase): @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_does_not_emit_if_previously_sent(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -52,7 +126,6 @@ class TestEvent(AgentTestCase): @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_emits_if_forced(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -63,7 +136,6 @@ class TestEvent(AgentTestCase): @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_emits_after_elapsed_delta(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -73,14 +145,13 @@ class TestEvent(AgentTestCase): self.assertEqual(1, mock_event.call_count) h = hash("FauxEvent"+""+ustr(True)+"") - event.__event_logger__.periodic_messages[h] = \ + event.__event_logger__.periodic_events[h] = \ datetime.now() - logger.EVERY_DAY - logger.EVERY_HOUR event.add_periodic(logger.EVERY_DAY, "FauxEvent") self.assertEqual(2, mock_event.call_count) @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_forwards_args(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -90,68 +161,58 @@ class TestEvent(AgentTestCase): log_event=True, message='', op='', version=str(CURRENT_VERSION)) def test_save_event(self): - tmp_evt = tempfile.mkdtemp() - init_event_logger(tmp_evt) add_event('test', message='test event') - self.assertTrue(len(os.listdir(tmp_evt)) == 1) - shutil.rmtree(tmp_evt) + self.assertTrue(len(os.listdir(self.tmp_dir)) == 1) def test_save_event_rollover(self): - tmp_evt = tempfile.mkdtemp() - init_event_logger(tmp_evt) add_event('test', message='first event') for i in range(0, 999): add_event('test', message='test event {0}'.format(i)) - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) events.sort() self.assertTrue(len(events) == 1000) - first_event = os.path.join(tmp_evt, events[0]) + first_event = os.path.join(self.tmp_dir, events[0]) with open(first_event) as first_fh: first_event_text = first_fh.read() self.assertTrue('first event' in first_event_text) add_event('test', message='last event') - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) events.sort() self.assertTrue(len(events) == 1000, "{0} events found, 1000 expected".format(len(events))) - first_event = os.path.join(tmp_evt, events[0]) + first_event = os.path.join(self.tmp_dir, events[0]) with open(first_event) as first_fh: first_event_text = first_fh.read() self.assertFalse('first event' in first_event_text) self.assertTrue('test event 0' in first_event_text) - last_event = os.path.join(tmp_evt, events[-1]) + last_event = os.path.join(self.tmp_dir, events[-1]) with open(last_event) as last_fh: last_event_text = last_fh.read() self.assertTrue('last event' in last_event_text) - shutil.rmtree(tmp_evt) - def test_save_event_cleanup(self): - tmp_evt = tempfile.mkdtemp() - init_event_logger(tmp_evt) - for i in range(0, 2000): - evt = os.path.join(tmp_evt, '{0}.tld'.format(ustr(1491004920536531 + i))) + evt = os.path.join(self.tmp_dir, '{0}.tld'.format(ustr(1491004920536531 + i))) with open(evt, 'w') as fh: fh.write('test event {0}'.format(i)) - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) self.assertTrue(len(events) == 2000, "{0} events found, 2000 expected".format(len(events))) add_event('test', message='last event') - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) events.sort() self.assertTrue(len(events) == 1000, "{0} events found, 1000 expected".format(len(events))) - first_event = os.path.join(tmp_evt, events[0]) + first_event = os.path.join(self.tmp_dir, events[0]) with open(first_event) as first_fh: first_event_text = first_fh.read() self.assertTrue('test event 1001' in first_event_text) - last_event = os.path.join(tmp_evt, events[-1]) + last_event = os.path.join(self.tmp_dir, events[-1]) with open(last_event) as last_fh: last_event_text = last_fh.read() self.assertTrue('last event' in last_event_text) diff --git a/tests/data/ga/WALinuxAgent-2.2.11.zip b/tests/data/ga/WALinuxAgent-2.2.11.zip Binary files differdeleted file mode 100644 index f018116..0000000 --- a/tests/data/ga/WALinuxAgent-2.2.11.zip +++ /dev/null diff --git a/tests/data/ga/WALinuxAgent-2.2.14.zip b/tests/data/ga/WALinuxAgent-2.2.14.zip Binary files differnew file mode 100644 index 0000000..a978207 --- /dev/null +++ b/tests/data/ga/WALinuxAgent-2.2.14.zip diff --git a/tests/data/test_waagent.conf b/tests/data/test_waagent.conf index 6368c39..edc3676 100644 --- a/tests/data/test_waagent.conf +++ b/tests/data/test_waagent.conf @@ -94,10 +94,13 @@ OS.SshDir=/notareal/path # Extension.LogDir=/var/log/azure # -# Home.Dir=/home +# OS.HomeDir=/home # Enable RDMA management and set up, should only be used in HPC images -# OS.EnableRDMA=y +# OS.EnableRDMA=n +# OS.UpdateRdmaDriver=n +# OS.CheckRdmaDriver=n + # Enable or disable goal state processing auto-update, default is enabled # AutoUpdate.Enabled=y @@ -109,3 +112,12 @@ OS.SshDir=/notareal/path # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py index 0c8642c..59251cb 100644 --- a/tests/ga/test_update.py +++ b/tests/ga/test_update.py @@ -22,6 +22,7 @@ from datetime import datetime import json import shutil +from azurelinuxagent.common.event import * from azurelinuxagent.common.protocol.hostplugin import * from azurelinuxagent.common.protocol.metadata import * from azurelinuxagent.common.protocol.wire import * @@ -148,7 +149,9 @@ class UpdateTestCase(AgentTestCase): def create_error(self, error_data=NO_ERROR): with self.get_error_file(error_data) as path: - return GuestAgentError(path.name) + err = GuestAgentError(path.name) + err.load() + return err def copy_agents(self, *agents): if len(agents) <= 0: @@ -157,11 +160,11 @@ class UpdateTestCase(AgentTestCase): fileutil.copy_file(agent, to_dir=self.tmp_dir) return - def expand_agents(self, mark_test=False): + def expand_agents(self, mark_optional=False): for agent in self.agent_pkgs(): path = os.path.join(self.tmp_dir, fileutil.trim_ext(agent, "zip")) zipfile.ZipFile(agent).extractall(path) - if mark_test: + if mark_optional: src = os.path.join(data_dir, 'ga', 'supported.json') dst = os.path.join(path, 'supported.json') shutil.copy(src, dst) @@ -170,12 +173,12 @@ class UpdateTestCase(AgentTestCase): fileutil.write_file(dst, json.dumps(SENTINEL_ERROR)) return - def prepare_agent(self, version, mark_test=False): + def prepare_agent(self, version, mark_optional=False): """ Create a download for the current agent version, copied from test data """ self.copy_agents(get_agent_pkgs()[0]) - self.expand_agents(mark_test=mark_test) + self.expand_agents(mark_optional=mark_optional) versions = self.agent_versions() src_v = FlexibleVersion(str(versions[0])) @@ -246,7 +249,6 @@ class TestSupportedDistribution(UpdateTestCase): self.sd = SupportedDistribution({ 'slice':10, 'versions': ['^Ubuntu,16.10,yakkety$']}) - def test_creation(self): self.assertRaises(TypeError, SupportedDistribution) @@ -276,6 +278,7 @@ class TestSupported(UpdateTestCase): def setUp(self): UpdateTestCase.setUp(self) self.sp = Supported(os.path.join(data_dir, 'ga', 'supported.json')) + self.sp.load() def test_creation(self): self.assertRaises(TypeError, Supported) @@ -305,6 +308,7 @@ class TestGuestAgentError(UpdateTestCase): with self.get_error_file(error_data=WITH_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertEqual(path.name, err.path) self.assertNotEqual(None, err) @@ -316,6 +320,7 @@ class TestGuestAgentError(UpdateTestCase): def test_clear(self): with self.get_error_file(error_data=WITH_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertEqual(path.name, err.path) self.assertNotEqual(None, err) @@ -328,27 +333,16 @@ class TestGuestAgentError(UpdateTestCase): def test_is_sentinel(self): with self.get_error_file(error_data=SENTINEL_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertTrue(err.is_blacklisted) self.assertTrue(err.is_sentinel) with self.get_error_file(error_data=FATAL_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertTrue(err.is_blacklisted) self.assertFalse(err.is_sentinel) - def test_load_preserves_error_state(self): - with self.get_error_file(error_data=WITH_ERROR) as path: - err = GuestAgentError(path.name) - self.assertEqual(path.name, err.path) - self.assertNotEqual(None, err) - - with self.get_error_file(error_data=NO_ERROR): - err.load() - self.assertEqual(WITH_ERROR["last_failure"], err.last_failure) - self.assertEqual(WITH_ERROR["failure_count"], err.failure_count) - self.assertEqual(WITH_ERROR["was_fatal"], err.was_fatal) - return - def test_save(self): err1 = self.create_error() err1.mark_failure() @@ -406,22 +400,20 @@ class TestGuestAgent(UpdateTestCase): self.agent_path = os.path.join(self.tmp_dir, get_agent_name()) return - def tearDown(self): - self.remove_agents() - return - def test_creation(self): self.assertRaises(UpdateError, GuestAgent, "A very bad file name") n = "{0}-a.bad.version".format(AGENT_NAME) self.assertRaises(UpdateError, GuestAgent, n) + self.expand_agents() + agent = GuestAgent(path=self.agent_path) self.assertNotEqual(None, agent) self.assertEqual(get_agent_name(), agent.name) self.assertEqual(get_agent_version(), agent.version) - self.assertFalse(agent.is_test) - self.assertFalse(agent.in_slice) + self.assertFalse(agent._is_optional) + self.assertFalse(agent._in_slice) self.assertEqual(self.agent_path, agent.get_agent_dir()) @@ -436,13 +428,14 @@ class TestGuestAgent(UpdateTestCase): self.assertEqual(path, agent.get_agent_pkg_path()) self.assertTrue(agent.is_downloaded) - # Note: Agent will get blacklisted since the package for this test is invalid - self.assertTrue(agent.is_blacklisted) - self.assertFalse(agent.is_available) + self.assertFalse(agent.is_blacklisted) + self.assertTrue(agent.is_available) return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_clear_error(self, mock_ensure): + def test_clear_error(self, mock_downloaded): + self.expand_agents() + agent = GuestAgent(path=self.agent_path) agent.mark_failure(is_fatal=True) @@ -459,7 +452,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_is_available(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_is_available(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_available) @@ -471,7 +465,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_is_blacklisted(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_is_blacklisted(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_blacklisted) @@ -485,7 +480,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_is_downloaded(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_is_downloaded(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_downloaded) agent._unpack() @@ -493,50 +489,47 @@ class TestGuestAgent(UpdateTestCase): return @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) - def test_is_test(self, mock_dist): - self.expand_agents(mark_test=True) + @patch('azurelinuxagent.ga.update.GuestAgent._enable') + def test_is_optional(self, mock_enable, mock_dist): + self.expand_agents(mark_optional=True) agent = GuestAgent(path=self.agent_path) self.assertTrue(agent.is_blacklisted) - self.assertTrue(agent.is_test) + self.assertTrue(agent._is_optional) @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) @patch('azurelinuxagent.ga.update.datetime') def test_in_slice(self, mock_dt, mock_dist): - self.expand_agents(mark_test=True) + self.expand_agents(mark_optional=True) agent = GuestAgent(path=self.agent_path) mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.assertTrue(agent.in_slice) + self.assertTrue(agent._in_slice) mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 42)) - self.assertFalse(agent.in_slice) + self.assertFalse(agent._in_slice) @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) @patch('azurelinuxagent.ga.update.datetime') def test_enable(self, mock_dt, mock_dist): mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.expand_agents(mark_test=True) + self.expand_agents(mark_optional=True) agent = GuestAgent(path=self.agent_path) - self.assertTrue(agent.is_blacklisted) - self.assertTrue(agent.is_test) - self.assertTrue(agent.in_slice) - - agent.enable() - self.assertFalse(agent.is_blacklisted) - self.assertFalse(agent.is_test) + self.assertFalse(agent._is_optional) # Ensure the new state is preserved to disk agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_blacklisted) - self.assertFalse(agent.is_test) + self.assertFalse(agent._is_optional) @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_mark_failure(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_mark_failure(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) + agent.mark_failure() self.assertEqual(1, agent.error.failure_count) @@ -546,7 +539,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_unpack(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_unpack(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -555,7 +549,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_unpack_fail(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_unpack_fail(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) os.remove(agent.get_agent_pkg_path()) @@ -563,7 +558,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) agent._unpack() agent._load_manifest() @@ -572,7 +568,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest_missing(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest_missing(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -581,7 +578,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest_is_empty(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest_is_empty(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -593,7 +591,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest_is_malformed(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest_is_malformed(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -613,8 +612,9 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download(self, mock_http_get, mock_ensure): + def test_download(self, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) @@ -630,8 +630,9 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download_fail(self, mock_http_get, mock_ensure): + def test_download_fail(self, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) @@ -647,8 +648,9 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download_fallback(self, mock_http_get, mock_ensure): + def test_download_fallback(self, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) @@ -681,8 +683,12 @@ class TestGuestAgent(UpdateTestCase): return_value=True): self.assertRaises(UpdateError, agent._download) self.assertEqual(mock_http_get.call_count, 4) + self.assertEqual(mock_http_get.call_args_list[2][0][0], ext_uri) + self.assertEqual(mock_http_get.call_args_list[3][0][0], art_uri) + a, k = mock_http_get.call_args_list[3] + self.assertEqual(False, k['use_proxy']) # ensure fallback works as expected with patch.object(HostPluginProtocol, @@ -690,8 +696,16 @@ class TestGuestAgent(UpdateTestCase): return_value=[art_uri, {}]): self.assertRaises(UpdateError, agent._download) self.assertEqual(mock_http_get.call_count, 6) + + a, k = mock_http_get.call_args_list[3] + self.assertEqual(False, k['use_proxy']) + self.assertEqual(mock_http_get.call_args_list[4][0][0], ext_uri) + a, k = mock_http_get.call_args_list[4] + self.assertEqual(mock_http_get.call_args_list[5][0][0], art_uri) + a, k = mock_http_get.call_args_list[5] + self.assertEqual(False, k['use_proxy']) @patch("azurelinuxagent.ga.update.restutil.http_get") def test_ensure_downloaded(self, mock_http_get): @@ -725,7 +739,7 @@ class TestGuestAgent(UpdateTestCase): @patch("azurelinuxagent.ga.update.GuestAgent._download") @patch("azurelinuxagent.ga.update.GuestAgent._unpack", side_effect=UpdateError) - def test_ensure_downloaded_unpack_fails(self, mock_download, mock_unpack): + def test_ensure_downloaded_unpack_fails(self, mock_unpack, mock_download): self.assertFalse(os.path.isdir(self.agent_path)) pkg = ExtHandlerPackage(version=str(get_agent_version())) @@ -740,7 +754,7 @@ class TestGuestAgent(UpdateTestCase): @patch("azurelinuxagent.ga.update.GuestAgent._download") @patch("azurelinuxagent.ga.update.GuestAgent._unpack") @patch("azurelinuxagent.ga.update.GuestAgent._load_manifest", side_effect=UpdateError) - def test_ensure_downloaded_load_manifest_fails(self, mock_download, mock_unpack, mock_manifest): + def test_ensure_downloaded_load_manifest_fails(self, mock_manifest, mock_unpack, mock_download): self.assertFalse(os.path.isdir(self.agent_path)) pkg = ExtHandlerPackage(version=str(get_agent_version())) @@ -755,10 +769,13 @@ class TestGuestAgent(UpdateTestCase): @patch("azurelinuxagent.ga.update.GuestAgent._download") @patch("azurelinuxagent.ga.update.GuestAgent._unpack") @patch("azurelinuxagent.ga.update.GuestAgent._load_manifest") - def test_ensure_download_skips_blacklisted(self, mock_download, mock_unpack, mock_manifest): + def test_ensure_download_skips_blacklisted(self, mock_manifest, mock_unpack, mock_download): agent = GuestAgent(path=self.agent_path) + self.assertEqual(0, mock_download.call_count) + agent.clear_error() agent.mark_failure(is_fatal=True) + self.assertTrue(agent.is_blacklisted) pkg = ExtHandlerPackage(version=str(get_agent_version())) pkg.uris.append(ExtHandlerPackageUri()) @@ -769,7 +786,6 @@ class TestGuestAgent(UpdateTestCase): self.assertTrue(agent.is_blacklisted) self.assertEqual(0, mock_download.call_count) self.assertEqual(0, mock_unpack.call_count) - self.assertEqual(0, mock_manifest.call_count) return @@ -812,6 +828,12 @@ class TestUpdate(UpdateTestCase): self.event_patch.stop() return + def _create_protocol(self, count=5, versions=None): + latest_version = self.prepare_agents(count=count) + if versions is None or len(versions) <= 0: + versions = [latest_version] + return ProtocolMock(versions=versions) + def _test_upgrade_available( self, base_version=FlexibleVersion(AGENT_VERSION), @@ -819,12 +841,9 @@ class TestUpdate(UpdateTestCase): versions=None, count=5): - latest_version = self.prepare_agents(count=count) - if versions is None or len(versions) <= 0: - versions = [latest_version] - if protocol is None: - protocol = ProtocolMock(versions=versions) + protocol = self._create_protocol(count=count, versions=versions) + self.update_handler.protocol_util = protocol conf.get_autoupdate_gafamily = Mock(return_value=protocol.family) @@ -834,6 +853,16 @@ class TestUpdate(UpdateTestCase): self.assertTrue(self._test_upgrade_available()) return + def test_upgrade_available_will_refresh_goal_state(self): + protocol = self._create_protocol() + protocol.emulate_stale_goal_state() + self.assertTrue(self._test_upgrade_available(protocol=protocol)) + self.assertEqual(2, protocol.call_counts["get_vmagent_manifests"]) + self.assertEqual(1, protocol.call_counts["get_vmagent_pkgs"]) + self.assertEqual(1, protocol.call_counts["update_goal_state"]) + self.assertTrue(protocol.goal_state_forced) + return + def test_get_latest_agent_excluded(self): self.prepare_agent(AGENT_VERSION) self.assertFalse(self._test_upgrade_available( @@ -909,7 +938,7 @@ class TestUpdate(UpdateTestCase): v = a.version return - def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL): + def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL, pid_count=0): with patch.object(self.update_handler, 'osutil') as mock_util: # Note: # - Python only allows mutations of objects to which a function has @@ -924,15 +953,20 @@ class TestUpdate(UpdateTestCase): mock_util.check_pid_alive = Mock(side_effect=iterator) + pid_files = self.update_handler._get_pid_files() + self.assertEqual(pid_count, len(pid_files)) + with patch('os.getpid', return_value=42): with patch('time.sleep', return_value=None) as mock_sleep: self.update_handler._ensure_no_orphans(orphan_wait_interval=interval) + for pid_file in pid_files: + self.assertFalse(os.path.exists(pid_file)) return mock_util.check_pid_alive.call_count, mock_sleep.call_count return def test_ensure_no_orphans(self): fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41)) - calls, sleeps = self._test_ensure_no_orphans(invocations=3) + calls, sleeps = self._test_ensure_no_orphans(invocations=3, pid_count=1) self.assertEqual(3, calls) self.assertEqual(2, sleeps) return @@ -955,7 +989,8 @@ class TestUpdate(UpdateTestCase): with patch('os.kill') as mock_kill: calls, sleeps = self._test_ensure_no_orphans( invocations=4, - interval=3*GOAL_STATE_INTERVAL) + interval=3*GOAL_STATE_INTERVAL, + pid_count=1) self.assertEqual(3, calls) self.assertEqual(2, sleeps) self.assertEqual(1, mock_kill.call_count) @@ -1066,26 +1101,20 @@ class TestUpdate(UpdateTestCase): return def test_get_pid_files(self): - previous_pid_file, pid_file, = self.update_handler._get_pid_files() - self.assertEqual(None, previous_pid_file) - self.assertEqual("0_waagent.pid", os.path.basename(pid_file)) + pid_files = self.update_handler._get_pid_files() + self.assertEqual(0, len(pid_files)) return def test_get_pid_files_returns_previous(self): for n in range(1250): fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1)) - previous_pid_file, pid_file, = self.update_handler._get_pid_files() - self.assertEqual("1249_waagent.pid", os.path.basename(previous_pid_file)) - self.assertEqual("1250_waagent.pid", os.path.basename(pid_file)) - return - - @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) - @patch('azurelinuxagent.ga.update.datetime') - def test_get_test_agent(self, mock_dt, mock_dist): - mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.prepare_agent(AGENT_VERSION, mark_test=True) + pid_files = self.update_handler._get_pid_files() + self.assertEqual(1250, len(pid_files)) - self.assertNotEqual(None, self.update_handler.get_test_agent()) + pid_dir, pid_name, pid_re = self.update_handler._get_pid_parts() + for p in pid_files: + self.assertTrue(pid_re.match(os.path.basename(p))) + return def test_is_clean_start_returns_true_when_no_sentinal(self): self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) @@ -1421,23 +1450,6 @@ class TestUpdate(UpdateTestCase): self.update_handler._upgrade_available = Mock(return_value=True) self._test_run(invocations=0, calls=[], enable_updates=True) return - - @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) - @patch('azurelinuxagent.ga.update.datetime') - def test_run_stops_if_test_agent_available(self, mock_dt, mock_dist): - mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.prepare_agent(AGENT_VERSION, mark_test=True) - - agent = GuestAgent(path=self.agent_dir(AGENT_VERSION)) - agent.enable = Mock() - self.assertTrue(agent.is_test) - self.assertTrue(agent.in_slice) - - with patch('azurelinuxagent.ga.update.UpdateHandler.get_test_agent', - return_value=agent) as mock_test: - self._test_run(invocations=0) - self.assertEqual(mock_test.call_count, 1) - self.assertEqual(agent.enable.call_count, 1) def test_run_stops_if_orphaned(self): with patch('os.getppid', return_value=1): @@ -1521,8 +1533,9 @@ class TestUpdate(UpdateTestCase): for n in range(1112): fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1)) with patch('os.getpid', return_value=1112): - previous_pid_file, pid_file = self.update_handler._write_pid_file() - self.assertEqual("1111_waagent.pid", os.path.basename(previous_pid_file)) + pid_files, pid_file = self.update_handler._write_pid_file() + self.assertEqual(1112, len(pid_files)) + self.assertEqual("1111_waagent.pid", os.path.basename(pid_files[-1])) self.assertEqual("1112_waagent.pid", os.path.basename(pid_file)) self.assertEqual(fileutil.read_file(pid_file), ustr(1112)) return @@ -1530,8 +1543,8 @@ class TestUpdate(UpdateTestCase): def test_write_pid_file_ignores_exceptions(self): with patch('azurelinuxagent.common.utils.fileutil.write_file', side_effect=Exception): with patch('os.getpid', return_value=42): - previous_pid_file, pid_file = self.update_handler._write_pid_file() - self.assertEqual(None, previous_pid_file) + pid_files, pid_file = self.update_handler._write_pid_file() + self.assertEqual(0, len(pid_files)) self.assertEqual(None, pid_file) return @@ -1549,12 +1562,22 @@ class ProtocolMock(object): def __init__(self, family="TestAgent", etag=42, versions=None, client=None): self.family = family self.client = client + self.call_counts = { + "get_vmagent_manifests" : 0, + "get_vmagent_pkgs" : 0, + "update_goal_state" : 0 + } + self.goal_state_is_stale = False + self.goal_state_forced = False self.etag = etag self.versions = versions if versions is not None else [] self.create_manifests() self.create_packages() return + def emulate_stale_goal_state(self): + self.goal_state_is_stale = True + def create_manifests(self): self.agent_manifests = VMAgentManifestList() if len(self.versions) <= 0: @@ -1585,11 +1608,23 @@ class ProtocolMock(object): return self def get_vmagent_manifests(self): + self.call_counts["get_vmagent_manifests"] += 1 + if self.goal_state_is_stale: + self.goal_state_is_stale = False + raise ResourceGoneError() return self.agent_manifests, self.etag def get_vmagent_pkgs(self, manifest): + self.call_counts["get_vmagent_pkgs"] += 1 + if self.goal_state_is_stale: + self.goal_state_is_stale = False + raise ResourceGoneError() return self.agent_packages + def update_goal_state(self, forced=False, max_retry=3): + self.call_counts["update_goal_state"] += 1 + self.goal_state_forced = self.goal_state_forced or forced + return class ResponseMock(Mock): def __init__(self, status=restutil.httpclient.OK, response=None, reason=None): diff --git a/tests/pa/test_provision.py b/tests/pa/test_provision.py index 0446442..7045fcc 100644 --- a/tests/pa/test_provision.py +++ b/tests/pa/test_provision.py @@ -53,6 +53,24 @@ class TestProvision(AgentTestCase): data = DefaultOSUtil().decode_customdata(base64data) fileutil.write_file(tempfile.mktemp(), data) + @patch('azurelinuxagent.common.conf.get_provision_enabled', + return_value=False) + def test_provisioning_is_skipped_when_not_enabled(self, mock_conf): + ph = ProvisionHandler() + ph.osutil = DefaultOSUtil() + ph.osutil.get_instance_id = Mock( + return_value='B9F3C233-9913-9F42-8EB3-BA656DF32502') + + ph.is_provisioned = Mock() + ph.report_ready = Mock() + ph.write_provisioned = Mock() + + ph.run() + + ph.is_provisioned.assert_not_called() + ph.report_ready.assert_called_once() + ph.write_provisioned.assert_called_once() + @patch('os.path.isfile', return_value=False) def test_is_provisioned_not_provisioned(self, mock_isfile): ph = ProvisionHandler() @@ -64,33 +82,37 @@ class TestProvision(AgentTestCase): @patch('azurelinuxagent.pa.deprovision.get_deprovision_handler') def test_is_provisioned_is_provisioned(self, mock_deprovision, mock_read, mock_isfile): + ph = ProvisionHandler() ph.osutil = Mock() - ph.osutil.get_instance_id = \ - Mock(return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502") + ph.osutil.is_current_instance_id = Mock(return_value=True) ph.write_provisioned = Mock() deprovision_handler = Mock() mock_deprovision.return_value = deprovision_handler self.assertTrue(ph.is_provisioned()) + ph.osutil.is_current_instance_id.assert_called_once() deprovision_handler.run_changed_unique_id.assert_not_called() @patch('os.path.isfile', return_value=True) @patch('azurelinuxagent.common.utils.fileutil.read_file', - side_effect=["Value"]) + return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502") @patch('azurelinuxagent.pa.deprovision.get_deprovision_handler') def test_is_provisioned_not_deprovisioned(self, mock_deprovision, mock_read, mock_isfile): ph = ProvisionHandler() ph.osutil = Mock() + ph.osutil.is_current_instance_id = Mock(return_value=False) + ph.report_ready = Mock() ph.write_provisioned = Mock() deprovision_handler = Mock() mock_deprovision.return_value = deprovision_handler self.assertTrue(ph.is_provisioned()) + ph.osutil.is_current_instance_id.assert_called_once() deprovision_handler.run_changed_unique_id.assert_called_once() if __name__ == '__main__': diff --git a/tests/protocol/mockwiredata.py b/tests/protocol/mockwiredata.py index 4e45623..5924719 100644 --- a/tests/protocol/mockwiredata.py +++ b/tests/protocol/mockwiredata.py @@ -16,6 +16,7 @@ # from tests.tools import * +from azurelinuxagent.common.exception import HttpError, ResourceGoneError from azurelinuxagent.common.future import httpclient from azurelinuxagent.common.utils.cryptutil import CryptUtil @@ -53,6 +54,20 @@ DATA_FILE_EXT_AUTOUPGRADE_INTERNALVERSION["ext_conf"] = "wire/ext_conf_autoupgra class WireProtocolData(object): def __init__(self, data_files=DATA_FILE): + self.emulate_stale_goal_state = False + self.call_counts = { + "comp=versions" : 0, + "/versions" : 0, + "goalstate" : 0, + "hostingenvuri" : 0, + "sharedconfiguri" : 0, + "certificatesuri" : 0, + "extensionsconfiguri" : 0, + "extensionArtifact" : 0, + "manifest.xml" : 0, + "manifest_of_ga.xml" : 0, + "ExampleHandlerLinux" : 0 + } self.version_info = load_data(data_files.get("version_info")) self.goal_state = load_data(data_files.get("goal_state")) self.hosting_env = load_data(data_files.get("hosting_env")) @@ -67,32 +82,70 @@ class WireProtocolData(object): def mock_http_get(self, url, *args, **kwargs): content = None - if "versions" in url: + + resp = MagicMock() + resp.status = httpclient.OK + + # wire server versions + if "comp=versions" in url: content = self.version_info + self.call_counts["comp=versions"] += 1 + + # HostPlugin versions + elif "/versions" in url: + content = '["2015-09-01"]' + self.call_counts["/versions"] += 1 elif "goalstate" in url: content = self.goal_state + self.call_counts["goalstate"] += 1 elif "hostingenvuri" in url: content = self.hosting_env + self.call_counts["hostingenvuri"] += 1 elif "sharedconfiguri" in url: content = self.shared_config + self.call_counts["sharedconfiguri"] += 1 elif "certificatesuri" in url: content = self.certs + self.call_counts["certificatesuri"] += 1 elif "extensionsconfiguri" in url: content = self.ext_conf - elif "manifest.xml" in url: - content = self.manifest - elif "manifest_of_ga.xml" in url: - content = self.ga_manifest - elif "ExampleHandlerLinux" in url: - content = self.ext - resp = MagicMock() - resp.status = httpclient.OK - resp.read = Mock(return_value=content) - return resp + self.call_counts["extensionsconfiguri"] += 1 + else: - raise Exception("Bad url {0}".format(url)) - resp = MagicMock() - resp.status = httpclient.OK + # A stale GoalState results in a 400 from the HostPlugin + # for which the HTTP handler in restutil raises ResourceGoneError + if self.emulate_stale_goal_state: + if "extensionArtifact" in url: + self.emulate_stale_goal_state = False + self.call_counts["extensionArtifact"] += 1 + raise ResourceGoneError() + else: + raise HttpError() + + # For HostPlugin requests, replace the URL with that passed + # via the x-ms-artifact-location header + if "extensionArtifact" in url: + self.call_counts["extensionArtifact"] += 1 + if "headers" not in kwargs or \ + "x-ms-artifact-location" not in kwargs["headers"]: + raise Exception("Bad HEADERS passed to HostPlugin: {0}", + kwargs) + url = kwargs["headers"]["x-ms-artifact-location"] + + if "manifest.xml" in url: + content = self.manifest + self.call_counts["manifest.xml"] += 1 + elif "manifest_of_ga.xml" in url: + content = self.ga_manifest + self.call_counts["manifest_of_ga.xml"] += 1 + elif "ExampleHandlerLinux" in url: + content = self.ext + self.call_counts["ExampleHandlerLinux"] += 1 + resp.read = Mock(return_value=content) + return resp + else: + raise Exception("Bad url {0}".format(url)) + resp.read = Mock(return_value=content.encode("utf-8")) return resp diff --git a/tests/protocol/test_hostplugin.py b/tests/protocol/test_hostplugin.py index b18b691..74f7f24 100644 --- a/tests/protocol/test_hostplugin.py +++ b/tests/protocol/test_hostplugin.py @@ -146,6 +146,7 @@ class TestHostPlugin(AgentTestCase): test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state) status = restapi.VMStatus(status="Ready", message="Guest Agent is running") + wire.HostPluginProtocol.set_default_channel(False) with patch.object(wire.HostPluginProtocol, "ensure_initialized", return_value=True): @@ -173,6 +174,7 @@ class TestHostPlugin(AgentTestCase): test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state) status = restapi.VMStatus(status="Ready", message="Guest Agent is running") + wire.HostPluginProtocol.set_default_channel(False) with patch.object(wire.StatusBlob, "upload", return_value=False): @@ -211,6 +213,8 @@ class TestHostPlugin(AgentTestCase): bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: + patch_http.return_value = Mock(status=httpclient.OK) + wire_protocol_client.get_goal_state = Mock(return_value=test_goal_state) plugin = wire_protocol_client.get_host_plugin() @@ -224,61 +228,6 @@ class TestHostPlugin(AgentTestCase): test_goal_state, exp_method, exp_url, exp_data) - def test_read_response_error(self): - """ - Validate the read_response_error method handles encoding correctly - """ - responses = ['message', b'message', '\x80message\x80'] - response = MagicMock() - response.status = 'status' - response.reason = 'reason' - with patch.object(response, 'read') as patch_response: - for s in responses: - patch_response.return_value = s - result = hostplugin.HostPluginProtocol.read_response_error(response) - self.assertTrue('[status: reason]' in result) - self.assertTrue('message' in result) - - def test_read_response_bytes(self): - response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \ - '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \ - '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \ - '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \ - '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \ - 'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \ - '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \ - '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \ - '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \ - '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \ - '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \ - '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \ - '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \ - '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \ - '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \ - '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \ - '6c:73:22:3a:20:22:22:0a:7d'.split(':') - expected_response = '[status: reason] {\n "errorCode": "The blob ' \ - 'type is invalid for this operation.",\n ' \ - '"message": "<?xml version="1.0" ' \ - 'encoding="utf-8"?>' \ - '<Error><Code>InvalidBlobType</Code><Message>The ' \ - 'blob type is invalid for this operation.\n' \ - 'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \ - '\n "details": ""\n}' - - response_string = ''.join(chr(int(b, 16)) for b in response_bytes) - response = MagicMock() - response.status = 'status' - response.reason = 'reason' - with patch.object(response, 'read') as patch_response: - patch_response.return_value = response_string - result = hostplugin.HostPluginProtocol.read_response_error(response) - self.assertEqual(result, expected_response) - try: - raise HttpError("{0}".format(result)) - except HttpError as e: - self.assertTrue(result in ustr(e)) - def test_no_fallback(self): """ Validate fallback to upload status using HostGAPlugin is not happening @@ -318,6 +267,8 @@ class TestHostPlugin(AgentTestCase): bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: + patch_http.return_value = Mock(status=httpclient.OK) + with patch.object(wire.HostPluginProtocol, "get_api_versions") as patch_get: patch_get.return_value = api_versions diff --git a/tests/protocol/test_metadata.py b/tests/protocol/test_metadata.py index ee4ba3e..5047b86 100644 --- a/tests/protocol/test_metadata.py +++ b/tests/protocol/test_metadata.py @@ -31,17 +31,15 @@ class TestMetadataProtocolGetters(AgentTestCase): return json.loads(ustr(load_data(path)), encoding="utf-8") @patch("time.sleep") - @patch("azurelinuxagent.common.protocol.metadata.restutil") - def _test_getters(self, test_data, mock_restutil ,_): - mock_restutil.http_get.side_effect = test_data.mock_http_get - - protocol = MetadataProtocol() - protocol.detect() - protocol.get_vminfo() - protocol.get_certs() - ext_handlers, etag = protocol.get_ext_handlers() - for ext_handler in ext_handlers.extHandlers: - protocol.get_ext_handler_pkgs(ext_handler) + def _test_getters(self, test_data ,_): + with patch.object(restutil, 'http_get', test_data.mock_http_get): + protocol = MetadataProtocol() + protocol.detect() + protocol.get_vminfo() + protocol.get_certs() + ext_handlers, etag = protocol.get_ext_handlers() + for ext_handler in ext_handlers.extHandlers: + protocol.get_ext_handler_pkgs(ext_handler) def test_getters(self, *args): test_data = MetadataProtocolData(DATA_FILE) diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py index 02976ca..d19bab1 100644 --- a/tests/protocol/test_wire.py +++ b/tests/protocol/test_wire.py @@ -25,30 +25,34 @@ wireserver_url = '168.63.129.16' @patch("time.sleep") @patch("azurelinuxagent.common.protocol.wire.CryptUtil") -@patch("azurelinuxagent.common.protocol.wire.restutil") class TestWireProtocolGetters(AgentTestCase): - def _test_getters(self, test_data, mock_restutil, MockCryptUtil, _): - mock_restutil.http_get.side_effect = test_data.mock_http_get + + def setUp(self): + super(TestWireProtocolGetters, self).setUp() + HostPluginProtocol.set_default_channel(False) + + def _test_getters(self, test_data, MockCryptUtil, _): MockCryptUtil.side_effect = test_data.mock_crypt_util - protocol = WireProtocol(wireserver_url) - protocol.detect() - protocol.get_vminfo() - protocol.get_certs() - ext_handlers, etag = protocol.get_ext_handlers() - for ext_handler in ext_handlers.extHandlers: - protocol.get_ext_handler_pkgs(ext_handler) - - crt1 = os.path.join(self.tmp_dir, - '33B0ABCE4673538650971C10F7D7397E71561F35.crt') - crt2 = os.path.join(self.tmp_dir, - '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt') - prv2 = os.path.join(self.tmp_dir, - '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv') - - self.assertTrue(os.path.isfile(crt1)) - self.assertTrue(os.path.isfile(crt2)) - self.assertTrue(os.path.isfile(prv2)) + with patch.object(restutil, 'http_get', test_data.mock_http_get): + protocol = WireProtocol(wireserver_url) + protocol.detect() + protocol.get_vminfo() + protocol.get_certs() + ext_handlers, etag = protocol.get_ext_handlers() + for ext_handler in ext_handlers.extHandlers: + protocol.get_ext_handler_pkgs(ext_handler) + + crt1 = os.path.join(self.tmp_dir, + '33B0ABCE4673538650971C10F7D7397E71561F35.crt') + crt2 = os.path.join(self.tmp_dir, + '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt') + prv2 = os.path.join(self.tmp_dir, + '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv') + + self.assertTrue(os.path.isfile(crt1)) + self.assertTrue(os.path.isfile(crt2)) + self.assertTrue(os.path.isfile(prv2)) def test_getters(self, *args): """Normal case""" @@ -70,8 +74,21 @@ class TestWireProtocolGetters(AgentTestCase): test_data = WireProtocolData(DATA_FILE_EXT_NO_PUBLIC) self._test_getters(test_data, *args) + def test_getters_with_stale_goal_state(self, *args): + test_data = WireProtocolData(DATA_FILE) + test_data.emulate_stale_goal_state = True + + self._test_getters(test_data, *args) + # Ensure HostPlugin was invoked + self.assertEqual(1, test_data.call_counts["/versions"]) + self.assertEqual(2, test_data.call_counts["extensionArtifact"]) + # Ensure the expected number of HTTP calls were made + # -- Tracking calls to retrieve GoalState is problematic since it is + # fetched often; however, the dependent documents, such as the + # HostingEnvironmentConfig, will be retrieved the expected number + self.assertEqual(2, test_data.call_counts["hostingenvuri"]) + def test_call_storage_kwargs(self, - mock_restutil, mock_cryptutil, mock_sleep): from azurelinuxagent.common.utils import restutil @@ -83,32 +100,32 @@ class TestWireProtocolGetters(AgentTestCase): # no kwargs -- Default to True WireClient.call_storage_service(http_req) - # kwargs, no chk_proxy -- Default to True + # kwargs, no use_proxy -- Default to True WireClient.call_storage_service(http_req, url, headers) - # kwargs, chk_proxy None -- Default to True + # kwargs, use_proxy None -- Default to True WireClient.call_storage_service(http_req, url, headers, - chk_proxy=None) + use_proxy=None) - # kwargs, chk_proxy False -- Keep False + # kwargs, use_proxy False -- Keep False WireClient.call_storage_service(http_req, url, headers, - chk_proxy=False) + use_proxy=False) - # kwargs, chk_proxy True -- Keep True + # kwargs, use_proxy True -- Keep True WireClient.call_storage_service(http_req, url, headers, - chk_proxy=True) + use_proxy=True) # assert self.assertTrue(http_patch.call_count == 5) for i in range(0,5): - c = http_patch.call_args_list[i][-1]['chk_proxy'] + c = http_patch.call_args_list[i][-1]['use_proxy'] self.assertTrue(c == (True if i != 3 else False)) def test_status_blob_parsing(self, *args): diff --git a/tests/test_agent.py b/tests/test_agent.py index 1b35933..77be07a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -17,12 +17,56 @@ import mock import os.path +import sys from azurelinuxagent.agent import * from azurelinuxagent.common.conf import * from tests.tools import * +EXPECTED_CONFIGURATION = \ +"""AutoUpdate.Enabled = True +AutoUpdate.GAFamily = Prod +Autoupdate.Frequency = 3600 +DVD.MountPoint = /mnt/cdrom/secure +DetectScvmmEnv = False +EnableOverProvisioning = False +Extension.LogDir = /var/log/azure +HttpProxy.Host = None +HttpProxy.Port = None +Lib.Dir = /var/lib/waagent +Logs.Verbose = False +OS.AllowHTTP = False +OS.CheckRdmaDriver = False +OS.EnableFIPS = True +OS.EnableFirewall = True +OS.EnableRDMA = False +OS.HomeDir = /home +OS.OpensslPath = /usr/bin/openssl +OS.PasswordPath = /etc/shadow +OS.RootDeviceScsiTimeout = 300 +OS.SshDir = /notareal/path +OS.SudoersDir = /etc/sudoers.d +OS.UpdateRdmaDriver = False +Pid.File = /var/run/waagent.pid +Provisioning.AllowResetSysUser = False +Provisioning.DecodeCustomData = False +Provisioning.DeleteRootPassword = True +Provisioning.Enabled = True +Provisioning.ExecuteCustomData = False +Provisioning.MonitorHostName = True +Provisioning.PasswordCryptId = 6 +Provisioning.PasswordCryptSaltLength = 10 +Provisioning.RegenerateSshHostKeyPair = True +Provisioning.SshHostKeyPairType = rsa +Provisioning.UseCloudInit = True +ResourceDisk.EnableSwap = False +ResourceDisk.Filesystem = ext4 +ResourceDisk.Format = True +ResourceDisk.MountOptions = None +ResourceDisk.MountPoint = /mnt/resource +ResourceDisk.SwapSizeMB = 0 +""".split('\n') class TestAgent(AgentTestCase): @@ -90,3 +134,36 @@ class TestAgent(AgentTestCase): mock_daemon.run.assert_called_once_with(child_args="-configuration-path:/foo/bar.conf") mock_load.assert_called_once() + + @patch("azurelinuxagent.common.conf.get_ext_log_dir") + def test_agent_ensures_extension_log_directory(self, mock_dir): + ext_log_dir = os.path.join(self.tmp_dir, "FauxLogDir") + mock_dir.return_value = ext_log_dir + + self.assertFalse(os.path.isdir(ext_log_dir)) + agent = Agent(False, + conf_file_path=os.path.join(data_dir, "test_waagent.conf")) + self.assertTrue(os.path.isdir(ext_log_dir)) + + @patch("azurelinuxagent.common.logger.error") + @patch("azurelinuxagent.common.conf.get_ext_log_dir") + def test_agent_logs_if_extension_log_directory_is_a_file(self, mock_dir, mock_log): + ext_log_dir = os.path.join(self.tmp_dir, "FauxLogDir") + mock_dir.return_value = ext_log_dir + fileutil.write_file(ext_log_dir, "Foo") + + self.assertTrue(os.path.isfile(ext_log_dir)) + self.assertFalse(os.path.isdir(ext_log_dir)) + agent = Agent(False, + conf_file_path=os.path.join(data_dir, "test_waagent.conf")) + self.assertTrue(os.path.isfile(ext_log_dir)) + self.assertFalse(os.path.isdir(ext_log_dir)) + mock_log.assert_called_once() + + def test_agent_show_configuration(self): + if not hasattr(sys.stdout, 'getvalue'): + self.fail('Test requires at least Python 2.7 with buffered output') + agent = Agent(False, + conf_file_path=os.path.join(data_dir, "test_waagent.conf")) + agent.show_configuration() + self.assertEqual(EXPECTED_CONFIGURATION, sys.stdout.getvalue().split('\n')) diff --git a/tests/tools.py b/tests/tools.py index a505700..94fab7f 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -26,8 +26,10 @@ import tempfile import unittest from functools import wraps +import azurelinuxagent.common.event as event import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger + from azurelinuxagent.common.version import PY_VERSION_MAJOR #Import mock module for Python2 and Python3 @@ -51,14 +53,21 @@ if debug: class AgentTestCase(unittest.TestCase): def setUp(self): prefix = "{0}_".format(self.__class__.__name__) + self.tmp_dir = tempfile.mkdtemp(prefix=prefix) self.test_file = 'test_file' + conf.get_autoupdate_enabled = Mock(return_value=True) conf.get_lib_dir = Mock(return_value=self.tmp_dir) + ext_log_dir = os.path.join(self.tmp_dir, "azure") conf.get_ext_log_dir = Mock(return_value=ext_log_dir) + conf.get_agent_pid_file_path = Mock(return_value=os.path.join(self.tmp_dir, "waagent.pid")) + event.init_event_status(self.tmp_dir) + event.init_event_logger(self.tmp_dir) + def tearDown(self): if not debug and self.tmp_dir is not None: shutil.rmtree(self.tmp_dir) diff --git a/tests/utils/test_file_util.py b/tests/utils/test_file_util.py index 0b92513..87bce8c 100644 --- a/tests/utils/test_file_util.py +++ b/tests/utils/test_file_util.py @@ -15,6 +15,7 @@ # Requires Python 2.4+ and Openssl 1.0+ # +import errno as errno import glob import random import string @@ -64,6 +65,50 @@ class TestFileOperations(AgentTestCase): os.remove(test_file) + def test_findre_in_file(self): + fp = tempfile.mktemp() + with open(fp, 'w') as f: + f.write( +''' +First line +Second line +Third line with more words +''' + ) + + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, ".*rst line$")) + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, ".*ond line$")) + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, ".*with more.*")) + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, "^Third.*")) + self.assertEquals( + None, + fileutil.findre_in_file(fp, "^Do not match.*")) + + def test_findstr_in_file(self): + fp = tempfile.mktemp() + with open(fp, 'w') as f: + f.write( +''' +First line +Second line +Third line with more words +''' + ) + + self.assertTrue(fileutil.findstr_in_file(fp, "First line")) + self.assertTrue(fileutil.findstr_in_file(fp, "Second line")) + self.assertTrue( + fileutil.findstr_in_file(fp, "Third line with more words")) + self.assertFalse(fileutil.findstr_in_file(fp, "Not a line")) + def test_get_last_path_element(self): filepath = '/tmp/abc.def' filename = fileutil.base_name(filepath) @@ -197,5 +242,75 @@ DHCP_HOSTNAME=test\n" fileutil.update_conf_file(path, 'DHCP_HOSTNAME', 'DHCP_HOSTNAME=test') patch_write.assert_called_once_with(path, updated_file) + def test_clean_ioerror_ignores_missing(self): + e = IOError() + e.errno = errno.ENOSPC + + # Send no paths + fileutil.clean_ioerror(e) + + # Send missing file(s) / directories + fileutil.clean_ioerror(e, paths=['/foo/not/here', None, '/bar/not/there']) + + def test_clean_ioerror_ignores_unless_ioerror(self): + try: + d = tempfile.mkdtemp() + fd, f = tempfile.mkstemp() + os.close(fd) + fileutil.write_file(f, 'Not empty') + + # Send non-IOError exception + e = Exception() + fileutil.clean_ioerror(e, paths=[d, f]) + self.assertTrue(os.path.isdir(d)) + self.assertTrue(os.path.isfile(f)) + + # Send unrecognized IOError + e = IOError() + e.errno = errno.EFAULT + self.assertFalse(e.errno in fileutil.KNOWN_IOERRORS) + fileutil.clean_ioerror(e, paths=[d, f]) + self.assertTrue(os.path.isdir(d)) + self.assertTrue(os.path.isfile(f)) + + finally: + shutil.rmtree(d) + os.remove(f) + + def test_clean_ioerror_removes_files(self): + fd, f = tempfile.mkstemp() + os.close(fd) + fileutil.write_file(f, 'Not empty') + + e = IOError() + e.errno = errno.ENOSPC + fileutil.clean_ioerror(e, paths=[f]) + self.assertFalse(os.path.isdir(f)) + self.assertFalse(os.path.isfile(f)) + + def test_clean_ioerror_removes_directories(self): + d1 = tempfile.mkdtemp() + d2 = tempfile.mkdtemp() + for n in ['foo', 'bar']: + fileutil.write_file(os.path.join(d2, n), 'Not empty') + + e = IOError() + e.errno = errno.ENOSPC + fileutil.clean_ioerror(e, paths=[d1, d2]) + self.assertFalse(os.path.isdir(d1)) + self.assertFalse(os.path.isfile(d1)) + self.assertFalse(os.path.isdir(d2)) + self.assertFalse(os.path.isfile(d2)) + + def test_clean_ioerror_handles_a_range_of_errors(self): + for err in fileutil.KNOWN_IOERRORS: + e = IOError() + e.errno = err + + d = tempfile.mkdtemp() + fileutil.clean_ioerror(e, paths=[d]) + self.assertFalse(os.path.isdir(d)) + self.assertFalse(os.path.isfile(d)) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_rest_util.py b/tests/utils/test_rest_util.py index 5f084a6..52674da 100644 --- a/tests/utils/test_rest_util.py +++ b/tests/utils/test_rest_util.py @@ -15,10 +15,16 @@ # Requires Python 2.4+ and Openssl 1.0+ # +import os import unittest + +from azurelinuxagent.common.exception import HttpError, \ + ProtocolError, \ + ResourceGoneError import azurelinuxagent.common.utils.restutil as restutil -from azurelinuxagent.common.future import httpclient -from tests.tools import AgentTestCase, patch, Mock, MagicMock + +from azurelinuxagent.common.future import httpclient, ustr +from tests.tools import * class TestHttpOperations(AgentTestCase): @@ -50,45 +56,163 @@ class TestHttpOperations(AgentTestCase): self.assertEquals(None, host) self.assertEquals(rel_uri, "None") + @patch('azurelinuxagent.common.conf.get_httpproxy_port') + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_none_is_default(self, mock_host, mock_port): + mock_host.return_value = None + mock_port.return_value = None + h, p = restutil._get_http_proxy() + self.assertEqual(None, h) + self.assertEqual(None, p) + + @patch('azurelinuxagent.common.conf.get_httpproxy_port') + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_configuration_overrides_env(self, mock_host, mock_port): + mock_host.return_value = "host" + mock_port.return_value = None + h, p = restutil._get_http_proxy() + self.assertEqual("host", h) + self.assertEqual(None, p) + mock_host.assert_called_once() + mock_port.assert_called_once() + + @patch('azurelinuxagent.common.conf.get_httpproxy_port') + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_configuration_requires_host(self, mock_host, mock_port): + mock_host.return_value = None + mock_port.return_value = None + h, p = restutil._get_http_proxy() + self.assertEqual(None, h) + self.assertEqual(None, p) + mock_host.assert_called_once() + mock_port.assert_not_called() + + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_http_uses_httpproxy(self, mock_host): + mock_host.return_value = None + with patch.dict(os.environ, { + 'http_proxy' : 'http://foo.com:80', + 'https_proxy' : 'https://bar.com:443' + }): + h, p = restutil._get_http_proxy() + self.assertEqual("foo.com", h) + self.assertEqual(80, p) + + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_https_uses_httpsproxy(self, mock_host): + mock_host.return_value = None + with patch.dict(os.environ, { + 'http_proxy' : 'http://foo.com:80', + 'https_proxy' : 'https://bar.com:443' + }): + h, p = restutil._get_http_proxy(secure=True) + self.assertEqual("bar.com", h) + self.assertEqual(443, p) + + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_ignores_user_in_httpproxy(self, mock_host): + mock_host.return_value = None + with patch.dict(os.environ, { + 'http_proxy' : 'http://user:pw@foo.com:80' + }): + h, p = restutil._get_http_proxy() + self.assertEqual("foo.com", h) + self.assertEqual(80, p) + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") - def test_http_request(self, HTTPConnection, HTTPSConnection): - mock_http_conn = MagicMock() - mock_http_resp = MagicMock() - mock_http_conn.getresponse = Mock(return_value=mock_http_resp) - HTTPConnection.return_value = mock_http_conn - HTTPSConnection.return_value = mock_http_conn + def test_http_request_direct(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) - mock_http_resp.read = Mock(return_value="_(:3| <)_") + HTTPConnection.return_value = mock_conn - # Test http get - resp = restutil._http_request("GET", "foo", "bar") - self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + resp = restutil._http_request("GET", "foo", "/bar") - # Test https get - resp = restutil._http_request("GET", "foo", "bar", secure=True) + HTTPConnection.assert_has_calls([ + call("foo", 80, timeout=10) + ]) + HTTPSConnection.assert_not_called() + mock_conn.request.assert_has_calls([ + call(method="GET", url="/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) + + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") + @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") + def test_http_request_direct_secure(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) + + HTTPSConnection.return_value = mock_conn + + resp = restutil._http_request("GET", "foo", "/bar", secure=True) - # Test http get with proxy - mock_http_resp.read = Mock(return_value="_(:3| <)_") - resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", - proxy_port=23333) + HTTPConnection.assert_not_called() + HTTPSConnection.assert_has_calls([ + call("foo", 443, timeout=10) + ]) + mock_conn.request.assert_has_calls([ + call(method="GET", url="/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) - # Test https get - resp = restutil._http_request("GET", "foo", "bar", secure=True) + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") + @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") + def test_http_request_proxy(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) + + HTTPConnection.return_value = mock_conn + + resp = restutil._http_request("GET", "foo", "/bar", + proxy_host="foo.bar", proxy_port=23333) + + HTTPConnection.assert_has_calls([ + call("foo.bar", 23333, timeout=10) + ]) + HTTPSConnection.assert_not_called() + mock_conn.request.assert_has_calls([ + call(method="GET", url="http://foo:80/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) + + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") + @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") + def test_http_request_proxy_secure(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) + + HTTPSConnection.return_value = mock_conn - # Test https get with proxy - mock_http_resp.read = Mock(return_value="_(:3| <)_") - resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", - proxy_port=23333, secure=True) + resp = restutil._http_request("GET", "foo", "/bar", + proxy_host="foo.bar", proxy_port=23333, + secure=True) + + HTTPConnection.assert_not_called() + HTTPSConnection.assert_has_calls([ + call("foo.bar", 23333, timeout=10) + ]) + mock_conn.request.assert_has_calls([ + call(method="GET", url="https://foo:443/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) @patch("time.sleep") @patch("azurelinuxagent.common.utils.restutil._http_request") @@ -115,6 +239,180 @@ class TestHttpOperations(AgentTestCase): self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar") + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_status_codes(self, _http_request, _sleep): + _http_request.side_effect = [ + Mock(status=httpclient.SERVICE_UNAVAILABLE), + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar") + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_passed_status_codes(self, _http_request, _sleep): + # Ensure the code is not part of the standard set + self.assertFalse(httpclient.UNAUTHORIZED in restutil.RETRY_CODES) + + _http_request.side_effect = [ + Mock(status=httpclient.UNAUTHORIZED), + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar", retry_codes=[httpclient.UNAUTHORIZED]) + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_raises_for_bad_request(self, _http_request, _sleep): + _http_request.side_effect = [ + Mock(status=httpclient.BAD_REQUEST) + ] + + self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar") + self.assertEqual(1, _http_request.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_raises_for_resource_gone(self, _http_request, _sleep): + _http_request.side_effect = [ + Mock(status=httpclient.GONE) + ] + + self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar") + self.assertEqual(1, _http_request.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_exceptions(self, _http_request, _sleep): + # Testing each exception is difficult because they have varying + # signatures; for now, test one and ensure the set is unchanged + recognized_exceptions = [ + httpclient.NotConnected, + httpclient.IncompleteRead, + httpclient.ImproperConnectionState, + httpclient.BadStatusLine + ] + self.assertEqual(recognized_exceptions, restutil.RETRY_EXCEPTIONS) + + _http_request.side_effect = [ + httpclient.IncompleteRead(''), + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar") + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_ioerrors(self, _http_request, _sleep): + ioerror = IOError() + ioerror.errno = 42 + + _http_request.side_effect = [ + ioerror, + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar") + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + def test_request_failed(self): + self.assertTrue(restutil.request_failed(None)) + + resp = Mock() + for status in restutil.OK_CODES: + resp.status = status + self.assertFalse(restutil.request_failed(resp)) + + self.assertFalse(httpclient.BAD_REQUEST in restutil.OK_CODES) + resp.status = httpclient.BAD_REQUEST + self.assertTrue(restutil.request_failed(resp)) + + self.assertFalse( + restutil.request_failed( + resp, ok_codes=[httpclient.BAD_REQUEST])) + + def test_request_succeeded(self): + self.assertFalse(restutil.request_succeeded(None)) + + resp = Mock() + for status in restutil.OK_CODES: + resp.status = status + self.assertTrue(restutil.request_succeeded(resp)) + + self.assertFalse(httpclient.BAD_REQUEST in restutil.OK_CODES) + resp.status = httpclient.BAD_REQUEST + self.assertFalse(restutil.request_succeeded(resp)) + + self.assertTrue( + restutil.request_succeeded( + resp, ok_codes=[httpclient.BAD_REQUEST])) + + def test_read_response_error(self): + """ + Validate the read_response_error method handles encoding correctly + """ + responses = ['message', b'message', '\x80message\x80'] + response = MagicMock() + response.status = 'status' + response.reason = 'reason' + with patch.object(response, 'read') as patch_response: + for s in responses: + patch_response.return_value = s + result = restutil.read_response_error(response) + print("RESPONSE: {0}".format(s)) + print("RESULT: {0}".format(result)) + print("PRESENT: {0}".format('[status: reason]' in result)) + self.assertTrue('[status: reason]' in result) + self.assertTrue('message' in result) + + def test_read_response_bytes(self): + response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \ + '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \ + '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \ + '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \ + '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \ + 'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \ + '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \ + '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \ + '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \ + '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \ + '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \ + '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \ + '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \ + '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \ + '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \ + '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \ + '6c:73:22:3a:20:22:22:0a:7d'.split(':') + expected_response = '[HTTP Failed] [status: reason] {\n "errorCode": "The blob ' \ + 'type is invalid for this operation.",\n ' \ + '"message": "<?xml version="1.0" ' \ + 'encoding="utf-8"?>' \ + '<Error><Code>InvalidBlobType</Code><Message>The ' \ + 'blob type is invalid for this operation.\n' \ + 'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \ + '\n "details": ""\n}' + + response_string = ''.join(chr(int(b, 16)) for b in response_bytes) + response = MagicMock() + response.status = 'status' + response.reason = 'reason' + with patch.object(response, 'read') as patch_response: + patch_response.return_value = response_string + result = restutil.read_response_error(response) + self.assertEqual(result, expected_response) + try: + raise HttpError("{0}".format(result)) + except HttpError as e: + self.assertTrue(result in ustr(e)) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_text_util.py b/tests/utils/test_text_util.py index 6f204c7..d182a67 100644 --- a/tests/utils/test_text_util.py +++ b/tests/utils/test_text_util.py @@ -34,6 +34,19 @@ class TestTextUtil(AgentTestCase): password_hash = textutil.gen_password_hash(data, 6, 10) self.assertNotEquals(None, password_hash) + def test_replace_non_ascii(self): + data = ustr(b'\xef\xbb\xbfhehe', encoding='utf-8') + self.assertEqual('hehe', textutil.replace_non_ascii(data)) + + data = "abcd\xa0e\xf0fghijk\xbblm" + self.assertEqual("abcdefghijklm", textutil.replace_non_ascii(data)) + + data = "abcd\xa0e\xf0fghijk\xbblm" + self.assertEqual("abcdXeXfghijkXlm", + textutil.replace_non_ascii(data, replace_char='X')) + + self.assertEqual('', textutil.replace_non_ascii(None)) + def test_remove_bom(self): #Test bom could be removed data = ustr(b'\xef\xbb\xbfhehe', encoding='utf-8') @@ -94,6 +107,37 @@ class TestTextUtil(AgentTestCase): "-----END PRIVATE Key-----\n") base64_bytes = textutil.get_bytes_from_pem(content) self.assertEquals("private key", base64_bytes) + + def test_swap_hexstring(self): + data = [ + ['12', 1, '21'], + ['12', 2, '12'], + ['12', 3, '012'], + ['12', 4, '0012'], + + ['123', 1, '321'], + ['123', 2, '2301'], + ['123', 3, '123'], + ['123', 4, '0123'], + + ['1234', 1, '4321'], + ['1234', 2, '3412'], + ['1234', 3, '234001'], + ['1234', 4, '1234'], + + ['abcdef12', 1, '21fedcba'], + ['abcdef12', 2, '12efcdab'], + ['abcdef12', 3, 'f12cde0ab'], + ['abcdef12', 4, 'ef12abcd'], + + ['aBcdEf12', 1, '21fEdcBa'], + ['aBcdEf12', 2, '12EfcdaB'], + ['aBcdEf12', 3, 'f12cdE0aB'], + ['aBcdEf12', 4, 'Ef12aBcd'] + ] + + for t in data: + self.assertEqual(t[2], textutil.swap_hexstring(t[0], width=t[1])) if __name__ == '__main__': unittest.main() |