diff options
Diffstat (limited to 'azurelinuxagent/protocol/v1.py')
-rw-r--r-- | azurelinuxagent/protocol/v1.py | 667 |
1 files changed, 412 insertions, 255 deletions
diff --git a/azurelinuxagent/protocol/v1.py b/azurelinuxagent/protocol/v1.py index 54a80b6..92fcc06 100644 --- a/azurelinuxagent/protocol/v1.py +++ b/azurelinuxagent/protocol/v1.py @@ -1,4 +1,4 @@ -# Windows Azure Linux Agent +# Microsoft Azure Linux Agent # # Copyright 2014 Microsoft Corporation # @@ -24,7 +24,7 @@ import traceback import xml.sax.saxutils as saxutils import xml.etree.ElementTree as ET import azurelinuxagent.logger as logger -from azurelinuxagent.future import text, httpclient +from azurelinuxagent.future import text, httpclient, bytebuffer import azurelinuxagent.utils.restutil as restutil from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext, \ getattrib, gettext, remove_bom @@ -54,6 +54,9 @@ TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" PROTOCOL_VERSION = "2012-11-30" +SHORT_WAITING_INTERVAL = 1 # 1 second +LONG_WAITING_INTERVAL = 15 # 15 seconds + class WireProtocolResourceGone(ProtocolError): pass @@ -77,113 +80,89 @@ class WireProtocol(Protocol): certificates = self.client.get_certs() return certificates.cert_list - def get_extensions(self): + def get_ext_handlers(self): #Update goal state to get latest extensions config self.client.update_goal_state() ext_conf = self.client.get_ext_conf() - return ext_conf.ext_list + return ext_conf.ext_handlers - def get_extension_pkgs(self, extension): + def get_ext_handler_pkgs(self, ext_handler): goal_state = self.client.get_goal_state() - man = self.client.get_ext_manifest(extension, goal_state) + man = self.client.get_ext_manifest(ext_handler, goal_state) return man.pkg_list - def report_provision_status(self, provisionStatus): - validata_param("provisionStatus", provisionStatus, ProvisionStatus) - - if provisionStatus.status is not None: - self.client.report_health(provisionStatus.status, - provisionStatus.subStatus, - provisionStatus.description) - if provisionStatus.properties.certificateThumbprint is not None: - thumbprint = provisionStatus.properties.certificateThumbprint + def report_provision_status(self, provision_status): + validata_param("provision_status", provision_status, ProvisionStatus) + + if provision_status.status is not None: + self.client.report_health(provision_status.status, + provision_status.subStatus, + provision_status.description) + if provision_status.properties.certificateThumbprint is not None: + thumbprint = provision_status.properties.certificateThumbprint self.client.report_role_prop(thumbprint) - def report_status(self, vmStatus): - validata_param("vmStatus", vmStatus, VMStatus) - self.client.upload_status_blob(vmStatus) + def report_vm_status(self, vm_status): + validata_param("vm_status", vm_status, VMStatus) + self.client.status_blob.set_vm_status(vm_status) + self.client.upload_status_blob() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validata_param("ext_status", ext_status, ExtensionStatus) + self.client.status_blob.set_ext_status(ext_handler_name, ext_status) def report_event(self, events): validata_param("events", events, TelemetryEventList) self.client.report_event(events) -def _fetch_cache(local_file): - if not os.path.isfile(local_file): - raise ProtocolError("{0} is missing.".format(local_file)) - return fileutil.read_file(local_file) - -def _fetch_uri(uri, headers, chk_proxy=False): - try: - resp = restutil.http_get(uri, headers, chk_proxy=chk_proxy) - except restutil.HttpError as e: - raise ProtocolError(text(e)) - - if(resp.status == httpclient.GONE): - raise WireProtocolResourceGone(uri) - if(resp.status != httpclient.OK): - raise ProtocolError("{0} - {1}".format(resp.status, uri)) - data = resp.read() - if data is None: - return None - data = remove_bom(data) - xml_text = text(data, encoding='utf-8') - return xml_text - -def _fetch_manifest(version_uris): - for version_uri in version_uris: - try: - xml_text = _fetch_uri(version_uri.uri, None, chk_proxy=True) - return xml_text - except IOError as e: - logger.warn("Failed to fetch ExtensionManifest: {0}, {1}", e, - version_uri.uri) - raise ProtocolError("Failed to fetch ExtensionManifest from all sources") - def _build_role_properties(container_id, role_instance_id, thumbprint): - xml = ("<?xml version=\"1.0\" encoding=\"utf-8\"?>" - "<RoleProperties>" - "<Container>" - "<ContainerId>{0}</ContainerId>" - "<RoleInstances>" - "<RoleInstance>" - "<Id>{1}</Id>" - "<Properties>" - "<Property name=\"CertificateThumbprint\" value=\"{2}\" />" - "</Properties>" - "</RoleInstance>" - "</RoleInstances>" - "</Container>" - "</RoleProperties>" - "").format(container_id, role_instance_id, thumbprint) + xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>" + u"<RoleProperties>" + u"<Container>" + u"<ContainerId>{0}</ContainerId>" + u"<RoleInstances>" + u"<RoleInstance>" + u"<Id>{1}</Id>" + u"<Properties>" + u"<Property name=\"CertificateThumbprint\" value=\"{2}\" />" + u"</Properties>" + u"</RoleInstance>" + u"</RoleInstances>" + u"</Container>" + u"</RoleProperties>" + u"").format(container_id, role_instance_id, thumbprint) return xml def _build_health_report(incarnation, container_id, role_instance_id, status, substatus, description): - detail = '' + #Escape '&', '<' and '>' + description = saxutils.escape(text(description)) + detail = u'' if substatus is not None: - detail = ("<Details>" - "<SubStatus>{0}</SubStatus>" - "<Description>{1}</Description>" - "</Details>").format(substatus, description) - xml = ("<?xml version=\"1.0\" encoding=\"utf-8\"?>" - "<Health " - "xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"" - " xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\">" - "<GoalStateIncarnation>{0}</GoalStateIncarnation>" - "<Container>" - "<ContainerId>{1}</ContainerId>" - "<RoleInstanceList>" - "<Role>" - "<InstanceId>{2}</InstanceId>" - "<Health>" - "<State>{3}</State>" - "{4}" - "</Health>" - "</Role>" - "</RoleInstanceList>" - "</Container>" - "</Health>" - "").format(incarnation, + substatus = saxutils.escape(text(substatus)) + detail = (u"<Details>" + u"<SubStatus>{0}</SubStatus>" + u"<Description>{1}</Description>" + u"</Details>").format(substatus, description) + xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>" + u"<Health " + u"xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"" + u" xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\">" + u"<GoalStateIncarnation>{0}</GoalStateIncarnation>" + u"<Container>" + u"<ContainerId>{1}</ContainerId>" + u"<RoleInstanceList>" + u"<Role>" + u"<InstanceId>{2}</InstanceId>" + u"<Health>" + u"<State>{3}</State>" + u"{4}" + u"</Health>" + u"</Role>" + u"</RoleInstanceList>" + u"</Container>" + u"</Health>" + u"").format(incarnation, container_id, role_instance_id, status, @@ -193,19 +172,19 @@ def _build_health_report(incarnation, container_id, role_instance_id, """ Convert VMStatus object to status blob format """ -def guest_agent_status_to_v1(ga_status): +def ga_status_to_v1(ga_status): formatted_msg = { 'lang' : 'en-US', 'message' : ga_status.message } v1_ga_status = { - 'version' : ga_status.agentVersion, + 'version' : ga_status.version, 'status' : ga_status.status, 'formattedMessage' : formatted_msg } return v1_ga_status -def extension_substatus_to_v1(sub_status_list): +def ext_substatus_to_v1(sub_status_list): status_list = [] for substatus in sub_status_list: status = { @@ -220,14 +199,14 @@ def extension_substatus_to_v1(sub_status_list): status_list.append(status) return status_list -def extension_handler_status_to_v1(handler_status, timestamp): - if handler_status is None or len(handler_status.extensionStatusList) == 0: - return - ext_status = handler_status.extensionStatusList[0] - sub_status = extension_substatus_to_v1(ext_status.substatusList) - ext_in_status = { +def ext_status_to_v1(ext_name, ext_status): + if ext_status is None: + return None + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + v1_sub_status = ext_substatus_to_v1(ext_status.substatusList) + v1_ext_status = { "status":{ - "name": ext_status.name, + "name": ext_name, "configurationAppliedTime": ext_status.configurationAppliedTime, "operation": ext_status.operation, "status": ext_status.status, @@ -237,33 +216,47 @@ def extension_handler_status_to_v1(handler_status, timestamp): "message": ext_status.message } }, + "version": 1.0, "timestampUTC": timestamp } - - if len(sub_status) != 0: - ext_in_status['substatus'] = sub_status - + if len(v1_sub_status) != 0: + v1_ext_status['substatus'] = v1_sub_status + return v1_ext_status + +def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): v1_handler_status = { - 'handlerVersion' : handler_status.handlerVersion, - 'handlerName' : handler_status.handlerName, + 'handlerVersion' : handler_status.version, + 'handlerName' : handler_status.name, 'status' : handler_status.status, - 'runtimeSettingsStatus' : { - 'settingsStatus' : ext_in_status, - 'sequenceNumber' : ext_status.sequenceNumber - } } - return v1_handler_status + if handler_status.message is not None: + v1_handler_status["formattedMessage"] = { + "lang":"en-US", + "message": handler_status.message + } + if len(handler_status.extensions) > 0: + #Currently, no more than one extension per handler + ext_name = handler_status.extensions[0] + ext_status = ext_statuses.get(ext_name) + v1_ext_status = ext_status_to_v1(ext_name, ext_status) + if ext_status is not None and v1_ext_status is not None: + v1_handler_status["runtimeSettingsStatus"] = { + 'settingsStatus' : v1_ext_status, + 'sequenceNumber' : ext_status.sequenceNumber + } + return v1_handler_status -def vm_status_to_v1(vm_status): +def vm_status_to_v1(vm_status, ext_statuses): timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - v1_ga_status = guest_agent_status_to_v1(vm_status.vmAgent) + v1_ga_status = ga_status_to_v1(vm_status.vmAgent) v1_handler_status_list = [] - for handler_status in vm_status.extensionHandlers: - v1_handler_status = extension_handler_status_to_v1(handler_status, - timestamp) - v1_handler_status_list.append(v1_handler_status) + for handler_status in vm_status.vmAgent.extensionHandlers: + v1_handler_status = ext_handler_status_to_v1(handler_status, + ext_statuses, timestamp) + if v1_handler_status is not None: + v1_handler_status_list.append(v1_handler_status) v1_agg_status = { 'guestAgentStatus': v1_ga_status, @@ -278,35 +271,53 @@ def vm_status_to_v1(vm_status): class StatusBlob(object): - def __init__(self, vm_status): - self.vm_status = vm_status + def __init__(self, client): + self.vm_status = None + self.ext_statuses = {} + self.client = client + def set_vm_status(self, vm_status): + validata_param("vmAgent", vm_status, VMStatus) + self.vm_status = vm_status + + def set_ext_status(self, ext_handler_name, ext_status): + validata_param("extensionStatus", ext_status, ExtensionStatus) + self.ext_statuses[ext_handler_name]= ext_status + def to_json(self): - report = vm_status_to_v1(self.vm_status) + report = vm_status_to_v1(self.vm_status, self.ext_statuses) return json.dumps(report) __storage_version__ = "2014-02-14" def upload(self, url): - logger.info("Upload status blob") + #TODO upload extension only if content has changed + logger.verb("Upload status blob") blob_type = self.get_blob_type(url) data = self.to_json() - if blob_type == "BlockBlob": - self.put_block_blob(url, data) - elif blob_type == "PageBlob": - self.put_page_blob(url, data) - else: - raise ProtocolError("Unknown blob type: {0}".format(blob_type)) + try: + if blob_type == "BlockBlob": + self.put_block_blob(url, data) + elif blob_type == "PageBlob": + self.put_page_blob(url, data) + else: + raise ProtocolError("Unknown blob type: {0}".format(blob_type)) + except restutil.HttpError as e: + raise ProtocolError("Failed to upload status blob: {0}".format(e)) def get_blob_type(self, url): #Check blob type logger.verb("Check blob type.") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - resp = restutil.http_head(url, { - "x-ms-date" : timestamp, - 'x-ms-version' : self.__class__.__storage_version__ - }) + try: + resp = self.client.call_storage_service(restutil.http_head, url, { + "x-ms-date" : timestamp, + 'x-ms-version' : self.__class__.__storage_version__ + }) + except restutil.HttpError as e: + raise ProtocolError((u"Failed to get status blob type: {0}" + u"").format(e)) if resp is None or resp.status != httpclient.OK: raise ProtocolError(("Failed to get status blob type: {0}" "").format(resp.status)) @@ -318,33 +329,47 @@ class StatusBlob(object): def put_block_blob(self, url, data): logger.verb("Upload block blob") timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - resp = restutil.http_put(url, data, { - "x-ms-date" : timestamp, - "x-ms-blob-type" : "BlockBlob", - "Content-Length": text(len(data)), - "x-ms-version" : self.__class__.__storage_version__ - }) - if resp is None or resp.status != httpclient.CREATED: + try: + resp = self.client.call_storage_service(restutil.http_put, url, + data, { + "x-ms-date" : timestamp, + "x-ms-blob-type" : "BlockBlob", + "Content-Length": text(len(data)), + "x-ms-version" : self.__class__.__storage_version__ + }) + except restutil.HttpError as e: + raise ProtocolError((u"Failed to upload block blob: {0}" + u"").format(e)) + if resp.status != httpclient.CREATED: raise ProtocolError(("Failed to upload block blob: {0}" "").format(resp.status)) def put_page_blob(self, url, data): logger.verb("Replace old page blob") + + #Convert string into bytes + data=bytearray(data, encoding='utf-8') timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + #Align to 512 bytes - page_blob_size = ((len(data) + 511) / 512) * 512 - resp = restutil.http_put(url, "", { - "x-ms-date" : timestamp, - "x-ms-blob-type" : "PageBlob", - "Content-Length": "0", - "x-ms-blob-content-length" : text(page_blob_size), - "x-ms-version" : self.__class__.__storage_version__ - }) - if resp is None or resp.status != httpclient.CREATED: + page_blob_size = int((len(data) + 511) / 512) * 512 + try: + resp = self.client.call_storage_service(restutil.http_put, url, + "", { + "x-ms-date" : timestamp, + "x-ms-blob-type" : "PageBlob", + "Content-Length": "0", + "x-ms-blob-content-length" : text(page_blob_size), + "x-ms-version" : self.__class__.__storage_version__ + }) + except restutil.HttpError as e: + raise ProtocolError((u"Failed to clean up page blob: {0}" + u"").format(e)) + if resp.status != httpclient.CREATED: raise ProtocolError(("Failed to clean up page blob: {0}" "").format(resp.status)) - if '?' in url < 0: + if url.count("?") < 0: url = "{0}?comp=page".format(url) else: url = "{0}&comp=page".format(url) @@ -359,15 +384,20 @@ class StatusBlob(object): #Align to 512 bytes page_end = int((end + 511) / 512) * 512 buf_size = page_end - start - buf = bytearray(source=data[start:end], encoding="utf-8") - #TODO buffer is not defined in python3, however we need this to make httplib to work on python 2.6 - resp = restutil.http_put(url, buf, { - "x-ms-date" : timestamp, - "x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1), - "x-ms-page-write" : "update", - "x-ms-version" : self.__class__.__storage_version__, - "Content-Length": text(page_end - start) - }) + buf = bytearray(buf_size) + buf[0: content_size] = data[start: end] + try: + resp = self.client.call_storage_service(restutil.http_put, url, + bytebuffer(buf), { + "x-ms-date" : timestamp, + "x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1), + "x-ms-page-write" : "update", + "x-ms-version" : self.__class__.__storage_version__, + "Content-Length": text(page_end - start) + }) + except restutil.HttpError as e: + raise ProtocolError((u"Failed to upload page blob: {0}" + u"").format(e)) if resp is None or resp.status != httpclient.CREATED: raise ProtocolError(("Failed to upload page blob: {0}" "").format(resp.status)) @@ -408,59 +438,176 @@ class WireClient(object): self.shared_conf = None self.certs = None self.ext_conf = None + self.last_request = 0 self.req_count = 0 + self.status_blob = StatusBlob(self) + + def prevent_throttling(self): + """ + Try to avoid throttling of wire server + """ + now = time.time() + if now - self.last_request < 1: + logger.info("Last request issued less than 1 second ago") + logger.info("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.last_request = now + + self.req_count += 1 + if self.req_count % 3 == 0: + logger.info("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.req_count = 0 + + def call_wireserver(self, http_req, *args, **kwargs): + """ + Call wire server. Handle throttling(403) and Resource Gone(410) + """ + self.prevent_throttling() + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.FORBIDDEN: + logger.warn("Sending too much request to wire server") + logger.info("Sleep {0} second to avoid throttling.", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + elif resp.status == httpclient.GONE: + msg = args[0] if len(args) > 0 else "" + raise WireProtocolResourceGone(msg) + else: + return resp + raise ProtocolError(("Calling wire server failed: {0}" + "").format(resp.status)) + + def decode_config(self, data): + if data is None: + return None + data = remove_bom(data) + xml_text = text(data, encoding='utf-8') + return xml_text + + def fetch_config(self, uri, headers): + try: + resp = self.call_wireserver(restutil.http_get, uri, + headers=headers) + except restutil.HttpError as e: + raise ProtocolError(text(e)) + + if(resp.status != httpclient.OK): + raise ProtocolError("{0} - {1}".format(resp.status, uri)) + + return self.decode_config(resp.read()) + + def fetch_cache(self, local_file): + if not os.path.isfile(local_file): + raise ProtocolError("{0} is missing.".format(local_file)) + try: + return fileutil.read_file(local_file) + except IOError as e: + raise ProtocolError("Failed to read cache: {0}".format(e)) + + def save_cache(self, local_file, data): + try: + fileutil.write_file(local_file, data) + except IOError as e: + raise ProtocolError("Failed to write cache: {0}".format(e)) + + def call_storage_service(self, http_req, *args, **kwargs): + """ + Call storage service, handle SERVICE_UNAVAILABLE(503) + """ + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.SERVICE_UNAVAILABLE: + logger.warn("Storage service is not avaible temporaryly") + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + else: + return resp + raise ProtocolError(("Calling storage endpoint failed: {0}" + "").format(resp.status)) + + def fetch_manifest(self, version_uris): + for version_uri in version_uris: + try: + resp = self.call_storage_service(restutil.http_get, + version_uri.uri, None, + chk_proxy=True) + except restutil.HttpError as e: + raise ProtocolError(text(e)) + + if resp.status == httpclient.OK: + return self.decode_config(resp.read()) + logger.warn("Failed to fetch ExtensionManifest: {0}, {1}", + resp.status, version_uri.uri) + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + raise ProtocolError(("Failed to fetch ExtensionManifest from " + "all sources")) + def update_hosting_env(self, goal_state): if goal_state.hosting_env_uri is None: raise ProtocolError("HostingEnvironmentConfig uri is empty") local_file = HOSTING_ENV_FILE_NAME - xml_text = _fetch_uri(goal_state.hosting_env_uri, self.get_header()) - fileutil.write_file(local_file, xml_text) + xml_text = self.fetch_config(goal_state.hosting_env_uri, + self.get_header()) + self.save_cache(local_file, xml_text) self.hosting_env = HostingEnv(xml_text) def update_shared_conf(self, goal_state): if goal_state.shared_conf_uri is None: raise ProtocolError("SharedConfig uri is empty") local_file = SHARED_CONF_FILE_NAME - xml_text = _fetch_uri(goal_state.shared_conf_uri, self.get_header()) - fileutil.write_file(local_file, xml_text) + xml_text = self.fetch_config(goal_state.shared_conf_uri, + self.get_header()) + self.save_cache(local_file, xml_text) self.shared_conf = SharedConfig(xml_text) def update_certs(self, goal_state): if goal_state.certs_uri is None: return local_file = CERTS_FILE_NAME - xml_text = _fetch_uri(goal_state.certs_uri, self.get_header_for_cert()) - fileutil.write_file(local_file, xml_text) - self.certs = Certificates(xml_text) + xml_text = self.fetch_config(goal_state.certs_uri, + self.get_header_for_cert()) + self.save_cache(local_file, xml_text) + self.certs = Certificates(self, xml_text) def update_ext_conf(self, goal_state): if goal_state.ext_uri is None: - raise ProtocolError("ExtensionsConfig uri is empty") + logger.info("ExtensionsConfig.xml uri is empty") + self.ext_conf = ExtensionsConfig(None) + return incarnation = goal_state.incarnation local_file = EXT_CONF_FILE_NAME.format(incarnation) - xml_text = _fetch_uri(goal_state.ext_uri, - self.get_header()) - fileutil.write_file(local_file, xml_text) + xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) + self.save_cache(local_file, xml_text) self.ext_conf = ExtensionsConfig(xml_text) - for extension in self.ext_conf.ext_list.extensions: - self.update_ext_manifest(extension, goal_state) + for ext_handler in self.ext_conf.ext_handlers.extHandlers: + self.update_ext_handler_manifest(ext_handler, goal_state) - def update_ext_manifest(self, extension, goal_state): - local_file = MANIFEST_FILE_NAME.format(extension.name, - goal_state.incarnation) - xml_text = _fetch_manifest(extension.version_uris) - fileutil.write_file(local_file, xml_text) + def update_ext_handler_manifest(self, ext_handler, goal_state): + local_file = MANIFEST_FILE_NAME.format(ext_handler.name, + goal_state.incarnation) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) def update_goal_state(self, forced=False, max_retry=3): uri = GOAL_STATE_URI.format(self.endpoint) - xml_text = _fetch_uri(uri, self.get_header()) + xml_text = self.fetch_config(uri, self.get_header()) goal_state = GoalState(xml_text) + incarnation_file = os.path.join(OSUTIL.get_lib_dir(), + INCARNATION_FILE_NAME) + if not forced: last_incarnation = None - if(os.path.isfile(INCARNATION_FILE_NAME)): - last_incarnation = fileutil.read_file(INCARNATION_FILE_NAME) + if(os.path.isfile(incarnation_file)): + last_incarnation = fileutil.read_file(incarnation_file) new_incarnation = goal_state.incarnation if last_incarnation is not None and \ last_incarnation == new_incarnation: @@ -471,10 +618,10 @@ class WireClient(object): for retry in range(0, max_retry): try: self.goal_state = goal_state - goal_state_file = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) - fileutil.write_file(goal_state_file, xml_text) - fileutil.write_file(INCARNATION_FILE_NAME, - goal_state.incarnation) + file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) + goal_state_file = os.path.join(OSUTIL.get_lib_dir(), file_name) + self.save_cache(goal_state_file, xml_text) + self.save_cache(incarnation_file, goal_state.incarnation) self.update_hosting_env(goal_state) self.update_shared_conf(goal_state) self.update_certs(goal_state) @@ -482,35 +629,35 @@ class WireClient(object): return except WireProtocolResourceGone: logger.info("Incarnation is out of date. Update goalstate.") - xml_text = _fetch_uri(GOAL_STATE_URI, self.get_header()) + xml_text = self.fetch_config(uri, self.get_header()) goal_state = GoalState(xml_text) raise ProtocolError("Exceeded max retry updating goal state") def get_goal_state(self): if(self.goal_state is None): - incarnation = _fetch_cache(INCARNATION_FILE_NAME) + incarnation = self.fetch_cache(INCARNATION_FILE_NAME) goal_state_file = GOAL_STATE_FILE_NAME.format(incarnation) - xml_text = _fetch_cache(goal_state_file) + xml_text = self.fetch_cache(goal_state_file) self.goal_state = GoalState(xml_text) return self.goal_state def get_hosting_env(self): if(self.hosting_env is None): - xml_text = _fetch_cache(HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_cache(HOSTING_ENV_FILE_NAME) self.hosting_env = HostingEnv(xml_text) return self.hosting_env def get_shared_conf(self): if(self.shared_conf is None): - xml_text = _fetch_cache(SHARED_CONF_FILE_NAME) + xml_text = self.fetch_cache(SHARED_CONF_FILE_NAME) self.shared_conf = SharedConfig(xml_text) return self.shared_conf def get_certs(self): if(self.certs is None): - xml_text = _fetch_cache(Certificates) - self.certs = Certificates(xml_text) + xml_text = self.fetch_cache(CERTS_FILE_NAME) + self.certs = Certificates(self, xml_text) if self.certs is None: return None return self.certs @@ -518,20 +665,23 @@ class WireClient(object): def get_ext_conf(self): if(self.ext_conf is None): goal_state = self.get_goal_state() - local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) - xml_text = _fetch_cache(local_file) - self.ext_conf = ExtensionsConfig(xml_text) + if goal_state.ext_uri is None: + self.ext_conf = ExtensionsConfig(None) + else: + local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) + xml_text = self.fetch_cache(local_file) + self.ext_conf = ExtensionsConfig(xml_text) return self.ext_conf def get_ext_manifest(self, extension, goal_state): local_file = MANIFEST_FILE_NAME.format(extension.name, goal_state.incarnation) - xml_text = _fetch_cache(local_file) + xml_text = self.fetch_cache(local_file) return ExtensionManifest(xml_text) def check_wire_protocol_version(self): uri = VERSION_INFO_URI.format(self.endpoint) - version_info_xml = _fetch_uri(uri, None) + version_info_xml = self.fetch_config(uri, None) version_info = VersionInfo(version_info_xml) preferred = version_info.get_preferred() @@ -544,42 +694,50 @@ class WireClient(object): error = ("Agent supported wire protocol version: {0} was not " "advised by Fabric.").format(PROTOCOL_VERSION) raise ProtocolNotFound(error) - - def upload_status_blob(self, vm_status): + + def upload_status_blob(self): ext_conf = self.get_ext_conf() - status_blob = StatusBlob(vm_status) - status_blob.upload(ext_conf.status_upload_blob) + if ext_conf.status_upload_blob is not None: + self.status_blob.upload(ext_conf.status_upload_blob) def report_role_prop(self, thumbprint): goal_state = self.get_goal_state() role_prop = _build_role_properties(goal_state.container_id, goal_state.role_instance_id, thumbprint) + role_prop = role_prop.encode("utf-8") role_prop_uri = ROLE_PROP_URI.format(self.endpoint) - ret = restutil.http_post(role_prop_uri, - role_prop, - headers=self.get_header_for_xml_content()) - + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, role_prop_uri, + role_prop, headers = headers) + except restutil.HttpError as e: + raise ProtocolError((u"Failed to send role properties: {0}" + u"").format(e)) + if resp.status != httpclient.ACCEPTED: + raise ProtocolError((u"Failed to send role properties: {0}" + u", {1}").format(resp.status, resp.read())) def report_health(self, status, substatus, description): goal_state = self.get_goal_state() health_report = _build_health_report(goal_state.incarnation, - goal_state.container_id, - goal_state.role_instance_id, - status, - substatus, - description) + goal_state.container_id, + goal_state.role_instance_id, + status, + substatus, + description) + health_report = health_report.encode("utf-8") health_report_uri = HEALTH_REPORT_URI.format(self.endpoint) headers = self.get_header_for_xml_content() - resp = restutil.http_post(health_report_uri, - health_report, - headers=headers) - def prevent_throttling(self): - self.req_count += 1 - if self.req_count % 3 == 0: - logger.info("Sleep 15 before sending event to avoid throttling.") - self.req_count = 0 - time.sleep(15) + try: + resp = self.call_wireserver(restutil.http_post, health_report_uri, + health_report, headers = headers) + except restutil.HttpError as e: + raise ProtocolError((u"Failed to send provision status: {0}" + u"").format(e)) + if resp.status != httpclient.OK: + raise ProtocolError((u"Failed to send provision status: {0}" + u", {1}").format(resp.status, resp.read())) def send_event(self, provider_id, event_str): uri = TELEMETRY_URI.format(self.endpoint) @@ -590,9 +748,8 @@ class WireClient(object): '</TelemetryData>') data = data_format.format(provider_id, event_str) try: - self.prevent_throttling() header = self.get_header_for_xml_content() - resp = restutil.http_post(uri, data, header) + resp = self.call_wireserver(restutil.http_post, uri, data, header) except restutil.HttpError as e: raise ProtocolError("Failed to send events:{0}".format(e)) @@ -635,7 +792,7 @@ class WireClient(object): def get_header_for_cert(self): cert = "" - content = _fetch_cache(TRANSPORT_CERT_FILE_NAME) + content = self.fetch_cache(TRANSPORT_CERT_FILE_NAME) for line in content.split('\n'): if "CERTIFICATE" not in line: cert += line.rstrip() @@ -762,10 +919,9 @@ class Certificates(object): """ Object containing certificates of host and provisioned user. """ - def __init__(self, xml_text=None): - if xml_text is None: - raise ValueError("Certificates.xml is None") + def __init__(self, client, xml_text): logger.verb("Load Certificates.xml") + self.client = client self.lib_dir = OSUTIL.get_lib_dir() self.openssl_cmd = OSUTIL.get_openssl_cmd() self.cert_list = CertList() @@ -787,7 +943,7 @@ class Certificates(object): "\n" "{2}").format(P7M_FILE_NAME, P7M_FILE_NAME, data) - fileutil.write_file(os.path.join(self.lib_dir, P7M_FILE_NAME), p7m) + self.client.save_cache(os.path.join(self.lib_dir, P7M_FILE_NAME), p7m) #decrypt certificates cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3}" "| {4} pkcs12 -nodes -password pass: -out {5}" @@ -844,13 +1000,12 @@ class Certificates(object): for v1_cert in v1_cert_list: cert = Cert() - set_properties(cert, v1_cert) + set_properties("certs", cert, v1_cert) self.cert_list.certificates.append(cert) def write_to_tmp_file(self, index, suffix, buf): file_name = os.path.join(self.lib_dir, "{0}.{1}".format(index, suffix)) - with open(file_name, 'w') as tmp: - tmp.writelines(buf) + self.client.save_cache(file_name, "".join(buf)) return file_name @@ -861,12 +1016,11 @@ class ExtensionsConfig(object): """ def __init__(self, xml_text): - if xml_text is None: - raise ValueError("ExtensionsConfig is None") logger.verb("Load ExtensionsConfig.xml") - self.ext_list = ExtensionList() + self.ext_handlers = ExtHandlerList() self.status_upload_blob = None - self.parse(xml_text) + if xml_text is not None: + self.parse(xml_text) def parse(self, xml_text): """ @@ -879,38 +1033,38 @@ class ExtensionsConfig(object): plugin_settings = findall(plugin_settings_list, "Plugin") for plugin in plugins: - ext = self.parse_ext(plugin) - self.ext_list.extensions.append(ext) - self.parse_ext_settings(ext, plugin_settings) + ext_handler = self.parse_plugin(plugin) + self.ext_handlers.extHandlers.append(ext_handler) + self.parse_plugin_settings(ext_handler, plugin_settings) self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob") - def parse_ext(self, plugin): - ext = Extension() - ext.name = getattrib(plugin, "name") - ext.properties.version = getattrib(plugin, "version") - ext.properties.state = getattrib(plugin, "state") + def parse_plugin(self, plugin): + ext_handler = ExtHandler() + ext_handler.name = getattrib(plugin, "name") + ext_handler.properties.version = getattrib(plugin, "version") + ext_handler.properties.state = getattrib(plugin, "state") auto_upgrade = getattrib(plugin, "autoUpgrade") if auto_upgrade is not None and auto_upgrade.lower() == "true": - ext.properties.upgradePolicy = "auto" + ext_handler.properties.upgradePolicy = "auto" else: - ext.properties.upgradePolicy = "manual" + ext_handler.properties.upgradePolicy = "manual" location = getattrib(plugin, "location") failover_location = getattrib(plugin, "failoverlocation") for uri in [location, failover_location]: - version_uri = ExtensionVersionUri() + version_uri = ExtHandlerVersionUri() version_uri.uri = uri - ext.version_uris.append(version_uri) - return ext + ext_handler.versionUris.append(version_uri) + return ext_handler - def parse_ext_settings(self, ext, plugin_settings): + def parse_plugin_settings(self, ext_handler, plugin_settings): if plugin_settings is None: return - name = ext.name - version = ext.properties.version + name = ext_handler.name + version = ext_handler.properties.version settings = [x for x in plugin_settings \ if getattrib(x, "name") == name and \ getattrib(x ,"version") == version] @@ -930,20 +1084,23 @@ class ExtensionsConfig(object): for plugin_settings_list in runtime_settings["runtimeSettings"]: handler_settings = plugin_settings_list["handlerSettings"] - ext_settings = ExtensionSettings() - ext_settings.sequenceNumber = seqNo - ext_settings.publicSettings = handler_settings.get("publicSettings", None) - ext_settings.privateSettings = handler_settings.get("protectedSettings", None) - thumbprint = handler_settings.get("protectedSettingsCertThumbprint", None) - ext_settings.certificateThumbprint = thumbprint - ext.properties.extensions.append(ext_settings) + ext = Extension() + #There is no "extension name" in wire protocol. + #Put + ext.name = ext_handler.name + ext.sequenceNumber = seqNo + ext.publicSettings = handler_settings.get("publicSettings") + ext.privateSettings = handler_settings.get("protectedSettings") + thumbprint = handler_settings.get("protectedSettingsCertThumbprint") + ext.certificateThumbprint = thumbprint + ext_handler.properties.extensions.append(ext) class ExtensionManifest(object): def __init__(self, xml_text): if xml_text is None: raise ValueError("ExtensionManifest is None") logger.verb("Load ExtensionManifest.xml") - self.pkg_list = ExtensionPackageList() + self.pkg_list = ExtHandlerPackageList() self.parse(xml_text) def parse(self, xml_text): @@ -954,10 +1111,10 @@ class ExtensionManifest(object): uris = find(package, "Uris") uri_list = findall(uris, "Uri") uri_list = [gettext(x) for x in uri_list] - package = ExtensionPackage() + package = ExtHandlerPackage() package.version = version for uri in uri_list: - pkg_uri = ExtensionPackageUri() + pkg_uri = ExtHandlerVersionUri() pkg_uri.uri = uri package.uris.append(pkg_uri) self.pkg_list.versions.append(package) |