summaryrefslogtreecommitdiff
path: root/tests/ga/test_update.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ga/test_update.py')
-rw-r--r--tests/ga/test_update.py1135
1 files changed, 1135 insertions, 0 deletions
diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py
new file mode 100644
index 0000000..74804fb
--- /dev/null
+++ b/tests/ga/test_update.py
@@ -0,0 +1,1135 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+from __future__ import print_function
+
+import copy
+import glob
+import json
+import os
+import platform
+import random
+import subprocess
+import sys
+import tempfile
+import zipfile
+
+from tests.protocol.mockwiredata import *
+from tests.tools import *
+
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.fileutil as fileutil
+
+from azurelinuxagent.common.exception import UpdateError
+from azurelinuxagent.common.protocol.restapi import *
+from azurelinuxagent.common.protocol.wire import *
+from azurelinuxagent.common.utils.flexible_version import FlexibleVersion
+from azurelinuxagent.common.version import AGENT_NAME, AGENT_VERSION
+from azurelinuxagent.ga.update import *
+
+NO_ERROR = {
+ "last_failure" : 0.0,
+ "failure_count" : 0,
+ "was_fatal" : False
+}
+
+WITH_ERROR = {
+ "last_failure" : 42.42,
+ "failure_count" : 2,
+ "was_fatal" : False
+}
+
+EMPTY_MANIFEST = {
+ "name": "WALinuxAgent",
+ "version": 1.0,
+ "handlerManifest": {
+ "installCommand": "",
+ "uninstallCommand": "",
+ "updateCommand": "",
+ "enableCommand": "",
+ "disableCommand": "",
+ "rebootAfterInstall": False,
+ "reportHeartbeat": False
+ }
+}
+
+
+def get_agent_pkgs(in_dir=os.path.join(data_dir, "ga")):
+ path = os.path.join(in_dir, AGENT_PKG_GLOB)
+ return glob.glob(path)
+
+
+def get_agents(in_dir=os.path.join(data_dir, "ga")):
+ path = os.path.join(in_dir, AGENT_DIR_GLOB)
+ return [a for a in glob.glob(path) if os.path.isdir(a)]
+
+
+def get_agent_file_path():
+ return get_agent_pkgs()[0]
+
+
+def get_agent_file_name():
+ return os.path.basename(get_agent_file_path())
+
+
+def get_agent_path():
+ return fileutil.trim_ext(get_agent_file_path(), "zip")
+
+
+def get_agent_name():
+ return os.path.basename(get_agent_path())
+
+
+def get_agent_version():
+ return FlexibleVersion(get_agent_name().split("-")[1])
+
+
+def faux_logger():
+ print("STDOUT message")
+ print("STDERR message", file=sys.stderr)
+ return DEFAULT
+
+
+class UpdateTestCase(AgentTestCase):
+ def setUp(self):
+ AgentTestCase.setUp(self)
+ return
+
+ def agent_bin(self, version):
+ return "bin/{0}-{1}.egg".format(AGENT_NAME, version)
+
+ def agent_count(self):
+ return len(self.agent_dirs())
+
+ def agent_dirs(self):
+ return get_agents(in_dir=self.tmp_dir)
+
+ def agent_dir(self, version):
+ return os.path.join(self.tmp_dir, "{0}-{1}".format(AGENT_NAME, version))
+
+ def agent_paths(self):
+ paths = glob.glob(os.path.join(self.tmp_dir, "*"))
+ paths.sort()
+ return paths
+
+ def agent_pkgs(self):
+ return get_agent_pkgs(in_dir=self.tmp_dir)
+
+ def agent_versions(self):
+ v = [FlexibleVersion(AGENT_DIR_PATTERN.match(a).group(1)) for a in self.agent_dirs()]
+ v.sort(reverse=True)
+ return v
+
+ def get_error_file(self, error_data=NO_ERROR):
+ fp = tempfile.NamedTemporaryFile(mode="w")
+ json.dump(error_data if error_data is not None else NO_ERROR, fp)
+ fp.seek(0)
+ return fp
+
+ def create_error(self, error_data=NO_ERROR):
+ with self.get_error_file(error_data) as path:
+ return GuestAgentError(path.name)
+
+ def copy_agents(self, *agents):
+ if len(agents) <= 0:
+ agents = get_agent_pkgs()
+ for agent in agents:
+ fileutil.copy_file(agent, to_dir=self.tmp_dir)
+ return
+
+ def expand_agents(self):
+ for agent in self.agent_pkgs():
+ zipfile.ZipFile(agent).extractall(os.path.join(
+ self.tmp_dir,
+ fileutil.trim_ext(agent, "zip")))
+ return
+
+ def prepare_agents(self, base_version=AGENT_VERSION, count=5, is_available=True):
+ base_v = FlexibleVersion(base_version)
+
+ # Ensure the test data is copied over
+ agent_count = self.agent_count()
+ if agent_count <= 0:
+ self.copy_agents(get_agent_pkgs()[0])
+ self.expand_agents()
+ count -= 1
+
+ # Determine the most recent agent version
+ versions = self.agent_versions()
+ src_v = FlexibleVersion(str(versions[0]))
+
+ # If the most recent agent is newer the minimum requested, use the agent version
+ if base_v < src_v:
+ base_v = src_v
+
+ # Create agent packages and directories
+ return self.replicate_agents(
+ src_v=src_v,
+ count=count-agent_count,
+ is_available=is_available)
+
+ def remove_agents(self):
+ for agent in self.agent_paths():
+ try:
+ if os.path.isfile(agent):
+ os.remove(agent)
+ else:
+ shutil.rmtree(agent)
+ except:
+ pass
+ return
+
+ def replicate_agents(
+ self,
+ count=5,
+ src_v=AGENT_VERSION,
+ is_available=True,
+ increment=1):
+ from_path = self.agent_dir(src_v)
+ dst_v = FlexibleVersion(str(src_v))
+ for i in range(0,count):
+ dst_v += increment
+ to_path = self.agent_dir(dst_v)
+ shutil.copyfile(from_path + ".zip", to_path + ".zip")
+ shutil.copytree(from_path, to_path)
+ shutil.move(
+ os.path.join(to_path, self.agent_bin(src_v)),
+ os.path.join(to_path, self.agent_bin(dst_v)))
+ if not is_available:
+ GuestAgent(to_path).mark_failure(is_fatal=True)
+
+ return dst_v
+
+
+class TestGuestAgentError(UpdateTestCase):
+ def test_creation(self):
+ self.assertRaises(TypeError, GuestAgentError)
+ self.assertRaises(UpdateError, GuestAgentError, None)
+
+ with self.get_error_file(error_data=WITH_ERROR) as path:
+ err = GuestAgentError(path.name)
+ self.assertEqual(path.name, err.path)
+ self.assertNotEqual(None, err)
+
+ self.assertEqual(WITH_ERROR["last_failure"], err.last_failure)
+ self.assertEqual(WITH_ERROR["failure_count"], err.failure_count)
+ self.assertEqual(WITH_ERROR["was_fatal"], err.was_fatal)
+ return
+
+ def test_clear(self):
+ with self.get_error_file(error_data=WITH_ERROR) as path:
+ err = GuestAgentError(path.name)
+ self.assertEqual(path.name, err.path)
+ self.assertNotEqual(None, err)
+
+ err.clear()
+ self.assertEqual(NO_ERROR["last_failure"], err.last_failure)
+ self.assertEqual(NO_ERROR["failure_count"], err.failure_count)
+ self.assertEqual(NO_ERROR["was_fatal"], err.was_fatal)
+ return
+
+ def test_load_preserves_error_state(self):
+ with self.get_error_file(error_data=WITH_ERROR) as path:
+ err = GuestAgentError(path.name)
+ self.assertEqual(path.name, err.path)
+ self.assertNotEqual(None, err)
+
+ with self.get_error_file(error_data=NO_ERROR):
+ err.load()
+ self.assertEqual(WITH_ERROR["last_failure"], err.last_failure)
+ self.assertEqual(WITH_ERROR["failure_count"], err.failure_count)
+ self.assertEqual(WITH_ERROR["was_fatal"], err.was_fatal)
+ return
+
+ def test_save(self):
+ err1 = self.create_error()
+ err1.mark_failure()
+ err1.mark_failure(is_fatal=True)
+
+ err2 = self.create_error(err1.to_json())
+ self.assertEqual(err1.last_failure, err2.last_failure)
+ self.assertEqual(err1.failure_count, err2.failure_count)
+ self.assertEqual(err1.was_fatal, err2.was_fatal)
+
+ def test_mark_failure(self):
+ err = self.create_error()
+ self.assertFalse(err.is_blacklisted)
+
+ for i in range(0, MAX_FAILURE):
+ err.mark_failure()
+
+ # Agent failed >= MAX_FAILURE, it should be blacklisted
+ self.assertTrue(err.is_blacklisted)
+ self.assertEqual(MAX_FAILURE, err.failure_count)
+
+ # Clear old failure does not clear recent failure
+ err.clear_old_failure()
+ self.assertTrue(err.is_blacklisted)
+
+ # Clear does remove old, outdated failures
+ err.last_failure -= RETAIN_INTERVAL * 2
+ err.clear_old_failure()
+ self.assertFalse(err.is_blacklisted)
+ return
+
+ def test_mark_failure_permanent(self):
+ err = self.create_error()
+
+ self.assertFalse(err.is_blacklisted)
+
+ # Fatal errors immediately blacklist
+ err.mark_failure(is_fatal=True)
+ self.assertTrue(err.is_blacklisted)
+ self.assertTrue(err.failure_count < MAX_FAILURE)
+ return
+
+ def test_str(self):
+ err = self.create_error(error_data=NO_ERROR)
+ s = "Last Failure: {0}, Total Failures: {1}, Fatal: {2}".format(
+ NO_ERROR["last_failure"],
+ NO_ERROR["failure_count"],
+ NO_ERROR["was_fatal"])
+ self.assertEqual(s, str(err))
+
+ err = self.create_error(error_data=WITH_ERROR)
+ s = "Last Failure: {0}, Total Failures: {1}, Fatal: {2}".format(
+ WITH_ERROR["last_failure"],
+ WITH_ERROR["failure_count"],
+ WITH_ERROR["was_fatal"])
+ self.assertEqual(s, str(err))
+ return
+
+
+class TestGuestAgent(UpdateTestCase):
+ def setUp(self):
+ UpdateTestCase.setUp(self)
+ self.copy_agents(get_agent_file_path())
+ self.agent_path = os.path.join(self.tmp_dir, get_agent_name())
+ return
+
+ def tearDown(self):
+ self.remove_agents()
+ return
+
+ def test_creation(self):
+ self.assertRaises(UpdateError, GuestAgent, "A very bad file name")
+ n = "{0}-a.bad.version".format(AGENT_NAME)
+ self.assertRaises(UpdateError, GuestAgent, n)
+
+ agent = GuestAgent(path=self.agent_path)
+ self.assertNotEqual(None, agent)
+ self.assertEqual(get_agent_name(), agent.name)
+ self.assertEqual(get_agent_version(), agent.version)
+
+ self.assertEqual(self.agent_path, agent.get_agent_dir())
+
+ path = os.path.join(self.agent_path, AGENT_MANIFEST_FILE)
+ self.assertEqual(path, agent.get_agent_manifest_path())
+
+ self.assertEqual(
+ os.path.join(self.agent_path, AGENT_ERROR_FILE),
+ agent.get_agent_error_file())
+
+ path = ".".join((os.path.join(conf.get_lib_dir(), get_agent_name()), "zip"))
+ self.assertEqual(path, agent.get_agent_pkg_path())
+
+ self.assertTrue(agent.is_downloaded)
+ # Note: Agent will get blacklisted since the package for this test is invalid
+ self.assertTrue(agent.is_blacklisted)
+ self.assertFalse(agent.is_available)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_clear_error(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ agent.mark_failure(is_fatal=True)
+
+ self.assertTrue(agent.error.last_failure > 0.0)
+ self.assertEqual(1, agent.error.failure_count)
+ self.assertTrue(agent.is_blacklisted)
+ self.assertEqual(agent.is_blacklisted, agent.error.is_blacklisted)
+
+ agent.clear_error()
+ self.assertEqual(0.0, agent.error.last_failure)
+ self.assertEqual(0, agent.error.failure_count)
+ self.assertFalse(agent.is_blacklisted)
+ self.assertEqual(agent.is_blacklisted, agent.error.is_blacklisted)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_is_available(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+
+ self.assertFalse(agent.is_available)
+ agent._unpack()
+ self.assertTrue(agent.is_available)
+
+ agent.mark_failure(is_fatal=True)
+ self.assertFalse(agent.is_available)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_is_blacklisted(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(agent.is_blacklisted)
+
+ agent._unpack()
+ self.assertFalse(agent.is_blacklisted)
+ self.assertEqual(agent.is_blacklisted, agent.error.is_blacklisted)
+
+ agent.mark_failure(is_fatal=True)
+ self.assertTrue(agent.is_blacklisted)
+ self.assertEqual(agent.is_blacklisted, agent.error.is_blacklisted)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_is_downloaded(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(agent.is_downloaded)
+ agent._unpack()
+ self.assertTrue(agent.is_downloaded)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_mark_failure(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ agent.mark_failure()
+ self.assertEqual(1, agent.error.failure_count)
+
+ agent.mark_failure(is_fatal=True)
+ self.assertEqual(2, agent.error.failure_count)
+ self.assertTrue(agent.is_blacklisted)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_unpack(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(os.path.isdir(agent.get_agent_dir()))
+ agent._unpack()
+ self.assertTrue(os.path.isdir(agent.get_agent_dir()))
+ self.assertTrue(os.path.isfile(agent.get_agent_manifest_path()))
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_unpack_fail(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(os.path.isdir(agent.get_agent_dir()))
+ os.remove(agent.get_agent_pkg_path())
+ self.assertRaises(UpdateError, agent._unpack)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_load_manifest(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ agent._unpack()
+ agent._load_manifest()
+ self.assertEqual(agent.manifest.get_enable_command(), agent.get_agent_cmd())
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_load_manifest_missing(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(os.path.isdir(agent.get_agent_dir()))
+ agent._unpack()
+ os.remove(agent.get_agent_manifest_path())
+ self.assertRaises(UpdateError, agent._load_manifest)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_load_manifest_is_empty(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(os.path.isdir(agent.get_agent_dir()))
+ agent._unpack()
+ self.assertTrue(os.path.isfile(agent.get_agent_manifest_path()))
+
+ with open(agent.get_agent_manifest_path(), "w") as file:
+ json.dump(EMPTY_MANIFEST, file)
+ self.assertRaises(UpdateError, agent._load_manifest)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ def test_load_manifest_is_malformed(self, mock_ensure):
+ agent = GuestAgent(path=self.agent_path)
+ self.assertFalse(os.path.isdir(agent.get_agent_dir()))
+ agent._unpack()
+ self.assertTrue(os.path.isfile(agent.get_agent_manifest_path()))
+
+ with open(agent.get_agent_manifest_path(), "w") as file:
+ file.write("This is not JSON data")
+ self.assertRaises(UpdateError, agent._load_manifest)
+ return
+
+ def test_load_error(self):
+ agent = GuestAgent(path=self.agent_path)
+ agent.error = None
+
+ agent._load_error()
+ self.assertTrue(agent.error is not None)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ @patch("azurelinuxagent.ga.update.restutil.http_get")
+ def test_download(self, mock_http_get, mock_ensure):
+ self.remove_agents()
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ agent_pkg = load_bin_data(os.path.join("ga", get_agent_file_name()))
+ mock_http_get.return_value= ResponseMock(response=agent_pkg)
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+ agent._download()
+
+ self.assertTrue(os.path.isfile(agent.get_agent_pkg_path()))
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ @patch("azurelinuxagent.ga.update.restutil.http_get")
+ def test_download_fail(self, mock_http_get, mock_ensure):
+ self.remove_agents()
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ mock_http_get.return_value= ResponseMock(status=restutil.httpclient.SERVICE_UNAVAILABLE)
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+
+ self.assertRaises(UpdateError, agent._download)
+ self.assertFalse(os.path.isfile(agent.get_agent_pkg_path()))
+ self.assertFalse(agent.is_downloaded)
+ return
+
+ @patch("azurelinuxagent.ga.update.restutil.http_get")
+ def test_ensure_downloaded(self, mock_http_get):
+ self.remove_agents()
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ agent_pkg = load_bin_data(os.path.join("ga", get_agent_file_name()))
+ mock_http_get.return_value= ResponseMock(response=agent_pkg)
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+
+ self.assertTrue(os.path.isfile(agent.get_agent_manifest_path()))
+ self.assertTrue(agent.is_downloaded)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._download", side_effect=UpdateError)
+ def test_ensure_downloaded_download_fails(self, mock_download):
+ self.remove_agents()
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+
+ self.assertEqual(1, agent.error.failure_count)
+ self.assertFalse(agent.error.was_fatal)
+ self.assertFalse(agent.is_blacklisted)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._download")
+ @patch("azurelinuxagent.ga.update.GuestAgent._unpack", side_effect=UpdateError)
+ def test_ensure_downloaded_unpack_fails(self, mock_download, mock_unpack):
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+
+ self.assertEqual(1, agent.error.failure_count)
+ self.assertTrue(agent.error.was_fatal)
+ self.assertTrue(agent.is_blacklisted)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._download")
+ @patch("azurelinuxagent.ga.update.GuestAgent._unpack")
+ @patch("azurelinuxagent.ga.update.GuestAgent._load_manifest", side_effect=UpdateError)
+ def test_ensure_downloaded_load_manifest_fails(self, mock_download, mock_unpack, mock_manifest):
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+
+ self.assertEqual(1, agent.error.failure_count)
+ self.assertTrue(agent.error.was_fatal)
+ self.assertTrue(agent.is_blacklisted)
+ return
+
+ @patch("azurelinuxagent.ga.update.GuestAgent._download")
+ @patch("azurelinuxagent.ga.update.GuestAgent._unpack")
+ @patch("azurelinuxagent.ga.update.GuestAgent._load_manifest")
+ def test_ensure_download_skips_blacklisted(self, mock_download, mock_unpack, mock_manifest):
+ agent = GuestAgent(path=self.agent_path)
+ agent.clear_error()
+ agent.mark_failure(is_fatal=True)
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri())
+ agent = GuestAgent(pkg=pkg)
+
+ self.assertEqual(1, agent.error.failure_count)
+ self.assertTrue(agent.error.was_fatal)
+ self.assertTrue(agent.is_blacklisted)
+ self.assertEqual(0, mock_download.call_count)
+ self.assertEqual(0, mock_unpack.call_count)
+ self.assertEqual(0, mock_manifest.call_count)
+ return
+
+
+class TestUpdate(UpdateTestCase):
+ def setUp(self):
+ UpdateTestCase.setUp(self)
+ self.update_handler = get_update_handler()
+ return
+
+ def test_creation(self):
+ self.assertTrue(self.update_handler.running)
+
+ self.assertEqual(None, self.update_handler.last_etag)
+ self.assertEqual(None, self.update_handler.last_attempt_time)
+
+ self.assertEqual(0, len(self.update_handler.agents))
+
+ self.assertEqual(None, self.update_handler.child_agent)
+ self.assertEqual(None, self.update_handler.child_launch_time)
+ self.assertEqual(0, self.update_handler.child_launch_attempts)
+ self.assertEqual(None, self.update_handler.child_process)
+
+ self.assertEqual(None, self.update_handler.signal_handler)
+ return
+
+ def _test_ensure_latest_agent(
+ self,
+ base_version=FlexibleVersion(AGENT_VERSION),
+ protocol=None,
+ versions=None):
+
+ latest_version = self.prepare_agents()
+ if versions is None or len(versions) <= 0:
+ versions = [latest_version]
+
+ etag = self.update_handler.last_etag if self.update_handler.last_etag is not None else 42
+ if protocol is None:
+ protocol = ProtocolMock(etag=etag, versions=versions)
+ self.update_handler.protocol_util = protocol
+ conf.get_autoupdate_gafamily = Mock(return_value=protocol.family)
+
+ return self.update_handler._ensure_latest_agent(base_version=base_version)
+
+ def test_ensure_latest_agent_returns_true_on_first_use(self):
+ self.assertEqual(None, self.update_handler.last_etag)
+ self.assertTrue(self._test_ensure_latest_agent())
+ return
+
+ def test_ensure_latest_agent_includes_old_agents(self):
+ self.prepare_agents()
+
+ old_count = FlexibleVersion(AGENT_VERSION).version[-1]
+ old_version = self.agent_versions()[-1]
+
+ self.replicate_agents(src_v=old_version, count=old_count, increment=-1)
+ all_count = len(self.agent_versions())
+
+ self.assertTrue(self._test_ensure_latest_agent(versions=self.agent_versions()))
+ self.assertEqual(all_count, len(self.update_handler.agents))
+ return
+
+ def test_ensure_lastest_agent_purges_old_agents(self):
+ self.prepare_agents()
+ agent_count = self.agent_count()
+ self.assertEqual(5, agent_count)
+
+ agent_versions = self.agent_versions()[:3]
+ self.assertTrue(self._test_ensure_latest_agent(versions=agent_versions))
+ self.assertEqual(len(agent_versions), len(self.update_handler.agents))
+ self.assertEqual(agent_versions, self.agent_versions())
+ return
+
+ def test_ensure_latest_agent_skips_if_too_frequent(self):
+ conf.get_autoupdate_frequency = Mock(return_value=10000)
+ self.update_handler.last_attempt_time = time.time()
+ self.assertFalse(self._test_ensure_latest_agent())
+ return
+
+ def test_ensure_latest_agent_skips_when_etag_matches(self):
+ self.update_handler.last_etag = 42
+ self.assertFalse(self._test_ensure_latest_agent())
+ return
+
+ def test_ensure_latest_agent_skips_if_when_no_new_versions(self):
+ self.prepare_agents()
+ base_version = self.agent_versions()[0] + 1
+ self.assertFalse(self._test_ensure_latest_agent(base_version=base_version))
+ return
+
+ def test_ensure_latest_agent_skips_when_no_versions(self):
+ self.assertFalse(self._test_ensure_latest_agent(protocol=ProtocolMock()))
+ return
+
+ def test_ensure_latest_agent_skips_when_updates_are_disabled(self):
+ conf.get_autoupdate_enabled = Mock(return_value=False)
+ self.assertFalse(self._test_ensure_latest_agent())
+ return
+
+ def test_ensure_latest_agent_sorts(self):
+ self.prepare_agents()
+ self._test_ensure_latest_agent()
+
+ v = FlexibleVersion("100000")
+ for a in self.update_handler.agents:
+ self.assertTrue(v > a.version)
+ v = a.version
+ return
+
+ def _test_evaluate_agent_health(self, child_agent_index=0):
+ self.prepare_agents()
+
+ latest_agent = self.update_handler.get_latest_agent()
+ self.assertTrue(latest_agent.is_available)
+ self.assertFalse(latest_agent.is_blacklisted)
+ self.assertTrue(len(self.update_handler.agents) > 1)
+
+ child_agent = self.update_handler.agents[child_agent_index]
+ self.assertTrue(child_agent.is_available)
+ self.assertFalse(child_agent.is_blacklisted)
+ self.update_handler.child_agent = child_agent
+
+ self.update_handler._evaluate_agent_health(latest_agent)
+ return
+
+ def test_evaluate_agent_health_ignores_installed_agent(self):
+ self.update_handler._evaluate_agent_health(None)
+ return
+
+ def test_evaluate_agent_health_raises_exception_for_restarting_agent(self):
+ self.update_handler.child_launch_time = time.time() - (4 * 60)
+ self.update_handler.child_launch_attempts = CHILD_LAUNCH_RESTART_MAX - 1
+ self.assertRaises(Exception, self._test_evaluate_agent_health)
+ return
+
+ def test_evaluate_agent_health_will_not_raise_exception_for_long_restarts(self):
+ self.update_handler.child_launch_time = time.time() - 24 * 60
+ self.update_handler.child_launch_attempts = CHILD_LAUNCH_RESTART_MAX
+ self._test_evaluate_agent_health()
+ return
+
+ def test_evaluate_agent_health_will_not_raise_exception_too_few_restarts(self):
+ self.update_handler.child_launch_time = time.time()
+ self.update_handler.child_launch_attempts = CHILD_LAUNCH_RESTART_MAX - 2
+ self._test_evaluate_agent_health()
+ return
+
+ def test_evaluate_agent_health_resets_with_new_agent(self):
+ self.update_handler.child_launch_time = time.time() - (4 * 60)
+ self.update_handler.child_launch_attempts = CHILD_LAUNCH_RESTART_MAX - 1
+ self._test_evaluate_agent_health(child_agent_index=1)
+ self.assertEqual(1, self.update_handler.child_launch_attempts)
+ return
+
+ def test_filter_blacklisted_agents(self):
+ self.prepare_agents()
+
+ self.update_handler._set_agents([GuestAgent(path=path) for path in self.agent_dirs()])
+ self.assertEqual(len(self.agent_dirs()), len(self.update_handler.agents))
+
+ kept_agents = self.update_handler.agents[1::2]
+ blacklisted_agents = self.update_handler.agents[::2]
+ for agent in blacklisted_agents:
+ agent.mark_failure(is_fatal=True)
+ self.update_handler._filter_blacklisted_agents()
+ self.assertEqual(kept_agents, self.update_handler.agents)
+ return
+
+ def test_get_latest_agent(self):
+ latest_version = self.prepare_agents()
+
+ latest_agent = self.update_handler.get_latest_agent()
+ self.assertEqual(len(get_agents(self.tmp_dir)), len(self.update_handler.agents))
+ self.assertEqual(latest_version, latest_agent.version)
+ return
+
+ def test_get_latest_agent_no_updates(self):
+ self.assertEqual(None, self.update_handler.get_latest_agent())
+ return
+
+ def test_get_latest_agent_skip_updates(self):
+ conf.get_autoupdate_enabled = Mock(return_value=False)
+ self.assertEqual(None, self.update_handler.get_latest_agent())
+ return
+
+ def test_get_latest_agent_skips_unavailable(self):
+ self.prepare_agents()
+ prior_agent = self.update_handler.get_latest_agent()
+
+ latest_version = self.prepare_agents(count=self.agent_count()+1, is_available=False)
+ latest_path = os.path.join(self.tmp_dir, "{0}-{1}".format(AGENT_NAME, latest_version))
+ self.assertFalse(GuestAgent(latest_path).is_available)
+
+ latest_agent = self.update_handler.get_latest_agent()
+ self.assertTrue(latest_agent.version < latest_version)
+ self.assertEqual(latest_agent.version, prior_agent.version)
+ return
+
+ def test_load_agents(self):
+ self.prepare_agents()
+
+ self.assertTrue(0 <= len(self.update_handler.agents))
+ self.update_handler._load_agents()
+ self.assertEqual(len(get_agents(self.tmp_dir)), len(self.update_handler.agents))
+ return
+
+ def test_load_agents_does_not_reload(self):
+ self.prepare_agents()
+
+ self.update_handler._load_agents()
+ agents = self.update_handler.agents
+
+ self.update_handler._load_agents()
+ self.assertEqual(agents, self.update_handler.agents)
+ return
+
+ def test_load_agents_sorts(self):
+ self.prepare_agents()
+ self.update_handler._load_agents()
+
+ v = FlexibleVersion("100000")
+ for a in self.update_handler.agents:
+ self.assertTrue(v > a.version)
+ v = a.version
+ return
+
+ def test_purge_agents(self):
+ self.prepare_agents()
+ self.update_handler._load_agents()
+
+ # Ensure at least three agents initially exist
+ self.assertTrue(2 < len(self.update_handler.agents))
+
+ # Purge every other agent
+ kept_agents = self.update_handler.agents[1::2]
+ purged_agents = self.update_handler.agents[::2]
+
+ # Reload and assert only the kept agents remain on disk
+ self.update_handler.agents = kept_agents
+ self.update_handler._purge_agents()
+ self.update_handler._load_agents()
+ self.assertEqual(
+ [agent.version for agent in kept_agents],
+ [agent.version for agent in self.update_handler.agents])
+
+ # Ensure both directories and packages are removed
+ for agent in purged_agents:
+ agent_path = os.path.join(self.tmp_dir, "{0}-{1}".format(AGENT_NAME, agent.version))
+ self.assertFalse(os.path.exists(agent_path))
+ self.assertFalse(os.path.exists(agent_path + ".zip"))
+
+ # Ensure kept agent directories and packages remain
+ for agent in kept_agents:
+ agent_path = os.path.join(self.tmp_dir, "{0}-{1}".format(AGENT_NAME, agent.version))
+ self.assertTrue(os.path.exists(agent_path))
+ self.assertTrue(os.path.exists(agent_path + ".zip"))
+ return
+
+ def _test_run_latest(self, mock_child=None, mock_time=None):
+ if mock_child is None:
+ mock_child = ChildMock()
+ if mock_time is None:
+ mock_time = TimeMock()
+
+ with patch('subprocess.Popen', return_value=mock_child) as mock_popen:
+ with patch('time.time', side_effect=mock_time.time):
+ with patch('time.sleep', return_value=mock_time.sleep):
+ self.update_handler.run_latest()
+ self.assertEqual(1, mock_popen.call_count)
+
+ return mock_popen.call_args
+
+ def test_run_latest(self):
+ self.prepare_agents()
+
+ agent = self.update_handler.get_latest_agent()
+ args, kwargs = self._test_run_latest()
+ cmds = shlex.split(agent.get_agent_cmd())
+ if cmds[0].lower() == "python":
+ cmds[0] = get_python_cmd()
+
+ self.assertEqual(args[0], cmds)
+ self.assertEqual(True, 'cwd' in kwargs)
+ self.assertEqual(agent.get_agent_dir(), kwargs['cwd'])
+ return
+
+ def test_run_latest_polls_and_waits_for_success(self):
+ mock_child = ChildMock(return_value=None)
+ mock_time = TimeMock(time_increment=CHILD_HEALTH_INTERVAL/3)
+ self._test_run_latest(mock_child=mock_child, mock_time=mock_time)
+ self.assertEqual(2, mock_child.poll.call_count)
+ self.assertEqual(1, mock_child.wait.call_count)
+ return
+
+ def test_run_latest_polling_stops_at_success(self):
+ mock_child = ChildMock(return_value=0)
+ mock_time = TimeMock(time_increment=CHILD_HEALTH_INTERVAL/3)
+ self._test_run_latest(mock_child=mock_child, mock_time=mock_time)
+ self.assertEqual(1, mock_child.poll.call_count)
+ self.assertEqual(0, mock_child.wait.call_count)
+ return
+
+ def test_run_latest_polling_stops_at_failure(self):
+ mock_child = ChildMock(return_value=42)
+ mock_time = TimeMock()
+ self._test_run_latest(mock_child=mock_child, mock_time=mock_time)
+ self.assertEqual(1, mock_child.poll.call_count)
+ self.assertEqual(0, mock_child.wait.call_count)
+ self.assertEqual(2, mock_time.time_call_count)
+ return
+
+ def test_run_latest_defaults_to_current(self):
+ self.assertEqual(None, self.update_handler.get_latest_agent())
+
+ args, kwargs = self._test_run_latest()
+
+ self.assertEqual(args[0], [get_python_cmd(), "-u", sys.argv[0], "-run-exthandlers"])
+ self.assertEqual(True, 'cwd' in kwargs)
+ self.assertEqual(os.getcwd(), kwargs['cwd'])
+ return
+
+ def test_run_latest_forwards_output(self):
+ try:
+ tempdir = tempfile.mkdtemp()
+ stdout_path = os.path.join(tempdir, "stdout")
+ stderr_path = os.path.join(tempdir, "stderr")
+
+ with open(stdout_path, "w") as stdout:
+ with open(stderr_path, "w") as stderr:
+ saved_stdout, sys.stdout = sys.stdout, stdout
+ saved_stderr, sys.stderr = sys.stderr, stderr
+ try:
+ self._test_run_latest(mock_child=ChildMock(side_effect=faux_logger))
+ finally:
+ sys.stdout = saved_stdout
+ sys.stderr = saved_stderr
+
+ with open(stdout_path, "r") as stdout:
+ self.assertEqual(1, len(stdout.readlines()))
+ with open(stderr_path, "r") as stderr:
+ self.assertEqual(1, len(stderr.readlines()))
+ finally:
+ shutil.rmtree(tempdir, True)
+ return
+
+ def test_run_latest_nonzero_code_marks_failures(self):
+ # logger.add_logger_appender(logger.AppenderType.STDOUT)
+ self.prepare_agents()
+
+ latest_agent = self.update_handler.get_latest_agent()
+ self.assertTrue(latest_agent.is_available)
+ self.assertEqual(0.0, latest_agent.error.last_failure)
+ self.assertEqual(0, latest_agent.error.failure_count)
+
+ self._test_run_latest(mock_child=ChildMock(return_value=1))
+
+ self.assertTrue(latest_agent.is_available)
+ self.assertNotEqual(0.0, latest_agent.error.last_failure)
+ self.assertEqual(1, latest_agent.error.failure_count)
+ return
+
+ def test_run_latest_exception_blacklists(self):
+ # logger.add_logger_appender(logger.AppenderType.STDOUT)
+ self.prepare_agents()
+
+ latest_agent = self.update_handler.get_latest_agent()
+ self.assertTrue(latest_agent.is_available)
+ self.assertEqual(0.0, latest_agent.error.last_failure)
+ self.assertEqual(0, latest_agent.error.failure_count)
+
+ self._test_run_latest(mock_child=ChildMock(side_effect=Exception("Force blacklisting")))
+
+ self.assertFalse(latest_agent.is_available)
+ self.assertTrue(latest_agent.error.is_blacklisted)
+ self.assertNotEqual(0.0, latest_agent.error.last_failure)
+ self.assertEqual(1, latest_agent.error.failure_count)
+ return
+
+ @patch('signal.signal')
+ def test_run_latest_captures_signals(self, mock_signal):
+ self._test_run_latest()
+ self.assertEqual(1, mock_signal.call_count)
+ return
+
+ @patch('signal.signal')
+ def test_run_latest_creates_only_one_signal_handler(self, mock_signal):
+ self.update_handler.signal_handler = "Not None"
+ self._test_run_latest()
+ self.assertEqual(0, mock_signal.call_count)
+ return
+
+ def _test_run(self, invocations=1, calls=[call.run()], enable_updates=False):
+ conf.get_autoupdate_enabled = Mock(return_value=enable_updates)
+
+ # Note:
+ # - Python only allows mutations of objects to which a function has
+ # a reference. Incrementing an integer directly changes the
+ # reference. Incrementing an item of a list changes an item to
+ # which the code has a reference.
+ # See http://stackoverflow.com/questions/26408941/python-nested-functions-and-variable-scope
+ iterations = [0]
+ def iterator(*args, **kwargs):
+ iterations[0] += 1
+ if iterations[0] >= invocations:
+ self.update_handler.running = False
+ return
+
+ calls = calls * invocations
+
+ with patch('azurelinuxagent.ga.exthandlers.get_exthandlers_handler') as mock_handler:
+ with patch('azurelinuxagent.ga.monitor.get_monitor_handler') as mock_monitor:
+ with patch('azurelinuxagent.ga.env.get_env_handler') as mock_env:
+ with patch('time.sleep', side_effect=iterator) as mock_sleep:
+ with patch('sys.exit') as mock_exit:
+
+ self.update_handler.run()
+
+ self.assertEqual(1, mock_handler.call_count)
+ self.assertEqual(mock_handler.return_value.method_calls, calls)
+ self.assertEqual(invocations, mock_sleep.call_count)
+ self.assertEqual(1, mock_monitor.call_count)
+ self.assertEqual(1, mock_env.call_count)
+ self.assertEqual(1, mock_exit.call_count)
+ return
+
+ def test_run(self):
+ self._test_run()
+ return
+
+ def test_run_keeps_running(self):
+ self._test_run(invocations=15)
+ return
+
+ def test_run_stops_if_update_available(self):
+ self.update_handler._ensure_latest_agent = Mock(return_value=True)
+ self._test_run(invocations=0, calls=[], enable_updates=True)
+ return
+
+ def test_set_agents_sets_agents(self):
+ self.prepare_agents()
+
+ self.update_handler._set_agents([GuestAgent(path=path) for path in self.agent_dirs()])
+ self.assertTrue(len(self.update_handler.agents) > 0)
+ self.assertEqual(len(self.agent_dirs()), len(self.update_handler.agents))
+ return
+
+ def test_set_agents_sorts_agents(self):
+ self.prepare_agents()
+
+ self.update_handler._set_agents([GuestAgent(path=path) for path in self.agent_dirs()])
+
+ v = FlexibleVersion("100000")
+ for a in self.update_handler.agents:
+ self.assertTrue(v > a.version)
+ v = a.version
+ return
+
+
+class ChildMock(Mock):
+ def __init__(self, return_value=0, side_effect=None):
+ Mock.__init__(self, return_value=return_value, side_effect=side_effect)
+
+ self.poll = Mock(return_value=return_value, side_effect=side_effect)
+ self.wait = Mock(return_value=return_value, side_effect=side_effect)
+ return
+
+
+class ProtocolMock(object):
+ def __init__(self, family="TestAgent", etag=42, versions=None):
+ self.family = family
+ self.etag = etag
+ self.versions = versions if versions is not None else []
+ self.create_manifests()
+ self.create_packages()
+ return
+
+ def create_manifests(self):
+ self.agent_manifests = VMAgentManifestList()
+ if len(self.versions) <= 0:
+ return
+
+ if self.family is not None:
+ manifest = VMAgentManifest(family=self.family)
+ for i in range(0,10):
+ manifest_uri = "https://nowhere.msft/agent/{0}".format(i)
+ manifest.versionsManifestUris.append(VMAgentManifestUri(uri=manifest_uri))
+ self.agent_manifests.vmAgentManifests.append(manifest)
+ return
+
+ def create_packages(self):
+ self.agent_packages = ExtHandlerPackageList()
+ if len(self.versions) <= 0:
+ return
+
+ for version in self.versions:
+ package = ExtHandlerPackage(str(version))
+ for i in range(0,5):
+ package_uri = "https://nowhere.msft/agent_pkg/{0}".format(i)
+ package.uris.append(ExtHandlerPackageUri(uri=package_uri))
+ self.agent_packages.versions.append(package)
+ return
+
+ def get_protocol(self):
+ return self
+
+ def get_vmagent_manifests(self):
+ return self.agent_manifests, self.etag
+
+ def get_vmagent_pkgs(self, manifest):
+ return self.agent_packages
+
+
+class ResponseMock(Mock):
+ def __init__(self, status=restutil.httpclient.OK, response=None):
+ Mock.__init__(self)
+ self.status = status
+ self.response = response
+ return
+
+ def read(self):
+ return self.response
+
+
+class TimeMock(Mock):
+ def __init__(self, time_increment=1):
+ Mock.__init__(self)
+ self.next_time = time.time()
+ self.time_call_count = 0
+ self.time_increment = time_increment
+
+ self.sleep = Mock(return_value=0)
+ return
+
+ def time(self):
+ self.time_call_count += 1
+ current_time = self.next_time
+ self.next_time += self.time_increment
+ return current_time
+
+
+if __name__ == '__main__':
+ unittest.main()