diff options
Diffstat (limited to 'tests/ga/test_update.py')
-rw-r--r-- | tests/ga/test_update.py | 381 |
1 files changed, 324 insertions, 57 deletions
diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py index 89fe95d..a431a9b 100644 --- a/tests/ga/test_update.py +++ b/tests/ga/test_update.py @@ -23,6 +23,7 @@ import json import os import platform import random +import re import subprocess import sys import tempfile @@ -109,8 +110,13 @@ class UpdateTestCase(AgentTestCase): AgentTestCase.setUp(self) return - def agent_bin(self, version): - return "bin/{0}-{1}.egg".format(AGENT_NAME, version) + def agent_bin(self, version, suffix): + return "bin/{0}-{1}{2}.egg".format(AGENT_NAME, version, suffix) + + def rename_agent_bin(self, path, src_v, dst_v): + src_bin = glob.glob(os.path.join(path, self.agent_bin(src_v, '*')))[0] + dst_bin = os.path.join(path, self.agent_bin(dst_v, '')) + shutil.move(src_bin, dst_bin) def agent_count(self): return len(self.agent_dirs()) @@ -158,8 +164,29 @@ class UpdateTestCase(AgentTestCase): fileutil.trim_ext(agent, "zip"))) return - def prepare_agents(self, base_version=AGENT_VERSION, count=5, is_available=True): - base_v = FlexibleVersion(base_version) + def prepare_agent(self, version): + """ + Create a download for the current agent version, copied from test data + """ + self.copy_agents(get_agent_pkgs()[0]) + self.expand_agents() + + versions = self.agent_versions() + src_v = FlexibleVersion(str(versions[0])) + + from_path = self.agent_dir(src_v) + dst_v = FlexibleVersion(str(version)) + to_path = self.agent_dir(dst_v) + + if from_path != to_path: + shutil.move(from_path + ".zip", to_path + ".zip") + shutil.move(from_path, to_path) + self.rename_agent_bin(to_path, src_v, dst_v) + return + + def prepare_agents(self, + count=5, + is_available=True): # Ensure the test data is copied over agent_count = self.agent_count() @@ -172,10 +199,6 @@ class UpdateTestCase(AgentTestCase): versions = self.agent_versions() src_v = FlexibleVersion(str(versions[0])) - # If the most recent agent is newer the minimum requested, use the agent version - if base_v < src_v: - base_v = src_v - # Create agent packages and directories return self.replicate_agents( src_v=src_v, @@ -193,25 +216,21 @@ class UpdateTestCase(AgentTestCase): pass return - def replicate_agents( - self, - count=5, - src_v=AGENT_VERSION, - is_available=True, - increment=1): + def replicate_agents(self, + count=5, + src_v=AGENT_VERSION, + is_available=True, + increment=1): from_path = self.agent_dir(src_v) dst_v = FlexibleVersion(str(src_v)) - for i in range(0,count): + for i in range(0, count): dst_v += increment to_path = self.agent_dir(dst_v) shutil.copyfile(from_path + ".zip", to_path + ".zip") shutil.copytree(from_path, to_path) - shutil.move( - os.path.join(to_path, self.agent_bin(src_v)), - os.path.join(to_path, self.agent_bin(dst_v))) + self.rename_agent_bin(to_path, src_v, dst_v) if not is_available: GuestAgent(to_path).mark_failure(is_fatal=True) - return dst_v @@ -437,7 +456,8 @@ class TestGuestAgent(UpdateTestCase): agent = GuestAgent(path=self.agent_path) agent._unpack() agent._load_manifest() - self.assertEqual(agent.manifest.get_enable_command(), agent.get_agent_cmd()) + self.assertEqual(agent.manifest.get_enable_command(), + agent.get_agent_cmd()) return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") @@ -515,6 +535,34 @@ class TestGuestAgent(UpdateTestCase): self.assertFalse(agent.is_downloaded) return + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.restutil.http_get") + def test_download_fallback(self, mock_http_get, mock_ensure): + self.remove_agents() + self.assertFalse(os.path.isdir(self.agent_path)) + + mock_http_get.return_value = ResponseMock( + status=restutil.httpclient.SERVICE_UNAVAILABLE) + + ext_uri = 'ext_uri' + host_uri = 'host_uri' + mock_host = HostPluginProtocol(host_uri, + 'container_id', + 'role_config') + + pkg = ExtHandlerPackage(version=str(get_agent_version())) + pkg.uris.append(ExtHandlerPackageUri(uri=ext_uri)) + agent = GuestAgent(pkg=pkg) + agent.host = mock_host + + with patch.object(HostPluginProtocol, + "get_artifact_request", + return_value=[host_uri, {}]): + self.assertRaises(UpdateError, agent._download) + self.assertEqual(mock_http_get.call_count, 2) + self.assertEqual(mock_http_get.call_args_list[0][0][0], ext_uri) + self.assertEqual(mock_http_get.call_args_list[1][0][0], host_uri) + @patch("azurelinuxagent.ga.update.restutil.http_get") def test_ensure_downloaded(self, mock_http_get): self.remove_agents() @@ -598,13 +646,13 @@ class TestGuestAgent(UpdateTestCase): class TestUpdate(UpdateTestCase): def setUp(self): UpdateTestCase.setUp(self) + self.event_patch = patch('azurelinuxagent.common.event.add_event') self.update_handler = get_update_handler() return def test_creation(self): self.assertTrue(self.update_handler.running) - self.assertEqual(None, self.update_handler.last_etag) self.assertEqual(None, self.update_handler.last_attempt_time) self.assertEqual(0, len(self.update_handler.agents)) @@ -617,82 +665,113 @@ class TestUpdate(UpdateTestCase): self.assertEqual(None, self.update_handler.signal_handler) return - def _test_ensure_latest_agent( + def test_emit_restart_event_writes_sentinal_file(self): + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + self.update_handler._emit_restart_event() + self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path())) + return + + def test_emit_restart_event_emits_event_if_not_clean_start(self): + try: + mock_event = self.event_patch.start() + self.update_handler._set_sentinal() + self.update_handler._emit_restart_event() + self.assertEqual(1, mock_event.call_count) + except Exception as e: + pass + self.event_patch.stop() + return + + def _test_upgrade_available( self, base_version=FlexibleVersion(AGENT_VERSION), protocol=None, - versions=None): + versions=None, + count=5): - latest_version = self.prepare_agents() + latest_version = self.prepare_agents(count=count) if versions is None or len(versions) <= 0: versions = [latest_version] - etag = self.update_handler.last_etag if self.update_handler.last_etag is not None else 42 if protocol is None: - protocol = ProtocolMock(etag=etag, versions=versions) + protocol = ProtocolMock(versions=versions) self.update_handler.protocol_util = protocol conf.get_autoupdate_gafamily = Mock(return_value=protocol.family) - return self.update_handler._ensure_latest_agent(base_version=base_version) + return self.update_handler._upgrade_available(base_version=base_version) + + def test_upgrade_available_returns_true_on_first_use(self): + self.assertTrue(self._test_upgrade_available()) + return + + def test_get_latest_agent_excluded(self): + self.prepare_agent(AGENT_VERSION) + self.assertFalse(self._test_upgrade_available( + versions=self.agent_versions(), + count=1)) + self.assertEqual(None, self.update_handler.get_latest_agent()) + return - def test_ensure_latest_agent_returns_true_on_first_use(self): - self.assertEqual(None, self.update_handler.last_etag) - self.assertTrue(self._test_ensure_latest_agent()) + def test_upgrade_available_handles_missing_family(self): + extensions_config = ExtensionsConfig(load_data("wire/ext_conf_missing_family.xml")) + protocol = ProtocolMock() + protocol.family = "Prod" + protocol.agent_manifests = extensions_config.vmagent_manifests + self.update_handler.protocol_util = protocol + with patch('azurelinuxagent.common.logger.warn') as mock_logger: + with patch('tests.ga.test_update.ProtocolMock.get_vmagent_pkgs', side_effect=ProtocolError): + self.assertFalse(self.update_handler._upgrade_available(base_version=CURRENT_VERSION)) + self.assertEqual(0, mock_logger.call_count) return - def test_ensure_latest_agent_includes_old_agents(self): + def test_upgrade_available_includes_old_agents(self): self.prepare_agents() - old_count = FlexibleVersion(AGENT_VERSION).version[-1] old_version = self.agent_versions()[-1] + old_count = old_version.version[-1] self.replicate_agents(src_v=old_version, count=old_count, increment=-1) all_count = len(self.agent_versions()) - self.assertTrue(self._test_ensure_latest_agent(versions=self.agent_versions())) + self.assertTrue(self._test_upgrade_available(versions=self.agent_versions())) self.assertEqual(all_count, len(self.update_handler.agents)) return - def test_ensure_lastest_agent_purges_old_agents(self): + def test_upgrade_available_purges_old_agents(self): self.prepare_agents() agent_count = self.agent_count() self.assertEqual(5, agent_count) agent_versions = self.agent_versions()[:3] - self.assertTrue(self._test_ensure_latest_agent(versions=agent_versions)) + self.assertTrue(self._test_upgrade_available(versions=agent_versions)) self.assertEqual(len(agent_versions), len(self.update_handler.agents)) self.assertEqual(agent_versions, self.agent_versions()) return - def test_ensure_latest_agent_skips_if_too_frequent(self): + def test_upgrade_available_skips_if_too_frequent(self): conf.get_autoupdate_frequency = Mock(return_value=10000) self.update_handler.last_attempt_time = time.time() - self.assertFalse(self._test_ensure_latest_agent()) + self.assertFalse(self._test_upgrade_available()) return - def test_ensure_latest_agent_skips_when_etag_matches(self): - self.update_handler.last_etag = 42 - self.assertFalse(self._test_ensure_latest_agent()) - return - - def test_ensure_latest_agent_skips_if_when_no_new_versions(self): + def test_upgrade_available_skips_if_when_no_new_versions(self): self.prepare_agents() base_version = self.agent_versions()[0] + 1 - self.assertFalse(self._test_ensure_latest_agent(base_version=base_version)) + self.assertFalse(self._test_upgrade_available(base_version=base_version)) return - def test_ensure_latest_agent_skips_when_no_versions(self): - self.assertFalse(self._test_ensure_latest_agent(protocol=ProtocolMock())) + def test_upgrade_available_skips_when_no_versions(self): + self.assertFalse(self._test_upgrade_available(protocol=ProtocolMock())) return - def test_ensure_latest_agent_skips_when_updates_are_disabled(self): + def test_upgrade_available_skips_when_updates_are_disabled(self): conf.get_autoupdate_enabled = Mock(return_value=False) - self.assertFalse(self._test_ensure_latest_agent()) + self.assertFalse(self._test_upgrade_available()) return - def test_ensure_latest_agent_sorts(self): + def test_upgrade_available_sorts(self): self.prepare_agents() - self._test_ensure_latest_agent() + self._test_upgrade_available() v = FlexibleVersion("100000") for a in self.update_handler.agents: @@ -700,6 +779,58 @@ class TestUpdate(UpdateTestCase): v = a.version return + def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL): + with patch.object(self.update_handler, 'osutil') as mock_util: + # Note: + # - Python only allows mutations of objects to which a function has + # a reference. Incrementing an integer directly changes the + # reference. Incrementing an item of a list changes an item to + # which the code has a reference. + # See http://stackoverflow.com/questions/26408941/python-nested-functions-and-variable-scope + iterations = [0] + def iterator(*args, **kwargs): + iterations[0] += 1 + return iterations[0] < invocations + + mock_util.check_pid_alive = Mock(side_effect=iterator) + + with patch('os.getpid', return_value=42): + with patch('time.sleep', return_value=None) as mock_sleep: + self.update_handler._ensure_no_orphans(orphan_wait_interval=interval) + return mock_util.check_pid_alive.call_count, mock_sleep.call_count + return + + def test_ensure_no_orphans(self): + fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41)) + calls, sleeps = self._test_ensure_no_orphans(invocations=3) + self.assertEqual(3, calls) + self.assertEqual(2, sleeps) + return + + def test_ensure_no_orphans_skips_if_no_orphans(self): + calls, sleeps = self._test_ensure_no_orphans(invocations=3) + self.assertEqual(0, calls) + self.assertEqual(0, sleeps) + return + + def test_ensure_no_orphans_ignores_exceptions(self): + with patch('azurelinuxagent.common.utils.fileutil.read_file', side_effect=Exception): + calls, sleeps = self._test_ensure_no_orphans(invocations=3) + self.assertEqual(0, calls) + self.assertEqual(0, sleeps) + return + + def test_ensure_no_orphans_kills_after_interval(self): + fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41)) + with patch('os.kill') as mock_kill: + calls, sleeps = self._test_ensure_no_orphans( + invocations=4, + interval=3*GOAL_STATE_INTERVAL) + self.assertEqual(3, calls) + self.assertEqual(2, sleeps) + self.assertEqual(1, mock_kill.call_count) + return + def _test_evaluate_agent_health(self, child_agent_index=0): self.prepare_agents() @@ -789,6 +920,59 @@ class TestUpdate(UpdateTestCase): self.assertEqual(latest_agent.version, prior_agent.version) return + def test_get_pid_files(self): + previous_pid_file, pid_file, = self.update_handler._get_pid_files() + self.assertEqual(None, previous_pid_file) + self.assertEqual("0_waagent.pid", os.path.basename(pid_file)) + return + + def test_get_pid_files_returns_previous(self): + for n in range(1250): + fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1)) + previous_pid_file, pid_file, = self.update_handler._get_pid_files() + self.assertEqual("1249_waagent.pid", os.path.basename(previous_pid_file)) + self.assertEqual("1250_waagent.pid", os.path.basename(pid_file)) + return + + def test_is_clean_start_returns_true_when_no_sentinal(self): + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + self.assertTrue(self.update_handler._is_clean_start) + return + + def test_is_clean_start_returns_true_sentinal_agent_is_not_current(self): + self.update_handler._set_sentinal(agent="Not the Current Agent") + self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path())) + self.assertTrue(self.update_handler._is_clean_start) + return + + def test_is_clean_start_returns_false_for_current_agent(self): + self.update_handler._set_sentinal(agent=CURRENT_AGENT) + self.assertFalse(self.update_handler._is_clean_start) + return + + def test_is_clean_start_returns_false_for_exceptions(self): + self.update_handler._set_sentinal() + with patch("azurelinuxagent.common.utils.fileutil.read_file", side_effect=Exception): + self.assertFalse(self.update_handler._is_clean_start) + return + + def test_is_orphaned_returns_false_if_parent_exists(self): + fileutil.write_file(conf.get_agent_pid_file_path(), ustr(42)) + with patch('os.getppid', return_value=42): + self.assertFalse(self.update_handler._is_orphaned) + return + + def test_is_orphaned_returns_true_if_parent_is_init(self): + with patch('os.getppid', return_value=1): + self.assertTrue(self.update_handler._is_orphaned) + return + + def test_is_orphaned_returns_true_if_parent_does_not_exist(self): + fileutil.write_file(conf.get_agent_pid_file_path(), ustr(24)) + with patch('os.getppid', return_value=42): + self.assertTrue(self.update_handler._is_orphaned) + return + def test_load_agents(self): self.prepare_agents() @@ -868,13 +1052,14 @@ class TestUpdate(UpdateTestCase): agent = self.update_handler.get_latest_agent() args, kwargs = self._test_run_latest() - cmds = shlex.split(agent.get_agent_cmd()) + cmds = textutil.safe_shlex_split(agent.get_agent_cmd()) if cmds[0].lower() == "python": cmds[0] = get_python_cmd() self.assertEqual(args[0], cmds) self.assertEqual(True, 'cwd' in kwargs) self.assertEqual(agent.get_agent_dir(), kwargs['cwd']) + self.assertEqual(False, '\x00' in cmds[0]) return def test_run_latest_polls_and_waits_for_success(self): @@ -993,7 +1178,7 @@ class TestUpdate(UpdateTestCase): # reference. Incrementing an item of a list changes an item to # which the code has a reference. # See http://stackoverflow.com/questions/26408941/python-nested-functions-and-variable-scope - iterations = [0] + iterations = [0] def iterator(*args, **kwargs): iterations[0] += 1 if iterations[0] >= invocations: @@ -1001,14 +1186,19 @@ class TestUpdate(UpdateTestCase): return calls = calls * invocations - + + fileutil.write_file(conf.get_agent_pid_file_path(), ustr(42)) + with patch('azurelinuxagent.ga.exthandlers.get_exthandlers_handler') as mock_handler: with patch('azurelinuxagent.ga.monitor.get_monitor_handler') as mock_monitor: with patch('azurelinuxagent.ga.env.get_env_handler') as mock_env: with patch('time.sleep', side_effect=iterator) as mock_sleep: with patch('sys.exit') as mock_exit: - - self.update_handler.run() + if isinstance(os.getppid, MagicMock): + self.update_handler.run() + else: + with patch('os.getppid', return_value=42): + self.update_handler.run() self.assertEqual(1, mock_handler.call_count) self.assertEqual(mock_handler.return_value.method_calls, calls) @@ -1016,6 +1206,7 @@ class TestUpdate(UpdateTestCase): self.assertEqual(1, mock_monitor.call_count) self.assertEqual(1, mock_env.call_count) self.assertEqual(1, mock_exit.call_count) + return def test_run(self): @@ -1027,8 +1218,30 @@ class TestUpdate(UpdateTestCase): return def test_run_stops_if_update_available(self): - self.update_handler._ensure_latest_agent = Mock(return_value=True) + self.update_handler._upgrade_available = Mock(return_value=True) + self._test_run(invocations=0, calls=[], enable_updates=True) + return + + def test_run_stops_if_orphaned(self): + with patch('os.getppid', return_value=1): + self._test_run(invocations=0, calls=[], enable_updates=True) + return + + def test_run_clears_sentinal_on_successful_exit(self): + self._test_run() + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + return + + def test_run_leaves_sentinal_on_unsuccessful_exit(self): + self.update_handler._upgrade_available = Mock(side_effect=Exception) self._test_run(invocations=0, calls=[], enable_updates=True) + self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path())) + return + + def test_run_emits_restart_event(self): + self.update_handler._emit_restart_event = Mock() + self._test_run() + self.assertEqual(1, self.update_handler._emit_restart_event.call_count) return def test_set_agents_sets_agents(self): @@ -1050,6 +1263,59 @@ class TestUpdate(UpdateTestCase): v = a.version return + def test_set_sentinal(self): + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + self.update_handler._set_sentinal() + self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path())) + return + + def test_set_sentinal_writes_current_agent(self): + self.update_handler._set_sentinal() + self.assertTrue( + fileutil.read_file(self.update_handler._sentinal_file_path()), + CURRENT_AGENT) + return + + def test_shutdown(self): + self.update_handler._set_sentinal() + self.update_handler._shutdown() + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + return + + def test_shutdown_ignores_missing_sentinal_file(self): + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + self.update_handler._shutdown() + self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) + return + + def test_shutdown_ignores_exceptions(self): + self.update_handler._set_sentinal() + + try: + with patch("os.remove", side_effect=Exception): + self.update_handler._shutdown() + except Exception as e: + self.assertTrue(False, "Unexpected exception") + return + + def test_write_pid_file(self): + for n in range(1112): + fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1)) + with patch('os.getpid', return_value=1112): + previous_pid_file, pid_file = self.update_handler._write_pid_file() + self.assertEqual("1111_waagent.pid", os.path.basename(previous_pid_file)) + self.assertEqual("1112_waagent.pid", os.path.basename(pid_file)) + self.assertEqual(fileutil.read_file(pid_file), ustr(1112)) + return + + def test_write_pid_file_ignores_exceptions(self): + with patch('azurelinuxagent.common.utils.fileutil.write_file', side_effect=Exception): + with patch('os.getpid', return_value=42): + previous_pid_file, pid_file = self.update_handler._write_pid_file() + self.assertEqual(None, previous_pid_file) + self.assertEqual(None, pid_file) + return + class ChildMock(Mock): def __init__(self, return_value=0, side_effect=None): @@ -1061,8 +1327,9 @@ class ChildMock(Mock): class ProtocolMock(object): - def __init__(self, family="TestAgent", etag=42, versions=None): + def __init__(self, family="TestAgent", etag=42, versions=None, client=None): self.family = family + self.client = client self.etag = etag self.versions = versions if versions is not None else [] self.create_manifests() |