diff options
author | Ben Howard <ben.howard@ubuntu.com> | 2016-02-08 16:33:07 -0700 |
---|---|---|
committer | usd-importer <ubuntu-server@lists.ubuntu.com> | 2016-02-09 00:59:05 +0000 |
commit | a00729ff7421b3661e8b1a1e0fa46393379f2e96 (patch) | |
tree | 4563b927e3a57446a4a928a72a92d72c9ad4f6e6 /azurelinuxagent | |
parent | 53f54030cae2de3d5fa474a61fe51f16c7a07c79 (diff) | |
download | vyos-walinuxagent-a00729ff7421b3661e8b1a1e0fa46393379f2e96.tar.gz vyos-walinuxagent-a00729ff7421b3661e8b1a1e0fa46393379f2e96.zip |
Import patches-unapplied version 2.1.3-0ubuntu1 to ubuntu/xenial-proposed
Imported using git-ubuntu import.
Changelog parent: 53f54030cae2de3d5fa474a61fe51f16c7a07c79
New changelog entries:
* New upstream release (LP: #1543359):
- Bug fixes for extension handling
- Feature enablement for AzureStack.
Diffstat (limited to 'azurelinuxagent')
54 files changed, 2190 insertions, 1809 deletions
diff --git a/azurelinuxagent/agent.py b/azurelinuxagent/agent.py index 849a192..93e9c16 100644 --- a/azurelinuxagent/agent.py +++ b/azurelinuxagent/agent.py @@ -29,27 +29,60 @@ from azurelinuxagent.metadata import AGENT_NAME, AGENT_LONG_VERSION, \ DISTRO_NAME, DISTRO_VERSION, \ PY_VERSION_MAJOR, PY_VERSION_MINOR, \ PY_VERSION_MICRO -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.handler import HANDLERS +from azurelinuxagent.distro.loader import get_distro -def init(verbose): - """ - Initialize agent running environment. - """ - HANDLERS.init_handler.init(verbose) +class Agent(object): + def __init__(self, verbose): + """ + Initialize agent running environment. + """ + self.distro = get_distro(); + self.distro.init_handler.run(verbose) -def run(): - """ - Run agent daemon - """ - HANDLERS.main_handler.run() + def daemon(self): + """ + Run agent daemon + """ + self.distro.daemon_handler.run() + + def deprovision(self, force=False, deluser=False): + """ + Run deprovision command + """ + self.distro.deprovision_handler.run(force=force, deluser=deluser) -def deprovision(force=False, deluser=False): + def register_service(self): + """ + Register agent as a service + """ + print("Register {0} service".format(AGENT_NAME)) + self.distro.osutil.register_agent_service() + print("Start {0} service".format(AGENT_NAME)) + self.distro.osutil.start_agent_service() + +def main(): """ - Run deprovision command + Parse command line arguments, exit with usage() on error. + Invoke different methods according to different command """ - HANDLERS.deprovision_handler.deprovision(force=force, deluser=deluser) + command, force, verbose = parse_args(sys.argv[1:]) + if command == "version": + version() + elif command == "help": + usage() + elif command == "start": + start() + else: + agent = Agent(verbose) + if command == "deprovision+user": + agent.deprovision(force, deluser=True) + elif command == "deprovision": + agent.deprovision(force, deluser=False) + elif command == "register-service": + agent.register_service() + elif command == "daemon": + agent.daemon() def parse_args(sys_args): """ @@ -108,34 +141,3 @@ def start(): devnull = open(os.devnull, 'w') subprocess.Popen([sys.argv[0], '-daemon'], stdout=devnull, stderr=devnull) -def register_service(): - """ - Register agent as a service - """ - print("Register {0} service".format(AGENT_NAME)) - OSUTIL.register_agent_service() - print("Start {0} service".format(AGENT_NAME)) - OSUTIL.start_agent_service() - -def main(): - """ - Parse command line arguments, exit with usage() on error. - Invoke different methods according to different command - """ - command, force, verbose = parse_args(sys.argv[1:]) - if command == "version": - version() - elif command == "help": - usage() - else: - init(verbose) - if command == "deprovision+user": - deprovision(force, deluser=True) - elif command == "deprovision": - deprovision(force, deluser=False) - elif command == "start": - start() - elif command == "register-service": - register_service() - elif command == "daemon": - run() diff --git a/azurelinuxagent/conf.py b/azurelinuxagent/conf.py index 2b0eb01..7921e79 100644 --- a/azurelinuxagent/conf.py +++ b/azurelinuxagent/conf.py @@ -43,11 +43,11 @@ class ConfigurationProvider(object): else: self.values[parts[0]] = None - def get(self, key, default_val=None): + def get(self, key, default_val): val = self.values.get(key) return val if val is not None else default_val - def get_switch(self, key, default_val=False): + def get_switch(self, key, default_val): val = self.values.get(key) if val is not None and val.lower() == 'y': return True @@ -55,7 +55,7 @@ class ConfigurationProvider(object): return False return default_val - def get_int(self, key, default_val=-1): + def get_int(self, key, default_val): try: return int(self.values.get(key)) except TypeError: @@ -64,9 +64,9 @@ class ConfigurationProvider(object): return default_val -__config__ = ConfigurationProvider() +__conf__ = ConfigurationProvider() -def load_conf(conf_file_path, conf=__config__): +def load_conf_from_file(conf_file_path, conf=__conf__): """ Load conf file from: conf_file_path """ @@ -80,30 +80,87 @@ def load_conf(conf_file_path, conf=__config__): raise AgentConfigError(("Failed to load conf file:{0}, {1}" "").format(conf_file_path, err)) -def get(key, default_val=None, conf=__config__): - """ - Get option value by key, return default_val if not found - """ - if conf is not None: - return conf.get(key, default_val) - else: - return default_val +def get_logs_verbose(conf=__conf__): + return conf.get_switch("Logs.Verbose", False) -def get_switch(key, default_val=None, conf=__config__): - """ - Get bool option value by key, return default_val if not found - """ - if conf is not None: - return conf.get_switch(key, default_val) - else: - return default_val +def get_lib_dir(conf=__conf__): + return conf.get("Lib.Dir", "/var/lib/waagent") -def get_int(key, default_val=None, conf=__config__): - """ - Get int option value by key, return default_val if not found - """ - if conf is not None: - return conf.get_int(key, default_val) - else: - return default_val +def get_dvd_mount_point(conf=__conf__): + return conf.get("DVD.MountPoint", "/mnt/cdrom/secure") + +def get_agent_pid_file_path(conf=__conf__): + return conf.get("Pid.File", "/var/run/waagent.pid") + +def get_ext_log_dir(conf=__conf__): + return conf.get("Extension.LogDir", "/var/log/azure") + +def get_openssl_cmd(conf=__conf__): + return conf.get("OS.OpensslPath", "/usr/bin/openssl") + +def get_home_dir(conf=__conf__): + return conf.get("OS.HomeDir", "/home") + +def get_passwd_file_path(conf=__conf__): + return conf.get("OS.PasswordPath", "/etc/shadow") + +def get_sshd_conf_file_path(conf=__conf__): + return conf.get("OS.SshdConfigPath", "/etc/ssh/sshd_config") + +def get_root_device_scsi_timeout(conf=__conf__): + return conf.get("OS.RootDeviceScsiTimeout", None) + +def get_ssh_host_keypair_type(conf=__conf__): + return conf.get("Provisioning.SshHostKeyPairType", "rsa") + +def get_provision_enabled(conf=__conf__): + return conf.get_switch("Provisioning.Enabled", True) + +def get_allow_reset_sys_user(conf=__conf__): + return conf.get_switch("Provisioning.AllowResetSysUser", False) + +def get_regenerate_ssh_host_key(conf=__conf__): + return conf.get_switch("Provisioning.RegenerateSshHostKeyPair", False) + +def get_delete_root_password(conf=__conf__): + return conf.get_switch("Provisioning.DeleteRootPassword", False) + +def get_decode_customdata(conf=__conf__): + return conf.get_switch("Provisioning.DecodeCustomData", False) + +def get_execute_customdata(conf=__conf__): + return conf.get_switch("Provisioning.ExecuteCustomData", False) + +def get_password_cryptid(conf=__conf__): + return conf.get("Provisioning.PasswordCryptId", "6") + +def get_password_crypt_salt_len(conf=__conf__): + return conf.get_int("Provisioning.PasswordCryptSaltLength", 10) + +def get_monitor_hostname(conf=__conf__): + return conf.get_switch("Provisioning.MonitorHostName", False) + +def get_httpproxy_host(conf=__conf__): + return conf.get("HttpProxy.Host", None) + +def get_httpproxy_port(conf=__conf__): + return conf.get("HttpProxy.Port", None) + +def get_detect_scvmm_env(conf=__conf__): + return conf.get_switch("DetectScvmmEnv", False) + +def get_resourcedisk_format(conf=__conf__): + return conf.get_switch("ResourceDisk.Format", False) + +def get_resourcedisk_enable_swap(conf=__conf__): + return conf.get_switch("ResourceDisk.EnableSwap", False) + +def get_resourcedisk_mountpoint(conf=__conf__): + return conf.get("ResourceDisk.MountPoint", "/mnt/resource") + +def get_resourcedisk_filesystem(conf=__conf__): + return conf.get("ResourceDisk.Filesystem", "ext3") + +def get_resourcedisk_swap_size_mb(conf=__conf__): + return conf.get_int("ResourceDisk.SwapSizeMB", 0) diff --git a/azurelinuxagent/distro/centos/__init__.py b/azurelinuxagent/distro/centos/__init__.py deleted file mode 100644 index d9b82f5..0000000 --- a/azurelinuxagent/distro/centos/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - diff --git a/azurelinuxagent/distro/centos/loader.py b/azurelinuxagent/distro/centos/loader.py deleted file mode 100644 index 9dc428f..0000000 --- a/azurelinuxagent/distro/centos/loader.py +++ /dev/null @@ -1,25 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION -import azurelinuxagent.distro.redhat.loader as redhat - -def get_osutil(): - return redhat.get_osutil() - diff --git a/azurelinuxagent/distro/coreos/deprovision.py b/azurelinuxagent/distro/coreos/deprovision.py index 99d3a40..9642579 100644 --- a/azurelinuxagent/distro/coreos/deprovision.py +++ b/azurelinuxagent/distro/coreos/deprovision.py @@ -21,6 +21,9 @@ import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.distro.default.deprovision import DeprovisionHandler, DeprovisionAction class CoreOSDeprovisionHandler(DeprovisionHandler): + def __init__(self, distro): + self.distro = distro + def setup(self, deluser): warnings, actions = super(CoreOSDeprovisionHandler, self).setup(deluser) warnings.append("WARNING! /etc/machine-id will be removed.") diff --git a/azurelinuxagent/distro/coreos/handlerFactory.py b/azurelinuxagent/distro/coreos/distro.py index 58f476c..04c7bff 100644 --- a/azurelinuxagent/distro/coreos/handlerFactory.py +++ b/azurelinuxagent/distro/coreos/distro.py @@ -17,11 +17,13 @@ # Requires Python 2.4+ and Openssl 1.0+ # -from .deprovision import CoreOSDeprovisionHandler -from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.coreos.osutil import CoreOSUtil +from azurelinuxagent.distro.coreos.deprovision import CoreOSDeprovisionHandler -class CoreOSHandlerFactory(DefaultHandlerFactory): +class CoreOSDistro(DefaultDistro): def __init__(self): - super(CoreOSHandlerFactory, self).__init__() - self.deprovision_handler = CoreOSDeprovisionHandler() + super(CoreOSDistro, self).__init__() + self.osutil = CoreOSUtil() + self.deprovision_handler = CoreOSDeprovisionHandler(self) diff --git a/azurelinuxagent/distro/coreos/loader.py b/azurelinuxagent/distro/coreos/loader.py deleted file mode 100644 index 802f276..0000000 --- a/azurelinuxagent/distro/coreos/loader.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - - -def get_osutil(): - from azurelinuxagent.distro.coreos.osutil import CoreOSUtil - return CoreOSUtil() - -def get_handlers(): - from azurelinuxagent.distro.coreos.handlerFactory import CoreOSHandlerFactory - return CoreOSHandlerFactory() - diff --git a/azurelinuxagent/distro/coreos/osutil.py b/azurelinuxagent/distro/coreos/osutil.py index c244311..ffc83e3 100644 --- a/azurelinuxagent/distro/coreos/osutil.py +++ b/azurelinuxagent/distro/coreos/osutil.py @@ -35,9 +35,9 @@ from azurelinuxagent.distro.default.osutil import DefaultOSUtil class CoreOSUtil(DefaultOSUtil): def __init__(self): super(CoreOSUtil, self).__init__() + self.agent_conf_file_path = '/usr/share/oem/waagent.conf' self.waagent_path='/usr/share/oem/bin/waagent' self.python_path='/usr/share/oem/python/bin' - self.conf_file_path = '/usr/share/oem/waagent.conf' if 'PATH' in os.environ: path = "{0}:{1}".format(os.environ['PATH'], self.python_path) else: @@ -85,9 +85,6 @@ class CoreOSUtil(DefaultOSUtil): ret= shellutil.run_get_output("pidof systemd-networkd") return ret[1] if ret[0] == 0 else None - def decode_customdata(self, data): - return base64.b64decode(data) - def set_ssh_client_alive_interval(self): #In CoreOS, /etc/sshd_config is mount readonly. Skip the setting pass diff --git a/azurelinuxagent/distro/oracle/loader.py b/azurelinuxagent/distro/debian/distro.py index 9dc428f..01f4e3e 100644 --- a/azurelinuxagent/distro/oracle/loader.py +++ b/azurelinuxagent/distro/debian/distro.py @@ -17,9 +17,11 @@ # Requires Python 2.4+ and Openssl 1.0+ # -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION -import azurelinuxagent.distro.redhat.loader as redhat +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.debian.osutil import DebianOSUtil -def get_osutil(): - return redhat.get_osutil() +class DebianDistro(DefaultDistro): + def __init__(self): + super(DebianDistro, self).__init__() + self.osutil = DebianOSUtil() diff --git a/azurelinuxagent/distro/default/daemon.py b/azurelinuxagent/distro/default/daemon.py new file mode 100644 index 0000000..cf9eb16 --- /dev/null +++ b/azurelinuxagent/distro/default/daemon.py @@ -0,0 +1,103 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import time +import sys +import traceback +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.future import ustr +from azurelinuxagent.event import add_event, WALAEventOperation +from azurelinuxagent.exception import ProtocolError +from azurelinuxagent.metadata import AGENT_LONG_NAME, AGENT_VERSION, \ + DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_FULL_NAME, PY_VERSION_MAJOR, \ + PY_VERSION_MINOR, PY_VERSION_MICRO +import azurelinuxagent.event as event +import azurelinuxagent.utils.fileutil as fileutil + + +class DaemonHandler(object): + def __init__(self, distro): + self.distro = distro + self.running = True + + + def run(self): + logger.info("{0} Version:{1}", AGENT_LONG_NAME, AGENT_VERSION) + logger.info("OS: {0} {1}", DISTRO_NAME, DISTRO_VERSION) + logger.info("Python: {0}.{1}.{2}", PY_VERSION_MAJOR, PY_VERSION_MINOR, + PY_VERSION_MICRO) + + self.check_pid() + + while self.running: + try: + self.daemon() + except Exception as e: + err_msg = traceback.format_exc() + add_event("WALA", is_success=False, message=ustr(err_msg), + op=WALAEventOperation.UnhandledError) + logger.info("Sleep 15 seconds and restart daemon") + time.sleep(15) + + def check_pid(self): + """Check whether daemon is already running""" + pid = None + pid_file = conf.get_agent_pid_file_path() + if os.path.isfile(pid_file): + pid = fileutil.read_file(pid_file) + + if pid is not None and os.path.isdir(os.path.join("/proc", pid)): + logger.info("Daemon is already running: {0}", pid) + sys.exit(0) + + fileutil.write_file(pid_file, ustr(os.getpid())) + + def daemon(self): + logger.info("Run daemon") + #Create lib dir + if not os.path.isdir(conf.get_lib_dir()): + fileutil.mkdir(conf.get_lib_dir(), mode=0o700) + os.chdir(conf.get_lib_dir()) + + if conf.get_detect_scvmm_env(): + if self.distro.scvmm_handler.run(): + return + + self.distro.provision_handler.run() + + if conf.get_resourcedisk_format(): + self.distro.resource_disk_handler.run() + + try: + protocol = self.distro.protocol_util.detect_protocol() + except ProtocolError as e: + logger.error("Failed to detect protocol, exit", e) + return + + self.distro.event_handler.run() + self.distro.env_handler.run() + + while self.running: + #Handle extensions + self.distro.ext_handlers_handler.run() + time.sleep(25) + diff --git a/azurelinuxagent/distro/default/deprovision.py b/azurelinuxagent/distro/default/deprovision.py index b62c5f6..4db4cdc 100644 --- a/azurelinuxagent/distro/default/deprovision.py +++ b/azurelinuxagent/distro/default/deprovision.py @@ -18,10 +18,8 @@ # import azurelinuxagent.conf as conf -from azurelinuxagent.utils.osutil import OSUTIL +from azurelinuxagent.exception import ProtocolError from azurelinuxagent.future import read_input -import azurelinuxagent.protocol as prot -import azurelinuxagent.protocol.ovfenv as ovf import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil @@ -35,18 +33,20 @@ class DeprovisionAction(object): self.func(*self.args, **self.kwargs) class DeprovisionHandler(object): + def __init__(self, distro): + self.distro = distro def del_root_password(self, warnings, actions): warnings.append("WARNING! root password will be disabled. " "You will not be able to login as root.") - actions.append(DeprovisionAction(OSUTIL.del_root_password)) + actions.append(DeprovisionAction(self.distro.osutil.del_root_password)) def del_user(self, warnings, actions): try: - ovfenv = ovf.get_ovf_env() - except prot.ProtocolError: + ovfenv = self.distro.protocol_util.get_ovf_env() + except ProtocolError: warnings.append("WARNING! ovf-env.xml is not found.") warnings.append("WARNING! Skip delete user.") return @@ -54,7 +54,8 @@ class DeprovisionHandler(object): username = ovfenv.username warnings.append(("WARNING! {0} account and entire home directory " "will be deleted.").format(username)) - actions.append(DeprovisionAction(OSUTIL.del_account, [username])) + actions.append(DeprovisionAction(self.distro.osutil.del_account, + [username])) def regen_ssh_host_key(self, warnings, actions): @@ -64,7 +65,7 @@ class DeprovisionHandler(object): def stop_agent_service(self, warnings, actions): warnings.append("WARNING! The waagent service will be stopped.") - actions.append(DeprovisionAction(OSUTIL.stop_agent_service)) + actions.append(DeprovisionAction(self.distro.osutil.stop_agent_service)) def del_files(self, warnings, actions): files_to_del = ['/root/.bash_history', '/var/log/waagent.log'] @@ -76,26 +77,28 @@ class DeprovisionHandler(object): actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del)) def del_lib_dir(self, warnings, actions): - dirs_to_del = [OSUTIL.get_lib_dir()] + dirs_to_del = [conf.get_lib_dir()] actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del)) def reset_hostname(self, warnings, actions): localhost = ["localhost.localdomain"] - actions.append(DeprovisionAction(OSUTIL.set_hostname, localhost)) - actions.append(DeprovisionAction(OSUTIL.set_dhcp_hostname, localhost)) + actions.append(DeprovisionAction(self.distro.osutil.set_hostname, + localhost)) + actions.append(DeprovisionAction(self.distro.osutil.set_dhcp_hostname, + localhost)) def setup(self, deluser): warnings = [] actions = [] self.stop_agent_service(warnings, actions) - if conf.get_switch("Provisioning.RegenerateSshHostkey", False): + if conf.get_regenerate_ssh_host_key(): self.regen_ssh_host_key(warnings, actions) self.del_dhcp_lease(warnings, actions) self.reset_hostname(warnings, actions) - if conf.get_switch("Provisioning.DeleteRootPassword", False): + if conf.get_delete_root_password(): self.del_root_password(warnings, actions) self.del_lib_dir(warnings, actions) @@ -106,7 +109,7 @@ class DeprovisionHandler(object): return warnings, actions - def deprovision(self, force=False, deluser=False): + def run(self, force=False, deluser=False): warnings, actions = self.setup(deluser) for warning in warnings: print(warning) diff --git a/azurelinuxagent/distro/default/dhcp.py b/azurelinuxagent/distro/default/dhcp.py index 4fd23ef..fc439d2 100644 --- a/azurelinuxagent/distro/default/dhcp.py +++ b/azurelinuxagent/distro/default/dhcp.py @@ -19,61 +19,106 @@ import os import socket import array import time +import threading import azurelinuxagent.logger as logger -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.exception import AgentNetworkError +import azurelinuxagent.conf as conf import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil -from azurelinuxagent.utils.textutil import * +from azurelinuxagent.utils.textutil import hex_dump, hex_dump2, hex_dump3, \ + compare_bytes, str_to_ord, \ + unpack_big_endian, \ + unpack_little_endian, \ + int_to_ip4_addr +from azurelinuxagent.exception import DhcpError -WIRE_SERVER_ADDR_FILE_NAME="WireServer" class DhcpHandler(object): - def __init__(self): + """ + Azure use DHCP option 245 to pass endpoint ip to VMs. + """ + def __init__(self, distro): + self.distro = distro self.endpoint = None self.gateway = None self.routes = None + def run(self): + """ + Send dhcp request + Configure default gateway and routes + Save wire server endpoint if found + """ + self.send_dhcp_req() + self.conf_routes() + def wait_for_network(self): - ipv4 = OSUTIL.get_ip4_addr() + """ + Wait for network stack to be initialized. + """ + ipv4 = self.distro.osutil.get_ip4_addr() while ipv4 == '' or ipv4 == '0.0.0.0': logger.info("Waiting for network.") time.sleep(10) - OSUTIL.start_network() - ipv4 = OSUTIL.get_ip4_addr() - - def probe(self): - logger.info("Send dhcp request") - self.wait_for_network() - mac_addr = OSUTIL.get_mac_addr() - req = build_dhcp_request(mac_addr) - resp = send_dhcp_request(req) - if resp is None: - logger.warn("Failed to detect wire server.") - return - endpoint, gateway, routes = parse_dhcp_resp(resp) - self.endpoint = endpoint - logger.info("Wire server endpoint:{0}", endpoint) - logger.info("Gateway:{0}", gateway) - logger.info("Routes:{0}", routes) - if endpoint is not None: - path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME) - fileutil.write_file(path, endpoint) - self.gateway = gateway - self.routes = routes - self.conf_routes() - - def get_endpoint(self): - return self.endpoint + logger.info("Try to start network interface.") + self.distro.osutil.start_network() + ipv4 = self.distro.osutil.get_ip4_addr() def conf_routes(self): logger.info("Configure routes") + logger.info("Gateway:{0}", self.gateway) + logger.info("Routes:{0}", self.routes) #Add default gateway if self.gateway is not None: - OSUTIL.route_add(0 , 0, self.gateway) + self.distro.osutil.route_add(0 , 0, self.gateway) if self.routes is not None: for route in self.routes: - OSUTIL.route_add(route[0], route[1], route[2]) + self.distro.osutil.route_add(route[0], route[1], route[2]) + + def _send_dhcp_req(self, request): + __waiting_duration__ = [0, 10, 30, 60, 60] + for duration in __waiting_duration__: + try: + self.distro.osutil.allow_dhcp_broadcast() + response = socket_send(request) + validate_dhcp_resp(request, response) + return response + except DhcpError as e: + logger.warn("Failed to send DHCP request: {0}", e) + time.sleep(duration) + return None + + def send_dhcp_req(self): + """ + Build dhcp request with mac addr + Configure route to allow dhcp traffic + Stop dhcp service if necessary + """ + logger.info("Send dhcp request") + mac_addr = self.distro.osutil.get_mac_addr() + req = build_dhcp_request(mac_addr) + + # Temporary allow broadcast for dhcp. Remove the route when done. + missing_default_route = self.distro.osutil.is_missing_default_route() + ifname = self.distro.osutil.get_if_name() + if missing_default_route: + self.distro.osutil.set_route_for_dhcp_broadcast(ifname) + + # In some distros, dhcp service needs to be shutdown before agent probe + # endpoint through dhcp. + if self.distro.osutil.is_dhcp_enabled(): + self.distro.osutil.stop_dhcp_service() + + resp = self._send_dhcp_req(req) + + if self.distro.osutil.is_dhcp_enabled(): + self.distro.osutil.start_dhcp_service() + + if missing_default_route: + self.distro.osutil.remove_route_for_dhcp_broadcast(ifname) + + if resp is None: + raise DhcpError("Failed to receive dhcp response.") + self.endpoint, self.gateway, self.routes = parse_dhcp_resp(resp) def validate_dhcp_resp(request, response): bytes_recv = len(response) @@ -92,28 +137,25 @@ def validate_dhcp_resp(request, response): logger.verb("Cookie not match:\nsend={0},\nreceive={1}", hex_dump3(request, 0xEC, 4), hex_dump3(response, 0xEC, 4)) - raise AgentNetworkError("Cookie in dhcp respones " - "doesn't match the request") + raise DhcpError("Cookie in dhcp respones doesn't match the request") if not compare_bytes(request, response, 4, 4): logger.verb("TransactionID not match:\nsend={0},\nreceive={1}", hex_dump3(request, 4, 4), hex_dump3(response, 4, 4)) - raise AgentNetworkError("TransactionID in dhcp respones " - "doesn't match the request") + raise DhcpError("TransactionID in dhcp respones " + "doesn't match the request") if not compare_bytes(request, response, 0x1C, 6): logger.verb("Mac Address not match:\nsend={0},\nreceive={1}", hex_dump3(request, 0x1C, 6), hex_dump3(response, 0x1C, 6)) - raise AgentNetworkError("Mac Addr in dhcp respones " - "doesn't match the request") + raise DhcpError("Mac Addr in dhcp respones " + "doesn't match the request") def parse_route(response, option, i, length, bytes_recv): # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx - logger.verb("Routes at offset: {0} with length:{1}", - hex(i), - hex(length)) + logger.verb("Routes at offset: {0} with length:{1}", hex(i), hex(length)) routes = [] if length < 5: logger.error("Data too small for option:{0}", option) @@ -169,9 +211,7 @@ def parse_dhcp_resp(response): if (i + 1) < bytes_recv: length = str_to_ord(response[i + 1]) logger.verb("DHCP option {0} at offset:{1} with length:{2}", - hex(option), - hex(i), - hex(length)) + hex(option), hex(i), hex(length)) if option == 255: logger.verb("DHCP packet ended at offset:{0}", hex(i)) break @@ -179,69 +219,17 @@ def parse_dhcp_resp(response): routes = parse_route(response, option, i, length, bytes_recv) elif option == 3: gateway = parse_ip_addr(response, option, i, length, bytes_recv) - logger.verb("Default gateway:{0}, at {1}", - gateway, - hex(i)) + logger.verb("Default gateway:{0}, at {1}", gateway, hex(i)) elif option == 245: endpoint = parse_ip_addr(response, option, i, length, bytes_recv) - logger.verb("Azure wire protocol endpoint:{0}, at {1}", - gateway, - hex(i)) + logger.verb("Azure wire protocol endpoint:{0}, at {1}", gateway, + hex(i)) else: logger.verb("Skipping DHCP option:{0} at {1} with length {2}", - hex(option), - hex(i), - hex(length)) + hex(option), hex(i), hex(length)) i += length + 2 return endpoint, gateway, routes - -def allow_dhcp_broadcast(func): - """ - Temporary allow broadcase for dhcp. Remove the route when done. - """ - def wrapper(*args, **kwargs): - missing_default_route = OSUTIL.is_missing_default_route() - ifname = OSUTIL.get_if_name() - if missing_default_route: - OSUTIL.set_route_for_dhcp_broadcast(ifname) - result = func(*args, **kwargs) - if missing_default_route: - OSUTIL.remove_route_for_dhcp_broadcast(ifname) - return result - return wrapper - -def disable_dhcp_service(func): - """ - In some distros, dhcp service needs to be shutdown before agent probe - endpoint through dhcp. - """ - def wrapper(*args, **kwargs): - if OSUTIL.is_dhcp_enabled(): - OSUTIL.stop_dhcp_service() - result = func(*args, **kwargs) - OSUTIL.start_dhcp_service() - return result - else: - return func(*args, **kwargs) - return wrapper - - -@allow_dhcp_broadcast -@disable_dhcp_service -def send_dhcp_request(request): - __waiting_duration__ = [0, 10, 30, 60, 60] - for duration in __waiting_duration__: - try: - OSUTIL.allow_dhcp_broadcast() - response = socket_send(request) - validate_dhcp_resp(request, response) - return response - except AgentNetworkError as e: - logger.warn("Failed to send DHCP request: {0}", e) - time.sleep(duration) - return None - def socket_send(request): sock = None try: @@ -257,7 +245,7 @@ def socket_send(request): response = sock.recv(1024) return response except IOError as e: - raise AgentNetworkError("{0}".format(e)) + raise DhcpError("{0}".format(e)) finally: if sock is not None: sock.close() diff --git a/azurelinuxagent/distro/default/distro.py b/azurelinuxagent/distro/default/distro.py new file mode 100644 index 0000000..ca0d77e --- /dev/null +++ b/azurelinuxagent/distro/default/distro.py @@ -0,0 +1,51 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.conf import ConfigurationProvider +from azurelinuxagent.distro.default.osutil import DefaultOSUtil +from azurelinuxagent.distro.default.daemon import DaemonHandler +from azurelinuxagent.distro.default.init import InitHandler +from azurelinuxagent.distro.default.monitor import MonitorHandler +from azurelinuxagent.distro.default.dhcp import DhcpHandler +from azurelinuxagent.distro.default.protocolUtil import ProtocolUtil +from azurelinuxagent.distro.default.scvmm import ScvmmHandler +from azurelinuxagent.distro.default.env import EnvHandler +from azurelinuxagent.distro.default.provision import ProvisionHandler +from azurelinuxagent.distro.default.resourceDisk import ResourceDiskHandler +from azurelinuxagent.distro.default.extension import ExtHandlersHandler +from azurelinuxagent.distro.default.deprovision import DeprovisionHandler + +class DefaultDistro(object): + """ + """ + def __init__(self): + self.osutil = DefaultOSUtil() + self.protocol_util = ProtocolUtil(self) + + self.init_handler = InitHandler(self) + self.daemon_handler = DaemonHandler(self) + self.event_handler = MonitorHandler(self) + self.dhcp_handler = DhcpHandler(self) + self.scvmm_handler = ScvmmHandler(self) + self.env_handler = EnvHandler(self) + self.provision_handler = ProvisionHandler(self) + self.resource_disk_handler = ResourceDiskHandler(self) + self.ext_handlers_handler = ExtHandlersHandler(self) + self.deprovision_handler = DeprovisionHandler(self) + diff --git a/azurelinuxagent/distro/default/env.py b/azurelinuxagent/distro/default/env.py index 28bf718..7878cff 100644 --- a/azurelinuxagent/distro/default/env.py +++ b/azurelinuxagent/distro/default/env.py @@ -23,7 +23,6 @@ import threading import time import azurelinuxagent.logger as logger import azurelinuxagent.conf as conf -from azurelinuxagent.utils.osutil import OSUTIL class EnvHandler(object): """ @@ -31,35 +30,25 @@ class EnvHandler(object): If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric. Monitor scsi disk. - If new scsi disk found, set + If new scsi disk found, set timeout """ - def __init__(self, handlers): - self.monitor = EnvMonitor(handlers.dhcp_handler) - - def start(self): - self.monitor.start() - - def stop(self): - self.monitor.stop() - -class EnvMonitor(object): - - def __init__(self, dhcp_handler): - self.dhcp_handler = dhcp_handler + def __init__(self, distro): + self.distro = distro self.stopped = True self.hostname = None self.dhcpid = None self.server_thread=None - def start(self): + def run(self): if not self.stopped: logger.info("Stop existing env monitor service.") self.stop() self.stopped = False logger.info("Start env monitor service.") + self.distro.dhcp_handler.conf_routes() self.hostname = socket.gethostname() - self.dhcpid = OSUTIL.get_dhcp_pid() + self.dhcpid = self.distro.osutil.get_dhcp_pid() self.server_thread = threading.Thread(target = self.monitor) self.server_thread.setDaemon(True) self.server_thread.start() @@ -70,11 +59,11 @@ class EnvMonitor(object): If dhcp clinet process re-start has occurred, reset routes. """ while not self.stopped: - OSUTIL.remove_rules_files() - timeout = conf.get("OS.RootDeviceScsiTimeout", None) + self.distro.osutil.remove_rules_files() + timeout = conf.get_root_device_scsi_timeout() if timeout is not None: - OSUTIL.set_scsi_disks_timeout(timeout) - if conf.get_switch("Provisioning.MonitorHostName", False): + self.distro.osutil.set_scsi_disks_timeout(timeout) + if conf.get_monitor_hostname(): self.handle_hostname_update() self.handle_dhclient_restart() time.sleep(5) @@ -84,25 +73,25 @@ class EnvMonitor(object): if curr_hostname != self.hostname: logger.info("EnvMonitor: Detected host name change: {0} -> {1}", self.hostname, curr_hostname) - OSUTIL.set_hostname(curr_hostname) - OSUTIL.publish_hostname(curr_hostname) + self.distro.osutil.set_hostname(curr_hostname) + self.distro.osutil.publish_hostname(curr_hostname) self.hostname = curr_hostname def handle_dhclient_restart(self): if self.dhcpid is None: logger.warn("Dhcp client is not running. ") - self.dhcpid = OSUTIL.get_dhcp_pid() + self.dhcpid = self.distro.osutil.get_dhcp_pid() return #The dhcp process hasn't changed since last check if os.path.isdir(os.path.join('/proc', self.dhcpid.strip())): return - newpid = OSUTIL.get_dhcp_pid() + newpid = self.distro.osutil.get_dhcp_pid() if newpid is not None and newpid != self.dhcpid: logger.info("EnvMonitor: Detected dhcp client restart. " "Restoring routing table.") - self.dhcp_handler.conf_routes() + self.distro.dhcp_handler.conf_routes() self.dhcpid = newpid def stop(self): diff --git a/azurelinuxagent/distro/default/extension.py b/azurelinuxagent/distro/default/extension.py index f6c02aa..82cdfed 100644 --- a/azurelinuxagent/distro/default/extension.py +++ b/azurelinuxagent/distro/default/extension.py @@ -22,13 +22,16 @@ import time import json import subprocess import shutil +import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.protocol as prot -from azurelinuxagent.metadata import AGENT_VERSION from azurelinuxagent.event import add_event, WALAEventOperation -from azurelinuxagent.exception import ExtensionError +from azurelinuxagent.exception import ExtensionError, ProtocolError, HttpError +from azurelinuxagent.future import ustr +from azurelinuxagent.metadata import AGENT_VERSION +from azurelinuxagent.protocol.restapi import ExtHandlerStatus, ExtensionStatus, \ + ExtensionSubStatus, Extension, \ + VMStatus, ExtHandler, \ + get_properties, set_properties import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.restutil as restutil import azurelinuxagent.utils.shellutil as shellutil @@ -41,15 +44,6 @@ VALID_EXTENSION_STATUS = ['transitioning', 'error', 'success', 'warning'] VALID_HANDLER_STATUS = ['Ready', 'NotReady', "Installing", "Unresponsive"] -def handler_state_to_status(handler_state): - if handler_state == "Enabled": - return "Ready" - elif handler_state in VALID_HANDLER_STATUS: - return handler_state - else: - return "NotReady" - - def validate_has_key(obj, key, fullname): if key not in obj: raise ExtensionError("Missing: {0}".format(fullname)) @@ -64,14 +58,13 @@ def parse_formatted_message(formatted_message): validate_has_key(formatted_message, 'lang', 'formattedMessage/lang') validate_has_key(formatted_message, 'message', 'formattedMessage/message') return formatted_message.get('message') - def parse_ext_substatus(substatus): #Check extension sub status format validate_has_key(substatus, 'status', 'substatus/status') validate_in_range(substatus['status'], VALID_EXTENSION_STATUS, 'substatus/status') - status = prot.ExtensionSubStatus() + status = ExtensionSubStatus() status.name = substatus.get('name') status.status = substatus.get('status') status.code = substatus.get('code', 0) @@ -105,333 +98,330 @@ def parse_ext_status(ext_status, data): for substatus in substatus_list: ext_status.substatusList.append(parse_ext_substatus(substatus)) -def parse_extension_dirname(dirname): - """ - Parse installed extension dir name. Sample: ExtensionName-Version/ - """ - seprator = dirname.rfind('-') - if seprator < 0: - raise ExtensionError("Invalid extenation dir name") - return dirname[0:seprator], dirname[seprator + 1:] - -def get_installed_version(target_name): - """ - Return the highest version instance with the same name - """ - installed_version = None - lib_dir = OSUTIL.get_lib_dir() - for dir_name in os.listdir(lib_dir): - path = os.path.join(lib_dir, dir_name) - if os.path.isdir(path) and dir_name.startswith(target_name): - name, version = parse_extension_dirname(dir_name) - #Here we need to ensure names are exactly the same. - if name == target_name: - if installed_version is None or \ - Version(installed_version) < Version(version): - installed_version = version - return installed_version - class ExtHandlerState(object): + NotInstalled = "NotInstalled" + Installed = "Installed" Enabled = "Enabled" - Disabled = "Disabled" - Failed = "Failed" - class ExtHandlersHandler(object): - - def process(self): + def __init__(self, distro): + self.distro = distro + self.ext_handlers = None + self.last_etag = None + self.log_report = False + + def run(self): + ext_handlers, etag = None, None try: - protocol = prot.FACTORY.get_default_protocol() - ext_handlers = protocol.get_ext_handlers() - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message = text(e)) + self.protocol = self.distro.protocol_util.get_protocol() + ext_handlers, etag = self.protocol.get_ext_handlers() + except ProtocolError as e: + add_event(name="WALA", is_success=False, message=ustr(e)) return - - vm_status = prot.VMStatus() + if self.last_etag is not None and self.last_etag == etag: + logger.verb("No change to ext handler config:{0}, skip", etag) + self.log_report = False + else: + logger.info("Handle new ext handler config") + self.log_report = True #Log status report success on new config + self.handle_ext_handlers(ext_handlers) + self.last_etag = etag + + self.report_ext_handlers_status(ext_handlers) + + def handle_ext_handlers(self, ext_handlers): + if ext_handlers.extHandlers is None or \ + len(ext_handlers.extHandlers) == 0: + logger.info("No ext handler config found") + return + + for ext_handler in ext_handlers.extHandlers: + #TODO handle install in sequence, enable in parallel + self.handle_ext_handler(ext_handler) + + def handle_ext_handler(self, ext_handler): + ext_handler_i = ExtHandlerInstance(ext_handler, self.protocol) + try: + state = ext_handler.properties.state + ext_handler_i.logger.info("Expected handler state: {0}", state) + if state == "enabled": + self.handle_enable(ext_handler_i) + elif state == u"disabled": + self.handle_disable(ext_handler_i) + elif state == u"uninstall": + self.handle_uninstall(ext_handler_i) + else: + message = u"Unknown ext handler state:{0}".format(state) + raise ExtensionError(message) + except ExtensionError as e: + ext_handler_i.set_handler_status(message=ustr(e), code=-1) + ext_handler_i.report_event(message=ustr(e), is_success=False) + + def handle_enable(self, ext_handler_i): + + ext_handler_i.decide_version() + + old_ext_handler_i = ext_handler_i.get_installed_ext_handler() + if old_ext_handler_i is not None and \ + old_ext_handler_i.version_gt(ext_handler_i): + raise ExtensionError(u"Downgrade not allowed") + + handler_state = ext_handler_i.get_handler_state() + ext_handler_i.logger.info("Current handler state is: {0}", handler_state) + if handler_state == ExtHandlerState.NotInstalled: + ext_handler_i.set_handler_state(ExtHandlerState.NotInstalled) + + ext_handler_i.download() + + ext_handler_i.update_settings() + + if old_ext_handler_i is None: + ext_handler_i.install() + elif ext_handler_i.version_gt(old_ext_handler_i): + old_ext_handler_i.disable() + ext_handler_i.copy_status_files(old_ext_handler_i) + ext_handler_i.update() + old_ext_handler_i.uninstall() + old_ext_handler_i.rm_ext_handler_dir() + ext_handler_i.update_with_install() + else: + ext_handler_i.update_settings() + + ext_handler_i.enable() + + def handle_disable(self, ext_handler_i): + handler_state = ext_handler_i.get_handler_state() + ext_handler_i.logger.info("Current handler state is: {0}", handler_state) + if handler_state == ExtHandlerState.Enabled: + ext_handler_i.disable() + + def handle_uninstall(self, ext_handler_i): + handler_state = ext_handler_i.get_handler_state() + ext_handler_i.logger.info("Current handler state is: {0}", handler_state) + if handler_state != ExtHandlerState.NotInstalled: + if handler_state == ExtHandlerState.Enabled: + ext_handler_i.disable() + ext_handler_i.uninstall() + ext_handler_i.rm_ext_handler_dir() + + def report_ext_handlers_status(self, ext_handlers): + """Go thru handler_state dir, collect and report status""" + vm_status = VMStatus() vm_status.vmAgent.version = AGENT_VERSION vm_status.vmAgent.status = "Ready" vm_status.vmAgent.message = "Guest Agent is running" - if ext_handlers.extHandlers is None or \ - len(ext_handlers.extHandlers) == 0: - logger.verb("No extensions to handle") - else: + if ext_handlers is not None: for ext_handler in ext_handlers.extHandlers: - #TODO handle extension in parallel try: - pkg_list = protocol.get_ext_handler_pkgs(ext_handler) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message=text(e)) - continue - - handler_status = self.process_extension(ext_handler, pkg_list) - if handler_status is not None: - vm_status.vmAgent.extensionHandlers.append(handler_status) - + self.report_ext_handler_status(vm_status, ext_handler) + except ExtensionError as e: + add_event(name="WALA", is_success=False, message=ustr(e)) + + logger.verb("Report vm agent status") + try: - logger.verb("Report vm agent status") - protocol.report_vm_status(vm_status) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message = text(e)) - - def process_extension(self, ext_handler, pkg_list): - installed_version = get_installed_version(ext_handler.name) - if installed_version is not None: - handler = ExtHandlerInstance(ext_handler, pkg_list, - installed_version, installed=True) - else: - handler = ExtHandlerInstance(ext_handler, pkg_list, - ext_handler.properties.version) - handler.handle() + self.protocol.report_vm_status(vm_status) + except ProtocolError as e: + message = "Failed to report vm agent status: {0}".format(e) + add_event(name="WALA", is_success=False, message=message) + + if self.log_report: + logger.info("Successfully reported vm agent status") + + + def report_ext_handler_status(self, vm_status, ext_handler): + ext_handler_i = ExtHandlerInstance(ext_handler, self.protocol) - if handler.ext_status is not None: + handler_status = ext_handler_i.get_handler_status() + if handler_status is None: + return + + handler_state = ext_handler_i.get_handler_state() + if handler_state != ExtHandlerState.NotInstalled: try: - protocol = prot.FACTORY.get_default_protocol() - protocol.report_ext_status(handler.name, handler.ext.name, - handler.ext_status) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message=text(e)) - - return handler.handler_status + active_exts = ext_handler_i.report_ext_status() + handler_status.extensions.extend(active_exts) + except ExtensionError as e: + ext_handler_i.set_handler_status(message=ustr(e), code=-1) + + try: + heartbeat = ext_handler_i.collect_heartbeat() + if heartbeat is not None: + handler_status.status = heartbeat.get('status') + except ExtensionError as e: + ext_handler_i.set_handler_status(message=ustr(e), code=-1) + vm_status.vmAgent.extensionHandlers.append(handler_status) + class ExtHandlerInstance(object): - def __init__(self, ext_handler, pkg_list, curr_version, installed=False): + def __init__(self, ext_handler, protocol): self.ext_handler = ext_handler - self.name = ext_handler.name - self.version = ext_handler.properties.version - self.pkg_list = pkg_list - self.state = ext_handler.properties.state - self.update_policy = ext_handler.properties.upgradePolicy - - self.curr_version = curr_version - self.installed = installed - self.handler_state = None - self.lib_dir = OSUTIL.get_lib_dir() - - self.ext_status = prot.ExtensionStatus() - self.handler_status = prot.ExtHandlerStatus() - self.handler_status.name = self.name - self.handler_status.version = self.curr_version - - #Currently, extension settings will have no more than 1 instance - if len(ext_handler.properties.extensions) > 0: - self.ext = ext_handler.properties.extensions[0] - self.handler_status.extensions = [self.ext.name] - else: - #When no extension settings, set sequenceNumber to 0 - self.ext = prot.Extension(sequenceNumber=0) - self.ext_status.sequenceNumber = self.ext.sequenceNumber + self.protocol = protocol + self.operation = None + self.pkg = None prefix = "[{0}]".format(self.get_full_name()) self.logger = logger.Logger(logger.DEFAULT_LOGGER, prefix) + + try: + fileutil.mkdir(self.get_log_dir(), mode=0o744) + except IOError as e: + self.logger.error(u"Failed to create extension log dir: {0}", e) - def init_logger(self): - #Init logger appender for extension - fileutil.mkdir(self.get_log_dir(), mode=0o644) log_file = os.path.join(self.get_log_dir(), "CommandExecution.log") self.logger.add_appender(logger.AppenderType.FILE, logger.LogLevel.INFO, log_file) - def handle(self): - self.init_logger() - self.logger.verb("Start processing extension handler") - - try: - self.handle_state() - except ExtensionError as e: - self.set_state_err(text(e)) - self.report_event(is_success=False, message=text(e)) - self.logger.error("Failed to process extension handler") - return - - try: - if self.installed: - self.collect_ext_status() - self.collect_handler_status() - except ExtensionError as e: - self.report_event(is_success=False, message=text(e)) - self.logger.error("Failed to get extension handler status") - return - - self.logger.verb("Finished processing extension handler") - - def handle_state(self): - if self.installed: - self.handler_state = self.get_state() - - self.handler_status.status = handler_state_to_status(self.handler_state) - self.logger.verb("Handler state: {0}", self.handler_state) - self.logger.verb("Sequence number: {0}", self.ext.sequenceNumber) - - if self.state == 'enabled': - if self.handler_state == ExtHandlerState.Failed: - self.logger.verb("Found previous failure, quit handle_enable") - return - - if self.handler_state == ExtHandlerState.Enabled: - self.logger.verb("Already enabled with sequenceNumber: {0}", - self.ext.sequenceNumber) - self.logger.verb("Quit handle_enable") - return + def decide_version(self): + """ + If auto-upgrade, get the largest public extension version under + the requested major version family of currently installed plugin version - try: - new = self.handle_enable() - if new is not None: - #Upgrade happened - new.set_state(ExtHandlerState.Enabled) - else: - self.set_state(ExtHandlerState.Enabled) + Else, get the highest hot-fix for requested version, + """ + self.logger.info("Decide which version to use") + try: + pkg_list = self.protocol.get_ext_handler_pkgs(self.ext_handler) + except ProtocolError as e: + raise ExtensionError("Failed to get ext handler pkgs", e) - except ExtensionError as e: - self.set_state(ExtHandlerState.Failed) - raise e - elif self.state == 'disabled': - if self.handler_state == ExtHandlerState.Failed: - self.logger.verb("Found previous failure, quit handle_disable") - return - - if self.handler_state == ExtHandlerState.Disabled: - self.logger.verb("Already disabled with sequenceNumber: {0}", - self.ext.sequenceNumber) - self.logger.verb("Quit handle_disable") - return + version = self.ext_handler.properties.version + update_policy = self.ext_handler.properties.upgradePolicy + + version_frag = version.split('.') + if len(version_frag) < 2: + raise ExtensionError("Wrong version format: {0}".format(version)) - try: - self.handle_disable() - self.set_state(ExtHandlerState.Disabled) - except ExtensionError as e: - self.set_state(ExtHandlerState.Failed) - raise e - elif self.state == 'uninstall': - try: - self.handle_uninstall() - except ExtensionError as e: - self.set_state(ExtHandlerState.Failed) - raise e + version_prefix = None + if update_policy is not None and update_policy == 'auto': + version_prefix = "{0}.".format(version_frag[0]) else: - raise ExtensionError("Unknown state:{0}".format(self.state)) - - def handle_enable(self): - target_version = self.get_target_version() - self.logger.info("Target version: {0}", target_version) - if self.installed: - if Version(target_version) > Version(self.curr_version): - return self.upgrade(target_version) - elif Version(target_version) == Version(self.curr_version): - self.enable() - else: - raise ExtensionError("A newer version is already installed") - else: - if Version(target_version) > Version(self.version): - #This will happen when auto upgrade policy is enabled - self.logger.info("Auto upgrade to new version:{0}", - target_version) - self.curr_version = target_version - self.download() - self.init_dir() - self.install() - self.enable() + version_prefix = "{0}.{1}.".format(version_frag[0], version_frag[1]) + + packages = [x for x in pkg_list.versions \ + if x.version.startswith(version_prefix) or \ + x.version == version] + + packages = sorted(packages, key=lambda x: Version(x.version), + reverse=True) - def handle_disable(self): - if not self.installed: - self.logger.verb("Not installed, quit disable") - return + if len(packages) <= 0: + raise ExtensionError("Failed to find and valid extension package") + self.pkg = packages[0] + self.ext_handler.properties.version = packages[0].version + self.logger.info("Use version: {0}", self.pkg.version) + + def version_gt(self, other): + self_version = self.ext_handler.properties.version + other_version = other.ext_handler.properties.version + return Version(self_version) > Version(other_version) + + def get_installed_ext_handler(self): + lastest_version = None + ext_handler_name = self.ext_handler.name + + for dir_name in os.listdir(conf.get_lib_dir()): + path = os.path.join(conf.get_lib_dir(), dir_name) + if os.path.isdir(path) and dir_name.startswith(ext_handler_name): + seperator = dir_name.rfind('-') + if seperator < 0: + continue + installed_name = dir_name[0: seperator] + installed_version = dir_name[seperator + 1:] + if installed_name != ext_handler_name: + continue + if lastest_version is None or \ + Version(lastest_version) < Version(installed_version): + lastest_version = installed_version - self.disable() + if lastest_version is None: + return None + + data = get_properties(self.ext_handler) + old_ext_handler = ExtHandler() + set_properties("ExtHandler", old_ext_handler, data) + old_ext_handler.properties.version = lastest_version + return ExtHandlerInstance(old_ext_handler, self.protocol) + + def copy_status_files(self, old_ext_handler_i): + self.logger.info("Copy status files from old plugin to new") + old_ext_dir = old_ext_handler_i.get_base_dir() + new_ext_dir = self.get_base_dir() + + old_ext_mrseq_file = os.path.join(old_ext_dir, "mrseq") + if os.path.isfile(old_ext_mrseq_file): + shutil.copy2(old_ext_mrseq_file, new_ext_dir) + + old_ext_status_dir = old_ext_handler_i.get_status_dir() + new_ext_status_dir = self.get_status_dir() + + if os.path.isdir(old_ext_status_dir): + for status_file in os.listdir(old_ext_status_dir): + status_file = os.path.join(old_ext_status_dir, status_file) + if os.path.isfile(status_file): + shutil.copy2(status_file, new_ext_status_dir) + + def set_operation(self, op): + self.operation = op - def handle_uninstall(self): - if not self.installed: - self.logger.verb("Not installed, quit unistall") - self.handler_status = None - self.ext_status = None - return - self.disable() - self.uninstall() - - def report_event(self, is_success=True, message=""): - if self.ext_status is not None: - if not is_success: - self.ext_status.status = "error" - self.ext_status.code = -1 - if self.handler_status is not None: - self.handler_status.message = message - if not is_success: - self.handler_status.status = "NotReady" - add_event(name=self.name, op=self.ext_status.operation, - is_success=is_success, message=message) - - def set_operation(self, operation): - if self.ext_status.operation != WALAEventOperation.Upgrade: - self.ext_status.operation = operation - - def upgrade(self, target_version): - self.logger.info("Upgrade from: {0} to {1}", self.curr_version, - target_version) - self.set_operation(WALAEventOperation.Upgrade) - - old = self - new = ExtHandlerInstance(self.ext_handler, self.pkg_list, - target_version) - self.logger.info("Download new extension package") - new.init_logger() - new.download() - self.logger.info("Initialize new extension directory") - new.init_dir() - - old.disable() - self.logger.info("Update new extension") - new.update() - old.uninstall() - man = new.load_manifest() - if man.is_update_with_install(): - self.logger.info("Install new extension") - new.install() - self.logger.info("Enable new extension") - new.enable() - return new + def report_event(self, message="", is_success=True): + version = self.ext_handler.properties.version + add_event(name=self.ext_handler.name, version=version, message=message, + op=self.operation, is_success=is_success) def download(self): self.logger.info("Download extension package") self.set_operation(WALAEventOperation.Download) - - uris = self.get_package_uris() + if self.pkg is None: + raise ExtensionError("No package uri found") + package = None - for uri in uris: + for uri in self.pkg.uris: try: - resp = restutil.http_get(uri.uri, chk_proxy=True) - if resp.status == restutil.httpclient.OK: - package = resp.read() - break - except restutil.HttpError as e: - self.logger.warn("Failed download extension from: {0}", uri.uri) - + package = self.protocol.download_ext_handler_pkg(uri.uri) + except ProtocolError as e: + logger.warn("Failed download extension: {0}", e) + if package is None: - raise ExtensionError("Download extension failed") + raise ExtensionError("Failed to download extension") self.logger.info("Unpack extension package") - pkg_file = os.path.join(self.lib_dir, os.path.basename(uri.uri) + ".zip") - fileutil.write_file(pkg_file, bytearray(package), asbin=True) - zipfile.ZipFile(pkg_file).extractall(self.get_base_dir()) + pkg_file = os.path.join(conf.get_lib_dir(), + os.path.basename(uri.uri) + ".zip") + try: + fileutil.write_file(pkg_file, bytearray(package), asbin=True) + zipfile.ZipFile(pkg_file).extractall(self.get_base_dir()) + except IOError as e: + raise ExtensionError(u"Failed to write and unzip plugin", e) + chmod = "find {0} -type f | xargs chmod u+x".format(self.get_base_dir()) shellutil.run(chmod) self.report_event(message="Download succeeded") - def init_dir(self): self.logger.info("Initialize extension directory") #Save HandlerManifest.json man_file = fileutil.search_file(self.get_base_dir(), 'HandlerManifest.json') - man = fileutil.read_file(man_file, remove_bom=True) - fileutil.write_file(self.get_manifest_file(), man) - #Create status and config dir - status_dir = self.get_status_dir() - fileutil.mkdir(status_dir, mode=0o700) - conf_dir = self.get_conf_dir() - fileutil.mkdir(conf_dir, mode=0o700) + if man_file is None: + raise ExtensionError("HandlerManifest.json not found") - self.make_handler_state_dir() + try: + man = fileutil.read_file(man_file, remove_bom=True) + fileutil.write_file(self.get_manifest_file(), man) + except IOError as e: + raise ExtensionError(u"Failed to save HandlerManifest.json", e) + + #Create status and config dir + try: + status_dir = self.get_status_dir() + fileutil.mkdir(status_dir, mode=0o700) + conf_dir = self.get_conf_dir() + fileutil.mkdir(conf_dir, mode=0o700) + except IOError as e: + raise ExtensionError(u"Failed to create status or config dir", e) #Save HandlerEnvironment.json self.create_handler_env() @@ -442,6 +432,8 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_enable_command()) + self.set_handler_state(ExtHandlerState.Enabled) + self.set_handler_status(status="Ready", message="Plugin enabled") def disable(self): self.logger.info("Disable extension.") @@ -449,6 +441,8 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_disable_command(), timeout=900) + self.set_handler_state(ExtHandlerState.Installed) + self.set_handler_status(status="NotReady", message="Plugin disabled") def install(self): self.logger.info("Install extension.") @@ -456,24 +450,31 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_install_command(), timeout=900) - self.installed = True + self.set_handler_state(ExtHandlerState.Installed) def uninstall(self): self.logger.info("Uninstall extension.") self.set_operation(WALAEventOperation.UnInstall) - man = self.load_manifest() - self.launch_command(man.get_uninstall_command()) - - self.logger.info("Remove ext handler dir: {0}", self.get_base_dir()) try: - shutil.rmtree(self.get_base_dir()) + man = self.load_manifest() + self.launch_command(man.get_uninstall_command()) + except ExtensionError as e: + self.report_event(message=ustr(e), is_success=False) + + def rm_ext_handler_dir(self): + try: + handler_state_dir = self.get_handler_state_dir() + if os.path.isdir(handler_state_dir): + self.logger.info("Remove ext handler dir: {0}", handler_state_dir) + shutil.rmtree(handler_state_dir) + base_dir = self.get_base_dir() + if os.path.isdir(base_dir): + self.logger.info("Remove ext handler dir: {0}", base_dir) + shutil.rmtree(base_dir) except IOError as e: - raise ExtensionError("Failed to rm ext handler dir: {0}".format(e)) - - self.installed = False - self.handler_status = None - self.ext_status = None + message = "Failed to rm ext handler dir: {0}".format(e) + self.report_event(message=message, is_success=False) def update(self): self.logger.info("Update extension.") @@ -481,95 +482,82 @@ class ExtHandlerInstance(object): man = self.load_manifest() self.launch_command(man.get_update_command(), timeout=900) - - def collect_handler_status(self): - self.logger.verb("Collect extension handler status") - if self.handler_status is None: - return - - handler_state = self.get_state() - self.handler_status.status = handler_state_to_status(handler_state) - self.handler_status.message = self.get_state_err() + + def update_with_install(self): man = self.load_manifest() - if man.is_report_heartbeat(): - heartbeat = self.collect_heartbeat() - if heartbeat is not None: - self.handler_status.status = heartbeat['status'] + if man.is_update_with_install(): + self.install() + else: + self.logger.info("UpdateWithInstall not set. " + "Skip install during upgrade.") + self.set_handler_state(ExtHandlerState.Installed) - def collect_ext_status(self): + def get_largest_seq_no(self): + seq_no = -1 + conf_dir = self.get_conf_dir() + for item in os.listdir(conf_dir): + item_path = os.path.join(conf_dir, item) + if os.path.isfile(item_path): + try: + seperator = item.rfind(".") + if seperator > 0 and item[seperator + 1:] == 'settings': + curr_seq_no = int(item.split('.')[0]) + if curr_seq_no > seq_no: + seq_no = curr_seq_no + except Exception as e: + self.logger.verb("Failed to parse file name: {0}", item) + continue + return seq_no + + def collect_ext_status(self, ext): self.logger.verb("Collect extension status") - if self.handler_status is None: - return - if self.ext is None: - return + seq_no = self.get_largest_seq_no() + if seq_no == -1: + return None + + status_dir = self.get_status_dir() + ext_status_file = "{0}.status".format(seq_no) + ext_status_file = os.path.join(status_dir, ext_status_file) - ext_status_file = self.get_status_file() + ext_status = ExtensionStatus(seq_no=seq_no) try: data_str = fileutil.read_file(ext_status_file) data = json.loads(data_str) - parse_ext_status(self.ext_status, data) + parse_ext_status(ext_status, data) except IOError as e: - raise ExtensionError("Failed to get status file: {0}".format(e)) + ext_status.message = u"Failed to get status file {0}".format(e) + ext_status.code = -1 + ext_status.status = "error" except ValueError as e: - raise ExtensionError("Malformed status file: {0}".format(e)) - - def make_handler_state_dir(self): - handler_state_dir = self.get_handler_state_dir() - fileutil.mkdir(handler_state_dir, 0o600) - if not os.path.exists(handler_state_dir): - os.makedirs(handler_state_dir) - - def get_state(self): - handler_state_file = self.get_handler_state_file() - if not os.path.isfile(handler_state_file): - return None - try: - handler_state = fileutil.read_file(handler_state_file) - if handler_state is not None: - handler_state = handler_state.rstrip() - return handler_state - except IOError as e: - err = "Failed to get handler state: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) - - def set_state(self, state): - handler_state_file = self.get_handler_state_file() - if not os.path.isfile(handler_state_file): - self.make_handler_state_dir() - try: - fileutil.write_file(handler_state_file, state) - except IOError as e: - err = "Failed to set handler state: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) - - def get_state_err(self): - """Get handler error message""" - handler_state_err_file= self.get_handler_state_err_file() - if not os.path.isfile(handler_state_err_file): - return None - try: - message = fileutil.read_file(handler_state_err_file) - return message - except IOError as e: - err = "Failed to get handler state message: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) - - def set_state_err(self, message): - """Set handler error message""" - handler_state_err_file = self.get_handler_state_err_file() - if not os.path.isfile(handler_state_err_file): - self.make_handler_state_dir() - try: - fileutil.write_file(handler_state_err_file, message) - except IOError as e: - err = "Failed to set handler state message: {0}".format(e) - add_event(name=self.name, is_success=False, message=err) + ext_status.message = u"Malformed status file {0}".format(e) + ext_status.code = -1 + ext_status.status = "error" + return ext_status + + def report_ext_status(self): + active_exts = [] + for ext in self.ext_handler.properties.extensions: + ext_status = self.collect_ext_status(ext) + if ext_status is None: + continue + try: + self.protocol.report_ext_status(self.ext_handler.name, ext.name, + ext_status) + active_exts.append(ext.name) + except ProtocolError as e: + self.logger.error(u"Failed to report extension status: {0}", e) + return active_exts + def collect_heartbeat(self): - self.logger.info("Collect heart beat") - heartbeat_file = os.path.join(OSUTIL.get_lib_dir(), + man = self.load_manifest() + if not man.is_report_heartbeat(): + return + heartbeat_file = os.path.join(conf.get_lib_dir(), self.get_heartbeat_file()) + + self.logger.info("Collect heart beat") if not os.path.isfile(heartbeat_file): raise ExtensionError("Failed to get heart beat file") if not self.is_responsive(heartbeat_file): @@ -586,15 +574,14 @@ class ExtHandlerInstance(object): except ValueError as e: raise ExtensionError("Malformed heartbeat file: {0}".format(e)) return heartbeat - + def is_responsive(self, heartbeat_file): last_update=int(time.time() - os.stat(heartbeat_file).st_mtime) return last_update > 600 # not updated for more than 10 min - + def launch_command(self, cmd, timeout=300): self.logger.info("Launch command:{0}", cmd) base_dir = self.get_base_dir() - self.update_settings() try: devnull = open(os.devnull, 'w') child = subprocess.Popen(base_dir + "/" + cmd, shell=True, @@ -614,6 +601,7 @@ class ExtHandlerInstance(object): ret = child.wait() if ret == None or ret != 0: raise ExtensionError("Non-zero exit code: {0}, {1}".format(ret, cmd)) + self.report_event(message="Launch command succeeded: {0}".format(cmd)) def load_manifest(self): @@ -627,26 +615,40 @@ class ExtHandlerInstance(object): return HandlerManifest(data[0]) + def update_settings_file(self, settings_file, settings): + settings_file = os.path.join(self.get_conf_dir(), settings_file) + try: + fileutil.write_file(settings_file, settings) + except IOError as e: + raise ExtensionError(u"Failed to update settings file", e) + def update_settings(self): - if self.ext is None: - self.logger.verb("Extension has no settings") + if self.ext_handler.properties.extensions is None or \ + len(self.ext_handler.properties.extensions) == 0: + #This is the behavior of waagent 2.0.x + #The new agent has to be consistent with the old one. + self.logger.info("Extension has no settings, write empty 0.settings") + self.update_settings_file("0.settings", "") return - - settings = { - 'publicSettings': self.ext.publicSettings, - 'protectedSettings': self.ext.privateSettings, - 'protectedSettingsCertThumbprint': self.ext.certificateThumbprint - } - ext_settings = { - "runtimeSettings":[{ - "handlerSettings": settings - }] - } - fileutil.write_file(self.get_settings_file(), json.dumps(ext_settings)) + + for ext in self.ext_handler.properties.extensions: + settings = { + 'publicSettings': ext.publicSettings, + 'protectedSettings': ext.protectedSettings, + 'protectedSettingsCertThumbprint': ext.certificateThumbprint + } + ext_settings = { + "runtimeSettings":[{ + "handlerSettings": settings + }] + } + settings_file = "{0}.settings".format(ext.sequenceNumber) + self.logger.info("Update settings file: {0}", settings_file) + self.update_settings_file(settings_file, json.dumps(ext_settings)) def create_handler_env(self): env = [{ - "name": self.name, + "name": self.ext_handler.name, "version" : HANDLER_ENVIRONMENT_VERSION, "handlerEnvironment" : { "logFolder" : self.get_log_dir(), @@ -655,73 +657,91 @@ class ExtHandlerInstance(object): "heartbeatFile" : self.get_heartbeat_file() } }] - fileutil.write_file(self.get_env_file(), - json.dumps(env)) - - def get_target_version(self): - version = self.version - update_policy = self.update_policy - if update_policy is None or update_policy.lower() != 'auto': - return version - - major = version.split('.')[0] - if major is None: - raise ExtensionError("Wrong version format: {0}".format(version)) - - packages = [x for x in self.pkg_list.versions \ - if x.version.startswith(major + ".")] - packages = sorted(packages, key=lambda x: Version(x.version), - reverse=True) - if len(packages) <= 0: - raise ExtensionError("Can't find version: {0}.*".format(major)) + try: + fileutil.write_file(self.get_env_file(), json.dumps(env)) + except IOError as e: + raise ExtensionError(u"Failed to save handler environment", e) + + def get_handler_state_dir(self): + return os.path.join(conf.get_lib_dir(), "handler_state", + self.get_full_name()) - return packages[0].version + def set_handler_state(self, handler_state): + state_dir = self.get_handler_state_dir() + if not os.path.exists(state_dir): + try: + fileutil.mkdir(state_dir, 0o700) + except IOError as e: + self.logger.error("Failed to create state dir: {0}", e) + + try: + state_file = os.path.join(state_dir, "state") + fileutil.write_file(state_file, handler_state) + except IOError as e: + self.logger.error("Failed to set state: {0}", e) + + def get_handler_state(self): + state_dir = self.get_handler_state_dir() + state_file = os.path.join(state_dir, "state") + if not os.path.isfile(state_file): + return ExtHandlerState.NotInstalled - def get_package_uris(self): - version = self.curr_version - packages = self.pkg_list.versions - if packages is None: - raise ExtensionError("Package uris is None.") + try: + return fileutil.read_file(state_file) + except IOError as e: + self.logger.error("Failed to get state: {0}", e) + return ExtHandlerState.NotInstalled + + def set_handler_status(self, status="NotReady", message="", + code=0): + state_dir = self.get_handler_state_dir() + if not os.path.exists(state_dir): + try: + fileutil.mkdir(state_dir, 0o700) + except IOError as e: + self.logger.error("Failed to create state dir: {0}", e) + + handler_status = ExtHandlerStatus() + handler_status.name = self.ext_handler.name + handler_status.version = self.ext_handler.properties.version + handler_status.message = message + handler_status.code = code + handler_status.status = status + status_file = os.path.join(state_dir, "status") - for package in packages: - if Version(package.version) == Version(version): - return package.uris + try: + fileutil.write_file(status_file, + json.dumps(get_properties(handler_status))) + except (IOError, ValueError, ProtocolError) as e: + self.logger.error("Failed to save handler status: {0}", e) + + def get_handler_status(self): + state_dir = self.get_handler_state_dir() + status_file = os.path.join(state_dir, "status") + if not os.path.isfile(status_file): + return None + + try: + data = json.loads(fileutil.read_file(status_file)) + handler_status = ExtHandlerStatus() + set_properties("ExtHandlerStatus", handler_status, data) + return handler_status + except (IOError, ValueError) as e: + self.logger.error("Failed to get handler status: {0}", e) - raise ExtensionError("Can't get package uris for {0}.".format(version)) - def get_full_name(self): - return "{0}-{1}".format(self.name, self.curr_version) - + return "{0}-{1}".format(self.ext_handler.name, + self.ext_handler.properties.version) + def get_base_dir(self): - return os.path.join(OSUTIL.get_lib_dir(), self.get_full_name()) + return os.path.join(conf.get_lib_dir(), self.get_full_name()) def get_status_dir(self): return os.path.join(self.get_base_dir(), "status") - def get_status_file(self): - return os.path.join(self.get_status_dir(), - "{0}.status".format(self.ext.sequenceNumber)) - def get_conf_dir(self): return os.path.join(self.get_base_dir(), 'config') - def get_settings_file(self): - return os.path.join(self.get_conf_dir(), - "{0}.settings".format(self.ext.sequenceNumber)) - - def get_handler_state_dir(self): - return os.path.join(OSUTIL.get_lib_dir(), "handler_state", - self.get_full_name()) - - def get_handler_state_file(self): - return os.path.join(self.get_handler_state_dir(), - '{0}.state'.format(self.ext.sequenceNumber)) - - def get_handler_state_err_file(self): - return os.path.join(self.get_handler_state_dir(), - '{0}.error'.format(self.ext.sequenceNumber)) - - def get_heartbeat_file(self): return os.path.join(self.get_base_dir(), 'heartbeat.log') @@ -732,8 +752,8 @@ class ExtHandlerInstance(object): return os.path.join(self.get_base_dir(), 'HandlerEnvironment.json') def get_log_dir(self): - return os.path.join(OSUTIL.get_ext_log_dir(), self.name, - self.curr_version) + return os.path.join(conf.get_ext_log_dir(), self.ext_handler.name, + self.ext_handler.properties.version) class HandlerEnvironment(object): def __init__(self, data): @@ -782,19 +802,16 @@ class HandlerManifest(object): return self.data['handlerManifest']["disableCommand"] def is_reboot_after_install(self): - #TODO handle reboot after install - if "rebootAfterInstall" not in self.data['handlerManifest']: - return False - return self.data['handlerManifest']["rebootAfterInstall"] + """ + Deprecated + """ + return False def is_report_heartbeat(self): - if "reportHeartbeat" not in self.data['handlerManifest']: - return False - return self.data['handlerManifest']["reportHeartbeat"] + return self.data['handlerManifest'].get('reportHeartbeat', False) def is_update_with_install(self): - if "updateMode" not in self.data['handlerManifest']: - return False - if "updateMode" in self.data: - return self.data['handlerManifest']["updateMode"].lower() == "updatewithinstall" - return False + update_mode = self.data['handlerManifest'].get('updateMode') + if update_mode is None: + return True + return update_mode.low() == "updatewithinstall" diff --git a/azurelinuxagent/distro/default/handlerFactory.py b/azurelinuxagent/distro/default/handlerFactory.py deleted file mode 100644 index dceb2a3..0000000 --- a/azurelinuxagent/distro/default/handlerFactory.py +++ /dev/null @@ -1,40 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# -from .init import InitHandler -from .run import MainHandler -from .scvmm import ScvmmHandler -from .dhcp import DhcpHandler -from .env import EnvHandler -from .provision import ProvisionHandler -from .resourceDisk import ResourceDiskHandler -from .extension import ExtHandlersHandler -from .deprovision import DeprovisionHandler - -class DefaultHandlerFactory(object): - def __init__(self): - self.init_handler = InitHandler() - self.main_handler = MainHandler(self) - self.scvmm_handler = ScvmmHandler() - self.dhcp_handler = DhcpHandler() - self.env_handler = EnvHandler(self) - self.provision_handler = ProvisionHandler() - self.resource_disk_handler = ResourceDiskHandler() - self.ext_handlers_handler = ExtHandlersHandler() - self.deprovision_handler = DeprovisionHandler() - diff --git a/azurelinuxagent/distro/default/init.py b/azurelinuxagent/distro/default/init.py index db74fef..c703e87 100644 --- a/azurelinuxagent/distro/default/init.py +++ b/azurelinuxagent/distro/default/init.py @@ -20,30 +20,34 @@ import os import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.utils.fileutil as fileutil +import azurelinuxagent.event as event class InitHandler(object): - def init(self, verbose): + def __init__(self, distro): + self.distro = distro + + def run(self, verbose): #Init stdout log level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO logger.add_logger_appender(logger.AppenderType.STDOUT, level) #Init config - conf_file_path = OSUTIL.get_conf_file_path() - conf.load_conf(conf_file_path) + conf_file_path = self.distro.osutil.get_agent_conf_file_path() + conf.load_conf_from_file(conf_file_path) #Init log - verbose = verbose or conf.get_switch("Logs.Verbose", False) + verbose = verbose or conf.get_logs_verbose() level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO logger.add_logger_appender(logger.AppenderType.FILE, level, path="/var/log/waagent.log") logger.add_logger_appender(logger.AppenderType.CONSOLE, level, path="/dev/console") - #Create lib dir - fileutil.mkdir(OSUTIL.get_lib_dir(), mode=0o700) - os.chdir(OSUTIL.get_lib_dir()) + #Init event reporter + event_dir = os.path.join(conf.get_lib_dir(), "events") + event.init_event_logger(event_dir) + event.enable_unhandled_err_dump("WALA") + diff --git a/azurelinuxagent/distro/default/monitor.py b/azurelinuxagent/distro/default/monitor.py new file mode 100644 index 0000000..3b26c9a --- /dev/null +++ b/azurelinuxagent/distro/default/monitor.py @@ -0,0 +1,182 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import sys +import traceback +import atexit +import json +import time +import datetime +import threading +import platform +import azurelinuxagent.logger as logger +import azurelinuxagent.conf as conf +from azurelinuxagent.event import WALAEventOperation, add_event +from azurelinuxagent.exception import EventError, ProtocolError, OSUtilError +from azurelinuxagent.future import ustr +from azurelinuxagent.utils.textutil import parse_doc, findall, find, getattrib +from azurelinuxagent.protocol.restapi import TelemetryEventParam, \ + TelemetryEventList, \ + TelemetryEvent, \ + set_properties, get_properties +from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_CODE_NAME, AGENT_LONG_VERSION + + +def parse_event(data_str): + try: + return parse_json_event(data_str) + except ValueError: + return parse_xml_event(data_str) + +def parse_xml_param(param_node): + name = getattrib(param_node, "Name") + value_str = getattrib(param_node, "Value") + attr_type = getattrib(param_node, "T") + value = value_str + if attr_type == 'mt:uint64': + value = int(value_str) + elif attr_type == 'mt:bool': + value = bool(value_str) + elif attr_type == 'mt:float64': + value = float(value_str) + return TelemetryEventParam(name, value) + +def parse_xml_event(data_str): + try: + xml_doc = parse_doc(data_str) + event_id = getattrib(find(xml_doc, "Event"), 'id') + provider_id = getattrib(find(xml_doc, "Provider"), 'id') + event = TelemetryEvent(event_id, provider_id) + param_nodes = findall(xml_doc, 'Param') + for param_node in param_nodes: + event.parameters.append(parse_xml_param(param_node)) + return event + except Exception as e: + raise ValueError(ustr(e)) + +def parse_json_event(data_str): + data = json.loads(data_str) + event = TelemetryEvent() + set_properties("TelemetryEvent", event, data) + return event + + +class MonitorHandler(object): + def __init__(self, distro): + self.distro = distro + self.sysinfo = [] + + def run(self): + event_thread = threading.Thread(target = self.daemon) + event_thread.setDaemon(True) + event_thread.start() + + def init_sysinfo(self): + osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(), + DISTRO_NAME, + DISTRO_VERSION, + DISTRO_CODE_NAME, + platform.release()) + + + self.sysinfo.append(TelemetryEventParam("OSVersion", osversion)) + self.sysinfo.append(TelemetryEventParam("GAVersion", AGENT_LONG_VERSION)) + + try: + ram = self.distro.osutil.get_total_mem() + processors = self.distro.osutil.get_processor_cores() + self.sysinfo.append(TelemetryEventParam("RAM", ram)) + self.sysinfo.append(TelemetryEventParam("Processors", processors)) + except OSUtilError as e: + logger.warn("Failed to get system info: {0}", e) + + try: + protocol = self.distro.protocol_util.get_protocol() + vminfo = protocol.get_vminfo() + self.sysinfo.append(TelemetryEventParam("VMName", + vminfo.vmName)) + self.sysinfo.append(TelemetryEventParam("TenantName", + vminfo.tenantName)) + self.sysinfo.append(TelemetryEventParam("RoleName", + vminfo.roleName)) + self.sysinfo.append(TelemetryEventParam("RoleInstanceName", + vminfo.roleInstanceName)) + self.sysinfo.append(TelemetryEventParam("ContainerId", + vminfo.containerId)) + except ProtocolError as e: + logger.warn("Failed to get system info: {0}", e) + + def collect_event(self, evt_file_name): + try: + logger.verb("Found event file: {0}", evt_file_name) + with open(evt_file_name, "rb") as evt_file: + #if fail to open or delete the file, throw exception + data_str = evt_file.read().decode("utf-8",'ignore') + logger.verb("Processed event file: {0}", evt_file_name) + os.remove(evt_file_name) + return data_str + except IOError as e: + msg = "Failed to process {0}, {1}".format(evt_file_name, e) + raise EventError(msg) + + def collect_and_send_events(self): + event_list = TelemetryEventList() + event_dir = os.path.join(conf.get_lib_dir(), "events") + event_files = os.listdir(event_dir) + for event_file in event_files: + if not event_file.endswith(".tld"): + continue + event_file_path = os.path.join(event_dir, event_file) + try: + data_str = self.collect_event(event_file_path) + except EventError as e: + logger.error("{0}", e) + continue + + try: + event = parse_event(data_str) + event.parameters.extend(self.sysinfo) + event_list.events.append(event) + except (ValueError, ProtocolError) as e: + logger.warn("Failed to decode event file: {0}", e) + continue + + if len(event_list.events) == 0: + return + + try: + protocol = self.distro.protocol_util.get_protocol() + protocol.report_event(event_list) + except ProtocolError as e: + logger.error("{0}", e) + + def daemon(self): + self.init_sysinfo() + last_heartbeat = datetime.datetime.min + period = datetime.timedelta(hours = 12) + while(True): + if (datetime.datetime.now()-last_heartbeat) > period: + last_heartbeat = datetime.datetime.now() + add_event(op=WALAEventOperation.HeartBeat, name="WALA", + is_success=True) + try: + self.collect_and_send_events() + except Exception as e: + logger.warn("Failed to send events: {0}", e) + time.sleep(60) diff --git a/azurelinuxagent/distro/default/osutil.py b/azurelinuxagent/distro/default/osutil.py index 00a57cc..18ab2ba 100644 --- a/azurelinuxagent/distro/default/osutil.py +++ b/azurelinuxagent/distro/default/osutil.py @@ -25,11 +25,15 @@ import struct import time import pwd import fcntl +import base64 import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +import azurelinuxagent.conf as conf +from azurelinuxagent.exception import OSUtilError +from azurelinuxagent.future import ustr import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.textutil as textutil +from azurelinuxagent.utils.cryptutil import CryptUtil __RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules" ] @@ -40,44 +44,14 @@ for all distros. Each concrete distro classes could overwrite default behavior if needed. """ -class OSUtilError(Exception): - pass - class DefaultOSUtil(object): def __init__(self): - self.lib_dir = "/var/lib/waagent" - self.ext_log_dir = "/var/log/azure" - self.dvd_mount_point = "/mnt/cdrom/secure" - self.ovf_env_file_path = "/mnt/cdrom/secure/ovf-env.xml" - self.agent_pid_file_path = "/var/run/waagent.pid" - self.passwd_file_path = "/etc/shadow" - self.home = '/home' - self.sshd_conf_file_path = '/etc/ssh/sshd_config' - self.openssl_cmd = '/usr/bin/openssl' - self.conf_file_path = '/etc/waagent.conf' + self.agent_conf_file_path = '/etc/waagent.conf' self.selinux=None - def get_lib_dir(self): - return self.lib_dir - - def get_ext_log_dir(self): - return self.ext_log_dir - - def get_dvd_mount_point(self): - return self.dvd_mount_point - - def get_conf_file_path(self): - return self.conf_file_path - - def get_ovf_env_file_path_on_dvd(self): - return self.ovf_env_file_path - - def get_agent_pid_file_path(self): - return self.agent_pid_file_path - - def get_openssl_cmd(self): - return self.openssl_cmd + def get_agent_conf_file_path(self): + return self.agent_conf_file_path def get_userentry(self, username): try: @@ -86,6 +60,14 @@ class DefaultOSUtil(object): return None def is_sys_user(self, username): + """ + Check whether use is a system user. + If reset sys user is allowed in conf, return False + Otherwise, check whether UID is less than UID_MIN + """ + if conf.get_allow_reset_sys_user(): + return False + userentry = self.get_userentry(username) uidmin = None try: @@ -104,9 +86,13 @@ class DefaultOSUtil(object): def useradd(self, username, expiration=None): """ - Update password and ssh key for user account. - New account will be created if not exists. + Create user account with 'username' """ + userentry = self.get_userentry(username) + if userentry is not None: + logger.info("User {0} already exists, skip useradd", username) + return + if expiration is not None: cmd = "useradd -m {0} -e {1}".format(username, expiration) else: @@ -146,42 +132,21 @@ class DefaultOSUtil(object): def del_root_password(self): try: - passwd_content = fileutil.read_file(self.passwd_file_path) + passwd_file_path = conf.get_passwd_file_path() + passwd_content = fileutil.read_file(passwd_file_path) passwd = passwd_content.split('\n') new_passwd = [x for x in passwd if not x.startswith("root:")] new_passwd.insert(0, "root:*LOCK*:14600::::::") - fileutil.write_file(self.passwd_file_path, "\n".join(new_passwd)) + fileutil.write_file(passwd_file_path, "\n".join(new_passwd)) except IOError as e: raise OSUtilError("Failed to delete root password:{0}".format(e)) - def get_home(self): - return self.home - - def get_pubkey_from_prv(self, file_name): - cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, - file_name) - pub = shellutil.run_get_output(cmd)[1] - return pub - - def get_pubkey_from_crt(self, file_name): - cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd, - file_name) - pub = shellutil.run_get_output(cmd)[1] - return pub - def _norm_path(self, filepath): - home = self.get_home() + home = conf.get_home_dir() # Expand HOME variable if present in path path = os.path.normpath(filepath.replace("$HOME", home)) return path - def get_thumbprint_from_crt(self, file_name): - cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd, - file_name) - thumbprint = shellutil.run_get_output(cmd)[1] - thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper() - return thumbprint - def deploy_ssh_keypair(self, username, keypair): """ Deploy id_rsa and id_rsa.pub @@ -190,13 +155,14 @@ class DefaultOSUtil(object): path = self._norm_path(path) dir_path = os.path.dirname(path) fileutil.mkdir(dir_path, mode=0o700, owner=username) - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() prv_path = os.path.join(lib_dir, thumbprint + '.prv') if not os.path.isfile(prv_path): raise OSUtilError("Can't find {0}.prv".format(thumbprint)) shutil.copyfile(prv_path, path) pub_path = path + '.pub' - pub = self.get_pubkey_from_prv(prv_path) + crytputil = CryptUtil(conf.get_openssl_cmd()) + pub = crytputil.get_pubkey_from_prv(prv_path) fileutil.write_file(pub_path, pub) self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0') @@ -204,8 +170,8 @@ class DefaultOSUtil(object): os.chmod(pub_path, 0o600) def openssl_to_openssh(self, input_file, output_file): - shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file, - output_file)) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.crt_to_ssh(input_file, output_file) def deploy_ssh_pubkey(self, username, pubkey): """ @@ -215,6 +181,8 @@ class DefaultOSUtil(object): if path is None: raise OSUtilError("Publich key path is None") + crytputil = CryptUtil(conf.get_openssl_cmd()) + path = self._norm_path(path) dir_path = os.path.dirname(path) fileutil.mkdir(dir_path, mode=0o700, owner=username) @@ -223,12 +191,12 @@ class DefaultOSUtil(object): raise OSUtilError("Bad public key: {0}".format(value)) fileutil.write_file(path, value) elif thumbprint is not None: - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() crt_path = os.path.join(lib_dir, thumbprint + '.crt') if not os.path.isfile(crt_path): raise OSUtilError("Can't find {0}.crt".format(thumbprint)) pub_path = os.path.join(lib_dir, thumbprint + '.pub') - pub = self.get_pubkey_from_crt(crt_path) + pub = crytputil.get_pubkey_from_crt(crt_path) fileutil.write_file(pub_path, pub) self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') @@ -280,23 +248,21 @@ class DefaultOSUtil(object): if self.is_selinux_system(): return shellutil.run('chcon ' + con + ' ' + path) - def get_sshd_conf_file_path(self): - return self.sshd_conf_file_path - def set_ssh_client_alive_interval(self): - conf_file_path = self.get_sshd_conf_file_path() - conf = fileutil.read_file(conf_file_path).split("\n") - textutil.set_ssh_config(conf, "ClientAliveInterval", "180") - fileutil.write_file(conf_file_path, '\n'.join(conf)) + conf_file_path = conf.get_sshd_conf_file_path() + conf_file = fileutil.read_file(conf_file_path).split("\n") + textutil.set_ssh_config(conf_file, "ClientAliveInterval", "180") + fileutil.write_file(conf_file_path, '\n'.join(conf_file)) logger.info("Configured SSH client probing to keep connections alive.") def conf_sshd(self, disable_password): option = "no" if disable_password else "yes" - conf_file_path = self.get_sshd_conf_file_path() - conf = fileutil.read_file(conf_file_path).split("\n") - textutil.set_ssh_config(conf, "PasswordAuthentication", option) - textutil.set_ssh_config(conf, "ChallengeResponseAuthentication", option) - fileutil.write_file(conf_file_path, "\n".join(conf)) + conf_file_path = conf.get_sshd_conf_file_path() + conf_file = fileutil.read_file(conf_file_path).split("\n") + textutil.set_ssh_config(conf_file, "PasswordAuthentication", option) + textutil.set_ssh_config(conf_file, "ChallengeResponseAuthentication", + option) + fileutil.write_file(conf_file_path, "\n".join(conf_file)) logger.info("Disabled SSH password-based authentication methods.") @@ -309,7 +275,7 @@ class DefaultOSUtil(object): def mount_dvd(self, max_retry=6, chk_err=True): dvd = self.get_dvd_device() - mount_point = self.get_dvd_mount_point() + mount_point = conf.get_dvd_mount_point() mountlist = shellutil.run_get_output("mount")[1] existing = self.get_mount_point(mountlist, dvd) if existing is not None: #Already mounted @@ -332,7 +298,7 @@ class DefaultOSUtil(object): raise OSUtilError("Failed to mount dvd.") def umount_dvd(self, chk_err=True): - mount_point = self.get_dvd_mount_point() + mount_point = conf.get_dvd_mount_point() retcode = self.umount(mount_point, chk_err=chk_err) if chk_err and retcode != 0: raise OSUtilError("Failed to umount dvd.") @@ -386,17 +352,9 @@ class DefaultOSUtil(object): shellutil.run("iptables -I INPUT -p udp --dport 68 -j ACCEPT", chk_err=False) - def gen_transport_cert(self): - """ - Create ssl certificate for https communication with endpoint server. - """ - cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 " - "-newkey rsa:2048 -keyout TransportPrivate.pem " - "-out TransportCert.pem").format(self.openssl_cmd) - shellutil.run(cmd) def remove_rules_files(self, rules_files=__RULES_FILES__): - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() for src in rules_files: file_name = fileutil.base_name(src) dest = os.path.join(lib_dir, file_name) @@ -407,7 +365,7 @@ class DefaultOSUtil(object): shutil.move(src, dest) def restore_rules_files(self, rules_files=__RULES_FILES__): - lib_dir = self.get_lib_dir() + lib_dir = conf.get_lib_dir() for dest in rules_files: filename = fileutil.base_name(dest) src = os.path.join(lib_dir, filename) @@ -603,7 +561,7 @@ class DefaultOSUtil(object): for vmbus in os.listdir(path): deviceid = fileutil.read_file(os.path.join(path, vmbus, "device_id")) guid = deviceid.lstrip('{').split('-') - if guid[0] == g0 and guid[1] == "000" + text(port_id): + if guid[0] == g0 and guid[1] == "000" + ustr(port_id): for root, dirs, files in os.walk(path + vmbus): if root.endswith("/block"): device = dirs[0] @@ -633,7 +591,7 @@ class DefaultOSUtil(object): raise OSUtilError("Failed to remove sudoer: {0}".format(e)) def decode_customdata(self, data): - return data + return base64.b64decode(data) def get_total_mem(self): cmd = "grep MemTotal /proc/meminfo |awk '{print $2}'" @@ -649,4 +607,17 @@ class DefaultOSUtil(object): return int(ret[1]) else: raise OSUtilError("Failed to get procerssor cores") + + def set_admin_access_to_ip(self, dest_ip): + #This allows root to access dest_ip + rm_old= "iptables -D OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0" + rule = "iptables -A OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0" + shellutil.run(rm_old.format(dest_ip), chk_err=False) + shellutil.run(rule.format(dest_ip)) + + #This blocks all other users to access dest_ip + rm_old = "iptables -D OUTPUT -d {0} -j DROP" + rule = "iptables -A OUTPUT -d {0} -j DROP" + shellutil.run(rm_old.format(dest_ip), chk_err=False) + shellutil.run(rule.format(dest_ip)) diff --git a/azurelinuxagent/distro/default/protocolUtil.py b/azurelinuxagent/distro/default/protocolUtil.py new file mode 100644 index 0000000..34466cf --- /dev/null +++ b/azurelinuxagent/distro/default/protocolUtil.py @@ -0,0 +1,243 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import re +import shutil +import time +import threading +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.exception import ProtocolError, OSUtilError, \ + ProtocolNotFoundError, DhcpError +from azurelinuxagent.future import ustr +import azurelinuxagent.utils.fileutil as fileutil +from azurelinuxagent.protocol.ovfenv import OvfEnv +from azurelinuxagent.protocol.wire import WireProtocol +from azurelinuxagent.protocol.metadata import MetadataProtocol, METADATA_ENDPOINT +import azurelinuxagent.utils.shellutil as shellutil + +OVF_FILE_NAME = "ovf-env.xml" + +#Tag file to indicate usage of metadata protocol +TAG_FILE_NAME = "useMetadataEndpoint.tag" + +PROTOCOL_FILE_NAME = "Protocol" + +#MAX retry times for protocol probing +MAX_RETRY = 360 + +PROBE_INTERVAL = 10 + +ENDPOINT_FILE_NAME = "WireServerEndpoint" + +class ProtocolUtil(object): + """ + ProtocolUtil handles initialization for protocol instance. 2 protocol types + are invoked, wire protocol and metadata protocols. + """ + def __init__(self, distro): + self.distro = distro + self.protocol = None + self.lock = threading.Lock() + + def copy_ovf_env(self): + """ + Copy ovf env file from dvd to hard disk. + Remove password before save it to the disk + """ + dvd_mount_point = conf.get_dvd_mount_point() + ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME) + tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME) + try: + self.distro.osutil.mount_dvd() + ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) + ovfenv = OvfEnv(ovfxml) + ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml) + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + fileutil.write_file(ovf_file_path, ovfxml) + + if os.path.isfile(tag_file_path_on_dvd): + logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME) + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + shutil.copyfile(tag_file_path_on_dvd, tag_file_path) + + except (OSUtilError, IOError) as e: + raise ProtocolError(ustr(e)) + + try: + self.distro.osutil.umount_dvd() + self.distro.osutil.eject_dvd() + except OSUtilError as e: + logger.warn(ustr(e)) + + return ovfenv + + def get_ovf_env(self): + """ + Load saved ovf-env.xml + """ + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + if os.path.isfile(ovf_file_path): + xml_text = fileutil.read_file(ovf_file_path) + return OvfEnv(xml_text) + else: + raise ProtocolError("ovf-env.xml is missing.") + + def _get_wireserver_endpoint(self): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + return fileutil.read_file(file_path) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _set_wireserver_endpoint(self, endpoint): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + fileutil.write_file(file_path, endpoint) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _detect_wire_protocol(self): + endpoint = self.distro.dhcp_handler.endpoint + if endpoint is None: + logger.info("WireServer endpoint is not found. Rerun dhcp handler") + try: + self.distro.dhcp_handler.run() + except DhcpError as e: + raise ProtocolError(ustr(e)) + endpoint = self.distro.dhcp_handler.endpoint + + try: + protocol = WireProtocol(endpoint) + protocol.detect() + self._set_wireserver_endpoint(endpoint) + return protocol + except ProtocolError as e: + logger.info("WireServer is not responding. Reset endpoint") + self.distro.dhcp_handler.endpoint = None + raise e + + def _detect_metadata_protocol(self): + protocol = MetadataProtocol() + protocol.detect() + + #Only allow root access METADATA_ENDPOINT + self.distro.osutil.set_admin_access_to_ip(METADATA_ENDPOINT) + + return protocol + + def _detect_protocol(self, protocols): + """ + Probe protocol endpoints in turn. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + if os.path.isfile(protocol_file_path): + os.remove(protocol_file_path) + for retry in range(0, MAX_RETRY): + for protocol in protocols: + try: + if protocol == "WireProtocol": + return self._detect_wire_protocol() + + if protocol == "MetadataProtocol": + return self._detect_metadata_protocol() + + except ProtocolError as e: + logger.info("Protocol endpoint not found: {0}, {1}", + protocol, e) + + if retry < MAX_RETRY -1: + logger.info("Retry detect protocols: retry={0}", retry) + time.sleep(PROBE_INTERVAL) + raise ProtocolNotFoundError("No protocol found.") + + def _get_protocol(self): + """ + Get protocol instance based on previous detecting result. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), + PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + raise ProtocolError("No protocl found") + + protocol_name = fileutil.read_file(protocol_file_path) + if protocol_name == "WireProtocol": + endpoint = self._get_wireserver_endpoint() + return WireProtocol(endpoint) + elif protocol_name == "MetadataProtocol": + return MetadataProtocol() + else: + raise ProtocolNotFoundError(("Unknown protocol: {0}" + "").format(protocol_name)) + + def detect_protocol(self): + """ + Detect protocol by endpoints + + :returns: protocol instance + """ + logger.info("Detect protocol endpoints") + protocols = ["WireProtocol", "MetadataProtocol"] + self.lock.acquire() + try: + if self.protocol is None: + self.protocol = self._detect_protocol(protocols) + return self.protocol + finally: + self.lock.release() + + def detect_protocol_by_file(self): + """ + Detect protocol by tag file. + + If a file "useMetadataEndpoint.tag" is found on provision iso, + metedata protocol will be used. No need to probe for wire protocol + + :returns: protocol instance + """ + logger.info("Detect protocol by file") + self.lock.acquire() + try: + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + if self.protocol is None: + protocols = [] + if os.path.isfile(tag_file_path): + protocols.append("MetadataProtocol") + else: + protocols.append("WireProtocol") + self.protocol = self._detect_protocol(protocols) + finally: + self.lock.release() + return self.protocol + + def get_protocol(self): + """ + Get protocol instance based on previous detecting result. + + :returns protocol instance + """ + self.lock.acquire() + try: + if self.protocol is None: + self.protocol = self._get_protocol() + return self.protocol + finally: + self.lock.release() + return self.protocol + diff --git a/azurelinuxagent/distro/default/provision.py b/azurelinuxagent/distro/default/provision.py index 424f083..695b82a 100644 --- a/azurelinuxagent/distro/default/provision.py +++ b/azurelinuxagent/distro/default/provision.py @@ -21,13 +21,11 @@ Provision handler import os import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.conf as conf from azurelinuxagent.event import add_event, WALAEventOperation -from azurelinuxagent.exception import * -from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError -import azurelinuxagent.protocol as prot -import azurelinuxagent.protocol.ovfenv as ovf +from azurelinuxagent.exception import ProvisionError, ProtocolError, OSUtilError +from azurelinuxagent.protocol.restapi import ProvisionStatus import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.fileutil as fileutil @@ -35,61 +33,49 @@ CUSTOM_DATA_FILE="CustomData" class ProvisionHandler(object): - def process(self): + def __init__(self, distro): + self.distro = distro + + def run(self): #If provision is not enabled, return - if not conf.get_switch("Provisioning.Enabled", True): + if not conf.get_provision_enabled(): logger.info("Provisioning is disabled. Skip.") - return + return - provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned") + provisioned = os.path.join(conf.get_lib_dir(), "provisioned") if os.path.isfile(provisioned): return - logger.info("run provision handler.") - protocol = prot.FACTORY.get_default_protocol() + logger.info("Run provision handler.") + logger.info("Copy ovf-env.xml.") + try: + ovfenv = self.distro.protocol_util.copy_ovf_env() + except ProtocolError as e: + self.report_event("Failed to copy ovf-env.xml: {0}".format(e)) + return + + self.distro.protocol_util.detect_protocol_by_file() + + self.report_not_ready("Provisioning", "Starting") + try: - status = prot.ProvisionStatus(status="NotReady", - subStatus="Provisioning", - description="Starting") - try: - protocol.report_provision_status(status) - except prot.ProtocolError as e: - add_event(name="WALA", is_success=False, message=text(e), - op=WALAEventOperation.Provision) - - self.provision() + logger.info("Start provisioning") + self.provision(ovfenv) fileutil.write_file(provisioned, "") thumbprint = self.reg_ssh_host_key() - logger.info("Finished provisioning") - status = prot.ProvisionStatus(status="Ready") - status.properties.certificateThumbprint = thumbprint - - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) - - add_event(name="WALA", is_success=True, message="", - op=WALAEventOperation.Provision) except ProvisionError as e: logger.error("Provision failed: {0}", e) - status = prot.ProvisionStatus(status="NotReady", - subStatus="ProvisioningFailed", - description= text(e)) - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) - - add_event(name="WALA", is_success=False, message=text(e), - op=WALAEventOperation.Provision) + self.report_not_ready("ProvisioningFailed", ustr(e)) + self.report_event(ustr(e)) + return + self.report_ready(thumbprint) + self.report_event("Provision succeed", is_success=True) + def reg_ssh_host_key(self): - keypair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa") - if conf.get_switch("Provisioning.RegenerateSshHostKeyPair"): + keypair_type = conf.get_ssh_host_keypair_type() + if conf.get_regenerate_ssh_host_key(): shellutil.run("rm -f /etc/ssh/ssh_host_*key*") shellutil.run(("ssh-keygen -N '' -t {0} -f /etc/ssh/ssh_host_{1}_key" "").format(keypair_type, keypair_type)) @@ -105,77 +91,101 @@ class ProvisionHandler(object): raise ProvisionError(("Failed to generate ssh host key: " "ret={0}, out= {1}").format(ret[0], ret[1])) - - def provision(self): - logger.info("Copy ovf-env.xml.") - try: - ovfenv = ovf.copy_ovf_env() - except prot.ProtocolError as e: - raise ProvisionError("Failed to copy ovf-env.xml: {0}".format(e)) - + def provision(self, ovfenv): logger.info("Handle ovf-env.xml.") try: logger.info("Set host name.") - OSUTIL.set_hostname(ovfenv.hostname) + self.distro.osutil.set_hostname(ovfenv.hostname) logger.info("Publish host name.") - OSUTIL.publish_hostname(ovfenv.hostname) + self.distro.osutil.publish_hostname(ovfenv.hostname) self.config_user_account(ovfenv) self.save_customdata(ovfenv) + + if conf.get_delete_root_password(): + self.distro.osutil.del_root_password() - if conf.get_switch("Provisioning.DeleteRootPassword"): - OSUTIL.del_root_password() except OSUtilError as e: raise ProvisionError("Failed to handle ovf-env.xml: {0}".format(e)) def config_user_account(self, ovfenv): logger.info("Create user account if not exists") - OSUTIL.useradd(ovfenv.username) + self.distro.osutil.useradd(ovfenv.username) if ovfenv.user_password is not None: logger.info("Set user password.") - crypt_id = conf.get("Provision.PasswordCryptId", "6") - salt_len = conf.get_int("Provision.PasswordCryptSaltLength", 10) - OSUTIL.chpasswd(ovfenv.username, ovfenv.user_password, + crypt_id = conf.get_password_cryptid() + salt_len = conf.get_password_crypt_salt_len() + self.distro.osutil.chpasswd(ovfenv.username, ovfenv.user_password, crypt_id=crypt_id, salt_len=salt_len) logger.info("Configure sudoer") - OSUTIL.conf_sudoer(ovfenv.username, ovfenv.user_password is None) + self.distro.osutil.conf_sudoer(ovfenv.username, ovfenv.user_password is None) logger.info("Configure sshd") - OSUTIL.conf_sshd(ovfenv.disable_ssh_password_auth) + self.distro.osutil.conf_sshd(ovfenv.disable_ssh_password_auth) #Disable selinux temporary - sel = OSUTIL.is_selinux_enforcing() + sel = self.distro.osutil.is_selinux_enforcing() if sel: - OSUTIL.set_selinux_enforce(0) + self.distro.osutil.set_selinux_enforce(0) self.deploy_ssh_pubkeys(ovfenv) self.deploy_ssh_keypairs(ovfenv) if sel: - OSUTIL.set_selinux_enforce(1) + self.distro.osutil.set_selinux_enforce(1) - OSUTIL.restart_ssh_service() + self.distro.osutil.restart_ssh_service() def save_customdata(self, ovfenv): - logger.info("Save custom data") customdata = ovfenv.customdata if customdata is None: return - lib_dir = OSUTIL.get_lib_dir() - fileutil.write_file(os.path.join(lib_dir, CUSTOM_DATA_FILE), - OSUTIL.decode_customdata(customdata)) + + logger.info("Save custom data") + lib_dir = conf.get_lib_dir() + if conf.get_decode_customdata(): + customdata= self.distro.osutil.decode_customdata(customdata) + customdata_file = os.path.join(lib_dir, CUSTOM_DATA_FILE) + fileutil.write_file(customdata_file, customdata) + + if conf.get_execute_customdata(): + logger.info("Execute custom data") + os.chmod(customdata_file, 0o700) + shellutil.run(customdata_file) def deploy_ssh_pubkeys(self, ovfenv): for pubkey in ovfenv.ssh_pubkeys: logger.info("Deploy ssh public key.") - OSUTIL.deploy_ssh_pubkey(ovfenv.username, pubkey) + self.distro.osutil.deploy_ssh_pubkey(ovfenv.username, pubkey) def deploy_ssh_keypairs(self, ovfenv): for keypair in ovfenv.ssh_keypairs: logger.info("Deploy ssh key pairs.") - OSUTIL.deploy_ssh_keypair(ovfenv.username, keypair) + self.distro.osutil.deploy_ssh_keypair(ovfenv.username, keypair) + + def report_event(self, message, is_success=False): + add_event(name="WALA", message=message, is_success=is_success, + op=WALAEventOperation.Provision) + + def report_not_ready(self, sub_status, description): + status = ProvisionStatus(status="NotReady", subStatus=sub_status, + description=description) + try: + protocol = self.distro.protocol_util.get_protocol() + protocol.report_provision_status(status) + except ProtocolError as e: + self.report_event(ustr(e)) + + def report_ready(self, thumbprint=None): + status = ProvisionStatus(status="Ready") + status.properties.certificateThumbprint = thumbprint + try: + protocol = self.distro.protocol_util.get_protocol() + protocol.report_provision_status(status) + except ProtocolError as e: + self.report_event(ustr(e)) diff --git a/azurelinuxagent/distro/default/resourceDisk.py b/azurelinuxagent/distro/default/resourceDisk.py index 734863c..a6c5232 100644 --- a/azurelinuxagent/distro/default/resourceDisk.py +++ b/azurelinuxagent/distro/default/resourceDisk.py @@ -21,9 +21,8 @@ import os import re import threading import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.conf as conf -from azurelinuxagent.utils.osutil import OSUTIL from azurelinuxagent.event import add_event, WALAEventOperation import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil @@ -41,6 +40,8 @@ For additional details to please refer to the MSDN documentation at : http://msd """ class ResourceDiskHandler(object): + def __init__(self, distro): + self.distro = distro def start_activate_resource_disk(self): disk_thread = threading.Thread(target = self.run) @@ -48,17 +49,17 @@ class ResourceDiskHandler(object): def run(self): mount_point = None - if conf.get_switch("ResourceDisk.Format", False): + if conf.get_resourcedisk_format(): mount_point = self.activate_resource_disk() if mount_point is not None and \ - conf.get_switch("ResourceDisk.EnableSwap", False): + conf.get_resourcedisk_enable_swap(): self.enable_swap(mount_point) def activate_resource_disk(self): logger.info("Activate resource disk") try: - mount_point = conf.get("ResourceDisk.MountPoint", "/mnt/resource") - fs = conf.get("ResourceDisk.Filesystem", "ext3") + mount_point = conf.get_resourcedisk_mountpoint() + fs = conf.get_resourcedisk_filesystem() mount_point = self.mount_resource_disk(mount_point, fs) warning_file = os.path.join(mount_point, DATALOSS_WARNING_FILE_NAME) try: @@ -68,25 +69,25 @@ class ResourceDiskHandler(object): return mount_point except ResourceDiskError as e: logger.error("Failed to mount resource disk {0}", e) - add_event(name="WALA", is_success=False, message=text(e), + add_event(name="WALA", is_success=False, message=ustr(e), op=WALAEventOperation.ActivateResourceDisk) def enable_swap(self, mount_point): logger.info("Enable swap") try: - size_mb = conf.get_int("ResourceDisk.SwapSizeMB", 0) + size_mb = conf.get_resourcedisk_swap_size_mb() self.create_swap_space(mount_point, size_mb) except ResourceDiskError as e: logger.error("Failed to enable swap {0}", e) def mount_resource_disk(self, mount_point, fs): - device = OSUTIL.device_for_ide_port(1) + device = self.distro.osutil.device_for_ide_port(1) if device is None: raise ResourceDiskError("unable to detect disk topology") device = "/dev/" + device mountlist = shellutil.run_get_output("mount")[1] - existing = OSUTIL.get_mount_point(mountlist, device) + existing = self.distro.osutil.get_mount_point(mountlist, device) if(existing): logger.info("Resource disk {0}1 is already mounted", device) diff --git a/azurelinuxagent/distro/default/run.py b/azurelinuxagent/distro/default/run.py deleted file mode 100644 index dfd3b03..0000000 --- a/azurelinuxagent/distro/default/run.py +++ /dev/null @@ -1,71 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -import os -import time -import sys -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.conf as conf -from azurelinuxagent.metadata import AGENT_LONG_NAME, AGENT_VERSION, \ - DISTRO_NAME, DISTRO_VERSION, \ - DISTRO_FULL_NAME, PY_VERSION_MAJOR, \ - PY_VERSION_MINOR, PY_VERSION_MICRO -import azurelinuxagent.event as event -import azurelinuxagent.protocol as prot -from azurelinuxagent.utils.osutil import OSUTIL -import azurelinuxagent.utils.fileutil as fileutil - - -class MainHandler(object): - def __init__(self, handlers): - self.handlers = handlers - - def run(self): - logger.info("{0} Version:{1}", AGENT_LONG_NAME, AGENT_VERSION) - logger.info("OS: {0} {1}", DISTRO_NAME, DISTRO_VERSION) - logger.info("Python: {0}.{1}.{2}", PY_VERSION_MAJOR, PY_VERSION_MINOR, - PY_VERSION_MICRO) - - event.enable_unhandled_err_dump("Azure Linux Agent") - fileutil.write_file(OSUTIL.get_agent_pid_file_path(), text(os.getpid())) - - if conf.get_switch("DetectScvmmEnv", False): - if self.handlers.scvmm_handler.detect_scvmm_env(): - return - - self.handlers.dhcp_handler.probe() - - prot.detect_default_protocol() - - event.EventMonitor().start() - - self.handlers.provision_handler.process() - - if conf.get_switch("ResourceDisk.Format", False): - self.handlers.resource_disk_handler.start_activate_resource_disk() - - self.handlers.env_handler.start() - - protocol = prot.FACTORY.get_default_protocol() - while True: - #Handle extensions - self.handlers.ext_handlers_handler.process() - time.sleep(25) - diff --git a/azurelinuxagent/distro/default/scvmm.py b/azurelinuxagent/distro/default/scvmm.py index 680c04b..4d083b4 100644 --- a/azurelinuxagent/distro/default/scvmm.py +++ b/azurelinuxagent/distro/default/scvmm.py @@ -20,28 +20,29 @@ import os import subprocess import azurelinuxagent.logger as logger -from azurelinuxagent.utils.osutil import OSUTIL VMM_CONF_FILE_NAME = "linuxosconfiguration.xml" VMM_STARTUP_SCRIPT_NAME= "install" class ScvmmHandler(object): + def __init__(self, distro): + self.distro = distro def detect_scvmm_env(self): logger.info("Detecting Microsoft System Center VMM Environment") - OSUTIL.mount_dvd(max_retry=1, chk_err=False) - mount_point = OSUTIL.get_dvd_mount_point() + self.distro.osutil.mount_dvd(max_retry=1, chk_err=False) + mount_point = self.distro.osutil.get_dvd_mount_point() found = os.path.isfile(os.path.join(mount_point, VMM_CONF_FILE_NAME)) if found: self.start_scvmm_agent() else: - OSUTIL.umount_dvd(chk_err=False) + self.distro.osutil.umount_dvd(chk_err=False) return found def start_scvmm_agent(self): logger.info("Starting Microsoft System Center VMM Initialization " "Process") - mount_point = OSUTIL.get_dvd_mount_point() + mount_point = self.distro.osutil.get_dvd_mount_point() startup_script = os.path.join(mount_point, VMM_STARTUP_SCRIPT_NAME) subprocess.Popen(["/bin/bash", startup_script, "-p " + mount_point]) diff --git a/azurelinuxagent/distro/loader.py b/azurelinuxagent/distro/loader.py index 375abd2..74ea9e7 100644 --- a/azurelinuxagent/distro/loader.py +++ b/azurelinuxagent/distro/loader.py @@ -16,31 +16,52 @@ # import azurelinuxagent.logger as logger -from azurelinuxagent.metadata import DISTRO_NAME -import azurelinuxagent.distro.default.loader as default_loader +from azurelinuxagent.utils.textutil import Version +from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_FULL_NAME +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.ubuntu.distro import UbuntuDistro, \ + Ubuntu14Distro, \ + Ubuntu12Distro, \ + UbuntuSnappyDistro +from azurelinuxagent.distro.redhat.distro import RedhatDistro, Redhat6xDistro +from azurelinuxagent.distro.coreos.distro import CoreOSDistro +from azurelinuxagent.distro.suse.distro import SUSE11Distro, SUSEDistro +from azurelinuxagent.distro.debian.distro import DebianDistro - -def get_distro_loader(): - try: - logger.verb("Loading distro implemetation from: {0}", DISTRO_NAME) - pkg_name = "azurelinuxagent.distro.{0}.loader".format(DISTRO_NAME) - return __import__(pkg_name, fromlist="loader") - except (ImportError, ValueError): - logger.warn("Unable to load distro implemetation for {0}.", DISTRO_NAME) +def get_distro(distro_name=DISTRO_NAME, distro_version=DISTRO_VERSION, + distro_full_name=DISTRO_FULL_NAME): + if distro_name == "ubuntu": + if Version(distro_version) == Version("12.04") or \ + Version(distro_version) == Version("12.10"): + return Ubuntu12Distro() + elif Version(distro_version) == Version("14.04") or \ + Version(distro_version) == Version("14.10"): + return Ubuntu14Distro() + elif distro_full_name == "Snappy Ubuntu Core": + return UbuntuSnappyDistro() + else: + return UbuntuDistro() + if distro_name == "coreos": + return CoreOSDistro() + if distro_name == "suse": + if distro_full_name=='SUSE Linux Enterprise Server' and \ + Version(distro_version) < Version('12') or \ + distro_full_name == 'openSUSE' and \ + Version(distro_version) < Version('13.2'): + return SUSE11Distro() + else: + return SUSEDistro() + elif distro_name == "debian": + return DebianDistro() + elif distro_name == "redhat" or distro_name == "centos" or \ + distro_name == "oracle": + if Version(distro_version) < Version("7"): + return Redhat6xDistro() + else: + return RedhatDistro() + else: + logger.warn("Unable to load distro implemetation for {0}.", distro_name) logger.warn("Use default distro implemetation instead.") - return default_loader - -DISTRO_LOADER = get_distro_loader() - -def get_osutil(): - try: - return DISTRO_LOADER.get_osutil() - except AttributeError: - return default_loader.get_osutil() - -def get_handlers(): - try: - return DISTRO_LOADER.get_handlers() - except AttributeError: - return default_loader.get_handlers() + return DefaultDistro() diff --git a/azurelinuxagent/distro/oracle/__init__.py b/azurelinuxagent/distro/oracle/__init__.py deleted file mode 100644 index d9b82f5..0000000 --- a/azurelinuxagent/distro/oracle/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - diff --git a/azurelinuxagent/distro/ubuntu/handlerFactory.py b/azurelinuxagent/distro/redhat/distro.py index 11f7f04..2f128d7 100644 --- a/azurelinuxagent/distro/ubuntu/handlerFactory.py +++ b/azurelinuxagent/distro/redhat/distro.py @@ -17,13 +17,16 @@ # Requires Python 2.4+ and Openssl 1.0+ # -from azurelinuxagent.distro.ubuntu.provision import UbuntuProvisionHandler -from azurelinuxagent.distro.ubuntu.deprovision import UbuntuDeprovisionHandler -from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.redhat.osutil import RedhatOSUtil, Redhat6xOSUtil +from azurelinuxagent.distro.coreos.deprovision import CoreOSDeprovisionHandler -class UbuntuHandlerFactory(DefaultHandlerFactory): +class Redhat6xDistro(DefaultDistro): def __init__(self): - super(UbuntuHandlerFactory, self).__init__() - self.provision_handler = UbuntuProvisionHandler() - self.deprovision_handler = UbuntuDeprovisionHandler() + super(Redhat6xDistro, self).__init__() + self.osutil = Redhat6xOSUtil() +class RedhatDistro(DefaultDistro): + def __init__(self): + super(RedhatDistro, self).__init__() + self.osutil = RedhatOSUtil() diff --git a/azurelinuxagent/distro/redhat/loader.py b/azurelinuxagent/distro/redhat/loader.py deleted file mode 100644 index 8d3c75b..0000000 --- a/azurelinuxagent/distro/redhat/loader.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION - -def get_osutil(): - from azurelinuxagent.distro.redhat.osutil import Redhat6xOSUtil, RedhatOSUtil - if DISTRO_VERSION < "7": - return Redhat6xOSUtil() - else: - return RedhatOSUtil() - diff --git a/azurelinuxagent/distro/redhat/osutil.py b/azurelinuxagent/distro/redhat/osutil.py index 7478867..7f769a5 100644 --- a/azurelinuxagent/distro/redhat/osutil.py +++ b/azurelinuxagent/distro/redhat/osutil.py @@ -26,20 +26,19 @@ import struct import fcntl import time import base64 +import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.future import text, bytebuffer +from azurelinuxagent.future import ustr, bytebuffer +from azurelinuxagent.exception import OSUtilError, CryptError import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.textutil as textutil -from azurelinuxagent.distro.default.osutil import DefaultOSUtil, OSUtilError +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.distro.default.osutil import DefaultOSUtil class Redhat6xOSUtil(DefaultOSUtil): def __init__(self): super(Redhat6xOSUtil, self).__init__() - self.sshd_conf_file_path = '/etc/ssh/sshd_config' - self.openssl_cmd = '/usr/bin/openssl' - self.conf_file_path = '/etc/waagent.conf' - self.selinux=None def start_network(self): return shellutil.run("/sbin/service networking start", chk_err=False) @@ -58,63 +57,14 @@ class Redhat6xOSUtil(DefaultOSUtil): def unregister_agent_service(self): return shellutil.run("chkconfig --del waagent", chk_err=False) - - def asn1_to_ssh_rsa(self, pubkey): - lines = pubkey.split("\n") - lines = [x for x in lines if not x.startswith("----")] - base64_encoded = "".join(lines) - try: - #TODO remove pyasn1 dependency - from pyasn1.codec.der import decoder as der_decoder - der_encoded = base64.b64decode(base64_encoded) - der_encoded = der_decoder.decode(der_encoded)[0][1] - key = der_decoder.decode(self.bits_to_bytes(der_encoded))[0] - n=key[0] - e=key[1] - keydata = bytearray() - keydata.extend(struct.pack('>I', len("ssh-rsa"))) - keydata.extend(b"ssh-rsa") - keydata.extend(struct.pack('>I', len(self.num_to_bytes(e)))) - keydata.extend(self.num_to_bytes(e)) - keydata.extend(struct.pack('>I', len(self.num_to_bytes(n)) + 1)) - keydata.extend(b"\0") - keydata.extend(self.num_to_bytes(n)) - keydata_base64 = base64.b64encode(bytebuffer(keydata)) - return text(b"ssh-rsa " + keydata_base64 + b"\n", - encoding='utf-8') - except ImportError as e: - raise OSUtilError("Failed to load pyasn1.codec.der") - - def num_to_bytes(self, num): - """ - Pack number into bytes. Retun as string. - """ - result = bytearray() - while num: - result.append(num & 0xFF) - num >>= 8 - result.reverse() - return result - - def bits_to_bytes(self, bits): - """ - Convert an array contains bits, [0,1] to a byte array - """ - index = 7 - byte_array = bytearray() - curr = 0 - for bit in bits: - curr = curr | (bit << index) - index = index - 1 - if index == -1: - byte_array.append(curr) - curr = 0 - index = 7 - return bytes(byte_array) - + def openssl_to_openssh(self, input_file, output_file): pubkey = fileutil.read_file(input_file) - ssh_rsa_pubkey = self.asn1_to_ssh_rsa(pubkey) + try: + cryptutil = CryptUtil(conf.get_openssl_cmd()) + ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey) + except CryptError as e: + raise OSUtilError(ustr(e)) fileutil.write_file(output_file, ssh_rsa_pubkey) #Override @@ -134,8 +84,7 @@ class Redhat6xOSUtil(DefaultOSUtil): def set_dhcp_hostname(self, hostname): ifname = self.get_if_name() filepath = "/etc/sysconfig/network-scripts/ifcfg-{0}".format(ifname) - fileutil.update_conf_file(filepath, - 'DHCP_HOSTNAME', + fileutil.update_conf_file(filepath, 'DHCP_HOSTNAME', 'DHCP_HOSTNAME={0}'.format(hostname)) class RedhatOSUtil(Redhat6xOSUtil): @@ -162,4 +111,5 @@ class RedhatOSUtil(Redhat6xOSUtil): def unregister_agent_service(self): return shellutil.run("systemctl disable waagent", chk_err=False) - + def openssl_to_openssh(self, input_file, output_file): + DefaultOSUtil.openssl_to_openssh(self, input_file, output_file) diff --git a/azurelinuxagent/distro/default/loader.py b/azurelinuxagent/distro/suse/distro.py index 55a51e0..5b39369 100644 --- a/azurelinuxagent/distro/default/loader.py +++ b/azurelinuxagent/distro/suse/distro.py @@ -17,12 +17,16 @@ # Requires Python 2.4+ and Openssl 1.0+ # -def get_osutil(): - from azurelinuxagent.distro.default.osutil import DefaultOSUtil - return DefaultOSUtil() +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.suse.osutil import SUSE11OSUtil, SUSEOSUtil -def get_handlers(): - from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory - return DefaultHandlerFactory() +class SUSE11Distro(DefaultDistro): + def __init__(self): + super(SUSE11Distro, self).__init__() + self.osutil = SUSE11OSUtil() +class SUSEDistro(DefaultDistro): + def __init__(self): + super(SUSEDistro, self).__init__() + self.osutil = SUSEOSUtil() diff --git a/azurelinuxagent/distro/suse/loader.py b/azurelinuxagent/distro/suse/loader.py deleted file mode 100644 index b01384b..0000000 --- a/azurelinuxagent/distro/suse/loader.py +++ /dev/null @@ -1,29 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME - -def get_osutil(): - from azurelinuxagent.distro.suse.osutil import SUSE11OSUtil, SUSEOSUtil - if DISTRO_FULL_NAME=='SUSE Linux Enterprise Server' and DISTRO_VERSION < '12' \ - or DISTRO_FULL_NAME == 'openSUSE' and DISTRO_VERSION < '13.2': - return SUSE11OSUtil() - else: - return SUSEOSUtil() - diff --git a/azurelinuxagent/distro/ubuntu/deprovision.py b/azurelinuxagent/distro/ubuntu/deprovision.py index 0c3c4e5..da6e834 100644 --- a/azurelinuxagent/distro/ubuntu/deprovision.py +++ b/azurelinuxagent/distro/ubuntu/deprovision.py @@ -33,6 +33,9 @@ def del_resolv(): class UbuntuDeprovisionHandler(DeprovisionHandler): + def __init__(self, distro): + super(UbuntuDeprovisionHandler, self).__init__(distro) + def setup(self, deluser): warnings, actions = super(UbuntuDeprovisionHandler, self).setup(deluser) warnings.append("WARNING! Nameserver configuration in " diff --git a/azurelinuxagent/distro/ubuntu/distro.py b/azurelinuxagent/distro/ubuntu/distro.py new file mode 100644 index 0000000..f380f6c --- /dev/null +++ b/azurelinuxagent/distro/ubuntu/distro.py @@ -0,0 +1,55 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.distro.default.distro import DefaultDistro +from azurelinuxagent.distro.ubuntu.osutil import Ubuntu14OSUtil, \ + Ubuntu12OSUtil, \ + UbuntuOSUtil, \ + UbuntuSnappyOSUtil + +from azurelinuxagent.distro.ubuntu.provision import UbuntuProvisionHandler +from azurelinuxagent.distro.ubuntu.deprovision import UbuntuDeprovisionHandler + +class UbuntuDistro(DefaultDistro): + def __init__(self): + super(UbuntuDistro, self).__init__() + self.osutil = UbuntuOSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) + +class Ubuntu12Distro(DefaultDistro): + def __init__(self): + super(Ubuntu12Distro, self).__init__() + self.osutil = Ubuntu12OSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) + +class Ubuntu14Distro(DefaultDistro): + def __init__(self): + super(Ubuntu14Distro, self).__init__() + self.osutil = Ubuntu14OSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) + +class UbuntuSnappyDistro(DefaultDistro): + def __init__(self): + super(UbuntuSnappyDistro, self).__init__() + self.osutil = UbuntuSnappyOSUtil() + self.provision_handler = UbuntuProvisionHandler(self) + self.deprovision_handler = UbuntuDeprovisionHandler(self) diff --git a/azurelinuxagent/distro/ubuntu/loader.py b/azurelinuxagent/distro/ubuntu/loader.py deleted file mode 100644 index 3fe2239..0000000 --- a/azurelinuxagent/distro/ubuntu/loader.py +++ /dev/null @@ -1,40 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME - -def get_osutil(): - from azurelinuxagent.distro.ubuntu.osutil import Ubuntu1204OSUtil, \ - UbuntuOSUtil, \ - Ubuntu14xOSUtil, \ - UbuntuSnappyOSUtil - - if DISTRO_VERSION == "12.04": - return Ubuntu1204OSUtil() - elif DISTRO_VERSION == "14.04" or DISTRO_VERSION == "14.10": - return Ubuntu14xOSUtil() - elif DISTRO_FULL_NAME == "Snappy Ubuntu Core": - return UbuntuSnappyOSUtil() - else: - return UbuntuOSUtil() - -def get_handlers(): - from azurelinuxagent.distro.ubuntu.handlerFactory import UbuntuHandlerFactory - return UbuntuHandlerFactory() - diff --git a/azurelinuxagent/distro/ubuntu/osutil.py b/azurelinuxagent/distro/ubuntu/osutil.py index adf7660..cc4b8ef 100644 --- a/azurelinuxagent/distro/ubuntu/osutil.py +++ b/azurelinuxagent/distro/ubuntu/osutil.py @@ -31,9 +31,9 @@ import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.textutil as textutil from azurelinuxagent.distro.default.osutil import DefaultOSUtil -class Ubuntu14xOSUtil(DefaultOSUtil): +class Ubuntu14OSUtil(DefaultOSUtil): def __init__(self): - super(Ubuntu14xOSUtil, self).__init__() + super(Ubuntu14OSUtil, self).__init__() def start_network(self): return shellutil.run("service networking start", chk_err=False) @@ -44,16 +44,16 @@ class Ubuntu14xOSUtil(DefaultOSUtil): def start_agent_service(self): return shellutil.run("service walinuxagent start", chk_err=False) -class Ubuntu1204OSUtil(Ubuntu14xOSUtil): +class Ubuntu12OSUtil(Ubuntu14OSUtil): def __init__(self): - super(Ubuntu1204OSUtil, self).__init__() + super(Ubuntu12OSUtil, self).__init__() #Override def get_dhcp_pid(self): ret= shellutil.run_get_output("pidof dhclient3") return ret[1] if ret[0] == 0 else None -class UbuntuOSUtil(Ubuntu14xOSUtil): +class UbuntuOSUtil(Ubuntu14OSUtil): def __init__(self): super(UbuntuOSUtil, self).__init__() @@ -63,7 +63,7 @@ class UbuntuOSUtil(Ubuntu14xOSUtil): def unregister_agent_service(self): return shellutil.run("systemctl mask walinuxagent", chk_err=False) -class UbuntuSnappyOSUtil(Ubuntu14xOSUtil): +class UbuntuSnappyOSUtil(Ubuntu14OSUtil): def __init__(self): super(UbuntuSnappyOSUtil, self).__init__() self.conf_file_path = '/apps/walinuxagent/current/waagent.conf' diff --git a/azurelinuxagent/distro/ubuntu/provision.py b/azurelinuxagent/distro/ubuntu/provision.py index a68fe4d..330e057 100644 --- a/azurelinuxagent/distro/ubuntu/provision.py +++ b/azurelinuxagent/distro/ubuntu/provision.py @@ -20,12 +20,11 @@ import os import time import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.conf as conf -import azurelinuxagent.protocol as prot +import azurelinuxagent.protocol.ovfenv as ovfenv from azurelinuxagent.event import add_event, WALAEventOperation -from azurelinuxagent.exception import * -from azurelinuxagent.utils.osutil import OSUTIL +from azurelinuxagent.exception import ProvisionError, ProtocolError import azurelinuxagent.utils.shellutil as shellutil import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.distro.default.provision import ProvisionHandler @@ -34,49 +33,61 @@ from azurelinuxagent.distro.default.provision import ProvisionHandler On ubuntu image, provision could be disabled. """ class UbuntuProvisionHandler(ProvisionHandler): - def process(self): + def __init__(self, distro): + self.distro = distro + + def run(self): #If provision is enabled, run default provision handler - if conf.get_switch("Provisioning.Enabled", False): - super(UbuntuProvisionHandler, self).process() + if conf.get_provision_enabled(): + super(UbuntuProvisionHandler, self).run() return logger.info("run Ubuntu provision handler") - provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned") + provisioned = os.path.join(conf.get_lib_dir(), "provisioned") if os.path.isfile(provisioned): return - logger.info("Waiting cloud-init to finish provisioning.") - protocol = prot.FACTORY.get_default_protocol() + logger.info("Waiting cloud-init to copy ovf-env.xml.") + self.wait_for_ovfenv() + + protocol = self.distro.protocol_util.detect_protocol() + self.report_not_ready("Provisioning", "Starting") + logger.info("Sleep 15 seconds to prevent throttling") + time.sleep(15) #Sleep to prevent throttling try: logger.info("Wait for ssh host key to be generated.") thumbprint = self.wait_for_ssh_host_key() fileutil.write_file(provisioned, "") - logger.info("Finished provisioning") - status = prot.ProvisionStatus(status="Ready") - status.properties.certificateThumbprint = thumbprint - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) - + except ProvisionError as e: logger.error("Provision failed: {0}", e) - status = prot.ProvisionStatus(status="NotReady", - subStatus="ProvisioningFailed", - description= text(e)) - try: - protocol.report_provision_status(status) - except prot.ProtocolError as pe: - add_event(name="WALA", is_success=False, message=text(pe), - op=WALAEventOperation.Provision) + self.report_not_ready("ProvisioningFailed", ustr(e)) + self.report_event(ustr(e)) + return + + self.report_ready(thumbprint) + self.report_event("Provision succeed", is_success=True) - add_event(name="WALA", is_success=False, message=text(e), - op=WALAEventOperation.Provision) + def wait_for_ovfenv(self, max_retry=60): + """ + Wait for cloud-init to copy ovf-env.xml file from provision ISO + """ + for retry in range(0, max_retry): + try: + self.distro.protocol_util.get_ovf_env() + return + except ProtocolError: + if retry < max_retry - 1: + logger.info("Wait for cloud-init to copy ovf-env.xml") + time.sleep(5) + raise ProvisionError("ovf-env.xml is not copied") def wait_for_ssh_host_key(self, max_retry=60): - kepair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa") + """ + Wait for cloud-init to generate ssh host key + """ + kepair_type = conf.get_ssh_host_keypair_type() path = '/etc/ssh/ssh_host_{0}_key'.format(kepair_type) for retry in range(0, max_retry): if os.path.isfile(path): diff --git a/azurelinuxagent/event.py b/azurelinuxagent/event.py index 02e8017..f38b242 100644 --- a/azurelinuxagent/event.py +++ b/azurelinuxagent/event.py @@ -25,14 +25,15 @@ import datetime import threading import platform import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.protocol as prot +from azurelinuxagent.exception import EventError, ProtocolError +from azurelinuxagent.future import ustr +from azurelinuxagent.protocol.restapi import TelemetryEventParam, \ + TelemetryEventList, \ + TelemetryEvent, \ + set_properties, get_properties from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, \ DISTRO_CODE_NAME, AGENT_VERSION -from azurelinuxagent.utils.osutil import OSUTIL -class EventError(Exception): - pass class WALAEventOperation: HeartBeat="HeartBeat" @@ -47,132 +48,65 @@ class WALAEventOperation: ActivateResourceDisk="ActivateResourceDisk" UnhandledError="UnhandledError" -class EventMonitor(object): +class EventLogger(object): def __init__(self): - self.sysinfo = [] - self.event_dir = os.path.join(OSUTIL.get_lib_dir(), "events") + self.event_dir = None - def init_sysinfo(self): - osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(), - DISTRO_NAME, - DISTRO_VERSION, - DISTRO_CODE_NAME, - platform.release()) + def save_event(self, data): + if self.event_dir is None: + logger.warn("Event reporter is not initialized.") + return - self.sysinfo.append(prot.TelemetryEventParam("OSVersion", osversion)) - self.sysinfo.append(prot.TelemetryEventParam("GAVersion", - AGENT_VERSION)) - self.sysinfo.append(prot.TelemetryEventParam("RAM", - OSUTIL.get_total_mem())) - self.sysinfo.append(prot.TelemetryEventParam("Processors", - OSUTIL.get_processor_cores())) - try: - protocol = prot.FACTORY.get_default_protocol() - vminfo = protocol.get_vminfo() - self.sysinfo.append(prot.TelemetryEventParam("VMName", - vminfo.vmName)) - #TODO add other system info like, subscription id, etc. - except prot.ProtocolError as e: - logger.warn("Failed to get vm info: {0}", e) - - def start(self): - event_thread = threading.Thread(target = self.run) - event_thread.setDaemon(True) - event_thread.start() + if not os.path.exists(self.event_dir): + os.mkdir(self.event_dir) + os.chmod(self.event_dir, 0o700) + if len(os.listdir(self.event_dir)) > 1000: + raise EventError("Too many files under: {0}".format(self.event_dir)) - def collect_event(self, evt_file_name): + filename = os.path.join(self.event_dir, ustr(int(time.time()*1000000))) try: - logger.verb("Found event file: {0}", evt_file_name) - with open(evt_file_name, "rb") as evt_file: - #if fail to open or delete the file, throw exception - json_str = evt_file.read().decode("utf-8",'ignore') - logger.verb("Processed event file: {0}", evt_file_name) - os.remove(evt_file_name) - return json_str + with open(filename+".tmp",'wb+') as hfile: + hfile.write(data.encode("utf-8")) + os.rename(filename+".tmp", filename+".tld") except IOError as e: - msg = "Failed to process {0}, {1}".format(evt_file_name, e) - raise EventError(msg) - - def collect_and_send_events(self): - event_list = prot.TelemetryEventList() - event_files = os.listdir(self.event_dir) - for event_file in event_files: - if not event_file.endswith(".tld"): - continue - event_file_path = os.path.join(self.event_dir, event_file) - try: - data_str = self.collect_event(event_file_path) - except EventError as e: - logger.error("{0}", e) - continue - try: - data = json.loads(data_str) - except ValueError as e: - logger.verb(data_str) - logger.error("Failed to decode json event file: {0}", e) - continue - - event = prot.TelemetryEvent() - prot.set_properties("event", event, data) - event.parameters.extend(self.sysinfo) - event_list.events.append(event) - if len(event_list.events) == 0: - return - + raise EventError("Failed to write events to file:{0}", e) + + def add_event(self, name, op="", is_success=True, duration=0, version="1.0", + message="", evt_type="", is_internal=False): + event = TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") + event.parameters.append(TelemetryEventParam('Name', name)) + event.parameters.append(TelemetryEventParam('Version', version)) + event.parameters.append(TelemetryEventParam('IsInternal', is_internal)) + event.parameters.append(TelemetryEventParam('Operation', op)) + event.parameters.append(TelemetryEventParam('OperationSuccess', + is_success)) + event.parameters.append(TelemetryEventParam('Message', message)) + event.parameters.append(TelemetryEventParam('Duration', duration)) + event.parameters.append(TelemetryEventParam('ExtensionType', evt_type)) + + data = get_properties(event) try: - protocol = prot.FACTORY.get_default_protocol() - protocol.report_event(event_list) - except prot.ProtocolError as e: + self.save_event(json.dumps(data)) + except EventError as e: logger.error("{0}", e) - def run(self): - self.init_sysinfo() - last_heartbeat = datetime.datetime.min - period = datetime.timedelta(hours = 12) - while(True): - if (datetime.datetime.now()-last_heartbeat) > period: - last_heartbeat = datetime.datetime.now() - add_event(op=WALAEventOperation.HeartBeat, - name="WALA",is_success=True) - self.collect_and_send_events() - time.sleep(60) - -def save_event(data): - event_dir = os.path.join(OSUTIL.get_lib_dir(), 'events') - if not os.path.exists(event_dir): - os.mkdir(event_dir) - os.chmod(event_dir,0o700) - if len(os.listdir(event_dir)) > 1000: - raise EventError("Too many files under: {0}", event_dir) - - filename = os.path.join(event_dir, text(int(time.time()*1000000))) - try: - with open(filename+".tmp",'wb+') as hfile: - hfile.write(data.encode("utf-8")) - os.rename(filename+".tmp", filename+".tld") - except IOError as e: - raise EventError("Failed to write events to file:{0}", e) +__event_logger__ = EventLogger() def add_event(name, op="", is_success=True, duration=0, version="1.0", - message="", evt_type="", is_internal=False): + message="", evt_type="", is_internal=False, + reporter=__event_logger__): log = logger.info if is_success else logger.error log("Event: name={0}, op={1}, message={2}", name, op, message) - event = prot.TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") - event.parameters.append(prot.TelemetryEventParam('Name', name)) - event.parameters.append(prot.TelemetryEventParam('Version', version)) - event.parameters.append(prot.TelemetryEventParam('IsInternal', is_internal)) - event.parameters.append(prot.TelemetryEventParam('Operation', op)) - event.parameters.append(prot.TelemetryEventParam('OperationSuccess', - is_success)) - event.parameters.append(prot.TelemetryEventParam('Message', message)) - event.parameters.append(prot.TelemetryEventParam('Duration', duration)) - event.parameters.append(prot.TelemetryEventParam('ExtensionType', evt_type)) - data = prot.get_properties(event) - try: - save_event(json.dumps(data)) - except EventError as e: - logger.error("{0}", e) + if reporter.event_dir is None: + logger.warn("Event reporter is not initialized.") + return + reporter.add_event(name, op=op, is_success=is_success, duration=duration, + version=version, message=message, evt_type=evt_type, + is_internal=is_internal) + +def init_event_logger(event_dir, reporter=__event_logger__): + reporter.event_dir = event_dir def dump_unhandled_err(name): if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \ @@ -184,8 +118,7 @@ def dump_unhandled_err(name): last_traceback) message= "".join(error) add_event(name, is_success=False, message=message, - op=WALAEventOperation.UnhandledError) + op=WALAEventOperation.UnhandledError) def enable_unhandled_err_dump(name): atexit.register(dump_unhandled_err, name) - diff --git a/azurelinuxagent/exception.py b/azurelinuxagent/exception.py index d7d9b0a..7fa5cff 100644 --- a/azurelinuxagent/exception.py +++ b/azurelinuxagent/exception.py @@ -24,42 +24,92 @@ class AgentError(Exception): """ Base class of agent error. """ - def __init__(self, errno, msg): - msg = "({0}){1}".format(errno, msg) + def __init__(self, errno, msg, inner=None): + msg = u"({0}){1}".format(errno, msg) + if inner is not None: + msg = u"{0} \n inner error: {1}".format(msg, inner) super(AgentError, self).__init__(msg) class AgentConfigError(AgentError): """ When configure file is not found or malformed. """ - def __init__(self, msg): - super(AgentConfigError, self).__init__('000001', msg) + def __init__(self, msg=None, inner=None): + super(AgentConfigError, self).__init__('000001', msg, inner) class AgentNetworkError(AgentError): """ When network is not avaiable. """ - def __init__(self, msg): - super(AgentNetworkError, self).__init__('000002', msg) + def __init__(self, msg=None, inner=None): + super(AgentNetworkError, self).__init__('000002', msg, inner) class ExtensionError(AgentError): """ When failed to execute an extension """ - def __init__(self, msg): - super(ExtensionError, self).__init__('000003', msg) + def __init__(self, msg=None, inner=None): + super(ExtensionError, self).__init__('000003', msg, inner) class ProvisionError(AgentError): """ When provision failed """ - def __init__(self, msg): - super(ProvisionError, self).__init__('000004', msg) + def __init__(self, msg=None, inner=None): + super(ProvisionError, self).__init__('000004', msg, inner) class ResourceDiskError(AgentError): """ Mount resource disk failed """ - def __init__(self, msg): - super(ResourceDiskError, self).__init__('000005', msg) + def __init__(self, msg=None, inner=None): + super(ResourceDiskError, self).__init__('000005', msg, inner) +class DhcpError(AgentError): + """ + Failed to handle dhcp response + """ + def __init__(self, msg=None, inner=None): + super(DhcpError, self).__init__('000006', msg, inner) + +class OSUtilError(AgentError): + """ + Failed to perform operation to OS configuration + """ + def __init__(self, msg=None, inner=None): + super(OSUtilError, self).__init__('000007', msg, inner) + +class ProtocolError(AgentError): + """ + Azure protocol error + """ + def __init__(self, msg=None, inner=None): + super(ProtocolError, self).__init__('000008', msg, inner) + +class ProtocolNotFoundError(ProtocolError): + """ + Azure protocol endpoint not found + """ + def __init__(self, msg=None, inner=None): + super(ProtocolNotFoundError, self).__init__(msg, inner) + +class HttpError(AgentError): + """ + Http request failure + """ + def __init__(self, msg=None, inner=None): + super(HttpError, self).__init__('000009', msg, inner) + +class EventError(AgentError): + """ + Event reporting error + """ + def __init__(self, msg=None, inner=None): + super(EventError, self).__init__('000010', msg, inner) + +class CryptError(AgentError): + """ + Encrypt/Decrypt error + """ + def __init__(self, msg=None, inner=None): + super(CryptError, self).__init__('000011', msg, inner) diff --git a/azurelinuxagent/future.py b/azurelinuxagent/future.py index 8451345..8509732 100644 --- a/azurelinuxagent/future.py +++ b/azurelinuxagent/future.py @@ -7,15 +7,25 @@ Add alies for python2 and python3 libs and fucntions. if sys.version_info[0]== 3: import http.client as httpclient from urllib.parse import urlparse - text = str + + """Rename Python3 str to ustr""" + ustr = str + bytebuffer = memoryview + read_input = input + elif sys.version_info[0] == 2: import httplib as httpclient from urlparse import urlparse - text = unicode + + """Rename Python2 unicode to ustr""" + ustr = unicode + bytebuffer = buffer + read_input = raw_input + else: raise ImportError("Unknown python version:{0}".format(sys.version_info)) diff --git a/azurelinuxagent/handler.py b/azurelinuxagent/handler.py deleted file mode 100644 index 538ee30..0000000 --- a/azurelinuxagent/handler.py +++ /dev/null @@ -1,28 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -""" -Handler handles different tasks like, provisioning, deprovisioning etc. -The handlers could be extended for different distros. The default -implementation is under azurelinuxagent.distros.default -""" -import azurelinuxagent.distro.loader as loader - -HANDLERS = loader.get_handlers() - diff --git a/azurelinuxagent/logger.py b/azurelinuxagent/logger.py index 21c02a6..6c6b406 100644 --- a/azurelinuxagent/logger.py +++ b/azurelinuxagent/logger.py @@ -22,7 +22,7 @@ Log utils """ import os import sys -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr from datetime import datetime class Logger(object): @@ -49,8 +49,8 @@ class Logger(object): def log(self, level, msg_format, *args): #if msg_format is not unicode convert it to unicode - if type(msg_format) is not text: - msg_format = text(msg_format, errors="backslashreplace") + if type(msg_format) is not ustr: + msg_format = ustr(msg_format, errors="backslashreplace") if len(args) > 0: msg = msg_format.format(*args) else: @@ -63,7 +63,7 @@ class Logger(object): else: log_item = u"{0} {1} {2}\n".format(time, level_str, msg) - log_item = text(log_item.encode('ascii', "backslashreplace"), + log_item = ustr(log_item.encode('ascii', "backslashreplace"), encoding="ascii") for appender in self.appenders: appender.write(level, log_item) diff --git a/azurelinuxagent/metadata.py b/azurelinuxagent/metadata.py index 5cf4902..34fdcf9 100644 --- a/azurelinuxagent/metadata.py +++ b/azurelinuxagent/metadata.py @@ -22,11 +22,11 @@ import re import platform import sys import azurelinuxagent.utils.fileutil as fileutil -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr def get_distro(): if 'FreeBSD' in platform.system(): - release = re.sub('\-.*\Z', '', text(platform.release())) + release = re.sub('\-.*\Z', '', ustr(platform.release())) osinfo = ['freebsd', release, '', 'freebsd'] if 'linux_distribution' in dir(platform): osinfo = list(platform.linux_distribution(full_distribution_name=0)) @@ -47,7 +47,7 @@ def get_distro(): AGENT_NAME = "WALinuxAgent" AGENT_LONG_NAME = "Azure Linux Agent" -AGENT_VERSION = '2.1.2' +AGENT_VERSION = '2.1.3' AGENT_LONG_VERSION = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION) AGENT_DESCRIPTION = """\ The Azure Linux Agent supports the provisioning and running of Linux diff --git a/azurelinuxagent/protocol/__init__.py b/azurelinuxagent/protocol/__init__.py index a4572e6..8c1bbdb 100644 --- a/azurelinuxagent/protocol/__init__.py +++ b/azurelinuxagent/protocol/__init__.py @@ -16,8 +16,3 @@ # # Requires Python 2.4+ and Openssl 1.0+ # - -from azurelinuxagent.protocol.common import * -from azurelinuxagent.protocol.protocolFactory import FACTORY, \ - detect_default_protocol - diff --git a/azurelinuxagent/protocol/v2.py b/azurelinuxagent/protocol/metadata.py index 34102b7..8a1656f 100644 --- a/azurelinuxagent/protocol/v2.py +++ b/azurelinuxagent/protocol/metadata.py @@ -17,16 +17,30 @@ # Requires Python 2.4+ and Openssl 1.0+ import json -from azurelinuxagent.future import httpclient, text +import shutil +import os +import time +from azurelinuxagent.exception import ProtocolError, HttpError +from azurelinuxagent.future import httpclient, ustr +import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger import azurelinuxagent.utils.restutil as restutil -from azurelinuxagent.protocol.common import * +import azurelinuxagent.utils.textutil as textutil +import azurelinuxagent.utils.fileutil as fileutil +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.protocol.restapi import * -ENDPOINT='169.254.169.254' -#TODO use http for azure pack test -#ENDPOINT='localhost' +METADATA_ENDPOINT='169.254.169.254' APIVERSION='2015-05-01-preview' BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" +TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem" +TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem" + +#TODO remote workarround for azure stack +MAX_PING = 30 +RETRY_PING_INTERVAL = 10 + def _add_content_type(headers): if headers is None: headers = {} @@ -35,7 +49,7 @@ def _add_content_type(headers): class MetadataProtocol(Protocol): - def __init__(self, apiversion=APIVERSION, endpoint=ENDPOINT): + def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT): self.apiversion = apiversion self.endpoint = endpoint self.identity_uri = BASE_URI.format(self.endpoint, "identity", @@ -58,24 +72,25 @@ class MetadataProtocol(Protocol): def _get_data(self, url, headers=None): try: resp = restutil.http_get(url, headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status != httpclient.OK: raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) data = resp.read() + etag = resp.getheader('ETag') if data is None: return None - data = json.loads(text(data, encoding="utf-8")) - return data + data = json.loads(ustr(data, encoding="utf-8")) + return data, etag def _put_data(self, url, data, headers=None): headers = _add_content_type(headers) try: resp = restutil.http_put(url, json.dumps(data), headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status != httpclient.OK: raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) @@ -83,17 +98,41 @@ class MetadataProtocol(Protocol): headers = _add_content_type(headers) try: resp = restutil.http_post(url, json.dumps(data), headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status != httpclient.CREATED: raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) + + def _get_trans_cert(self): + trans_crt_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + if not os.path.isfile(trans_crt_file): + raise ProtocolError("{0} is missing.".format(trans_crt_file)) + content = fileutil.read_file(trans_crt_file) + return textutil.get_bytes_from_pem(content) + + def detect(self): + self.get_vminfo() + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + #"Install" the cert and private key to /var/lib/waagent + thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) + prv_file = os.path.join(conf.get_lib_dir(), + "{0}.prv".format(thumbprint)) + crt_file = os.path.join(conf.get_lib_dir(), + "{0}.crt".format(thumbprint)) + shutil.copyfile(trans_prv_file, prv_file) + shutil.copyfile(trans_cert_file, crt_file) - def initialize(self): - pass def get_vminfo(self): vminfo = VMInfo() - data = self._get_data(self.identity_uri) + data, etag = self._get_data(self.identity_uri) set_properties("vminfo", vminfo, data) return vminfo @@ -102,17 +141,20 @@ class MetadataProtocol(Protocol): return CertList() def get_ext_handlers(self): + headers = { + "x-ms-vmagent-public-x509-cert": self._get_trans_cert() + } ext_list = ExtHandlerList() - data = self._get_data(self.ext_uri) + data, etag = self._get_data(self.ext_uri, headers=headers) set_properties("extensionHandlers", ext_list.extHandlers, data) - return ext_list + return ext_list, etag def get_ext_handler_pkgs(self, ext_handler): ext_handler_pkgs = ExtHandlerPackageList() data = None for version_uri in ext_handler.versionUris: try: - data = self._get_data(version_uri.uri) + data, etag = self._get_data(version_uri.uri) break except ProtocolError as e: logger.warn("Failed to get version uris: {0}", e) @@ -128,6 +170,14 @@ class MetadataProtocol(Protocol): def report_vm_status(self, vm_status): validata_param('vmStatus', vm_status, VMStatus) data = get_properties(vm_status) + #TODO code field is not implemented for metadata protocol yet. Remove it + handler_statuses = data['vmAgent']['extensionHandlers'] + for handler_status in handler_statuses: + try: + handler_status.pop('code', None) + except KeyError: + pass + self._put_data(self.vm_status_uri, data) def report_ext_status(self, ext_handler_name, ext_name, ext_status): diff --git a/azurelinuxagent/protocol/ovfenv.py b/azurelinuxagent/protocol/ovfenv.py index 9c845ee..de6791c 100644 --- a/azurelinuxagent/protocol/ovfenv.py +++ b/azurelinuxagent/protocol/ovfenv.py @@ -17,60 +17,22 @@ # Requires Python 2.4+ and Openssl 1.0+ # """ -Copy and parse ovf-env.xml from provisiong ISO and local cache +Copy and parse ovf-env.xml from provisioning ISO and local cache """ import os import re +import shutil import xml.dom.minidom as minidom import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.exception import ProtocolError +from azurelinuxagent.future import ustr import azurelinuxagent.utils.fileutil as fileutil from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext -from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError -from azurelinuxagent.protocol import ProtocolError -OVF_FILE_NAME = "ovf-env.xml" OVF_VERSION = "1.0" OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1" WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure" -def get_ovf_env(): - """ - Load saved ovf-env.xml - """ - ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME) - if os.path.isfile(ovf_file_path): - xml_text = fileutil.read_file(ovf_file_path) - return OvfEnv(xml_text) - else: - raise ProtocolError("ovf-env.xml is missing.") - -def copy_ovf_env(): - """ - Copy ovf env file from dvd to hard disk. - Remove password before save it to the disk - """ - try: - OSUTIL.mount_dvd() - ovf_file_path_on_dvd = OSUTIL.get_ovf_env_file_path_on_dvd() - ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) - ovfenv = OvfEnv(ovfxml) - ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml) - ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME) - fileutil.write_file(ovf_file_path, ovfxml) - except IOError as e: - raise ProtocolError(text(e)) - except OSUtilError as e: - raise ProtocolError(text(e)) - - try: - OSUTIL.umount_dvd() - OSUTIL.eject_dvd() - except OSUtilError as e: - logger.warn(text(e)) - - return ovfenv - def _validate_ovf(val, msg): if val is None: raise ProtocolError("Failed to parse OVF XML: {0}".format(msg)) diff --git a/azurelinuxagent/protocol/protocolFactory.py b/azurelinuxagent/protocol/protocolFactory.py deleted file mode 100644 index 0bf6e52..0000000 --- a/azurelinuxagent/protocol/protocolFactory.py +++ /dev/null @@ -1,114 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# -import os -import traceback -import threading -import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.utils.fileutil as fileutil -from azurelinuxagent.utils.osutil import OSUTIL -from azurelinuxagent.protocol.common import * -from azurelinuxagent.protocol.v1 import WireProtocol -from azurelinuxagent.protocol.v2 import MetadataProtocol - -WIRE_SERVER_ADDR_FILE_NAME = "WireServer" - -def get_wire_protocol_endpoint(): - path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME) - try: - endpoint = fileutil.read_file(path) - except IOError as e: - raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e)) - - if endpoint is None: - raise ProtocolNotFound("Wire server endpoint is None") - - return endpoint - -def detect_wire_protocol(): - endpoint = get_wire_protocol_endpoint() - - OSUTIL.gen_transport_cert() - protocol = WireProtocol(endpoint) - protocol.initialize() - logger.info("Protocol V1 found.") - return protocol - -def detect_metadata_protocol(): - protocol = MetadataProtocol() - protocol.initialize() - - logger.info("Protocol V2 found.") - return protocol - -def detect_available_protocols(prob_funcs=[detect_wire_protocol, - detect_metadata_protocol]): - available_protocols = [] - for probe_func in prob_funcs: - try: - protocol = probe_func() - available_protocols.append(protocol) - except ProtocolNotFound as e: - logger.info(text(e)) - return available_protocols - -def detect_default_protocol(): - logger.info("Detect default protocol.") - available_protocols = detect_available_protocols() - return choose_default_protocol(available_protocols) - -def choose_default_protocol(protocols): - if len(protocols) > 0: - return protocols[0] - else: - raise ProtocolNotFound("No available protocol detected.") - -def get_wire_protocol(): - endpoint = get_wire_protocol_endpoint() - return WireProtocol(endpoint) - -def get_metadata_protocol(): - return MetadataProtocol() - -def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]): - available_protocols = [] - for getter in getters: - try: - protocol = getter() - available_protocols.append(protocol) - except ProtocolNotFound as e: - logger.info(text(e)) - return available_protocols - -class ProtocolFactory(object): - def __init__(self): - self._protocol = None - self._lock = threading.Lock() - - def get_default_protocol(self): - if self._protocol is None: - self._lock.acquire() - if self._protocol is None: - available_protocols = get_available_protocols() - self._protocol = choose_default_protocol(available_protocols) - self._lock.release() - - return self._protocol - -FACTORY = ProtocolFactory() diff --git a/azurelinuxagent/protocol/common.py b/azurelinuxagent/protocol/restapi.py index 367794f..fbd29ed 100644 --- a/azurelinuxagent/protocol/common.py +++ b/azurelinuxagent/protocol/restapi.py @@ -22,14 +22,9 @@ import re import json import xml.dom.minidom import azurelinuxagent.logger as logger -from azurelinuxagent.future import text -import azurelinuxagent.utils.fileutil as fileutil - -class ProtocolError(Exception): - pass - -class ProtocolNotFound(Exception): - pass +from azurelinuxagent.exception import ProtocolError, HttpError +from azurelinuxagent.future import ustr +import azurelinuxagent.utils.restutil as restutil def validata_param(name, val, expected_type): if val is None: @@ -88,9 +83,14 @@ class DataContractList(list): Data contract between guest and host """ class VMInfo(DataContract): - def __init__(self, subscriptionId=None, vmName=None): + def __init__(self, subscriptionId=None, vmName=None, containerId=None, + roleName=None, roleInstanceName=None, tenantName=None): self.subscriptionId = subscriptionId self.vmName = vmName + self.containerId = containerId + self.roleName = roleName + self.roleInstanceName = roleInstanceName + self.tenantName = tenantName class Cert(DataContract): def __init__(self, name=None, thumbprint=None, certificateDataUri=None): @@ -104,11 +104,11 @@ class CertList(DataContract): class Extension(DataContract): def __init__(self, name=None, sequenceNumber=None, publicSettings=None, - privateSettings=None, certificateThumbprint=None): + protectedSettings=None, certificateThumbprint=None): self.name = name self.sequenceNumber = sequenceNumber self.publicSettings = publicSettings - self.privateSettings = privateSettings + self.protectedSettings = protectedSettings self.certificateThumbprint = certificateThumbprint class ExtHandlerProperties(DataContract): @@ -176,12 +176,14 @@ class ExtensionStatus(DataContract): self.substatusList = DataContractList(ExtensionSubStatus) class ExtHandlerStatus(DataContract): - def __init__(self, name=None, version=None, status=None, message=None): + def __init__(self, name=None, version=None, status=None, code=0, + message=None): self.name = name self.version = version self.status = status + self.code = code self.message = message - self.extensions = DataContractList(text) + self.extensions = DataContractList(ustr) class VMAgentStatus(DataContract): def __init__(self, version=None, status=None, message=None): @@ -211,7 +213,7 @@ class TelemetryEventList(DataContract): class Protocol(DataContract): - def initialize(self): + def detect(self): raise NotImplementedError() def get_vminfo(self): @@ -226,6 +228,14 @@ class Protocol(DataContract): def get_ext_handler_pkgs(self, extension): raise NotImplementedError() + def download_ext_handler_pkg(self, uri): + try: + resp = restutil.http_get(uri, chk_proxy=True) + if resp.status == restutil.httpclient.OK: + return resp.read() + except HttpError as e: + raise ProtocolError("Failed to download from: {0}".format(uri), e) + def report_provision_status(self, provision_status): raise NotImplementedError() diff --git a/azurelinuxagent/protocol/v1.py b/azurelinuxagent/protocol/wire.py index 92fcc06..7b5ffe8 100644 --- a/azurelinuxagent/protocol/v1.py +++ b/azurelinuxagent/protocol/wire.py @@ -22,16 +22,19 @@ import re import time import traceback import xml.sax.saxutils as saxutils -import xml.etree.ElementTree as ET +import azurelinuxagent.conf as conf import azurelinuxagent.logger as logger -from azurelinuxagent.future import text, httpclient, bytebuffer +from azurelinuxagent.exception import ProtocolError, HttpError, \ + ProtocolNotFoundError +from azurelinuxagent.future import ustr, httpclient, bytebuffer import azurelinuxagent.utils.restutil as restutil from azurelinuxagent.utils.textutil import parse_doc, findall, find, findtext, \ - getattrib, gettext, remove_bom -from azurelinuxagent.utils.osutil import OSUTIL + getattrib, gettext, remove_bom, \ + get_bytes_from_pem import azurelinuxagent.utils.fileutil as fileutil import azurelinuxagent.utils.shellutil as shellutil -from azurelinuxagent.protocol.common import * +from azurelinuxagent.utils.cryptutil import CryptUtil +from azurelinuxagent.protocol.restapi import * VERSION_INFO_URI = "http://{0}/?comp=versions" GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate" @@ -53,6 +56,7 @@ TRANSPORT_CERT_FILE_NAME = "TransportCert.pem" TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" PROTOCOL_VERSION = "2012-11-30" +ENDPOINT_FINE_NAME = "WireServer" SHORT_WAITING_INTERVAL = 1 # 1 second LONG_WAITING_INTERVAL = 15 # 15 seconds @@ -61,19 +65,37 @@ class WireProtocolResourceGone(ProtocolError): pass class WireProtocol(Protocol): + """Slim layer to adapte wire protocol data to metadata protocol interface""" def __init__(self, endpoint): - self.client = WireClient(endpoint) + if endpoint is None: + raise ProtocolError("WireProtocl endpoint is None") + self.endpoint = endpoint + self.client = WireClient(self.endpoint) - def initialize(self): + def detect(self): self.client.check_wire_protocol_version() + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + self.client.update_goal_state(forced=True) def get_vminfo(self): + goal_state = self.client.get_goal_state() hosting_env = self.client.get_hosting_env() + vminfo = VMInfo() vminfo.subscriptionId = None vminfo.vmName = hosting_env.vm_name + vminfo.tenantName = hosting_env.deployment_name + vminfo.roleName = hosting_env.role_name + vminfo.roleInstanceName = goal_state.role_instance_id + vminfo.containerId = goal_state.container_id return vminfo def get_certs(self): @@ -81,12 +103,16 @@ class WireProtocol(Protocol): return certificates.cert_list def get_ext_handlers(self): + logger.verb("Get extension handler config") #Update goal state to get latest extensions config self.client.update_goal_state() + goal_state = self.client.get_goal_state() ext_conf = self.client.get_ext_conf() - return ext_conf.ext_handlers + #In wire protocol, incarnation is equivalent to ETag + return ext_conf.ext_handlers, goal_state.incarnation def get_ext_handler_pkgs(self, ext_handler): + logger.verb("Get extension handler package") goal_state = self.client.get_goal_state() man = self.client.get_ext_manifest(ext_handler, goal_state) return man.pkg_list @@ -134,12 +160,12 @@ def _build_role_properties(container_id, role_instance_id, thumbprint): return xml def _build_health_report(incarnation, container_id, role_instance_id, - status, substatus, description): + status, substatus, description): #Escape '&', '<' and '>' - description = saxutils.escape(text(description)) + description = saxutils.escape(ustr(description)) detail = u'' if substatus is not None: - substatus = saxutils.escape(text(substatus)) + substatus = saxutils.escape(ustr(substatus)) detail = (u"<Details>" u"<SubStatus>{0}</SubStatus>" u"<Description>{1}</Description>" @@ -228,6 +254,7 @@ def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): 'handlerVersion' : handler_status.version, 'handlerName' : handler_status.name, 'status' : handler_status.status, + 'code': handler_status.code } if handler_status.message is not None: v1_handler_status["formattedMessage"] = { @@ -303,7 +330,7 @@ class StatusBlob(object): self.put_page_blob(url, data) else: raise ProtocolError("Unknown blob type: {0}".format(blob_type)) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError("Failed to upload status blob: {0}".format(e)) def get_blob_type(self, url): @@ -315,7 +342,7 @@ class StatusBlob(object): "x-ms-date" : timestamp, 'x-ms-version' : self.__class__.__storage_version__ }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to get status blob type: {0}" u"").format(e)) if resp is None or resp.status != httpclient.OK: @@ -334,10 +361,10 @@ class StatusBlob(object): data, { "x-ms-date" : timestamp, "x-ms-blob-type" : "BlockBlob", - "Content-Length": text(len(data)), + "Content-Length": ustr(len(data)), "x-ms-version" : self.__class__.__storage_version__ }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to upload block blob: {0}" u"").format(e)) if resp.status != httpclient.CREATED: @@ -359,10 +386,10 @@ class StatusBlob(object): "x-ms-date" : timestamp, "x-ms-blob-type" : "PageBlob", "Content-Length": "0", - "x-ms-blob-content-length" : text(page_blob_size), + "x-ms-blob-content-length" : ustr(page_blob_size), "x-ms-version" : self.__class__.__storage_version__ }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to clean up page blob: {0}" u"").format(e)) if resp.status != httpclient.CREATED: @@ -393,9 +420,9 @@ class StatusBlob(object): "x-ms-range" : "bytes={0}-{1}".format(start, page_end - 1), "x-ms-page-write" : "update", "x-ms-version" : self.__class__.__storage_version__, - "Content-Length": text(page_end - start) + "Content-Length": ustr(page_end - start) }) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to upload page blob: {0}" u"").format(e)) if resp is None or resp.status != httpclient.CREATED: @@ -411,13 +438,13 @@ def event_param_to_v1(param): attr_type = 'mt:uint64' elif param_type is str: attr_type = 'mt:wstr' - elif text(param_type).count("'unicode'") > 0: + elif ustr(param_type).count("'unicode'") > 0: attr_type = 'mt:wstr' elif param_type is bool: attr_type = 'mt:bool' elif param_type is float: attr_type = 'mt:float64' - return param_format.format(param.name, saxutils.quoteattr(text(param.value)), + return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)), attr_type) def event_to_v1(event): @@ -431,6 +458,7 @@ def event_to_v1(event): class WireClient(object): def __init__(self, endpoint): + logger.info("Wire server endpoint:{0}", endpoint) self.endpoint = endpoint self.goal_state = None self.updated = None @@ -448,15 +476,15 @@ class WireClient(object): """ now = time.time() if now - self.last_request < 1: - logger.info("Last request issued less than 1 second ago") - logger.info("Sleep {0} second to avoid throttling.", + logger.verb("Last request issued less than 1 second ago") + logger.verb("Sleep {0} second to avoid throttling.", SHORT_WAITING_INTERVAL) time.sleep(SHORT_WAITING_INTERVAL) self.last_request = now self.req_count += 1 if self.req_count % 3 == 0: - logger.info("Sleep {0} second to avoid throttling.", + logger.verb("Sleep {0} second to avoid throttling.", SHORT_WAITING_INTERVAL) time.sleep(SHORT_WAITING_INTERVAL) self.req_count = 0 @@ -485,15 +513,15 @@ class WireClient(object): if data is None: return None data = remove_bom(data) - xml_text = text(data, encoding='utf-8') + xml_text = ustr(data, encoding='utf-8') return xml_text def fetch_config(self, uri, headers): try: resp = self.call_wireserver(restutil.http_get, uri, headers=headers) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if(resp.status != httpclient.OK): raise ProtocolError("{0} - {1}".format(resp.status, uri)) @@ -532,12 +560,13 @@ class WireClient(object): def fetch_manifest(self, version_uris): for version_uri in version_uris: + logger.verb("Fetch ext handler manifest: {0}", version_uri.uri) try: resp = self.call_storage_service(restutil.http_get, version_uri.uri, None, chk_proxy=True) - except restutil.HttpError as e: - raise ProtocolError(text(e)) + except HttpError as e: + raise ProtocolError(ustr(e)) if resp.status == httpclient.OK: return self.decode_config(resp.read()) @@ -553,7 +582,7 @@ class WireClient(object): def update_hosting_env(self, goal_state): if goal_state.hosting_env_uri is None: raise ProtocolError("HostingEnvironmentConfig uri is empty") - local_file = HOSTING_ENV_FILE_NAME + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) xml_text = self.fetch_config(goal_state.hosting_env_uri, self.get_header()) self.save_cache(local_file, xml_text) @@ -562,7 +591,7 @@ class WireClient(object): def update_shared_conf(self, goal_state): if goal_state.shared_conf_uri is None: raise ProtocolError("SharedConfig uri is empty") - local_file = SHARED_CONF_FILE_NAME + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) xml_text = self.fetch_config(goal_state.shared_conf_uri, self.get_header()) self.save_cache(local_file, xml_text) @@ -571,7 +600,7 @@ class WireClient(object): def update_certs(self, goal_state): if goal_state.certs_uri is None: return - local_file = CERTS_FILE_NAME + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) xml_text = self.fetch_config(goal_state.certs_uri, self.get_header_for_cert()) self.save_cache(local_file, xml_text) @@ -583,25 +612,18 @@ class WireClient(object): self.ext_conf = ExtensionsConfig(None) return incarnation = goal_state.incarnation - local_file = EXT_CONF_FILE_NAME.format(incarnation) + local_file = os.path.join(conf.get_lib_dir(), + EXT_CONF_FILE_NAME.format(incarnation)) xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) self.save_cache(local_file, xml_text) self.ext_conf = ExtensionsConfig(xml_text) - for ext_handler in self.ext_conf.ext_handlers.extHandlers: - self.update_ext_handler_manifest(ext_handler, goal_state) - - def update_ext_handler_manifest(self, ext_handler, goal_state): - local_file = MANIFEST_FILE_NAME.format(ext_handler.name, - goal_state.incarnation) - xml_text = self.fetch_manifest(ext_handler.versionUris) - self.save_cache(local_file, xml_text) - + def update_goal_state(self, forced=False, max_retry=3): uri = GOAL_STATE_URI.format(self.endpoint) xml_text = self.fetch_config(uri, self.get_header()) goal_state = GoalState(xml_text) - incarnation_file = os.path.join(OSUTIL.get_lib_dir(), + incarnation_file = os.path.join(conf.get_lib_dir(), INCARNATION_FILE_NAME) if not forced: @@ -619,7 +641,7 @@ class WireClient(object): try: self.goal_state = goal_state file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) - goal_state_file = os.path.join(OSUTIL.get_lib_dir(), file_name) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) self.save_cache(goal_state_file, xml_text) self.save_cache(incarnation_file, goal_state.incarnation) self.update_hosting_env(goal_state) @@ -636,27 +658,34 @@ class WireClient(object): def get_goal_state(self): if(self.goal_state is None): - incarnation = self.fetch_cache(INCARNATION_FILE_NAME) - goal_state_file = GOAL_STATE_FILE_NAME.format(incarnation) + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + incarnation = self.fetch_cache(incarnation_file) + + file_name = GOAL_STATE_FILE_NAME.format(incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) xml_text = self.fetch_cache(goal_state_file) self.goal_state = GoalState(xml_text) return self.goal_state def get_hosting_env(self): if(self.hosting_env is None): - xml_text = self.fetch_cache(HOSTING_ENV_FILE_NAME) + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_cache(local_file) self.hosting_env = HostingEnv(xml_text) return self.hosting_env def get_shared_conf(self): if(self.shared_conf is None): - xml_text = self.fetch_cache(SHARED_CONF_FILE_NAME) + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_cache(local_file) self.shared_conf = SharedConfig(xml_text) return self.shared_conf def get_certs(self): if(self.certs is None): - xml_text = self.fetch_cache(CERTS_FILE_NAME) + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_cache(local_file) self.certs = Certificates(self, xml_text) if self.certs is None: return None @@ -669,14 +698,17 @@ class WireClient(object): self.ext_conf = ExtensionsConfig(None) else: local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) xml_text = self.fetch_cache(local_file) self.ext_conf = ExtensionsConfig(xml_text) return self.ext_conf - def get_ext_manifest(self, extension, goal_state): - local_file = MANIFEST_FILE_NAME.format(extension.name, - goal_state.incarnation) - xml_text = self.fetch_cache(local_file) + def get_ext_manifest(self, ext_handler, goal_state): + local_file = MANIFEST_FILE_NAME.format(ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) return ExtensionManifest(xml_text) def check_wire_protocol_version(self): @@ -693,7 +725,7 @@ class WireClient(object): else: error = ("Agent supported wire protocol version: {0} was not " "advised by Fabric.").format(PROTOCOL_VERSION) - raise ProtocolNotFound(error) + raise ProtocolNotFoundError(error) def upload_status_blob(self): ext_conf = self.get_ext_conf() @@ -711,7 +743,7 @@ class WireClient(object): try: resp = self.call_wireserver(restutil.http_post, role_prop_uri, role_prop, headers = headers) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to send role properties: {0}" u"").format(e)) if resp.status != httpclient.ACCEPTED: @@ -732,7 +764,7 @@ class WireClient(object): try: resp = self.call_wireserver(restutil.http_post, health_report_uri, health_report, headers = headers) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError((u"Failed to send provision status: {0}" u"").format(e)) if resp.status != httpclient.OK: @@ -750,7 +782,7 @@ class WireClient(object): try: header = self.get_header_for_xml_content() resp = self.call_wireserver(restutil.http_post, uri, data, header) - except restutil.HttpError as e: + except HttpError as e: raise ProtocolError("Failed to send events:{0}".format(e)) if resp.status != httpclient.OK: @@ -791,11 +823,10 @@ class WireClient(object): } def get_header_for_cert(self): - cert = "" - content = self.fetch_cache(TRANSPORT_CERT_FILE_NAME) - for line in content.split('\n'): - if "CERTIFICATE" not in line: - cert += line.rstrip() + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + content = self.fetch_cache(trans_cert_file) + cert = get_bytes_from_pem(content) return { "x-ms-agent-name":"WALinuxAgent", "x-ms-version":PROTOCOL_VERSION, @@ -922,8 +953,6 @@ class Certificates(object): def __init__(self, client, xml_text): logger.verb("Load Certificates.xml") self.client = client - self.lib_dir = OSUTIL.get_lib_dir() - self.openssl_cmd = OSUTIL.get_openssl_cmd() self.cert_list = CertList() self.parse(xml_text) @@ -935,22 +964,26 @@ class Certificates(object): data = findtext(xml_doc, "Data") if data is None: return - + + cryptutil = CryptUtil(conf.get_openssl_cmd()) + p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) p7m = ("MIME-Version:1.0\n" "Content-Disposition: attachment; filename=\"{0}\"\n" "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" "Content-Transfer-Encoding: base64\n" "\n" - "{2}").format(P7M_FILE_NAME, P7M_FILE_NAME, data) + "{2}").format(p7m_file, p7m_file, data) - self.client.save_cache(os.path.join(self.lib_dir, P7M_FILE_NAME), p7m) + self.client.save_cache(p7m_file, p7m) + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) #decrypt certificates - cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3}" - "| {4} pkcs12 -nodes -password pass: -out {5}" - "").format(self.openssl_cmd, P7M_FILE_NAME, - TRANSPORT_PRV_FILE_NAME, TRANSPORT_CERT_FILE_NAME, - self.openssl_cmd, PEM_FILE_NAME) - shellutil.run(cmd) + cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, + pem_file) #The parsing process use public key to match prv and crt. buf = [] @@ -960,7 +993,7 @@ class Certificates(object): thumbprints = {} index = 0 v1_cert_list = [] - with open(PEM_FILE_NAME) as pem: + with open(pem_file) as pem: for line in pem.readlines(): buf.append(line) if re.match(r'[-]+BEGIN.*KEY[-]+', line): @@ -969,15 +1002,15 @@ class Certificates(object): begin_crt = True elif re.match(r'[-]+END.*KEY[-]+', line): tmp_file = self.write_to_tmp_file(index, 'prv', buf) - pub = OSUTIL.get_pubkey_from_prv(tmp_file) + pub = cryptutil.get_pubkey_from_prv(tmp_file) prvs[pub] = tmp_file buf = [] index += 1 begin_prv = False elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): tmp_file = self.write_to_tmp_file(index, 'crt', buf) - pub = OSUTIL.get_pubkey_from_crt(tmp_file) - thumbprint = OSUTIL.get_thumbprint_from_crt(tmp_file) + pub = cryptutil.get_pubkey_from_crt(tmp_file) + thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) thumbprints[pub] = thumbprint #Rename crt with thumbprint as the file name crt = "{0}.crt".format(thumbprint) @@ -985,7 +1018,7 @@ class Certificates(object): "name":None, "thumbprint":thumbprint }) - os.rename(tmp_file, os.path.join(self.lib_dir, crt)) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) buf = [] index += 1 begin_crt = False @@ -996,7 +1029,7 @@ class Certificates(object): if thumbprint: tmp_file = prvs[pubkey] prv = "{0}.prv".format(thumbprint) - os.rename(tmp_file, os.path.join(self.lib_dir, prv)) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) for v1_cert in v1_cert_list: cert = Cert() @@ -1004,7 +1037,8 @@ class Certificates(object): self.cert_list.certificates.append(cert) def write_to_tmp_file(self, index, suffix, buf): - file_name = os.path.join(self.lib_dir, "{0}.{1}".format(index, suffix)) + file_name = os.path.join(conf.get_lib_dir(), + "{0}.{1}".format(index, suffix)) self.client.save_cache(file_name, "".join(buf)) return file_name @@ -1090,7 +1124,7 @@ class ExtensionsConfig(object): ext.name = ext_handler.name ext.sequenceNumber = seqNo ext.publicSettings = handler_settings.get("publicSettings") - ext.privateSettings = handler_settings.get("protectedSettings") + ext.protectedSettings = handler_settings.get("protectedSettings") thumbprint = handler_settings.get("protectedSettingsCertThumbprint") ext.certificateThumbprint = thumbprint ext_handler.properties.extensions.append(ext) diff --git a/azurelinuxagent/utils/cryptutil.py b/azurelinuxagent/utils/cryptutil.py new file mode 100644 index 0000000..5ee5637 --- /dev/null +++ b/azurelinuxagent/utils/cryptutil.py @@ -0,0 +1,121 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import base64 +import struct +from azurelinuxagent.future import ustr, bytebuffer +from azurelinuxagent.exception import CryptError +import azurelinuxagent.utils.shellutil as shellutil + +class CryptUtil(object): + def __init__(self, openssl_cmd): + self.openssl_cmd = openssl_cmd + + def gen_transport_cert(self, prv_file, crt_file): + """ + Create ssl certificate for https communication with endpoint server. + """ + cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 " + "-newkey rsa:2048 -keyout {1} " + "-out {2}").format(self.openssl_cmd, prv_file, crt_file) + shellutil.run(cmd) + + def get_pubkey_from_prv(self, file_name): + cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, + file_name) + pub = shellutil.run_get_output(cmd)[1] + return pub + + def get_pubkey_from_crt(self, file_name): + cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd, + file_name) + pub = shellutil.run_get_output(cmd)[1] + return pub + + def get_thumbprint_from_crt(self, file_name): + cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd, + file_name) + thumbprint = shellutil.run_get_output(cmd)[1] + thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper() + return thumbprint + + def decrypt_p7m(self, p7m_file, trans_prv_file, trans_cert_file, pem_file): + cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3} " + "| {4} pkcs12 -nodes -password pass: -out {5}" + "").format(self.openssl_cmd, p7m_file, trans_prv_file, + trans_cert_file, self.openssl_cmd, pem_file) + shellutil.run(cmd) + + def crt_to_ssh(self, input_file, output_file): + shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file, + output_file)) + + def asn1_to_ssh(self, pubkey): + lines = pubkey.split("\n") + lines = [x for x in lines if not x.startswith("----")] + base64_encoded = "".join(lines) + try: + #TODO remove pyasn1 dependency + from pyasn1.codec.der import decoder as der_decoder + der_encoded = base64.b64decode(base64_encoded) + der_encoded = der_decoder.decode(der_encoded)[0][1] + key = der_decoder.decode(self.bits_to_bytes(der_encoded))[0] + n=key[0] + e=key[1] + keydata = bytearray() + keydata.extend(struct.pack('>I', len("ssh-rsa"))) + keydata.extend(b"ssh-rsa") + keydata.extend(struct.pack('>I', len(self.num_to_bytes(e)))) + keydata.extend(self.num_to_bytes(e)) + keydata.extend(struct.pack('>I', len(self.num_to_bytes(n)) + 1)) + keydata.extend(b"\0") + keydata.extend(self.num_to_bytes(n)) + keydata_base64 = base64.b64encode(bytebuffer(keydata)) + return ustr(b"ssh-rsa " + keydata_base64 + b"\n", + encoding='utf-8') + except ImportError as e: + raise CryptError("Failed to load pyasn1.codec.der") + + def num_to_bytes(self, num): + """ + Pack number into bytes. Retun as string. + """ + result = bytearray() + while num: + result.append(num & 0xFF) + num >>= 8 + result.reverse() + return result + + def bits_to_bytes(self, bits): + """ + Convert an array contains bits, [0,1] to a byte array + """ + index = 7 + byte_array = bytearray() + curr = 0 + for bit in bits: + curr = curr | (bit << index) + index = index - 1 + if index == -1: + byte_array.append(curr) + curr = 0 + index = 7 + return bytes(byte_array) + diff --git a/azurelinuxagent/utils/fileutil.py b/azurelinuxagent/utils/fileutil.py index 08592bc..5369a7c 100644 --- a/azurelinuxagent/utils/fileutil.py +++ b/azurelinuxagent/utils/fileutil.py @@ -27,7 +27,7 @@ import shutil import pwd import tempfile import azurelinuxagent.logger as logger -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.utils.textutil as textutil def read_file(filepath, asbin=False, remove_bom=False, encoding='utf-8'): @@ -46,7 +46,7 @@ def read_file(filepath, asbin=False, remove_bom=False, encoding='utf-8'): if remove_bom: #Remove bom on bytes data before it is converted into string. data = textutil.remove_bom(data) - data = text(data, encoding=encoding) + data = ustr(data, encoding=encoding) return data def write_file(filepath, contents, asbin=False, encoding='utf-8', append=False): @@ -100,6 +100,7 @@ def replace_file(filepath, contents): return 1 return 0 + def base_name(path): head, tail = os.path.split(path) return tail @@ -151,7 +152,7 @@ def rm_dirs(*args): def update_conf_file(path, line_start, val, chk_err=False): conf = [] if not os.path.isfile(path) and chk_err: - raise Exception("Can't find config file:{0}".format(path)) + raise IOError("Can't find config file:{0}".format(path)) conf = read_file(path).split('\n') conf = [x for x in conf if not x.startswith(line_start)] conf.append(val) diff --git a/azurelinuxagent/utils/osutil.py b/azurelinuxagent/utils/osutil.py deleted file mode 100644 index 9de47e7..0000000 --- a/azurelinuxagent/utils/osutil.py +++ /dev/null @@ -1,27 +0,0 @@ -# Microsoft Azure Linux Agent -# -# Copyright 2014 Microsoft Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Requires Python 2.4+ and Openssl 1.0+ -# - -""" -Load OSUtil implementation from azurelinuxagent.distro -""" -from azurelinuxagent.distro.default.osutil import OSUtilError -import azurelinuxagent.distro.loader as loader - -OSUTIL = loader.get_osutil() - diff --git a/azurelinuxagent/utils/restutil.py b/azurelinuxagent/utils/restutil.py index 2acfa57..2e8b0be 100644 --- a/azurelinuxagent/utils/restutil.py +++ b/azurelinuxagent/utils/restutil.py @@ -21,8 +21,9 @@ import time import platform import os import subprocess -import azurelinuxagent.logger as logger import azurelinuxagent.conf as conf +import azurelinuxagent.logger as logger +from azurelinuxagent.exception import HttpError from azurelinuxagent.future import httpclient, urlparse """ @@ -31,9 +32,6 @@ REST api util functions RETRY_WAITING_INTERVAL = 10 -class HttpError(Exception): - pass - def _parse_url(url): o = urlparse(url) rel_uri = o.path @@ -51,8 +49,8 @@ def get_http_proxy(): Get http_proxy and https_proxy from environment variables. Username and password is not supported now. """ - host = conf.get("HttpProxy.Host", None) - port = conf.get("HttpProxy.Port", None) + host = conf.get_httpproxy_host() + port = conf.get_httpproxy_port() return (host, port) def _http_request(method, host, rel_uri, port=None, data=None, secure=False, @@ -61,7 +59,7 @@ def _http_request(method, host, rel_uri, port=None, data=None, secure=False, if secure: port = 443 if port is None else port if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPSConnection(proxy_host, proxy_port) + conn = httpclient.HTTPSConnection(proxy_host, proxy_port, timeout=10) conn.set_tunnel(host, port) #If proxy is used, full url is needed. url = "https://{0}:{1}{2}".format(host, port, rel_uri) @@ -71,7 +69,7 @@ def _http_request(method, host, rel_uri, port=None, data=None, secure=False, else: port = 80 if port is None else port if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPConnection(proxy_host, proxy_port) + conn = httpclient.HTTPConnection(proxy_host, proxy_port, timeout=10) #If proxy is used, full url is needed. url = "http://{0}:{1}{2}".format(host, port, rel_uri) else: @@ -128,8 +126,12 @@ def http_request(method, url, data, headers=None, max_retry=3, chk_proxy=False): if retry < max_retry - 1: logger.info("Retry={0}, {1} {2}", retry, method, url) time.sleep(RETRY_WAITING_INTERVAL) - - raise HttpError("HTTP Err: {0} {1}".format(method, url)) + + if url is not None and len(url) > 100: + url_log = url[0: 100] #In case the url is too long + else: + url_log = url + raise HttpError("HTTP Err: {0} {1}".format(method, url_log)) def http_get(url, headers=None, max_retry=3, chk_proxy=False): return http_request("GET", url, data=None, headers=headers, diff --git a/azurelinuxagent/utils/shellutil.py b/azurelinuxagent/utils/shellutil.py index 372c78a..98871a1 100644 --- a/azurelinuxagent/utils/shellutil.py +++ b/azurelinuxagent/utils/shellutil.py @@ -20,7 +20,7 @@ import platform import os import subprocess -from azurelinuxagent.future import text +from azurelinuxagent.future import ustr import azurelinuxagent.logger as logger if not hasattr(subprocess,'check_output'): @@ -75,9 +75,9 @@ def run_get_output(cmd, chk_err=True, log_cmd=True): logger.verb(u"run cmd '{0}'", cmd) try: output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True) - output = text(output, encoding='utf-8', errors="backslashreplace") + output = ustr(output, encoding='utf-8', errors="backslashreplace") except subprocess.CalledProcessError as e : - output = text(e.output, encoding='utf-8', errors="backslashreplace") + output = ustr(e.output, encoding='utf-8', errors="backslashreplace") if chk_err: if log_cmd: logger.error(u"run cmd '{0}' failed", e.cmd) diff --git a/azurelinuxagent/utils/textutil.py b/azurelinuxagent/utils/textutil.py index e0f1395..851f98a 100644 --- a/azurelinuxagent/utils/textutil.py +++ b/azurelinuxagent/utils/textutil.py @@ -224,5 +224,13 @@ def gen_password_hash(password, crypt_id, salt_len): salt = "${0}${1}".format(crypt_id, salt) return crypt.crypt(password, salt) +def get_bytes_from_pem(pem_str): + base64_bytes = "" + for line in pem_str.split('\n'): + if "----" not in line: + base64_bytes += line + return base64_bytes + + Version = LooseVersion |