diff options
Diffstat (limited to 'azurelinuxagent/protocol')
-rw-r--r-- | azurelinuxagent/protocol/__init__.py | 5 | ||||
-rw-r--r-- | azurelinuxagent/protocol/metadata.py (renamed from azurelinuxagent/protocol/v2.py) | 90 | ||||
-rw-r--r-- | azurelinuxagent/protocol/ovfenv.py | 46 | ||||
-rw-r--r-- | azurelinuxagent/protocol/protocolFactory.py | 114 | ||||
-rw-r--r-- | azurelinuxagent/protocol/restapi.py (renamed from azurelinuxagent/protocol/common.py) | 38 | ||||
-rw-r--r-- | azurelinuxagent/protocol/wire.py (renamed from azurelinuxagent/protocol/v1.py) | 196 |
6 files changed, 213 insertions, 276 deletions
diff --git a/azurelinuxagent/protocol/__init__.py b/azurelinuxagent/protocol/__init__.py index a4572e6..8c1bbdb 100644 --- a/azurelinuxagent/protocol/__init__.py +++ b/azurelinuxagent/protocol/__init__.py @@ -16,8 +16,3 @@ # # Requires Python 2.4+ and Openssl 1.0+ # - -from azurelinuxagent.protocol.common import * -from azurelinuxagent.protocol.protocolFactory import FACTORY, \ - detect_default_protocol - diff --git a/azurelinuxagent/protocol/v2.py b/azurelinuxagent/protocol/metadata.py index 34102b7..8a1656f 100644 --- a/azurelinuxagent/protocol/v2.py +++ b/azurelinuxagent/protocol/metadata.py @@ -17,16 +17,30 @@ # Requires Python 2.4+ and Openssl 1.0+ import json -from azurelinuxagent.future import httpclient, text +import shutil +import os +import time +from azurelinuxagent.exception import ProtocolError, HttpError +from azurelinuxagent.future import httpclient, ustr +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger import azurelinuxagent.utils.restutil as restutil -from azurelinuxagent.protocol.common import * +import azurelinuxagent.utils.textutil as textutil +import azurelinuxagent.utils.fileutil as fileutil +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.protocol.restapi import * -ENDPOINT='169.254.169.254' -#TODO use http for azure pack test -#ENDPOINT='localhost' +METADATA_ENDPOINT='169.254.169.254' APIVERSION='2015-05-01-preview' BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" +TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem" +TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem" + +#TODO remote workarround for azure stack +MAX_PING = 30 +RETRY_PING_INTERVAL = 10 + def _add_content_type(headers): if headers is None: headers = {} @@ -35,7 +49,7 @@ def _add_content_type(headers): class MetadataProtocol(Protocol): - def __init__(self, apiversion=APIVERSION, endpoint=ENDPOINT): + def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT): self.apiversion = apiversion self.endpoint = endpoint self.identity_uri = BASE_URI.format(self.endpoint, "identity", @@ -58,24 +72,25 @@ class MetadataProtocol(Protocol): def _get_data(self, url, headers=None): try: resp = restutil.http_get(url, headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status != httpclient.OK: raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) data = resp.read() + etag = resp.getheader('ETag') if data is None: return None - data = json.loads(text(data, encoding="utf-8")) - return data + data = json.loads(ustr(data, encoding="utf-8")) + return data, etag def _put_data(self, url, data, headers=None): headers = _add_content_type(headers) try: resp = restutil.http_put(url, json.dumps(data), headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status != httpclient.OK: raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) @@ -83,17 +98,41 @@ class MetadataProtocol(Protocol): headers = _add_content_type(headers) try: resp = restutil.http_post(url, json.dumps(data), headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status != httpclient.CREATED: raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) + + def _get_trans_cert(self): + trans_crt_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + if not os.path.isfile(trans_crt_file): + raise ProtocolError("{0} is missing.".format(trans_crt_file)) + content = fileutil.read_file(trans_crt_file) + return textutil.get_bytes_from_pem(content) + + def detect(self): + self.get_vminfo() + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + #"Install" the cert and private key to /var/lib/waagent + thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) + prv_file = os.path.join(conf.get_lib_dir(), + "{0}.prv".format(thumbprint)) + crt_file = os.path.join(conf.get_lib_dir(), + "{0}.crt".format(thumbprint)) + shutil.copyfile(trans_prv_file, prv_file) + shutil.copyfile(trans_cert_file, crt_file) - def initialize(self): - pass def get_vminfo(self): vminfo = VMInfo() - data = self._get_data(self.identity_uri) + data, etag = self._get_data(self.identity_uri) set_properties("vminfo", vminfo, data) return vminfo @@ -102,17 +141,20 @@ class MetadataProtocol(Protocol): return CertList() def get_ext_handlers(self): + headers = { + "x-ms-vmagent-public-x509-cert": self._get_trans_cert() + } ext_list = ExtHandlerList() - data = self._get_data(self.ext_uri) + data, etag = self._get_data(self.ext_uri, headers=headers) set_properties("extensionHandlers", ext_list.extHandlers, data) - return ext_list + return ext_list, etag def get_ext_handler_pkgs(self, ext_handler): ext_handler_pkgs = ExtHandlerPackageList() data = None for version_uri in ext_handler.versionUris: try: - data = self._get_data(version_uri.uri) + data, etag = self._get_data(version_uri.uri) break except ProtocolError as e: logger.warn("Failed to get version uris: {0}", e) @@ -128,6 +170,14 @@ class MetadataProtocol(Protocol): def report_vm_status(self, vm_status): validata_param('vmStatus', vm_status, VMStatus) data = get_properties(vm_status) + #TODO code field is not implemented for metadata protocol yet. Remove it + handler_statuses = data['vmAgent']['extensionHandlers'] + for handler_status in handler_statuses: + try: + handler_status.pop('code', None) + except KeyError: + pass + self._put_data(self.vm_status_uri, data) def report_ext_status(self, ext_handler_name, ext_name, ext_status): diff --git a/azurelinuxagent/protocol/ovfenv.py b/azurelinuxagent/protocol/ovfenv.py index 9c845ee..de6791c 100644 --- a/azurelinuxagent/protocol/ovfenv.py +++ b/azurelinuxagent/protocol/ovfenv.py @@ -17,60 +17,22 @@ # Requires Python 2.4+ and Openssl 1.0+ # """ -Copy and parse ovf-env.xml from provisiong ISO and local cache +Copy and parse ovf-env.xml from provisioning ISO and local cache """ import os import re +import shutil import xml.dom.minidom as minidom import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.exception import ProtocolError +from azurelinuxagent.future import ustr import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext -from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError -from azurelinuxagent.protocol import ProtocolError -OVF_FILE_NAME = "ovf-env.xml" OVF_VERSION = "1.0" OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1" WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure" -def get_ovf_env(): - """ - Load saved ovf-env.xml - """ - ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME) - if os.path.isfile(ovf_file_path): - xml_text = fileutil.read_file(ovf_file_path) - return OvfEnv(xml_text) - else: - raise ProtocolError("ovf-env.xml is missing.") - -def copy_ovf_env(): - """ - Copy ovf env file from dvd to hard disk. - Remove password before save it to the disk - """ - try: - OSUTIL.mount_dvd() - ovf_file_path_on_dvd = OSUTIL.get_ovf_env_file_path_on_dvd() - ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) - ovfenv = OvfEnv(ovfxml) - ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml) - ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME) - fileutil.write_file(ovf_file_path, ovfxml) - except IOError as e: - raise ProtocolError(text(e)) - except OSUtilError as e: - raise ProtocolError(text(e)) - - try: - OSUTIL.umount_dvd() - OSUTIL.eject_dvd() - except OSUtilError as e: - logger.warn(text(e)) - - return ovfenv - def _validate_ovf(val, msg): if val is None: raise ProtocolError("Failed to parse OVF XML: {0}".format(msg)) diff --git a/azurelinuxagent/protocol/protocolFactory.py b/azurelinuxagent/protocol/protocolFactory.py deleted file mode 100644 index 0bf6e52..0000000 --- a/azurelinuxagent/protocol/protocolFactory.py +++ /dev/null @@ -1,114 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# -import os -import traceback -import threading -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.utils.fileutil as fileutil -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.protocol.common import * -from azurelinuxagent.protocol.v1 import WireProtocol -from azurelinuxagent.protocol.v2 import MetadataProtocol - -WIRE_SERVER_ADDR_FILE_NAME = "WireServer" - -def get_wire_protocol_endpoint(): - path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME) - try: - endpoint = fileutil.read_file(path) - except IOError as e: - raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e)) - - if endpoint is None: - raise ProtocolNotFound("Wire server endpoint is None") - - return endpoint - -def detect_wire_protocol(): - endpoint = get_wire_protocol_endpoint() - - OSUTIL.gen_transport_cert() - protocol = WireProtocol(endpoint) - protocol.initialize() - logger.info("Protocol V1 found.") - return protocol - -def detect_metadata_protocol(): - protocol = MetadataProtocol() - protocol.initialize() - - logger.info("Protocol V2 found.") - return protocol - -def detect_available_protocols(prob_funcs=[detect_wire_protocol, - detect_metadata_protocol]): - available_protocols = [] - for probe_func in prob_funcs: - try: - protocol = probe_func() - available_protocols.append(protocol) - except ProtocolNotFound as e: - logger.info(text(e)) - return available_protocols - -def detect_default_protocol(): - logger.info("Detect default protocol.") - available_protocols = detect_available_protocols() - return choose_default_protocol(available_protocols) - -def choose_default_protocol(protocols): - if len(protocols) > 0: - return protocols[0] - else: - raise ProtocolNotFound("No available protocol detected.") - -def get_wire_protocol(): - endpoint = get_wire_protocol_endpoint() - return WireProtocol(endpoint) - -def get_metadata_protocol(): - return MetadataProtocol() - -def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]): - available_protocols = [] - for getter in getters: - try: - protocol = getter() - available_protocols.append(protocol) - except ProtocolNotFound as e: - logger.info(text(e)) - return available_protocols - -class ProtocolFactory(object): - def __init__(self): - self._protocol = None - self._lock = threading.Lock() - - def get_default_protocol(self): - if self._protocol is None: - self._lock.acquire() - if self._protocol is None: - available_protocols = get_available_protocols() - self._protocol = choose_default_protocol(available_protocols) - self._lock.release() - - return self._protocol - -FACTORY = ProtocolFactory() diff --git a/azurelinuxagent/protocol/common.py b/azurelinuxagent/protocol/restapi.py index 367794f..fbd29ed 100644 --- a/azurelinuxagent/protocol/common.py +++ b/azurelinuxagent/protocol/restapi.py @@ -22,14 +22,9 @@ import re import json import xml.dom.minidom import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.utils.fileutil as fileutil - -class ProtocolError(Exception): - pass - -class ProtocolNotFound(Exception): - pass +from azurelinuxagent.exception import ProtocolError, HttpError +from azurelinuxagent.future import ustr +import azurelinuxagent.utils.restutil as restutil def validata_param(name, val, expected_type): if val is None: @@ -88,9 +83,14 @@ class DataContractList(list): Data contract between guest and host """ class VMInfo(DataContract): - def __init__(self, subscriptionId=None, vmName=None): + def __init__(self, subscriptionId=None, vmName=None, containerId=None, + roleName=None, roleInstanceName=None, tenantName=None): self.subscriptionId = subscriptionId self.vmName = vmName + self.containerId = containerId + self.roleName = roleName + self.roleInstanceName = roleInstanceName + self.tenantName = tenantName class Cert(DataContract): def __init__(self, name=None, thumbprint=None, certificateDataUri=None): @@ -104,11 +104,11 @@ class CertList(DataContract): class Extension(DataContract): def __init__(self, name=None, sequenceNumber=None, publicSettings=None, - privateSettings=None, certificateThumbprint=None): + protectedSettings=None, certificateThumbprint=None): self.name = name self.sequenceNumber = sequenceNumber self.publicSettings = publicSettings - self.privateSettings = privateSettings + self.protectedSettings = protectedSettings self.certificateThumbprint = certificateThumbprint class ExtHandlerProperties(DataContract): @@ -176,12 +176,14 @@ class ExtensionStatus(DataContract): self.substatusList = DataContractList(ExtensionSubStatus) class ExtHandlerStatus(DataContract): - def __init__(self, name=None, version=None, status=None, message=None): + def __init__(self, name=None, version=None, status=None, code=0, + message=None): self.name = name self.version = version self.status = status + self.code = code self.message = message - self.extensions = DataContractList(text) + self.extensions = DataContractList(ustr) class VMAgentStatus(DataContract): def __init__(self, version=None, status=None, message=None): @@ -211,7 +213,7 @@ class TelemetryEventList(DataContract): class Protocol(DataContract): - def initialize(self): + def detect(self): raise NotImplementedError() def get_vminfo(self): @@ -226,6 +228,14 @@ class Protocol(DataContract): def get_ext_handler_pkgs(self, extension): raise NotImplementedError() + def download_ext_handler_pkg(self, uri): + try: + resp = restutil.http_get(uri, chk_proxy=True) + if resp.status == restutil.httpclient.OK: + return resp.read() + except HttpError as e: + raise ProtocolError("Failed to download from: {0}".format(uri), e) + def report_provision_status(self, provision_status): raise NotImplementedError() diff --git a/azurelinuxagent/protocol/v1.py b/azurelinuxagent/protocol/wire.py index 92fcc06..7b5ffe8 100644 --- a/azurelinuxagent/protocol/v1.py +++ b/azurelinuxagent/protocol/wire.py @@ -22,16 +22,19 @@ import re import time import traceback import xml.sax.saxutils as saxutils -import xml.etree.ElementTree as ET +import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.future import text, httpclient, bytebuffer +from azurelinuxagent.exception import ProtocolError, HttpError, \ + ProtocolNotFoundError +from azurelinuxagent.future import ustr, httpclient, bytebuffer import azurelinuxagent.utils.restutil as restutil from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext, \ - getattrib, gettext, remove_bom -from azurelinuxagent.utils.osutil import OSUTIL + getattrib, gettext, remove_bom, \ + get_bytes_from_pem import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil -from azurelinuxagent.protocol.common import * +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.protocol.restapi import * VERSION_INFO_URI = "http://{0}/?comp=versions" GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate" @@ -53,6 +56,7 @@ TRANSPORT_CERT_FILE_NAME = "TransportCert.pem" TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" PROTOCOL_VERSION = "2012-11-30" +ENDPOINT_FINE_NAME = "WireServer" SHORT_WAITING_INTERVAL = 1 # 1 second LONG_WAITING_INTERVAL = 15 # 15 seconds @@ -61,19 +65,37 @@ class WireProtocolResourceGone(ProtocolError): pass class WireProtocol(Protocol): + """Slim layer to adapte wire protocol data to metadata protocol interface""" def __init__(self, endpoint): - self.client = WireClient(endpoint) + if endpoint is None: + raise ProtocolError("WireProtocl endpoint is None") + self.endpoint = endpoint + self.client = WireClient(self.endpoint) - def initialize(self): + def detect(self): self.client.check_wire_protocol_version() + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + self.client.update_goal_state(forced=True) def get_vminfo(self): + goal_state = self.client.get_goal_state() hosting_env = self.client.get_hosting_env() + vminfo = VMInfo() vminfo.subscriptionId = None vminfo.vmName = hosting_env.vm_name + vminfo.tenantName = hosting_env.deployment_name + vminfo.roleName = hosting_env.role_name + vminfo.roleInstanceName = goal_state.role_instance_id + vminfo.containerId = goal_state.container_id return vminfo def get_certs(self): @@ -81,12 +103,16 @@ class WireProtocol(Protocol): return certificates.cert_list def get_ext_handlers(self): + logger.verb("Get extension handler config") #Update goal state to get latest extensions config self.client.update_goal_state() + goal_state = self.client.get_goal_state() ext_conf = self.client.get_ext_conf() - return ext_conf.ext_handlers + #In wire protocol, incarnation is equivalent to ETag + return ext_conf.ext_handlers, goal_state.incarnation def get_ext_handler_pkgs(self, ext_handler): + logger.verb("Get extension handler package") goal_state = self.client.get_goal_state() man = self.client.get_ext_manifest(ext_handler, goal_state) return man.pkg_list @@ -134,12 +160,12 @@ def _build_role_properties(container_id, role_instance_id, thumbprint): return xml def _build_health_report(incarnation, container_id, role_instance_id, - status, substatus, description): + status, substatus, description): #Escape '&', '<' and '>' - description = saxutils.escape(text(description)) + description = saxutils.escape(ustr(description)) detail = u'' if substatus is not None: - substatus = saxutils.escape(text(substatus)) + substatus = saxutils.escape(ustr(substatus)) detail = (u"<Details>" u"<SubStatus>{0}</SubStatus>" u"<Description>{1}</Description>" @@ -228,6 +254,7 @@ def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): 'handlerVersion' : handler_status.version, 'handlerName' : handler_status.name, 'status' : handler_status.status, + 'code': handler_status.code } if handler_status.message is not None: v1_handler_status["formattedMessage"] = { @@ -303,7 +330,7 @@ class StatusBlob(object): self.put_page_blob(url, data) else: raise ProtocolError("Unknown blob type: {0}".format(blob_type)) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError("Failed to upload status blob: {0}".format(e)) def get_blob_type(self, url): @@ -315,7 +342,7 @@ class StatusBlob(object): "x-ms-date" : timestamp, 'x-ms-version' : self.__class__.__storage_version__ }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to get status blob type: {0}" u"").format(e)) if resp is None or resp.status != httpclient.OK: @@ -334,10 +361,10 @@ class StatusBlob(object): data, { "x-ms-date" : timestamp, "x-ms-blob-type" : "BlockBlob", - "Content-Length": text(len(data)), + "Content-Length": ustr(len(data)), "x-ms-version" : self.__class__.__storage_version__ }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to upload block blob: {0}" u"").format(e)) if resp.status != httpclient.CREATED: @@ -359,10 +386,10 @@ class StatusBlob(object): "x-ms-date" : timestamp, "x-ms-blob-type" : "PageBlob", "Content-Length": "0", - "x-ms-blob-content-length" : text(page_blob_size), + "x-ms-blob-content-length" : ustr(page_blob_size), "x-ms-version" : self.__class__.__storage_version__ }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to clean up page blob: {0}" u"").format(e)) if resp.status != httpclient.CREATED: @@ -393,9 +420,9 @@ class StatusBlob(object): "x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1), "x-ms-page-write" : "update", "x-ms-version" : self.__class__.__storage_version__, - "Content-Length": text(page_end - start) + "Content-Length": ustr(page_end - start) }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to upload page blob: {0}" u"").format(e)) if resp is None or resp.status != httpclient.CREATED: @@ -411,13 +438,13 @@ def event_param_to_v1(param): attr_type = 'mt:uint64' elif param_type is str: attr_type = 'mt:wstr' - elif text(param_type).count("'unicode'") > 0: + elif ustr(param_type).count("'unicode'") > 0: attr_type = 'mt:wstr' elif param_type is bool: attr_type = 'mt:bool' elif param_type is float: attr_type = 'mt:float64' - return param_format.format(param.name, saxutils.quoteattr(text(param.value)), + return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)), attr_type) def event_to_v1(event): @@ -431,6 +458,7 @@ def event_to_v1(event): class WireClient(object): def __init__(self, endpoint): + logger.info("Wire server endpoint:{0}", endpoint) self.endpoint = endpoint self.goal_state = None self.updated = None @@ -448,15 +476,15 @@ class WireClient(object): """ now = time.time() if now - self.last_request < 1: - logger.info("Last request issued less than 1 second ago") - logger.info("Sleep {0} second to avoid throttling.", + logger.verb("Last request issued less than 1 second ago") + logger.verb("Sleep {0} second to avoid throttling.", SHORT_WAITING_INTERVAL) time.sleep(SHORT_WAITING_INTERVAL) self.last_request = now self.req_count += 1 if self.req_count % 3 == 0: - logger.info("Sleep {0} second to avoid throttling.", + logger.verb("Sleep {0} second to avoid throttling.", SHORT_WAITING_INTERVAL) time.sleep(SHORT_WAITING_INTERVAL) self.req_count = 0 @@ -485,15 +513,15 @@ class WireClient(object): if data is None: return None data = remove_bom(data) - xml_text = text(data, encoding='utf-8') + xml_text = ustr(data, encoding='utf-8') return xml_text def fetch_config(self, uri, headers): try: resp = self.call_wireserver(restutil.http_get, uri, headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if(resp.status != httpclient.OK): raise ProtocolError("{0} - {1}".format(resp.status, uri)) @@ -532,12 +560,13 @@ class WireClient(object): def fetch_manifest(self, version_uris): for version_uri in version_uris: + logger.verb("Fetch ext handler manifest: {0}", version_uri.uri) try: resp = self.call_storage_service(restutil.http_get, version_uri.uri, None, chk_proxy=True) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status == httpclient.OK: return self.decode_config(resp.read()) @@ -553,7 +582,7 @@ class WireClient(object): def update_hosting_env(self, goal_state): if goal_state.hosting_env_uri is None: raise ProtocolError("HostingEnvironmentConfig uri is empty") - local_file = HOSTING_ENV_FILE_NAME + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) xml_text = self.fetch_config(goal_state.hosting_env_uri, self.get_header()) self.save_cache(local_file, xml_text) @@ -562,7 +591,7 @@ class WireClient(object): def update_shared_conf(self, goal_state): if goal_state.shared_conf_uri is None: raise ProtocolError("SharedConfig uri is empty") - local_file = SHARED_CONF_FILE_NAME + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) xml_text = self.fetch_config(goal_state.shared_conf_uri, self.get_header()) self.save_cache(local_file, xml_text) @@ -571,7 +600,7 @@ class WireClient(object): def update_certs(self, goal_state): if goal_state.certs_uri is None: return - local_file = CERTS_FILE_NAME + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) xml_text = self.fetch_config(goal_state.certs_uri, self.get_header_for_cert()) self.save_cache(local_file, xml_text) @@ -583,25 +612,18 @@ class WireClient(object): self.ext_conf = ExtensionsConfig(None) return incarnation = goal_state.incarnation - local_file = EXT_CONF_FILE_NAME.format(incarnation) + local_file = os.path.join(conf.get_lib_dir(), + EXT_CONF_FILE_NAME.format(incarnation)) xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) self.save_cache(local_file, xml_text) self.ext_conf = ExtensionsConfig(xml_text) - for ext_handler in self.ext_conf.ext_handlers.extHandlers: - self.update_ext_handler_manifest(ext_handler, goal_state) - - def update_ext_handler_manifest(self, ext_handler, goal_state): - local_file = MANIFEST_FILE_NAME.format(ext_handler.name, - goal_state.incarnation) - xml_text = self.fetch_manifest(ext_handler.versionUris) - self.save_cache(local_file, xml_text) - + def update_goal_state(self, forced=False, max_retry=3): uri = GOAL_STATE_URI.format(self.endpoint) xml_text = self.fetch_config(uri, self.get_header()) goal_state = GoalState(xml_text) - incarnation_file = os.path.join(OSUTIL.get_lib_dir(), + incarnation_file = os.path.join(conf.get_lib_dir(), INCARNATION_FILE_NAME) if not forced: @@ -619,7 +641,7 @@ class WireClient(object): try: self.goal_state = goal_state file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) - goal_state_file = os.path.join(OSUTIL.get_lib_dir(), file_name) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) self.save_cache(goal_state_file, xml_text) self.save_cache(incarnation_file, goal_state.incarnation) self.update_hosting_env(goal_state) @@ -636,27 +658,34 @@ class WireClient(object): def get_goal_state(self): if(self.goal_state is None): - incarnation = self.fetch_cache(INCARNATION_FILE_NAME) - goal_state_file = GOAL_STATE_FILE_NAME.format(incarnation) + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + incarnation = self.fetch_cache(incarnation_file) + + file_name = GOAL_STATE_FILE_NAME.format(incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) xml_text = self.fetch_cache(goal_state_file) self.goal_state = GoalState(xml_text) return self.goal_state def get_hosting_env(self): if(self.hosting_env is None): - xml_text = self.fetch_cache(HOSTING_ENV_FILE_NAME) + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_cache(local_file) self.hosting_env = HostingEnv(xml_text) return self.hosting_env def get_shared_conf(self): if(self.shared_conf is None): - xml_text = self.fetch_cache(SHARED_CONF_FILE_NAME) + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_cache(local_file) self.shared_conf = SharedConfig(xml_text) return self.shared_conf def get_certs(self): if(self.certs is None): - xml_text = self.fetch_cache(CERTS_FILE_NAME) + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_cache(local_file) self.certs = Certificates(self, xml_text) if self.certs is None: return None @@ -669,14 +698,17 @@ class WireClient(object): self.ext_conf = ExtensionsConfig(None) else: local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) xml_text = self.fetch_cache(local_file) self.ext_conf = ExtensionsConfig(xml_text) return self.ext_conf - def get_ext_manifest(self, extension, goal_state): - local_file = MANIFEST_FILE_NAME.format(extension.name, - goal_state.incarnation) - xml_text = self.fetch_cache(local_file) + def get_ext_manifest(self, ext_handler, goal_state): + local_file = MANIFEST_FILE_NAME.format(ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) return ExtensionManifest(xml_text) def check_wire_protocol_version(self): @@ -693,7 +725,7 @@ class WireClient(object): else: error = ("Agent supported wire protocol version: {0} was not " "advised by Fabric.").format(PROTOCOL_VERSION) - raise ProtocolNotFound(error) + raise ProtocolNotFoundError(error) def upload_status_blob(self): ext_conf = self.get_ext_conf() @@ -711,7 +743,7 @@ class WireClient(object): try: resp = self.call_wireserver(restutil.http_post, role_prop_uri, role_prop, headers = headers) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to send role properties: {0}" u"").format(e)) if resp.status != httpclient.ACCEPTED: @@ -732,7 +764,7 @@ class WireClient(object): try: resp = self.call_wireserver(restutil.http_post, health_report_uri, health_report, headers = headers) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to send provision status: {0}" u"").format(e)) if resp.status != httpclient.OK: @@ -750,7 +782,7 @@ class WireClient(object): try: header = self.get_header_for_xml_content() resp = self.call_wireserver(restutil.http_post, uri, data, header) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError("Failed to send events:{0}".format(e)) if resp.status != httpclient.OK: @@ -791,11 +823,10 @@ class WireClient(object): } def get_header_for_cert(self): - cert = "" - content = self.fetch_cache(TRANSPORT_CERT_FILE_NAME) - for line in content.split('\n'): - if "CERTIFICATE" not in line: - cert += line.rstrip() + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + content = self.fetch_cache(trans_cert_file) + cert = get_bytes_from_pem(content) return { "x-ms-agent-name":"WALinuxAgent", "x-ms-version":PROTOCOL_VERSION, @@ -922,8 +953,6 @@ class Certificates(object): def __init__(self, client, xml_text): logger.verb("Load Certificates.xml") self.client = client - self.lib_dir = OSUTIL.get_lib_dir() - self.openssl_cmd = OSUTIL.get_openssl_cmd() self.cert_list = CertList() self.parse(xml_text) @@ -935,22 +964,26 @@ class Certificates(object): data = findtext(xml_doc, "Data") if data is None: return - + + cryptutil = CryptUtil(conf.get_openssl_cmd()) + p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) p7m = ("MIME-Version:1.0\n" "Content-Disposition: attachment; filename=\"{0}\"\n" "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" "Content-Transfer-Encoding: base64\n" "\n" - "{2}").format(P7M_FILE_NAME, P7M_FILE_NAME, data) + "{2}").format(p7m_file, p7m_file, data) - self.client.save_cache(os.path.join(self.lib_dir, P7M_FILE_NAME), p7m) + self.client.save_cache(p7m_file, p7m) + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) #decrypt certificates - cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3}" - "| {4} pkcs12 -nodes -password pass: -out {5}" - "").format(self.openssl_cmd, P7M_FILE_NAME, - TRANSPORT_PRV_FILE_NAME, TRANSPORT_CERT_FILE_NAME, - self.openssl_cmd, PEM_FILE_NAME) - shellutil.run(cmd) + cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, + pem_file) #The parsing process use public key to match prv and crt. buf = [] @@ -960,7 +993,7 @@ class Certificates(object): thumbprints = {} index = 0 v1_cert_list = [] - with open(PEM_FILE_NAME) as pem: + with open(pem_file) as pem: for line in pem.readlines(): buf.append(line) if re.match(r'[-]+BEGIN.*KEY[-]+', line): @@ -969,15 +1002,15 @@ class Certificates(object): begin_crt = True elif re.match(r'[-]+END.*KEY[-]+', line): tmp_file = self.write_to_tmp_file(index, 'prv', buf) - pub = OSUTIL.get_pubkey_from_prv(tmp_file) + pub = cryptutil.get_pubkey_from_prv(tmp_file) prvs[pub] = tmp_file buf = [] index += 1 begin_prv = False elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): tmp_file = self.write_to_tmp_file(index, 'crt', buf) - pub = OSUTIL.get_pubkey_from_crt(tmp_file) - thumbprint = OSUTIL.get_thumbprint_from_crt(tmp_file) + pub = cryptutil.get_pubkey_from_crt(tmp_file) + thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) thumbprints[pub] = thumbprint #Rename crt with thumbprint as the file name crt = "{0}.crt".format(thumbprint) @@ -985,7 +1018,7 @@ class Certificates(object): "name":None, "thumbprint":thumbprint }) - os.rename(tmp_file, os.path.join(self.lib_dir, crt)) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) buf = [] index += 1 begin_crt = False @@ -996,7 +1029,7 @@ class Certificates(object): if thumbprint: tmp_file = prvs[pubkey] prv = "{0}.prv".format(thumbprint) - os.rename(tmp_file, os.path.join(self.lib_dir, prv)) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) for v1_cert in v1_cert_list: cert = Cert() @@ -1004,7 +1037,8 @@ class Certificates(object): self.cert_list.certificates.append(cert) def write_to_tmp_file(self, index, suffix, buf): - file_name = os.path.join(self.lib_dir, "{0}.{1}".format(index, suffix)) + file_name = os.path.join(conf.get_lib_dir(), + "{0}.{1}".format(index, suffix)) self.client.save_cache(file_name, "".join(buf)) return file_name @@ -1090,7 +1124,7 @@ class ExtensionsConfig(object): ext.name = ext_handler.name ext.sequenceNumber = seqNo ext.publicSettings = handler_settings.get("publicSettings") - ext.privateSettings = handler_settings.get("protectedSettings") + ext.protectedSettings = handler_settings.get("protectedSettings") thumbprint = handler_settings.get("protectedSettingsCertThumbprint") ext.certificateThumbprint = thumbprint ext_handler.properties.extensions.append(ext) |