summaryrefslogtreecommitdiff
path: root/azurelinuxagent/common/protocol/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'azurelinuxagent/common/protocol/util.py')
-rw-r--r--azurelinuxagent/common/protocol/util.py285
1 files changed, 285 insertions, 0 deletions
diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py
new file mode 100644
index 0000000..7e7a74f
--- /dev/null
+++ b/azurelinuxagent/common/protocol/util.py
@@ -0,0 +1,285 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+import os
+import re
+import shutil
+import time
+import threading
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \
+ ProtocolNotFoundError, DhcpError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.osutil import get_osutil
+from azurelinuxagent.common.dhcp import get_dhcp_handler
+from azurelinuxagent.common.protocol.ovfenv import OvfEnv
+from azurelinuxagent.common.protocol.wire import WireProtocol
+from azurelinuxagent.common.protocol.metadata import MetadataProtocol, \
+ METADATA_ENDPOINT
+import azurelinuxagent.common.utils.shellutil as shellutil
+
+OVF_FILE_NAME = "ovf-env.xml"
+
+#Tag file to indicate usage of metadata protocol
+TAG_FILE_NAME = "useMetadataEndpoint.tag"
+
+PROTOCOL_FILE_NAME = "Protocol"
+
+#MAX retry times for protocol probing
+MAX_RETRY = 360
+
+PROBE_INTERVAL = 10
+
+ENDPOINT_FILE_NAME = "WireServerEndpoint"
+
+def get_protocol_util():
+ return ProtocolUtil()
+
+class ProtocolUtil(object):
+ """
+ ProtocolUtil handles initialization for protocol instance. 2 protocol types
+ are invoked, wire protocol and metadata protocols.
+ """
+ def __init__(self):
+ self.lock = threading.Lock()
+ self.protocol = None
+ self.osutil = get_osutil()
+ self.dhcp_handler = get_dhcp_handler()
+
+ def copy_ovf_env(self):
+ """
+ Copy ovf env file from dvd to hard disk.
+ Remove password before save it to the disk
+ """
+ dvd_mount_point = conf.get_dvd_mount_point()
+ ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME)
+ tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME)
+ try:
+ self.osutil.mount_dvd()
+ ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
+ ovfenv = OvfEnv(ovfxml)
+ ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml)
+ ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
+ fileutil.write_file(ovf_file_path, ovfxml)
+
+ if os.path.isfile(tag_file_path_on_dvd):
+ logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME)
+ tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME)
+ shutil.copyfile(tag_file_path_on_dvd, tag_file_path)
+
+ except (OSUtilError, IOError) as e:
+ raise ProtocolError(ustr(e))
+
+ try:
+ self.osutil.umount_dvd()
+ self.osutil.eject_dvd()
+ except OSUtilError as e:
+ logger.warn(ustr(e))
+
+ return ovfenv
+
+ def get_ovf_env(self):
+ """
+ Load saved ovf-env.xml
+ """
+ ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
+ if os.path.isfile(ovf_file_path):
+ xml_text = fileutil.read_file(ovf_file_path)
+ return OvfEnv(xml_text)
+ else:
+ raise ProtocolError("ovf-env.xml is missing.")
+
+ def _get_wireserver_endpoint(self):
+ try:
+ file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME)
+ return fileutil.read_file(file_path)
+ except IOError as e:
+ raise OSUtilError(ustr(e))
+
+ def _set_wireserver_endpoint(self, endpoint):
+ try:
+ file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME)
+ fileutil.write_file(file_path, endpoint)
+ except IOError as e:
+ raise OSUtilError(ustr(e))
+
+ def _detect_wire_protocol(self):
+ endpoint = self.dhcp_handler.endpoint
+ if endpoint is None:
+ logger.info("WireServer endpoint is not found. Rerun dhcp handler")
+ try:
+ self.dhcp_handler.run()
+ except DhcpError as e:
+ raise ProtocolError(ustr(e))
+ endpoint = self.dhcp_handler.endpoint
+
+ try:
+ protocol = WireProtocol(endpoint)
+ protocol.detect()
+ self._set_wireserver_endpoint(endpoint)
+ self.save_protocol("WireProtocol")
+ return protocol
+ except ProtocolError as e:
+ logger.info("WireServer is not responding. Reset endpoint")
+ self.dhcp_handler.endpoint = None
+ self.dhcp_handler.skip_cache = True
+ raise e
+
+ def _detect_metadata_protocol(self):
+ protocol = MetadataProtocol()
+ protocol.detect()
+
+ #Only allow root access METADATA_ENDPOINT
+ self.osutil.set_admin_access_to_ip(METADATA_ENDPOINT)
+
+ self.save_protocol("MetadataProtocol")
+
+ return protocol
+
+ def _detect_protocol(self, protocols):
+ """
+ Probe protocol endpoints in turn.
+ """
+ self.clear_protocol()
+
+ for retry in range(0, MAX_RETRY):
+ for protocol in protocols:
+ try:
+ if protocol == "WireProtocol":
+ return self._detect_wire_protocol()
+
+ if protocol == "MetadataProtocol":
+ return self._detect_metadata_protocol()
+
+ except ProtocolError as e:
+ logger.info("Protocol endpoint not found: {0}, {1}",
+ protocol, e)
+
+ if retry < MAX_RETRY -1:
+ logger.info("Retry detect protocols: retry={0}", retry)
+ time.sleep(PROBE_INTERVAL)
+ raise ProtocolNotFoundError("No protocol found.")
+
+ def _get_protocol(self):
+ """
+ Get protocol instance based on previous detecting result.
+ """
+ protocol_file_path = os.path.join(conf.get_lib_dir(),
+ PROTOCOL_FILE_NAME)
+ if not os.path.isfile(protocol_file_path):
+ raise ProtocolNotFoundError("No protocol found")
+
+ protocol_name = fileutil.read_file(protocol_file_path)
+ if protocol_name == "WireProtocol":
+ endpoint = self._get_wireserver_endpoint()
+ return WireProtocol(endpoint)
+ elif protocol_name == "MetadataProtocol":
+ return MetadataProtocol()
+ else:
+ raise ProtocolNotFoundError(("Unknown protocol: {0}"
+ "").format(protocol_name))
+
+ def save_protocol(self, protocol_name):
+ """
+ Save protocol endpoint
+ """
+ protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME)
+ try:
+ fileutil.write_file(protocol_file_path, protocol_name)
+ except IOError as e:
+ logger.error("Failed to save protocol endpoint: {0}", e)
+
+
+ def clear_protocol(self):
+ """
+ Cleanup previous saved endpoint.
+ """
+ logger.info("Clean protocol")
+ self.protocol = None
+ protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME)
+ if not os.path.isfile(protocol_file_path):
+ return
+
+ try:
+ os.remove(protocol_file_path)
+ except IOError as e:
+ logger.error("Failed to clear protocol endpoint: {0}", e)
+
+ def get_protocol(self):
+ """
+ Detect protocol by endpoints
+
+ :returns: protocol instance
+ """
+ self.lock.acquire()
+
+ try:
+ if self.protocol is not None:
+ return self.protocol
+
+ try:
+ self.protocol = self._get_protocol()
+ return self.protocol
+ except ProtocolNotFoundError:
+ pass
+
+ logger.info("Detect protocol endpoints")
+ protocols = ["WireProtocol", "MetadataProtocol"]
+ self.protocol = self._detect_protocol(protocols)
+
+ return self.protocol
+
+ finally:
+ self.lock.release()
+
+
+ def get_protocol_by_file(self):
+ """
+ Detect protocol by tag file.
+
+ If a file "useMetadataEndpoint.tag" is found on provision iso,
+ metedata protocol will be used. No need to probe for wire protocol
+
+ :returns: protocol instance
+ """
+ self.lock.acquire()
+
+ try:
+ if self.protocol is not None:
+ return self.protocol
+
+ try:
+ self.protocol = self._get_protocol()
+ return self.protocol
+ except ProtocolNotFoundError:
+ pass
+
+ logger.info("Detect protocol by file")
+ tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME)
+ protocols = []
+ if os.path.isfile(tag_file_path):
+ protocols.append("MetadataProtocol")
+ else:
+ protocols.append("WireProtocol")
+ self.protocol = self._detect_protocol(protocols)
+ return self.protocol
+
+ finally:
+ self.lock.release()