summaryrefslogtreecommitdiff
path: root/tests/ga/test_update.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ga/test_update.py')
-rw-r--r--tests/ga/test_update.py381
1 files changed, 324 insertions, 57 deletions
diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py
index 89fe95d..a431a9b 100644
--- a/tests/ga/test_update.py
+++ b/tests/ga/test_update.py
@@ -23,6 +23,7 @@ import json
import os
import platform
import random
+import re
import subprocess
import sys
import tempfile
@@ -109,8 +110,13 @@ class UpdateTestCase(AgentTestCase):
AgentTestCase.setUp(self)
return
- def agent_bin(self, version):
- return "bin/{0}-{1}.egg".format(AGENT_NAME, version)
+ def agent_bin(self, version, suffix):
+ return "bin/{0}-{1}{2}.egg".format(AGENT_NAME, version, suffix)
+
+ def rename_agent_bin(self, path, src_v, dst_v):
+ src_bin = glob.glob(os.path.join(path, self.agent_bin(src_v, '*')))[0]
+ dst_bin = os.path.join(path, self.agent_bin(dst_v, ''))
+ shutil.move(src_bin, dst_bin)
def agent_count(self):
return len(self.agent_dirs())
@@ -158,8 +164,29 @@ class UpdateTestCase(AgentTestCase):
fileutil.trim_ext(agent, "zip")))
return
- def prepare_agents(self, base_version=AGENT_VERSION, count=5, is_available=True):
- base_v = FlexibleVersion(base_version)
+ def prepare_agent(self, version):
+ """
+ Create a download for the current agent version, copied from test data
+ """
+ self.copy_agents(get_agent_pkgs()[0])
+ self.expand_agents()
+
+ versions = self.agent_versions()
+ src_v = FlexibleVersion(str(versions[0]))
+
+ from_path = self.agent_dir(src_v)
+ dst_v = FlexibleVersion(str(version))
+ to_path = self.agent_dir(dst_v)
+
+ if from_path != to_path:
+ shutil.move(from_path + ".zip", to_path + ".zip")
+ shutil.move(from_path, to_path)
+ self.rename_agent_bin(to_path, src_v, dst_v)
+ return
+
+ def prepare_agents(self,
+ count=5,
+ is_available=True):
# Ensure the test data is copied over
agent_count = self.agent_count()
@@ -172,10 +199,6 @@ class UpdateTestCase(AgentTestCase):
versions = self.agent_versions()
src_v = FlexibleVersion(str(versions[0]))
- # If the most recent agent is newer the minimum requested, use the agent version
- if base_v < src_v:
- base_v = src_v
-
# Create agent packages and directories
return self.replicate_agents(
src_v=src_v,
@@ -193,25 +216,21 @@ class UpdateTestCase(AgentTestCase):
pass
return
- def replicate_agents(
- self,
- count=5,
- src_v=AGENT_VERSION,
- is_available=True,
- increment=1):
+ def replicate_agents(self,
+ count=5,
+ src_v=AGENT_VERSION,
+ is_available=True,
+ increment=1):
from_path = self.agent_dir(src_v)
dst_v = FlexibleVersion(str(src_v))
- for i in range(0,count):
+ for i in range(0, count):
dst_v += increment
to_path = self.agent_dir(dst_v)
shutil.copyfile(from_path + ".zip", to_path + ".zip")
shutil.copytree(from_path, to_path)
- shutil.move(
- os.path.join(to_path, self.agent_bin(src_v)),
- os.path.join(to_path, self.agent_bin(dst_v)))
+ self.rename_agent_bin(to_path, src_v, dst_v)
if not is_available:
GuestAgent(to_path).mark_failure(is_fatal=True)
-
return dst_v
@@ -437,7 +456,8 @@ class TestGuestAgent(UpdateTestCase):
agent = GuestAgent(path=self.agent_path)
agent._unpack()
agent._load_manifest()
- self.assertEqual(agent.manifest.get_enable_command(), agent.get_agent_cmd())
+ self.assertEqual(agent.manifest.get_enable_command(),
+ agent.get_agent_cmd())
return
@patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
@@ -515,6 +535,34 @@ class TestGuestAgent(UpdateTestCase):
self.assertFalse(agent.is_downloaded)
return
+ @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded")
+ @patch("azurelinuxagent.ga.update.restutil.http_get")
+ def test_download_fallback(self, mock_http_get, mock_ensure):
+ self.remove_agents()
+ self.assertFalse(os.path.isdir(self.agent_path))
+
+ mock_http_get.return_value = ResponseMock(
+ status=restutil.httpclient.SERVICE_UNAVAILABLE)
+
+ ext_uri = 'ext_uri'
+ host_uri = 'host_uri'
+ mock_host = HostPluginProtocol(host_uri,
+ 'container_id',
+ 'role_config')
+
+ pkg = ExtHandlerPackage(version=str(get_agent_version()))
+ pkg.uris.append(ExtHandlerPackageUri(uri=ext_uri))
+ agent = GuestAgent(pkg=pkg)
+ agent.host = mock_host
+
+ with patch.object(HostPluginProtocol,
+ "get_artifact_request",
+ return_value=[host_uri, {}]):
+ self.assertRaises(UpdateError, agent._download)
+ self.assertEqual(mock_http_get.call_count, 2)
+ self.assertEqual(mock_http_get.call_args_list[0][0][0], ext_uri)
+ self.assertEqual(mock_http_get.call_args_list[1][0][0], host_uri)
+
@patch("azurelinuxagent.ga.update.restutil.http_get")
def test_ensure_downloaded(self, mock_http_get):
self.remove_agents()
@@ -598,13 +646,13 @@ class TestGuestAgent(UpdateTestCase):
class TestUpdate(UpdateTestCase):
def setUp(self):
UpdateTestCase.setUp(self)
+ self.event_patch = patch('azurelinuxagent.common.event.add_event')
self.update_handler = get_update_handler()
return
def test_creation(self):
self.assertTrue(self.update_handler.running)
- self.assertEqual(None, self.update_handler.last_etag)
self.assertEqual(None, self.update_handler.last_attempt_time)
self.assertEqual(0, len(self.update_handler.agents))
@@ -617,82 +665,113 @@ class TestUpdate(UpdateTestCase):
self.assertEqual(None, self.update_handler.signal_handler)
return
- def _test_ensure_latest_agent(
+ def test_emit_restart_event_writes_sentinal_file(self):
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ self.update_handler._emit_restart_event()
+ self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path()))
+ return
+
+ def test_emit_restart_event_emits_event_if_not_clean_start(self):
+ try:
+ mock_event = self.event_patch.start()
+ self.update_handler._set_sentinal()
+ self.update_handler._emit_restart_event()
+ self.assertEqual(1, mock_event.call_count)
+ except Exception as e:
+ pass
+ self.event_patch.stop()
+ return
+
+ def _test_upgrade_available(
self,
base_version=FlexibleVersion(AGENT_VERSION),
protocol=None,
- versions=None):
+ versions=None,
+ count=5):
- latest_version = self.prepare_agents()
+ latest_version = self.prepare_agents(count=count)
if versions is None or len(versions) <= 0:
versions = [latest_version]
- etag = self.update_handler.last_etag if self.update_handler.last_etag is not None else 42
if protocol is None:
- protocol = ProtocolMock(etag=etag, versions=versions)
+ protocol = ProtocolMock(versions=versions)
self.update_handler.protocol_util = protocol
conf.get_autoupdate_gafamily = Mock(return_value=protocol.family)
- return self.update_handler._ensure_latest_agent(base_version=base_version)
+ return self.update_handler._upgrade_available(base_version=base_version)
+
+ def test_upgrade_available_returns_true_on_first_use(self):
+ self.assertTrue(self._test_upgrade_available())
+ return
+
+ def test_get_latest_agent_excluded(self):
+ self.prepare_agent(AGENT_VERSION)
+ self.assertFalse(self._test_upgrade_available(
+ versions=self.agent_versions(),
+ count=1))
+ self.assertEqual(None, self.update_handler.get_latest_agent())
+ return
- def test_ensure_latest_agent_returns_true_on_first_use(self):
- self.assertEqual(None, self.update_handler.last_etag)
- self.assertTrue(self._test_ensure_latest_agent())
+ def test_upgrade_available_handles_missing_family(self):
+ extensions_config = ExtensionsConfig(load_data("wire/ext_conf_missing_family.xml"))
+ protocol = ProtocolMock()
+ protocol.family = "Prod"
+ protocol.agent_manifests = extensions_config.vmagent_manifests
+ self.update_handler.protocol_util = protocol
+ with patch('azurelinuxagent.common.logger.warn') as mock_logger:
+ with patch('tests.ga.test_update.ProtocolMock.get_vmagent_pkgs', side_effect=ProtocolError):
+ self.assertFalse(self.update_handler._upgrade_available(base_version=CURRENT_VERSION))
+ self.assertEqual(0, mock_logger.call_count)
return
- def test_ensure_latest_agent_includes_old_agents(self):
+ def test_upgrade_available_includes_old_agents(self):
self.prepare_agents()
- old_count = FlexibleVersion(AGENT_VERSION).version[-1]
old_version = self.agent_versions()[-1]
+ old_count = old_version.version[-1]
self.replicate_agents(src_v=old_version, count=old_count, increment=-1)
all_count = len(self.agent_versions())
- self.assertTrue(self._test_ensure_latest_agent(versions=self.agent_versions()))
+ self.assertTrue(self._test_upgrade_available(versions=self.agent_versions()))
self.assertEqual(all_count, len(self.update_handler.agents))
return
- def test_ensure_lastest_agent_purges_old_agents(self):
+ def test_upgrade_available_purges_old_agents(self):
self.prepare_agents()
agent_count = self.agent_count()
self.assertEqual(5, agent_count)
agent_versions = self.agent_versions()[:3]
- self.assertTrue(self._test_ensure_latest_agent(versions=agent_versions))
+ self.assertTrue(self._test_upgrade_available(versions=agent_versions))
self.assertEqual(len(agent_versions), len(self.update_handler.agents))
self.assertEqual(agent_versions, self.agent_versions())
return
- def test_ensure_latest_agent_skips_if_too_frequent(self):
+ def test_upgrade_available_skips_if_too_frequent(self):
conf.get_autoupdate_frequency = Mock(return_value=10000)
self.update_handler.last_attempt_time = time.time()
- self.assertFalse(self._test_ensure_latest_agent())
+ self.assertFalse(self._test_upgrade_available())
return
- def test_ensure_latest_agent_skips_when_etag_matches(self):
- self.update_handler.last_etag = 42
- self.assertFalse(self._test_ensure_latest_agent())
- return
-
- def test_ensure_latest_agent_skips_if_when_no_new_versions(self):
+ def test_upgrade_available_skips_if_when_no_new_versions(self):
self.prepare_agents()
base_version = self.agent_versions()[0] + 1
- self.assertFalse(self._test_ensure_latest_agent(base_version=base_version))
+ self.assertFalse(self._test_upgrade_available(base_version=base_version))
return
- def test_ensure_latest_agent_skips_when_no_versions(self):
- self.assertFalse(self._test_ensure_latest_agent(protocol=ProtocolMock()))
+ def test_upgrade_available_skips_when_no_versions(self):
+ self.assertFalse(self._test_upgrade_available(protocol=ProtocolMock()))
return
- def test_ensure_latest_agent_skips_when_updates_are_disabled(self):
+ def test_upgrade_available_skips_when_updates_are_disabled(self):
conf.get_autoupdate_enabled = Mock(return_value=False)
- self.assertFalse(self._test_ensure_latest_agent())
+ self.assertFalse(self._test_upgrade_available())
return
- def test_ensure_latest_agent_sorts(self):
+ def test_upgrade_available_sorts(self):
self.prepare_agents()
- self._test_ensure_latest_agent()
+ self._test_upgrade_available()
v = FlexibleVersion("100000")
for a in self.update_handler.agents:
@@ -700,6 +779,58 @@ class TestUpdate(UpdateTestCase):
v = a.version
return
+ def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL):
+ with patch.object(self.update_handler, 'osutil') as mock_util:
+ # Note:
+ # - Python only allows mutations of objects to which a function has
+ # a reference. Incrementing an integer directly changes the
+ # reference. Incrementing an item of a list changes an item to
+ # which the code has a reference.
+ # See http://stackoverflow.com/questions/26408941/python-nested-functions-and-variable-scope
+ iterations = [0]
+ def iterator(*args, **kwargs):
+ iterations[0] += 1
+ return iterations[0] < invocations
+
+ mock_util.check_pid_alive = Mock(side_effect=iterator)
+
+ with patch('os.getpid', return_value=42):
+ with patch('time.sleep', return_value=None) as mock_sleep:
+ self.update_handler._ensure_no_orphans(orphan_wait_interval=interval)
+ return mock_util.check_pid_alive.call_count, mock_sleep.call_count
+ return
+
+ def test_ensure_no_orphans(self):
+ fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41))
+ calls, sleeps = self._test_ensure_no_orphans(invocations=3)
+ self.assertEqual(3, calls)
+ self.assertEqual(2, sleeps)
+ return
+
+ def test_ensure_no_orphans_skips_if_no_orphans(self):
+ calls, sleeps = self._test_ensure_no_orphans(invocations=3)
+ self.assertEqual(0, calls)
+ self.assertEqual(0, sleeps)
+ return
+
+ def test_ensure_no_orphans_ignores_exceptions(self):
+ with patch('azurelinuxagent.common.utils.fileutil.read_file', side_effect=Exception):
+ calls, sleeps = self._test_ensure_no_orphans(invocations=3)
+ self.assertEqual(0, calls)
+ self.assertEqual(0, sleeps)
+ return
+
+ def test_ensure_no_orphans_kills_after_interval(self):
+ fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41))
+ with patch('os.kill') as mock_kill:
+ calls, sleeps = self._test_ensure_no_orphans(
+ invocations=4,
+ interval=3*GOAL_STATE_INTERVAL)
+ self.assertEqual(3, calls)
+ self.assertEqual(2, sleeps)
+ self.assertEqual(1, mock_kill.call_count)
+ return
+
def _test_evaluate_agent_health(self, child_agent_index=0):
self.prepare_agents()
@@ -789,6 +920,59 @@ class TestUpdate(UpdateTestCase):
self.assertEqual(latest_agent.version, prior_agent.version)
return
+ def test_get_pid_files(self):
+ previous_pid_file, pid_file, = self.update_handler._get_pid_files()
+ self.assertEqual(None, previous_pid_file)
+ self.assertEqual("0_waagent.pid", os.path.basename(pid_file))
+ return
+
+ def test_get_pid_files_returns_previous(self):
+ for n in range(1250):
+ fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1))
+ previous_pid_file, pid_file, = self.update_handler._get_pid_files()
+ self.assertEqual("1249_waagent.pid", os.path.basename(previous_pid_file))
+ self.assertEqual("1250_waagent.pid", os.path.basename(pid_file))
+ return
+
+ def test_is_clean_start_returns_true_when_no_sentinal(self):
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ self.assertTrue(self.update_handler._is_clean_start)
+ return
+
+ def test_is_clean_start_returns_true_sentinal_agent_is_not_current(self):
+ self.update_handler._set_sentinal(agent="Not the Current Agent")
+ self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path()))
+ self.assertTrue(self.update_handler._is_clean_start)
+ return
+
+ def test_is_clean_start_returns_false_for_current_agent(self):
+ self.update_handler._set_sentinal(agent=CURRENT_AGENT)
+ self.assertFalse(self.update_handler._is_clean_start)
+ return
+
+ def test_is_clean_start_returns_false_for_exceptions(self):
+ self.update_handler._set_sentinal()
+ with patch("azurelinuxagent.common.utils.fileutil.read_file", side_effect=Exception):
+ self.assertFalse(self.update_handler._is_clean_start)
+ return
+
+ def test_is_orphaned_returns_false_if_parent_exists(self):
+ fileutil.write_file(conf.get_agent_pid_file_path(), ustr(42))
+ with patch('os.getppid', return_value=42):
+ self.assertFalse(self.update_handler._is_orphaned)
+ return
+
+ def test_is_orphaned_returns_true_if_parent_is_init(self):
+ with patch('os.getppid', return_value=1):
+ self.assertTrue(self.update_handler._is_orphaned)
+ return
+
+ def test_is_orphaned_returns_true_if_parent_does_not_exist(self):
+ fileutil.write_file(conf.get_agent_pid_file_path(), ustr(24))
+ with patch('os.getppid', return_value=42):
+ self.assertTrue(self.update_handler._is_orphaned)
+ return
+
def test_load_agents(self):
self.prepare_agents()
@@ -868,13 +1052,14 @@ class TestUpdate(UpdateTestCase):
agent = self.update_handler.get_latest_agent()
args, kwargs = self._test_run_latest()
- cmds = shlex.split(agent.get_agent_cmd())
+ cmds = textutil.safe_shlex_split(agent.get_agent_cmd())
if cmds[0].lower() == "python":
cmds[0] = get_python_cmd()
self.assertEqual(args[0], cmds)
self.assertEqual(True, 'cwd' in kwargs)
self.assertEqual(agent.get_agent_dir(), kwargs['cwd'])
+ self.assertEqual(False, '\x00' in cmds[0])
return
def test_run_latest_polls_and_waits_for_success(self):
@@ -993,7 +1178,7 @@ class TestUpdate(UpdateTestCase):
# reference. Incrementing an item of a list changes an item to
# which the code has a reference.
# See http://stackoverflow.com/questions/26408941/python-nested-functions-and-variable-scope
- iterations = [0]
+ iterations = [0]
def iterator(*args, **kwargs):
iterations[0] += 1
if iterations[0] >= invocations:
@@ -1001,14 +1186,19 @@ class TestUpdate(UpdateTestCase):
return
calls = calls * invocations
-
+
+ fileutil.write_file(conf.get_agent_pid_file_path(), ustr(42))
+
with patch('azurelinuxagent.ga.exthandlers.get_exthandlers_handler') as mock_handler:
with patch('azurelinuxagent.ga.monitor.get_monitor_handler') as mock_monitor:
with patch('azurelinuxagent.ga.env.get_env_handler') as mock_env:
with patch('time.sleep', side_effect=iterator) as mock_sleep:
with patch('sys.exit') as mock_exit:
-
- self.update_handler.run()
+ if isinstance(os.getppid, MagicMock):
+ self.update_handler.run()
+ else:
+ with patch('os.getppid', return_value=42):
+ self.update_handler.run()
self.assertEqual(1, mock_handler.call_count)
self.assertEqual(mock_handler.return_value.method_calls, calls)
@@ -1016,6 +1206,7 @@ class TestUpdate(UpdateTestCase):
self.assertEqual(1, mock_monitor.call_count)
self.assertEqual(1, mock_env.call_count)
self.assertEqual(1, mock_exit.call_count)
+
return
def test_run(self):
@@ -1027,8 +1218,30 @@ class TestUpdate(UpdateTestCase):
return
def test_run_stops_if_update_available(self):
- self.update_handler._ensure_latest_agent = Mock(return_value=True)
+ self.update_handler._upgrade_available = Mock(return_value=True)
+ self._test_run(invocations=0, calls=[], enable_updates=True)
+ return
+
+ def test_run_stops_if_orphaned(self):
+ with patch('os.getppid', return_value=1):
+ self._test_run(invocations=0, calls=[], enable_updates=True)
+ return
+
+ def test_run_clears_sentinal_on_successful_exit(self):
+ self._test_run()
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ return
+
+ def test_run_leaves_sentinal_on_unsuccessful_exit(self):
+ self.update_handler._upgrade_available = Mock(side_effect=Exception)
self._test_run(invocations=0, calls=[], enable_updates=True)
+ self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path()))
+ return
+
+ def test_run_emits_restart_event(self):
+ self.update_handler._emit_restart_event = Mock()
+ self._test_run()
+ self.assertEqual(1, self.update_handler._emit_restart_event.call_count)
return
def test_set_agents_sets_agents(self):
@@ -1050,6 +1263,59 @@ class TestUpdate(UpdateTestCase):
v = a.version
return
+ def test_set_sentinal(self):
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ self.update_handler._set_sentinal()
+ self.assertTrue(os.path.isfile(self.update_handler._sentinal_file_path()))
+ return
+
+ def test_set_sentinal_writes_current_agent(self):
+ self.update_handler._set_sentinal()
+ self.assertTrue(
+ fileutil.read_file(self.update_handler._sentinal_file_path()),
+ CURRENT_AGENT)
+ return
+
+ def test_shutdown(self):
+ self.update_handler._set_sentinal()
+ self.update_handler._shutdown()
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ return
+
+ def test_shutdown_ignores_missing_sentinal_file(self):
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ self.update_handler._shutdown()
+ self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path()))
+ return
+
+ def test_shutdown_ignores_exceptions(self):
+ self.update_handler._set_sentinal()
+
+ try:
+ with patch("os.remove", side_effect=Exception):
+ self.update_handler._shutdown()
+ except Exception as e:
+ self.assertTrue(False, "Unexpected exception")
+ return
+
+ def test_write_pid_file(self):
+ for n in range(1112):
+ fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1))
+ with patch('os.getpid', return_value=1112):
+ previous_pid_file, pid_file = self.update_handler._write_pid_file()
+ self.assertEqual("1111_waagent.pid", os.path.basename(previous_pid_file))
+ self.assertEqual("1112_waagent.pid", os.path.basename(pid_file))
+ self.assertEqual(fileutil.read_file(pid_file), ustr(1112))
+ return
+
+ def test_write_pid_file_ignores_exceptions(self):
+ with patch('azurelinuxagent.common.utils.fileutil.write_file', side_effect=Exception):
+ with patch('os.getpid', return_value=42):
+ previous_pid_file, pid_file = self.update_handler._write_pid_file()
+ self.assertEqual(None, previous_pid_file)
+ self.assertEqual(None, pid_file)
+ return
+
class ChildMock(Mock):
def __init__(self, return_value=0, side_effect=None):
@@ -1061,8 +1327,9 @@ class ChildMock(Mock):
class ProtocolMock(object):
- def __init__(self, family="TestAgent", etag=42, versions=None):
+ def __init__(self, family="TestAgent", etag=42, versions=None, client=None):
self.family = family
+ self.client = client
self.etag = etag
self.versions = versions if versions is not None else []
self.create_manifests()