summaryrefslogtreecommitdiff
path: root/azurelinuxagent/protocol
diff options
context:
space:
mode:
Diffstat (limited to 'azurelinuxagent/protocol')
-rw-r--r--azurelinuxagent/protocol/__init__.py5
-rw-r--r--azurelinuxagent/protocol/metadata.py (renamed from azurelinuxagent/protocol/v2.py)90
-rw-r--r--azurelinuxagent/protocol/ovfenv.py46
-rw-r--r--azurelinuxagent/protocol/protocolFactory.py114
-rw-r--r--azurelinuxagent/protocol/restapi.py (renamed from azurelinuxagent/protocol/common.py)38
-rw-r--r--azurelinuxagent/protocol/wire.py (renamed from azurelinuxagent/protocol/v1.py)196
6 files changed, 213 insertions, 276 deletions
diff --git a/azurelinuxagent/protocol/__init__.py b/azurelinuxagent/protocol/__init__.py
index a4572e6..8c1bbdb 100644
--- a/azurelinuxagent/protocol/__init__.py
+++ b/azurelinuxagent/protocol/__init__.py
@@ -16,8 +16,3 @@
#
# Requires Python 2.4+ and Openssl 1.0+
#
-
-from azurelinuxagent.protocol.common import *
-from azurelinuxagent.protocol.protocolFactory import FACTORY, \
- detect_default_protocol
-
diff --git a/azurelinuxagent/protocol/v2.py b/azurelinuxagent/protocol/metadata.py
index 34102b7..8a1656f 100644
--- a/azurelinuxagent/protocol/v2.py
+++ b/azurelinuxagent/protocol/metadata.py
@@ -17,16 +17,30 @@
# Requires Python 2.4+ and Openssl 1.0+
import json
-from azurelinuxagent.future import httpclient, text
+import shutil
+import os
+import time
+from azurelinuxagent.exception import ProtocolError, HttpError
+from azurelinuxagent.future import httpclient, ustr
+import azurelinuxagent.conf as conf
+import azurelinuxagent.logger as logger
import azurelinuxagent.utils.restutil as restutil
-from azurelinuxagent.protocol.common import *
+import azurelinuxagent.utils.textutil as textutil
+import azurelinuxagent.utils.fileutil as fileutil
+from azurelinuxagent.utils.cryptutil import CryptUtil
+from azurelinuxagent.protocol.restapi import *
-ENDPOINT='169.254.169.254'
-#TODO use http for azure pack test
-#ENDPOINT='localhost'
+METADATA_ENDPOINT='169.254.169.254'
APIVERSION='2015-05-01-preview'
BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}"
+TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem"
+TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem"
+
+#TODO remote workarround for azure stack
+MAX_PING = 30
+RETRY_PING_INTERVAL = 10
+
def _add_content_type(headers):
if headers is None:
headers = {}
@@ -35,7 +49,7 @@ def _add_content_type(headers):
class MetadataProtocol(Protocol):
- def __init__(self, apiversion=APIVERSION, endpoint=ENDPOINT):
+ def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT):
self.apiversion = apiversion
self.endpoint = endpoint
self.identity_uri = BASE_URI.format(self.endpoint, "identity",
@@ -58,24 +72,25 @@ class MetadataProtocol(Protocol):
def _get_data(self, url, headers=None):
try:
resp = restutil.http_get(url, headers=headers)
- except restutil.HttpError as e:
- raise ProtocolError(text(e))
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
if resp.status != httpclient.OK:
raise ProtocolError("{0} - GET: {1}".format(resp.status, url))
data = resp.read()
+ etag = resp.getheader('ETag')
if data is None:
return None
- data = json.loads(text(data, encoding="utf-8"))
- return data
+ data = json.loads(ustr(data, encoding="utf-8"))
+ return data, etag
def _put_data(self, url, data, headers=None):
headers = _add_content_type(headers)
try:
resp = restutil.http_put(url, json.dumps(data), headers=headers)
- except restutil.HttpError as e:
- raise ProtocolError(text(e))
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
if resp.status != httpclient.OK:
raise ProtocolError("{0} - PUT: {1}".format(resp.status, url))
@@ -83,17 +98,41 @@ class MetadataProtocol(Protocol):
headers = _add_content_type(headers)
try:
resp = restutil.http_post(url, json.dumps(data), headers=headers)
- except restutil.HttpError as e:
- raise ProtocolError(text(e))
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
if resp.status != httpclient.CREATED:
raise ProtocolError("{0} - POST: {1}".format(resp.status, url))
+
+ def _get_trans_cert(self):
+ trans_crt_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ if not os.path.isfile(trans_crt_file):
+ raise ProtocolError("{0} is missing.".format(trans_crt_file))
+ content = fileutil.read_file(trans_crt_file)
+ return textutil.get_bytes_from_pem(content)
+
+ def detect(self):
+ self.get_vminfo()
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
+
+ #"Install" the cert and private key to /var/lib/waagent
+ thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file)
+ prv_file = os.path.join(conf.get_lib_dir(),
+ "{0}.prv".format(thumbprint))
+ crt_file = os.path.join(conf.get_lib_dir(),
+ "{0}.crt".format(thumbprint))
+ shutil.copyfile(trans_prv_file, prv_file)
+ shutil.copyfile(trans_cert_file, crt_file)
- def initialize(self):
- pass
def get_vminfo(self):
vminfo = VMInfo()
- data = self._get_data(self.identity_uri)
+ data, etag = self._get_data(self.identity_uri)
set_properties("vminfo", vminfo, data)
return vminfo
@@ -102,17 +141,20 @@ class MetadataProtocol(Protocol):
return CertList()
def get_ext_handlers(self):
+ headers = {
+ "x-ms-vmagent-public-x509-cert": self._get_trans_cert()
+ }
ext_list = ExtHandlerList()
- data = self._get_data(self.ext_uri)
+ data, etag = self._get_data(self.ext_uri, headers=headers)
set_properties("extensionHandlers", ext_list.extHandlers, data)
- return ext_list
+ return ext_list, etag
def get_ext_handler_pkgs(self, ext_handler):
ext_handler_pkgs = ExtHandlerPackageList()
data = None
for version_uri in ext_handler.versionUris:
try:
- data = self._get_data(version_uri.uri)
+ data, etag = self._get_data(version_uri.uri)
break
except ProtocolError as e:
logger.warn("Failed to get version uris: {0}", e)
@@ -128,6 +170,14 @@ class MetadataProtocol(Protocol):
def report_vm_status(self, vm_status):
validata_param('vmStatus', vm_status, VMStatus)
data = get_properties(vm_status)
+ #TODO code field is not implemented for metadata protocol yet. Remove it
+ handler_statuses = data['vmAgent']['extensionHandlers']
+ for handler_status in handler_statuses:
+ try:
+ handler_status.pop('code', None)
+ except KeyError:
+ pass
+
self._put_data(self.vm_status_uri, data)
def report_ext_status(self, ext_handler_name, ext_name, ext_status):
diff --git a/azurelinuxagent/protocol/ovfenv.py b/azurelinuxagent/protocol/ovfenv.py
index 9c845ee..de6791c 100644
--- a/azurelinuxagent/protocol/ovfenv.py
+++ b/azurelinuxagent/protocol/ovfenv.py
@@ -17,60 +17,22 @@
# Requires Python 2.4+ and Openssl 1.0+
#
"""
-Copy and parse ovf-env.xml from provisiong ISO and local cache
+Copy and parse ovf-env.xml from provisioning ISO and local cache
"""
import os
import re
+import shutil
import xml.dom.minidom as minidom
import azurelinuxagent.logger as logger
-from azurelinuxagent.future import text
+from azurelinuxagent.exception import ProtocolError
+from azurelinuxagent.future import ustr
import azurelinuxagent.utils.fileutil as fileutil
from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext
-from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError
-from azurelinuxagent.protocol import ProtocolError
-OVF_FILE_NAME = "ovf-env.xml"
OVF_VERSION = "1.0"
OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1"
WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure"
-def get_ovf_env():
- """
- Load saved ovf-env.xml
- """
- ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME)
- if os.path.isfile(ovf_file_path):
- xml_text = fileutil.read_file(ovf_file_path)
- return OvfEnv(xml_text)
- else:
- raise ProtocolError("ovf-env.xml is missing.")
-
-def copy_ovf_env():
- """
- Copy ovf env file from dvd to hard disk.
- Remove password before save it to the disk
- """
- try:
- OSUTIL.mount_dvd()
- ovf_file_path_on_dvd = OSUTIL.get_ovf_env_file_path_on_dvd()
- ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
- ovfenv = OvfEnv(ovfxml)
- ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml)
- ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME)
- fileutil.write_file(ovf_file_path, ovfxml)
- except IOError as e:
- raise ProtocolError(text(e))
- except OSUtilError as e:
- raise ProtocolError(text(e))
-
- try:
- OSUTIL.umount_dvd()
- OSUTIL.eject_dvd()
- except OSUtilError as e:
- logger.warn(text(e))
-
- return ovfenv
-
def _validate_ovf(val, msg):
if val is None:
raise ProtocolError("Failed to parse OVF XML: {0}".format(msg))
diff --git a/azurelinuxagent/protocol/protocolFactory.py b/azurelinuxagent/protocol/protocolFactory.py
deleted file mode 100644
index 0bf6e52..0000000
--- a/azurelinuxagent/protocol/protocolFactory.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Microsoft Azure Linux Agent
-#
-# Copyright 2014 Microsoft Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-# Requires Python 2.4+ and Openssl 1.0+
-#
-import os
-import traceback
-import threading
-import azurelinuxagent.logger as logger
-from azurelinuxagent.future import text
-import azurelinuxagent.utils.fileutil as fileutil
-from azurelinuxagent.utils.osutil import OSUTIL
-from azurelinuxagent.protocol.common import *
-from azurelinuxagent.protocol.v1 import WireProtocol
-from azurelinuxagent.protocol.v2 import MetadataProtocol
-
-WIRE_SERVER_ADDR_FILE_NAME = "WireServer"
-
-def get_wire_protocol_endpoint():
- path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME)
- try:
- endpoint = fileutil.read_file(path)
- except IOError as e:
- raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e))
-
- if endpoint is None:
- raise ProtocolNotFound("Wire server endpoint is None")
-
- return endpoint
-
-def detect_wire_protocol():
- endpoint = get_wire_protocol_endpoint()
-
- OSUTIL.gen_transport_cert()
- protocol = WireProtocol(endpoint)
- protocol.initialize()
- logger.info("Protocol V1 found.")
- return protocol
-
-def detect_metadata_protocol():
- protocol = MetadataProtocol()
- protocol.initialize()
-
- logger.info("Protocol V2 found.")
- return protocol
-
-def detect_available_protocols(prob_funcs=[detect_wire_protocol,
- detect_metadata_protocol]):
- available_protocols = []
- for probe_func in prob_funcs:
- try:
- protocol = probe_func()
- available_protocols.append(protocol)
- except ProtocolNotFound as e:
- logger.info(text(e))
- return available_protocols
-
-def detect_default_protocol():
- logger.info("Detect default protocol.")
- available_protocols = detect_available_protocols()
- return choose_default_protocol(available_protocols)
-
-def choose_default_protocol(protocols):
- if len(protocols) > 0:
- return protocols[0]
- else:
- raise ProtocolNotFound("No available protocol detected.")
-
-def get_wire_protocol():
- endpoint = get_wire_protocol_endpoint()
- return WireProtocol(endpoint)
-
-def get_metadata_protocol():
- return MetadataProtocol()
-
-def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]):
- available_protocols = []
- for getter in getters:
- try:
- protocol = getter()
- available_protocols.append(protocol)
- except ProtocolNotFound as e:
- logger.info(text(e))
- return available_protocols
-
-class ProtocolFactory(object):
- def __init__(self):
- self._protocol = None
- self._lock = threading.Lock()
-
- def get_default_protocol(self):
- if self._protocol is None:
- self._lock.acquire()
- if self._protocol is None:
- available_protocols = get_available_protocols()
- self._protocol = choose_default_protocol(available_protocols)
- self._lock.release()
-
- return self._protocol
-
-FACTORY = ProtocolFactory()
diff --git a/azurelinuxagent/protocol/common.py b/azurelinuxagent/protocol/restapi.py
index 367794f..fbd29ed 100644
--- a/azurelinuxagent/protocol/common.py
+++ b/azurelinuxagent/protocol/restapi.py
@@ -22,14 +22,9 @@ import re
import json
import xml.dom.minidom
import azurelinuxagent.logger as logger
-from azurelinuxagent.future import text
-import azurelinuxagent.utils.fileutil as fileutil
-
-class ProtocolError(Exception):
- pass
-
-class ProtocolNotFound(Exception):
- pass
+from azurelinuxagent.exception import ProtocolError, HttpError
+from azurelinuxagent.future import ustr
+import azurelinuxagent.utils.restutil as restutil
def validata_param(name, val, expected_type):
if val is None:
@@ -88,9 +83,14 @@ class DataContractList(list):
Data contract between guest and host
"""
class VMInfo(DataContract):
- def __init__(self, subscriptionId=None, vmName=None):
+ def __init__(self, subscriptionId=None, vmName=None, containerId=None,
+ roleName=None, roleInstanceName=None, tenantName=None):
self.subscriptionId = subscriptionId
self.vmName = vmName
+ self.containerId = containerId
+ self.roleName = roleName
+ self.roleInstanceName = roleInstanceName
+ self.tenantName = tenantName
class Cert(DataContract):
def __init__(self, name=None, thumbprint=None, certificateDataUri=None):
@@ -104,11 +104,11 @@ class CertList(DataContract):
class Extension(DataContract):
def __init__(self, name=None, sequenceNumber=None, publicSettings=None,
- privateSettings=None, certificateThumbprint=None):
+ protectedSettings=None, certificateThumbprint=None):
self.name = name
self.sequenceNumber = sequenceNumber
self.publicSettings = publicSettings
- self.privateSettings = privateSettings
+ self.protectedSettings = protectedSettings
self.certificateThumbprint = certificateThumbprint
class ExtHandlerProperties(DataContract):
@@ -176,12 +176,14 @@ class ExtensionStatus(DataContract):
self.substatusList = DataContractList(ExtensionSubStatus)
class ExtHandlerStatus(DataContract):
- def __init__(self, name=None, version=None, status=None, message=None):
+ def __init__(self, name=None, version=None, status=None, code=0,
+ message=None):
self.name = name
self.version = version
self.status = status
+ self.code = code
self.message = message
- self.extensions = DataContractList(text)
+ self.extensions = DataContractList(ustr)
class VMAgentStatus(DataContract):
def __init__(self, version=None, status=None, message=None):
@@ -211,7 +213,7 @@ class TelemetryEventList(DataContract):
class Protocol(DataContract):
- def initialize(self):
+ def detect(self):
raise NotImplementedError()
def get_vminfo(self):
@@ -226,6 +228,14 @@ class Protocol(DataContract):
def get_ext_handler_pkgs(self, extension):
raise NotImplementedError()
+ def download_ext_handler_pkg(self, uri):
+ try:
+ resp = restutil.http_get(uri, chk_proxy=True)
+ if resp.status == restutil.httpclient.OK:
+ return resp.read()
+ except HttpError as e:
+ raise ProtocolError("Failed to download from: {0}".format(uri), e)
+
def report_provision_status(self, provision_status):
raise NotImplementedError()
diff --git a/azurelinuxagent/protocol/v1.py b/azurelinuxagent/protocol/wire.py
index 92fcc06..7b5ffe8 100644
--- a/azurelinuxagent/protocol/v1.py
+++ b/azurelinuxagent/protocol/wire.py
@@ -22,16 +22,19 @@ import re
import time
import traceback
import xml.sax.saxutils as saxutils
-import xml.etree.ElementTree as ET
+import azurelinuxagent.conf as conf
import azurelinuxagent.logger as logger
-from azurelinuxagent.future import text, httpclient, bytebuffer
+from azurelinuxagent.exception import ProtocolError, HttpError, \
+ ProtocolNotFoundError
+from azurelinuxagent.future import ustr, httpclient, bytebuffer
import azurelinuxagent.utils.restutil as restutil
from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext, \
- getattrib, gettext, remove_bom
-from azurelinuxagent.utils.osutil import OSUTIL
+ getattrib, gettext, remove_bom, \
+ get_bytes_from_pem
import azurelinuxagent.utils.fileutil as fileutil
import azurelinuxagent.utils.shellutil as shellutil
-from azurelinuxagent.protocol.common import *
+from azurelinuxagent.utils.cryptutil import CryptUtil
+from azurelinuxagent.protocol.restapi import *
VERSION_INFO_URI = "http://{0}/?comp=versions"
GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate"
@@ -53,6 +56,7 @@ TRANSPORT_CERT_FILE_NAME = "TransportCert.pem"
TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem"
PROTOCOL_VERSION = "2012-11-30"
+ENDPOINT_FINE_NAME = "WireServer"
SHORT_WAITING_INTERVAL = 1 # 1 second
LONG_WAITING_INTERVAL = 15 # 15 seconds
@@ -61,19 +65,37 @@ class WireProtocolResourceGone(ProtocolError):
pass
class WireProtocol(Protocol):
+ """Slim layer to adapte wire protocol data to metadata protocol interface"""
def __init__(self, endpoint):
- self.client = WireClient(endpoint)
+ if endpoint is None:
+ raise ProtocolError("WireProtocl endpoint is None")
+ self.endpoint = endpoint
+ self.client = WireClient(self.endpoint)
- def initialize(self):
+ def detect(self):
self.client.check_wire_protocol_version()
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
+
self.client.update_goal_state(forced=True)
def get_vminfo(self):
+ goal_state = self.client.get_goal_state()
hosting_env = self.client.get_hosting_env()
+
vminfo = VMInfo()
vminfo.subscriptionId = None
vminfo.vmName = hosting_env.vm_name
+ vminfo.tenantName = hosting_env.deployment_name
+ vminfo.roleName = hosting_env.role_name
+ vminfo.roleInstanceName = goal_state.role_instance_id
+ vminfo.containerId = goal_state.container_id
return vminfo
def get_certs(self):
@@ -81,12 +103,16 @@ class WireProtocol(Protocol):
return certificates.cert_list
def get_ext_handlers(self):
+ logger.verb("Get extension handler config")
#Update goal state to get latest extensions config
self.client.update_goal_state()
+ goal_state = self.client.get_goal_state()
ext_conf = self.client.get_ext_conf()
- return ext_conf.ext_handlers
+ #In wire protocol, incarnation is equivalent to ETag
+ return ext_conf.ext_handlers, goal_state.incarnation
def get_ext_handler_pkgs(self, ext_handler):
+ logger.verb("Get extension handler package")
goal_state = self.client.get_goal_state()
man = self.client.get_ext_manifest(ext_handler, goal_state)
return man.pkg_list
@@ -134,12 +160,12 @@ def _build_role_properties(container_id, role_instance_id, thumbprint):
return xml
def _build_health_report(incarnation, container_id, role_instance_id,
- status, substatus, description):
+ status, substatus, description):
#Escape '&', '<' and '>'
- description = saxutils.escape(text(description))
+ description = saxutils.escape(ustr(description))
detail = u''
if substatus is not None:
- substatus = saxutils.escape(text(substatus))
+ substatus = saxutils.escape(ustr(substatus))
detail = (u"<Details>"
u"<SubStatus>{0}</SubStatus>"
u"<Description>{1}</Description>"
@@ -228,6 +254,7 @@ def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp):
'handlerVersion' : handler_status.version,
'handlerName' : handler_status.name,
'status' : handler_status.status,
+ 'code': handler_status.code
}
if handler_status.message is not None:
v1_handler_status["formattedMessage"] = {
@@ -303,7 +330,7 @@ class StatusBlob(object):
self.put_page_blob(url, data)
else:
raise ProtocolError("Unknown blob type: {0}".format(blob_type))
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError("Failed to upload status blob: {0}".format(e))
def get_blob_type(self, url):
@@ -315,7 +342,7 @@ class StatusBlob(object):
"x-ms-date" : timestamp,
'x-ms-version' : self.__class__.__storage_version__
})
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError((u"Failed to get status blob type: {0}"
u"").format(e))
if resp is None or resp.status != httpclient.OK:
@@ -334,10 +361,10 @@ class StatusBlob(object):
data, {
"x-ms-date" : timestamp,
"x-ms-blob-type" : "BlockBlob",
- "Content-Length": text(len(data)),
+ "Content-Length": ustr(len(data)),
"x-ms-version" : self.__class__.__storage_version__
})
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError((u"Failed to upload block blob: {0}"
u"").format(e))
if resp.status != httpclient.CREATED:
@@ -359,10 +386,10 @@ class StatusBlob(object):
"x-ms-date" : timestamp,
"x-ms-blob-type" : "PageBlob",
"Content-Length": "0",
- "x-ms-blob-content-length" : text(page_blob_size),
+ "x-ms-blob-content-length" : ustr(page_blob_size),
"x-ms-version" : self.__class__.__storage_version__
})
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError((u"Failed to clean up page blob: {0}"
u"").format(e))
if resp.status != httpclient.CREATED:
@@ -393,9 +420,9 @@ class StatusBlob(object):
"x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1),
"x-ms-page-write" : "update",
"x-ms-version" : self.__class__.__storage_version__,
- "Content-Length": text(page_end - start)
+ "Content-Length": ustr(page_end - start)
})
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError((u"Failed to upload page blob: {0}"
u"").format(e))
if resp is None or resp.status != httpclient.CREATED:
@@ -411,13 +438,13 @@ def event_param_to_v1(param):
attr_type = 'mt:uint64'
elif param_type is str:
attr_type = 'mt:wstr'
- elif text(param_type).count("'unicode'") > 0:
+ elif ustr(param_type).count("'unicode'") > 0:
attr_type = 'mt:wstr'
elif param_type is bool:
attr_type = 'mt:bool'
elif param_type is float:
attr_type = 'mt:float64'
- return param_format.format(param.name, saxutils.quoteattr(text(param.value)),
+ return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)),
attr_type)
def event_to_v1(event):
@@ -431,6 +458,7 @@ def event_to_v1(event):
class WireClient(object):
def __init__(self, endpoint):
+ logger.info("Wire server endpoint:{0}", endpoint)
self.endpoint = endpoint
self.goal_state = None
self.updated = None
@@ -448,15 +476,15 @@ class WireClient(object):
"""
now = time.time()
if now - self.last_request < 1:
- logger.info("Last request issued less than 1 second ago")
- logger.info("Sleep {0} second to avoid throttling.",
+ logger.verb("Last request issued less than 1 second ago")
+ logger.verb("Sleep {0} second to avoid throttling.",
SHORT_WAITING_INTERVAL)
time.sleep(SHORT_WAITING_INTERVAL)
self.last_request = now
self.req_count += 1
if self.req_count % 3 == 0:
- logger.info("Sleep {0} second to avoid throttling.",
+ logger.verb("Sleep {0} second to avoid throttling.",
SHORT_WAITING_INTERVAL)
time.sleep(SHORT_WAITING_INTERVAL)
self.req_count = 0
@@ -485,15 +513,15 @@ class WireClient(object):
if data is None:
return None
data = remove_bom(data)
- xml_text = text(data, encoding='utf-8')
+ xml_text = ustr(data, encoding='utf-8')
return xml_text
def fetch_config(self, uri, headers):
try:
resp = self.call_wireserver(restutil.http_get, uri,
headers=headers)
- except restutil.HttpError as e:
- raise ProtocolError(text(e))
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
if(resp.status != httpclient.OK):
raise ProtocolError("{0} - {1}".format(resp.status, uri))
@@ -532,12 +560,13 @@ class WireClient(object):
def fetch_manifest(self, version_uris):
for version_uri in version_uris:
+ logger.verb("Fetch ext handler manifest: {0}", version_uri.uri)
try:
resp = self.call_storage_service(restutil.http_get,
version_uri.uri, None,
chk_proxy=True)
- except restutil.HttpError as e:
- raise ProtocolError(text(e))
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
if resp.status == httpclient.OK:
return self.decode_config(resp.read())
@@ -553,7 +582,7 @@ class WireClient(object):
def update_hosting_env(self, goal_state):
if goal_state.hosting_env_uri is None:
raise ProtocolError("HostingEnvironmentConfig uri is empty")
- local_file = HOSTING_ENV_FILE_NAME
+ local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
xml_text = self.fetch_config(goal_state.hosting_env_uri,
self.get_header())
self.save_cache(local_file, xml_text)
@@ -562,7 +591,7 @@ class WireClient(object):
def update_shared_conf(self, goal_state):
if goal_state.shared_conf_uri is None:
raise ProtocolError("SharedConfig uri is empty")
- local_file = SHARED_CONF_FILE_NAME
+ local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
xml_text = self.fetch_config(goal_state.shared_conf_uri,
self.get_header())
self.save_cache(local_file, xml_text)
@@ -571,7 +600,7 @@ class WireClient(object):
def update_certs(self, goal_state):
if goal_state.certs_uri is None:
return
- local_file = CERTS_FILE_NAME
+ local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
xml_text = self.fetch_config(goal_state.certs_uri,
self.get_header_for_cert())
self.save_cache(local_file, xml_text)
@@ -583,25 +612,18 @@ class WireClient(object):
self.ext_conf = ExtensionsConfig(None)
return
incarnation = goal_state.incarnation
- local_file = EXT_CONF_FILE_NAME.format(incarnation)
+ local_file = os.path.join(conf.get_lib_dir(),
+ EXT_CONF_FILE_NAME.format(incarnation))
xml_text = self.fetch_config(goal_state.ext_uri, self.get_header())
self.save_cache(local_file, xml_text)
self.ext_conf = ExtensionsConfig(xml_text)
- for ext_handler in self.ext_conf.ext_handlers.extHandlers:
- self.update_ext_handler_manifest(ext_handler, goal_state)
-
- def update_ext_handler_manifest(self, ext_handler, goal_state):
- local_file = MANIFEST_FILE_NAME.format(ext_handler.name,
- goal_state.incarnation)
- xml_text = self.fetch_manifest(ext_handler.versionUris)
- self.save_cache(local_file, xml_text)
-
+
def update_goal_state(self, forced=False, max_retry=3):
uri = GOAL_STATE_URI.format(self.endpoint)
xml_text = self.fetch_config(uri, self.get_header())
goal_state = GoalState(xml_text)
- incarnation_file = os.path.join(OSUTIL.get_lib_dir(),
+ incarnation_file = os.path.join(conf.get_lib_dir(),
INCARNATION_FILE_NAME)
if not forced:
@@ -619,7 +641,7 @@ class WireClient(object):
try:
self.goal_state = goal_state
file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation)
- goal_state_file = os.path.join(OSUTIL.get_lib_dir(), file_name)
+ goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
self.save_cache(goal_state_file, xml_text)
self.save_cache(incarnation_file, goal_state.incarnation)
self.update_hosting_env(goal_state)
@@ -636,27 +658,34 @@ class WireClient(object):
def get_goal_state(self):
if(self.goal_state is None):
- incarnation = self.fetch_cache(INCARNATION_FILE_NAME)
- goal_state_file = GOAL_STATE_FILE_NAME.format(incarnation)
+ incarnation_file = os.path.join(conf.get_lib_dir(),
+ INCARNATION_FILE_NAME)
+ incarnation = self.fetch_cache(incarnation_file)
+
+ file_name = GOAL_STATE_FILE_NAME.format(incarnation)
+ goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
xml_text = self.fetch_cache(goal_state_file)
self.goal_state = GoalState(xml_text)
return self.goal_state
def get_hosting_env(self):
if(self.hosting_env is None):
- xml_text = self.fetch_cache(HOSTING_ENV_FILE_NAME)
+ local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
self.hosting_env = HostingEnv(xml_text)
return self.hosting_env
def get_shared_conf(self):
if(self.shared_conf is None):
- xml_text = self.fetch_cache(SHARED_CONF_FILE_NAME)
+ local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
self.shared_conf = SharedConfig(xml_text)
return self.shared_conf
def get_certs(self):
if(self.certs is None):
- xml_text = self.fetch_cache(CERTS_FILE_NAME)
+ local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
self.certs = Certificates(self, xml_text)
if self.certs is None:
return None
@@ -669,14 +698,17 @@ class WireClient(object):
self.ext_conf = ExtensionsConfig(None)
else:
local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
xml_text = self.fetch_cache(local_file)
self.ext_conf = ExtensionsConfig(xml_text)
return self.ext_conf
- def get_ext_manifest(self, extension, goal_state):
- local_file = MANIFEST_FILE_NAME.format(extension.name,
- goal_state.incarnation)
- xml_text = self.fetch_cache(local_file)
+ def get_ext_manifest(self, ext_handler, goal_state):
+ local_file = MANIFEST_FILE_NAME.format(ext_handler.name,
+ goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_manifest(ext_handler.versionUris)
+ self.save_cache(local_file, xml_text)
return ExtensionManifest(xml_text)
def check_wire_protocol_version(self):
@@ -693,7 +725,7 @@ class WireClient(object):
else:
error = ("Agent supported wire protocol version: {0} was not "
"advised by Fabric.").format(PROTOCOL_VERSION)
- raise ProtocolNotFound(error)
+ raise ProtocolNotFoundError(error)
def upload_status_blob(self):
ext_conf = self.get_ext_conf()
@@ -711,7 +743,7 @@ class WireClient(object):
try:
resp = self.call_wireserver(restutil.http_post, role_prop_uri,
role_prop, headers = headers)
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError((u"Failed to send role properties: {0}"
u"").format(e))
if resp.status != httpclient.ACCEPTED:
@@ -732,7 +764,7 @@ class WireClient(object):
try:
resp = self.call_wireserver(restutil.http_post, health_report_uri,
health_report, headers = headers)
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError((u"Failed to send provision status: {0}"
u"").format(e))
if resp.status != httpclient.OK:
@@ -750,7 +782,7 @@ class WireClient(object):
try:
header = self.get_header_for_xml_content()
resp = self.call_wireserver(restutil.http_post, uri, data, header)
- except restutil.HttpError as e:
+ except HttpError as e:
raise ProtocolError("Failed to send events:{0}".format(e))
if resp.status != httpclient.OK:
@@ -791,11 +823,10 @@ class WireClient(object):
}
def get_header_for_cert(self):
- cert = ""
- content = self.fetch_cache(TRANSPORT_CERT_FILE_NAME)
- for line in content.split('\n'):
- if "CERTIFICATE" not in line:
- cert += line.rstrip()
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ content = self.fetch_cache(trans_cert_file)
+ cert = get_bytes_from_pem(content)
return {
"x-ms-agent-name":"WALinuxAgent",
"x-ms-version":PROTOCOL_VERSION,
@@ -922,8 +953,6 @@ class Certificates(object):
def __init__(self, client, xml_text):
logger.verb("Load Certificates.xml")
self.client = client
- self.lib_dir = OSUTIL.get_lib_dir()
- self.openssl_cmd = OSUTIL.get_openssl_cmd()
self.cert_list = CertList()
self.parse(xml_text)
@@ -935,22 +964,26 @@ class Certificates(object):
data = findtext(xml_doc, "Data")
if data is None:
return
-
+
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
p7m = ("MIME-Version:1.0\n"
"Content-Disposition: attachment; filename=\"{0}\"\n"
"Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
"Content-Transfer-Encoding: base64\n"
"\n"
- "{2}").format(P7M_FILE_NAME, P7M_FILE_NAME, data)
+ "{2}").format(p7m_file, p7m_file, data)
- self.client.save_cache(os.path.join(self.lib_dir, P7M_FILE_NAME), p7m)
+ self.client.save_cache(p7m_file, p7m)
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
#decrypt certificates
- cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3}"
- "| {4} pkcs12 -nodes -password pass: -out {5}"
- "").format(self.openssl_cmd, P7M_FILE_NAME,
- TRANSPORT_PRV_FILE_NAME, TRANSPORT_CERT_FILE_NAME,
- self.openssl_cmd, PEM_FILE_NAME)
- shellutil.run(cmd)
+ cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
+ pem_file)
#The parsing process use public key to match prv and crt.
buf = []
@@ -960,7 +993,7 @@ class Certificates(object):
thumbprints = {}
index = 0
v1_cert_list = []
- with open(PEM_FILE_NAME) as pem:
+ with open(pem_file) as pem:
for line in pem.readlines():
buf.append(line)
if re.match(r'[-]+BEGIN.*KEY[-]+', line):
@@ -969,15 +1002,15 @@ class Certificates(object):
begin_crt = True
elif re.match(r'[-]+END.*KEY[-]+', line):
tmp_file = self.write_to_tmp_file(index, 'prv', buf)
- pub = OSUTIL.get_pubkey_from_prv(tmp_file)
+ pub = cryptutil.get_pubkey_from_prv(tmp_file)
prvs[pub] = tmp_file
buf = []
index += 1
begin_prv = False
elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
tmp_file = self.write_to_tmp_file(index, 'crt', buf)
- pub = OSUTIL.get_pubkey_from_crt(tmp_file)
- thumbprint = OSUTIL.get_thumbprint_from_crt(tmp_file)
+ pub = cryptutil.get_pubkey_from_crt(tmp_file)
+ thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
thumbprints[pub] = thumbprint
#Rename crt with thumbprint as the file name
crt = "{0}.crt".format(thumbprint)
@@ -985,7 +1018,7 @@ class Certificates(object):
"name":None,
"thumbprint":thumbprint
})
- os.rename(tmp_file, os.path.join(self.lib_dir, crt))
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
buf = []
index += 1
begin_crt = False
@@ -996,7 +1029,7 @@ class Certificates(object):
if thumbprint:
tmp_file = prvs[pubkey]
prv = "{0}.prv".format(thumbprint)
- os.rename(tmp_file, os.path.join(self.lib_dir, prv))
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))
for v1_cert in v1_cert_list:
cert = Cert()
@@ -1004,7 +1037,8 @@ class Certificates(object):
self.cert_list.certificates.append(cert)
def write_to_tmp_file(self, index, suffix, buf):
- file_name = os.path.join(self.lib_dir, "{0}.{1}".format(index, suffix))
+ file_name = os.path.join(conf.get_lib_dir(),
+ "{0}.{1}".format(index, suffix))
self.client.save_cache(file_name, "".join(buf))
return file_name
@@ -1090,7 +1124,7 @@ class ExtensionsConfig(object):
ext.name = ext_handler.name
ext.sequenceNumber = seqNo
ext.publicSettings = handler_settings.get("publicSettings")
- ext.privateSettings = handler_settings.get("protectedSettings")
+ ext.protectedSettings = handler_settings.get("protectedSettings")
thumbprint = handler_settings.get("protectedSettingsCertThumbprint")
ext.certificateThumbprint = thumbprint
ext_handler.properties.extensions.append(ext)