summaryrefslogtreecommitdiff
path: root/tests/test_v1.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_v1.py')
-rw-r--r--tests/test_v1.py64
1 files changed, 47 insertions, 17 deletions
diff --git a/tests/test_v1.py b/tests/test_v1.py
index 02225af..5a1b36b 100644
--- a/tests/test_v1.py
+++ b/tests/test_v1.py
@@ -35,7 +35,7 @@ from tests.test_sharedconfig import shared_config_sample
from tests.test_certificates import certs_sample, transport_cert
from tests.test_extensionsconfig import ext_conf_sample, manifest_sample
-def mock_fetch_uri(url, headers=None, chk_proxy=False):
+def mock_fetch_config(self, url, headers=None, chk_proxy=False):
content = None
if "versions" in url:
content = VersionInfoSample
@@ -55,10 +55,10 @@ def mock_fetch_uri(url, headers=None, chk_proxy=False):
raise Exception("Bad url {0}".format(url))
return content
-def mock_fetch_manifest(uris):
+def mock_fetch_manifest(self, uris):
return manifest_sample
-def mock_fetch_cache(file_path):
+def mock_fetch_cache(self, file_path):
content = None
if "Incarnation" in file_path:
content = 1
@@ -90,13 +90,23 @@ class MockResp(object):
def read(self):
return self.data
+def mock_403():
+ return MockResp(status = v1.httpclient.FORBIDDEN)
+
+def mock_410():
+ return MockResp(status = v1.httpclient.GONE)
+
+def mock_503():
+ return MockResp(status = v1.httpclient.SERVICE_UNAVAILABLE)
+
class TestWireClint(unittest.TestCase):
@mock(v1.restutil, 'http_get', MockFunc(retval=MockResp(data=data_with_bom)))
def test_fetch_uri_with_bom(self):
- v1._fetch_uri("http://foo.bar", None)
+ client = v1.WireClient("http://foo.bar/")
+ client.fetch_config("http://foo.bar", None)
- @mock(v1, '_fetch_cache', mock_fetch_cache)
+ @mock(v1.WireClient, 'fetch_cache', mock_fetch_cache)
def test_get(self):
os.chdir('/tmp')
client = v1.WireClient("foobar")
@@ -110,14 +120,15 @@ class TestWireClint(unittest.TestCase):
self.assertNotEquals(None, extensionsConfig)
- @mock(v1, '_fetch_cache', mock_fetch_cache)
+ @mock(v1.WireClient, 'fetch_cache', mock_fetch_cache)
def test_get_head_for_cert(self):
client = v1.WireClient("foobar")
header = client.get_header_for_cert()
self.assertNotEquals(None, header)
@mock(v1.WireClient, 'get_header_for_cert', MockFunc())
- @mock(v1, '_fetch_uri', mock_fetch_uri)
+ @mock(v1.WireClient, 'fetch_config', mock_fetch_config)
+ @mock(v1.WireClient, 'fetch_manifest', mock_fetch_manifest)
@mock(v1.fileutil, 'write_file', MockFunc())
def test_update_goal_state(self):
client = v1.WireClient("foobar")
@@ -131,35 +142,54 @@ class TestWireClint(unittest.TestCase):
ext_conf = client.get_ext_conf()
self.assertNotEquals(None, ext_conf)
+ @mock(v1.time, "sleep", MockFunc())
+ def test_call_wireserver(self):
+ client = v1.WireClient("foobar")
+ self.assertRaises(v1.ProtocolError, client.call_wireserver, mock_403)
+ self.assertRaises(v1.WireProtocolResourceGone, client.call_wireserver,
+ mock_410)
+
+ @mock(v1.time, "sleep", MockFunc())
+ def test_call_storage_service(self):
+ client = v1.WireClient("foobar")
+ self.assertRaises(v1.ProtocolError, client.call_storage_service,
+ mock_503)
+
+
class TestStatusBlob(unittest.TestCase):
def testToJson(self):
vm_status = v1.VMStatus()
- status_blob = v1.StatusBlob(vm_status)
+ status_blob = v1.StatusBlob(v1.WireClient("http://foo.bar/"))
+ status_blob.set_vm_status(vm_status)
self.assertNotEquals(None, status_blob.to_json())
@mock(v1.restutil, 'http_put', MockFunc(retval=MockResp(httpclient.CREATED)))
@mock(v1.restutil, 'http_head', MockFunc(retval=MockResp(httpclient.OK)))
def test_put_page_blob(self):
vm_status = v1.VMStatus()
- status_blob = v1.StatusBlob(vm_status)
+ status_blob = v1.StatusBlob(v1.WireClient("http://foo.bar/"))
+ status_blob.set_vm_status(vm_status)
data = 'a' * 100
status_blob.put_page_blob("http://foo.bar", data)
class TestConvert(unittest.TestCase):
def test_status(self):
vm_status = v1.VMStatus()
- handler_status = v1.ExtensionHandlerStatus()
- substatus = v1.ExtensionSubStatus()
- ext_status = v1.ExtensionStatus()
+ handler_status = v1.ExtHandlerStatus(name="foo")
- vm_status.extensionHandlers.append(handler_status)
- v1.vm_status_to_v1(vm_status)
+ ext_statuses = {}
- handler_status.extensionStatusList.append(ext_status)
- v1.vm_status_to_v1(vm_status)
+ ext_name="bar"
+ ext_status = v1.ExtensionStatus()
+ handler_status.extensions.append(ext_name)
+ ext_statuses[ext_name] = ext_status
+ substatus = v1.ExtensionSubStatus()
ext_status.substatusList.append(substatus)
- v1.vm_status_to_v1(vm_status)
+
+ vm_status.vmAgent.extensionHandlers.append(handler_status)
+ v1_status = v1.vm_status_to_v1(vm_status, ext_statuses)
+ print(v1_status)
def test_param(self):
param = v1.TelemetryEventParam()