diff options
Diffstat (limited to 'azurelinuxagent/protocol/protocolFactory.py')
-rw-r--r-- | azurelinuxagent/protocol/protocolFactory.py | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/azurelinuxagent/protocol/protocolFactory.py b/azurelinuxagent/protocol/protocolFactory.py new file mode 100644 index 0000000..d2ca201 --- /dev/null +++ b/azurelinuxagent/protocol/protocolFactory.py @@ -0,0 +1,114 @@ +# Windows Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import traceback +import threading +import azurelinuxagent.logger as logger +from azurelinuxagent.future import text +import azurelinuxagent.utils.fileutil as fileutil +from azurelinuxagent.utils.osutil import OSUTIL +from azurelinuxagent.protocol.common import * +from azurelinuxagent.protocol.v1 import WireProtocol +from azurelinuxagent.protocol.v2 import MetadataProtocol + +WIRE_SERVER_ADDR_FILE_NAME = "WireServer" + +def get_wire_protocol_endpoint(): + path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME) + try: + endpoint = fileutil.read_file(path) + except IOError as e: + raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e)) + + if endpoint is None: + raise ProtocolNotFound("Wire server endpoint is None") + + return endpoint + +def detect_wire_protocol(): + endpoint = get_wire_protocol_endpoint() + + OSUTIL.gen_transport_cert() + protocol = WireProtocol(endpoint) + protocol.initialize() + logger.info("Protocol V1 found.") + return protocol + +def detect_metadata_protocol(): + protocol = MetadataProtocol() + protocol.initialize() + + logger.info("Protocol V2 found.") + return protocol + +def detect_available_protocols(prob_funcs=[detect_wire_protocol, + detect_metadata_protocol]): + available_protocols = [] + for probe_func in prob_funcs: + try: + protocol = probe_func() + available_protocols.append(protocol) + except ProtocolNotFound as e: + logger.info(text(e)) + return available_protocols + +def detect_default_protocol(): + logger.info("Detect default protocol.") + available_protocols = detect_available_protocols() + return choose_default_protocol(available_protocols) + +def choose_default_protocol(protocols): + if len(protocols) > 0: + return protocols[0] + else: + raise ProtocolNotFound("No available protocol detected.") + +def get_wire_protocol(): + endpoint = get_wire_protocol_endpoint() + return WireProtocol(endpoint) + +def get_metadata_protocol(): + return MetadataProtocol() + +def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]): + available_protocols = [] + for getter in getters: + try: + protocol = getter() + available_protocols.append(protocol) + except ProtocolNotFound as e: + logger.info(text(e)) + return available_protocols + +class ProtocolFactory(object): + def __init__(self): + self._protocol = None + self._lock = threading.Lock() + + def get_default_protocol(self): + if self._protocol is None: + self._lock.acquire() + if self._protocol is None: + available_protocols = get_available_protocols() + self._protocol = choose_default_protocol(available_protocols) + self._lock.release() + + return self._protocol + +FACTORY = ProtocolFactory() |