diff options
author | Ćukasz 'sil2100' Zemczak <lukasz.zemczak@ubuntu.com> | 2017-09-04 10:27:07 +0200 |
---|---|---|
committer | usd-importer <ubuntu-server@lists.ubuntu.com> | 2017-09-04 09:38:24 +0000 |
commit | 185ceb32fea5d5c2a43d7b6ee2a40228489055f4 (patch) | |
tree | 2e1c9cc42510c4a922cf63fa265ec0e1945ec14b | |
parent | 43bdf9debe5377216aed0086bff2aad864f6ba82 (diff) | |
download | vyos-walinuxagent-185ceb32fea5d5c2a43d7b6ee2a40228489055f4.tar.gz vyos-walinuxagent-185ceb32fea5d5c2a43d7b6ee2a40228489055f4.zip |
Import patches-unapplied version 2.2.16-0ubuntu1 to ubuntu/artful-proposed
Imported using git-ubuntu import.
Changelog parent: 43bdf9debe5377216aed0086bff2aad864f6ba82
New changelog entries:
* New upstream release (LP: #1714299).
58 files changed, 2773 insertions, 914 deletions
@@ -50,6 +50,12 @@ The information flow from the platform to the agent occurs via two channels: * A TCP endpoint exposing a REST API used to obtain deployment and topology configuration. +The agent will use an HTTP proxy if provided via the `http_proxy` (for `http` requests) or +`https_proxy` (for `https` requests) environment variables. The `HttpProxy.Host` and +`HttpProxy.Port` configuration variables (see below), if used, will override the environment +settings. Due to limitations of Python, the agent *does not* support HTTP proxies requiring +authentication. + ### REQUIREMENTS @@ -58,21 +64,6 @@ Linux Agent. Please note that this list may differ from the official list of supported systems on the Microsoft Azure Platform as described here: http://support.microsoft.com/kb/2805216 -Supported Linux Distributions: - * Archlinux - * CoreOS - * CentOS 6.2+ - * Red Hat Enterprise Linux 6.7+ - * Debian 7.0+ - * Ubuntu 12.04+ - * openSUSE 12.3+ - * SLES 11 SP2+ - * Oracle Linux 6.4+ - -Other Supported Systems: - * FreeBSD 10+ (Azure Linux Agent v2.0.10+) - * OpenBSD 6+ (Azure Linux Agent v2.2.11+) - Waagent depends on some system packages in order to function properly: * Python 2.6+ @@ -168,7 +159,10 @@ script. ### CONFIGURATION A configuration file (/etc/waagent.conf) controls the actions of -waagent. A sample configuration file is shown below: +waagent. Blank lines and lines whose first character is a `#` are +ignored (end-of-line comments are *not* supported). + +A sample configuration file is shown below: ``` Provisioning.Enabled=y @@ -189,6 +183,7 @@ ResourceDisk.EnableSwap=n ResourceDisk.SwapSizeMB=0 LBProbeResponder=y Logs.Verbose=n +OS.AllowHTTP=n OS.RootDeviceScsiTimeout=300 OS.EnableFIPS=n OS.OpensslPath=None @@ -213,7 +208,7 @@ agent. Valid values are "y" or "n". If provisioning is disabled, SSH host and user keys in the image are preserved and any configuration specified in the Azure provisioning API is ignored. -# __Provisioning.UseCloudInit__ +* __Provisioning.UseCloudInit__ _Type: Boolean_ _Default: n_ @@ -348,6 +343,16 @@ _Default: n_ If set, log verbosity is boosted. Waagent logs to /var/log/waagent.log and leverages the system logrotate functionality to rotate logs. +* __OS.AllowHTTP__ +_Type: Boolean_ +_Default: n_ + +If set to `y` and SSL support is not compiled into Python, the agent will fall-back to +use HTTP. Otherwise, if SSL support is not compiled into Python, the agent will fail +all HTTPS requests. + +Note: Allowing HTTP may unintentionally expose secure data. + * __OS.EnableRDMA__ _Type: Boolean_ _Default: n_ @@ -355,9 +360,9 @@ _Default: n_ If set, the agent will attempt to install and then load an RDMA kernel driver that matches the version of the firmware on the underlying hardware. -* __OS.EnableFIPS__ -_Type: Boolean_ -_Default: n_ +* __OS.EnableFIPS__ +_Type: Boolean_ +_Default: n_ If set, the agent will emit into the environment "OPENSSL_FIPS=1" when executing OpenSSL commands. This signals OpenSSL to use any installed FIPS-compliant libraries. @@ -378,9 +383,9 @@ _Default: None_ This can be used to specify an alternate path for the openssl binary to use for cryptographic operations. -* __OS.SshDir__ -_Type: String_ -_Default: "/etc/ssh"_ +* __OS.SshDir__ +_Type: String_ +_Default: `/etc/ssh`_ This option can be used to override the normal location of the SSH configuration directory. @@ -389,7 +394,9 @@ directory. _Type: String_ _Default: None_ -If set, the agent will use this proxy server to access the internet. +If set, the agent will use this proxy server to access the internet. These values +*will* override the `http_proxy` or `https_proxy` environment variables. Lastly, +`HttpProxy.Host` is required (if to be used) and `HttpProxy.Port` is optional. ### APPENDIX diff --git a/azurelinuxagent/agent.py b/azurelinuxagent/agent.py index d1ac354..e99f7be 100644 --- a/azurelinuxagent/agent.py +++ b/azurelinuxagent/agent.py @@ -64,7 +64,19 @@ class Agent(object): logger.add_logger_appender(logger.AppenderType.CONSOLE, level, path="/dev/console") + ext_log_dir = conf.get_ext_log_dir() + try: + if os.path.isfile(ext_log_dir): + raise Exception("{0} is a file".format(ext_log_dir)) + if not os.path.isdir(ext_log_dir): + os.makedirs(ext_log_dir) + except Exception as e: + logger.error( + "Exception occurred while creating extension " + "log directory {0}: {1}".format(ext_log_dir, e)) + #Init event reporter + event.init_event_status(conf.get_lib_dir()) event_dir = os.path.join(conf.get_lib_dir(), "events") event.init_event_logger(event_dir) event.enable_unhandled_err_dump("WALA") @@ -116,6 +128,11 @@ class Agent(object): update_handler = get_update_handler() update_handler.run() + def show_configuration(self): + configuration = conf.get_configuration() + for k in sorted(configuration.keys()): + print("{0} = {1}".format(k, configuration[k])) + def main(args=[]): """ Parse command line arguments, exit with usage() on error. @@ -145,6 +162,8 @@ def main(args=[]): agent.daemon() elif command == "run-exthandlers": agent.run_exthandlers() + elif command == "show-configuration": + agent.show_configuration() except Exception: logger.error(u"Failed to run '{0}': {1}", command, @@ -186,6 +205,8 @@ def parse_args(sys_args): verbose = True elif re.match("^([-/]*)force", a): force = True + elif re.match("^([-/]*)show-configuration", a): + cmd = "show-configuration" elif re.match("^([-/]*)(help|usage|\\?)", a): cmd = "help" else: diff --git a/azurelinuxagent/common/conf.py b/azurelinuxagent/common/conf.py index 5422784..75a0248 100644 --- a/azurelinuxagent/common/conf.py +++ b/azurelinuxagent/common/conf.py @@ -85,12 +85,81 @@ def load_conf_from_file(conf_file_path, conf=__conf__): raise AgentConfigError(("Failed to load conf file:{0}, {1}" "").format(conf_file_path, err)) +__SWITCH_OPTIONS__ = { + "OS.AllowHTTP" : False, + "OS.EnableFirewall" : False, + "OS.EnableFIPS" : False, + "OS.EnableRDMA" : False, + "OS.UpdateRdmaDriver" : False, + "OS.CheckRdmaDriver" : False, + "Logs.Verbose" : False, + "Provisioning.Enabled" : True, + "Provisioning.UseCloudInit" : False, + "Provisioning.AllowResetSysUser" : False, + "Provisioning.RegenerateSshHostKeyPair" : False, + "Provisioning.DeleteRootPassword" : False, + "Provisioning.DecodeCustomData" : False, + "Provisioning.ExecuteCustomData" : False, + "Provisioning.MonitorHostName" : False, + "DetectScvmmEnv" : False, + "ResourceDisk.Format" : False, + "DetectScvmmEnv" : False, + "ResourceDisk.Format" : False, + "ResourceDisk.EnableSwap" : False, + "AutoUpdate.Enabled" : True, + "EnableOverProvisioning" : False +} + +__STRING_OPTIONS__ = { + "Lib.Dir" : "/var/lib/waagent", + "DVD.MountPoint" : "/mnt/cdrom/secure", + "Pid.File" : "/var/run/waagent.pid", + "Extension.LogDir" : "/var/log/azure", + "OS.OpensslPath" : "/usr/bin/openssl", + "OS.SshDir" : "/etc/ssh", + "OS.HomeDir" : "/home", + "OS.PasswordPath" : "/etc/shadow", + "OS.SudoersDir" : "/etc/sudoers.d", + "OS.RootDeviceScsiTimeout" : None, + "Provisioning.SshHostKeyPairType" : "rsa", + "Provisioning.PasswordCryptId" : "6", + "HttpProxy.Host" : None, + "ResourceDisk.MountPoint" : "/mnt/resource", + "ResourceDisk.MountOptions" : None, + "ResourceDisk.Filesystem" : "ext3", + "AutoUpdate.GAFamily" : "Prod" +} + +__INTEGER_OPTIONS__ = { + "Provisioning.PasswordCryptSaltLength" : 10, + "HttpProxy.Port" : None, + "ResourceDisk.SwapSizeMB" : 0, + "Autoupdate.Frequency" : 3600 +} + +def get_configuration(conf=__conf__): + options = {} + for option in __SWITCH_OPTIONS__: + options[option] = conf.get_switch(option, __SWITCH_OPTIONS__[option]) + + for option in __STRING_OPTIONS__: + options[option] = conf.get(option, __STRING_OPTIONS__[option]) + + for option in __INTEGER_OPTIONS__: + options[option] = conf.get_int(option, __INTEGER_OPTIONS__[option]) + + return options + +def enable_firewall(conf=__conf__): + return conf.get_switch("OS.EnableFirewall", False) def enable_rdma(conf=__conf__): return conf.get_switch("OS.EnableRDMA", False) or \ conf.get_switch("OS.UpdateRdmaDriver", False) or \ conf.get_switch("OS.CheckRdmaDriver", False) +def enable_rdma_update(conf=__conf__): + return conf.get_switch("OS.UpdateRdmaDriver", False) def get_logs_verbose(conf=__conf__): return conf.get_switch("Logs.Verbose", False) @@ -151,6 +220,16 @@ def get_root_device_scsi_timeout(conf=__conf__): return conf.get("OS.RootDeviceScsiTimeout", None) def get_ssh_host_keypair_type(conf=__conf__): + keypair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa") + if keypair_type == "auto": + ''' + auto generates all supported key types and returns the + rsa thumbprint as the default. + ''' + return "rsa" + return keypair_type + +def get_ssh_host_keypair_mode(conf=__conf__): return conf.get("Provisioning.SshHostKeyPairType", "rsa") def get_provision_enabled(conf=__conf__): @@ -239,4 +318,7 @@ def get_autoupdate_frequency(conf=__conf__): return conf.get_int("Autoupdate.Frequency", 3600) def get_enable_overprovisioning(conf=__conf__): - return conf.get_switch("EnableOverProvisioning", False)
\ No newline at end of file + return conf.get_switch("EnableOverProvisioning", False) + +def get_allow_http(conf=__conf__): + return conf.get_switch("OS.AllowHTTP", False) diff --git a/azurelinuxagent/common/event.py b/azurelinuxagent/common/event.py index 723b8bf..e62a925 100644 --- a/azurelinuxagent/common/event.py +++ b/azurelinuxagent/common/event.py @@ -27,6 +27,7 @@ import platform from datetime import datetime, timedelta +import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger from azurelinuxagent.common.exception import EventError, ProtocolError @@ -39,13 +40,15 @@ from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \ DISTRO_CODE_NAME, AGENT_VERSION, \ CURRENT_AGENT, CURRENT_VERSION -_EVENT_MSG = "Event: name={0}, op={1}, message={2}" +_EVENT_MSG = "Event: name={0}, op={1}, message={2}, duration={3}" class WALAEventOperation: ActivateResourceDisk = "ActivateResourceDisk" + AutoUpdate = "AutoUpdate" Disable = "Disable" Download = "Download" Enable = "Enable" + Firewall = "Firewall" HealthCheck = "HealthCheck" HeartBeat = "HeartBeat" HostPlugin = "HostPlugin" @@ -60,13 +63,70 @@ class WALAEventOperation: Upgrade = "Upgrade" Update = "Update" -def _log_event(name, op, message, is_success=True): + +class EventStatus(object): + EVENT_STATUS_FILE = "event_status.json" + + def __init__(self, status_dir=conf.get_lib_dir()): + self._path = None + self._status = {} + + def clear(self): + self._status = {} + self._save() + + def event_marked(self, name, version, op): + return self._event_name(name, version, op) in self._status + + def event_succeeded(self, name, version, op): + event = self._event_name(name, version, op) + if event not in self._status: + return True + return self._status[event] == True + + def initialize(self, status_dir=conf.get_lib_dir()): + self._path = os.path.join(status_dir, EventStatus.EVENT_STATUS_FILE) + self._load() + + def mark_event_status(self, name, version, op, status): + event = self._event_name(name, version, op) + self._status[event] = (status == True) + self._save() + + def _event_name(self, name, version, op): + return "{0}-{1}-{2}".format(name, version, op) + + def _load(self): + try: + self._status = {} + if os.path.isfile(self._path): + with open(self._path, 'r') as f: + self._status = json.load(f) + except Exception as e: + logger.warn("Exception occurred loading event status: {0}".format(e)) + self._status = {} + + def _save(self): + try: + with open(self._path, 'w') as f: + json.dump(self._status, f) + except Exception as e: + logger.warn("Exception occurred saving event status: {0}".format(e)) + +__event_status__ = EventStatus() +__event_status_operations__ = [ + WALAEventOperation.AutoUpdate, + WALAEventOperation.ReportStatus + ] + + +def _log_event(name, op, message, duration, is_success=True): global _EVENT_MSG if not is_success: - logger.error(_EVENT_MSG, name, op, message) + logger.error(_EVENT_MSG, name, op, message, duration) else: - logger.info(_EVENT_MSG, name, op, message) + logger.info(_EVENT_MSG, name, op, message, duration) class EventLogger(object): @@ -76,7 +136,7 @@ class EventLogger(object): def save_event(self, data): if self.event_dir is None: - logger.warn("Event reporter is not initialized.") + logger.warn("Cannot save event -- Event reporter is not initialized.") return if not os.path.exists(self.event_dir): @@ -104,11 +164,11 @@ class EventLogger(object): raise EventError("Failed to write events to file:{0}", e) def reset_periodic(self): - self.periodic_messages = {} + self.periodic_events = {} def is_period_elapsed(self, delta, h): - return h not in self.periodic_messages or \ - (self.periodic_messages[h] + delta) <= datetime.now() + return h not in self.periodic_events or \ + (self.periodic_events[h] + delta) <= datetime.now() def add_periodic(self, delta, name, op="", is_success=True, duration=0, @@ -122,13 +182,21 @@ class EventLogger(object): op=op, is_success=is_success, duration=duration, version=version, message=message, evt_type=evt_type, is_internal=is_internal, log_event=log_event) - self.periodic_messages[h] = datetime.now() + self.periodic_events[h] = datetime.now() - def add_event(self, name, op="", is_success=True, duration=0, + def add_event(self, + name, + op="", + is_success=True, + duration=0, version=CURRENT_VERSION, - message="", evt_type="", is_internal=False, log_event=True): + message="", + evt_type="", + is_internal=False, + log_event=True): + if not is_success or log_event: - _log_event(name, op, message, is_success=is_success) + _log_event(name, op, message, duration, is_success=is_success) event = TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") event.parameters.append(TelemetryEventParam('Name', name)) @@ -176,22 +244,24 @@ def add_event(name, op="", is_success=True, duration=0, version=CURRENT_VERSION, message="", evt_type="", is_internal=False, log_event=True, reporter=__event_logger__): if reporter.event_dir is None: - logger.warn("Event reporter is not initialized.") - _log_event(name, op, message, is_success=is_success) + logger.warn("Cannot add event -- Event reporter is not initialized.") + _log_event(name, op, message, duration, is_success=is_success) return - reporter.add_event( - name, op=op, is_success=is_success, duration=duration, - version=str(version), message=message, evt_type=evt_type, - is_internal=is_internal, log_event=log_event) + if should_emit_event(name, version, op, is_success): + mark_event_status(name, version, op, is_success) + reporter.add_event( + name, op=op, is_success=is_success, duration=duration, + version=str(version), message=message, evt_type=evt_type, + is_internal=is_internal, log_event=log_event) def add_periodic( delta, name, op="", is_success=True, duration=0, version=CURRENT_VERSION, message="", evt_type="", is_internal=False, log_event=True, force=False, reporter=__event_logger__): if reporter.event_dir is None: - logger.warn("Event reporter is not initialized.") - _log_event(name, op, message, is_success=is_success) + logger.warn("Cannot add periodic event -- Event reporter is not initialized.") + _log_event(name, op, message, duration, is_success=is_success) return reporter.add_periodic( @@ -199,9 +269,22 @@ def add_periodic( version=str(version), message=message, evt_type=evt_type, is_internal=is_internal, log_event=log_event, force=force) -def init_event_logger(event_dir, reporter=__event_logger__): - reporter.event_dir = event_dir +def mark_event_status(name, version, op, status): + if op in __event_status_operations__: + __event_status__.mark_event_status(name, version, op, status) + +def should_emit_event(name, version, op, status): + return \ + op not in __event_status_operations__ or \ + __event_status__ is None or \ + not __event_status__.event_marked(name, version, op) or \ + __event_status__.event_succeeded(name, version, op) != status + +def init_event_logger(event_dir): + __event_logger__.event_dir = event_dir +def init_event_status(status_dir): + __event_status__.initialize(status_dir) def dump_unhandled_err(name): if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \ diff --git a/azurelinuxagent/common/exception.py b/azurelinuxagent/common/exception.py index 7a0c75e..17c6ce0 100644 --- a/azurelinuxagent/common/exception.py +++ b/azurelinuxagent/common/exception.py @@ -86,7 +86,6 @@ class DhcpError(AgentError): def __init__(self, msg=None, inner=None): super(DhcpError, self).__init__('000006', msg, inner) - class OSUtilError(AgentError): """ Failed to perform operation to OS configuration @@ -148,3 +147,12 @@ class UpdateError(AgentError): def __init__(self, msg=None, inner=None): super(UpdateError, self).__init__('000012', msg, inner) + + +class ResourceGoneError(HttpError): + """ + The requested resource no longer exists (i.e., status code 410) + """ + + def __init__(self, msg=None, inner=None): + super(ResourceGoneError, self).__init__(msg, inner) diff --git a/azurelinuxagent/common/osutil/default.py b/azurelinuxagent/common/osutil/default.py index 58c0ef8..dc1c11a 100644 --- a/azurelinuxagent/common/osutil/default.py +++ b/azurelinuxagent/common/osutil/default.py @@ -40,6 +40,7 @@ import azurelinuxagent.common.utils.textutil as textutil from azurelinuxagent.common.exception import OSUtilError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.utils.flexible_version import FlexibleVersion __RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules", "/etc/udev/rules.d/70-persistent-net.rules" ] @@ -50,10 +51,20 @@ for all distros. Each concrete distro classes could overwrite default behavior if needed. """ +IPTABLES_VERSION_PATTERN = re.compile("^[^\d\.]*([\d\.]+).*$") +IPTABLES_VERSION = "iptables --version" +IPTABLES_LOCKING_VERSION = FlexibleVersion('1.4.21') + +FIREWALL_ACCEPT = "iptables {0} -t security -{1} OUTPUT -d {2} -p tcp -m owner --uid-owner {3} -j ACCEPT" +FIREWALL_DROP = "iptables {0} -t security -{1} OUTPUT -d {2} -p tcp -j DROP" +FIREWALL_LIST = "iptables {0} -t security -L" + +_enable_firewall = True + DMIDECODE_CMD = 'dmidecode --string system-uuid' PRODUCT_ID_FILE = '/sys/class/dmi/id/product_uuid' UUID_PATTERN = re.compile( - '^\s*[A-F0-9]{8}(?:\-[A-F0-9]{4}){3}\-[A-F0-9]{12}\s*$', + r'^\s*[A-F0-9]{8}(?:\-[A-F0-9]{4}){3}\-[A-F0-9]{12}\s*$', re.IGNORECASE) class DefaultOSUtil(object): @@ -63,6 +74,113 @@ class DefaultOSUtil(object): self.selinux = None self.disable_route_warning = False + def enable_firewall(self, dst_ip=None, uid=None): + + # If a previous attempt threw an exception, do not retry + global _enable_firewall + if not _enable_firewall: + return False + + try: + if dst_ip is None or uid is None: + msg = "Missing arguments to enable_firewall" + logger.warn(msg) + raise Exception(msg) + + # Determine if iptables will serialize access + rc, output = shellutil.run_get_output(IPTABLES_VERSION) + if rc != 0: + msg = "Unable to determine version of iptables" + logger.warn(msg) + raise Exception(msg) + + m = IPTABLES_VERSION_PATTERN.match(output) + if m is None: + msg = "iptables did not return version information" + logger.warn(msg) + raise Exception(msg) + + wait = "-w" \ + if FlexibleVersion(m.group(1)) >= IPTABLES_LOCKING_VERSION \ + else "" + + # If the DROP rule exists, make no changes + drop_rule = FIREWALL_DROP.format(wait, "C", dst_ip) + + if shellutil.run(drop_rule, chk_err=False) == 0: + logger.verbose("Firewall appears established") + return True + + # Otherwise, append both rules + accept_rule = FIREWALL_ACCEPT.format(wait, "A", dst_ip, uid) + drop_rule = FIREWALL_DROP.format(wait, "A", dst_ip) + + if shellutil.run(accept_rule) != 0: + msg = "Unable to add ACCEPT firewall rule '{0}'".format( + accept_rule) + logger.warn(msg) + raise Exception(msg) + + if shellutil.run(drop_rule) != 0: + msg = "Unable to add DROP firewall rule '{0}'".format( + drop_rule) + logger.warn(msg) + raise Exception(msg) + + logger.info("Successfully added Azure fabric firewall rules") + + rc, output = shellutil.run_get_output(FIREWALL_LIST.format(wait)) + if rc == 0: + logger.info("Firewall rules:\n{0}".format(output)) + else: + logger.warn("Listing firewall rules failed: {0}".format(output)) + + return True + + except Exception as e: + _enable_firewall = False + logger.info("Unable to establish firewall -- " + "no further attempts will be made: " + "{0}".format(ustr(e))) + return False + + def _correct_instance_id(self, id): + ''' + Azure stores the instance ID with an incorrect byte ordering for the + first parts. For example, the ID returned by the metadata service: + + D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8 + + will be found as: + + 544CDFD0-CB4E-4B4A-9954-5BDF3ED5C3B8 + + This code corrects the byte order such that it is consistent with + that returned by the metadata service. + ''' + + if not UUID_PATTERN.match(id): + return id + + parts = id.split('-') + return '-'.join([ + textutil.swap_hexstring(parts[0], width=2), + textutil.swap_hexstring(parts[1], width=2), + textutil.swap_hexstring(parts[2], width=2), + parts[3], + parts[4] + ]) + + def is_current_instance_id(self, id_that): + ''' + Compare two instance IDs for equality, but allow that some IDs + may have been persisted using the incorrect byte ordering. + ''' + id_this = self.get_instance_id() + return id_that == id_this or \ + id_that == self._correct_instance_id(id_this) + + def get_agent_conf_file_path(self): return self.agent_conf_file_path @@ -74,13 +192,14 @@ class DefaultOSUtil(object): If nothing works (for old VMs), return the empty string ''' if os.path.isfile(PRODUCT_ID_FILE): - return fileutil.read_file(PRODUCT_ID_FILE).strip() + s = fileutil.read_file(PRODUCT_ID_FILE).strip() - rc, s = shellutil.run_get_output(DMIDECODE_CMD) - if rc != 0 or UUID_PATTERN.match(s) is None: - return "" + else: + rc, s = shellutil.run_get_output(DMIDECODE_CMD) + if rc != 0 or UUID_PATTERN.match(s) is None: + return "" - return s.strip() + return self._correct_instance_id(s.strip()) def get_userentry(self, username): try: @@ -158,10 +277,12 @@ class DefaultOSUtil(object): fileutil.append_file(sudoers_file, sudoers) sudoer = None if nopasswd: - sudoer = "{0} ALL=(ALL) NOPASSWD: ALL\n".format(username) + sudoer = "{0} ALL=(ALL) NOPASSWD: ALL".format(username) else: - sudoer = "{0} ALL=(ALL) ALL\n".format(username) - fileutil.append_file(sudoers_wagent, sudoer) + sudoer = "{0} ALL=(ALL) ALL".format(username) + if not os.path.isfile(sudoers_wagent) or \ + fileutil.findstr_in_file(sudoers_wagent, sudoer) is None: + fileutil.append_file(sudoers_wagent, "{0}\n".format(sudoer)) fileutil.chmod(sudoers_wagent, 0o440) else: #Remove user from sudoers @@ -334,7 +455,7 @@ class DefaultOSUtil(object): return_code, err = self.mount(dvd_device, mount_point, option="-o ro -t udf,iso9660", - chk_err=chk_err) + chk_err=False) if return_code == 0: logger.info("Successfully mounted dvd") return @@ -718,7 +839,7 @@ class DefaultOSUtil(object): for conf_file in dhclient_files: if not os.path.isfile(conf_file): continue - if fileutil.findstr_in_file(conf_file, autosend): + if fileutil.findre_in_file(conf_file, autosend): #Return if auto send host-name is configured return fileutil.update_conf_file(conf_file, diff --git a/azurelinuxagent/common/osutil/factory.py b/azurelinuxagent/common/osutil/factory.py index 2be90ab..43aa6a7 100644 --- a/azurelinuxagent/common/osutil/factory.py +++ b/azurelinuxagent/common/osutil/factory.py @@ -41,7 +41,8 @@ def get_osutil(distro_name=DISTRO_NAME, if distro_name == "arch": return ArchUtil() - if distro_name == "clear linux software for intel architecture": + if distro_name == "clear linux os for intel architecture" \ + or distro_name == "clear linux software for intel architecture": return ClearLinuxUtil() if distro_name == "ubuntu": diff --git a/azurelinuxagent/common/osutil/openbsd.py b/azurelinuxagent/common/osutil/openbsd.py index 9bfe6de..a022c59 100644 --- a/azurelinuxagent/common/osutil/openbsd.py +++ b/azurelinuxagent/common/osutil/openbsd.py @@ -248,8 +248,10 @@ class OpenBSDOSUtil(DefaultOSUtil): os.makedirs(mount_point) for retry in range(0, max_retry): - retcode = self.mount(dvd_device, mount_point, option="-o ro -t udf", - chk_err=chk_err) + retcode = self.mount(dvd_device, + mount_point, + option="-o ro -t udf", + chk_err=False) if retcode == 0: logger.info("Successfully mounted DVD") return diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py index 9af8a97..729d8fb 100644 --- a/azurelinuxagent/common/protocol/hostplugin.py +++ b/azurelinuxagent/common/protocol/hostplugin.py @@ -22,7 +22,8 @@ import json import traceback from azurelinuxagent.common import logger -from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.exception import HttpError, ProtocolError, \ + ResourceGoneError from azurelinuxagent.common.future import ustr, httpclient from azurelinuxagent.common.utils import restutil from azurelinuxagent.common.utils import textutil @@ -85,10 +86,10 @@ class HostPluginProtocol(object): try: headers = {HEADER_CONTAINER_ID: self.container_id} response = restutil.http_get(url, headers) - if response.status != httpclient.OK: + if restutil.request_failed(response): logger.error( "HostGAPlugin: Failed Get API versions: {0}".format( - self.read_response_error(response))) + restutil.read_response_error(response))) else: return_val = ustr(remove_bom(response.read()), encoding='utf-8') @@ -117,42 +118,7 @@ class HostPluginProtocol(object): return url, headers def put_vm_log(self, content): - """ - Try to upload the given content to the host plugin - :param deployment_id: the deployment id, which is obtained from the - goal state (tenant name) - :param container_id: the container id, which is obtained from the - goal state - :param content: the binary content of the zip file to upload - :return: - """ - if not self.ensure_initialized(): - raise ProtocolError("HostGAPlugin: Host plugin channel is not available") - - if content is None \ - or self.container_id is None \ - or self.deployment_id is None: - logger.error( - "HostGAPlugin: Invalid arguments passed: " - "[{0}], [{1}], [{2}]".format( - content, - self.container_id, - self.deployment_id)) - return - url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT) - - headers = {"x-ms-vmagentlog-deploymentid": self.deployment_id, - "x-ms-vmagentlog-containerid": self.container_id} - logger.periodic( - logger.EVERY_FIFTEEN_MINUTES, - "HostGAPlugin: Put VM log to [{0}]".format(url)) - try: - response = restutil.http_put(url, content, headers) - if response.status != httpclient.OK: - logger.error("HostGAPlugin: Put log failed: Code {0}".format( - response.status)) - except HttpError as e: - logger.error("HostGAPlugin: Put log exception: {0}".format(e)) + raise NotImplementedError("Unimplemented") def put_vm_status(self, status_blob, sas_url, config_blob_type=None): """ @@ -169,6 +135,7 @@ class HostPluginProtocol(object): logger.verbose("HostGAPlugin: Posting VM status") try: + blob_type = status_blob.type if status_blob.type else config_blob_type if blob_type == "BlockBlob": @@ -176,17 +143,14 @@ class HostPluginProtocol(object): else: self._put_page_blob_status(sas_url, status_blob) - if not HostPluginProtocol.is_default_channel(): + except Exception as e: + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + if isinstance(e, ResourceGoneError): logger.verbose("HostGAPlugin: Setting host plugin as default channel") HostPluginProtocol.set_default_channel(True) - except Exception as e: - message = "HostGAPlugin: Exception Put VM status: {0}, {1}".format(e, traceback.format_exc()) - from azurelinuxagent.common.event import WALAEventOperation, report_event - report_event(op=WALAEventOperation.ReportStatus, - is_success=False, - message=message) - logger.warn("HostGAPlugin: resetting default channel") - HostPluginProtocol.set_default_channel(False) + + raise def _put_block_blob_status(self, sas_url, status_blob): url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT) @@ -198,9 +162,9 @@ class HostPluginProtocol(object): bytearray(status_blob.data, encoding='utf-8')), headers=self._build_status_headers()) - if response.status != httpclient.OK: + if restutil.request_failed(response): raise HttpError("HostGAPlugin: Put BlockBlob failed: {0}".format( - self.read_response_error(response))) + restutil.read_response_error(response))) else: logger.verbose("HostGAPlugin: Put BlockBlob status succeeded") @@ -219,10 +183,10 @@ class HostPluginProtocol(object): status_blob.get_page_blob_create_headers(status_size)), headers=self._build_status_headers()) - if response.status != httpclient.OK: + if restutil.request_failed(response): raise HttpError( "HostGAPlugin: Failed PageBlob clean-up: {0}".format( - self.read_response_error(response))) + restutil.read_response_error(response))) else: logger.verbose("HostGAPlugin: PageBlob clean-up succeeded") @@ -249,11 +213,11 @@ class HostPluginProtocol(object): buf), headers=self._build_status_headers()) - if response.status != httpclient.OK: + if restutil.request_failed(response): raise HttpError( "HostGAPlugin Error: Put PageBlob bytes [{0},{1}]: " \ "{2}".format( - start, end, self.read_response_error(response))) + start, end, restutil.read_response_error(response))) # Advance to the next page (if any) start = end @@ -287,26 +251,3 @@ class HostPluginProtocol(object): if PY_VERSION_MAJOR > 2: return s.decode('utf-8') return s - - @staticmethod - def read_response_error(response): - result = '' - if response is not None: - try: - body = remove_bom(response.read()) - result = "[{0}: {1}] {2}".format(response.status, - response.reason, - body) - - # this result string is passed upstream to several methods - # which do a raise HttpError() or a format() of some kind; - # as a result it cannot have any unicode characters - if PY_VERSION_MAJOR < 3: - result = ustr(result, encoding='ascii', errors='ignore') - else: - result = result\ - .encode(encoding='ascii', errors='ignore')\ - .decode(encoding='ascii', errors='ignore') - except Exception: - logger.warn(traceback.format_exc()) - return result diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py index b0b6f67..4de7ecf 100644 --- a/azurelinuxagent/common/protocol/metadata.py +++ b/azurelinuxagent/common/protocol/metadata.py @@ -88,7 +88,7 @@ class MetadataProtocol(Protocol): except HttpError as e: raise ProtocolError(ustr(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) data = resp.read() @@ -103,7 +103,7 @@ class MetadataProtocol(Protocol): resp = restutil.http_put(url, json.dumps(data), headers=headers) except HttpError as e: raise ProtocolError(ustr(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) def _post_data(self, url, data, headers=None): diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py index a42db37..1ec3e21 100644 --- a/azurelinuxagent/common/protocol/restapi.py +++ b/azurelinuxagent/common/protocol/restapi.py @@ -317,8 +317,8 @@ class Protocol(DataContract): def download_ext_handler_pkg(self, uri, headers=None): try: - resp = restutil.http_get(uri, chk_proxy=True, headers=headers) - if resp.status == restutil.httpclient.OK: + resp = restutil.http_get(uri, use_proxy=True, headers=headers) + if restutil.request_succeeded(resp): return resp.read() except Exception as e: logger.warn("Failed to download from: {0}".format(uri), e) diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py index bb3500a..3071d7a 100644 --- a/azurelinuxagent/common/protocol/util.py +++ b/azurelinuxagent/common/protocol/util.py @@ -16,11 +16,14 @@ # # Requires Python 2.4+ and Openssl 1.0+ # + +import errno import os import re import shutil import time import threading + import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \ @@ -231,6 +234,9 @@ class ProtocolUtil(object): try: os.remove(protocol_file_path) except IOError as e: + # Ignore file-not-found errors (since the file is being removed) + if e.errno == errno.ENOENT: + return logger.error("Failed to clear protocol endpoint: {0}", e) def get_protocol(self): diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py index d731e11..4f3b7e0 100644 --- a/azurelinuxagent/common/protocol/wire.py +++ b/azurelinuxagent/common/protocol/wire.py @@ -26,7 +26,8 @@ import azurelinuxagent.common.conf as conf import azurelinuxagent.common.utils.fileutil as fileutil import azurelinuxagent.common.utils.textutil as textutil -from azurelinuxagent.common.exception import ProtocolNotFoundError +from azurelinuxagent.common.exception import ProtocolNotFoundError, \ + ResourceGoneError from azurelinuxagent.common.future import httpclient, bytebuffer from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol from azurelinuxagent.common.protocol.restapi import * @@ -96,7 +97,10 @@ class WireProtocol(Protocol): cryptutil = CryptUtil(conf.get_openssl_cmd()) cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) - self.client.update_goal_state(forced=True) + self.update_goal_state(forced=True) + + def update_goal_state(self, forced=False, max_retry=3): + self.client.update_goal_state(forced=forced, max_retry=max_retry) def get_vminfo(self): goal_state = self.client.get_goal_state() @@ -117,7 +121,7 @@ class WireProtocol(Protocol): def get_vmagent_manifests(self): # Update goal state to get latest extensions config - self.client.update_goal_state() + self.update_goal_state() goal_state = self.client.get_goal_state() ext_conf = self.client.get_ext_conf() return ext_conf.vmagent_manifests, goal_state.incarnation @@ -130,7 +134,7 @@ class WireProtocol(Protocol): def get_ext_handlers(self): logger.verbose("Get extension handler config") # Update goal state to get latest extensions config - self.client.update_goal_state() + self.update_goal_state() goal_state = self.client.get_goal_state() ext_conf = self.client.get_ext_conf() # In wire protocol, incarnation is equivalent to ETag @@ -533,29 +537,27 @@ class WireClient(object): self.req_count = 0 def call_wireserver(self, http_req, *args, **kwargs): - """ - Call wire server; handle throttling (403), resource gone (410) and - service unavailable (503). - """ self.prevent_throttling() - for retry in range(0, 3): + + try: + # Never use the HTTP proxy for wireserver + kwargs['use_proxy'] = False resp = http_req(*args, **kwargs) - if resp.status == httpclient.FORBIDDEN: - logger.warn("Sending too many requests to wire server. ") - logger.info("Sleeping {0}s to avoid throttling.", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - elif resp.status == httpclient.SERVICE_UNAVAILABLE: - logger.warn("Service temporarily unavailable, sleeping {0}s " - "before retrying.", LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - elif resp.status == httpclient.GONE: - msg = args[0] if len(args) > 0 else "" - raise WireProtocolResourceGone(msg) - else: - return resp - raise ProtocolError(("Calling wire server failed: " - "{0}").format(resp.status)) + except Exception as e: + raise ProtocolError("[Wireserver Exception] {0}".format( + ustr(e))) + + if resp is not None and resp.status == httpclient.GONE: + msg = args[0] if len(args) > 0 else "" + raise WireProtocolResourceGone(msg) + + elif restutil.request_failed(resp): + msg = "[Wireserver Failed] URI {0} ".format(args[0]) + if resp is not None: + msg += " [HTTP Failed] Status Code {0}".format(resp.status) + raise ProtocolError(msg) + + return resp def decode_config(self, data): if data is None: @@ -565,16 +567,9 @@ class WireClient(object): return xml_text def fetch_config(self, uri, headers): - try: - resp = self.call_wireserver(restutil.http_get, - uri, - headers=headers) - except HttpError as e: - raise ProtocolError(ustr(e)) - - if resp.status != httpclient.OK: - raise ProtocolError("{0} - {1}".format(resp.status, uri)) - + resp = self.call_wireserver(restutil.http_get, + uri, + headers=headers) return self.decode_config(resp.read()) def fetch_cache(self, local_file): @@ -589,29 +584,17 @@ class WireClient(object): try: fileutil.write_file(local_file, data) except IOError as e: + fileutil.clean_ioerror(e, + paths=[local_file]) raise ProtocolError("Failed to write cache: {0}".format(e)) @staticmethod def call_storage_service(http_req, *args, **kwargs): - """ - Call storage service, handle SERVICE_UNAVAILABLE(503) - """ - # Default to use the configured HTTP proxy - if not 'chk_proxy' in kwargs or kwargs['chk_proxy'] is None: - kwargs['chk_proxy'] = True + if not 'use_proxy' in kwargs or kwargs['use_proxy'] is None: + kwargs['use_proxy'] = True - for retry in range(0, 3): - resp = http_req(*args, **kwargs) - if resp.status == httpclient.SERVICE_UNAVAILABLE: - logger.warn("Storage service is temporarily unavailable. ") - logger.info("Will retry in {0} seconds. ", - LONG_WAITING_INTERVAL) - time.sleep(LONG_WAITING_INTERVAL) - else: - return resp - raise ProtocolError(("Calling storage endpoint failed: " - "{0}").format(resp.status)) + return http_req(*args, **kwargs) def fetch_manifest(self, version_uris): logger.verbose("Fetch manifest") @@ -619,47 +602,61 @@ class WireClient(object): response = None if not HostPluginProtocol.is_default_channel(): response = self.fetch(version.uri) + if not response: if HostPluginProtocol.is_default_channel(): logger.verbose("Using host plugin as default channel") else: - logger.verbose("Manifest could not be downloaded, falling back to host plugin") - host = self.get_host_plugin() - uri, headers = host.get_artifact_request(version.uri) - response = self.fetch(uri, headers, chk_proxy=False) - if not response: - host = self.get_host_plugin(force_update=True) - logger.info("Retry fetch in {0} seconds", - SHORT_WAITING_INTERVAL) - time.sleep(SHORT_WAITING_INTERVAL) - else: - host.manifest_uri = version.uri - logger.verbose("Manifest downloaded successfully from host plugin") - if not HostPluginProtocol.is_default_channel(): - logger.info("Setting host plugin as default channel") - HostPluginProtocol.set_default_channel(True) + logger.verbose("Failed to download manifest, " + "switching to host plugin") + + try: + host = self.get_host_plugin() + uri, headers = host.get_artifact_request(version.uri) + response = self.fetch(uri, headers, use_proxy=False) + + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + except ResourceGoneError: + HostPluginProtocol.set_default_channel(True) + raise + + host.manifest_uri = version.uri + logger.verbose("Manifest downloaded successfully from host plugin") + if not HostPluginProtocol.is_default_channel(): + logger.info("Setting host plugin as default channel") + HostPluginProtocol.set_default_channel(True) + if response: return response + raise ProtocolError("Failed to fetch manifest from all sources") - def fetch(self, uri, headers=None, chk_proxy=None): + def fetch(self, uri, headers=None, use_proxy=None): logger.verbose("Fetch [{0}] with headers [{1}]", uri, headers) - return_value = None try: resp = self.call_storage_service( - restutil.http_get, - uri, - headers, - chk_proxy=chk_proxy) - if resp.status == httpclient.OK: - return_value = self.decode_config(resp.read()) - else: - logger.warn("Could not fetch {0} [{1}]", - uri, - HostPluginProtocol.read_response_error(resp)) + restutil.http_get, + uri, + headers=headers, + use_proxy=use_proxy) + + if restutil.request_failed(resp): + msg = "[Storage Failed] URI {0} ".format(uri) + if resp is not None: + msg += restutil.read_response_error(resp) + logger.warn(msg) + raise ProtocolError(msg) + + return self.decode_config(resp.read()) + except (HttpError, ProtocolError) as e: logger.verbose("Fetch failed from [{0}]: {1}", uri, e) - return return_value + + if isinstance(e, ResourceGoneError): + raise + + return None def update_hosting_env(self, goal_state): if goal_state.hosting_env_uri is None: @@ -734,6 +731,12 @@ class WireClient(object): self.host_plugin.container_id = goal_state.container_id self.host_plugin.role_config_name = goal_state.role_config_name return + + except ProtocolError: + if retry < max_retry-1: + continue + raise + except WireProtocolResourceGone: logger.info("Incarnation is out of date. Update goalstate.") xml_text = self.fetch_config(uri, self.get_header()) @@ -791,20 +794,45 @@ class WireClient(object): return self.ext_conf def get_ext_manifest(self, ext_handler, goal_state): - local_file = MANIFEST_FILE_NAME.format(ext_handler.name, - goal_state.incarnation) - local_file = os.path.join(conf.get_lib_dir(), local_file) - xml_text = self.fetch_manifest(ext_handler.versionUris) - self.save_cache(local_file, xml_text) - return ExtensionManifest(xml_text) + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) + goal_state = self.get_goal_state() + + local_file = MANIFEST_FILE_NAME.format( + ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) + return ExtensionManifest(xml_text) + + except ResourceGoneError: + continue + + raise ProtocolError("Failed to retrieve extension manifest") def get_gafamily_manifest(self, vmagent_manifest, goal_state): - local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family, - goal_state.incarnation) - local_file = os.path.join(conf.get_lib_dir(), local_file) - xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris) - fileutil.write_file(local_file, xml_text) - return ExtensionManifest(xml_text) + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) + goal_state = self.get_goal_state() + + local_file = MANIFEST_FILE_NAME.format( + vmagent_manifest.family, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest( + vmagent_manifest.versionsManifestUris) + fileutil.write_file(local_file, xml_text) + return ExtensionManifest(xml_text) + + except ResourceGoneError: + continue + + raise ProtocolError("Failed to retrieve GAFamily manifest") def check_wire_protocol_version(self): uri = VERSION_INFO_URI.format(self.endpoint) @@ -823,39 +851,55 @@ class WireClient(object): raise ProtocolNotFoundError(error) def upload_status_blob(self): - ext_conf = self.get_ext_conf() + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) - blob_uri = ext_conf.status_upload_blob - blob_type = ext_conf.status_upload_blob_type + ext_conf = self.get_ext_conf() - if blob_uri is not None: + blob_uri = ext_conf.status_upload_blob + blob_type = ext_conf.status_upload_blob_type - if not blob_type in ["BlockBlob", "PageBlob"]: - blob_type = "BlockBlob" - logger.verbose("Status Blob type is unspecified " - "-- assuming it is a BlockBlob") + if blob_uri is not None: + + if not blob_type in ["BlockBlob", "PageBlob"]: + blob_type = "BlockBlob" + logger.verbose("Status Blob type is unspecified " + "-- assuming it is a BlockBlob") + + try: + self.status_blob.prepare(blob_type) + except Exception as e: + self.report_status_event( + "Exception creating status blob: {0}", ustr(e)) + return + + if not HostPluginProtocol.is_default_channel(): + try: + if self.status_blob.upload(blob_uri): + return + except HttpError as e: + pass + + host = self.get_host_plugin() + host.put_vm_status(self.status_blob, + ext_conf.status_upload_blob, + ext_conf.status_upload_blob_type) + HostPluginProtocol.set_default_channel(True) + return - try: - self.status_blob.prepare(blob_type) except Exception as e: + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + if isinstance(e, ResourceGoneError): + HostPluginProtocol.set_default_channel(True) + continue + self.report_status_event( - "Exception creating status blob: {0}", - e) + "Exception uploading status blob: {0}", ustr(e)) return - uploaded = False - if not HostPluginProtocol.is_default_channel(): - try: - uploaded = self.status_blob.upload(blob_uri) - except HttpError as e: - pass - - if not uploaded: - host = self.get_host_plugin() - host.put_vm_status(self.status_blob, - ext_conf.status_upload_blob, - ext_conf.status_upload_blob_type) - def report_role_prop(self, thumbprint): goal_state = self.get_goal_state() role_prop = _build_role_properties(goal_state.container_id, @@ -896,11 +940,12 @@ class WireClient(object): health_report_uri, health_report, headers=headers, - max_retry=30) + max_retry=30, + retry_delay=15) except HttpError as e: raise ProtocolError((u"Failed to send provision status: " u"{0}").format(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): raise ProtocolError((u"Failed to send provision status: " u",{0}: {1}").format(resp.status, resp.read())) @@ -919,7 +964,7 @@ class WireClient(object): except HttpError as e: raise ProtocolError("Failed to send events:{0}".format(e)) - if resp.status != httpclient.OK: + if restutil.request_failed(resp): logger.verbose(resp.read()) raise ProtocolError( "Failed to send events:{0}".format(resp.status)) @@ -979,12 +1024,8 @@ class WireClient(object): "x-ms-guest-agent-public-x509-cert": cert } - def get_host_plugin(self, force_update=False): - if self.host_plugin is None or force_update: - if force_update: - logger.warn("Forcing update of goal state") - self.goal_state = None - self.update_goal_state(forced=True) + def get_host_plugin(self): + if self.host_plugin is None: goal_state = self.get_goal_state() self.host_plugin = HostPluginProtocol(self.endpoint, goal_state.container_id, @@ -997,23 +1038,47 @@ class WireClient(object): def get_artifacts_profile(self): artifacts_profile = None - if self.has_artifacts_profile_blob(): - blob = self.ext_conf.artifacts_profile_blob - logger.verbose("Getting the artifacts profile") - profile = self.fetch(blob) + for update_goal_state in [False, True]: + try: + if update_goal_state: + self.update_goal_state(forced=True) - if profile is None: - logger.warn("Download failed, falling back to host plugin") - host = self.get_host_plugin() - uri, headers = host.get_artifact_request(blob) - profile = self.decode_config(self.fetch(uri, headers, chk_proxy=False)) + if self.has_artifacts_profile_blob(): + blob = self.ext_conf.artifacts_profile_blob - if not textutil.is_str_none_or_whitespace(profile): - logger.verbose("Artifacts profile downloaded successfully") - artifacts_profile = InVMArtifactsProfile(profile) + profile = None + if not HostPluginProtocol.is_default_channel(): + logger.verbose("Retrieving the artifacts profile") + profile = self.fetch(blob) - return artifacts_profile + if profile is None: + if HostPluginProtocol.is_default_channel(): + logger.verbose("Using host plugin as default channel") + else: + logger.verbose("Failed to download artifacts profile, " + "switching to host plugin") + host = self.get_host_plugin() + uri, headers = host.get_artifact_request(blob) + config = self.fetch(uri, headers, use_proxy=False) + profile = self.decode_config(config) + + if not textutil.is_str_none_or_whitespace(profile): + logger.verbose("Artifacts profile downloaded") + artifacts_profile = InVMArtifactsProfile(profile) + + return artifacts_profile + + except ResourceGoneError: + HostPluginProtocol.set_default_channel(True) + continue + + except Exception as e: + logger.warn( + "Exception retrieving artifacts profile: {0}".format( + ustr(e))) + + return None class VersionInfo(object): def __init__(self, xml_text): diff --git a/azurelinuxagent/common/rdma.py b/azurelinuxagent/common/rdma.py index 226482d..3c01e77 100644 --- a/azurelinuxagent/common/rdma.py +++ b/azurelinuxagent/common/rdma.py @@ -202,7 +202,17 @@ class RDMADeviceHandler(object): RDMADeviceHandler.update_dat_conf(dapl_config_paths, self.ipv4_addr) skip_rdma_device = False - retcode,out = shellutil.run_get_output("modinfo hv_network_direct") + module_name = "hv_network_direct" + retcode,out = shellutil.run_get_output("modprobe -R %s" % module_name, chk_err=False) + if retcode == 0: + module_name = out.strip() + else: + logger.info("RDMA: failed to resolve module name. Use original name") + retcode,out = shellutil.run_get_output("modprobe %s" % module_name) + if retcode != 0: + logger.error("RDMA: failed to load module %s" % module_name) + return + retcode,out = shellutil.run_get_output("modinfo %s" % module_name) if retcode == 0: version = re.search("version:\s+(\d+)\.(\d+)\.(\d+)\D", out, re.IGNORECASE) if version: diff --git a/azurelinuxagent/common/utils/fileutil.py b/azurelinuxagent/common/utils/fileutil.py index bae1957..96b5b82 100644 --- a/azurelinuxagent/common/utils/fileutil.py +++ b/azurelinuxagent/common/utils/fileutil.py @@ -21,15 +21,30 @@ File operation util functions """ +import errno as errno import glob import os +import pwd import re import shutil -import pwd +import string + import azurelinuxagent.common.logger as logger -from azurelinuxagent.common.future import ustr import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.future import ustr + +KNOWN_IOERRORS = [ + errno.EIO, # I/O error + errno.ENOMEM, # Out of memory + errno.ENFILE, # File table overflow + errno.EMFILE, # Too many open files + errno.ENOSPC, # Out of space + errno.ENAMETOOLONG, # Name too long + errno.ELOOP, # Too many symbolic links encountered + errno.EREMOTEIO # Remote I/O error +] + def copy_file(from_path, to_path=None, to_dir=None): if to_path is None: to_path = os.path.join(to_dir, os.path.basename(from_path)) @@ -160,18 +175,31 @@ def chmod_tree(path, mode): for file_name in files: os.chmod(os.path.join(root, file_name), mode) -def findstr_in_file(file_path, pattern_str): +def findstr_in_file(file_path, line_str): + """ + Return True if the line is in the file; False otherwise. + (Trailing whitespace is ignore.) + """ + try: + for line in (open(file_path, 'r')).readlines(): + if line_str == line.rstrip(): + return True + except Exception as e: + pass + return False + +def findre_in_file(file_path, line_re): """ Return match object if found in file. """ try: - pattern = re.compile(pattern_str) + pattern = re.compile(line_re) for line in (open(file_path, 'r')).readlines(): match = re.search(pattern, line) if match: return match except: - raise + pass return None @@ -184,3 +212,21 @@ def get_all_files(root_path): result.extend([os.path.join(root, file) for file in files]) return result + +def clean_ioerror(e, paths=[]): + """ + Clean-up possibly bad files and directories after an IO error. + The code ignores *all* errors since disk state may be unhealthy. + """ + if isinstance(e, IOError) and e.errno in KNOWN_IOERRORS: + for path in paths: + if path is None: + continue + + try: + if os.path.isdir(path): + shutil.rmtree(path, ignore_errors=True) + else: + os.remove(path) + except Exception as e: + pass diff --git a/azurelinuxagent/common/utils/restutil.py b/azurelinuxagent/common/utils/restutil.py index 49d2d68..ddd930b 100644 --- a/azurelinuxagent/common/utils/restutil.py +++ b/azurelinuxagent/common/utils/restutil.py @@ -17,20 +17,74 @@ # Requires Python 2.4+ and Openssl 1.0+ # +import os import time +import traceback import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger -from azurelinuxagent.common.exception import HttpError -from azurelinuxagent.common.future import httpclient, urlparse +import azurelinuxagent.common.utils.textutil as textutil -""" -REST api util functions -""" +from azurelinuxagent.common.exception import HttpError, ResourceGoneError +from azurelinuxagent.common.future import httpclient, urlparse, ustr +from azurelinuxagent.common.version import PY_VERSION_MAJOR -RETRY_WAITING_INTERVAL = 3 -secure_warning = True +SECURE_WARNING_EMITTED = False + +DEFAULT_RETRIES = 3 + +SHORT_DELAY_IN_SECONDS = 5 +LONG_DELAY_IN_SECONDS = 15 + +RETRY_CODES = [ + httpclient.RESET_CONTENT, + httpclient.PARTIAL_CONTENT, + httpclient.FORBIDDEN, + httpclient.INTERNAL_SERVER_ERROR, + httpclient.NOT_IMPLEMENTED, + httpclient.BAD_GATEWAY, + httpclient.SERVICE_UNAVAILABLE, + httpclient.GATEWAY_TIMEOUT, + httpclient.INSUFFICIENT_STORAGE, + 429, # Request Rate Limit Exceeded +] + +RESOURCE_GONE_CODES = [ + httpclient.BAD_REQUEST, + httpclient.GONE +] + +OK_CODES = [ + httpclient.OK, + httpclient.CREATED, + httpclient.ACCEPTED +] + +THROTTLE_CODES = [ + httpclient.FORBIDDEN, + httpclient.SERVICE_UNAVAILABLE +] + +RETRY_EXCEPTIONS = [ + httpclient.NotConnected, + httpclient.IncompleteRead, + httpclient.ImproperConnectionState, + httpclient.BadStatusLine +] + +HTTP_PROXY_ENV = "http_proxy" +HTTPS_PROXY_ENV = "https_proxy" + + +def _is_retry_status(status, retry_codes=RETRY_CODES): + return status in retry_codes + +def _is_retry_exception(e): + return len([x for x in RETRY_EXCEPTIONS if isinstance(e, x)]) > 0 + +def _is_throttle_status(status): + return status in THROTTLE_CODES def _parse_url(url): o = urlparse(url) @@ -45,46 +99,57 @@ def _parse_url(url): return o.hostname, o.port, secure, rel_uri -def get_http_proxy(): - """ - Get http_proxy and https_proxy from environment variables. - Username and password is not supported now. - """ +def _get_http_proxy(secure=False): + # Prefer the configuration settings over environment variables host = conf.get_httpproxy_host() - port = conf.get_httpproxy_port() + port = None + + if not host is None: + port = conf.get_httpproxy_port() + + else: + http_proxy_env = HTTPS_PROXY_ENV if secure else HTTP_PROXY_ENV + http_proxy_url = None + for v in [http_proxy_env, http_proxy_env.upper()]: + if v in os.environ: + http_proxy_url = os.environ[v] + break + + if not http_proxy_url is None: + host, port, _, _ = _parse_url(http_proxy_url) + return host, port def _http_request(method, host, rel_uri, port=None, data=None, secure=False, headers=None, proxy_host=None, proxy_port=None): - url, conn = None, None + + headers = {} if headers is None else headers + use_proxy = proxy_host is not None and proxy_port is not None + + if port is None: + port = 443 if secure else 80 + + if use_proxy: + conn_host, conn_port = proxy_host, proxy_port + scheme = "https" if secure else "http" + url = "{0}://{1}:{2}{3}".format(scheme, host, port, rel_uri) + + else: + conn_host, conn_port = host, port + url = rel_uri + if secure: - port = 443 if port is None else port - if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPSConnection(proxy_host, - proxy_port, - timeout=10) + conn = httpclient.HTTPSConnection(conn_host, + conn_port, + timeout=10) + if use_proxy: conn.set_tunnel(host, port) - # If proxy is used, full url is needed. - url = "https://{0}:{1}{2}".format(host, port, rel_uri) - else: - conn = httpclient.HTTPSConnection(host, - port, - timeout=10) - url = rel_uri + else: - port = 80 if port is None else port - if proxy_host is not None and proxy_port is not None: - conn = httpclient.HTTPConnection(proxy_host, - proxy_port, - timeout=10) - # If proxy is used, full url is needed. - url = "http://{0}:{1}{2}".format(host, port, rel_uri) - else: - conn = httpclient.HTTPConnection(host, - port, - timeout=10) - url = rel_uri + conn = httpclient.HTTPConnection(conn_host, + conn_port, + timeout=10) logger.verbose("HTTP connection [{0}] [{1}] [{2}] [{3}]", method, @@ -92,49 +157,70 @@ def _http_request(method, host, rel_uri, port=None, data=None, secure=False, data, headers) - headers = {} if headers is None else headers conn.request(method=method, url=url, body=data, headers=headers) - resp = conn.getresponse() - return resp + return conn.getresponse() -def http_request(method, url, data, headers=None, max_retry=3, - chk_proxy=False): - """ - Sending http request to server - On error, sleep 10 and retry max_retry times. - """ +def http_request(method, + url, data, headers=None, + use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + + global SECURE_WARNING_EMITTED + host, port, secure, rel_uri = _parse_url(url) - global secure_warning - # Check proxy + # Use the HTTP(S) proxy proxy_host, proxy_port = (None, None) - if chk_proxy: - proxy_host, proxy_port = get_http_proxy() + if use_proxy: + proxy_host, proxy_port = _get_http_proxy(secure=secure) - # If httplib module is not built with ssl support. Fallback to http + if proxy_host or proxy_port: + logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port) + + # If httplib module is not built with ssl support, + # fallback to HTTP if allowed if secure and not hasattr(httpclient, "HTTPSConnection"): + if not conf.get_allow_http(): + raise HttpError("HTTPS is unavailable and required") + secure = False - if secure_warning: - logger.warn("httplib is not built with ssl support") - secure_warning = False + if not SECURE_WARNING_EMITTED: + logger.warn("Python does not include SSL support") + SECURE_WARNING_EMITTED = True + + # If httplib module doesn't support HTTPS tunnelling, + # fallback to HTTP if allowed + if secure and \ + proxy_host is not None and \ + proxy_port is not None \ + and not hasattr(httpclient.HTTPSConnection, "set_tunnel"): + + if not conf.get_allow_http(): + raise HttpError("HTTPS tunnelling is unavailable and required") - # If httplib module doesn't support https tunnelling. Fallback to http - if secure and proxy_host is not None and proxy_port is not None \ - and not hasattr(httpclient.HTTPSConnection, "set_tunnel"): secure = False - if secure_warning: - logger.warn("httplib does not support https tunnelling " - "(new in python 2.7)") - secure_warning = False - - if proxy_host or proxy_port: - logger.verbose("HTTP proxy: [{0}:{1}]", proxy_host, proxy_port) - - retry_msg = '' - log_msg = "HTTP {0}".format(method) - for retry in range(0, max_retry): - retry_interval = RETRY_WAITING_INTERVAL + if not SECURE_WARNING_EMITTED: + logger.warn("Python does not support HTTPS tunnelling") + SECURE_WARNING_EMITTED = True + + msg = '' + attempt = 0 + delay = retry_delay + + while attempt < max_retry: + if attempt > 0: + logger.info("[HTTP Retry] Attempt {0} of {1}: {2}", + attempt+1, + max_retry, + msg) + time.sleep(delay) + + attempt += 1 + delay = retry_delay + try: resp = _http_request(method, host, @@ -145,55 +231,123 @@ def http_request(method, url, data, headers=None, max_retry=3, headers=headers, proxy_host=proxy_host, proxy_port=proxy_port) - logger.verbose("HTTP response status: [{0}]", resp.status) + logger.verbose("[HTTP Response] Status Code {0}", resp.status) + + if request_failed(resp): + if _is_retry_status(resp.status, retry_codes=retry_codes): + msg = '[HTTP Retry] HTTP {0} Status Code {1}'.format( + method, resp.status) + if _is_throttle_status(resp.status): + delay = LONG_DELAY_IN_SECONDS + logger.info("[HTTP Delay] Delay {0} seconds for " \ + "Status Code {1}".format( + delay, resp.status)) + continue + + if resp.status in RESOURCE_GONE_CODES: + raise ResourceGoneError() + return resp + except httpclient.HTTPException as e: - retry_msg = 'HTTP exception: {0} {1}'.format(log_msg, e) - retry_interval = 5 + msg = '[HTTP Failed] HTTP {0} HttpException {1}'.format(method, e) + if _is_retry_exception(e): + continue + break + except IOError as e: - retry_msg = 'IO error: {0} {1}'.format(log_msg, e) - # error 101: network unreachable; when the adapter resets we may - # see this transient error for a short time, retry once. - if e.errno == 101: - retry_interval = RETRY_WAITING_INTERVAL - max_retry = 1 + msg = '[HTTP Failed] HTTP {0} IOError {1}'.format(method, e) + continue + + raise HttpError(msg) + + +def http_get(url, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("GET", + url, None, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_head(url, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("HEAD", + url, None, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_post(url, data, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("POST", + url, data, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_put(url, data, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("PUT", + url, data, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + + +def http_delete(url, headers=None, use_proxy=False, + max_retry=DEFAULT_RETRIES, + retry_codes=RETRY_CODES, + retry_delay=SHORT_DELAY_IN_SECONDS): + return http_request("DELETE", + url, None, headers=headers, + use_proxy=use_proxy, + max_retry=max_retry, + retry_codes=retry_codes, + retry_delay=retry_delay) + +def request_failed(resp, ok_codes=OK_CODES): + return not request_succeeded(resp, ok_codes=ok_codes) + +def request_succeeded(resp, ok_codes=OK_CODES): + return resp is not None and resp.status in ok_codes + +def read_response_error(resp): + result = '' + if resp is not None: + try: + result = "[HTTP Failed] [{0}: {1}] {2}".format( + resp.status, + resp.reason, + resp.read()) + + # this result string is passed upstream to several methods + # which do a raise HttpError() or a format() of some kind; + # as a result it cannot have any unicode characters + if PY_VERSION_MAJOR < 3: + result = ustr(result, encoding='ascii', errors='ignore') else: - retry_interval = 0 - max_retry = 0 - - if retry < max_retry: - logger.info("Retry [{0}/{1} - {3}]", - retry+1, - max_retry, - retry_interval, - retry_msg) - time.sleep(retry_interval) - - raise HttpError("{0} failed".format(log_msg)) - - -def http_get(url, headers=None, max_retry=3, chk_proxy=False): - return http_request("GET", url, data=None, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - - -def http_head(url, headers=None, max_retry=3, chk_proxy=False): - return http_request("HEAD", url, None, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - - -def http_post(url, data, headers=None, max_retry=3, chk_proxy=False): - return http_request("POST", url, data, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - - -def http_put(url, data, headers=None, max_retry=3, chk_proxy=False): - return http_request("PUT", url, data, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) - + result = result\ + .encode(encoding='ascii', errors='ignore')\ + .decode(encoding='ascii', errors='ignore') -def http_delete(url, headers=None, max_retry=3, chk_proxy=False): - return http_request("DELETE", url, None, headers=headers, - max_retry=max_retry, chk_proxy=chk_proxy) + result = textutil.replace_non_ascii(result) -# End REST api util functions + except Exception: + logger.warn(traceback.format_exc()) + return result diff --git a/azurelinuxagent/common/utils/textutil.py b/azurelinuxagent/common/utils/textutil.py index 2d99f6f..7e244fc 100644 --- a/azurelinuxagent/common/utils/textutil.py +++ b/azurelinuxagent/common/utils/textutil.py @@ -19,6 +19,7 @@ import base64 import crypt import random +import re import string import struct import sys @@ -259,6 +260,17 @@ def set_ini_config(config, name, val): config.insert(length - 1, text) +def replace_non_ascii(incoming, replace_char=''): + outgoing = '' + if incoming is not None: + for c in incoming: + if str_to_ord(c) > 128: + outgoing += replace_char + else: + outgoing += c + return outgoing + + def remove_bom(c): ''' bom is comprised of a sequence of three chars,0xef, 0xbb, 0xbf, in case of utf-8. @@ -311,6 +323,16 @@ def safe_shlex_split(s): return shlex.split(s.encode('utf-8')) return shlex.split(s) +def swap_hexstring(s, width=2): + r = len(s) % width + if r != 0: + s = ('0' * (width - (len(s) % width))) + s + + return ''.join(reversed( + re.findall( + r'[a-f0-9]{{{0}}}'.format(width), + s, + re.IGNORECASE))) def parse_json(json_str): """ diff --git a/azurelinuxagent/common/version.py b/azurelinuxagent/common/version.py index d1d4c62..f27db38 100644 --- a/azurelinuxagent/common/version.py +++ b/azurelinuxagent/common/version.py @@ -113,7 +113,7 @@ def get_distro(): AGENT_NAME = "WALinuxAgent" AGENT_LONG_NAME = "Azure Linux Agent" -AGENT_VERSION = '2.2.14' +AGENT_VERSION = '2.2.16' AGENT_LONG_VERSION = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION) AGENT_DESCRIPTION = """ The Azure Linux Agent supports the provisioning and running of Linux @@ -129,9 +129,20 @@ AGENT_NAME_PATTERN = re.compile(AGENT_PATTERN) AGENT_PKG_PATTERN = re.compile(AGENT_PATTERN+"\.zip") AGENT_DIR_PATTERN = re.compile(".*/{0}".format(AGENT_PATTERN)) -EXT_HANDLER_PATTERN = b".*/WALinuxAgent-(\w.\w.\w[.\w]*)-.*-run-exthandlers" +EXT_HANDLER_PATTERN = b".*/WALinuxAgent-(\d+.\d+.\d+[.\d+]*).*-run-exthandlers" EXT_HANDLER_REGEX = re.compile(EXT_HANDLER_PATTERN) +__distro__ = get_distro() +DISTRO_NAME = __distro__[0] +DISTRO_VERSION = __distro__[1] +DISTRO_CODE_NAME = __distro__[2] +DISTRO_FULL_NAME = __distro__[3] + +PY_VERSION = sys.version_info +PY_VERSION_MAJOR = sys.version_info[0] +PY_VERSION_MINOR = sys.version_info[1] +PY_VERSION_MICRO = sys.version_info[2] + # Set the CURRENT_AGENT and CURRENT_VERSION to match the agent directory name # - This ensures the agent will "see itself" using the same name and version @@ -173,6 +184,8 @@ def set_goal_state_agent(): match = EXT_HANDLER_REGEX.match(pname) if match: agent = match.group(1) + if PY_VERSION_MAJOR > 2: + agent = agent.decode('UTF-8') break except IOError: continue @@ -188,18 +201,6 @@ def is_current_agent_installed(): return CURRENT_AGENT == AGENT_LONG_VERSION -__distro__ = get_distro() -DISTRO_NAME = __distro__[0] -DISTRO_VERSION = __distro__[1] -DISTRO_CODE_NAME = __distro__[2] -DISTRO_FULL_NAME = __distro__[3] - -PY_VERSION = sys.version_info -PY_VERSION_MAJOR = sys.version_info[0] -PY_VERSION_MINOR = sys.version_info[1] -PY_VERSION_MICRO = sys.version_info[2] - - def is_snappy(): """ Add this workaround for detecting Snappy Ubuntu Core temporarily, diff --git a/azurelinuxagent/ga/env.py b/azurelinuxagent/ga/env.py index c81eed7..0456cb0 100644 --- a/azurelinuxagent/ga/env.py +++ b/azurelinuxagent/ga/env.py @@ -26,7 +26,10 @@ import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger from azurelinuxagent.common.dhcp import get_dhcp_handler +from azurelinuxagent.common.event import add_periodic, WALAEventOperation from azurelinuxagent.common.osutil import get_osutil +from azurelinuxagent.common.protocol import get_protocol_util +from azurelinuxagent.common.version import AGENT_NAME, CURRENT_VERSION def get_env_handler(): return EnvHandler() @@ -42,6 +45,7 @@ class EnvHandler(object): def __init__(self): self.osutil = get_osutil() self.dhcp_handler = get_dhcp_handler() + self.protocol_util = get_protocol_util() self.stopped = True self.hostname = None self.dhcpid = None @@ -64,17 +68,35 @@ class EnvHandler(object): def monitor(self): """ + Monitor firewall rules Monitor dhcp client pid and hostname. If dhcp clinet process re-start has occurred, reset routes. """ + protocol = self.protocol_util.get_protocol() while not self.stopped: self.osutil.remove_rules_files() + + if conf.enable_firewall(): + success = self.osutil.enable_firewall( + dst_ip=protocol.endpoint, + uid=os.getuid()) + add_periodic( + logger.EVERY_HOUR, + AGENT_NAME, + version=CURRENT_VERSION, + op=WALAEventOperation.Firewall, + is_success=success, + log_event=True) + timeout = conf.get_root_device_scsi_timeout() if timeout is not None: self.osutil.set_scsi_disks_timeout(timeout) + if conf.get_monitor_hostname(): self.handle_hostname_update() + self.handle_dhclient_restart() + time.sleep(5) def handle_hostname_update(self): diff --git a/azurelinuxagent/ga/exthandlers.py b/azurelinuxagent/ga/exthandlers.py index 4324d92..f0a3b09 100644 --- a/azurelinuxagent/ga/exthandlers.py +++ b/azurelinuxagent/ga/exthandlers.py @@ -411,6 +411,7 @@ class ExtHandlerInstance(object): self.protocol = protocol self.operation = None self.pkg = None + self.pkg_file = None self.is_upgrade = False prefix = "[{0}]".format(self.get_full_name()) @@ -612,12 +613,14 @@ class ExtHandlerInstance(object): raise ExtensionError("Failed to download extension") self.logger.verbose("Unpack extension package") - pkg_file = os.path.join(conf.get_lib_dir(), + self.pkg_file = os.path.join(conf.get_lib_dir(), os.path.basename(uri.uri) + ".zip") try: - fileutil.write_file(pkg_file, bytearray(package), asbin=True) - zipfile.ZipFile(pkg_file).extractall(self.get_base_dir()) + fileutil.write_file(self.pkg_file, bytearray(package), asbin=True) + zipfile.ZipFile(self.pkg_file).extractall(self.get_base_dir()) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to write and unzip plugin", e) #Add user execute permission to all files under the base dir @@ -638,6 +641,8 @@ class ExtHandlerInstance(object): man = fileutil.read_file(man_file, remove_bom=True) fileutil.write_file(self.get_manifest_file(), man) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to save HandlerManifest.json", e) #Create status and config dir @@ -647,6 +652,8 @@ class ExtHandlerInstance(object): conf_dir = self.get_conf_dir() fileutil.mkdir(conf_dir, mode=0o700) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to create status or config dir", e) #Save HandlerEnvironment.json @@ -846,6 +853,8 @@ class ExtHandlerInstance(object): try: fileutil.write_file(settings_file, settings) except IOError as e: + fileutil.clean_ioerror(e, + paths=[settings_file]) raise ExtensionError(u"Failed to update settings file", e) def update_settings(self): @@ -886,6 +895,8 @@ class ExtHandlerInstance(object): try: fileutil.write_file(self.get_env_file(), json.dumps(env)) except IOError as e: + fileutil.clean_ioerror(e, + paths=[self.get_base_dir(), self.pkg_file]) raise ExtensionError(u"Failed to save handler environment", e) def set_handler_state(self, handler_state): @@ -897,6 +908,8 @@ class ExtHandlerInstance(object): state_file = os.path.join(state_dir, "HandlerState") fileutil.write_file(state_file, handler_state) except IOError as e: + fileutil.clean_ioerror(e, + paths=[state_file]) self.logger.error("Failed to set state: {0}", e) def get_handler_state(self): @@ -925,6 +938,8 @@ class ExtHandlerInstance(object): try: fileutil.write_file(status_file, json.dumps(get_properties(handler_status))) except (IOError, ValueError, ProtocolError) as e: + fileutil.clean_ioerror(e, + paths=[status_file]) self.logger.error("Failed to save handler status: {0}", e) def get_handler_status(self): diff --git a/azurelinuxagent/ga/monitor.py b/azurelinuxagent/ga/monitor.py index dcfd6a4..307a514 100644 --- a/azurelinuxagent/ga/monitor.py +++ b/azurelinuxagent/ga/monitor.py @@ -36,8 +36,8 @@ from azurelinuxagent.common.protocol.restapi import TelemetryEventParam, \ set_properties from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, getattrib from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \ - DISTRO_CODE_NAME, AGENT_LONG_VERSION, \ - CURRENT_AGENT, CURRENT_VERSION + DISTRO_CODE_NAME, AGENT_LONG_VERSION, \ + AGENT_NAME, CURRENT_AGENT, CURRENT_VERSION def parse_event(data_str): @@ -184,9 +184,9 @@ class MonitorHandler(object): if datetime.datetime.utcnow() >= (last_heartbeat + period): last_heartbeat = datetime.datetime.utcnow() add_event( - op=WALAEventOperation.HeartBeat, - name=CURRENT_AGENT, + name=AGENT_NAME, version=CURRENT_VERSION, + op=WALAEventOperation.HeartBeat, is_success=True) try: self.collect_and_send_events() diff --git a/azurelinuxagent/ga/update.py b/azurelinuxagent/ga/update.py index 10eac82..b7ee96a 100644 --- a/azurelinuxagent/ga/update.py +++ b/azurelinuxagent/ga/update.py @@ -41,7 +41,9 @@ import azurelinuxagent.common.utils.textutil as textutil from azurelinuxagent.common.event import add_event, add_periodic, \ elapsed_milliseconds, \ WALAEventOperation -from azurelinuxagent.common.exception import UpdateError, ProtocolError +from azurelinuxagent.common.exception import ProtocolError, \ + ResourceGoneError, \ + UpdateError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.osutil import get_osutil from azurelinuxagent.common.protocol import get_protocol_util @@ -231,7 +233,8 @@ class UpdateHandler(object): This is the main loop which watches for agent and extension updates. """ - logger.info(u"Agent {0} is running as the goal state agent", CURRENT_AGENT) + logger.info(u"Agent {0} is running as the goal state agent", + CURRENT_AGENT) # Launch monitoring threads from azurelinuxagent.ga.monitor import get_monitor_handler @@ -245,14 +248,13 @@ class UpdateHandler(object): migrate_handler_state() try: - send_event_time = datetime.utcnow() - self._ensure_no_orphans() self._emit_restart_event() while self.running: if self._is_orphaned: - logger.info("Goal state agent {0} was orphaned -- exiting", CURRENT_AGENT) + logger.info("Goal state agent {0} was orphaned -- exiting", + CURRENT_AGENT) break if self._upgrade_available(): @@ -277,23 +279,19 @@ class UpdateHandler(object): duration=elapsed_milliseconds(utc_start), log_event=True) - test_agent = self.get_test_agent() - if test_agent is not None and test_agent.in_slice: - test_agent.enable() - logger.info(u"Enabled Agent {0} as test agent", test_agent.name) - break - time.sleep(GOAL_STATE_INTERVAL) except Exception as e: - logger.warn(u"Agent {0} failed with exception: {1}", CURRENT_AGENT, ustr(e)) + logger.warn(u"Agent {0} failed with exception: {1}", + CURRENT_AGENT, + ustr(e)) logger.warn(traceback.format_exc()) sys.exit(1) + # additional return here because sys.exit is mocked in unit tests return self._shutdown() sys.exit(0) - return def forward_signal(self, signum, frame): # Note: @@ -339,14 +337,6 @@ class UpdateHandler(object): return available_agents[0] if len(available_agents) >= 1 else None - def get_test_agent(self): - agent = None - agents = [agent for agent in self._load_agents() if agent.is_test] - if len(agents) > 0: - agents.sort(key=lambda agent: agent.version, reverse=True) - agent = agents[0] - return agent - def _emit_restart_event(self): if not self._is_clean_start: msg = u"{0} did not terminate cleanly".format(CURRENT_AGENT) @@ -361,81 +351,13 @@ class UpdateHandler(object): self._set_sentinal() return - def _upgrade_available(self, base_version=CURRENT_VERSION): - # Ignore new agents if updating is disabled - if not conf.get_autoupdate_enabled(): - return False - - now = time.time() - if self.last_attempt_time is not None: - next_attempt_time = self.last_attempt_time + conf.get_autoupdate_frequency() - else: - next_attempt_time = now - if next_attempt_time > now: - return False - - family = conf.get_autoupdate_gafamily() - logger.verbose("Checking for agent family {0} updates", family) - - self.last_attempt_time = now - try: - protocol = self.protocol_util.get_protocol() - manifest_list, etag = protocol.get_vmagent_manifests() - except Exception as e: - msg = u"Exception retrieving agent manifests: {0}".format(ustr(e)) - logger.warn(msg) - add_event( - AGENT_NAME, - op=WALAEventOperation.Download, - version=CURRENT_VERSION, - is_success=False, - message=msg) - return False - - manifests = [m for m in manifest_list.vmAgentManifests \ - if m.family == family and len(m.versionsManifestUris) > 0] - if len(manifests) == 0: - logger.verbose(u"Incarnation {0} has no agent family {1} updates", etag, family) - return False - - try: - pkg_list = protocol.get_vmagent_pkgs(manifests[0]) - except ProtocolError as e: - msg = u"Incarnation {0} failed to get {1} package list: " \ - u"{2}".format( - etag, - family, - ustr(e)) - logger.warn(msg) - add_event( - AGENT_NAME, - op=WALAEventOperation.Download, - version=CURRENT_VERSION, - is_success=False, - message=msg) - return False - - # Set the agents to those available for download at least as current - # as the existing agent and remove from disk any agent no longer - # reported to the VM. - # Note: - # The code leaves on disk available, but blacklisted, agents so as to - # preserve the state. Otherwise, those agents could be again - # downloaded and inappropriately retried. - host = self._get_host_plugin(protocol=protocol) - self._set_agents([GuestAgent(pkg=pkg, host=host) for pkg in pkg_list.versions]) - self._purge_agents() - self._filter_blacklisted_agents() - - # Return True if agents more recent than the current are available - return len(self.agents) > 0 and self.agents[0].version > base_version - def _ensure_no_orphans(self, orphan_wait_interval=ORPHAN_WAIT_INTERVAL): - previous_pid_file, pid_file = self._write_pid_file() - if previous_pid_file is not None: + pid_files, ignored = self._write_pid_file() + for pid_file in pid_files: try: - pid = fileutil.read_file(previous_pid_file) + pid = fileutil.read_file(pid_file) wait_interval = orphan_wait_interval + while self.osutil.check_pid_alive(pid): wait_interval -= GOAL_STATE_INTERVAL if wait_interval <= 0: @@ -452,6 +374,8 @@ class UpdateHandler(object): pid) time.sleep(GOAL_STATE_INTERVAL) + os.remove(pid_file) + except Exception as e: logger.warn( u"Exception occurred waiting for orphan agent to terminate: {0}", @@ -508,22 +432,18 @@ class UpdateHandler(object): protocol.client \ else None - def _get_pid_files(self): + def _get_pid_parts(self): pid_file = conf.get_agent_pid_file_path() - pid_dir = os.path.dirname(pid_file) pid_name = os.path.basename(pid_file) - pid_re = re.compile("(\d+)_{0}".format(re.escape(pid_name))) - pid_files = [int(pid_re.match(f).group(1)) for f in os.listdir(pid_dir) if pid_re.match(f)] - pid_files.sort() + return pid_dir, pid_name, pid_re - pid_index = -1 if len(pid_files) <= 0 else pid_files[-1] - previous_pid_file = None \ - if pid_index < 0 \ - else os.path.join(pid_dir, "{0}_{1}".format(pid_index, pid_name)) - pid_file = os.path.join(pid_dir, "{0}_{1}".format(pid_index+1, pid_name)) - return previous_pid_file, pid_file + def _get_pid_files(self): + pid_dir, pid_name, pid_re = self._get_pid_parts() + pid_files = [os.path.join(pid_dir, f) for f in os.listdir(pid_dir) if pid_re.match(f)] + pid_files.sort(key=lambda f: int(pid_re.match(os.path.basename(f)).group(1))) + return pid_files @property def _is_clean_start(self): @@ -619,8 +539,98 @@ class UpdateHandler(object): str(e)) return + def _upgrade_available(self, base_version=CURRENT_VERSION): + # Emit an event expressing the state of AutoUpdate + # Note: + # - Duplicate events get suppressed; state transitions always emit + add_event( + AGENT_NAME, + version=CURRENT_VERSION, + op=WALAEventOperation.AutoUpdate, + is_success=conf.get_autoupdate_enabled()) + + # Ignore new agents if updating is disabled + if not conf.get_autoupdate_enabled(): + return False + + now = time.time() + if self.last_attempt_time is not None: + next_attempt_time = self.last_attempt_time + \ + conf.get_autoupdate_frequency() + else: + next_attempt_time = now + if next_attempt_time > now: + return False + + family = conf.get_autoupdate_gafamily() + logger.verbose("Checking for agent family {0} updates", family) + + self.last_attempt_time = now + protocol = self.protocol_util.get_protocol() + + for update_goal_state in [False, True]: + try: + if update_goal_state: + protocol.update_goal_state(forced=True) + + manifest_list, etag = protocol.get_vmagent_manifests() + + manifests = [m for m in manifest_list.vmAgentManifests \ + if m.family == family and \ + len(m.versionsManifestUris) > 0] + if len(manifests) == 0: + logger.verbose(u"Incarnation {0} has no {1} agent updates", + etag, family) + return False + + pkg_list = protocol.get_vmagent_pkgs(manifests[0]) + + # Set the agents to those available for download at least as + # current as the existing agent and remove from disk any agent + # no longer reported to the VM. + # Note: + # The code leaves on disk available, but blacklisted, agents + # so as to preserve the state. Otherwise, those agents could be + # again downloaded and inappropriately retried. + host = self._get_host_plugin(protocol=protocol) + self._set_agents([GuestAgent(pkg=pkg, host=host) \ + for pkg in pkg_list.versions]) + + self._purge_agents() + self._filter_blacklisted_agents() + + # Return True if more recent agents are available + return len(self.agents) > 0 and \ + self.agents[0].version > base_version + + except Exception as e: + if isinstance(e, ResourceGoneError): + continue + + msg = u"Exception retrieving agent manifests: {0}".format( + ustr(e)) + logger.warn(msg) + add_event( + AGENT_NAME, + op=WALAEventOperation.Download, + version=CURRENT_VERSION, + is_success=False, + message=msg) + return False + def _write_pid_file(self): - previous_pid_file, pid_file = self._get_pid_files() + pid_files = self._get_pid_files() + + pid_dir, pid_name, pid_re = self._get_pid_parts() + + previous_pid_file = None \ + if len(pid_files) <= 0 \ + else pid_files[-1] + pid_index = -1 \ + if previous_pid_file is None \ + else int(pid_re.match(os.path.basename(previous_pid_file)).group(1)) + pid_file = os.path.join(pid_dir, "{0}_{1}".format(pid_index+1, pid_name)) + try: fileutil.write_file(pid_file, ustr(os.getpid())) logger.info(u"{0} running as process {1}", CURRENT_AGENT, ustr(os.getpid())) @@ -631,7 +641,8 @@ class UpdateHandler(object): CURRENT_AGENT, pid_file, ustr(e)) - return previous_pid_file, pid_file + + return pid_files, pid_file class GuestAgent(object): @@ -652,15 +663,34 @@ class GuestAgent(object): self.version = FlexibleVersion(version) location = u"disk" if path is not None else u"package" - logger.verbose(u"Loading Agent {0} from package {1}", self.name, location) + logger.verbose(u"Loading Agent {0} from {1}", self.name, location) - self.error = None - self.supported = None + self.error = GuestAgentError(self.get_agent_error_file()) + self.error.load() + self.supported = Supported(self.get_agent_supported_file()) + self.supported.load() - self._load_error() - self._load_supported() + try: + self._ensure_downloaded() + self._ensure_loaded() + except Exception as e: + if isinstance(e, ResourceGoneError): + raise - self._ensure_downloaded() + # Note the failure, blacklist the agent if the package downloaded + # - An exception with a downloaded package indicates the package + # is corrupt (e.g., missing the HandlerManifest.json file) + self.mark_failure(is_fatal=os.path.isfile(self.get_agent_pkg_path())) + + msg = u"Agent {0} install failed with exception: {1}".format( + self.name, ustr(e)) + logger.warn(msg) + add_event( + AGENT_NAME, + version=self.version, + op=WALAEventOperation.Install, + is_success=False, + message=msg) return @property @@ -687,12 +717,7 @@ class GuestAgent(object): def clear_error(self): self.error.clear() - return - - def enable(self): - if self.error.is_sentinel: - self.error.clear() - self.error.save() + self.error.save() return @property @@ -708,12 +733,12 @@ class GuestAgent(object): return self.is_blacklisted or os.path.isfile(self.get_agent_manifest_path()) @property - def is_test(self): - return self.error.is_sentinel and self.supported.is_supported + def _is_optional(self): + return self.error is not None and self.error.is_sentinel and self.supported.is_supported @property - def in_slice(self): - return self.is_test and self.supported.in_slice + def _in_slice(self): + return self.supported.is_supported and self.supported.in_slice def mark_failure(self, is_fatal=False): try: @@ -727,72 +752,79 @@ class GuestAgent(object): logger.warn(u"Agent {0} failed recording error state: {1}", self.name, ustr(e)) return + def _enable(self): + # Enable optional agents if within the "slice" + # - The "slice" is a percentage of the agent to execute + # - Blacklist out-of-slice agents to prevent reconsideration + if self._is_optional: + if self._in_slice: + self.error.clear() + self.error.save() + logger.info(u"Enabled optional Agent {0}", self.name) + else: + self.mark_failure(is_fatal=True) + logger.info(u"Optional Agent {0} not in slice", self.name) + return + def _ensure_downloaded(self): - try: - logger.verbose(u"Ensuring Agent {0} is downloaded", self.name) - - if self.is_blacklisted: - logger.verbose(u"Agent {0} is blacklisted - skipping download", self.name) - return - - if self.is_downloaded: - logger.verbose(u"Agent {0} was previously downloaded - skipping download", self.name) - self._load_manifest() - return - - if self.pkg is None: - raise UpdateError(u"Agent {0} is missing package and download URIs".format( - self.name)) - - self._download() - self._unpack() - self._load_manifest() - self._load_error() - self._load_supported() - - msg = u"Agent {0} downloaded successfully".format(self.name) - logger.verbose(msg) - add_event( - AGENT_NAME, - version=self.version, - op=WALAEventOperation.Install, - is_success=True, - message=msg) + logger.verbose(u"Ensuring Agent {0} is downloaded", self.name) - except Exception as e: - # Note the failure, blacklist the agent if the package downloaded - # - An exception with a downloaded package indicates the package - # is corrupt (e.g., missing the HandlerManifest.json file) - self.mark_failure(is_fatal=os.path.isfile(self.get_agent_pkg_path())) + if self.is_downloaded: + logger.verbose(u"Agent {0} was previously downloaded - skipping download", self.name) + return - msg = u"Agent {0} download failed with exception: {1}".format(self.name, ustr(e)) - logger.warn(msg) - add_event( - AGENT_NAME, - version=self.version, - op=WALAEventOperation.Install, - is_success=False, - message=msg) + if self.pkg is None: + raise UpdateError(u"Agent {0} is missing package and download URIs".format( + self.name)) + + self._download() + self._unpack() + + msg = u"Agent {0} downloaded successfully".format(self.name) + logger.verbose(msg) + add_event( + AGENT_NAME, + version=self.version, + op=WALAEventOperation.Install, + is_success=True, + message=msg) + return + + def _ensure_loaded(self): + self._load_manifest() + self._load_error() + self._load_supported() + + self._enable() return def _download(self): for uri in self.pkg.uris: if not HostPluginProtocol.is_default_channel() and self._fetch(uri.uri): break + elif self.host is not None and self.host.ensure_initialized(): if not HostPluginProtocol.is_default_channel(): - logger.warn("Download unsuccessful, falling back to host plugin") + logger.warn("Download failed, switching to host plugin") else: logger.verbose("Using host plugin as default channel") uri, headers = self.host.get_artifact_request(uri.uri, self.host.manifest_uri) - if self._fetch(uri, headers=headers): - if not HostPluginProtocol.is_default_channel(): - logger.verbose("Setting host plugin as default channel") - HostPluginProtocol.set_default_channel(True) - break - else: - logger.warn("Host plugin download unsuccessful") + try: + if self._fetch(uri, headers=headers, use_proxy=False): + if not HostPluginProtocol.is_default_channel(): + logger.verbose("Setting host plugin as default channel") + HostPluginProtocol.set_default_channel(True) + break + else: + logger.warn("Host plugin download failed") + + # If the HostPlugin rejects the request, + # let the error continue, but set to use the HostPlugin + except ResourceGoneError: + HostPluginProtocol.set_default_channel(True) + raise + else: logger.error("No download channels available") @@ -805,13 +837,14 @@ class GuestAgent(object): is_success=False, message=msg) raise UpdateError(msg) + return - def _fetch(self, uri, headers=None): + def _fetch(self, uri, headers=None, use_proxy=True): package = None try: - resp = restutil.http_get(uri, chk_proxy=True, headers=headers) - if resp.status == restutil.httpclient.OK: + resp = restutil.http_get(uri, use_proxy=use_proxy, headers=headers) + if restutil.request_succeeded(resp): package = resp.read() fileutil.write_file(self.get_agent_pkg_path(), bytearray(package), @@ -819,18 +852,21 @@ class GuestAgent(object): logger.verbose(u"Agent {0} downloaded from {1}", self.name, uri) else: logger.verbose("Fetch was unsuccessful [{0}]", - HostPluginProtocol.read_response_error(resp)) + restutil.read_response_error(resp)) except restutil.HttpError as http_error: + if isinstance(http_error, ResourceGoneError): + raise + logger.verbose(u"Agent {0} download from {1} failed [{2}]", self.name, uri, http_error) + return package is not None def _load_error(self): try: - if self.error is None: - self.error = GuestAgentError(self.get_agent_error_file()) + self.error = GuestAgentError(self.get_agent_error_file()) self.error.load() logger.verbose(u"Agent {0} error state: {1}", self.name, ustr(self.error)) except Exception as e: @@ -840,6 +876,7 @@ class GuestAgent(object): def _load_supported(self): try: self.supported = Supported(self.get_agent_supported_file()) + self.supported.load() except Exception as e: self.supported = Supported() @@ -892,6 +929,9 @@ class GuestAgent(object): zipfile.ZipFile(self.get_agent_pkg_path()).extractall(self.get_agent_dir()) except Exception as e: + fileutil.clean_ioerror(e, + paths=[self.get_agent_dir(), self.get_agent_pkg_path()]) + msg = u"Exception unpacking Agent {0} from {1}: {2}".format( self.name, self.get_agent_pkg_path(), @@ -918,7 +958,6 @@ class GuestAgentError(object): self.path = path self.clear() - self.load() return def mark_failure(self, is_fatal=False): @@ -982,8 +1021,7 @@ class Supported(object): if path is None: raise UpdateError(u"Supported requires a path") self.path = path - - self._load() + self.distributions = {} return @property @@ -995,15 +1033,7 @@ class Supported(object): d = self._supported_distribution return d is not None and d.in_slice - @property - def _supported_distribution(self): - for d in self.distributions: - dd = self.distributions[d] - if dd.is_supported: - return dd - return None - - def _load(self): + def load(self): self.distributions = {} try: if self.path is not None and os.path.isfile(self.path): @@ -1014,6 +1044,14 @@ class Supported(object): logger.warn("Failed JSON parse of {0}: {1}".format(self.path, e)) return + @property + def _supported_distribution(self): + for d in self.distributions: + dd = self.distributions[d] + if dd.is_supported: + return dd + return None + class SupportedDistribution(object): def __init__(self, s): if s is None or not isinstance(s, dict): diff --git a/azurelinuxagent/pa/provision/default.py b/azurelinuxagent/pa/provision/default.py index 959a2fe..2f7ec18 100644 --- a/azurelinuxagent/pa/provision/default.py +++ b/azurelinuxagent/pa/provision/default.py @@ -53,16 +53,16 @@ class ProvisionHandler(object): self.protocol_util = get_protocol_util() def run(self): - # If provision is not enabled, report ready and then return if not conf.get_provision_enabled(): logger.info("Provisioning is disabled, skipping.") + self.write_provisioned() + self.report_ready() return try: utc_start = datetime.utcnow() thumbprint = None - # if provisioning is already done, return if self.is_provisioned(): logger.info("Provisioning already completed, skipping.") return @@ -85,7 +85,6 @@ class ProvisionHandler(object): thumbprint = self.reg_ssh_host_key() self.osutil.restart_ssh_service() - # write out provisioned file and report Ready self.write_provisioned() self.report_event("Provision succeed", @@ -128,9 +127,17 @@ class ProvisionHandler(object): keypair_type = conf.get_ssh_host_keypair_type() if conf.get_regenerate_ssh_host_key(): fileutil.rm_files(conf.get_ssh_key_glob()) - keygen_cmd = "ssh-keygen -N '' -t {0} -f {1}" - shellutil.run(keygen_cmd.format(keypair_type, - conf.get_ssh_key_private_path())) + if conf.get_ssh_host_keypair_mode() == "auto": + ''' + The -A option generates all supported key types. + This is supported since OpenSSH 5.9 (2011). + ''' + shellutil.run("ssh-keygen -A") + else: + keygen_cmd = "ssh-keygen -N '' -t {0} -f {1}" + shellutil.run(keygen_cmd. + format(keypair_type, + conf.get_ssh_key_private_path())) return self.get_ssh_host_key_thumbprint() def get_ssh_host_key_thumbprint(self, chk_err=True): @@ -162,7 +169,7 @@ class ProvisionHandler(object): return False s = fileutil.read_file(self.provisioned_file_path()).strip() - if s != self.osutil.get_instance_id(): + if not self.osutil.is_current_instance_id(s): if len(s) > 0: logger.warn("VM is provisioned, " "but the VM unique identifier has changed -- " @@ -173,6 +180,7 @@ class ProvisionHandler(object): deprovision_handler.run_changed_unique_id() self.write_provisioned() + self.report_ready() return True diff --git a/azurelinuxagent/pa/rdma/factory.py b/azurelinuxagent/pa/rdma/factory.py index 535b3d3..92bd2e0 100644 --- a/azurelinuxagent/pa/rdma/factory.py +++ b/azurelinuxagent/pa/rdma/factory.py @@ -21,6 +21,7 @@ from azurelinuxagent.common.version import DISTRO_FULL_NAME, DISTRO_VERSION from azurelinuxagent.common.rdma import RDMAHandler from .suse import SUSERDMAHandler from .centos import CentOSRDMAHandler +from .ubuntu import UbuntuRDMAHandler def get_rdma_handler( @@ -37,5 +38,8 @@ def get_rdma_handler( if distro_full_name == 'CentOS Linux' or distro_full_name == 'CentOS': return CentOSRDMAHandler(distro_version) + if distro_full_name == 'Ubuntu': + return UbuntuRDMAHandler() + logger.info("No RDMA handler exists for distro='{0}' version='{1}'", distro_full_name, distro_version) return RDMAHandler() diff --git a/azurelinuxagent/pa/rdma/suse.py b/azurelinuxagent/pa/rdma/suse.py index d31b2b0..20f06cd 100644 --- a/azurelinuxagent/pa/rdma/suse.py +++ b/azurelinuxagent/pa/rdma/suse.py @@ -36,8 +36,9 @@ class SUSERDMAHandler(RDMAHandler): logger.error(error_msg) return zypper_install = 'zypper -n in %s' + zypper_install_noref = 'zypper -n --no-refresh in %s' zypper_remove = 'zypper -n rm %s' - zypper_search = 'zypper se -s %s' + zypper_search = 'zypper -n se -s %s' package_name = 'msft-rdma-kmp-default' cmd = zypper_search % package_name status, repo_package_info = shellutil.run_get_output(cmd) @@ -108,9 +109,9 @@ class SUSERDMAHandler(RDMAHandler): fw_version in local_package ): logger.info("RDMA: Installing: %s" % local_package) - cmd = zypper_install % local_package + cmd = zypper_install_noref % local_package result = shellutil.run(cmd) - if result: + if result and result != 106: error_msg = 'RDMA: Failed install of package "%s" ' error_msg += 'from local package cache' logger.error(error_msg % local_package) diff --git a/azurelinuxagent/pa/rdma/ubuntu.py b/azurelinuxagent/pa/rdma/ubuntu.py new file mode 100644 index 0000000..050797d --- /dev/null +++ b/azurelinuxagent/pa/rdma/ubuntu.py @@ -0,0 +1,122 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import glob +import os +import re +import time +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.shellutil as shellutil +from azurelinuxagent.common.rdma import RDMAHandler + + +class UbuntuRDMAHandler(RDMAHandler): + + def install_driver(self): + #Install the appropriate driver package for the RDMA firmware + + nd_version = RDMAHandler.get_rdma_version() + if not nd_version: + logger.error("RDMA: Could not determine firmware version. No driver will be installed") + return + #replace . with _, we are looking for number like 144_0 + nd_version = re.sub('\.', '_', nd_version) + + #Check to see if we need to reconfigure driver + status,module_name = shellutil.run_get_output('modprobe -R hv_network_direct', chk_err=False) + if status != 0: + logger.info("RDMA: modprobe -R hv_network_direct failed. Use module name hv_network_direct") + module_name = "hv_network_direct" + else: + module_name = module_name.strip() + logger.info("RDMA: current RDMA driver %s nd_version %s" % (module_name, nd_version)) + if module_name == 'hv_network_direct_%s' % nd_version: + logger.info("RDMA: driver is installed and ND version matched. Skip reconfiguring driver") + return + + #Reconfigure driver if one is available + status,output = shellutil.run_get_output('modinfo hv_network_direct_%s' % nd_version); + if status == 0: + logger.info("RDMA: driver with ND version is installed. Link to module name") + self.update_modprobed_conf(nd_version) + return + + #Driver not found. We need to check to see if we need to update kernel + if not conf.enable_rdma_update(): + logger.info("RDMA: driver update is disabled. Skip kernel update") + return + + status,output = shellutil.run_get_output('uname -r') + if status != 0: + return + if not re.search('-azure$', output): + logger.error("RDMA: skip driver update on non-Azure kernel") + return + kernel_version = re.sub('-azure$', '', output) + kernel_version = re.sub('-', '.', kernel_version) + + #Find the new kernel package version + status,output = shellutil.run_get_output('apt-get update') + if status != 0: + return + status,output = shellutil.run_get_output('apt-cache show --no-all-versions linux-azure') + if status != 0: + return + r = re.search('Version: (\S+)', output) + if not r: + logger.error("RDMA: version not found in package linux-azure.") + return + package_version = r.groups()[0] + #Remove the ending .<upload number> after <ABI number> + package_version = re.sub("\.\d+$", "", package_version) + + logger.info('RDMA: kernel_version=%s package_version=%s' % (kernel_version, package_version)) + kernel_version_array = [ int(x) for x in kernel_version.split('.') ] + package_version_array = [ int(x) for x in package_version.split('.') ] + if kernel_version_array < package_version_array: + logger.info("RDMA: newer version available, update kernel and reboot") + status,output = shellutil.run_get_output('apt-get -y install linux-azure') + if status: + logger.error("RDMA: kernel update failed") + return + self.reboot_system() + else: + logger.error("RDMA: no kernel update is avaiable for ND version %s" % nd_version) + + def update_modprobed_conf(self, nd_version): + #Update /etc/modprobe.d/vmbus-rdma.conf to point to the correct driver + + modprobed_file = '/etc/modprobe.d/vmbus-rdma.conf' + lines = '' + if not os.path.isfile(modprobed_file): + logger.info("RDMA: %s not found, it will be created" % modprobed_file) + else: + f = open(modprobed_file, 'r') + lines = f.read() + f.close() + r = re.search('alias hv_network_direct hv_network_direct_\S+', lines) + if r: + lines = re.sub('alias hv_network_direct hv_network_direct_\S+', 'alias hv_network_direct hv_network_direct_%s' % nd_version, lines) + else: + lines += '\nalias hv_network_direct hv_network_direct_%s\n' % nd_version + f = open('/etc/modprobe.d/vmbus-rdma.conf', 'w') + f.write(lines) + f.close() + logger.info("RDMA: hv_network_direct alias updated to ND %s" % nd_version) diff --git a/config/alpine/waagent.conf b/config/alpine/waagent.conf index 2e3f6a5..99495d5 100644 --- a/config/alpine/waagent.conf +++ b/config/alpine/waagent.conf @@ -81,3 +81,12 @@ OS.SshDir=/etc/ssh # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/arch/waagent.conf b/config/arch/waagent.conf index 686b90c..200a458 100644 --- a/config/arch/waagent.conf +++ b/config/arch/waagent.conf @@ -107,3 +107,12 @@ OS.SshDir=/etc/ssh # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/bigip/waagent.conf b/config/bigip/waagent.conf index a6a380b..9ff6ee1 100644 --- a/config/bigip/waagent.conf +++ b/config/bigip/waagent.conf @@ -76,4 +76,21 @@ OS.SshdConfigPath=/config/ssh/sshd_config OS.EnableRDMA=n # Enable or disable goal state processing auto-update, default is enabled -AutoUpdate.Enabled=y
\ No newline at end of file +AutoUpdate.Enabled=y + +# Determine the update family, this should not be changed +# AutoUpdate.GAFamily=Prod + +# Determine if the overprovisioning feature is enabled. If yes, hold extension +# handling until inVMArtifactsProfile.OnHold is false. +# Default is disabled +# EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/clearlinux/waagent.conf b/config/clearlinux/waagent.conf index 6606cd7..8109425 100644 --- a/config/clearlinux/waagent.conf +++ b/config/clearlinux/waagent.conf @@ -79,3 +79,12 @@ AutoUpdate.GAFamily=Prod # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/coreos/waagent.conf b/config/coreos/waagent.conf index 664d037..cbb327f 100644 --- a/config/coreos/waagent.conf +++ b/config/coreos/waagent.conf @@ -107,3 +107,12 @@ OS.OpensslPath=None # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +OS.AllowHTTP=y + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/freebsd/waagent.conf b/config/freebsd/waagent.conf index 5149573..6406c75 100644 --- a/config/freebsd/waagent.conf +++ b/config/freebsd/waagent.conf @@ -105,3 +105,12 @@ OS.SudoersDir=/usr/local/etc/sudoers.d # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/gaia/waagent.conf b/config/gaia/waagent.conf index 75550a6..9c28ba3 100644 --- a/config/gaia/waagent.conf +++ b/config/gaia/waagent.conf @@ -104,3 +104,12 @@ AutoUpdate.Enabled=n # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/openbsd/waagent.conf b/config/openbsd/waagent.conf index 09e7db7..a39a9a5 100644 --- a/config/openbsd/waagent.conf +++ b/config/openbsd/waagent.conf @@ -14,8 +14,8 @@ Provisioning.DeleteRootPassword=y # Generate fresh host key pair. Provisioning.RegenerateSshHostKeyPair=y -# Supported values are "rsa", "dsa", "ecdsa", and "ed25519". -Provisioning.SshHostKeyPairType=ed25519 +# Supported values are "rsa", "dsa", "ecdsa", "ed25519", and "auto". +Provisioning.SshHostKeyPairType=auto # Monitor host name changes and publish changes via DHCP requests. Provisioning.MonitorHostName=y @@ -103,3 +103,12 @@ OS.PasswordPath=/etc/master.passwd # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/suse/waagent.conf b/config/suse/waagent.conf index b2e90a8..ba50be6 100644 --- a/config/suse/waagent.conf +++ b/config/suse/waagent.conf @@ -107,3 +107,12 @@ OS.SshDir=/etc/ssh # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/ubuntu/waagent.conf b/config/ubuntu/waagent.conf index 734a403..71f2c04 100644 --- a/config/ubuntu/waagent.conf +++ b/config/ubuntu/waagent.conf @@ -82,6 +82,9 @@ OS.SshDir=/etc/ssh # Enable RDMA management and set up, should only be used in HPC images # OS.EnableRDMA=y +# Enable RDMA kernel update, this value is effective on Ubuntu +# OS.UpdateRdmaDriver=y + # Enable or disable goal state processing auto-update, default is enabled # AutoUpdate.Enabled=y @@ -92,3 +95,12 @@ OS.SshDir=/etc/ssh # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/config/waagent.conf b/config/waagent.conf index b1b1ba3..99f54d6 100644 --- a/config/waagent.conf +++ b/config/waagent.conf @@ -14,7 +14,7 @@ Provisioning.DeleteRootPassword=y # Generate fresh host key pair. Provisioning.RegenerateSshHostKeyPair=y -# Supported values are "rsa", "dsa" and "ecdsa". +# Supported values are "rsa", "dsa", "ecdsa", "ed25519", and "auto". Provisioning.SshHostKeyPairType=rsa # Monitor host name changes and publish changes via DHCP requests. @@ -104,3 +104,12 @@ OS.SshDir=/etc/ssh # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/debian/changelog b/debian/changelog index 453beee..cf0a1fd 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +walinuxagent (2.2.16-0ubuntu1) artful; urgency=medium + + * New upstream release (LP: #1714299). + + -- Ćukasz 'sil2100' Zemczak <lukasz.zemczak@ubuntu.com> Mon, 04 Sep 2017 10:27:07 +0200 + walinuxagent (2.2.14-0ubuntu1) artful; urgency=medium * New upstream release (LP: #1701350). diff --git a/debian/patches/disable_import_test.patch b/debian/patches/disable_import_test.patch index cc69b6b..f51d77e 100644 --- a/debian/patches/disable_import_test.patch +++ b/debian/patches/disable_import_test.patch @@ -18,7 +18,7 @@ -Provisioning.RegenerateSshHostKeyPair=y +Provisioning.RegenerateSshHostKeyPair=n - # Supported values are "rsa", "dsa" and "ecdsa". + # Supported values are "rsa", "dsa", "ecdsa", "ed25519", and "auto". Provisioning.SshHostKeyPairType=rsa @@ -36,14 +36,14 @@ Provisioning.AllowResetSysUser=n diff --git a/init/freebsd/waagent b/init/freebsd/waagent index 99ddef7..becc2a3 100755 --- a/init/freebsd/waagent +++ b/init/freebsd/waagent @@ -6,7 +6,7 @@ . /etc/rc.subr -PATH=$PATH:/usr/local/bin +PATH=$PATH:/usr/local/bin:/usr/local/sbin name="waagent" rcvar="waagent_enable" pidfile="/var/run/waagent.pid" @@ -106,7 +106,8 @@ def get_data_files(name, version, fullname): set_udev_files(data_files) set_files(data_files, dest="/usr/share/oem", src=["init/coreos/cloud-config.yml"]) - elif name == 'clear linux software for intel architecture': + elif name == 'clear linux os for intel architecture' \ + or name == 'clear linux software for intel architecture': set_bin_files(data_files, dest="/usr/bin") set_conf_files(data_files, dest="/usr/share/defaults/waagent", src=["config/clearlinux/waagent.conf"]) diff --git a/tests/common/osutil/test_default.py b/tests/common/osutil/test_default.py index 87acc60..ec4408b 100644 --- a/tests/common/osutil/test_default.py +++ b/tests/common/osutil/test_default.py @@ -25,6 +25,7 @@ from azurelinuxagent.common.exception import OSUtilError from azurelinuxagent.common.future import ustr from azurelinuxagent.common.osutil import get_osutil from azurelinuxagent.common.utils import fileutil +from azurelinuxagent.common.utils.flexible_version import FlexibleVersion from tests.tools import * @@ -112,6 +113,21 @@ class TestOSUtil(AgentTestCase): self.assertFalse(osutil.DefaultOSUtil().is_primary_interface('lo')) self.assertTrue(osutil.DefaultOSUtil().is_primary_interface('eth0')) + def test_sriov(self): + routing_table = "\ + Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT \n" \ + "bond0 00000000 0100000A 0003 0 0 0 00000000 0 0 0 \n" \ + "bond0 0000000A 00000000 0001 0 0 0 00000000 0 0 0 \n" \ + "eth0 0000000A 00000000 0001 0 0 0 00000000 0 0 0 \n" \ + "bond0 10813FA8 0100000A 0007 0 0 0 00000000 0 0 0 \n" \ + "bond0 FEA9FEA9 0100000A 0007 0 0 0 00000000 0 0 0 \n" + + mo = mock.mock_open(read_data=routing_table) + with patch(open_patch(), mo): + self.assertFalse(osutil.DefaultOSUtil().is_primary_interface('eth0')) + self.assertTrue(osutil.DefaultOSUtil().is_primary_interface('bond0')) + + def test_multiple_default_routes(self): routing_table = "\ Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT \n\ @@ -362,23 +378,50 @@ Match host 192.168.1.2\n\ conf.get_sshd_conf_file_path(), expected_output) + def test_correct_instance_id(self): + util = osutil.DefaultOSUtil() + self.assertEqual( + "12345678-1234-1234-1234-123456789012", + util._correct_instance_id("78563412-3412-3412-1234-123456789012")) + self.assertEqual( + "D0DF4C54-4ECB-4A4B-9954-5BDF3ED5C3B8", + util._correct_instance_id("544CDFD0-CB4E-4B4A-9954-5BDF3ED5C3B8")) + @patch('os.path.isfile', return_value=True) @patch('azurelinuxagent.common.utils.fileutil.read_file', - return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502") + return_value="33C2F3B9-1399-429F-8EB3-BA656DF32502") def test_get_instance_id_from_file(self, mock_read, mock_isfile): util = osutil.DefaultOSUtil() self.assertEqual( - "B9F3C233-9913-9F42-8EB3-BA656DF32502", + util.get_instance_id(), + "B9F3C233-9913-9F42-8EB3-BA656DF32502") + + @patch('os.path.isfile', return_value=True) + @patch('azurelinuxagent.common.utils.fileutil.read_file', + return_value="") + def test_get_instance_id_empty_from_file(self, mock_read, mock_isfile): + util = osutil.DefaultOSUtil() + self.assertEqual( + "", + util.get_instance_id()) + + @patch('os.path.isfile', return_value=True) + @patch('azurelinuxagent.common.utils.fileutil.read_file', + return_value="Value") + def test_get_instance_id_malformed_from_file(self, mock_read, mock_isfile): + util = osutil.DefaultOSUtil() + self.assertEqual( + "Value", util.get_instance_id()) @patch('os.path.isfile', return_value=False) @patch('azurelinuxagent.common.utils.shellutil.run_get_output', - return_value=[0, 'B9F3C233-9913-9F42-8EB3-BA656DF32502']) + return_value=[0, '33C2F3B9-1399-429F-8EB3-BA656DF32502']) def test_get_instance_id_from_dmidecode(self, mock_shell, mock_isfile): util = osutil.DefaultOSUtil() self.assertEqual( - "B9F3C233-9913-9F42-8EB3-BA656DF32502", - util.get_instance_id()) + util.get_instance_id(), + "B9F3C233-9913-9F42-8EB3-BA656DF32502") @patch('os.path.isfile', return_value=False) @patch('azurelinuxagent.common.utils.shellutil.run_get_output', @@ -394,5 +437,181 @@ Match host 192.168.1.2\n\ util = osutil.DefaultOSUtil() self.assertEqual("", util.get_instance_id()) + @patch('os.path.isfile', return_value=True) + @patch('azurelinuxagent.common.utils.fileutil.read_file') + def test_is_current_instance_id_from_file(self, mock_read, mock_isfile): + util = osutil.DefaultOSUtil() + + mock_read.return_value = "B9F3C233-9913-9F42-8EB3-BA656DF32502" + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + mock_read.return_value = "33C2F3B9-1399-429F-8EB3-BA656DF32502" + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + @patch('os.path.isfile', return_value=False) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + def test_is_current_instance_id_from_dmidecode(self, mock_shell, mock_isfile): + util = osutil.DefaultOSUtil() + + mock_shell.return_value = [0, 'B9F3C233-9913-9F42-8EB3-BA656DF32502'] + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + mock_shell.return_value = [0, '33C2F3B9-1399-429F-8EB3-BA656DF32502'] + self.assertTrue(util.is_current_instance_id( + "B9F3C233-9913-9F42-8EB3-BA656DF32502")) + + @patch('azurelinuxagent.common.conf.get_sudoers_dir') + def test_conf_sudoer(self, mock_dir): + tmp_dir = tempfile.mkdtemp() + mock_dir.return_value = tmp_dir + + util = osutil.DefaultOSUtil() + + # Assert the sudoer line is added if missing + util.conf_sudoer("FooBar") + waagent_sudoers = os.path.join(tmp_dir, 'waagent') + self.assertTrue(os.path.isfile(waagent_sudoers)) + + count = -1 + with open(waagent_sudoers, 'r') as f: + count = len(f.readlines()) + self.assertEqual(1, count) + + # Assert the line does not get added a second time + util.conf_sudoer("FooBar") + + count = -1 + with open(waagent_sudoers, 'r') as f: + count = len(f.readlines()) + print("WRITING TO {0}".format(waagent_sudoers)) + self.assertEqual(1, count) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [1, 0, 0] + mock_output.side_effect = [(0, version), (0, "Output")] + self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)), + call(osutil.FIREWALL_DROP.format(wait, "A", dst)) + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION), + call(osutil.FIREWALL_LIST.format(wait)) + ]) + self.assertTrue(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_no_wait(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION-1) + wait = "" + + mock_run.side_effect = [1, 0, 0] + mock_output.side_effect = [(0, version), (0, "Output")] + self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)), + call(osutil.FIREWALL_DROP.format(wait, "A", dst)) + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION), + call(osutil.FIREWALL_LIST.format(wait)) + ]) + self.assertTrue(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_skips_if_drop_exists(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [0, 0, 0] + mock_output.return_value = (0, version) + self.assertTrue(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION) + ]) + self.assertTrue(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_ignores_exceptions(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = True + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [1, Exception] + mock_output.return_value = (0, version) + self.assertFalse(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_has_calls([ + call(osutil.FIREWALL_DROP.format(wait, "C", dst), chk_err=False), + call(osutil.FIREWALL_ACCEPT.format(wait, "A", dst, uid)) + ]) + mock_output.assert_has_calls([ + call(osutil.IPTABLES_VERSION) + ]) + self.assertFalse(osutil._enable_firewall) + + @patch('os.getuid', return_value=42) + @patch('azurelinuxagent.common.utils.shellutil.run_get_output') + @patch('azurelinuxagent.common.utils.shellutil.run') + def test_enable_firewall_skips_if_disabled(self, mock_run, mock_output, mock_uid): + osutil._enable_firewall = False + util = osutil.DefaultOSUtil() + + dst = '1.2.3.4' + uid = 42 + version = "iptables v{0}".format(osutil.IPTABLES_LOCKING_VERSION) + wait = "-w" + + mock_run.side_effect = [1, 0, 0] + mock_output.side_effect = [(0, version), (0, "Output")] + self.assertFalse(util.enable_firewall(dst_ip=dst, uid=uid)) + + mock_run.assert_not_called() + mock_output.assert_not_called() + mock_uid.assert_not_called() + self.assertFalse(osutil._enable_firewall) + if __name__ == '__main__': unittest.main() diff --git a/tests/common/test_conf.py b/tests/common/test_conf.py index 1287b0d..93759de 100644 --- a/tests/common/test_conf.py +++ b/tests/common/test_conf.py @@ -24,6 +24,49 @@ from tests.tools import * class TestConf(AgentTestCase): + # Note: + # -- These values *MUST* match those from data/test_waagent.conf + EXPECTED_CONFIGURATION = { + "Provisioning.Enabled" : True, + "Provisioning.UseCloudInit" : True, + "Provisioning.DeleteRootPassword" : True, + "Provisioning.RegenerateSshHostKeyPair" : True, + "Provisioning.SshHostKeyPairType" : "rsa", + "Provisioning.MonitorHostName" : True, + "Provisioning.DecodeCustomData" : False, + "Provisioning.ExecuteCustomData" : False, + "Provisioning.PasswordCryptId" : '6', + "Provisioning.PasswordCryptSaltLength" : 10, + "Provisioning.AllowResetSysUser" : False, + "ResourceDisk.Format" : True, + "ResourceDisk.Filesystem" : "ext4", + "ResourceDisk.MountPoint" : "/mnt/resource", + "ResourceDisk.EnableSwap" : False, + "ResourceDisk.SwapSizeMB" : 0, + "ResourceDisk.MountOptions" : None, + "Logs.Verbose" : False, + "OS.EnableFIPS" : True, + "OS.RootDeviceScsiTimeout" : '300', + "OS.OpensslPath" : '/usr/bin/openssl', + "OS.SshDir" : "/notareal/path", + "HttpProxy.Host" : None, + "HttpProxy.Port" : None, + "DetectScvmmEnv" : False, + "Lib.Dir" : "/var/lib/waagent", + "DVD.MountPoint" : "/mnt/cdrom/secure", + "Pid.File" : "/var/run/waagent.pid", + "Extension.LogDir" : "/var/log/azure", + "OS.HomeDir" : "/home", + "OS.EnableRDMA" : False, + "OS.UpdateRdmaDriver" : False, + "OS.CheckRdmaDriver" : False, + "AutoUpdate.Enabled" : True, + "AutoUpdate.GAFamily" : "Prod", + "EnableOverProvisioning" : False, + "OS.AllowHTTP" : False, + "OS.EnableFirewall" : True + } + def setUp(self): AgentTestCase.setUp(self) self.conf = ConfigurationProvider() @@ -59,3 +102,11 @@ class TestConf(AgentTestCase): def test_get_provision_cloudinit(self): self.assertTrue(get_provision_cloudinit(self.conf)) + + def test_get_configuration(self): + configuration = conf.get_configuration(self.conf) + self.assertTrue(len(configuration.keys()) > 0) + for k in TestConf.EXPECTED_CONFIGURATION.keys(): + self.assertEqual( + TestConf.EXPECTED_CONFIGURATION[k], + configuration[k]) diff --git a/tests/common/test_event.py b/tests/common/test_event.py index a485edf..55a99c4 100644 --- a/tests/common/test_event.py +++ b/tests/common/test_event.py @@ -22,7 +22,8 @@ from datetime import datetime import azurelinuxagent.common.event as event import azurelinuxagent.common.logger as logger -from azurelinuxagent.common.event import init_event_logger, add_event +from azurelinuxagent.common.event import add_event, \ + mark_event_status, should_emit_event from azurelinuxagent.common.future import ustr from azurelinuxagent.common.version import CURRENT_VERSION @@ -30,10 +31,84 @@ from tests.tools import * class TestEvent(AgentTestCase): + def test_event_status_event_marked(self): + es = event.__event_status__ + + self.assertFalse(es.event_marked("Foo", "1.2", "FauxOperation")) + es.mark_event_status("Foo", "1.2", "FauxOperation", True) + self.assertTrue(es.event_marked("Foo", "1.2", "FauxOperation")) + + event.__event_status__ = event.EventStatus() + event.init_event_status(self.tmp_dir) + es = event.__event_status__ + self.assertTrue(es.event_marked("Foo", "1.2", "FauxOperation")) + + def test_event_status_defaults_to_success(self): + es = event.__event_status__ + self.assertTrue(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + def test_event_status_records_status(self): + d = tempfile.mkdtemp() + es = event.EventStatus(tempfile.mkdtemp()) + + es.mark_event_status("Foo", "1.2", "FauxOperation", True) + self.assertTrue(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + es.mark_event_status("Foo", "1.2", "FauxOperation", False) + self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + def test_event_status_preserves_state(self): + es = event.__event_status__ + + es.mark_event_status("Foo", "1.2", "FauxOperation", False) + self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + event.__event_status__ = event.EventStatus() + event.init_event_status(self.tmp_dir) + es = event.__event_status__ + self.assertFalse(es.event_succeeded("Foo", "1.2", "FauxOperation")) + + def test_should_emit_event_ignores_unknown_operations(self): + event.__event_status__ = event.EventStatus(tempfile.mkdtemp()) + + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", False)) + + # Marking the event has no effect + event.mark_event_status("Foo", "1.2", "FauxOperation", True) + + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", "FauxOperation", False)) + + + def test_should_emit_event_handles_known_operations(self): + event.__event_status__ = event.EventStatus(tempfile.mkdtemp()) + + # Known operations always initially "fire" + for op in event.__event_status_operations__: + self.assertTrue(event.should_emit_event("Foo", "1.2", op, True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", op, False)) + + # Note a success event... + for op in event.__event_status_operations__: + event.mark_event_status("Foo", "1.2", op, True) + + # Subsequent success events should not fire, but failures will + for op in event.__event_status_operations__: + self.assertFalse(event.should_emit_event("Foo", "1.2", op, True)) + self.assertTrue(event.should_emit_event("Foo", "1.2", op, False)) + + # Note a failure event... + for op in event.__event_status_operations__: + event.mark_event_status("Foo", "1.2", op, False) + + # Subsequent success events fire and failure do not + for op in event.__event_status_operations__: + self.assertTrue(event.should_emit_event("Foo", "1.2", op, True)) + self.assertFalse(event.should_emit_event("Foo", "1.2", op, False)) @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_emits_if_not_previously_sent(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -41,7 +116,6 @@ class TestEvent(AgentTestCase): @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_does_not_emit_if_previously_sent(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -52,7 +126,6 @@ class TestEvent(AgentTestCase): @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_emits_if_forced(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -63,7 +136,6 @@ class TestEvent(AgentTestCase): @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_emits_after_elapsed_delta(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -73,14 +145,13 @@ class TestEvent(AgentTestCase): self.assertEqual(1, mock_event.call_count) h = hash("FauxEvent"+""+ustr(True)+"") - event.__event_logger__.periodic_messages[h] = \ + event.__event_logger__.periodic_events[h] = \ datetime.now() - logger.EVERY_DAY - logger.EVERY_HOUR event.add_periodic(logger.EVERY_DAY, "FauxEvent") self.assertEqual(2, mock_event.call_count) @patch('azurelinuxagent.common.event.EventLogger.add_event') def test_periodic_forwards_args(self, mock_event): - init_event_logger(tempfile.mkdtemp()) event.__event_logger__.reset_periodic() event.add_periodic(logger.EVERY_DAY, "FauxEvent") @@ -90,68 +161,58 @@ class TestEvent(AgentTestCase): log_event=True, message='', op='', version=str(CURRENT_VERSION)) def test_save_event(self): - tmp_evt = tempfile.mkdtemp() - init_event_logger(tmp_evt) add_event('test', message='test event') - self.assertTrue(len(os.listdir(tmp_evt)) == 1) - shutil.rmtree(tmp_evt) + self.assertTrue(len(os.listdir(self.tmp_dir)) == 1) def test_save_event_rollover(self): - tmp_evt = tempfile.mkdtemp() - init_event_logger(tmp_evt) add_event('test', message='first event') for i in range(0, 999): add_event('test', message='test event {0}'.format(i)) - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) events.sort() self.assertTrue(len(events) == 1000) - first_event = os.path.join(tmp_evt, events[0]) + first_event = os.path.join(self.tmp_dir, events[0]) with open(first_event) as first_fh: first_event_text = first_fh.read() self.assertTrue('first event' in first_event_text) add_event('test', message='last event') - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) events.sort() self.assertTrue(len(events) == 1000, "{0} events found, 1000 expected".format(len(events))) - first_event = os.path.join(tmp_evt, events[0]) + first_event = os.path.join(self.tmp_dir, events[0]) with open(first_event) as first_fh: first_event_text = first_fh.read() self.assertFalse('first event' in first_event_text) self.assertTrue('test event 0' in first_event_text) - last_event = os.path.join(tmp_evt, events[-1]) + last_event = os.path.join(self.tmp_dir, events[-1]) with open(last_event) as last_fh: last_event_text = last_fh.read() self.assertTrue('last event' in last_event_text) - shutil.rmtree(tmp_evt) - def test_save_event_cleanup(self): - tmp_evt = tempfile.mkdtemp() - init_event_logger(tmp_evt) - for i in range(0, 2000): - evt = os.path.join(tmp_evt, '{0}.tld'.format(ustr(1491004920536531 + i))) + evt = os.path.join(self.tmp_dir, '{0}.tld'.format(ustr(1491004920536531 + i))) with open(evt, 'w') as fh: fh.write('test event {0}'.format(i)) - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) self.assertTrue(len(events) == 2000, "{0} events found, 2000 expected".format(len(events))) add_event('test', message='last event') - events = os.listdir(tmp_evt) + events = os.listdir(self.tmp_dir) events.sort() self.assertTrue(len(events) == 1000, "{0} events found, 1000 expected".format(len(events))) - first_event = os.path.join(tmp_evt, events[0]) + first_event = os.path.join(self.tmp_dir, events[0]) with open(first_event) as first_fh: first_event_text = first_fh.read() self.assertTrue('test event 1001' in first_event_text) - last_event = os.path.join(tmp_evt, events[-1]) + last_event = os.path.join(self.tmp_dir, events[-1]) with open(last_event) as last_fh: last_event_text = last_fh.read() self.assertTrue('last event' in last_event_text) diff --git a/tests/data/ga/WALinuxAgent-2.2.11.zip b/tests/data/ga/WALinuxAgent-2.2.11.zip Binary files differdeleted file mode 100644 index f018116..0000000 --- a/tests/data/ga/WALinuxAgent-2.2.11.zip +++ /dev/null diff --git a/tests/data/ga/WALinuxAgent-2.2.14.zip b/tests/data/ga/WALinuxAgent-2.2.14.zip Binary files differnew file mode 100644 index 0000000..a978207 --- /dev/null +++ b/tests/data/ga/WALinuxAgent-2.2.14.zip diff --git a/tests/data/test_waagent.conf b/tests/data/test_waagent.conf index 6368c39..edc3676 100644 --- a/tests/data/test_waagent.conf +++ b/tests/data/test_waagent.conf @@ -94,10 +94,13 @@ OS.SshDir=/notareal/path # Extension.LogDir=/var/log/azure # -# Home.Dir=/home +# OS.HomeDir=/home # Enable RDMA management and set up, should only be used in HPC images -# OS.EnableRDMA=y +# OS.EnableRDMA=n +# OS.UpdateRdmaDriver=n +# OS.CheckRdmaDriver=n + # Enable or disable goal state processing auto-update, default is enabled # AutoUpdate.Enabled=y @@ -109,3 +112,12 @@ OS.SshDir=/notareal/path # handling until inVMArtifactsProfile.OnHold is false. # Default is disabled # EnableOverProvisioning=n + +# Allow fallback to HTTP if HTTPS is unavailable +# Note: Allowing HTTP (vs. HTTPS) may cause security risks +# OS.AllowHTTP=n + +# Add firewall rules to protect access to Azure host node services +# Note: +# - The default is false to protect the state of exising VMs +OS.EnableFirewall=y diff --git a/tests/ga/test_update.py b/tests/ga/test_update.py index 0c8642c..59251cb 100644 --- a/tests/ga/test_update.py +++ b/tests/ga/test_update.py @@ -22,6 +22,7 @@ from datetime import datetime import json import shutil +from azurelinuxagent.common.event import * from azurelinuxagent.common.protocol.hostplugin import * from azurelinuxagent.common.protocol.metadata import * from azurelinuxagent.common.protocol.wire import * @@ -148,7 +149,9 @@ class UpdateTestCase(AgentTestCase): def create_error(self, error_data=NO_ERROR): with self.get_error_file(error_data) as path: - return GuestAgentError(path.name) + err = GuestAgentError(path.name) + err.load() + return err def copy_agents(self, *agents): if len(agents) <= 0: @@ -157,11 +160,11 @@ class UpdateTestCase(AgentTestCase): fileutil.copy_file(agent, to_dir=self.tmp_dir) return - def expand_agents(self, mark_test=False): + def expand_agents(self, mark_optional=False): for agent in self.agent_pkgs(): path = os.path.join(self.tmp_dir, fileutil.trim_ext(agent, "zip")) zipfile.ZipFile(agent).extractall(path) - if mark_test: + if mark_optional: src = os.path.join(data_dir, 'ga', 'supported.json') dst = os.path.join(path, 'supported.json') shutil.copy(src, dst) @@ -170,12 +173,12 @@ class UpdateTestCase(AgentTestCase): fileutil.write_file(dst, json.dumps(SENTINEL_ERROR)) return - def prepare_agent(self, version, mark_test=False): + def prepare_agent(self, version, mark_optional=False): """ Create a download for the current agent version, copied from test data """ self.copy_agents(get_agent_pkgs()[0]) - self.expand_agents(mark_test=mark_test) + self.expand_agents(mark_optional=mark_optional) versions = self.agent_versions() src_v = FlexibleVersion(str(versions[0])) @@ -246,7 +249,6 @@ class TestSupportedDistribution(UpdateTestCase): self.sd = SupportedDistribution({ 'slice':10, 'versions': ['^Ubuntu,16.10,yakkety$']}) - def test_creation(self): self.assertRaises(TypeError, SupportedDistribution) @@ -276,6 +278,7 @@ class TestSupported(UpdateTestCase): def setUp(self): UpdateTestCase.setUp(self) self.sp = Supported(os.path.join(data_dir, 'ga', 'supported.json')) + self.sp.load() def test_creation(self): self.assertRaises(TypeError, Supported) @@ -305,6 +308,7 @@ class TestGuestAgentError(UpdateTestCase): with self.get_error_file(error_data=WITH_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertEqual(path.name, err.path) self.assertNotEqual(None, err) @@ -316,6 +320,7 @@ class TestGuestAgentError(UpdateTestCase): def test_clear(self): with self.get_error_file(error_data=WITH_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertEqual(path.name, err.path) self.assertNotEqual(None, err) @@ -328,27 +333,16 @@ class TestGuestAgentError(UpdateTestCase): def test_is_sentinel(self): with self.get_error_file(error_data=SENTINEL_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertTrue(err.is_blacklisted) self.assertTrue(err.is_sentinel) with self.get_error_file(error_data=FATAL_ERROR) as path: err = GuestAgentError(path.name) + err.load() self.assertTrue(err.is_blacklisted) self.assertFalse(err.is_sentinel) - def test_load_preserves_error_state(self): - with self.get_error_file(error_data=WITH_ERROR) as path: - err = GuestAgentError(path.name) - self.assertEqual(path.name, err.path) - self.assertNotEqual(None, err) - - with self.get_error_file(error_data=NO_ERROR): - err.load() - self.assertEqual(WITH_ERROR["last_failure"], err.last_failure) - self.assertEqual(WITH_ERROR["failure_count"], err.failure_count) - self.assertEqual(WITH_ERROR["was_fatal"], err.was_fatal) - return - def test_save(self): err1 = self.create_error() err1.mark_failure() @@ -406,22 +400,20 @@ class TestGuestAgent(UpdateTestCase): self.agent_path = os.path.join(self.tmp_dir, get_agent_name()) return - def tearDown(self): - self.remove_agents() - return - def test_creation(self): self.assertRaises(UpdateError, GuestAgent, "A very bad file name") n = "{0}-a.bad.version".format(AGENT_NAME) self.assertRaises(UpdateError, GuestAgent, n) + self.expand_agents() + agent = GuestAgent(path=self.agent_path) self.assertNotEqual(None, agent) self.assertEqual(get_agent_name(), agent.name) self.assertEqual(get_agent_version(), agent.version) - self.assertFalse(agent.is_test) - self.assertFalse(agent.in_slice) + self.assertFalse(agent._is_optional) + self.assertFalse(agent._in_slice) self.assertEqual(self.agent_path, agent.get_agent_dir()) @@ -436,13 +428,14 @@ class TestGuestAgent(UpdateTestCase): self.assertEqual(path, agent.get_agent_pkg_path()) self.assertTrue(agent.is_downloaded) - # Note: Agent will get blacklisted since the package for this test is invalid - self.assertTrue(agent.is_blacklisted) - self.assertFalse(agent.is_available) + self.assertFalse(agent.is_blacklisted) + self.assertTrue(agent.is_available) return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_clear_error(self, mock_ensure): + def test_clear_error(self, mock_downloaded): + self.expand_agents() + agent = GuestAgent(path=self.agent_path) agent.mark_failure(is_fatal=True) @@ -459,7 +452,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_is_available(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_is_available(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_available) @@ -471,7 +465,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_is_blacklisted(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_is_blacklisted(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_blacklisted) @@ -485,7 +480,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_is_downloaded(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_is_downloaded(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_downloaded) agent._unpack() @@ -493,50 +489,47 @@ class TestGuestAgent(UpdateTestCase): return @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) - def test_is_test(self, mock_dist): - self.expand_agents(mark_test=True) + @patch('azurelinuxagent.ga.update.GuestAgent._enable') + def test_is_optional(self, mock_enable, mock_dist): + self.expand_agents(mark_optional=True) agent = GuestAgent(path=self.agent_path) self.assertTrue(agent.is_blacklisted) - self.assertTrue(agent.is_test) + self.assertTrue(agent._is_optional) @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) @patch('azurelinuxagent.ga.update.datetime') def test_in_slice(self, mock_dt, mock_dist): - self.expand_agents(mark_test=True) + self.expand_agents(mark_optional=True) agent = GuestAgent(path=self.agent_path) mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.assertTrue(agent.in_slice) + self.assertTrue(agent._in_slice) mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 42)) - self.assertFalse(agent.in_slice) + self.assertFalse(agent._in_slice) @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) @patch('azurelinuxagent.ga.update.datetime') def test_enable(self, mock_dt, mock_dist): mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.expand_agents(mark_test=True) + self.expand_agents(mark_optional=True) agent = GuestAgent(path=self.agent_path) - self.assertTrue(agent.is_blacklisted) - self.assertTrue(agent.is_test) - self.assertTrue(agent.in_slice) - - agent.enable() - self.assertFalse(agent.is_blacklisted) - self.assertFalse(agent.is_test) + self.assertFalse(agent._is_optional) # Ensure the new state is preserved to disk agent = GuestAgent(path=self.agent_path) self.assertFalse(agent.is_blacklisted) - self.assertFalse(agent.is_test) + self.assertFalse(agent._is_optional) @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_mark_failure(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_mark_failure(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) + agent.mark_failure() self.assertEqual(1, agent.error.failure_count) @@ -546,7 +539,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_unpack(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_unpack(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -555,7 +549,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_unpack_fail(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_unpack_fail(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) os.remove(agent.get_agent_pkg_path()) @@ -563,7 +558,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) agent._unpack() agent._load_manifest() @@ -572,7 +568,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest_missing(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest_missing(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -581,7 +578,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest_is_empty(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest_is_empty(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -593,7 +591,8 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") - def test_load_manifest_is_malformed(self, mock_ensure): + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") + def test_load_manifest_is_malformed(self, mock_loaded, mock_downloaded): agent = GuestAgent(path=self.agent_path) self.assertFalse(os.path.isdir(agent.get_agent_dir())) agent._unpack() @@ -613,8 +612,9 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download(self, mock_http_get, mock_ensure): + def test_download(self, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) @@ -630,8 +630,9 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download_fail(self, mock_http_get, mock_ensure): + def test_download_fail(self, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) @@ -647,8 +648,9 @@ class TestGuestAgent(UpdateTestCase): return @patch("azurelinuxagent.ga.update.GuestAgent._ensure_downloaded") + @patch("azurelinuxagent.ga.update.GuestAgent._ensure_loaded") @patch("azurelinuxagent.ga.update.restutil.http_get") - def test_download_fallback(self, mock_http_get, mock_ensure): + def test_download_fallback(self, mock_http_get, mock_loaded, mock_downloaded): self.remove_agents() self.assertFalse(os.path.isdir(self.agent_path)) @@ -681,8 +683,12 @@ class TestGuestAgent(UpdateTestCase): return_value=True): self.assertRaises(UpdateError, agent._download) self.assertEqual(mock_http_get.call_count, 4) + self.assertEqual(mock_http_get.call_args_list[2][0][0], ext_uri) + self.assertEqual(mock_http_get.call_args_list[3][0][0], art_uri) + a, k = mock_http_get.call_args_list[3] + self.assertEqual(False, k['use_proxy']) # ensure fallback works as expected with patch.object(HostPluginProtocol, @@ -690,8 +696,16 @@ class TestGuestAgent(UpdateTestCase): return_value=[art_uri, {}]): self.assertRaises(UpdateError, agent._download) self.assertEqual(mock_http_get.call_count, 6) + + a, k = mock_http_get.call_args_list[3] + self.assertEqual(False, k['use_proxy']) + self.assertEqual(mock_http_get.call_args_list[4][0][0], ext_uri) + a, k = mock_http_get.call_args_list[4] + self.assertEqual(mock_http_get.call_args_list[5][0][0], art_uri) + a, k = mock_http_get.call_args_list[5] + self.assertEqual(False, k['use_proxy']) @patch("azurelinuxagent.ga.update.restutil.http_get") def test_ensure_downloaded(self, mock_http_get): @@ -725,7 +739,7 @@ class TestGuestAgent(UpdateTestCase): @patch("azurelinuxagent.ga.update.GuestAgent._download") @patch("azurelinuxagent.ga.update.GuestAgent._unpack", side_effect=UpdateError) - def test_ensure_downloaded_unpack_fails(self, mock_download, mock_unpack): + def test_ensure_downloaded_unpack_fails(self, mock_unpack, mock_download): self.assertFalse(os.path.isdir(self.agent_path)) pkg = ExtHandlerPackage(version=str(get_agent_version())) @@ -740,7 +754,7 @@ class TestGuestAgent(UpdateTestCase): @patch("azurelinuxagent.ga.update.GuestAgent._download") @patch("azurelinuxagent.ga.update.GuestAgent._unpack") @patch("azurelinuxagent.ga.update.GuestAgent._load_manifest", side_effect=UpdateError) - def test_ensure_downloaded_load_manifest_fails(self, mock_download, mock_unpack, mock_manifest): + def test_ensure_downloaded_load_manifest_fails(self, mock_manifest, mock_unpack, mock_download): self.assertFalse(os.path.isdir(self.agent_path)) pkg = ExtHandlerPackage(version=str(get_agent_version())) @@ -755,10 +769,13 @@ class TestGuestAgent(UpdateTestCase): @patch("azurelinuxagent.ga.update.GuestAgent._download") @patch("azurelinuxagent.ga.update.GuestAgent._unpack") @patch("azurelinuxagent.ga.update.GuestAgent._load_manifest") - def test_ensure_download_skips_blacklisted(self, mock_download, mock_unpack, mock_manifest): + def test_ensure_download_skips_blacklisted(self, mock_manifest, mock_unpack, mock_download): agent = GuestAgent(path=self.agent_path) + self.assertEqual(0, mock_download.call_count) + agent.clear_error() agent.mark_failure(is_fatal=True) + self.assertTrue(agent.is_blacklisted) pkg = ExtHandlerPackage(version=str(get_agent_version())) pkg.uris.append(ExtHandlerPackageUri()) @@ -769,7 +786,6 @@ class TestGuestAgent(UpdateTestCase): self.assertTrue(agent.is_blacklisted) self.assertEqual(0, mock_download.call_count) self.assertEqual(0, mock_unpack.call_count) - self.assertEqual(0, mock_manifest.call_count) return @@ -812,6 +828,12 @@ class TestUpdate(UpdateTestCase): self.event_patch.stop() return + def _create_protocol(self, count=5, versions=None): + latest_version = self.prepare_agents(count=count) + if versions is None or len(versions) <= 0: + versions = [latest_version] + return ProtocolMock(versions=versions) + def _test_upgrade_available( self, base_version=FlexibleVersion(AGENT_VERSION), @@ -819,12 +841,9 @@ class TestUpdate(UpdateTestCase): versions=None, count=5): - latest_version = self.prepare_agents(count=count) - if versions is None or len(versions) <= 0: - versions = [latest_version] - if protocol is None: - protocol = ProtocolMock(versions=versions) + protocol = self._create_protocol(count=count, versions=versions) + self.update_handler.protocol_util = protocol conf.get_autoupdate_gafamily = Mock(return_value=protocol.family) @@ -834,6 +853,16 @@ class TestUpdate(UpdateTestCase): self.assertTrue(self._test_upgrade_available()) return + def test_upgrade_available_will_refresh_goal_state(self): + protocol = self._create_protocol() + protocol.emulate_stale_goal_state() + self.assertTrue(self._test_upgrade_available(protocol=protocol)) + self.assertEqual(2, protocol.call_counts["get_vmagent_manifests"]) + self.assertEqual(1, protocol.call_counts["get_vmagent_pkgs"]) + self.assertEqual(1, protocol.call_counts["update_goal_state"]) + self.assertTrue(protocol.goal_state_forced) + return + def test_get_latest_agent_excluded(self): self.prepare_agent(AGENT_VERSION) self.assertFalse(self._test_upgrade_available( @@ -909,7 +938,7 @@ class TestUpdate(UpdateTestCase): v = a.version return - def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL): + def _test_ensure_no_orphans(self, invocations=3, interval=ORPHAN_WAIT_INTERVAL, pid_count=0): with patch.object(self.update_handler, 'osutil') as mock_util: # Note: # - Python only allows mutations of objects to which a function has @@ -924,15 +953,20 @@ class TestUpdate(UpdateTestCase): mock_util.check_pid_alive = Mock(side_effect=iterator) + pid_files = self.update_handler._get_pid_files() + self.assertEqual(pid_count, len(pid_files)) + with patch('os.getpid', return_value=42): with patch('time.sleep', return_value=None) as mock_sleep: self.update_handler._ensure_no_orphans(orphan_wait_interval=interval) + for pid_file in pid_files: + self.assertFalse(os.path.exists(pid_file)) return mock_util.check_pid_alive.call_count, mock_sleep.call_count return def test_ensure_no_orphans(self): fileutil.write_file(os.path.join(self.tmp_dir, "0_waagent.pid"), ustr(41)) - calls, sleeps = self._test_ensure_no_orphans(invocations=3) + calls, sleeps = self._test_ensure_no_orphans(invocations=3, pid_count=1) self.assertEqual(3, calls) self.assertEqual(2, sleeps) return @@ -955,7 +989,8 @@ class TestUpdate(UpdateTestCase): with patch('os.kill') as mock_kill: calls, sleeps = self._test_ensure_no_orphans( invocations=4, - interval=3*GOAL_STATE_INTERVAL) + interval=3*GOAL_STATE_INTERVAL, + pid_count=1) self.assertEqual(3, calls) self.assertEqual(2, sleeps) self.assertEqual(1, mock_kill.call_count) @@ -1066,26 +1101,20 @@ class TestUpdate(UpdateTestCase): return def test_get_pid_files(self): - previous_pid_file, pid_file, = self.update_handler._get_pid_files() - self.assertEqual(None, previous_pid_file) - self.assertEqual("0_waagent.pid", os.path.basename(pid_file)) + pid_files = self.update_handler._get_pid_files() + self.assertEqual(0, len(pid_files)) return def test_get_pid_files_returns_previous(self): for n in range(1250): fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1)) - previous_pid_file, pid_file, = self.update_handler._get_pid_files() - self.assertEqual("1249_waagent.pid", os.path.basename(previous_pid_file)) - self.assertEqual("1250_waagent.pid", os.path.basename(pid_file)) - return - - @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) - @patch('azurelinuxagent.ga.update.datetime') - def test_get_test_agent(self, mock_dt, mock_dist): - mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.prepare_agent(AGENT_VERSION, mark_test=True) + pid_files = self.update_handler._get_pid_files() + self.assertEqual(1250, len(pid_files)) - self.assertNotEqual(None, self.update_handler.get_test_agent()) + pid_dir, pid_name, pid_re = self.update_handler._get_pid_parts() + for p in pid_files: + self.assertTrue(pid_re.match(os.path.basename(p))) + return def test_is_clean_start_returns_true_when_no_sentinal(self): self.assertFalse(os.path.isfile(self.update_handler._sentinal_file_path())) @@ -1421,23 +1450,6 @@ class TestUpdate(UpdateTestCase): self.update_handler._upgrade_available = Mock(return_value=True) self._test_run(invocations=0, calls=[], enable_updates=True) return - - @patch('platform.linux_distribution', return_value=['Ubuntu', '16.10', 'yakkety']) - @patch('azurelinuxagent.ga.update.datetime') - def test_run_stops_if_test_agent_available(self, mock_dt, mock_dist): - mock_dt.utcnow = Mock(return_value=datetime(2017, 1, 1, 0, 0, 5)) - self.prepare_agent(AGENT_VERSION, mark_test=True) - - agent = GuestAgent(path=self.agent_dir(AGENT_VERSION)) - agent.enable = Mock() - self.assertTrue(agent.is_test) - self.assertTrue(agent.in_slice) - - with patch('azurelinuxagent.ga.update.UpdateHandler.get_test_agent', - return_value=agent) as mock_test: - self._test_run(invocations=0) - self.assertEqual(mock_test.call_count, 1) - self.assertEqual(agent.enable.call_count, 1) def test_run_stops_if_orphaned(self): with patch('os.getppid', return_value=1): @@ -1521,8 +1533,9 @@ class TestUpdate(UpdateTestCase): for n in range(1112): fileutil.write_file(os.path.join(self.tmp_dir, str(n)+"_waagent.pid"), ustr(n+1)) with patch('os.getpid', return_value=1112): - previous_pid_file, pid_file = self.update_handler._write_pid_file() - self.assertEqual("1111_waagent.pid", os.path.basename(previous_pid_file)) + pid_files, pid_file = self.update_handler._write_pid_file() + self.assertEqual(1112, len(pid_files)) + self.assertEqual("1111_waagent.pid", os.path.basename(pid_files[-1])) self.assertEqual("1112_waagent.pid", os.path.basename(pid_file)) self.assertEqual(fileutil.read_file(pid_file), ustr(1112)) return @@ -1530,8 +1543,8 @@ class TestUpdate(UpdateTestCase): def test_write_pid_file_ignores_exceptions(self): with patch('azurelinuxagent.common.utils.fileutil.write_file', side_effect=Exception): with patch('os.getpid', return_value=42): - previous_pid_file, pid_file = self.update_handler._write_pid_file() - self.assertEqual(None, previous_pid_file) + pid_files, pid_file = self.update_handler._write_pid_file() + self.assertEqual(0, len(pid_files)) self.assertEqual(None, pid_file) return @@ -1549,12 +1562,22 @@ class ProtocolMock(object): def __init__(self, family="TestAgent", etag=42, versions=None, client=None): self.family = family self.client = client + self.call_counts = { + "get_vmagent_manifests" : 0, + "get_vmagent_pkgs" : 0, + "update_goal_state" : 0 + } + self.goal_state_is_stale = False + self.goal_state_forced = False self.etag = etag self.versions = versions if versions is not None else [] self.create_manifests() self.create_packages() return + def emulate_stale_goal_state(self): + self.goal_state_is_stale = True + def create_manifests(self): self.agent_manifests = VMAgentManifestList() if len(self.versions) <= 0: @@ -1585,11 +1608,23 @@ class ProtocolMock(object): return self def get_vmagent_manifests(self): + self.call_counts["get_vmagent_manifests"] += 1 + if self.goal_state_is_stale: + self.goal_state_is_stale = False + raise ResourceGoneError() return self.agent_manifests, self.etag def get_vmagent_pkgs(self, manifest): + self.call_counts["get_vmagent_pkgs"] += 1 + if self.goal_state_is_stale: + self.goal_state_is_stale = False + raise ResourceGoneError() return self.agent_packages + def update_goal_state(self, forced=False, max_retry=3): + self.call_counts["update_goal_state"] += 1 + self.goal_state_forced = self.goal_state_forced or forced + return class ResponseMock(Mock): def __init__(self, status=restutil.httpclient.OK, response=None, reason=None): diff --git a/tests/pa/test_provision.py b/tests/pa/test_provision.py index 0446442..7045fcc 100644 --- a/tests/pa/test_provision.py +++ b/tests/pa/test_provision.py @@ -53,6 +53,24 @@ class TestProvision(AgentTestCase): data = DefaultOSUtil().decode_customdata(base64data) fileutil.write_file(tempfile.mktemp(), data) + @patch('azurelinuxagent.common.conf.get_provision_enabled', + return_value=False) + def test_provisioning_is_skipped_when_not_enabled(self, mock_conf): + ph = ProvisionHandler() + ph.osutil = DefaultOSUtil() + ph.osutil.get_instance_id = Mock( + return_value='B9F3C233-9913-9F42-8EB3-BA656DF32502') + + ph.is_provisioned = Mock() + ph.report_ready = Mock() + ph.write_provisioned = Mock() + + ph.run() + + ph.is_provisioned.assert_not_called() + ph.report_ready.assert_called_once() + ph.write_provisioned.assert_called_once() + @patch('os.path.isfile', return_value=False) def test_is_provisioned_not_provisioned(self, mock_isfile): ph = ProvisionHandler() @@ -64,33 +82,37 @@ class TestProvision(AgentTestCase): @patch('azurelinuxagent.pa.deprovision.get_deprovision_handler') def test_is_provisioned_is_provisioned(self, mock_deprovision, mock_read, mock_isfile): + ph = ProvisionHandler() ph.osutil = Mock() - ph.osutil.get_instance_id = \ - Mock(return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502") + ph.osutil.is_current_instance_id = Mock(return_value=True) ph.write_provisioned = Mock() deprovision_handler = Mock() mock_deprovision.return_value = deprovision_handler self.assertTrue(ph.is_provisioned()) + ph.osutil.is_current_instance_id.assert_called_once() deprovision_handler.run_changed_unique_id.assert_not_called() @patch('os.path.isfile', return_value=True) @patch('azurelinuxagent.common.utils.fileutil.read_file', - side_effect=["Value"]) + return_value="B9F3C233-9913-9F42-8EB3-BA656DF32502") @patch('azurelinuxagent.pa.deprovision.get_deprovision_handler') def test_is_provisioned_not_deprovisioned(self, mock_deprovision, mock_read, mock_isfile): ph = ProvisionHandler() ph.osutil = Mock() + ph.osutil.is_current_instance_id = Mock(return_value=False) + ph.report_ready = Mock() ph.write_provisioned = Mock() deprovision_handler = Mock() mock_deprovision.return_value = deprovision_handler self.assertTrue(ph.is_provisioned()) + ph.osutil.is_current_instance_id.assert_called_once() deprovision_handler.run_changed_unique_id.assert_called_once() if __name__ == '__main__': diff --git a/tests/protocol/mockwiredata.py b/tests/protocol/mockwiredata.py index 4e45623..5924719 100644 --- a/tests/protocol/mockwiredata.py +++ b/tests/protocol/mockwiredata.py @@ -16,6 +16,7 @@ # from tests.tools import * +from azurelinuxagent.common.exception import HttpError, ResourceGoneError from azurelinuxagent.common.future import httpclient from azurelinuxagent.common.utils.cryptutil import CryptUtil @@ -53,6 +54,20 @@ DATA_FILE_EXT_AUTOUPGRADE_INTERNALVERSION["ext_conf"] = "wire/ext_conf_autoupgra class WireProtocolData(object): def __init__(self, data_files=DATA_FILE): + self.emulate_stale_goal_state = False + self.call_counts = { + "comp=versions" : 0, + "/versions" : 0, + "goalstate" : 0, + "hostingenvuri" : 0, + "sharedconfiguri" : 0, + "certificatesuri" : 0, + "extensionsconfiguri" : 0, + "extensionArtifact" : 0, + "manifest.xml" : 0, + "manifest_of_ga.xml" : 0, + "ExampleHandlerLinux" : 0 + } self.version_info = load_data(data_files.get("version_info")) self.goal_state = load_data(data_files.get("goal_state")) self.hosting_env = load_data(data_files.get("hosting_env")) @@ -67,32 +82,70 @@ class WireProtocolData(object): def mock_http_get(self, url, *args, **kwargs): content = None - if "versions" in url: + + resp = MagicMock() + resp.status = httpclient.OK + + # wire server versions + if "comp=versions" in url: content = self.version_info + self.call_counts["comp=versions"] += 1 + + # HostPlugin versions + elif "/versions" in url: + content = '["2015-09-01"]' + self.call_counts["/versions"] += 1 elif "goalstate" in url: content = self.goal_state + self.call_counts["goalstate"] += 1 elif "hostingenvuri" in url: content = self.hosting_env + self.call_counts["hostingenvuri"] += 1 elif "sharedconfiguri" in url: content = self.shared_config + self.call_counts["sharedconfiguri"] += 1 elif "certificatesuri" in url: content = self.certs + self.call_counts["certificatesuri"] += 1 elif "extensionsconfiguri" in url: content = self.ext_conf - elif "manifest.xml" in url: - content = self.manifest - elif "manifest_of_ga.xml" in url: - content = self.ga_manifest - elif "ExampleHandlerLinux" in url: - content = self.ext - resp = MagicMock() - resp.status = httpclient.OK - resp.read = Mock(return_value=content) - return resp + self.call_counts["extensionsconfiguri"] += 1 + else: - raise Exception("Bad url {0}".format(url)) - resp = MagicMock() - resp.status = httpclient.OK + # A stale GoalState results in a 400 from the HostPlugin + # for which the HTTP handler in restutil raises ResourceGoneError + if self.emulate_stale_goal_state: + if "extensionArtifact" in url: + self.emulate_stale_goal_state = False + self.call_counts["extensionArtifact"] += 1 + raise ResourceGoneError() + else: + raise HttpError() + + # For HostPlugin requests, replace the URL with that passed + # via the x-ms-artifact-location header + if "extensionArtifact" in url: + self.call_counts["extensionArtifact"] += 1 + if "headers" not in kwargs or \ + "x-ms-artifact-location" not in kwargs["headers"]: + raise Exception("Bad HEADERS passed to HostPlugin: {0}", + kwargs) + url = kwargs["headers"]["x-ms-artifact-location"] + + if "manifest.xml" in url: + content = self.manifest + self.call_counts["manifest.xml"] += 1 + elif "manifest_of_ga.xml" in url: + content = self.ga_manifest + self.call_counts["manifest_of_ga.xml"] += 1 + elif "ExampleHandlerLinux" in url: + content = self.ext + self.call_counts["ExampleHandlerLinux"] += 1 + resp.read = Mock(return_value=content) + return resp + else: + raise Exception("Bad url {0}".format(url)) + resp.read = Mock(return_value=content.encode("utf-8")) return resp diff --git a/tests/protocol/test_hostplugin.py b/tests/protocol/test_hostplugin.py index b18b691..74f7f24 100644 --- a/tests/protocol/test_hostplugin.py +++ b/tests/protocol/test_hostplugin.py @@ -146,6 +146,7 @@ class TestHostPlugin(AgentTestCase): test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state) status = restapi.VMStatus(status="Ready", message="Guest Agent is running") + wire.HostPluginProtocol.set_default_channel(False) with patch.object(wire.HostPluginProtocol, "ensure_initialized", return_value=True): @@ -173,6 +174,7 @@ class TestHostPlugin(AgentTestCase): test_goal_state = wire.GoalState(WireProtocolData(DATA_FILE).goal_state) status = restapi.VMStatus(status="Ready", message="Guest Agent is running") + wire.HostPluginProtocol.set_default_channel(False) with patch.object(wire.StatusBlob, "upload", return_value=False): @@ -211,6 +213,8 @@ class TestHostPlugin(AgentTestCase): bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: + patch_http.return_value = Mock(status=httpclient.OK) + wire_protocol_client.get_goal_state = Mock(return_value=test_goal_state) plugin = wire_protocol_client.get_host_plugin() @@ -224,61 +228,6 @@ class TestHostPlugin(AgentTestCase): test_goal_state, exp_method, exp_url, exp_data) - def test_read_response_error(self): - """ - Validate the read_response_error method handles encoding correctly - """ - responses = ['message', b'message', '\x80message\x80'] - response = MagicMock() - response.status = 'status' - response.reason = 'reason' - with patch.object(response, 'read') as patch_response: - for s in responses: - patch_response.return_value = s - result = hostplugin.HostPluginProtocol.read_response_error(response) - self.assertTrue('[status: reason]' in result) - self.assertTrue('message' in result) - - def test_read_response_bytes(self): - response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \ - '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \ - '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \ - '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \ - '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \ - 'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \ - '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \ - '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \ - '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \ - '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \ - '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \ - '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \ - '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \ - '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \ - '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \ - '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \ - '6c:73:22:3a:20:22:22:0a:7d'.split(':') - expected_response = '[status: reason] {\n "errorCode": "The blob ' \ - 'type is invalid for this operation.",\n ' \ - '"message": "<?xml version="1.0" ' \ - 'encoding="utf-8"?>' \ - '<Error><Code>InvalidBlobType</Code><Message>The ' \ - 'blob type is invalid for this operation.\n' \ - 'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \ - '\n "details": ""\n}' - - response_string = ''.join(chr(int(b, 16)) for b in response_bytes) - response = MagicMock() - response.status = 'status' - response.reason = 'reason' - with patch.object(response, 'read') as patch_response: - patch_response.return_value = response_string - result = hostplugin.HostPluginProtocol.read_response_error(response) - self.assertEqual(result, expected_response) - try: - raise HttpError("{0}".format(result)) - except HttpError as e: - self.assertTrue(result in ustr(e)) - def test_no_fallback(self): """ Validate fallback to upload status using HostGAPlugin is not happening @@ -318,6 +267,8 @@ class TestHostPlugin(AgentTestCase): bytearray(faux_status, encoding='utf-8')) with patch.object(restutil, "http_request") as patch_http: + patch_http.return_value = Mock(status=httpclient.OK) + with patch.object(wire.HostPluginProtocol, "get_api_versions") as patch_get: patch_get.return_value = api_versions diff --git a/tests/protocol/test_metadata.py b/tests/protocol/test_metadata.py index ee4ba3e..5047b86 100644 --- a/tests/protocol/test_metadata.py +++ b/tests/protocol/test_metadata.py @@ -31,17 +31,15 @@ class TestMetadataProtocolGetters(AgentTestCase): return json.loads(ustr(load_data(path)), encoding="utf-8") @patch("time.sleep") - @patch("azurelinuxagent.common.protocol.metadata.restutil") - def _test_getters(self, test_data, mock_restutil ,_): - mock_restutil.http_get.side_effect = test_data.mock_http_get - - protocol = MetadataProtocol() - protocol.detect() - protocol.get_vminfo() - protocol.get_certs() - ext_handlers, etag = protocol.get_ext_handlers() - for ext_handler in ext_handlers.extHandlers: - protocol.get_ext_handler_pkgs(ext_handler) + def _test_getters(self, test_data ,_): + with patch.object(restutil, 'http_get', test_data.mock_http_get): + protocol = MetadataProtocol() + protocol.detect() + protocol.get_vminfo() + protocol.get_certs() + ext_handlers, etag = protocol.get_ext_handlers() + for ext_handler in ext_handlers.extHandlers: + protocol.get_ext_handler_pkgs(ext_handler) def test_getters(self, *args): test_data = MetadataProtocolData(DATA_FILE) diff --git a/tests/protocol/test_wire.py b/tests/protocol/test_wire.py index 02976ca..d19bab1 100644 --- a/tests/protocol/test_wire.py +++ b/tests/protocol/test_wire.py @@ -25,30 +25,34 @@ wireserver_url = '168.63.129.16' @patch("time.sleep") @patch("azurelinuxagent.common.protocol.wire.CryptUtil") -@patch("azurelinuxagent.common.protocol.wire.restutil") class TestWireProtocolGetters(AgentTestCase): - def _test_getters(self, test_data, mock_restutil, MockCryptUtil, _): - mock_restutil.http_get.side_effect = test_data.mock_http_get + + def setUp(self): + super(TestWireProtocolGetters, self).setUp() + HostPluginProtocol.set_default_channel(False) + + def _test_getters(self, test_data, MockCryptUtil, _): MockCryptUtil.side_effect = test_data.mock_crypt_util - protocol = WireProtocol(wireserver_url) - protocol.detect() - protocol.get_vminfo() - protocol.get_certs() - ext_handlers, etag = protocol.get_ext_handlers() - for ext_handler in ext_handlers.extHandlers: - protocol.get_ext_handler_pkgs(ext_handler) - - crt1 = os.path.join(self.tmp_dir, - '33B0ABCE4673538650971C10F7D7397E71561F35.crt') - crt2 = os.path.join(self.tmp_dir, - '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt') - prv2 = os.path.join(self.tmp_dir, - '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv') - - self.assertTrue(os.path.isfile(crt1)) - self.assertTrue(os.path.isfile(crt2)) - self.assertTrue(os.path.isfile(prv2)) + with patch.object(restutil, 'http_get', test_data.mock_http_get): + protocol = WireProtocol(wireserver_url) + protocol.detect() + protocol.get_vminfo() + protocol.get_certs() + ext_handlers, etag = protocol.get_ext_handlers() + for ext_handler in ext_handlers.extHandlers: + protocol.get_ext_handler_pkgs(ext_handler) + + crt1 = os.path.join(self.tmp_dir, + '33B0ABCE4673538650971C10F7D7397E71561F35.crt') + crt2 = os.path.join(self.tmp_dir, + '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.crt') + prv2 = os.path.join(self.tmp_dir, + '4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3.prv') + + self.assertTrue(os.path.isfile(crt1)) + self.assertTrue(os.path.isfile(crt2)) + self.assertTrue(os.path.isfile(prv2)) def test_getters(self, *args): """Normal case""" @@ -70,8 +74,21 @@ class TestWireProtocolGetters(AgentTestCase): test_data = WireProtocolData(DATA_FILE_EXT_NO_PUBLIC) self._test_getters(test_data, *args) + def test_getters_with_stale_goal_state(self, *args): + test_data = WireProtocolData(DATA_FILE) + test_data.emulate_stale_goal_state = True + + self._test_getters(test_data, *args) + # Ensure HostPlugin was invoked + self.assertEqual(1, test_data.call_counts["/versions"]) + self.assertEqual(2, test_data.call_counts["extensionArtifact"]) + # Ensure the expected number of HTTP calls were made + # -- Tracking calls to retrieve GoalState is problematic since it is + # fetched often; however, the dependent documents, such as the + # HostingEnvironmentConfig, will be retrieved the expected number + self.assertEqual(2, test_data.call_counts["hostingenvuri"]) + def test_call_storage_kwargs(self, - mock_restutil, mock_cryptutil, mock_sleep): from azurelinuxagent.common.utils import restutil @@ -83,32 +100,32 @@ class TestWireProtocolGetters(AgentTestCase): # no kwargs -- Default to True WireClient.call_storage_service(http_req) - # kwargs, no chk_proxy -- Default to True + # kwargs, no use_proxy -- Default to True WireClient.call_storage_service(http_req, url, headers) - # kwargs, chk_proxy None -- Default to True + # kwargs, use_proxy None -- Default to True WireClient.call_storage_service(http_req, url, headers, - chk_proxy=None) + use_proxy=None) - # kwargs, chk_proxy False -- Keep False + # kwargs, use_proxy False -- Keep False WireClient.call_storage_service(http_req, url, headers, - chk_proxy=False) + use_proxy=False) - # kwargs, chk_proxy True -- Keep True + # kwargs, use_proxy True -- Keep True WireClient.call_storage_service(http_req, url, headers, - chk_proxy=True) + use_proxy=True) # assert self.assertTrue(http_patch.call_count == 5) for i in range(0,5): - c = http_patch.call_args_list[i][-1]['chk_proxy'] + c = http_patch.call_args_list[i][-1]['use_proxy'] self.assertTrue(c == (True if i != 3 else False)) def test_status_blob_parsing(self, *args): diff --git a/tests/test_agent.py b/tests/test_agent.py index 1b35933..77be07a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -17,12 +17,56 @@ import mock import os.path +import sys from azurelinuxagent.agent import * from azurelinuxagent.common.conf import * from tests.tools import * +EXPECTED_CONFIGURATION = \ +"""AutoUpdate.Enabled = True +AutoUpdate.GAFamily = Prod +Autoupdate.Frequency = 3600 +DVD.MountPoint = /mnt/cdrom/secure +DetectScvmmEnv = False +EnableOverProvisioning = False +Extension.LogDir = /var/log/azure +HttpProxy.Host = None +HttpProxy.Port = None +Lib.Dir = /var/lib/waagent +Logs.Verbose = False +OS.AllowHTTP = False +OS.CheckRdmaDriver = False +OS.EnableFIPS = True +OS.EnableFirewall = True +OS.EnableRDMA = False +OS.HomeDir = /home +OS.OpensslPath = /usr/bin/openssl +OS.PasswordPath = /etc/shadow +OS.RootDeviceScsiTimeout = 300 +OS.SshDir = /notareal/path +OS.SudoersDir = /etc/sudoers.d +OS.UpdateRdmaDriver = False +Pid.File = /var/run/waagent.pid +Provisioning.AllowResetSysUser = False +Provisioning.DecodeCustomData = False +Provisioning.DeleteRootPassword = True +Provisioning.Enabled = True +Provisioning.ExecuteCustomData = False +Provisioning.MonitorHostName = True +Provisioning.PasswordCryptId = 6 +Provisioning.PasswordCryptSaltLength = 10 +Provisioning.RegenerateSshHostKeyPair = True +Provisioning.SshHostKeyPairType = rsa +Provisioning.UseCloudInit = True +ResourceDisk.EnableSwap = False +ResourceDisk.Filesystem = ext4 +ResourceDisk.Format = True +ResourceDisk.MountOptions = None +ResourceDisk.MountPoint = /mnt/resource +ResourceDisk.SwapSizeMB = 0 +""".split('\n') class TestAgent(AgentTestCase): @@ -90,3 +134,36 @@ class TestAgent(AgentTestCase): mock_daemon.run.assert_called_once_with(child_args="-configuration-path:/foo/bar.conf") mock_load.assert_called_once() + + @patch("azurelinuxagent.common.conf.get_ext_log_dir") + def test_agent_ensures_extension_log_directory(self, mock_dir): + ext_log_dir = os.path.join(self.tmp_dir, "FauxLogDir") + mock_dir.return_value = ext_log_dir + + self.assertFalse(os.path.isdir(ext_log_dir)) + agent = Agent(False, + conf_file_path=os.path.join(data_dir, "test_waagent.conf")) + self.assertTrue(os.path.isdir(ext_log_dir)) + + @patch("azurelinuxagent.common.logger.error") + @patch("azurelinuxagent.common.conf.get_ext_log_dir") + def test_agent_logs_if_extension_log_directory_is_a_file(self, mock_dir, mock_log): + ext_log_dir = os.path.join(self.tmp_dir, "FauxLogDir") + mock_dir.return_value = ext_log_dir + fileutil.write_file(ext_log_dir, "Foo") + + self.assertTrue(os.path.isfile(ext_log_dir)) + self.assertFalse(os.path.isdir(ext_log_dir)) + agent = Agent(False, + conf_file_path=os.path.join(data_dir, "test_waagent.conf")) + self.assertTrue(os.path.isfile(ext_log_dir)) + self.assertFalse(os.path.isdir(ext_log_dir)) + mock_log.assert_called_once() + + def test_agent_show_configuration(self): + if not hasattr(sys.stdout, 'getvalue'): + self.fail('Test requires at least Python 2.7 with buffered output') + agent = Agent(False, + conf_file_path=os.path.join(data_dir, "test_waagent.conf")) + agent.show_configuration() + self.assertEqual(EXPECTED_CONFIGURATION, sys.stdout.getvalue().split('\n')) diff --git a/tests/tools.py b/tests/tools.py index a505700..94fab7f 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -26,8 +26,10 @@ import tempfile import unittest from functools import wraps +import azurelinuxagent.common.event as event import azurelinuxagent.common.conf as conf import azurelinuxagent.common.logger as logger + from azurelinuxagent.common.version import PY_VERSION_MAJOR #Import mock module for Python2 and Python3 @@ -51,14 +53,21 @@ if debug: class AgentTestCase(unittest.TestCase): def setUp(self): prefix = "{0}_".format(self.__class__.__name__) + self.tmp_dir = tempfile.mkdtemp(prefix=prefix) self.test_file = 'test_file' + conf.get_autoupdate_enabled = Mock(return_value=True) conf.get_lib_dir = Mock(return_value=self.tmp_dir) + ext_log_dir = os.path.join(self.tmp_dir, "azure") conf.get_ext_log_dir = Mock(return_value=ext_log_dir) + conf.get_agent_pid_file_path = Mock(return_value=os.path.join(self.tmp_dir, "waagent.pid")) + event.init_event_status(self.tmp_dir) + event.init_event_logger(self.tmp_dir) + def tearDown(self): if not debug and self.tmp_dir is not None: shutil.rmtree(self.tmp_dir) diff --git a/tests/utils/test_file_util.py b/tests/utils/test_file_util.py index 0b92513..87bce8c 100644 --- a/tests/utils/test_file_util.py +++ b/tests/utils/test_file_util.py @@ -15,6 +15,7 @@ # Requires Python 2.4+ and Openssl 1.0+ # +import errno as errno import glob import random import string @@ -64,6 +65,50 @@ class TestFileOperations(AgentTestCase): os.remove(test_file) + def test_findre_in_file(self): + fp = tempfile.mktemp() + with open(fp, 'w') as f: + f.write( +''' +First line +Second line +Third line with more words +''' + ) + + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, ".*rst line$")) + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, ".*ond line$")) + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, ".*with more.*")) + self.assertNotEquals( + None, + fileutil.findre_in_file(fp, "^Third.*")) + self.assertEquals( + None, + fileutil.findre_in_file(fp, "^Do not match.*")) + + def test_findstr_in_file(self): + fp = tempfile.mktemp() + with open(fp, 'w') as f: + f.write( +''' +First line +Second line +Third line with more words +''' + ) + + self.assertTrue(fileutil.findstr_in_file(fp, "First line")) + self.assertTrue(fileutil.findstr_in_file(fp, "Second line")) + self.assertTrue( + fileutil.findstr_in_file(fp, "Third line with more words")) + self.assertFalse(fileutil.findstr_in_file(fp, "Not a line")) + def test_get_last_path_element(self): filepath = '/tmp/abc.def' filename = fileutil.base_name(filepath) @@ -197,5 +242,75 @@ DHCP_HOSTNAME=test\n" fileutil.update_conf_file(path, 'DHCP_HOSTNAME', 'DHCP_HOSTNAME=test') patch_write.assert_called_once_with(path, updated_file) + def test_clean_ioerror_ignores_missing(self): + e = IOError() + e.errno = errno.ENOSPC + + # Send no paths + fileutil.clean_ioerror(e) + + # Send missing file(s) / directories + fileutil.clean_ioerror(e, paths=['/foo/not/here', None, '/bar/not/there']) + + def test_clean_ioerror_ignores_unless_ioerror(self): + try: + d = tempfile.mkdtemp() + fd, f = tempfile.mkstemp() + os.close(fd) + fileutil.write_file(f, 'Not empty') + + # Send non-IOError exception + e = Exception() + fileutil.clean_ioerror(e, paths=[d, f]) + self.assertTrue(os.path.isdir(d)) + self.assertTrue(os.path.isfile(f)) + + # Send unrecognized IOError + e = IOError() + e.errno = errno.EFAULT + self.assertFalse(e.errno in fileutil.KNOWN_IOERRORS) + fileutil.clean_ioerror(e, paths=[d, f]) + self.assertTrue(os.path.isdir(d)) + self.assertTrue(os.path.isfile(f)) + + finally: + shutil.rmtree(d) + os.remove(f) + + def test_clean_ioerror_removes_files(self): + fd, f = tempfile.mkstemp() + os.close(fd) + fileutil.write_file(f, 'Not empty') + + e = IOError() + e.errno = errno.ENOSPC + fileutil.clean_ioerror(e, paths=[f]) + self.assertFalse(os.path.isdir(f)) + self.assertFalse(os.path.isfile(f)) + + def test_clean_ioerror_removes_directories(self): + d1 = tempfile.mkdtemp() + d2 = tempfile.mkdtemp() + for n in ['foo', 'bar']: + fileutil.write_file(os.path.join(d2, n), 'Not empty') + + e = IOError() + e.errno = errno.ENOSPC + fileutil.clean_ioerror(e, paths=[d1, d2]) + self.assertFalse(os.path.isdir(d1)) + self.assertFalse(os.path.isfile(d1)) + self.assertFalse(os.path.isdir(d2)) + self.assertFalse(os.path.isfile(d2)) + + def test_clean_ioerror_handles_a_range_of_errors(self): + for err in fileutil.KNOWN_IOERRORS: + e = IOError() + e.errno = err + + d = tempfile.mkdtemp() + fileutil.clean_ioerror(e, paths=[d]) + self.assertFalse(os.path.isdir(d)) + self.assertFalse(os.path.isfile(d)) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_rest_util.py b/tests/utils/test_rest_util.py index 5f084a6..52674da 100644 --- a/tests/utils/test_rest_util.py +++ b/tests/utils/test_rest_util.py @@ -15,10 +15,16 @@ # Requires Python 2.4+ and Openssl 1.0+ # +import os import unittest + +from azurelinuxagent.common.exception import HttpError, \ + ProtocolError, \ + ResourceGoneError import azurelinuxagent.common.utils.restutil as restutil -from azurelinuxagent.common.future import httpclient -from tests.tools import AgentTestCase, patch, Mock, MagicMock + +from azurelinuxagent.common.future import httpclient, ustr +from tests.tools import * class TestHttpOperations(AgentTestCase): @@ -50,45 +56,163 @@ class TestHttpOperations(AgentTestCase): self.assertEquals(None, host) self.assertEquals(rel_uri, "None") + @patch('azurelinuxagent.common.conf.get_httpproxy_port') + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_none_is_default(self, mock_host, mock_port): + mock_host.return_value = None + mock_port.return_value = None + h, p = restutil._get_http_proxy() + self.assertEqual(None, h) + self.assertEqual(None, p) + + @patch('azurelinuxagent.common.conf.get_httpproxy_port') + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_configuration_overrides_env(self, mock_host, mock_port): + mock_host.return_value = "host" + mock_port.return_value = None + h, p = restutil._get_http_proxy() + self.assertEqual("host", h) + self.assertEqual(None, p) + mock_host.assert_called_once() + mock_port.assert_called_once() + + @patch('azurelinuxagent.common.conf.get_httpproxy_port') + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_configuration_requires_host(self, mock_host, mock_port): + mock_host.return_value = None + mock_port.return_value = None + h, p = restutil._get_http_proxy() + self.assertEqual(None, h) + self.assertEqual(None, p) + mock_host.assert_called_once() + mock_port.assert_not_called() + + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_http_uses_httpproxy(self, mock_host): + mock_host.return_value = None + with patch.dict(os.environ, { + 'http_proxy' : 'http://foo.com:80', + 'https_proxy' : 'https://bar.com:443' + }): + h, p = restutil._get_http_proxy() + self.assertEqual("foo.com", h) + self.assertEqual(80, p) + + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_https_uses_httpsproxy(self, mock_host): + mock_host.return_value = None + with patch.dict(os.environ, { + 'http_proxy' : 'http://foo.com:80', + 'https_proxy' : 'https://bar.com:443' + }): + h, p = restutil._get_http_proxy(secure=True) + self.assertEqual("bar.com", h) + self.assertEqual(443, p) + + @patch('azurelinuxagent.common.conf.get_httpproxy_host') + def test_get_http_proxy_ignores_user_in_httpproxy(self, mock_host): + mock_host.return_value = None + with patch.dict(os.environ, { + 'http_proxy' : 'http://user:pw@foo.com:80' + }): + h, p = restutil._get_http_proxy() + self.assertEqual("foo.com", h) + self.assertEqual(80, p) + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") - def test_http_request(self, HTTPConnection, HTTPSConnection): - mock_http_conn = MagicMock() - mock_http_resp = MagicMock() - mock_http_conn.getresponse = Mock(return_value=mock_http_resp) - HTTPConnection.return_value = mock_http_conn - HTTPSConnection.return_value = mock_http_conn + def test_http_request_direct(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) - mock_http_resp.read = Mock(return_value="_(:3| <)_") + HTTPConnection.return_value = mock_conn - # Test http get - resp = restutil._http_request("GET", "foo", "bar") - self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + resp = restutil._http_request("GET", "foo", "/bar") - # Test https get - resp = restutil._http_request("GET", "foo", "bar", secure=True) + HTTPConnection.assert_has_calls([ + call("foo", 80, timeout=10) + ]) + HTTPSConnection.assert_not_called() + mock_conn.request.assert_has_calls([ + call(method="GET", url="/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) + + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") + @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") + def test_http_request_direct_secure(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) + + HTTPSConnection.return_value = mock_conn + + resp = restutil._http_request("GET", "foo", "/bar", secure=True) - # Test http get with proxy - mock_http_resp.read = Mock(return_value="_(:3| <)_") - resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", - proxy_port=23333) + HTTPConnection.assert_not_called() + HTTPSConnection.assert_has_calls([ + call("foo", 443, timeout=10) + ]) + mock_conn.request.assert_has_calls([ + call(method="GET", url="/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) - # Test https get - resp = restutil._http_request("GET", "foo", "bar", secure=True) + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") + @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") + def test_http_request_proxy(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) + + HTTPConnection.return_value = mock_conn + + resp = restutil._http_request("GET", "foo", "/bar", + proxy_host="foo.bar", proxy_port=23333) + + HTTPConnection.assert_has_calls([ + call("foo.bar", 23333, timeout=10) + ]) + HTTPSConnection.assert_not_called() + mock_conn.request.assert_has_calls([ + call(method="GET", url="http://foo:80/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) + + @patch("azurelinuxagent.common.future.httpclient.HTTPSConnection") + @patch("azurelinuxagent.common.future.httpclient.HTTPConnection") + def test_http_request_proxy_secure(self, HTTPConnection, HTTPSConnection): + mock_conn = \ + MagicMock(getresponse=\ + Mock(return_value=\ + Mock(read=Mock(return_value="TheResults")))) + + HTTPSConnection.return_value = mock_conn - # Test https get with proxy - mock_http_resp.read = Mock(return_value="_(:3| <)_") - resp = restutil._http_request("GET", "foo", "bar", proxy_host="foo.bar", - proxy_port=23333, secure=True) + resp = restutil._http_request("GET", "foo", "/bar", + proxy_host="foo.bar", proxy_port=23333, + secure=True) + + HTTPConnection.assert_not_called() + HTTPSConnection.assert_has_calls([ + call("foo.bar", 23333, timeout=10) + ]) + mock_conn.request.assert_has_calls([ + call(method="GET", url="https://foo:443/bar", body=None, headers={}) + ]) + mock_conn.getresponse.assert_called_once() self.assertNotEquals(None, resp) - self.assertEquals("_(:3| <)_", resp.read()) + self.assertEquals("TheResults", resp.read()) @patch("time.sleep") @patch("azurelinuxagent.common.utils.restutil._http_request") @@ -115,6 +239,180 @@ class TestHttpOperations(AgentTestCase): self.assertRaises(restutil.HttpError, restutil.http_get, "http://foo.bar") + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_status_codes(self, _http_request, _sleep): + _http_request.side_effect = [ + Mock(status=httpclient.SERVICE_UNAVAILABLE), + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar") + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_passed_status_codes(self, _http_request, _sleep): + # Ensure the code is not part of the standard set + self.assertFalse(httpclient.UNAUTHORIZED in restutil.RETRY_CODES) + + _http_request.side_effect = [ + Mock(status=httpclient.UNAUTHORIZED), + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar", retry_codes=[httpclient.UNAUTHORIZED]) + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_raises_for_bad_request(self, _http_request, _sleep): + _http_request.side_effect = [ + Mock(status=httpclient.BAD_REQUEST) + ] + + self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar") + self.assertEqual(1, _http_request.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_raises_for_resource_gone(self, _http_request, _sleep): + _http_request.side_effect = [ + Mock(status=httpclient.GONE) + ] + + self.assertRaises(ResourceGoneError, restutil.http_get, "https://foo.bar") + self.assertEqual(1, _http_request.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_exceptions(self, _http_request, _sleep): + # Testing each exception is difficult because they have varying + # signatures; for now, test one and ensure the set is unchanged + recognized_exceptions = [ + httpclient.NotConnected, + httpclient.IncompleteRead, + httpclient.ImproperConnectionState, + httpclient.BadStatusLine + ] + self.assertEqual(recognized_exceptions, restutil.RETRY_EXCEPTIONS) + + _http_request.side_effect = [ + httpclient.IncompleteRead(''), + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar") + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + @patch("time.sleep") + @patch("azurelinuxagent.common.utils.restutil._http_request") + def test_http_request_retries_ioerrors(self, _http_request, _sleep): + ioerror = IOError() + ioerror.errno = 42 + + _http_request.side_effect = [ + ioerror, + Mock(status=httpclient.OK) + ] + + restutil.http_get("https://foo.bar") + self.assertEqual(2, _http_request.call_count) + self.assertEqual(1, _sleep.call_count) + + def test_request_failed(self): + self.assertTrue(restutil.request_failed(None)) + + resp = Mock() + for status in restutil.OK_CODES: + resp.status = status + self.assertFalse(restutil.request_failed(resp)) + + self.assertFalse(httpclient.BAD_REQUEST in restutil.OK_CODES) + resp.status = httpclient.BAD_REQUEST + self.assertTrue(restutil.request_failed(resp)) + + self.assertFalse( + restutil.request_failed( + resp, ok_codes=[httpclient.BAD_REQUEST])) + + def test_request_succeeded(self): + self.assertFalse(restutil.request_succeeded(None)) + + resp = Mock() + for status in restutil.OK_CODES: + resp.status = status + self.assertTrue(restutil.request_succeeded(resp)) + + self.assertFalse(httpclient.BAD_REQUEST in restutil.OK_CODES) + resp.status = httpclient.BAD_REQUEST + self.assertFalse(restutil.request_succeeded(resp)) + + self.assertTrue( + restutil.request_succeeded( + resp, ok_codes=[httpclient.BAD_REQUEST])) + + def test_read_response_error(self): + """ + Validate the read_response_error method handles encoding correctly + """ + responses = ['message', b'message', '\x80message\x80'] + response = MagicMock() + response.status = 'status' + response.reason = 'reason' + with patch.object(response, 'read') as patch_response: + for s in responses: + patch_response.return_value = s + result = restutil.read_response_error(response) + print("RESPONSE: {0}".format(s)) + print("RESULT: {0}".format(result)) + print("PRESENT: {0}".format('[status: reason]' in result)) + self.assertTrue('[status: reason]' in result) + self.assertTrue('message' in result) + + def test_read_response_bytes(self): + response_bytes = '7b:0a:20:20:20:20:22:65:72:72:6f:72:43:6f:64:65:22:' \ + '3a:20:22:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:' \ + '69:73:20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:' \ + '69:73:20:6f:70:65:72:61:74:69:6f:6e:2e:22:2c:0a:20:' \ + '20:20:20:22:6d:65:73:73:61:67:65:22:3a:20:22:c3:af:' \ + 'c2:bb:c2:bf:3c:3f:78:6d:6c:20:76:65:72:73:69:6f:6e:' \ + '3d:22:31:2e:30:22:20:65:6e:63:6f:64:69:6e:67:3d:22:' \ + '75:74:66:2d:38:22:3f:3e:3c:45:72:72:6f:72:3e:3c:43:' \ + '6f:64:65:3e:49:6e:76:61:6c:69:64:42:6c:6f:62:54:79:' \ + '70:65:3c:2f:43:6f:64:65:3e:3c:4d:65:73:73:61:67:65:' \ + '3e:54:68:65:20:62:6c:6f:62:20:74:79:70:65:20:69:73:' \ + '20:69:6e:76:61:6c:69:64:20:66:6f:72:20:74:68:69:73:' \ + '20:6f:70:65:72:61:74:69:6f:6e:2e:0a:52:65:71:75:65:' \ + '73:74:49:64:3a:63:37:34:32:39:30:63:62:2d:30:30:30:' \ + '31:2d:30:30:62:35:2d:30:36:64:61:2d:64:64:36:36:36:' \ + '61:30:30:30:22:2c:0a:20:20:20:20:22:64:65:74:61:69:' \ + '6c:73:22:3a:20:22:22:0a:7d'.split(':') + expected_response = '[HTTP Failed] [status: reason] {\n "errorCode": "The blob ' \ + 'type is invalid for this operation.",\n ' \ + '"message": "<?xml version="1.0" ' \ + 'encoding="utf-8"?>' \ + '<Error><Code>InvalidBlobType</Code><Message>The ' \ + 'blob type is invalid for this operation.\n' \ + 'RequestId:c74290cb-0001-00b5-06da-dd666a000",' \ + '\n "details": ""\n}' + + response_string = ''.join(chr(int(b, 16)) for b in response_bytes) + response = MagicMock() + response.status = 'status' + response.reason = 'reason' + with patch.object(response, 'read') as patch_response: + patch_response.return_value = response_string + result = restutil.read_response_error(response) + self.assertEqual(result, expected_response) + try: + raise HttpError("{0}".format(result)) + except HttpError as e: + self.assertTrue(result in ustr(e)) + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/test_text_util.py b/tests/utils/test_text_util.py index 6f204c7..d182a67 100644 --- a/tests/utils/test_text_util.py +++ b/tests/utils/test_text_util.py @@ -34,6 +34,19 @@ class TestTextUtil(AgentTestCase): password_hash = textutil.gen_password_hash(data, 6, 10) self.assertNotEquals(None, password_hash) + def test_replace_non_ascii(self): + data = ustr(b'\xef\xbb\xbfhehe', encoding='utf-8') + self.assertEqual('hehe', textutil.replace_non_ascii(data)) + + data = "abcd\xa0e\xf0fghijk\xbblm" + self.assertEqual("abcdefghijklm", textutil.replace_non_ascii(data)) + + data = "abcd\xa0e\xf0fghijk\xbblm" + self.assertEqual("abcdXeXfghijkXlm", + textutil.replace_non_ascii(data, replace_char='X')) + + self.assertEqual('', textutil.replace_non_ascii(None)) + def test_remove_bom(self): #Test bom could be removed data = ustr(b'\xef\xbb\xbfhehe', encoding='utf-8') @@ -94,6 +107,37 @@ class TestTextUtil(AgentTestCase): "-----END PRIVATE Key-----\n") base64_bytes = textutil.get_bytes_from_pem(content) self.assertEquals("private key", base64_bytes) + + def test_swap_hexstring(self): + data = [ + ['12', 1, '21'], + ['12', 2, '12'], + ['12', 3, '012'], + ['12', 4, '0012'], + + ['123', 1, '321'], + ['123', 2, '2301'], + ['123', 3, '123'], + ['123', 4, '0123'], + + ['1234', 1, '4321'], + ['1234', 2, '3412'], + ['1234', 3, '234001'], + ['1234', 4, '1234'], + + ['abcdef12', 1, '21fedcba'], + ['abcdef12', 2, '12efcdab'], + ['abcdef12', 3, 'f12cde0ab'], + ['abcdef12', 4, 'ef12abcd'], + + ['aBcdEf12', 1, '21fEdcBa'], + ['aBcdEf12', 2, '12EfcdaB'], + ['aBcdEf12', 3, 'f12cdE0aB'], + ['aBcdEf12', 4, 'Ef12aBcd'] + ] + + for t in data: + self.assertEqual(t[2], textutil.swap_hexstring(t[0], width=t[1])) if __name__ == '__main__': unittest.main() |