diff options
Diffstat (limited to 'azurelinuxagent/common/protocol/util.py')
-rw-r--r-- | azurelinuxagent/common/protocol/util.py | 285 |
1 files changed, 285 insertions, 0 deletions
diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py new file mode 100644 index 0000000..7e7a74f --- /dev/null +++ b/azurelinuxagent/common/protocol/util.py @@ -0,0 +1,285 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import re +import shutil +import time +import threading +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \ + ProtocolNotFoundError, DhcpError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.osutil import get_osutil +from azurelinuxagent.common.dhcp import get_dhcp_handler +from azurelinuxagent.common.protocol.ovfenv import OvfEnv +from azurelinuxagent.common.protocol.wire import WireProtocol +from azurelinuxagent.common.protocol.metadata import MetadataProtocol, \ + METADATA_ENDPOINT +import azurelinuxagent.common.utils.shellutil as shellutil + +OVF_FILE_NAME = "ovf-env.xml" + +#Tag file to indicate usage of metadata protocol +TAG_FILE_NAME = "useMetadataEndpoint.tag" + +PROTOCOL_FILE_NAME = "Protocol" + +#MAX retry times for protocol probing +MAX_RETRY = 360 + +PROBE_INTERVAL = 10 + +ENDPOINT_FILE_NAME = "WireServerEndpoint" + +def get_protocol_util(): + return ProtocolUtil() + +class ProtocolUtil(object): + """ + ProtocolUtil handles initialization for protocol instance. 2 protocol types + are invoked, wire protocol and metadata protocols. + """ + def __init__(self): + self.lock = threading.Lock() + self.protocol = None + self.osutil = get_osutil() + self.dhcp_handler = get_dhcp_handler() + + def copy_ovf_env(self): + """ + Copy ovf env file from dvd to hard disk. + Remove password before save it to the disk + """ + dvd_mount_point = conf.get_dvd_mount_point() + ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME) + tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME) + try: + self.osutil.mount_dvd() + ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) + ovfenv = OvfEnv(ovfxml) + ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml) + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + fileutil.write_file(ovf_file_path, ovfxml) + + if os.path.isfile(tag_file_path_on_dvd): + logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME) + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + shutil.copyfile(tag_file_path_on_dvd, tag_file_path) + + except (OSUtilError, IOError) as e: + raise ProtocolError(ustr(e)) + + try: + self.osutil.umount_dvd() + self.osutil.eject_dvd() + except OSUtilError as e: + logger.warn(ustr(e)) + + return ovfenv + + def get_ovf_env(self): + """ + Load saved ovf-env.xml + """ + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + if os.path.isfile(ovf_file_path): + xml_text = fileutil.read_file(ovf_file_path) + return OvfEnv(xml_text) + else: + raise ProtocolError("ovf-env.xml is missing.") + + def _get_wireserver_endpoint(self): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + return fileutil.read_file(file_path) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _set_wireserver_endpoint(self, endpoint): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + fileutil.write_file(file_path, endpoint) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _detect_wire_protocol(self): + endpoint = self.dhcp_handler.endpoint + if endpoint is None: + logger.info("WireServer endpoint is not found. Rerun dhcp handler") + try: + self.dhcp_handler.run() + except DhcpError as e: + raise ProtocolError(ustr(e)) + endpoint = self.dhcp_handler.endpoint + + try: + protocol = WireProtocol(endpoint) + protocol.detect() + self._set_wireserver_endpoint(endpoint) + self.save_protocol("WireProtocol") + return protocol + except ProtocolError as e: + logger.info("WireServer is not responding. Reset endpoint") + self.dhcp_handler.endpoint = None + self.dhcp_handler.skip_cache = True + raise e + + def _detect_metadata_protocol(self): + protocol = MetadataProtocol() + protocol.detect() + + #Only allow root access METADATA_ENDPOINT + self.osutil.set_admin_access_to_ip(METADATA_ENDPOINT) + + self.save_protocol("MetadataProtocol") + + return protocol + + def _detect_protocol(self, protocols): + """ + Probe protocol endpoints in turn. + """ + self.clear_protocol() + + for retry in range(0, MAX_RETRY): + for protocol in protocols: + try: + if protocol == "WireProtocol": + return self._detect_wire_protocol() + + if protocol == "MetadataProtocol": + return self._detect_metadata_protocol() + + except ProtocolError as e: + logger.info("Protocol endpoint not found: {0}, {1}", + protocol, e) + + if retry < MAX_RETRY -1: + logger.info("Retry detect protocols: retry={0}", retry) + time.sleep(PROBE_INTERVAL) + raise ProtocolNotFoundError("No protocol found.") + + def _get_protocol(self): + """ + Get protocol instance based on previous detecting result. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), + PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + raise ProtocolNotFoundError("No protocol found") + + protocol_name = fileutil.read_file(protocol_file_path) + if protocol_name == "WireProtocol": + endpoint = self._get_wireserver_endpoint() + return WireProtocol(endpoint) + elif protocol_name == "MetadataProtocol": + return MetadataProtocol() + else: + raise ProtocolNotFoundError(("Unknown protocol: {0}" + "").format(protocol_name)) + + def save_protocol(self, protocol_name): + """ + Save protocol endpoint + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + try: + fileutil.write_file(protocol_file_path, protocol_name) + except IOError as e: + logger.error("Failed to save protocol endpoint: {0}", e) + + + def clear_protocol(self): + """ + Cleanup previous saved endpoint. + """ + logger.info("Clean protocol") + self.protocol = None + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + return + + try: + os.remove(protocol_file_path) + except IOError as e: + logger.error("Failed to clear protocol endpoint: {0}", e) + + def get_protocol(self): + """ + Detect protocol by endpoints + + :returns: protocol instance + """ + self.lock.acquire() + + try: + if self.protocol is not None: + return self.protocol + + try: + self.protocol = self._get_protocol() + return self.protocol + except ProtocolNotFoundError: + pass + + logger.info("Detect protocol endpoints") + protocols = ["WireProtocol", "MetadataProtocol"] + self.protocol = self._detect_protocol(protocols) + + return self.protocol + + finally: + self.lock.release() + + + def get_protocol_by_file(self): + """ + Detect protocol by tag file. + + If a file "useMetadataEndpoint.tag" is found on provision iso, + metedata protocol will be used. No need to probe for wire protocol + + :returns: protocol instance + """ + self.lock.acquire() + + try: + if self.protocol is not None: + return self.protocol + + try: + self.protocol = self._get_protocol() + return self.protocol + except ProtocolNotFoundError: + pass + + logger.info("Detect protocol by file") + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + protocols = [] + if os.path.isfile(tag_file_path): + protocols.append("MetadataProtocol") + else: + protocols.append("WireProtocol") + self.protocol = self._detect_protocol(protocols) + return self.protocol + + finally: + self.lock.release() |