From a00729ff7421b3661e8b1a1e0fa46393379f2e96 Mon Sep 17 00:00:00 2001 From: Ben Howard Date: Mon, 8 Feb 2016 16:33:07 -0700 Subject: Import patches-unapplied version 2.1.3-0ubuntu1 to ubuntu/xenial-proposed Imported using git-ubuntu import. Changelog parent: 53f54030cae2de3d5fa474a61fe51f16c7a07c79 New changelog entries: * New upstream release (LP: #1543359): - Bug fixes for extension handling - Feature enablement for AzureStack. --- azurelinuxagent/agent.py | 94 +- azurelinuxagent/conf.py | 115 ++- azurelinuxagent/distro/centos/__init__.py | 19 - azurelinuxagent/distro/centos/loader.py | 25 - azurelinuxagent/distro/coreos/deprovision.py | 3 + azurelinuxagent/distro/coreos/distro.py | 29 + azurelinuxagent/distro/coreos/handlerFactory.py | 27 - azurelinuxagent/distro/coreos/loader.py | 28 - azurelinuxagent/distro/coreos/osutil.py | 5 +- azurelinuxagent/distro/debian/distro.py | 27 + azurelinuxagent/distro/default/daemon.py | 103 ++ azurelinuxagent/distro/default/deprovision.py | 31 +- azurelinuxagent/distro/default/dhcp.py | 194 ++-- azurelinuxagent/distro/default/distro.py | 51 + azurelinuxagent/distro/default/env.py | 41 +- azurelinuxagent/distro/default/extension.py | 943 +++++++++--------- azurelinuxagent/distro/default/handlerFactory.py | 40 - azurelinuxagent/distro/default/init.py | 22 +- azurelinuxagent/distro/default/loader.py | 28 - azurelinuxagent/distro/default/monitor.py | 182 ++++ azurelinuxagent/distro/default/osutil.py | 157 ++- azurelinuxagent/distro/default/protocolUtil.py | 243 +++++ azurelinuxagent/distro/default/provision.py | 158 +-- azurelinuxagent/distro/default/resourceDisk.py | 21 +- azurelinuxagent/distro/default/run.py | 71 -- azurelinuxagent/distro/default/scvmm.py | 11 +- azurelinuxagent/distro/loader.py | 71 +- azurelinuxagent/distro/oracle/__init__.py | 19 - azurelinuxagent/distro/oracle/loader.py | 25 - azurelinuxagent/distro/redhat/distro.py | 32 + azurelinuxagent/distro/redhat/loader.py | 28 - azurelinuxagent/distro/redhat/osutil.py | 78 +- azurelinuxagent/distro/suse/distro.py | 32 + azurelinuxagent/distro/suse/loader.py | 29 - azurelinuxagent/distro/ubuntu/deprovision.py | 3 + azurelinuxagent/distro/ubuntu/distro.py | 55 ++ azurelinuxagent/distro/ubuntu/handlerFactory.py | 29 - azurelinuxagent/distro/ubuntu/loader.py | 40 - azurelinuxagent/distro/ubuntu/osutil.py | 12 +- azurelinuxagent/distro/ubuntu/provision.py | 71 +- azurelinuxagent/event.py | 171 +--- azurelinuxagent/exception.py | 74 +- azurelinuxagent/future.py | 14 +- azurelinuxagent/handler.py | 28 - azurelinuxagent/logger.py | 8 +- azurelinuxagent/metadata.py | 6 +- azurelinuxagent/protocol/__init__.py | 5 - azurelinuxagent/protocol/common.py | 240 ----- azurelinuxagent/protocol/metadata.py | 195 ++++ azurelinuxagent/protocol/ovfenv.py | 46 +- azurelinuxagent/protocol/protocolFactory.py | 114 --- azurelinuxagent/protocol/restapi.py | 250 +++++ azurelinuxagent/protocol/v1.py | 1121 --------------------- azurelinuxagent/protocol/v2.py | 145 --- azurelinuxagent/protocol/wire.py | 1155 ++++++++++++++++++++++ azurelinuxagent/utils/cryptutil.py | 121 +++ azurelinuxagent/utils/fileutil.py | 7 +- azurelinuxagent/utils/osutil.py | 27 - azurelinuxagent/utils/restutil.py | 22 +- azurelinuxagent/utils/shellutil.py | 6 +- azurelinuxagent/utils/textutil.py | 8 + 61 files changed, 3668 insertions(+), 3287 deletions(-) delete mode 100644 azurelinuxagent/distro/centos/__init__.py delete mode 100644 azurelinuxagent/distro/centos/loader.py create mode 100644 azurelinuxagent/distro/coreos/distro.py delete mode 100644 azurelinuxagent/distro/coreos/handlerFactory.py delete mode 100644 azurelinuxagent/distro/coreos/loader.py create mode 100644 azurelinuxagent/distro/debian/distro.py create mode 100644 azurelinuxagent/distro/default/daemon.py create mode 100644 azurelinuxagent/distro/default/distro.py delete mode 100644 azurelinuxagent/distro/default/handlerFactory.py delete mode 100644 azurelinuxagent/distro/default/loader.py create mode 100644 azurelinuxagent/distro/default/monitor.py create mode 100644 azurelinuxagent/distro/default/protocolUtil.py delete mode 100644 azurelinuxagent/distro/default/run.py delete mode 100644 azurelinuxagent/distro/oracle/__init__.py delete mode 100644 azurelinuxagent/distro/oracle/loader.py create mode 100644 azurelinuxagent/distro/redhat/distro.py delete mode 100644 azurelinuxagent/distro/redhat/loader.py create mode 100644 azurelinuxagent/distro/suse/distro.py delete mode 100644 azurelinuxagent/distro/suse/loader.py create mode 100644 azurelinuxagent/distro/ubuntu/distro.py delete mode 100644 azurelinuxagent/distro/ubuntu/handlerFactory.py delete mode 100644 azurelinuxagent/distro/ubuntu/loader.py delete mode 100644 azurelinuxagent/handler.py delete mode 100644 azurelinuxagent/protocol/common.py create mode 100644 azurelinuxagent/protocol/metadata.py delete mode 100644 azurelinuxagent/protocol/protocolFactory.py create mode 100644 azurelinuxagent/protocol/restapi.py delete mode 100644 azurelinuxagent/protocol/v1.py delete mode 100644 azurelinuxagent/protocol/v2.py create mode 100644 azurelinuxagent/protocol/wire.py create mode 100644 azurelinuxagent/utils/cryptutil.py delete mode 100644 azurelinuxagent/utils/osutil.py (limited to 'azurelinuxagent') diff --git a/azurelinuxagent/agent.py b/azurelinuxagent/agent.py index 849a192..93e9c16 100644 --- a/azurelinuxagent/agent.py +++ b/azurelinuxagent/agent.py @@ -29,27 +29,60 @@ from azurelinuxagent.metadata import AGENT_NAME, AGENT_LONG_VERSION, \ DISTRO_NAME, DISTRO_VERSION, \ PY_VERSION_MAJOR, PY_VERSION_MINOR, \ PY_VERSION_MICRO -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.handler import HANDLERS +from azurelinuxagent.distro.loader import get_distro -def init(verbose): - """ - Initialize agent running environment. - """ - HANDLERS.init_handler.init(verbose) +class Agent(object): + def __init__(self, verbose): + """ + Initialize agent running environment. + """ + self.distro = get_distro(); + self.distro.init_handler.run(verbose) -def run(): - """ - Run agent daemon - """ - HANDLERS.main_handler.run() + def daemon(self): + """ + Run agent daemon + """ + self.distro.daemon_handler.run() + + def deprovision(self, force=False, deluser=False): + """ + Run deprovision command + """ + self.distro.deprovision_handler.run(force=force, deluser=deluser) -def deprovision(force=False, deluser=False): + def register_service(self): + """ + Register agent as a service + """ + print("Register {0} service".format(AGENT_NAME)) + self.distro.osutil.register_agent_service() + print("Start {0} service".format(AGENT_NAME)) + self.distro.osutil.start_agent_service() + +def main(): """ - Run deprovision command + Parse command line arguments, exit with usage() on error. + Invoke different methods according to different command """ - HANDLERS.deprovision_handler.deprovision(force=force, deluser=deluser) + command, force, verbose = parse_args(sys.argv[1:]) + if command == "version": + version() + elif command == "help": + usage() + elif command == "start": + start() + else: + agent = Agent(verbose) + if command == "deprovision+user": + agent.deprovision(force, deluser=True) + elif command == "deprovision": + agent.deprovision(force, deluser=False) + elif command == "register-service": + agent.register_service() + elif command == "daemon": + agent.daemon() def parse_args(sys_args): """ @@ -108,34 +141,3 @@ def start(): devnull = open(os.devnull, 'w') subprocess.Popen([sys.argv[0], '-daemon'], stdout=devnull, stderr=devnull) -def register_service(): - """ - Register agent as a service - """ - print("Register {0} service".format(AGENT_NAME)) - OSUTIL.register_agent_service() - print("Start {0} service".format(AGENT_NAME)) - OSUTIL.start_agent_service() - -def main(): - """ - Parse command line arguments, exit with usage() on error. - Invoke different methods according to different command - """ - command, force, verbose = parse_args(sys.argv[1:]) - if command == "version": - version() - elif command == "help": - usage() - else: - init(verbose) - if command == "deprovision+user": - deprovision(force, deluser=True) - elif command == "deprovision": - deprovision(force, deluser=False) - elif command == "start": - start() - elif command == "register-service": - register_service() - elif command == "daemon": - run() diff --git a/azurelinuxagent/conf.py b/azurelinuxagent/conf.py index 2b0eb01..7921e79 100644 --- a/azurelinuxagent/conf.py +++ b/azurelinuxagent/conf.py @@ -43,11 +43,11 @@ class ConfigurationProvider(object): else: self.values[parts[0]] = None - def get(self, key, default_val=None): + def get(self, key, default_val): val = self.values.get(key) return val if val is not None else default_val - def get_switch(self, key, default_val=False): + def get_switch(self, key, default_val): val = self.values.get(key) if val is not None and val.lower() == 'y': return True @@ -55,7 +55,7 @@ class ConfigurationProvider(object): return False return default_val - def get_int(self, key, default_val=-1): + def get_int(self, key, default_val): try: return int(self.values.get(key)) except TypeError: @@ -64,9 +64,9 @@ class ConfigurationProvider(object): return default_val -__config__ = ConfigurationProvider() +__conf__ = ConfigurationProvider() -def load_conf(conf_file_path, conf=__config__): +def load_conf_from_file(conf_file_path, conf=__conf__): """ Load conf file from: conf_file_path """ @@ -80,30 +80,87 @@ def load_conf(conf_file_path, conf=__config__): raise AgentConfigError(("Failed to load conf file:{0}, {1}" "").format(conf_file_path, err)) -def get(key, default_val=None, conf=__config__): - """ - Get option value by key, return default_val if not found - """ - if conf is not None: - return conf.get(key, default_val) - else: - return default_val +def get_logs_verbose(conf=__conf__): + return conf.get_switch("Logs.Verbose", False) -def get_switch(key, default_val=None, conf=__config__): - """ - Get bool option value by key, return default_val if not found - """ - if conf is not None: - return conf.get_switch(key, default_val) - else: - return default_val +def get_lib_dir(conf=__conf__): + return conf.get("Lib.Dir", "/var/lib/waagent") -def get_int(key, default_val=None, conf=__config__): - """ - Get int option value by key, return default_val if not found - """ - if conf is not None: - return conf.get_int(key, default_val) - else: - return default_val +def get_dvd_mount_point(conf=__conf__): + return conf.get("DVD.MountPoint", "/mnt/cdrom/secure") + +def get_agent_pid_file_path(conf=__conf__): + return conf.get("Pid.File", "/var/run/waagent.pid") + +def get_ext_log_dir(conf=__conf__): + return conf.get("Extension.LogDir", "/var/log/azure") + +def get_openssl_cmd(conf=__conf__): + return conf.get("OS.OpensslPath", "/usr/bin/openssl") + +def get_home_dir(conf=__conf__): + return conf.get("OS.HomeDir", "/home") + +def get_passwd_file_path(conf=__conf__): + return conf.get("OS.PasswordPath", "/etc/shadow") + +def get_sshd_conf_file_path(conf=__conf__): + return conf.get("OS.SshdConfigPath", "/etc/ssh/sshd_config") + +def get_root_device_scsi_timeout(conf=__conf__): + return conf.get("OS.RootDeviceScsiTimeout", None) + +def get_ssh_host_keypair_type(conf=__conf__): + return conf.get("Provisioning.SshHostKeyPairType", "rsa") + +def get_provision_enabled(conf=__conf__): + return conf.get_switch("Provisioning.Enabled", True) + +def get_allow_reset_sys_user(conf=__conf__): + return conf.get_switch("Provisioning.AllowResetSysUser", False) + +def get_regenerate_ssh_host_key(conf=__conf__): + return conf.get_switch("Provisioning.RegenerateSshHostKeyPair", False) + +def get_delete_root_password(conf=__conf__): + return conf.get_switch("Provisioning.DeleteRootPassword", False) + +def get_decode_customdata(conf=__conf__): + return conf.get_switch("Provisioning.DecodeCustomData", False) + +def get_execute_customdata(conf=__conf__): + return conf.get_switch("Provisioning.ExecuteCustomData", False) + +def get_password_cryptid(conf=__conf__): + return conf.get("Provisioning.PasswordCryptId", "6") + +def get_password_crypt_salt_len(conf=__conf__): + return conf.get_int("Provisioning.PasswordCryptSaltLength", 10) + +def get_monitor_hostname(conf=__conf__): + return conf.get_switch("Provisioning.MonitorHostName", False) + +def get_httpproxy_host(conf=__conf__): + return conf.get("HttpProxy.Host", None) + +def get_httpproxy_port(conf=__conf__): + return conf.get("HttpProxy.Port", None) + +def get_detect_scvmm_env(conf=__conf__): + return conf.get_switch("DetectScvmmEnv", False) + +def get_resourcedisk_format(conf=__conf__): + return conf.get_switch("ResourceDisk.Format", False) + +def get_resourcedisk_enable_swap(conf=__conf__): + return conf.get_switch("ResourceDisk.EnableSwap", False) + +def get_resourcedisk_mountpoint(conf=__conf__): + return conf.get("ResourceDisk.MountPoint", "/mnt/resource") + +def get_resourcedisk_filesystem(conf=__conf__): + return conf.get("ResourceDisk.Filesystem", "ext3") + +def get_resourcedisk_swap_size_mb(conf=__conf__): + return conf.get_int("ResourceDisk.SwapSizeMB", 0) diff --git a/azurelinuxagent/distro/centos/__init__.py b/azurelinuxagent/distro/centos/__init__.py deleted file mode 100644 index d9b82f5..0000000 --- a/azurelinuxagent/distro/centos/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - diff --git a/azurelinuxagent/distro/centos/loader.py b/azurelinuxagent/distro/centos/loader.py deleted file mode 100644 index 9dc428f..0000000 --- a/azurelinuxagent/distro/centos/loader.py +++ /dev/null @@ -1,25 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION -import azurelinuxagent.distro.redhat.loader as redhat - -def get_osutil(): - return redhat.get_osutil() - diff --git a/azurelinuxagent/distro/coreos/deprovision.py b/azurelinuxagent/distro/coreos/deprovision.py index 99d3a40..9642579 100644 --- a/azurelinuxagent/distro/coreos/deprovision.py +++ b/azurelinuxagent/distro/coreos/deprovision.py @@ -21,6 +21,9 @@ import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.distro.default.deprovision import DeprovisionHandler, DeprovisionAction class CoreOSDeprovisionHandler(DeprovisionHandler): + def __init__(self, distro): + self.distro = distro + def setup(self, deluser): warnings, actions = super(CoreOSDeprovisionHandler, self).setup(deluser) warnings.append("WARNING! /etc/machine-id will be removed.") diff --git a/azurelinuxagent/distro/coreos/distro.py b/azurelinuxagent/distro/coreos/distro.py new file mode 100644 index 0000000..04c7bff --- /dev/null +++ b/azurelinuxagent/distro/coreos/distro.py @@ -0,0 +1,29 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.coreos.osutil import CoreOSUtil +from azurelinuxagent.distro.coreos.deprovision import CoreOSDeprovisionHandler + +class CoreOSDistro(DefaultDistro): + def __init__(self): + super(CoreOSDistro, self).__init__() + self.osutil = CoreOSUtil() + self.deprovision_handler = CoreOSDeprovisionHandler(self) + diff --git a/azurelinuxagent/distro/coreos/handlerFactory.py b/azurelinuxagent/distro/coreos/handlerFactory.py deleted file mode 100644 index 58f476c..0000000 --- a/azurelinuxagent/distro/coreos/handlerFactory.py +++ /dev/null @@ -1,27 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from .deprovision import CoreOSDeprovisionHandler -from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory - -class CoreOSHandlerFactory(DefaultHandlerFactory): - def __init__(self): - super(CoreOSHandlerFactory, self).__init__() - self.deprovision_handler = CoreOSDeprovisionHandler() - diff --git a/azurelinuxagent/distro/coreos/loader.py b/azurelinuxagent/distro/coreos/loader.py deleted file mode 100644 index 802f276..0000000 --- a/azurelinuxagent/distro/coreos/loader.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - - -def get_osutil(): - from azurelinuxagent.distro.coreos.osutil import CoreOSUtil - return CoreOSUtil() - -def get_handlers(): - from azurelinuxagent.distro.coreos.handlerFactory import CoreOSHandlerFactory - return CoreOSHandlerFactory() - diff --git a/azurelinuxagent/distro/coreos/osutil.py b/azurelinuxagent/distro/coreos/osutil.py index c244311..ffc83e3 100644 --- a/azurelinuxagent/distro/coreos/osutil.py +++ b/azurelinuxagent/distro/coreos/osutil.py @@ -35,9 +35,9 @@ from azurelinuxagent.distro.default.osutil import DefaultOSUtil class CoreOSUtil(DefaultOSUtil): def __init__(self): super(CoreOSUtil, self).__init__() + self.agent_conf_file_path = '/usr/share/oem/waagent.conf' self.waagent_path='/usr/share/oem/bin/waagent' self.python_path='/usr/share/oem/python/bin' - self.conf_file_path = '/usr/share/oem/waagent.conf' if 'PATH' in os.environ: path = "{0}:{1}".format(os.environ['PATH'], self.python_path) else: @@ -85,9 +85,6 @@ class CoreOSUtil(DefaultOSUtil): ret= shellutil.run_get_output("pidof systemd-networkd") return ret[1] if ret[0] == 0 else None - def decode_customdata(self, data): - return base64.b64decode(data) - def set_ssh_client_alive_interval(self): #In CoreOS, /etc/sshd_config is mount readonly. Skip the setting pass diff --git a/azurelinuxagent/distro/debian/distro.py b/azurelinuxagent/distro/debian/distro.py new file mode 100644 index 0000000..01f4e3e --- /dev/null +++ b/azurelinuxagent/distro/debian/distro.py @@ -0,0 +1,27 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.debian.osutil import DebianOSUtil + +class DebianDistro(DefaultDistro): + def __init__(self): + super(DebianDistro, self).__init__() + self.osutil = DebianOSUtil() + diff --git a/azurelinuxagent/distro/default/daemon.py b/azurelinuxagent/distro/default/daemon.py new file mode 100644 index 0000000..cf9eb16 --- /dev/null +++ b/azurelinuxagent/distro/default/daemon.py @@ -0,0 +1,103 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import time +import sys +import traceback +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.future import ustr +from azurelinuxagent.event import add_event, WALAEventOperation +from azurelinuxagent.exception import ProtocolError +from azurelinuxagent.metadata import AGENT_LONG_NAME, AGENT_VERSION, \ + DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_FULL_NAME, PY_VERSION_MAJOR, \ + PY_VERSION_MINOR, PY_VERSION_MICRO +import azurelinuxagent.event as event +import azurelinuxagent.utils.fileutil as fileutil + + +class DaemonHandler(object): + def __init__(self, distro): + self.distro = distro + self.running = True + + + def run(self): + logger.info("{0} Version:{1}", AGENT_LONG_NAME, AGENT_VERSION) + logger.info("OS: {0} {1}", DISTRO_NAME, DISTRO_VERSION) + logger.info("Python: {0}.{1}.{2}", PY_VERSION_MAJOR, PY_VERSION_MINOR, + PY_VERSION_MICRO) + + self.check_pid() + + while self.running: + try: + self.daemon() + except Exception as e: + err_msg = traceback.format_exc() + add_event("WALA", is_success=False, message=ustr(err_msg), + op=WALAEventOperation.UnhandledError) + logger.info("Sleep 15 seconds and restart daemon") + time.sleep(15) + + def check_pid(self): + """Check whether daemon is already running""" + pid = None + pid_file = conf.get_agent_pid_file_path() + if os.path.isfile(pid_file): + pid = fileutil.read_file(pid_file) + + if pid is not None and os.path.isdir(os.path.join("/proc", pid)): + logger.info("Daemon is already running: {0}", pid) + sys.exit(0) + + fileutil.write_file(pid_file, ustr(os.getpid())) + + def daemon(self): + logger.info("Run daemon") + #Create lib dir + if not os.path.isdir(conf.get_lib_dir()): + fileutil.mkdir(conf.get_lib_dir(), mode=0o700) + os.chdir(conf.get_lib_dir()) + + if conf.get_detect_scvmm_env(): + if self.distro.scvmm_handler.run(): + return + + self.distro.provision_handler.run() + + if conf.get_resourcedisk_format(): + self.distro.resource_disk_handler.run() + + try: + protocol = self.distro.protocol_util.detect_protocol() + except ProtocolError as e: + logger.error("Failed to detect protocol, exit", e) + return + + self.distro.event_handler.run() + self.distro.env_handler.run() + + while self.running: + #Handle extensions + self.distro.ext_handlers_handler.run() + time.sleep(25) + diff --git a/azurelinuxagent/distro/default/deprovision.py b/azurelinuxagent/distro/default/deprovision.py index b62c5f6..4db4cdc 100644 --- a/azurelinuxagent/distro/default/deprovision.py +++ b/azurelinuxagent/distro/default/deprovision.py @@ -18,10 +18,8 @@ # import azurelinuxagent.conf as conf -from azurelinuxagent.utils.osutil import OSUTIL +from azurelinuxagent.exception import ProtocolError from azurelinuxagent.future import read_input -import azurelinuxagent.protocol as prot -import azurelinuxagent.protocol.ovfenv as ovf import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil @@ -35,18 +33,20 @@ class DeprovisionAction(object): self.func(*self.args, **self.kwargs) class DeprovisionHandler(object): + def __init__(self, distro): + self.distro = distro def del_root_password(self, warnings, actions): warnings.append("WARNING! root password will be disabled. " "You will not be able to login as root.") - actions.append(DeprovisionAction(OSUTIL.del_root_password)) + actions.append(DeprovisionAction(self.distro.osutil.del_root_password)) def del_user(self, warnings, actions): try: - ovfenv = ovf.get_ovf_env() - except prot.ProtocolError: + ovfenv = self.distro.protocol_util.get_ovf_env() + except ProtocolError: warnings.append("WARNING! ovf-env.xml is not found.") warnings.append("WARNING! Skip delete user.") return @@ -54,7 +54,8 @@ class DeprovisionHandler(object): username = ovfenv.username warnings.append(("WARNING! {0} account and entire home directory " "will be deleted.").format(username)) - actions.append(DeprovisionAction(OSUTIL.del_account, [username])) + actions.append(DeprovisionAction(self.distro.osutil.del_account, + [username])) def regen_ssh_host_key(self, warnings, actions): @@ -64,7 +65,7 @@ class DeprovisionHandler(object): def stop_agent_service(self, warnings, actions): warnings.append("WARNING! The waagent service will be stopped.") - actions.append(DeprovisionAction(OSUTIL.stop_agent_service)) + actions.append(DeprovisionAction(self.distro.osutil.stop_agent_service)) def del_files(self, warnings, actions): files_to_del = ['/root/.bash_history', '/var/log/waagent.log'] @@ -76,26 +77,28 @@ class DeprovisionHandler(object): actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del)) def del_lib_dir(self, warnings, actions): - dirs_to_del = [OSUTIL.get_lib_dir()] + dirs_to_del = [conf.get_lib_dir()] actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del)) def reset_hostname(self, warnings, actions): localhost = ["localhost.localdomain"] - actions.append(DeprovisionAction(OSUTIL.set_hostname, localhost)) - actions.append(DeprovisionAction(OSUTIL.set_dhcp_hostname, localhost)) + actions.append(DeprovisionAction(self.distro.osutil.set_hostname, + localhost)) + actions.append(DeprovisionAction(self.distro.osutil.set_dhcp_hostname, + localhost)) def setup(self, deluser): warnings = [] actions = [] self.stop_agent_service(warnings, actions) - if conf.get_switch("Provisioning.RegenerateSshHostkey", False): + if conf.get_regenerate_ssh_host_key(): self.regen_ssh_host_key(warnings, actions) self.del_dhcp_lease(warnings, actions) self.reset_hostname(warnings, actions) - if conf.get_switch("Provisioning.DeleteRootPassword", False): + if conf.get_delete_root_password(): self.del_root_password(warnings, actions) self.del_lib_dir(warnings, actions) @@ -106,7 +109,7 @@ class DeprovisionHandler(object): return warnings, actions - def deprovision(self, force=False, deluser=False): + def run(self, force=False, deluser=False): warnings, actions = self.setup(deluser) for warning in warnings: print(warning) diff --git a/azurelinuxagent/distro/default/dhcp.py b/azurelinuxagent/distro/default/dhcp.py index 4fd23ef..fc439d2 100644 --- a/azurelinuxagent/distro/default/dhcp.py +++ b/azurelinuxagent/distro/default/dhcp.py @@ -19,61 +19,106 @@ import os import socket import array import time +import threading import azurelinuxagent.logger as logger -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.exception import AgentNetworkError +import azurelinuxagent.conf as conf import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil -from azurelinuxagent.utils.textutil import * +from azurelinuxagent.utils.textutil import hex_dump, hex_dump2, hex_dump3, \ + compare_bytes, str_to_ord, \ + unpack_big_endian, \ + unpack_little_endian, \ + int_to_ip4_addr +from azurelinuxagent.exception import DhcpError -WIRE_SERVER_ADDR_FILE_NAME="WireServer" class DhcpHandler(object): - def __init__(self): + """ + Azure use DHCP option 245 to pass endpoint ip to VMs. + """ + def __init__(self, distro): + self.distro = distro self.endpoint = None self.gateway = None self.routes = None + def run(self): + """ + Send dhcp request + Configure default gateway and routes + Save wire server endpoint if found + """ + self.send_dhcp_req() + self.conf_routes() + def wait_for_network(self): - ipv4 = OSUTIL.get_ip4_addr() + """ + Wait for network stack to be initialized. + """ + ipv4 = self.distro.osutil.get_ip4_addr() while ipv4 == '' or ipv4 == '0.0.0.0': logger.info("Waiting for network.") time.sleep(10) - OSUTIL.start_network() - ipv4 = OSUTIL.get_ip4_addr() - - def probe(self): - logger.info("Send dhcp request") - self.wait_for_network() - mac_addr = OSUTIL.get_mac_addr() - req = build_dhcp_request(mac_addr) - resp = send_dhcp_request(req) - if resp is None: - logger.warn("Failed to detect wire server.") - return - endpoint, gateway, routes = parse_dhcp_resp(resp) - self.endpoint = endpoint - logger.info("Wire server endpoint:{0}", endpoint) - logger.info("Gateway:{0}", gateway) - logger.info("Routes:{0}", routes) - if endpoint is not None: - path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME) - fileutil.write_file(path, endpoint) - self.gateway = gateway - self.routes = routes - self.conf_routes() - - def get_endpoint(self): - return self.endpoint + logger.info("Try to start network interface.") + self.distro.osutil.start_network() + ipv4 = self.distro.osutil.get_ip4_addr() def conf_routes(self): logger.info("Configure routes") + logger.info("Gateway:{0}", self.gateway) + logger.info("Routes:{0}", self.routes) #Add default gateway if self.gateway is not None: - OSUTIL.route_add(0 , 0, self.gateway) + self.distro.osutil.route_add(0 , 0, self.gateway) if self.routes is not None: for route in self.routes: - OSUTIL.route_add(route[0], route[1], route[2]) + self.distro.osutil.route_add(route[0], route[1], route[2]) + + def _send_dhcp_req(self, request): + __waiting_duration__ = [0, 10, 30, 60, 60] + for duration in __waiting_duration__: + try: + self.distro.osutil.allow_dhcp_broadcast() + response = socket_send(request) + validate_dhcp_resp(request, response) + return response + except DhcpError as e: + logger.warn("Failed to send DHCP request: {0}", e) + time.sleep(duration) + return None + + def send_dhcp_req(self): + """ + Build dhcp request with mac addr + Configure route to allow dhcp traffic + Stop dhcp service if necessary + """ + logger.info("Send dhcp request") + mac_addr = self.distro.osutil.get_mac_addr() + req = build_dhcp_request(mac_addr) + + # Temporary allow broadcast for dhcp. Remove the route when done. + missing_default_route = self.distro.osutil.is_missing_default_route() + ifname = self.distro.osutil.get_if_name() + if missing_default_route: + self.distro.osutil.set_route_for_dhcp_broadcast(ifname) + + # In some distros, dhcp service needs to be shutdown before agent probe + # endpoint through dhcp. + if self.distro.osutil.is_dhcp_enabled(): + self.distro.osutil.stop_dhcp_service() + + resp = self._send_dhcp_req(req) + + if self.distro.osutil.is_dhcp_enabled(): + self.distro.osutil.start_dhcp_service() + + if missing_default_route: + self.distro.osutil.remove_route_for_dhcp_broadcast(ifname) + + if resp is None: + raise DhcpError("Failed to receive dhcp response.") + self.endpoint, self.gateway, self.routes = parse_dhcp_resp(resp) def validate_dhcp_resp(request, response): bytes_recv = len(response) @@ -92,28 +137,25 @@ def validate_dhcp_resp(request, response): logger.verb("Cookie not match:\nsend={0},\nreceive={1}", hex_dump3(request, 0xEC, 4), hex_dump3(response, 0xEC, 4)) - raise AgentNetworkError("Cookie in dhcp respones " - "doesn't match the request") + raise DhcpError("Cookie in dhcp respones doesn't match the request") if not compare_bytes(request, response, 4, 4): logger.verb("TransactionID not match:\nsend={0},\nreceive={1}", hex_dump3(request, 4, 4), hex_dump3(response, 4, 4)) - raise AgentNetworkError("TransactionID in dhcp respones " - "doesn't match the request") + raise DhcpError("TransactionID in dhcp respones " + "doesn't match the request") if not compare_bytes(request, response, 0x1C, 6): logger.verb("Mac Address not match:\nsend={0},\nreceive={1}", hex_dump3(request, 0x1C, 6), hex_dump3(response, 0x1C, 6)) - raise AgentNetworkError("Mac Addr in dhcp respones " - "doesn't match the request") + raise DhcpError("Mac Addr in dhcp respones " + "doesn't match the request") def parse_route(response, option, i, length, bytes_recv): # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx - logger.verb("Routes at offset: {0} with length:{1}", - hex(i), - hex(length)) + logger.verb("Routes at offset: {0} with length:{1}", hex(i), hex(length)) routes = [] if length < 5: logger.error("Data too small for option:{0}", option) @@ -169,9 +211,7 @@ def parse_dhcp_resp(response): if (i + 1) < bytes_recv: length = str_to_ord(response[i + 1]) logger.verb("DHCP option {0} at offset:{1} with length:{2}", - hex(option), - hex(i), - hex(length)) + hex(option), hex(i), hex(length)) if option == 255: logger.verb("DHCP packet ended at offset:{0}", hex(i)) break @@ -179,69 +219,17 @@ def parse_dhcp_resp(response): routes = parse_route(response, option, i, length, bytes_recv) elif option == 3: gateway = parse_ip_addr(response, option, i, length, bytes_recv) - logger.verb("Default gateway:{0}, at {1}", - gateway, - hex(i)) + logger.verb("Default gateway:{0}, at {1}", gateway, hex(i)) elif option == 245: endpoint = parse_ip_addr(response, option, i, length, bytes_recv) - logger.verb("Azure wire protocol endpoint:{0}, at {1}", - gateway, - hex(i)) + logger.verb("Azure wire protocol endpoint:{0}, at {1}", gateway, + hex(i)) else: logger.verb("Skipping DHCP option:{0} at {1} with length {2}", - hex(option), - hex(i), - hex(length)) + hex(option), hex(i), hex(length)) i += length + 2 return endpoint, gateway, routes - -def allow_dhcp_broadcast(func): - """ - Temporary allow broadcase for dhcp. Remove the route when done. - """ - def wrapper(*args, **kwargs): - missing_default_route = OSUTIL.is_missing_default_route() - ifname = OSUTIL.get_if_name() - if missing_default_route: - OSUTIL.set_route_for_dhcp_broadcast(ifname) - result = func(*args, **kwargs) - if missing_default_route: - OSUTIL.remove_route_for_dhcp_broadcast(ifname) - return result - return wrapper - -def disable_dhcp_service(func): - """ - In some distros, dhcp service needs to be shutdown before agent probe - endpoint through dhcp. - """ - def wrapper(*args, **kwargs): - if OSUTIL.is_dhcp_enabled(): - OSUTIL.stop_dhcp_service() - result = func(*args, **kwargs) - OSUTIL.start_dhcp_service() - return result - else: - return func(*args, **kwargs) - return wrapper - - -@allow_dhcp_broadcast -@disable_dhcp_service -def send_dhcp_request(request): - __waiting_duration__ = [0, 10, 30, 60, 60] - for duration in __waiting_duration__: - try: - OSUTIL.allow_dhcp_broadcast() - response = socket_send(request) - validate_dhcp_resp(request, response) - return response - except AgentNetworkError as e: - logger.warn("Failed to send DHCP request: {0}", e) - time.sleep(duration) - return None - def socket_send(request): sock = None try: @@ -257,7 +245,7 @@ def socket_send(request): response = sock.recv(1024) return response except IOError as e: - raise AgentNetworkError("{0}".format(e)) + raise DhcpError("{0}".format(e)) finally: if sock is not None: sock.close() diff --git a/azurelinuxagent/distro/default/distro.py b/azurelinuxagent/distro/default/distro.py new file mode 100644 index 0000000..ca0d77e --- /dev/null +++ b/azurelinuxagent/distro/default/distro.py @@ -0,0 +1,51 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.conf import ConfigurationProvider +from azurelinuxagent.distro.default.osutil import DefaultOSUtil +from azurelinuxagent.distro.default.daemon import DaemonHandler +from azurelinuxagent.distro.default.init import InitHandler +from azurelinuxagent.distro.default.monitor import MonitorHandler +from azurelinuxagent.distro.default.dhcp import DhcpHandler +from azurelinuxagent.distro.default.protocolUtil import ProtocolUtil +from azurelinuxagent.distro.default.scvmm import ScvmmHandler +from azurelinuxagent.distro.default.env import EnvHandler +from azurelinuxagent.distro.default.provision import ProvisionHandler +from azurelinuxagent.distro.default.resourceDisk import ResourceDiskHandler +from azurelinuxagent.distro.default.extension import ExtHandlersHandler +from azurelinuxagent.distro.default.deprovision import DeprovisionHandler + +class DefaultDistro(object): + """ + """ + def __init__(self): + self.osutil = DefaultOSUtil() + self.protocol_util = ProtocolUtil(self) + + self.init_handler = InitHandler(self) + self.daemon_handler = DaemonHandler(self) + self.event_handler = MonitorHandler(self) + self.dhcp_handler = DhcpHandler(self) + self.scvmm_handler = ScvmmHandler(self) + self.env_handler = EnvHandler(self) + self.provision_handler = ProvisionHandler(self) + self.resource_disk_handler = ResourceDiskHandler(self) + self.ext_handlers_handler = ExtHandlersHandler(self) + self.deprovision_handler = DeprovisionHandler(self) + diff --git a/azurelinuxagent/distro/default/env.py b/azurelinuxagent/distro/default/env.py index 28bf718..7878cff 100644 --- a/azurelinuxagent/distro/default/env.py +++ b/azurelinuxagent/distro/default/env.py @@ -23,7 +23,6 @@ import threading import time import azurelinuxagent.logger as logger import azurelinuxagent.conf as conf -from azurelinuxagent.utils.osutil import OSUTIL class EnvHandler(object): """ @@ -31,35 +30,25 @@ class EnvHandler(object): If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. Monitor scsi disk. - If new scsi disk found, set + If new scsi disk found, set timeout """ - def __init__(self, handlers): - self.monitor = EnvMonitor(handlers.dhcp_handler) - - def start(self): - self.monitor.start() - - def stop(self): - self.monitor.stop() - -class EnvMonitor(object): - - def __init__(self, dhcp_handler): - self.dhcp_handler = dhcp_handler + def __init__(self, distro): + self.distro = distro self.stopped = True self.hostname = None self.dhcpid = None self.server_thread=None - def start(self): + def run(self): if not self.stopped: logger.info("Stop existing env monitor service.") self.stop() self.stopped = False logger.info("Start env monitor service.") + self.distro.dhcp_handler.conf_routes() self.hostname = socket.gethostname() - self.dhcpid = OSUTIL.get_dhcp_pid() + self.dhcpid = self.distro.osutil.get_dhcp_pid() self.server_thread = threading.Thread(target = self.monitor) self.server_thread.setDaemon(True) self.server_thread.start() @@ -70,11 +59,11 @@ class EnvMonitor(object): If dhcp clinet process re-start has occurred, reset routes. """ while not self.stopped: - OSUTIL.remove_rules_files() - timeout = conf.get("OS.RootDeviceScsiTimeout", None) + self.distro.osutil.remove_rules_files() + timeout = conf.get_root_device_scsi_timeout() if timeout is not None: - OSUTIL.set_scsi_disks_timeout(timeout) - if conf.get_switch("Provisioning.MonitorHostName", False): + self.distro.osutil.set_scsi_disks_timeout(timeout) + if conf.get_monitor_hostname(): self.handle_hostname_update() self.handle_dhclient_restart() time.sleep(5) @@ -84,25 +73,25 @@ class EnvMonitor(object): if curr_hostname != self.hostname: logger.info("EnvMonitor: Detected host name change: {0} -> {1}", self.hostname, curr_hostname) - OSUTIL.set_hostname(curr_hostname) - OSUTIL.publish_hostname(curr_hostname) + self.distro.osutil.set_hostname(curr_hostname) + self.distro.osutil.publish_hostname(curr_hostname) self.hostname = curr_hostname def handle_dhclient_restart(self): if self.dhcpid is None: logger.warn("Dhcp client is not running. ") - self.dhcpid = OSUTIL.get_dhcp_pid() + self.dhcpid = self.distro.osutil.get_dhcp_pid() return #The dhcp process hasn't changed since last check if os.path.isdir(os.path.join('/proc', self.dhcpid.strip())): return - newpid = OSUTIL.get_dhcp_pid() + newpid = self.distro.osutil.get_dhcp_pid() if newpid is not None and newpid != self.dhcpid: logger.info("EnvMonitor: Detected dhcp client restart. " "Restoring routing table.") - self.dhcp_handler.conf_routes() + self.distro.dhcp_handler.conf_routes() self.dhcpid = newpid def stop(self): diff --git a/azurelinuxagent/distro/default/extension.py b/azurelinuxagent/distro/default/extension.py index f6c02aa..82cdfed 100644 --- a/azurelinuxagent/distro/default/extension.py +++ b/azurelinuxagent/distro/default/extension.py @@ -22,13 +22,16 @@ import time import json import subprocess import shutil +import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.protocol as prot -from azurelinuxagent.metadata import AGENT_VERSION from azurelinuxagent.event import add_event, WALAEventOperation -from azurelinuxagent.exception import ExtensionError +from azurelinuxagent.exception import ExtensionError, ProtocolError, HttpError +from azurelinuxagent.future import ustr +from azurelinuxagent.metadata import AGENT_VERSION +from azurelinuxagent.protocol.restapi import ExtHandlerStatus, ExtensionStatus, \ + ExtensionSubStatus, Extension, \ + VMStatus, ExtHandler, \ + get_properties, set_properties import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.restutil as restutil import azurelinuxagent.utils.shellutil as shellutil @@ -41,15 +44,6 @@ VALID_EXTENSION_STATUS = ['transitioning', 'error', 'success', 'warning'] VALID_HANDLER_STATUS = ['Ready', 'NotReady', "Installing", "Unresponsive"] -def handler_state_to_status(handler_state): - if handler_state == "Enabled": - return "Ready" - elif handler_state in VALID_HANDLER_STATUS: - return handler_state - else: - return "NotReady" - - def validate_has_key(obj, key, fullname): if key not in obj: raise ExtensionError("Missing: {0}".format(fullname)) @@ -64,14 +58,13 @@ def parse_formatted_message(formatted_message): validate_has_key(formatted_message, 'lang', 'formattedMessage/lang') validate_has_key(formatted_message, 'message', 'formattedMessage/message') return formatted_message.get('message') - def parse_ext_substatus(substatus): #Check extension sub status format validate_has_key(substatus, 'status', 'substatus/status') validate_in_range(substatus['status'], VALID_EXTENSION_STATUS, 'substatus/status') - status = prot.ExtensionSubStatus() + status = ExtensionSubStatus() status.name = substatus.get('name') status.status = substatus.get('status') status.code = substatus.get('code', 0) @@ -105,333 +98,330 @@ def parse_ext_status(ext_status, data): for substatus in substatus_list: ext_status.substatusList.append(parse_ext_substatus(substatus)) -def parse_extension_dirname(dirname): - """ - Parse installed extension dir name. Sample: ExtensionName-Version/ - """ - seprator = dirname.rfind('-') - if seprator < 0: - raise ExtensionError("Invalid extenation dir name") - return dirname[0:seprator], dirname[seprator + 1:] - -def get_installed_version(target_name): - """ - Return the highest version instance with the same name - """ - installed_version = None - lib_dir = OSUTIL.get_lib_dir() - for dir_name in os.listdir(lib_dir): - path = os.path.join(lib_dir, dir_name) - if os.path.isdir(path) and dir_name.startswith(target_name): - name, version = parse_extension_dirname(dir_name) - #Here we need to ensure names are exactly the same. - if name == target_name: - if installed_version is None or \ - Version(installed_version) < Version(version): - installed_version = version - return installed_version - class ExtHandlerState(object): + NotInstalled = "NotInstalled" + Installed = "Installed" Enabled = "Enabled" - Disabled = "Disabled" - Failed = "Failed" - class ExtHandlersHandler(object): - - def process(self): + def __init__(self, distro): + self.distro = distro + self.ext_handlers = None + self.last_etag = None + self.log_report = False + + def run(self): + ext_handlers, etag = None, None try: - protocol = prot.FACTORY.get_default_protocol() - ext_handlers = protocol.get_ext_handlers() - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message = text(e)) + self.protocol = self.distro.protocol_util.get_protocol() + ext_handlers, etag = self.protocol.get_ext_handlers() + except ProtocolError as e: + add_event(name="WALA", is_success=False, message=ustr(e)) return - - vm_status = prot.VMStatus() + if self.last_etag is not None and self.last_etag == etag: + logger.verb("No change to ext handler config:{0}, skip", etag) + self.log_report = False + else: + logger.info("Handle new ext handler config") + self.log_report = True #Log status report success on new config + self.handle_ext_handlers(ext_handlers) + self.last_etag = etag + + self.report_ext_handlers_status(ext_handlers) + + def handle_ext_handlers(self, ext_handlers): + if ext_handlers.extHandlers is None or \ + len(ext_handlers.extHandlers) == 0: + logger.info("No ext handler config found") + return + + for ext_handler in ext_handlers.extHandlers: + #TODO handle install in sequence, enable in parallel + self.handle_ext_handler(ext_handler) + + def handle_ext_handler(self, ext_handler): + ext_handler_i = ExtHandlerInstance(ext_handler, self.protocol) + try: + state = ext_handler.properties.state + ext_handler_i.logger.info("Expected handler state: {0}", state) + if state == "enabled": + self.handle_enable(ext_handler_i) + elif state == u"disabled": + self.handle_disable(ext_handler_i) + elif state == u"uninstall": + self.handle_uninstall(ext_handler_i) + else: + message = u"Unknown ext handler state:{0}".format(state) + raise ExtensionError(message) + except ExtensionError as e: + ext_handler_i.set_handler_status(message=ustr(e), code=-1) + ext_handler_i.report_event(message=ustr(e), is_success=False) + + def handle_enable(self, ext_handler_i): + + ext_handler_i.decide_version() + + old_ext_handler_i = ext_handler_i.get_installed_ext_handler() + if old_ext_handler_i is not None and \ + old_ext_handler_i.version_gt(ext_handler_i): + raise ExtensionError(u"Downgrade not allowed") + + handler_state = ext_handler_i.get_handler_state() + ext_handler_i.logger.info("Current handler state is: {0}", handler_state) + if handler_state == ExtHandlerState.NotInstalled: + ext_handler_i.set_handler_state(ExtHandlerState.NotInstalled) + + ext_handler_i.download() + + ext_handler_i.update_settings() + + if old_ext_handler_i is None: + ext_handler_i.install() + elif ext_handler_i.version_gt(old_ext_handler_i): + old_ext_handler_i.disable() + ext_handler_i.copy_status_files(old_ext_handler_i) + ext_handler_i.update() + old_ext_handler_i.uninstall() + old_ext_handler_i.rm_ext_handler_dir() + ext_handler_i.update_with_install() + else: + ext_handler_i.update_settings() + + ext_handler_i.enable() + + def handle_disable(self, ext_handler_i): + handler_state = ext_handler_i.get_handler_state() + ext_handler_i.logger.info("Current handler state is: {0}", handler_state) + if handler_state == ExtHandlerState.Enabled: + ext_handler_i.disable() + + def handle_uninstall(self, ext_handler_i): + handler_state = ext_handler_i.get_handler_state() + ext_handler_i.logger.info("Current handler state is: {0}", handler_state) + if handler_state != ExtHandlerState.NotInstalled: + if handler_state == ExtHandlerState.Enabled: + ext_handler_i.disable() + ext_handler_i.uninstall() + ext_handler_i.rm_ext_handler_dir() + + def report_ext_handlers_status(self, ext_handlers): + """Go thru handler_state dir, collect and report status""" + vm_status = VMStatus() vm_status.vmAgent.version = AGENT_VERSION vm_status.vmAgent.status = "Ready" vm_status.vmAgent.message = "Guest Agent is running" - if ext_handlers.extHandlers is None or \ - len(ext_handlers.extHandlers) == 0: - logger.verb("No extensions to handle") - else: + if ext_handlers is not None: for ext_handler in ext_handlers.extHandlers: - #TODO handle extension in parallel try: - pkg_list = protocol.get_ext_handler_pkgs(ext_handler) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message=text(e)) - continue - - handler_status = self.process_extension(ext_handler, pkg_list) - if handler_status is not None: - vm_status.vmAgent.extensionHandlers.append(handler_status) - + self.report_ext_handler_status(vm_status, ext_handler) + except ExtensionError as e: + add_event(name="WALA", is_success=False, message=ustr(e)) + + logger.verb("Report vm agent status") + try: - logger.verb("Report vm agent status") - protocol.report_vm_status(vm_status) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message = text(e)) - - def process_extension(self, ext_handler, pkg_list): - installed_version = get_installed_version(ext_handler.name) - if installed_version is not None: - handler = ExtHandlerInstance(ext_handler, pkg_list, - installed_version, installed=True) - else: - handler = ExtHandlerInstance(ext_handler, pkg_list, - ext_handler.properties.version) - handler.handle() + self.protocol.report_vm_status(vm_status) + except ProtocolError as e: + message = "Failed to report vm agent status: {0}".format(e) + add_event(name="WALA", is_success=False, message=message) + + if self.log_report: + logger.info("Successfully reported vm agent status") + + + def report_ext_handler_status(self, vm_status, ext_handler): + ext_handler_i = ExtHandlerInstance(ext_handler, self.protocol) - if handler.ext_status is not None: + handler_status = ext_handler_i.get_handler_status() + if handler_status is None: + return + + handler_state = ext_handler_i.get_handler_state() + if handler_state != ExtHandlerState.NotInstalled: try: - protocol = prot.FACTORY.get_default_protocol() - protocol.report_ext_status(handler.name, handler.ext.name, - handler.ext_status) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message=text(e)) - - return handler.handler_status + active_exts = ext_handler_i.report_ext_status() + handler_status.extensions.extend(active_exts) + except ExtensionError as e: + ext_handler_i.set_handler_status(message=ustr(e), code=-1) + + try: + heartbeat = ext_handler_i.collect_heartbeat() + if heartbeat is not None: + handler_status.status = heartbeat.get('status') + except ExtensionError as e: + ext_handler_i.set_handler_status(message=ustr(e), code=-1) + vm_status.vmAgent.extensionHandlers.append(handler_status) + class ExtHandlerInstance(object): - def __init__(self, ext_handler, pkg_list, curr_version, installed=False): + def __init__(self, ext_handler, protocol): self.ext_handler = ext_handler - self.name = ext_handler.name - self.version = ext_handler.properties.version - self.pkg_list = pkg_list - self.state = ext_handler.properties.state - self.update_policy = ext_handler.properties.upgradePolicy - - self.curr_version = curr_version - self.installed = installed - self.handler_state = None - self.lib_dir = OSUTIL.get_lib_dir() - - self.ext_status = prot.ExtensionStatus() - self.handler_status = prot.ExtHandlerStatus() - self.handler_status.name = self.name - self.handler_status.version = self.curr_version - - #Currently, extension settings will have no more than 1 instance - if len(ext_handler.properties.extensions) > 0: - self.ext = ext_handler.properties.extensions[0] - self.handler_status.extensions = [self.ext.name] - else: - #When no extension settings, set sequenceNumber to 0 - self.ext = prot.Extension(sequenceNumber=0) - self.ext_status.sequenceNumber = self.ext.sequenceNumber + self.protocol = protocol + self.operation = None + self.pkg = None prefix = "[{0}]".format(self.get_full_name()) self.logger = logger.Logger(logger.DEFAULT_LOGGER, prefix) + + try: + fileutil.mkdir(self.get_log_dir(), mode=0o744) + except IOError as e: + self.logger.error(u"Failed to create extension log dir: {0}", e) - def init_logger(self): - #Init logger appender for extension - fileutil.mkdir(self.get_log_dir(), mode=0o644) log_file = os.path.join(self.get_log_dir(), "CommandExecution.log") self.logger.add_appender(logger.AppenderType.FILE, logger.LogLevel.INFO, log_file) - def handle(self): - self.init_logger() - self.logger.verb("Start processing extension handler") - - try: - self.handle_state() - except ExtensionError as e: - self.set_state_err(text(e)) - self.report_event(is_success=False, message=text(e)) - self.logger.error("Failed to process extension handler") - return - - try: - if self.installed: - self.collect_ext_status() - self.collect_handler_status() - except ExtensionError as e: - self.report_event(is_success=False, message=text(e)) - self.logger.error("Failed to get extension handler status") - return - - self.logger.verb("Finished processing extension handler") - - def handle_state(self): - if self.installed: - self.handler_state = self.get_state() - - self.handler_status.status = handler_state_to_status(self.handler_state) - self.logger.verb("Handler state: {0}", self.handler_state) - self.logger.verb("Sequence number: {0}", self.ext.sequenceNumber) - - if self.state == 'enabled': - if self.handler_state == ExtHandlerState.Failed: - self.logger.verb("Found previous failure, quit handle_enable") - return - - if self.handler_state == ExtHandlerState.Enabled: - self.logger.verb("Already enabled with sequenceNumber: {0}", - self.ext.sequenceNumber) - self.logger.verb("Quit handle_enable") - return + def decide_version(self): + """ + If auto-upgrade, get the largest public extension version under + the requested major version family of currently installed plugin version - try: - new = self.handle_enable() - if new is not None: - #Upgrade happened - new.set_state(ExtHandlerState.Enabled) - else: - self.set_state(ExtHandlerState.Enabled) + Else, get the highest hot-fix for requested version, + """ + self.logger.info("Decide which version to use") + try: + pkg_list = self.protocol.get_ext_handler_pkgs(self.ext_handler) + except ProtocolError as e: + raise ExtensionError("Failed to get ext handler pkgs", e) - except ExtensionError as e: - self.set_state(ExtHandlerState.Failed) - raise e - elif self.state == 'disabled': - if self.handler_state == ExtHandlerState.Failed: - self.logger.verb("Found previous failure, quit handle_disable") - return - - if self.handler_state == ExtHandlerState.Disabled: - self.logger.verb("Already disabled with sequenceNumber: {0}", - self.ext.sequenceNumber) - self.logger.verb("Quit handle_disable") - return + version = self.ext_handler.properties.version + update_policy = self.ext_handler.properties.upgradePolicy + + version_frag = version.split('.') + if len(version_frag) < 2: + raise ExtensionError("Wrong version format: {0}".format(version)) - try: - self.handle_disable() - self.set_state(ExtHandlerState.Disabled) - except ExtensionError as e: - self.set_state(ExtHandlerState.Failed) - raise e - elif self.state == 'uninstall': - try: - self.handle_uninstall() - except ExtensionError as e: - self.set_state(ExtHandlerState.Failed) - raise e + version_prefix = None + if update_policy is not None and update_policy == 'auto': + version_prefix = "{0}.".format(version_frag[0]) else: - raise ExtensionError("Unknown state:{0}".format(self.state)) - - def handle_enable(self): - target_version = self.get_target_version() - self.logger.info("Target version: {0}", target_version) - if self.installed: - if Version(target_version) > Version(self.curr_version): - return self.upgrade(target_version) - elif Version(target_version) == Version(self.curr_version): - self.enable() - else: - raise ExtensionError("A newer version is already installed") - else: - if Version(target_version) > Version(self.version): - #This will happen when auto upgrade policy is enabled - self.logger.info("Auto upgrade to new version:{0}", - target_version) - self.curr_version = target_version - self.download() - self.init_dir() - self.install() - self.enable() + version_prefix = "{0}.{1}.".format(version_frag[0], version_frag[1]) + + packages = [x for x in pkg_list.versions \ + if x.version.startswith(version_prefix) or \ + x.version == version] + + packages = sorted(packages, key=lambda x: Version(x.version), + reverse=True) - def handle_disable(self): - if not self.installed: - self.logger.verb("Not installed, quit disable") - return + if len(packages) <= 0: + raise ExtensionError("Failed to find and valid extension package") + self.pkg = packages[0] + self.ext_handler.properties.version = packages[0].version + self.logger.info("Use version: {0}", self.pkg.version) + + def version_gt(self, other): + self_version = self.ext_handler.properties.version + other_version = other.ext_handler.properties.version + return Version(self_version) > Version(other_version) + + def get_installed_ext_handler(self): + lastest_version = None + ext_handler_name = self.ext_handler.name + + for dir_name in os.listdir(conf.get_lib_dir()): + path = os.path.join(conf.get_lib_dir(), dir_name) + if os.path.isdir(path) and dir_name.startswith(ext_handler_name): + seperator = dir_name.rfind('-') + if seperator < 0: + continue + installed_name = dir_name[0: seperator] + installed_version = dir_name[seperator + 1:] + if installed_name != ext_handler_name: + continue + if lastest_version is None or \ + Version(lastest_version) < Version(installed_version): + lastest_version = installed_version - self.disable() + if lastest_version is None: + return None + + data = get_properties(self.ext_handler) + old_ext_handler = ExtHandler() + set_properties("ExtHandler", old_ext_handler, data) + old_ext_handler.properties.version = lastest_version + return ExtHandlerInstance(old_ext_handler, self.protocol) + + def copy_status_files(self, old_ext_handler_i): + self.logger.info("Copy status files from old plugin to new") + old_ext_dir = old_ext_handler_i.get_base_dir() + new_ext_dir = self.get_base_dir() + + old_ext_mrseq_file = os.path.join(old_ext_dir, "mrseq") + if os.path.isfile(old_ext_mrseq_file): + shutil.copy2(old_ext_mrseq_file, new_ext_dir) + + old_ext_status_dir = old_ext_handler_i.get_status_dir() + new_ext_status_dir = self.get_status_dir() + + if os.path.isdir(old_ext_status_dir): + for status_file in os.listdir(old_ext_status_dir): + status_file = os.path.join(old_ext_status_dir, status_file) + if os.path.isfile(status_file): + shutil.copy2(status_file, new_ext_status_dir) + + def set_operation(self, op): + self.operation = op - def handle_uninstall(self): - if not self.installed: - self.logger.verb("Not installed, quit unistall") - self.handler_status = None - self.ext_status = None - return - self.disable() - self.uninstall() - - def report_event(self, is_success=True, message=""): - if self.ext_status is not None: - if not is_success: - self.ext_status.status = "error" - self.ext_status.code = -1 - if self.handler_status is not None: - self.handler_status.message = message - if not is_success: - self.handler_status.status = "NotReady" - add_event(name=self.name, op=self.ext_status.operation, - is_success=is_success, message=message) - - def set_operation(self, operation): - if self.ext_status.operation != WALAEventOperation.Upgrade: - self.ext_status.operation = operation - - def upgrade(self, target_version): - self.logger.info("Upgrade from: {0} to {1}", self.curr_version, - target_version) - self.set_operation(WALAEventOperation.Upgrade) - - old = self - new = ExtHandlerInstance(self.ext_handler, self.pkg_list, - target_version) - self.logger.info("Download new extension package") - new.init_logger() - new.download() - self.logger.info("Initialize new extension directory") - new.init_dir() - - old.disable() - self.logger.info("Update new extension") - new.update() - old.uninstall() - man = new.load_manifest() - if man.is_update_with_install(): - self.logger.info("Install new extension") - new.install() - self.logger.info("Enable new extension") - new.enable() - return new + def report_event(self, message="", is_success=True): + version = self.ext_handler.properties.version + add_event(name=self.ext_handler.name, version=version, message=message, + op=self.operation, is_success=is_success) def download(self): self.logger.info("Download extension package") self.set_operation(WALAEventOperation.Download) - - uris = self.get_package_uris() + if self.pkg is None: + raise ExtensionError("No package uri found") + package = None - for uri in uris: + for uri in self.pkg.uris: try: - resp = restutil.http_get(uri.uri, chk_proxy=True) - if resp.status == restutil.httpclient.OK: - package = resp.read() - break - except restutil.HttpError as e: - self.logger.warn("Failed download extension from: {0}", uri.uri) - + package = self.protocol.download_ext_handler_pkg(uri.uri) + except ProtocolError as e: + logger.warn("Failed download extension: {0}", e) + if package is None: - raise ExtensionError("Download extension failed") + raise ExtensionError("Failed to download extension") self.logger.info("Unpack extension package") - pkg_file = os.path.join(self.lib_dir, os.path.basename(uri.uri) + ".zip") - fileutil.write_file(pkg_file, bytearray(package), asbin=True) - zipfile.ZipFile(pkg_file).extractall(self.get_base_dir()) + pkg_file = os.path.join(conf.get_lib_dir(), + os.path.basename(uri.uri) + ".zip") + try: + fileutil.write_file(pkg_file, bytearray(package), asbin=True) + zipfile.ZipFile(pkg_file).extractall(self.get_base_dir()) + except IOError as e: + raise ExtensionError(u"Failed to write and unzip plugin", e) + chmod = "find {0} -type f | xargs chmod u+x".format(self.get_base_dir()) shellutil.run(chmod) self.report_event(message="Download succeeded") - def init_dir(self): self.logger.info("Initialize extension directory") #Save HandlerManifest.json man_file = fileutil.search_file(self.get_base_dir(), 'HandlerManifest.json') - man = fileutil.read_file(man_file, remove_bom=True) - fileutil.write_file(self.get_manifest_file(), man) - #Create status and config dir - status_dir = self.get_status_dir() - fileutil.mkdir(status_dir, mode=0o700) - conf_dir = self.get_conf_dir() - fileutil.mkdir(conf_dir, mode=0o700) + if man_file is None: + raise ExtensionError("HandlerManifest.json not found") - self.make_handler_state_dir() + try: + man = fileutil.read_file(man_file, remove_bom=True) + fileutil.write_file(self.get_manifest_file(), man) + except IOError as e: + raise ExtensionError(u"Failed to save HandlerManifest.json", e) + + #Create status and config dir + try: + status_dir = self.get_status_dir() + fileutil.mkdir(status_dir, mode=0o700) + conf_dir = self.get_conf_dir() + fileutil.mkdir(conf_dir, mode=0o700) + except IOError as e: + raise ExtensionError(u"Failed to create status or config dir", e) #Save HandlerEnvironment.json self.create_handler_env() @@ -442,6 +432,8 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_enable_command()) + self.set_handler_state(ExtHandlerState.Enabled) + self.set_handler_status(status="Ready", message="Plugin enabled") def disable(self): self.logger.info("Disable extension.") @@ -449,6 +441,8 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_disable_command(), timeout=900) + self.set_handler_state(ExtHandlerState.Installed) + self.set_handler_status(status="NotReady", message="Plugin disabled") def install(self): self.logger.info("Install extension.") @@ -456,24 +450,31 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_install_command(), timeout=900) - self.installed = True + self.set_handler_state(ExtHandlerState.Installed) def uninstall(self): self.logger.info("Uninstall extension.") self.set_operation(WALAEventOperation.UnInstall) - man = self.load_manifest() - self.launch_command(man.get_uninstall_command()) - - self.logger.info("Remove ext handler dir: {0}", self.get_base_dir()) try: - shutil.rmtree(self.get_base_dir()) + man = self.load_manifest() + self.launch_command(man.get_uninstall_command()) + except ExtensionError as e: + self.report_event(message=ustr(e), is_success=False) + + def rm_ext_handler_dir(self): + try: + handler_state_dir = self.get_handler_state_dir() + if os.path.isdir(handler_state_dir): + self.logger.info("Remove ext handler dir: {0}", handler_state_dir) + shutil.rmtree(handler_state_dir) + base_dir = self.get_base_dir() + if os.path.isdir(base_dir): + self.logger.info("Remove ext handler dir: {0}", base_dir) + shutil.rmtree(base_dir) except IOError as e: - raise ExtensionError("Failed to rm ext handler dir: {0}".format(e)) - - self.installed = False - self.handler_status = None - self.ext_status = None + message = "Failed to rm ext handler dir: {0}".format(e) + self.report_event(message=message, is_success=False) def update(self): self.logger.info("Update extension.") @@ -481,95 +482,82 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_update_command(), timeout=900) - - def collect_handler_status(self): - self.logger.verb("Collect extension handler status") - if self.handler_status is None: - return - - handler_state = self.get_state() - self.handler_status.status = handler_state_to_status(handler_state) - self.handler_status.message = self.get_state_err() + + def update_with_install(self): man = self.load_manifest() - if man.is_report_heartbeat(): - heartbeat = self.collect_heartbeat() - if heartbeat is not None: - self.handler_status.status = heartbeat['status'] + if man.is_update_with_install(): + self.install() + else: + self.logger.info("UpdateWithInstall not set. " + "Skip install during upgrade.") + self.set_handler_state(ExtHandlerState.Installed) - def collect_ext_status(self): + def get_largest_seq_no(self): + seq_no = -1 + conf_dir = self.get_conf_dir() + for item in os.listdir(conf_dir): + item_path = os.path.join(conf_dir, item) + if os.path.isfile(item_path): + try: + seperator = item.rfind(".") + if seperator > 0 and item[seperator + 1:] == 'settings': + curr_seq_no = int(item.split('.')[0]) + if curr_seq_no > seq_no: + seq_no = curr_seq_no + except Exception as e: + self.logger.verb("Failed to parse file name: {0}", item) + continue + return seq_no + + def collect_ext_status(self, ext): self.logger.verb("Collect extension status") - if self.handler_status is None: - return - if self.ext is None: - return + seq_no = self.get_largest_seq_no() + if seq_no == -1: + return None + + status_dir = self.get_status_dir() + ext_status_file = "{0}.status".format(seq_no) + ext_status_file = os.path.join(status_dir, ext_status_file) - ext_status_file = self.get_status_file() + ext_status = ExtensionStatus(seq_no=seq_no) try: data_str = fileutil.read_file(ext_status_file) data = json.loads(data_str) - parse_ext_status(self.ext_status, data) + parse_ext_status(ext_status, data) except IOError as e: - raise ExtensionError("Failed to get status file: {0}".format(e)) + ext_status.message = u"Failed to get status file {0}".format(e) + ext_status.code = -1 + ext_status.status = "error" except ValueError as e: - raise ExtensionError("Malformed status file: {0}".format(e)) - - def make_handler_state_dir(self): - handler_state_dir = self.get_handler_state_dir() - fileutil.mkdir(handler_state_dir, 0o600) - if not os.path.exists(handler_state_dir): - os.makedirs(handler_state_dir) - - def get_state(self): - handler_state_file = self.get_handler_state_file() - if not os.path.isfile(handler_state_file): - return None - try: - handler_state = fileutil.read_file(handler_state_file) - if handler_state is not None: - handler_state = handler_state.rstrip() - return handler_state - except IOError as e: - err = "Failed to get handler state: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) - - def set_state(self, state): - handler_state_file = self.get_handler_state_file() - if not os.path.isfile(handler_state_file): - self.make_handler_state_dir() - try: - fileutil.write_file(handler_state_file, state) - except IOError as e: - err = "Failed to set handler state: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) - - def get_state_err(self): - """Get handler error message""" - handler_state_err_file= self.get_handler_state_err_file() - if not os.path.isfile(handler_state_err_file): - return None - try: - message = fileutil.read_file(handler_state_err_file) - return message - except IOError as e: - err = "Failed to get handler state message: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) - - def set_state_err(self, message): - """Set handler error message""" - handler_state_err_file = self.get_handler_state_err_file() - if not os.path.isfile(handler_state_err_file): - self.make_handler_state_dir() - try: - fileutil.write_file(handler_state_err_file, message) - except IOError as e: - err = "Failed to set handler state message: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) + ext_status.message = u"Malformed status file {0}".format(e) + ext_status.code = -1 + ext_status.status = "error" + return ext_status + + def report_ext_status(self): + active_exts = [] + for ext in self.ext_handler.properties.extensions: + ext_status = self.collect_ext_status(ext) + if ext_status is None: + continue + try: + self.protocol.report_ext_status(self.ext_handler.name, ext.name, + ext_status) + active_exts.append(ext.name) + except ProtocolError as e: + self.logger.error(u"Failed to report extension status: {0}", e) + return active_exts + def collect_heartbeat(self): - self.logger.info("Collect heart beat") - heartbeat_file = os.path.join(OSUTIL.get_lib_dir(), + man = self.load_manifest() + if not man.is_report_heartbeat(): + return + heartbeat_file = os.path.join(conf.get_lib_dir(), self.get_heartbeat_file()) + + self.logger.info("Collect heart beat") if not os.path.isfile(heartbeat_file): raise ExtensionError("Failed to get heart beat file") if not self.is_responsive(heartbeat_file): @@ -586,15 +574,14 @@ class ExtHandlerInstance(object): except ValueError as e: raise ExtensionError("Malformed heartbeat file: {0}".format(e)) return heartbeat - + def is_responsive(self, heartbeat_file): last_update=int(time.time() - os.stat(heartbeat_file).st_mtime) return last_update > 600 # not updated for more than 10 min - + def launch_command(self, cmd, timeout=300): self.logger.info("Launch command:{0}", cmd) base_dir = self.get_base_dir() - self.update_settings() try: devnull = open(os.devnull, 'w') child = subprocess.Popen(base_dir + "/" + cmd, shell=True, @@ -614,6 +601,7 @@ class ExtHandlerInstance(object): ret = child.wait() if ret == None or ret != 0: raise ExtensionError("Non-zero exit code: {0}, {1}".format(ret, cmd)) + self.report_event(message="Launch command succeeded: {0}".format(cmd)) def load_manifest(self): @@ -627,26 +615,40 @@ class ExtHandlerInstance(object): return HandlerManifest(data[0]) + def update_settings_file(self, settings_file, settings): + settings_file = os.path.join(self.get_conf_dir(), settings_file) + try: + fileutil.write_file(settings_file, settings) + except IOError as e: + raise ExtensionError(u"Failed to update settings file", e) + def update_settings(self): - if self.ext is None: - self.logger.verb("Extension has no settings") + if self.ext_handler.properties.extensions is None or \ + len(self.ext_handler.properties.extensions) == 0: + #This is the behavior of waagent 2.0.x + #The new agent has to be consistent with the old one. + self.logger.info("Extension has no settings, write empty 0.settings") + self.update_settings_file("0.settings", "") return - - settings = { - 'publicSettings': self.ext.publicSettings, - 'protectedSettings': self.ext.privateSettings, - 'protectedSettingsCertThumbprint': self.ext.certificateThumbprint - } - ext_settings = { - "runtimeSettings":[{ - "handlerSettings": settings - }] - } - fileutil.write_file(self.get_settings_file(), json.dumps(ext_settings)) + + for ext in self.ext_handler.properties.extensions: + settings = { + 'publicSettings': ext.publicSettings, + 'protectedSettings': ext.protectedSettings, + 'protectedSettingsCertThumbprint': ext.certificateThumbprint + } + ext_settings = { + "runtimeSettings":[{ + "handlerSettings": settings + }] + } + settings_file = "{0}.settings".format(ext.sequenceNumber) + self.logger.info("Update settings file: {0}", settings_file) + self.update_settings_file(settings_file, json.dumps(ext_settings)) def create_handler_env(self): env = [{ - "name": self.name, + "name": self.ext_handler.name, "version" : HANDLER_ENVIRONMENT_VERSION, "handlerEnvironment" : { "logFolder" : self.get_log_dir(), @@ -655,73 +657,91 @@ class ExtHandlerInstance(object): "heartbeatFile" : self.get_heartbeat_file() } }] - fileutil.write_file(self.get_env_file(), - json.dumps(env)) - - def get_target_version(self): - version = self.version - update_policy = self.update_policy - if update_policy is None or update_policy.lower() != 'auto': - return version - - major = version.split('.')[0] - if major is None: - raise ExtensionError("Wrong version format: {0}".format(version)) - - packages = [x for x in self.pkg_list.versions \ - if x.version.startswith(major + ".")] - packages = sorted(packages, key=lambda x: Version(x.version), - reverse=True) - if len(packages) <= 0: - raise ExtensionError("Can't find version: {0}.*".format(major)) + try: + fileutil.write_file(self.get_env_file(), json.dumps(env)) + except IOError as e: + raise ExtensionError(u"Failed to save handler environment", e) + + def get_handler_state_dir(self): + return os.path.join(conf.get_lib_dir(), "handler_state", + self.get_full_name()) - return packages[0].version + def set_handler_state(self, handler_state): + state_dir = self.get_handler_state_dir() + if not os.path.exists(state_dir): + try: + fileutil.mkdir(state_dir, 0o700) + except IOError as e: + self.logger.error("Failed to create state dir: {0}", e) + + try: + state_file = os.path.join(state_dir, "state") + fileutil.write_file(state_file, handler_state) + except IOError as e: + self.logger.error("Failed to set state: {0}", e) + + def get_handler_state(self): + state_dir = self.get_handler_state_dir() + state_file = os.path.join(state_dir, "state") + if not os.path.isfile(state_file): + return ExtHandlerState.NotInstalled - def get_package_uris(self): - version = self.curr_version - packages = self.pkg_list.versions - if packages is None: - raise ExtensionError("Package uris is None.") + try: + return fileutil.read_file(state_file) + except IOError as e: + self.logger.error("Failed to get state: {0}", e) + return ExtHandlerState.NotInstalled + + def set_handler_status(self, status="NotReady", message="", + code=0): + state_dir = self.get_handler_state_dir() + if not os.path.exists(state_dir): + try: + fileutil.mkdir(state_dir, 0o700) + except IOError as e: + self.logger.error("Failed to create state dir: {0}", e) + + handler_status = ExtHandlerStatus() + handler_status.name = self.ext_handler.name + handler_status.version = self.ext_handler.properties.version + handler_status.message = message + handler_status.code = code + handler_status.status = status + status_file = os.path.join(state_dir, "status") - for package in packages: - if Version(package.version) == Version(version): - return package.uris + try: + fileutil.write_file(status_file, + json.dumps(get_properties(handler_status))) + except (IOError, ValueError, ProtocolError) as e: + self.logger.error("Failed to save handler status: {0}", e) + + def get_handler_status(self): + state_dir = self.get_handler_state_dir() + status_file = os.path.join(state_dir, "status") + if not os.path.isfile(status_file): + return None + + try: + data = json.loads(fileutil.read_file(status_file)) + handler_status = ExtHandlerStatus() + set_properties("ExtHandlerStatus", handler_status, data) + return handler_status + except (IOError, ValueError) as e: + self.logger.error("Failed to get handler status: {0}", e) - raise ExtensionError("Can't get package uris for {0}.".format(version)) - def get_full_name(self): - return "{0}-{1}".format(self.name, self.curr_version) - + return "{0}-{1}".format(self.ext_handler.name, + self.ext_handler.properties.version) + def get_base_dir(self): - return os.path.join(OSUTIL.get_lib_dir(), self.get_full_name()) + return os.path.join(conf.get_lib_dir(), self.get_full_name()) def get_status_dir(self): return os.path.join(self.get_base_dir(), "status") - def get_status_file(self): - return os.path.join(self.get_status_dir(), - "{0}.status".format(self.ext.sequenceNumber)) - def get_conf_dir(self): return os.path.join(self.get_base_dir(), 'config') - def get_settings_file(self): - return os.path.join(self.get_conf_dir(), - "{0}.settings".format(self.ext.sequenceNumber)) - - def get_handler_state_dir(self): - return os.path.join(OSUTIL.get_lib_dir(), "handler_state", - self.get_full_name()) - - def get_handler_state_file(self): - return os.path.join(self.get_handler_state_dir(), - '{0}.state'.format(self.ext.sequenceNumber)) - - def get_handler_state_err_file(self): - return os.path.join(self.get_handler_state_dir(), - '{0}.error'.format(self.ext.sequenceNumber)) - - def get_heartbeat_file(self): return os.path.join(self.get_base_dir(), 'heartbeat.log') @@ -732,8 +752,8 @@ class ExtHandlerInstance(object): return os.path.join(self.get_base_dir(), 'HandlerEnvironment.json') def get_log_dir(self): - return os.path.join(OSUTIL.get_ext_log_dir(), self.name, - self.curr_version) + return os.path.join(conf.get_ext_log_dir(), self.ext_handler.name, + self.ext_handler.properties.version) class HandlerEnvironment(object): def __init__(self, data): @@ -782,19 +802,16 @@ class HandlerManifest(object): return self.data['handlerManifest']["disableCommand"] def is_reboot_after_install(self): - #TODO handle reboot after install - if "rebootAfterInstall" not in self.data['handlerManifest']: - return False - return self.data['handlerManifest']["rebootAfterInstall"] + """ + Deprecated + """ + return False def is_report_heartbeat(self): - if "reportHeartbeat" not in self.data['handlerManifest']: - return False - return self.data['handlerManifest']["reportHeartbeat"] + return self.data['handlerManifest'].get('reportHeartbeat', False) def is_update_with_install(self): - if "updateMode" not in self.data['handlerManifest']: - return False - if "updateMode" in self.data: - return self.data['handlerManifest']["updateMode"].lower() == "updatewithinstall" - return False + update_mode = self.data['handlerManifest'].get('updateMode') + if update_mode is None: + return True + return update_mode.low() == "updatewithinstall" diff --git a/azurelinuxagent/distro/default/handlerFactory.py b/azurelinuxagent/distro/default/handlerFactory.py deleted file mode 100644 index dceb2a3..0000000 --- a/azurelinuxagent/distro/default/handlerFactory.py +++ /dev/null @@ -1,40 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# -from .init import InitHandler -from .run import MainHandler -from .scvmm import ScvmmHandler -from .dhcp import DhcpHandler -from .env import EnvHandler -from .provision import ProvisionHandler -from .resourceDisk import ResourceDiskHandler -from .extension import ExtHandlersHandler -from .deprovision import DeprovisionHandler - -class DefaultHandlerFactory(object): - def __init__(self): - self.init_handler = InitHandler() - self.main_handler = MainHandler(self) - self.scvmm_handler = ScvmmHandler() - self.dhcp_handler = DhcpHandler() - self.env_handler = EnvHandler(self) - self.provision_handler = ProvisionHandler() - self.resource_disk_handler = ResourceDiskHandler() - self.ext_handlers_handler = ExtHandlersHandler() - self.deprovision_handler = DeprovisionHandler() - diff --git a/azurelinuxagent/distro/default/init.py b/azurelinuxagent/distro/default/init.py index db74fef..c703e87 100644 --- a/azurelinuxagent/distro/default/init.py +++ b/azurelinuxagent/distro/default/init.py @@ -20,30 +20,34 @@ import os import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.utils.fileutil as fileutil +import azurelinuxagent.event as event class InitHandler(object): - def init(self, verbose): + def __init__(self, distro): + self.distro = distro + + def run(self, verbose): #Init stdout log level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO logger.add_logger_appender(logger.AppenderType.STDOUT, level) #Init config - conf_file_path = OSUTIL.get_conf_file_path() - conf.load_conf(conf_file_path) + conf_file_path = self.distro.osutil.get_agent_conf_file_path() + conf.load_conf_from_file(conf_file_path) #Init log - verbose = verbose or conf.get_switch("Logs.Verbose", False) + verbose = verbose or conf.get_logs_verbose() level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO logger.add_logger_appender(logger.AppenderType.FILE, level, path="/var/log/waagent.log") logger.add_logger_appender(logger.AppenderType.CONSOLE, level, path="/dev/console") - #Create lib dir - fileutil.mkdir(OSUTIL.get_lib_dir(), mode=0o700) - os.chdir(OSUTIL.get_lib_dir()) + #Init event reporter + event_dir = os.path.join(conf.get_lib_dir(), "events") + event.init_event_logger(event_dir) + event.enable_unhandled_err_dump("WALA") + diff --git a/azurelinuxagent/distro/default/loader.py b/azurelinuxagent/distro/default/loader.py deleted file mode 100644 index 55a51e0..0000000 --- a/azurelinuxagent/distro/default/loader.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -def get_osutil(): - from azurelinuxagent.distro.default.osutil import DefaultOSUtil - return DefaultOSUtil() - -def get_handlers(): - from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory - return DefaultHandlerFactory() - - diff --git a/azurelinuxagent/distro/default/monitor.py b/azurelinuxagent/distro/default/monitor.py new file mode 100644 index 0000000..3b26c9a --- /dev/null +++ b/azurelinuxagent/distro/default/monitor.py @@ -0,0 +1,182 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import sys +import traceback +import atexit +import json +import time +import datetime +import threading +import platform +import azurelinuxagent.logger as logger +import azurelinuxagent.conf as conf +from azurelinuxagent.event import WALAEventOperation, add_event +from azurelinuxagent.exception import EventError, ProtocolError, OSUtilError +from azurelinuxagent.future import ustr +from azurelinuxagent.utils.textutil import parse_doc, findall, find, getattrib +from azurelinuxagent.protocol.restapi import TelemetryEventParam, \ + TelemetryEventList, \ + TelemetryEvent, \ + set_properties, get_properties +from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_CODE_NAME, AGENT_LONG_VERSION + + +def parse_event(data_str): + try: + return parse_json_event(data_str) + except ValueError: + return parse_xml_event(data_str) + +def parse_xml_param(param_node): + name = getattrib(param_node, "Name") + value_str = getattrib(param_node, "Value") + attr_type = getattrib(param_node, "T") + value = value_str + if attr_type == 'mt:uint64': + value = int(value_str) + elif attr_type == 'mt:bool': + value = bool(value_str) + elif attr_type == 'mt:float64': + value = float(value_str) + return TelemetryEventParam(name, value) + +def parse_xml_event(data_str): + try: + xml_doc = parse_doc(data_str) + event_id = getattrib(find(xml_doc, "Event"), 'id') + provider_id = getattrib(find(xml_doc, "Provider"), 'id') + event = TelemetryEvent(event_id, provider_id) + param_nodes = findall(xml_doc, 'Param') + for param_node in param_nodes: + event.parameters.append(parse_xml_param(param_node)) + return event + except Exception as e: + raise ValueError(ustr(e)) + +def parse_json_event(data_str): + data = json.loads(data_str) + event = TelemetryEvent() + set_properties("TelemetryEvent", event, data) + return event + + +class MonitorHandler(object): + def __init__(self, distro): + self.distro = distro + self.sysinfo = [] + + def run(self): + event_thread = threading.Thread(target = self.daemon) + event_thread.setDaemon(True) + event_thread.start() + + def init_sysinfo(self): + osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(), + DISTRO_NAME, + DISTRO_VERSION, + DISTRO_CODE_NAME, + platform.release()) + + + self.sysinfo.append(TelemetryEventParam("OSVersion", osversion)) + self.sysinfo.append(TelemetryEventParam("GAVersion", AGENT_LONG_VERSION)) + + try: + ram = self.distro.osutil.get_total_mem() + processors = self.distro.osutil.get_processor_cores() + self.sysinfo.append(TelemetryEventParam("RAM", ram)) + self.sysinfo.append(TelemetryEventParam("Processors", processors)) + except OSUtilError as e: + logger.warn("Failed to get system info: {0}", e) + + try: + protocol = self.distro.protocol_util.get_protocol() + vminfo = protocol.get_vminfo() + self.sysinfo.append(TelemetryEventParam("VMName", + vminfo.vmName)) + self.sysinfo.append(TelemetryEventParam("TenantName", + vminfo.tenantName)) + self.sysinfo.append(TelemetryEventParam("RoleName", + vminfo.roleName)) + self.sysinfo.append(TelemetryEventParam("RoleInstanceName", + vminfo.roleInstanceName)) + self.sysinfo.append(TelemetryEventParam("ContainerId", + vminfo.containerId)) + except ProtocolError as e: + logger.warn("Failed to get system info: {0}", e) + + def collect_event(self, evt_file_name): + try: + logger.verb("Found event file: {0}", evt_file_name) + with open(evt_file_name, "rb") as evt_file: + #if fail to open or delete the file, throw exception + data_str = evt_file.read().decode("utf-8",'ignore') + logger.verb("Processed event file: {0}", evt_file_name) + os.remove(evt_file_name) + return data_str + except IOError as e: + msg = "Failed to process {0}, {1}".format(evt_file_name, e) + raise EventError(msg) + + def collect_and_send_events(self): + event_list = TelemetryEventList() + event_dir = os.path.join(conf.get_lib_dir(), "events") + event_files = os.listdir(event_dir) + for event_file in event_files: + if not event_file.endswith(".tld"): + continue + event_file_path = os.path.join(event_dir, event_file) + try: + data_str = self.collect_event(event_file_path) + except EventError as e: + logger.error("{0}", e) + continue + + try: + event = parse_event(data_str) + event.parameters.extend(self.sysinfo) + event_list.events.append(event) + except (ValueError, ProtocolError) as e: + logger.warn("Failed to decode event file: {0}", e) + continue + + if len(event_list.events) == 0: + return + + try: + protocol = self.distro.protocol_util.get_protocol() + protocol.report_event(event_list) + except ProtocolError as e: + logger.error("{0}", e) + + def daemon(self): + self.init_sysinfo() + last_heartbeat = datetime.datetime.min + period = datetime.timedelta(hours = 12) + while(True): + if (datetime.datetime.now()-last_heartbeat) > period: + last_heartbeat = datetime.datetime.now() + add_event(op=WALAEventOperation.HeartBeat, name="WALA", + is_success=True) + try: + self.collect_and_send_events() + except Exception as e: + logger.warn("Failed to send events: {0}", e) + time.sleep(60) diff --git a/azurelinuxagent/distro/default/osutil.py b/azurelinuxagent/distro/default/osutil.py index 00a57cc..18ab2ba 100644 --- a/azurelinuxagent/distro/default/osutil.py +++ b/azurelinuxagent/distro/default/osutil.py @@ -25,11 +25,15 @@ import struct import time import pwd import fcntl +import base64 import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +import azurelinuxagent.conf as conf +from azurelinuxagent.exception import OSUtilError +from azurelinuxagent.future import ustr import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.textutil as textutil +from azurelinuxagent.utils.cryptutil import CryptUtil __RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules" ] @@ -40,44 +44,14 @@ for all distros. Each concrete distro classes could overwrite default behavior if needed. """ -class OSUtilError(Exception): - pass - class DefaultOSUtil(object): def __init__(self): - self.lib_dir = "/var/lib/waagent" - self.ext_log_dir = "/var/log/azure" - self.dvd_mount_point = "/mnt/cdrom/secure" - self.ovf_env_file_path = "/mnt/cdrom/secure/ovf-env.xml" - self.agent_pid_file_path = "/var/run/waagent.pid" - self.passwd_file_path = "/etc/shadow" - self.home = '/home' - self.sshd_conf_file_path = '/etc/ssh/sshd_config' - self.openssl_cmd = '/usr/bin/openssl' - self.conf_file_path = '/etc/waagent.conf' + self.agent_conf_file_path = '/etc/waagent.conf' self.selinux=None - def get_lib_dir(self): - return self.lib_dir - - def get_ext_log_dir(self): - return self.ext_log_dir - - def get_dvd_mount_point(self): - return self.dvd_mount_point - - def get_conf_file_path(self): - return self.conf_file_path - - def get_ovf_env_file_path_on_dvd(self): - return self.ovf_env_file_path - - def get_agent_pid_file_path(self): - return self.agent_pid_file_path - - def get_openssl_cmd(self): - return self.openssl_cmd + def get_agent_conf_file_path(self): + return self.agent_conf_file_path def get_userentry(self, username): try: @@ -86,6 +60,14 @@ class DefaultOSUtil(object): return None def is_sys_user(self, username): + """ + Check whether use is a system user. + If reset sys user is allowed in conf, return False + Otherwise, check whether UID is less than UID_MIN + """ + if conf.get_allow_reset_sys_user(): + return False + userentry = self.get_userentry(username) uidmin = None try: @@ -104,9 +86,13 @@ class DefaultOSUtil(object): def useradd(self, username, expiration=None): """ - Update password and ssh key for user account. - New account will be created if not exists. + Create user account with 'username' """ + userentry = self.get_userentry(username) + if userentry is not None: + logger.info("User {0} already exists, skip useradd", username) + return + if expiration is not None: cmd = "useradd -m {0} -e {1}".format(username, expiration) else: @@ -146,42 +132,21 @@ class DefaultOSUtil(object): def del_root_password(self): try: - passwd_content = fileutil.read_file(self.passwd_file_path) + passwd_file_path = conf.get_passwd_file_path() + passwd_content = fileutil.read_file(passwd_file_path) passwd = passwd_content.split('\n') new_passwd = [x for x in passwd if not x.startswith("root:")] new_passwd.insert(0, "root:*LOCK*:14600::::::") - fileutil.write_file(self.passwd_file_path, "\n".join(new_passwd)) + fileutil.write_file(passwd_file_path, "\n".join(new_passwd)) except IOError as e: raise OSUtilError("Failed to delete root password:{0}".format(e)) - def get_home(self): - return self.home - - def get_pubkey_from_prv(self, file_name): - cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, - file_name) - pub = shellutil.run_get_output(cmd)[1] - return pub - - def get_pubkey_from_crt(self, file_name): - cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd, - file_name) - pub = shellutil.run_get_output(cmd)[1] - return pub - def _norm_path(self, filepath): - home = self.get_home() + home = conf.get_home_dir() # Expand HOME variable if present in path path = os.path.normpath(filepath.replace("$HOME", home)) return path - def get_thumbprint_from_crt(self, file_name): - cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd, - file_name) - thumbprint = shellutil.run_get_output(cmd)[1] - thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper() - return thumbprint - def deploy_ssh_keypair(self, username, keypair): """ Deploy id_rsa and id_rsa.pub @@ -190,13 +155,14 @@ class DefaultOSUtil(object): path = self._norm_path(path) dir_path = os.path.dirname(path) fileutil.mkdir(dir_path, mode=0o700, owner=username) - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() prv_path = os.path.join(lib_dir, thumbprint + '.prv') if not os.path.isfile(prv_path): raise OSUtilError("Can't find {0}.prv".format(thumbprint)) shutil.copyfile(prv_path, path) pub_path = path + '.pub' - pub = self.get_pubkey_from_prv(prv_path) + crytputil = CryptUtil(conf.get_openssl_cmd()) + pub = crytputil.get_pubkey_from_prv(prv_path) fileutil.write_file(pub_path, pub) self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0') @@ -204,8 +170,8 @@ class DefaultOSUtil(object): os.chmod(pub_path, 0o600) def openssl_to_openssh(self, input_file, output_file): - shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file, - output_file)) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.crt_to_ssh(input_file, output_file) def deploy_ssh_pubkey(self, username, pubkey): """ @@ -215,6 +181,8 @@ class DefaultOSUtil(object): if path is None: raise OSUtilError("Publich key path is None") + crytputil = CryptUtil(conf.get_openssl_cmd()) + path = self._norm_path(path) dir_path = os.path.dirname(path) fileutil.mkdir(dir_path, mode=0o700, owner=username) @@ -223,12 +191,12 @@ class DefaultOSUtil(object): raise OSUtilError("Bad public key: {0}".format(value)) fileutil.write_file(path, value) elif thumbprint is not None: - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() crt_path = os.path.join(lib_dir, thumbprint + '.crt') if not os.path.isfile(crt_path): raise OSUtilError("Can't find {0}.crt".format(thumbprint)) pub_path = os.path.join(lib_dir, thumbprint + '.pub') - pub = self.get_pubkey_from_crt(crt_path) + pub = crytputil.get_pubkey_from_crt(crt_path) fileutil.write_file(pub_path, pub) self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') @@ -280,23 +248,21 @@ class DefaultOSUtil(object): if self.is_selinux_system(): return shellutil.run('chcon ' + con + ' ' + path) - def get_sshd_conf_file_path(self): - return self.sshd_conf_file_path - def set_ssh_client_alive_interval(self): - conf_file_path = self.get_sshd_conf_file_path() - conf = fileutil.read_file(conf_file_path).split("\n") - textutil.set_ssh_config(conf, "ClientAliveInterval", "180") - fileutil.write_file(conf_file_path, '\n'.join(conf)) + conf_file_path = conf.get_sshd_conf_file_path() + conf_file = fileutil.read_file(conf_file_path).split("\n") + textutil.set_ssh_config(conf_file, "ClientAliveInterval", "180") + fileutil.write_file(conf_file_path, '\n'.join(conf_file)) logger.info("Configured SSH client probing to keep connections alive.") def conf_sshd(self, disable_password): option = "no" if disable_password else "yes" - conf_file_path = self.get_sshd_conf_file_path() - conf = fileutil.read_file(conf_file_path).split("\n") - textutil.set_ssh_config(conf, "PasswordAuthentication", option) - textutil.set_ssh_config(conf, "ChallengeResponseAuthentication", option) - fileutil.write_file(conf_file_path, "\n".join(conf)) + conf_file_path = conf.get_sshd_conf_file_path() + conf_file = fileutil.read_file(conf_file_path).split("\n") + textutil.set_ssh_config(conf_file, "PasswordAuthentication", option) + textutil.set_ssh_config(conf_file, "ChallengeResponseAuthentication", + option) + fileutil.write_file(conf_file_path, "\n".join(conf_file)) logger.info("Disabled SSH password-based authentication methods.") @@ -309,7 +275,7 @@ class DefaultOSUtil(object): def mount_dvd(self, max_retry=6, chk_err=True): dvd = self.get_dvd_device() - mount_point = self.get_dvd_mount_point() + mount_point = conf.get_dvd_mount_point() mountlist = shellutil.run_get_output("mount")[1] existing = self.get_mount_point(mountlist, dvd) if existing is not None: #Already mounted @@ -332,7 +298,7 @@ class DefaultOSUtil(object): raise OSUtilError("Failed to mount dvd.") def umount_dvd(self, chk_err=True): - mount_point = self.get_dvd_mount_point() + mount_point = conf.get_dvd_mount_point() retcode = self.umount(mount_point, chk_err=chk_err) if chk_err and retcode != 0: raise OSUtilError("Failed to umount dvd.") @@ -386,17 +352,9 @@ class DefaultOSUtil(object): shellutil.run("iptables -I INPUT -p udp --dport 68 -j ACCEPT", chk_err=False) - def gen_transport_cert(self): - """ - Create ssl certificate for https communication with endpoint server. - """ - cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 " - "-newkey rsa:2048 -keyout TransportPrivate.pem " - "-out TransportCert.pem").format(self.openssl_cmd) - shellutil.run(cmd) def remove_rules_files(self, rules_files=__RULES_FILES__): - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() for src in rules_files: file_name = fileutil.base_name(src) dest = os.path.join(lib_dir, file_name) @@ -407,7 +365,7 @@ class DefaultOSUtil(object): shutil.move(src, dest) def restore_rules_files(self, rules_files=__RULES_FILES__): - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() for dest in rules_files: filename = fileutil.base_name(dest) src = os.path.join(lib_dir, filename) @@ -603,7 +561,7 @@ class DefaultOSUtil(object): for vmbus in os.listdir(path): deviceid = fileutil.read_file(os.path.join(path, vmbus, "device_id")) guid = deviceid.lstrip('{').split('-') - if guid[0] == g0 and guid[1] == "000" + text(port_id): + if guid[0] == g0 and guid[1] == "000" + ustr(port_id): for root, dirs, files in os.walk(path + vmbus): if root.endswith("/block"): device = dirs[0] @@ -633,7 +591,7 @@ class DefaultOSUtil(object): raise OSUtilError("Failed to remove sudoer: {0}".format(e)) def decode_customdata(self, data): - return data + return base64.b64decode(data) def get_total_mem(self): cmd = "grep MemTotal /proc/meminfo |awk '{print $2}'" @@ -649,4 +607,17 @@ class DefaultOSUtil(object): return int(ret[1]) else: raise OSUtilError("Failed to get procerssor cores") + + def set_admin_access_to_ip(self, dest_ip): + #This allows root to access dest_ip + rm_old= "iptables -D OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0" + rule = "iptables -A OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0" + shellutil.run(rm_old.format(dest_ip), chk_err=False) + shellutil.run(rule.format(dest_ip)) + + #This blocks all other users to access dest_ip + rm_old = "iptables -D OUTPUT -d {0} -j DROP" + rule = "iptables -A OUTPUT -d {0} -j DROP" + shellutil.run(rm_old.format(dest_ip), chk_err=False) + shellutil.run(rule.format(dest_ip)) diff --git a/azurelinuxagent/distro/default/protocolUtil.py b/azurelinuxagent/distro/default/protocolUtil.py new file mode 100644 index 0000000..34466cf --- /dev/null +++ b/azurelinuxagent/distro/default/protocolUtil.py @@ -0,0 +1,243 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import re +import shutil +import time +import threading +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.exception import ProtocolError, OSUtilError, \ + ProtocolNotFoundError, DhcpError +from azurelinuxagent.future import ustr +import azurelinuxagent.utils.fileutil as fileutil +from azurelinuxagent.protocol.ovfenv import OvfEnv +from azurelinuxagent.protocol.wire import WireProtocol +from azurelinuxagent.protocol.metadata import MetadataProtocol, METADATA_ENDPOINT +import azurelinuxagent.utils.shellutil as shellutil + +OVF_FILE_NAME = "ovf-env.xml" + +#Tag file to indicate usage of metadata protocol +TAG_FILE_NAME = "useMetadataEndpoint.tag" + +PROTOCOL_FILE_NAME = "Protocol" + +#MAX retry times for protocol probing +MAX_RETRY = 360 + +PROBE_INTERVAL = 10 + +ENDPOINT_FILE_NAME = "WireServerEndpoint" + +class ProtocolUtil(object): + """ + ProtocolUtil handles initialization for protocol instance. 2 protocol types + are invoked, wire protocol and metadata protocols. + """ + def __init__(self, distro): + self.distro = distro + self.protocol = None + self.lock = threading.Lock() + + def copy_ovf_env(self): + """ + Copy ovf env file from dvd to hard disk. + Remove password before save it to the disk + """ + dvd_mount_point = conf.get_dvd_mount_point() + ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME) + tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME) + try: + self.distro.osutil.mount_dvd() + ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) + ovfenv = OvfEnv(ovfxml) + ovfxml = re.sub(".*?<", "*<", ovfxml) + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + fileutil.write_file(ovf_file_path, ovfxml) + + if os.path.isfile(tag_file_path_on_dvd): + logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME) + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + shutil.copyfile(tag_file_path_on_dvd, tag_file_path) + + except (OSUtilError, IOError) as e: + raise ProtocolError(ustr(e)) + + try: + self.distro.osutil.umount_dvd() + self.distro.osutil.eject_dvd() + except OSUtilError as e: + logger.warn(ustr(e)) + + return ovfenv + + def get_ovf_env(self): + """ + Load saved ovf-env.xml + """ + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + if os.path.isfile(ovf_file_path): + xml_text = fileutil.read_file(ovf_file_path) + return OvfEnv(xml_text) + else: + raise ProtocolError("ovf-env.xml is missing.") + + def _get_wireserver_endpoint(self): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + return fileutil.read_file(file_path) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _set_wireserver_endpoint(self, endpoint): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + fileutil.write_file(file_path, endpoint) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _detect_wire_protocol(self): + endpoint = self.distro.dhcp_handler.endpoint + if endpoint is None: + logger.info("WireServer endpoint is not found. Rerun dhcp handler") + try: + self.distro.dhcp_handler.run() + except DhcpError as e: + raise ProtocolError(ustr(e)) + endpoint = self.distro.dhcp_handler.endpoint + + try: + protocol = WireProtocol(endpoint) + protocol.detect() + self._set_wireserver_endpoint(endpoint) + return protocol + except ProtocolError as e: + logger.info("WireServer is not responding. Reset endpoint") + self.distro.dhcp_handler.endpoint = None + raise e + + def _detect_metadata_protocol(self): + protocol = MetadataProtocol() + protocol.detect() + + #Only allow root access METADATA_ENDPOINT + self.distro.osutil.set_admin_access_to_ip(METADATA_ENDPOINT) + + return protocol + + def _detect_protocol(self, protocols): + """ + Probe protocol endpoints in turn. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + if os.path.isfile(protocol_file_path): + os.remove(protocol_file_path) + for retry in range(0, MAX_RETRY): + for protocol in protocols: + try: + if protocol == "WireProtocol": + return self._detect_wire_protocol() + + if protocol == "MetadataProtocol": + return self._detect_metadata_protocol() + + except ProtocolError as e: + logger.info("Protocol endpoint not found: {0}, {1}", + protocol, e) + + if retry < MAX_RETRY -1: + logger.info("Retry detect protocols: retry={0}", retry) + time.sleep(PROBE_INTERVAL) + raise ProtocolNotFoundError("No protocol found.") + + def _get_protocol(self): + """ + Get protocol instance based on previous detecting result. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), + PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + raise ProtocolError("No protocl found") + + protocol_name = fileutil.read_file(protocol_file_path) + if protocol_name == "WireProtocol": + endpoint = self._get_wireserver_endpoint() + return WireProtocol(endpoint) + elif protocol_name == "MetadataProtocol": + return MetadataProtocol() + else: + raise ProtocolNotFoundError(("Unknown protocol: {0}" + "").format(protocol_name)) + + def detect_protocol(self): + """ + Detect protocol by endpoints + + :returns: protocol instance + """ + logger.info("Detect protocol endpoints") + protocols = ["WireProtocol", "MetadataProtocol"] + self.lock.acquire() + try: + if self.protocol is None: + self.protocol = self._detect_protocol(protocols) + return self.protocol + finally: + self.lock.release() + + def detect_protocol_by_file(self): + """ + Detect protocol by tag file. + + If a file "useMetadataEndpoint.tag" is found on provision iso, + metedata protocol will be used. No need to probe for wire protocol + + :returns: protocol instance + """ + logger.info("Detect protocol by file") + self.lock.acquire() + try: + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + if self.protocol is None: + protocols = [] + if os.path.isfile(tag_file_path): + protocols.append("MetadataProtocol") + else: + protocols.append("WireProtocol") + self.protocol = self._detect_protocol(protocols) + finally: + self.lock.release() + return self.protocol + + def get_protocol(self): + """ + Get protocol instance based on previous detecting result. + + :returns protocol instance + """ + self.lock.acquire() + try: + if self.protocol is None: + self.protocol = self._get_protocol() + return self.protocol + finally: + self.lock.release() + return self.protocol + diff --git a/azurelinuxagent/distro/default/provision.py b/azurelinuxagent/distro/default/provision.py index 424f083..695b82a 100644 --- a/azurelinuxagent/distro/default/provision.py +++ b/azurelinuxagent/distro/default/provision.py @@ -21,13 +21,11 @@ Provision handler import os import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.conf as conf from azurelinuxagent.event import add_event, WALAEventOperation -from azurelinuxagent.exception import * -from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError -import azurelinuxagent.protocol as prot -import azurelinuxagent.protocol.ovfenv as ovf +from azurelinuxagent.exception import ProvisionError, ProtocolError, OSUtilError +from azurelinuxagent.protocol.restapi import ProvisionStatus import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.fileutil as fileutil @@ -35,61 +33,49 @@ CUSTOM_DATA_FILE="CustomData" class ProvisionHandler(object): - def process(self): + def __init__(self, distro): + self.distro = distro + + def run(self): #If provision is not enabled, return - if not conf.get_switch("Provisioning.Enabled", True): + if not conf.get_provision_enabled(): logger.info("Provisioning is disabled. Skip.") - return + return - provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned") + provisioned = os.path.join(conf.get_lib_dir(), "provisioned") if os.path.isfile(provisioned): return - logger.info("run provision handler.") - protocol = prot.FACTORY.get_default_protocol() + logger.info("Run provision handler.") + logger.info("Copy ovf-env.xml.") + try: + ovfenv = self.distro.protocol_util.copy_ovf_env() + except ProtocolError as e: + self.report_event("Failed to copy ovf-env.xml: {0}".format(e)) + return + + self.distro.protocol_util.detect_protocol_by_file() + + self.report_not_ready("Provisioning", "Starting") + try: - status = prot.ProvisionStatus(status="NotReady", - subStatus="Provisioning", - description="Starting") - try: - protocol.report_provision_status(status) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message=text(e), - op=WALAEventOperation.Provision) - - self.provision() + logger.info("Start provisioning") + self.provision(ovfenv) fileutil.write_file(provisioned, "") thumbprint = self.reg_ssh_host_key() - logger.info("Finished provisioning") - status = prot.ProvisionStatus(status="Ready") - status.properties.certificateThumbprint = thumbprint - - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) - - add_event(name="WALA", is_success=True, message="", - op=WALAEventOperation.Provision) except ProvisionError as e: logger.error("Provision failed: {0}", e) - status = prot.ProvisionStatus(status="NotReady", - subStatus="ProvisioningFailed", - description= text(e)) - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) - - add_event(name="WALA", is_success=False, message=text(e), - op=WALAEventOperation.Provision) + self.report_not_ready("ProvisioningFailed", ustr(e)) + self.report_event(ustr(e)) + return + self.report_ready(thumbprint) + self.report_event("Provision succeed", is_success=True) + def reg_ssh_host_key(self): - keypair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa") - if conf.get_switch("Provisioning.RegenerateSshHostKeyPair"): + keypair_type = conf.get_ssh_host_keypair_type() + if conf.get_regenerate_ssh_host_key(): shellutil.run("rm -f /etc/ssh/ssh_host_*key*") shellutil.run(("ssh-keygen -N '' -t {0} -f /etc/ssh/ssh_host_{1}_key" "").format(keypair_type, keypair_type)) @@ -105,77 +91,101 @@ class ProvisionHandler(object): raise ProvisionError(("Failed to generate ssh host key: " "ret={0}, out= {1}").format(ret[0], ret[1])) - - def provision(self): - logger.info("Copy ovf-env.xml.") - try: - ovfenv = ovf.copy_ovf_env() - except prot.ProtocolError as e: - raise ProvisionError("Failed to copy ovf-env.xml: {0}".format(e)) - + def provision(self, ovfenv): logger.info("Handle ovf-env.xml.") try: logger.info("Set host name.") - OSUTIL.set_hostname(ovfenv.hostname) + self.distro.osutil.set_hostname(ovfenv.hostname) logger.info("Publish host name.") - OSUTIL.publish_hostname(ovfenv.hostname) + self.distro.osutil.publish_hostname(ovfenv.hostname) self.config_user_account(ovfenv) self.save_customdata(ovfenv) + + if conf.get_delete_root_password(): + self.distro.osutil.del_root_password() - if conf.get_switch("Provisioning.DeleteRootPassword"): - OSUTIL.del_root_password() except OSUtilError as e: raise ProvisionError("Failed to handle ovf-env.xml: {0}".format(e)) def config_user_account(self, ovfenv): logger.info("Create user account if not exists") - OSUTIL.useradd(ovfenv.username) + self.distro.osutil.useradd(ovfenv.username) if ovfenv.user_password is not None: logger.info("Set user password.") - crypt_id = conf.get("Provision.PasswordCryptId", "6") - salt_len = conf.get_int("Provision.PasswordCryptSaltLength", 10) - OSUTIL.chpasswd(ovfenv.username, ovfenv.user_password, + crypt_id = conf.get_password_cryptid() + salt_len = conf.get_password_crypt_salt_len() + self.distro.osutil.chpasswd(ovfenv.username, ovfenv.user_password, crypt_id=crypt_id, salt_len=salt_len) logger.info("Configure sudoer") - OSUTIL.conf_sudoer(ovfenv.username, ovfenv.user_password is None) + self.distro.osutil.conf_sudoer(ovfenv.username, ovfenv.user_password is None) logger.info("Configure sshd") - OSUTIL.conf_sshd(ovfenv.disable_ssh_password_auth) + self.distro.osutil.conf_sshd(ovfenv.disable_ssh_password_auth) #Disable selinux temporary - sel = OSUTIL.is_selinux_enforcing() + sel = self.distro.osutil.is_selinux_enforcing() if sel: - OSUTIL.set_selinux_enforce(0) + self.distro.osutil.set_selinux_enforce(0) self.deploy_ssh_pubkeys(ovfenv) self.deploy_ssh_keypairs(ovfenv) if sel: - OSUTIL.set_selinux_enforce(1) + self.distro.osutil.set_selinux_enforce(1) - OSUTIL.restart_ssh_service() + self.distro.osutil.restart_ssh_service() def save_customdata(self, ovfenv): - logger.info("Save custom data") customdata = ovfenv.customdata if customdata is None: return - lib_dir = OSUTIL.get_lib_dir() - fileutil.write_file(os.path.join(lib_dir, CUSTOM_DATA_FILE), - OSUTIL.decode_customdata(customdata)) + + logger.info("Save custom data") + lib_dir = conf.get_lib_dir() + if conf.get_decode_customdata(): + customdata= self.distro.osutil.decode_customdata(customdata) + customdata_file = os.path.join(lib_dir, CUSTOM_DATA_FILE) + fileutil.write_file(customdata_file, customdata) + + if conf.get_execute_customdata(): + logger.info("Execute custom data") + os.chmod(customdata_file, 0o700) + shellutil.run(customdata_file) def deploy_ssh_pubkeys(self, ovfenv): for pubkey in ovfenv.ssh_pubkeys: logger.info("Deploy ssh public key.") - OSUTIL.deploy_ssh_pubkey(ovfenv.username, pubkey) + self.distro.osutil.deploy_ssh_pubkey(ovfenv.username, pubkey) def deploy_ssh_keypairs(self, ovfenv): for keypair in ovfenv.ssh_keypairs: logger.info("Deploy ssh key pairs.") - OSUTIL.deploy_ssh_keypair(ovfenv.username, keypair) + self.distro.osutil.deploy_ssh_keypair(ovfenv.username, keypair) + + def report_event(self, message, is_success=False): + add_event(name="WALA", message=message, is_success=is_success, + op=WALAEventOperation.Provision) + + def report_not_ready(self, sub_status, description): + status = ProvisionStatus(status="NotReady", subStatus=sub_status, + description=description) + try: + protocol = self.distro.protocol_util.get_protocol() + protocol.report_provision_status(status) + except ProtocolError as e: + self.report_event(ustr(e)) + + def report_ready(self, thumbprint=None): + status = ProvisionStatus(status="Ready") + status.properties.certificateThumbprint = thumbprint + try: + protocol = self.distro.protocol_util.get_protocol() + protocol.report_provision_status(status) + except ProtocolError as e: + self.report_event(ustr(e)) diff --git a/azurelinuxagent/distro/default/resourceDisk.py b/azurelinuxagent/distro/default/resourceDisk.py index 734863c..a6c5232 100644 --- a/azurelinuxagent/distro/default/resourceDisk.py +++ b/azurelinuxagent/distro/default/resourceDisk.py @@ -21,9 +21,8 @@ import os import re import threading import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.conf as conf -from azurelinuxagent.utils.osutil import OSUTIL from azurelinuxagent.event import add_event, WALAEventOperation import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil @@ -41,6 +40,8 @@ For additional details to please refer to the MSDN documentation at : http://msd """ class ResourceDiskHandler(object): + def __init__(self, distro): + self.distro = distro def start_activate_resource_disk(self): disk_thread = threading.Thread(target = self.run) @@ -48,17 +49,17 @@ class ResourceDiskHandler(object): def run(self): mount_point = None - if conf.get_switch("ResourceDisk.Format", False): + if conf.get_resourcedisk_format(): mount_point = self.activate_resource_disk() if mount_point is not None and \ - conf.get_switch("ResourceDisk.EnableSwap", False): + conf.get_resourcedisk_enable_swap(): self.enable_swap(mount_point) def activate_resource_disk(self): logger.info("Activate resource disk") try: - mount_point = conf.get("ResourceDisk.MountPoint", "/mnt/resource") - fs = conf.get("ResourceDisk.Filesystem", "ext3") + mount_point = conf.get_resourcedisk_mountpoint() + fs = conf.get_resourcedisk_filesystem() mount_point = self.mount_resource_disk(mount_point, fs) warning_file = os.path.join(mount_point, DATALOSS_WARNING_FILE_NAME) try: @@ -68,25 +69,25 @@ class ResourceDiskHandler(object): return mount_point except ResourceDiskError as e: logger.error("Failed to mount resource disk {0}", e) - add_event(name="WALA", is_success=False, message=text(e), + add_event(name="WALA", is_success=False, message=ustr(e), op=WALAEventOperation.ActivateResourceDisk) def enable_swap(self, mount_point): logger.info("Enable swap") try: - size_mb = conf.get_int("ResourceDisk.SwapSizeMB", 0) + size_mb = conf.get_resourcedisk_swap_size_mb() self.create_swap_space(mount_point, size_mb) except ResourceDiskError as e: logger.error("Failed to enable swap {0}", e) def mount_resource_disk(self, mount_point, fs): - device = OSUTIL.device_for_ide_port(1) + device = self.distro.osutil.device_for_ide_port(1) if device is None: raise ResourceDiskError("unable to detect disk topology") device = "/dev/" + device mountlist = shellutil.run_get_output("mount")[1] - existing = OSUTIL.get_mount_point(mountlist, device) + existing = self.distro.osutil.get_mount_point(mountlist, device) if(existing): logger.info("Resource disk {0}1 is already mounted", device) diff --git a/azurelinuxagent/distro/default/run.py b/azurelinuxagent/distro/default/run.py deleted file mode 100644 index dfd3b03..0000000 --- a/azurelinuxagent/distro/default/run.py +++ /dev/null @@ -1,71 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -import os -import time -import sys -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.conf as conf -from azurelinuxagent.metadata import AGENT_LONG_NAME, AGENT_VERSION, \ - DISTRO_NAME, DISTRO_VERSION, \ - DISTRO_FULL_NAME, PY_VERSION_MAJOR, \ - PY_VERSION_MINOR, PY_VERSION_MICRO -import azurelinuxagent.event as event -import azurelinuxagent.protocol as prot -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.utils.fileutil as fileutil - - -class MainHandler(object): - def __init__(self, handlers): - self.handlers = handlers - - def run(self): - logger.info("{0} Version:{1}", AGENT_LONG_NAME, AGENT_VERSION) - logger.info("OS: {0} {1}", DISTRO_NAME, DISTRO_VERSION) - logger.info("Python: {0}.{1}.{2}", PY_VERSION_MAJOR, PY_VERSION_MINOR, - PY_VERSION_MICRO) - - event.enable_unhandled_err_dump("Azure Linux Agent") - fileutil.write_file(OSUTIL.get_agent_pid_file_path(), text(os.getpid())) - - if conf.get_switch("DetectScvmmEnv", False): - if self.handlers.scvmm_handler.detect_scvmm_env(): - return - - self.handlers.dhcp_handler.probe() - - prot.detect_default_protocol() - - event.EventMonitor().start() - - self.handlers.provision_handler.process() - - if conf.get_switch("ResourceDisk.Format", False): - self.handlers.resource_disk_handler.start_activate_resource_disk() - - self.handlers.env_handler.start() - - protocol = prot.FACTORY.get_default_protocol() - while True: - #Handle extensions - self.handlers.ext_handlers_handler.process() - time.sleep(25) - diff --git a/azurelinuxagent/distro/default/scvmm.py b/azurelinuxagent/distro/default/scvmm.py index 680c04b..4d083b4 100644 --- a/azurelinuxagent/distro/default/scvmm.py +++ b/azurelinuxagent/distro/default/scvmm.py @@ -20,28 +20,29 @@ import os import subprocess import azurelinuxagent.logger as logger -from azurelinuxagent.utils.osutil import OSUTIL VMM_CONF_FILE_NAME = "linuxosconfiguration.xml" VMM_STARTUP_SCRIPT_NAME= "install" class ScvmmHandler(object): + def __init__(self, distro): + self.distro = distro def detect_scvmm_env(self): logger.info("Detecting Microsoft System Center VMM Environment") - OSUTIL.mount_dvd(max_retry=1, chk_err=False) - mount_point = OSUTIL.get_dvd_mount_point() + self.distro.osutil.mount_dvd(max_retry=1, chk_err=False) + mount_point = self.distro.osutil.get_dvd_mount_point() found = os.path.isfile(os.path.join(mount_point, VMM_CONF_FILE_NAME)) if found: self.start_scvmm_agent() else: - OSUTIL.umount_dvd(chk_err=False) + self.distro.osutil.umount_dvd(chk_err=False) return found def start_scvmm_agent(self): logger.info("Starting Microsoft System Center VMM Initialization " "Process") - mount_point = OSUTIL.get_dvd_mount_point() + mount_point = self.distro.osutil.get_dvd_mount_point() startup_script = os.path.join(mount_point, VMM_STARTUP_SCRIPT_NAME) subprocess.Popen(["/bin/bash", startup_script, "-p " + mount_point]) diff --git a/azurelinuxagent/distro/loader.py b/azurelinuxagent/distro/loader.py index 375abd2..74ea9e7 100644 --- a/azurelinuxagent/distro/loader.py +++ b/azurelinuxagent/distro/loader.py @@ -16,31 +16,52 @@ # import azurelinuxagent.logger as logger -from azurelinuxagent.metadata import DISTRO_NAME -import azurelinuxagent.distro.default.loader as default_loader +from azurelinuxagent.utils.textutil import Version +from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_FULL_NAME +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.ubuntu.distro import UbuntuDistro, \ + Ubuntu14Distro, \ + Ubuntu12Distro, \ + UbuntuSnappyDistro +from azurelinuxagent.distro.redhat.distro import RedhatDistro, Redhat6xDistro +from azurelinuxagent.distro.coreos.distro import CoreOSDistro +from azurelinuxagent.distro.suse.distro import SUSE11Distro, SUSEDistro +from azurelinuxagent.distro.debian.distro import DebianDistro - -def get_distro_loader(): - try: - logger.verb("Loading distro implemetation from: {0}", DISTRO_NAME) - pkg_name = "azurelinuxagent.distro.{0}.loader".format(DISTRO_NAME) - return __import__(pkg_name, fromlist="loader") - except (ImportError, ValueError): - logger.warn("Unable to load distro implemetation for {0}.", DISTRO_NAME) +def get_distro(distro_name=DISTRO_NAME, distro_version=DISTRO_VERSION, + distro_full_name=DISTRO_FULL_NAME): + if distro_name == "ubuntu": + if Version(distro_version) == Version("12.04") or \ + Version(distro_version) == Version("12.10"): + return Ubuntu12Distro() + elif Version(distro_version) == Version("14.04") or \ + Version(distro_version) == Version("14.10"): + return Ubuntu14Distro() + elif distro_full_name == "Snappy Ubuntu Core": + return UbuntuSnappyDistro() + else: + return UbuntuDistro() + if distro_name == "coreos": + return CoreOSDistro() + if distro_name == "suse": + if distro_full_name=='SUSE Linux Enterprise Server' and \ + Version(distro_version) < Version('12') or \ + distro_full_name == 'openSUSE' and \ + Version(distro_version) < Version('13.2'): + return SUSE11Distro() + else: + return SUSEDistro() + elif distro_name == "debian": + return DebianDistro() + elif distro_name == "redhat" or distro_name == "centos" or \ + distro_name == "oracle": + if Version(distro_version) < Version("7"): + return Redhat6xDistro() + else: + return RedhatDistro() + else: + logger.warn("Unable to load distro implemetation for {0}.", distro_name) logger.warn("Use default distro implemetation instead.") - return default_loader - -DISTRO_LOADER = get_distro_loader() - -def get_osutil(): - try: - return DISTRO_LOADER.get_osutil() - except AttributeError: - return default_loader.get_osutil() - -def get_handlers(): - try: - return DISTRO_LOADER.get_handlers() - except AttributeError: - return default_loader.get_handlers() + return DefaultDistro() diff --git a/azurelinuxagent/distro/oracle/__init__.py b/azurelinuxagent/distro/oracle/__init__.py deleted file mode 100644 index d9b82f5..0000000 --- a/azurelinuxagent/distro/oracle/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - diff --git a/azurelinuxagent/distro/oracle/loader.py b/azurelinuxagent/distro/oracle/loader.py deleted file mode 100644 index 9dc428f..0000000 --- a/azurelinuxagent/distro/oracle/loader.py +++ /dev/null @@ -1,25 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION -import azurelinuxagent.distro.redhat.loader as redhat - -def get_osutil(): - return redhat.get_osutil() - diff --git a/azurelinuxagent/distro/redhat/distro.py b/azurelinuxagent/distro/redhat/distro.py new file mode 100644 index 0000000..2f128d7 --- /dev/null +++ b/azurelinuxagent/distro/redhat/distro.py @@ -0,0 +1,32 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.redhat.osutil import RedhatOSUtil, Redhat6xOSUtil +from azurelinuxagent.distro.coreos.deprovision import CoreOSDeprovisionHandler + +class Redhat6xDistro(DefaultDistro): + def __init__(self): + super(Redhat6xDistro, self).__init__() + self.osutil = Redhat6xOSUtil() + +class RedhatDistro(DefaultDistro): + def __init__(self): + super(RedhatDistro, self).__init__() + self.osutil = RedhatOSUtil() diff --git a/azurelinuxagent/distro/redhat/loader.py b/azurelinuxagent/distro/redhat/loader.py deleted file mode 100644 index 8d3c75b..0000000 --- a/azurelinuxagent/distro/redhat/loader.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION - -def get_osutil(): - from azurelinuxagent.distro.redhat.osutil import Redhat6xOSUtil, RedhatOSUtil - if DISTRO_VERSION < "7": - return Redhat6xOSUtil() - else: - return RedhatOSUtil() - diff --git a/azurelinuxagent/distro/redhat/osutil.py b/azurelinuxagent/distro/redhat/osutil.py index 7478867..7f769a5 100644 --- a/azurelinuxagent/distro/redhat/osutil.py +++ b/azurelinuxagent/distro/redhat/osutil.py @@ -26,20 +26,19 @@ import struct import fcntl import time import base64 +import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.future import text, bytebuffer +from azurelinuxagent.future import ustr, bytebuffer +from azurelinuxagent.exception import OSUtilError, CryptError import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.textutil as textutil -from azurelinuxagent.distro.default.osutil import DefaultOSUtil, OSUtilError +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.distro.default.osutil import DefaultOSUtil class Redhat6xOSUtil(DefaultOSUtil): def __init__(self): super(Redhat6xOSUtil, self).__init__() - self.sshd_conf_file_path = '/etc/ssh/sshd_config' - self.openssl_cmd = '/usr/bin/openssl' - self.conf_file_path = '/etc/waagent.conf' - self.selinux=None def start_network(self): return shellutil.run("/sbin/service networking start", chk_err=False) @@ -58,63 +57,14 @@ class Redhat6xOSUtil(DefaultOSUtil): def unregister_agent_service(self): return shellutil.run("chkconfig --del waagent", chk_err=False) - - def asn1_to_ssh_rsa(self, pubkey): - lines = pubkey.split("\n") - lines = [x for x in lines if not x.startswith("----")] - base64_encoded = "".join(lines) - try: - #TODO remove pyasn1 dependency - from pyasn1.codec.der import decoder as der_decoder - der_encoded = base64.b64decode(base64_encoded) - der_encoded = der_decoder.decode(der_encoded)[0][1] - key = der_decoder.decode(self.bits_to_bytes(der_encoded))[0] - n=key[0] - e=key[1] - keydata = bytearray() - keydata.extend(struct.pack('>I', len("ssh-rsa"))) - keydata.extend(b"ssh-rsa") - keydata.extend(struct.pack('>I', len(self.num_to_bytes(e)))) - keydata.extend(self.num_to_bytes(e)) - keydata.extend(struct.pack('>I', len(self.num_to_bytes(n)) + 1)) - keydata.extend(b"\0") - keydata.extend(self.num_to_bytes(n)) - keydata_base64 = base64.b64encode(bytebuffer(keydata)) - return text(b"ssh-rsa " + keydata_base64 + b"\n", - encoding='utf-8') - except ImportError as e: - raise OSUtilError("Failed to load pyasn1.codec.der") - - def num_to_bytes(self, num): - """ - Pack number into bytes. Retun as string. - """ - result = bytearray() - while num: - result.append(num & 0xFF) - num >>= 8 - result.reverse() - return result - - def bits_to_bytes(self, bits): - """ - Convert an array contains bits, [0,1] to a byte array - """ - index = 7 - byte_array = bytearray() - curr = 0 - for bit in bits: - curr = curr | (bit << index) - index = index - 1 - if index == -1: - byte_array.append(curr) - curr = 0 - index = 7 - return bytes(byte_array) - + def openssl_to_openssh(self, input_file, output_file): pubkey = fileutil.read_file(input_file) - ssh_rsa_pubkey = self.asn1_to_ssh_rsa(pubkey) + try: + cryptutil = CryptUtil(conf.get_openssl_cmd()) + ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey) + except CryptError as e: + raise OSUtilError(ustr(e)) fileutil.write_file(output_file, ssh_rsa_pubkey) #Override @@ -134,8 +84,7 @@ class Redhat6xOSUtil(DefaultOSUtil): def set_dhcp_hostname(self, hostname): ifname = self.get_if_name() filepath = "/etc/sysconfig/network-scripts/ifcfg-{0}".format(ifname) - fileutil.update_conf_file(filepath, - 'DHCP_HOSTNAME', + fileutil.update_conf_file(filepath, 'DHCP_HOSTNAME', 'DHCP_HOSTNAME={0}'.format(hostname)) class RedhatOSUtil(Redhat6xOSUtil): @@ -162,4 +111,5 @@ class RedhatOSUtil(Redhat6xOSUtil): def unregister_agent_service(self): return shellutil.run("systemctl disable waagent", chk_err=False) - + def openssl_to_openssh(self, input_file, output_file): + DefaultOSUtil.openssl_to_openssh(self, input_file, output_file) diff --git a/azurelinuxagent/distro/suse/distro.py b/azurelinuxagent/distro/suse/distro.py new file mode 100644 index 0000000..5b39369 --- /dev/null +++ b/azurelinuxagent/distro/suse/distro.py @@ -0,0 +1,32 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.suse.osutil import SUSE11OSUtil, SUSEOSUtil + +class SUSE11Distro(DefaultDistro): + def __init__(self): + super(SUSE11Distro, self).__init__() + self.osutil = SUSE11OSUtil() + +class SUSEDistro(DefaultDistro): + def __init__(self): + super(SUSEDistro, self).__init__() + self.osutil = SUSEOSUtil() + diff --git a/azurelinuxagent/distro/suse/loader.py b/azurelinuxagent/distro/suse/loader.py deleted file mode 100644 index b01384b..0000000 --- a/azurelinuxagent/distro/suse/loader.py +++ /dev/null @@ -1,29 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME - -def get_osutil(): - from azurelinuxagent.distro.suse.osutil import SUSE11OSUtil, SUSEOSUtil - if DISTRO_FULL_NAME=='SUSE Linux Enterprise Server' and DISTRO_VERSION < '12' \ - or DISTRO_FULL_NAME == 'openSUSE' and DISTRO_VERSION < '13.2': - return SUSE11OSUtil() - else: - return SUSEOSUtil() - diff --git a/azurelinuxagent/distro/ubuntu/deprovision.py b/azurelinuxagent/distro/ubuntu/deprovision.py index 0c3c4e5..da6e834 100644 --- a/azurelinuxagent/distro/ubuntu/deprovision.py +++ b/azurelinuxagent/distro/ubuntu/deprovision.py @@ -33,6 +33,9 @@ def del_resolv(): class UbuntuDeprovisionHandler(DeprovisionHandler): + def __init__(self, distro): + super(UbuntuDeprovisionHandler, self).__init__(distro) + def setup(self, deluser): warnings, actions = super(UbuntuDeprovisionHandler, self).setup(deluser) warnings.append("WARNING! Nameserver configuration in " diff --git a/azurelinuxagent/distro/ubuntu/distro.py b/azurelinuxagent/distro/ubuntu/distro.py new file mode 100644 index 0000000..f380f6c --- /dev/null +++ b/azurelinuxagent/distro/ubuntu/distro.py @@ -0,0 +1,55 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.ubuntu.osutil import Ubuntu14OSUtil, \ + Ubuntu12OSUtil, \ + UbuntuOSUtil, \ + UbuntuSnappyOSUtil + +from azurelinuxagent.distro.ubuntu.provision import UbuntuProvisionHandler +from azurelinuxagent.distro.ubuntu.deprovision import UbuntuDeprovisionHandler + +class UbuntuDistro(DefaultDistro): + def __init__(self): + super(UbuntuDistro, self).__init__() + self.osutil = UbuntuOSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) + +class Ubuntu12Distro(DefaultDistro): + def __init__(self): + super(Ubuntu12Distro, self).__init__() + self.osutil = Ubuntu12OSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) + +class Ubuntu14Distro(DefaultDistro): + def __init__(self): + super(Ubuntu14Distro, self).__init__() + self.osutil = Ubuntu14OSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) + +class UbuntuSnappyDistro(DefaultDistro): + def __init__(self): + super(UbuntuSnappyDistro, self).__init__() + self.osutil = UbuntuSnappyOSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) diff --git a/azurelinuxagent/distro/ubuntu/handlerFactory.py b/azurelinuxagent/distro/ubuntu/handlerFactory.py deleted file mode 100644 index 11f7f04..0000000 --- a/azurelinuxagent/distro/ubuntu/handlerFactory.py +++ /dev/null @@ -1,29 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.distro.ubuntu.provision import UbuntuProvisionHandler -from azurelinuxagent.distro.ubuntu.deprovision import UbuntuDeprovisionHandler -from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory - -class UbuntuHandlerFactory(DefaultHandlerFactory): - def __init__(self): - super(UbuntuHandlerFactory, self).__init__() - self.provision_handler = UbuntuProvisionHandler() - self.deprovision_handler = UbuntuDeprovisionHandler() - diff --git a/azurelinuxagent/distro/ubuntu/loader.py b/azurelinuxagent/distro/ubuntu/loader.py deleted file mode 100644 index 3fe2239..0000000 --- a/azurelinuxagent/distro/ubuntu/loader.py +++ /dev/null @@ -1,40 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME - -def get_osutil(): - from azurelinuxagent.distro.ubuntu.osutil import Ubuntu1204OSUtil, \ - UbuntuOSUtil, \ - Ubuntu14xOSUtil, \ - UbuntuSnappyOSUtil - - if DISTRO_VERSION == "12.04": - return Ubuntu1204OSUtil() - elif DISTRO_VERSION == "14.04" or DISTRO_VERSION == "14.10": - return Ubuntu14xOSUtil() - elif DISTRO_FULL_NAME == "Snappy Ubuntu Core": - return UbuntuSnappyOSUtil() - else: - return UbuntuOSUtil() - -def get_handlers(): - from azurelinuxagent.distro.ubuntu.handlerFactory import UbuntuHandlerFactory - return UbuntuHandlerFactory() - diff --git a/azurelinuxagent/distro/ubuntu/osutil.py b/azurelinuxagent/distro/ubuntu/osutil.py index adf7660..cc4b8ef 100644 --- a/azurelinuxagent/distro/ubuntu/osutil.py +++ b/azurelinuxagent/distro/ubuntu/osutil.py @@ -31,9 +31,9 @@ import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.textutil as textutil from azurelinuxagent.distro.default.osutil import DefaultOSUtil -class Ubuntu14xOSUtil(DefaultOSUtil): +class Ubuntu14OSUtil(DefaultOSUtil): def __init__(self): - super(Ubuntu14xOSUtil, self).__init__() + super(Ubuntu14OSUtil, self).__init__() def start_network(self): return shellutil.run("service networking start", chk_err=False) @@ -44,16 +44,16 @@ class Ubuntu14xOSUtil(DefaultOSUtil): def start_agent_service(self): return shellutil.run("service walinuxagent start", chk_err=False) -class Ubuntu1204OSUtil(Ubuntu14xOSUtil): +class Ubuntu12OSUtil(Ubuntu14OSUtil): def __init__(self): - super(Ubuntu1204OSUtil, self).__init__() + super(Ubuntu12OSUtil, self).__init__() #Override def get_dhcp_pid(self): ret= shellutil.run_get_output("pidof dhclient3") return ret[1] if ret[0] == 0 else None -class UbuntuOSUtil(Ubuntu14xOSUtil): +class UbuntuOSUtil(Ubuntu14OSUtil): def __init__(self): super(UbuntuOSUtil, self).__init__() @@ -63,7 +63,7 @@ class UbuntuOSUtil(Ubuntu14xOSUtil): def unregister_agent_service(self): return shellutil.run("systemctl mask walinuxagent", chk_err=False) -class UbuntuSnappyOSUtil(Ubuntu14xOSUtil): +class UbuntuSnappyOSUtil(Ubuntu14OSUtil): def __init__(self): super(UbuntuSnappyOSUtil, self).__init__() self.conf_file_path = '/apps/walinuxagent/current/waagent.conf' diff --git a/azurelinuxagent/distro/ubuntu/provision.py b/azurelinuxagent/distro/ubuntu/provision.py index a68fe4d..330e057 100644 --- a/azurelinuxagent/distro/ubuntu/provision.py +++ b/azurelinuxagent/distro/ubuntu/provision.py @@ -20,12 +20,11 @@ import os import time import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.conf as conf -import azurelinuxagent.protocol as prot +import azurelinuxagent.protocol.ovfenv as ovfenv from azurelinuxagent.event import add_event, WALAEventOperation -from azurelinuxagent.exception import * -from azurelinuxagent.utils.osutil import OSUTIL +from azurelinuxagent.exception import ProvisionError, ProtocolError import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.distro.default.provision import ProvisionHandler @@ -34,49 +33,61 @@ from azurelinuxagent.distro.default.provision import ProvisionHandler On ubuntu image, provision could be disabled. """ class UbuntuProvisionHandler(ProvisionHandler): - def process(self): + def __init__(self, distro): + self.distro = distro + + def run(self): #If provision is enabled, run default provision handler - if conf.get_switch("Provisioning.Enabled", False): - super(UbuntuProvisionHandler, self).process() + if conf.get_provision_enabled(): + super(UbuntuProvisionHandler, self).run() return logger.info("run Ubuntu provision handler") - provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned") + provisioned = os.path.join(conf.get_lib_dir(), "provisioned") if os.path.isfile(provisioned): return - logger.info("Waiting cloud-init to finish provisioning.") - protocol = prot.FACTORY.get_default_protocol() + logger.info("Waiting cloud-init to copy ovf-env.xml.") + self.wait_for_ovfenv() + + protocol = self.distro.protocol_util.detect_protocol() + self.report_not_ready("Provisioning", "Starting") + logger.info("Sleep 15 seconds to prevent throttling") + time.sleep(15) #Sleep to prevent throttling try: logger.info("Wait for ssh host key to be generated.") thumbprint = self.wait_for_ssh_host_key() fileutil.write_file(provisioned, "") - logger.info("Finished provisioning") - status = prot.ProvisionStatus(status="Ready") - status.properties.certificateThumbprint = thumbprint - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) - + except ProvisionError as e: logger.error("Provision failed: {0}", e) - status = prot.ProvisionStatus(status="NotReady", - subStatus="ProvisioningFailed", - description= text(e)) - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) + self.report_not_ready("ProvisioningFailed", ustr(e)) + self.report_event(ustr(e)) + return + + self.report_ready(thumbprint) + self.report_event("Provision succeed", is_success=True) - add_event(name="WALA", is_success=False, message=text(e), - op=WALAEventOperation.Provision) + def wait_for_ovfenv(self, max_retry=60): + """ + Wait for cloud-init to copy ovf-env.xml file from provision ISO + """ + for retry in range(0, max_retry): + try: + self.distro.protocol_util.get_ovf_env() + return + except ProtocolError: + if retry < max_retry - 1: + logger.info("Wait for cloud-init to copy ovf-env.xml") + time.sleep(5) + raise ProvisionError("ovf-env.xml is not copied") def wait_for_ssh_host_key(self, max_retry=60): - kepair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa") + """ + Wait for cloud-init to generate ssh host key + """ + kepair_type = conf.get_ssh_host_keypair_type() path = '/etc/ssh/ssh_host_{0}_key'.format(kepair_type) for retry in range(0, max_retry): if os.path.isfile(path): diff --git a/azurelinuxagent/event.py b/azurelinuxagent/event.py index 02e8017..f38b242 100644 --- a/azurelinuxagent/event.py +++ b/azurelinuxagent/event.py @@ -25,14 +25,15 @@ import datetime import threading import platform import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.protocol as prot +from azurelinuxagent.exception import EventError, ProtocolError +from azurelinuxagent.future import ustr +from azurelinuxagent.protocol.restapi import TelemetryEventParam, \ + TelemetryEventList, \ + TelemetryEvent, \ + set_properties, get_properties from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, \ DISTRO_CODE_NAME, AGENT_VERSION -from azurelinuxagent.utils.osutil import OSUTIL -class EventError(Exception): - pass class WALAEventOperation: HeartBeat="HeartBeat" @@ -47,132 +48,65 @@ class WALAEventOperation: ActivateResourceDisk="ActivateResourceDisk" UnhandledError="UnhandledError" -class EventMonitor(object): +class EventLogger(object): def __init__(self): - self.sysinfo = [] - self.event_dir = os.path.join(OSUTIL.get_lib_dir(), "events") + self.event_dir = None - def init_sysinfo(self): - osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(), - DISTRO_NAME, - DISTRO_VERSION, - DISTRO_CODE_NAME, - platform.release()) + def save_event(self, data): + if self.event_dir is None: + logger.warn("Event reporter is not initialized.") + return - self.sysinfo.append(prot.TelemetryEventParam("OSVersion", osversion)) - self.sysinfo.append(prot.TelemetryEventParam("GAVersion", - AGENT_VERSION)) - self.sysinfo.append(prot.TelemetryEventParam("RAM", - OSUTIL.get_total_mem())) - self.sysinfo.append(prot.TelemetryEventParam("Processors", - OSUTIL.get_processor_cores())) - try: - protocol = prot.FACTORY.get_default_protocol() - vminfo = protocol.get_vminfo() - self.sysinfo.append(prot.TelemetryEventParam("VMName", - vminfo.vmName)) - #TODO add other system info like, subscription id, etc. - except prot.ProtocolError as e: - logger.warn("Failed to get vm info: {0}", e) - - def start(self): - event_thread = threading.Thread(target = self.run) - event_thread.setDaemon(True) - event_thread.start() + if not os.path.exists(self.event_dir): + os.mkdir(self.event_dir) + os.chmod(self.event_dir, 0o700) + if len(os.listdir(self.event_dir)) > 1000: + raise EventError("Too many files under: {0}".format(self.event_dir)) - def collect_event(self, evt_file_name): + filename = os.path.join(self.event_dir, ustr(int(time.time()*1000000))) try: - logger.verb("Found event file: {0}", evt_file_name) - with open(evt_file_name, "rb") as evt_file: - #if fail to open or delete the file, throw exception - json_str = evt_file.read().decode("utf-8",'ignore') - logger.verb("Processed event file: {0}", evt_file_name) - os.remove(evt_file_name) - return json_str + with open(filename+".tmp",'wb+') as hfile: + hfile.write(data.encode("utf-8")) + os.rename(filename+".tmp", filename+".tld") except IOError as e: - msg = "Failed to process {0}, {1}".format(evt_file_name, e) - raise EventError(msg) - - def collect_and_send_events(self): - event_list = prot.TelemetryEventList() - event_files = os.listdir(self.event_dir) - for event_file in event_files: - if not event_file.endswith(".tld"): - continue - event_file_path = os.path.join(self.event_dir, event_file) - try: - data_str = self.collect_event(event_file_path) - except EventError as e: - logger.error("{0}", e) - continue - try: - data = json.loads(data_str) - except ValueError as e: - logger.verb(data_str) - logger.error("Failed to decode json event file: {0}", e) - continue - - event = prot.TelemetryEvent() - prot.set_properties("event", event, data) - event.parameters.extend(self.sysinfo) - event_list.events.append(event) - if len(event_list.events) == 0: - return - + raise EventError("Failed to write events to file:{0}", e) + + def add_event(self, name, op="", is_success=True, duration=0, version="1.0", + message="", evt_type="", is_internal=False): + event = TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") + event.parameters.append(TelemetryEventParam('Name', name)) + event.parameters.append(TelemetryEventParam('Version', version)) + event.parameters.append(TelemetryEventParam('IsInternal', is_internal)) + event.parameters.append(TelemetryEventParam('Operation', op)) + event.parameters.append(TelemetryEventParam('OperationSuccess', + is_success)) + event.parameters.append(TelemetryEventParam('Message', message)) + event.parameters.append(TelemetryEventParam('Duration', duration)) + event.parameters.append(TelemetryEventParam('ExtensionType', evt_type)) + + data = get_properties(event) try: - protocol = prot.FACTORY.get_default_protocol() - protocol.report_event(event_list) - except prot.ProtocolError as e: + self.save_event(json.dumps(data)) + except EventError as e: logger.error("{0}", e) - def run(self): - self.init_sysinfo() - last_heartbeat = datetime.datetime.min - period = datetime.timedelta(hours = 12) - while(True): - if (datetime.datetime.now()-last_heartbeat) > period: - last_heartbeat = datetime.datetime.now() - add_event(op=WALAEventOperation.HeartBeat, - name="WALA",is_success=True) - self.collect_and_send_events() - time.sleep(60) - -def save_event(data): - event_dir = os.path.join(OSUTIL.get_lib_dir(), 'events') - if not os.path.exists(event_dir): - os.mkdir(event_dir) - os.chmod(event_dir,0o700) - if len(os.listdir(event_dir)) > 1000: - raise EventError("Too many files under: {0}", event_dir) - - filename = os.path.join(event_dir, text(int(time.time()*1000000))) - try: - with open(filename+".tmp",'wb+') as hfile: - hfile.write(data.encode("utf-8")) - os.rename(filename+".tmp", filename+".tld") - except IOError as e: - raise EventError("Failed to write events to file:{0}", e) +__event_logger__ = EventLogger() def add_event(name, op="", is_success=True, duration=0, version="1.0", - message="", evt_type="", is_internal=False): + message="", evt_type="", is_internal=False, + reporter=__event_logger__): log = logger.info if is_success else logger.error log("Event: name={0}, op={1}, message={2}", name, op, message) - event = prot.TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") - event.parameters.append(prot.TelemetryEventParam('Name', name)) - event.parameters.append(prot.TelemetryEventParam('Version', version)) - event.parameters.append(prot.TelemetryEventParam('IsInternal', is_internal)) - event.parameters.append(prot.TelemetryEventParam('Operation', op)) - event.parameters.append(prot.TelemetryEventParam('OperationSuccess', - is_success)) - event.parameters.append(prot.TelemetryEventParam('Message', message)) - event.parameters.append(prot.TelemetryEventParam('Duration', duration)) - event.parameters.append(prot.TelemetryEventParam('ExtensionType', evt_type)) - data = prot.get_properties(event) - try: - save_event(json.dumps(data)) - except EventError as e: - logger.error("{0}", e) + if reporter.event_dir is None: + logger.warn("Event reporter is not initialized.") + return + reporter.add_event(name, op=op, is_success=is_success, duration=duration, + version=version, message=message, evt_type=evt_type, + is_internal=is_internal) + +def init_event_logger(event_dir, reporter=__event_logger__): + reporter.event_dir = event_dir def dump_unhandled_err(name): if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \ @@ -184,8 +118,7 @@ def dump_unhandled_err(name): last_traceback) message= "".join(error) add_event(name, is_success=False, message=message, - op=WALAEventOperation.UnhandledError) + op=WALAEventOperation.UnhandledError) def enable_unhandled_err_dump(name): atexit.register(dump_unhandled_err, name) - diff --git a/azurelinuxagent/exception.py b/azurelinuxagent/exception.py index d7d9b0a..7fa5cff 100644 --- a/azurelinuxagent/exception.py +++ b/azurelinuxagent/exception.py @@ -24,42 +24,92 @@ class AgentError(Exception): """ Base class of agent error. """ - def __init__(self, errno, msg): - msg = "({0}){1}".format(errno, msg) + def __init__(self, errno, msg, inner=None): + msg = u"({0}){1}".format(errno, msg) + if inner is not None: + msg = u"{0} \n inner error: {1}".format(msg, inner) super(AgentError, self).__init__(msg) class AgentConfigError(AgentError): """ When configure file is not found or malformed. """ - def __init__(self, msg): - super(AgentConfigError, self).__init__('000001', msg) + def __init__(self, msg=None, inner=None): + super(AgentConfigError, self).__init__('000001', msg, inner) class AgentNetworkError(AgentError): """ When network is not avaiable. """ - def __init__(self, msg): - super(AgentNetworkError, self).__init__('000002', msg) + def __init__(self, msg=None, inner=None): + super(AgentNetworkError, self).__init__('000002', msg, inner) class ExtensionError(AgentError): """ When failed to execute an extension """ - def __init__(self, msg): - super(ExtensionError, self).__init__('000003', msg) + def __init__(self, msg=None, inner=None): + super(ExtensionError, self).__init__('000003', msg, inner) class ProvisionError(AgentError): """ When provision failed """ - def __init__(self, msg): - super(ProvisionError, self).__init__('000004', msg) + def __init__(self, msg=None, inner=None): + super(ProvisionError, self).__init__('000004', msg, inner) class ResourceDiskError(AgentError): """ Mount resource disk failed """ - def __init__(self, msg): - super(ResourceDiskError, self).__init__('000005', msg) + def __init__(self, msg=None, inner=None): + super(ResourceDiskError, self).__init__('000005', msg, inner) +class DhcpError(AgentError): + """ + Failed to handle dhcp response + """ + def __init__(self, msg=None, inner=None): + super(DhcpError, self).__init__('000006', msg, inner) + +class OSUtilError(AgentError): + """ + Failed to perform operation to OS configuration + """ + def __init__(self, msg=None, inner=None): + super(OSUtilError, self).__init__('000007', msg, inner) + +class ProtocolError(AgentError): + """ + Azure protocol error + """ + def __init__(self, msg=None, inner=None): + super(ProtocolError, self).__init__('000008', msg, inner) + +class ProtocolNotFoundError(ProtocolError): + """ + Azure protocol endpoint not found + """ + def __init__(self, msg=None, inner=None): + super(ProtocolNotFoundError, self).__init__(msg, inner) + +class HttpError(AgentError): + """ + Http request failure + """ + def __init__(self, msg=None, inner=None): + super(HttpError, self).__init__('000009', msg, inner) + +class EventError(AgentError): + """ + Event reporting error + """ + def __init__(self, msg=None, inner=None): + super(EventError, self).__init__('000010', msg, inner) + +class CryptError(AgentError): + """ + Encrypt/Decrypt error + """ + def __init__(self, msg=None, inner=None): + super(CryptError, self).__init__('000011', msg, inner) diff --git a/azurelinuxagent/future.py b/azurelinuxagent/future.py index 8451345..8509732 100644 --- a/azurelinuxagent/future.py +++ b/azurelinuxagent/future.py @@ -7,15 +7,25 @@ Add alies for python2 and python3 libs and fucntions. if sys.version_info[0]== 3: import http.client as httpclient from urllib.parse import urlparse - text = str + + """Rename Python3 str to ustr""" + ustr = str + bytebuffer = memoryview + read_input = input + elif sys.version_info[0] == 2: import httplib as httpclient from urlparse import urlparse - text = unicode + + """Rename Python2 unicode to ustr""" + ustr = unicode + bytebuffer = buffer + read_input = raw_input + else: raise ImportError("Unknown python version:{0}".format(sys.version_info)) diff --git a/azurelinuxagent/handler.py b/azurelinuxagent/handler.py deleted file mode 100644 index 538ee30..0000000 --- a/azurelinuxagent/handler.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -""" -Handler handles different tasks like, provisioning, deprovisioning etc. -The handlers could be extended for different distros. The default -implementation is under azurelinuxagent.distros.default -""" -import azurelinuxagent.distro.loader as loader - -HANDLERS = loader.get_handlers() - diff --git a/azurelinuxagent/logger.py b/azurelinuxagent/logger.py index 21c02a6..6c6b406 100644 --- a/azurelinuxagent/logger.py +++ b/azurelinuxagent/logger.py @@ -22,7 +22,7 @@ Log utils """ import os import sys -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr from datetime import datetime class Logger(object): @@ -49,8 +49,8 @@ class Logger(object): def log(self, level, msg_format, *args): #if msg_format is not unicode convert it to unicode - if type(msg_format) is not text: - msg_format = text(msg_format, errors="backslashreplace") + if type(msg_format) is not ustr: + msg_format = ustr(msg_format, errors="backslashreplace") if len(args) > 0: msg = msg_format.format(*args) else: @@ -63,7 +63,7 @@ class Logger(object): else: log_item = u"{0} {1} {2}\n".format(time, level_str, msg) - log_item = text(log_item.encode('ascii', "backslashreplace"), + log_item = ustr(log_item.encode('ascii', "backslashreplace"), encoding="ascii") for appender in self.appenders: appender.write(level, log_item) diff --git a/azurelinuxagent/metadata.py b/azurelinuxagent/metadata.py index 5cf4902..34fdcf9 100644 --- a/azurelinuxagent/metadata.py +++ b/azurelinuxagent/metadata.py @@ -22,11 +22,11 @@ import re import platform import sys import azurelinuxagent.utils.fileutil as fileutil -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr def get_distro(): if 'FreeBSD' in platform.system(): - release = re.sub('\-.*\Z', '', text(platform.release())) + release = re.sub('\-.*\Z', '', ustr(platform.release())) osinfo = ['freebsd', release, '', 'freebsd'] if 'linux_distribution' in dir(platform): osinfo = list(platform.linux_distribution(full_distribution_name=0)) @@ -47,7 +47,7 @@ def get_distro(): AGENT_NAME = "WALinuxAgent" AGENT_LONG_NAME = "Azure Linux Agent" -AGENT_VERSION = '2.1.2' +AGENT_VERSION = '2.1.3' AGENT_LONG_VERSION = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION) AGENT_DESCRIPTION = """\ The Azure Linux Agent supports the provisioning and running of Linux diff --git a/azurelinuxagent/protocol/__init__.py b/azurelinuxagent/protocol/__init__.py index a4572e6..8c1bbdb 100644 --- a/azurelinuxagent/protocol/__init__.py +++ b/azurelinuxagent/protocol/__init__.py @@ -16,8 +16,3 @@ # # Requires Python 2.4+ and Openssl 1.0+ # - -from azurelinuxagent.protocol.common import * -from azurelinuxagent.protocol.protocolFactory import FACTORY, \ - detect_default_protocol - diff --git a/azurelinuxagent/protocol/common.py b/azurelinuxagent/protocol/common.py deleted file mode 100644 index 367794f..0000000 --- a/azurelinuxagent/protocol/common.py +++ /dev/null @@ -1,240 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# -import os -import copy -import re -import json -import xml.dom.minidom -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.utils.fileutil as fileutil - -class ProtocolError(Exception): - pass - -class ProtocolNotFound(Exception): - pass - -def validata_param(name, val, expected_type): - if val is None: - raise ProtocolError("{0} is None".format(name)) - if not isinstance(val, expected_type): - raise ProtocolError(("{0} type should be {1} not {2}" - "").format(name, expected_type, type(val))) - -def set_properties(name, obj, data): - if isinstance(obj, DataContract): - validata_param("Property '{0}'".format(name), data, dict) - for prob_name, prob_val in data.items(): - prob_full_name = "{0}.{1}".format(name, prob_name) - try: - prob = getattr(obj, prob_name) - except AttributeError: - logger.warn("Unknown property: {0}", prob_full_name) - continue - prob = set_properties(prob_full_name, prob, prob_val) - setattr(obj, prob_name, prob) - return obj - elif isinstance(obj, DataContractList): - validata_param("List '{0}'".format(name), data, list) - for item_data in data: - item = obj.item_cls() - item = set_properties(name, item, item_data) - obj.append(item) - return obj - else: - return data - -def get_properties(obj): - if isinstance(obj, DataContract): - data = {} - props = vars(obj) - for prob_name, prob in list(props.items()): - data[prob_name] = get_properties(prob) - return data - elif isinstance(obj, DataContractList): - data = [] - for item in obj: - item_data = get_properties(item) - data.append(item_data) - return data - else: - return obj - -class DataContract(object): - pass - -class DataContractList(list): - def __init__(self, item_cls): - self.item_cls = item_cls - -""" -Data contract between guest and host -""" -class VMInfo(DataContract): - def __init__(self, subscriptionId=None, vmName=None): - self.subscriptionId = subscriptionId - self.vmName = vmName - -class Cert(DataContract): - def __init__(self, name=None, thumbprint=None, certificateDataUri=None): - self.name = name - self.thumbprint = thumbprint - self.certificateDataUri = certificateDataUri - -class CertList(DataContract): - def __init__(self): - self.certificates = DataContractList(Cert) - -class Extension(DataContract): - def __init__(self, name=None, sequenceNumber=None, publicSettings=None, - privateSettings=None, certificateThumbprint=None): - self.name = name - self.sequenceNumber = sequenceNumber - self.publicSettings = publicSettings - self.privateSettings = privateSettings - self.certificateThumbprint = certificateThumbprint - -class ExtHandlerProperties(DataContract): - def __init__(self): - self.version = None - self.upgradePolicy = None - self.state = None - self.extensions = DataContractList(Extension) - -class ExtHandlerVersionUri(DataContract): - def __init__(self): - self.uri = None - -class ExtHandler(DataContract): - def __init__(self, name=None): - self.name = name - self.properties = ExtHandlerProperties() - self.versionUris = DataContractList(ExtHandlerVersionUri) - -class ExtHandlerList(DataContract): - def __init__(self): - self.extHandlers = DataContractList(ExtHandler) - -class ExtHandlerPackageUri(DataContract): - def __init__(self, uri=None): - self.uri = uri - -class ExtHandlerPackage(DataContract): - def __init__(self, version = None): - self.version = version - self.uris = DataContractList(ExtHandlerPackageUri) - -class ExtHandlerPackageList(DataContract): - def __init__(self): - self.versions = DataContractList(ExtHandlerPackage) - -class VMProperties(DataContract): - def __init__(self, certificateThumbprint=None): - #TODO need to confirm the property name - self.certificateThumbprint = certificateThumbprint - -class ProvisionStatus(DataContract): - def __init__(self, status=None, subStatus=None, description=None): - self.status = status - self.subStatus = subStatus - self.description = description - self.properties = VMProperties() - -class ExtensionSubStatus(DataContract): - def __init__(self, name=None, status=None, code=None, message=None): - self.name = name - self.status = status - self.code = code - self.message = message - -class ExtensionStatus(DataContract): - def __init__(self, configurationAppliedTime=None, operation=None, - status=None, seq_no=None, code=None, message=None): - self.configurationAppliedTime = configurationAppliedTime - self.operation = operation - self.status = status - self.sequenceNumber = seq_no - self.code = code - self.message = message - self.substatusList = DataContractList(ExtensionSubStatus) - -class ExtHandlerStatus(DataContract): - def __init__(self, name=None, version=None, status=None, message=None): - self.name = name - self.version = version - self.status = status - self.message = message - self.extensions = DataContractList(text) - -class VMAgentStatus(DataContract): - def __init__(self, version=None, status=None, message=None): - self.version = version - self.status = status - self.message = message - self.extensionHandlers = DataContractList(ExtHandlerStatus) - -class VMStatus(DataContract): - def __init__(self): - self.vmAgent = VMAgentStatus() - -class TelemetryEventParam(DataContract): - def __init__(self, name=None, value=None): - self.name = name - self.value = value - -class TelemetryEvent(DataContract): - def __init__(self, eventId=None, providerId=None): - self.eventId = eventId - self.providerId = providerId - self.parameters = DataContractList(TelemetryEventParam) - -class TelemetryEventList(DataContract): - def __init__(self): - self.events = DataContractList(TelemetryEvent) - -class Protocol(DataContract): - - def initialize(self): - raise NotImplementedError() - - def get_vminfo(self): - raise NotImplementedError() - - def get_certs(self): - raise NotImplementedError() - - def get_ext_handlers(self): - raise NotImplementedError() - - def get_ext_handler_pkgs(self, extension): - raise NotImplementedError() - - def report_provision_status(self, provision_status): - raise NotImplementedError() - - def report_vm_status(self, vm_status): - raise NotImplementedError() - - def report_ext_status(self, ext_handler_name, ext_name, ext_status): - raise NotImplementedError() - - def report_event(self, event): - raise NotImplementedError() - diff --git a/azurelinuxagent/protocol/metadata.py b/azurelinuxagent/protocol/metadata.py new file mode 100644 index 0000000..8a1656f --- /dev/null +++ b/azurelinuxagent/protocol/metadata.py @@ -0,0 +1,195 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import json +import shutil +import os +import time +from azurelinuxagent.exception import ProtocolError, HttpError +from azurelinuxagent.future import httpclient, ustr +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +import azurelinuxagent.utils.restutil as restutil +import azurelinuxagent.utils.textutil as textutil +import azurelinuxagent.utils.fileutil as fileutil +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.protocol.restapi import * + +METADATA_ENDPOINT='169.254.169.254' +APIVERSION='2015-05-01-preview' +BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" + +TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem" +TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem" + +#TODO remote workarround for azure stack +MAX_PING = 30 +RETRY_PING_INTERVAL = 10 + +def _add_content_type(headers): + if headers is None: + headers = {} + headers["content-type"] = "application/json" + return headers + +class MetadataProtocol(Protocol): + + def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT): + self.apiversion = apiversion + self.endpoint = endpoint + self.identity_uri = BASE_URI.format(self.endpoint, "identity", + self.apiversion, "&$expand=*") + self.cert_uri = BASE_URI.format(self.endpoint, "certificates", + self.apiversion, "&$expand=*") + self.ext_uri = BASE_URI.format(self.endpoint, "extensionHandlers", + self.apiversion, "&$expand=*") + self.provision_status_uri = BASE_URI.format(self.endpoint, + "provisioningStatus", + self.apiversion, "") + self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent", + self.apiversion, "") + self.ext_status_uri = BASE_URI.format(self.endpoint, + "status/extensions/{0}", + self.apiversion, "") + self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry", + self.apiversion, "") + + def _get_data(self, url, headers=None): + try: + resp = restutil.http_get(url, headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if resp.status != httpclient.OK: + raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) + + data = resp.read() + etag = resp.getheader('ETag') + if data is None: + return None + data = json.loads(ustr(data, encoding="utf-8")) + return data, etag + + def _put_data(self, url, data, headers=None): + headers = _add_content_type(headers) + try: + resp = restutil.http_put(url, json.dumps(data), headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + if resp.status != httpclient.OK: + raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) + + def _post_data(self, url, data, headers=None): + headers = _add_content_type(headers) + try: + resp = restutil.http_post(url, json.dumps(data), headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + if resp.status != httpclient.CREATED: + raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) + + def _get_trans_cert(self): + trans_crt_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + if not os.path.isfile(trans_crt_file): + raise ProtocolError("{0} is missing.".format(trans_crt_file)) + content = fileutil.read_file(trans_crt_file) + return textutil.get_bytes_from_pem(content) + + def detect(self): + self.get_vminfo() + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + #"Install" the cert and private key to /var/lib/waagent + thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) + prv_file = os.path.join(conf.get_lib_dir(), + "{0}.prv".format(thumbprint)) + crt_file = os.path.join(conf.get_lib_dir(), + "{0}.crt".format(thumbprint)) + shutil.copyfile(trans_prv_file, prv_file) + shutil.copyfile(trans_cert_file, crt_file) + + + def get_vminfo(self): + vminfo = VMInfo() + data, etag = self._get_data(self.identity_uri) + set_properties("vminfo", vminfo, data) + return vminfo + + def get_certs(self): + #TODO download and save certs + return CertList() + + def get_ext_handlers(self): + headers = { + "x-ms-vmagent-public-x509-cert": self._get_trans_cert() + } + ext_list = ExtHandlerList() + data, etag = self._get_data(self.ext_uri, headers=headers) + set_properties("extensionHandlers", ext_list.extHandlers, data) + return ext_list, etag + + def get_ext_handler_pkgs(self, ext_handler): + ext_handler_pkgs = ExtHandlerPackageList() + data = None + for version_uri in ext_handler.versionUris: + try: + data, etag = self._get_data(version_uri.uri) + break + except ProtocolError as e: + logger.warn("Failed to get version uris: {0}", e) + logger.info("Retry getting version uris") + set_properties("extensionPackages", ext_handler_pkgs, data) + return ext_handler_pkgs + + def report_provision_status(self, provision_status): + validata_param('provisionStatus', provision_status, ProvisionStatus) + data = get_properties(provision_status) + self._put_data(self.provision_status_uri, data) + + def report_vm_status(self, vm_status): + validata_param('vmStatus', vm_status, VMStatus) + data = get_properties(vm_status) + #TODO code field is not implemented for metadata protocol yet. Remove it + handler_statuses = data['vmAgent']['extensionHandlers'] + for handler_status in handler_statuses: + try: + handler_status.pop('code', None) + except KeyError: + pass + + self._put_data(self.vm_status_uri, data) + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validata_param('extensionStatus', ext_status, ExtensionStatus) + data = get_properties(ext_status) + uri = self.ext_status_uri.format(ext_name) + self._put_data(uri, data) + + def report_event(self, events): + #TODO disable telemetry for azure stack test + #validata_param('events', events, TelemetryEventList) + #data = get_properties(events) + #self._post_data(self.event_uri, data) + pass + diff --git a/azurelinuxagent/protocol/ovfenv.py b/azurelinuxagent/protocol/ovfenv.py index 9c845ee..de6791c 100644 --- a/azurelinuxagent/protocol/ovfenv.py +++ b/azurelinuxagent/protocol/ovfenv.py @@ -17,60 +17,22 @@ # Requires Python 2.4+ and Openssl 1.0+ # """ -Copy and parse ovf-env.xml from provisiong ISO and local cache +Copy and parse ovf-env.xml from provisioning ISO and local cache """ import os import re +import shutil import xml.dom.minidom as minidom import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.exception import ProtocolError +from azurelinuxagent.future import ustr import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext -from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError -from azurelinuxagent.protocol import ProtocolError -OVF_FILE_NAME = "ovf-env.xml" OVF_VERSION = "1.0" OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1" WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure" -def get_ovf_env(): - """ - Load saved ovf-env.xml - """ - ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME) - if os.path.isfile(ovf_file_path): - xml_text = fileutil.read_file(ovf_file_path) - return OvfEnv(xml_text) - else: - raise ProtocolError("ovf-env.xml is missing.") - -def copy_ovf_env(): - """ - Copy ovf env file from dvd to hard disk. - Remove password before save it to the disk - """ - try: - OSUTIL.mount_dvd() - ovf_file_path_on_dvd = OSUTIL.get_ovf_env_file_path_on_dvd() - ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) - ovfenv = OvfEnv(ovfxml) - ovfxml = re.sub(".*?<", "*<", ovfxml) - ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME) - fileutil.write_file(ovf_file_path, ovfxml) - except IOError as e: - raise ProtocolError(text(e)) - except OSUtilError as e: - raise ProtocolError(text(e)) - - try: - OSUTIL.umount_dvd() - OSUTIL.eject_dvd() - except OSUtilError as e: - logger.warn(text(e)) - - return ovfenv - def _validate_ovf(val, msg): if val is None: raise ProtocolError("Failed to parse OVF XML: {0}".format(msg)) diff --git a/azurelinuxagent/protocol/protocolFactory.py b/azurelinuxagent/protocol/protocolFactory.py deleted file mode 100644 index 0bf6e52..0000000 --- a/azurelinuxagent/protocol/protocolFactory.py +++ /dev/null @@ -1,114 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# -import os -import traceback -import threading -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.utils.fileutil as fileutil -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.protocol.common import * -from azurelinuxagent.protocol.v1 import WireProtocol -from azurelinuxagent.protocol.v2 import MetadataProtocol - -WIRE_SERVER_ADDR_FILE_NAME = "WireServer" - -def get_wire_protocol_endpoint(): - path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME) - try: - endpoint = fileutil.read_file(path) - except IOError as e: - raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e)) - - if endpoint is None: - raise ProtocolNotFound("Wire server endpoint is None") - - return endpoint - -def detect_wire_protocol(): - endpoint = get_wire_protocol_endpoint() - - OSUTIL.gen_transport_cert() - protocol = WireProtocol(endpoint) - protocol.initialize() - logger.info("Protocol V1 found.") - return protocol - -def detect_metadata_protocol(): - protocol = MetadataProtocol() - protocol.initialize() - - logger.info("Protocol V2 found.") - return protocol - -def detect_available_protocols(prob_funcs=[detect_wire_protocol, - detect_metadata_protocol]): - available_protocols = [] - for probe_func in prob_funcs: - try: - protocol = probe_func() - available_protocols.append(protocol) - except ProtocolNotFound as e: - logger.info(text(e)) - return available_protocols - -def detect_default_protocol(): - logger.info("Detect default protocol.") - available_protocols = detect_available_protocols() - return choose_default_protocol(available_protocols) - -def choose_default_protocol(protocols): - if len(protocols) > 0: - return protocols[0] - else: - raise ProtocolNotFound("No available protocol detected.") - -def get_wire_protocol(): - endpoint = get_wire_protocol_endpoint() - return WireProtocol(endpoint) - -def get_metadata_protocol(): - return MetadataProtocol() - -def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]): - available_protocols = [] - for getter in getters: - try: - protocol = getter() - available_protocols.append(protocol) - except ProtocolNotFound as e: - logger.info(text(e)) - return available_protocols - -class ProtocolFactory(object): - def __init__(self): - self._protocol = None - self._lock = threading.Lock() - - def get_default_protocol(self): - if self._protocol is None: - self._lock.acquire() - if self._protocol is None: - available_protocols = get_available_protocols() - self._protocol = choose_default_protocol(available_protocols) - self._lock.release() - - return self._protocol - -FACTORY = ProtocolFactory() diff --git a/azurelinuxagent/protocol/restapi.py b/azurelinuxagent/protocol/restapi.py new file mode 100644 index 0000000..fbd29ed --- /dev/null +++ b/azurelinuxagent/protocol/restapi.py @@ -0,0 +1,250 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import copy +import re +import json +import xml.dom.minidom +import azurelinuxagent.logger as logger +from azurelinuxagent.exception import ProtocolError, HttpError +from azurelinuxagent.future import ustr +import azurelinuxagent.utils.restutil as restutil + +def validata_param(name, val, expected_type): + if val is None: + raise ProtocolError("{0} is None".format(name)) + if not isinstance(val, expected_type): + raise ProtocolError(("{0} type should be {1} not {2}" + "").format(name, expected_type, type(val))) + +def set_properties(name, obj, data): + if isinstance(obj, DataContract): + validata_param("Property '{0}'".format(name), data, dict) + for prob_name, prob_val in data.items(): + prob_full_name = "{0}.{1}".format(name, prob_name) + try: + prob = getattr(obj, prob_name) + except AttributeError: + logger.warn("Unknown property: {0}", prob_full_name) + continue + prob = set_properties(prob_full_name, prob, prob_val) + setattr(obj, prob_name, prob) + return obj + elif isinstance(obj, DataContractList): + validata_param("List '{0}'".format(name), data, list) + for item_data in data: + item = obj.item_cls() + item = set_properties(name, item, item_data) + obj.append(item) + return obj + else: + return data + +def get_properties(obj): + if isinstance(obj, DataContract): + data = {} + props = vars(obj) + for prob_name, prob in list(props.items()): + data[prob_name] = get_properties(prob) + return data + elif isinstance(obj, DataContractList): + data = [] + for item in obj: + item_data = get_properties(item) + data.append(item_data) + return data + else: + return obj + +class DataContract(object): + pass + +class DataContractList(list): + def __init__(self, item_cls): + self.item_cls = item_cls + +""" +Data contract between guest and host +""" +class VMInfo(DataContract): + def __init__(self, subscriptionId=None, vmName=None, containerId=None, + roleName=None, roleInstanceName=None, tenantName=None): + self.subscriptionId = subscriptionId + self.vmName = vmName + self.containerId = containerId + self.roleName = roleName + self.roleInstanceName = roleInstanceName + self.tenantName = tenantName + +class Cert(DataContract): + def __init__(self, name=None, thumbprint=None, certificateDataUri=None): + self.name = name + self.thumbprint = thumbprint + self.certificateDataUri = certificateDataUri + +class CertList(DataContract): + def __init__(self): + self.certificates = DataContractList(Cert) + +class Extension(DataContract): + def __init__(self, name=None, sequenceNumber=None, publicSettings=None, + protectedSettings=None, certificateThumbprint=None): + self.name = name + self.sequenceNumber = sequenceNumber + self.publicSettings = publicSettings + self.protectedSettings = protectedSettings + self.certificateThumbprint = certificateThumbprint + +class ExtHandlerProperties(DataContract): + def __init__(self): + self.version = None + self.upgradePolicy = None + self.state = None + self.extensions = DataContractList(Extension) + +class ExtHandlerVersionUri(DataContract): + def __init__(self): + self.uri = None + +class ExtHandler(DataContract): + def __init__(self, name=None): + self.name = name + self.properties = ExtHandlerProperties() + self.versionUris = DataContractList(ExtHandlerVersionUri) + +class ExtHandlerList(DataContract): + def __init__(self): + self.extHandlers = DataContractList(ExtHandler) + +class ExtHandlerPackageUri(DataContract): + def __init__(self, uri=None): + self.uri = uri + +class ExtHandlerPackage(DataContract): + def __init__(self, version = None): + self.version = version + self.uris = DataContractList(ExtHandlerPackageUri) + +class ExtHandlerPackageList(DataContract): + def __init__(self): + self.versions = DataContractList(ExtHandlerPackage) + +class VMProperties(DataContract): + def __init__(self, certificateThumbprint=None): + #TODO need to confirm the property name + self.certificateThumbprint = certificateThumbprint + +class ProvisionStatus(DataContract): + def __init__(self, status=None, subStatus=None, description=None): + self.status = status + self.subStatus = subStatus + self.description = description + self.properties = VMProperties() + +class ExtensionSubStatus(DataContract): + def __init__(self, name=None, status=None, code=None, message=None): + self.name = name + self.status = status + self.code = code + self.message = message + +class ExtensionStatus(DataContract): + def __init__(self, configurationAppliedTime=None, operation=None, + status=None, seq_no=None, code=None, message=None): + self.configurationAppliedTime = configurationAppliedTime + self.operation = operation + self.status = status + self.sequenceNumber = seq_no + self.code = code + self.message = message + self.substatusList = DataContractList(ExtensionSubStatus) + +class ExtHandlerStatus(DataContract): + def __init__(self, name=None, version=None, status=None, code=0, + message=None): + self.name = name + self.version = version + self.status = status + self.code = code + self.message = message + self.extensions = DataContractList(ustr) + +class VMAgentStatus(DataContract): + def __init__(self, version=None, status=None, message=None): + self.version = version + self.status = status + self.message = message + self.extensionHandlers = DataContractList(ExtHandlerStatus) + +class VMStatus(DataContract): + def __init__(self): + self.vmAgent = VMAgentStatus() + +class TelemetryEventParam(DataContract): + def __init__(self, name=None, value=None): + self.name = name + self.value = value + +class TelemetryEvent(DataContract): + def __init__(self, eventId=None, providerId=None): + self.eventId = eventId + self.providerId = providerId + self.parameters = DataContractList(TelemetryEventParam) + +class TelemetryEventList(DataContract): + def __init__(self): + self.events = DataContractList(TelemetryEvent) + +class Protocol(DataContract): + + def detect(self): + raise NotImplementedError() + + def get_vminfo(self): + raise NotImplementedError() + + def get_certs(self): + raise NotImplementedError() + + def get_ext_handlers(self): + raise NotImplementedError() + + def get_ext_handler_pkgs(self, extension): + raise NotImplementedError() + + def download_ext_handler_pkg(self, uri): + try: + resp = restutil.http_get(uri, chk_proxy=True) + if resp.status == restutil.httpclient.OK: + return resp.read() + except HttpError as e: + raise ProtocolError("Failed to download from: {0}".format(uri), e) + + def report_provision_status(self, provision_status): + raise NotImplementedError() + + def report_vm_status(self, vm_status): + raise NotImplementedError() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + raise NotImplementedError() + + def report_event(self, event): + raise NotImplementedError() + diff --git a/azurelinuxagent/protocol/v1.py b/azurelinuxagent/protocol/v1.py deleted file mode 100644 index 92fcc06..0000000 --- a/azurelinuxagent/protocol/v1.py +++ /dev/null @@ -1,1121 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ - -import os -import json -import re -import time -import traceback -import xml.sax.saxutils as saxutils -import xml.etree.ElementTree as ET -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text, httpclient, bytebuffer -import azurelinuxagent.utils.restutil as restutil -from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext, \ - getattrib, gettext, remove_bom -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.utils.fileutil as fileutil -import azurelinuxagent.utils.shellutil as shellutil -from azurelinuxagent.protocol.common import * - -VERSION_INFO_URI = "http://{0}/?comp=versions" -GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate" -HEALTH_REPORT_URI = "http://{0}/machine?comp=health" -ROLE_PROP_URI = "http://{0}/machine?comp=roleProperties" -TELEMETRY_URI = "http://{0}/machine?comp=telemetrydata" - -WIRE_SERVER_ADDR_FILE_NAME = "WireServer" -INCARNATION_FILE_NAME = "Incarnation" -GOAL_STATE_FILE_NAME = "GoalState.{0}.xml" -HOSTING_ENV_FILE_NAME = "HostingEnvironmentConfig.xml" -SHARED_CONF_FILE_NAME = "SharedConfig.xml" -CERTS_FILE_NAME = "Certificates.xml" -P7M_FILE_NAME = "Certificates.p7m" -PEM_FILE_NAME = "Certificates.pem" -EXT_CONF_FILE_NAME = "ExtensionsConfig.{0}.xml" -MANIFEST_FILE_NAME = "{0}.{1}.manifest.xml" -TRANSPORT_CERT_FILE_NAME = "TransportCert.pem" -TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" - -PROTOCOL_VERSION = "2012-11-30" - -SHORT_WAITING_INTERVAL = 1 # 1 second -LONG_WAITING_INTERVAL = 15 # 15 seconds - -class WireProtocolResourceGone(ProtocolError): - pass - -class WireProtocol(Protocol): - - def __init__(self, endpoint): - self.client = WireClient(endpoint) - - def initialize(self): - self.client.check_wire_protocol_version() - self.client.update_goal_state(forced=True) - - def get_vminfo(self): - hosting_env = self.client.get_hosting_env() - vminfo = VMInfo() - vminfo.subscriptionId = None - vminfo.vmName = hosting_env.vm_name - return vminfo - - def get_certs(self): - certificates = self.client.get_certs() - return certificates.cert_list - - def get_ext_handlers(self): - #Update goal state to get latest extensions config - self.client.update_goal_state() - ext_conf = self.client.get_ext_conf() - return ext_conf.ext_handlers - - def get_ext_handler_pkgs(self, ext_handler): - goal_state = self.client.get_goal_state() - man = self.client.get_ext_manifest(ext_handler, goal_state) - return man.pkg_list - - def report_provision_status(self, provision_status): - validata_param("provision_status", provision_status, ProvisionStatus) - - if provision_status.status is not None: - self.client.report_health(provision_status.status, - provision_status.subStatus, - provision_status.description) - if provision_status.properties.certificateThumbprint is not None: - thumbprint = provision_status.properties.certificateThumbprint - self.client.report_role_prop(thumbprint) - - def report_vm_status(self, vm_status): - validata_param("vm_status", vm_status, VMStatus) - self.client.status_blob.set_vm_status(vm_status) - self.client.upload_status_blob() - - def report_ext_status(self, ext_handler_name, ext_name, ext_status): - validata_param("ext_status", ext_status, ExtensionStatus) - self.client.status_blob.set_ext_status(ext_handler_name, ext_status) - - def report_event(self, events): - validata_param("events", events, TelemetryEventList) - self.client.report_event(events) - -def _build_role_properties(container_id, role_instance_id, thumbprint): - xml = (u"" - u"" - u"" - u"{0}" - u"" - u"" - u"{1}" - u"" - u"" - u"" - u"" - u"" - u"" - u"" - u"").format(container_id, role_instance_id, thumbprint) - return xml - -def _build_health_report(incarnation, container_id, role_instance_id, - status, substatus, description): - #Escape '&', '<' and '>' - description = saxutils.escape(text(description)) - detail = u'' - if substatus is not None: - substatus = saxutils.escape(text(substatus)) - detail = (u"
" - u"{0}" - u"{1}" - u"
").format(substatus, description) - xml = (u"" - u"" - u"{0}" - u"" - u"{1}" - u"" - u"" - u"{2}" - u"" - u"{3}" - u"{4}" - u"" - u"" - u"" - u"" - u"" - u"").format(incarnation, - container_id, - role_instance_id, - status, - detail) - return xml - -""" -Convert VMStatus object to status blob format -""" -def ga_status_to_v1(ga_status): - formatted_msg = { - 'lang' : 'en-US', - 'message' : ga_status.message - } - v1_ga_status = { - 'version' : ga_status.version, - 'status' : ga_status.status, - 'formattedMessage' : formatted_msg - } - return v1_ga_status - -def ext_substatus_to_v1(sub_status_list): - status_list = [] - for substatus in sub_status_list: - status = { - "name": substatus.name, - "status": substatus.status, - "code": substatus.code, - "formattedMessage":{ - "lang": "en-US", - "message": substatus.message - } - } - status_list.append(status) - return status_list - -def ext_status_to_v1(ext_name, ext_status): - if ext_status is None: - return None - timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - v1_sub_status = ext_substatus_to_v1(ext_status.substatusList) - v1_ext_status = { - "status":{ - "name": ext_name, - "configurationAppliedTime": ext_status.configurationAppliedTime, - "operation": ext_status.operation, - "status": ext_status.status, - "code": ext_status.code, - "formattedMessage": { - "lang":"en-US", - "message": ext_status.message - } - }, - "version": 1.0, - "timestampUTC": timestamp - } - if len(v1_sub_status) != 0: - v1_ext_status['substatus'] = v1_sub_status - return v1_ext_status - -def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): - v1_handler_status = { - 'handlerVersion' : handler_status.version, - 'handlerName' : handler_status.name, - 'status' : handler_status.status, - } - if handler_status.message is not None: - v1_handler_status["formattedMessage"] = { - "lang":"en-US", - "message": handler_status.message - } - - if len(handler_status.extensions) > 0: - #Currently, no more than one extension per handler - ext_name = handler_status.extensions[0] - ext_status = ext_statuses.get(ext_name) - v1_ext_status = ext_status_to_v1(ext_name, ext_status) - if ext_status is not None and v1_ext_status is not None: - v1_handler_status["runtimeSettingsStatus"] = { - 'settingsStatus' : v1_ext_status, - 'sequenceNumber' : ext_status.sequenceNumber - } - return v1_handler_status - -def vm_status_to_v1(vm_status, ext_statuses): - timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - - v1_ga_status = ga_status_to_v1(vm_status.vmAgent) - v1_handler_status_list = [] - for handler_status in vm_status.vmAgent.extensionHandlers: - v1_handler_status = ext_handler_status_to_v1(handler_status, - ext_statuses, timestamp) - if v1_handler_status is not None: - v1_handler_status_list.append(v1_handler_status) - - v1_agg_status = { - 'guestAgentStatus': v1_ga_status, - 'handlerAggregateStatus' : v1_handler_status_list - } - v1_vm_status = { - 'version' : '1.0', - 'timestampUTC' : timestamp, - 'aggregateStatus' : v1_agg_status - } - return v1_vm_status - - -class StatusBlob(object): - def __init__(self, client): - self.vm_status = None - self.ext_statuses = {} - self.client = client - - def set_vm_status(self, vm_status): - validata_param("vmAgent", vm_status, VMStatus) - self.vm_status = vm_status - - def set_ext_status(self, ext_handler_name, ext_status): - validata_param("extensionStatus", ext_status, ExtensionStatus) - self.ext_statuses[ext_handler_name]= ext_status - - def to_json(self): - report = vm_status_to_v1(self.vm_status, self.ext_statuses) - return json.dumps(report) - - __storage_version__ = "2014-02-14" - - def upload(self, url): - #TODO upload extension only if content has changed - logger.verb("Upload status blob") - blob_type = self.get_blob_type(url) - - data = self.to_json() - try: - if blob_type == "BlockBlob": - self.put_block_blob(url, data) - elif blob_type == "PageBlob": - self.put_page_blob(url, data) - else: - raise ProtocolError("Unknown blob type: {0}".format(blob_type)) - except restutil.HttpError as e: - raise ProtocolError("Failed to upload status blob: {0}".format(e)) - - def get_blob_type(self, url): - #Check blob type - logger.verb("Check blob type.") - timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - try: - resp = self.client.call_storage_service(restutil.http_head, url, { - "x-ms-date" : timestamp, - 'x-ms-version' : self.__class__.__storage_version__ - }) - except restutil.HttpError as e: - raise ProtocolError((u"Failed to get status blob type: {0}" - u"").format(e)) - if resp is None or resp.status != httpclient.OK: - raise ProtocolError(("Failed to get status blob type: {0}" - "").format(resp.status)) - - blob_type = resp.getheader("x-ms-blob-type") - logger.verb("Blob type={0}".format(blob_type)) - return blob_type - - def put_block_blob(self, url, data): - logger.verb("Upload block blob") - timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - try: - resp = self.client.call_storage_service(restutil.http_put, url, - data, { - "x-ms-date" : timestamp, - "x-ms-blob-type" : "BlockBlob", - "Content-Length": text(len(data)), - "x-ms-version" : self.__class__.__storage_version__ - }) - except restutil.HttpError as e: - raise ProtocolError((u"Failed to upload block blob: {0}" - u"").format(e)) - if resp.status != httpclient.CREATED: - raise ProtocolError(("Failed to upload block blob: {0}" - "").format(resp.status)) - - def put_page_blob(self, url, data): - logger.verb("Replace old page blob") - - #Convert string into bytes - data=bytearray(data, encoding='utf-8') - timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) - - #Align to 512 bytes - page_blob_size = int((len(data) + 511) / 512) * 512 - try: - resp = self.client.call_storage_service(restutil.http_put, url, - "", { - "x-ms-date" : timestamp, - "x-ms-blob-type" : "PageBlob", - "Content-Length": "0", - "x-ms-blob-content-length" : text(page_blob_size), - "x-ms-version" : self.__class__.__storage_version__ - }) - except restutil.HttpError as e: - raise ProtocolError((u"Failed to clean up page blob: {0}" - u"").format(e)) - if resp.status != httpclient.CREATED: - raise ProtocolError(("Failed to clean up page blob: {0}" - "").format(resp.status)) - - if url.count("?") < 0: - url = "{0}?comp=page".format(url) - else: - url = "{0}&comp=page".format(url) - - logger.verb("Upload page blob") - page_max = 4 * 1024 * 1024 #Max page size: 4MB - start = 0 - end = 0 - while end < len(data): - end = min(len(data), start + page_max) - content_size = end - start - #Align to 512 bytes - page_end = int((end + 511) / 512) * 512 - buf_size = page_end - start - buf = bytearray(buf_size) - buf[0: content_size] = data[start: end] - try: - resp = self.client.call_storage_service(restutil.http_put, url, - bytebuffer(buf), { - "x-ms-date" : timestamp, - "x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1), - "x-ms-page-write" : "update", - "x-ms-version" : self.__class__.__storage_version__, - "Content-Length": text(page_end - start) - }) - except restutil.HttpError as e: - raise ProtocolError((u"Failed to upload page blob: {0}" - u"").format(e)) - if resp is None or resp.status != httpclient.CREATED: - raise ProtocolError(("Failed to upload page blob: {0}" - "").format(resp.status)) - start = end - -def event_param_to_v1(param): - param_format = '' - param_type = type(param.value) - attr_type = "" - if param_type is int: - attr_type = 'mt:uint64' - elif param_type is str: - attr_type = 'mt:wstr' - elif text(param_type).count("'unicode'") > 0: - attr_type = 'mt:wstr' - elif param_type is bool: - attr_type = 'mt:bool' - elif param_type is float: - attr_type = 'mt:float64' - return param_format.format(param.name, saxutils.quoteattr(text(param.value)), - attr_type) - -def event_to_v1(event): - params = "" - for param in event.parameters: - params += event_param_to_v1(param) - event_str = ('' - '' - '').format(event.eventId, params) - return event_str - -class WireClient(object): - def __init__(self, endpoint): - self.endpoint = endpoint - self.goal_state = None - self.updated = None - self.hosting_env = None - self.shared_conf = None - self.certs = None - self.ext_conf = None - self.last_request = 0 - self.req_count = 0 - self.status_blob = StatusBlob(self) - - def prevent_throttling(self): - """ - Try to avoid throttling of wire server - """ - now = time.time() - if now - self.last_request < 1: - logger.info("Last request issued less than 1 second ago") - logger.info("Sleep {0} second to avoid throttling.", - SHORT_WAITING_INTERVAL) - time.sleep(SHORT_WAITING_INTERVAL) - self.last_request = now - - self.req_count += 1 - if self.req_count % 3 == 0: - logger.info("Sleep {0} second to avoid throttling.", - SHORT_WAITING_INTERVAL) - time.sleep(SHORT_WAITING_INTERVAL) - self.req_count = 0 - - def call_wireserver(self, http_req, *args, **kwargs): - """ - Call wire server. Handle throttling(403) and Resource Gone(410) - """ - self.prevent_throttling() - for retry in range(0, 3): - resp = http_req(*args, **kwargs) - if resp.status == httpclient.FORBIDDEN: - logger.warn("Sending too much request to wire server") - logger.info("Sleep {0} second to avoid throttling.", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - elif resp.status == httpclient.GONE: - msg = args[0] if len(args) > 0 else "" - raise WireProtocolResourceGone(msg) - else: - return resp - raise ProtocolError(("Calling wire server failed: {0}" - "").format(resp.status)) - - def decode_config(self, data): - if data is None: - return None - data = remove_bom(data) - xml_text = text(data, encoding='utf-8') - return xml_text - - def fetch_config(self, uri, headers): - try: - resp = self.call_wireserver(restutil.http_get, uri, - headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) - - if(resp.status != httpclient.OK): - raise ProtocolError("{0} - {1}".format(resp.status, uri)) - - return self.decode_config(resp.read()) - - def fetch_cache(self, local_file): - if not os.path.isfile(local_file): - raise ProtocolError("{0} is missing.".format(local_file)) - try: - return fileutil.read_file(local_file) - except IOError as e: - raise ProtocolError("Failed to read cache: {0}".format(e)) - - def save_cache(self, local_file, data): - try: - fileutil.write_file(local_file, data) - except IOError as e: - raise ProtocolError("Failed to write cache: {0}".format(e)) - - def call_storage_service(self, http_req, *args, **kwargs): - """ - Call storage service, handle SERVICE_UNAVAILABLE(503) - """ - for retry in range(0, 3): - resp = http_req(*args, **kwargs) - if resp.status == httpclient.SERVICE_UNAVAILABLE: - logger.warn("Storage service is not avaible temporaryly") - logger.info("Will retry later, in {0} seconds", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - else: - return resp - raise ProtocolError(("Calling storage endpoint failed: {0}" - "").format(resp.status)) - - def fetch_manifest(self, version_uris): - for version_uri in version_uris: - try: - resp = self.call_storage_service(restutil.http_get, - version_uri.uri, None, - chk_proxy=True) - except restutil.HttpError as e: - raise ProtocolError(text(e)) - - if resp.status == httpclient.OK: - return self.decode_config(resp.read()) - logger.warn("Failed to fetch ExtensionManifest: {0}, {1}", - resp.status, version_uri.uri) - logger.info("Will retry later, in {0} seconds", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - raise ProtocolError(("Failed to fetch ExtensionManifest from " - "all sources")) - - - def update_hosting_env(self, goal_state): - if goal_state.hosting_env_uri is None: - raise ProtocolError("HostingEnvironmentConfig uri is empty") - local_file = HOSTING_ENV_FILE_NAME - xml_text = self.fetch_config(goal_state.hosting_env_uri, - self.get_header()) - self.save_cache(local_file, xml_text) - self.hosting_env = HostingEnv(xml_text) - - def update_shared_conf(self, goal_state): - if goal_state.shared_conf_uri is None: - raise ProtocolError("SharedConfig uri is empty") - local_file = SHARED_CONF_FILE_NAME - xml_text = self.fetch_config(goal_state.shared_conf_uri, - self.get_header()) - self.save_cache(local_file, xml_text) - self.shared_conf = SharedConfig(xml_text) - - def update_certs(self, goal_state): - if goal_state.certs_uri is None: - return - local_file = CERTS_FILE_NAME - xml_text = self.fetch_config(goal_state.certs_uri, - self.get_header_for_cert()) - self.save_cache(local_file, xml_text) - self.certs = Certificates(self, xml_text) - - def update_ext_conf(self, goal_state): - if goal_state.ext_uri is None: - logger.info("ExtensionsConfig.xml uri is empty") - self.ext_conf = ExtensionsConfig(None) - return - incarnation = goal_state.incarnation - local_file = EXT_CONF_FILE_NAME.format(incarnation) - xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) - self.save_cache(local_file, xml_text) - self.ext_conf = ExtensionsConfig(xml_text) - for ext_handler in self.ext_conf.ext_handlers.extHandlers: - self.update_ext_handler_manifest(ext_handler, goal_state) - - def update_ext_handler_manifest(self, ext_handler, goal_state): - local_file = MANIFEST_FILE_NAME.format(ext_handler.name, - goal_state.incarnation) - xml_text = self.fetch_manifest(ext_handler.versionUris) - self.save_cache(local_file, xml_text) - - def update_goal_state(self, forced=False, max_retry=3): - uri = GOAL_STATE_URI.format(self.endpoint) - xml_text = self.fetch_config(uri, self.get_header()) - goal_state = GoalState(xml_text) - - incarnation_file = os.path.join(OSUTIL.get_lib_dir(), - INCARNATION_FILE_NAME) - - if not forced: - last_incarnation = None - if(os.path.isfile(incarnation_file)): - last_incarnation = fileutil.read_file(incarnation_file) - new_incarnation = goal_state.incarnation - if last_incarnation is not None and \ - last_incarnation == new_incarnation: - #Goalstate is not updated. - return - - #Start updating goalstate, retry on 410 - for retry in range(0, max_retry): - try: - self.goal_state = goal_state - file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) - goal_state_file = os.path.join(OSUTIL.get_lib_dir(), file_name) - self.save_cache(goal_state_file, xml_text) - self.save_cache(incarnation_file, goal_state.incarnation) - self.update_hosting_env(goal_state) - self.update_shared_conf(goal_state) - self.update_certs(goal_state) - self.update_ext_conf(goal_state) - return - except WireProtocolResourceGone: - logger.info("Incarnation is out of date. Update goalstate.") - xml_text = self.fetch_config(uri, self.get_header()) - goal_state = GoalState(xml_text) - - raise ProtocolError("Exceeded max retry updating goal state") - - def get_goal_state(self): - if(self.goal_state is None): - incarnation = self.fetch_cache(INCARNATION_FILE_NAME) - goal_state_file = GOAL_STATE_FILE_NAME.format(incarnation) - xml_text = self.fetch_cache(goal_state_file) - self.goal_state = GoalState(xml_text) - return self.goal_state - - def get_hosting_env(self): - if(self.hosting_env is None): - xml_text = self.fetch_cache(HOSTING_ENV_FILE_NAME) - self.hosting_env = HostingEnv(xml_text) - return self.hosting_env - - def get_shared_conf(self): - if(self.shared_conf is None): - xml_text = self.fetch_cache(SHARED_CONF_FILE_NAME) - self.shared_conf = SharedConfig(xml_text) - return self.shared_conf - - def get_certs(self): - if(self.certs is None): - xml_text = self.fetch_cache(CERTS_FILE_NAME) - self.certs = Certificates(self, xml_text) - if self.certs is None: - return None - return self.certs - - def get_ext_conf(self): - if(self.ext_conf is None): - goal_state = self.get_goal_state() - if goal_state.ext_uri is None: - self.ext_conf = ExtensionsConfig(None) - else: - local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) - xml_text = self.fetch_cache(local_file) - self.ext_conf = ExtensionsConfig(xml_text) - return self.ext_conf - - def get_ext_manifest(self, extension, goal_state): - local_file = MANIFEST_FILE_NAME.format(extension.name, - goal_state.incarnation) - xml_text = self.fetch_cache(local_file) - return ExtensionManifest(xml_text) - - def check_wire_protocol_version(self): - uri = VERSION_INFO_URI.format(self.endpoint) - version_info_xml = self.fetch_config(uri, None) - version_info = VersionInfo(version_info_xml) - - preferred = version_info.get_preferred() - if PROTOCOL_VERSION == preferred: - logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) - elif PROTOCOL_VERSION in version_info.get_supported(): - logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) - logger.warn("Server prefered version:{0}", preferred) - else: - error = ("Agent supported wire protocol version: {0} was not " - "advised by Fabric.").format(PROTOCOL_VERSION) - raise ProtocolNotFound(error) - - def upload_status_blob(self): - ext_conf = self.get_ext_conf() - if ext_conf.status_upload_blob is not None: - self.status_blob.upload(ext_conf.status_upload_blob) - - def report_role_prop(self, thumbprint): - goal_state = self.get_goal_state() - role_prop = _build_role_properties(goal_state.container_id, - goal_state.role_instance_id, - thumbprint) - role_prop = role_prop.encode("utf-8") - role_prop_uri = ROLE_PROP_URI.format(self.endpoint) - headers = self.get_header_for_xml_content() - try: - resp = self.call_wireserver(restutil.http_post, role_prop_uri, - role_prop, headers = headers) - except restutil.HttpError as e: - raise ProtocolError((u"Failed to send role properties: {0}" - u"").format(e)) - if resp.status != httpclient.ACCEPTED: - raise ProtocolError((u"Failed to send role properties: {0}" - u", {1}").format(resp.status, resp.read())) - - def report_health(self, status, substatus, description): - goal_state = self.get_goal_state() - health_report = _build_health_report(goal_state.incarnation, - goal_state.container_id, - goal_state.role_instance_id, - status, - substatus, - description) - health_report = health_report.encode("utf-8") - health_report_uri = HEALTH_REPORT_URI.format(self.endpoint) - headers = self.get_header_for_xml_content() - try: - resp = self.call_wireserver(restutil.http_post, health_report_uri, - health_report, headers = headers) - except restutil.HttpError as e: - raise ProtocolError((u"Failed to send provision status: {0}" - u"").format(e)) - if resp.status != httpclient.OK: - raise ProtocolError((u"Failed to send provision status: {0}" - u", {1}").format(resp.status, resp.read())) - - def send_event(self, provider_id, event_str): - uri = TELEMETRY_URI.format(self.endpoint) - data_format = ('' - '' - '{1}' - '' - '') - data = data_format.format(provider_id, event_str) - try: - header = self.get_header_for_xml_content() - resp = self.call_wireserver(restutil.http_post, uri, data, header) - except restutil.HttpError as e: - raise ProtocolError("Failed to send events:{0}".format(e)) - - if resp.status != httpclient.OK: - logger.verb(resp.read()) - raise ProtocolError("Failed to send events:{0}".format(resp.status)) - - def report_event(self, event_list): - buf = {} - #Group events by providerId - for event in event_list.events: - if event.providerId not in buf: - buf[event.providerId] = "" - event_str = event_to_v1(event) - if len(event_str) >= 63 * 1024: - logger.warn("Single event too large: {0}", event_str[300:]) - continue - if len(buf[event.providerId] + event_str) >= 63 * 1024: - self.send_event(event.providerId, buf[event.providerId]) - buf[event.providerId] = "" - buf[event.providerId] = buf[event.providerId] + event_str - - #Send out all events left in buffer. - for provider_id in list(buf.keys()): - if len(buf[provider_id]) > 0: - self.send_event(provider_id, buf[provider_id]) - - def get_header(self): - return { - "x-ms-agent-name":"WALinuxAgent", - "x-ms-version":PROTOCOL_VERSION - } - - def get_header_for_xml_content(self): - return { - "x-ms-agent-name":"WALinuxAgent", - "x-ms-version":PROTOCOL_VERSION, - "Content-Type":"text/xml;charset=utf-8" - } - - def get_header_for_cert(self): - cert = "" - content = self.fetch_cache(TRANSPORT_CERT_FILE_NAME) - for line in content.split('\n'): - if "CERTIFICATE" not in line: - cert += line.rstrip() - return { - "x-ms-agent-name":"WALinuxAgent", - "x-ms-version":PROTOCOL_VERSION, - "x-ms-cipher-name": "DES_EDE3_CBC", - "x-ms-guest-agent-public-x509-cert":cert - } - -class VersionInfo(object): - def __init__(self, xml_text): - """ - Query endpoint server for wire protocol version. - Fail if our desired protocol version is not seen. - """ - logger.verb("Load Version.xml") - self.parse(xml_text) - - def parse(self, xml_text): - xml_doc = parse_doc(xml_text) - preferred = find(xml_doc, "Preferred") - self.preferred = findtext(preferred, "Version") - logger.info("Fabric preferred wire protocol version:{0}", self.preferred) - - self.supported = [] - supported = find(xml_doc, "Supported") - supported_version = findall(supported, "Version") - for node in supported_version: - version = gettext(node) - logger.verb("Fabric supported wire protocol version:{0}", version) - self.supported.append(version) - - def get_preferred(self): - return self.preferred - - def get_supported(self): - return self.supported - - -class GoalState(object): - - def __init__(self, xml_text): - if xml_text is None: - raise ValueError("GoalState.xml is None") - logger.verb("Load GoalState.xml") - self.incarnation = None - self.expected_state = None - self.hosting_env_uri = None - self.shared_conf_uri = None - self.certs_uri = None - self.ext_uri = None - self.role_instance_id = None - self.container_id = None - self.load_balancer_probe_port = None - self.parse(xml_text) - - def parse(self, xml_text): - """ - Request configuration data from endpoint server. - """ - self.xml_text = xml_text - xml_doc = parse_doc(xml_text) - self.incarnation = findtext(xml_doc, "Incarnation") - self.expected_state = findtext(xml_doc, "ExpectedState") - self.hosting_env_uri = findtext(xml_doc, "HostingEnvironmentConfig") - self.shared_conf_uri = findtext(xml_doc, "SharedConfig") - self.certs_uri = findtext(xml_doc, "Certificates") - self.ext_uri = findtext(xml_doc, "ExtensionsConfig") - role_instance = find(xml_doc, "RoleInstance") - self.role_instance_id = findtext(role_instance, "InstanceId") - container = find(xml_doc, "Container") - self.container_id = findtext(container, "ContainerId") - lbprobe_ports = find(xml_doc, "LBProbePorts") - self.load_balancer_probe_port = findtext(lbprobe_ports, "Port") - return self - - -class HostingEnv(object): - """ - parse Hosting enviromnet config and store in - HostingEnvironmentConfig.xml - """ - def __init__(self, xml_text): - if xml_text is None: - raise ValueError("HostingEnvironmentConfig.xml is None") - logger.verb("Load HostingEnvironmentConfig.xml") - self.vm_name = None - self.role_name = None - self.deployment_name = None - self.parse(xml_text) - - def parse(self, xml_text): - """ - parse and create HostingEnvironmentConfig.xml. - """ - self.xml_text = xml_text - xml_doc = parse_doc(xml_text) - incarnation = find(xml_doc, "Incarnation") - self.vm_name = getattrib(incarnation, "instance") - role = find(xml_doc, "Role") - self.role_name = getattrib(role, "name") - deployment = find(xml_doc, "Deployment") - self.deployment_name = getattrib(deployment, "name") - return self - -class SharedConfig(object): - """ - parse role endpoint server and goal state config. - """ - def __init__(self, xml_text): - logger.verb("Load SharedConfig.xml") - self.parse(xml_text) - - def parse(self, xml_text): - """ - parse and write configuration to file SharedConfig.xml. - """ - #Not used currently - return self - -class Certificates(object): - - """ - Object containing certificates of host and provisioned user. - """ - def __init__(self, client, xml_text): - logger.verb("Load Certificates.xml") - self.client = client - self.lib_dir = OSUTIL.get_lib_dir() - self.openssl_cmd = OSUTIL.get_openssl_cmd() - self.cert_list = CertList() - self.parse(xml_text) - - def parse(self, xml_text): - """ - Parse multiple certificates into seperate files. - """ - xml_doc = parse_doc(xml_text) - data = findtext(xml_doc, "Data") - if data is None: - return - - p7m = ("MIME-Version:1.0\n" - "Content-Disposition: attachment; filename=\"{0}\"\n" - "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" - "Content-Transfer-Encoding: base64\n" - "\n" - "{2}").format(P7M_FILE_NAME, P7M_FILE_NAME, data) - - self.client.save_cache(os.path.join(self.lib_dir, P7M_FILE_NAME), p7m) - #decrypt certificates - cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3}" - "| {4} pkcs12 -nodes -password pass: -out {5}" - "").format(self.openssl_cmd, P7M_FILE_NAME, - TRANSPORT_PRV_FILE_NAME, TRANSPORT_CERT_FILE_NAME, - self.openssl_cmd, PEM_FILE_NAME) - shellutil.run(cmd) - - #The parsing process use public key to match prv and crt. - buf = [] - begin_crt = False - begin_prv = False - prvs = {} - thumbprints = {} - index = 0 - v1_cert_list = [] - with open(PEM_FILE_NAME) as pem: - for line in pem.readlines(): - buf.append(line) - if re.match(r'[-]+BEGIN.*KEY[-]+', line): - begin_prv = True - elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): - begin_crt = True - elif re.match(r'[-]+END.*KEY[-]+', line): - tmp_file = self.write_to_tmp_file(index, 'prv', buf) - pub = OSUTIL.get_pubkey_from_prv(tmp_file) - prvs[pub] = tmp_file - buf = [] - index += 1 - begin_prv = False - elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): - tmp_file = self.write_to_tmp_file(index, 'crt', buf) - pub = OSUTIL.get_pubkey_from_crt(tmp_file) - thumbprint = OSUTIL.get_thumbprint_from_crt(tmp_file) - thumbprints[pub] = thumbprint - #Rename crt with thumbprint as the file name - crt = "{0}.crt".format(thumbprint) - v1_cert_list.append({ - "name":None, - "thumbprint":thumbprint - }) - os.rename(tmp_file, os.path.join(self.lib_dir, crt)) - buf = [] - index += 1 - begin_crt = False - - #Rename prv key with thumbprint as the file name - for pubkey in prvs: - thumbprint = thumbprints[pubkey] - if thumbprint: - tmp_file = prvs[pubkey] - prv = "{0}.prv".format(thumbprint) - os.rename(tmp_file, os.path.join(self.lib_dir, prv)) - - for v1_cert in v1_cert_list: - cert = Cert() - set_properties("certs", cert, v1_cert) - self.cert_list.certificates.append(cert) - - def write_to_tmp_file(self, index, suffix, buf): - file_name = os.path.join(self.lib_dir, "{0}.{1}".format(index, suffix)) - self.client.save_cache(file_name, "".join(buf)) - return file_name - - -class ExtensionsConfig(object): - """ - parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. - Install if true, remove if it is set to false. - """ - - def __init__(self, xml_text): - logger.verb("Load ExtensionsConfig.xml") - self.ext_handlers = ExtHandlerList() - self.status_upload_blob = None - if xml_text is not None: - self.parse(xml_text) - - def parse(self, xml_text): - """ - Write configuration to file ExtensionsConfig.xml. - """ - xml_doc = parse_doc(xml_text) - plugins_list = find(xml_doc, "Plugins") - plugins = findall(plugins_list, "Plugin") - plugin_settings_list = find(xml_doc, "PluginSettings") - plugin_settings = findall(plugin_settings_list, "Plugin") - - for plugin in plugins: - ext_handler = self.parse_plugin(plugin) - self.ext_handlers.extHandlers.append(ext_handler) - self.parse_plugin_settings(ext_handler, plugin_settings) - - self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob") - - def parse_plugin(self, plugin): - ext_handler = ExtHandler() - ext_handler.name = getattrib(plugin, "name") - ext_handler.properties.version = getattrib(plugin, "version") - ext_handler.properties.state = getattrib(plugin, "state") - - auto_upgrade = getattrib(plugin, "autoUpgrade") - if auto_upgrade is not None and auto_upgrade.lower() == "true": - ext_handler.properties.upgradePolicy = "auto" - else: - ext_handler.properties.upgradePolicy = "manual" - - location = getattrib(plugin, "location") - failover_location = getattrib(plugin, "failoverlocation") - for uri in [location, failover_location]: - version_uri = ExtHandlerVersionUri() - version_uri.uri = uri - ext_handler.versionUris.append(version_uri) - return ext_handler - - def parse_plugin_settings(self, ext_handler, plugin_settings): - if plugin_settings is None: - return - - name = ext_handler.name - version = ext_handler.properties.version - settings = [x for x in plugin_settings \ - if getattrib(x, "name") == name and \ - getattrib(x ,"version") == version] - - if settings is None or len(settings) == 0: - return - - runtime_settings = None - runtime_settings_node = find(settings[0], "RuntimeSettings") - seqNo = getattrib(runtime_settings_node, "seqNo") - runtime_settings_str = gettext(runtime_settings_node) - try: - runtime_settings = json.loads(runtime_settings_str) - except ValueError as e: - logger.error("Invalid extension settings") - return - - for plugin_settings_list in runtime_settings["runtimeSettings"]: - handler_settings = plugin_settings_list["handlerSettings"] - ext = Extension() - #There is no "extension name" in wire protocol. - #Put - ext.name = ext_handler.name - ext.sequenceNumber = seqNo - ext.publicSettings = handler_settings.get("publicSettings") - ext.privateSettings = handler_settings.get("protectedSettings") - thumbprint = handler_settings.get("protectedSettingsCertThumbprint") - ext.certificateThumbprint = thumbprint - ext_handler.properties.extensions.append(ext) - -class ExtensionManifest(object): - def __init__(self, xml_text): - if xml_text is None: - raise ValueError("ExtensionManifest is None") - logger.verb("Load ExtensionManifest.xml") - self.pkg_list = ExtHandlerPackageList() - self.parse(xml_text) - - def parse(self, xml_text): - xml_doc = parse_doc(xml_text) - packages = findall(xml_doc, "Plugin") - for package in packages: - version = findtext(package, "Version") - uris = find(package, "Uris") - uri_list = findall(uris, "Uri") - uri_list = [gettext(x) for x in uri_list] - package = ExtHandlerPackage() - package.version = version - for uri in uri_list: - pkg_uri = ExtHandlerVersionUri() - pkg_uri.uri = uri - package.uris.append(pkg_uri) - self.pkg_list.versions.append(package) - diff --git a/azurelinuxagent/protocol/v2.py b/azurelinuxagent/protocol/v2.py deleted file mode 100644 index 34102b7..0000000 --- a/azurelinuxagent/protocol/v2.py +++ /dev/null @@ -1,145 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ - -import json -from azurelinuxagent.future import httpclient, text -import azurelinuxagent.utils.restutil as restutil -from azurelinuxagent.protocol.common import * - -ENDPOINT='169.254.169.254' -#TODO use http for azure pack test -#ENDPOINT='localhost' -APIVERSION='2015-05-01-preview' -BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" - -def _add_content_type(headers): - if headers is None: - headers = {} - headers["content-type"] = "application/json" - return headers - -class MetadataProtocol(Protocol): - - def __init__(self, apiversion=APIVERSION, endpoint=ENDPOINT): - self.apiversion = apiversion - self.endpoint = endpoint - self.identity_uri = BASE_URI.format(self.endpoint, "identity", - self.apiversion, "&$expand=*") - self.cert_uri = BASE_URI.format(self.endpoint, "certificates", - self.apiversion, "&$expand=*") - self.ext_uri = BASE_URI.format(self.endpoint, "extensionHandlers", - self.apiversion, "&$expand=*") - self.provision_status_uri = BASE_URI.format(self.endpoint, - "provisioningStatus", - self.apiversion, "") - self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent", - self.apiversion, "") - self.ext_status_uri = BASE_URI.format(self.endpoint, - "status/extensions/{0}", - self.apiversion, "") - self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry", - self.apiversion, "") - - def _get_data(self, url, headers=None): - try: - resp = restutil.http_get(url, headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) - - if resp.status != httpclient.OK: - raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) - - data = resp.read() - if data is None: - return None - data = json.loads(text(data, encoding="utf-8")) - return data - - def _put_data(self, url, data, headers=None): - headers = _add_content_type(headers) - try: - resp = restutil.http_put(url, json.dumps(data), headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) - if resp.status != httpclient.OK: - raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) - - def _post_data(self, url, data, headers=None): - headers = _add_content_type(headers) - try: - resp = restutil.http_post(url, json.dumps(data), headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) - if resp.status != httpclient.CREATED: - raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) - - def initialize(self): - pass - - def get_vminfo(self): - vminfo = VMInfo() - data = self._get_data(self.identity_uri) - set_properties("vminfo", vminfo, data) - return vminfo - - def get_certs(self): - #TODO download and save certs - return CertList() - - def get_ext_handlers(self): - ext_list = ExtHandlerList() - data = self._get_data(self.ext_uri) - set_properties("extensionHandlers", ext_list.extHandlers, data) - return ext_list - - def get_ext_handler_pkgs(self, ext_handler): - ext_handler_pkgs = ExtHandlerPackageList() - data = None - for version_uri in ext_handler.versionUris: - try: - data = self._get_data(version_uri.uri) - break - except ProtocolError as e: - logger.warn("Failed to get version uris: {0}", e) - logger.info("Retry getting version uris") - set_properties("extensionPackages", ext_handler_pkgs, data) - return ext_handler_pkgs - - def report_provision_status(self, provision_status): - validata_param('provisionStatus', provision_status, ProvisionStatus) - data = get_properties(provision_status) - self._put_data(self.provision_status_uri, data) - - def report_vm_status(self, vm_status): - validata_param('vmStatus', vm_status, VMStatus) - data = get_properties(vm_status) - self._put_data(self.vm_status_uri, data) - - def report_ext_status(self, ext_handler_name, ext_name, ext_status): - validata_param('extensionStatus', ext_status, ExtensionStatus) - data = get_properties(ext_status) - uri = self.ext_status_uri.format(ext_name) - self._put_data(uri, data) - - def report_event(self, events): - #TODO disable telemetry for azure stack test - #validata_param('events', events, TelemetryEventList) - #data = get_properties(events) - #self._post_data(self.event_uri, data) - pass - diff --git a/azurelinuxagent/protocol/wire.py b/azurelinuxagent/protocol/wire.py new file mode 100644 index 0000000..7b5ffe8 --- /dev/null +++ b/azurelinuxagent/protocol/wire.py @@ -0,0 +1,1155 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import os +import json +import re +import time +import traceback +import xml.sax.saxutils as saxutils +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.exception import ProtocolError, HttpError, \ + ProtocolNotFoundError +from azurelinuxagent.future import ustr, httpclient, bytebuffer +import azurelinuxagent.utils.restutil as restutil +from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext, \ + getattrib, gettext, remove_bom, \ + get_bytes_from_pem +import azurelinuxagent.utils.fileutil as fileutil +import azurelinuxagent.utils.shellutil as shellutil +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.protocol.restapi import * + +VERSION_INFO_URI = "http://{0}/?comp=versions" +GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate" +HEALTH_REPORT_URI = "http://{0}/machine?comp=health" +ROLE_PROP_URI = "http://{0}/machine?comp=roleProperties" +TELEMETRY_URI = "http://{0}/machine?comp=telemetrydata" + +WIRE_SERVER_ADDR_FILE_NAME = "WireServer" +INCARNATION_FILE_NAME = "Incarnation" +GOAL_STATE_FILE_NAME = "GoalState.{0}.xml" +HOSTING_ENV_FILE_NAME = "HostingEnvironmentConfig.xml" +SHARED_CONF_FILE_NAME = "SharedConfig.xml" +CERTS_FILE_NAME = "Certificates.xml" +P7M_FILE_NAME = "Certificates.p7m" +PEM_FILE_NAME = "Certificates.pem" +EXT_CONF_FILE_NAME = "ExtensionsConfig.{0}.xml" +MANIFEST_FILE_NAME = "{0}.{1}.manifest.xml" +TRANSPORT_CERT_FILE_NAME = "TransportCert.pem" +TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" + +PROTOCOL_VERSION = "2012-11-30" +ENDPOINT_FINE_NAME = "WireServer" + +SHORT_WAITING_INTERVAL = 1 # 1 second +LONG_WAITING_INTERVAL = 15 # 15 seconds + +class WireProtocolResourceGone(ProtocolError): + pass + +class WireProtocol(Protocol): + """Slim layer to adapte wire protocol data to metadata protocol interface""" + + def __init__(self, endpoint): + if endpoint is None: + raise ProtocolError("WireProtocl endpoint is None") + self.endpoint = endpoint + self.client = WireClient(self.endpoint) + + def detect(self): + self.client.check_wire_protocol_version() + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + self.client.update_goal_state(forced=True) + + def get_vminfo(self): + goal_state = self.client.get_goal_state() + hosting_env = self.client.get_hosting_env() + + vminfo = VMInfo() + vminfo.subscriptionId = None + vminfo.vmName = hosting_env.vm_name + vminfo.tenantName = hosting_env.deployment_name + vminfo.roleName = hosting_env.role_name + vminfo.roleInstanceName = goal_state.role_instance_id + vminfo.containerId = goal_state.container_id + return vminfo + + def get_certs(self): + certificates = self.client.get_certs() + return certificates.cert_list + + def get_ext_handlers(self): + logger.verb("Get extension handler config") + #Update goal state to get latest extensions config + self.client.update_goal_state() + goal_state = self.client.get_goal_state() + ext_conf = self.client.get_ext_conf() + #In wire protocol, incarnation is equivalent to ETag + return ext_conf.ext_handlers, goal_state.incarnation + + def get_ext_handler_pkgs(self, ext_handler): + logger.verb("Get extension handler package") + goal_state = self.client.get_goal_state() + man = self.client.get_ext_manifest(ext_handler, goal_state) + return man.pkg_list + + def report_provision_status(self, provision_status): + validata_param("provision_status", provision_status, ProvisionStatus) + + if provision_status.status is not None: + self.client.report_health(provision_status.status, + provision_status.subStatus, + provision_status.description) + if provision_status.properties.certificateThumbprint is not None: + thumbprint = provision_status.properties.certificateThumbprint + self.client.report_role_prop(thumbprint) + + def report_vm_status(self, vm_status): + validata_param("vm_status", vm_status, VMStatus) + self.client.status_blob.set_vm_status(vm_status) + self.client.upload_status_blob() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validata_param("ext_status", ext_status, ExtensionStatus) + self.client.status_blob.set_ext_status(ext_handler_name, ext_status) + + def report_event(self, events): + validata_param("events", events, TelemetryEventList) + self.client.report_event(events) + +def _build_role_properties(container_id, role_instance_id, thumbprint): + xml = (u"" + u"" + u"" + u"{0}" + u"" + u"" + u"{1}" + u"" + u"" + u"" + u"" + u"" + u"" + u"" + u"").format(container_id, role_instance_id, thumbprint) + return xml + +def _build_health_report(incarnation, container_id, role_instance_id, + status, substatus, description): + #Escape '&', '<' and '>' + description = saxutils.escape(ustr(description)) + detail = u'' + if substatus is not None: + substatus = saxutils.escape(ustr(substatus)) + detail = (u"
" + u"{0}" + u"{1}" + u"
").format(substatus, description) + xml = (u"" + u"" + u"{0}" + u"" + u"{1}" + u"" + u"" + u"{2}" + u"" + u"{3}" + u"{4}" + u"" + u"" + u"" + u"" + u"" + u"").format(incarnation, + container_id, + role_instance_id, + status, + detail) + return xml + +""" +Convert VMStatus object to status blob format +""" +def ga_status_to_v1(ga_status): + formatted_msg = { + 'lang' : 'en-US', + 'message' : ga_status.message + } + v1_ga_status = { + 'version' : ga_status.version, + 'status' : ga_status.status, + 'formattedMessage' : formatted_msg + } + return v1_ga_status + +def ext_substatus_to_v1(sub_status_list): + status_list = [] + for substatus in sub_status_list: + status = { + "name": substatus.name, + "status": substatus.status, + "code": substatus.code, + "formattedMessage":{ + "lang": "en-US", + "message": substatus.message + } + } + status_list.append(status) + return status_list + +def ext_status_to_v1(ext_name, ext_status): + if ext_status is None: + return None + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + v1_sub_status = ext_substatus_to_v1(ext_status.substatusList) + v1_ext_status = { + "status":{ + "name": ext_name, + "configurationAppliedTime": ext_status.configurationAppliedTime, + "operation": ext_status.operation, + "status": ext_status.status, + "code": ext_status.code, + "formattedMessage": { + "lang":"en-US", + "message": ext_status.message + } + }, + "version": 1.0, + "timestampUTC": timestamp + } + if len(v1_sub_status) != 0: + v1_ext_status['substatus'] = v1_sub_status + return v1_ext_status + +def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): + v1_handler_status = { + 'handlerVersion' : handler_status.version, + 'handlerName' : handler_status.name, + 'status' : handler_status.status, + 'code': handler_status.code + } + if handler_status.message is not None: + v1_handler_status["formattedMessage"] = { + "lang":"en-US", + "message": handler_status.message + } + + if len(handler_status.extensions) > 0: + #Currently, no more than one extension per handler + ext_name = handler_status.extensions[0] + ext_status = ext_statuses.get(ext_name) + v1_ext_status = ext_status_to_v1(ext_name, ext_status) + if ext_status is not None and v1_ext_status is not None: + v1_handler_status["runtimeSettingsStatus"] = { + 'settingsStatus' : v1_ext_status, + 'sequenceNumber' : ext_status.sequenceNumber + } + return v1_handler_status + +def vm_status_to_v1(vm_status, ext_statuses): + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + v1_ga_status = ga_status_to_v1(vm_status.vmAgent) + v1_handler_status_list = [] + for handler_status in vm_status.vmAgent.extensionHandlers: + v1_handler_status = ext_handler_status_to_v1(handler_status, + ext_statuses, timestamp) + if v1_handler_status is not None: + v1_handler_status_list.append(v1_handler_status) + + v1_agg_status = { + 'guestAgentStatus': v1_ga_status, + 'handlerAggregateStatus' : v1_handler_status_list + } + v1_vm_status = { + 'version' : '1.0', + 'timestampUTC' : timestamp, + 'aggregateStatus' : v1_agg_status + } + return v1_vm_status + + +class StatusBlob(object): + def __init__(self, client): + self.vm_status = None + self.ext_statuses = {} + self.client = client + + def set_vm_status(self, vm_status): + validata_param("vmAgent", vm_status, VMStatus) + self.vm_status = vm_status + + def set_ext_status(self, ext_handler_name, ext_status): + validata_param("extensionStatus", ext_status, ExtensionStatus) + self.ext_statuses[ext_handler_name]= ext_status + + def to_json(self): + report = vm_status_to_v1(self.vm_status, self.ext_statuses) + return json.dumps(report) + + __storage_version__ = "2014-02-14" + + def upload(self, url): + #TODO upload extension only if content has changed + logger.verb("Upload status blob") + blob_type = self.get_blob_type(url) + + data = self.to_json() + try: + if blob_type == "BlockBlob": + self.put_block_blob(url, data) + elif blob_type == "PageBlob": + self.put_page_blob(url, data) + else: + raise ProtocolError("Unknown blob type: {0}".format(blob_type)) + except HttpError as e: + raise ProtocolError("Failed to upload status blob: {0}".format(e)) + + def get_blob_type(self, url): + #Check blob type + logger.verb("Check blob type.") + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + try: + resp = self.client.call_storage_service(restutil.http_head, url, { + "x-ms-date" : timestamp, + 'x-ms-version' : self.__class__.__storage_version__ + }) + except HttpError as e: + raise ProtocolError((u"Failed to get status blob type: {0}" + u"").format(e)) + if resp is None or resp.status != httpclient.OK: + raise ProtocolError(("Failed to get status blob type: {0}" + "").format(resp.status)) + + blob_type = resp.getheader("x-ms-blob-type") + logger.verb("Blob type={0}".format(blob_type)) + return blob_type + + def put_block_blob(self, url, data): + logger.verb("Upload block blob") + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + try: + resp = self.client.call_storage_service(restutil.http_put, url, + data, { + "x-ms-date" : timestamp, + "x-ms-blob-type" : "BlockBlob", + "Content-Length": ustr(len(data)), + "x-ms-version" : self.__class__.__storage_version__ + }) + except HttpError as e: + raise ProtocolError((u"Failed to upload block blob: {0}" + u"").format(e)) + if resp.status != httpclient.CREATED: + raise ProtocolError(("Failed to upload block blob: {0}" + "").format(resp.status)) + + def put_page_blob(self, url, data): + logger.verb("Replace old page blob") + + #Convert string into bytes + data=bytearray(data, encoding='utf-8') + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + #Align to 512 bytes + page_blob_size = int((len(data) + 511) / 512) * 512 + try: + resp = self.client.call_storage_service(restutil.http_put, url, + "", { + "x-ms-date" : timestamp, + "x-ms-blob-type" : "PageBlob", + "Content-Length": "0", + "x-ms-blob-content-length" : ustr(page_blob_size), + "x-ms-version" : self.__class__.__storage_version__ + }) + except HttpError as e: + raise ProtocolError((u"Failed to clean up page blob: {0}" + u"").format(e)) + if resp.status != httpclient.CREATED: + raise ProtocolError(("Failed to clean up page blob: {0}" + "").format(resp.status)) + + if url.count("?") < 0: + url = "{0}?comp=page".format(url) + else: + url = "{0}&comp=page".format(url) + + logger.verb("Upload page blob") + page_max = 4 * 1024 * 1024 #Max page size: 4MB + start = 0 + end = 0 + while end < len(data): + end = min(len(data), start + page_max) + content_size = end - start + #Align to 512 bytes + page_end = int((end + 511) / 512) * 512 + buf_size = page_end - start + buf = bytearray(buf_size) + buf[0: content_size] = data[start: end] + try: + resp = self.client.call_storage_service(restutil.http_put, url, + bytebuffer(buf), { + "x-ms-date" : timestamp, + "x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1), + "x-ms-page-write" : "update", + "x-ms-version" : self.__class__.__storage_version__, + "Content-Length": ustr(page_end - start) + }) + except HttpError as e: + raise ProtocolError((u"Failed to upload page blob: {0}" + u"").format(e)) + if resp is None or resp.status != httpclient.CREATED: + raise ProtocolError(("Failed to upload page blob: {0}" + "").format(resp.status)) + start = end + +def event_param_to_v1(param): + param_format = '' + param_type = type(param.value) + attr_type = "" + if param_type is int: + attr_type = 'mt:uint64' + elif param_type is str: + attr_type = 'mt:wstr' + elif ustr(param_type).count("'unicode'") > 0: + attr_type = 'mt:wstr' + elif param_type is bool: + attr_type = 'mt:bool' + elif param_type is float: + attr_type = 'mt:float64' + return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)), + attr_type) + +def event_to_v1(event): + params = "" + for param in event.parameters: + params += event_param_to_v1(param) + event_str = ('' + '' + '').format(event.eventId, params) + return event_str + +class WireClient(object): + def __init__(self, endpoint): + logger.info("Wire server endpoint:{0}", endpoint) + self.endpoint = endpoint + self.goal_state = None + self.updated = None + self.hosting_env = None + self.shared_conf = None + self.certs = None + self.ext_conf = None + self.last_request = 0 + self.req_count = 0 + self.status_blob = StatusBlob(self) + + def prevent_throttling(self): + """ + Try to avoid throttling of wire server + """ + now = time.time() + if now - self.last_request < 1: + logger.verb("Last request issued less than 1 second ago") + logger.verb("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.last_request = now + + self.req_count += 1 + if self.req_count % 3 == 0: + logger.verb("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.req_count = 0 + + def call_wireserver(self, http_req, *args, **kwargs): + """ + Call wire server. Handle throttling(403) and Resource Gone(410) + """ + self.prevent_throttling() + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.FORBIDDEN: + logger.warn("Sending too much request to wire server") + logger.info("Sleep {0} second to avoid throttling.", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + elif resp.status == httpclient.GONE: + msg = args[0] if len(args) > 0 else "" + raise WireProtocolResourceGone(msg) + else: + return resp + raise ProtocolError(("Calling wire server failed: {0}" + "").format(resp.status)) + + def decode_config(self, data): + if data is None: + return None + data = remove_bom(data) + xml_text = ustr(data, encoding='utf-8') + return xml_text + + def fetch_config(self, uri, headers): + try: + resp = self.call_wireserver(restutil.http_get, uri, + headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if(resp.status != httpclient.OK): + raise ProtocolError("{0} - {1}".format(resp.status, uri)) + + return self.decode_config(resp.read()) + + def fetch_cache(self, local_file): + if not os.path.isfile(local_file): + raise ProtocolError("{0} is missing.".format(local_file)) + try: + return fileutil.read_file(local_file) + except IOError as e: + raise ProtocolError("Failed to read cache: {0}".format(e)) + + def save_cache(self, local_file, data): + try: + fileutil.write_file(local_file, data) + except IOError as e: + raise ProtocolError("Failed to write cache: {0}".format(e)) + + def call_storage_service(self, http_req, *args, **kwargs): + """ + Call storage service, handle SERVICE_UNAVAILABLE(503) + """ + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.SERVICE_UNAVAILABLE: + logger.warn("Storage service is not avaible temporaryly") + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + else: + return resp + raise ProtocolError(("Calling storage endpoint failed: {0}" + "").format(resp.status)) + + def fetch_manifest(self, version_uris): + for version_uri in version_uris: + logger.verb("Fetch ext handler manifest: {0}", version_uri.uri) + try: + resp = self.call_storage_service(restutil.http_get, + version_uri.uri, None, + chk_proxy=True) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if resp.status == httpclient.OK: + return self.decode_config(resp.read()) + logger.warn("Failed to fetch ExtensionManifest: {0}, {1}", + resp.status, version_uri.uri) + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + raise ProtocolError(("Failed to fetch ExtensionManifest from " + "all sources")) + + + def update_hosting_env(self, goal_state): + if goal_state.hosting_env_uri is None: + raise ProtocolError("HostingEnvironmentConfig uri is empty") + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_config(goal_state.hosting_env_uri, + self.get_header()) + self.save_cache(local_file, xml_text) + self.hosting_env = HostingEnv(xml_text) + + def update_shared_conf(self, goal_state): + if goal_state.shared_conf_uri is None: + raise ProtocolError("SharedConfig uri is empty") + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_config(goal_state.shared_conf_uri, + self.get_header()) + self.save_cache(local_file, xml_text) + self.shared_conf = SharedConfig(xml_text) + + def update_certs(self, goal_state): + if goal_state.certs_uri is None: + return + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_config(goal_state.certs_uri, + self.get_header_for_cert()) + self.save_cache(local_file, xml_text) + self.certs = Certificates(self, xml_text) + + def update_ext_conf(self, goal_state): + if goal_state.ext_uri is None: + logger.info("ExtensionsConfig.xml uri is empty") + self.ext_conf = ExtensionsConfig(None) + return + incarnation = goal_state.incarnation + local_file = os.path.join(conf.get_lib_dir(), + EXT_CONF_FILE_NAME.format(incarnation)) + xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) + self.save_cache(local_file, xml_text) + self.ext_conf = ExtensionsConfig(xml_text) + + def update_goal_state(self, forced=False, max_retry=3): + uri = GOAL_STATE_URI.format(self.endpoint) + xml_text = self.fetch_config(uri, self.get_header()) + goal_state = GoalState(xml_text) + + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + + if not forced: + last_incarnation = None + if(os.path.isfile(incarnation_file)): + last_incarnation = fileutil.read_file(incarnation_file) + new_incarnation = goal_state.incarnation + if last_incarnation is not None and \ + last_incarnation == new_incarnation: + #Goalstate is not updated. + return + + #Start updating goalstate, retry on 410 + for retry in range(0, max_retry): + try: + self.goal_state = goal_state + file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) + self.save_cache(goal_state_file, xml_text) + self.save_cache(incarnation_file, goal_state.incarnation) + self.update_hosting_env(goal_state) + self.update_shared_conf(goal_state) + self.update_certs(goal_state) + self.update_ext_conf(goal_state) + return + except WireProtocolResourceGone: + logger.info("Incarnation is out of date. Update goalstate.") + xml_text = self.fetch_config(uri, self.get_header()) + goal_state = GoalState(xml_text) + + raise ProtocolError("Exceeded max retry updating goal state") + + def get_goal_state(self): + if(self.goal_state is None): + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + incarnation = self.fetch_cache(incarnation_file) + + file_name = GOAL_STATE_FILE_NAME.format(incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) + xml_text = self.fetch_cache(goal_state_file) + self.goal_state = GoalState(xml_text) + return self.goal_state + + def get_hosting_env(self): + if(self.hosting_env is None): + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.hosting_env = HostingEnv(xml_text) + return self.hosting_env + + def get_shared_conf(self): + if(self.shared_conf is None): + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.shared_conf = SharedConfig(xml_text) + return self.shared_conf + + def get_certs(self): + if(self.certs is None): + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.certs = Certificates(self, xml_text) + if self.certs is None: + return None + return self.certs + + def get_ext_conf(self): + if(self.ext_conf is None): + goal_state = self.get_goal_state() + if goal_state.ext_uri is None: + self.ext_conf = ExtensionsConfig(None) + else: + local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_cache(local_file) + self.ext_conf = ExtensionsConfig(xml_text) + return self.ext_conf + + def get_ext_manifest(self, ext_handler, goal_state): + local_file = MANIFEST_FILE_NAME.format(ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) + return ExtensionManifest(xml_text) + + def check_wire_protocol_version(self): + uri = VERSION_INFO_URI.format(self.endpoint) + version_info_xml = self.fetch_config(uri, None) + version_info = VersionInfo(version_info_xml) + + preferred = version_info.get_preferred() + if PROTOCOL_VERSION == preferred: + logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) + elif PROTOCOL_VERSION in version_info.get_supported(): + logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) + logger.warn("Server prefered version:{0}", preferred) + else: + error = ("Agent supported wire protocol version: {0} was not " + "advised by Fabric.").format(PROTOCOL_VERSION) + raise ProtocolNotFoundError(error) + + def upload_status_blob(self): + ext_conf = self.get_ext_conf() + if ext_conf.status_upload_blob is not None: + self.status_blob.upload(ext_conf.status_upload_blob) + + def report_role_prop(self, thumbprint): + goal_state = self.get_goal_state() + role_prop = _build_role_properties(goal_state.container_id, + goal_state.role_instance_id, + thumbprint) + role_prop = role_prop.encode("utf-8") + role_prop_uri = ROLE_PROP_URI.format(self.endpoint) + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, role_prop_uri, + role_prop, headers = headers) + except HttpError as e: + raise ProtocolError((u"Failed to send role properties: {0}" + u"").format(e)) + if resp.status != httpclient.ACCEPTED: + raise ProtocolError((u"Failed to send role properties: {0}" + u", {1}").format(resp.status, resp.read())) + + def report_health(self, status, substatus, description): + goal_state = self.get_goal_state() + health_report = _build_health_report(goal_state.incarnation, + goal_state.container_id, + goal_state.role_instance_id, + status, + substatus, + description) + health_report = health_report.encode("utf-8") + health_report_uri = HEALTH_REPORT_URI.format(self.endpoint) + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, health_report_uri, + health_report, headers = headers) + except HttpError as e: + raise ProtocolError((u"Failed to send provision status: {0}" + u"").format(e)) + if resp.status != httpclient.OK: + raise ProtocolError((u"Failed to send provision status: {0}" + u", {1}").format(resp.status, resp.read())) + + def send_event(self, provider_id, event_str): + uri = TELEMETRY_URI.format(self.endpoint) + data_format = ('' + '' + '{1}' + '' + '') + data = data_format.format(provider_id, event_str) + try: + header = self.get_header_for_xml_content() + resp = self.call_wireserver(restutil.http_post, uri, data, header) + except HttpError as e: + raise ProtocolError("Failed to send events:{0}".format(e)) + + if resp.status != httpclient.OK: + logger.verb(resp.read()) + raise ProtocolError("Failed to send events:{0}".format(resp.status)) + + def report_event(self, event_list): + buf = {} + #Group events by providerId + for event in event_list.events: + if event.providerId not in buf: + buf[event.providerId] = "" + event_str = event_to_v1(event) + if len(event_str) >= 63 * 1024: + logger.warn("Single event too large: {0}", event_str[300:]) + continue + if len(buf[event.providerId] + event_str) >= 63 * 1024: + self.send_event(event.providerId, buf[event.providerId]) + buf[event.providerId] = "" + buf[event.providerId] = buf[event.providerId] + event_str + + #Send out all events left in buffer. + for provider_id in list(buf.keys()): + if len(buf[provider_id]) > 0: + self.send_event(provider_id, buf[provider_id]) + + def get_header(self): + return { + "x-ms-agent-name":"WALinuxAgent", + "x-ms-version":PROTOCOL_VERSION + } + + def get_header_for_xml_content(self): + return { + "x-ms-agent-name":"WALinuxAgent", + "x-ms-version":PROTOCOL_VERSION, + "Content-Type":"text/xml;charset=utf-8" + } + + def get_header_for_cert(self): + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + content = self.fetch_cache(trans_cert_file) + cert = get_bytes_from_pem(content) + return { + "x-ms-agent-name":"WALinuxAgent", + "x-ms-version":PROTOCOL_VERSION, + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert":cert + } + +class VersionInfo(object): + def __init__(self, xml_text): + """ + Query endpoint server for wire protocol version. + Fail if our desired protocol version is not seen. + """ + logger.verb("Load Version.xml") + self.parse(xml_text) + + def parse(self, xml_text): + xml_doc = parse_doc(xml_text) + preferred = find(xml_doc, "Preferred") + self.preferred = findtext(preferred, "Version") + logger.info("Fabric preferred wire protocol version:{0}", self.preferred) + + self.supported = [] + supported = find(xml_doc, "Supported") + supported_version = findall(supported, "Version") + for node in supported_version: + version = gettext(node) + logger.verb("Fabric supported wire protocol version:{0}", version) + self.supported.append(version) + + def get_preferred(self): + return self.preferred + + def get_supported(self): + return self.supported + + +class GoalState(object): + + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("GoalState.xml is None") + logger.verb("Load GoalState.xml") + self.incarnation = None + self.expected_state = None + self.hosting_env_uri = None + self.shared_conf_uri = None + self.certs_uri = None + self.ext_uri = None + self.role_instance_id = None + self.container_id = None + self.load_balancer_probe_port = None + self.parse(xml_text) + + def parse(self, xml_text): + """ + Request configuration data from endpoint server. + """ + self.xml_text = xml_text + xml_doc = parse_doc(xml_text) + self.incarnation = findtext(xml_doc, "Incarnation") + self.expected_state = findtext(xml_doc, "ExpectedState") + self.hosting_env_uri = findtext(xml_doc, "HostingEnvironmentConfig") + self.shared_conf_uri = findtext(xml_doc, "SharedConfig") + self.certs_uri = findtext(xml_doc, "Certificates") + self.ext_uri = findtext(xml_doc, "ExtensionsConfig") + role_instance = find(xml_doc, "RoleInstance") + self.role_instance_id = findtext(role_instance, "InstanceId") + container = find(xml_doc, "Container") + self.container_id = findtext(container, "ContainerId") + lbprobe_ports = find(xml_doc, "LBProbePorts") + self.load_balancer_probe_port = findtext(lbprobe_ports, "Port") + return self + + +class HostingEnv(object): + """ + parse Hosting enviromnet config and store in + HostingEnvironmentConfig.xml + """ + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("HostingEnvironmentConfig.xml is None") + logger.verb("Load HostingEnvironmentConfig.xml") + self.vm_name = None + self.role_name = None + self.deployment_name = None + self.parse(xml_text) + + def parse(self, xml_text): + """ + parse and create HostingEnvironmentConfig.xml. + """ + self.xml_text = xml_text + xml_doc = parse_doc(xml_text) + incarnation = find(xml_doc, "Incarnation") + self.vm_name = getattrib(incarnation, "instance") + role = find(xml_doc, "Role") + self.role_name = getattrib(role, "name") + deployment = find(xml_doc, "Deployment") + self.deployment_name = getattrib(deployment, "name") + return self + +class SharedConfig(object): + """ + parse role endpoint server and goal state config. + """ + def __init__(self, xml_text): + logger.verb("Load SharedConfig.xml") + self.parse(xml_text) + + def parse(self, xml_text): + """ + parse and write configuration to file SharedConfig.xml. + """ + #Not used currently + return self + +class Certificates(object): + + """ + Object containing certificates of host and provisioned user. + """ + def __init__(self, client, xml_text): + logger.verb("Load Certificates.xml") + self.client = client + self.cert_list = CertList() + self.parse(xml_text) + + def parse(self, xml_text): + """ + Parse multiple certificates into seperate files. + """ + xml_doc = parse_doc(xml_text) + data = findtext(xml_doc, "Data") + if data is None: + return + + cryptutil = CryptUtil(conf.get_openssl_cmd()) + p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) + p7m = ("MIME-Version:1.0\n" + "Content-Disposition: attachment; filename=\"{0}\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" + "Content-Transfer-Encoding: base64\n" + "\n" + "{2}").format(p7m_file, p7m_file, data) + + self.client.save_cache(p7m_file, p7m) + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) + #decrypt certificates + cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, + pem_file) + + #The parsing process use public key to match prv and crt. + buf = [] + begin_crt = False + begin_prv = False + prvs = {} + thumbprints = {} + index = 0 + v1_cert_list = [] + with open(pem_file) as pem: + for line in pem.readlines(): + buf.append(line) + if re.match(r'[-]+BEGIN.*KEY[-]+', line): + begin_prv = True + elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): + begin_crt = True + elif re.match(r'[-]+END.*KEY[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'prv', buf) + pub = cryptutil.get_pubkey_from_prv(tmp_file) + prvs[pub] = tmp_file + buf = [] + index += 1 + begin_prv = False + elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'crt', buf) + pub = cryptutil.get_pubkey_from_crt(tmp_file) + thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) + thumbprints[pub] = thumbprint + #Rename crt with thumbprint as the file name + crt = "{0}.crt".format(thumbprint) + v1_cert_list.append({ + "name":None, + "thumbprint":thumbprint + }) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) + buf = [] + index += 1 + begin_crt = False + + #Rename prv key with thumbprint as the file name + for pubkey in prvs: + thumbprint = thumbprints[pubkey] + if thumbprint: + tmp_file = prvs[pubkey] + prv = "{0}.prv".format(thumbprint) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) + + for v1_cert in v1_cert_list: + cert = Cert() + set_properties("certs", cert, v1_cert) + self.cert_list.certificates.append(cert) + + def write_to_tmp_file(self, index, suffix, buf): + file_name = os.path.join(conf.get_lib_dir(), + "{0}.{1}".format(index, suffix)) + self.client.save_cache(file_name, "".join(buf)) + return file_name + + +class ExtensionsConfig(object): + """ + parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. + Install if true, remove if it is set to false. + """ + + def __init__(self, xml_text): + logger.verb("Load ExtensionsConfig.xml") + self.ext_handlers = ExtHandlerList() + self.status_upload_blob = None + if xml_text is not None: + self.parse(xml_text) + + def parse(self, xml_text): + """ + Write configuration to file ExtensionsConfig.xml. + """ + xml_doc = parse_doc(xml_text) + plugins_list = find(xml_doc, "Plugins") + plugins = findall(plugins_list, "Plugin") + plugin_settings_list = find(xml_doc, "PluginSettings") + plugin_settings = findall(plugin_settings_list, "Plugin") + + for plugin in plugins: + ext_handler = self.parse_plugin(plugin) + self.ext_handlers.extHandlers.append(ext_handler) + self.parse_plugin_settings(ext_handler, plugin_settings) + + self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob") + + def parse_plugin(self, plugin): + ext_handler = ExtHandler() + ext_handler.name = getattrib(plugin, "name") + ext_handler.properties.version = getattrib(plugin, "version") + ext_handler.properties.state = getattrib(plugin, "state") + + auto_upgrade = getattrib(plugin, "autoUpgrade") + if auto_upgrade is not None and auto_upgrade.lower() == "true": + ext_handler.properties.upgradePolicy = "auto" + else: + ext_handler.properties.upgradePolicy = "manual" + + location = getattrib(plugin, "location") + failover_location = getattrib(plugin, "failoverlocation") + for uri in [location, failover_location]: + version_uri = ExtHandlerVersionUri() + version_uri.uri = uri + ext_handler.versionUris.append(version_uri) + return ext_handler + + def parse_plugin_settings(self, ext_handler, plugin_settings): + if plugin_settings is None: + return + + name = ext_handler.name + version = ext_handler.properties.version + settings = [x for x in plugin_settings \ + if getattrib(x, "name") == name and \ + getattrib(x ,"version") == version] + + if settings is None or len(settings) == 0: + return + + runtime_settings = None + runtime_settings_node = find(settings[0], "RuntimeSettings") + seqNo = getattrib(runtime_settings_node, "seqNo") + runtime_settings_str = gettext(runtime_settings_node) + try: + runtime_settings = json.loads(runtime_settings_str) + except ValueError as e: + logger.error("Invalid extension settings") + return + + for plugin_settings_list in runtime_settings["runtimeSettings"]: + handler_settings = plugin_settings_list["handlerSettings"] + ext = Extension() + #There is no "extension name" in wire protocol. + #Put + ext.name = ext_handler.name + ext.sequenceNumber = seqNo + ext.publicSettings = handler_settings.get("publicSettings") + ext.protectedSettings = handler_settings.get("protectedSettings") + thumbprint = handler_settings.get("protectedSettingsCertThumbprint") + ext.certificateThumbprint = thumbprint + ext_handler.properties.extensions.append(ext) + +class ExtensionManifest(object): + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("ExtensionManifest is None") + logger.verb("Load ExtensionManifest.xml") + self.pkg_list = ExtHandlerPackageList() + self.parse(xml_text) + + def parse(self, xml_text): + xml_doc = parse_doc(xml_text) + packages = findall(xml_doc, "Plugin") + for package in packages: + version = findtext(package, "Version") + uris = find(package, "Uris") + uri_list = findall(uris, "Uri") + uri_list = [gettext(x) for x in uri_list] + package = ExtHandlerPackage() + package.version = version + for uri in uri_list: + pkg_uri = ExtHandlerVersionUri() + pkg_uri.uri = uri + package.uris.append(pkg_uri) + self.pkg_list.versions.append(package) + diff --git a/azurelinuxagent/utils/cryptutil.py b/azurelinuxagent/utils/cryptutil.py new file mode 100644 index 0000000..5ee5637 --- /dev/null +++ b/azurelinuxagent/utils/cryptutil.py @@ -0,0 +1,121 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import base64 +import struct +from azurelinuxagent.future import ustr, bytebuffer +from azurelinuxagent.exception import CryptError +import azurelinuxagent.utils.shellutil as shellutil + +class CryptUtil(object): + def __init__(self, openssl_cmd): + self.openssl_cmd = openssl_cmd + + def gen_transport_cert(self, prv_file, crt_file): + """ + Create ssl certificate for https communication with endpoint server. + """ + cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 " + "-newkey rsa:2048 -keyout {1} " + "-out {2}").format(self.openssl_cmd, prv_file, crt_file) + shellutil.run(cmd) + + def get_pubkey_from_prv(self, file_name): + cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, + file_name) + pub = shellutil.run_get_output(cmd)[1] + return pub + + def get_pubkey_from_crt(self, file_name): + cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd, + file_name) + pub = shellutil.run_get_output(cmd)[1] + return pub + + def get_thumbprint_from_crt(self, file_name): + cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd, + file_name) + thumbprint = shellutil.run_get_output(cmd)[1] + thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper() + return thumbprint + + def decrypt_p7m(self, p7m_file, trans_prv_file, trans_cert_file, pem_file): + cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3} " + "| {4} pkcs12 -nodes -password pass: -out {5}" + "").format(self.openssl_cmd, p7m_file, trans_prv_file, + trans_cert_file, self.openssl_cmd, pem_file) + shellutil.run(cmd) + + def crt_to_ssh(self, input_file, output_file): + shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file, + output_file)) + + def asn1_to_ssh(self, pubkey): + lines = pubkey.split("\n") + lines = [x for x in lines if not x.startswith("----")] + base64_encoded = "".join(lines) + try: + #TODO remove pyasn1 dependency + from pyasn1.codec.der import decoder as der_decoder + der_encoded = base64.b64decode(base64_encoded) + der_encoded = der_decoder.decode(der_encoded)[0][1] + key = der_decoder.decode(self.bits_to_bytes(der_encoded))[0] + n=key[0] + e=key[1] + keydata = bytearray() + keydata.extend(struct.pack('>I', len("ssh-rsa"))) + keydata.extend(b"ssh-rsa") + keydata.extend(struct.pack('>I', len(self.num_to_bytes(e)))) + keydata.extend(self.num_to_bytes(e)) + keydata.extend(struct.pack('>I', len(self.num_to_bytes(n)) + 1)) + keydata.extend(b"\0") + keydata.extend(self.num_to_bytes(n)) + keydata_base64 = base64.b64encode(bytebuffer(keydata)) + return ustr(b"ssh-rsa " + keydata_base64 + b"\n", + encoding='utf-8') + except ImportError as e: + raise CryptError("Failed to load pyasn1.codec.der") + + def num_to_bytes(self, num): + """ + Pack number into bytes. Retun as string. + """ + result = bytearray() + while num: + result.append(num & 0xFF) + num >>= 8 + result.reverse() + return result + + def bits_to_bytes(self, bits): + """ + Convert an array contains bits, [0,1] to a byte array + """ + index = 7 + byte_array = bytearray() + curr = 0 + for bit in bits: + curr = curr | (bit << index) + index = index - 1 + if index == -1: + byte_array.append(curr) + curr = 0 + index = 7 + return bytes(byte_array) + diff --git a/azurelinuxagent/utils/fileutil.py b/azurelinuxagent/utils/fileutil.py index 08592bc..5369a7c 100644 --- a/azurelinuxagent/utils/fileutil.py +++ b/azurelinuxagent/utils/fileutil.py @@ -27,7 +27,7 @@ import shutil import pwd import tempfile import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.utils.textutil as textutil def read_file(filepath, asbin=False, remove_bom=False, encoding='utf-8'): @@ -46,7 +46,7 @@ def read_file(filepath, asbin=False, remove_bom=False, encoding='utf-8'): if remove_bom: #Remove bom on bytes data before it is converted into string. data = textutil.remove_bom(data) - data = text(data, encoding=encoding) + data = ustr(data, encoding=encoding) return data def write_file(filepath, contents, asbin=False, encoding='utf-8', append=False): @@ -100,6 +100,7 @@ def replace_file(filepath, contents): return 1 return 0 + def base_name(path): head, tail = os.path.split(path) return tail @@ -151,7 +152,7 @@ def rm_dirs(*args): def update_conf_file(path, line_start, val, chk_err=False): conf = [] if not os.path.isfile(path) and chk_err: - raise Exception("Can't find config file:{0}".format(path)) + raise IOError("Can't find config file:{0}".format(path)) conf = read_file(path).split('\n') conf = [x for x in conf if not x.startswith(line_start)] conf.append(val) diff --git a/azurelinuxagent/utils/osutil.py b/azurelinuxagent/utils/osutil.py deleted file mode 100644 index 9de47e7..0000000 --- a/azurelinuxagent/utils/osutil.py +++ /dev/null @@ -1,27 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -""" -Load OSUtil implementation from azurelinuxagent.distro -""" -from azurelinuxagent.distro.default.osutil import OSUtilError -import azurelinuxagent.distro.loader as loader - -OSUTIL = loader.get_osutil() - diff --git a/azurelinuxagent/utils/restutil.py b/azurelinuxagent/utils/restutil.py index 2acfa57..2e8b0be 100644 --- a/azurelinuxagent/utils/restutil.py +++ b/azurelinuxagent/utils/restutil.py @@ -21,8 +21,9 @@ import time import platform import os import subprocess -import azurelinuxagent.logger as logger import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.exception import HttpError from azurelinuxagent.future import httpclient, urlparse """ @@ -31,9 +32,6 @@ REST api util functions RETRY_WAITING_INTERVAL = 10 -class HttpError(Exception): - pass - def _parse_url(url): o = urlparse(url) rel_uri = o.path @@ -51,8 +49,8 @@ def get_http_proxy(): Get http_proxy and https_proxy from environment variables. Username and password is not supported now. """ - host = conf.get("HttpProxy.Host", None) - port = conf.get("HttpProxy.Port", None) + host = conf.get_httpproxy_host() + port = conf.get_httpproxy_port() return (host, port) def _http_request(method, host, rel_uri, port=None, data=None, secure=False, @@ -61,7 +59,7 @@ def _http_request(method, host, rel_uri, port=None, data=None, secure=False, if secure: port = 443 if port is None else port if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPSConnection(proxy_host, proxy_port) + conn = httpclient.HTTPSConnection(proxy_host, proxy_port, timeout=10) conn.set_tunnel(host, port) #If proxy is used, full url is needed. url = "https://{0}:{1}{2}".format(host, port, rel_uri) @@ -71,7 +69,7 @@ def _http_request(method, host, rel_uri, port=None, data=None, secure=False, else: port = 80 if port is None else port if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPConnection(proxy_host, proxy_port) + conn = httpclient.HTTPConnection(proxy_host, proxy_port, timeout=10) #If proxy is used, full url is needed. url = "http://{0}:{1}{2}".format(host, port, rel_uri) else: @@ -128,8 +126,12 @@ def http_request(method, url, data, headers=None, max_retry=3, chk_proxy=False): if retry < max_retry - 1: logger.info("Retry={0}, {1} {2}", retry, method, url) time.sleep(RETRY_WAITING_INTERVAL) - - raise HttpError("HTTP Err: {0} {1}".format(method, url)) + + if url is not None and len(url) > 100: + url_log = url[0: 100] #In case the url is too long + else: + url_log = url + raise HttpError("HTTP Err: {0} {1}".format(method, url_log)) def http_get(url, headers=None, max_retry=3, chk_proxy=False): return http_request("GET", url, data=None, headers=headers, diff --git a/azurelinuxagent/utils/shellutil.py b/azurelinuxagent/utils/shellutil.py index 372c78a..98871a1 100644 --- a/azurelinuxagent/utils/shellutil.py +++ b/azurelinuxagent/utils/shellutil.py @@ -20,7 +20,7 @@ import platform import os import subprocess -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.logger as logger if not hasattr(subprocess,'check_output'): @@ -75,9 +75,9 @@ def run_get_output(cmd, chk_err=True, log_cmd=True): logger.verb(u"run cmd '{0}'", cmd) try: output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True) - output = text(output, encoding='utf-8', errors="backslashreplace") + output = ustr(output, encoding='utf-8', errors="backslashreplace") except subprocess.CalledProcessError as e : - output = text(e.output, encoding='utf-8', errors="backslashreplace") + output = ustr(e.output, encoding='utf-8', errors="backslashreplace") if chk_err: if log_cmd: logger.error(u"run cmd '{0}' failed", e.cmd) diff --git a/azurelinuxagent/utils/textutil.py b/azurelinuxagent/utils/textutil.py index e0f1395..851f98a 100644 --- a/azurelinuxagent/utils/textutil.py +++ b/azurelinuxagent/utils/textutil.py @@ -224,5 +224,13 @@ def gen_password_hash(password, crypt_id, salt_len): salt = "${0}${1}".format(crypt_id, salt) return crypt.crypt(password, salt) +def get_bytes_from_pem(pem_str): + base64_bytes = "" + for line in pem_str.split('\n'): + if "----" not in line: + base64_bytes += line + return base64_bytes + + Version = LooseVersion -- cgit v1.2.3