diff options
Diffstat (limited to 'azurelinuxagent')
25 files changed, 1375 insertions, 602 deletions
diff --git a/azurelinuxagent/agent.py b/azurelinuxagent/agent.py index d1ac354..e99f7be 100644 --- a/azurelinuxagent/agent.py +++ b/azurelinuxagent/agent.py @@ -64,7 +64,19 @@ class Agent(object): logger.add_logger_appender(logger.AppenderType.CONSOLE, level, path="/dev/console") + ext_log_dir = conf.get_ext_log_dir() + try: + if os.path.isfile(ext_log_dir): + raise Exception("{0} is a file".format(ext_log_dir)) + if not os.path.isdir(ext_log_dir): + os.makedirs(ext_log_dir) + except Exception as e: + logger.error( + "Exception occurred while creating extension " + "log directory {0}: {1}".format(ext_log_dir, e)) + #Init event reporter + event.init_event_status(conf.get_lib_dir()) event_dir = os.path.join(conf.get_lib_dir(), "events") event.init_event_logger(event_dir) event.enable_unhandled_err_dump("WALA") @@ -116,6 +128,11 @@ class Agent(object): update_handler = get_update_handler() update_handler.run() + def show_configuration(self): + configuration = conf.get_configuration() + for k in sorted(configuration.keys()): + print("{0} = {1}".format(k, configuration[k])) + def main(args=[]): """ Parse command line arguments, exit with usage() on error. @@ -145,6 +162,8 @@ def main(args=[]): agent.daemon() elif command == "run-exthandlers": agent.run_exthandlers() + elif command == "show-configuration": + agent.show_configuration() except Exception: logger.error(u"Failed to run '{0}': {1}", command, @@ -186,6 +205,8 @@ def parse_args(sys_args): verbose = True elif re.match("^([-/]*)force", a): force = True + elif re.match("^([-/]*)show-configuration", a): + cmd = "show-configuration" elif re.match("^([-/]*)(help|usage|\\?)", a): cmd = "help" else: diff --git a/azurelinuxagent/common/conf.py b/azurelinuxagent/common/conf.py index 5422784..75a0248 100644 --- a/azurelinuxagent/common/conf.py +++ b/azurelinuxagent/common/conf.py @@ -85,12 +85,81 @@ def load_conf_from_file(conf_file_path, conf=__conf__): raise AgentConfigError(("Failed to load conf file:{0}, {1}" "").format(conf_file_path, err)) +__SWITCH_OPTIONS__ = { + "OS.AllowHTTP" : False, + "OS.EnableFirewall" : False, + "OS.EnableFIPS" : False, + "OS.EnableRDMA" : False, + "OS.UpdateRdmaDriver" : False, + "OS.CheckRdmaDriver" : False, + "Logs.Verbose" : False, + "Provisioning.Enabled" : True, + "Provisioning.UseCloudInit" : False, + "Provisioning.AllowResetSysUser" : False, + "Provisioning.RegenerateSshHostKeyPair" : False, + "Provisioning.DeleteRootPassword" : False, + "Provisioning.DecodeCustomData" : False, + "Provisioning.ExecuteCustomData" : False, + "Provisioning.MonitorHostName" : False, + "DetectScvmmEnv" : False, + "ResourceDisk.Format" : False, + "DetectScvmmEnv" : False, + "ResourceDisk.Format" : False, + "ResourceDisk.EnableSwap" : False, + "AutoUpdate.Enabled" : True, + "EnableOverProvisioning" : False +} + +__STRING_OPTIONS__ = { + "Lib.Dir" : "/var/lib/waagent", + "DVD.MountPoint" : "/mnt/cdrom/secure", + "Pid.File" : "/var/run/waagent.pid", + "Extension.LogDir" : "/var/log/azure", + "OS.OpensslPath" : "/usr/bin/openssl", + "OS.SshDir" : "/etc/ssh", + "OS.HomeDir" : "/home", + "OS.PasswordPath" : "/etc/shadow", + "OS.SudoersDir" : "/etc/sudoers.d", + "OS.RootDeviceScsiTimeout" : None, + "Provisioning.SshHostKeyPairType" : "rsa", + "Provisioning.PasswordCryptId" : "6", + "HttpProxy.Host" : None, + "ResourceDisk.MountPoint" : "/mnt/resource", + "ResourceDisk.MountOptions" : None, + "ResourceDisk.Filesystem" : "ext3", + "AutoUpdate.GAFamily" : "Prod" +} + +__INTEGER_OPTIONS__ = { + "Provisioning.PasswordCryptSaltLength" : 10, + "HttpProxy.Port" : None, + "ResourceDisk.SwapSizeMB" : 0, + "Autoupdate.Frequency" : 3600 +} + +def get_configuration(conf=__conf__): + options = {} + for option in __SWITCH_OPTIONS__: + options[option] = conf.get_switch(option, __SWITCH_OPTIONS__[option]) + + for option in __STRING_OPTIONS__: + options[option] = conf.get(option, __STRING_OPTIONS__[option]) + + for option in __INTEGER_OPTIONS__: + options[option] = conf.get_int(option, __INTEGER_OPTIONS__[option]) + + return options + +def enable_firewall(conf=__conf__): + return conf.get_switch("OS.EnableFirewall", False) def enable_rdma(conf=__conf__): return conf.get_switch("OS.EnableRDMA", False) or \ conf.get_switch("OS.UpdateRdmaDriver", False) or \ conf.get_switch("OS.CheckRdmaDriver", False) +def enable_rdma_update(conf=__conf__): + return conf.get_switch("OS.UpdateRdmaDriver", False) def get_logs_verbose(conf=__conf__): return conf.get_switch("Logs.Verbose", False) @@ -151,6 +220,16 @@ def get_root_device_scsi_timeout(conf=__conf__): return conf.get("OS.RootDeviceScsiTimeout", None) def get_ssh_host_keypair_type(conf=__conf__): + keypair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa") + if keypair_type == "auto": + ''' + auto generates all supported key types and returns the + rsa thumbprint as the default. + ''' + return "rsa" + return keypair_type + +def get_ssh_host_keypair_mode(conf=__conf__): return conf.get("Provisioning.SshHostKeyPairType", "rsa") def get_provision_enabled(conf=__conf__): @@ -239,4 +318,7 @@ def get_autoupdate_frequency(conf=__conf__): return conf.get_int("Autoupdate.Frequency", 3600) def get_enable_overprovisioning(conf=__conf__): - return conf.get_switch("EnableOverProvisioning", False)
\ No newline at end of file + return conf.get_switch("EnableOverProvisioning", False) + +def get_allow_http(conf=__conf__): + return conf.get_switch("OS.AllowHTTP", False) diff --git a/azurelinuxagent/common/event.py b/azurelinuxagent/common/event.py index 723b8bf..e62a925 100644 --- a/azurelinuxagent/common/event.py +++ b/azurelinuxagent/common/event.py @@ -27,6 +27,7 @@ import platform from datetime import datetime, timedelta +import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger from azurelinuxagent.common.exception import EventError, ProtocolError @@ -39,13 +40,15 @@ from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \ DISTRO_CODE_NAME, AGENT_VERSION, \ CURRENT_AGENT, CURRENT_VERSION -_EVENT_MSG = "Event: name={0}, op={1}, message={2}" +_EVENT_MSG = "Event: name={0}, op={1}, message={2}, duration={3}" class WALAEventOperation: ActivateResourceDisk = "ActivateResourceDisk" + AutoUpdate = "AutoUpdate" Disable = "Disable" Download = "Download" Enable = "Enable" + Firewall = "Firewall" HealthCheck = "HealthCheck" HeartBeat = "HeartBeat" HostPlugin = "HostPlugin" @@ -60,13 +63,70 @@ class WALAEventOperation: Upgrade = "Upgrade" Update = "Update" -def _log_event(name, op, message, is_success=True): + +class EventStatus(object): + EVENT_STATUS_FILE = "event_status.json" + + def __init__(self, status_dir=conf.get_lib_dir()): + self._path = None + self._status = {} + + def clear(self): + self._status = {} + self._save() + + def event_marked(self, name, version, op): + return self._event_name(name, version, op) in self._status + + def event_succeeded(self, name, version, op): + event = self._event_name(name, version, op) + if event not in self._status: + return True + return self._status[event] == True + + def initialize(self, status_dir=conf.get_lib_dir()): + self._path = os.path.join(status_dir, EventStatus.EVENT_STATUS_FILE) + self._load() + + def mark_event_status(self, name, version, op, status): + event = self._event_name(name, version, op) + self._status[event] = (status == True) + self._save() + + def _event_name(self, name, version, op): + return "{0}-{1}-{2}".format(name, version, op) + + def _load(self): + try: + self._status = {} + if os.path.isfile(self._path): + with open(self._path, 'r') as f: + self._status = json.load(f) + except Exception as e: + logger.warn("Exception occurred loading event status: {0}".format(e)) + self._status = {} + + def _save(self): + try: + with open(self._path, 'w') as f: + json.dump(self._status, f) + except Exception as e: + logger.warn("Exception occurred saving event status: {0}".format(e)) + +__event_status__ = EventStatus() +__event_status_operations__ = [ + WALAEventOperation.AutoUpdate, + WALAEventOperation.ReportStatus + ] + + +def _log_event(name, op, message, duration, is_success=True): global _EVENT_MSG if not is_success: - logger.error(_EVENT_MSG, name, op, message) + logger.error(_EVENT_MSG, name, op, message, duration) else: - logger.info(_EVENT_MSG, name, op, message) + logger.info(_EVENT_MSG, name, op, message, duration) class EventLogger(object): @@ -76,7 +136,7 @@ class EventLogger(object): def save_event(self, data): if self.event_dir is None: - logger.warn("Event reporter is not initialized.") + logger.warn("Cannot save event -- Event reporter is not initialized.") return if not os.path.exists(self.event_dir): @@ -104,11 +164,11 @@ class EventLogger(object): raise EventError("Failed to write events to file:{0}", e) def reset_periodic(self): - self.periodic_messages = {} + self.periodic_events = {} def is_period_elapsed(self, delta, h): - return h not in self.periodic_messages or \ - (self.periodic_messages[h] + delta) <= datetime.now() + return h not in self.periodic_events or \ + (self.periodic_events[h] + delta) <= datetime.now() def add_periodic(self, delta, name, op="", is_success=True, duration=0, @@ -122,13 +182,21 @@ class EventLogger(object): op=op, is_success=is_success, duration=duration, version=version, message=message, evt_type=evt_type, is_internal=is_internal, log_event=log_event) - self.periodic_messages[h] = datetime.now() + self.periodic_events[h] = datetime.now() - def add_event(self, name, op="", is_success=True, duration=0, + def add_event(self, + name, + op="", + is_success=True, + duration=0, version=CURRENT_VERSION, - message="", evt_type="", is_internal=False, log_event=True): + message="", + evt_type="", + is_internal=False, + log_event=True): + if not is_success or log_event: - _log_event(name, op, message, is_success=is_success) + _log_event(name, op, message, duration, is_success=is_success) event = TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") event.parameters.append(TelemetryEventParam('Name', name)) @@ -176,22 +244,24 @@ def add_event(name, op="", is_success=True, duration=0, version=CURRENT_VERSION, message="", evt_type="", is_internal=False, log_event=True, reporter=__event_logger__): if reporter.event_dir is None: - logger.warn("Event reporter is not initialized.") - _log_event(name, op, message, is_success=is_success) + logger.warn("Cannot add event -- Event reporter is not initialized.") + _log_event(name, op, message, duration, is_success=is_success) return - reporter.add_event( - name, op=op, is_success=is_success, duration=duration, - version=str(version), message=message, evt_type=evt_type, - is_internal=is_internal, log_event=log_event) + if should_emit_event(name, version, op, is_success): + mark_event_status(name, version, op, is_success) + reporter.add_event( + name, op=op, is_success=is_success, duration=duration, + version=str(version), message=message, evt_type=evt_type, + is_internal=is_internal, log_event=log_event) def add_periodic( delta, name, op="", is_success=True, duration=0, version=CURRENT_VERSION, message="", evt_type="", is_internal=False, log_event=True, force=False, reporter=__event_logger__): if reporter.event_dir is None: - logger.warn("Event reporter is not initialized.") - _log_event(name, op, message, is_success=is_success) + logger.warn("Cannot add periodic event -- Event reporter is not initialized.") + _log_event(name, op, message, duration, is_success=is_success) return reporter.add_periodic( @@ -199,9 +269,22 @@ def add_periodic( version=str(version), message=message, evt_type=evt_type, is_internal=is_internal, log_event=log_event, force=force) -def init_event_logger(event_dir, reporter=__event_logger__): - reporter.event_dir = event_dir +def mark_event_status(name, version, op, status): + if op in __event_status_operations__: + __event_status__.mark_event_status(name, version, op, status) + +def should_emit_event(name, version, op, status): + return \ + op not in __event_status_operations__ or \ + __event_status__ is None or \ + not __event_status__.event_marked(name, version, op) or \ + __event_status__.event_succeeded(name, version, op) != status + +def init_event_logger(event_dir): + __event_logger__.event_dir = event_dir +def init_event_status(status_dir): + __event_status__.initialize(status_dir) def dump_unhandled_err(name): if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \ diff --git a/azurelinuxagent/common/exception.py b/azurelinuxagent/common/exception.py index 7a0c75e..17c6ce0 100644 --- a/azurelinuxagent/common/exception.py +++ b/azurelinuxagent/common/exception.py @@ -86,7 +86,6 @@ class DhcpError(AgentError): def __init__(self, msg=None, inner=None): super(DhcpError, self).__init__('000006', msg, inner) - class OSUtilError(AgentError): """ Failed to perform operation to OS configuration @@ -148,3 +147,12 @@ class UpdateError(AgentError): def __init__(self, msg=None, inner=None): super(UpdateError, self).__init__('000012', msg, inner) + + +class ResourceGoneError(HttpError): + """ + The requested resource no longer exists (i.e., status code 410) + """ + + def __init__(self, msg=None, inner=None): + super(ResourceGoneError, self).__init__(msg, inner) diff --git a/azurelinuxagent/common/osutil/default.py b/azurelinuxagent/common/osutil/default.py index 58c0ef8..dc1c11a 100644 --- a/azurelinuxagent/common/osutil/default.py +++ b/azurelinuxagent/common/osutil/default.py @@ -40,6 +40,7 @@ import azurelinuxagent.common.utils.textutil as textutil from azurelinuxagent.common.exception import OSUtilError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.utils.flexible_version import FlexibleVersion __RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules" ] @@ -50,10 +51,20 @@ for all distros. Each concrete distro classes could overwrite default behavior if needed. """ +IPTABLES_VERSION_PATTERN = re.compile("^[^\d\.]*([\d\.]+).*$") +IPTABLES_VERSION = "iptables --version" +IPTABLES_LOCKING_VERSION = FlexibleVersion('1.4.21') + +FIREWALL_ACCEPT = "iptables {0} -t security -{1} OUTPUT -d {2} -p tcp -m owner --uid-owner {3} -j ACCEPT" +FIREWALL_DROP = "iptables {0} -t security -{1} OUTPUT -d {2} -p tcp -j DROP" +FIREWALL_LIST = "iptables {0} -t security -L" + +_enable_firewall = True + DMIDECODE_CMD = 'dmidecode --string system-uuid' PRODUCT_ID_FILE = '/sys/class/dmi/id/product_uuid' UUID_PATTERN = re.compile( - '^\s*[A-F0-9]{8}(?:\-[A-F0-9]{4}){3}\-[A-F0-9]{12}\s*$', + r'^\s*[A-F0-9]{8}(?:\-[A-F0-9]{4}){3}\-[A-F0-9]{12}\s*$', re.IGNORECASE) class DefaultOSUtil(object): @@ -63,6 +74,113 @@ class DefaultOSUtil(object): self.selinux = None self.disable_route_warning = False + def enable_firewall(self, dst_ip=None, uid=None): + + # If a previous attempt threw an exception, do not retry + global _enable_firewall + if not _enable_firewall: + return False + + try: + if dst_ip is None or uid is None: + msg = "Missing arguments to enable_firewall" + logger.warn(msg) + raise Exception(msg) + + # Determine if iptables will serialize access + rc, output = shellutil.run_get_output(IPTABLES_VERSION) + if rc != 0: + msg = "Unable to determine version of iptables" + logger.warn(msg) + raise Exception(msg) + + m = IPTABLES_VERSION_PATTERN.match(output) + if m is None: + msg = "iptables did not return version information" + logger.warn(msg) + raise Exception(msg) + + wait = "-w" \ + if FlexibleVersion(m.group(1)) >= IPTABLES_LOCKING_VERSION \ + else "" + + # If the DROP rule exists, make no changes + drop_rule = FIREWALL_DROP.format(wait, "C", dst_ip) + + if shellutil.run(drop_rule, chk_err=False) == 0: + logger.verbose("Firewall appears established") + return True + + # Otherwise, append both rules + accept_rule = FIREWALL_ACCEPT.format(wait, "A", dst_ip, uid) + drop_rule = FIREWALL_DROP.format(wait, "A", dst_ip) + + if shellutil.run(accept_rule) != 0: + msg = "Unable to add ACCEPT firewall rule '{0}'".format( + accept_rule) + logger.warn(msg) + raise Exception(msg) + + if shellutil.run(drop_rule) != 0: + msg = "Unable to add DROP firewall rule '{0}'".format( + drop_rule) + logger.warn(msg) + raise Exception(msg) + + logger.info("Successfully added Azure fabric firewall rules") + + rc, output = shellutil.run_get_output(FIREWALL_LIST.format(wait)) + if rc == 0: + logger.info("Firewall rules:\n{0}".format(output)) + else: + logger.warn("Listing firewall rules failed: {0}".format(output)) + + return True + + except Exception as e: + _enable_firewall = False + logger.info("Unable to establish firewall -- " + "no further attempts will be made: " + "{0}".format(ustr(e))) + return False + + def _correct_instance_id(self, id): + ''' + Azure stores the instance ID with an incorrect byte ordering for the + first parts. For example, the ID returned by the metadata service: + + D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8 + + will be found as: + + 544CDFD0-CB4E-4B4A-9954-5BDF3ED5C3B8 + + This code corrects the byte order such that it is consistent with + that returned by the metadata service. + ''' + + if not UUID_PATTERN.match(id): + return id + + parts = id.split('-') + return '-'.join([ + textutil.swap_hexstring(parts[0], width=2), + textutil.swap_hexstring(parts[1], width=2), + textutil.swap_hexstring(parts[2], width=2), + parts[3], + parts[4] + ]) + + def is_current_instance_id(self, id_that): + ''' + Compare two instance IDs for equality, but allow that some IDs + may have been persisted using the incorrect byte ordering. + ''' + id_this = self.get_instance_id() + return id_that == id_this or \ + id_that == self._correct_instance_id(id_this) + + def get_agent_conf_file_path(self): return self.agent_conf_file_path @@ -74,13 +192,14 @@ class DefaultOSUtil(object): If nothing works (for old VMs), return the empty string ''' if os.path.isfile(PRODUCT_ID_FILE): - return fileutil.read_file(PRODUCT_ID_FILE).strip() + s = fileutil.read_file(PRODUCT_ID_FILE).strip() - rc, s = shellutil.run_get_output(DMIDECODE_CMD) - if rc != 0 or UUID_PATTERN.match(s) is None: - return "" + else: + rc, s = shellutil.run_get_output(DMIDECODE_CMD) + if rc != 0 or UUID_PATTERN.match(s) is None: + return "" - return s.strip() + return self._correct_instance_id(s.strip()) def get_userentry(self, username): try: @@ -158,10 +277,12 @@ class DefaultOSUtil(object): fileutil.append_file(sudoers_file, sudoers) sudoer = None if nopasswd: - sudoer = "{0} ALL=(ALL) NOPASSWD: ALL\n".format(username) + sudoer = "{0} ALL=(ALL) NOPASSWD: ALL".format(username) else: - sudoer = "{0} ALL=(ALL) ALL\n".format(username) - fileutil.append_file(sudoers_wagent, sudoer) + sudoer = "{0} ALL=(ALL) ALL".format(username) + if not os.path.isfile(sudoers_wagent) or \ + fileutil.findstr_in_file(sudoers_wagent, sudoer) is None: + fileutil.append_file(sudoers_wagent, "{0}\n".format(sudoer)) fileutil.chmod(sudoers_wagent, 0o440) else: #Remove user from sudoers @@ -334,7 +455,7 @@ class DefaultOSUtil(object): return_code, err = self.mount(dvd_device, mount_point, option="-o ro -t udf,iso9660", - chk_err=chk_err) + chk_err=False) if return_code == 0: logger.info("Successfully mounted dvd") return @@ -718,7 +839,7 @@ class DefaultOSUtil(object): for conf_file in dhclient_files: if not os.path.isfile(conf_file): continue - if fileutil.findstr_in_file(conf_file, autosend): + if fileutil.findre_in_file(conf_file, autosend): #Return if auto send host-name is configured return fileutil.update_conf_file(conf_file, diff --git a/azurelinuxagent/common/osutil/factory.py b/azurelinuxagent/common/osutil/factory.py index 2be90ab..43aa6a7 100644 --- a/azurelinuxagent/common/osutil/factory.py +++ b/azurelinuxagent/common/osutil/factory.py @@ -41,7 +41,8 @@ def get_osutil(distro_name=DISTRO_NAME, if distro_name == "arch": return ArchUtil() - if distro_name == "clear linux software for intel architecture": + if distro_name == "clear linux os for intel architecture" \ + or distro_name == "clear linux software for intel architecture": return ClearLinuxUtil() if distro_name == "ubuntu": diff --git a/azurelinuxagent/common/osutil/openbsd.py b/azurelinuxagent/common/osutil/openbsd.py index 9bfe6de..a022c59 100644 --- a/azurelinuxagent/common/osutil/openbsd.py +++ b/azurelinuxagent/common/osutil/openbsd.py @@ -248,8 +248,10 @@ class OpenBSDOSUtil(DefaultOSUtil): os.makedirs(mount_point) for retry in range(0, max_retry): - retcode = self.mount(dvd_device, mount_point, option="-o ro -t udf", - chk_err=chk_err) + retcode = self.mount(dvd_device, + mount_point, + option="-o ro -t udf", + chk_err=False) if retcode == 0: logger.info("Successfully mounted DVD") return diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py index 9af8a97..729d8fb 100644 --- a/azurelinuxagent/common/protocol/hostplugin.py +++ b/azurelinuxagent/common/protocol/hostplugin.py @@ -22,7 +22,8 @@ import json import traceback from azurelinuxagent.common import logger -from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.exception import HttpError, ProtocolError, \ + ResourceGoneError from azurelinuxagent.common.future import ustr, httpclient from azurelinuxagent.common.utils import restutil from azurelinuxagent.common.utils import textutil @@ -85,10 +86,10 @@ class HostPluginProtocol(object): try: headers = {HEADER_CONTAINER_ID: self.container_id} response = restutil.http_get(url, headers) - if response.status != httpclient.OK: + if restutil.request_failed(response): logger.error( "HostGAPlugin: Failed Get API versions: {0}".format( - self.read_response_error(response))) + restutil.read_response_error(response))) else: return_val = ustr(remove_bom(response.read()), encoding='utf-8') @@ -117,42 +118,7 @@ class HostPluginProtocol(object): return url, headers def put_vm_log(self, content): - """ - Try to upload the given content to the host plugin - :param deployment_id: the deployment id, which is obtained from the - goal state (tenant name) - :param container_id: the container id, which is obtained from the - goal state - :param content: the binary content of the zip file to upload - :return: - """ - if not self.ensure_initialized(): - raise ProtocolError("HostGAPlugin: Host plugin channel is not available") - - if content is None \ - or self.container_id is None \ - or self.deployment_id is None: - logger.error( - "HostGAPlugin: Invalid arguments passed: " - "[{0}], [{1}], [{2}]".format( - content, - self.container_id, - self.deployment_id)) - return - url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT) - - headers = {"x-ms-vmagentlog-deploymentid": self.deployment_id, - "x-ms-vmagentlog-containerid": self.container_id} - logger.periodic( - logger.EVERY_FIFTEEN_MINUTES, - "HostGAPlugin: Put VM log to [{0}]".format(url)) - try: - response = restutil.http_put(url, content, headers) - if response.status != httpclient.OK: - logger.error("HostGAPlugin: Put log failed: Code {0}".format( - response.status)) - except HttpError as e: - logger.error("HostGAPlugin: Put log exception: {0}".format(e)) + raise NotImplementedError("Unimplemented") def put_vm_status(self, status_blob, sas_url, config_blob_type=None): """ @@ -169,6 +135,7 @@ class HostPluginProtocol(object): logger.verbose("HostGAPlugin: Posting VM status") try: + blob_type = status_blob.type if status_blob.type else config_blob_type if blob_type == "BlockBlob": @@ -176,17 +143,14 @@ class HostPluginProtocol(object): else: self._put_page_blob_status(sas_url, status_blob) - if not HostPluginProtocol.is_default_channel(): + except Exception as e: + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + if isinstance(e, ResourceGoneError): logger.verbose("HostGAPlugin: Setting host plugin as default channel") HostPluginProtocol.set_default_channel(True) - except Exception as e: - message = "HostGAPlugin: Exception Put VM status: {0}, {1}".format(e, traceback.format_exc()) - from azurelinuxagent.common.event import WALAEventOperation, report_event - report_event(op=WALAEventOperation.ReportStatus, - is_success=False, - message=message) - logger.warn("HostGAPlugin: resetting default channel") - HostPluginProtocol.set_default_channel(False) + + raise def _put_block_blob_status(self, sas_url, status_blob): url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT) @@ -198,9 +162,9 @@ class HostPluginProtocol(object): bytearray(status_blob.data, encoding='utf-8')), headers=self._build_status_headers()) - if response.status != httpclient.OK: + if restutil.request_failed(response): raise HttpError("HostGAPlugin: Put BlockBlob failed: {0}".format( - self.read_response_error(response))) + restutil.read_response_error(response))) else: logger.verbose("HostGAPlugin: Put BlockBlob status succeeded") @@ -219,10 +183,10 @@ class HostPluginProtocol(object): status_blob.get_page_blob_create_headers(status_size)), headers=self._build_status_headers()) - if response.status != httpclient.OK: + if restutil.request_failed(response): raise HttpError( "HostGAPlugin: Failed PageBlob clean-up: {0}".format( - self.read_response_error(response))) + restutil.read_response_error(response))) else: logger.verbose("HostGAPlugin: PageBlob clean-up succeeded") @@ -249,11 +213,11 @@ class HostPluginProtocol(object): buf), headers=self._build_status_headers()) - if response.status != httpclient.OK: + if restutil.request_failed(response): raise HttpError( "HostGAPlugin Error: Put PageBlob bytes [{0},{1}]: " \ "{2}".format( - start, end, self.read_response_error(response))) + start, end, restutil.read_response_error(response))) # Advance to the next page (if any) start = end @@ -287,26 +251,3 @@ class HostPluginProtocol(object): if PY_VERSION_MAJOR > 2: return s.decode('utf-8') return s - - @staticmethod - def read_response_error(response): - result = '' - if response is not None: - try: - body = remove_bom(response.read()) - result = "[{0}: {1}] {2}".format(response.status, - response.reason, - body) - - # this result string is passed upstream to several methods - # which do a raise HttpError() or a format() of some kind; - # as a result it cannot have any unicode characters - if PY_VERSION_MAJOR < 3: - result = ustr(result, encoding='ascii', errors='ignore') - else: - result = result\ - .encode(encoding='ascii', errors='ignore')\ - .decode(encoding='ascii', errors='ignore') - except Exception: - logger.warn(traceback.format_exc()) - return result diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py index b0b6f67..4de7ecf 100644 --- a/azurelinuxagent/common/protocol/metadata.py +++ b/azurelinuxagent/common/protocol/metadata.py @@ -88,7 +88,7 @@ class MetadataProtocol(Protocol): except HttpError as e: raise ProtocolError(ustr(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) data = resp.read() @@ -103,7 +103,7 @@ class MetadataProtocol(Protocol): resp = restutil.http_put(url, json.dumps(data), headers=headers) except HttpError as e: raise ProtocolError(ustr(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) def _post_data(self, url, data, headers=None): diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py index a42db37..1ec3e21 100644 --- a/azurelinuxagent/common/protocol/restapi.py +++ b/azurelinuxagent/common/protocol/restapi.py @@ -317,8 +317,8 @@ class Protocol(DataContract): def download_ext_handler_pkg(self, uri, headers=None): try: - resp = restutil.http_get(uri, chk_proxy=True, headers=headers) - if resp.status == restutil.httpclient.OK: + resp = restutil.http_get(uri, use_proxy=True, headers=headers) + if restutil.request_succeeded(resp): return resp.read() except Exception as e: logger.warn("Failed to download from: {0}".format(uri), e) diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py index bb3500a..3071d7a 100644 --- a/azurelinuxagent/common/protocol/util.py +++ b/azurelinuxagent/common/protocol/util.py @@ -16,11 +16,14 @@ # # Requires Python 2.4+ and Openssl 1.0+ # + +import errno import os import re import shutil import time import threading + import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \ @@ -231,6 +234,9 @@ class ProtocolUtil(object): try: os.remove(protocol_file_path) except IOError as e: + # Ignore file-not-found errors (since the file is being removed) + if e.errno == errno.ENOENT: + return logger.error("Failed to clear protocol endpoint: {0}", e) def get_protocol(self): diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py index d731e11..4f3b7e0 100644 --- a/azurelinuxagent/common/protocol/wire.py +++ b/azurelinuxagent/common/protocol/wire.py @@ -26,7 +26,8 @@ import azurelinuxagent.common.conf as conf import azurelinuxagent.common.utils.fileutil as fileutil import azurelinuxagent.common.utils.textutil as textutil -from azurelinuxagent.common.exception import ProtocolNotFoundError +from azurelinuxagent.common.exception import ProtocolNotFoundError, \ + ResourceGoneError from azurelinuxagent.common.future import httpclient, bytebuffer from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol from azurelinuxagent.common.protocol.restapi import * @@ -96,7 +97,10 @@ class WireProtocol(Protocol): cryptutil = CryptUtil(conf.get_openssl_cmd()) cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) - self.client.update_goal_state(forced=True) + self.update_goal_state(forced=True) + + def update_goal_state(self, forced=False, max_retry=3): + self.client.update_goal_state(forced=forced, max_retry=max_retry) def get_vminfo(self): goal_state = self.client.get_goal_state() @@ -117,7 +121,7 @@ class WireProtocol(Protocol): def get_vmagent_manifests(self): # Update goal state to get latest extensions config - self.client.update_goal_state() + self.update_goal_state() goal_state = self.client.get_goal_state() ext_conf = self.client.get_ext_conf() return ext_conf.vmagent_manifests, goal_state.incarnation @@ -130,7 +134,7 @@ class WireProtocol(Protocol): def get_ext_handlers(self): logger.verbose("Get extension handler config") # Update goal state to get latest extensions config - self.client.update_goal_state() + self.update_goal_state() goal_state = self.client.get_goal_state() ext_conf = self.client.get_ext_conf() # In wire protocol, incarnation is equivalent to ETag @@ -533,29 +537,27 @@ class WireClient(object): self.req_count = 0 def call_wireserver(self, http_req, *args, **kwargs): - """ - Call wire server; handle throttling (403), resource gone (410) and - service unavailable (503). - """ self.prevent_throttling() - for retry in range(0, 3): + + try: + # Never use the HTTP proxy for wireserver + kwargs['use_proxy'] = False resp = http_req(*args, **kwargs) - if resp.status == httpclient.FORBIDDEN: - logger.warn("Sending too many requests to wire server. ") - logger.info("Sleeping {0}s to avoid throttling.", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - elif resp.status == httpclient.SERVICE_UNAVAILABLE: - logger.warn("Service temporarily unavailable, sleeping {0}s " - "before retrying.", LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - elif resp.status == httpclient.GONE: - msg = args[0] if len(args) > 0 else "" - raise WireProtocolResourceGone(msg) - else: - return resp - raise ProtocolError(("Calling wire server failed: " - "{0}").format(resp.status)) + except Exception as e: + raise ProtocolError("[Wireserver Exception] {0}".format( + ustr(e))) + + if resp is not None and resp.status == httpclient.GONE: + msg = args[0] if len(args) > 0 else "" + raise WireProtocolResourceGone(msg) + + elif restutil.request_failed(resp): + msg = "[Wireserver Failed] URI {0} ".format(args[0]) + if resp is not None: + msg += " [HTTP Failed] Status Code {0}".format(resp.status) + raise ProtocolError(msg) + + return resp def decode_config(self, data): if data is None: @@ -565,16 +567,9 @@ class WireClient(object): return xml_text def fetch_config(self, uri, headers): - try: - resp = self.call_wireserver(restutil.http_get, - uri, - headers=headers) - except HttpError as e: - raise ProtocolError(ustr(e)) - - if resp.status != httpclient.OK: - raise ProtocolError("{0} - {1}".format(resp.status, uri)) - + resp = self.call_wireserver(restutil.http_get, + uri, + headers=headers) return self.decode_config(resp.read()) def fetch_cache(self, local_file): @@ -589,29 +584,17 @@ class WireClient(object): try: fileutil.write_file(local_file, data) except IOError as e: + fileutil.clean_ioerror(e, + paths=[local_file]) raise ProtocolError("Failed to write cache: {0}".format(e)) @staticmethod def call_storage_service(http_req, *args, **kwargs): - """ - Call storage service, handle SERVICE_UNAVAILABLE(503) - """ - # Default to use the configured HTTP proxy - if not 'chk_proxy' in kwargs or kwargs['chk_proxy'] is None: - kwargs['chk_proxy'] = True + if not 'use_proxy' in kwargs or kwargs['use_proxy'] is None: + kwargs['use_proxy'] = True - for retry in range(0, 3): - resp = http_req(*args, **kwargs) - if resp.status == httpclient.SERVICE_UNAVAILABLE: - logger.warn("Storage service is temporarily unavailable. ") - logger.info("Will retry in {0} seconds. ", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - else: - return resp - raise ProtocolError(("Calling storage endpoint failed: " - "{0}").format(resp.status)) + return http_req(*args, **kwargs) def fetch_manifest(self, version_uris): logger.verbose("Fetch manifest") @@ -619,47 +602,61 @@ class WireClient(object): response = None if not HostPluginProtocol.is_default_channel(): response = self.fetch(version.uri) + if not response: if HostPluginProtocol.is_default_channel(): logger.verbose("Using host plugin as default channel") else: - logger.verbose("Manifest could not be downloaded, falling back to host plugin") - host = self.get_host_plugin() - uri, headers = host.get_artifact_request(version.uri) - response = self.fetch(uri, headers, chk_proxy=False) - if not response: - host = self.get_host_plugin(force_update=True) - logger.info("Retry fetch in {0} seconds", - SHORT_WAITING_INTERVAL) - time.sleep(SHORT_WAITING_INTERVAL) - else: - host.manifest_uri = version.uri - logger.verbose("Manifest downloaded successfully from host plugin") - if not HostPluginProtocol.is_default_channel(): - logger.info("Setting host plugin as default channel") - HostPluginProtocol.set_default_channel(True) + logger.verbose("Failed to download manifest, " + "switching to host plugin") + + try: + host = self.get_host_plugin() + uri, headers = host.get_artifact_request(version.uri) + response = self.fetch(uri, headers, use_proxy=False) + + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + except ResourceGoneError: + HostPluginProtocol.set_default_channel(True) + raise + + host.manifest_uri = version.uri + logger.verbose("Manifest downloaded successfully from host plugin") + if not HostPluginProtocol.is_default_channel(): + logger.info("Setting host plugin as default channel") + HostPluginProtocol.set_default_channel(True) + if response: return response + raise ProtocolError("Failed to fetch manifest from all sources") - def fetch(self, uri, headers=None, chk_proxy=None): + def fetch(self, uri, headers=None, use_proxy=None): logger.verbose("Fetch [{0}] with headers [{1}]", uri, headers) - return_value = None try: resp = self.call_storage_service( - restutil.http_get, - uri, - headers, - chk_proxy=chk_proxy) - if resp.status == httpclient.OK: - return_value = self.decode_config(resp.read()) - else: - logger.warn("Could not fetch {0} [{1}]", - uri, - HostPluginProtocol.read_response_error(resp)) + restutil.http_get, + uri, + headers=headers, + use_proxy=use_proxy) + + if restutil.request_failed(resp): + msg = "[Storage Failed] URI {0} ".format(uri) + if resp is not None: + msg += restutil.read_response_error(resp) + logger.warn(msg) + raise ProtocolError(msg) + + return self.decode_config(resp.read()) + except (HttpError, ProtocolError) as e: logger.verbose("Fetch failed from [{0}]: {1}", uri, e) - return return_value + + if isinstance(e, ResourceGoneError): + raise + + return None def update_hosting_env(self, goal_state): if goal_state.hosting_env_uri is None: @@ -734,6 +731,12 @@ class WireClient(object): self.host_plugin.container_id = goal_state.container_id self.host_plugin.role_config_name = goal_state.role_config_name return + + except ProtocolError: + if retry < max_retry-1: + continue + raise + except WireProtocolResourceGone: logger.info("Incarnation is out of date. Update goalstate.") xml_text = self.fetch_config(uri, self.get_header()) @@ -791,20 +794,45 @@ class WireClient(object): return self.ext_conf def get_ext_manifest(self, ext_handler, goal_state): - local_file = MANIFEST_FILE_NAME.format(ext_handler.name, - goal_state.incarnation) - local_file = os.path.join(conf.get_lib_dir(), local_file) - xml_text = self.fetch_manifest(ext_handler.versionUris) - self.save_cache(local_file, xml_text) - return ExtensionManifest(xml_text) + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) + goal_state = self.get_goal_state() + + local_file = MANIFEST_FILE_NAME.format( + ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) + return ExtensionManifest(xml_text) + + except ResourceGoneError: + continue + + raise ProtocolError("Failed to retrieve extension manifest") def get_gafamily_manifest(self, vmagent_manifest, goal_state): - local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family, - goal_state.incarnation) - local_file = os.path.join(conf.get_lib_dir(), local_file) - xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris) - fileutil.write_file(local_file, xml_text) - return ExtensionManifest(xml_text) + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) + goal_state = self.get_goal_state() + + local_file = MANIFEST_FILE_NAME.format( + vmagent_manifest.family, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest( + vmagent_manifest.versionsManifestUris) + fileutil.write_file(local_file, xml_text) + return ExtensionManifest(xml_text) + + except ResourceGoneError: + continue + + raise ProtocolError("Failed to retrieve GAFamily manifest") def check_wire_protocol_version(self): uri = VERSION_INFO_URI.format(self.endpoint) @@ -823,39 +851,55 @@ class WireClient(object): raise ProtocolNotFoundError(error) def upload_status_blob(self): - ext_conf = self.get_ext_conf() + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) - blob_uri = ext_conf.status_upload_blob - blob_type = ext_conf.status_upload_blob_type + ext_conf = self.get_ext_conf() - if blob_uri is not None: + blob_uri = ext_conf.status_upload_blob + blob_type = ext_conf.status_upload_blob_type - if not blob_type in ["BlockBlob", "PageBlob"]: - blob_type = "BlockBlob" - logger.verbose("Status Blob type is unspecified " - "-- assuming it is a BlockBlob") + if blob_uri is not None: + + if not blob_type in ["BlockBlob", "PageBlob"]: + blob_type = "BlockBlob" + logger.verbose("Status Blob type is unspecified " + "-- assuming it is a BlockBlob") + + try: + self.status_blob.prepare(blob_type) + except Exception as e: + self.report_status_event( + "Exception creating status blob: {0}", ustr(e)) + return + + if not HostPluginProtocol.is_default_channel(): + try: + if self.status_blob.upload(blob_uri): + return + except HttpError as e: + pass + + host = self.get_host_plugin() + host.put_vm_status(self.status_blob, + ext_conf.status_upload_blob, + ext_conf.status_upload_blob_type) + HostPluginProtocol.set_default_channel(True) + return - try: - self.status_blob.prepare(blob_type) except Exception as e: + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + if isinstance(e, ResourceGoneError): + HostPluginProtocol.set_default_channel(True) + continue + self.report_status_event( - "Exception creating status blob: {0}", - e) + "Exception uploading status blob: {0}", ustr(e)) return - uploaded = False - if not HostPluginProtocol.is_default_channel(): - try: - uploaded = self.status_blob.upload(blob_uri) - except HttpError as e: - pass - - if not uploaded: - host = self.get_host_plugin() - host.put_vm_status(self.status_blob, - ext_conf.status_upload_blob, - ext_conf.status_upload_blob_type) - def report_role_prop(self, thumbprint): goal_state = self.get_goal_state() role_prop = _build_role_properties(goal_state.container_id, @@ -896,11 +940,12 @@ class WireClient(object): health_report_uri, health_report, headers=headers, - max_retry=30) + max_retry=30, + retry_delay=15) except HttpError as e: raise ProtocolError((u"Failed to send provision status: " u"{0}").format(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): raise ProtocolError((u"Failed to send provision status: " u",{0}: {1}").format(resp.status, resp.read())) @@ -919,7 +964,7 @@ class WireClient(object): except HttpError as e: raise ProtocolError("Failed to send events:{0}".format(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): logger.verbose(resp.read()) raise ProtocolError( "Failed to send events:{0}".format(resp.status)) @@ -979,12 +1024,8 @@ class WireClient(object): "x-ms-guest-agent-public-x509-cert": cert } - def get_host_plugin(self, force_update=False): - if self.host_plugin is None or force_update: - if force_update: - logger.warn("Forcing update of goal state") - self.goal_state = None - self.update_goal_state(forced=True) + def get_host_plugin(self): + if self.host_plugin is None: goal_state = self.get_goal_state() self.host_plugin = HostPluginProtocol(self.endpoint, goal_state.container_id, @@ -997,23 +1038,47 @@ class WireClient(object): def get_artifacts_profile(self): artifacts_profile = None - if self.has_artifacts_profile_blob(): - blob = self.ext_conf.artifacts_profile_blob - logger.verbose("Getting the artifacts profile") - profile = self.fetch(blob) + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) - if profile is None: - logger.warn("Download failed, falling back to host plugin") - host = self.get_host_plugin() - uri, headers = host.get_artifact_request(blob) - profile = self.decode_config(self.fetch(uri, headers, chk_proxy=False)) + if self.has_artifacts_profile_blob(): + blob = self.ext_conf.artifacts_profile_blob - if not textutil.is_str_none_or_whitespace(profile): - logger.verbose("Artifacts profile downloaded successfully") - artifacts_profile = InVMArtifactsProfile(profile) + profile = None + if not HostPluginProtocol.is_default_channel(): + logger.verbose("Retrieving the artifacts profile") + profile = self.fetch(blob) - return artifacts_profile + if profile is None: + if HostPluginProtocol.is_default_channel(): + logger.verbose("Using host plugin as default channel") + else: + logger.verbose("Failed to download artifacts profile, " + "switching to host plugin") + host = self.get_host_plugin() + uri, headers = host.get_artifact_request(blob) + config = self.fetch(uri, headers, use_proxy=False) + profile = self.decode_config(config) + + if not textutil.is_str_none_or_whitespace(profile): + logger.verbose("Artifacts profile downloaded") + artifacts_profile = InVMArtifactsProfile(profile) + + return artifacts_profile + + except ResourceGoneError: + HostPluginProtocol.set_default_channel(True) + continue + + except Exception as e: + logger.warn( + "Exception retrieving artifacts profile: {0}".format( + ustr(e))) + + return None class VersionInfo(object): def __init__(self, xml_text): diff --git a/azurelinuxagent/common/rdma.py b/azurelinuxagent/common/rdma.py index 226482d..3c01e77 100644 --- a/azurelinuxagent/common/rdma.py +++ b/azurelinuxagent/common/rdma.py @@ -202,7 +202,17 @@ class RDMADeviceHandler(object): RDMADeviceHandler.update_dat_conf(dapl_config_paths, self.ipv4_addr) skip_rdma_device = False - retcode,out = shellutil.run_get_output("modinfo hv_network_direct") + module_name = "hv_network_direct" + retcode,out = shellutil.run_get_output("modprobe -R %s" % module_name, chk_err=False) + if retcode == 0: + module_name = out.strip() + else: + logger.info("RDMA: failed to resolve module name. Use original name") + retcode,out = shellutil.run_get_output("modprobe %s" % module_name) + if retcode != 0: + logger.error("RDMA: failed to load module %s" % module_name) + return + retcode,out = shellutil.run_get_output("modinfo %s" % module_name) if retcode == 0: version = re.search("version:\s+(\d+)\.(\d+)\.(\d+)\D", out, re.IGNORECASE) if version: diff --git a/azurelinuxagent/common/utils/fileutil.py b/azurelinuxagent/common/utils/fileutil.py index bae1957..96b5b82 100644 --- a/azurelinuxagent/common/utils/fileutil.py +++ b/azurelinuxagent/common/utils/fileutil.py @@ -21,15 +21,30 @@ File operation util functions """ +import errno as errno import glob import os +import pwd import re import shutil -import pwd +import string + import azurelinuxagent.common.logger as logger -from azurelinuxagent.common.future import ustr import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.future import ustr + +KNOWN_IOERRORS = [ + errno.EIO, # I/O error + errno.ENOMEM, # Out of memory + errno.ENFILE, # File table overflow + errno.EMFILE, # Too many open files + errno.ENOSPC, # Out of space + errno.ENAMETOOLONG, # Name too long + errno.ELOOP, # Too many symbolic links encountered + errno.EREMOTEIO # Remote I/O error +] + def copy_file(from_path, to_path=None, to_dir=None): if to_path is None: to_path = os.path.join(to_dir, os.path.basename(from_path)) @@ -160,18 +175,31 @@ def chmod_tree(path, mode): for file_name in files: os.chmod(os.path.join(root, file_name), mode) -def findstr_in_file(file_path, pattern_str): +def findstr_in_file(file_path, line_str): + """ + Return True if the line is in the file; False otherwise. + (Trailing whitespace is ignore.) + """ + try: + for line in (open(file_path, 'r')).readlines(): + if line_str == line.rstrip(): + return True + except Exception as e: + pass + return False + +def findre_in_file(file_path, line_re): """ Return match object if found in file. """ try: - pattern = re.compile(pattern_str) + pattern = re.compile(line_re) for line in (open(file_path, 'r')).readlines(): match = re.search(pattern, line) if match: return match except: - raise + pass return None @@ -184,3 +212,21 @@ def get_all_files(root_path): result.extend([os.path.join(root, file) for file in files]) return result + +def clean_ioerror(e, paths=[]): + """ + Clean-up possibly bad files and directories after an IO error. + The code ignores *all* errors since disk state may be unhealthy. + """ + if isinstance(e, IOError) and e.errno in KNOWN_IOERRORS: + for path in paths: + if path is None: + continue + + try: + if os.path.isdir(path): + shutil.rmtree(path, ignore_errors=True) + else: + os.remove(path) + except Exception as e: + pass diff --git a/azurelinuxagent/common/utils/restutil.py b/azurelinuxagent/common/utils/restutil.py index 49d2d68..ddd930b 100644 --- a/azurelinuxagent/common/utils/restutil.py +++ b/azurelinuxagent/common/utils/restutil.py @@ -17,20 +17,74 @@ # Requires Python 2.4+ and Openssl 1.0+ # +import os import time +import traceback import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger -from azurelinuxagent.common.exception import HttpError -from azurelinuxagent.common.future import httpclient, urlparse +import azurelinuxagent.common.utils.textutil as textutil -""" -REST api util functions -""" +from azurelinuxagent.common.exception import HttpError, ResourceGoneError +from azurelinuxagent.common.future import httpclient, urlparse, ustr +from azurelinuxagent.common.version import PY_VERSION_MAJOR -RETRY_WAITING_INTERVAL = 3 -secure_warning = True +SECURE_WARNING_EMITTED = False + +DEFAULT_RETRIES = 3 + +SHORT_DELAY_IN_SECONDS = 5 +LONG_DELAY_IN_SECONDS = 15 + +RETRY_CODES = [ + httpclient.RESET_CONTENT, + httpclient.PARTIAL_CONTENT, + httpclient.FORBIDDEN, + httpclient.INTERNAL_SERVER_ERROR, + httpclient.NOT_IMPLEMENTED, + httpclient.BAD_GATEWAY, + httpclient.SERVICE_UNAVAILABLE, + httpclient.GATEWAY_TIMEOUT, + httpclient.INSUFFICIENT_STORAGE, + 429, # Request Rate Limit Exceeded +] + +RESOURCE_GONE_CODES = [ + httpclient.BAD_REQUEST, + httpclient.GONE +] + +OK_CODES = [ + httpclient.OK, + httpclient.CREATED, + httpclient.ACCEPTED +] + +THROTTLE_CODES = [ + httpclient.FORBIDDEN, + httpclient.SERVICE_UNAVAILABLE +] + +RETRY_EXCEPTIONS = [ + httpclient.NotConnected, + httpclient.IncompleteRead, + httpclient.ImproperConnectionState, + httpclient.BadStatusLine +] + +HTTP_PROXY_ENV = "http_proxy" +HTTPS_PROXY_ENV = "https_proxy" + + +def _is_retry_status(status, retry_codes=RETRY_CODES): + return status in retry_codes + +def _is_retry_exception(e): + return len([x for x in RETRY_EXCEPTIONS if isinstance(e, x)]) > 0 + +def _is_throttle_status(status): + return status in THROTTLE_CODES def _parse_url(url): o = urlparse(url) @@ -45,46 +99,57 @@ def _parse_url(url): return o.hostname, o.port, secure, rel_uri -def get_http_proxy(): - """ - Get http_proxy and https_proxy from environment variables. - Username and password is not supported now. - """ +def _get_http_proxy(secure=False): + # Prefer the configuration settings over environment variables host = conf.get_httpproxy_host() - port = conf.get_httpproxy_port() + port = None + + if not host is None: + port = conf.get_httpproxy_port() + + else: + http_proxy_env = HTTPS_PROXY_ENV if secure else HTTP_PROXY_ENV + http_proxy_url = None + for v in [http_proxy_env, http_proxy_env.upper()]: + if v in os.environ: + http_proxy_url = os.environ[v] + break + + if not http_proxy_url is None: + host, port, _, _ = _parse_url(http_proxy_url) + return host, port def _http_request(method, host, rel_uri, port=None, data=None, secure=False, headers=None, proxy_host=None, proxy_port=None): - url, conn = None, None + + headers = {} if headers is None else headers + use_proxy = proxy_host is not None and proxy_port is not None + + if port is None: + port = 443 if secure else 80 + + if use_proxy: + conn_host, conn_port = proxy_host, proxy_port + scheme = "https" if secure else "http" + url = "{0}://{1}:{2}{3}".format(scheme, host, port, rel_uri) + + else: + conn_host, conn_port = host, port + url = rel_uri + if secure: - port = 443 if port is None else port - if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPSConnection(proxy_host, - proxy_port, - timeout=10) + conn = httpclient.HTTPSConnection(conn_host, + conn_port, + timeout=10) + if use_proxy: conn.set_tunnel(host, port) - # If proxy is used, full url is needed. - url = "https://{0}:{1}{2}".format(host, port, rel_uri) - else: - conn = httpclient.HTTPSConnection(host, - port, - timeout=10) - url = rel_uri + else: - port = 80 if port is None else port - if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPConnection(proxy_host, - proxy_port, - timeout=10) - # If proxy is used, full url is needed. - url = "http://{0}:{1}{2}".format(host, port, rel_uri) - else: - conn = httpclient.HTTPConnection(host, - port, - timeout=10) - url = rel_uri + conn = httpclient.HTTPConnection(conn_host, + conn_port, + timeout=10) logger.verbose("HTTP connection [{0}] [{1}] [{2}] [{3}]", method, @@ -92,49 +157,70 @@ def _http_request(method, host, rel_uri, port=None, data=None, secure=False, data, headers) - headers = {} if headers is None else headers conn.request(method=method, url=url, body=data, headers=headers) - resp = conn.getresponse() - return resp + return conn.getresponse() -def http_request(method, url, data, headers=None, max_retry=3, - chk_proxy=False): - """ - Sending http request to server - On error, sleep 10 and retry max_retry times. - """ +def http_request(method, + url, data, headers=None, + use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + + global SECURE_WARNING_EMITTED + host, port, secure, rel_uri = _parse_url(url) - global secure_warning - # Check proxy + # Use the HTTP(S) proxy proxy_host, proxy_port = (None, None) - if chk_proxy: - proxy_host, proxy_port = get_http_proxy() + if use_proxy: + proxy_host, proxy_port = _get_http_proxy(secure=secure) - # If httplib module is not built with ssl support. Fallback to http + if proxy_host or proxy_port: + logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port) + + # If httplib module is not built with ssl support, + # fallback to HTTP if allowed if secure and not hasattr(httpclient, "HTTPSConnection"): + if not conf.get_allow_http(): + raise HttpError("HTTPS is unavailable and required") + secure = False - if secure_warning: - logger.warn("httplib is not built with ssl support") - secure_warning = False + if not SECURE_WARNING_EMITTED: + logger.warn("Python does not include SSL support") + SECURE_WARNING_EMITTED = True + + # If httplib module doesn't support HTTPS tunnelling, + # fallback to HTTP if allowed + if secure and \ + proxy_host is not None and \ + proxy_port is not None \ + and not hasattr(httpclient.HTTPSConnection, "set_tunnel"): + + if not conf.get_allow_http(): + raise HttpError("HTTPS tunnelling is unavailable and required") - # If httplib module doesn't support https tunnelling. Fallback to http - if secure and proxy_host is not None and proxy_port is not None \ - and not hasattr(httpclient.HTTPSConnection, "set_tunnel"): secure = False - if secure_warning: - logger.warn("httplib does not support https tunnelling " - "(new in python 2.7)") - secure_warning = False - - if proxy_host or proxy_port: - logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port) - - retry_msg = '' - log_msg = "HTTP {0}".format(method) - for retry in range(0, max_retry): - retry_interval = RETRY_WAITING_INTERVAL + if not SECURE_WARNING_EMITTED: + logger.warn("Python does not support HTTPS tunnelling") + SECURE_WARNING_EMITTED = True + + msg = '' + attempt = 0 + delay = retry_delay + + while attempt < max_retry: + if attempt > 0: + logger.info("[HTTP Retry] Attempt {0} of {1}: {2}", + attempt+1, + max_retry, + msg) + time.sleep(delay) + + attempt += 1 + delay = retry_delay + try: resp = _http_request(method, host, @@ -145,55 +231,123 @@ def http_request(method, url, data, headers=None, max_retry=3, headers=headers, proxy_host=proxy_host, proxy_port=proxy_port) - logger.verbose("HTTP response status: [{0}]", resp.status) + logger.verbose("[HTTP Response] Status Code {0}", resp.status) + + if request_failed(resp): + if _is_retry_status(resp.status, retry_codes=retry_codes): + msg = '[HTTP Retry] HTTP {0} Status Code {1}'.format( + method, resp.status) + if _is_throttle_status(resp.status): + delay = LONG_DELAY_IN_SECONDS + logger.info("[HTTP Delay] Delay {0} seconds for " \ + "Status Code {1}".format( + delay, resp.status)) + continue + + if resp.status in RESOURCE_GONE_CODES: + raise ResourceGoneError() + return resp + except httpclient.HTTPException as e: - retry_msg = 'HTTP exception: {0} {1}'.format(log_msg, e) - retry_interval = 5 + msg = '[HTTP Failed] HTTP {0} HttpException {1}'.format(method, e) + if _is_retry_exception(e): + continue + break + except IOError as e: - retry_msg = 'IO error: {0} {1}'.format(log_msg, e) - # error 101: network unreachable; when the adapter resets we may - # see this transient error for a short time, retry once. - if e.errno == 101: - retry_interval = RETRY_WAITING_INTERVAL - max_retry = 1 + msg = '[HTTP Failed] HTTP {0} IOError {1}'.format(method, e) + continue + + raise HttpError(msg) + + +def http_get(url, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("GET", + url, None, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_head(url, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("HEAD", + url, None, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_post(url, data, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("POST", + url, data, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_put(url, data, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("PUT", + url, data, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_delete(url, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("DELETE", + url, None, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + +def request_failed(resp, ok_codes=OK_CODES): + return not request_succeeded(resp, ok_codes=ok_codes) + +def request_succeeded(resp, ok_codes=OK_CODES): + return resp is not None and resp.status in ok_codes + +def read_response_error(resp): + result = '' + if resp is not None: + try: + result = "[HTTP Failed] [{0}: {1}] {2}".format( + resp.status, + resp.reason, + resp.read()) + + # this result string is passed upstream to several methods + # which do a raise HttpError() or a format() of some kind; + # as a result it cannot have any unicode characters + if PY_VERSION_MAJOR < 3: + result = ustr(result, encoding='ascii', errors='ignore') else: - retry_interval = 0 - max_retry = 0 - - if retry < max_retry: - logger.info("Retry [{0}/{1} - {3}]", - retry+1, - max_retry, - retry_interval, - retry_msg) - time.sleep(retry_interval) - - raise HttpError("{0} failed".format(log_msg)) - - -def http_get(url, headers=None, max_retry=3, chk_proxy=False): - return http_request("GET", url, data=None, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - - -def http_head(url, headers=None, max_retry=3, chk_proxy=False): - return http_request("HEAD", url, None, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - - -def http_post(url, data, headers=None, max_retry=3, chk_proxy=False): - return http_request("POST", url, data, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - - -def http_put(url, data, headers=None, max_retry=3, chk_proxy=False): - return http_request("PUT", url, data, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - + result = result\ + .encode(encoding='ascii', errors='ignore')\ + .decode(encoding='ascii', errors='ignore') -def http_delete(url, headers=None, max_retry=3, chk_proxy=False): - return http_request("DELETE", url, None, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) + result = textutil.replace_non_ascii(result) -# End REST api util functions + except Exception: + logger.warn(traceback.format_exc()) + return result diff --git a/azurelinuxagent/common/utils/textutil.py b/azurelinuxagent/common/utils/textutil.py index 2d99f6f..7e244fc 100644 --- a/azurelinuxagent/common/utils/textutil.py +++ b/azurelinuxagent/common/utils/textutil.py @@ -19,6 +19,7 @@ import base64 import crypt import random +import re import string import struct import sys @@ -259,6 +260,17 @@ def set_ini_config(config, name, val): config.insert(length - 1, text) +def replace_non_ascii(incoming, replace_char=''): + outgoing = '' + if incoming is not None: + for c in incoming: + if str_to_ord(c) > 128: + outgoing += replace_char + else: + outgoing += c + return outgoing + + def remove_bom(c): ''' bom is comprised of a sequence of three chars,0xef, 0xbb, 0xbf, in case of utf-8. @@ -311,6 +323,16 @@ def safe_shlex_split(s): return shlex.split(s.encode('utf-8')) return shlex.split(s) +def swap_hexstring(s, width=2): + r = len(s) % width + if r != 0: + s = ('0' * (width - (len(s) % width))) + s + + return ''.join(reversed( + re.findall( + r'[a-f0-9]{{{0}}}'.format(width), + s, + re.IGNORECASE))) def parse_json(json_str): """ diff --git a/azurelinuxagent/common/version.py b/azurelinuxagent/common/version.py index d1d4c62..f27db38 100644 --- a/azurelinuxagent/common/version.py +++ b/azurelinuxagent/common/version.py @@ -113,7 +113,7 @@ def get_distro(): AGENT_NAME = "WALinuxAgent" AGENT_LONG_NAME = "Azure Linux Agent" -AGENT_VERSION = '2.2.14' +AGENT_VERSION = '2.2.16' AGENT_LONG_VERSION = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION) AGENT_DESCRIPTION = """ The Azure Linux Agent supports the provisioning and running of Linux @@ -129,9 +129,20 @@ AGENT_NAME_PATTERN = re.compile(AGENT_PATTERN) AGENT_PKG_PATTERN = re.compile(AGENT_PATTERN+"\.zip") AGENT_DIR_PATTERN = re.compile(".*/{0}".format(AGENT_PATTERN)) -EXT_HANDLER_PATTERN = b".*/WALinuxAgent-(\w.\w.\w[.\w]*)-.*-run-exthandlers" +EXT_HANDLER_PATTERN = b".*/WALinuxAgent-(\d+.\d+.\d+[.\d+]*).*-run-exthandlers" EXT_HANDLER_REGEX = re.compile(EXT_HANDLER_PATTERN) +__distro__ = get_distro() +DISTRO_NAME = __distro__[0] +DISTRO_VERSION = __distro__[1] +DISTRO_CODE_NAME = __distro__[2] +DISTRO_FULL_NAME = __distro__[3] + +PY_VERSION = sys.version_info +PY_VERSION_MAJOR = sys.version_info[0] +PY_VERSION_MINOR = sys.version_info[1] +PY_VERSION_MICRO = sys.version_info[2] + # Set the CURRENT_AGENT and CURRENT_VERSION to match the agent directory name # - This ensures the agent will "see itself" using the same name and version @@ -173,6 +184,8 @@ def set_goal_state_agent(): match = EXT_HANDLER_REGEX.match(pname) if match: agent = match.group(1) + if PY_VERSION_MAJOR > 2: + agent = agent.decode('UTF-8') break except IOError: continue @@ -188,18 +201,6 @@ def is_current_agent_installed(): return CURRENT_AGENT == AGENT_LONG_VERSION -__distro__ = get_distro() -DISTRO_NAME = __distro__[0] -DISTRO_VERSION = __distro__[1] -DISTRO_CODE_NAME = __distro__[2] -DISTRO_FULL_NAME = __distro__[3] - -PY_VERSION = sys.version_info -PY_VERSION_MAJOR = sys.version_info[0] -PY_VERSION_MINOR = sys.version_info[1] -PY_VERSION_MICRO = sys.version_info[2] - - def is_snappy(): """ Add this workaround for detecting Snappy Ubuntu Core temporarily, diff --git a/azurelinuxagent/ga/env.py b/azurelinuxagent/ga/env.py index c81eed7..0456cb0 100644 --- a/azurelinuxagent/ga/env.py +++ b/azurelinuxagent/ga/env.py @@ -26,7 +26,10 @@ import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger from azurelinuxagent.common.dhcp import get_dhcp_handler +from azurelinuxagent.common.event import add_periodic, WALAEventOperation from azurelinuxagent.common.osutil import get_osutil +from azurelinuxagent.common.protocol import get_protocol_util +from azurelinuxagent.common.version import AGENT_NAME, CURRENT_VERSION def get_env_handler(): return EnvHandler() @@ -42,6 +45,7 @@ class EnvHandler(object): def __init__(self): self.osutil = get_osutil() self.dhcp_handler = get_dhcp_handler() + self.protocol_util = get_protocol_util() self.stopped = True self.hostname = None self.dhcpid = None @@ -64,17 +68,35 @@ class EnvHandler(object): def monitor(self): """ + Monitor firewall rules Monitor dhcp client pid and hostname. If dhcp clinet process re-start has occurred, reset routes. """ + protocol = self.protocol_util.get_protocol() while not self.stopped: self.osutil.remove_rules_files() + + if conf.enable_firewall(): + success = self.osutil.enable_firewall( + dst_ip=protocol.endpoint, + uid=os.getuid()) + add_periodic( + logger.EVERY_HOUR, + AGENT_NAME, + version=CURRENT_VERSION, + op=WALAEventOperation.Firewall, + is_success=success, + log_event=True) + timeout = conf.get_root_device_scsi_timeout() if timeout is not None: self.osutil.set_scsi_disks_timeout(timeout) + if conf.get_monitor_hostname(): self.handle_hostname_update() + self.handle_dhclient_restart() + time.sleep(5) def handle_hostname_update(self): diff --git a/azurelinuxagent/ga/exthandlers.py b/azurelinuxagent/ga/exthandlers.py index 4324d92..f0a3b09 100644 --- a/azurelinuxagent/ga/exthandlers.py +++ b/azurelinuxagent/ga/exthandlers.py @@ -411,6 +411,7 @@ class ExtHandlerInstance(object): self.protocol = protocol self.operation = None self.pkg = None + self.pkg_file = None self.is_upgrade = False prefix = "[{0}]".format(self.get_full_name()) @@ -612,12 +613,14 @@ class ExtHandlerInstance(object): raise ExtensionError("Failed to download extension") self.logger.verbose("Unpack extension package") - pkg_file = os.path.join(conf.get_lib_dir(), + self.pkg_file = os.path.join(conf.get_lib_dir(), os.path.basename(uri.uri) + ".zip") try: - fileutil.write_file(pkg_file, bytearray(package), asbin=True) - zipfile.ZipFile(pkg_file).extractall(self.get_base_dir()) + fileutil.write_file(self.pkg_file, bytearray(package), asbin=True) + zipfile.ZipFile(self.pkg_file).extractall(self.get_base_dir()) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to write and unzip plugin", e) #Add user execute permission to all files under the base dir @@ -638,6 +641,8 @@ class ExtHandlerInstance(object): man = fileutil.read_file(man_file, remove_bom=True) fileutil.write_file(self.get_manifest_file(), man) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to save HandlerManifest.json", e) #Create status and config dir @@ -647,6 +652,8 @@ class ExtHandlerInstance(object): conf_dir = self.get_conf_dir() fileutil.mkdir(conf_dir, mode=0o700) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to create status or config dir", e) #Save HandlerEnvironment.json @@ -846,6 +853,8 @@ class ExtHandlerInstance(object): try: fileutil.write_file(settings_file, settings) except IOError as e: + fileutil.clean_ioerror(e, + paths=[settings_file]) raise ExtensionError(u"Failed to update settings file", e) def update_settings(self): @@ -886,6 +895,8 @@ class ExtHandlerInstance(object): try: fileutil.write_file(self.get_env_file(), json.dumps(env)) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to save handler environment", e) def set_handler_state(self, handler_state): @@ -897,6 +908,8 @@ class ExtHandlerInstance(object): state_file = os.path.join(state_dir, "HandlerState") fileutil.write_file(state_file, handler_state) except IOError as e: + fileutil.clean_ioerror(e, + paths=[state_file]) self.logger.error("Failed to set state: {0}", e) def get_handler_state(self): @@ -925,6 +938,8 @@ class ExtHandlerInstance(object): try: fileutil.write_file(status_file, json.dumps(get_properties(handler_status))) except (IOError, ValueError, ProtocolError) as e: + fileutil.clean_ioerror(e, + paths=[status_file]) self.logger.error("Failed to save handler status: {0}", e) def get_handler_status(self): diff --git a/azurelinuxagent/ga/monitor.py b/azurelinuxagent/ga/monitor.py index dcfd6a4..307a514 100644 --- a/azurelinuxagent/ga/monitor.py +++ b/azurelinuxagent/ga/monitor.py @@ -36,8 +36,8 @@ from azurelinuxagent.common.protocol.restapi import TelemetryEventParam, \ set_properties from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, getattrib from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \ - DISTRO_CODE_NAME, AGENT_LONG_VERSION, \ - CURRENT_AGENT, CURRENT_VERSION + DISTRO_CODE_NAME, AGENT_LONG_VERSION, \ + AGENT_NAME, CURRENT_AGENT, CURRENT_VERSION def parse_event(data_str): @@ -184,9 +184,9 @@ class MonitorHandler(object): if datetime.datetime.utcnow() >= (last_heartbeat + period): last_heartbeat = datetime.datetime.utcnow() add_event( - op=WALAEventOperation.HeartBeat, - name=CURRENT_AGENT, + name=AGENT_NAME, version=CURRENT_VERSION, + op=WALAEventOperation.HeartBeat, is_success=True) try: self.collect_and_send_events() diff --git a/azurelinuxagent/ga/update.py b/azurelinuxagent/ga/update.py index 10eac82..b7ee96a 100644 --- a/azurelinuxagent/ga/update.py +++ b/azurelinuxagent/ga/update.py @@ -41,7 +41,9 @@ import azurelinuxagent.common.utils.textutil as textutil from azurelinuxagent.common.event import add_event, add_periodic, \ elapsed_milliseconds, \ WALAEventOperation -from azurelinuxagent.common.exception import UpdateError, ProtocolError +from azurelinuxagent.common.exception import ProtocolError, \ + ResourceGoneError, \ + UpdateError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.osutil import get_osutil from azurelinuxagent.common.protocol import get_protocol_util @@ -231,7 +233,8 @@ class UpdateHandler(object): This is the main loop which watches for agent and extension updates. """ - logger.info(u"Agent {0} is running as the goal state agent", CURRENT_AGENT) + logger.info(u"Agent {0} is running as the goal state agent", + CURRENT_AGENT) # Launch monitoring threads from azurelinuxagent.ga.monitor import get_monitor_handler @@ -245,14 +248,13 @@ class UpdateHandler(object): migrate_handler_state() try: - send_event_time = datetime.utcnow() - self._ensure_no_orphans() self._emit_restart_event() while self.running: if self._is_orphaned: - logger.info("Goal state agent {0} was orphaned -- exiting", CURRENT_AGENT) + logger.info("Goal state agent {0} was orphaned -- exiting", + CURRENT_AGENT) break if self._upgrade_available(): @@ -277,23 +279,19 @@ class UpdateHandler(object): duration=elapsed_milliseconds(utc_start), log_event=True) - test_agent = self.get_test_agent() - if test_agent is not None and test_agent.in_slice: - test_agent.enable() - logger.info(u"Enabled Agent {0} as test agent", test_agent.name) - break - time.sleep(GOAL_STATE_INTERVAL) except Exception as e: - logger.warn(u"Agent {0} failed with exception: {1}", CURRENT_AGENT, ustr(e)) + logger.warn(u"Agent {0} failed with exception: {1}", + CURRENT_AGENT, + ustr(e)) logger.warn(traceback.format_exc()) sys.exit(1) + # additional return here because sys.exit is mocked in unit tests return self._shutdown() sys.exit(0) - return def forward_signal(self, signum, frame): # Note: @@ -339,14 +337,6 @@ class UpdateHandler(object): return available_agents[0] if len(available_agents) >= 1 else None - def get_test_agent(self): - agent = None - agents = [agent for agent in self._load_agents() if agent.is_test] - if len(agents) > 0: - agents.sort(key=lambda agent: agent.version, reverse=True) - agent = agents[0] - return agent - def _emit_restart_event(self): if not self._is_clean_start: msg = u"{0} did not terminate cleanly".format(CURRENT_AGENT) @@ -361,81 +351,13 @@ class UpdateHandler(object): self._set_sentinal() return - def _upgrade_available(self, base_version=CURRENT_VERSION): - # Ignore new agents if updating is disabled - if not conf.get_autoupdate_enabled(): - return False - - now = time.time() - if self.last_attempt_time is not None: - next_attempt_time = self.last_attempt_time + conf.get_autoupdate_frequency() - else: - next_attempt_time = now - if next_attempt_time > now: - return False - - family = conf.get_autoupdate_gafamily() - logger.verbose("Checking for agent family {0} updates", family) - - self.last_attempt_time = now - try: - protocol = self.protocol_util.get_protocol() - manifest_list, etag = protocol.get_vmagent_manifests() - except Exception as e: - msg = u"Exception retrieving agent manifests: {0}".format(ustr(e)) - logger.warn(msg) - add_event( - AGENT_NAME, - op=WALAEventOperation.Download, - version=CURRENT_VERSION, - is_success=False, - message=msg) - return False - - manifests = [m for m in manifest_list.vmAgentManifests \ - if m.family == family and len(m.versionsManifestUris) > 0] - if len(manifests) == 0: - logger.verbose(u"Incarnation {0} has no agent family {1} updates", etag, family) - return False - - try: - pkg_list = protocol.get_vmagent_pkgs(manifests[0]) - except ProtocolError as e: - msg = u"Incarnation {0} failed to get {1} package list: " \ - u"{2}".format( - etag, - family, - ustr(e)) - logger.warn(msg) - add_event( - AGENT_NAME, - op=WALAEventOperation.Download, - version=CURRENT_VERSION, - is_success=False, - message=msg) - return False - - # Set the agents to those available for download at least as current - # as the existing agent and remove from disk any agent no longer - # reported to the VM. - # Note: - # The code leaves on disk available, but blacklisted, agents so as to - # preserve the state. Otherwise, those agents could be again - # downloaded and inappropriately retried. - host = self._get_host_plugin(protocol=protocol) - self._set_agents([GuestAgent(pkg=pkg, host=host) for pkg in pkg_list.versions]) - self._purge_agents() - self._filter_blacklisted_agents() - - # Return True if agents more recent than the current are available - return len(self.agents) > 0 and self.agents[0].version > base_version - def _ensure_no_orphans(self, orphan_wait_interval=ORPHAN_WAIT_INTERVAL): - previous_pid_file, pid_file = self._write_pid_file() - if previous_pid_file is not None: + pid_files, ignored = self._write_pid_file() + for pid_file in pid_files: try: - pid = fileutil.read_file(previous_pid_file) + pid = fileutil.read_file(pid_file) wait_interval = orphan_wait_interval + while self.osutil.check_pid_alive(pid): wait_interval -= GOAL_STATE_INTERVAL if wait_interval <= 0: @@ -452,6 +374,8 @@ class UpdateHandler(object): pid) time.sleep(GOAL_STATE_INTERVAL) + os.remove(pid_file) + except Exception as e: logger.warn( u"Exception occurred waiting for orphan agent to terminate: {0}", @@ -508,22 +432,18 @@ class UpdateHandler(object): protocol.client \ else None - def _get_pid_files(self): + def _get_pid_parts(self): pid_file = conf.get_agent_pid_file_path() - pid_dir = os.path.dirname(pid_file) pid_name = os.path.basename(pid_file) - pid_re = re.compile("(\d+)_{0}".format(re.escape(pid_name))) - pid_files = [int(pid_re.match(f).group(1)) for f in os.listdir(pid_dir) if pid_re.match(f)] - pid_files.sort() + return pid_dir, pid_name, pid_re - pid_index = -1 if len(pid_files) <= 0 else pid_files[-1] - previous_pid_file = None \ - if pid_index < 0 \ - else os.path.join(pid_dir, "{0}_{1}".format(pid_index, pid_name)) - pid_file = os.path.join(pid_dir, "{0}_{1}".format(pid_index+1, pid_name)) - return previous_pid_file, pid_file + def _get_pid_files(self): + pid_dir, pid_name, pid_re = self._get_pid_parts() + pid_files = [os.path.join(pid_dir, f) for f in os.listdir(pid_dir) if pid_re.match(f)] + pid_files.sort(key=lambda f: int(pid_re.match(os.path.basename(f)).group(1))) + return pid_files @property def _is_clean_start(self): @@ -619,8 +539,98 @@ class UpdateHandler(object): str(e)) return + def _upgrade_available(self, base_version=CURRENT_VERSION): + # Emit an event expressing the state of AutoUpdate + # Note: + # - Duplicate events get suppressed; state transitions always emit + add_event( + AGENT_NAME, + version=CURRENT_VERSION, + op=WALAEventOperation.AutoUpdate, + is_success=conf.get_autoupdate_enabled()) + + # Ignore new agents if updating is disabled + if not conf.get_autoupdate_enabled(): + return False + + now = time.time() + if self.last_attempt_time is not None: + next_attempt_time = self.last_attempt_time + \ + conf.get_autoupdate_frequency() + else: + next_attempt_time = now + if next_attempt_time > now: + return False + + family = conf.get_autoupdate_gafamily() + logger.verbose("Checking for agent family {0} updates", family) + + self.last_attempt_time = now + protocol = self.protocol_util.get_protocol() + + for update_goal_state in [False, True]: + try: + if update_goal_state: + protocol.update_goal_state(forced=True) + + manifest_list, etag = protocol.get_vmagent_manifests() + + manifests = [m for m in manifest_list.vmAgentManifests \ + if m.family == family and \ + len(m.versionsManifestUris) > 0] + if len(manifests) == 0: + logger.verbose(u"Incarnation {0} has no {1} agent updates", + etag, family) + return False + + pkg_list = protocol.get_vmagent_pkgs(manifests[0]) + + # Set the agents to those available for download at least as + # current as the existing agent and remove from disk any agent + # no longer reported to the VM. + # Note: + # The code leaves on disk available, but blacklisted, agents + # so as to preserve the state. Otherwise, those agents could be + # again downloaded and inappropriately retried. + host = self._get_host_plugin(protocol=protocol) + self._set_agents([GuestAgent(pkg=pkg, host=host) \ + for pkg in pkg_list.versions]) + + self._purge_agents() + self._filter_blacklisted_agents() + + # Return True if more recent agents are available + return len(self.agents) > 0 and \ + self.agents[0].version > base_version + + except Exception as e: + if isinstance(e, ResourceGoneError): + continue + + msg = u"Exception retrieving agent manifests: {0}".format( + ustr(e)) + logger.warn(msg) + add_event( + AGENT_NAME, + op=WALAEventOperation.Download, + version=CURRENT_VERSION, + is_success=False, + message=msg) + return False + def _write_pid_file(self): - previous_pid_file, pid_file = self._get_pid_files() + pid_files = self._get_pid_files() + + pid_dir, pid_name, pid_re = self._get_pid_parts() + + previous_pid_file = None \ + if len(pid_files) <= 0 \ + else pid_files[-1] + pid_index = -1 \ + if previous_pid_file is None \ + else int(pid_re.match(os.path.basename(previous_pid_file)).group(1)) + pid_file = os.path.join(pid_dir, "{0}_{1}".format(pid_index+1, pid_name)) + try: fileutil.write_file(pid_file, ustr(os.getpid())) logger.info(u"{0} running as process {1}", CURRENT_AGENT, ustr(os.getpid())) @@ -631,7 +641,8 @@ class UpdateHandler(object): CURRENT_AGENT, pid_file, ustr(e)) - return previous_pid_file, pid_file + + return pid_files, pid_file class GuestAgent(object): @@ -652,15 +663,34 @@ class GuestAgent(object): self.version = FlexibleVersion(version) location = u"disk" if path is not None else u"package" - logger.verbose(u"Loading Agent {0} from package {1}", self.name, location) + logger.verbose(u"Loading Agent {0} from {1}", self.name, location) - self.error = None - self.supported = None + self.error = GuestAgentError(self.get_agent_error_file()) + self.error.load() + self.supported = Supported(self.get_agent_supported_file()) + self.supported.load() - self._load_error() - self._load_supported() + try: + self._ensure_downloaded() + self._ensure_loaded() + except Exception as e: + if isinstance(e, ResourceGoneError): + raise - self._ensure_downloaded() + # Note the failure, blacklist the agent if the package downloaded + # - An exception with a downloaded package indicates the package + # is corrupt (e.g., missing the HandlerManifest.json file) + self.mark_failure(is_fatal=os.path.isfile(self.get_agent_pkg_path())) + + msg = u"Agent {0} install failed with exception: {1}".format( + self.name, ustr(e)) + logger.warn(msg) + add_event( + AGENT_NAME, + version=self.version, + op=WALAEventOperation.Install, + is_success=False, + message=msg) return @property @@ -687,12 +717,7 @@ class GuestAgent(object): def clear_error(self): self.error.clear() - return - - def enable(self): - if self.error.is_sentinel: - self.error.clear() - self.error.save() + self.error.save() return @property @@ -708,12 +733,12 @@ class GuestAgent(object): return self.is_blacklisted or os.path.isfile(self.get_agent_manifest_path()) @property - def is_test(self): - return self.error.is_sentinel and self.supported.is_supported + def _is_optional(self): + return self.error is not None and self.error.is_sentinel and self.supported.is_supported @property - def in_slice(self): - return self.is_test and self.supported.in_slice + def _in_slice(self): + return self.supported.is_supported and self.supported.in_slice def mark_failure(self, is_fatal=False): try: @@ -727,72 +752,79 @@ class GuestAgent(object): logger.warn(u"Agent {0} failed recording error state: {1}", self.name, ustr(e)) return + def _enable(self): + # Enable optional agents if within the "slice" + # - The "slice" is a percentage of the agent to execute + # - Blacklist out-of-slice agents to prevent reconsideration + if self._is_optional: + if self._in_slice: + self.error.clear() + self.error.save() + logger.info(u"Enabled optional Agent {0}", self.name) + else: + self.mark_failure(is_fatal=True) + logger.info(u"Optional Agent {0} not in slice", self.name) + return + def _ensure_downloaded(self): - try: - logger.verbose(u"Ensuring Agent {0} is downloaded", self.name) - - if self.is_blacklisted: - logger.verbose(u"Agent {0} is blacklisted - skipping download", self.name) - return - - if self.is_downloaded: - logger.verbose(u"Agent {0} was previously downloaded - skipping download", self.name) - self._load_manifest() - return - - if self.pkg is None: - raise UpdateError(u"Agent {0} is missing package and download URIs".format( - self.name)) - - self._download() - self._unpack() - self._load_manifest() - self._load_error() - self._load_supported() - - msg = u"Agent {0} downloaded successfully".format(self.name) - logger.verbose(msg) - add_event( - AGENT_NAME, - version=self.version, - op=WALAEventOperation.Install, - is_success=True, - message=msg) + logger.verbose(u"Ensuring Agent {0} is downloaded", self.name) - except Exception as e: - # Note the failure, blacklist the agent if the package downloaded - # - An exception with a downloaded package indicates the package - # is corrupt (e.g., missing the HandlerManifest.json file) - self.mark_failure(is_fatal=os.path.isfile(self.get_agent_pkg_path())) + if self.is_downloaded: + logger.verbose(u"Agent {0} was previously downloaded - skipping download", self.name) + return - msg = u"Agent {0} download failed with exception: {1}".format(self.name, ustr(e)) - logger.warn(msg) - add_event( - AGENT_NAME, - version=self.version, - op=WALAEventOperation.Install, - is_success=False, - message=msg) + if self.pkg is None: + raise UpdateError(u"Agent {0} is missing package and download URIs".format( + self.name)) + + self._download() + self._unpack() + + msg = u"Agent {0} downloaded successfully".format(self.name) + logger.verbose(msg) + add_event( + AGENT_NAME, + version=self.version, + op=WALAEventOperation.Install, + is_success=True, + message=msg) + return + + def _ensure_loaded(self): + self._load_manifest() + self._load_error() + self._load_supported() + + self._enable() return def _download(self): for uri in self.pkg.uris: if not HostPluginProtocol.is_default_channel() and self._fetch(uri.uri): break + elif self.host is not None and self.host.ensure_initialized(): if not HostPluginProtocol.is_default_channel(): - logger.warn("Download unsuccessful, falling back to host plugin") + logger.warn("Download failed, switching to host plugin") else: logger.verbose("Using host plugin as default channel") uri, headers = self.host.get_artifact_request(uri.uri, self.host.manifest_uri) - if self._fetch(uri, headers=headers): - if not HostPluginProtocol.is_default_channel(): - logger.verbose("Setting host plugin as default channel") - HostPluginProtocol.set_default_channel(True) - break - else: - logger.warn("Host plugin download unsuccessful") + try: + if self._fetch(uri, headers=headers, use_proxy=False): + if not HostPluginProtocol.is_default_channel(): + logger.verbose("Setting host plugin as default channel") + HostPluginProtocol.set_default_channel(True) + break + else: + logger.warn("Host plugin download failed") + + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + except ResourceGoneError: + HostPluginProtocol.set_default_channel(True) + raise + else: logger.error("No download channels available") @@ -805,13 +837,14 @@ class GuestAgent(object): is_success=False, message=msg) raise UpdateError(msg) + return - def _fetch(self, uri, headers=None): + def _fetch(self, uri, headers=None, use_proxy=True): package = None try: - resp = restutil.http_get(uri, chk_proxy=True, headers=headers) - if resp.status == restutil.httpclient.OK: + resp = restutil.http_get(uri, use_proxy=use_proxy, headers=headers) + if restutil.request_succeeded(resp): package = resp.read() fileutil.write_file(self.get_agent_pkg_path(), bytearray(package), @@ -819,18 +852,21 @@ class GuestAgent(object): logger.verbose(u"Agent {0} downloaded from {1}", self.name, uri) else: logger.verbose("Fetch was unsuccessful [{0}]", - HostPluginProtocol.read_response_error(resp)) + restutil.read_response_error(resp)) except restutil.HttpError as http_error: + if isinstance(http_error, ResourceGoneError): + raise + logger.verbose(u"Agent {0} download from {1} failed [{2}]", self.name, uri, http_error) + return package is not None def _load_error(self): try: - if self.error is None: - self.error = GuestAgentError(self.get_agent_error_file()) + self.error = GuestAgentError(self.get_agent_error_file()) self.error.load() logger.verbose(u"Agent {0} error state: {1}", self.name, ustr(self.error)) except Exception as e: @@ -840,6 +876,7 @@ class GuestAgent(object): def _load_supported(self): try: self.supported = Supported(self.get_agent_supported_file()) + self.supported.load() except Exception as e: self.supported = Supported() @@ -892,6 +929,9 @@ class GuestAgent(object): zipfile.ZipFile(self.get_agent_pkg_path()).extractall(self.get_agent_dir()) except Exception as e: + fileutil.clean_ioerror(e, + paths=[self.get_agent_dir(), self.get_agent_pkg_path()]) + msg = u"Exception unpacking Agent {0} from {1}: {2}".format( self.name, self.get_agent_pkg_path(), @@ -918,7 +958,6 @@ class GuestAgentError(object): self.path = path self.clear() - self.load() return def mark_failure(self, is_fatal=False): @@ -982,8 +1021,7 @@ class Supported(object): if path is None: raise UpdateError(u"Supported requires a path") self.path = path - - self._load() + self.distributions = {} return @property @@ -995,15 +1033,7 @@ class Supported(object): d = self._supported_distribution return d is not None and d.in_slice - @property - def _supported_distribution(self): - for d in self.distributions: - dd = self.distributions[d] - if dd.is_supported: - return dd - return None - - def _load(self): + def load(self): self.distributions = {} try: if self.path is not None and os.path.isfile(self.path): @@ -1014,6 +1044,14 @@ class Supported(object): logger.warn("Failed JSON parse of {0}: {1}".format(self.path, e)) return + @property + def _supported_distribution(self): + for d in self.distributions: + dd = self.distributions[d] + if dd.is_supported: + return dd + return None + class SupportedDistribution(object): def __init__(self, s): if s is None or not isinstance(s, dict): diff --git a/azurelinuxagent/pa/provision/default.py b/azurelinuxagent/pa/provision/default.py index 959a2fe..2f7ec18 100644 --- a/azurelinuxagent/pa/provision/default.py +++ b/azurelinuxagent/pa/provision/default.py @@ -53,16 +53,16 @@ class ProvisionHandler(object): self.protocol_util = get_protocol_util() def run(self): - # If provision is not enabled, report ready and then return if not conf.get_provision_enabled(): logger.info("Provisioning is disabled, skipping.") + self.write_provisioned() + self.report_ready() return try: utc_start = datetime.utcnow() thumbprint = None - # if provisioning is already done, return if self.is_provisioned(): logger.info("Provisioning already completed, skipping.") return @@ -85,7 +85,6 @@ class ProvisionHandler(object): thumbprint = self.reg_ssh_host_key() self.osutil.restart_ssh_service() - # write out provisioned file and report Ready self.write_provisioned() self.report_event("Provision succeed", @@ -128,9 +127,17 @@ class ProvisionHandler(object): keypair_type = conf.get_ssh_host_keypair_type() if conf.get_regenerate_ssh_host_key(): fileutil.rm_files(conf.get_ssh_key_glob()) - keygen_cmd = "ssh-keygen -N '' -t {0} -f {1}" - shellutil.run(keygen_cmd.format(keypair_type, - conf.get_ssh_key_private_path())) + if conf.get_ssh_host_keypair_mode() == "auto": + ''' + The -A option generates all supported key types. + This is supported since OpenSSH 5.9 (2011). + ''' + shellutil.run("ssh-keygen -A") + else: + keygen_cmd = "ssh-keygen -N '' -t {0} -f {1}" + shellutil.run(keygen_cmd. + format(keypair_type, + conf.get_ssh_key_private_path())) return self.get_ssh_host_key_thumbprint() def get_ssh_host_key_thumbprint(self, chk_err=True): @@ -162,7 +169,7 @@ class ProvisionHandler(object): return False s = fileutil.read_file(self.provisioned_file_path()).strip() - if s != self.osutil.get_instance_id(): + if not self.osutil.is_current_instance_id(s): if len(s) > 0: logger.warn("VM is provisioned, " "but the VM unique identifier has changed -- " @@ -173,6 +180,7 @@ class ProvisionHandler(object): deprovision_handler.run_changed_unique_id() self.write_provisioned() + self.report_ready() return True diff --git a/azurelinuxagent/pa/rdma/factory.py b/azurelinuxagent/pa/rdma/factory.py index 535b3d3..92bd2e0 100644 --- a/azurelinuxagent/pa/rdma/factory.py +++ b/azurelinuxagent/pa/rdma/factory.py @@ -21,6 +21,7 @@ from azurelinuxagent.common.version import DISTRO_FULL_NAME, DISTRO_VERSION from azurelinuxagent.common.rdma import RDMAHandler from .suse import SUSERDMAHandler from .centos import CentOSRDMAHandler +from .ubuntu import UbuntuRDMAHandler def get_rdma_handler( @@ -37,5 +38,8 @@ def get_rdma_handler( if distro_full_name == 'CentOS Linux' or distro_full_name == 'CentOS': return CentOSRDMAHandler(distro_version) + if distro_full_name == 'Ubuntu': + return UbuntuRDMAHandler() + logger.info("No RDMA handler exists for distro='{0}' version='{1}'", distro_full_name, distro_version) return RDMAHandler() diff --git a/azurelinuxagent/pa/rdma/suse.py b/azurelinuxagent/pa/rdma/suse.py index d31b2b0..20f06cd 100644 --- a/azurelinuxagent/pa/rdma/suse.py +++ b/azurelinuxagent/pa/rdma/suse.py @@ -36,8 +36,9 @@ class SUSERDMAHandler(RDMAHandler): logger.error(error_msg) return zypper_install = 'zypper -n in %s' + zypper_install_noref = 'zypper -n --no-refresh in %s' zypper_remove = 'zypper -n rm %s' - zypper_search = 'zypper se -s %s' + zypper_search = 'zypper -n se -s %s' package_name = 'msft-rdma-kmp-default' cmd = zypper_search % package_name status, repo_package_info = shellutil.run_get_output(cmd) @@ -108,9 +109,9 @@ class SUSERDMAHandler(RDMAHandler): fw_version in local_package ): logger.info("RDMA: Installing: %s" % local_package) - cmd = zypper_install % local_package + cmd = zypper_install_noref % local_package result = shellutil.run(cmd) - if result: + if result and result != 106: error_msg = 'RDMA: Failed install of package "%s" ' error_msg += 'from local package cache' logger.error(error_msg % local_package) diff --git a/azurelinuxagent/pa/rdma/ubuntu.py b/azurelinuxagent/pa/rdma/ubuntu.py new file mode 100644 index 0000000..050797d --- /dev/null +++ b/azurelinuxagent/pa/rdma/ubuntu.py @@ -0,0 +1,122 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import glob +import os +import re +import time +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.shellutil as shellutil +from azurelinuxagent.common.rdma import RDMAHandler + + +class UbuntuRDMAHandler(RDMAHandler): + + def install_driver(self): + #Install the appropriate driver package for the RDMA firmware + + nd_version = RDMAHandler.get_rdma_version() + if not nd_version: + logger.error("RDMA: Could not determine firmware version. No driver will be installed") + return + #replace . with _, we are looking for number like 144_0 + nd_version = re.sub('\.', '_', nd_version) + + #Check to see if we need to reconfigure driver + status,module_name = shellutil.run_get_output('modprobe -R hv_network_direct', chk_err=False) + if status != 0: + logger.info("RDMA: modprobe -R hv_network_direct failed. Use module name hv_network_direct") + module_name = "hv_network_direct" + else: + module_name = module_name.strip() + logger.info("RDMA: current RDMA driver %s nd_version %s" % (module_name, nd_version)) + if module_name == 'hv_network_direct_%s' % nd_version: + logger.info("RDMA: driver is installed and ND version matched. Skip reconfiguring driver") + return + + #Reconfigure driver if one is available + status,output = shellutil.run_get_output('modinfo hv_network_direct_%s' % nd_version); + if status == 0: + logger.info("RDMA: driver with ND version is installed. Link to module name") + self.update_modprobed_conf(nd_version) + return + + #Driver not found. We need to check to see if we need to update kernel + if not conf.enable_rdma_update(): + logger.info("RDMA: driver update is disabled. Skip kernel update") + return + + status,output = shellutil.run_get_output('uname -r') + if status != 0: + return + if not re.search('-azure$', output): + logger.error("RDMA: skip driver update on non-Azure kernel") + return + kernel_version = re.sub('-azure$', '', output) + kernel_version = re.sub('-', '.', kernel_version) + + #Find the new kernel package version + status,output = shellutil.run_get_output('apt-get update') + if status != 0: + return + status,output = shellutil.run_get_output('apt-cache show --no-all-versions linux-azure') + if status != 0: + return + r = re.search('Version: (\S+)', output) + if not r: + logger.error("RDMA: version not found in package linux-azure.") + return + package_version = r.groups()[0] + #Remove the ending .<upload number> after <ABI number> + package_version = re.sub("\.\d+$", "", package_version) + + logger.info('RDMA: kernel_version=%s package_version=%s' % (kernel_version, package_version)) + kernel_version_array = [ int(x) for x in kernel_version.split('.') ] + package_version_array = [ int(x) for x in package_version.split('.') ] + if kernel_version_array < package_version_array: + logger.info("RDMA: newer version available, update kernel and reboot") + status,output = shellutil.run_get_output('apt-get -y install linux-azure') + if status: + logger.error("RDMA: kernel update failed") + return + self.reboot_system() + else: + logger.error("RDMA: no kernel update is avaiable for ND version %s" % nd_version) + + def update_modprobed_conf(self, nd_version): + #Update /etc/modprobe.d/vmbus-rdma.conf to point to the correct driver + + modprobed_file = '/etc/modprobe.d/vmbus-rdma.conf' + lines = '' + if not os.path.isfile(modprobed_file): + logger.info("RDMA: %s not found, it will be created" % modprobed_file) + else: + f = open(modprobed_file, 'r') + lines = f.read() + f.close() + r = re.search('alias hv_network_direct hv_network_direct_\S+', lines) + if r: + lines = re.sub('alias hv_network_direct hv_network_direct_\S+', 'alias hv_network_direct hv_network_direct_%s' % nd_version, lines) + else: + lines += '\nalias hv_network_direct hv_network_direct_%s\n' % nd_version + f = open('/etc/modprobe.d/vmbus-rdma.conf', 'w') + f.write(lines) + f.close() + logger.info("RDMA: hv_network_direct alias updated to ND %s" % nd_version) |