diff options
Diffstat (limited to 'azurelinuxagent/common/protocol/metadata.py')
-rw-r--r-- | azurelinuxagent/common/protocol/metadata.py | 249 |
1 files changed, 205 insertions, 44 deletions
diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py index f86f72f..c61e373 100644 --- a/azurelinuxagent/common/protocol/metadata.py +++ b/azurelinuxagent/common/protocol/metadata.py @@ -16,39 +16,42 @@ # # Requires Python 2.4+ and Openssl 1.0+ +import base64 import json -import shutil import os -import time -from azurelinuxagent.common.exception import ProtocolError, HttpError -from azurelinuxagent.common.future import httpclient, ustr +import shutil +import re import azurelinuxagent.common.conf as conf -import azurelinuxagent.common.logger as logger -import azurelinuxagent.common.utils.restutil as restutil -import azurelinuxagent.common.utils.textutil as textutil import azurelinuxagent.common.utils.fileutil as fileutil -from azurelinuxagent.common.utils.cryptutil import CryptUtil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.future import httpclient from azurelinuxagent.common.protocol.restapi import * +from azurelinuxagent.common.utils.cryptutil import CryptUtil -METADATA_ENDPOINT='169.254.169.254' -APIVERSION='2015-05-01-preview' +METADATA_ENDPOINT = '169.254.169.254' +APIVERSION = '2015-05-01-preview' BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem" TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem" +P7M_FILE_NAME = "Certificates.p7m" +P7B_FILE_NAME = "Certificates.p7b" +PEM_FILE_NAME = "Certificates.pem" -#TODO remote workarround for azure stack +# TODO remote workaround for azure stack MAX_PING = 30 RETRY_PING_INTERVAL = 10 + def _add_content_type(headers): if headers is None: headers = {} headers["content-type"] = "application/json" return headers -class MetadataProtocol(Protocol): +class MetadataProtocol(Protocol): def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT): self.apiversion = apiversion self.endpoint = endpoint @@ -65,11 +68,12 @@ class MetadataProtocol(Protocol): self.apiversion, "") self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent", self.apiversion, "") - self.ext_status_uri = BASE_URI.format(self.endpoint, + self.ext_status_uri = BASE_URI.format(self.endpoint, "status/extensions/{0}", self.apiversion, "") self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry", self.apiversion, "") + self.certs = None def _get_data(self, url, headers=None): try: @@ -82,13 +86,12 @@ class MetadataProtocol(Protocol): data = resp.read() etag = resp.getheader('ETag') - if data is None: - return None - data = json.loads(ustr(data, encoding="utf-8")) + if data is not None: + data = json.loads(ustr(data, encoding="utf-8")) return data, etag def _put_data(self, url, data, headers=None): - headers = _add_content_type(headers) + headers = _add_content_type(headers) try: resp = restutil.http_put(url, json.dumps(data), headers=headers) except HttpError as e: @@ -97,16 +100,16 @@ class MetadataProtocol(Protocol): raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) def _post_data(self, url, data, headers=None): - headers = _add_content_type(headers) + headers = _add_content_type(headers) try: resp = restutil.http_post(url, json.dumps(data), headers=headers) except HttpError as e: raise ProtocolError(ustr(e)) if resp.status != httpclient.CREATED: raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) - + def _get_trans_cert(self): - trans_crt_file = os.path.join(conf.get_lib_dir(), + trans_crt_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) if not os.path.isfile(trans_crt_file): raise ProtocolError("{0} is missing.".format(trans_crt_file)) @@ -115,22 +118,22 @@ class MetadataProtocol(Protocol): def detect(self): self.get_vminfo() - trans_prv_file = os.path.join(conf.get_lib_dir(), + trans_prv_file = os.path.join(conf.get_lib_dir(), TRANSPORT_PRV_FILE_NAME) - trans_cert_file = os.path.join(conf.get_lib_dir(), + trans_cert_file = os.path.join(conf.get_lib_dir(), TRANSPORT_CERT_FILE_NAME) cryptutil = CryptUtil(conf.get_openssl_cmd()) cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) - #"Install" the cert and private key to /var/lib/waagent + # "Install" the cert and private key to /var/lib/waagent thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) - prv_file = os.path.join(conf.get_lib_dir(), + prv_file = os.path.join(conf.get_lib_dir(), "{0}.prv".format(thumbprint)) - crt_file = os.path.join(conf.get_lib_dir(), + crt_file = os.path.join(conf.get_lib_dir(), "{0}.crt".format(thumbprint)) shutil.copyfile(trans_prv_file, prv_file) shutil.copyfile(trans_cert_file, crt_file) - + self.update_goal_state(forced=True) def get_vminfo(self): vminfo = VMInfo() @@ -139,18 +142,42 @@ class MetadataProtocol(Protocol): return vminfo def get_certs(self): - #TODO download and save certs - return CertList() + certlist = CertList() + certificatedata = CertificateData() + data, etag = self._get_data(self.cert_uri) + + set_properties("certlist", certlist, data) + + cert_list = get_properties(certlist) + + headers = { + "x-ms-vmagent-public-x509-cert": self._get_trans_cert() + } + + for cert_i in cert_list["certificates"]: + certificate_data_uri = cert_i['certificateDataUri'] + data, etag = self._get_data(certificate_data_uri, headers=headers) + set_properties("certificatedata", certificatedata, data) + json_certificate_data = get_properties(certificatedata) + + self.certs = Certificates(self, json_certificate_data) + + if self.certs is None: + return None + return self.certs def get_vmagent_manifests(self, last_etag=None): manifests = VMAgentManifestList() + self.update_goal_state() data, etag = self._get_data(self.vmagent_uri) - if last_etag == None or last_etag < etag: - set_properties("vmAgentManifests", manifests.vmAgentManifests, data) + if last_etag is None or last_etag < etag: + set_properties("vmAgentManifests", + manifests.vmAgentManifests, + data) return manifests, etag def get_vmagent_pkgs(self, vmagent_manifest): - #Agent package is the same with extension handler + # Agent package is the same with extension handler vmagent_pkgs = ExtHandlerPackageList() data = None for manifest_uri in vmagent_manifest.versionsManifestUris: @@ -168,27 +195,35 @@ class MetadataProtocol(Protocol): return vmagent_pkgs def get_ext_handlers(self, last_etag=None): + self.update_goal_state() headers = { "x-ms-vmagent-public-x509-cert": self._get_trans_cert() } ext_list = ExtHandlerList() data, etag = self._get_data(self.ext_uri, headers=headers) - if last_etag == None or last_etag < etag: + if last_etag is None or last_etag < etag: set_properties("extensionHandlers", ext_list.extHandlers, data) return ext_list, etag def get_ext_handler_pkgs(self, ext_handler): - ext_handler_pkgs = ExtHandlerPackageList() - data = None + logger.info("Get extension handler packages") + pkg_list = ExtHandlerPackageList() + + manifest = None for version_uri in ext_handler.versionUris: try: - data, etag = self._get_data(version_uri.uri) + manifest, etag = self._get_data(version_uri.uri) + logger.info("Successfully downloaded manifest") break except ProtocolError as e: - logger.warn("Failed to get version uris: {0}", e) - logger.info("Retry getting version uris") - set_properties("extensionPackages", ext_handler_pkgs, data) - return ext_handler_pkgs + logger.warn("Failed to fetch manifest: {0}", e) + + if manifest is None: + raise ValueError("Extension manifest is empty") + + set_properties("extensionPackages", pkg_list, manifest) + + return pkg_list def report_provision_status(self, provision_status): validate_param('provisionStatus', provision_status, ProvisionStatus) @@ -198,7 +233,8 @@ class MetadataProtocol(Protocol): def report_vm_status(self, vm_status): validate_param('vmStatus', vm_status, VMStatus) data = get_properties(vm_status) - #TODO code field is not implemented for metadata protocol yet. Remove it + # TODO code field is not implemented for metadata protocol yet. + # Remove it handler_statuses = data['vmAgent']['extensionHandlers'] for handler_status in handler_statuses: try: @@ -215,9 +251,134 @@ class MetadataProtocol(Protocol): self._put_data(uri, data) def report_event(self, events): - #TODO disable telemetry for azure stack test - #validate_param('events', events, TelemetryEventList) - #data = get_properties(events) - #self._post_data(self.event_uri, data) + # TODO disable telemetry for azure stack test + # validate_param('events', events, TelemetryEventList) + # data = get_properties(events) + # self._post_data(self.event_uri, data) pass + def update_certs(self): + certificates = self.get_certs() + return certificates.cert_list + + def update_goal_state(self, forced=False, max_retry=3): + # Start updating goalstate, retry on 410 + for retry in range(0, max_retry): + try: + self.update_certs() + return + except: + logger.verbose("Incarnation is out of date. Update goalstate.") + raise ProtocolError("Exceeded max retry updating goal state") + + +class Certificates(object): + """ + Object containing certificates of host and provisioned user. + """ + + def __init__(self, client, json_text): + self.cert_list = CertList() + self.parse(json_text) + + def parse(self, json_text): + """ + Parse multiple certificates into seperate files. + """ + + data = json_text["certificateData"] + if data is None: + logger.verbose("No data in json_text received!") + return + + cryptutil = CryptUtil(conf.get_openssl_cmd()) + p7b_file = os.path.join(conf.get_lib_dir(), P7B_FILE_NAME) + + # Wrapping the certificate lines. + # decode and save the result into p7b_file + fileutil.write_file(p7b_file, base64.b64decode(data), asbin=True) + + ssl_cmd = "openssl pkcs7 -text -in {0} -inform der | grep -v '^-----' " + ret, data = shellutil.run_get_output(ssl_cmd.format(p7b_file)) + + p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) + p7m = ("MIME-Version:1.0\n" + "Content-Disposition: attachment; filename=\"{0}\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" + "Content-Transfer-Encoding: base64\n" + "\n" + "{2}").format(p7m_file, p7m_file, data) + + self.save_cache(p7m_file, p7m) + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) + # decrypt certificates + cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, + pem_file) + + # The parsing process use public key to match prv and crt. + buf = [] + begin_crt = False + begin_prv = False + prvs = {} + thumbprints = {} + index = 0 + v1_cert_list = [] + with open(pem_file) as pem: + for line in pem.readlines(): + buf.append(line) + if re.match(r'[-]+BEGIN.*KEY[-]+', line): + begin_prv = True + elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): + begin_crt = True + elif re.match(r'[-]+END.*KEY[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'prv', buf) + pub = cryptutil.get_pubkey_from_prv(tmp_file) + prvs[pub] = tmp_file + buf = [] + index += 1 + begin_prv = False + elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'crt', buf) + pub = cryptutil.get_pubkey_from_crt(tmp_file) + thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) + thumbprints[pub] = thumbprint + # Rename crt with thumbprint as the file name + crt = "{0}.crt".format(thumbprint) + v1_cert_list.append({ + "name": None, + "thumbprint": thumbprint + }) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) + buf = [] + index += 1 + begin_crt = False + + # Rename prv key with thumbprint as the file name + for pubkey in prvs: + thumbprint = thumbprints[pubkey] + if thumbprint: + tmp_file = prvs[pubkey] + prv = "{0}.prv".format(thumbprint) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) + + for v1_cert in v1_cert_list: + cert = Cert() + set_properties("certs", cert, v1_cert) + self.cert_list.certificates.append(cert) + + def save_cache(self, local_file, data): + try: + fileutil.write_file(local_file, data) + except IOError as e: + raise ProtocolError("Failed to write cache: {0}".format(e)) + + def write_to_tmp_file(self, index, suffix, buf): + file_name = os.path.join(conf.get_lib_dir(), + "{0}.{1}".format(index, suffix)) + self.save_cache(file_name, "".join(buf)) + return file_name |