diff options
Diffstat (limited to 'azurelinuxagent/common/protocol')
-rw-r--r-- | azurelinuxagent/common/protocol/__init__.py | 21 | ||||
-rw-r--r-- | azurelinuxagent/common/protocol/hostplugin.py | 124 | ||||
-rw-r--r-- | azurelinuxagent/common/protocol/metadata.py | 223 | ||||
-rw-r--r-- | azurelinuxagent/common/protocol/ovfenv.py | 113 | ||||
-rw-r--r-- | azurelinuxagent/common/protocol/restapi.py | 272 | ||||
-rw-r--r-- | azurelinuxagent/common/protocol/util.py | 285 | ||||
-rw-r--r-- | azurelinuxagent/common/protocol/wire.py | 1218 |
7 files changed, 2256 insertions, 0 deletions
diff --git a/azurelinuxagent/common/protocol/__init__.py b/azurelinuxagent/common/protocol/__init__.py new file mode 100644 index 0000000..fb7c273 --- /dev/null +++ b/azurelinuxagent/common/protocol/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.common.protocol.util import get_protocol_util, \ + OVF_FILE_NAME, \ + TAG_FILE_NAME + diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py new file mode 100644 index 0000000..6569604 --- /dev/null +++ b/azurelinuxagent/common/protocol/hostplugin.py @@ -0,0 +1,124 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.common.protocol.wire import * +from azurelinuxagent.common.utils import textutil + +HOST_PLUGIN_PORT = 32526 +URI_FORMAT_GET_API_VERSIONS = "http://{0}:{1}/versions" +URI_FORMAT_PUT_VM_STATUS = "http://{0}:{1}/status" +URI_FORMAT_PUT_LOG = "http://{0}:{1}/vmAgentLog" +API_VERSION = "2015-09-01" + + +class HostPluginProtocol(object): + def __init__(self, endpoint): + if endpoint is None: + raise ProtocolError("Host plugin endpoint not provided") + self.is_initialized = False + self.is_available = False + self.api_versions = None + self.endpoint = endpoint + + def ensure_initialized(self): + if not self.is_initialized: + self.api_versions = self.get_api_versions() + self.is_available = API_VERSION in self.api_versions + self.is_initialized = True + return self.is_available + + def get_api_versions(self): + url = URI_FORMAT_GET_API_VERSIONS.format(self.endpoint, + HOST_PLUGIN_PORT) + logger.info("getting API versions at [{0}]".format(url)) + try: + response = restutil.http_get(url) + if response.status != httpclient.OK: + logger.error( + "get API versions returned status code [{0}]".format( + response.status)) + return [] + return response.read() + except HttpError as e: + logger.error("get API versions failed with [{0}]".format(e)) + return [] + + def put_vm_status(self, status_blob, sas_url): + """ + Try to upload the VM status via the host plugin /status channel + :param sas_url: the blob SAS url to pass to the host plugin + :type status_blob: StatusBlob + """ + if not self.ensure_initialized(): + logger.error("host plugin channel is not available") + return + if status_blob is None or status_blob.vm_status is None: + logger.error("no status data was provided") + return + url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT) + status = textutil.b64encode(status_blob.vm_status) + headers = {"x-ms-version": API_VERSION} + blob_headers = [{'headerName': 'x-ms-version', + 'headerValue': status_blob.__storage_version__}, + {'headerName': 'x-ms-blob-type', + 'headerValue': status_blob.type}] + data = json.dumps({'requestUri': sas_url, 'headers': blob_headers, + 'content': status}, sort_keys=True) + logger.info("put VM status at [{0}]".format(url)) + try: + response = restutil.http_put(url, data, headers) + if response.status != httpclient.OK: + logger.error("put VM status returned status code [{0}]".format( + response.status)) + except HttpError as e: + logger.error("put VM status failed with [{0}]".format(e)) + + def put_vm_log(self, content, container_id, deployment_id): + """ + Try to upload the given content to the host plugin + :param deployment_id: the deployment id, which is obtained from the + goal state (tenant name) + :param container_id: the container id, which is obtained from the + goal state + :param content: the binary content of the zip file to upload + :return: + """ + if not self.ensure_initialized(): + logger.error("host plugin channel is not available") + return + if content is None or container_id is None or deployment_id is None: + logger.error( + "invalid arguments passed: " + "[{0}], [{1}], [{2}]".format( + content, + container_id, + deployment_id)) + return + url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT) + + headers = {"x-ms-vmagentlog-deploymentid": deployment_id, + "x-ms-vmagentlog-containerid": container_id} + logger.info("put VM log at [{0}]".format(url)) + try: + response = restutil.http_put(url, content, headers) + if response.status != httpclient.OK: + logger.error("put log returned status code [{0}]".format( + response.status)) + except HttpError as e: + logger.error("put log failed with [{0}]".format(e)) diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py new file mode 100644 index 0000000..f86f72f --- /dev/null +++ b/azurelinuxagent/common/protocol/metadata.py @@ -0,0 +1,223 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import json +import shutil +import os +import time +from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.future import httpclient, ustr +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.restutil as restutil +import azurelinuxagent.common.utils.textutil as textutil +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.protocol.restapi import * + +METADATA_ENDPOINT='169.254.169.254' +APIVERSION='2015-05-01-preview' +BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" + +TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem" +TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem" + +#TODO remote workarround for azure stack +MAX_PING = 30 +RETRY_PING_INTERVAL = 10 + +def _add_content_type(headers): + if headers is None: + headers = {} + headers["content-type"] = "application/json" + return headers + +class MetadataProtocol(Protocol): + + def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT): + self.apiversion = apiversion + self.endpoint = endpoint + self.identity_uri = BASE_URI.format(self.endpoint, "identity", + self.apiversion, "&$expand=*") + self.cert_uri = BASE_URI.format(self.endpoint, "certificates", + self.apiversion, "&$expand=*") + self.ext_uri = BASE_URI.format(self.endpoint, "extensionHandlers", + self.apiversion, "&$expand=*") + self.vmagent_uri = BASE_URI.format(self.endpoint, "vmAgentVersions", + self.apiversion, "&$expand=*") + self.provision_status_uri = BASE_URI.format(self.endpoint, + "provisioningStatus", + self.apiversion, "") + self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent", + self.apiversion, "") + self.ext_status_uri = BASE_URI.format(self.endpoint, + "status/extensions/{0}", + self.apiversion, "") + self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry", + self.apiversion, "") + + def _get_data(self, url, headers=None): + try: + resp = restutil.http_get(url, headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if resp.status != httpclient.OK: + raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) + + data = resp.read() + etag = resp.getheader('ETag') + if data is None: + return None + data = json.loads(ustr(data, encoding="utf-8")) + return data, etag + + def _put_data(self, url, data, headers=None): + headers = _add_content_type(headers) + try: + resp = restutil.http_put(url, json.dumps(data), headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + if resp.status != httpclient.OK: + raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) + + def _post_data(self, url, data, headers=None): + headers = _add_content_type(headers) + try: + resp = restutil.http_post(url, json.dumps(data), headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + if resp.status != httpclient.CREATED: + raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) + + def _get_trans_cert(self): + trans_crt_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + if not os.path.isfile(trans_crt_file): + raise ProtocolError("{0} is missing.".format(trans_crt_file)) + content = fileutil.read_file(trans_crt_file) + return textutil.get_bytes_from_pem(content) + + def detect(self): + self.get_vminfo() + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + #"Install" the cert and private key to /var/lib/waagent + thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) + prv_file = os.path.join(conf.get_lib_dir(), + "{0}.prv".format(thumbprint)) + crt_file = os.path.join(conf.get_lib_dir(), + "{0}.crt".format(thumbprint)) + shutil.copyfile(trans_prv_file, prv_file) + shutil.copyfile(trans_cert_file, crt_file) + + + def get_vminfo(self): + vminfo = VMInfo() + data, etag = self._get_data(self.identity_uri) + set_properties("vminfo", vminfo, data) + return vminfo + + def get_certs(self): + #TODO download and save certs + return CertList() + + def get_vmagent_manifests(self, last_etag=None): + manifests = VMAgentManifestList() + data, etag = self._get_data(self.vmagent_uri) + if last_etag == None or last_etag < etag: + set_properties("vmAgentManifests", manifests.vmAgentManifests, data) + return manifests, etag + + def get_vmagent_pkgs(self, vmagent_manifest): + #Agent package is the same with extension handler + vmagent_pkgs = ExtHandlerPackageList() + data = None + for manifest_uri in vmagent_manifest.versionsManifestUris: + try: + data = self._get_data(manifest_uri.uri) + break + except ProtocolError as e: + logger.warn("Failed to get vmagent versions: {0}", e) + logger.info("Retry getting vmagent versions") + if data is None: + raise ProtocolError(("Failed to get versions for vm agent: {0}" + "").format(vmagent_manifest.family)) + set_properties("vmAgentVersions", vmagent_pkgs, data) + # TODO: What etag should this return? + return vmagent_pkgs + + def get_ext_handlers(self, last_etag=None): + headers = { + "x-ms-vmagent-public-x509-cert": self._get_trans_cert() + } + ext_list = ExtHandlerList() + data, etag = self._get_data(self.ext_uri, headers=headers) + if last_etag == None or last_etag < etag: + set_properties("extensionHandlers", ext_list.extHandlers, data) + return ext_list, etag + + def get_ext_handler_pkgs(self, ext_handler): + ext_handler_pkgs = ExtHandlerPackageList() + data = None + for version_uri in ext_handler.versionUris: + try: + data, etag = self._get_data(version_uri.uri) + break + except ProtocolError as e: + logger.warn("Failed to get version uris: {0}", e) + logger.info("Retry getting version uris") + set_properties("extensionPackages", ext_handler_pkgs, data) + return ext_handler_pkgs + + def report_provision_status(self, provision_status): + validate_param('provisionStatus', provision_status, ProvisionStatus) + data = get_properties(provision_status) + self._put_data(self.provision_status_uri, data) + + def report_vm_status(self, vm_status): + validate_param('vmStatus', vm_status, VMStatus) + data = get_properties(vm_status) + #TODO code field is not implemented for metadata protocol yet. Remove it + handler_statuses = data['vmAgent']['extensionHandlers'] + for handler_status in handler_statuses: + try: + handler_status.pop('code', None) + except KeyError: + pass + + self._put_data(self.vm_status_uri, data) + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validate_param('extensionStatus', ext_status, ExtensionStatus) + data = get_properties(ext_status) + uri = self.ext_status_uri.format(ext_name) + self._put_data(uri, data) + + def report_event(self, events): + #TODO disable telemetry for azure stack test + #validate_param('events', events, TelemetryEventList) + #data = get_properties(events) + #self._post_data(self.event_uri, data) + pass + diff --git a/azurelinuxagent/common/protocol/ovfenv.py b/azurelinuxagent/common/protocol/ovfenv.py new file mode 100644 index 0000000..4901871 --- /dev/null +++ b/azurelinuxagent/common/protocol/ovfenv.py @@ -0,0 +1,113 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +""" +Copy and parse ovf-env.xml from provisioning ISO and local cache +""" +import os +import re +import shutil +import xml.dom.minidom as minidom +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext + +OVF_VERSION = "1.0" +OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1" +WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure" + +def _validate_ovf(val, msg): + if val is None: + raise ProtocolError("Failed to parse OVF XML: {0}".format(msg)) + +class OvfEnv(object): + """ + Read, and process provisioning info from provisioning file OvfEnv.xml + """ + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("ovf-env is None") + logger.verbose("Load ovf-env.xml") + self.hostname = None + self.username = None + self.user_password = None + self.customdata = None + self.disable_ssh_password_auth = True + self.ssh_pubkeys = [] + self.ssh_keypairs = [] + self.parse(xml_text) + + def parse(self, xml_text): + """ + Parse xml tree, retreiving user and ssh key information. + Return self. + """ + wans = WA_NAME_SPACE + ovfns = OVF_NAME_SPACE + + xml_doc = parse_doc(xml_text) + + environment = find(xml_doc, "Environment", namespace=ovfns) + _validate_ovf(environment, "Environment not found") + + section = find(environment, "ProvisioningSection", namespace=wans) + _validate_ovf(section, "ProvisioningSection not found") + + version = findtext(environment, "Version", namespace=wans) + _validate_ovf(version, "Version not found") + + if version > OVF_VERSION: + logger.warn("Newer provisioning configuration detected. " + "Please consider updating waagent") + + conf_set = find(section, "LinuxProvisioningConfigurationSet", + namespace=wans) + _validate_ovf(conf_set, "LinuxProvisioningConfigurationSet not found") + + self.hostname = findtext(conf_set, "HostName", namespace=wans) + _validate_ovf(self.hostname, "HostName not found") + + self.username = findtext(conf_set, "UserName", namespace=wans) + _validate_ovf(self.username, "UserName not found") + + self.user_password = findtext(conf_set, "UserPassword", namespace=wans) + + self.customdata = findtext(conf_set, "CustomData", namespace=wans) + + auth_option = findtext(conf_set, "DisableSshPasswordAuthentication", + namespace=wans) + if auth_option is not None and auth_option.lower() == "true": + self.disable_ssh_password_auth = True + else: + self.disable_ssh_password_auth = False + + public_keys = findall(conf_set, "PublicKey", namespace=wans) + for public_key in public_keys: + path = findtext(public_key, "Path", namespace=wans) + fingerprint = findtext(public_key, "Fingerprint", namespace=wans) + value = findtext(public_key, "Value", namespace=wans) + self.ssh_pubkeys.append((path, fingerprint, value)) + + keypairs = findall(conf_set, "KeyPair", namespace=wans) + for keypair in keypairs: + path = findtext(keypair, "Path", namespace=wans) + fingerprint = findtext(keypair, "Fingerprint", namespace=wans) + self.ssh_keypairs.append((path, fingerprint)) + diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py new file mode 100644 index 0000000..7f00488 --- /dev/null +++ b/azurelinuxagent/common/protocol/restapi.py @@ -0,0 +1,272 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import copy +import re +import json +import xml.dom.minidom +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.restutil as restutil + +def validate_param(name, val, expected_type): + if val is None: + raise ProtocolError("{0} is None".format(name)) + if not isinstance(val, expected_type): + raise ProtocolError(("{0} type should be {1} not {2}" + "").format(name, expected_type, type(val))) + +def set_properties(name, obj, data): + if isinstance(obj, DataContract): + validate_param("Property '{0}'".format(name), data, dict) + for prob_name, prob_val in data.items(): + prob_full_name = "{0}.{1}".format(name, prob_name) + try: + prob = getattr(obj, prob_name) + except AttributeError: + logger.warn("Unknown property: {0}", prob_full_name) + continue + prob = set_properties(prob_full_name, prob, prob_val) + setattr(obj, prob_name, prob) + return obj + elif isinstance(obj, DataContractList): + validate_param("List '{0}'".format(name), data, list) + for item_data in data: + item = obj.item_cls() + item = set_properties(name, item, item_data) + obj.append(item) + return obj + else: + return data + +def get_properties(obj): + if isinstance(obj, DataContract): + data = {} + props = vars(obj) + for prob_name, prob in list(props.items()): + data[prob_name] = get_properties(prob) + return data + elif isinstance(obj, DataContractList): + data = [] + for item in obj: + item_data = get_properties(item) + data.append(item_data) + return data + else: + return obj + +class DataContract(object): + pass + +class DataContractList(list): + def __init__(self, item_cls): + self.item_cls = item_cls + +""" +Data contract between guest and host +""" +class VMInfo(DataContract): + def __init__(self, subscriptionId=None, vmName=None, containerId=None, + roleName=None, roleInstanceName=None, tenantName=None): + self.subscriptionId = subscriptionId + self.vmName = vmName + self.containerId = containerId + self.roleName = roleName + self.roleInstanceName = roleInstanceName + self.tenantName = tenantName + +class Cert(DataContract): + def __init__(self, name=None, thumbprint=None, certificateDataUri=None): + self.name = name + self.thumbprint = thumbprint + self.certificateDataUri = certificateDataUri + +class CertList(DataContract): + def __init__(self): + self.certificates = DataContractList(Cert) + +#TODO: confirm vmagent manifest schema +class VMAgentManifestUri(DataContract): + def __init__(self, uri=None): + self.uri = uri + +class VMAgentManifest(DataContract): + def __init__(self, family=None): + self.family = family + self.versionsManifestUris = DataContractList(VMAgentManifestUri) + +class VMAgentManifestList(DataContract): + def __init__(self): + self.vmAgentManifests = DataContractList(VMAgentManifest) + +class Extension(DataContract): + def __init__(self, name=None, sequenceNumber=None, publicSettings=None, + protectedSettings=None, certificateThumbprint=None): + self.name = name + self.sequenceNumber = sequenceNumber + self.publicSettings = publicSettings + self.protectedSettings = protectedSettings + self.certificateThumbprint = certificateThumbprint + +class ExtHandlerProperties(DataContract): + def __init__(self): + self.version = None + self.upgradePolicy = None + self.state = None + self.extensions = DataContractList(Extension) + +class ExtHandlerVersionUri(DataContract): + def __init__(self): + self.uri = None + +class ExtHandler(DataContract): + def __init__(self, name=None): + self.name = name + self.properties = ExtHandlerProperties() + self.versionUris = DataContractList(ExtHandlerVersionUri) + +class ExtHandlerList(DataContract): + def __init__(self): + self.extHandlers = DataContractList(ExtHandler) + +class ExtHandlerPackageUri(DataContract): + def __init__(self, uri=None): + self.uri = uri + +class ExtHandlerPackage(DataContract): + def __init__(self, version = None): + self.version = version + self.uris = DataContractList(ExtHandlerPackageUri) + # TODO update the naming to align with metadata protocol + self.isinternal = False + +class ExtHandlerPackageList(DataContract): + def __init__(self): + self.versions = DataContractList(ExtHandlerPackage) + +class VMProperties(DataContract): + def __init__(self, certificateThumbprint=None): + #TODO need to confirm the property name + self.certificateThumbprint = certificateThumbprint + +class ProvisionStatus(DataContract): + def __init__(self, status=None, subStatus=None, description=None): + self.status = status + self.subStatus = subStatus + self.description = description + self.properties = VMProperties() + +class ExtensionSubStatus(DataContract): + def __init__(self, name=None, status=None, code=None, message=None): + self.name = name + self.status = status + self.code = code + self.message = message + +class ExtensionStatus(DataContract): + def __init__(self, configurationAppliedTime=None, operation=None, + status=None, seq_no=None, code=None, message=None): + self.configurationAppliedTime = configurationAppliedTime + self.operation = operation + self.status = status + self.sequenceNumber = seq_no + self.code = code + self.message = message + self.substatusList = DataContractList(ExtensionSubStatus) + +class ExtHandlerStatus(DataContract): + def __init__(self, name=None, version=None, status=None, code=0, + message=None): + self.name = name + self.version = version + self.status = status + self.code = code + self.message = message + self.extensions = DataContractList(ustr) + +class VMAgentStatus(DataContract): + def __init__(self, version=None, status=None, message=None): + self.version = version + self.status = status + self.message = message + self.extensionHandlers = DataContractList(ExtHandlerStatus) + +class VMStatus(DataContract): + def __init__(self): + self.vmAgent = VMAgentStatus() + +class TelemetryEventParam(DataContract): + def __init__(self, name=None, value=None): + self.name = name + self.value = value + +class TelemetryEvent(DataContract): + def __init__(self, eventId=None, providerId=None): + self.eventId = eventId + self.providerId = providerId + self.parameters = DataContractList(TelemetryEventParam) + +class TelemetryEventList(DataContract): + def __init__(self): + self.events = DataContractList(TelemetryEvent) + +class Protocol(DataContract): + + def detect(self): + raise NotImplementedError() + + def get_vminfo(self): + raise NotImplementedError() + + def get_certs(self): + raise NotImplementedError() + + def get_vmagent_manifests(self): + raise NotImplementedError() + + def get_vmagent_pkgs(self): + raise NotImplementedError() + + def get_ext_handlers(self): + raise NotImplementedError() + + def get_ext_handler_pkgs(self, extension): + raise NotImplementedError() + + def download_ext_handler_pkg(self, uri): + try: + resp = restutil.http_get(uri, chk_proxy=True) + if resp.status == restutil.httpclient.OK: + return resp.read() + except HttpError as e: + raise ProtocolError("Failed to download from: {0}".format(uri), e) + + def report_provision_status(self, provision_status): + raise NotImplementedError() + + def report_vm_status(self, vm_status): + raise NotImplementedError() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + raise NotImplementedError() + + def report_event(self, event): + raise NotImplementedError() + diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py new file mode 100644 index 0000000..7e7a74f --- /dev/null +++ b/azurelinuxagent/common/protocol/util.py @@ -0,0 +1,285 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import re +import shutil +import time +import threading +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \ + ProtocolNotFoundError, DhcpError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.osutil import get_osutil +from azurelinuxagent.common.dhcp import get_dhcp_handler +from azurelinuxagent.common.protocol.ovfenv import OvfEnv +from azurelinuxagent.common.protocol.wire import WireProtocol +from azurelinuxagent.common.protocol.metadata import MetadataProtocol, \ + METADATA_ENDPOINT +import azurelinuxagent.common.utils.shellutil as shellutil + +OVF_FILE_NAME = "ovf-env.xml" + +#Tag file to indicate usage of metadata protocol +TAG_FILE_NAME = "useMetadataEndpoint.tag" + +PROTOCOL_FILE_NAME = "Protocol" + +#MAX retry times for protocol probing +MAX_RETRY = 360 + +PROBE_INTERVAL = 10 + +ENDPOINT_FILE_NAME = "WireServerEndpoint" + +def get_protocol_util(): + return ProtocolUtil() + +class ProtocolUtil(object): + """ + ProtocolUtil handles initialization for protocol instance. 2 protocol types + are invoked, wire protocol and metadata protocols. + """ + def __init__(self): + self.lock = threading.Lock() + self.protocol = None + self.osutil = get_osutil() + self.dhcp_handler = get_dhcp_handler() + + def copy_ovf_env(self): + """ + Copy ovf env file from dvd to hard disk. + Remove password before save it to the disk + """ + dvd_mount_point = conf.get_dvd_mount_point() + ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME) + tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME) + try: + self.osutil.mount_dvd() + ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) + ovfenv = OvfEnv(ovfxml) + ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml) + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + fileutil.write_file(ovf_file_path, ovfxml) + + if os.path.isfile(tag_file_path_on_dvd): + logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME) + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + shutil.copyfile(tag_file_path_on_dvd, tag_file_path) + + except (OSUtilError, IOError) as e: + raise ProtocolError(ustr(e)) + + try: + self.osutil.umount_dvd() + self.osutil.eject_dvd() + except OSUtilError as e: + logger.warn(ustr(e)) + + return ovfenv + + def get_ovf_env(self): + """ + Load saved ovf-env.xml + """ + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + if os.path.isfile(ovf_file_path): + xml_text = fileutil.read_file(ovf_file_path) + return OvfEnv(xml_text) + else: + raise ProtocolError("ovf-env.xml is missing.") + + def _get_wireserver_endpoint(self): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + return fileutil.read_file(file_path) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _set_wireserver_endpoint(self, endpoint): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + fileutil.write_file(file_path, endpoint) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _detect_wire_protocol(self): + endpoint = self.dhcp_handler.endpoint + if endpoint is None: + logger.info("WireServer endpoint is not found. Rerun dhcp handler") + try: + self.dhcp_handler.run() + except DhcpError as e: + raise ProtocolError(ustr(e)) + endpoint = self.dhcp_handler.endpoint + + try: + protocol = WireProtocol(endpoint) + protocol.detect() + self._set_wireserver_endpoint(endpoint) + self.save_protocol("WireProtocol") + return protocol + except ProtocolError as e: + logger.info("WireServer is not responding. Reset endpoint") + self.dhcp_handler.endpoint = None + self.dhcp_handler.skip_cache = True + raise e + + def _detect_metadata_protocol(self): + protocol = MetadataProtocol() + protocol.detect() + + #Only allow root access METADATA_ENDPOINT + self.osutil.set_admin_access_to_ip(METADATA_ENDPOINT) + + self.save_protocol("MetadataProtocol") + + return protocol + + def _detect_protocol(self, protocols): + """ + Probe protocol endpoints in turn. + """ + self.clear_protocol() + + for retry in range(0, MAX_RETRY): + for protocol in protocols: + try: + if protocol == "WireProtocol": + return self._detect_wire_protocol() + + if protocol == "MetadataProtocol": + return self._detect_metadata_protocol() + + except ProtocolError as e: + logger.info("Protocol endpoint not found: {0}, {1}", + protocol, e) + + if retry < MAX_RETRY -1: + logger.info("Retry detect protocols: retry={0}", retry) + time.sleep(PROBE_INTERVAL) + raise ProtocolNotFoundError("No protocol found.") + + def _get_protocol(self): + """ + Get protocol instance based on previous detecting result. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), + PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + raise ProtocolNotFoundError("No protocol found") + + protocol_name = fileutil.read_file(protocol_file_path) + if protocol_name == "WireProtocol": + endpoint = self._get_wireserver_endpoint() + return WireProtocol(endpoint) + elif protocol_name == "MetadataProtocol": + return MetadataProtocol() + else: + raise ProtocolNotFoundError(("Unknown protocol: {0}" + "").format(protocol_name)) + + def save_protocol(self, protocol_name): + """ + Save protocol endpoint + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + try: + fileutil.write_file(protocol_file_path, protocol_name) + except IOError as e: + logger.error("Failed to save protocol endpoint: {0}", e) + + + def clear_protocol(self): + """ + Cleanup previous saved endpoint. + """ + logger.info("Clean protocol") + self.protocol = None + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + return + + try: + os.remove(protocol_file_path) + except IOError as e: + logger.error("Failed to clear protocol endpoint: {0}", e) + + def get_protocol(self): + """ + Detect protocol by endpoints + + :returns: protocol instance + """ + self.lock.acquire() + + try: + if self.protocol is not None: + return self.protocol + + try: + self.protocol = self._get_protocol() + return self.protocol + except ProtocolNotFoundError: + pass + + logger.info("Detect protocol endpoints") + protocols = ["WireProtocol", "MetadataProtocol"] + self.protocol = self._detect_protocol(protocols) + + return self.protocol + + finally: + self.lock.release() + + + def get_protocol_by_file(self): + """ + Detect protocol by tag file. + + If a file "useMetadataEndpoint.tag" is found on provision iso, + metedata protocol will be used. No need to probe for wire protocol + + :returns: protocol instance + """ + self.lock.acquire() + + try: + if self.protocol is not None: + return self.protocol + + try: + self.protocol = self._get_protocol() + return self.protocol + except ProtocolNotFoundError: + pass + + logger.info("Detect protocol by file") + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + protocols = [] + if os.path.isfile(tag_file_path): + protocols.append("MetadataProtocol") + else: + protocols.append("WireProtocol") + self.protocol = self._detect_protocol(protocols) + return self.protocol + + finally: + self.lock.release() diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py new file mode 100644 index 0000000..29a1663 --- /dev/null +++ b/azurelinuxagent/common/protocol/wire.py @@ -0,0 +1,1218 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import time +import xml.sax.saxutils as saxutils +import azurelinuxagent.common.conf as conf +from azurelinuxagent.common.exception import ProtocolNotFoundError +from azurelinuxagent.common.future import httpclient, bytebuffer +from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext, \ + getattrib, gettext, remove_bom, get_bytes_from_pem +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.protocol.restapi import * +from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol + +VERSION_INFO_URI = "http://{0}/?comp=versions" +GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate" +HEALTH_REPORT_URI = "http://{0}/machine?comp=health" +ROLE_PROP_URI = "http://{0}/machine?comp=roleProperties" +TELEMETRY_URI = "http://{0}/machine?comp=telemetrydata" + +WIRE_SERVER_ADDR_FILE_NAME = "WireServer" +INCARNATION_FILE_NAME = "Incarnation" +GOAL_STATE_FILE_NAME = "GoalState.{0}.xml" +HOSTING_ENV_FILE_NAME = "HostingEnvironmentConfig.xml" +SHARED_CONF_FILE_NAME = "SharedConfig.xml" +CERTS_FILE_NAME = "Certificates.xml" +P7M_FILE_NAME = "Certificates.p7m" +PEM_FILE_NAME = "Certificates.pem" +EXT_CONF_FILE_NAME = "ExtensionsConfig.{0}.xml" +MANIFEST_FILE_NAME = "{0}.{1}.manifest.xml" +TRANSPORT_CERT_FILE_NAME = "TransportCert.pem" +TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" + +PROTOCOL_VERSION = "2012-11-30" +ENDPOINT_FINE_NAME = "WireServer" + +SHORT_WAITING_INTERVAL = 1 # 1 second +LONG_WAITING_INTERVAL = 15 # 15 seconds + + +class UploadError(HttpError): + pass + + +class WireProtocolResourceGone(ProtocolError): + pass + + +class WireProtocol(Protocol): + """Slim layer to adapt wire protocol data to metadata protocol interface""" + + # TODO: Clean-up goal state processing + # At present, some methods magically update GoalState (e.g., get_vmagent_manifests), others (e.g., get_vmagent_pkgs) + # assume its presence. A better approach would make an explicit update call that returns the incarnation number and + # establishes that number the "context" for all other calls (either by updating the internal state of the protocol or + # by having callers pass the incarnation number to the method). + + def __init__(self, endpoint): + if endpoint is None: + raise ProtocolError("WireProtocol endpoint is None") + self.endpoint = endpoint + self.client = WireClient(self.endpoint) + + def detect(self): + self.client.check_wire_protocol_version() + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + self.client.update_goal_state(forced=True) + + def get_vminfo(self): + goal_state = self.client.get_goal_state() + hosting_env = self.client.get_hosting_env() + + vminfo = VMInfo() + vminfo.subscriptionId = None + vminfo.vmName = hosting_env.vm_name + vminfo.tenantName = hosting_env.deployment_name + vminfo.roleName = hosting_env.role_name + vminfo.roleInstanceName = goal_state.role_instance_id + vminfo.containerId = goal_state.container_id + return vminfo + + def get_certs(self): + certificates = self.client.get_certs() + return certificates.cert_list + + def get_vmagent_manifests(self): + # Update goal state to get latest extensions config + self.client.update_goal_state() + goal_state = self.client.get_goal_state() + ext_conf = self.client.get_ext_conf() + return ext_conf.vmagent_manifests, goal_state.incarnation + + def get_vmagent_pkgs(self, vmagent_manifest): + goal_state = self.client.get_goal_state() + man = self.client.get_gafamily_manifest(vmagent_manifest, goal_state) + return man.pkg_list + + def get_ext_handlers(self): + logger.verbose("Get extension handler config") + # Update goal state to get latest extensions config + self.client.update_goal_state() + goal_state = self.client.get_goal_state() + ext_conf = self.client.get_ext_conf() + # In wire protocol, incarnation is equivalent to ETag + return ext_conf.ext_handlers, goal_state.incarnation + + def get_ext_handler_pkgs(self, ext_handler): + logger.verbose("Get extension handler package") + goal_state = self.client.get_goal_state() + man = self.client.get_ext_manifest(ext_handler, goal_state) + return man.pkg_list + + def report_provision_status(self, provision_status): + validate_param("provision_status", provision_status, ProvisionStatus) + + if provision_status.status is not None: + self.client.report_health(provision_status.status, + provision_status.subStatus, + provision_status.description) + if provision_status.properties.certificateThumbprint is not None: + thumbprint = provision_status.properties.certificateThumbprint + self.client.report_role_prop(thumbprint) + + def report_vm_status(self, vm_status): + validate_param("vm_status", vm_status, VMStatus) + self.client.status_blob.set_vm_status(vm_status) + self.client.upload_status_blob() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validate_param("ext_status", ext_status, ExtensionStatus) + self.client.status_blob.set_ext_status(ext_handler_name, ext_status) + + def report_event(self, events): + validate_param("events", events, TelemetryEventList) + self.client.report_event(events) + + +def _build_role_properties(container_id, role_instance_id, thumbprint): + xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>" + u"<RoleProperties>" + u"<Container>" + u"<ContainerId>{0}</ContainerId>" + u"<RoleInstances>" + u"<RoleInstance>" + u"<Id>{1}</Id>" + u"<Properties>" + u"<Property name=\"CertificateThumbprint\" value=\"{2}\" />" + u"</Properties>" + u"</RoleInstance>" + u"</RoleInstances>" + u"</Container>" + u"</RoleProperties>" + u"").format(container_id, role_instance_id, thumbprint) + return xml + + +def _build_health_report(incarnation, container_id, role_instance_id, + status, substatus, description): + # Escape '&', '<' and '>' + description = saxutils.escape(ustr(description)) + detail = u'' + if substatus is not None: + substatus = saxutils.escape(ustr(substatus)) + detail = (u"<Details>" + u"<SubStatus>{0}</SubStatus>" + u"<Description>{1}</Description>" + u"</Details>").format(substatus, description) + xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>" + u"<Health " + u"xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"" + u" xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\">" + u"<GoalStateIncarnation>{0}</GoalStateIncarnation>" + u"<Container>" + u"<ContainerId>{1}</ContainerId>" + u"<RoleInstanceList>" + u"<Role>" + u"<InstanceId>{2}</InstanceId>" + u"<Health>" + u"<State>{3}</State>" + u"{4}" + u"</Health>" + u"</Role>" + u"</RoleInstanceList>" + u"</Container>" + u"</Health>" + u"").format(incarnation, + container_id, + role_instance_id, + status, + detail) + return xml + + +""" +Convert VMStatus object to status blob format +""" + + +def ga_status_to_v1(ga_status): + formatted_msg = { + 'lang': 'en-US', + 'message': ga_status.message + } + v1_ga_status = { + 'version': ga_status.version, + 'status': ga_status.status, + 'formattedMessage': formatted_msg + } + return v1_ga_status + + +def ext_substatus_to_v1(sub_status_list): + status_list = [] + for substatus in sub_status_list: + status = { + "name": substatus.name, + "status": substatus.status, + "code": substatus.code, + "formattedMessage": { + "lang": "en-US", + "message": substatus.message + } + } + status_list.append(status) + return status_list + + +def ext_status_to_v1(ext_name, ext_status): + if ext_status is None: + return None + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + v1_sub_status = ext_substatus_to_v1(ext_status.substatusList) + v1_ext_status = { + "status": { + "name": ext_name, + "configurationAppliedTime": ext_status.configurationAppliedTime, + "operation": ext_status.operation, + "status": ext_status.status, + "code": ext_status.code, + "formattedMessage": { + "lang": "en-US", + "message": ext_status.message + } + }, + "version": 1.0, + "timestampUTC": timestamp + } + if len(v1_sub_status) != 0: + v1_ext_status['substatus'] = v1_sub_status + return v1_ext_status + + +def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): + v1_handler_status = { + 'handlerVersion': handler_status.version, + 'handlerName': handler_status.name, + 'status': handler_status.status, + 'code': handler_status.code + } + if handler_status.message is not None: + v1_handler_status["formattedMessage"] = { + "lang": "en-US", + "message": handler_status.message + } + + if len(handler_status.extensions) > 0: + # Currently, no more than one extension per handler + ext_name = handler_status.extensions[0] + ext_status = ext_statuses.get(ext_name) + v1_ext_status = ext_status_to_v1(ext_name, ext_status) + if ext_status is not None and v1_ext_status is not None: + v1_handler_status["runtimeSettingsStatus"] = { + 'settingsStatus': v1_ext_status, + 'sequenceNumber': ext_status.sequenceNumber + } + return v1_handler_status + + +def vm_status_to_v1(vm_status, ext_statuses): + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + v1_ga_status = ga_status_to_v1(vm_status.vmAgent) + v1_handler_status_list = [] + for handler_status in vm_status.vmAgent.extensionHandlers: + v1_handler_status = ext_handler_status_to_v1(handler_status, + ext_statuses, timestamp) + if v1_handler_status is not None: + v1_handler_status_list.append(v1_handler_status) + + v1_agg_status = { + 'guestAgentStatus': v1_ga_status, + 'handlerAggregateStatus': v1_handler_status_list + } + v1_vm_status = { + 'version': '1.0', + 'timestampUTC': timestamp, + 'aggregateStatus': v1_agg_status + } + return v1_vm_status + + +class StatusBlob(object): + def __init__(self, client): + self.vm_status = None + self.ext_statuses = {} + self.client = client + self.type = None + self.data = None + + def set_vm_status(self, vm_status): + validate_param("vmAgent", vm_status, VMStatus) + self.vm_status = vm_status + + def set_ext_status(self, ext_handler_name, ext_status): + validate_param("extensionStatus", ext_status, ExtensionStatus) + self.ext_statuses[ext_handler_name] = ext_status + + def to_json(self): + report = vm_status_to_v1(self.vm_status, self.ext_statuses) + return json.dumps(report) + + __storage_version__ = "2014-02-14" + + def upload(self, url): + # TODO upload extension only if content has changed + logger.verbose("Upload status blob") + upload_successful = False + self.type = self.get_blob_type(url) + self.data = self.to_json() + try: + if self.type == "BlockBlob": + self.put_block_blob(url, self.data) + elif self.type == "PageBlob": + self.put_page_blob(url, self.data) + else: + raise ProtocolError("Unknown blob type: {0}".format(self.type)) + except HttpError as e: + logger.warn("Initial upload failed [{0}]".format(e)) + else: + logger.verbose("Uploading status blob succeeded") + upload_successful = True + return upload_successful + + def get_blob_type(self, url): + # Check blob type + logger.verbose("Check blob type.") + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + try: + resp = self.client.call_storage_service(restutil.http_head, url, { + "x-ms-date": timestamp, + 'x-ms-version': self.__class__.__storage_version__ + }) + except HttpError as e: + raise ProtocolError((u"Failed to get status blob type: {0}" + u"").format(e)) + if resp is None or resp.status != httpclient.OK: + raise ProtocolError(("Failed to get status blob type: {0}" + "").format(resp.status)) + + blob_type = resp.getheader("x-ms-blob-type") + logger.verbose("Blob type={0}".format(blob_type)) + return blob_type + + def put_block_blob(self, url, data): + logger.verbose("Upload block blob") + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + resp = self.client.call_storage_service(restutil.http_put, url, data, + { + "x-ms-date": timestamp, + "x-ms-blob-type": "BlockBlob", + "Content-Length": ustr(len(data)), + "x-ms-version": self.__class__.__storage_version__ + }) + if resp.status != httpclient.CREATED: + raise UploadError( + "Failed to upload block blob: {0}".format(resp.status)) + + def put_page_blob(self, url, data): + logger.verbose("Replace old page blob") + + # Convert string into bytes + data = bytearray(data, encoding='utf-8') + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + # Align to 512 bytes + page_blob_size = int((len(data) + 511) / 512) * 512 + resp = self.client.call_storage_service(restutil.http_put, url, "", + { + "x-ms-date": timestamp, + "x-ms-blob-type": "PageBlob", + "Content-Length": "0", + "x-ms-blob-content-length": ustr(page_blob_size), + "x-ms-version": self.__class__.__storage_version__ + }) + if resp.status != httpclient.CREATED: + raise UploadError( + "Failed to clean up page blob: {0}".format(resp.status)) + + if url.count("?") < 0: + url = "{0}?comp=page".format(url) + else: + url = "{0}&comp=page".format(url) + + logger.verbose("Upload page blob") + page_max = 4 * 1024 * 1024 # Max page size: 4MB + start = 0 + end = 0 + while end < len(data): + end = min(len(data), start + page_max) + content_size = end - start + # Align to 512 bytes + page_end = int((end + 511) / 512) * 512 + buf_size = page_end - start + buf = bytearray(buf_size) + buf[0: content_size] = data[start: end] + resp = self.client.call_storage_service( + restutil.http_put, url, bytebuffer(buf), + { + "x-ms-date": timestamp, + "x-ms-range": "bytes={0}-{1}".format(start, page_end - 1), + "x-ms-page-write": "update", + "x-ms-version": self.__class__.__storage_version__, + "Content-Length": ustr(page_end - start) + }) + if resp is None or resp.status != httpclient.CREATED: + raise UploadError( + "Failed to upload page blob: {0}".format(resp.status)) + start = end + + +def event_param_to_v1(param): + param_format = '<Param Name="{0}" Value={1} T="{2}" />' + param_type = type(param.value) + attr_type = "" + if param_type is int: + attr_type = 'mt:uint64' + elif param_type is str: + attr_type = 'mt:wstr' + elif ustr(param_type).count("'unicode'") > 0: + attr_type = 'mt:wstr' + elif param_type is bool: + attr_type = 'mt:bool' + elif param_type is float: + attr_type = 'mt:float64' + return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)), + attr_type) + + +def event_to_v1(event): + params = "" + for param in event.parameters: + params += event_param_to_v1(param) + event_str = ('<Event id="{0}">' + '<![CDATA[{1}]]>' + '</Event>').format(event.eventId, params) + return event_str + + +class WireClient(object): + def __init__(self, endpoint): + logger.info("Wire server endpoint:{0}", endpoint) + self.endpoint = endpoint + self.goal_state = None + self.updated = None + self.hosting_env = None + self.shared_conf = None + self.certs = None + self.ext_conf = None + self.last_request = 0 + self.req_count = 0 + self.status_blob = StatusBlob(self) + self.host_plugin = HostPluginProtocol(self.endpoint) + + def prevent_throttling(self): + """ + Try to avoid throttling of wire server + """ + now = time.time() + if now - self.last_request < 1: + logger.verbose("Last request issued less than 1 second ago") + logger.verbose("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.last_request = now + + self.req_count += 1 + if self.req_count % 3 == 0: + logger.verbose("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.req_count = 0 + + def call_wireserver(self, http_req, *args, **kwargs): + """ + Call wire server. Handle throttling(403) and Resource Gone(410) + """ + self.prevent_throttling() + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.FORBIDDEN: + logger.warn("Sending too much request to wire server") + logger.info("Sleep {0} second to avoid throttling.", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + elif resp.status == httpclient.GONE: + msg = args[0] if len(args) > 0 else "" + raise WireProtocolResourceGone(msg) + else: + return resp + raise ProtocolError(("Calling wire server failed: {0}" + "").format(resp.status)) + + def decode_config(self, data): + if data is None: + return None + data = remove_bom(data) + xml_text = ustr(data, encoding='utf-8') + return xml_text + + def fetch_config(self, uri, headers): + try: + resp = self.call_wireserver(restutil.http_get, uri, + headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if (resp.status != httpclient.OK): + raise ProtocolError("{0} - {1}".format(resp.status, uri)) + + return self.decode_config(resp.read()) + + def fetch_cache(self, local_file): + if not os.path.isfile(local_file): + raise ProtocolError("{0} is missing.".format(local_file)) + try: + return fileutil.read_file(local_file) + except IOError as e: + raise ProtocolError("Failed to read cache: {0}".format(e)) + + def save_cache(self, local_file, data): + try: + fileutil.write_file(local_file, data) + except IOError as e: + raise ProtocolError("Failed to write cache: {0}".format(e)) + + def call_storage_service(self, http_req, *args, **kwargs): + """ + Call storage service, handle SERVICE_UNAVAILABLE(503) + """ + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.SERVICE_UNAVAILABLE: + logger.warn("Storage service is not avaible temporaryly") + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + else: + return resp + raise ProtocolError(("Calling storage endpoint failed: {0}" + "").format(resp.status)) + + def fetch_manifest(self, version_uris): + for version_uri in version_uris: + logger.verbose("Fetch ext handler manifest: {0}", version_uri.uri) + try: + resp = self.call_storage_service(restutil.http_get, + version_uri.uri, None, + chk_proxy=True) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if resp.status == httpclient.OK: + return self.decode_config(resp.read()) + logger.warn("Failed to fetch ExtensionManifest: {0}, {1}", + resp.status, version_uri.uri) + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + raise ProtocolError(("Failed to fetch ExtensionManifest from " + "all sources")) + + def update_hosting_env(self, goal_state): + if goal_state.hosting_env_uri is None: + raise ProtocolError("HostingEnvironmentConfig uri is empty") + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_config(goal_state.hosting_env_uri, + self.get_header()) + self.save_cache(local_file, xml_text) + self.hosting_env = HostingEnv(xml_text) + + def update_shared_conf(self, goal_state): + if goal_state.shared_conf_uri is None: + raise ProtocolError("SharedConfig uri is empty") + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_config(goal_state.shared_conf_uri, + self.get_header()) + self.save_cache(local_file, xml_text) + self.shared_conf = SharedConfig(xml_text) + + def update_certs(self, goal_state): + if goal_state.certs_uri is None: + return + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_config(goal_state.certs_uri, + self.get_header_for_cert()) + self.save_cache(local_file, xml_text) + self.certs = Certificates(self, xml_text) + + def update_ext_conf(self, goal_state): + if goal_state.ext_uri is None: + logger.info("ExtensionsConfig.xml uri is empty") + self.ext_conf = ExtensionsConfig(None) + return + incarnation = goal_state.incarnation + local_file = os.path.join(conf.get_lib_dir(), + EXT_CONF_FILE_NAME.format(incarnation)) + xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) + self.save_cache(local_file, xml_text) + self.ext_conf = ExtensionsConfig(xml_text) + + def update_goal_state(self, forced=False, max_retry=3): + uri = GOAL_STATE_URI.format(self.endpoint) + xml_text = self.fetch_config(uri, self.get_header()) + goal_state = GoalState(xml_text) + + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + + if not forced: + last_incarnation = None + if (os.path.isfile(incarnation_file)): + last_incarnation = fileutil.read_file(incarnation_file) + new_incarnation = goal_state.incarnation + if last_incarnation is not None and \ + last_incarnation == new_incarnation: + # Goalstate is not updated. + return + + # Start updating goalstate, retry on 410 + for retry in range(0, max_retry): + try: + self.goal_state = goal_state + file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) + self.save_cache(goal_state_file, xml_text) + self.save_cache(incarnation_file, goal_state.incarnation) + self.update_hosting_env(goal_state) + self.update_shared_conf(goal_state) + self.update_certs(goal_state) + self.update_ext_conf(goal_state) + return + except WireProtocolResourceGone: + logger.info("Incarnation is out of date. Update goalstate.") + xml_text = self.fetch_config(uri, self.get_header()) + goal_state = GoalState(xml_text) + + raise ProtocolError("Exceeded max retry updating goal state") + + def get_goal_state(self): + if (self.goal_state is None): + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + incarnation = self.fetch_cache(incarnation_file) + + file_name = GOAL_STATE_FILE_NAME.format(incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) + xml_text = self.fetch_cache(goal_state_file) + self.goal_state = GoalState(xml_text) + return self.goal_state + + def get_hosting_env(self): + if (self.hosting_env is None): + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.hosting_env = HostingEnv(xml_text) + return self.hosting_env + + def get_shared_conf(self): + if (self.shared_conf is None): + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.shared_conf = SharedConfig(xml_text) + return self.shared_conf + + def get_certs(self): + if (self.certs is None): + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.certs = Certificates(self, xml_text) + if self.certs is None: + return None + return self.certs + + def get_ext_conf(self): + if (self.ext_conf is None): + goal_state = self.get_goal_state() + if goal_state.ext_uri is None: + self.ext_conf = ExtensionsConfig(None) + else: + local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_cache(local_file) + self.ext_conf = ExtensionsConfig(xml_text) + return self.ext_conf + + def get_ext_manifest(self, ext_handler, goal_state): + local_file = MANIFEST_FILE_NAME.format(ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) + return ExtensionManifest(xml_text) + + def get_gafamily_manifest(self, vmagent_manifest, goal_state): + local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris) + fileutil.write_file(local_file, xml_text) + return ExtensionManifest(xml_text) + + def check_wire_protocol_version(self): + uri = VERSION_INFO_URI.format(self.endpoint) + version_info_xml = self.fetch_config(uri, None) + version_info = VersionInfo(version_info_xml) + + preferred = version_info.get_preferred() + if PROTOCOL_VERSION == preferred: + logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) + elif PROTOCOL_VERSION in version_info.get_supported(): + logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) + logger.warn("Server prefered version:{0}", preferred) + else: + error = ("Agent supported wire protocol version: {0} was not " + "advised by Fabric.").format(PROTOCOL_VERSION) + raise ProtocolNotFoundError(error) + + def upload_status_blob(self): + ext_conf = self.get_ext_conf() + if ext_conf.status_upload_blob is not None: + if not self.status_blob.upload(ext_conf.status_upload_blob): + self.host_plugin.put_vm_status(self.status_blob, + ext_conf.status_upload_blob) + + def report_role_prop(self, thumbprint): + goal_state = self.get_goal_state() + role_prop = _build_role_properties(goal_state.container_id, + goal_state.role_instance_id, + thumbprint) + role_prop = role_prop.encode("utf-8") + role_prop_uri = ROLE_PROP_URI.format(self.endpoint) + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, role_prop_uri, + role_prop, headers=headers) + except HttpError as e: + raise ProtocolError((u"Failed to send role properties: {0}" + u"").format(e)) + if resp.status != httpclient.ACCEPTED: + raise ProtocolError((u"Failed to send role properties: {0}" + u", {1}").format(resp.status, resp.read())) + + def report_health(self, status, substatus, description): + goal_state = self.get_goal_state() + health_report = _build_health_report(goal_state.incarnation, + goal_state.container_id, + goal_state.role_instance_id, + status, + substatus, + description) + health_report = health_report.encode("utf-8") + health_report_uri = HEALTH_REPORT_URI.format(self.endpoint) + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, health_report_uri, + health_report, headers=headers, max_retry=8) + except HttpError as e: + raise ProtocolError((u"Failed to send provision status: {0}" + u"").format(e)) + if resp.status != httpclient.OK: + raise ProtocolError((u"Failed to send provision status: {0}" + u", {1}").format(resp.status, resp.read())) + + def send_event(self, provider_id, event_str): + uri = TELEMETRY_URI.format(self.endpoint) + data_format = ('<?xml version="1.0"?>' + '<TelemetryData version="1.0">' + '<Provider id="{0}">{1}' + '</Provider>' + '</TelemetryData>') + data = data_format.format(provider_id, event_str) + try: + header = self.get_header_for_xml_content() + resp = self.call_wireserver(restutil.http_post, uri, data, header) + except HttpError as e: + raise ProtocolError("Failed to send events:{0}".format(e)) + + if resp.status != httpclient.OK: + logger.verbose(resp.read()) + raise ProtocolError("Failed to send events:{0}".format(resp.status)) + + def report_event(self, event_list): + buf = {} + # Group events by providerId + for event in event_list.events: + if event.providerId not in buf: + buf[event.providerId] = "" + event_str = event_to_v1(event) + if len(event_str) >= 63 * 1024: + logger.warn("Single event too large: {0}", event_str[300:]) + continue + if len(buf[event.providerId] + event_str) >= 63 * 1024: + self.send_event(event.providerId, buf[event.providerId]) + buf[event.providerId] = "" + buf[event.providerId] = buf[event.providerId] + event_str + + # Send out all events left in buffer. + for provider_id in list(buf.keys()): + if len(buf[provider_id]) > 0: + self.send_event(provider_id, buf[provider_id]) + + def get_header(self): + return { + "x-ms-agent-name": "WALinuxAgent", + "x-ms-version": PROTOCOL_VERSION + } + + def get_header_for_xml_content(self): + return { + "x-ms-agent-name": "WALinuxAgent", + "x-ms-version": PROTOCOL_VERSION, + "Content-Type": "text/xml;charset=utf-8" + } + + def get_header_for_cert(self): + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + content = self.fetch_cache(trans_cert_file) + cert = get_bytes_from_pem(content) + return { + "x-ms-agent-name": "WALinuxAgent", + "x-ms-version": PROTOCOL_VERSION, + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": cert + } + +class VersionInfo(object): + def __init__(self, xml_text): + """ + Query endpoint server for wire protocol version. + Fail if our desired protocol version is not seen. + """ + logger.verbose("Load Version.xml") + self.parse(xml_text) + + def parse(self, xml_text): + xml_doc = parse_doc(xml_text) + preferred = find(xml_doc, "Preferred") + self.preferred = findtext(preferred, "Version") + logger.info("Fabric preferred wire protocol version:{0}", self.preferred) + + self.supported = [] + supported = find(xml_doc, "Supported") + supported_version = findall(supported, "Version") + for node in supported_version: + version = gettext(node) + logger.verbose("Fabric supported wire protocol version:{0}", version) + self.supported.append(version) + + def get_preferred(self): + return self.preferred + + def get_supported(self): + return self.supported + + +class GoalState(object): + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("GoalState.xml is None") + logger.verbose("Load GoalState.xml") + self.incarnation = None + self.expected_state = None + self.hosting_env_uri = None + self.shared_conf_uri = None + self.certs_uri = None + self.ext_uri = None + self.role_instance_id = None + self.container_id = None + self.load_balancer_probe_port = None + self.parse(xml_text) + + def parse(self, xml_text): + """ + Request configuration data from endpoint server. + """ + self.xml_text = xml_text + xml_doc = parse_doc(xml_text) + self.incarnation = findtext(xml_doc, "Incarnation") + self.expected_state = findtext(xml_doc, "ExpectedState") + self.hosting_env_uri = findtext(xml_doc, "HostingEnvironmentConfig") + self.shared_conf_uri = findtext(xml_doc, "SharedConfig") + self.certs_uri = findtext(xml_doc, "Certificates") + self.ext_uri = findtext(xml_doc, "ExtensionsConfig") + role_instance = find(xml_doc, "RoleInstance") + self.role_instance_id = findtext(role_instance, "InstanceId") + container = find(xml_doc, "Container") + self.container_id = findtext(container, "ContainerId") + lbprobe_ports = find(xml_doc, "LBProbePorts") + self.load_balancer_probe_port = findtext(lbprobe_ports, "Port") + return self + + +class HostingEnv(object): + """ + parse Hosting enviromnet config and store in + HostingEnvironmentConfig.xml + """ + + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("HostingEnvironmentConfig.xml is None") + logger.verbose("Load HostingEnvironmentConfig.xml") + self.vm_name = None + self.role_name = None + self.deployment_name = None + self.parse(xml_text) + + def parse(self, xml_text): + """ + parse and create HostingEnvironmentConfig.xml. + """ + self.xml_text = xml_text + xml_doc = parse_doc(xml_text) + incarnation = find(xml_doc, "Incarnation") + self.vm_name = getattrib(incarnation, "instance") + role = find(xml_doc, "Role") + self.role_name = getattrib(role, "name") + deployment = find(xml_doc, "Deployment") + self.deployment_name = getattrib(deployment, "name") + return self + + +class SharedConfig(object): + """ + parse role endpoint server and goal state config. + """ + + def __init__(self, xml_text): + logger.verbose("Load SharedConfig.xml") + self.parse(xml_text) + + def parse(self, xml_text): + """ + parse and write configuration to file SharedConfig.xml. + """ + # Not used currently + return self + +class Certificates(object): + """ + Object containing certificates of host and provisioned user. + """ + + def __init__(self, client, xml_text): + logger.verbose("Load Certificates.xml") + self.client = client + self.cert_list = CertList() + self.parse(xml_text) + + def parse(self, xml_text): + """ + Parse multiple certificates into seperate files. + """ + xml_doc = parse_doc(xml_text) + data = findtext(xml_doc, "Data") + if data is None: + return + + cryptutil = CryptUtil(conf.get_openssl_cmd()) + p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) + p7m = ("MIME-Version:1.0\n" + "Content-Disposition: attachment; filename=\"{0}\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" + "Content-Transfer-Encoding: base64\n" + "\n" + "{2}").format(p7m_file, p7m_file, data) + + self.client.save_cache(p7m_file, p7m) + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) + # decrypt certificates + cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, + pem_file) + + # The parsing process use public key to match prv and crt. + buf = [] + begin_crt = False + begin_prv = False + prvs = {} + thumbprints = {} + index = 0 + v1_cert_list = [] + with open(pem_file) as pem: + for line in pem.readlines(): + buf.append(line) + if re.match(r'[-]+BEGIN.*KEY[-]+', line): + begin_prv = True + elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): + begin_crt = True + elif re.match(r'[-]+END.*KEY[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'prv', buf) + pub = cryptutil.get_pubkey_from_prv(tmp_file) + prvs[pub] = tmp_file + buf = [] + index += 1 + begin_prv = False + elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'crt', buf) + pub = cryptutil.get_pubkey_from_crt(tmp_file) + thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) + thumbprints[pub] = thumbprint + # Rename crt with thumbprint as the file name + crt = "{0}.crt".format(thumbprint) + v1_cert_list.append({ + "name": None, + "thumbprint": thumbprint + }) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) + buf = [] + index += 1 + begin_crt = False + + # Rename prv key with thumbprint as the file name + for pubkey in prvs: + thumbprint = thumbprints[pubkey] + if thumbprint: + tmp_file = prvs[pubkey] + prv = "{0}.prv".format(thumbprint) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) + + for v1_cert in v1_cert_list: + cert = Cert() + set_properties("certs", cert, v1_cert) + self.cert_list.certificates.append(cert) + + def write_to_tmp_file(self, index, suffix, buf): + file_name = os.path.join(conf.get_lib_dir(), + "{0}.{1}".format(index, suffix)) + self.client.save_cache(file_name, "".join(buf)) + return file_name + + +class ExtensionsConfig(object): + """ + parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. + Install if <enabled>true</enabled>, remove if it is set to false. + """ + + def __init__(self, xml_text): + logger.verbose("Load ExtensionsConfig.xml") + self.ext_handlers = ExtHandlerList() + self.vmagent_manifests = VMAgentManifestList() + self.status_upload_blob = None + if xml_text is not None: + self.parse(xml_text) + + def parse(self, xml_text): + """ + Write configuration to file ExtensionsConfig.xml. + """ + xml_doc = parse_doc(xml_text) + + ga_families_list = find(xml_doc, "GAFamilies") + ga_families = findall(ga_families_list, "GAFamily") + + for ga_family in ga_families: + family = findtext(ga_family, "Name") + uris_list = find(ga_family, "Uris") + uris = findall(uris_list, "Uri") + manifest = VMAgentManifest() + manifest.family = family + for uri in uris: + manifestUri = VMAgentManifestUri(uri=gettext(uri)) + manifest.versionsManifestUris.append(manifestUri) + self.vmagent_manifests.vmAgentManifests.append(manifest) + + plugins_list = find(xml_doc, "Plugins") + plugins = findall(plugins_list, "Plugin") + plugin_settings_list = find(xml_doc, "PluginSettings") + plugin_settings = findall(plugin_settings_list, "Plugin") + + for plugin in plugins: + ext_handler = self.parse_plugin(plugin) + self.ext_handlers.extHandlers.append(ext_handler) + self.parse_plugin_settings(ext_handler, plugin_settings) + + self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob") + + def parse_plugin(self, plugin): + ext_handler = ExtHandler() + ext_handler.name = getattrib(plugin, "name") + ext_handler.properties.version = getattrib(plugin, "version") + ext_handler.properties.state = getattrib(plugin, "state") + + auto_upgrade = getattrib(plugin, "autoUpgrade") + if auto_upgrade is not None and auto_upgrade.lower() == "true": + ext_handler.properties.upgradePolicy = "auto" + else: + ext_handler.properties.upgradePolicy = "manual" + + location = getattrib(plugin, "location") + failover_location = getattrib(plugin, "failoverlocation") + for uri in [location, failover_location]: + version_uri = ExtHandlerVersionUri() + version_uri.uri = uri + ext_handler.versionUris.append(version_uri) + return ext_handler + + def parse_plugin_settings(self, ext_handler, plugin_settings): + if plugin_settings is None: + return + + name = ext_handler.name + version = ext_handler.properties.version + settings = [x for x in plugin_settings \ + if getattrib(x, "name") == name and \ + getattrib(x, "version") == version] + + if settings is None or len(settings) == 0: + return + + runtime_settings = None + runtime_settings_node = find(settings[0], "RuntimeSettings") + seqNo = getattrib(runtime_settings_node, "seqNo") + runtime_settings_str = gettext(runtime_settings_node) + try: + runtime_settings = json.loads(runtime_settings_str) + except ValueError as e: + logger.error("Invalid extension settings") + return + + for plugin_settings_list in runtime_settings["runtimeSettings"]: + handler_settings = plugin_settings_list["handlerSettings"] + ext = Extension() + # There is no "extension name" in wire protocol. + # Put + ext.name = ext_handler.name + ext.sequenceNumber = seqNo + ext.publicSettings = handler_settings.get("publicSettings") + ext.protectedSettings = handler_settings.get("protectedSettings") + thumbprint = handler_settings.get("protectedSettingsCertThumbprint") + ext.certificateThumbprint = thumbprint + ext_handler.properties.extensions.append(ext) + + +class ExtensionManifest(object): + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("ExtensionManifest is None") + logger.verbose("Load ExtensionManifest.xml") + self.pkg_list = ExtHandlerPackageList() + self.parse(xml_text) + + def parse(self, xml_text): + xml_doc = parse_doc(xml_text) + self._handle_packages(findall(find(xml_doc, "Plugins"), "Plugin"), False) + self._handle_packages(findall(find(xml_doc, "InternalPlugins"), "Plugin"), True) + + def _handle_packages(self, packages, isinternal): + for package in packages: + version = findtext(package, "Version") + + disallow_major_upgrade = findtext(package, "DisallowMajorVersionUpgrade") + if disallow_major_upgrade is None: + disallow_major_upgrade = '' + disallow_major_upgrade = disallow_major_upgrade.lower() == "true" + + uris = find(package, "Uris") + uri_list = findall(uris, "Uri") + uri_list = [gettext(x) for x in uri_list] + pkg = ExtHandlerPackage() + pkg.version = version + pkg.disallow_major_upgrade = disallow_major_upgrade + for uri in uri_list: + pkg_uri = ExtHandlerVersionUri() + pkg_uri.uri = uri + pkg.uris.append(pkg_uri) + + pkg.isinternal = isinternal + self.pkg_list.versions.append(pkg) |