summaryrefslogtreecommitdiff
path: root/azurelinuxagent/common/protocol/metadata.py
diff options
context:
space:
mode:
Diffstat (limited to 'azurelinuxagent/common/protocol/metadata.py')
-rw-r--r--azurelinuxagent/common/protocol/metadata.py249
1 files changed, 205 insertions, 44 deletions
diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py
index f86f72f..c61e373 100644
--- a/azurelinuxagent/common/protocol/metadata.py
+++ b/azurelinuxagent/common/protocol/metadata.py
@@ -16,39 +16,42 @@
#
# Requires Python 2.4+ and Openssl 1.0+
+import base64
import json
-import shutil
import os
-import time
-from azurelinuxagent.common.exception import ProtocolError, HttpError
-from azurelinuxagent.common.future import httpclient, ustr
+import shutil
+import re
import azurelinuxagent.common.conf as conf
-import azurelinuxagent.common.logger as logger
-import azurelinuxagent.common.utils.restutil as restutil
-import azurelinuxagent.common.utils.textutil as textutil
import azurelinuxagent.common.utils.fileutil as fileutil
-from azurelinuxagent.common.utils.cryptutil import CryptUtil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+from azurelinuxagent.common.future import httpclient
from azurelinuxagent.common.protocol.restapi import *
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
-METADATA_ENDPOINT='169.254.169.254'
-APIVERSION='2015-05-01-preview'
+METADATA_ENDPOINT = '169.254.169.254'
+APIVERSION = '2015-05-01-preview'
BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}"
TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem"
TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem"
+P7M_FILE_NAME = "Certificates.p7m"
+P7B_FILE_NAME = "Certificates.p7b"
+PEM_FILE_NAME = "Certificates.pem"
-#TODO remote workarround for azure stack
+# TODO remote workaround for azure stack
MAX_PING = 30
RETRY_PING_INTERVAL = 10
+
def _add_content_type(headers):
if headers is None:
headers = {}
headers["content-type"] = "application/json"
return headers
-class MetadataProtocol(Protocol):
+class MetadataProtocol(Protocol):
def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT):
self.apiversion = apiversion
self.endpoint = endpoint
@@ -65,11 +68,12 @@ class MetadataProtocol(Protocol):
self.apiversion, "")
self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent",
self.apiversion, "")
- self.ext_status_uri = BASE_URI.format(self.endpoint,
+ self.ext_status_uri = BASE_URI.format(self.endpoint,
"status/extensions/{0}",
self.apiversion, "")
self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry",
self.apiversion, "")
+ self.certs = None
def _get_data(self, url, headers=None):
try:
@@ -82,13 +86,12 @@ class MetadataProtocol(Protocol):
data = resp.read()
etag = resp.getheader('ETag')
- if data is None:
- return None
- data = json.loads(ustr(data, encoding="utf-8"))
+ if data is not None:
+ data = json.loads(ustr(data, encoding="utf-8"))
return data, etag
def _put_data(self, url, data, headers=None):
- headers = _add_content_type(headers)
+ headers = _add_content_type(headers)
try:
resp = restutil.http_put(url, json.dumps(data), headers=headers)
except HttpError as e:
@@ -97,16 +100,16 @@ class MetadataProtocol(Protocol):
raise ProtocolError("{0} - PUT: {1}".format(resp.status, url))
def _post_data(self, url, data, headers=None):
- headers = _add_content_type(headers)
+ headers = _add_content_type(headers)
try:
resp = restutil.http_post(url, json.dumps(data), headers=headers)
except HttpError as e:
raise ProtocolError(ustr(e))
if resp.status != httpclient.CREATED:
raise ProtocolError("{0} - POST: {1}".format(resp.status, url))
-
+
def _get_trans_cert(self):
- trans_crt_file = os.path.join(conf.get_lib_dir(),
+ trans_crt_file = os.path.join(conf.get_lib_dir(),
TRANSPORT_CERT_FILE_NAME)
if not os.path.isfile(trans_crt_file):
raise ProtocolError("{0} is missing.".format(trans_crt_file))
@@ -115,22 +118,22 @@ class MetadataProtocol(Protocol):
def detect(self):
self.get_vminfo()
- trans_prv_file = os.path.join(conf.get_lib_dir(),
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
TRANSPORT_PRV_FILE_NAME)
- trans_cert_file = os.path.join(conf.get_lib_dir(),
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
TRANSPORT_CERT_FILE_NAME)
cryptutil = CryptUtil(conf.get_openssl_cmd())
cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
- #"Install" the cert and private key to /var/lib/waagent
+ # "Install" the cert and private key to /var/lib/waagent
thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file)
- prv_file = os.path.join(conf.get_lib_dir(),
+ prv_file = os.path.join(conf.get_lib_dir(),
"{0}.prv".format(thumbprint))
- crt_file = os.path.join(conf.get_lib_dir(),
+ crt_file = os.path.join(conf.get_lib_dir(),
"{0}.crt".format(thumbprint))
shutil.copyfile(trans_prv_file, prv_file)
shutil.copyfile(trans_cert_file, crt_file)
-
+ self.update_goal_state(forced=True)
def get_vminfo(self):
vminfo = VMInfo()
@@ -139,18 +142,42 @@ class MetadataProtocol(Protocol):
return vminfo
def get_certs(self):
- #TODO download and save certs
- return CertList()
+ certlist = CertList()
+ certificatedata = CertificateData()
+ data, etag = self._get_data(self.cert_uri)
+
+ set_properties("certlist", certlist, data)
+
+ cert_list = get_properties(certlist)
+
+ headers = {
+ "x-ms-vmagent-public-x509-cert": self._get_trans_cert()
+ }
+
+ for cert_i in cert_list["certificates"]:
+ certificate_data_uri = cert_i['certificateDataUri']
+ data, etag = self._get_data(certificate_data_uri, headers=headers)
+ set_properties("certificatedata", certificatedata, data)
+ json_certificate_data = get_properties(certificatedata)
+
+ self.certs = Certificates(self, json_certificate_data)
+
+ if self.certs is None:
+ return None
+ return self.certs
def get_vmagent_manifests(self, last_etag=None):
manifests = VMAgentManifestList()
+ self.update_goal_state()
data, etag = self._get_data(self.vmagent_uri)
- if last_etag == None or last_etag < etag:
- set_properties("vmAgentManifests", manifests.vmAgentManifests, data)
+ if last_etag is None or last_etag < etag:
+ set_properties("vmAgentManifests",
+ manifests.vmAgentManifests,
+ data)
return manifests, etag
def get_vmagent_pkgs(self, vmagent_manifest):
- #Agent package is the same with extension handler
+ # Agent package is the same with extension handler
vmagent_pkgs = ExtHandlerPackageList()
data = None
for manifest_uri in vmagent_manifest.versionsManifestUris:
@@ -168,27 +195,35 @@ class MetadataProtocol(Protocol):
return vmagent_pkgs
def get_ext_handlers(self, last_etag=None):
+ self.update_goal_state()
headers = {
"x-ms-vmagent-public-x509-cert": self._get_trans_cert()
}
ext_list = ExtHandlerList()
data, etag = self._get_data(self.ext_uri, headers=headers)
- if last_etag == None or last_etag < etag:
+ if last_etag is None or last_etag < etag:
set_properties("extensionHandlers", ext_list.extHandlers, data)
return ext_list, etag
def get_ext_handler_pkgs(self, ext_handler):
- ext_handler_pkgs = ExtHandlerPackageList()
- data = None
+ logger.info("Get extension handler packages")
+ pkg_list = ExtHandlerPackageList()
+
+ manifest = None
for version_uri in ext_handler.versionUris:
try:
- data, etag = self._get_data(version_uri.uri)
+ manifest, etag = self._get_data(version_uri.uri)
+ logger.info("Successfully downloaded manifest")
break
except ProtocolError as e:
- logger.warn("Failed to get version uris: {0}", e)
- logger.info("Retry getting version uris")
- set_properties("extensionPackages", ext_handler_pkgs, data)
- return ext_handler_pkgs
+ logger.warn("Failed to fetch manifest: {0}", e)
+
+ if manifest is None:
+ raise ValueError("Extension manifest is empty")
+
+ set_properties("extensionPackages", pkg_list, manifest)
+
+ return pkg_list
def report_provision_status(self, provision_status):
validate_param('provisionStatus', provision_status, ProvisionStatus)
@@ -198,7 +233,8 @@ class MetadataProtocol(Protocol):
def report_vm_status(self, vm_status):
validate_param('vmStatus', vm_status, VMStatus)
data = get_properties(vm_status)
- #TODO code field is not implemented for metadata protocol yet. Remove it
+ # TODO code field is not implemented for metadata protocol yet.
+ # Remove it
handler_statuses = data['vmAgent']['extensionHandlers']
for handler_status in handler_statuses:
try:
@@ -215,9 +251,134 @@ class MetadataProtocol(Protocol):
self._put_data(uri, data)
def report_event(self, events):
- #TODO disable telemetry for azure stack test
- #validate_param('events', events, TelemetryEventList)
- #data = get_properties(events)
- #self._post_data(self.event_uri, data)
+ # TODO disable telemetry for azure stack test
+ # validate_param('events', events, TelemetryEventList)
+ # data = get_properties(events)
+ # self._post_data(self.event_uri, data)
pass
+ def update_certs(self):
+ certificates = self.get_certs()
+ return certificates.cert_list
+
+ def update_goal_state(self, forced=False, max_retry=3):
+ # Start updating goalstate, retry on 410
+ for retry in range(0, max_retry):
+ try:
+ self.update_certs()
+ return
+ except:
+ logger.verbose("Incarnation is out of date. Update goalstate.")
+ raise ProtocolError("Exceeded max retry updating goal state")
+
+
+class Certificates(object):
+ """
+ Object containing certificates of host and provisioned user.
+ """
+
+ def __init__(self, client, json_text):
+ self.cert_list = CertList()
+ self.parse(json_text)
+
+ def parse(self, json_text):
+ """
+ Parse multiple certificates into seperate files.
+ """
+
+ data = json_text["certificateData"]
+ if data is None:
+ logger.verbose("No data in json_text received!")
+ return
+
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ p7b_file = os.path.join(conf.get_lib_dir(), P7B_FILE_NAME)
+
+ # Wrapping the certificate lines.
+ # decode and save the result into p7b_file
+ fileutil.write_file(p7b_file, base64.b64decode(data), asbin=True)
+
+ ssl_cmd = "openssl pkcs7 -text -in {0} -inform der | grep -v '^-----' "
+ ret, data = shellutil.run_get_output(ssl_cmd.format(p7b_file))
+
+ p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
+ p7m = ("MIME-Version:1.0\n"
+ "Content-Disposition: attachment; filename=\"{0}\"\n"
+ "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
+ "Content-Transfer-Encoding: base64\n"
+ "\n"
+ "{2}").format(p7m_file, p7m_file, data)
+
+ self.save_cache(p7m_file, p7m)
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
+ # decrypt certificates
+ cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
+ pem_file)
+
+ # The parsing process use public key to match prv and crt.
+ buf = []
+ begin_crt = False
+ begin_prv = False
+ prvs = {}
+ thumbprints = {}
+ index = 0
+ v1_cert_list = []
+ with open(pem_file) as pem:
+ for line in pem.readlines():
+ buf.append(line)
+ if re.match(r'[-]+BEGIN.*KEY[-]+', line):
+ begin_prv = True
+ elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
+ begin_crt = True
+ elif re.match(r'[-]+END.*KEY[-]+', line):
+ tmp_file = self.write_to_tmp_file(index, 'prv', buf)
+ pub = cryptutil.get_pubkey_from_prv(tmp_file)
+ prvs[pub] = tmp_file
+ buf = []
+ index += 1
+ begin_prv = False
+ elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
+ tmp_file = self.write_to_tmp_file(index, 'crt', buf)
+ pub = cryptutil.get_pubkey_from_crt(tmp_file)
+ thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
+ thumbprints[pub] = thumbprint
+ # Rename crt with thumbprint as the file name
+ crt = "{0}.crt".format(thumbprint)
+ v1_cert_list.append({
+ "name": None,
+ "thumbprint": thumbprint
+ })
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
+ buf = []
+ index += 1
+ begin_crt = False
+
+ # Rename prv key with thumbprint as the file name
+ for pubkey in prvs:
+ thumbprint = thumbprints[pubkey]
+ if thumbprint:
+ tmp_file = prvs[pubkey]
+ prv = "{0}.prv".format(thumbprint)
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))
+
+ for v1_cert in v1_cert_list:
+ cert = Cert()
+ set_properties("certs", cert, v1_cert)
+ self.cert_list.certificates.append(cert)
+
+ def save_cache(self, local_file, data):
+ try:
+ fileutil.write_file(local_file, data)
+ except IOError as e:
+ raise ProtocolError("Failed to write cache: {0}".format(e))
+
+ def write_to_tmp_file(self, index, suffix, buf):
+ file_name = os.path.join(conf.get_lib_dir(),
+ "{0}.{1}".format(index, suffix))
+ self.save_cache(file_name, "".join(buf))
+ return file_name