diff options
Diffstat (limited to 'azurelinuxagent/common')
32 files changed, 6246 insertions, 0 deletions
diff --git a/azurelinuxagent/common/__init__.py b/azurelinuxagent/common/__init__.py new file mode 100644 index 0000000..1ea2f38 --- /dev/null +++ b/azurelinuxagent/common/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + diff --git a/azurelinuxagent/common/conf.py b/azurelinuxagent/common/conf.py new file mode 100644 index 0000000..1a3b0da --- /dev/null +++ b/azurelinuxagent/common/conf.py @@ -0,0 +1,181 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +""" +Module conf loads and parses configuration file +""" +import os +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.exception import AgentConfigError + +class ConfigurationProvider(object): + """ + Parse amd store key:values in /etc/waagent.conf. + """ + def __init__(self): + self.values = dict() + + def load(self, content): + if not content: + raise AgentConfigError("Can't not parse empty configuration") + for line in content.split('\n'): + if not line.startswith("#") and "=" in line: + parts = line.split()[0].split('=') + value = parts[1].strip("\" ") + if value != "None": + self.values[parts[0]] = value + else: + self.values[parts[0]] = None + + def get(self, key, default_val): + val = self.values.get(key) + return val if val is not None else default_val + + def get_switch(self, key, default_val): + val = self.values.get(key) + if val is not None and val.lower() == 'y': + return True + elif val is not None and val.lower() == 'n': + return False + return default_val + + def get_int(self, key, default_val): + try: + return int(self.values.get(key)) + except TypeError: + return default_val + except ValueError: + return default_val + + +__conf__ = ConfigurationProvider() + +def load_conf_from_file(conf_file_path, conf=__conf__): + """ + Load conf file from: conf_file_path + """ + if os.path.isfile(conf_file_path) == False: + raise AgentConfigError(("Missing configuration in {0}" + "").format(conf_file_path)) + try: + content = fileutil.read_file(conf_file_path) + conf.load(content) + except IOError as err: + raise AgentConfigError(("Failed to load conf file:{0}, {1}" + "").format(conf_file_path, err)) + +def enable_rdma(conf=__conf__): + return conf.get_switch("OS.EnableRDMA", False) + +def get_logs_verbose(conf=__conf__): + return conf.get_switch("Logs.Verbose", False) + +def get_lib_dir(conf=__conf__): + return conf.get("Lib.Dir", "/var/lib/waagent") + +def get_dvd_mount_point(conf=__conf__): + return conf.get("DVD.MountPoint", "/mnt/cdrom/secure") + +def get_agent_pid_file_path(conf=__conf__): + return conf.get("Pid.File", "/var/run/waagent.pid") + +def get_ext_log_dir(conf=__conf__): + return conf.get("Extension.LogDir", "/var/log/azure") + +def get_openssl_cmd(conf=__conf__): + return conf.get("OS.OpensslPath", "/usr/bin/openssl") + +def get_home_dir(conf=__conf__): + return conf.get("OS.HomeDir", "/home") + +def get_passwd_file_path(conf=__conf__): + return conf.get("OS.PasswordPath", "/etc/shadow") + +def get_sudoers_dir(conf=__conf__): + return conf.get("OS.SudoersDir", "/etc/sudoers.d") + +def get_sshd_conf_file_path(conf=__conf__): + return conf.get("OS.SshdConfigPath", "/etc/ssh/sshd_config") + +def get_root_device_scsi_timeout(conf=__conf__): + return conf.get("OS.RootDeviceScsiTimeout", None) + +def get_ssh_host_keypair_type(conf=__conf__): + return conf.get("Provisioning.SshHostKeyPairType", "rsa") + +def get_provision_enabled(conf=__conf__): + return conf.get_switch("Provisioning.Enabled", True) + +def get_allow_reset_sys_user(conf=__conf__): + return conf.get_switch("Provisioning.AllowResetSysUser", False) + +def get_regenerate_ssh_host_key(conf=__conf__): + return conf.get_switch("Provisioning.RegenerateSshHostKeyPair", False) + +def get_delete_root_password(conf=__conf__): + return conf.get_switch("Provisioning.DeleteRootPassword", False) + +def get_decode_customdata(conf=__conf__): + return conf.get_switch("Provisioning.DecodeCustomData", False) + +def get_execute_customdata(conf=__conf__): + return conf.get_switch("Provisioning.ExecuteCustomData", False) + +def get_password_cryptid(conf=__conf__): + return conf.get("Provisioning.PasswordCryptId", "6") + +def get_password_crypt_salt_len(conf=__conf__): + return conf.get_int("Provisioning.PasswordCryptSaltLength", 10) + +def get_monitor_hostname(conf=__conf__): + return conf.get_switch("Provisioning.MonitorHostName", False) + +def get_httpproxy_host(conf=__conf__): + return conf.get("HttpProxy.Host", None) + +def get_httpproxy_port(conf=__conf__): + return conf.get("HttpProxy.Port", None) + +def get_detect_scvmm_env(conf=__conf__): + return conf.get_switch("DetectScvmmEnv", False) + +def get_resourcedisk_format(conf=__conf__): + return conf.get_switch("ResourceDisk.Format", False) + +def get_resourcedisk_enable_swap(conf=__conf__): + return conf.get_switch("ResourceDisk.EnableSwap", False) + +def get_resourcedisk_mountpoint(conf=__conf__): + return conf.get("ResourceDisk.MountPoint", "/mnt/resource") + +def get_resourcedisk_filesystem(conf=__conf__): + return conf.get("ResourceDisk.Filesystem", "ext3") + +def get_resourcedisk_swap_size_mb(conf=__conf__): + return conf.get_int("ResourceDisk.SwapSizeMB", 0) + +def get_autoupdate_gafamily(conf=__conf__): + return conf.get("AutoUpdate.GAFamily", "Prod") + +def get_autoupdate_enabled(conf=__conf__): + return conf.get_switch("AutoUpdate.Enabled", True) + +def get_autoupdate_frequency(conf=__conf__): + return conf.get_int("Autoupdate.Frequency", 3600) + diff --git a/azurelinuxagent/common/dhcp.py b/azurelinuxagent/common/dhcp.py new file mode 100644 index 0000000..d5c90cb --- /dev/null +++ b/azurelinuxagent/common/dhcp.py @@ -0,0 +1,400 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import os +import socket +import array +import time +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.shellutil as shellutil +from azurelinuxagent.common.utils import fileutil +from azurelinuxagent.common.utils.textutil import hex_dump, hex_dump2, \ + hex_dump3, \ + compare_bytes, str_to_ord, \ + unpack_big_endian, \ + int_to_ip4_addr +from azurelinuxagent.common.exception import DhcpError +from azurelinuxagent.common.osutil import get_osutil + +# the kernel routing table representation of 168.63.129.16 +KNOWN_WIRESERVER_IP_ENTRY = '10813FA8' +KNOWN_WIRESERVER_IP = '168.63.129.16' + + +def get_dhcp_handler(): + return DhcpHandler() + + +class DhcpHandler(object): + """ + Azure use DHCP option 245 to pass endpoint ip to VMs. + """ + + def __init__(self): + self.osutil = get_osutil() + self.endpoint = None + self.gateway = None + self.routes = None + self._request_broadcast = False + self.skip_cache = False + + def run(self): + """ + Send dhcp request + Configure default gateway and routes + Save wire server endpoint if found + """ + if self.wireserver_route_exists or self.dhcp_cache_exists: + return + + self.send_dhcp_req() + self.conf_routes() + + def wait_for_network(self): + """ + Wait for network stack to be initialized. + """ + ipv4 = self.osutil.get_ip4_addr() + while ipv4 == '' or ipv4 == '0.0.0.0': + logger.info("Waiting for network.") + time.sleep(10) + logger.info("Try to start network interface.") + self.osutil.start_network() + ipv4 = self.osutil.get_ip4_addr() + + @property + def wireserver_route_exists(self): + """ + Determine whether a route to the known wireserver + ip already exists, and if so use that as the endpoint. + This is true when running in a virtual network. + :return: True if a route to KNOWN_WIRESERVER_IP exists. + """ + route_exists = False + logger.info("test for route to {0}".format(KNOWN_WIRESERVER_IP)) + try: + route_file = '/proc/net/route' + if os.path.exists(route_file) and \ + KNOWN_WIRESERVER_IP_ENTRY in open(route_file).read(): + # reset self.gateway and self.routes + # we do not need to alter the routing table + self.endpoint = KNOWN_WIRESERVER_IP + self.gateway = None + self.routes = None + route_exists = True + logger.info("route to {0} exists".format(KNOWN_WIRESERVER_IP)) + else: + logger.warn( + "no route exists to {0}".format(KNOWN_WIRESERVER_IP)) + except Exception as e: + logger.error( + "could not determine whether route exists to {0}: {1}".format( + KNOWN_WIRESERVER_IP, e)) + + return route_exists + + @property + def dhcp_cache_exists(self): + """ + Check whether the dhcp options cache exists and contains the + wireserver endpoint, unless skip_cache is True. + :return: True if the cached endpoint was found in the dhcp lease + """ + if self.skip_cache: + return False + + exists = False + + logger.info("checking for dhcp lease cache") + cached_endpoint = self.osutil.get_dhcp_lease_endpoint() + if cached_endpoint is not None: + self.endpoint = cached_endpoint + exists = True + logger.info("cache exists [{0}]".format(exists)) + return exists + + def conf_routes(self): + logger.info("Configure routes") + logger.info("Gateway:{0}", self.gateway) + logger.info("Routes:{0}", self.routes) + # Add default gateway + if self.gateway is not None: + self.osutil.route_add(0, 0, self.gateway) + if self.routes is not None: + for route in self.routes: + self.osutil.route_add(route[0], route[1], route[2]) + + def _send_dhcp_req(self, request): + __waiting_duration__ = [0, 10, 30, 60, 60] + for duration in __waiting_duration__: + try: + self.osutil.allow_dhcp_broadcast() + response = socket_send(request) + validate_dhcp_resp(request, response) + return response + except DhcpError as e: + logger.warn("Failed to send DHCP request: {0}", e) + time.sleep(duration) + return None + + def send_dhcp_req(self): + """ + Build dhcp request with mac addr + Configure route to allow dhcp traffic + Stop dhcp service if necessary + """ + logger.info("Send dhcp request") + mac_addr = self.osutil.get_mac_addr() + + # Do unicast first, then fallback to broadcast if fails. + req = build_dhcp_request(mac_addr, self._request_broadcast) + if not self._request_broadcast: + self._request_broadcast = True + + # Temporary allow broadcast for dhcp. Remove the route when done. + missing_default_route = self.osutil.is_missing_default_route() + ifname = self.osutil.get_if_name() + if missing_default_route: + self.osutil.set_route_for_dhcp_broadcast(ifname) + + # In some distros, dhcp service needs to be shutdown before agent probe + # endpoint through dhcp. + if self.osutil.is_dhcp_enabled(): + self.osutil.stop_dhcp_service() + + resp = self._send_dhcp_req(req) + + if self.osutil.is_dhcp_enabled(): + self.osutil.start_dhcp_service() + + if missing_default_route: + self.osutil.remove_route_for_dhcp_broadcast(ifname) + + if resp is None: + raise DhcpError("Failed to receive dhcp response.") + self.endpoint, self.gateway, self.routes = parse_dhcp_resp(resp) + + +def validate_dhcp_resp(request, response): + bytes_recv = len(response) + if bytes_recv < 0xF6: + logger.error("HandleDhcpResponse: Too few bytes received:{0}", + bytes_recv) + return False + + logger.verbose("BytesReceived:{0}", hex(bytes_recv)) + logger.verbose("DHCP response:{0}", hex_dump(response, bytes_recv)) + + # check transactionId, cookie, MAC address cookie should never mismatch + # transactionId and MAC address may mismatch if we see a response + # meant from another machine + if not compare_bytes(request, response, 0xEC, 4): + logger.verbose("Cookie not match:\nsend={0},\nreceive={1}", + hex_dump3(request, 0xEC, 4), + hex_dump3(response, 0xEC, 4)) + raise DhcpError("Cookie in dhcp respones doesn't match the request") + + if not compare_bytes(request, response, 4, 4): + logger.verbose("TransactionID not match:\nsend={0},\nreceive={1}", + hex_dump3(request, 4, 4), + hex_dump3(response, 4, 4)) + raise DhcpError("TransactionID in dhcp respones " + "doesn't match the request") + + if not compare_bytes(request, response, 0x1C, 6): + logger.verbose("Mac Address not match:\nsend={0},\nreceive={1}", + hex_dump3(request, 0x1C, 6), + hex_dump3(response, 0x1C, 6)) + raise DhcpError("Mac Addr in dhcp respones " + "doesn't match the request") + + +def parse_route(response, option, i, length, bytes_recv): + # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx + logger.verbose("Routes at offset: {0} with length:{1}", hex(i), + hex(length)) + routes = [] + if length < 5: + logger.error("Data too small for option:{0}", option) + j = i + 2 + while j < (i + length + 2): + mask_len_bits = str_to_ord(response[j]) + mask_len_bytes = (((mask_len_bits + 7) & ~7) >> 3) + mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - mask_len_bits)) + j += 1 + net = unpack_big_endian(response, j, mask_len_bytes) + net <<= (32 - mask_len_bytes * 8) + net &= mask + j += mask_len_bytes + gateway = unpack_big_endian(response, j, 4) + j += 4 + routes.append((net, mask, gateway)) + if j != (i + length + 2): + logger.error("Unable to parse routes") + return routes + + +def parse_ip_addr(response, option, i, length, bytes_recv): + if i + 5 < bytes_recv: + if length != 4: + logger.error("Endpoint or Default Gateway not 4 bytes") + return None + addr = unpack_big_endian(response, i + 2, 4) + ip_addr = int_to_ip4_addr(addr) + return ip_addr + else: + logger.error("Data too small for option:{0}", option) + return None + + +def parse_dhcp_resp(response): + """ + Parse DHCP response: + Returns endpoint server or None on error. + """ + logger.verbose("parse Dhcp Response") + bytes_recv = len(response) + endpoint = None + gateway = None + routes = None + + # Walk all the returned options, parsing out what we need, ignoring the + # others. We need the custom option 245 to find the the endpoint we talk to + # as well as to handle some Linux DHCP client incompatibilities; + # options 3 for default gateway and 249 for routes; 255 is end. + + i = 0xF0 # offset to first option + while i < bytes_recv: + option = str_to_ord(response[i]) + length = 0 + if (i + 1) < bytes_recv: + length = str_to_ord(response[i + 1]) + logger.verbose("DHCP option {0} at offset:{1} with length:{2}", + hex(option), hex(i), hex(length)) + if option == 255: + logger.verbose("DHCP packet ended at offset:{0}", hex(i)) + break + elif option == 249: + routes = parse_route(response, option, i, length, bytes_recv) + elif option == 3: + gateway = parse_ip_addr(response, option, i, length, bytes_recv) + logger.verbose("Default gateway:{0}, at {1}", gateway, hex(i)) + elif option == 245: + endpoint = parse_ip_addr(response, option, i, length, bytes_recv) + logger.verbose("Azure wire protocol endpoint:{0}, at {1}", + endpoint, + hex(i)) + else: + logger.verbose("Skipping DHCP option:{0} at {1} with length {2}", + hex(option), hex(i), hex(length)) + i += length + 2 + return endpoint, gateway, routes + + +def socket_send(request): + sock = None + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.IPPROTO_UDP) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("0.0.0.0", 68)) + sock.sendto(request, ("<broadcast>", 67)) + sock.settimeout(10) + logger.verbose("Send DHCP request: Setting socket.timeout=10, " + "entering recv") + response = sock.recv(1024) + return response + except IOError as e: + raise DhcpError("{0}".format(e)) + finally: + if sock is not None: + sock.close() + + +def build_dhcp_request(mac_addr, request_broadcast): + """ + Build DHCP request string. + """ + # + # typedef struct _DHCP { + # UINT8 Opcode; /* op: BOOTREQUEST or BOOTREPLY */ + # UINT8 HardwareAddressType; /* htype: ethernet */ + # UINT8 HardwareAddressLength; /* hlen: 6 (48 bit mac address) */ + # UINT8 Hops; /* hops: 0 */ + # UINT8 TransactionID[4]; /* xid: random */ + # UINT8 Seconds[2]; /* secs: 0 */ + # UINT8 Flags[2]; /* flags: 0 or 0x8000 for broadcast*/ + # UINT8 ClientIpAddress[4]; /* ciaddr: 0 */ + # UINT8 YourIpAddress[4]; /* yiaddr: 0 */ + # UINT8 ServerIpAddress[4]; /* siaddr: 0 */ + # UINT8 RelayAgentIpAddress[4]; /* giaddr: 0 */ + # UINT8 ClientHardwareAddress[16]; /* chaddr: 6 byte eth MAC address */ + # UINT8 ServerName[64]; /* sname: 0 */ + # UINT8 BootFileName[128]; /* file: 0 */ + # UINT8 MagicCookie[4]; /* 99 130 83 99 */ + # /* 0x63 0x82 0x53 0x63 */ + # /* options -- hard code ours */ + # + # UINT8 MessageTypeCode; /* 53 */ + # UINT8 MessageTypeLength; /* 1 */ + # UINT8 MessageType; /* 1 for DISCOVER */ + # UINT8 End; /* 255 */ + # } DHCP; + # + + # tuple of 244 zeros + # (struct.pack_into would be good here, but requires Python 2.5) + request = [0] * 244 + + trans_id = gen_trans_id() + + # Opcode = 1 + # HardwareAddressType = 1 (ethernet/MAC) + # HardwareAddressLength = 6 (ethernet/MAC/48 bits) + for a in range(0, 3): + request[a] = [1, 1, 6][a] + + # fill in transaction id (random number to ensure response matches request) + for a in range(0, 4): + request[4 + a] = str_to_ord(trans_id[a]) + + logger.verbose("BuildDhcpRequest: transactionId:%s,%04X" % ( + hex_dump2(trans_id), + unpack_big_endian(request, 4, 4))) + + if request_broadcast: + # set broadcast flag to true to request the dhcp sever + # to respond to a boradcast address, + # this is useful when user dhclient fails. + request[0x0A] = 0x80; + + # fill in ClientHardwareAddress + for a in range(0, 6): + request[0x1C + a] = str_to_ord(mac_addr[a]) + + # DHCP Magic Cookie: 99, 130, 83, 99 + # MessageTypeCode = 53 DHCP Message Type + # MessageTypeLength = 1 + # MessageType = DHCPDISCOVER + # End = 255 DHCP_END + for a in range(0, 8): + request[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a] + return array.array("B", request) + + +def gen_trans_id(): + return os.urandom(4) diff --git a/azurelinuxagent/common/event.py b/azurelinuxagent/common/event.py new file mode 100644 index 0000000..374b0e7 --- /dev/null +++ b/azurelinuxagent/common/event.py @@ -0,0 +1,124 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import sys +import traceback +import atexit +import json +import time +import datetime +import threading +import platform +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import EventError, ProtocolError +from azurelinuxagent.common.future import ustr +from azurelinuxagent.common.protocol.restapi import TelemetryEventParam, \ + TelemetryEventList, \ + TelemetryEvent, \ + set_properties, get_properties +from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_CODE_NAME, AGENT_VERSION + + +class WALAEventOperation: + HeartBeat="HeartBeat" + Provision = "Provision" + Install = "Install" + UnInstall = "UnInstall" + Disable = "Disable" + Enable = "Enable" + Download = "Download" + Upgrade = "Upgrade" + Update = "Update" + ActivateResourceDisk="ActivateResourceDisk" + UnhandledError="UnhandledError" + +class EventLogger(object): + def __init__(self): + self.event_dir = None + + def save_event(self, data): + if self.event_dir is None: + logger.warn("Event reporter is not initialized.") + return + + if not os.path.exists(self.event_dir): + os.mkdir(self.event_dir) + os.chmod(self.event_dir, 0o700) + if len(os.listdir(self.event_dir)) > 1000: + raise EventError("Too many files under: {0}".format(self.event_dir)) + + filename = os.path.join(self.event_dir, ustr(int(time.time()*1000000))) + try: + with open(filename+".tmp",'wb+') as hfile: + hfile.write(data.encode("utf-8")) + os.rename(filename+".tmp", filename+".tld") + except IOError as e: + raise EventError("Failed to write events to file:{0}", e) + + def add_event(self, name, op="", is_success=True, duration=0, version=AGENT_VERSION, + message="", evt_type="", is_internal=False): + event = TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975") + event.parameters.append(TelemetryEventParam('Name', name)) + event.parameters.append(TelemetryEventParam('Version', str(version))) + event.parameters.append(TelemetryEventParam('IsInternal', is_internal)) + event.parameters.append(TelemetryEventParam('Operation', op)) + event.parameters.append(TelemetryEventParam('OperationSuccess', + is_success)) + event.parameters.append(TelemetryEventParam('Message', message)) + event.parameters.append(TelemetryEventParam('Duration', duration)) + event.parameters.append(TelemetryEventParam('ExtensionType', evt_type)) + + data = get_properties(event) + try: + self.save_event(json.dumps(data)) + except EventError as e: + logger.error("{0}", e) + +__event_logger__ = EventLogger() + +def add_event(name, op="", is_success=True, duration=0, version=AGENT_VERSION, + message="", evt_type="", is_internal=False, + reporter=__event_logger__): + log = logger.info if is_success else logger.error + log("Event: name={0}, op={1}, message={2}", name, op, message) + + if reporter.event_dir is None: + logger.warn("Event reporter is not initialized.") + return + reporter.add_event(name, op=op, is_success=is_success, duration=duration, + version=str(version), message=message, evt_type=evt_type, + is_internal=is_internal) + +def init_event_logger(event_dir, reporter=__event_logger__): + reporter.event_dir = event_dir + +def dump_unhandled_err(name): + if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \ + hasattr(sys, 'last_traceback'): + last_type = getattr(sys, 'last_type') + last_value = getattr(sys, 'last_value') + last_traceback = getattr(sys, 'last_traceback') + error = traceback.format_exception(last_type, last_value, + last_traceback) + message= "".join(error) + add_event(name, is_success=False, message=message, + op=WALAEventOperation.UnhandledError) + +def enable_unhandled_err_dump(name): + atexit.register(dump_unhandled_err, name) diff --git a/azurelinuxagent/common/exception.py b/azurelinuxagent/common/exception.py new file mode 100644 index 0000000..457490c --- /dev/null +++ b/azurelinuxagent/common/exception.py @@ -0,0 +1,123 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +""" +Defines all exceptions +""" + +class AgentError(Exception): + """ + Base class of agent error. + """ + def __init__(self, errno, msg, inner=None): + msg = u"({0}){1}".format(errno, msg) + if inner is not None: + msg = u"{0} \n inner error: {1}".format(msg, inner) + super(AgentError, self).__init__(msg) + +class AgentConfigError(AgentError): + """ + When configure file is not found or malformed. + """ + def __init__(self, msg=None, inner=None): + super(AgentConfigError, self).__init__('000001', msg, inner) + +class AgentNetworkError(AgentError): + """ + When network is not avaiable. + """ + def __init__(self, msg=None, inner=None): + super(AgentNetworkError, self).__init__('000002', msg, inner) + +class ExtensionError(AgentError): + """ + When failed to execute an extension + """ + def __init__(self, msg=None, inner=None): + super(ExtensionError, self).__init__('000003', msg, inner) + +class ProvisionError(AgentError): + """ + When provision failed + """ + def __init__(self, msg=None, inner=None): + super(ProvisionError, self).__init__('000004', msg, inner) + +class ResourceDiskError(AgentError): + """ + Mount resource disk failed + """ + def __init__(self, msg=None, inner=None): + super(ResourceDiskError, self).__init__('000005', msg, inner) + +class DhcpError(AgentError): + """ + Failed to handle dhcp response + """ + def __init__(self, msg=None, inner=None): + super(DhcpError, self).__init__('000006', msg, inner) + +class OSUtilError(AgentError): + """ + Failed to perform operation to OS configuration + """ + def __init__(self, msg=None, inner=None): + super(OSUtilError, self).__init__('000007', msg, inner) + +class ProtocolError(AgentError): + """ + Azure protocol error + """ + def __init__(self, msg=None, inner=None): + super(ProtocolError, self).__init__('000008', msg, inner) + +class ProtocolNotFoundError(ProtocolError): + """ + Azure protocol endpoint not found + """ + def __init__(self, msg=None, inner=None): + super(ProtocolNotFoundError, self).__init__(msg, inner) + +class HttpError(AgentError): + """ + Http request failure + """ + def __init__(self, msg=None, inner=None): + super(HttpError, self).__init__('000009', msg, inner) + +class EventError(AgentError): + """ + Event reporting error + """ + def __init__(self, msg=None, inner=None): + super(EventError, self).__init__('000010', msg, inner) + +class CryptError(AgentError): + """ + Encrypt/Decrypt error + """ + def __init__(self, msg=None, inner=None): + super(CryptError, self).__init__('000011', msg, inner) + +class UpdateError(AgentError): + """ + Update Guest Agent error + """ + def __init__(self, msg=None, inner=None): + super(UpdateError, self).__init__('000012', msg, inner) + diff --git a/azurelinuxagent/common/future.py b/azurelinuxagent/common/future.py new file mode 100644 index 0000000..8509732 --- /dev/null +++ b/azurelinuxagent/common/future.py @@ -0,0 +1,31 @@ +import sys + +""" +Add alies for python2 and python3 libs and fucntions. +""" + +if sys.version_info[0]== 3: + import http.client as httpclient + from urllib.parse import urlparse + + """Rename Python3 str to ustr""" + ustr = str + + bytebuffer = memoryview + + read_input = input + +elif sys.version_info[0] == 2: + import httplib as httpclient + from urlparse import urlparse + + """Rename Python2 unicode to ustr""" + ustr = unicode + + bytebuffer = buffer + + read_input = raw_input + +else: + raise ImportError("Unknown python version:{0}".format(sys.version_info)) + diff --git a/azurelinuxagent/common/logger.py b/azurelinuxagent/common/logger.py new file mode 100644 index 0000000..c1eb18f --- /dev/null +++ b/azurelinuxagent/common/logger.py @@ -0,0 +1,156 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and openssl_bin 1.0+ +# +""" +Log utils +""" +import os +import sys +from azurelinuxagent.common.future import ustr +from datetime import datetime + +class Logger(object): + """ + Logger class + """ + def __init__(self, logger=None, prefix=None): + self.appenders = [] + if logger is not None: + self.appenders.extend(logger.appenders) + self.prefix = prefix + + def verbose(self, msg_format, *args): + self.log(LogLevel.VERBOSE, msg_format, *args) + + def info(self, msg_format, *args): + self.log(LogLevel.INFO, msg_format, *args) + + def warn(self, msg_format, *args): + self.log(LogLevel.WARNING, msg_format, *args) + + def error(self, msg_format, *args): + self.log(LogLevel.ERROR, msg_format, *args) + + def log(self, level, msg_format, *args): + #if msg_format is not unicode convert it to unicode + if type(msg_format) is not ustr: + msg_format = ustr(msg_format, errors="backslashreplace") + if len(args) > 0: + msg = msg_format.format(*args) + else: + msg = msg_format + time = datetime.now().strftime(u'%Y/%m/%d %H:%M:%S.%f') + level_str = LogLevel.STRINGS[level] + if self.prefix is not None: + log_item = u"{0} {1} {2} {3}\n".format(time, level_str, self.prefix, + msg) + else: + log_item = u"{0} {1} {2}\n".format(time, level_str, msg) + + log_item = ustr(log_item.encode('ascii', "backslashreplace"), + encoding="ascii") + for appender in self.appenders: + appender.write(level, log_item) + + def add_appender(self, appender_type, level, path): + appender = _create_logger_appender(appender_type, level, path) + self.appenders.append(appender) + +class ConsoleAppender(object): + def __init__(self, level, path): + self.level = level + self.path = path + + def write(self, level, msg): + if self.level <= level: + try: + with open(self.path, "w") as console: + console.write(msg) + except IOError: + pass + +class FileAppender(object): + def __init__(self, level, path): + self.level = level + self.path = path + + def write(self, level, msg): + if self.level <= level: + try: + with open(self.path, "a+") as log_file: + log_file.write(msg) + except IOError: + pass + +class StdoutAppender(object): + def __init__(self, level): + self.level = level + + def write(self, level, msg): + if self.level <= level: + try: + sys.stdout.write(msg) + except IOError: + pass + +#Initialize logger instance +DEFAULT_LOGGER = Logger() + +class LogLevel(object): + VERBOSE = 0 + INFO = 1 + WARNING = 2 + ERROR = 3 + STRINGS = [ + "VERBOSE", + "INFO", + "WARNING", + "ERROR" + ] + +class AppenderType(object): + FILE = 0 + CONSOLE = 1 + STDOUT = 2 + +def add_logger_appender(appender_type, level=LogLevel.INFO, path=None): + DEFAULT_LOGGER.add_appender(appender_type, level, path) + +def verbose(msg_format, *args): + DEFAULT_LOGGER.verbose(msg_format, *args) + +def info(msg_format, *args): + DEFAULT_LOGGER.info(msg_format, *args) + +def warn(msg_format, *args): + DEFAULT_LOGGER.warn(msg_format, *args) + +def error(msg_format, *args): + DEFAULT_LOGGER.error(msg_format, *args) + +def log(level, msg_format, *args): + DEFAULT_LOGGER.log(level, msg_format, args) + +def _create_logger_appender(appender_type, level=LogLevel.INFO, path=None): + if appender_type == AppenderType.CONSOLE: + return ConsoleAppender(level, path) + elif appender_type == AppenderType.FILE: + return FileAppender(level, path) + elif appender_type == AppenderType.STDOUT: + return StdoutAppender(level) + else: + raise ValueError("Unknown appender type") + diff --git a/azurelinuxagent/common/osutil/__init__.py b/azurelinuxagent/common/osutil/__init__.py new file mode 100644 index 0000000..3b5ba3b --- /dev/null +++ b/azurelinuxagent/common/osutil/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.common.osutil.factory import get_osutil diff --git a/azurelinuxagent/common/osutil/coreos.py b/azurelinuxagent/common/osutil/coreos.py new file mode 100644 index 0000000..e26fd97 --- /dev/null +++ b/azurelinuxagent/common/osutil/coreos.py @@ -0,0 +1,92 @@ +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import re +import pwd +import shutil +import socket +import array +import struct +import fcntl +import time +import base64 +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.osutil.default import DefaultOSUtil + +class CoreOSUtil(DefaultOSUtil): + def __init__(self): + super(CoreOSUtil, self).__init__() + self.agent_conf_file_path = '/usr/share/oem/waagent.conf' + self.waagent_path='/usr/share/oem/bin/waagent' + self.python_path='/usr/share/oem/python/bin' + if 'PATH' in os.environ: + path = "{0}:{1}".format(os.environ['PATH'], self.python_path) + else: + path = self.python_path + os.environ['PATH'] = path + + if 'PYTHONPATH' in os.environ: + py_path = os.environ['PYTHONPATH'] + py_path = "{0}:{1}".format(py_path, self.waagent_path) + else: + py_path = self.waagent_path + os.environ['PYTHONPATH'] = py_path + + def is_sys_user(self, username): + #User 'core' is not a sysuser + if username == 'core': + return False + return super(CoreOSUtil, self).is_sys_user(username) + + def is_dhcp_enabled(self): + return True + + def start_network(self) : + return shellutil.run("systemctl start systemd-networkd", chk_err=False) + + def restart_if(self, iface): + shellutil.run("systemctl restart systemd-networkd") + + def restart_ssh_service(self): + # SSH is socket activated on CoreOS. No need to restart it. + pass + + def stop_dhcp_service(self): + return shellutil.run("systemctl stop systemd-networkd", chk_err=False) + + def start_dhcp_service(self): + return shellutil.run("systemctl start systemd-networkd", chk_err=False) + + def start_agent_service(self): + return shellutil.run("systemctl start wagent", chk_err=False) + + def stop_agent_service(self): + return shellutil.run("systemctl stop wagent", chk_err=False) + + def get_dhcp_pid(self): + ret= shellutil.run_get_output("pidof systemd-networkd") + return ret[1] if ret[0] == 0 else None + + def conf_sshd(self, disable_password): + #In CoreOS, /etc/sshd_config is mount readonly. Skip the setting + pass + diff --git a/azurelinuxagent/common/osutil/debian.py b/azurelinuxagent/common/osutil/debian.py new file mode 100644 index 0000000..f455572 --- /dev/null +++ b/azurelinuxagent/common/osutil/debian.py @@ -0,0 +1,47 @@ +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import re +import pwd +import shutil +import socket +import array +import struct +import fcntl +import time +import base64 +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.osutil.default import DefaultOSUtil + +class DebianOSUtil(DefaultOSUtil): + def __init__(self): + super(DebianOSUtil, self).__init__() + + def restart_ssh_service(self): + return shellutil.run("service sshd restart", chk_err=False) + + def stop_agent_service(self): + return shellutil.run("service azurelinuxagent stop", chk_err=False) + + def start_agent_service(self): + return shellutil.run("service azurelinuxagent start", chk_err=False) + diff --git a/azurelinuxagent/common/osutil/default.py b/azurelinuxagent/common/osutil/default.py new file mode 100644 index 0000000..c243c85 --- /dev/null +++ b/azurelinuxagent/common/osutil/default.py @@ -0,0 +1,792 @@ +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import re +import shutil +import socket +import array +import struct +import time +import pwd +import fcntl +import base64 +import glob +import datetime +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.conf as conf +from azurelinuxagent.common.exception import OSUtilError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.utils.cryptutil import CryptUtil + +__RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules", + "/etc/udev/rules.d/70-persistent-net.rules" ] + +""" +Define distro specific behavior. OSUtil class defines default behavior +for all distros. Each concrete distro classes could overwrite default behavior +if needed. +""" + +class DefaultOSUtil(object): + + def __init__(self): + self.agent_conf_file_path = '/etc/waagent.conf' + self.selinux=None + + def get_agent_conf_file_path(self): + return self.agent_conf_file_path + + def get_userentry(self, username): + try: + return pwd.getpwnam(username) + except KeyError: + return None + + def is_sys_user(self, username): + """ + Check whether use is a system user. + If reset sys user is allowed in conf, return False + Otherwise, check whether UID is less than UID_MIN + """ + if conf.get_allow_reset_sys_user(): + return False + + userentry = self.get_userentry(username) + uidmin = None + try: + uidmin_def = fileutil.get_line_startingwith("UID_MIN", + "/etc/login.defs") + if uidmin_def is not None: + uidmin = int(uidmin_def.split()[1]) + except IOError as e: + pass + if uidmin == None: + uidmin = 100 + if userentry != None and userentry[2] < uidmin: + return True + else: + return False + + def useradd(self, username, expiration=None): + """ + Create user account with 'username' + """ + userentry = self.get_userentry(username) + if userentry is not None: + logger.info("User {0} already exists, skip useradd", username) + return + + if expiration is not None: + cmd = "useradd -m {0} -e {1}".format(username, expiration) + else: + cmd = "useradd -m {0}".format(username) + retcode, out = shellutil.run_get_output(cmd) + if retcode != 0: + raise OSUtilError(("Failed to create user account:{0}, " + "retcode:{1}, " + "output:{2}").format(username, retcode, out)) + + def chpasswd(self, username, password, crypt_id=6, salt_len=10): + if self.is_sys_user(username): + raise OSUtilError(("User {0} is a system user. " + "Will not set passwd.").format(username)) + passwd_hash = textutil.gen_password_hash(password, crypt_id, salt_len) + cmd = "usermod -p '{0}' {1}".format(passwd_hash, username) + ret, output = shellutil.run_get_output(cmd, log_cmd=False) + if ret != 0: + raise OSUtilError(("Failed to set password for {0}: {1}" + "").format(username, output)) + + def conf_sudoer(self, username, nopasswd=False, remove=False): + sudoers_dir = conf.get_sudoers_dir() + sudoers_wagent = os.path.join(sudoers_dir, 'waagent') + + if not remove: + # for older distros create sudoers.d + if not os.path.isdir(sudoers_dir): + sudoers_file = os.path.join(sudoers_dir, '../sudoers') + # create the sudoers.d directory + os.mkdir(sudoers_dir) + # add the include of sudoers.d to the /etc/sudoers + sudoers = '\n#includedir ' + sudoers_dir + '\n' + fileutil.append_file(sudoers_file, sudoers) + sudoer = None + if nopasswd: + sudoer = "{0} ALL=(ALL) NOPASSWD: ALL\n".format(username) + else: + sudoer = "{0} ALL=(ALL) ALL\n".format(username) + fileutil.append_file(sudoers_wagent, sudoer) + fileutil.chmod(sudoers_wagent, 0o440) + else: + #Remove user from sudoers + if os.path.isfile(sudoers_wagent): + try: + content = fileutil.read_file(sudoers_wagent) + sudoers = content.split("\n") + sudoers = [x for x in sudoers if username not in x] + fileutil.write_file(sudoers_wagent, "\n".join(sudoers)) + except IOError as e: + raise OSUtilError("Failed to remove sudoer: {0}".format(e)) + + def del_root_password(self): + try: + passwd_file_path = conf.get_passwd_file_path() + passwd_content = fileutil.read_file(passwd_file_path) + passwd = passwd_content.split('\n') + new_passwd = [x for x in passwd if not x.startswith("root:")] + new_passwd.insert(0, "root:*LOCK*:14600::::::") + fileutil.write_file(passwd_file_path, "\n".join(new_passwd)) + except IOError as e: + raise OSUtilError("Failed to delete root password:{0}".format(e)) + + def _norm_path(self, filepath): + home = conf.get_home_dir() + # Expand HOME variable if present in path + path = os.path.normpath(filepath.replace("$HOME", home)) + return path + + def deploy_ssh_keypair(self, username, keypair): + """ + Deploy id_rsa and id_rsa.pub + """ + path, thumbprint = keypair + path = self._norm_path(path) + dir_path = os.path.dirname(path) + fileutil.mkdir(dir_path, mode=0o700, owner=username) + lib_dir = conf.get_lib_dir() + prv_path = os.path.join(lib_dir, thumbprint + '.prv') + if not os.path.isfile(prv_path): + raise OSUtilError("Can't find {0}.prv".format(thumbprint)) + shutil.copyfile(prv_path, path) + pub_path = path + '.pub' + crytputil = CryptUtil(conf.get_openssl_cmd()) + pub = crytputil.get_pubkey_from_prv(prv_path) + fileutil.write_file(pub_path, pub) + self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0') + self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0') + os.chmod(path, 0o644) + os.chmod(pub_path, 0o600) + + def openssl_to_openssh(self, input_file, output_file): + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.crt_to_ssh(input_file, output_file) + + def deploy_ssh_pubkey(self, username, pubkey): + """ + Deploy authorized_key + """ + path, thumbprint, value = pubkey + if path is None: + raise OSUtilError("Public key path is None") + + crytputil = CryptUtil(conf.get_openssl_cmd()) + + path = self._norm_path(path) + dir_path = os.path.dirname(path) + fileutil.mkdir(dir_path, mode=0o700, owner=username) + if value is not None: + if not value.startswith("ssh-"): + raise OSUtilError("Bad public key: {0}".format(value)) + fileutil.write_file(path, value) + elif thumbprint is not None: + lib_dir = conf.get_lib_dir() + crt_path = os.path.join(lib_dir, thumbprint + '.crt') + if not os.path.isfile(crt_path): + raise OSUtilError("Can't find {0}.crt".format(thumbprint)) + pub_path = os.path.join(lib_dir, thumbprint + '.pub') + pub = crytputil.get_pubkey_from_crt(crt_path) + fileutil.write_file(pub_path, pub) + self.set_selinux_context(pub_path, + 'unconfined_u:object_r:ssh_home_t:s0') + self.openssl_to_openssh(pub_path, path) + fileutil.chmod(pub_path, 0o600) + else: + raise OSUtilError("SSH public key Fingerprint and Value are None") + + self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0') + fileutil.chowner(path, username) + fileutil.chmod(path, 0o644) + + def is_selinux_system(self): + """ + Checks and sets self.selinux = True if SELinux is available on system. + """ + if self.selinux == None: + if shellutil.run("which getenforce", chk_err=False) == 0: + self.selinux = True + else: + self.selinux = False + return self.selinux + + def is_selinux_enforcing(self): + """ + Calls shell command 'getenforce' and returns True if 'Enforcing'. + """ + if self.is_selinux_system(): + output = shellutil.run_get_output("getenforce")[1] + return output.startswith("Enforcing") + else: + return False + + def set_selinux_enforce(self, state): + """ + Calls shell command 'setenforce' with 'state' + and returns resulting exit code. + """ + if self.is_selinux_system(): + if state: s = '1' + else: s='0' + return shellutil.run("setenforce "+s) + + def set_selinux_context(self, path, con): + """ + Calls shell 'chcon' with 'path' and 'con' context. + Returns exit result. + """ + if self.is_selinux_system(): + if not os.path.exists(path): + logger.error("Path does not exist: {0}".format(path)) + return 1 + return shellutil.run('chcon ' + con + ' ' + path) + + def conf_sshd(self, disable_password): + option = "no" if disable_password else "yes" + conf_file_path = conf.get_sshd_conf_file_path() + conf_file = fileutil.read_file(conf_file_path).split("\n") + textutil.set_ssh_config(conf_file, "PasswordAuthentication", option) + textutil.set_ssh_config(conf_file, "ChallengeResponseAuthentication", option) + textutil.set_ssh_config(conf_file, "ClientAliveInterval", "180") + fileutil.write_file(conf_file_path, "\n".join(conf_file)) + logger.info("{0} SSH password-based authentication methods." + .format("Disabled" if disable_password else "Enabled")) + logger.info("Configured SSH client probing to keep connections alive.") + + + def get_dvd_device(self, dev_dir='/dev'): + pattern=r'(sr[0-9]|hd[c-z]|cdrom[0-9]|cd[0-9])' + for dvd in [re.match(pattern, dev) for dev in os.listdir(dev_dir)]: + if dvd is not None: + return "/dev/{0}".format(dvd.group(0)) + raise OSUtilError("Failed to get dvd device") + + def mount_dvd(self, max_retry=6, chk_err=True, dvd_device=None, mount_point=None): + if dvd_device is None: + dvd_device = self.get_dvd_device() + if mount_point is None: + mount_point = conf.get_dvd_mount_point() + mountlist = shellutil.run_get_output("mount")[1] + existing = self.get_mount_point(mountlist, dvd_device) + if existing is not None: #Already mounted + logger.info("{0} is already mounted at {1}", dvd_device, existing) + return + if not os.path.isdir(mount_point): + os.makedirs(mount_point) + + for retry in range(0, max_retry): + retcode = self.mount(dvd_device, mount_point, option="-o ro -t udf,iso9660", + chk_err=chk_err) + if retcode == 0: + logger.info("Successfully mounted dvd") + return + if retry < max_retry - 1: + logger.warn("Mount dvd failed: retry={0}, ret={1}", retry, + retcode) + time.sleep(5) + if chk_err: + raise OSUtilError("Failed to mount dvd.") + + def umount_dvd(self, chk_err=True, mount_point=None): + if mount_point is None: + mount_point = conf.get_dvd_mount_point() + retcode = self.umount(mount_point, chk_err=chk_err) + if chk_err and retcode != 0: + raise OSUtilError("Failed to umount dvd.") + + def eject_dvd(self, chk_err=True): + dvd = self.get_dvd_device() + retcode = shellutil.run("eject {0}".format(dvd)) + if chk_err and retcode != 0: + raise OSUtilError("Failed to eject dvd: ret={0}".format(retcode)) + + def try_load_atapiix_mod(self): + try: + self.load_atapiix_mod() + except Exception as e: + logger.warn("Could not load ATAPI driver: {0}".format(e)) + + def load_atapiix_mod(self): + if self.is_atapiix_mod_loaded(): + return + ret, kern_version = shellutil.run_get_output("uname -r") + if ret != 0: + raise Exception("Failed to call uname -r") + mod_path = os.path.join('/lib/modules', + kern_version.strip('\n'), + 'kernel/drivers/ata/ata_piix.ko') + if not os.path.isfile(mod_path): + raise Exception("Can't find module file:{0}".format(mod_path)) + + ret, output = shellutil.run_get_output("insmod " + mod_path) + if ret != 0: + raise Exception("Error calling insmod for ATAPI CD-ROM driver") + if not self.is_atapiix_mod_loaded(max_retry=3): + raise Exception("Failed to load ATAPI CD-ROM driver") + + def is_atapiix_mod_loaded(self, max_retry=1): + for retry in range(0, max_retry): + ret = shellutil.run("lsmod | grep ata_piix", chk_err=False) + if ret == 0: + logger.info("Module driver for ATAPI CD-ROM is already present.") + return True + if retry < max_retry - 1: + time.sleep(1) + return False + + def mount(self, dvd, mount_point, option="", chk_err=True): + cmd = "mount {0} {1} {2}".format(option, dvd, mount_point) + return shellutil.run_get_output(cmd, chk_err)[0] + + def umount(self, mount_point, chk_err=True): + return shellutil.run("umount {0}".format(mount_point), chk_err=chk_err) + + def allow_dhcp_broadcast(self): + #Open DHCP port if iptables is enabled. + # We supress error logging on error. + shellutil.run("iptables -D INPUT -p udp --dport 68 -j ACCEPT", + chk_err=False) + shellutil.run("iptables -I INPUT -p udp --dport 68 -j ACCEPT", + chk_err=False) + + + def remove_rules_files(self, rules_files=__RULES_FILES__): + lib_dir = conf.get_lib_dir() + for src in rules_files: + file_name = fileutil.base_name(src) + dest = os.path.join(lib_dir, file_name) + if os.path.isfile(dest): + os.remove(dest) + if os.path.isfile(src): + logger.warn("Move rules file {0} to {1}", file_name, dest) + shutil.move(src, dest) + + def restore_rules_files(self, rules_files=__RULES_FILES__): + lib_dir = conf.get_lib_dir() + for dest in rules_files: + filename = fileutil.base_name(dest) + src = os.path.join(lib_dir, filename) + if os.path.isfile(dest): + continue + if os.path.isfile(src): + logger.warn("Move rules file {0} to {1}", filename, dest) + shutil.move(src, dest) + + def get_mac_addr(self): + """ + Convienience function, returns mac addr bound to + first non-loopback interface. + """ + ifname='' + while len(ifname) < 2 : + ifname=self.get_first_if()[0] + addr = self.get_if_mac(ifname) + return textutil.hexstr_to_bytearray(addr) + + def get_if_mac(self, ifname): + """ + Return the mac-address bound to the socket. + """ + sock = socket.socket(socket.AF_INET, + socket.SOCK_DGRAM, + socket.IPPROTO_UDP) + param = struct.pack('256s', (ifname[:15]+('\0'*241)).encode('latin-1')) + info = fcntl.ioctl(sock.fileno(), 0x8927, param) + return ''.join(['%02X' % textutil.str_to_ord(char) for char in info[18:24]]) + + def get_first_if(self): + """ + Return the interface name, and ip addr of the + first active non-loopback interface. + """ + iface='' + expected=16 # how many devices should I expect... + struct_size=40 # for 64bit the size is 40 bytes + sock = socket.socket(socket.AF_INET, + socket.SOCK_DGRAM, + socket.IPPROTO_UDP) + buff=array.array('B', b'\0' * (expected * struct_size)) + param = struct.pack('iL', + expected*struct_size, + buff.buffer_info()[0]) + ret = fcntl.ioctl(sock.fileno(), 0x8912, param) + retsize=(struct.unpack('iL', ret)[0]) + if retsize == (expected * struct_size): + logger.warn(('SIOCGIFCONF returned more than {0} up ' + 'network interfaces.'), expected) + sock = buff.tostring() + primary = bytearray(self.get_primary_interface(), encoding='utf-8') + for i in range(0, struct_size * expected, struct_size): + iface=sock[i:i+16].split(b'\0', 1)[0] + if len(iface) == 0 or self.is_loopback(iface) or iface != primary: + # test the next one + logger.info('interface [{0}] skipped'.format(iface)) + continue + else: + # use this one + logger.info('interface [{0}] selected'.format(iface)) + break + + return iface.decode('latin-1'), socket.inet_ntoa(sock[i+20:i+24]) + + def get_primary_interface(self): + """ + Get the name of the primary interface, which is the one with the + default route attached to it; if there are multiple default routes, + the primary has the lowest Metric. + :return: the interface which has the default route + """ + # from linux/route.h + RTF_GATEWAY = 0x02 + DEFAULT_DEST = "00000000" + + hdr_iface = "Iface" + hdr_dest = "Destination" + hdr_flags = "Flags" + hdr_metric = "Metric" + + idx_iface = -1 + idx_dest = -1 + idx_flags = -1 + idx_metric = -1 + primary = None + primary_metric = None + + logger.info("examine /proc/net/route for primary interface") + with open('/proc/net/route') as routing_table: + idx = 0 + for header in filter(lambda h: len(h) > 0, routing_table.readline().strip(" \n").split("\t")): + if header == hdr_iface: + idx_iface = idx + elif header == hdr_dest: + idx_dest = idx + elif header == hdr_flags: + idx_flags = idx + elif header == hdr_metric: + idx_metric = idx + idx = idx + 1 + for entry in routing_table.readlines(): + route = entry.strip(" \n").split("\t") + if route[idx_dest] == DEFAULT_DEST and int(route[idx_flags]) & RTF_GATEWAY == RTF_GATEWAY: + metric = int(route[idx_metric]) + iface = route[idx_iface] + if primary is None or metric < primary_metric: + primary = iface + primary_metric = metric + + if primary is None: + primary = '' + + logger.info('primary interface is [{0}]'.format(primary)) + return primary + + + def is_primary_interface(self, ifname): + """ + Indicate whether the specified interface is the primary. + :param ifname: the name of the interface - eth0, lo, etc. + :return: True if this interface binds the default route + """ + return self.get_primary_interface() == ifname + + + def is_loopback(self, ifname): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + result = fcntl.ioctl(s.fileno(), 0x8913, struct.pack('256s', ifname[:15])) + flags, = struct.unpack('H', result[16:18]) + isloopback = flags & 8 == 8 + logger.info('interface [{0}] has flags [{1}], is loopback [{2}]'.format(ifname, flags, isloopback)) + return isloopback + + def get_dhcp_lease_endpoint(self): + """ + OS specific, this should return the decoded endpoint of + the wireserver from option 245 in the dhcp leases file + if it exists on disk. + :return: The endpoint if available, or None + """ + return None + + @staticmethod + def get_endpoint_from_leases_path(pathglob): + """ + Try to discover and decode the wireserver endpoint in the + specified dhcp leases path. + :param pathglob: The path containing dhcp lease files + :return: The endpoint if available, otherwise None + """ + endpoint = None + + HEADER_LEASE = "lease" + HEADER_OPTION = "option unknown-245" + HEADER_DNS = "option domain-name-servers" + HEADER_EXPIRE = "expire" + FOOTER_LEASE = "}" + FORMAT_DATETIME = "%Y/%m/%d %H:%M:%S" + + logger.info("looking for leases in path [{0}]".format(pathglob)) + for lease_file in glob.glob(pathglob): + leases = open(lease_file).read() + if HEADER_OPTION in leases: + cached_endpoint = None + has_option_245 = False + expired = True # assume expired + for line in leases.splitlines(): + if line.startswith(HEADER_LEASE): + cached_endpoint = None + has_option_245 = False + expired = True + elif HEADER_DNS in line: + cached_endpoint = line.replace(HEADER_DNS, '').strip(" ;") + elif HEADER_OPTION in line: + has_option_245 = True + elif HEADER_EXPIRE in line: + if "never" in line: + expired = False + else: + try: + expire_string = line.split(" ", 4)[-1].strip(";") + expire_date = datetime.datetime.strptime(expire_string, FORMAT_DATETIME) + if expire_date > datetime.datetime.utcnow(): + expired = False + except: + logger.error("could not parse expiry token '{0}'".format(line)) + elif FOOTER_LEASE in line: + logger.info("dhcp entry:{0}, 245:{1}, expired:{2}".format( + cached_endpoint, has_option_245, expired)) + if not expired and cached_endpoint is not None and has_option_245: + endpoint = cached_endpoint + logger.info("found endpoint [{0}]".format(endpoint)) + # we want to return the last valid entry, so + # keep searching + if endpoint is not None: + logger.info("cached endpoint found [{0}]".format(endpoint)) + else: + logger.info("cached endpoint not found") + return endpoint + + def is_missing_default_route(self): + routes = shellutil.run_get_output("route -n")[1] + for route in routes.split("\n"): + if route.startswith("0.0.0.0 ") or route.startswith("default "): + return False + return True + + def get_if_name(self): + return self.get_first_if()[0] + + def get_ip4_addr(self): + return self.get_first_if()[1] + + def set_route_for_dhcp_broadcast(self, ifname): + return shellutil.run("route add 255.255.255.255 dev {0}".format(ifname), + chk_err=False) + + def remove_route_for_dhcp_broadcast(self, ifname): + shellutil.run("route del 255.255.255.255 dev {0}".format(ifname), + chk_err=False) + + def is_dhcp_enabled(self): + return False + + def stop_dhcp_service(self): + pass + + def start_dhcp_service(self): + pass + + def start_network(self): + pass + + def start_agent_service(self): + pass + + def stop_agent_service(self): + pass + + def register_agent_service(self): + pass + + def unregister_agent_service(self): + pass + + def restart_ssh_service(self): + pass + + def route_add(self, net, mask, gateway): + """ + Add specified route using /sbin/route add -net. + """ + cmd = ("/sbin/route add -net " + "{0} netmask {1} gw {2}").format(net, mask, gateway) + return shellutil.run(cmd, chk_err=False) + + def get_dhcp_pid(self): + ret= shellutil.run_get_output("pidof dhclient") + return ret[1] if ret[0] == 0 else None + + def set_hostname(self, hostname): + fileutil.write_file('/etc/hostname', hostname) + shellutil.run("hostname {0}".format(hostname), chk_err=False) + + def set_dhcp_hostname(self, hostname): + autosend = r'^[^#]*?send\s*host-name.*?(<hostname>|gethostname[(,)])' + dhclient_files = ['/etc/dhcp/dhclient.conf', '/etc/dhcp3/dhclient.conf', '/etc/dhclient.conf'] + for conf_file in dhclient_files: + if not os.path.isfile(conf_file): + continue + if fileutil.findstr_in_file(conf_file, autosend): + #Return if auto send host-name is configured + return + fileutil.update_conf_file(conf_file, + 'send host-name', + 'send host-name "{0}";'.format(hostname)) + + def restart_if(self, ifname, retries=3, wait=5): + retry_limit=retries+1 + for attempt in range(1, retry_limit): + return_code=shellutil.run("ifdown {0} && ifup {0}".format(ifname)) + if return_code == 0: + return + logger.warn("failed to restart {0}: return code {1}".format(ifname, return_code)) + if attempt < retry_limit: + logger.info("retrying in {0} seconds".format(wait)) + time.sleep(wait) + else: + logger.warn("exceeded restart retries") + + def publish_hostname(self, hostname): + self.set_dhcp_hostname(hostname) + ifname = self.get_if_name() + self.restart_if(ifname) + + def set_scsi_disks_timeout(self, timeout): + for dev in os.listdir("/sys/block"): + if dev.startswith('sd'): + self.set_block_device_timeout(dev, timeout) + + def set_block_device_timeout(self, dev, timeout): + if dev is not None and timeout is not None: + file_path = "/sys/block/{0}/device/timeout".format(dev) + content = fileutil.read_file(file_path) + original = content.splitlines()[0].rstrip() + if original != timeout: + fileutil.write_file(file_path, timeout) + logger.info("Set block dev timeout: {0} with timeout: {1}", + dev, timeout) + + def get_mount_point(self, mountlist, device): + """ + Example of mountlist: + /dev/sda1 on / type ext4 (rw) + proc on /proc type proc (rw) + sysfs on /sys type sysfs (rw) + devpts on /dev/pts type devpts (rw,gid=5,mode=620) + tmpfs on /dev/shm type tmpfs + (rw,rootcontext="system_u:object_r:tmpfs_t:s0") + none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw) + /dev/sdb1 on /mnt/resource type ext4 (rw) + """ + if (mountlist and device): + for entry in mountlist.split('\n'): + if(re.search(device, entry)): + tokens = entry.split() + #Return the 3rd column of this line + return tokens[2] if len(tokens) > 2 else None + return None + + def device_for_ide_port(self, port_id): + """ + Return device name attached to ide port 'n'. + """ + if port_id > 3: + return None + g0 = "00000000" + if port_id > 1: + g0 = "00000001" + port_id = port_id - 2 + device = None + path = "/sys/bus/vmbus/devices/" + for vmbus in os.listdir(path): + deviceid = fileutil.read_file(os.path.join(path, vmbus, "device_id")) + guid = deviceid.lstrip('{').split('-') + if guid[0] == g0 and guid[1] == "000" + ustr(port_id): + for root, dirs, files in os.walk(path + vmbus): + if root.endswith("/block"): + device = dirs[0] + break + else : #older distros + for d in dirs: + if ':' in d and "block" == d.split(':')[0]: + device = d.split(':')[1] + break + break + return device + + def del_account(self, username): + if self.is_sys_user(username): + logger.error("{0} is a system user. Will not delete it.", username) + shellutil.run("> /var/run/utmp") + shellutil.run("userdel -f -r " + username) + self.conf_sudoer(username, remove=True) + + def decode_customdata(self, data): + return base64.b64decode(data) + + def get_total_mem(self): + cmd = "grep MemTotal /proc/meminfo |awk '{print $2}'" + ret = shellutil.run_get_output(cmd) + if ret[0] == 0: + return int(ret[1])/1024 + else: + raise OSUtilError("Failed to get total memory: {0}".format(ret[1])) + + def get_processor_cores(self): + ret = shellutil.run_get_output("grep 'processor.*:' /proc/cpuinfo |wc -l") + if ret[0] == 0: + return int(ret[1]) + else: + raise OSUtilError("Failed to get processor cores") + + def set_admin_access_to_ip(self, dest_ip): + #This allows root to access dest_ip + rm_old= "iptables -D OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0" + rule = "iptables -A OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0" + shellutil.run(rm_old.format(dest_ip), chk_err=False) + shellutil.run(rule.format(dest_ip)) + + #This blocks all other users to access dest_ip + rm_old = "iptables -D OUTPUT -d {0} -j DROP" + rule = "iptables -A OUTPUT -d {0} -j DROP" + shellutil.run(rm_old.format(dest_ip), chk_err=False) + shellutil.run(rule.format(dest_ip)) + + def check_pid_alive(self, pid): + return pid is not None and os.path.isdir(os.path.join('/proc', pid)) diff --git a/azurelinuxagent/common/osutil/factory.py b/azurelinuxagent/common/osutil/factory.py new file mode 100644 index 0000000..5e8ae6e --- /dev/null +++ b/azurelinuxagent/common/osutil/factory.py @@ -0,0 +1,69 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.utils.textutil import Version +from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \ + DISTRO_FULL_NAME + +from .default import DefaultOSUtil +from .coreos import CoreOSUtil +from .debian import DebianOSUtil +from .freebsd import FreeBSDOSUtil +from .redhat import RedhatOSUtil, Redhat6xOSUtil +from .suse import SUSEOSUtil, SUSE11OSUtil +from .ubuntu import UbuntuOSUtil, Ubuntu12OSUtil, Ubuntu14OSUtil, \ + UbuntuSnappyOSUtil + +def get_osutil(distro_name=DISTRO_NAME, distro_version=DISTRO_VERSION, + distro_full_name=DISTRO_FULL_NAME): + if distro_name == "ubuntu": + if Version(distro_version) == Version("12.04") or \ + Version(distro_version) == Version("12.10"): + return Ubuntu12OSUtil() + elif Version(distro_version) == Version("14.04") or \ + Version(distro_version) == Version("14.10"): + return Ubuntu14OSUtil() + elif distro_full_name == "Snappy Ubuntu Core": + return UbuntuSnappyOSUtil() + else: + return UbuntuOSUtil() + if distro_name == "coreos": + return CoreOSUtil() + if distro_name == "suse": + if distro_full_name=='SUSE Linux Enterprise Server' and \ + Version(distro_version) < Version('12') or \ + distro_full_name == 'openSUSE' and \ + Version(distro_version) < Version('13.2'): + return SUSE11OSUtil() + else: + return SUSEOSUtil() + elif distro_name == "debian": + return DebianOSUtil() + elif distro_name == "redhat" or distro_name == "centos" or \ + distro_name == "oracle": + if Version(distro_version) < Version("7"): + return Redhat6xOSUtil() + else: + return RedhatOSUtil() + elif distro_name == "freebsd": + return FreeBSDOSUtil() + else: + logger.warn("Unable to load distro implemetation for {0}.", distro_name) + logger.warn("Use default distro implemetation instead.") + return DefaultOSUtil() + diff --git a/azurelinuxagent/common/osutil/freebsd.py b/azurelinuxagent/common/osutil/freebsd.py new file mode 100644 index 0000000..ddf8db6 --- /dev/null +++ b/azurelinuxagent/common/osutil/freebsd.py @@ -0,0 +1,198 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import OSUtilError +from azurelinuxagent.common.osutil.default import DefaultOSUtil + + +class FreeBSDOSUtil(DefaultOSUtil): + def __init__(self): + super(FreeBSDOSUtil, self).__init__() + self._scsi_disks_timeout_set = False + + def set_hostname(self, hostname): + rc_file_path = '/etc/rc.conf' + conf_file = fileutil.read_file(rc_file_path).split("\n") + textutil.set_ini_config(conf_file, "hostname", hostname) + fileutil.write_file(rc_file_path, "\n".join(conf_file)) + shellutil.run("hostname {0}".format(hostname), chk_err=False) + + def restart_ssh_service(self): + return shellutil.run('service sshd restart', chk_err=False) + + def useradd(self, username, expiration=None): + """ + Create user account with 'username' + """ + userentry = self.get_userentry(username) + if userentry is not None: + logger.warn("User {0} already exists, skip useradd", username) + return + + if expiration is not None: + cmd = "pw useradd {0} -e {1} -m".format(username, expiration) + else: + cmd = "pw useradd {0} -m".format(username) + retcode, out = shellutil.run_get_output(cmd) + if retcode != 0: + raise OSUtilError(("Failed to create user account:{0}, " + "retcode:{1}, " + "output:{2}").format(username, retcode, out)) + + def del_account(self, username): + if self.is_sys_user(username): + logger.error("{0} is a system user. Will not delete it.", username) + shellutil.run('> /var/run/utx.active') + shellutil.run('rmuser -y ' + username) + self.conf_sudoer(username, remove=True) + + def chpasswd(self, username, password, crypt_id=6, salt_len=10): + if self.is_sys_user(username): + raise OSUtilError(("User {0} is a system user. " + "Will not set passwd.").format(username)) + passwd_hash = textutil.gen_password_hash(password, crypt_id, salt_len) + cmd = "echo '{0}'|pw usermod {1} -H 0 ".format(passwd_hash, username) + ret, output = shellutil.run_get_output(cmd, log_cmd=False) + if ret != 0: + raise OSUtilError(("Failed to set password for {0}: {1}" + "").format(username, output)) + + def del_root_password(self): + err = shellutil.run('pw mod user root -w no') + if err: + raise OSUtilError("Failed to delete root password: Failed to update password database.") + + def get_if_mac(self, ifname): + data = self._get_net_info() + if data[0] == ifname: + return data[2].replace(':', '').upper() + return None + + def get_first_if(self): + return self._get_net_info()[:2] + + def route_add(self, net, mask, gateway): + cmd = 'route add {0} {1} {2}'.format(net, gateway, mask) + return shellutil.run(cmd, chk_err=False) + + def is_missing_default_route(self): + """ + For FreeBSD, the default broadcast goes to current default gw, not a all-ones broadcast address, need to + specify the route manually to get it work in a VNET environment. + SEE ALSO: man ip(4) IP_ONESBCAST, + """ + return True + + def is_dhcp_enabled(self): + return True + + def start_dhcp_service(self): + shellutil.run("/etc/rc.d/dhclient start {0}".format(self.get_if_name()), chk_err=False) + + def allow_dhcp_broadcast(self): + pass + + def set_route_for_dhcp_broadcast(self, ifname): + return shellutil.run("route add 255.255.255.255 -iface {0}".format(ifname), chk_err=False) + + def remove_route_for_dhcp_broadcast(self, ifname): + shellutil.run("route delete 255.255.255.255 -iface {0}".format(ifname), chk_err=False) + + def get_dhcp_pid(self): + ret = shellutil.run_get_output("pgrep -n dhclient") + return ret[1] if ret[0] == 0 else None + + def eject_dvd(self, chk_err=True): + dvd = self.get_dvd_device() + retcode = shellutil.run("cdcontrol -f {0} eject".format(dvd)) + if chk_err and retcode != 0: + raise OSUtilError("Failed to eject dvd: ret={0}".format(retcode)) + + def restart_if(self, ifname): + # Restart dhclient only to publish hostname + shellutil.run("/etc/rc.d/dhclient restart {0}".format(ifname), chk_err=False) + + def get_total_mem(self): + cmd = "sysctl hw.physmem |awk '{print $2}'" + ret, output = shellutil.run_get_output(cmd) + if ret: + raise OSUtilError("Failed to get total memory: {0}".format(output)) + try: + return int(output)/1024/1024 + except ValueError: + raise OSUtilError("Failed to get total memory: {0}".format(output)) + + def get_processor_cores(self): + ret, output = shellutil.run_get_output("sysctl hw.ncpu |awk '{print $2}'") + if ret: + raise OSUtilError("Failed to get processor cores.") + + try: + return int(output) + except ValueError: + raise OSUtilError("Failed to get total memory: {0}".format(output)) + + def set_scsi_disks_timeout(self, timeout): + if self._scsi_disks_timeout_set: + return + + ret, output = shellutil.run_get_output('sysctl kern.cam.da.default_timeout={0}'.format(timeout)) + if ret: + raise OSUtilError("Failed set SCSI disks timeout: {0}".format(output)) + self._scsi_disks_timeout_set = True + + def check_pid_alive(self, pid): + return shellutil.run('ps -p {0}'.format(pid), chk_err=False) == 0 + + @staticmethod + def _get_net_info(): + """ + There is no SIOCGIFCONF + on freeBSD - just parse ifconfig. + Returns strings: iface, inet4_addr, and mac + or 'None,None,None' if unable to parse. + We will sleep and retry as the network must be up. + """ + iface = '' + inet = '' + mac = '' + + err, output = shellutil.run_get_output('ifconfig -l ether', chk_err=False) + if err: + raise OSUtilError("Can't find ether interface:{0}".format(output)) + ifaces = output.split() + if not ifaces: + raise OSUtilError("Can't find ether interface.") + iface = ifaces[0] + + err, output = shellutil.run_get_output('ifconfig ' + iface, chk_err=False) + if err: + raise OSUtilError("Can't get info for interface:{0}".format(iface)) + + for line in output.split('\n'): + if line.find('inet ') != -1: + inet = line.split()[1] + elif line.find('ether ') != -1: + mac = line.split()[1] + logger.verbose("Interface info: ({0},{1},{2})", iface, inet, mac) + + return iface, inet, mac diff --git a/azurelinuxagent/common/osutil/redhat.py b/azurelinuxagent/common/osutil/redhat.py new file mode 100644 index 0000000..03084b6 --- /dev/null +++ b/azurelinuxagent/common/osutil/redhat.py @@ -0,0 +1,122 @@ +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import re +import pwd +import shutil +import socket +import array +import struct +import fcntl +import time +import base64 +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.future import ustr, bytebuffer +from azurelinuxagent.common.exception import OSUtilError, CryptError +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.osutil.default import DefaultOSUtil + +class Redhat6xOSUtil(DefaultOSUtil): + def __init__(self): + super(Redhat6xOSUtil, self).__init__() + + def start_network(self): + return shellutil.run("/sbin/service networking start", chk_err=False) + + def restart_ssh_service(self): + return shellutil.run("/sbin/service sshd condrestart", chk_err=False) + + def stop_agent_service(self): + return shellutil.run("/sbin/service waagent stop", chk_err=False) + + def start_agent_service(self): + return shellutil.run("/sbin/service waagent start", chk_err=False) + + def register_agent_service(self): + return shellutil.run("chkconfig --add waagent", chk_err=False) + + def unregister_agent_service(self): + return shellutil.run("chkconfig --del waagent", chk_err=False) + + def openssl_to_openssh(self, input_file, output_file): + pubkey = fileutil.read_file(input_file) + try: + cryptutil = CryptUtil(conf.get_openssl_cmd()) + ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey) + except CryptError as e: + raise OSUtilError(ustr(e)) + fileutil.write_file(output_file, ssh_rsa_pubkey) + + #Override + def get_dhcp_pid(self): + ret= shellutil.run_get_output("pidof dhclient") + return ret[1] if ret[0] == 0 else None + + def set_hostname(self, hostname): + """ + Set /etc/sysconfig/network + """ + fileutil.update_conf_file('/etc/sysconfig/network', + 'HOSTNAME', + 'HOSTNAME={0}'.format(hostname)) + shellutil.run("hostname {0}".format(hostname), chk_err=False) + + def set_dhcp_hostname(self, hostname): + ifname = self.get_if_name() + filepath = "/etc/sysconfig/network-scripts/ifcfg-{0}".format(ifname) + fileutil.update_conf_file(filepath, 'DHCP_HOSTNAME', + 'DHCP_HOSTNAME={0}'.format(hostname)) + + def get_dhcp_lease_endpoint(self): + return self.get_endpoint_from_leases_path('/var/lib/dhclient/dhclient-*.leases') + +class RedhatOSUtil(Redhat6xOSUtil): + def __init__(self): + super(RedhatOSUtil, self).__init__() + + def set_hostname(self, hostname): + """ + Set /etc/hostname + Unlike redhat 6.x, redhat 7.x will set hostname to /etc/hostname + """ + DefaultOSUtil.set_hostname(self, hostname) + + def publish_hostname(self, hostname): + """ + Restart NetworkManager first before publishing hostname + """ + shellutil.run("service NetworkManager restart") + super(RedhatOSUtil, self).publish_hostname(hostname) + + def register_agent_service(self): + return shellutil.run("systemctl enable waagent", chk_err=False) + + def unregister_agent_service(self): + return shellutil.run("systemctl disable waagent", chk_err=False) + + def openssl_to_openssh(self, input_file, output_file): + DefaultOSUtil.openssl_to_openssh(self, input_file, output_file) + + def get_dhcp_lease_endpoint(self): + # centos7 has this weird naming with double hyphen like /var/lib/dhclient/dhclient--eth0.lease + return self.get_endpoint_from_leases_path('/var/lib/dhclient/dhclient-*.lease') diff --git a/azurelinuxagent/common/osutil/suse.py b/azurelinuxagent/common/osutil/suse.py new file mode 100644 index 0000000..f0ed0c0 --- /dev/null +++ b/azurelinuxagent/common/osutil/suse.py @@ -0,0 +1,108 @@ +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import re +import pwd +import shutil +import socket +import array +import struct +import fcntl +import time +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +import azurelinuxagent.common.utils.textutil as textutil +from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME +from azurelinuxagent.common.osutil.default import DefaultOSUtil + +class SUSE11OSUtil(DefaultOSUtil): + def __init__(self): + super(SUSE11OSUtil, self).__init__() + self.dhclient_name='dhcpcd' + + def set_hostname(self, hostname): + fileutil.write_file('/etc/HOSTNAME', hostname) + shellutil.run("hostname {0}".format(hostname), chk_err=False) + + def get_dhcp_pid(self): + ret= shellutil.run_get_output("pidof {0}".format(self.dhclient_name)) + return ret[1] if ret[0] == 0 else None + + def is_dhcp_enabled(self): + return True + + def stop_dhcp_service(self): + cmd = "/sbin/service {0} stop".format(self.dhclient_name) + return shellutil.run(cmd, chk_err=False) + + def start_dhcp_service(self): + cmd = "/sbin/service {0} start".format(self.dhclient_name) + return shellutil.run(cmd, chk_err=False) + + def start_network(self) : + return shellutil.run("/sbin/service start network", chk_err=False) + + def restart_ssh_service(self): + return shellutil.run("/sbin/service sshd restart", chk_err=False) + + def stop_agent_service(self): + return shellutil.run("/sbin/service waagent stop", chk_err=False) + + def start_agent_service(self): + return shellutil.run("/sbin/service waagent start", chk_err=False) + + def register_agent_service(self): + return shellutil.run("/sbin/insserv waagent", chk_err=False) + + def unregister_agent_service(self): + return shellutil.run("/sbin/insserv -r waagent", chk_err=False) + +class SUSEOSUtil(SUSE11OSUtil): + def __init__(self): + super(SUSEOSUtil, self).__init__() + self.dhclient_name = 'wickedd-dhcp4' + + def stop_dhcp_service(self): + cmd = "systemctl stop {0}".format(self.dhclient_name) + return shellutil.run(cmd, chk_err=False) + + def start_dhcp_service(self): + cmd = "systemctl start {0}".format(self.dhclient_name) + return shellutil.run(cmd, chk_err=False) + + def start_network(self) : + return shellutil.run("systemctl start network", chk_err=False) + + def restart_ssh_service(self): + return shellutil.run("systemctl restart sshd", chk_err=False) + + def stop_agent_service(self): + return shellutil.run("systemctl stop waagent", chk_err=False) + + def start_agent_service(self): + return shellutil.run("systemctl start waagent", chk_err=False) + + def register_agent_service(self): + return shellutil.run("systemctl enable waagent", chk_err=False) + + def unregister_agent_service(self): + return shellutil.run("systemctl disable waagent", chk_err=False) + + diff --git a/azurelinuxagent/common/osutil/ubuntu.py b/azurelinuxagent/common/osutil/ubuntu.py new file mode 100644 index 0000000..4032cf4 --- /dev/null +++ b/azurelinuxagent/common/osutil/ubuntu.py @@ -0,0 +1,66 @@ +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import azurelinuxagent.common.utils.shellutil as shellutil +from azurelinuxagent.common.osutil.default import DefaultOSUtil + +class Ubuntu14OSUtil(DefaultOSUtil): + def __init__(self): + super(Ubuntu14OSUtil, self).__init__() + + def start_network(self): + return shellutil.run("service networking start", chk_err=False) + + def stop_agent_service(self): + return shellutil.run("service walinuxagent stop", chk_err=False) + + def start_agent_service(self): + return shellutil.run("service walinuxagent start", chk_err=False) + + def remove_rules_files(self, rules_files=""): + pass + + def restore_rules_files(self, rules_files=""): + pass + + def get_dhcp_lease_endpoint(self): + return self.get_endpoint_from_leases_path('/var/lib/dhcp/dhclient.*.leases') + +class Ubuntu12OSUtil(Ubuntu14OSUtil): + def __init__(self): + super(Ubuntu12OSUtil, self).__init__() + + #Override + def get_dhcp_pid(self): + ret= shellutil.run_get_output("pidof dhclient3") + return ret[1] if ret[0] == 0 else None + +class UbuntuOSUtil(Ubuntu14OSUtil): + def __init__(self): + super(UbuntuOSUtil, self).__init__() + + def register_agent_service(self): + return shellutil.run("systemctl unmask walinuxagent", chk_err=False) + + def unregister_agent_service(self): + return shellutil.run("systemctl mask walinuxagent", chk_err=False) + +class UbuntuSnappyOSUtil(Ubuntu14OSUtil): + def __init__(self): + super(UbuntuSnappyOSUtil, self).__init__() + self.conf_file_path = '/apps/walinuxagent/current/waagent.conf' diff --git a/azurelinuxagent/common/protocol/__init__.py b/azurelinuxagent/common/protocol/__init__.py new file mode 100644 index 0000000..fb7c273 --- /dev/null +++ b/azurelinuxagent/common/protocol/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.common.protocol.util import get_protocol_util, \ + OVF_FILE_NAME, \ + TAG_FILE_NAME + diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py new file mode 100644 index 0000000..6569604 --- /dev/null +++ b/azurelinuxagent/common/protocol/hostplugin.py @@ -0,0 +1,124 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +from azurelinuxagent.common.protocol.wire import * +from azurelinuxagent.common.utils import textutil + +HOST_PLUGIN_PORT = 32526 +URI_FORMAT_GET_API_VERSIONS = "http://{0}:{1}/versions" +URI_FORMAT_PUT_VM_STATUS = "http://{0}:{1}/status" +URI_FORMAT_PUT_LOG = "http://{0}:{1}/vmAgentLog" +API_VERSION = "2015-09-01" + + +class HostPluginProtocol(object): + def __init__(self, endpoint): + if endpoint is None: + raise ProtocolError("Host plugin endpoint not provided") + self.is_initialized = False + self.is_available = False + self.api_versions = None + self.endpoint = endpoint + + def ensure_initialized(self): + if not self.is_initialized: + self.api_versions = self.get_api_versions() + self.is_available = API_VERSION in self.api_versions + self.is_initialized = True + return self.is_available + + def get_api_versions(self): + url = URI_FORMAT_GET_API_VERSIONS.format(self.endpoint, + HOST_PLUGIN_PORT) + logger.info("getting API versions at [{0}]".format(url)) + try: + response = restutil.http_get(url) + if response.status != httpclient.OK: + logger.error( + "get API versions returned status code [{0}]".format( + response.status)) + return [] + return response.read() + except HttpError as e: + logger.error("get API versions failed with [{0}]".format(e)) + return [] + + def put_vm_status(self, status_blob, sas_url): + """ + Try to upload the VM status via the host plugin /status channel + :param sas_url: the blob SAS url to pass to the host plugin + :type status_blob: StatusBlob + """ + if not self.ensure_initialized(): + logger.error("host plugin channel is not available") + return + if status_blob is None or status_blob.vm_status is None: + logger.error("no status data was provided") + return + url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT) + status = textutil.b64encode(status_blob.vm_status) + headers = {"x-ms-version": API_VERSION} + blob_headers = [{'headerName': 'x-ms-version', + 'headerValue': status_blob.__storage_version__}, + {'headerName': 'x-ms-blob-type', + 'headerValue': status_blob.type}] + data = json.dumps({'requestUri': sas_url, 'headers': blob_headers, + 'content': status}, sort_keys=True) + logger.info("put VM status at [{0}]".format(url)) + try: + response = restutil.http_put(url, data, headers) + if response.status != httpclient.OK: + logger.error("put VM status returned status code [{0}]".format( + response.status)) + except HttpError as e: + logger.error("put VM status failed with [{0}]".format(e)) + + def put_vm_log(self, content, container_id, deployment_id): + """ + Try to upload the given content to the host plugin + :param deployment_id: the deployment id, which is obtained from the + goal state (tenant name) + :param container_id: the container id, which is obtained from the + goal state + :param content: the binary content of the zip file to upload + :return: + """ + if not self.ensure_initialized(): + logger.error("host plugin channel is not available") + return + if content is None or container_id is None or deployment_id is None: + logger.error( + "invalid arguments passed: " + "[{0}], [{1}], [{2}]".format( + content, + container_id, + deployment_id)) + return + url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT) + + headers = {"x-ms-vmagentlog-deploymentid": deployment_id, + "x-ms-vmagentlog-containerid": container_id} + logger.info("put VM log at [{0}]".format(url)) + try: + response = restutil.http_put(url, content, headers) + if response.status != httpclient.OK: + logger.error("put log returned status code [{0}]".format( + response.status)) + except HttpError as e: + logger.error("put log failed with [{0}]".format(e)) diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py new file mode 100644 index 0000000..f86f72f --- /dev/null +++ b/azurelinuxagent/common/protocol/metadata.py @@ -0,0 +1,223 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import json +import shutil +import os +import time +from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.future import httpclient, ustr +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.restutil as restutil +import azurelinuxagent.common.utils.textutil as textutil +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.protocol.restapi import * + +METADATA_ENDPOINT='169.254.169.254' +APIVERSION='2015-05-01-preview' +BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}" + +TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem" +TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem" + +#TODO remote workarround for azure stack +MAX_PING = 30 +RETRY_PING_INTERVAL = 10 + +def _add_content_type(headers): + if headers is None: + headers = {} + headers["content-type"] = "application/json" + return headers + +class MetadataProtocol(Protocol): + + def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT): + self.apiversion = apiversion + self.endpoint = endpoint + self.identity_uri = BASE_URI.format(self.endpoint, "identity", + self.apiversion, "&$expand=*") + self.cert_uri = BASE_URI.format(self.endpoint, "certificates", + self.apiversion, "&$expand=*") + self.ext_uri = BASE_URI.format(self.endpoint, "extensionHandlers", + self.apiversion, "&$expand=*") + self.vmagent_uri = BASE_URI.format(self.endpoint, "vmAgentVersions", + self.apiversion, "&$expand=*") + self.provision_status_uri = BASE_URI.format(self.endpoint, + "provisioningStatus", + self.apiversion, "") + self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent", + self.apiversion, "") + self.ext_status_uri = BASE_URI.format(self.endpoint, + "status/extensions/{0}", + self.apiversion, "") + self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry", + self.apiversion, "") + + def _get_data(self, url, headers=None): + try: + resp = restutil.http_get(url, headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if resp.status != httpclient.OK: + raise ProtocolError("{0} - GET: {1}".format(resp.status, url)) + + data = resp.read() + etag = resp.getheader('ETag') + if data is None: + return None + data = json.loads(ustr(data, encoding="utf-8")) + return data, etag + + def _put_data(self, url, data, headers=None): + headers = _add_content_type(headers) + try: + resp = restutil.http_put(url, json.dumps(data), headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + if resp.status != httpclient.OK: + raise ProtocolError("{0} - PUT: {1}".format(resp.status, url)) + + def _post_data(self, url, data, headers=None): + headers = _add_content_type(headers) + try: + resp = restutil.http_post(url, json.dumps(data), headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + if resp.status != httpclient.CREATED: + raise ProtocolError("{0} - POST: {1}".format(resp.status, url)) + + def _get_trans_cert(self): + trans_crt_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + if not os.path.isfile(trans_crt_file): + raise ProtocolError("{0} is missing.".format(trans_crt_file)) + content = fileutil.read_file(trans_crt_file) + return textutil.get_bytes_from_pem(content) + + def detect(self): + self.get_vminfo() + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + #"Install" the cert and private key to /var/lib/waagent + thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file) + prv_file = os.path.join(conf.get_lib_dir(), + "{0}.prv".format(thumbprint)) + crt_file = os.path.join(conf.get_lib_dir(), + "{0}.crt".format(thumbprint)) + shutil.copyfile(trans_prv_file, prv_file) + shutil.copyfile(trans_cert_file, crt_file) + + + def get_vminfo(self): + vminfo = VMInfo() + data, etag = self._get_data(self.identity_uri) + set_properties("vminfo", vminfo, data) + return vminfo + + def get_certs(self): + #TODO download and save certs + return CertList() + + def get_vmagent_manifests(self, last_etag=None): + manifests = VMAgentManifestList() + data, etag = self._get_data(self.vmagent_uri) + if last_etag == None or last_etag < etag: + set_properties("vmAgentManifests", manifests.vmAgentManifests, data) + return manifests, etag + + def get_vmagent_pkgs(self, vmagent_manifest): + #Agent package is the same with extension handler + vmagent_pkgs = ExtHandlerPackageList() + data = None + for manifest_uri in vmagent_manifest.versionsManifestUris: + try: + data = self._get_data(manifest_uri.uri) + break + except ProtocolError as e: + logger.warn("Failed to get vmagent versions: {0}", e) + logger.info("Retry getting vmagent versions") + if data is None: + raise ProtocolError(("Failed to get versions for vm agent: {0}" + "").format(vmagent_manifest.family)) + set_properties("vmAgentVersions", vmagent_pkgs, data) + # TODO: What etag should this return? + return vmagent_pkgs + + def get_ext_handlers(self, last_etag=None): + headers = { + "x-ms-vmagent-public-x509-cert": self._get_trans_cert() + } + ext_list = ExtHandlerList() + data, etag = self._get_data(self.ext_uri, headers=headers) + if last_etag == None or last_etag < etag: + set_properties("extensionHandlers", ext_list.extHandlers, data) + return ext_list, etag + + def get_ext_handler_pkgs(self, ext_handler): + ext_handler_pkgs = ExtHandlerPackageList() + data = None + for version_uri in ext_handler.versionUris: + try: + data, etag = self._get_data(version_uri.uri) + break + except ProtocolError as e: + logger.warn("Failed to get version uris: {0}", e) + logger.info("Retry getting version uris") + set_properties("extensionPackages", ext_handler_pkgs, data) + return ext_handler_pkgs + + def report_provision_status(self, provision_status): + validate_param('provisionStatus', provision_status, ProvisionStatus) + data = get_properties(provision_status) + self._put_data(self.provision_status_uri, data) + + def report_vm_status(self, vm_status): + validate_param('vmStatus', vm_status, VMStatus) + data = get_properties(vm_status) + #TODO code field is not implemented for metadata protocol yet. Remove it + handler_statuses = data['vmAgent']['extensionHandlers'] + for handler_status in handler_statuses: + try: + handler_status.pop('code', None) + except KeyError: + pass + + self._put_data(self.vm_status_uri, data) + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validate_param('extensionStatus', ext_status, ExtensionStatus) + data = get_properties(ext_status) + uri = self.ext_status_uri.format(ext_name) + self._put_data(uri, data) + + def report_event(self, events): + #TODO disable telemetry for azure stack test + #validate_param('events', events, TelemetryEventList) + #data = get_properties(events) + #self._post_data(self.event_uri, data) + pass + diff --git a/azurelinuxagent/common/protocol/ovfenv.py b/azurelinuxagent/common/protocol/ovfenv.py new file mode 100644 index 0000000..4901871 --- /dev/null +++ b/azurelinuxagent/common/protocol/ovfenv.py @@ -0,0 +1,113 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +""" +Copy and parse ovf-env.xml from provisioning ISO and local cache +""" +import os +import re +import shutil +import xml.dom.minidom as minidom +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext + +OVF_VERSION = "1.0" +OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1" +WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure" + +def _validate_ovf(val, msg): + if val is None: + raise ProtocolError("Failed to parse OVF XML: {0}".format(msg)) + +class OvfEnv(object): + """ + Read, and process provisioning info from provisioning file OvfEnv.xml + """ + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("ovf-env is None") + logger.verbose("Load ovf-env.xml") + self.hostname = None + self.username = None + self.user_password = None + self.customdata = None + self.disable_ssh_password_auth = True + self.ssh_pubkeys = [] + self.ssh_keypairs = [] + self.parse(xml_text) + + def parse(self, xml_text): + """ + Parse xml tree, retreiving user and ssh key information. + Return self. + """ + wans = WA_NAME_SPACE + ovfns = OVF_NAME_SPACE + + xml_doc = parse_doc(xml_text) + + environment = find(xml_doc, "Environment", namespace=ovfns) + _validate_ovf(environment, "Environment not found") + + section = find(environment, "ProvisioningSection", namespace=wans) + _validate_ovf(section, "ProvisioningSection not found") + + version = findtext(environment, "Version", namespace=wans) + _validate_ovf(version, "Version not found") + + if version > OVF_VERSION: + logger.warn("Newer provisioning configuration detected. " + "Please consider updating waagent") + + conf_set = find(section, "LinuxProvisioningConfigurationSet", + namespace=wans) + _validate_ovf(conf_set, "LinuxProvisioningConfigurationSet not found") + + self.hostname = findtext(conf_set, "HostName", namespace=wans) + _validate_ovf(self.hostname, "HostName not found") + + self.username = findtext(conf_set, "UserName", namespace=wans) + _validate_ovf(self.username, "UserName not found") + + self.user_password = findtext(conf_set, "UserPassword", namespace=wans) + + self.customdata = findtext(conf_set, "CustomData", namespace=wans) + + auth_option = findtext(conf_set, "DisableSshPasswordAuthentication", + namespace=wans) + if auth_option is not None and auth_option.lower() == "true": + self.disable_ssh_password_auth = True + else: + self.disable_ssh_password_auth = False + + public_keys = findall(conf_set, "PublicKey", namespace=wans) + for public_key in public_keys: + path = findtext(public_key, "Path", namespace=wans) + fingerprint = findtext(public_key, "Fingerprint", namespace=wans) + value = findtext(public_key, "Value", namespace=wans) + self.ssh_pubkeys.append((path, fingerprint, value)) + + keypairs = findall(conf_set, "KeyPair", namespace=wans) + for keypair in keypairs: + path = findtext(keypair, "Path", namespace=wans) + fingerprint = findtext(keypair, "Fingerprint", namespace=wans) + self.ssh_keypairs.append((path, fingerprint)) + diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py new file mode 100644 index 0000000..7f00488 --- /dev/null +++ b/azurelinuxagent/common/protocol/restapi.py @@ -0,0 +1,272 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import copy +import re +import json +import xml.dom.minidom +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError, HttpError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.restutil as restutil + +def validate_param(name, val, expected_type): + if val is None: + raise ProtocolError("{0} is None".format(name)) + if not isinstance(val, expected_type): + raise ProtocolError(("{0} type should be {1} not {2}" + "").format(name, expected_type, type(val))) + +def set_properties(name, obj, data): + if isinstance(obj, DataContract): + validate_param("Property '{0}'".format(name), data, dict) + for prob_name, prob_val in data.items(): + prob_full_name = "{0}.{1}".format(name, prob_name) + try: + prob = getattr(obj, prob_name) + except AttributeError: + logger.warn("Unknown property: {0}", prob_full_name) + continue + prob = set_properties(prob_full_name, prob, prob_val) + setattr(obj, prob_name, prob) + return obj + elif isinstance(obj, DataContractList): + validate_param("List '{0}'".format(name), data, list) + for item_data in data: + item = obj.item_cls() + item = set_properties(name, item, item_data) + obj.append(item) + return obj + else: + return data + +def get_properties(obj): + if isinstance(obj, DataContract): + data = {} + props = vars(obj) + for prob_name, prob in list(props.items()): + data[prob_name] = get_properties(prob) + return data + elif isinstance(obj, DataContractList): + data = [] + for item in obj: + item_data = get_properties(item) + data.append(item_data) + return data + else: + return obj + +class DataContract(object): + pass + +class DataContractList(list): + def __init__(self, item_cls): + self.item_cls = item_cls + +""" +Data contract between guest and host +""" +class VMInfo(DataContract): + def __init__(self, subscriptionId=None, vmName=None, containerId=None, + roleName=None, roleInstanceName=None, tenantName=None): + self.subscriptionId = subscriptionId + self.vmName = vmName + self.containerId = containerId + self.roleName = roleName + self.roleInstanceName = roleInstanceName + self.tenantName = tenantName + +class Cert(DataContract): + def __init__(self, name=None, thumbprint=None, certificateDataUri=None): + self.name = name + self.thumbprint = thumbprint + self.certificateDataUri = certificateDataUri + +class CertList(DataContract): + def __init__(self): + self.certificates = DataContractList(Cert) + +#TODO: confirm vmagent manifest schema +class VMAgentManifestUri(DataContract): + def __init__(self, uri=None): + self.uri = uri + +class VMAgentManifest(DataContract): + def __init__(self, family=None): + self.family = family + self.versionsManifestUris = DataContractList(VMAgentManifestUri) + +class VMAgentManifestList(DataContract): + def __init__(self): + self.vmAgentManifests = DataContractList(VMAgentManifest) + +class Extension(DataContract): + def __init__(self, name=None, sequenceNumber=None, publicSettings=None, + protectedSettings=None, certificateThumbprint=None): + self.name = name + self.sequenceNumber = sequenceNumber + self.publicSettings = publicSettings + self.protectedSettings = protectedSettings + self.certificateThumbprint = certificateThumbprint + +class ExtHandlerProperties(DataContract): + def __init__(self): + self.version = None + self.upgradePolicy = None + self.state = None + self.extensions = DataContractList(Extension) + +class ExtHandlerVersionUri(DataContract): + def __init__(self): + self.uri = None + +class ExtHandler(DataContract): + def __init__(self, name=None): + self.name = name + self.properties = ExtHandlerProperties() + self.versionUris = DataContractList(ExtHandlerVersionUri) + +class ExtHandlerList(DataContract): + def __init__(self): + self.extHandlers = DataContractList(ExtHandler) + +class ExtHandlerPackageUri(DataContract): + def __init__(self, uri=None): + self.uri = uri + +class ExtHandlerPackage(DataContract): + def __init__(self, version = None): + self.version = version + self.uris = DataContractList(ExtHandlerPackageUri) + # TODO update the naming to align with metadata protocol + self.isinternal = False + +class ExtHandlerPackageList(DataContract): + def __init__(self): + self.versions = DataContractList(ExtHandlerPackage) + +class VMProperties(DataContract): + def __init__(self, certificateThumbprint=None): + #TODO need to confirm the property name + self.certificateThumbprint = certificateThumbprint + +class ProvisionStatus(DataContract): + def __init__(self, status=None, subStatus=None, description=None): + self.status = status + self.subStatus = subStatus + self.description = description + self.properties = VMProperties() + +class ExtensionSubStatus(DataContract): + def __init__(self, name=None, status=None, code=None, message=None): + self.name = name + self.status = status + self.code = code + self.message = message + +class ExtensionStatus(DataContract): + def __init__(self, configurationAppliedTime=None, operation=None, + status=None, seq_no=None, code=None, message=None): + self.configurationAppliedTime = configurationAppliedTime + self.operation = operation + self.status = status + self.sequenceNumber = seq_no + self.code = code + self.message = message + self.substatusList = DataContractList(ExtensionSubStatus) + +class ExtHandlerStatus(DataContract): + def __init__(self, name=None, version=None, status=None, code=0, + message=None): + self.name = name + self.version = version + self.status = status + self.code = code + self.message = message + self.extensions = DataContractList(ustr) + +class VMAgentStatus(DataContract): + def __init__(self, version=None, status=None, message=None): + self.version = version + self.status = status + self.message = message + self.extensionHandlers = DataContractList(ExtHandlerStatus) + +class VMStatus(DataContract): + def __init__(self): + self.vmAgent = VMAgentStatus() + +class TelemetryEventParam(DataContract): + def __init__(self, name=None, value=None): + self.name = name + self.value = value + +class TelemetryEvent(DataContract): + def __init__(self, eventId=None, providerId=None): + self.eventId = eventId + self.providerId = providerId + self.parameters = DataContractList(TelemetryEventParam) + +class TelemetryEventList(DataContract): + def __init__(self): + self.events = DataContractList(TelemetryEvent) + +class Protocol(DataContract): + + def detect(self): + raise NotImplementedError() + + def get_vminfo(self): + raise NotImplementedError() + + def get_certs(self): + raise NotImplementedError() + + def get_vmagent_manifests(self): + raise NotImplementedError() + + def get_vmagent_pkgs(self): + raise NotImplementedError() + + def get_ext_handlers(self): + raise NotImplementedError() + + def get_ext_handler_pkgs(self, extension): + raise NotImplementedError() + + def download_ext_handler_pkg(self, uri): + try: + resp = restutil.http_get(uri, chk_proxy=True) + if resp.status == restutil.httpclient.OK: + return resp.read() + except HttpError as e: + raise ProtocolError("Failed to download from: {0}".format(uri), e) + + def report_provision_status(self, provision_status): + raise NotImplementedError() + + def report_vm_status(self, vm_status): + raise NotImplementedError() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + raise NotImplementedError() + + def report_event(self, event): + raise NotImplementedError() + diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py new file mode 100644 index 0000000..7e7a74f --- /dev/null +++ b/azurelinuxagent/common/protocol/util.py @@ -0,0 +1,285 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# +import os +import re +import shutil +import time +import threading +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \ + ProtocolNotFoundError, DhcpError +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.osutil import get_osutil +from azurelinuxagent.common.dhcp import get_dhcp_handler +from azurelinuxagent.common.protocol.ovfenv import OvfEnv +from azurelinuxagent.common.protocol.wire import WireProtocol +from azurelinuxagent.common.protocol.metadata import MetadataProtocol, \ + METADATA_ENDPOINT +import azurelinuxagent.common.utils.shellutil as shellutil + +OVF_FILE_NAME = "ovf-env.xml" + +#Tag file to indicate usage of metadata protocol +TAG_FILE_NAME = "useMetadataEndpoint.tag" + +PROTOCOL_FILE_NAME = "Protocol" + +#MAX retry times for protocol probing +MAX_RETRY = 360 + +PROBE_INTERVAL = 10 + +ENDPOINT_FILE_NAME = "WireServerEndpoint" + +def get_protocol_util(): + return ProtocolUtil() + +class ProtocolUtil(object): + """ + ProtocolUtil handles initialization for protocol instance. 2 protocol types + are invoked, wire protocol and metadata protocols. + """ + def __init__(self): + self.lock = threading.Lock() + self.protocol = None + self.osutil = get_osutil() + self.dhcp_handler = get_dhcp_handler() + + def copy_ovf_env(self): + """ + Copy ovf env file from dvd to hard disk. + Remove password before save it to the disk + """ + dvd_mount_point = conf.get_dvd_mount_point() + ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME) + tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME) + try: + self.osutil.mount_dvd() + ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True) + ovfenv = OvfEnv(ovfxml) + ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml) + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + fileutil.write_file(ovf_file_path, ovfxml) + + if os.path.isfile(tag_file_path_on_dvd): + logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME) + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + shutil.copyfile(tag_file_path_on_dvd, tag_file_path) + + except (OSUtilError, IOError) as e: + raise ProtocolError(ustr(e)) + + try: + self.osutil.umount_dvd() + self.osutil.eject_dvd() + except OSUtilError as e: + logger.warn(ustr(e)) + + return ovfenv + + def get_ovf_env(self): + """ + Load saved ovf-env.xml + """ + ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME) + if os.path.isfile(ovf_file_path): + xml_text = fileutil.read_file(ovf_file_path) + return OvfEnv(xml_text) + else: + raise ProtocolError("ovf-env.xml is missing.") + + def _get_wireserver_endpoint(self): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + return fileutil.read_file(file_path) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _set_wireserver_endpoint(self, endpoint): + try: + file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME) + fileutil.write_file(file_path, endpoint) + except IOError as e: + raise OSUtilError(ustr(e)) + + def _detect_wire_protocol(self): + endpoint = self.dhcp_handler.endpoint + if endpoint is None: + logger.info("WireServer endpoint is not found. Rerun dhcp handler") + try: + self.dhcp_handler.run() + except DhcpError as e: + raise ProtocolError(ustr(e)) + endpoint = self.dhcp_handler.endpoint + + try: + protocol = WireProtocol(endpoint) + protocol.detect() + self._set_wireserver_endpoint(endpoint) + self.save_protocol("WireProtocol") + return protocol + except ProtocolError as e: + logger.info("WireServer is not responding. Reset endpoint") + self.dhcp_handler.endpoint = None + self.dhcp_handler.skip_cache = True + raise e + + def _detect_metadata_protocol(self): + protocol = MetadataProtocol() + protocol.detect() + + #Only allow root access METADATA_ENDPOINT + self.osutil.set_admin_access_to_ip(METADATA_ENDPOINT) + + self.save_protocol("MetadataProtocol") + + return protocol + + def _detect_protocol(self, protocols): + """ + Probe protocol endpoints in turn. + """ + self.clear_protocol() + + for retry in range(0, MAX_RETRY): + for protocol in protocols: + try: + if protocol == "WireProtocol": + return self._detect_wire_protocol() + + if protocol == "MetadataProtocol": + return self._detect_metadata_protocol() + + except ProtocolError as e: + logger.info("Protocol endpoint not found: {0}, {1}", + protocol, e) + + if retry < MAX_RETRY -1: + logger.info("Retry detect protocols: retry={0}", retry) + time.sleep(PROBE_INTERVAL) + raise ProtocolNotFoundError("No protocol found.") + + def _get_protocol(self): + """ + Get protocol instance based on previous detecting result. + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), + PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + raise ProtocolNotFoundError("No protocol found") + + protocol_name = fileutil.read_file(protocol_file_path) + if protocol_name == "WireProtocol": + endpoint = self._get_wireserver_endpoint() + return WireProtocol(endpoint) + elif protocol_name == "MetadataProtocol": + return MetadataProtocol() + else: + raise ProtocolNotFoundError(("Unknown protocol: {0}" + "").format(protocol_name)) + + def save_protocol(self, protocol_name): + """ + Save protocol endpoint + """ + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + try: + fileutil.write_file(protocol_file_path, protocol_name) + except IOError as e: + logger.error("Failed to save protocol endpoint: {0}", e) + + + def clear_protocol(self): + """ + Cleanup previous saved endpoint. + """ + logger.info("Clean protocol") + self.protocol = None + protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME) + if not os.path.isfile(protocol_file_path): + return + + try: + os.remove(protocol_file_path) + except IOError as e: + logger.error("Failed to clear protocol endpoint: {0}", e) + + def get_protocol(self): + """ + Detect protocol by endpoints + + :returns: protocol instance + """ + self.lock.acquire() + + try: + if self.protocol is not None: + return self.protocol + + try: + self.protocol = self._get_protocol() + return self.protocol + except ProtocolNotFoundError: + pass + + logger.info("Detect protocol endpoints") + protocols = ["WireProtocol", "MetadataProtocol"] + self.protocol = self._detect_protocol(protocols) + + return self.protocol + + finally: + self.lock.release() + + + def get_protocol_by_file(self): + """ + Detect protocol by tag file. + + If a file "useMetadataEndpoint.tag" is found on provision iso, + metedata protocol will be used. No need to probe for wire protocol + + :returns: protocol instance + """ + self.lock.acquire() + + try: + if self.protocol is not None: + return self.protocol + + try: + self.protocol = self._get_protocol() + return self.protocol + except ProtocolNotFoundError: + pass + + logger.info("Detect protocol by file") + tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME) + protocols = [] + if os.path.isfile(tag_file_path): + protocols.append("MetadataProtocol") + else: + protocols.append("WireProtocol") + self.protocol = self._detect_protocol(protocols) + return self.protocol + + finally: + self.lock.release() diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py new file mode 100644 index 0000000..29a1663 --- /dev/null +++ b/azurelinuxagent/common/protocol/wire.py @@ -0,0 +1,1218 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import time +import xml.sax.saxutils as saxutils +import azurelinuxagent.common.conf as conf +from azurelinuxagent.common.exception import ProtocolNotFoundError +from azurelinuxagent.common.future import httpclient, bytebuffer +from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext, \ + getattrib, gettext, remove_bom, get_bytes_from_pem +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.cryptutil import CryptUtil +from azurelinuxagent.common.protocol.restapi import * +from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol + +VERSION_INFO_URI = "http://{0}/?comp=versions" +GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate" +HEALTH_REPORT_URI = "http://{0}/machine?comp=health" +ROLE_PROP_URI = "http://{0}/machine?comp=roleProperties" +TELEMETRY_URI = "http://{0}/machine?comp=telemetrydata" + +WIRE_SERVER_ADDR_FILE_NAME = "WireServer" +INCARNATION_FILE_NAME = "Incarnation" +GOAL_STATE_FILE_NAME = "GoalState.{0}.xml" +HOSTING_ENV_FILE_NAME = "HostingEnvironmentConfig.xml" +SHARED_CONF_FILE_NAME = "SharedConfig.xml" +CERTS_FILE_NAME = "Certificates.xml" +P7M_FILE_NAME = "Certificates.p7m" +PEM_FILE_NAME = "Certificates.pem" +EXT_CONF_FILE_NAME = "ExtensionsConfig.{0}.xml" +MANIFEST_FILE_NAME = "{0}.{1}.manifest.xml" +TRANSPORT_CERT_FILE_NAME = "TransportCert.pem" +TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem" + +PROTOCOL_VERSION = "2012-11-30" +ENDPOINT_FINE_NAME = "WireServer" + +SHORT_WAITING_INTERVAL = 1 # 1 second +LONG_WAITING_INTERVAL = 15 # 15 seconds + + +class UploadError(HttpError): + pass + + +class WireProtocolResourceGone(ProtocolError): + pass + + +class WireProtocol(Protocol): + """Slim layer to adapt wire protocol data to metadata protocol interface""" + + # TODO: Clean-up goal state processing + # At present, some methods magically update GoalState (e.g., get_vmagent_manifests), others (e.g., get_vmagent_pkgs) + # assume its presence. A better approach would make an explicit update call that returns the incarnation number and + # establishes that number the "context" for all other calls (either by updating the internal state of the protocol or + # by having callers pass the incarnation number to the method). + + def __init__(self, endpoint): + if endpoint is None: + raise ProtocolError("WireProtocol endpoint is None") + self.endpoint = endpoint + self.client = WireClient(self.endpoint) + + def detect(self): + self.client.check_wire_protocol_version() + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + cryptutil = CryptUtil(conf.get_openssl_cmd()) + cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file) + + self.client.update_goal_state(forced=True) + + def get_vminfo(self): + goal_state = self.client.get_goal_state() + hosting_env = self.client.get_hosting_env() + + vminfo = VMInfo() + vminfo.subscriptionId = None + vminfo.vmName = hosting_env.vm_name + vminfo.tenantName = hosting_env.deployment_name + vminfo.roleName = hosting_env.role_name + vminfo.roleInstanceName = goal_state.role_instance_id + vminfo.containerId = goal_state.container_id + return vminfo + + def get_certs(self): + certificates = self.client.get_certs() + return certificates.cert_list + + def get_vmagent_manifests(self): + # Update goal state to get latest extensions config + self.client.update_goal_state() + goal_state = self.client.get_goal_state() + ext_conf = self.client.get_ext_conf() + return ext_conf.vmagent_manifests, goal_state.incarnation + + def get_vmagent_pkgs(self, vmagent_manifest): + goal_state = self.client.get_goal_state() + man = self.client.get_gafamily_manifest(vmagent_manifest, goal_state) + return man.pkg_list + + def get_ext_handlers(self): + logger.verbose("Get extension handler config") + # Update goal state to get latest extensions config + self.client.update_goal_state() + goal_state = self.client.get_goal_state() + ext_conf = self.client.get_ext_conf() + # In wire protocol, incarnation is equivalent to ETag + return ext_conf.ext_handlers, goal_state.incarnation + + def get_ext_handler_pkgs(self, ext_handler): + logger.verbose("Get extension handler package") + goal_state = self.client.get_goal_state() + man = self.client.get_ext_manifest(ext_handler, goal_state) + return man.pkg_list + + def report_provision_status(self, provision_status): + validate_param("provision_status", provision_status, ProvisionStatus) + + if provision_status.status is not None: + self.client.report_health(provision_status.status, + provision_status.subStatus, + provision_status.description) + if provision_status.properties.certificateThumbprint is not None: + thumbprint = provision_status.properties.certificateThumbprint + self.client.report_role_prop(thumbprint) + + def report_vm_status(self, vm_status): + validate_param("vm_status", vm_status, VMStatus) + self.client.status_blob.set_vm_status(vm_status) + self.client.upload_status_blob() + + def report_ext_status(self, ext_handler_name, ext_name, ext_status): + validate_param("ext_status", ext_status, ExtensionStatus) + self.client.status_blob.set_ext_status(ext_handler_name, ext_status) + + def report_event(self, events): + validate_param("events", events, TelemetryEventList) + self.client.report_event(events) + + +def _build_role_properties(container_id, role_instance_id, thumbprint): + xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>" + u"<RoleProperties>" + u"<Container>" + u"<ContainerId>{0}</ContainerId>" + u"<RoleInstances>" + u"<RoleInstance>" + u"<Id>{1}</Id>" + u"<Properties>" + u"<Property name=\"CertificateThumbprint\" value=\"{2}\" />" + u"</Properties>" + u"</RoleInstance>" + u"</RoleInstances>" + u"</Container>" + u"</RoleProperties>" + u"").format(container_id, role_instance_id, thumbprint) + return xml + + +def _build_health_report(incarnation, container_id, role_instance_id, + status, substatus, description): + # Escape '&', '<' and '>' + description = saxutils.escape(ustr(description)) + detail = u'' + if substatus is not None: + substatus = saxutils.escape(ustr(substatus)) + detail = (u"<Details>" + u"<SubStatus>{0}</SubStatus>" + u"<Description>{1}</Description>" + u"</Details>").format(substatus, description) + xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>" + u"<Health " + u"xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"" + u" xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\">" + u"<GoalStateIncarnation>{0}</GoalStateIncarnation>" + u"<Container>" + u"<ContainerId>{1}</ContainerId>" + u"<RoleInstanceList>" + u"<Role>" + u"<InstanceId>{2}</InstanceId>" + u"<Health>" + u"<State>{3}</State>" + u"{4}" + u"</Health>" + u"</Role>" + u"</RoleInstanceList>" + u"</Container>" + u"</Health>" + u"").format(incarnation, + container_id, + role_instance_id, + status, + detail) + return xml + + +""" +Convert VMStatus object to status blob format +""" + + +def ga_status_to_v1(ga_status): + formatted_msg = { + 'lang': 'en-US', + 'message': ga_status.message + } + v1_ga_status = { + 'version': ga_status.version, + 'status': ga_status.status, + 'formattedMessage': formatted_msg + } + return v1_ga_status + + +def ext_substatus_to_v1(sub_status_list): + status_list = [] + for substatus in sub_status_list: + status = { + "name": substatus.name, + "status": substatus.status, + "code": substatus.code, + "formattedMessage": { + "lang": "en-US", + "message": substatus.message + } + } + status_list.append(status) + return status_list + + +def ext_status_to_v1(ext_name, ext_status): + if ext_status is None: + return None + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + v1_sub_status = ext_substatus_to_v1(ext_status.substatusList) + v1_ext_status = { + "status": { + "name": ext_name, + "configurationAppliedTime": ext_status.configurationAppliedTime, + "operation": ext_status.operation, + "status": ext_status.status, + "code": ext_status.code, + "formattedMessage": { + "lang": "en-US", + "message": ext_status.message + } + }, + "version": 1.0, + "timestampUTC": timestamp + } + if len(v1_sub_status) != 0: + v1_ext_status['substatus'] = v1_sub_status + return v1_ext_status + + +def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp): + v1_handler_status = { + 'handlerVersion': handler_status.version, + 'handlerName': handler_status.name, + 'status': handler_status.status, + 'code': handler_status.code + } + if handler_status.message is not None: + v1_handler_status["formattedMessage"] = { + "lang": "en-US", + "message": handler_status.message + } + + if len(handler_status.extensions) > 0: + # Currently, no more than one extension per handler + ext_name = handler_status.extensions[0] + ext_status = ext_statuses.get(ext_name) + v1_ext_status = ext_status_to_v1(ext_name, ext_status) + if ext_status is not None and v1_ext_status is not None: + v1_handler_status["runtimeSettingsStatus"] = { + 'settingsStatus': v1_ext_status, + 'sequenceNumber': ext_status.sequenceNumber + } + return v1_handler_status + + +def vm_status_to_v1(vm_status, ext_statuses): + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + v1_ga_status = ga_status_to_v1(vm_status.vmAgent) + v1_handler_status_list = [] + for handler_status in vm_status.vmAgent.extensionHandlers: + v1_handler_status = ext_handler_status_to_v1(handler_status, + ext_statuses, timestamp) + if v1_handler_status is not None: + v1_handler_status_list.append(v1_handler_status) + + v1_agg_status = { + 'guestAgentStatus': v1_ga_status, + 'handlerAggregateStatus': v1_handler_status_list + } + v1_vm_status = { + 'version': '1.0', + 'timestampUTC': timestamp, + 'aggregateStatus': v1_agg_status + } + return v1_vm_status + + +class StatusBlob(object): + def __init__(self, client): + self.vm_status = None + self.ext_statuses = {} + self.client = client + self.type = None + self.data = None + + def set_vm_status(self, vm_status): + validate_param("vmAgent", vm_status, VMStatus) + self.vm_status = vm_status + + def set_ext_status(self, ext_handler_name, ext_status): + validate_param("extensionStatus", ext_status, ExtensionStatus) + self.ext_statuses[ext_handler_name] = ext_status + + def to_json(self): + report = vm_status_to_v1(self.vm_status, self.ext_statuses) + return json.dumps(report) + + __storage_version__ = "2014-02-14" + + def upload(self, url): + # TODO upload extension only if content has changed + logger.verbose("Upload status blob") + upload_successful = False + self.type = self.get_blob_type(url) + self.data = self.to_json() + try: + if self.type == "BlockBlob": + self.put_block_blob(url, self.data) + elif self.type == "PageBlob": + self.put_page_blob(url, self.data) + else: + raise ProtocolError("Unknown blob type: {0}".format(self.type)) + except HttpError as e: + logger.warn("Initial upload failed [{0}]".format(e)) + else: + logger.verbose("Uploading status blob succeeded") + upload_successful = True + return upload_successful + + def get_blob_type(self, url): + # Check blob type + logger.verbose("Check blob type.") + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + try: + resp = self.client.call_storage_service(restutil.http_head, url, { + "x-ms-date": timestamp, + 'x-ms-version': self.__class__.__storage_version__ + }) + except HttpError as e: + raise ProtocolError((u"Failed to get status blob type: {0}" + u"").format(e)) + if resp is None or resp.status != httpclient.OK: + raise ProtocolError(("Failed to get status blob type: {0}" + "").format(resp.status)) + + blob_type = resp.getheader("x-ms-blob-type") + logger.verbose("Blob type={0}".format(blob_type)) + return blob_type + + def put_block_blob(self, url, data): + logger.verbose("Upload block blob") + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + resp = self.client.call_storage_service(restutil.http_put, url, data, + { + "x-ms-date": timestamp, + "x-ms-blob-type": "BlockBlob", + "Content-Length": ustr(len(data)), + "x-ms-version": self.__class__.__storage_version__ + }) + if resp.status != httpclient.CREATED: + raise UploadError( + "Failed to upload block blob: {0}".format(resp.status)) + + def put_page_blob(self, url, data): + logger.verbose("Replace old page blob") + + # Convert string into bytes + data = bytearray(data, encoding='utf-8') + timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + # Align to 512 bytes + page_blob_size = int((len(data) + 511) / 512) * 512 + resp = self.client.call_storage_service(restutil.http_put, url, "", + { + "x-ms-date": timestamp, + "x-ms-blob-type": "PageBlob", + "Content-Length": "0", + "x-ms-blob-content-length": ustr(page_blob_size), + "x-ms-version": self.__class__.__storage_version__ + }) + if resp.status != httpclient.CREATED: + raise UploadError( + "Failed to clean up page blob: {0}".format(resp.status)) + + if url.count("?") < 0: + url = "{0}?comp=page".format(url) + else: + url = "{0}&comp=page".format(url) + + logger.verbose("Upload page blob") + page_max = 4 * 1024 * 1024 # Max page size: 4MB + start = 0 + end = 0 + while end < len(data): + end = min(len(data), start + page_max) + content_size = end - start + # Align to 512 bytes + page_end = int((end + 511) / 512) * 512 + buf_size = page_end - start + buf = bytearray(buf_size) + buf[0: content_size] = data[start: end] + resp = self.client.call_storage_service( + restutil.http_put, url, bytebuffer(buf), + { + "x-ms-date": timestamp, + "x-ms-range": "bytes={0}-{1}".format(start, page_end - 1), + "x-ms-page-write": "update", + "x-ms-version": self.__class__.__storage_version__, + "Content-Length": ustr(page_end - start) + }) + if resp is None or resp.status != httpclient.CREATED: + raise UploadError( + "Failed to upload page blob: {0}".format(resp.status)) + start = end + + +def event_param_to_v1(param): + param_format = '<Param Name="{0}" Value={1} T="{2}" />' + param_type = type(param.value) + attr_type = "" + if param_type is int: + attr_type = 'mt:uint64' + elif param_type is str: + attr_type = 'mt:wstr' + elif ustr(param_type).count("'unicode'") > 0: + attr_type = 'mt:wstr' + elif param_type is bool: + attr_type = 'mt:bool' + elif param_type is float: + attr_type = 'mt:float64' + return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)), + attr_type) + + +def event_to_v1(event): + params = "" + for param in event.parameters: + params += event_param_to_v1(param) + event_str = ('<Event id="{0}">' + '<![CDATA[{1}]]>' + '</Event>').format(event.eventId, params) + return event_str + + +class WireClient(object): + def __init__(self, endpoint): + logger.info("Wire server endpoint:{0}", endpoint) + self.endpoint = endpoint + self.goal_state = None + self.updated = None + self.hosting_env = None + self.shared_conf = None + self.certs = None + self.ext_conf = None + self.last_request = 0 + self.req_count = 0 + self.status_blob = StatusBlob(self) + self.host_plugin = HostPluginProtocol(self.endpoint) + + def prevent_throttling(self): + """ + Try to avoid throttling of wire server + """ + now = time.time() + if now - self.last_request < 1: + logger.verbose("Last request issued less than 1 second ago") + logger.verbose("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.last_request = now + + self.req_count += 1 + if self.req_count % 3 == 0: + logger.verbose("Sleep {0} second to avoid throttling.", + SHORT_WAITING_INTERVAL) + time.sleep(SHORT_WAITING_INTERVAL) + self.req_count = 0 + + def call_wireserver(self, http_req, *args, **kwargs): + """ + Call wire server. Handle throttling(403) and Resource Gone(410) + """ + self.prevent_throttling() + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.FORBIDDEN: + logger.warn("Sending too much request to wire server") + logger.info("Sleep {0} second to avoid throttling.", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + elif resp.status == httpclient.GONE: + msg = args[0] if len(args) > 0 else "" + raise WireProtocolResourceGone(msg) + else: + return resp + raise ProtocolError(("Calling wire server failed: {0}" + "").format(resp.status)) + + def decode_config(self, data): + if data is None: + return None + data = remove_bom(data) + xml_text = ustr(data, encoding='utf-8') + return xml_text + + def fetch_config(self, uri, headers): + try: + resp = self.call_wireserver(restutil.http_get, uri, + headers=headers) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if (resp.status != httpclient.OK): + raise ProtocolError("{0} - {1}".format(resp.status, uri)) + + return self.decode_config(resp.read()) + + def fetch_cache(self, local_file): + if not os.path.isfile(local_file): + raise ProtocolError("{0} is missing.".format(local_file)) + try: + return fileutil.read_file(local_file) + except IOError as e: + raise ProtocolError("Failed to read cache: {0}".format(e)) + + def save_cache(self, local_file, data): + try: + fileutil.write_file(local_file, data) + except IOError as e: + raise ProtocolError("Failed to write cache: {0}".format(e)) + + def call_storage_service(self, http_req, *args, **kwargs): + """ + Call storage service, handle SERVICE_UNAVAILABLE(503) + """ + for retry in range(0, 3): + resp = http_req(*args, **kwargs) + if resp.status == httpclient.SERVICE_UNAVAILABLE: + logger.warn("Storage service is not avaible temporaryly") + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + else: + return resp + raise ProtocolError(("Calling storage endpoint failed: {0}" + "").format(resp.status)) + + def fetch_manifest(self, version_uris): + for version_uri in version_uris: + logger.verbose("Fetch ext handler manifest: {0}", version_uri.uri) + try: + resp = self.call_storage_service(restutil.http_get, + version_uri.uri, None, + chk_proxy=True) + except HttpError as e: + raise ProtocolError(ustr(e)) + + if resp.status == httpclient.OK: + return self.decode_config(resp.read()) + logger.warn("Failed to fetch ExtensionManifest: {0}, {1}", + resp.status, version_uri.uri) + logger.info("Will retry later, in {0} seconds", + LONG_WAITING_INTERVAL) + time.sleep(LONG_WAITING_INTERVAL) + raise ProtocolError(("Failed to fetch ExtensionManifest from " + "all sources")) + + def update_hosting_env(self, goal_state): + if goal_state.hosting_env_uri is None: + raise ProtocolError("HostingEnvironmentConfig uri is empty") + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_config(goal_state.hosting_env_uri, + self.get_header()) + self.save_cache(local_file, xml_text) + self.hosting_env = HostingEnv(xml_text) + + def update_shared_conf(self, goal_state): + if goal_state.shared_conf_uri is None: + raise ProtocolError("SharedConfig uri is empty") + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_config(goal_state.shared_conf_uri, + self.get_header()) + self.save_cache(local_file, xml_text) + self.shared_conf = SharedConfig(xml_text) + + def update_certs(self, goal_state): + if goal_state.certs_uri is None: + return + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_config(goal_state.certs_uri, + self.get_header_for_cert()) + self.save_cache(local_file, xml_text) + self.certs = Certificates(self, xml_text) + + def update_ext_conf(self, goal_state): + if goal_state.ext_uri is None: + logger.info("ExtensionsConfig.xml uri is empty") + self.ext_conf = ExtensionsConfig(None) + return + incarnation = goal_state.incarnation + local_file = os.path.join(conf.get_lib_dir(), + EXT_CONF_FILE_NAME.format(incarnation)) + xml_text = self.fetch_config(goal_state.ext_uri, self.get_header()) + self.save_cache(local_file, xml_text) + self.ext_conf = ExtensionsConfig(xml_text) + + def update_goal_state(self, forced=False, max_retry=3): + uri = GOAL_STATE_URI.format(self.endpoint) + xml_text = self.fetch_config(uri, self.get_header()) + goal_state = GoalState(xml_text) + + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + + if not forced: + last_incarnation = None + if (os.path.isfile(incarnation_file)): + last_incarnation = fileutil.read_file(incarnation_file) + new_incarnation = goal_state.incarnation + if last_incarnation is not None and \ + last_incarnation == new_incarnation: + # Goalstate is not updated. + return + + # Start updating goalstate, retry on 410 + for retry in range(0, max_retry): + try: + self.goal_state = goal_state + file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) + self.save_cache(goal_state_file, xml_text) + self.save_cache(incarnation_file, goal_state.incarnation) + self.update_hosting_env(goal_state) + self.update_shared_conf(goal_state) + self.update_certs(goal_state) + self.update_ext_conf(goal_state) + return + except WireProtocolResourceGone: + logger.info("Incarnation is out of date. Update goalstate.") + xml_text = self.fetch_config(uri, self.get_header()) + goal_state = GoalState(xml_text) + + raise ProtocolError("Exceeded max retry updating goal state") + + def get_goal_state(self): + if (self.goal_state is None): + incarnation_file = os.path.join(conf.get_lib_dir(), + INCARNATION_FILE_NAME) + incarnation = self.fetch_cache(incarnation_file) + + file_name = GOAL_STATE_FILE_NAME.format(incarnation) + goal_state_file = os.path.join(conf.get_lib_dir(), file_name) + xml_text = self.fetch_cache(goal_state_file) + self.goal_state = GoalState(xml_text) + return self.goal_state + + def get_hosting_env(self): + if (self.hosting_env is None): + local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.hosting_env = HostingEnv(xml_text) + return self.hosting_env + + def get_shared_conf(self): + if (self.shared_conf is None): + local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.shared_conf = SharedConfig(xml_text) + return self.shared_conf + + def get_certs(self): + if (self.certs is None): + local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME) + xml_text = self.fetch_cache(local_file) + self.certs = Certificates(self, xml_text) + if self.certs is None: + return None + return self.certs + + def get_ext_conf(self): + if (self.ext_conf is None): + goal_state = self.get_goal_state() + if goal_state.ext_uri is None: + self.ext_conf = ExtensionsConfig(None) + else: + local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_cache(local_file) + self.ext_conf = ExtensionsConfig(xml_text) + return self.ext_conf + + def get_ext_manifest(self, ext_handler, goal_state): + local_file = MANIFEST_FILE_NAME.format(ext_handler.name, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(ext_handler.versionUris) + self.save_cache(local_file, xml_text) + return ExtensionManifest(xml_text) + + def get_gafamily_manifest(self, vmagent_manifest, goal_state): + local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family, + goal_state.incarnation) + local_file = os.path.join(conf.get_lib_dir(), local_file) + xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris) + fileutil.write_file(local_file, xml_text) + return ExtensionManifest(xml_text) + + def check_wire_protocol_version(self): + uri = VERSION_INFO_URI.format(self.endpoint) + version_info_xml = self.fetch_config(uri, None) + version_info = VersionInfo(version_info_xml) + + preferred = version_info.get_preferred() + if PROTOCOL_VERSION == preferred: + logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) + elif PROTOCOL_VERSION in version_info.get_supported(): + logger.info("Wire protocol version:{0}", PROTOCOL_VERSION) + logger.warn("Server prefered version:{0}", preferred) + else: + error = ("Agent supported wire protocol version: {0} was not " + "advised by Fabric.").format(PROTOCOL_VERSION) + raise ProtocolNotFoundError(error) + + def upload_status_blob(self): + ext_conf = self.get_ext_conf() + if ext_conf.status_upload_blob is not None: + if not self.status_blob.upload(ext_conf.status_upload_blob): + self.host_plugin.put_vm_status(self.status_blob, + ext_conf.status_upload_blob) + + def report_role_prop(self, thumbprint): + goal_state = self.get_goal_state() + role_prop = _build_role_properties(goal_state.container_id, + goal_state.role_instance_id, + thumbprint) + role_prop = role_prop.encode("utf-8") + role_prop_uri = ROLE_PROP_URI.format(self.endpoint) + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, role_prop_uri, + role_prop, headers=headers) + except HttpError as e: + raise ProtocolError((u"Failed to send role properties: {0}" + u"").format(e)) + if resp.status != httpclient.ACCEPTED: + raise ProtocolError((u"Failed to send role properties: {0}" + u", {1}").format(resp.status, resp.read())) + + def report_health(self, status, substatus, description): + goal_state = self.get_goal_state() + health_report = _build_health_report(goal_state.incarnation, + goal_state.container_id, + goal_state.role_instance_id, + status, + substatus, + description) + health_report = health_report.encode("utf-8") + health_report_uri = HEALTH_REPORT_URI.format(self.endpoint) + headers = self.get_header_for_xml_content() + try: + resp = self.call_wireserver(restutil.http_post, health_report_uri, + health_report, headers=headers, max_retry=8) + except HttpError as e: + raise ProtocolError((u"Failed to send provision status: {0}" + u"").format(e)) + if resp.status != httpclient.OK: + raise ProtocolError((u"Failed to send provision status: {0}" + u", {1}").format(resp.status, resp.read())) + + def send_event(self, provider_id, event_str): + uri = TELEMETRY_URI.format(self.endpoint) + data_format = ('<?xml version="1.0"?>' + '<TelemetryData version="1.0">' + '<Provider id="{0}">{1}' + '</Provider>' + '</TelemetryData>') + data = data_format.format(provider_id, event_str) + try: + header = self.get_header_for_xml_content() + resp = self.call_wireserver(restutil.http_post, uri, data, header) + except HttpError as e: + raise ProtocolError("Failed to send events:{0}".format(e)) + + if resp.status != httpclient.OK: + logger.verbose(resp.read()) + raise ProtocolError("Failed to send events:{0}".format(resp.status)) + + def report_event(self, event_list): + buf = {} + # Group events by providerId + for event in event_list.events: + if event.providerId not in buf: + buf[event.providerId] = "" + event_str = event_to_v1(event) + if len(event_str) >= 63 * 1024: + logger.warn("Single event too large: {0}", event_str[300:]) + continue + if len(buf[event.providerId] + event_str) >= 63 * 1024: + self.send_event(event.providerId, buf[event.providerId]) + buf[event.providerId] = "" + buf[event.providerId] = buf[event.providerId] + event_str + + # Send out all events left in buffer. + for provider_id in list(buf.keys()): + if len(buf[provider_id]) > 0: + self.send_event(provider_id, buf[provider_id]) + + def get_header(self): + return { + "x-ms-agent-name": "WALinuxAgent", + "x-ms-version": PROTOCOL_VERSION + } + + def get_header_for_xml_content(self): + return { + "x-ms-agent-name": "WALinuxAgent", + "x-ms-version": PROTOCOL_VERSION, + "Content-Type": "text/xml;charset=utf-8" + } + + def get_header_for_cert(self): + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + content = self.fetch_cache(trans_cert_file) + cert = get_bytes_from_pem(content) + return { + "x-ms-agent-name": "WALinuxAgent", + "x-ms-version": PROTOCOL_VERSION, + "x-ms-cipher-name": "DES_EDE3_CBC", + "x-ms-guest-agent-public-x509-cert": cert + } + +class VersionInfo(object): + def __init__(self, xml_text): + """ + Query endpoint server for wire protocol version. + Fail if our desired protocol version is not seen. + """ + logger.verbose("Load Version.xml") + self.parse(xml_text) + + def parse(self, xml_text): + xml_doc = parse_doc(xml_text) + preferred = find(xml_doc, "Preferred") + self.preferred = findtext(preferred, "Version") + logger.info("Fabric preferred wire protocol version:{0}", self.preferred) + + self.supported = [] + supported = find(xml_doc, "Supported") + supported_version = findall(supported, "Version") + for node in supported_version: + version = gettext(node) + logger.verbose("Fabric supported wire protocol version:{0}", version) + self.supported.append(version) + + def get_preferred(self): + return self.preferred + + def get_supported(self): + return self.supported + + +class GoalState(object): + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("GoalState.xml is None") + logger.verbose("Load GoalState.xml") + self.incarnation = None + self.expected_state = None + self.hosting_env_uri = None + self.shared_conf_uri = None + self.certs_uri = None + self.ext_uri = None + self.role_instance_id = None + self.container_id = None + self.load_balancer_probe_port = None + self.parse(xml_text) + + def parse(self, xml_text): + """ + Request configuration data from endpoint server. + """ + self.xml_text = xml_text + xml_doc = parse_doc(xml_text) + self.incarnation = findtext(xml_doc, "Incarnation") + self.expected_state = findtext(xml_doc, "ExpectedState") + self.hosting_env_uri = findtext(xml_doc, "HostingEnvironmentConfig") + self.shared_conf_uri = findtext(xml_doc, "SharedConfig") + self.certs_uri = findtext(xml_doc, "Certificates") + self.ext_uri = findtext(xml_doc, "ExtensionsConfig") + role_instance = find(xml_doc, "RoleInstance") + self.role_instance_id = findtext(role_instance, "InstanceId") + container = find(xml_doc, "Container") + self.container_id = findtext(container, "ContainerId") + lbprobe_ports = find(xml_doc, "LBProbePorts") + self.load_balancer_probe_port = findtext(lbprobe_ports, "Port") + return self + + +class HostingEnv(object): + """ + parse Hosting enviromnet config and store in + HostingEnvironmentConfig.xml + """ + + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("HostingEnvironmentConfig.xml is None") + logger.verbose("Load HostingEnvironmentConfig.xml") + self.vm_name = None + self.role_name = None + self.deployment_name = None + self.parse(xml_text) + + def parse(self, xml_text): + """ + parse and create HostingEnvironmentConfig.xml. + """ + self.xml_text = xml_text + xml_doc = parse_doc(xml_text) + incarnation = find(xml_doc, "Incarnation") + self.vm_name = getattrib(incarnation, "instance") + role = find(xml_doc, "Role") + self.role_name = getattrib(role, "name") + deployment = find(xml_doc, "Deployment") + self.deployment_name = getattrib(deployment, "name") + return self + + +class SharedConfig(object): + """ + parse role endpoint server and goal state config. + """ + + def __init__(self, xml_text): + logger.verbose("Load SharedConfig.xml") + self.parse(xml_text) + + def parse(self, xml_text): + """ + parse and write configuration to file SharedConfig.xml. + """ + # Not used currently + return self + +class Certificates(object): + """ + Object containing certificates of host and provisioned user. + """ + + def __init__(self, client, xml_text): + logger.verbose("Load Certificates.xml") + self.client = client + self.cert_list = CertList() + self.parse(xml_text) + + def parse(self, xml_text): + """ + Parse multiple certificates into seperate files. + """ + xml_doc = parse_doc(xml_text) + data = findtext(xml_doc, "Data") + if data is None: + return + + cryptutil = CryptUtil(conf.get_openssl_cmd()) + p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME) + p7m = ("MIME-Version:1.0\n" + "Content-Disposition: attachment; filename=\"{0}\"\n" + "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n" + "Content-Transfer-Encoding: base64\n" + "\n" + "{2}").format(p7m_file, p7m_file, data) + + self.client.save_cache(p7m_file, p7m) + + trans_prv_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_PRV_FILE_NAME) + trans_cert_file = os.path.join(conf.get_lib_dir(), + TRANSPORT_CERT_FILE_NAME) + pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME) + # decrypt certificates + cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file, + pem_file) + + # The parsing process use public key to match prv and crt. + buf = [] + begin_crt = False + begin_prv = False + prvs = {} + thumbprints = {} + index = 0 + v1_cert_list = [] + with open(pem_file) as pem: + for line in pem.readlines(): + buf.append(line) + if re.match(r'[-]+BEGIN.*KEY[-]+', line): + begin_prv = True + elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line): + begin_crt = True + elif re.match(r'[-]+END.*KEY[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'prv', buf) + pub = cryptutil.get_pubkey_from_prv(tmp_file) + prvs[pub] = tmp_file + buf = [] + index += 1 + begin_prv = False + elif re.match(r'[-]+END.*CERTIFICATE[-]+', line): + tmp_file = self.write_to_tmp_file(index, 'crt', buf) + pub = cryptutil.get_pubkey_from_crt(tmp_file) + thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file) + thumbprints[pub] = thumbprint + # Rename crt with thumbprint as the file name + crt = "{0}.crt".format(thumbprint) + v1_cert_list.append({ + "name": None, + "thumbprint": thumbprint + }) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt)) + buf = [] + index += 1 + begin_crt = False + + # Rename prv key with thumbprint as the file name + for pubkey in prvs: + thumbprint = thumbprints[pubkey] + if thumbprint: + tmp_file = prvs[pubkey] + prv = "{0}.prv".format(thumbprint) + os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv)) + + for v1_cert in v1_cert_list: + cert = Cert() + set_properties("certs", cert, v1_cert) + self.cert_list.certificates.append(cert) + + def write_to_tmp_file(self, index, suffix, buf): + file_name = os.path.join(conf.get_lib_dir(), + "{0}.{1}".format(index, suffix)) + self.client.save_cache(file_name, "".join(buf)) + return file_name + + +class ExtensionsConfig(object): + """ + parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent. + Install if <enabled>true</enabled>, remove if it is set to false. + """ + + def __init__(self, xml_text): + logger.verbose("Load ExtensionsConfig.xml") + self.ext_handlers = ExtHandlerList() + self.vmagent_manifests = VMAgentManifestList() + self.status_upload_blob = None + if xml_text is not None: + self.parse(xml_text) + + def parse(self, xml_text): + """ + Write configuration to file ExtensionsConfig.xml. + """ + xml_doc = parse_doc(xml_text) + + ga_families_list = find(xml_doc, "GAFamilies") + ga_families = findall(ga_families_list, "GAFamily") + + for ga_family in ga_families: + family = findtext(ga_family, "Name") + uris_list = find(ga_family, "Uris") + uris = findall(uris_list, "Uri") + manifest = VMAgentManifest() + manifest.family = family + for uri in uris: + manifestUri = VMAgentManifestUri(uri=gettext(uri)) + manifest.versionsManifestUris.append(manifestUri) + self.vmagent_manifests.vmAgentManifests.append(manifest) + + plugins_list = find(xml_doc, "Plugins") + plugins = findall(plugins_list, "Plugin") + plugin_settings_list = find(xml_doc, "PluginSettings") + plugin_settings = findall(plugin_settings_list, "Plugin") + + for plugin in plugins: + ext_handler = self.parse_plugin(plugin) + self.ext_handlers.extHandlers.append(ext_handler) + self.parse_plugin_settings(ext_handler, plugin_settings) + + self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob") + + def parse_plugin(self, plugin): + ext_handler = ExtHandler() + ext_handler.name = getattrib(plugin, "name") + ext_handler.properties.version = getattrib(plugin, "version") + ext_handler.properties.state = getattrib(plugin, "state") + + auto_upgrade = getattrib(plugin, "autoUpgrade") + if auto_upgrade is not None and auto_upgrade.lower() == "true": + ext_handler.properties.upgradePolicy = "auto" + else: + ext_handler.properties.upgradePolicy = "manual" + + location = getattrib(plugin, "location") + failover_location = getattrib(plugin, "failoverlocation") + for uri in [location, failover_location]: + version_uri = ExtHandlerVersionUri() + version_uri.uri = uri + ext_handler.versionUris.append(version_uri) + return ext_handler + + def parse_plugin_settings(self, ext_handler, plugin_settings): + if plugin_settings is None: + return + + name = ext_handler.name + version = ext_handler.properties.version + settings = [x for x in plugin_settings \ + if getattrib(x, "name") == name and \ + getattrib(x, "version") == version] + + if settings is None or len(settings) == 0: + return + + runtime_settings = None + runtime_settings_node = find(settings[0], "RuntimeSettings") + seqNo = getattrib(runtime_settings_node, "seqNo") + runtime_settings_str = gettext(runtime_settings_node) + try: + runtime_settings = json.loads(runtime_settings_str) + except ValueError as e: + logger.error("Invalid extension settings") + return + + for plugin_settings_list in runtime_settings["runtimeSettings"]: + handler_settings = plugin_settings_list["handlerSettings"] + ext = Extension() + # There is no "extension name" in wire protocol. + # Put + ext.name = ext_handler.name + ext.sequenceNumber = seqNo + ext.publicSettings = handler_settings.get("publicSettings") + ext.protectedSettings = handler_settings.get("protectedSettings") + thumbprint = handler_settings.get("protectedSettingsCertThumbprint") + ext.certificateThumbprint = thumbprint + ext_handler.properties.extensions.append(ext) + + +class ExtensionManifest(object): + def __init__(self, xml_text): + if xml_text is None: + raise ValueError("ExtensionManifest is None") + logger.verbose("Load ExtensionManifest.xml") + self.pkg_list = ExtHandlerPackageList() + self.parse(xml_text) + + def parse(self, xml_text): + xml_doc = parse_doc(xml_text) + self._handle_packages(findall(find(xml_doc, "Plugins"), "Plugin"), False) + self._handle_packages(findall(find(xml_doc, "InternalPlugins"), "Plugin"), True) + + def _handle_packages(self, packages, isinternal): + for package in packages: + version = findtext(package, "Version") + + disallow_major_upgrade = findtext(package, "DisallowMajorVersionUpgrade") + if disallow_major_upgrade is None: + disallow_major_upgrade = '' + disallow_major_upgrade = disallow_major_upgrade.lower() == "true" + + uris = find(package, "Uris") + uri_list = findall(uris, "Uri") + uri_list = [gettext(x) for x in uri_list] + pkg = ExtHandlerPackage() + pkg.version = version + pkg.disallow_major_upgrade = disallow_major_upgrade + for uri in uri_list: + pkg_uri = ExtHandlerVersionUri() + pkg_uri.uri = uri + pkg.uris.append(pkg_uri) + + pkg.isinternal = isinternal + self.pkg_list.versions.append(pkg) diff --git a/azurelinuxagent/common/rdma.py b/azurelinuxagent/common/rdma.py new file mode 100644 index 0000000..0c17e38 --- /dev/null +++ b/azurelinuxagent/common/rdma.py @@ -0,0 +1,280 @@ +# Windows Azure Linux Agent +# +# Copyright 2016 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Handle packages and modules to enable RDMA for IB networking +""" + +import os +import re +import time +import threading + +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +import azurelinuxagent.common.utils.fileutil as fileutil +import azurelinuxagent.common.utils.shellutil as shellutil +from azurelinuxagent.common.utils.textutil import parse_doc, find, getattrib + + +from azurelinuxagent.common.protocol.wire import SHARED_CONF_FILE_NAME + +dapl_config_paths = [ + '/etc/dat.conf', + '/etc/rdma/dat.conf', + '/usr/local/etc/dat.conf' +] + +def setup_rdma_device(): + logger.verbose("Parsing SharedConfig XML contents for RDMA details") + xml_doc = parse_doc( + fileutil.read_file(os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME))) + if xml_doc is None: + logger.error("Could not parse SharedConfig XML document") + return + instance_elem = find(xml_doc, "Instance") + if not instance_elem: + logger.error("Could not find <Instance> in SharedConfig document") + return + + rdma_ipv4_addr = getattrib(instance_elem, "rdmaIPv4Address") + if not rdma_ipv4_addr: + logger.error( + "Could not find rdmaIPv4Address attribute on Instance element of SharedConfig.xml document") + return + + rdma_mac_addr = getattrib(instance_elem, "rdmaMacAddress") + if not rdma_mac_addr: + logger.error( + "Could not find rdmaMacAddress attribute on Instance element of SharedConfig.xml document") + return + + # add colons to the MAC address (e.g. 00155D33FF1D -> + # 00:15:5D:33:FF:1D) + rdma_mac_addr = ':'.join([rdma_mac_addr[i:i+2] + for i in range(0, len(rdma_mac_addr), 2)]) + logger.info("Found RDMA details. IPv4={0} MAC={1}".format( + rdma_ipv4_addr, rdma_mac_addr)) + + # Set up the RDMA device with collected informatino + RDMADeviceHandler(rdma_ipv4_addr, rdma_mac_addr).start() + logger.info("RDMA: device is set up") + return + +class RDMAHandler(object): + + driver_module_name = 'hv_network_direct' + + @staticmethod + def get_rdma_version(): + """Retrieve the firmware version information from the system. + This depends on information provided by the Linux kernel.""" + + driver_info_source = '/var/lib/hyperv/.kvp_pool_0' + base_kernel_err_msg = 'Kernel does not provide the necessary ' + base_kernel_err_msg += 'information or the hv_kvp_daemon is not ' + base_kernel_err_msg += 'running.' + if not os.path.isfile(driver_info_source): + error_msg = 'RDMA: Source file "%s" does not exist. ' + error_msg += base_kernel_err_msg + logger.error(error_msg % driver_info_source) + return + + lines = open(driver_info_source).read() + if not lines: + error_msg = 'RDMA: Source file "%s" is empty. ' + error_msg += base_kernel_err_msg + logger.error(error_msg % driver_info_source) + return + + r = re.search("NdDriverVersion\0+(\d\d\d\.\d)", lines) + if r: + NdDriverVersion = r.groups()[0] + return NdDriverVersion + else: + error_msg = 'RDMA: NdDriverVersion not found in "%s"' + logger.error(error_msg % driver_info_source) + return + + def load_driver_module(self): + """Load the kernel driver, this depends on the proper driver + to be installed with the install_driver() method""" + logger.info("RDMA: probing module '%s'" % self.driver_module_name) + result = shellutil.run('modprobe %s' % self.driver_module_name) + if result != 0: + error_msg = 'Could not load "%s" kernel module. ' + error_msg += 'Run "modprobe %s" as root for more details' + logger.error( + error_msg % (self.driver_module_name, self.driver_module_name) + ) + return + logger.info('RDMA: Loaded the kernel driver successfully.') + return True + + def install_driver(self): + """Install the driver. This is distribution specific and must + be overwritten in the child implementation.""" + logger.error('RDMAHandler.install_driver not implemented') + + def is_driver_loaded(self): + """Check if the network module is loaded in kernel space""" + cmd = 'lsmod | grep ^%s' % self.driver_module_name + status, loaded_modules = shellutil.run_get_output(cmd) + logger.info('RDMA: Checking if the module loaded.') + if loaded_modules: + logger.info('RDMA: module loaded.') + return True + logger.info('RDMA: module not loaded.') + + def reboot_system(self): + """Reboot the system. This is required as the kernel module for + the rdma driver cannot be unloaded with rmmod""" + logger.info('RDMA: Rebooting system.') + ret = shellutil.run('shutdown -r now') + if ret != 0: + logger.error('RDMA: Failed to reboot the system') + + +dapl_config_paths = [ + '/etc/dat.conf', '/etc/rdma/dat.conf', '/usr/local/etc/dat.conf'] + +class RDMADeviceHandler(object): + + """ + Responsible for writing RDMA IP and MAC address to the /dev/hvnd_rdma + interface. + """ + + rdma_dev = '/dev/hvnd_rdma' + device_check_timeout_sec = 120 + device_check_interval_sec = 1 + + ipv4_addr = None + mac_adr = None + + def __init__(self, ipv4_addr, mac_addr): + self.ipv4_addr = ipv4_addr + self.mac_addr = mac_addr + + def start(self): + """ + Start a thread in the background to process the RDMA tasks and returns. + """ + logger.info("RDMA: starting device processing in the background.") + threading.Thread(target=self.process).start() + + def process(self): + RDMADeviceHandler.wait_rdma_device( + self.rdma_dev, self.device_check_timeout_sec, self.device_check_interval_sec) + RDMADeviceHandler.update_dat_conf(dapl_config_paths, self.ipv4_addr) + RDMADeviceHandler.write_rdma_config_to_device( + self.rdma_dev, self.ipv4_addr, self.mac_addr) + RDMADeviceHandler.update_network_interface(self.mac_addr, self.ipv4_addr) + + @staticmethod + def update_dat_conf(paths, ipv4_addr): + """ + Looks at paths for dat.conf file and updates the ip address for the + infiniband interface. + """ + logger.info("Updating DAPL configuration file") + for f in paths: + logger.info("RDMA: trying {0}".format(f)) + if not os.path.isfile(f): + logger.info( + "RDMA: DAPL config not found at {0}".format(f)) + continue + logger.info("RDMA: DAPL config is at: {0}".format(f)) + cfg = fileutil.read_file(f) + new_cfg = RDMADeviceHandler.replace_dat_conf_contents( + cfg, ipv4_addr) + fileutil.write_file(f, new_cfg) + logger.info("RDMA: DAPL configuration is updated") + return + + raise Exception("RDMA: DAPL configuration file not found at predefined paths") + + @staticmethod + def replace_dat_conf_contents(cfg, ipv4_addr): + old = "ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 dapl.2.0 \"\S+ 0\"" + new = "ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 dapl.2.0 \"{0} 0\"".format( + ipv4_addr) + return re.sub(old, new, cfg) + + @staticmethod + def write_rdma_config_to_device(path, ipv4_addr, mac_addr): + data = RDMADeviceHandler.generate_rdma_config(ipv4_addr, mac_addr) + logger.info( + "RDMA: Updating device with configuration: {0}".format(data)) + with open(path, "w") as f: + f.write(data) + logger.info("RDMA: Updated device with IPv4/MAC addr successfully") + + @staticmethod + def generate_rdma_config(ipv4_addr, mac_addr): + return 'rdmaMacAddress="{0}" rdmaIPv4Address="{1}"'.format(mac_addr, ipv4_addr) + + @staticmethod + def wait_rdma_device(path, timeout_sec, check_interval_sec): + logger.info("RDMA: waiting for device={0} timeout={1}s".format(path, timeout_sec)) + total_retries = timeout_sec/check_interval_sec + n = 0 + while n < total_retries: + if os.path.exists(path): + logger.info("RDMA: device ready") + return + logger.verbose( + "RDMA: device not ready, sleep {0}s".format(check_interval_sec)) + time.sleep(check_interval_sec) + n += 1 + logger.error("RDMA device wait timed out") + raise Exception("The device did not show up in {0} seconds ({1} retries)".format( + timeout_sec, total_retries)) + + @staticmethod + def update_network_interface(mac_addr, ipv4_addr): + netmask=16 + + logger.info("RDMA: will update the network interface with IPv4/MAC") + + if_name=RDMADeviceHandler.get_interface_by_mac(mac_addr) + logger.info("RDMA: network interface found: {0}", if_name) + logger.info("RDMA: bringing network interface up") + if shellutil.run("ifconfig {0} up".format(if_name)) != 0: + raise Exception("Could not bring up RMDA interface: {0}".format(if_name)) + + logger.info("RDMA: configuring IPv4 addr and netmask on interface") + addr = '{0}/{1}'.format(ipv4_addr, netmask) + if shellutil.run("ifconfig {0} {1}".format(if_name, addr)) != 0: + raise Exception("Could set addr to {1} on {0}".format(if_name, addr)) + logger.info("RDMA: network address and netmask configured on interface") + + @staticmethod + def get_interface_by_mac(mac): + ret, output = shellutil.run_get_output("ifconfig -a") + if ret != 0: + raise Exception("Failed to list network interfaces") + output = output.replace('\n', '') + match = re.search(r"(eth\d).*(HWaddr|ether) {0}".format(mac), + output, re.IGNORECASE) + if match is None: + raise Exception("Failed to get ifname with mac: {0}".format(mac)) + output = match.group(0) + eths = re.findall(r"eth\d", output) + if eths is None or len(eths) == 0: + raise Exception("ifname with mac: {0} not found".format(mac)) + return eths[-1] diff --git a/azurelinuxagent/common/utils/__init__.py b/azurelinuxagent/common/utils/__init__.py new file mode 100644 index 0000000..1ea2f38 --- /dev/null +++ b/azurelinuxagent/common/utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + diff --git a/azurelinuxagent/common/utils/cryptutil.py b/azurelinuxagent/common/utils/cryptutil.py new file mode 100644 index 0000000..b35bda0 --- /dev/null +++ b/azurelinuxagent/common/utils/cryptutil.py @@ -0,0 +1,121 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import base64 +import struct +from azurelinuxagent.common.future import ustr, bytebuffer +from azurelinuxagent.common.exception import CryptError +import azurelinuxagent.common.utils.shellutil as shellutil + +class CryptUtil(object): + def __init__(self, openssl_cmd): + self.openssl_cmd = openssl_cmd + + def gen_transport_cert(self, prv_file, crt_file): + """ + Create ssl certificate for https communication with endpoint server. + """ + cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 " + "-newkey rsa:2048 -keyout {1} " + "-out {2}").format(self.openssl_cmd, prv_file, crt_file) + shellutil.run(cmd) + + def get_pubkey_from_prv(self, file_name): + cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, + file_name) + pub = shellutil.run_get_output(cmd)[1] + return pub + + def get_pubkey_from_crt(self, file_name): + cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd, + file_name) + pub = shellutil.run_get_output(cmd)[1] + return pub + + def get_thumbprint_from_crt(self, file_name): + cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd, + file_name) + thumbprint = shellutil.run_get_output(cmd)[1] + thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper() + return thumbprint + + def decrypt_p7m(self, p7m_file, trans_prv_file, trans_cert_file, pem_file): + cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3} " + "| {4} pkcs12 -nodes -password pass: -out {5}" + "").format(self.openssl_cmd, p7m_file, trans_prv_file, + trans_cert_file, self.openssl_cmd, pem_file) + shellutil.run(cmd) + + def crt_to_ssh(self, input_file, output_file): + shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file, + output_file)) + + def asn1_to_ssh(self, pubkey): + lines = pubkey.split("\n") + lines = [x for x in lines if not x.startswith("----")] + base64_encoded = "".join(lines) + try: + #TODO remove pyasn1 dependency + from pyasn1.codec.der import decoder as der_decoder + der_encoded = base64.b64decode(base64_encoded) + der_encoded = der_decoder.decode(der_encoded)[0][1] + key = der_decoder.decode(self.bits_to_bytes(der_encoded))[0] + n=key[0] + e=key[1] + keydata = bytearray() + keydata.extend(struct.pack('>I', len("ssh-rsa"))) + keydata.extend(b"ssh-rsa") + keydata.extend(struct.pack('>I', len(self.num_to_bytes(e)))) + keydata.extend(self.num_to_bytes(e)) + keydata.extend(struct.pack('>I', len(self.num_to_bytes(n)) + 1)) + keydata.extend(b"\0") + keydata.extend(self.num_to_bytes(n)) + keydata_base64 = base64.b64encode(bytebuffer(keydata)) + return ustr(b"ssh-rsa " + keydata_base64 + b"\n", + encoding='utf-8') + except ImportError as e: + raise CryptError("Failed to load pyasn1.codec.der") + + def num_to_bytes(self, num): + """ + Pack number into bytes. Retun as string. + """ + result = bytearray() + while num: + result.append(num & 0xFF) + num >>= 8 + result.reverse() + return result + + def bits_to_bytes(self, bits): + """ + Convert an array contains bits, [0,1] to a byte array + """ + index = 7 + byte_array = bytearray() + curr = 0 + for bit in bits: + curr = curr | (bit << index) + index = index - 1 + if index == -1: + byte_array.append(curr) + curr = 0 + index = 7 + return bytes(byte_array) + diff --git a/azurelinuxagent/common/utils/fileutil.py b/azurelinuxagent/common/utils/fileutil.py new file mode 100644 index 0000000..7ef4fef --- /dev/null +++ b/azurelinuxagent/common/utils/fileutil.py @@ -0,0 +1,171 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +""" +File operation util functions +""" + +import os +import re +import shutil +import pwd +import tempfile +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.utils.textutil as textutil + +def copy_file(from_path, to_path=None, to_dir=None): + if to_path is None: + to_path = os.path.join(to_dir, os.path.basename(from_path)) + shutil.copyfile(from_path, to_path) + return to_path + + +def read_file(filepath, asbin=False, remove_bom=False, encoding='utf-8'): + """ + Read and return contents of 'filepath'. + """ + mode = 'rb' + with open(filepath, mode) as in_file: + data = in_file.read() + if data is None: + return None + + if asbin: + return data + + if remove_bom: + #Remove bom on bytes data before it is converted into string. + data = textutil.remove_bom(data) + data = ustr(data, encoding=encoding) + return data + +def write_file(filepath, contents, asbin=False, encoding='utf-8', append=False): + """ + Write 'contents' to 'filepath'. + """ + mode = "ab" if append else "wb" + data = contents + if not asbin: + data = contents.encode(encoding) + with open(filepath, mode) as out_file: + out_file.write(data) + +def append_file(filepath, contents, asbin=False, encoding='utf-8'): + """ + Append 'contents' to 'filepath'. + """ + write_file(filepath, contents, asbin=asbin, encoding=encoding, append=True) + + +def base_name(path): + head, tail = os.path.split(path) + return tail + +def get_line_startingwith(prefix, filepath): + """ + Return line from 'filepath' if the line startswith 'prefix' + """ + for line in read_file(filepath).split('\n'): + if line.startswith(prefix): + return line + return None + +#End File operation util functions + +def mkdir(dirpath, mode=None, owner=None): + if not os.path.isdir(dirpath): + os.makedirs(dirpath) + if mode is not None: + chmod(dirpath, mode) + if owner is not None: + chowner(dirpath, owner) + +def chowner(path, owner): + if not os.path.exists(path): + logger.error("Path does not exist: {0}".format(path)) + else: + owner_info = pwd.getpwnam(owner) + os.chown(path, owner_info[2], owner_info[3]) + +def chmod(path, mode): + if not os.path.exists(path): + logger.error("Path does not exist: {0}".format(path)) + else: + os.chmod(path, mode) + +def rm_files(*args): + for path in args: + if os.path.isfile(path): + os.remove(path) + +def rm_dirs(*args): + """ + Remove all the contents under the directry + """ + for dir_name in args: + if os.path.isdir(dir_name): + for item in os.listdir(dir_name): + path = os.path.join(dir_name, item) + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + +def trim_ext(path, ext): + if not ext.startswith("."): + ext = "." + ext + return path.split(ext)[0] if path.endswith(ext) else path + +def update_conf_file(path, line_start, val, chk_err=False): + conf = [] + if not os.path.isfile(path) and chk_err: + raise IOError("Can't find config file:{0}".format(path)) + conf = read_file(path).split('\n') + conf = [x for x in conf if not x.startswith(line_start)] + conf.append(val) + write_file(path, '\n'.join(conf)) + +def search_file(target_dir_name, target_file_name): + for root, dirs, files in os.walk(target_dir_name): + for file_name in files: + if file_name == target_file_name: + return os.path.join(root, file_name) + return None + +def chmod_tree(path, mode): + for root, dirs, files in os.walk(path): + for file_name in files: + os.chmod(os.path.join(root, file_name), mode) + +def findstr_in_file(file_path, pattern_str): + """ + Return match object if found in file. + """ + try: + pattern = re.compile(pattern_str) + for line in (open(file_path, 'r')).readlines(): + match = re.search(pattern, line) + if match: + return match + except: + raise + + return None + diff --git a/azurelinuxagent/common/utils/flexible_version.py b/azurelinuxagent/common/utils/flexible_version.py new file mode 100644 index 0000000..2fce88d --- /dev/null +++ b/azurelinuxagent/common/utils/flexible_version.py @@ -0,0 +1,199 @@ +from distutils import version +import re + +class FlexibleVersion(version.Version): + """ + A more flexible implementation of distutils.version.StrictVersion + + The implementation allows to specify: + - an arbitrary number of version numbers: + not only '1.2.3' , but also '1.2.3.4.5' + - the separator between version numbers: + '1-2-3' is allowed when '-' is specified as separator + - a flexible pre-release separator: + '1.2.3.alpha1', '1.2.3-alpha1', and '1.2.3alpha1' are considered equivalent + - an arbitrary ordering of pre-release tags: + 1.1alpha3 < 1.1beta2 < 1.1rc1 < 1.1 + when ["alpha", "beta", "rc"] is specified as pre-release tag list + + Inspiration from this discussion at StackOverflow: + http://stackoverflow.com/questions/12255554/sort-versions-in-python + """ + + def __init__(self, vstring=None, sep='.', prerel_tags=('alpha', 'beta', 'rc')): + version.Version.__init__(self) + + if sep is None: + sep = '.' + if prerel_tags is None: + prerel_tags = () + + self.sep = sep + self.prerel_sep = '' + self.prerel_tags = tuple(prerel_tags) if prerel_tags is not None else () + + self._compile_pattern() + + self.prerelease = None + self.version = () + if vstring: + self._parse(vstring) + return + + _nn_version = 'version' + _nn_prerel_sep = 'prerel_sep' + _nn_prerel_tag = 'tag' + _nn_prerel_num = 'tag_num' + + _re_prerel_sep = r'(?P<{pn}>{sep})?'.format( + pn=_nn_prerel_sep, + sep='|'.join(map(re.escape, ('.', '-')))) + + @property + def major(self): + return self.version[0] if len(self.version) > 0 else 0 + + @property + def minor(self): + return self.version[1] if len(self.version) > 1 else 0 + + @property + def patch(self): + return self.version[2] if len(self.version) > 2 else 0 + + def _parse(self, vstring): + m = self.version_re.match(vstring) + if not m: + raise ValueError("Invalid version number '{0}'".format(vstring)) + + self.prerelease = None + self.version = () + + self.prerel_sep = m.group(self._nn_prerel_sep) + tag = m.group(self._nn_prerel_tag) + tag_num = m.group(self._nn_prerel_num) + + if tag is not None and tag_num is not None: + self.prerelease = (tag, int(tag_num) if len(tag_num) else None) + + self.version = tuple(map(int, self.sep_re.split(m.group(self._nn_version)))) + return + + def __add__(self, increment): + version = list(self.version) + version[-1] += increment + vstring = self._assemble(version, self.sep, self.prerel_sep, self.prerelease) + return FlexibleVersion(vstring=vstring, sep=self.sep, prerel_tags=self.prerel_tags) + + def __sub__(self, decrement): + version = list(self.version) + if version[-1] <= 0: + raise ArithmeticError("Cannot decrement final numeric component of {0} below zero" \ + .format(self)) + version[-1] -= decrement + vstring = self._assemble(version, self.sep, self.prerel_sep, self.prerelease) + return FlexibleVersion(vstring=vstring, sep=self.sep, prerel_tags=self.prerel_tags) + + def __repr__(self): + return "{cls} ('{vstring}', '{sep}', {prerel_tags})"\ + .format( + cls=self.__class__.__name__, + vstring=str(self), + sep=self.sep, + prerel_tags=self.prerel_tags) + + def __str__(self): + return self._assemble(self.version, self.sep, self.prerel_sep, self.prerelease) + + def __ge__(self, that): + return not self.__lt__(that) + + def __gt__(self, that): + return (not self.__lt__(that)) and (not self.__eq__(that)) + + def __le__(self, that): + return (self.__lt__(that)) or (self.__eq__(that)) + + def __lt__(self, that): + this_version, that_version = self._ensure_compatible(that) + + if this_version != that_version \ + or self.prerelease is None and that.prerelease is None: + return this_version < that_version + + if self.prerelease is not None and that.prerelease is None: + return True + if self.prerelease is None and that.prerelease is not None: + return False + + this_index = self.prerel_tags_set[self.prerelease[0]] + that_index = self.prerel_tags_set[that.prerelease[0]] + if this_index == that_index: + return self.prerelease[1] < that.prerelease[1] + + return this_index < that_index + + def __ne__(self, that): + return not self.__eq__(that) + + def __eq__(self, that): + this_version, that_version = self._ensure_compatible(that) + + if this_version != that_version: + return False + + if self.prerelease != that.prerelease: + return False + + return True + + def _assemble(self, version, sep, prerel_sep, prerelease): + s = sep.join(map(str, version)) + if prerelease is not None: + if prerel_sep is not None: + s += prerel_sep + s += prerelease[0] + if prerelease[1] is not None: + s += str(prerelease[1]) + return s + + def _compile_pattern(self): + sep, self.sep_re = self._compile_separator(self.sep) + + if self.prerel_tags: + tags = '|'.join(re.escape(tag) for tag in self.prerel_tags) + self.prerel_tags_set = dict(zip(self.prerel_tags, range(len(self.prerel_tags)))) + release_re = '(?:{prerel_sep}(?P<{tn}>{tags})(?P<{nn}>\d*))?'.format( + prerel_sep=self._re_prerel_sep, + tags=tags, + tn=self._nn_prerel_tag, + nn=self._nn_prerel_num) + else: + release_re = '' + + version_re = r'^(?P<{vn}>\d+(?:(?:{sep}\d+)*)?){rel}$'.format( + vn=self._nn_version, + sep=sep, + rel=release_re) + self.version_re = re.compile(version_re) + return + + def _compile_separator(self, sep): + if sep is None: + return '', re.compile('') + return re.escape(sep), re.compile(re.escape(sep)) + + def _ensure_compatible(self, that): + """ + Ensures the instances have the same structure and, if so, returns length compatible + version lists (so that x.y.0.0 is equivalent to x.y). + """ + if self.prerel_tags != that.prerel_tags or self.sep != that.sep: + raise ValueError("Unable to compare: versions have different structures") + + this_version = list(self.version[:]) + that_version = list(that.version[:]) + while len(this_version) < len(that_version): this_version.append(0) + while len(that_version) < len(this_version): that_version.append(0) + + return this_version, that_version diff --git a/azurelinuxagent/common/utils/restutil.py b/azurelinuxagent/common/utils/restutil.py new file mode 100644 index 0000000..a789650 --- /dev/null +++ b/azurelinuxagent/common/utils/restutil.py @@ -0,0 +1,156 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import time +import platform +import os +import subprocess +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.logger as logger +from azurelinuxagent.common.exception import HttpError +from azurelinuxagent.common.future import httpclient, urlparse + +""" +REST api util functions +""" + +RETRY_WAITING_INTERVAL = 10 + +def _parse_url(url): + o = urlparse(url) + rel_uri = o.path + if o.fragment: + rel_uri = "{0}#{1}".format(rel_uri, o.fragment) + if o.query: + rel_uri = "{0}?{1}".format(rel_uri, o.query) + secure = False + if o.scheme.lower() == "https": + secure = True + return o.hostname, o.port, secure, rel_uri + +def get_http_proxy(): + """ + Get http_proxy and https_proxy from environment variables. + Username and password is not supported now. + """ + host = conf.get_httpproxy_host() + port = conf.get_httpproxy_port() + return (host, port) + +def _http_request(method, host, rel_uri, port=None, data=None, secure=False, + headers=None, proxy_host=None, proxy_port=None): + url, conn = None, None + if secure: + port = 443 if port is None else port + if proxy_host is not None and proxy_port is not None: + conn = httpclient.HTTPSConnection(proxy_host, proxy_port, timeout=10) + conn.set_tunnel(host, port) + #If proxy is used, full url is needed. + url = "https://{0}:{1}{2}".format(host, port, rel_uri) + else: + conn = httpclient.HTTPSConnection(host, port, timeout=10) + url = rel_uri + else: + port = 80 if port is None else port + if proxy_host is not None and proxy_port is not None: + conn = httpclient.HTTPConnection(proxy_host, proxy_port, timeout=10) + #If proxy is used, full url is needed. + url = "http://{0}:{1}{2}".format(host, port, rel_uri) + else: + conn = httpclient.HTTPConnection(host, port, timeout=10) + url = rel_uri + if headers == None: + conn.request(method, url, data) + else: + conn.request(method, url, data, headers) + resp = conn.getresponse() + return resp + +def http_request(method, url, data, headers=None, max_retry=3, chk_proxy=False): + """ + Sending http request to server + On error, sleep 10 and retry max_retry times. + """ + logger.verbose("HTTP Req: {0} {1}", method, url) + logger.verbose(" Data={0}", data) + logger.verbose(" Header={0}", headers) + host, port, secure, rel_uri = _parse_url(url) + + #Check proxy + proxy_host, proxy_port = (None, None) + if chk_proxy: + proxy_host, proxy_port = get_http_proxy() + + #If httplib module is not built with ssl support. Fallback to http + if secure and not hasattr(httpclient, "HTTPSConnection"): + logger.warn("httplib is not built with ssl support") + secure = False + + #If httplib module doesn't support https tunnelling. Fallback to http + if secure and \ + proxy_host is not None and \ + proxy_port is not None and \ + not hasattr(httpclient.HTTPSConnection, "set_tunnel"): + logger.warn("httplib doesn't support https tunnelling(new in python 2.7)") + secure = False + + for retry in range(0, max_retry): + try: + resp = _http_request(method, host, rel_uri, port=port, data=data, + secure=secure, headers=headers, + proxy_host=proxy_host, proxy_port=proxy_port) + logger.verbose("HTTP Resp: Status={0}", resp.status) + logger.verbose(" Header={0}", resp.getheaders()) + return resp + except httpclient.HTTPException as e: + logger.warn('HTTPException {0}, args:{1}', e, repr(e.args)) + except IOError as e: + logger.warn('Socket IOError {0}, args:{1}', e, repr(e.args)) + + if retry < max_retry - 1: + logger.info("Retry={0}, {1} {2}", retry, method, url) + time.sleep(RETRY_WAITING_INTERVAL) + + if url is not None and len(url) > 100: + url_log = url[0: 100] #In case the url is too long + else: + url_log = url + raise HttpError("HTTP Err: {0} {1}".format(method, url_log)) + +def http_get(url, headers=None, max_retry=3, chk_proxy=False): + return http_request("GET", url, data=None, headers=headers, + max_retry=max_retry, chk_proxy=chk_proxy) + +def http_head(url, headers=None, max_retry=3, chk_proxy=False): + return http_request("HEAD", url, None, headers=headers, + max_retry=max_retry, chk_proxy=chk_proxy) + +def http_post(url, data, headers=None, max_retry=3, chk_proxy=False): + return http_request("POST", url, data, headers=headers, + max_retry=max_retry, chk_proxy=chk_proxy) + +def http_put(url, data, headers=None, max_retry=3, chk_proxy=False): + return http_request("PUT", url, data, headers=headers, + max_retry=max_retry, chk_proxy=chk_proxy) + +def http_delete(url, headers=None, max_retry=3, chk_proxy=False): + return http_request("DELETE", url, None, headers=headers, + max_retry=max_retry, chk_proxy=chk_proxy) + +#End REST api util functions diff --git a/azurelinuxagent/common/utils/shellutil.py b/azurelinuxagent/common/utils/shellutil.py new file mode 100644 index 0000000..d273c92 --- /dev/null +++ b/azurelinuxagent/common/utils/shellutil.py @@ -0,0 +1,107 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import platform +import os +import subprocess +from azurelinuxagent.common.future import ustr +import azurelinuxagent.common.logger as logger + +if not hasattr(subprocess,'check_output'): + def check_output(*popenargs, **kwargs): + r"""Backport from subprocess module from python 2.7""" + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, ' + 'it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output + + # Exception classes used by this module. + class CalledProcessError(Exception): + def __init__(self, returncode, cmd, output=None): + self.returncode = returncode + self.cmd = cmd + self.output = output + def __str__(self): + return ("Command '{0}' returned non-zero exit status {1}" + "").format(self.cmd, self.returncode) + + subprocess.check_output=check_output + subprocess.CalledProcessError=CalledProcessError + + +""" +Shell command util functions +""" +def run(cmd, chk_err=True): + """ + Calls run_get_output on 'cmd', returning only the return code. + If chk_err=True then errors will be reported in the log. + If chk_err=False then errors will be suppressed from the log. + """ + retcode,out=run_get_output(cmd,chk_err) + return retcode + +def run_get_output(cmd, chk_err=True, log_cmd=True): + """ + Wrapper for subprocess.check_output. + Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions. + Reports exceptions to Error if chk_err parameter is True + """ + if log_cmd: + logger.verbose(u"run cmd '{0}'", cmd) + try: + output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True) + output = ustr(output, encoding='utf-8', errors="backslashreplace") + except subprocess.CalledProcessError as e : + output = ustr(e.output, encoding='utf-8', errors="backslashreplace") + if chk_err: + if log_cmd: + logger.error(u"run cmd '{0}' failed", e.cmd) + logger.error(u"Error Code:{0}", e.returncode) + logger.error(u"Result:{0}", output) + return e.returncode, output + return 0, output + + +def quote(word_list): + """ + Quote a list or tuple of strings for Unix Shell as words, using the + byte-literal single quote. + + The resulting string is safe for use with ``shell=True`` in ``subprocess``, + and in ``os.system``. ``assert shlex.split(ShellQuote(wordList)) == wordList``. + + See POSIX.1:2013 Vol 3, Chap 2, Sec 2.2.2: + http://pubs.opengroup.org/onlinepubs/9699919799/utilities/V3_chap02.html#tag_18_02_02 + """ + if not isinstance(word_list, (tuple, list)): + word_list = (word_list,) + + return " ".join(list("'{0}'".format(s.replace("'", "'\\''")) for s in word_list)) + + +# End shell command util functions diff --git a/azurelinuxagent/common/utils/textutil.py b/azurelinuxagent/common/utils/textutil.py new file mode 100644 index 0000000..f03c7e6 --- /dev/null +++ b/azurelinuxagent/common/utils/textutil.py @@ -0,0 +1,279 @@ +# Microsoft Azure Linux Agent +# +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ + +import base64 +import crypt +import random +import string +import struct +import sys +import xml.dom.minidom as minidom + +from distutils.version import LooseVersion as Version + + +def parse_doc(xml_text): + """ + Parse xml document from string + """ + # The minidom lib has some issue with unicode in python2. + # Encode the string into utf-8 first + xml_text = xml_text.encode('utf-8') + return minidom.parseString(xml_text) + + +def findall(root, tag, namespace=None): + """ + Get all nodes by tag and namespace under Node root. + """ + if root is None: + return [] + + if namespace is None: + return root.getElementsByTagName(tag) + else: + return root.getElementsByTagNameNS(namespace, tag) + + +def find(root, tag, namespace=None): + """ + Get first node by tag and namespace under Node root. + """ + nodes = findall(root, tag, namespace=namespace) + if nodes is not None and len(nodes) >= 1: + return nodes[0] + else: + return None + + +def gettext(node): + """ + Get node text + """ + if node is None: + return None + + for child in node.childNodes: + if child.nodeType == child.TEXT_NODE: + return child.data + return None + + +def findtext(root, tag, namespace=None): + """ + Get text of node by tag and namespace under Node root. + """ + node = find(root, tag, namespace=namespace) + return gettext(node) + + +def getattrib(node, attr_name): + """ + Get attribute of xml node + """ + if node is not None: + return node.getAttribute(attr_name) + else: + return None + + +def unpack(buf, offset, range): + """ + Unpack bytes into python values. + """ + result = 0 + for i in range: + result = (result << 8) | str_to_ord(buf[offset + i]) + return result + + +def unpack_little_endian(buf, offset, length): + """ + Unpack little endian bytes into python values. + """ + return unpack(buf, offset, list(range(length - 1, -1, -1))) + + +def unpack_big_endian(buf, offset, length): + """ + Unpack big endian bytes into python values. + """ + return unpack(buf, offset, list(range(0, length))) + + +def hex_dump3(buf, offset, length): + """ + Dump range of buf in formatted hex. + """ + return ''.join(['%02X' % str_to_ord(char) for char in buf[offset:offset + length]]) + + +def hex_dump2(buf): + """ + Dump buf in formatted hex. + """ + return hex_dump3(buf, 0, len(buf)) + + +def is_in_range(a, low, high): + """ + Return True if 'a' in 'low' <= a >= 'high' + """ + return (a >= low and a <= high) + + +def is_printable(ch): + """ + Return True if character is displayable. + """ + return (is_in_range(ch, str_to_ord('A'), str_to_ord('Z')) + or is_in_range(ch, str_to_ord('a'), str_to_ord('z')) + or is_in_range(ch, str_to_ord('0'), str_to_ord('9'))) + + +def hex_dump(buffer, size): + """ + Return Hex formated dump of a 'buffer' of 'size'. + """ + if size < 0: + size = len(buffer) + result = "" + for i in range(0, size): + if (i % 16) == 0: + result += "%06X: " % i + byte = buffer[i] + if type(byte) == str: + byte = ord(byte.decode('latin1')) + result += "%02X " % byte + if (i & 15) == 7: + result += " " + if ((i + 1) % 16) == 0 or (i + 1) == size: + j = i + while ((j + 1) % 16) != 0: + result += " " + if (j & 7) == 7: + result += " " + j += 1 + result += " " + for j in range(i - (i % 16), i + 1): + byte = buffer[j] + if type(byte) == str: + byte = str_to_ord(byte.decode('latin1')) + k = '.' + if is_printable(byte): + k = chr(byte) + result += k + if (i + 1) != size: + result += "\n" + return result + + +def str_to_ord(a): + """ + Allows indexing into a string or an array of integers transparently. + Generic utility function. + """ + if type(a) == type(b'') or type(a) == type(u''): + a = ord(a) + return a + + +def compare_bytes(a, b, start, length): + for offset in range(start, start + length): + if str_to_ord(a[offset]) != str_to_ord(b[offset]): + return False + return True + + +def int_to_ip4_addr(a): + """ + Build DHCP request string. + """ + return "%u.%u.%u.%u" % ((a >> 24) & 0xFF, + (a >> 16) & 0xFF, + (a >> 8) & 0xFF, + (a) & 0xFF) + + +def hexstr_to_bytearray(a): + """ + Return hex string packed into a binary struct. + """ + b = b"" + for c in range(0, len(a) // 2): + b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16)) + return b + + +def set_ssh_config(config, name, val): + notfound = True + for i in range(0, len(config)): + if config[i].startswith(name): + config[i] = "{0} {1}".format(name, val) + notfound = False + elif config[i].startswith("Match"): + # Match block must be put in the end of sshd config + break + if notfound: + config.insert(i, "{0} {1}".format(name, val)) + return config + + +def set_ini_config(config, name, val): + notfound = True + nameEqual = name + '=' + length = len(config) + text = "{0}=\"{1}\"".format(name, val) + + for i in reversed(range(0, length)): + if config[i].startswith(nameEqual): + config[i] = text + notfound = False + break + + if notfound: + config.insert(length - 1, text) + + +def remove_bom(c): + if str_to_ord(c[0]) > 128 and str_to_ord(c[1]) > 128 and \ + str_to_ord(c[2]) > 128: + c = c[3:] + return c + + +def gen_password_hash(password, crypt_id, salt_len): + collection = string.ascii_letters + string.digits + salt = ''.join(random.choice(collection) for _ in range(salt_len)) + salt = "${0}${1}".format(crypt_id, salt) + return crypt.crypt(password, salt) + + +def get_bytes_from_pem(pem_str): + base64_bytes = "" + for line in pem_str.split('\n'): + if "----" not in line: + base64_bytes += line + return base64_bytes + + +def b64encode(s): + from azurelinuxagent.common.version import PY_VERSION_MAJOR + if PY_VERSION_MAJOR > 2: + return base64.b64encode(bytes(s, 'utf-8')).decode('utf-8') + return base64.b64encode(s) diff --git a/azurelinuxagent/common/version.py b/azurelinuxagent/common/version.py new file mode 100644 index 0000000..6c4b475 --- /dev/null +++ b/azurelinuxagent/common/version.py @@ -0,0 +1,116 @@ +# Copyright 2014 Microsoft Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Requires Python 2.4+ and Openssl 1.0+ +# + +import os +import re +import platform +import sys + +import azurelinuxagent.common.conf as conf +import azurelinuxagent.common.utils.fileutil as fileutil +from azurelinuxagent.common.utils.flexible_version import FlexibleVersion +from azurelinuxagent.common.future import ustr + + +def get_distro(): + if 'FreeBSD' in platform.system(): + release = re.sub('\-.*\Z', '', ustr(platform.release())) + osinfo = ['freebsd', release, '', 'freebsd'] + elif 'linux_distribution' in dir(platform): + osinfo = list(platform.linux_distribution(full_distribution_name=0)) + full_name = platform.linux_distribution()[0].strip() + osinfo.append(full_name) + else: + osinfo = platform.dist() + + # The platform.py lib has issue with detecting oracle linux distribution. + # Merge the following patch provided by oracle as a temparory fix. + if os.path.exists("/etc/oracle-release"): + osinfo[2] = "oracle" + osinfo[3] = "Oracle Linux" + + # Remove trailing whitespace and quote in distro name + osinfo[0] = osinfo[0].strip('"').strip(' ').lower() + return osinfo + + +AGENT_NAME = "WALinuxAgent" +AGENT_LONG_NAME = "Azure Linux Agent" +AGENT_VERSION = '2.1.5' +AGENT_LONG_VERSION = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION) +AGENT_DESCRIPTION = """\ +The Azure Linux Agent supports the provisioning and running of Linux +VMs in the Azure cloud. This package should be installed on Linux disk +images that are built to run in the Azure environment. +""" + +AGENT_DIR_GLOB = "{0}-*".format(AGENT_NAME) +AGENT_PKG_GLOB = "{0}-*.zip".format(AGENT_NAME) + +AGENT_PATTERN = "{0}-(.*)".format(AGENT_NAME) +AGENT_NAME_PATTERN = re.compile(AGENT_PATTERN) +AGENT_DIR_PATTERN = re.compile(".*/{0}".format(AGENT_PATTERN)) + + +# Set the CURRENT_AGENT and CURRENT_VERSION to match the agent directory name +# - This ensures the agent will "see itself" using the same name and version +# as the code that downloads agents. +def set_current_agent(): + path = os.getcwd() + lib_dir = conf.get_lib_dir() + if lib_dir[-1] != os.path.sep: + lib_dir += os.path.sep + if path[:len(lib_dir)] != lib_dir: + agent = AGENT_LONG_VERSION + version = AGENT_VERSION + else: + agent = path[len(lib_dir):].split(os.path.sep)[0] + version = AGENT_NAME_PATTERN.match(agent).group(1) + return agent, FlexibleVersion(version) +CURRENT_AGENT, CURRENT_VERSION = set_current_agent() + +def is_current_agent_installed(): + return CURRENT_AGENT == AGENT_LONG_VERSION + + +__distro__ = get_distro() +DISTRO_NAME = __distro__[0] +DISTRO_VERSION = __distro__[1] +DISTRO_CODE_NAME = __distro__[2] +DISTRO_FULL_NAME = __distro__[3] + +PY_VERSION = sys.version_info +PY_VERSION_MAJOR = sys.version_info[0] +PY_VERSION_MINOR = sys.version_info[1] +PY_VERSION_MICRO = sys.version_info[2] + +""" +Add this workaround for detecting Snappy Ubuntu Core temporarily, until ubuntu +fixed this bug: https://bugs.launchpad.net/snappy/+bug/1481086 +""" + + +def is_snappy(): + if os.path.exists("/etc/motd"): + motd = fileutil.read_file("/etc/motd") + if "snappy" in motd: + return True + return False + + +if is_snappy(): + DISTRO_FULL_NAME = "Snappy Ubuntu Core" |