summaryrefslogtreecommitdiff
path: root/azurelinuxagent/common/protocol
diff options
context:
space:
mode:
Diffstat (limited to 'azurelinuxagent/common/protocol')
-rw-r--r--azurelinuxagent/common/protocol/__init__.py21
-rw-r--r--azurelinuxagent/common/protocol/hostplugin.py124
-rw-r--r--azurelinuxagent/common/protocol/metadata.py223
-rw-r--r--azurelinuxagent/common/protocol/ovfenv.py113
-rw-r--r--azurelinuxagent/common/protocol/restapi.py272
-rw-r--r--azurelinuxagent/common/protocol/util.py285
-rw-r--r--azurelinuxagent/common/protocol/wire.py1218
7 files changed, 2256 insertions, 0 deletions
diff --git a/azurelinuxagent/common/protocol/__init__.py b/azurelinuxagent/common/protocol/__init__.py
new file mode 100644
index 0000000..fb7c273
--- /dev/null
+++ b/azurelinuxagent/common/protocol/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+from azurelinuxagent.common.protocol.util import get_protocol_util, \
+ OVF_FILE_NAME, \
+ TAG_FILE_NAME
+
diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py
new file mode 100644
index 0000000..6569604
--- /dev/null
+++ b/azurelinuxagent/common/protocol/hostplugin.py
@@ -0,0 +1,124 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+from azurelinuxagent.common.protocol.wire import *
+from azurelinuxagent.common.utils import textutil
+
+HOST_PLUGIN_PORT = 32526
+URI_FORMAT_GET_API_VERSIONS = "http://{0}:{1}/versions"
+URI_FORMAT_PUT_VM_STATUS = "http://{0}:{1}/status"
+URI_FORMAT_PUT_LOG = "http://{0}:{1}/vmAgentLog"
+API_VERSION = "2015-09-01"
+
+
+class HostPluginProtocol(object):
+ def __init__(self, endpoint):
+ if endpoint is None:
+ raise ProtocolError("Host plugin endpoint not provided")
+ self.is_initialized = False
+ self.is_available = False
+ self.api_versions = None
+ self.endpoint = endpoint
+
+ def ensure_initialized(self):
+ if not self.is_initialized:
+ self.api_versions = self.get_api_versions()
+ self.is_available = API_VERSION in self.api_versions
+ self.is_initialized = True
+ return self.is_available
+
+ def get_api_versions(self):
+ url = URI_FORMAT_GET_API_VERSIONS.format(self.endpoint,
+ HOST_PLUGIN_PORT)
+ logger.info("getting API versions at [{0}]".format(url))
+ try:
+ response = restutil.http_get(url)
+ if response.status != httpclient.OK:
+ logger.error(
+ "get API versions returned status code [{0}]".format(
+ response.status))
+ return []
+ return response.read()
+ except HttpError as e:
+ logger.error("get API versions failed with [{0}]".format(e))
+ return []
+
+ def put_vm_status(self, status_blob, sas_url):
+ """
+ Try to upload the VM status via the host plugin /status channel
+ :param sas_url: the blob SAS url to pass to the host plugin
+ :type status_blob: StatusBlob
+ """
+ if not self.ensure_initialized():
+ logger.error("host plugin channel is not available")
+ return
+ if status_blob is None or status_blob.vm_status is None:
+ logger.error("no status data was provided")
+ return
+ url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT)
+ status = textutil.b64encode(status_blob.vm_status)
+ headers = {"x-ms-version": API_VERSION}
+ blob_headers = [{'headerName': 'x-ms-version',
+ 'headerValue': status_blob.__storage_version__},
+ {'headerName': 'x-ms-blob-type',
+ 'headerValue': status_blob.type}]
+ data = json.dumps({'requestUri': sas_url, 'headers': blob_headers,
+ 'content': status}, sort_keys=True)
+ logger.info("put VM status at [{0}]".format(url))
+ try:
+ response = restutil.http_put(url, data, headers)
+ if response.status != httpclient.OK:
+ logger.error("put VM status returned status code [{0}]".format(
+ response.status))
+ except HttpError as e:
+ logger.error("put VM status failed with [{0}]".format(e))
+
+ def put_vm_log(self, content, container_id, deployment_id):
+ """
+ Try to upload the given content to the host plugin
+ :param deployment_id: the deployment id, which is obtained from the
+ goal state (tenant name)
+ :param container_id: the container id, which is obtained from the
+ goal state
+ :param content: the binary content of the zip file to upload
+ :return:
+ """
+ if not self.ensure_initialized():
+ logger.error("host plugin channel is not available")
+ return
+ if content is None or container_id is None or deployment_id is None:
+ logger.error(
+ "invalid arguments passed: "
+ "[{0}], [{1}], [{2}]".format(
+ content,
+ container_id,
+ deployment_id))
+ return
+ url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT)
+
+ headers = {"x-ms-vmagentlog-deploymentid": deployment_id,
+ "x-ms-vmagentlog-containerid": container_id}
+ logger.info("put VM log at [{0}]".format(url))
+ try:
+ response = restutil.http_put(url, content, headers)
+ if response.status != httpclient.OK:
+ logger.error("put log returned status code [{0}]".format(
+ response.status))
+ except HttpError as e:
+ logger.error("put log failed with [{0}]".format(e))
diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py
new file mode 100644
index 0000000..f86f72f
--- /dev/null
+++ b/azurelinuxagent/common/protocol/metadata.py
@@ -0,0 +1,223 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import json
+import shutil
+import os
+import time
+from azurelinuxagent.common.exception import ProtocolError, HttpError
+from azurelinuxagent.common.future import httpclient, ustr
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.restutil as restutil
+import azurelinuxagent.common.utils.textutil as textutil
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
+from azurelinuxagent.common.protocol.restapi import *
+
+METADATA_ENDPOINT='169.254.169.254'
+APIVERSION='2015-05-01-preview'
+BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}"
+
+TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem"
+TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem"
+
+#TODO remote workarround for azure stack
+MAX_PING = 30
+RETRY_PING_INTERVAL = 10
+
+def _add_content_type(headers):
+ if headers is None:
+ headers = {}
+ headers["content-type"] = "application/json"
+ return headers
+
+class MetadataProtocol(Protocol):
+
+ def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT):
+ self.apiversion = apiversion
+ self.endpoint = endpoint
+ self.identity_uri = BASE_URI.format(self.endpoint, "identity",
+ self.apiversion, "&$expand=*")
+ self.cert_uri = BASE_URI.format(self.endpoint, "certificates",
+ self.apiversion, "&$expand=*")
+ self.ext_uri = BASE_URI.format(self.endpoint, "extensionHandlers",
+ self.apiversion, "&$expand=*")
+ self.vmagent_uri = BASE_URI.format(self.endpoint, "vmAgentVersions",
+ self.apiversion, "&$expand=*")
+ self.provision_status_uri = BASE_URI.format(self.endpoint,
+ "provisioningStatus",
+ self.apiversion, "")
+ self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent",
+ self.apiversion, "")
+ self.ext_status_uri = BASE_URI.format(self.endpoint,
+ "status/extensions/{0}",
+ self.apiversion, "")
+ self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry",
+ self.apiversion, "")
+
+ def _get_data(self, url, headers=None):
+ try:
+ resp = restutil.http_get(url, headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+
+ if resp.status != httpclient.OK:
+ raise ProtocolError("{0} - GET: {1}".format(resp.status, url))
+
+ data = resp.read()
+ etag = resp.getheader('ETag')
+ if data is None:
+ return None
+ data = json.loads(ustr(data, encoding="utf-8"))
+ return data, etag
+
+ def _put_data(self, url, data, headers=None):
+ headers = _add_content_type(headers)
+ try:
+ resp = restutil.http_put(url, json.dumps(data), headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+ if resp.status != httpclient.OK:
+ raise ProtocolError("{0} - PUT: {1}".format(resp.status, url))
+
+ def _post_data(self, url, data, headers=None):
+ headers = _add_content_type(headers)
+ try:
+ resp = restutil.http_post(url, json.dumps(data), headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+ if resp.status != httpclient.CREATED:
+ raise ProtocolError("{0} - POST: {1}".format(resp.status, url))
+
+ def _get_trans_cert(self):
+ trans_crt_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ if not os.path.isfile(trans_crt_file):
+ raise ProtocolError("{0} is missing.".format(trans_crt_file))
+ content = fileutil.read_file(trans_crt_file)
+ return textutil.get_bytes_from_pem(content)
+
+ def detect(self):
+ self.get_vminfo()
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
+
+ #"Install" the cert and private key to /var/lib/waagent
+ thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file)
+ prv_file = os.path.join(conf.get_lib_dir(),
+ "{0}.prv".format(thumbprint))
+ crt_file = os.path.join(conf.get_lib_dir(),
+ "{0}.crt".format(thumbprint))
+ shutil.copyfile(trans_prv_file, prv_file)
+ shutil.copyfile(trans_cert_file, crt_file)
+
+
+ def get_vminfo(self):
+ vminfo = VMInfo()
+ data, etag = self._get_data(self.identity_uri)
+ set_properties("vminfo", vminfo, data)
+ return vminfo
+
+ def get_certs(self):
+ #TODO download and save certs
+ return CertList()
+
+ def get_vmagent_manifests(self, last_etag=None):
+ manifests = VMAgentManifestList()
+ data, etag = self._get_data(self.vmagent_uri)
+ if last_etag == None or last_etag < etag:
+ set_properties("vmAgentManifests", manifests.vmAgentManifests, data)
+ return manifests, etag
+
+ def get_vmagent_pkgs(self, vmagent_manifest):
+ #Agent package is the same with extension handler
+ vmagent_pkgs = ExtHandlerPackageList()
+ data = None
+ for manifest_uri in vmagent_manifest.versionsManifestUris:
+ try:
+ data = self._get_data(manifest_uri.uri)
+ break
+ except ProtocolError as e:
+ logger.warn("Failed to get vmagent versions: {0}", e)
+ logger.info("Retry getting vmagent versions")
+ if data is None:
+ raise ProtocolError(("Failed to get versions for vm agent: {0}"
+ "").format(vmagent_manifest.family))
+ set_properties("vmAgentVersions", vmagent_pkgs, data)
+ # TODO: What etag should this return?
+ return vmagent_pkgs
+
+ def get_ext_handlers(self, last_etag=None):
+ headers = {
+ "x-ms-vmagent-public-x509-cert": self._get_trans_cert()
+ }
+ ext_list = ExtHandlerList()
+ data, etag = self._get_data(self.ext_uri, headers=headers)
+ if last_etag == None or last_etag < etag:
+ set_properties("extensionHandlers", ext_list.extHandlers, data)
+ return ext_list, etag
+
+ def get_ext_handler_pkgs(self, ext_handler):
+ ext_handler_pkgs = ExtHandlerPackageList()
+ data = None
+ for version_uri in ext_handler.versionUris:
+ try:
+ data, etag = self._get_data(version_uri.uri)
+ break
+ except ProtocolError as e:
+ logger.warn("Failed to get version uris: {0}", e)
+ logger.info("Retry getting version uris")
+ set_properties("extensionPackages", ext_handler_pkgs, data)
+ return ext_handler_pkgs
+
+ def report_provision_status(self, provision_status):
+ validate_param('provisionStatus', provision_status, ProvisionStatus)
+ data = get_properties(provision_status)
+ self._put_data(self.provision_status_uri, data)
+
+ def report_vm_status(self, vm_status):
+ validate_param('vmStatus', vm_status, VMStatus)
+ data = get_properties(vm_status)
+ #TODO code field is not implemented for metadata protocol yet. Remove it
+ handler_statuses = data['vmAgent']['extensionHandlers']
+ for handler_status in handler_statuses:
+ try:
+ handler_status.pop('code', None)
+ except KeyError:
+ pass
+
+ self._put_data(self.vm_status_uri, data)
+
+ def report_ext_status(self, ext_handler_name, ext_name, ext_status):
+ validate_param('extensionStatus', ext_status, ExtensionStatus)
+ data = get_properties(ext_status)
+ uri = self.ext_status_uri.format(ext_name)
+ self._put_data(uri, data)
+
+ def report_event(self, events):
+ #TODO disable telemetry for azure stack test
+ #validate_param('events', events, TelemetryEventList)
+ #data = get_properties(events)
+ #self._post_data(self.event_uri, data)
+ pass
+
diff --git a/azurelinuxagent/common/protocol/ovfenv.py b/azurelinuxagent/common/protocol/ovfenv.py
new file mode 100644
index 0000000..4901871
--- /dev/null
+++ b/azurelinuxagent/common/protocol/ovfenv.py
@@ -0,0 +1,113 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+"""
+Copy and parse ovf-env.xml from provisioning ISO and local cache
+"""
+import os
+import re
+import shutil
+import xml.dom.minidom as minidom
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext
+
+OVF_VERSION = "1.0"
+OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1"
+WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure"
+
+def _validate_ovf(val, msg):
+ if val is None:
+ raise ProtocolError("Failed to parse OVF XML: {0}".format(msg))
+
+class OvfEnv(object):
+ """
+ Read, and process provisioning info from provisioning file OvfEnv.xml
+ """
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("ovf-env is None")
+ logger.verbose("Load ovf-env.xml")
+ self.hostname = None
+ self.username = None
+ self.user_password = None
+ self.customdata = None
+ self.disable_ssh_password_auth = True
+ self.ssh_pubkeys = []
+ self.ssh_keypairs = []
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Parse xml tree, retreiving user and ssh key information.
+ Return self.
+ """
+ wans = WA_NAME_SPACE
+ ovfns = OVF_NAME_SPACE
+
+ xml_doc = parse_doc(xml_text)
+
+ environment = find(xml_doc, "Environment", namespace=ovfns)
+ _validate_ovf(environment, "Environment not found")
+
+ section = find(environment, "ProvisioningSection", namespace=wans)
+ _validate_ovf(section, "ProvisioningSection not found")
+
+ version = findtext(environment, "Version", namespace=wans)
+ _validate_ovf(version, "Version not found")
+
+ if version > OVF_VERSION:
+ logger.warn("Newer provisioning configuration detected. "
+ "Please consider updating waagent")
+
+ conf_set = find(section, "LinuxProvisioningConfigurationSet",
+ namespace=wans)
+ _validate_ovf(conf_set, "LinuxProvisioningConfigurationSet not found")
+
+ self.hostname = findtext(conf_set, "HostName", namespace=wans)
+ _validate_ovf(self.hostname, "HostName not found")
+
+ self.username = findtext(conf_set, "UserName", namespace=wans)
+ _validate_ovf(self.username, "UserName not found")
+
+ self.user_password = findtext(conf_set, "UserPassword", namespace=wans)
+
+ self.customdata = findtext(conf_set, "CustomData", namespace=wans)
+
+ auth_option = findtext(conf_set, "DisableSshPasswordAuthentication",
+ namespace=wans)
+ if auth_option is not None and auth_option.lower() == "true":
+ self.disable_ssh_password_auth = True
+ else:
+ self.disable_ssh_password_auth = False
+
+ public_keys = findall(conf_set, "PublicKey", namespace=wans)
+ for public_key in public_keys:
+ path = findtext(public_key, "Path", namespace=wans)
+ fingerprint = findtext(public_key, "Fingerprint", namespace=wans)
+ value = findtext(public_key, "Value", namespace=wans)
+ self.ssh_pubkeys.append((path, fingerprint, value))
+
+ keypairs = findall(conf_set, "KeyPair", namespace=wans)
+ for keypair in keypairs:
+ path = findtext(keypair, "Path", namespace=wans)
+ fingerprint = findtext(keypair, "Fingerprint", namespace=wans)
+ self.ssh_keypairs.append((path, fingerprint))
+
diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py
new file mode 100644
index 0000000..7f00488
--- /dev/null
+++ b/azurelinuxagent/common/protocol/restapi.py
@@ -0,0 +1,272 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+import os
+import copy
+import re
+import json
+import xml.dom.minidom
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError, HttpError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.restutil as restutil
+
+def validate_param(name, val, expected_type):
+ if val is None:
+ raise ProtocolError("{0} is None".format(name))
+ if not isinstance(val, expected_type):
+ raise ProtocolError(("{0} type should be {1} not {2}"
+ "").format(name, expected_type, type(val)))
+
+def set_properties(name, obj, data):
+ if isinstance(obj, DataContract):
+ validate_param("Property '{0}'".format(name), data, dict)
+ for prob_name, prob_val in data.items():
+ prob_full_name = "{0}.{1}".format(name, prob_name)
+ try:
+ prob = getattr(obj, prob_name)
+ except AttributeError:
+ logger.warn("Unknown property: {0}", prob_full_name)
+ continue
+ prob = set_properties(prob_full_name, prob, prob_val)
+ setattr(obj, prob_name, prob)
+ return obj
+ elif isinstance(obj, DataContractList):
+ validate_param("List '{0}'".format(name), data, list)
+ for item_data in data:
+ item = obj.item_cls()
+ item = set_properties(name, item, item_data)
+ obj.append(item)
+ return obj
+ else:
+ return data
+
+def get_properties(obj):
+ if isinstance(obj, DataContract):
+ data = {}
+ props = vars(obj)
+ for prob_name, prob in list(props.items()):
+ data[prob_name] = get_properties(prob)
+ return data
+ elif isinstance(obj, DataContractList):
+ data = []
+ for item in obj:
+ item_data = get_properties(item)
+ data.append(item_data)
+ return data
+ else:
+ return obj
+
+class DataContract(object):
+ pass
+
+class DataContractList(list):
+ def __init__(self, item_cls):
+ self.item_cls = item_cls
+
+"""
+Data contract between guest and host
+"""
+class VMInfo(DataContract):
+ def __init__(self, subscriptionId=None, vmName=None, containerId=None,
+ roleName=None, roleInstanceName=None, tenantName=None):
+ self.subscriptionId = subscriptionId
+ self.vmName = vmName
+ self.containerId = containerId
+ self.roleName = roleName
+ self.roleInstanceName = roleInstanceName
+ self.tenantName = tenantName
+
+class Cert(DataContract):
+ def __init__(self, name=None, thumbprint=None, certificateDataUri=None):
+ self.name = name
+ self.thumbprint = thumbprint
+ self.certificateDataUri = certificateDataUri
+
+class CertList(DataContract):
+ def __init__(self):
+ self.certificates = DataContractList(Cert)
+
+#TODO: confirm vmagent manifest schema
+class VMAgentManifestUri(DataContract):
+ def __init__(self, uri=None):
+ self.uri = uri
+
+class VMAgentManifest(DataContract):
+ def __init__(self, family=None):
+ self.family = family
+ self.versionsManifestUris = DataContractList(VMAgentManifestUri)
+
+class VMAgentManifestList(DataContract):
+ def __init__(self):
+ self.vmAgentManifests = DataContractList(VMAgentManifest)
+
+class Extension(DataContract):
+ def __init__(self, name=None, sequenceNumber=None, publicSettings=None,
+ protectedSettings=None, certificateThumbprint=None):
+ self.name = name
+ self.sequenceNumber = sequenceNumber
+ self.publicSettings = publicSettings
+ self.protectedSettings = protectedSettings
+ self.certificateThumbprint = certificateThumbprint
+
+class ExtHandlerProperties(DataContract):
+ def __init__(self):
+ self.version = None
+ self.upgradePolicy = None
+ self.state = None
+ self.extensions = DataContractList(Extension)
+
+class ExtHandlerVersionUri(DataContract):
+ def __init__(self):
+ self.uri = None
+
+class ExtHandler(DataContract):
+ def __init__(self, name=None):
+ self.name = name
+ self.properties = ExtHandlerProperties()
+ self.versionUris = DataContractList(ExtHandlerVersionUri)
+
+class ExtHandlerList(DataContract):
+ def __init__(self):
+ self.extHandlers = DataContractList(ExtHandler)
+
+class ExtHandlerPackageUri(DataContract):
+ def __init__(self, uri=None):
+ self.uri = uri
+
+class ExtHandlerPackage(DataContract):
+ def __init__(self, version = None):
+ self.version = version
+ self.uris = DataContractList(ExtHandlerPackageUri)
+ # TODO update the naming to align with metadata protocol
+ self.isinternal = False
+
+class ExtHandlerPackageList(DataContract):
+ def __init__(self):
+ self.versions = DataContractList(ExtHandlerPackage)
+
+class VMProperties(DataContract):
+ def __init__(self, certificateThumbprint=None):
+ #TODO need to confirm the property name
+ self.certificateThumbprint = certificateThumbprint
+
+class ProvisionStatus(DataContract):
+ def __init__(self, status=None, subStatus=None, description=None):
+ self.status = status
+ self.subStatus = subStatus
+ self.description = description
+ self.properties = VMProperties()
+
+class ExtensionSubStatus(DataContract):
+ def __init__(self, name=None, status=None, code=None, message=None):
+ self.name = name
+ self.status = status
+ self.code = code
+ self.message = message
+
+class ExtensionStatus(DataContract):
+ def __init__(self, configurationAppliedTime=None, operation=None,
+ status=None, seq_no=None, code=None, message=None):
+ self.configurationAppliedTime = configurationAppliedTime
+ self.operation = operation
+ self.status = status
+ self.sequenceNumber = seq_no
+ self.code = code
+ self.message = message
+ self.substatusList = DataContractList(ExtensionSubStatus)
+
+class ExtHandlerStatus(DataContract):
+ def __init__(self, name=None, version=None, status=None, code=0,
+ message=None):
+ self.name = name
+ self.version = version
+ self.status = status
+ self.code = code
+ self.message = message
+ self.extensions = DataContractList(ustr)
+
+class VMAgentStatus(DataContract):
+ def __init__(self, version=None, status=None, message=None):
+ self.version = version
+ self.status = status
+ self.message = message
+ self.extensionHandlers = DataContractList(ExtHandlerStatus)
+
+class VMStatus(DataContract):
+ def __init__(self):
+ self.vmAgent = VMAgentStatus()
+
+class TelemetryEventParam(DataContract):
+ def __init__(self, name=None, value=None):
+ self.name = name
+ self.value = value
+
+class TelemetryEvent(DataContract):
+ def __init__(self, eventId=None, providerId=None):
+ self.eventId = eventId
+ self.providerId = providerId
+ self.parameters = DataContractList(TelemetryEventParam)
+
+class TelemetryEventList(DataContract):
+ def __init__(self):
+ self.events = DataContractList(TelemetryEvent)
+
+class Protocol(DataContract):
+
+ def detect(self):
+ raise NotImplementedError()
+
+ def get_vminfo(self):
+ raise NotImplementedError()
+
+ def get_certs(self):
+ raise NotImplementedError()
+
+ def get_vmagent_manifests(self):
+ raise NotImplementedError()
+
+ def get_vmagent_pkgs(self):
+ raise NotImplementedError()
+
+ def get_ext_handlers(self):
+ raise NotImplementedError()
+
+ def get_ext_handler_pkgs(self, extension):
+ raise NotImplementedError()
+
+ def download_ext_handler_pkg(self, uri):
+ try:
+ resp = restutil.http_get(uri, chk_proxy=True)
+ if resp.status == restutil.httpclient.OK:
+ return resp.read()
+ except HttpError as e:
+ raise ProtocolError("Failed to download from: {0}".format(uri), e)
+
+ def report_provision_status(self, provision_status):
+ raise NotImplementedError()
+
+ def report_vm_status(self, vm_status):
+ raise NotImplementedError()
+
+ def report_ext_status(self, ext_handler_name, ext_name, ext_status):
+ raise NotImplementedError()
+
+ def report_event(self, event):
+ raise NotImplementedError()
+
diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py
new file mode 100644
index 0000000..7e7a74f
--- /dev/null
+++ b/azurelinuxagent/common/protocol/util.py
@@ -0,0 +1,285 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+import os
+import re
+import shutil
+import time
+import threading
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \
+ ProtocolNotFoundError, DhcpError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.osutil import get_osutil
+from azurelinuxagent.common.dhcp import get_dhcp_handler
+from azurelinuxagent.common.protocol.ovfenv import OvfEnv
+from azurelinuxagent.common.protocol.wire import WireProtocol
+from azurelinuxagent.common.protocol.metadata import MetadataProtocol, \
+ METADATA_ENDPOINT
+import azurelinuxagent.common.utils.shellutil as shellutil
+
+OVF_FILE_NAME = "ovf-env.xml"
+
+#Tag file to indicate usage of metadata protocol
+TAG_FILE_NAME = "useMetadataEndpoint.tag"
+
+PROTOCOL_FILE_NAME = "Protocol"
+
+#MAX retry times for protocol probing
+MAX_RETRY = 360
+
+PROBE_INTERVAL = 10
+
+ENDPOINT_FILE_NAME = "WireServerEndpoint"
+
+def get_protocol_util():
+ return ProtocolUtil()
+
+class ProtocolUtil(object):
+ """
+ ProtocolUtil handles initialization for protocol instance. 2 protocol types
+ are invoked, wire protocol and metadata protocols.
+ """
+ def __init__(self):
+ self.lock = threading.Lock()
+ self.protocol = None
+ self.osutil = get_osutil()
+ self.dhcp_handler = get_dhcp_handler()
+
+ def copy_ovf_env(self):
+ """
+ Copy ovf env file from dvd to hard disk.
+ Remove password before save it to the disk
+ """
+ dvd_mount_point = conf.get_dvd_mount_point()
+ ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME)
+ tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME)
+ try:
+ self.osutil.mount_dvd()
+ ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
+ ovfenv = OvfEnv(ovfxml)
+ ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml)
+ ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
+ fileutil.write_file(ovf_file_path, ovfxml)
+
+ if os.path.isfile(tag_file_path_on_dvd):
+ logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME)
+ tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME)
+ shutil.copyfile(tag_file_path_on_dvd, tag_file_path)
+
+ except (OSUtilError, IOError) as e:
+ raise ProtocolError(ustr(e))
+
+ try:
+ self.osutil.umount_dvd()
+ self.osutil.eject_dvd()
+ except OSUtilError as e:
+ logger.warn(ustr(e))
+
+ return ovfenv
+
+ def get_ovf_env(self):
+ """
+ Load saved ovf-env.xml
+ """
+ ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
+ if os.path.isfile(ovf_file_path):
+ xml_text = fileutil.read_file(ovf_file_path)
+ return OvfEnv(xml_text)
+ else:
+ raise ProtocolError("ovf-env.xml is missing.")
+
+ def _get_wireserver_endpoint(self):
+ try:
+ file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME)
+ return fileutil.read_file(file_path)
+ except IOError as e:
+ raise OSUtilError(ustr(e))
+
+ def _set_wireserver_endpoint(self, endpoint):
+ try:
+ file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME)
+ fileutil.write_file(file_path, endpoint)
+ except IOError as e:
+ raise OSUtilError(ustr(e))
+
+ def _detect_wire_protocol(self):
+ endpoint = self.dhcp_handler.endpoint
+ if endpoint is None:
+ logger.info("WireServer endpoint is not found. Rerun dhcp handler")
+ try:
+ self.dhcp_handler.run()
+ except DhcpError as e:
+ raise ProtocolError(ustr(e))
+ endpoint = self.dhcp_handler.endpoint
+
+ try:
+ protocol = WireProtocol(endpoint)
+ protocol.detect()
+ self._set_wireserver_endpoint(endpoint)
+ self.save_protocol("WireProtocol")
+ return protocol
+ except ProtocolError as e:
+ logger.info("WireServer is not responding. Reset endpoint")
+ self.dhcp_handler.endpoint = None
+ self.dhcp_handler.skip_cache = True
+ raise e
+
+ def _detect_metadata_protocol(self):
+ protocol = MetadataProtocol()
+ protocol.detect()
+
+ #Only allow root access METADATA_ENDPOINT
+ self.osutil.set_admin_access_to_ip(METADATA_ENDPOINT)
+
+ self.save_protocol("MetadataProtocol")
+
+ return protocol
+
+ def _detect_protocol(self, protocols):
+ """
+ Probe protocol endpoints in turn.
+ """
+ self.clear_protocol()
+
+ for retry in range(0, MAX_RETRY):
+ for protocol in protocols:
+ try:
+ if protocol == "WireProtocol":
+ return self._detect_wire_protocol()
+
+ if protocol == "MetadataProtocol":
+ return self._detect_metadata_protocol()
+
+ except ProtocolError as e:
+ logger.info("Protocol endpoint not found: {0}, {1}",
+ protocol, e)
+
+ if retry < MAX_RETRY -1:
+ logger.info("Retry detect protocols: retry={0}", retry)
+ time.sleep(PROBE_INTERVAL)
+ raise ProtocolNotFoundError("No protocol found.")
+
+ def _get_protocol(self):
+ """
+ Get protocol instance based on previous detecting result.
+ """
+ protocol_file_path = os.path.join(conf.get_lib_dir(),
+ PROTOCOL_FILE_NAME)
+ if not os.path.isfile(protocol_file_path):
+ raise ProtocolNotFoundError("No protocol found")
+
+ protocol_name = fileutil.read_file(protocol_file_path)
+ if protocol_name == "WireProtocol":
+ endpoint = self._get_wireserver_endpoint()
+ return WireProtocol(endpoint)
+ elif protocol_name == "MetadataProtocol":
+ return MetadataProtocol()
+ else:
+ raise ProtocolNotFoundError(("Unknown protocol: {0}"
+ "").format(protocol_name))
+
+ def save_protocol(self, protocol_name):
+ """
+ Save protocol endpoint
+ """
+ protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME)
+ try:
+ fileutil.write_file(protocol_file_path, protocol_name)
+ except IOError as e:
+ logger.error("Failed to save protocol endpoint: {0}", e)
+
+
+ def clear_protocol(self):
+ """
+ Cleanup previous saved endpoint.
+ """
+ logger.info("Clean protocol")
+ self.protocol = None
+ protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME)
+ if not os.path.isfile(protocol_file_path):
+ return
+
+ try:
+ os.remove(protocol_file_path)
+ except IOError as e:
+ logger.error("Failed to clear protocol endpoint: {0}", e)
+
+ def get_protocol(self):
+ """
+ Detect protocol by endpoints
+
+ :returns: protocol instance
+ """
+ self.lock.acquire()
+
+ try:
+ if self.protocol is not None:
+ return self.protocol
+
+ try:
+ self.protocol = self._get_protocol()
+ return self.protocol
+ except ProtocolNotFoundError:
+ pass
+
+ logger.info("Detect protocol endpoints")
+ protocols = ["WireProtocol", "MetadataProtocol"]
+ self.protocol = self._detect_protocol(protocols)
+
+ return self.protocol
+
+ finally:
+ self.lock.release()
+
+
+ def get_protocol_by_file(self):
+ """
+ Detect protocol by tag file.
+
+ If a file "useMetadataEndpoint.tag" is found on provision iso,
+ metedata protocol will be used. No need to probe for wire protocol
+
+ :returns: protocol instance
+ """
+ self.lock.acquire()
+
+ try:
+ if self.protocol is not None:
+ return self.protocol
+
+ try:
+ self.protocol = self._get_protocol()
+ return self.protocol
+ except ProtocolNotFoundError:
+ pass
+
+ logger.info("Detect protocol by file")
+ tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME)
+ protocols = []
+ if os.path.isfile(tag_file_path):
+ protocols.append("MetadataProtocol")
+ else:
+ protocols.append("WireProtocol")
+ self.protocol = self._detect_protocol(protocols)
+ return self.protocol
+
+ finally:
+ self.lock.release()
diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py
new file mode 100644
index 0000000..29a1663
--- /dev/null
+++ b/azurelinuxagent/common/protocol/wire.py
@@ -0,0 +1,1218 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import time
+import xml.sax.saxutils as saxutils
+import azurelinuxagent.common.conf as conf
+from azurelinuxagent.common.exception import ProtocolNotFoundError
+from azurelinuxagent.common.future import httpclient, bytebuffer
+from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext, \
+ getattrib, gettext, remove_bom, get_bytes_from_pem
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
+from azurelinuxagent.common.protocol.restapi import *
+from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol
+
+VERSION_INFO_URI = "http://{0}/?comp=versions"
+GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate"
+HEALTH_REPORT_URI = "http://{0}/machine?comp=health"
+ROLE_PROP_URI = "http://{0}/machine?comp=roleProperties"
+TELEMETRY_URI = "http://{0}/machine?comp=telemetrydata"
+
+WIRE_SERVER_ADDR_FILE_NAME = "WireServer"
+INCARNATION_FILE_NAME = "Incarnation"
+GOAL_STATE_FILE_NAME = "GoalState.{0}.xml"
+HOSTING_ENV_FILE_NAME = "HostingEnvironmentConfig.xml"
+SHARED_CONF_FILE_NAME = "SharedConfig.xml"
+CERTS_FILE_NAME = "Certificates.xml"
+P7M_FILE_NAME = "Certificates.p7m"
+PEM_FILE_NAME = "Certificates.pem"
+EXT_CONF_FILE_NAME = "ExtensionsConfig.{0}.xml"
+MANIFEST_FILE_NAME = "{0}.{1}.manifest.xml"
+TRANSPORT_CERT_FILE_NAME = "TransportCert.pem"
+TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem"
+
+PROTOCOL_VERSION = "2012-11-30"
+ENDPOINT_FINE_NAME = "WireServer"
+
+SHORT_WAITING_INTERVAL = 1 # 1 second
+LONG_WAITING_INTERVAL = 15 # 15 seconds
+
+
+class UploadError(HttpError):
+ pass
+
+
+class WireProtocolResourceGone(ProtocolError):
+ pass
+
+
+class WireProtocol(Protocol):
+ """Slim layer to adapt wire protocol data to metadata protocol interface"""
+
+ # TODO: Clean-up goal state processing
+ # At present, some methods magically update GoalState (e.g., get_vmagent_manifests), others (e.g., get_vmagent_pkgs)
+ # assume its presence. A better approach would make an explicit update call that returns the incarnation number and
+ # establishes that number the "context" for all other calls (either by updating the internal state of the protocol or
+ # by having callers pass the incarnation number to the method).
+
+ def __init__(self, endpoint):
+ if endpoint is None:
+ raise ProtocolError("WireProtocol endpoint is None")
+ self.endpoint = endpoint
+ self.client = WireClient(self.endpoint)
+
+ def detect(self):
+ self.client.check_wire_protocol_version()
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
+
+ self.client.update_goal_state(forced=True)
+
+ def get_vminfo(self):
+ goal_state = self.client.get_goal_state()
+ hosting_env = self.client.get_hosting_env()
+
+ vminfo = VMInfo()
+ vminfo.subscriptionId = None
+ vminfo.vmName = hosting_env.vm_name
+ vminfo.tenantName = hosting_env.deployment_name
+ vminfo.roleName = hosting_env.role_name
+ vminfo.roleInstanceName = goal_state.role_instance_id
+ vminfo.containerId = goal_state.container_id
+ return vminfo
+
+ def get_certs(self):
+ certificates = self.client.get_certs()
+ return certificates.cert_list
+
+ def get_vmagent_manifests(self):
+ # Update goal state to get latest extensions config
+ self.client.update_goal_state()
+ goal_state = self.client.get_goal_state()
+ ext_conf = self.client.get_ext_conf()
+ return ext_conf.vmagent_manifests, goal_state.incarnation
+
+ def get_vmagent_pkgs(self, vmagent_manifest):
+ goal_state = self.client.get_goal_state()
+ man = self.client.get_gafamily_manifest(vmagent_manifest, goal_state)
+ return man.pkg_list
+
+ def get_ext_handlers(self):
+ logger.verbose("Get extension handler config")
+ # Update goal state to get latest extensions config
+ self.client.update_goal_state()
+ goal_state = self.client.get_goal_state()
+ ext_conf = self.client.get_ext_conf()
+ # In wire protocol, incarnation is equivalent to ETag
+ return ext_conf.ext_handlers, goal_state.incarnation
+
+ def get_ext_handler_pkgs(self, ext_handler):
+ logger.verbose("Get extension handler package")
+ goal_state = self.client.get_goal_state()
+ man = self.client.get_ext_manifest(ext_handler, goal_state)
+ return man.pkg_list
+
+ def report_provision_status(self, provision_status):
+ validate_param("provision_status", provision_status, ProvisionStatus)
+
+ if provision_status.status is not None:
+ self.client.report_health(provision_status.status,
+ provision_status.subStatus,
+ provision_status.description)
+ if provision_status.properties.certificateThumbprint is not None:
+ thumbprint = provision_status.properties.certificateThumbprint
+ self.client.report_role_prop(thumbprint)
+
+ def report_vm_status(self, vm_status):
+ validate_param("vm_status", vm_status, VMStatus)
+ self.client.status_blob.set_vm_status(vm_status)
+ self.client.upload_status_blob()
+
+ def report_ext_status(self, ext_handler_name, ext_name, ext_status):
+ validate_param("ext_status", ext_status, ExtensionStatus)
+ self.client.status_blob.set_ext_status(ext_handler_name, ext_status)
+
+ def report_event(self, events):
+ validate_param("events", events, TelemetryEventList)
+ self.client.report_event(events)
+
+
+def _build_role_properties(container_id, role_instance_id, thumbprint):
+ xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>"
+ u"<RoleProperties>"
+ u"<Container>"
+ u"<ContainerId>{0}</ContainerId>"
+ u"<RoleInstances>"
+ u"<RoleInstance>"
+ u"<Id>{1}</Id>"
+ u"<Properties>"
+ u"<Property name=\"CertificateThumbprint\" value=\"{2}\" />"
+ u"</Properties>"
+ u"</RoleInstance>"
+ u"</RoleInstances>"
+ u"</Container>"
+ u"</RoleProperties>"
+ u"").format(container_id, role_instance_id, thumbprint)
+ return xml
+
+
+def _build_health_report(incarnation, container_id, role_instance_id,
+ status, substatus, description):
+ # Escape '&', '<' and '>'
+ description = saxutils.escape(ustr(description))
+ detail = u''
+ if substatus is not None:
+ substatus = saxutils.escape(ustr(substatus))
+ detail = (u"<Details>"
+ u"<SubStatus>{0}</SubStatus>"
+ u"<Description>{1}</Description>"
+ u"</Details>").format(substatus, description)
+ xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>"
+ u"<Health "
+ u"xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\""
+ u" xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\">"
+ u"<GoalStateIncarnation>{0}</GoalStateIncarnation>"
+ u"<Container>"
+ u"<ContainerId>{1}</ContainerId>"
+ u"<RoleInstanceList>"
+ u"<Role>"
+ u"<InstanceId>{2}</InstanceId>"
+ u"<Health>"
+ u"<State>{3}</State>"
+ u"{4}"
+ u"</Health>"
+ u"</Role>"
+ u"</RoleInstanceList>"
+ u"</Container>"
+ u"</Health>"
+ u"").format(incarnation,
+ container_id,
+ role_instance_id,
+ status,
+ detail)
+ return xml
+
+
+"""
+Convert VMStatus object to status blob format
+"""
+
+
+def ga_status_to_v1(ga_status):
+ formatted_msg = {
+ 'lang': 'en-US',
+ 'message': ga_status.message
+ }
+ v1_ga_status = {
+ 'version': ga_status.version,
+ 'status': ga_status.status,
+ 'formattedMessage': formatted_msg
+ }
+ return v1_ga_status
+
+
+def ext_substatus_to_v1(sub_status_list):
+ status_list = []
+ for substatus in sub_status_list:
+ status = {
+ "name": substatus.name,
+ "status": substatus.status,
+ "code": substatus.code,
+ "formattedMessage": {
+ "lang": "en-US",
+ "message": substatus.message
+ }
+ }
+ status_list.append(status)
+ return status_list
+
+
+def ext_status_to_v1(ext_name, ext_status):
+ if ext_status is None:
+ return None
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ v1_sub_status = ext_substatus_to_v1(ext_status.substatusList)
+ v1_ext_status = {
+ "status": {
+ "name": ext_name,
+ "configurationAppliedTime": ext_status.configurationAppliedTime,
+ "operation": ext_status.operation,
+ "status": ext_status.status,
+ "code": ext_status.code,
+ "formattedMessage": {
+ "lang": "en-US",
+ "message": ext_status.message
+ }
+ },
+ "version": 1.0,
+ "timestampUTC": timestamp
+ }
+ if len(v1_sub_status) != 0:
+ v1_ext_status['substatus'] = v1_sub_status
+ return v1_ext_status
+
+
+def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp):
+ v1_handler_status = {
+ 'handlerVersion': handler_status.version,
+ 'handlerName': handler_status.name,
+ 'status': handler_status.status,
+ 'code': handler_status.code
+ }
+ if handler_status.message is not None:
+ v1_handler_status["formattedMessage"] = {
+ "lang": "en-US",
+ "message": handler_status.message
+ }
+
+ if len(handler_status.extensions) > 0:
+ # Currently, no more than one extension per handler
+ ext_name = handler_status.extensions[0]
+ ext_status = ext_statuses.get(ext_name)
+ v1_ext_status = ext_status_to_v1(ext_name, ext_status)
+ if ext_status is not None and v1_ext_status is not None:
+ v1_handler_status["runtimeSettingsStatus"] = {
+ 'settingsStatus': v1_ext_status,
+ 'sequenceNumber': ext_status.sequenceNumber
+ }
+ return v1_handler_status
+
+
+def vm_status_to_v1(vm_status, ext_statuses):
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+
+ v1_ga_status = ga_status_to_v1(vm_status.vmAgent)
+ v1_handler_status_list = []
+ for handler_status in vm_status.vmAgent.extensionHandlers:
+ v1_handler_status = ext_handler_status_to_v1(handler_status,
+ ext_statuses, timestamp)
+ if v1_handler_status is not None:
+ v1_handler_status_list.append(v1_handler_status)
+
+ v1_agg_status = {
+ 'guestAgentStatus': v1_ga_status,
+ 'handlerAggregateStatus': v1_handler_status_list
+ }
+ v1_vm_status = {
+ 'version': '1.0',
+ 'timestampUTC': timestamp,
+ 'aggregateStatus': v1_agg_status
+ }
+ return v1_vm_status
+
+
+class StatusBlob(object):
+ def __init__(self, client):
+ self.vm_status = None
+ self.ext_statuses = {}
+ self.client = client
+ self.type = None
+ self.data = None
+
+ def set_vm_status(self, vm_status):
+ validate_param("vmAgent", vm_status, VMStatus)
+ self.vm_status = vm_status
+
+ def set_ext_status(self, ext_handler_name, ext_status):
+ validate_param("extensionStatus", ext_status, ExtensionStatus)
+ self.ext_statuses[ext_handler_name] = ext_status
+
+ def to_json(self):
+ report = vm_status_to_v1(self.vm_status, self.ext_statuses)
+ return json.dumps(report)
+
+ __storage_version__ = "2014-02-14"
+
+ def upload(self, url):
+ # TODO upload extension only if content has changed
+ logger.verbose("Upload status blob")
+ upload_successful = False
+ self.type = self.get_blob_type(url)
+ self.data = self.to_json()
+ try:
+ if self.type == "BlockBlob":
+ self.put_block_blob(url, self.data)
+ elif self.type == "PageBlob":
+ self.put_page_blob(url, self.data)
+ else:
+ raise ProtocolError("Unknown blob type: {0}".format(self.type))
+ except HttpError as e:
+ logger.warn("Initial upload failed [{0}]".format(e))
+ else:
+ logger.verbose("Uploading status blob succeeded")
+ upload_successful = True
+ return upload_successful
+
+ def get_blob_type(self, url):
+ # Check blob type
+ logger.verbose("Check blob type.")
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ try:
+ resp = self.client.call_storage_service(restutil.http_head, url, {
+ "x-ms-date": timestamp,
+ 'x-ms-version': self.__class__.__storage_version__
+ })
+ except HttpError as e:
+ raise ProtocolError((u"Failed to get status blob type: {0}"
+ u"").format(e))
+ if resp is None or resp.status != httpclient.OK:
+ raise ProtocolError(("Failed to get status blob type: {0}"
+ "").format(resp.status))
+
+ blob_type = resp.getheader("x-ms-blob-type")
+ logger.verbose("Blob type={0}".format(blob_type))
+ return blob_type
+
+ def put_block_blob(self, url, data):
+ logger.verbose("Upload block blob")
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ resp = self.client.call_storage_service(restutil.http_put, url, data,
+ {
+ "x-ms-date": timestamp,
+ "x-ms-blob-type": "BlockBlob",
+ "Content-Length": ustr(len(data)),
+ "x-ms-version": self.__class__.__storage_version__
+ })
+ if resp.status != httpclient.CREATED:
+ raise UploadError(
+ "Failed to upload block blob: {0}".format(resp.status))
+
+ def put_page_blob(self, url, data):
+ logger.verbose("Replace old page blob")
+
+ # Convert string into bytes
+ data = bytearray(data, encoding='utf-8')
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+
+ # Align to 512 bytes
+ page_blob_size = int((len(data) + 511) / 512) * 512
+ resp = self.client.call_storage_service(restutil.http_put, url, "",
+ {
+ "x-ms-date": timestamp,
+ "x-ms-blob-type": "PageBlob",
+ "Content-Length": "0",
+ "x-ms-blob-content-length": ustr(page_blob_size),
+ "x-ms-version": self.__class__.__storage_version__
+ })
+ if resp.status != httpclient.CREATED:
+ raise UploadError(
+ "Failed to clean up page blob: {0}".format(resp.status))
+
+ if url.count("?") < 0:
+ url = "{0}?comp=page".format(url)
+ else:
+ url = "{0}&comp=page".format(url)
+
+ logger.verbose("Upload page blob")
+ page_max = 4 * 1024 * 1024 # Max page size: 4MB
+ start = 0
+ end = 0
+ while end < len(data):
+ end = min(len(data), start + page_max)
+ content_size = end - start
+ # Align to 512 bytes
+ page_end = int((end + 511) / 512) * 512
+ buf_size = page_end - start
+ buf = bytearray(buf_size)
+ buf[0: content_size] = data[start: end]
+ resp = self.client.call_storage_service(
+ restutil.http_put, url, bytebuffer(buf),
+ {
+ "x-ms-date": timestamp,
+ "x-ms-range": "bytes={0}-{1}".format(start, page_end - 1),
+ "x-ms-page-write": "update",
+ "x-ms-version": self.__class__.__storage_version__,
+ "Content-Length": ustr(page_end - start)
+ })
+ if resp is None or resp.status != httpclient.CREATED:
+ raise UploadError(
+ "Failed to upload page blob: {0}".format(resp.status))
+ start = end
+
+
+def event_param_to_v1(param):
+ param_format = '<Param Name="{0}" Value={1} T="{2}" />'
+ param_type = type(param.value)
+ attr_type = ""
+ if param_type is int:
+ attr_type = 'mt:uint64'
+ elif param_type is str:
+ attr_type = 'mt:wstr'
+ elif ustr(param_type).count("'unicode'") > 0:
+ attr_type = 'mt:wstr'
+ elif param_type is bool:
+ attr_type = 'mt:bool'
+ elif param_type is float:
+ attr_type = 'mt:float64'
+ return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)),
+ attr_type)
+
+
+def event_to_v1(event):
+ params = ""
+ for param in event.parameters:
+ params += event_param_to_v1(param)
+ event_str = ('<Event id="{0}">'
+ '<![CDATA[{1}]]>'
+ '</Event>').format(event.eventId, params)
+ return event_str
+
+
+class WireClient(object):
+ def __init__(self, endpoint):
+ logger.info("Wire server endpoint:{0}", endpoint)
+ self.endpoint = endpoint
+ self.goal_state = None
+ self.updated = None
+ self.hosting_env = None
+ self.shared_conf = None
+ self.certs = None
+ self.ext_conf = None
+ self.last_request = 0
+ self.req_count = 0
+ self.status_blob = StatusBlob(self)
+ self.host_plugin = HostPluginProtocol(self.endpoint)
+
+ def prevent_throttling(self):
+ """
+ Try to avoid throttling of wire server
+ """
+ now = time.time()
+ if now - self.last_request < 1:
+ logger.verbose("Last request issued less than 1 second ago")
+ logger.verbose("Sleep {0} second to avoid throttling.",
+ SHORT_WAITING_INTERVAL)
+ time.sleep(SHORT_WAITING_INTERVAL)
+ self.last_request = now
+
+ self.req_count += 1
+ if self.req_count % 3 == 0:
+ logger.verbose("Sleep {0} second to avoid throttling.",
+ SHORT_WAITING_INTERVAL)
+ time.sleep(SHORT_WAITING_INTERVAL)
+ self.req_count = 0
+
+ def call_wireserver(self, http_req, *args, **kwargs):
+ """
+ Call wire server. Handle throttling(403) and Resource Gone(410)
+ """
+ self.prevent_throttling()
+ for retry in range(0, 3):
+ resp = http_req(*args, **kwargs)
+ if resp.status == httpclient.FORBIDDEN:
+ logger.warn("Sending too much request to wire server")
+ logger.info("Sleep {0} second to avoid throttling.",
+ LONG_WAITING_INTERVAL)
+ time.sleep(LONG_WAITING_INTERVAL)
+ elif resp.status == httpclient.GONE:
+ msg = args[0] if len(args) > 0 else ""
+ raise WireProtocolResourceGone(msg)
+ else:
+ return resp
+ raise ProtocolError(("Calling wire server failed: {0}"
+ "").format(resp.status))
+
+ def decode_config(self, data):
+ if data is None:
+ return None
+ data = remove_bom(data)
+ xml_text = ustr(data, encoding='utf-8')
+ return xml_text
+
+ def fetch_config(self, uri, headers):
+ try:
+ resp = self.call_wireserver(restutil.http_get, uri,
+ headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+
+ if (resp.status != httpclient.OK):
+ raise ProtocolError("{0} - {1}".format(resp.status, uri))
+
+ return self.decode_config(resp.read())
+
+ def fetch_cache(self, local_file):
+ if not os.path.isfile(local_file):
+ raise ProtocolError("{0} is missing.".format(local_file))
+ try:
+ return fileutil.read_file(local_file)
+ except IOError as e:
+ raise ProtocolError("Failed to read cache: {0}".format(e))
+
+ def save_cache(self, local_file, data):
+ try:
+ fileutil.write_file(local_file, data)
+ except IOError as e:
+ raise ProtocolError("Failed to write cache: {0}".format(e))
+
+ def call_storage_service(self, http_req, *args, **kwargs):
+ """
+ Call storage service, handle SERVICE_UNAVAILABLE(503)
+ """
+ for retry in range(0, 3):
+ resp = http_req(*args, **kwargs)
+ if resp.status == httpclient.SERVICE_UNAVAILABLE:
+ logger.warn("Storage service is not avaible temporaryly")
+ logger.info("Will retry later, in {0} seconds",
+ LONG_WAITING_INTERVAL)
+ time.sleep(LONG_WAITING_INTERVAL)
+ else:
+ return resp
+ raise ProtocolError(("Calling storage endpoint failed: {0}"
+ "").format(resp.status))
+
+ def fetch_manifest(self, version_uris):
+ for version_uri in version_uris:
+ logger.verbose("Fetch ext handler manifest: {0}", version_uri.uri)
+ try:
+ resp = self.call_storage_service(restutil.http_get,
+ version_uri.uri, None,
+ chk_proxy=True)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+
+ if resp.status == httpclient.OK:
+ return self.decode_config(resp.read())
+ logger.warn("Failed to fetch ExtensionManifest: {0}, {1}",
+ resp.status, version_uri.uri)
+ logger.info("Will retry later, in {0} seconds",
+ LONG_WAITING_INTERVAL)
+ time.sleep(LONG_WAITING_INTERVAL)
+ raise ProtocolError(("Failed to fetch ExtensionManifest from "
+ "all sources"))
+
+ def update_hosting_env(self, goal_state):
+ if goal_state.hosting_env_uri is None:
+ raise ProtocolError("HostingEnvironmentConfig uri is empty")
+ local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
+ xml_text = self.fetch_config(goal_state.hosting_env_uri,
+ self.get_header())
+ self.save_cache(local_file, xml_text)
+ self.hosting_env = HostingEnv(xml_text)
+
+ def update_shared_conf(self, goal_state):
+ if goal_state.shared_conf_uri is None:
+ raise ProtocolError("SharedConfig uri is empty")
+ local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
+ xml_text = self.fetch_config(goal_state.shared_conf_uri,
+ self.get_header())
+ self.save_cache(local_file, xml_text)
+ self.shared_conf = SharedConfig(xml_text)
+
+ def update_certs(self, goal_state):
+ if goal_state.certs_uri is None:
+ return
+ local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
+ xml_text = self.fetch_config(goal_state.certs_uri,
+ self.get_header_for_cert())
+ self.save_cache(local_file, xml_text)
+ self.certs = Certificates(self, xml_text)
+
+ def update_ext_conf(self, goal_state):
+ if goal_state.ext_uri is None:
+ logger.info("ExtensionsConfig.xml uri is empty")
+ self.ext_conf = ExtensionsConfig(None)
+ return
+ incarnation = goal_state.incarnation
+ local_file = os.path.join(conf.get_lib_dir(),
+ EXT_CONF_FILE_NAME.format(incarnation))
+ xml_text = self.fetch_config(goal_state.ext_uri, self.get_header())
+ self.save_cache(local_file, xml_text)
+ self.ext_conf = ExtensionsConfig(xml_text)
+
+ def update_goal_state(self, forced=False, max_retry=3):
+ uri = GOAL_STATE_URI.format(self.endpoint)
+ xml_text = self.fetch_config(uri, self.get_header())
+ goal_state = GoalState(xml_text)
+
+ incarnation_file = os.path.join(conf.get_lib_dir(),
+ INCARNATION_FILE_NAME)
+
+ if not forced:
+ last_incarnation = None
+ if (os.path.isfile(incarnation_file)):
+ last_incarnation = fileutil.read_file(incarnation_file)
+ new_incarnation = goal_state.incarnation
+ if last_incarnation is not None and \
+ last_incarnation == new_incarnation:
+ # Goalstate is not updated.
+ return
+
+ # Start updating goalstate, retry on 410
+ for retry in range(0, max_retry):
+ try:
+ self.goal_state = goal_state
+ file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation)
+ goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
+ self.save_cache(goal_state_file, xml_text)
+ self.save_cache(incarnation_file, goal_state.incarnation)
+ self.update_hosting_env(goal_state)
+ self.update_shared_conf(goal_state)
+ self.update_certs(goal_state)
+ self.update_ext_conf(goal_state)
+ return
+ except WireProtocolResourceGone:
+ logger.info("Incarnation is out of date. Update goalstate.")
+ xml_text = self.fetch_config(uri, self.get_header())
+ goal_state = GoalState(xml_text)
+
+ raise ProtocolError("Exceeded max retry updating goal state")
+
+ def get_goal_state(self):
+ if (self.goal_state is None):
+ incarnation_file = os.path.join(conf.get_lib_dir(),
+ INCARNATION_FILE_NAME)
+ incarnation = self.fetch_cache(incarnation_file)
+
+ file_name = GOAL_STATE_FILE_NAME.format(incarnation)
+ goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
+ xml_text = self.fetch_cache(goal_state_file)
+ self.goal_state = GoalState(xml_text)
+ return self.goal_state
+
+ def get_hosting_env(self):
+ if (self.hosting_env is None):
+ local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
+ self.hosting_env = HostingEnv(xml_text)
+ return self.hosting_env
+
+ def get_shared_conf(self):
+ if (self.shared_conf is None):
+ local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
+ self.shared_conf = SharedConfig(xml_text)
+ return self.shared_conf
+
+ def get_certs(self):
+ if (self.certs is None):
+ local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
+ self.certs = Certificates(self, xml_text)
+ if self.certs is None:
+ return None
+ return self.certs
+
+ def get_ext_conf(self):
+ if (self.ext_conf is None):
+ goal_state = self.get_goal_state()
+ if goal_state.ext_uri is None:
+ self.ext_conf = ExtensionsConfig(None)
+ else:
+ local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_cache(local_file)
+ self.ext_conf = ExtensionsConfig(xml_text)
+ return self.ext_conf
+
+ def get_ext_manifest(self, ext_handler, goal_state):
+ local_file = MANIFEST_FILE_NAME.format(ext_handler.name,
+ goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_manifest(ext_handler.versionUris)
+ self.save_cache(local_file, xml_text)
+ return ExtensionManifest(xml_text)
+
+ def get_gafamily_manifest(self, vmagent_manifest, goal_state):
+ local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family,
+ goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris)
+ fileutil.write_file(local_file, xml_text)
+ return ExtensionManifest(xml_text)
+
+ def check_wire_protocol_version(self):
+ uri = VERSION_INFO_URI.format(self.endpoint)
+ version_info_xml = self.fetch_config(uri, None)
+ version_info = VersionInfo(version_info_xml)
+
+ preferred = version_info.get_preferred()
+ if PROTOCOL_VERSION == preferred:
+ logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
+ elif PROTOCOL_VERSION in version_info.get_supported():
+ logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
+ logger.warn("Server prefered version:{0}", preferred)
+ else:
+ error = ("Agent supported wire protocol version: {0} was not "
+ "advised by Fabric.").format(PROTOCOL_VERSION)
+ raise ProtocolNotFoundError(error)
+
+ def upload_status_blob(self):
+ ext_conf = self.get_ext_conf()
+ if ext_conf.status_upload_blob is not None:
+ if not self.status_blob.upload(ext_conf.status_upload_blob):
+ self.host_plugin.put_vm_status(self.status_blob,
+ ext_conf.status_upload_blob)
+
+ def report_role_prop(self, thumbprint):
+ goal_state = self.get_goal_state()
+ role_prop = _build_role_properties(goal_state.container_id,
+ goal_state.role_instance_id,
+ thumbprint)
+ role_prop = role_prop.encode("utf-8")
+ role_prop_uri = ROLE_PROP_URI.format(self.endpoint)
+ headers = self.get_header_for_xml_content()
+ try:
+ resp = self.call_wireserver(restutil.http_post, role_prop_uri,
+ role_prop, headers=headers)
+ except HttpError as e:
+ raise ProtocolError((u"Failed to send role properties: {0}"
+ u"").format(e))
+ if resp.status != httpclient.ACCEPTED:
+ raise ProtocolError((u"Failed to send role properties: {0}"
+ u", {1}").format(resp.status, resp.read()))
+
+ def report_health(self, status, substatus, description):
+ goal_state = self.get_goal_state()
+ health_report = _build_health_report(goal_state.incarnation,
+ goal_state.container_id,
+ goal_state.role_instance_id,
+ status,
+ substatus,
+ description)
+ health_report = health_report.encode("utf-8")
+ health_report_uri = HEALTH_REPORT_URI.format(self.endpoint)
+ headers = self.get_header_for_xml_content()
+ try:
+ resp = self.call_wireserver(restutil.http_post, health_report_uri,
+ health_report, headers=headers, max_retry=8)
+ except HttpError as e:
+ raise ProtocolError((u"Failed to send provision status: {0}"
+ u"").format(e))
+ if resp.status != httpclient.OK:
+ raise ProtocolError((u"Failed to send provision status: {0}"
+ u", {1}").format(resp.status, resp.read()))
+
+ def send_event(self, provider_id, event_str):
+ uri = TELEMETRY_URI.format(self.endpoint)
+ data_format = ('<?xml version="1.0"?>'
+ '<TelemetryData version="1.0">'
+ '<Provider id="{0}">{1}'
+ '</Provider>'
+ '</TelemetryData>')
+ data = data_format.format(provider_id, event_str)
+ try:
+ header = self.get_header_for_xml_content()
+ resp = self.call_wireserver(restutil.http_post, uri, data, header)
+ except HttpError as e:
+ raise ProtocolError("Failed to send events:{0}".format(e))
+
+ if resp.status != httpclient.OK:
+ logger.verbose(resp.read())
+ raise ProtocolError("Failed to send events:{0}".format(resp.status))
+
+ def report_event(self, event_list):
+ buf = {}
+ # Group events by providerId
+ for event in event_list.events:
+ if event.providerId not in buf:
+ buf[event.providerId] = ""
+ event_str = event_to_v1(event)
+ if len(event_str) >= 63 * 1024:
+ logger.warn("Single event too large: {0}", event_str[300:])
+ continue
+ if len(buf[event.providerId] + event_str) >= 63 * 1024:
+ self.send_event(event.providerId, buf[event.providerId])
+ buf[event.providerId] = ""
+ buf[event.providerId] = buf[event.providerId] + event_str
+
+ # Send out all events left in buffer.
+ for provider_id in list(buf.keys()):
+ if len(buf[provider_id]) > 0:
+ self.send_event(provider_id, buf[provider_id])
+
+ def get_header(self):
+ return {
+ "x-ms-agent-name": "WALinuxAgent",
+ "x-ms-version": PROTOCOL_VERSION
+ }
+
+ def get_header_for_xml_content(self):
+ return {
+ "x-ms-agent-name": "WALinuxAgent",
+ "x-ms-version": PROTOCOL_VERSION,
+ "Content-Type": "text/xml;charset=utf-8"
+ }
+
+ def get_header_for_cert(self):
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ content = self.fetch_cache(trans_cert_file)
+ cert = get_bytes_from_pem(content)
+ return {
+ "x-ms-agent-name": "WALinuxAgent",
+ "x-ms-version": PROTOCOL_VERSION,
+ "x-ms-cipher-name": "DES_EDE3_CBC",
+ "x-ms-guest-agent-public-x509-cert": cert
+ }
+
+class VersionInfo(object):
+ def __init__(self, xml_text):
+ """
+ Query endpoint server for wire protocol version.
+ Fail if our desired protocol version is not seen.
+ """
+ logger.verbose("Load Version.xml")
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ xml_doc = parse_doc(xml_text)
+ preferred = find(xml_doc, "Preferred")
+ self.preferred = findtext(preferred, "Version")
+ logger.info("Fabric preferred wire protocol version:{0}", self.preferred)
+
+ self.supported = []
+ supported = find(xml_doc, "Supported")
+ supported_version = findall(supported, "Version")
+ for node in supported_version:
+ version = gettext(node)
+ logger.verbose("Fabric supported wire protocol version:{0}", version)
+ self.supported.append(version)
+
+ def get_preferred(self):
+ return self.preferred
+
+ def get_supported(self):
+ return self.supported
+
+
+class GoalState(object):
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("GoalState.xml is None")
+ logger.verbose("Load GoalState.xml")
+ self.incarnation = None
+ self.expected_state = None
+ self.hosting_env_uri = None
+ self.shared_conf_uri = None
+ self.certs_uri = None
+ self.ext_uri = None
+ self.role_instance_id = None
+ self.container_id = None
+ self.load_balancer_probe_port = None
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Request configuration data from endpoint server.
+ """
+ self.xml_text = xml_text
+ xml_doc = parse_doc(xml_text)
+ self.incarnation = findtext(xml_doc, "Incarnation")
+ self.expected_state = findtext(xml_doc, "ExpectedState")
+ self.hosting_env_uri = findtext(xml_doc, "HostingEnvironmentConfig")
+ self.shared_conf_uri = findtext(xml_doc, "SharedConfig")
+ self.certs_uri = findtext(xml_doc, "Certificates")
+ self.ext_uri = findtext(xml_doc, "ExtensionsConfig")
+ role_instance = find(xml_doc, "RoleInstance")
+ self.role_instance_id = findtext(role_instance, "InstanceId")
+ container = find(xml_doc, "Container")
+ self.container_id = findtext(container, "ContainerId")
+ lbprobe_ports = find(xml_doc, "LBProbePorts")
+ self.load_balancer_probe_port = findtext(lbprobe_ports, "Port")
+ return self
+
+
+class HostingEnv(object):
+ """
+ parse Hosting enviromnet config and store in
+ HostingEnvironmentConfig.xml
+ """
+
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("HostingEnvironmentConfig.xml is None")
+ logger.verbose("Load HostingEnvironmentConfig.xml")
+ self.vm_name = None
+ self.role_name = None
+ self.deployment_name = None
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ parse and create HostingEnvironmentConfig.xml.
+ """
+ self.xml_text = xml_text
+ xml_doc = parse_doc(xml_text)
+ incarnation = find(xml_doc, "Incarnation")
+ self.vm_name = getattrib(incarnation, "instance")
+ role = find(xml_doc, "Role")
+ self.role_name = getattrib(role, "name")
+ deployment = find(xml_doc, "Deployment")
+ self.deployment_name = getattrib(deployment, "name")
+ return self
+
+
+class SharedConfig(object):
+ """
+ parse role endpoint server and goal state config.
+ """
+
+ def __init__(self, xml_text):
+ logger.verbose("Load SharedConfig.xml")
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ parse and write configuration to file SharedConfig.xml.
+ """
+ # Not used currently
+ return self
+
+class Certificates(object):
+ """
+ Object containing certificates of host and provisioned user.
+ """
+
+ def __init__(self, client, xml_text):
+ logger.verbose("Load Certificates.xml")
+ self.client = client
+ self.cert_list = CertList()
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Parse multiple certificates into seperate files.
+ """
+ xml_doc = parse_doc(xml_text)
+ data = findtext(xml_doc, "Data")
+ if data is None:
+ return
+
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
+ p7m = ("MIME-Version:1.0\n"
+ "Content-Disposition: attachment; filename=\"{0}\"\n"
+ "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
+ "Content-Transfer-Encoding: base64\n"
+ "\n"
+ "{2}").format(p7m_file, p7m_file, data)
+
+ self.client.save_cache(p7m_file, p7m)
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
+ # decrypt certificates
+ cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
+ pem_file)
+
+ # The parsing process use public key to match prv and crt.
+ buf = []
+ begin_crt = False
+ begin_prv = False
+ prvs = {}
+ thumbprints = {}
+ index = 0
+ v1_cert_list = []
+ with open(pem_file) as pem:
+ for line in pem.readlines():
+ buf.append(line)
+ if re.match(r'[-]+BEGIN.*KEY[-]+', line):
+ begin_prv = True
+ elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
+ begin_crt = True
+ elif re.match(r'[-]+END.*KEY[-]+', line):
+ tmp_file = self.write_to_tmp_file(index, 'prv', buf)
+ pub = cryptutil.get_pubkey_from_prv(tmp_file)
+ prvs[pub] = tmp_file
+ buf = []
+ index += 1
+ begin_prv = False
+ elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
+ tmp_file = self.write_to_tmp_file(index, 'crt', buf)
+ pub = cryptutil.get_pubkey_from_crt(tmp_file)
+ thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
+ thumbprints[pub] = thumbprint
+ # Rename crt with thumbprint as the file name
+ crt = "{0}.crt".format(thumbprint)
+ v1_cert_list.append({
+ "name": None,
+ "thumbprint": thumbprint
+ })
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
+ buf = []
+ index += 1
+ begin_crt = False
+
+ # Rename prv key with thumbprint as the file name
+ for pubkey in prvs:
+ thumbprint = thumbprints[pubkey]
+ if thumbprint:
+ tmp_file = prvs[pubkey]
+ prv = "{0}.prv".format(thumbprint)
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))
+
+ for v1_cert in v1_cert_list:
+ cert = Cert()
+ set_properties("certs", cert, v1_cert)
+ self.cert_list.certificates.append(cert)
+
+ def write_to_tmp_file(self, index, suffix, buf):
+ file_name = os.path.join(conf.get_lib_dir(),
+ "{0}.{1}".format(index, suffix))
+ self.client.save_cache(file_name, "".join(buf))
+ return file_name
+
+
+class ExtensionsConfig(object):
+ """
+ parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent.
+ Install if <enabled>true</enabled>, remove if it is set to false.
+ """
+
+ def __init__(self, xml_text):
+ logger.verbose("Load ExtensionsConfig.xml")
+ self.ext_handlers = ExtHandlerList()
+ self.vmagent_manifests = VMAgentManifestList()
+ self.status_upload_blob = None
+ if xml_text is not None:
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Write configuration to file ExtensionsConfig.xml.
+ """
+ xml_doc = parse_doc(xml_text)
+
+ ga_families_list = find(xml_doc, "GAFamilies")
+ ga_families = findall(ga_families_list, "GAFamily")
+
+ for ga_family in ga_families:
+ family = findtext(ga_family, "Name")
+ uris_list = find(ga_family, "Uris")
+ uris = findall(uris_list, "Uri")
+ manifest = VMAgentManifest()
+ manifest.family = family
+ for uri in uris:
+ manifestUri = VMAgentManifestUri(uri=gettext(uri))
+ manifest.versionsManifestUris.append(manifestUri)
+ self.vmagent_manifests.vmAgentManifests.append(manifest)
+
+ plugins_list = find(xml_doc, "Plugins")
+ plugins = findall(plugins_list, "Plugin")
+ plugin_settings_list = find(xml_doc, "PluginSettings")
+ plugin_settings = findall(plugin_settings_list, "Plugin")
+
+ for plugin in plugins:
+ ext_handler = self.parse_plugin(plugin)
+ self.ext_handlers.extHandlers.append(ext_handler)
+ self.parse_plugin_settings(ext_handler, plugin_settings)
+
+ self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob")
+
+ def parse_plugin(self, plugin):
+ ext_handler = ExtHandler()
+ ext_handler.name = getattrib(plugin, "name")
+ ext_handler.properties.version = getattrib(plugin, "version")
+ ext_handler.properties.state = getattrib(plugin, "state")
+
+ auto_upgrade = getattrib(plugin, "autoUpgrade")
+ if auto_upgrade is not None and auto_upgrade.lower() == "true":
+ ext_handler.properties.upgradePolicy = "auto"
+ else:
+ ext_handler.properties.upgradePolicy = "manual"
+
+ location = getattrib(plugin, "location")
+ failover_location = getattrib(plugin, "failoverlocation")
+ for uri in [location, failover_location]:
+ version_uri = ExtHandlerVersionUri()
+ version_uri.uri = uri
+ ext_handler.versionUris.append(version_uri)
+ return ext_handler
+
+ def parse_plugin_settings(self, ext_handler, plugin_settings):
+ if plugin_settings is None:
+ return
+
+ name = ext_handler.name
+ version = ext_handler.properties.version
+ settings = [x for x in plugin_settings \
+ if getattrib(x, "name") == name and \
+ getattrib(x, "version") == version]
+
+ if settings is None or len(settings) == 0:
+ return
+
+ runtime_settings = None
+ runtime_settings_node = find(settings[0], "RuntimeSettings")
+ seqNo = getattrib(runtime_settings_node, "seqNo")
+ runtime_settings_str = gettext(runtime_settings_node)
+ try:
+ runtime_settings = json.loads(runtime_settings_str)
+ except ValueError as e:
+ logger.error("Invalid extension settings")
+ return
+
+ for plugin_settings_list in runtime_settings["runtimeSettings"]:
+ handler_settings = plugin_settings_list["handlerSettings"]
+ ext = Extension()
+ # There is no "extension name" in wire protocol.
+ # Put
+ ext.name = ext_handler.name
+ ext.sequenceNumber = seqNo
+ ext.publicSettings = handler_settings.get("publicSettings")
+ ext.protectedSettings = handler_settings.get("protectedSettings")
+ thumbprint = handler_settings.get("protectedSettingsCertThumbprint")
+ ext.certificateThumbprint = thumbprint
+ ext_handler.properties.extensions.append(ext)
+
+
+class ExtensionManifest(object):
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("ExtensionManifest is None")
+ logger.verbose("Load ExtensionManifest.xml")
+ self.pkg_list = ExtHandlerPackageList()
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ xml_doc = parse_doc(xml_text)
+ self._handle_packages(findall(find(xml_doc, "Plugins"), "Plugin"), False)
+ self._handle_packages(findall(find(xml_doc, "InternalPlugins"), "Plugin"), True)
+
+ def _handle_packages(self, packages, isinternal):
+ for package in packages:
+ version = findtext(package, "Version")
+
+ disallow_major_upgrade = findtext(package, "DisallowMajorVersionUpgrade")
+ if disallow_major_upgrade is None:
+ disallow_major_upgrade = ''
+ disallow_major_upgrade = disallow_major_upgrade.lower() == "true"
+
+ uris = find(package, "Uris")
+ uri_list = findall(uris, "Uri")
+ uri_list = [gettext(x) for x in uri_list]
+ pkg = ExtHandlerPackage()
+ pkg.version = version
+ pkg.disallow_major_upgrade = disallow_major_upgrade
+ for uri in uri_list:
+ pkg_uri = ExtHandlerVersionUri()
+ pkg_uri.uri = uri
+ pkg.uris.append(pkg_uri)
+
+ pkg.isinternal = isinternal
+ self.pkg_list.versions.append(pkg)