summaryrefslogtreecommitdiff
path: root/azurelinuxagent/common
diff options
context:
space:
mode:
Diffstat (limited to 'azurelinuxagent/common')
-rw-r--r--azurelinuxagent/common/__init__.py17
-rw-r--r--azurelinuxagent/common/conf.py181
-rw-r--r--azurelinuxagent/common/dhcp.py400
-rw-r--r--azurelinuxagent/common/event.py124
-rw-r--r--azurelinuxagent/common/exception.py123
-rw-r--r--azurelinuxagent/common/future.py31
-rw-r--r--azurelinuxagent/common/logger.py156
-rw-r--r--azurelinuxagent/common/osutil/__init__.py18
-rw-r--r--azurelinuxagent/common/osutil/coreos.py92
-rw-r--r--azurelinuxagent/common/osutil/debian.py47
-rw-r--r--azurelinuxagent/common/osutil/default.py792
-rw-r--r--azurelinuxagent/common/osutil/factory.py69
-rw-r--r--azurelinuxagent/common/osutil/freebsd.py198
-rw-r--r--azurelinuxagent/common/osutil/redhat.py122
-rw-r--r--azurelinuxagent/common/osutil/suse.py108
-rw-r--r--azurelinuxagent/common/osutil/ubuntu.py66
-rw-r--r--azurelinuxagent/common/protocol/__init__.py21
-rw-r--r--azurelinuxagent/common/protocol/hostplugin.py124
-rw-r--r--azurelinuxagent/common/protocol/metadata.py223
-rw-r--r--azurelinuxagent/common/protocol/ovfenv.py113
-rw-r--r--azurelinuxagent/common/protocol/restapi.py272
-rw-r--r--azurelinuxagent/common/protocol/util.py285
-rw-r--r--azurelinuxagent/common/protocol/wire.py1218
-rw-r--r--azurelinuxagent/common/rdma.py280
-rw-r--r--azurelinuxagent/common/utils/__init__.py17
-rw-r--r--azurelinuxagent/common/utils/cryptutil.py121
-rw-r--r--azurelinuxagent/common/utils/fileutil.py171
-rw-r--r--azurelinuxagent/common/utils/flexible_version.py199
-rw-r--r--azurelinuxagent/common/utils/restutil.py156
-rw-r--r--azurelinuxagent/common/utils/shellutil.py107
-rw-r--r--azurelinuxagent/common/utils/textutil.py279
-rw-r--r--azurelinuxagent/common/version.py116
32 files changed, 6246 insertions, 0 deletions
diff --git a/azurelinuxagent/common/__init__.py b/azurelinuxagent/common/__init__.py
new file mode 100644
index 0000000..1ea2f38
--- /dev/null
+++ b/azurelinuxagent/common/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
diff --git a/azurelinuxagent/common/conf.py b/azurelinuxagent/common/conf.py
new file mode 100644
index 0000000..1a3b0da
--- /dev/null
+++ b/azurelinuxagent/common/conf.py
@@ -0,0 +1,181 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+"""
+Module conf loads and parses configuration file
+"""
+import os
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.exception import AgentConfigError
+
+class ConfigurationProvider(object):
+ """
+ Parse amd store key:values in /etc/waagent.conf.
+ """
+ def __init__(self):
+ self.values = dict()
+
+ def load(self, content):
+ if not content:
+ raise AgentConfigError("Can't not parse empty configuration")
+ for line in content.split('\n'):
+ if not line.startswith("#") and "=" in line:
+ parts = line.split()[0].split('=')
+ value = parts[1].strip("\" ")
+ if value != "None":
+ self.values[parts[0]] = value
+ else:
+ self.values[parts[0]] = None
+
+ def get(self, key, default_val):
+ val = self.values.get(key)
+ return val if val is not None else default_val
+
+ def get_switch(self, key, default_val):
+ val = self.values.get(key)
+ if val is not None and val.lower() == 'y':
+ return True
+ elif val is not None and val.lower() == 'n':
+ return False
+ return default_val
+
+ def get_int(self, key, default_val):
+ try:
+ return int(self.values.get(key))
+ except TypeError:
+ return default_val
+ except ValueError:
+ return default_val
+
+
+__conf__ = ConfigurationProvider()
+
+def load_conf_from_file(conf_file_path, conf=__conf__):
+ """
+ Load conf file from: conf_file_path
+ """
+ if os.path.isfile(conf_file_path) == False:
+ raise AgentConfigError(("Missing configuration in {0}"
+ "").format(conf_file_path))
+ try:
+ content = fileutil.read_file(conf_file_path)
+ conf.load(content)
+ except IOError as err:
+ raise AgentConfigError(("Failed to load conf file:{0}, {1}"
+ "").format(conf_file_path, err))
+
+def enable_rdma(conf=__conf__):
+ return conf.get_switch("OS.EnableRDMA", False)
+
+def get_logs_verbose(conf=__conf__):
+ return conf.get_switch("Logs.Verbose", False)
+
+def get_lib_dir(conf=__conf__):
+ return conf.get("Lib.Dir", "/var/lib/waagent")
+
+def get_dvd_mount_point(conf=__conf__):
+ return conf.get("DVD.MountPoint", "/mnt/cdrom/secure")
+
+def get_agent_pid_file_path(conf=__conf__):
+ return conf.get("Pid.File", "/var/run/waagent.pid")
+
+def get_ext_log_dir(conf=__conf__):
+ return conf.get("Extension.LogDir", "/var/log/azure")
+
+def get_openssl_cmd(conf=__conf__):
+ return conf.get("OS.OpensslPath", "/usr/bin/openssl")
+
+def get_home_dir(conf=__conf__):
+ return conf.get("OS.HomeDir", "/home")
+
+def get_passwd_file_path(conf=__conf__):
+ return conf.get("OS.PasswordPath", "/etc/shadow")
+
+def get_sudoers_dir(conf=__conf__):
+ return conf.get("OS.SudoersDir", "/etc/sudoers.d")
+
+def get_sshd_conf_file_path(conf=__conf__):
+ return conf.get("OS.SshdConfigPath", "/etc/ssh/sshd_config")
+
+def get_root_device_scsi_timeout(conf=__conf__):
+ return conf.get("OS.RootDeviceScsiTimeout", None)
+
+def get_ssh_host_keypair_type(conf=__conf__):
+ return conf.get("Provisioning.SshHostKeyPairType", "rsa")
+
+def get_provision_enabled(conf=__conf__):
+ return conf.get_switch("Provisioning.Enabled", True)
+
+def get_allow_reset_sys_user(conf=__conf__):
+ return conf.get_switch("Provisioning.AllowResetSysUser", False)
+
+def get_regenerate_ssh_host_key(conf=__conf__):
+ return conf.get_switch("Provisioning.RegenerateSshHostKeyPair", False)
+
+def get_delete_root_password(conf=__conf__):
+ return conf.get_switch("Provisioning.DeleteRootPassword", False)
+
+def get_decode_customdata(conf=__conf__):
+ return conf.get_switch("Provisioning.DecodeCustomData", False)
+
+def get_execute_customdata(conf=__conf__):
+ return conf.get_switch("Provisioning.ExecuteCustomData", False)
+
+def get_password_cryptid(conf=__conf__):
+ return conf.get("Provisioning.PasswordCryptId", "6")
+
+def get_password_crypt_salt_len(conf=__conf__):
+ return conf.get_int("Provisioning.PasswordCryptSaltLength", 10)
+
+def get_monitor_hostname(conf=__conf__):
+ return conf.get_switch("Provisioning.MonitorHostName", False)
+
+def get_httpproxy_host(conf=__conf__):
+ return conf.get("HttpProxy.Host", None)
+
+def get_httpproxy_port(conf=__conf__):
+ return conf.get("HttpProxy.Port", None)
+
+def get_detect_scvmm_env(conf=__conf__):
+ return conf.get_switch("DetectScvmmEnv", False)
+
+def get_resourcedisk_format(conf=__conf__):
+ return conf.get_switch("ResourceDisk.Format", False)
+
+def get_resourcedisk_enable_swap(conf=__conf__):
+ return conf.get_switch("ResourceDisk.EnableSwap", False)
+
+def get_resourcedisk_mountpoint(conf=__conf__):
+ return conf.get("ResourceDisk.MountPoint", "/mnt/resource")
+
+def get_resourcedisk_filesystem(conf=__conf__):
+ return conf.get("ResourceDisk.Filesystem", "ext3")
+
+def get_resourcedisk_swap_size_mb(conf=__conf__):
+ return conf.get_int("ResourceDisk.SwapSizeMB", 0)
+
+def get_autoupdate_gafamily(conf=__conf__):
+ return conf.get("AutoUpdate.GAFamily", "Prod")
+
+def get_autoupdate_enabled(conf=__conf__):
+ return conf.get_switch("AutoUpdate.Enabled", True)
+
+def get_autoupdate_frequency(conf=__conf__):
+ return conf.get_int("Autoupdate.Frequency", 3600)
+
diff --git a/azurelinuxagent/common/dhcp.py b/azurelinuxagent/common/dhcp.py
new file mode 100644
index 0000000..d5c90cb
--- /dev/null
+++ b/azurelinuxagent/common/dhcp.py
@@ -0,0 +1,400 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import os
+import socket
+import array
+import time
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.shellutil as shellutil
+from azurelinuxagent.common.utils import fileutil
+from azurelinuxagent.common.utils.textutil import hex_dump, hex_dump2, \
+ hex_dump3, \
+ compare_bytes, str_to_ord, \
+ unpack_big_endian, \
+ int_to_ip4_addr
+from azurelinuxagent.common.exception import DhcpError
+from azurelinuxagent.common.osutil import get_osutil
+
+# the kernel routing table representation of 168.63.129.16
+KNOWN_WIRESERVER_IP_ENTRY = '10813FA8'
+KNOWN_WIRESERVER_IP = '168.63.129.16'
+
+
+def get_dhcp_handler():
+ return DhcpHandler()
+
+
+class DhcpHandler(object):
+ """
+ Azure use DHCP option 245 to pass endpoint ip to VMs.
+ """
+
+ def __init__(self):
+ self.osutil = get_osutil()
+ self.endpoint = None
+ self.gateway = None
+ self.routes = None
+ self._request_broadcast = False
+ self.skip_cache = False
+
+ def run(self):
+ """
+ Send dhcp request
+ Configure default gateway and routes
+ Save wire server endpoint if found
+ """
+ if self.wireserver_route_exists or self.dhcp_cache_exists:
+ return
+
+ self.send_dhcp_req()
+ self.conf_routes()
+
+ def wait_for_network(self):
+ """
+ Wait for network stack to be initialized.
+ """
+ ipv4 = self.osutil.get_ip4_addr()
+ while ipv4 == '' or ipv4 == '0.0.0.0':
+ logger.info("Waiting for network.")
+ time.sleep(10)
+ logger.info("Try to start network interface.")
+ self.osutil.start_network()
+ ipv4 = self.osutil.get_ip4_addr()
+
+ @property
+ def wireserver_route_exists(self):
+ """
+ Determine whether a route to the known wireserver
+ ip already exists, and if so use that as the endpoint.
+ This is true when running in a virtual network.
+ :return: True if a route to KNOWN_WIRESERVER_IP exists.
+ """
+ route_exists = False
+ logger.info("test for route to {0}".format(KNOWN_WIRESERVER_IP))
+ try:
+ route_file = '/proc/net/route'
+ if os.path.exists(route_file) and \
+ KNOWN_WIRESERVER_IP_ENTRY in open(route_file).read():
+ # reset self.gateway and self.routes
+ # we do not need to alter the routing table
+ self.endpoint = KNOWN_WIRESERVER_IP
+ self.gateway = None
+ self.routes = None
+ route_exists = True
+ logger.info("route to {0} exists".format(KNOWN_WIRESERVER_IP))
+ else:
+ logger.warn(
+ "no route exists to {0}".format(KNOWN_WIRESERVER_IP))
+ except Exception as e:
+ logger.error(
+ "could not determine whether route exists to {0}: {1}".format(
+ KNOWN_WIRESERVER_IP, e))
+
+ return route_exists
+
+ @property
+ def dhcp_cache_exists(self):
+ """
+ Check whether the dhcp options cache exists and contains the
+ wireserver endpoint, unless skip_cache is True.
+ :return: True if the cached endpoint was found in the dhcp lease
+ """
+ if self.skip_cache:
+ return False
+
+ exists = False
+
+ logger.info("checking for dhcp lease cache")
+ cached_endpoint = self.osutil.get_dhcp_lease_endpoint()
+ if cached_endpoint is not None:
+ self.endpoint = cached_endpoint
+ exists = True
+ logger.info("cache exists [{0}]".format(exists))
+ return exists
+
+ def conf_routes(self):
+ logger.info("Configure routes")
+ logger.info("Gateway:{0}", self.gateway)
+ logger.info("Routes:{0}", self.routes)
+ # Add default gateway
+ if self.gateway is not None:
+ self.osutil.route_add(0, 0, self.gateway)
+ if self.routes is not None:
+ for route in self.routes:
+ self.osutil.route_add(route[0], route[1], route[2])
+
+ def _send_dhcp_req(self, request):
+ __waiting_duration__ = [0, 10, 30, 60, 60]
+ for duration in __waiting_duration__:
+ try:
+ self.osutil.allow_dhcp_broadcast()
+ response = socket_send(request)
+ validate_dhcp_resp(request, response)
+ return response
+ except DhcpError as e:
+ logger.warn("Failed to send DHCP request: {0}", e)
+ time.sleep(duration)
+ return None
+
+ def send_dhcp_req(self):
+ """
+ Build dhcp request with mac addr
+ Configure route to allow dhcp traffic
+ Stop dhcp service if necessary
+ """
+ logger.info("Send dhcp request")
+ mac_addr = self.osutil.get_mac_addr()
+
+ # Do unicast first, then fallback to broadcast if fails.
+ req = build_dhcp_request(mac_addr, self._request_broadcast)
+ if not self._request_broadcast:
+ self._request_broadcast = True
+
+ # Temporary allow broadcast for dhcp. Remove the route when done.
+ missing_default_route = self.osutil.is_missing_default_route()
+ ifname = self.osutil.get_if_name()
+ if missing_default_route:
+ self.osutil.set_route_for_dhcp_broadcast(ifname)
+
+ # In some distros, dhcp service needs to be shutdown before agent probe
+ # endpoint through dhcp.
+ if self.osutil.is_dhcp_enabled():
+ self.osutil.stop_dhcp_service()
+
+ resp = self._send_dhcp_req(req)
+
+ if self.osutil.is_dhcp_enabled():
+ self.osutil.start_dhcp_service()
+
+ if missing_default_route:
+ self.osutil.remove_route_for_dhcp_broadcast(ifname)
+
+ if resp is None:
+ raise DhcpError("Failed to receive dhcp response.")
+ self.endpoint, self.gateway, self.routes = parse_dhcp_resp(resp)
+
+
+def validate_dhcp_resp(request, response):
+ bytes_recv = len(response)
+ if bytes_recv < 0xF6:
+ logger.error("HandleDhcpResponse: Too few bytes received:{0}",
+ bytes_recv)
+ return False
+
+ logger.verbose("BytesReceived:{0}", hex(bytes_recv))
+ logger.verbose("DHCP response:{0}", hex_dump(response, bytes_recv))
+
+ # check transactionId, cookie, MAC address cookie should never mismatch
+ # transactionId and MAC address may mismatch if we see a response
+ # meant from another machine
+ if not compare_bytes(request, response, 0xEC, 4):
+ logger.verbose("Cookie not match:\nsend={0},\nreceive={1}",
+ hex_dump3(request, 0xEC, 4),
+ hex_dump3(response, 0xEC, 4))
+ raise DhcpError("Cookie in dhcp respones doesn't match the request")
+
+ if not compare_bytes(request, response, 4, 4):
+ logger.verbose("TransactionID not match:\nsend={0},\nreceive={1}",
+ hex_dump3(request, 4, 4),
+ hex_dump3(response, 4, 4))
+ raise DhcpError("TransactionID in dhcp respones "
+ "doesn't match the request")
+
+ if not compare_bytes(request, response, 0x1C, 6):
+ logger.verbose("Mac Address not match:\nsend={0},\nreceive={1}",
+ hex_dump3(request, 0x1C, 6),
+ hex_dump3(response, 0x1C, 6))
+ raise DhcpError("Mac Addr in dhcp respones "
+ "doesn't match the request")
+
+
+def parse_route(response, option, i, length, bytes_recv):
+ # http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
+ logger.verbose("Routes at offset: {0} with length:{1}", hex(i),
+ hex(length))
+ routes = []
+ if length < 5:
+ logger.error("Data too small for option:{0}", option)
+ j = i + 2
+ while j < (i + length + 2):
+ mask_len_bits = str_to_ord(response[j])
+ mask_len_bytes = (((mask_len_bits + 7) & ~7) >> 3)
+ mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - mask_len_bits))
+ j += 1
+ net = unpack_big_endian(response, j, mask_len_bytes)
+ net <<= (32 - mask_len_bytes * 8)
+ net &= mask
+ j += mask_len_bytes
+ gateway = unpack_big_endian(response, j, 4)
+ j += 4
+ routes.append((net, mask, gateway))
+ if j != (i + length + 2):
+ logger.error("Unable to parse routes")
+ return routes
+
+
+def parse_ip_addr(response, option, i, length, bytes_recv):
+ if i + 5 < bytes_recv:
+ if length != 4:
+ logger.error("Endpoint or Default Gateway not 4 bytes")
+ return None
+ addr = unpack_big_endian(response, i + 2, 4)
+ ip_addr = int_to_ip4_addr(addr)
+ return ip_addr
+ else:
+ logger.error("Data too small for option:{0}", option)
+ return None
+
+
+def parse_dhcp_resp(response):
+ """
+ Parse DHCP response:
+ Returns endpoint server or None on error.
+ """
+ logger.verbose("parse Dhcp Response")
+ bytes_recv = len(response)
+ endpoint = None
+ gateway = None
+ routes = None
+
+ # Walk all the returned options, parsing out what we need, ignoring the
+ # others. We need the custom option 245 to find the the endpoint we talk to
+ # as well as to handle some Linux DHCP client incompatibilities;
+ # options 3 for default gateway and 249 for routes; 255 is end.
+
+ i = 0xF0 # offset to first option
+ while i < bytes_recv:
+ option = str_to_ord(response[i])
+ length = 0
+ if (i + 1) < bytes_recv:
+ length = str_to_ord(response[i + 1])
+ logger.verbose("DHCP option {0} at offset:{1} with length:{2}",
+ hex(option), hex(i), hex(length))
+ if option == 255:
+ logger.verbose("DHCP packet ended at offset:{0}", hex(i))
+ break
+ elif option == 249:
+ routes = parse_route(response, option, i, length, bytes_recv)
+ elif option == 3:
+ gateway = parse_ip_addr(response, option, i, length, bytes_recv)
+ logger.verbose("Default gateway:{0}, at {1}", gateway, hex(i))
+ elif option == 245:
+ endpoint = parse_ip_addr(response, option, i, length, bytes_recv)
+ logger.verbose("Azure wire protocol endpoint:{0}, at {1}",
+ endpoint,
+ hex(i))
+ else:
+ logger.verbose("Skipping DHCP option:{0} at {1} with length {2}",
+ hex(option), hex(i), hex(length))
+ i += length + 2
+ return endpoint, gateway, routes
+
+
+def socket_send(request):
+ sock = None
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
+ socket.IPPROTO_UDP)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(("0.0.0.0", 68))
+ sock.sendto(request, ("<broadcast>", 67))
+ sock.settimeout(10)
+ logger.verbose("Send DHCP request: Setting socket.timeout=10, "
+ "entering recv")
+ response = sock.recv(1024)
+ return response
+ except IOError as e:
+ raise DhcpError("{0}".format(e))
+ finally:
+ if sock is not None:
+ sock.close()
+
+
+def build_dhcp_request(mac_addr, request_broadcast):
+ """
+ Build DHCP request string.
+ """
+ #
+ # typedef struct _DHCP {
+ # UINT8 Opcode; /* op: BOOTREQUEST or BOOTREPLY */
+ # UINT8 HardwareAddressType; /* htype: ethernet */
+ # UINT8 HardwareAddressLength; /* hlen: 6 (48 bit mac address) */
+ # UINT8 Hops; /* hops: 0 */
+ # UINT8 TransactionID[4]; /* xid: random */
+ # UINT8 Seconds[2]; /* secs: 0 */
+ # UINT8 Flags[2]; /* flags: 0 or 0x8000 for broadcast*/
+ # UINT8 ClientIpAddress[4]; /* ciaddr: 0 */
+ # UINT8 YourIpAddress[4]; /* yiaddr: 0 */
+ # UINT8 ServerIpAddress[4]; /* siaddr: 0 */
+ # UINT8 RelayAgentIpAddress[4]; /* giaddr: 0 */
+ # UINT8 ClientHardwareAddress[16]; /* chaddr: 6 byte eth MAC address */
+ # UINT8 ServerName[64]; /* sname: 0 */
+ # UINT8 BootFileName[128]; /* file: 0 */
+ # UINT8 MagicCookie[4]; /* 99 130 83 99 */
+ # /* 0x63 0x82 0x53 0x63 */
+ # /* options -- hard code ours */
+ #
+ # UINT8 MessageTypeCode; /* 53 */
+ # UINT8 MessageTypeLength; /* 1 */
+ # UINT8 MessageType; /* 1 for DISCOVER */
+ # UINT8 End; /* 255 */
+ # } DHCP;
+ #
+
+ # tuple of 244 zeros
+ # (struct.pack_into would be good here, but requires Python 2.5)
+ request = [0] * 244
+
+ trans_id = gen_trans_id()
+
+ # Opcode = 1
+ # HardwareAddressType = 1 (ethernet/MAC)
+ # HardwareAddressLength = 6 (ethernet/MAC/48 bits)
+ for a in range(0, 3):
+ request[a] = [1, 1, 6][a]
+
+ # fill in transaction id (random number to ensure response matches request)
+ for a in range(0, 4):
+ request[4 + a] = str_to_ord(trans_id[a])
+
+ logger.verbose("BuildDhcpRequest: transactionId:%s,%04X" % (
+ hex_dump2(trans_id),
+ unpack_big_endian(request, 4, 4)))
+
+ if request_broadcast:
+ # set broadcast flag to true to request the dhcp sever
+ # to respond to a boradcast address,
+ # this is useful when user dhclient fails.
+ request[0x0A] = 0x80;
+
+ # fill in ClientHardwareAddress
+ for a in range(0, 6):
+ request[0x1C + a] = str_to_ord(mac_addr[a])
+
+ # DHCP Magic Cookie: 99, 130, 83, 99
+ # MessageTypeCode = 53 DHCP Message Type
+ # MessageTypeLength = 1
+ # MessageType = DHCPDISCOVER
+ # End = 255 DHCP_END
+ for a in range(0, 8):
+ request[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a]
+ return array.array("B", request)
+
+
+def gen_trans_id():
+ return os.urandom(4)
diff --git a/azurelinuxagent/common/event.py b/azurelinuxagent/common/event.py
new file mode 100644
index 0000000..374b0e7
--- /dev/null
+++ b/azurelinuxagent/common/event.py
@@ -0,0 +1,124 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import sys
+import traceback
+import atexit
+import json
+import time
+import datetime
+import threading
+import platform
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import EventError, ProtocolError
+from azurelinuxagent.common.future import ustr
+from azurelinuxagent.common.protocol.restapi import TelemetryEventParam, \
+ TelemetryEventList, \
+ TelemetryEvent, \
+ set_properties, get_properties
+from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \
+ DISTRO_CODE_NAME, AGENT_VERSION
+
+
+class WALAEventOperation:
+ HeartBeat="HeartBeat"
+ Provision = "Provision"
+ Install = "Install"
+ UnInstall = "UnInstall"
+ Disable = "Disable"
+ Enable = "Enable"
+ Download = "Download"
+ Upgrade = "Upgrade"
+ Update = "Update"
+ ActivateResourceDisk="ActivateResourceDisk"
+ UnhandledError="UnhandledError"
+
+class EventLogger(object):
+ def __init__(self):
+ self.event_dir = None
+
+ def save_event(self, data):
+ if self.event_dir is None:
+ logger.warn("Event reporter is not initialized.")
+ return
+
+ if not os.path.exists(self.event_dir):
+ os.mkdir(self.event_dir)
+ os.chmod(self.event_dir, 0o700)
+ if len(os.listdir(self.event_dir)) > 1000:
+ raise EventError("Too many files under: {0}".format(self.event_dir))
+
+ filename = os.path.join(self.event_dir, ustr(int(time.time()*1000000)))
+ try:
+ with open(filename+".tmp",'wb+') as hfile:
+ hfile.write(data.encode("utf-8"))
+ os.rename(filename+".tmp", filename+".tld")
+ except IOError as e:
+ raise EventError("Failed to write events to file:{0}", e)
+
+ def add_event(self, name, op="", is_success=True, duration=0, version=AGENT_VERSION,
+ message="", evt_type="", is_internal=False):
+ event = TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975")
+ event.parameters.append(TelemetryEventParam('Name', name))
+ event.parameters.append(TelemetryEventParam('Version', str(version)))
+ event.parameters.append(TelemetryEventParam('IsInternal', is_internal))
+ event.parameters.append(TelemetryEventParam('Operation', op))
+ event.parameters.append(TelemetryEventParam('OperationSuccess',
+ is_success))
+ event.parameters.append(TelemetryEventParam('Message', message))
+ event.parameters.append(TelemetryEventParam('Duration', duration))
+ event.parameters.append(TelemetryEventParam('ExtensionType', evt_type))
+
+ data = get_properties(event)
+ try:
+ self.save_event(json.dumps(data))
+ except EventError as e:
+ logger.error("{0}", e)
+
+__event_logger__ = EventLogger()
+
+def add_event(name, op="", is_success=True, duration=0, version=AGENT_VERSION,
+ message="", evt_type="", is_internal=False,
+ reporter=__event_logger__):
+ log = logger.info if is_success else logger.error
+ log("Event: name={0}, op={1}, message={2}", name, op, message)
+
+ if reporter.event_dir is None:
+ logger.warn("Event reporter is not initialized.")
+ return
+ reporter.add_event(name, op=op, is_success=is_success, duration=duration,
+ version=str(version), message=message, evt_type=evt_type,
+ is_internal=is_internal)
+
+def init_event_logger(event_dir, reporter=__event_logger__):
+ reporter.event_dir = event_dir
+
+def dump_unhandled_err(name):
+ if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \
+ hasattr(sys, 'last_traceback'):
+ last_type = getattr(sys, 'last_type')
+ last_value = getattr(sys, 'last_value')
+ last_traceback = getattr(sys, 'last_traceback')
+ error = traceback.format_exception(last_type, last_value,
+ last_traceback)
+ message= "".join(error)
+ add_event(name, is_success=False, message=message,
+ op=WALAEventOperation.UnhandledError)
+
+def enable_unhandled_err_dump(name):
+ atexit.register(dump_unhandled_err, name)
diff --git a/azurelinuxagent/common/exception.py b/azurelinuxagent/common/exception.py
new file mode 100644
index 0000000..457490c
--- /dev/null
+++ b/azurelinuxagent/common/exception.py
@@ -0,0 +1,123 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+"""
+Defines all exceptions
+"""
+
+class AgentError(Exception):
+ """
+ Base class of agent error.
+ """
+ def __init__(self, errno, msg, inner=None):
+ msg = u"({0}){1}".format(errno, msg)
+ if inner is not None:
+ msg = u"{0} \n inner error: {1}".format(msg, inner)
+ super(AgentError, self).__init__(msg)
+
+class AgentConfigError(AgentError):
+ """
+ When configure file is not found or malformed.
+ """
+ def __init__(self, msg=None, inner=None):
+ super(AgentConfigError, self).__init__('000001', msg, inner)
+
+class AgentNetworkError(AgentError):
+ """
+ When network is not avaiable.
+ """
+ def __init__(self, msg=None, inner=None):
+ super(AgentNetworkError, self).__init__('000002', msg, inner)
+
+class ExtensionError(AgentError):
+ """
+ When failed to execute an extension
+ """
+ def __init__(self, msg=None, inner=None):
+ super(ExtensionError, self).__init__('000003', msg, inner)
+
+class ProvisionError(AgentError):
+ """
+ When provision failed
+ """
+ def __init__(self, msg=None, inner=None):
+ super(ProvisionError, self).__init__('000004', msg, inner)
+
+class ResourceDiskError(AgentError):
+ """
+ Mount resource disk failed
+ """
+ def __init__(self, msg=None, inner=None):
+ super(ResourceDiskError, self).__init__('000005', msg, inner)
+
+class DhcpError(AgentError):
+ """
+ Failed to handle dhcp response
+ """
+ def __init__(self, msg=None, inner=None):
+ super(DhcpError, self).__init__('000006', msg, inner)
+
+class OSUtilError(AgentError):
+ """
+ Failed to perform operation to OS configuration
+ """
+ def __init__(self, msg=None, inner=None):
+ super(OSUtilError, self).__init__('000007', msg, inner)
+
+class ProtocolError(AgentError):
+ """
+ Azure protocol error
+ """
+ def __init__(self, msg=None, inner=None):
+ super(ProtocolError, self).__init__('000008', msg, inner)
+
+class ProtocolNotFoundError(ProtocolError):
+ """
+ Azure protocol endpoint not found
+ """
+ def __init__(self, msg=None, inner=None):
+ super(ProtocolNotFoundError, self).__init__(msg, inner)
+
+class HttpError(AgentError):
+ """
+ Http request failure
+ """
+ def __init__(self, msg=None, inner=None):
+ super(HttpError, self).__init__('000009', msg, inner)
+
+class EventError(AgentError):
+ """
+ Event reporting error
+ """
+ def __init__(self, msg=None, inner=None):
+ super(EventError, self).__init__('000010', msg, inner)
+
+class CryptError(AgentError):
+ """
+ Encrypt/Decrypt error
+ """
+ def __init__(self, msg=None, inner=None):
+ super(CryptError, self).__init__('000011', msg, inner)
+
+class UpdateError(AgentError):
+ """
+ Update Guest Agent error
+ """
+ def __init__(self, msg=None, inner=None):
+ super(UpdateError, self).__init__('000012', msg, inner)
+
diff --git a/azurelinuxagent/common/future.py b/azurelinuxagent/common/future.py
new file mode 100644
index 0000000..8509732
--- /dev/null
+++ b/azurelinuxagent/common/future.py
@@ -0,0 +1,31 @@
+import sys
+
+"""
+Add alies for python2 and python3 libs and fucntions.
+"""
+
+if sys.version_info[0]== 3:
+ import http.client as httpclient
+ from urllib.parse import urlparse
+
+ """Rename Python3 str to ustr"""
+ ustr = str
+
+ bytebuffer = memoryview
+
+ read_input = input
+
+elif sys.version_info[0] == 2:
+ import httplib as httpclient
+ from urlparse import urlparse
+
+ """Rename Python2 unicode to ustr"""
+ ustr = unicode
+
+ bytebuffer = buffer
+
+ read_input = raw_input
+
+else:
+ raise ImportError("Unknown python version:{0}".format(sys.version_info))
+
diff --git a/azurelinuxagent/common/logger.py b/azurelinuxagent/common/logger.py
new file mode 100644
index 0000000..c1eb18f
--- /dev/null
+++ b/azurelinuxagent/common/logger.py
@@ -0,0 +1,156 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and openssl_bin 1.0+
+#
+"""
+Log utils
+"""
+import os
+import sys
+from azurelinuxagent.common.future import ustr
+from datetime import datetime
+
+class Logger(object):
+ """
+ Logger class
+ """
+ def __init__(self, logger=None, prefix=None):
+ self.appenders = []
+ if logger is not None:
+ self.appenders.extend(logger.appenders)
+ self.prefix = prefix
+
+ def verbose(self, msg_format, *args):
+ self.log(LogLevel.VERBOSE, msg_format, *args)
+
+ def info(self, msg_format, *args):
+ self.log(LogLevel.INFO, msg_format, *args)
+
+ def warn(self, msg_format, *args):
+ self.log(LogLevel.WARNING, msg_format, *args)
+
+ def error(self, msg_format, *args):
+ self.log(LogLevel.ERROR, msg_format, *args)
+
+ def log(self, level, msg_format, *args):
+ #if msg_format is not unicode convert it to unicode
+ if type(msg_format) is not ustr:
+ msg_format = ustr(msg_format, errors="backslashreplace")
+ if len(args) > 0:
+ msg = msg_format.format(*args)
+ else:
+ msg = msg_format
+ time = datetime.now().strftime(u'%Y/%m/%d %H:%M:%S.%f')
+ level_str = LogLevel.STRINGS[level]
+ if self.prefix is not None:
+ log_item = u"{0} {1} {2} {3}\n".format(time, level_str, self.prefix,
+ msg)
+ else:
+ log_item = u"{0} {1} {2}\n".format(time, level_str, msg)
+
+ log_item = ustr(log_item.encode('ascii', "backslashreplace"),
+ encoding="ascii")
+ for appender in self.appenders:
+ appender.write(level, log_item)
+
+ def add_appender(self, appender_type, level, path):
+ appender = _create_logger_appender(appender_type, level, path)
+ self.appenders.append(appender)
+
+class ConsoleAppender(object):
+ def __init__(self, level, path):
+ self.level = level
+ self.path = path
+
+ def write(self, level, msg):
+ if self.level <= level:
+ try:
+ with open(self.path, "w") as console:
+ console.write(msg)
+ except IOError:
+ pass
+
+class FileAppender(object):
+ def __init__(self, level, path):
+ self.level = level
+ self.path = path
+
+ def write(self, level, msg):
+ if self.level <= level:
+ try:
+ with open(self.path, "a+") as log_file:
+ log_file.write(msg)
+ except IOError:
+ pass
+
+class StdoutAppender(object):
+ def __init__(self, level):
+ self.level = level
+
+ def write(self, level, msg):
+ if self.level <= level:
+ try:
+ sys.stdout.write(msg)
+ except IOError:
+ pass
+
+#Initialize logger instance
+DEFAULT_LOGGER = Logger()
+
+class LogLevel(object):
+ VERBOSE = 0
+ INFO = 1
+ WARNING = 2
+ ERROR = 3
+ STRINGS = [
+ "VERBOSE",
+ "INFO",
+ "WARNING",
+ "ERROR"
+ ]
+
+class AppenderType(object):
+ FILE = 0
+ CONSOLE = 1
+ STDOUT = 2
+
+def add_logger_appender(appender_type, level=LogLevel.INFO, path=None):
+ DEFAULT_LOGGER.add_appender(appender_type, level, path)
+
+def verbose(msg_format, *args):
+ DEFAULT_LOGGER.verbose(msg_format, *args)
+
+def info(msg_format, *args):
+ DEFAULT_LOGGER.info(msg_format, *args)
+
+def warn(msg_format, *args):
+ DEFAULT_LOGGER.warn(msg_format, *args)
+
+def error(msg_format, *args):
+ DEFAULT_LOGGER.error(msg_format, *args)
+
+def log(level, msg_format, *args):
+ DEFAULT_LOGGER.log(level, msg_format, args)
+
+def _create_logger_appender(appender_type, level=LogLevel.INFO, path=None):
+ if appender_type == AppenderType.CONSOLE:
+ return ConsoleAppender(level, path)
+ elif appender_type == AppenderType.FILE:
+ return FileAppender(level, path)
+ elif appender_type == AppenderType.STDOUT:
+ return StdoutAppender(level)
+ else:
+ raise ValueError("Unknown appender type")
+
diff --git a/azurelinuxagent/common/osutil/__init__.py b/azurelinuxagent/common/osutil/__init__.py
new file mode 100644
index 0000000..3b5ba3b
--- /dev/null
+++ b/azurelinuxagent/common/osutil/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+from azurelinuxagent.common.osutil.factory import get_osutil
diff --git a/azurelinuxagent/common/osutil/coreos.py b/azurelinuxagent/common/osutil/coreos.py
new file mode 100644
index 0000000..e26fd97
--- /dev/null
+++ b/azurelinuxagent/common/osutil/coreos.py
@@ -0,0 +1,92 @@
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import re
+import pwd
+import shutil
+import socket
+import array
+import struct
+import fcntl
+import time
+import base64
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+from azurelinuxagent.common.osutil.default import DefaultOSUtil
+
+class CoreOSUtil(DefaultOSUtil):
+ def __init__(self):
+ super(CoreOSUtil, self).__init__()
+ self.agent_conf_file_path = '/usr/share/oem/waagent.conf'
+ self.waagent_path='/usr/share/oem/bin/waagent'
+ self.python_path='/usr/share/oem/python/bin'
+ if 'PATH' in os.environ:
+ path = "{0}:{1}".format(os.environ['PATH'], self.python_path)
+ else:
+ path = self.python_path
+ os.environ['PATH'] = path
+
+ if 'PYTHONPATH' in os.environ:
+ py_path = os.environ['PYTHONPATH']
+ py_path = "{0}:{1}".format(py_path, self.waagent_path)
+ else:
+ py_path = self.waagent_path
+ os.environ['PYTHONPATH'] = py_path
+
+ def is_sys_user(self, username):
+ #User 'core' is not a sysuser
+ if username == 'core':
+ return False
+ return super(CoreOSUtil, self).is_sys_user(username)
+
+ def is_dhcp_enabled(self):
+ return True
+
+ def start_network(self) :
+ return shellutil.run("systemctl start systemd-networkd", chk_err=False)
+
+ def restart_if(self, iface):
+ shellutil.run("systemctl restart systemd-networkd")
+
+ def restart_ssh_service(self):
+ # SSH is socket activated on CoreOS. No need to restart it.
+ pass
+
+ def stop_dhcp_service(self):
+ return shellutil.run("systemctl stop systemd-networkd", chk_err=False)
+
+ def start_dhcp_service(self):
+ return shellutil.run("systemctl start systemd-networkd", chk_err=False)
+
+ def start_agent_service(self):
+ return shellutil.run("systemctl start wagent", chk_err=False)
+
+ def stop_agent_service(self):
+ return shellutil.run("systemctl stop wagent", chk_err=False)
+
+ def get_dhcp_pid(self):
+ ret= shellutil.run_get_output("pidof systemd-networkd")
+ return ret[1] if ret[0] == 0 else None
+
+ def conf_sshd(self, disable_password):
+ #In CoreOS, /etc/sshd_config is mount readonly. Skip the setting
+ pass
+
diff --git a/azurelinuxagent/common/osutil/debian.py b/azurelinuxagent/common/osutil/debian.py
new file mode 100644
index 0000000..f455572
--- /dev/null
+++ b/azurelinuxagent/common/osutil/debian.py
@@ -0,0 +1,47 @@
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import re
+import pwd
+import shutil
+import socket
+import array
+import struct
+import fcntl
+import time
+import base64
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+from azurelinuxagent.common.osutil.default import DefaultOSUtil
+
+class DebianOSUtil(DefaultOSUtil):
+ def __init__(self):
+ super(DebianOSUtil, self).__init__()
+
+ def restart_ssh_service(self):
+ return shellutil.run("service sshd restart", chk_err=False)
+
+ def stop_agent_service(self):
+ return shellutil.run("service azurelinuxagent stop", chk_err=False)
+
+ def start_agent_service(self):
+ return shellutil.run("service azurelinuxagent start", chk_err=False)
+
diff --git a/azurelinuxagent/common/osutil/default.py b/azurelinuxagent/common/osutil/default.py
new file mode 100644
index 0000000..c243c85
--- /dev/null
+++ b/azurelinuxagent/common/osutil/default.py
@@ -0,0 +1,792 @@
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import re
+import shutil
+import socket
+import array
+import struct
+import time
+import pwd
+import fcntl
+import base64
+import glob
+import datetime
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.conf as conf
+from azurelinuxagent.common.exception import OSUtilError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
+
+__RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules",
+ "/etc/udev/rules.d/70-persistent-net.rules" ]
+
+"""
+Define distro specific behavior. OSUtil class defines default behavior
+for all distros. Each concrete distro classes could overwrite default behavior
+if needed.
+"""
+
+class DefaultOSUtil(object):
+
+ def __init__(self):
+ self.agent_conf_file_path = '/etc/waagent.conf'
+ self.selinux=None
+
+ def get_agent_conf_file_path(self):
+ return self.agent_conf_file_path
+
+ def get_userentry(self, username):
+ try:
+ return pwd.getpwnam(username)
+ except KeyError:
+ return None
+
+ def is_sys_user(self, username):
+ """
+ Check whether use is a system user.
+ If reset sys user is allowed in conf, return False
+ Otherwise, check whether UID is less than UID_MIN
+ """
+ if conf.get_allow_reset_sys_user():
+ return False
+
+ userentry = self.get_userentry(username)
+ uidmin = None
+ try:
+ uidmin_def = fileutil.get_line_startingwith("UID_MIN",
+ "/etc/login.defs")
+ if uidmin_def is not None:
+ uidmin = int(uidmin_def.split()[1])
+ except IOError as e:
+ pass
+ if uidmin == None:
+ uidmin = 100
+ if userentry != None and userentry[2] < uidmin:
+ return True
+ else:
+ return False
+
+ def useradd(self, username, expiration=None):
+ """
+ Create user account with 'username'
+ """
+ userentry = self.get_userentry(username)
+ if userentry is not None:
+ logger.info("User {0} already exists, skip useradd", username)
+ return
+
+ if expiration is not None:
+ cmd = "useradd -m {0} -e {1}".format(username, expiration)
+ else:
+ cmd = "useradd -m {0}".format(username)
+ retcode, out = shellutil.run_get_output(cmd)
+ if retcode != 0:
+ raise OSUtilError(("Failed to create user account:{0}, "
+ "retcode:{1}, "
+ "output:{2}").format(username, retcode, out))
+
+ def chpasswd(self, username, password, crypt_id=6, salt_len=10):
+ if self.is_sys_user(username):
+ raise OSUtilError(("User {0} is a system user. "
+ "Will not set passwd.").format(username))
+ passwd_hash = textutil.gen_password_hash(password, crypt_id, salt_len)
+ cmd = "usermod -p '{0}' {1}".format(passwd_hash, username)
+ ret, output = shellutil.run_get_output(cmd, log_cmd=False)
+ if ret != 0:
+ raise OSUtilError(("Failed to set password for {0}: {1}"
+ "").format(username, output))
+
+ def conf_sudoer(self, username, nopasswd=False, remove=False):
+ sudoers_dir = conf.get_sudoers_dir()
+ sudoers_wagent = os.path.join(sudoers_dir, 'waagent')
+
+ if not remove:
+ # for older distros create sudoers.d
+ if not os.path.isdir(sudoers_dir):
+ sudoers_file = os.path.join(sudoers_dir, '../sudoers')
+ # create the sudoers.d directory
+ os.mkdir(sudoers_dir)
+ # add the include of sudoers.d to the /etc/sudoers
+ sudoers = '\n#includedir ' + sudoers_dir + '\n'
+ fileutil.append_file(sudoers_file, sudoers)
+ sudoer = None
+ if nopasswd:
+ sudoer = "{0} ALL=(ALL) NOPASSWD: ALL\n".format(username)
+ else:
+ sudoer = "{0} ALL=(ALL) ALL\n".format(username)
+ fileutil.append_file(sudoers_wagent, sudoer)
+ fileutil.chmod(sudoers_wagent, 0o440)
+ else:
+ #Remove user from sudoers
+ if os.path.isfile(sudoers_wagent):
+ try:
+ content = fileutil.read_file(sudoers_wagent)
+ sudoers = content.split("\n")
+ sudoers = [x for x in sudoers if username not in x]
+ fileutil.write_file(sudoers_wagent, "\n".join(sudoers))
+ except IOError as e:
+ raise OSUtilError("Failed to remove sudoer: {0}".format(e))
+
+ def del_root_password(self):
+ try:
+ passwd_file_path = conf.get_passwd_file_path()
+ passwd_content = fileutil.read_file(passwd_file_path)
+ passwd = passwd_content.split('\n')
+ new_passwd = [x for x in passwd if not x.startswith("root:")]
+ new_passwd.insert(0, "root:*LOCK*:14600::::::")
+ fileutil.write_file(passwd_file_path, "\n".join(new_passwd))
+ except IOError as e:
+ raise OSUtilError("Failed to delete root password:{0}".format(e))
+
+ def _norm_path(self, filepath):
+ home = conf.get_home_dir()
+ # Expand HOME variable if present in path
+ path = os.path.normpath(filepath.replace("$HOME", home))
+ return path
+
+ def deploy_ssh_keypair(self, username, keypair):
+ """
+ Deploy id_rsa and id_rsa.pub
+ """
+ path, thumbprint = keypair
+ path = self._norm_path(path)
+ dir_path = os.path.dirname(path)
+ fileutil.mkdir(dir_path, mode=0o700, owner=username)
+ lib_dir = conf.get_lib_dir()
+ prv_path = os.path.join(lib_dir, thumbprint + '.prv')
+ if not os.path.isfile(prv_path):
+ raise OSUtilError("Can't find {0}.prv".format(thumbprint))
+ shutil.copyfile(prv_path, path)
+ pub_path = path + '.pub'
+ crytputil = CryptUtil(conf.get_openssl_cmd())
+ pub = crytputil.get_pubkey_from_prv(prv_path)
+ fileutil.write_file(pub_path, pub)
+ self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
+ self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
+ os.chmod(path, 0o644)
+ os.chmod(pub_path, 0o600)
+
+ def openssl_to_openssh(self, input_file, output_file):
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.crt_to_ssh(input_file, output_file)
+
+ def deploy_ssh_pubkey(self, username, pubkey):
+ """
+ Deploy authorized_key
+ """
+ path, thumbprint, value = pubkey
+ if path is None:
+ raise OSUtilError("Public key path is None")
+
+ crytputil = CryptUtil(conf.get_openssl_cmd())
+
+ path = self._norm_path(path)
+ dir_path = os.path.dirname(path)
+ fileutil.mkdir(dir_path, mode=0o700, owner=username)
+ if value is not None:
+ if not value.startswith("ssh-"):
+ raise OSUtilError("Bad public key: {0}".format(value))
+ fileutil.write_file(path, value)
+ elif thumbprint is not None:
+ lib_dir = conf.get_lib_dir()
+ crt_path = os.path.join(lib_dir, thumbprint + '.crt')
+ if not os.path.isfile(crt_path):
+ raise OSUtilError("Can't find {0}.crt".format(thumbprint))
+ pub_path = os.path.join(lib_dir, thumbprint + '.pub')
+ pub = crytputil.get_pubkey_from_crt(crt_path)
+ fileutil.write_file(pub_path, pub)
+ self.set_selinux_context(pub_path,
+ 'unconfined_u:object_r:ssh_home_t:s0')
+ self.openssl_to_openssh(pub_path, path)
+ fileutil.chmod(pub_path, 0o600)
+ else:
+ raise OSUtilError("SSH public key Fingerprint and Value are None")
+
+ self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
+ fileutil.chowner(path, username)
+ fileutil.chmod(path, 0o644)
+
+ def is_selinux_system(self):
+ """
+ Checks and sets self.selinux = True if SELinux is available on system.
+ """
+ if self.selinux == None:
+ if shellutil.run("which getenforce", chk_err=False) == 0:
+ self.selinux = True
+ else:
+ self.selinux = False
+ return self.selinux
+
+ def is_selinux_enforcing(self):
+ """
+ Calls shell command 'getenforce' and returns True if 'Enforcing'.
+ """
+ if self.is_selinux_system():
+ output = shellutil.run_get_output("getenforce")[1]
+ return output.startswith("Enforcing")
+ else:
+ return False
+
+ def set_selinux_enforce(self, state):
+ """
+ Calls shell command 'setenforce' with 'state'
+ and returns resulting exit code.
+ """
+ if self.is_selinux_system():
+ if state: s = '1'
+ else: s='0'
+ return shellutil.run("setenforce "+s)
+
+ def set_selinux_context(self, path, con):
+ """
+ Calls shell 'chcon' with 'path' and 'con' context.
+ Returns exit result.
+ """
+ if self.is_selinux_system():
+ if not os.path.exists(path):
+ logger.error("Path does not exist: {0}".format(path))
+ return 1
+ return shellutil.run('chcon ' + con + ' ' + path)
+
+ def conf_sshd(self, disable_password):
+ option = "no" if disable_password else "yes"
+ conf_file_path = conf.get_sshd_conf_file_path()
+ conf_file = fileutil.read_file(conf_file_path).split("\n")
+ textutil.set_ssh_config(conf_file, "PasswordAuthentication", option)
+ textutil.set_ssh_config(conf_file, "ChallengeResponseAuthentication", option)
+ textutil.set_ssh_config(conf_file, "ClientAliveInterval", "180")
+ fileutil.write_file(conf_file_path, "\n".join(conf_file))
+ logger.info("{0} SSH password-based authentication methods."
+ .format("Disabled" if disable_password else "Enabled"))
+ logger.info("Configured SSH client probing to keep connections alive.")
+
+
+ def get_dvd_device(self, dev_dir='/dev'):
+ pattern=r'(sr[0-9]|hd[c-z]|cdrom[0-9]|cd[0-9])'
+ for dvd in [re.match(pattern, dev) for dev in os.listdir(dev_dir)]:
+ if dvd is not None:
+ return "/dev/{0}".format(dvd.group(0))
+ raise OSUtilError("Failed to get dvd device")
+
+ def mount_dvd(self, max_retry=6, chk_err=True, dvd_device=None, mount_point=None):
+ if dvd_device is None:
+ dvd_device = self.get_dvd_device()
+ if mount_point is None:
+ mount_point = conf.get_dvd_mount_point()
+ mountlist = shellutil.run_get_output("mount")[1]
+ existing = self.get_mount_point(mountlist, dvd_device)
+ if existing is not None: #Already mounted
+ logger.info("{0} is already mounted at {1}", dvd_device, existing)
+ return
+ if not os.path.isdir(mount_point):
+ os.makedirs(mount_point)
+
+ for retry in range(0, max_retry):
+ retcode = self.mount(dvd_device, mount_point, option="-o ro -t udf,iso9660",
+ chk_err=chk_err)
+ if retcode == 0:
+ logger.info("Successfully mounted dvd")
+ return
+ if retry < max_retry - 1:
+ logger.warn("Mount dvd failed: retry={0}, ret={1}", retry,
+ retcode)
+ time.sleep(5)
+ if chk_err:
+ raise OSUtilError("Failed to mount dvd.")
+
+ def umount_dvd(self, chk_err=True, mount_point=None):
+ if mount_point is None:
+ mount_point = conf.get_dvd_mount_point()
+ retcode = self.umount(mount_point, chk_err=chk_err)
+ if chk_err and retcode != 0:
+ raise OSUtilError("Failed to umount dvd.")
+
+ def eject_dvd(self, chk_err=True):
+ dvd = self.get_dvd_device()
+ retcode = shellutil.run("eject {0}".format(dvd))
+ if chk_err and retcode != 0:
+ raise OSUtilError("Failed to eject dvd: ret={0}".format(retcode))
+
+ def try_load_atapiix_mod(self):
+ try:
+ self.load_atapiix_mod()
+ except Exception as e:
+ logger.warn("Could not load ATAPI driver: {0}".format(e))
+
+ def load_atapiix_mod(self):
+ if self.is_atapiix_mod_loaded():
+ return
+ ret, kern_version = shellutil.run_get_output("uname -r")
+ if ret != 0:
+ raise Exception("Failed to call uname -r")
+ mod_path = os.path.join('/lib/modules',
+ kern_version.strip('\n'),
+ 'kernel/drivers/ata/ata_piix.ko')
+ if not os.path.isfile(mod_path):
+ raise Exception("Can't find module file:{0}".format(mod_path))
+
+ ret, output = shellutil.run_get_output("insmod " + mod_path)
+ if ret != 0:
+ raise Exception("Error calling insmod for ATAPI CD-ROM driver")
+ if not self.is_atapiix_mod_loaded(max_retry=3):
+ raise Exception("Failed to load ATAPI CD-ROM driver")
+
+ def is_atapiix_mod_loaded(self, max_retry=1):
+ for retry in range(0, max_retry):
+ ret = shellutil.run("lsmod | grep ata_piix", chk_err=False)
+ if ret == 0:
+ logger.info("Module driver for ATAPI CD-ROM is already present.")
+ return True
+ if retry < max_retry - 1:
+ time.sleep(1)
+ return False
+
+ def mount(self, dvd, mount_point, option="", chk_err=True):
+ cmd = "mount {0} {1} {2}".format(option, dvd, mount_point)
+ return shellutil.run_get_output(cmd, chk_err)[0]
+
+ def umount(self, mount_point, chk_err=True):
+ return shellutil.run("umount {0}".format(mount_point), chk_err=chk_err)
+
+ def allow_dhcp_broadcast(self):
+ #Open DHCP port if iptables is enabled.
+ # We supress error logging on error.
+ shellutil.run("iptables -D INPUT -p udp --dport 68 -j ACCEPT",
+ chk_err=False)
+ shellutil.run("iptables -I INPUT -p udp --dport 68 -j ACCEPT",
+ chk_err=False)
+
+
+ def remove_rules_files(self, rules_files=__RULES_FILES__):
+ lib_dir = conf.get_lib_dir()
+ for src in rules_files:
+ file_name = fileutil.base_name(src)
+ dest = os.path.join(lib_dir, file_name)
+ if os.path.isfile(dest):
+ os.remove(dest)
+ if os.path.isfile(src):
+ logger.warn("Move rules file {0} to {1}", file_name, dest)
+ shutil.move(src, dest)
+
+ def restore_rules_files(self, rules_files=__RULES_FILES__):
+ lib_dir = conf.get_lib_dir()
+ for dest in rules_files:
+ filename = fileutil.base_name(dest)
+ src = os.path.join(lib_dir, filename)
+ if os.path.isfile(dest):
+ continue
+ if os.path.isfile(src):
+ logger.warn("Move rules file {0} to {1}", filename, dest)
+ shutil.move(src, dest)
+
+ def get_mac_addr(self):
+ """
+ Convienience function, returns mac addr bound to
+ first non-loopback interface.
+ """
+ ifname=''
+ while len(ifname) < 2 :
+ ifname=self.get_first_if()[0]
+ addr = self.get_if_mac(ifname)
+ return textutil.hexstr_to_bytearray(addr)
+
+ def get_if_mac(self, ifname):
+ """
+ Return the mac-address bound to the socket.
+ """
+ sock = socket.socket(socket.AF_INET,
+ socket.SOCK_DGRAM,
+ socket.IPPROTO_UDP)
+ param = struct.pack('256s', (ifname[:15]+('\0'*241)).encode('latin-1'))
+ info = fcntl.ioctl(sock.fileno(), 0x8927, param)
+ return ''.join(['%02X' % textutil.str_to_ord(char) for char in info[18:24]])
+
+ def get_first_if(self):
+ """
+ Return the interface name, and ip addr of the
+ first active non-loopback interface.
+ """
+ iface=''
+ expected=16 # how many devices should I expect...
+ struct_size=40 # for 64bit the size is 40 bytes
+ sock = socket.socket(socket.AF_INET,
+ socket.SOCK_DGRAM,
+ socket.IPPROTO_UDP)
+ buff=array.array('B', b'\0' * (expected * struct_size))
+ param = struct.pack('iL',
+ expected*struct_size,
+ buff.buffer_info()[0])
+ ret = fcntl.ioctl(sock.fileno(), 0x8912, param)
+ retsize=(struct.unpack('iL', ret)[0])
+ if retsize == (expected * struct_size):
+ logger.warn(('SIOCGIFCONF returned more than {0} up '
+ 'network interfaces.'), expected)
+ sock = buff.tostring()
+ primary = bytearray(self.get_primary_interface(), encoding='utf-8')
+ for i in range(0, struct_size * expected, struct_size):
+ iface=sock[i:i+16].split(b'\0', 1)[0]
+ if len(iface) == 0 or self.is_loopback(iface) or iface != primary:
+ # test the next one
+ logger.info('interface [{0}] skipped'.format(iface))
+ continue
+ else:
+ # use this one
+ logger.info('interface [{0}] selected'.format(iface))
+ break
+
+ return iface.decode('latin-1'), socket.inet_ntoa(sock[i+20:i+24])
+
+ def get_primary_interface(self):
+ """
+ Get the name of the primary interface, which is the one with the
+ default route attached to it; if there are multiple default routes,
+ the primary has the lowest Metric.
+ :return: the interface which has the default route
+ """
+ # from linux/route.h
+ RTF_GATEWAY = 0x02
+ DEFAULT_DEST = "00000000"
+
+ hdr_iface = "Iface"
+ hdr_dest = "Destination"
+ hdr_flags = "Flags"
+ hdr_metric = "Metric"
+
+ idx_iface = -1
+ idx_dest = -1
+ idx_flags = -1
+ idx_metric = -1
+ primary = None
+ primary_metric = None
+
+ logger.info("examine /proc/net/route for primary interface")
+ with open('/proc/net/route') as routing_table:
+ idx = 0
+ for header in filter(lambda h: len(h) > 0, routing_table.readline().strip(" \n").split("\t")):
+ if header == hdr_iface:
+ idx_iface = idx
+ elif header == hdr_dest:
+ idx_dest = idx
+ elif header == hdr_flags:
+ idx_flags = idx
+ elif header == hdr_metric:
+ idx_metric = idx
+ idx = idx + 1
+ for entry in routing_table.readlines():
+ route = entry.strip(" \n").split("\t")
+ if route[idx_dest] == DEFAULT_DEST and int(route[idx_flags]) & RTF_GATEWAY == RTF_GATEWAY:
+ metric = int(route[idx_metric])
+ iface = route[idx_iface]
+ if primary is None or metric < primary_metric:
+ primary = iface
+ primary_metric = metric
+
+ if primary is None:
+ primary = ''
+
+ logger.info('primary interface is [{0}]'.format(primary))
+ return primary
+
+
+ def is_primary_interface(self, ifname):
+ """
+ Indicate whether the specified interface is the primary.
+ :param ifname: the name of the interface - eth0, lo, etc.
+ :return: True if this interface binds the default route
+ """
+ return self.get_primary_interface() == ifname
+
+
+ def is_loopback(self, ifname):
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
+ result = fcntl.ioctl(s.fileno(), 0x8913, struct.pack('256s', ifname[:15]))
+ flags, = struct.unpack('H', result[16:18])
+ isloopback = flags & 8 == 8
+ logger.info('interface [{0}] has flags [{1}], is loopback [{2}]'.format(ifname, flags, isloopback))
+ return isloopback
+
+ def get_dhcp_lease_endpoint(self):
+ """
+ OS specific, this should return the decoded endpoint of
+ the wireserver from option 245 in the dhcp leases file
+ if it exists on disk.
+ :return: The endpoint if available, or None
+ """
+ return None
+
+ @staticmethod
+ def get_endpoint_from_leases_path(pathglob):
+ """
+ Try to discover and decode the wireserver endpoint in the
+ specified dhcp leases path.
+ :param pathglob: The path containing dhcp lease files
+ :return: The endpoint if available, otherwise None
+ """
+ endpoint = None
+
+ HEADER_LEASE = "lease"
+ HEADER_OPTION = "option unknown-245"
+ HEADER_DNS = "option domain-name-servers"
+ HEADER_EXPIRE = "expire"
+ FOOTER_LEASE = "}"
+ FORMAT_DATETIME = "%Y/%m/%d %H:%M:%S"
+
+ logger.info("looking for leases in path [{0}]".format(pathglob))
+ for lease_file in glob.glob(pathglob):
+ leases = open(lease_file).read()
+ if HEADER_OPTION in leases:
+ cached_endpoint = None
+ has_option_245 = False
+ expired = True # assume expired
+ for line in leases.splitlines():
+ if line.startswith(HEADER_LEASE):
+ cached_endpoint = None
+ has_option_245 = False
+ expired = True
+ elif HEADER_DNS in line:
+ cached_endpoint = line.replace(HEADER_DNS, '').strip(" ;")
+ elif HEADER_OPTION in line:
+ has_option_245 = True
+ elif HEADER_EXPIRE in line:
+ if "never" in line:
+ expired = False
+ else:
+ try:
+ expire_string = line.split(" ", 4)[-1].strip(";")
+ expire_date = datetime.datetime.strptime(expire_string, FORMAT_DATETIME)
+ if expire_date > datetime.datetime.utcnow():
+ expired = False
+ except:
+ logger.error("could not parse expiry token '{0}'".format(line))
+ elif FOOTER_LEASE in line:
+ logger.info("dhcp entry:{0}, 245:{1}, expired:{2}".format(
+ cached_endpoint, has_option_245, expired))
+ if not expired and cached_endpoint is not None and has_option_245:
+ endpoint = cached_endpoint
+ logger.info("found endpoint [{0}]".format(endpoint))
+ # we want to return the last valid entry, so
+ # keep searching
+ if endpoint is not None:
+ logger.info("cached endpoint found [{0}]".format(endpoint))
+ else:
+ logger.info("cached endpoint not found")
+ return endpoint
+
+ def is_missing_default_route(self):
+ routes = shellutil.run_get_output("route -n")[1]
+ for route in routes.split("\n"):
+ if route.startswith("0.0.0.0 ") or route.startswith("default "):
+ return False
+ return True
+
+ def get_if_name(self):
+ return self.get_first_if()[0]
+
+ def get_ip4_addr(self):
+ return self.get_first_if()[1]
+
+ def set_route_for_dhcp_broadcast(self, ifname):
+ return shellutil.run("route add 255.255.255.255 dev {0}".format(ifname),
+ chk_err=False)
+
+ def remove_route_for_dhcp_broadcast(self, ifname):
+ shellutil.run("route del 255.255.255.255 dev {0}".format(ifname),
+ chk_err=False)
+
+ def is_dhcp_enabled(self):
+ return False
+
+ def stop_dhcp_service(self):
+ pass
+
+ def start_dhcp_service(self):
+ pass
+
+ def start_network(self):
+ pass
+
+ def start_agent_service(self):
+ pass
+
+ def stop_agent_service(self):
+ pass
+
+ def register_agent_service(self):
+ pass
+
+ def unregister_agent_service(self):
+ pass
+
+ def restart_ssh_service(self):
+ pass
+
+ def route_add(self, net, mask, gateway):
+ """
+ Add specified route using /sbin/route add -net.
+ """
+ cmd = ("/sbin/route add -net "
+ "{0} netmask {1} gw {2}").format(net, mask, gateway)
+ return shellutil.run(cmd, chk_err=False)
+
+ def get_dhcp_pid(self):
+ ret= shellutil.run_get_output("pidof dhclient")
+ return ret[1] if ret[0] == 0 else None
+
+ def set_hostname(self, hostname):
+ fileutil.write_file('/etc/hostname', hostname)
+ shellutil.run("hostname {0}".format(hostname), chk_err=False)
+
+ def set_dhcp_hostname(self, hostname):
+ autosend = r'^[^#]*?send\s*host-name.*?(<hostname>|gethostname[(,)])'
+ dhclient_files = ['/etc/dhcp/dhclient.conf', '/etc/dhcp3/dhclient.conf', '/etc/dhclient.conf']
+ for conf_file in dhclient_files:
+ if not os.path.isfile(conf_file):
+ continue
+ if fileutil.findstr_in_file(conf_file, autosend):
+ #Return if auto send host-name is configured
+ return
+ fileutil.update_conf_file(conf_file,
+ 'send host-name',
+ 'send host-name "{0}";'.format(hostname))
+
+ def restart_if(self, ifname, retries=3, wait=5):
+ retry_limit=retries+1
+ for attempt in range(1, retry_limit):
+ return_code=shellutil.run("ifdown {0} && ifup {0}".format(ifname))
+ if return_code == 0:
+ return
+ logger.warn("failed to restart {0}: return code {1}".format(ifname, return_code))
+ if attempt < retry_limit:
+ logger.info("retrying in {0} seconds".format(wait))
+ time.sleep(wait)
+ else:
+ logger.warn("exceeded restart retries")
+
+ def publish_hostname(self, hostname):
+ self.set_dhcp_hostname(hostname)
+ ifname = self.get_if_name()
+ self.restart_if(ifname)
+
+ def set_scsi_disks_timeout(self, timeout):
+ for dev in os.listdir("/sys/block"):
+ if dev.startswith('sd'):
+ self.set_block_device_timeout(dev, timeout)
+
+ def set_block_device_timeout(self, dev, timeout):
+ if dev is not None and timeout is not None:
+ file_path = "/sys/block/{0}/device/timeout".format(dev)
+ content = fileutil.read_file(file_path)
+ original = content.splitlines()[0].rstrip()
+ if original != timeout:
+ fileutil.write_file(file_path, timeout)
+ logger.info("Set block dev timeout: {0} with timeout: {1}",
+ dev, timeout)
+
+ def get_mount_point(self, mountlist, device):
+ """
+ Example of mountlist:
+ /dev/sda1 on / type ext4 (rw)
+ proc on /proc type proc (rw)
+ sysfs on /sys type sysfs (rw)
+ devpts on /dev/pts type devpts (rw,gid=5,mode=620)
+ tmpfs on /dev/shm type tmpfs
+ (rw,rootcontext="system_u:object_r:tmpfs_t:s0")
+ none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw)
+ /dev/sdb1 on /mnt/resource type ext4 (rw)
+ """
+ if (mountlist and device):
+ for entry in mountlist.split('\n'):
+ if(re.search(device, entry)):
+ tokens = entry.split()
+ #Return the 3rd column of this line
+ return tokens[2] if len(tokens) > 2 else None
+ return None
+
+ def device_for_ide_port(self, port_id):
+ """
+ Return device name attached to ide port 'n'.
+ """
+ if port_id > 3:
+ return None
+ g0 = "00000000"
+ if port_id > 1:
+ g0 = "00000001"
+ port_id = port_id - 2
+ device = None
+ path = "/sys/bus/vmbus/devices/"
+ for vmbus in os.listdir(path):
+ deviceid = fileutil.read_file(os.path.join(path, vmbus, "device_id"))
+ guid = deviceid.lstrip('{').split('-')
+ if guid[0] == g0 and guid[1] == "000" + ustr(port_id):
+ for root, dirs, files in os.walk(path + vmbus):
+ if root.endswith("/block"):
+ device = dirs[0]
+ break
+ else : #older distros
+ for d in dirs:
+ if ':' in d and "block" == d.split(':')[0]:
+ device = d.split(':')[1]
+ break
+ break
+ return device
+
+ def del_account(self, username):
+ if self.is_sys_user(username):
+ logger.error("{0} is a system user. Will not delete it.", username)
+ shellutil.run("> /var/run/utmp")
+ shellutil.run("userdel -f -r " + username)
+ self.conf_sudoer(username, remove=True)
+
+ def decode_customdata(self, data):
+ return base64.b64decode(data)
+
+ def get_total_mem(self):
+ cmd = "grep MemTotal /proc/meminfo |awk '{print $2}'"
+ ret = shellutil.run_get_output(cmd)
+ if ret[0] == 0:
+ return int(ret[1])/1024
+ else:
+ raise OSUtilError("Failed to get total memory: {0}".format(ret[1]))
+
+ def get_processor_cores(self):
+ ret = shellutil.run_get_output("grep 'processor.*:' /proc/cpuinfo |wc -l")
+ if ret[0] == 0:
+ return int(ret[1])
+ else:
+ raise OSUtilError("Failed to get processor cores")
+
+ def set_admin_access_to_ip(self, dest_ip):
+ #This allows root to access dest_ip
+ rm_old= "iptables -D OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0"
+ rule = "iptables -A OUTPUT -d {0} -j ACCEPT -m owner --uid-owner 0"
+ shellutil.run(rm_old.format(dest_ip), chk_err=False)
+ shellutil.run(rule.format(dest_ip))
+
+ #This blocks all other users to access dest_ip
+ rm_old = "iptables -D OUTPUT -d {0} -j DROP"
+ rule = "iptables -A OUTPUT -d {0} -j DROP"
+ shellutil.run(rm_old.format(dest_ip), chk_err=False)
+ shellutil.run(rule.format(dest_ip))
+
+ def check_pid_alive(self, pid):
+ return pid is not None and os.path.isdir(os.path.join('/proc', pid))
diff --git a/azurelinuxagent/common/osutil/factory.py b/azurelinuxagent/common/osutil/factory.py
new file mode 100644
index 0000000..5e8ae6e
--- /dev/null
+++ b/azurelinuxagent/common/osutil/factory.py
@@ -0,0 +1,69 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.utils.textutil import Version
+from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, \
+ DISTRO_FULL_NAME
+
+from .default import DefaultOSUtil
+from .coreos import CoreOSUtil
+from .debian import DebianOSUtil
+from .freebsd import FreeBSDOSUtil
+from .redhat import RedhatOSUtil, Redhat6xOSUtil
+from .suse import SUSEOSUtil, SUSE11OSUtil
+from .ubuntu import UbuntuOSUtil, Ubuntu12OSUtil, Ubuntu14OSUtil, \
+ UbuntuSnappyOSUtil
+
+def get_osutil(distro_name=DISTRO_NAME, distro_version=DISTRO_VERSION,
+ distro_full_name=DISTRO_FULL_NAME):
+ if distro_name == "ubuntu":
+ if Version(distro_version) == Version("12.04") or \
+ Version(distro_version) == Version("12.10"):
+ return Ubuntu12OSUtil()
+ elif Version(distro_version) == Version("14.04") or \
+ Version(distro_version) == Version("14.10"):
+ return Ubuntu14OSUtil()
+ elif distro_full_name == "Snappy Ubuntu Core":
+ return UbuntuSnappyOSUtil()
+ else:
+ return UbuntuOSUtil()
+ if distro_name == "coreos":
+ return CoreOSUtil()
+ if distro_name == "suse":
+ if distro_full_name=='SUSE Linux Enterprise Server' and \
+ Version(distro_version) < Version('12') or \
+ distro_full_name == 'openSUSE' and \
+ Version(distro_version) < Version('13.2'):
+ return SUSE11OSUtil()
+ else:
+ return SUSEOSUtil()
+ elif distro_name == "debian":
+ return DebianOSUtil()
+ elif distro_name == "redhat" or distro_name == "centos" or \
+ distro_name == "oracle":
+ if Version(distro_version) < Version("7"):
+ return Redhat6xOSUtil()
+ else:
+ return RedhatOSUtil()
+ elif distro_name == "freebsd":
+ return FreeBSDOSUtil()
+ else:
+ logger.warn("Unable to load distro implemetation for {0}.", distro_name)
+ logger.warn("Use default distro implemetation instead.")
+ return DefaultOSUtil()
+
diff --git a/azurelinuxagent/common/osutil/freebsd.py b/azurelinuxagent/common/osutil/freebsd.py
new file mode 100644
index 0000000..ddf8db6
--- /dev/null
+++ b/azurelinuxagent/common/osutil/freebsd.py
@@ -0,0 +1,198 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import OSUtilError
+from azurelinuxagent.common.osutil.default import DefaultOSUtil
+
+
+class FreeBSDOSUtil(DefaultOSUtil):
+ def __init__(self):
+ super(FreeBSDOSUtil, self).__init__()
+ self._scsi_disks_timeout_set = False
+
+ def set_hostname(self, hostname):
+ rc_file_path = '/etc/rc.conf'
+ conf_file = fileutil.read_file(rc_file_path).split("\n")
+ textutil.set_ini_config(conf_file, "hostname", hostname)
+ fileutil.write_file(rc_file_path, "\n".join(conf_file))
+ shellutil.run("hostname {0}".format(hostname), chk_err=False)
+
+ def restart_ssh_service(self):
+ return shellutil.run('service sshd restart', chk_err=False)
+
+ def useradd(self, username, expiration=None):
+ """
+ Create user account with 'username'
+ """
+ userentry = self.get_userentry(username)
+ if userentry is not None:
+ logger.warn("User {0} already exists, skip useradd", username)
+ return
+
+ if expiration is not None:
+ cmd = "pw useradd {0} -e {1} -m".format(username, expiration)
+ else:
+ cmd = "pw useradd {0} -m".format(username)
+ retcode, out = shellutil.run_get_output(cmd)
+ if retcode != 0:
+ raise OSUtilError(("Failed to create user account:{0}, "
+ "retcode:{1}, "
+ "output:{2}").format(username, retcode, out))
+
+ def del_account(self, username):
+ if self.is_sys_user(username):
+ logger.error("{0} is a system user. Will not delete it.", username)
+ shellutil.run('> /var/run/utx.active')
+ shellutil.run('rmuser -y ' + username)
+ self.conf_sudoer(username, remove=True)
+
+ def chpasswd(self, username, password, crypt_id=6, salt_len=10):
+ if self.is_sys_user(username):
+ raise OSUtilError(("User {0} is a system user. "
+ "Will not set passwd.").format(username))
+ passwd_hash = textutil.gen_password_hash(password, crypt_id, salt_len)
+ cmd = "echo '{0}'|pw usermod {1} -H 0 ".format(passwd_hash, username)
+ ret, output = shellutil.run_get_output(cmd, log_cmd=False)
+ if ret != 0:
+ raise OSUtilError(("Failed to set password for {0}: {1}"
+ "").format(username, output))
+
+ def del_root_password(self):
+ err = shellutil.run('pw mod user root -w no')
+ if err:
+ raise OSUtilError("Failed to delete root password: Failed to update password database.")
+
+ def get_if_mac(self, ifname):
+ data = self._get_net_info()
+ if data[0] == ifname:
+ return data[2].replace(':', '').upper()
+ return None
+
+ def get_first_if(self):
+ return self._get_net_info()[:2]
+
+ def route_add(self, net, mask, gateway):
+ cmd = 'route add {0} {1} {2}'.format(net, gateway, mask)
+ return shellutil.run(cmd, chk_err=False)
+
+ def is_missing_default_route(self):
+ """
+ For FreeBSD, the default broadcast goes to current default gw, not a all-ones broadcast address, need to
+ specify the route manually to get it work in a VNET environment.
+ SEE ALSO: man ip(4) IP_ONESBCAST,
+ """
+ return True
+
+ def is_dhcp_enabled(self):
+ return True
+
+ def start_dhcp_service(self):
+ shellutil.run("/etc/rc.d/dhclient start {0}".format(self.get_if_name()), chk_err=False)
+
+ def allow_dhcp_broadcast(self):
+ pass
+
+ def set_route_for_dhcp_broadcast(self, ifname):
+ return shellutil.run("route add 255.255.255.255 -iface {0}".format(ifname), chk_err=False)
+
+ def remove_route_for_dhcp_broadcast(self, ifname):
+ shellutil.run("route delete 255.255.255.255 -iface {0}".format(ifname), chk_err=False)
+
+ def get_dhcp_pid(self):
+ ret = shellutil.run_get_output("pgrep -n dhclient")
+ return ret[1] if ret[0] == 0 else None
+
+ def eject_dvd(self, chk_err=True):
+ dvd = self.get_dvd_device()
+ retcode = shellutil.run("cdcontrol -f {0} eject".format(dvd))
+ if chk_err and retcode != 0:
+ raise OSUtilError("Failed to eject dvd: ret={0}".format(retcode))
+
+ def restart_if(self, ifname):
+ # Restart dhclient only to publish hostname
+ shellutil.run("/etc/rc.d/dhclient restart {0}".format(ifname), chk_err=False)
+
+ def get_total_mem(self):
+ cmd = "sysctl hw.physmem |awk '{print $2}'"
+ ret, output = shellutil.run_get_output(cmd)
+ if ret:
+ raise OSUtilError("Failed to get total memory: {0}".format(output))
+ try:
+ return int(output)/1024/1024
+ except ValueError:
+ raise OSUtilError("Failed to get total memory: {0}".format(output))
+
+ def get_processor_cores(self):
+ ret, output = shellutil.run_get_output("sysctl hw.ncpu |awk '{print $2}'")
+ if ret:
+ raise OSUtilError("Failed to get processor cores.")
+
+ try:
+ return int(output)
+ except ValueError:
+ raise OSUtilError("Failed to get total memory: {0}".format(output))
+
+ def set_scsi_disks_timeout(self, timeout):
+ if self._scsi_disks_timeout_set:
+ return
+
+ ret, output = shellutil.run_get_output('sysctl kern.cam.da.default_timeout={0}'.format(timeout))
+ if ret:
+ raise OSUtilError("Failed set SCSI disks timeout: {0}".format(output))
+ self._scsi_disks_timeout_set = True
+
+ def check_pid_alive(self, pid):
+ return shellutil.run('ps -p {0}'.format(pid), chk_err=False) == 0
+
+ @staticmethod
+ def _get_net_info():
+ """
+ There is no SIOCGIFCONF
+ on freeBSD - just parse ifconfig.
+ Returns strings: iface, inet4_addr, and mac
+ or 'None,None,None' if unable to parse.
+ We will sleep and retry as the network must be up.
+ """
+ iface = ''
+ inet = ''
+ mac = ''
+
+ err, output = shellutil.run_get_output('ifconfig -l ether', chk_err=False)
+ if err:
+ raise OSUtilError("Can't find ether interface:{0}".format(output))
+ ifaces = output.split()
+ if not ifaces:
+ raise OSUtilError("Can't find ether interface.")
+ iface = ifaces[0]
+
+ err, output = shellutil.run_get_output('ifconfig ' + iface, chk_err=False)
+ if err:
+ raise OSUtilError("Can't get info for interface:{0}".format(iface))
+
+ for line in output.split('\n'):
+ if line.find('inet ') != -1:
+ inet = line.split()[1]
+ elif line.find('ether ') != -1:
+ mac = line.split()[1]
+ logger.verbose("Interface info: ({0},{1},{2})", iface, inet, mac)
+
+ return iface, inet, mac
diff --git a/azurelinuxagent/common/osutil/redhat.py b/azurelinuxagent/common/osutil/redhat.py
new file mode 100644
index 0000000..03084b6
--- /dev/null
+++ b/azurelinuxagent/common/osutil/redhat.py
@@ -0,0 +1,122 @@
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import re
+import pwd
+import shutil
+import socket
+import array
+import struct
+import fcntl
+import time
+import base64
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.future import ustr, bytebuffer
+from azurelinuxagent.common.exception import OSUtilError, CryptError
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
+from azurelinuxagent.common.osutil.default import DefaultOSUtil
+
+class Redhat6xOSUtil(DefaultOSUtil):
+ def __init__(self):
+ super(Redhat6xOSUtil, self).__init__()
+
+ def start_network(self):
+ return shellutil.run("/sbin/service networking start", chk_err=False)
+
+ def restart_ssh_service(self):
+ return shellutil.run("/sbin/service sshd condrestart", chk_err=False)
+
+ def stop_agent_service(self):
+ return shellutil.run("/sbin/service waagent stop", chk_err=False)
+
+ def start_agent_service(self):
+ return shellutil.run("/sbin/service waagent start", chk_err=False)
+
+ def register_agent_service(self):
+ return shellutil.run("chkconfig --add waagent", chk_err=False)
+
+ def unregister_agent_service(self):
+ return shellutil.run("chkconfig --del waagent", chk_err=False)
+
+ def openssl_to_openssh(self, input_file, output_file):
+ pubkey = fileutil.read_file(input_file)
+ try:
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ ssh_rsa_pubkey = cryptutil.asn1_to_ssh(pubkey)
+ except CryptError as e:
+ raise OSUtilError(ustr(e))
+ fileutil.write_file(output_file, ssh_rsa_pubkey)
+
+ #Override
+ def get_dhcp_pid(self):
+ ret= shellutil.run_get_output("pidof dhclient")
+ return ret[1] if ret[0] == 0 else None
+
+ def set_hostname(self, hostname):
+ """
+ Set /etc/sysconfig/network
+ """
+ fileutil.update_conf_file('/etc/sysconfig/network',
+ 'HOSTNAME',
+ 'HOSTNAME={0}'.format(hostname))
+ shellutil.run("hostname {0}".format(hostname), chk_err=False)
+
+ def set_dhcp_hostname(self, hostname):
+ ifname = self.get_if_name()
+ filepath = "/etc/sysconfig/network-scripts/ifcfg-{0}".format(ifname)
+ fileutil.update_conf_file(filepath, 'DHCP_HOSTNAME',
+ 'DHCP_HOSTNAME={0}'.format(hostname))
+
+ def get_dhcp_lease_endpoint(self):
+ return self.get_endpoint_from_leases_path('/var/lib/dhclient/dhclient-*.leases')
+
+class RedhatOSUtil(Redhat6xOSUtil):
+ def __init__(self):
+ super(RedhatOSUtil, self).__init__()
+
+ def set_hostname(self, hostname):
+ """
+ Set /etc/hostname
+ Unlike redhat 6.x, redhat 7.x will set hostname to /etc/hostname
+ """
+ DefaultOSUtil.set_hostname(self, hostname)
+
+ def publish_hostname(self, hostname):
+ """
+ Restart NetworkManager first before publishing hostname
+ """
+ shellutil.run("service NetworkManager restart")
+ super(RedhatOSUtil, self).publish_hostname(hostname)
+
+ def register_agent_service(self):
+ return shellutil.run("systemctl enable waagent", chk_err=False)
+
+ def unregister_agent_service(self):
+ return shellutil.run("systemctl disable waagent", chk_err=False)
+
+ def openssl_to_openssh(self, input_file, output_file):
+ DefaultOSUtil.openssl_to_openssh(self, input_file, output_file)
+
+ def get_dhcp_lease_endpoint(self):
+ # centos7 has this weird naming with double hyphen like /var/lib/dhclient/dhclient--eth0.lease
+ return self.get_endpoint_from_leases_path('/var/lib/dhclient/dhclient-*.lease')
diff --git a/azurelinuxagent/common/osutil/suse.py b/azurelinuxagent/common/osutil/suse.py
new file mode 100644
index 0000000..f0ed0c0
--- /dev/null
+++ b/azurelinuxagent/common/osutil/suse.py
@@ -0,0 +1,108 @@
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import re
+import pwd
+import shutil
+import socket
+import array
+import struct
+import fcntl
+import time
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+import azurelinuxagent.common.utils.textutil as textutil
+from azurelinuxagent.common.version import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
+from azurelinuxagent.common.osutil.default import DefaultOSUtil
+
+class SUSE11OSUtil(DefaultOSUtil):
+ def __init__(self):
+ super(SUSE11OSUtil, self).__init__()
+ self.dhclient_name='dhcpcd'
+
+ def set_hostname(self, hostname):
+ fileutil.write_file('/etc/HOSTNAME', hostname)
+ shellutil.run("hostname {0}".format(hostname), chk_err=False)
+
+ def get_dhcp_pid(self):
+ ret= shellutil.run_get_output("pidof {0}".format(self.dhclient_name))
+ return ret[1] if ret[0] == 0 else None
+
+ def is_dhcp_enabled(self):
+ return True
+
+ def stop_dhcp_service(self):
+ cmd = "/sbin/service {0} stop".format(self.dhclient_name)
+ return shellutil.run(cmd, chk_err=False)
+
+ def start_dhcp_service(self):
+ cmd = "/sbin/service {0} start".format(self.dhclient_name)
+ return shellutil.run(cmd, chk_err=False)
+
+ def start_network(self) :
+ return shellutil.run("/sbin/service start network", chk_err=False)
+
+ def restart_ssh_service(self):
+ return shellutil.run("/sbin/service sshd restart", chk_err=False)
+
+ def stop_agent_service(self):
+ return shellutil.run("/sbin/service waagent stop", chk_err=False)
+
+ def start_agent_service(self):
+ return shellutil.run("/sbin/service waagent start", chk_err=False)
+
+ def register_agent_service(self):
+ return shellutil.run("/sbin/insserv waagent", chk_err=False)
+
+ def unregister_agent_service(self):
+ return shellutil.run("/sbin/insserv -r waagent", chk_err=False)
+
+class SUSEOSUtil(SUSE11OSUtil):
+ def __init__(self):
+ super(SUSEOSUtil, self).__init__()
+ self.dhclient_name = 'wickedd-dhcp4'
+
+ def stop_dhcp_service(self):
+ cmd = "systemctl stop {0}".format(self.dhclient_name)
+ return shellutil.run(cmd, chk_err=False)
+
+ def start_dhcp_service(self):
+ cmd = "systemctl start {0}".format(self.dhclient_name)
+ return shellutil.run(cmd, chk_err=False)
+
+ def start_network(self) :
+ return shellutil.run("systemctl start network", chk_err=False)
+
+ def restart_ssh_service(self):
+ return shellutil.run("systemctl restart sshd", chk_err=False)
+
+ def stop_agent_service(self):
+ return shellutil.run("systemctl stop waagent", chk_err=False)
+
+ def start_agent_service(self):
+ return shellutil.run("systemctl start waagent", chk_err=False)
+
+ def register_agent_service(self):
+ return shellutil.run("systemctl enable waagent", chk_err=False)
+
+ def unregister_agent_service(self):
+ return shellutil.run("systemctl disable waagent", chk_err=False)
+
+
diff --git a/azurelinuxagent/common/osutil/ubuntu.py b/azurelinuxagent/common/osutil/ubuntu.py
new file mode 100644
index 0000000..4032cf4
--- /dev/null
+++ b/azurelinuxagent/common/osutil/ubuntu.py
@@ -0,0 +1,66 @@
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import azurelinuxagent.common.utils.shellutil as shellutil
+from azurelinuxagent.common.osutil.default import DefaultOSUtil
+
+class Ubuntu14OSUtil(DefaultOSUtil):
+ def __init__(self):
+ super(Ubuntu14OSUtil, self).__init__()
+
+ def start_network(self):
+ return shellutil.run("service networking start", chk_err=False)
+
+ def stop_agent_service(self):
+ return shellutil.run("service walinuxagent stop", chk_err=False)
+
+ def start_agent_service(self):
+ return shellutil.run("service walinuxagent start", chk_err=False)
+
+ def remove_rules_files(self, rules_files=""):
+ pass
+
+ def restore_rules_files(self, rules_files=""):
+ pass
+
+ def get_dhcp_lease_endpoint(self):
+ return self.get_endpoint_from_leases_path('/var/lib/dhcp/dhclient.*.leases')
+
+class Ubuntu12OSUtil(Ubuntu14OSUtil):
+ def __init__(self):
+ super(Ubuntu12OSUtil, self).__init__()
+
+ #Override
+ def get_dhcp_pid(self):
+ ret= shellutil.run_get_output("pidof dhclient3")
+ return ret[1] if ret[0] == 0 else None
+
+class UbuntuOSUtil(Ubuntu14OSUtil):
+ def __init__(self):
+ super(UbuntuOSUtil, self).__init__()
+
+ def register_agent_service(self):
+ return shellutil.run("systemctl unmask walinuxagent", chk_err=False)
+
+ def unregister_agent_service(self):
+ return shellutil.run("systemctl mask walinuxagent", chk_err=False)
+
+class UbuntuSnappyOSUtil(Ubuntu14OSUtil):
+ def __init__(self):
+ super(UbuntuSnappyOSUtil, self).__init__()
+ self.conf_file_path = '/apps/walinuxagent/current/waagent.conf'
diff --git a/azurelinuxagent/common/protocol/__init__.py b/azurelinuxagent/common/protocol/__init__.py
new file mode 100644
index 0000000..fb7c273
--- /dev/null
+++ b/azurelinuxagent/common/protocol/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+from azurelinuxagent.common.protocol.util import get_protocol_util, \
+ OVF_FILE_NAME, \
+ TAG_FILE_NAME
+
diff --git a/azurelinuxagent/common/protocol/hostplugin.py b/azurelinuxagent/common/protocol/hostplugin.py
new file mode 100644
index 0000000..6569604
--- /dev/null
+++ b/azurelinuxagent/common/protocol/hostplugin.py
@@ -0,0 +1,124 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+from azurelinuxagent.common.protocol.wire import *
+from azurelinuxagent.common.utils import textutil
+
+HOST_PLUGIN_PORT = 32526
+URI_FORMAT_GET_API_VERSIONS = "http://{0}:{1}/versions"
+URI_FORMAT_PUT_VM_STATUS = "http://{0}:{1}/status"
+URI_FORMAT_PUT_LOG = "http://{0}:{1}/vmAgentLog"
+API_VERSION = "2015-09-01"
+
+
+class HostPluginProtocol(object):
+ def __init__(self, endpoint):
+ if endpoint is None:
+ raise ProtocolError("Host plugin endpoint not provided")
+ self.is_initialized = False
+ self.is_available = False
+ self.api_versions = None
+ self.endpoint = endpoint
+
+ def ensure_initialized(self):
+ if not self.is_initialized:
+ self.api_versions = self.get_api_versions()
+ self.is_available = API_VERSION in self.api_versions
+ self.is_initialized = True
+ return self.is_available
+
+ def get_api_versions(self):
+ url = URI_FORMAT_GET_API_VERSIONS.format(self.endpoint,
+ HOST_PLUGIN_PORT)
+ logger.info("getting API versions at [{0}]".format(url))
+ try:
+ response = restutil.http_get(url)
+ if response.status != httpclient.OK:
+ logger.error(
+ "get API versions returned status code [{0}]".format(
+ response.status))
+ return []
+ return response.read()
+ except HttpError as e:
+ logger.error("get API versions failed with [{0}]".format(e))
+ return []
+
+ def put_vm_status(self, status_blob, sas_url):
+ """
+ Try to upload the VM status via the host plugin /status channel
+ :param sas_url: the blob SAS url to pass to the host plugin
+ :type status_blob: StatusBlob
+ """
+ if not self.ensure_initialized():
+ logger.error("host plugin channel is not available")
+ return
+ if status_blob is None or status_blob.vm_status is None:
+ logger.error("no status data was provided")
+ return
+ url = URI_FORMAT_PUT_VM_STATUS.format(self.endpoint, HOST_PLUGIN_PORT)
+ status = textutil.b64encode(status_blob.vm_status)
+ headers = {"x-ms-version": API_VERSION}
+ blob_headers = [{'headerName': 'x-ms-version',
+ 'headerValue': status_blob.__storage_version__},
+ {'headerName': 'x-ms-blob-type',
+ 'headerValue': status_blob.type}]
+ data = json.dumps({'requestUri': sas_url, 'headers': blob_headers,
+ 'content': status}, sort_keys=True)
+ logger.info("put VM status at [{0}]".format(url))
+ try:
+ response = restutil.http_put(url, data, headers)
+ if response.status != httpclient.OK:
+ logger.error("put VM status returned status code [{0}]".format(
+ response.status))
+ except HttpError as e:
+ logger.error("put VM status failed with [{0}]".format(e))
+
+ def put_vm_log(self, content, container_id, deployment_id):
+ """
+ Try to upload the given content to the host plugin
+ :param deployment_id: the deployment id, which is obtained from the
+ goal state (tenant name)
+ :param container_id: the container id, which is obtained from the
+ goal state
+ :param content: the binary content of the zip file to upload
+ :return:
+ """
+ if not self.ensure_initialized():
+ logger.error("host plugin channel is not available")
+ return
+ if content is None or container_id is None or deployment_id is None:
+ logger.error(
+ "invalid arguments passed: "
+ "[{0}], [{1}], [{2}]".format(
+ content,
+ container_id,
+ deployment_id))
+ return
+ url = URI_FORMAT_PUT_LOG.format(self.endpoint, HOST_PLUGIN_PORT)
+
+ headers = {"x-ms-vmagentlog-deploymentid": deployment_id,
+ "x-ms-vmagentlog-containerid": container_id}
+ logger.info("put VM log at [{0}]".format(url))
+ try:
+ response = restutil.http_put(url, content, headers)
+ if response.status != httpclient.OK:
+ logger.error("put log returned status code [{0}]".format(
+ response.status))
+ except HttpError as e:
+ logger.error("put log failed with [{0}]".format(e))
diff --git a/azurelinuxagent/common/protocol/metadata.py b/azurelinuxagent/common/protocol/metadata.py
new file mode 100644
index 0000000..f86f72f
--- /dev/null
+++ b/azurelinuxagent/common/protocol/metadata.py
@@ -0,0 +1,223 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import json
+import shutil
+import os
+import time
+from azurelinuxagent.common.exception import ProtocolError, HttpError
+from azurelinuxagent.common.future import httpclient, ustr
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.restutil as restutil
+import azurelinuxagent.common.utils.textutil as textutil
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
+from azurelinuxagent.common.protocol.restapi import *
+
+METADATA_ENDPOINT='169.254.169.254'
+APIVERSION='2015-05-01-preview'
+BASE_URI = "http://{0}/Microsoft.Compute/{1}?api-version={2}{3}"
+
+TRANSPORT_PRV_FILE_NAME = "V2TransportPrivate.pem"
+TRANSPORT_CERT_FILE_NAME = "V2TransportCert.pem"
+
+#TODO remote workarround for azure stack
+MAX_PING = 30
+RETRY_PING_INTERVAL = 10
+
+def _add_content_type(headers):
+ if headers is None:
+ headers = {}
+ headers["content-type"] = "application/json"
+ return headers
+
+class MetadataProtocol(Protocol):
+
+ def __init__(self, apiversion=APIVERSION, endpoint=METADATA_ENDPOINT):
+ self.apiversion = apiversion
+ self.endpoint = endpoint
+ self.identity_uri = BASE_URI.format(self.endpoint, "identity",
+ self.apiversion, "&$expand=*")
+ self.cert_uri = BASE_URI.format(self.endpoint, "certificates",
+ self.apiversion, "&$expand=*")
+ self.ext_uri = BASE_URI.format(self.endpoint, "extensionHandlers",
+ self.apiversion, "&$expand=*")
+ self.vmagent_uri = BASE_URI.format(self.endpoint, "vmAgentVersions",
+ self.apiversion, "&$expand=*")
+ self.provision_status_uri = BASE_URI.format(self.endpoint,
+ "provisioningStatus",
+ self.apiversion, "")
+ self.vm_status_uri = BASE_URI.format(self.endpoint, "status/vmagent",
+ self.apiversion, "")
+ self.ext_status_uri = BASE_URI.format(self.endpoint,
+ "status/extensions/{0}",
+ self.apiversion, "")
+ self.event_uri = BASE_URI.format(self.endpoint, "status/telemetry",
+ self.apiversion, "")
+
+ def _get_data(self, url, headers=None):
+ try:
+ resp = restutil.http_get(url, headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+
+ if resp.status != httpclient.OK:
+ raise ProtocolError("{0} - GET: {1}".format(resp.status, url))
+
+ data = resp.read()
+ etag = resp.getheader('ETag')
+ if data is None:
+ return None
+ data = json.loads(ustr(data, encoding="utf-8"))
+ return data, etag
+
+ def _put_data(self, url, data, headers=None):
+ headers = _add_content_type(headers)
+ try:
+ resp = restutil.http_put(url, json.dumps(data), headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+ if resp.status != httpclient.OK:
+ raise ProtocolError("{0} - PUT: {1}".format(resp.status, url))
+
+ def _post_data(self, url, data, headers=None):
+ headers = _add_content_type(headers)
+ try:
+ resp = restutil.http_post(url, json.dumps(data), headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+ if resp.status != httpclient.CREATED:
+ raise ProtocolError("{0} - POST: {1}".format(resp.status, url))
+
+ def _get_trans_cert(self):
+ trans_crt_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ if not os.path.isfile(trans_crt_file):
+ raise ProtocolError("{0} is missing.".format(trans_crt_file))
+ content = fileutil.read_file(trans_crt_file)
+ return textutil.get_bytes_from_pem(content)
+
+ def detect(self):
+ self.get_vminfo()
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
+
+ #"Install" the cert and private key to /var/lib/waagent
+ thumbprint = cryptutil.get_thumbprint_from_crt(trans_cert_file)
+ prv_file = os.path.join(conf.get_lib_dir(),
+ "{0}.prv".format(thumbprint))
+ crt_file = os.path.join(conf.get_lib_dir(),
+ "{0}.crt".format(thumbprint))
+ shutil.copyfile(trans_prv_file, prv_file)
+ shutil.copyfile(trans_cert_file, crt_file)
+
+
+ def get_vminfo(self):
+ vminfo = VMInfo()
+ data, etag = self._get_data(self.identity_uri)
+ set_properties("vminfo", vminfo, data)
+ return vminfo
+
+ def get_certs(self):
+ #TODO download and save certs
+ return CertList()
+
+ def get_vmagent_manifests(self, last_etag=None):
+ manifests = VMAgentManifestList()
+ data, etag = self._get_data(self.vmagent_uri)
+ if last_etag == None or last_etag < etag:
+ set_properties("vmAgentManifests", manifests.vmAgentManifests, data)
+ return manifests, etag
+
+ def get_vmagent_pkgs(self, vmagent_manifest):
+ #Agent package is the same with extension handler
+ vmagent_pkgs = ExtHandlerPackageList()
+ data = None
+ for manifest_uri in vmagent_manifest.versionsManifestUris:
+ try:
+ data = self._get_data(manifest_uri.uri)
+ break
+ except ProtocolError as e:
+ logger.warn("Failed to get vmagent versions: {0}", e)
+ logger.info("Retry getting vmagent versions")
+ if data is None:
+ raise ProtocolError(("Failed to get versions for vm agent: {0}"
+ "").format(vmagent_manifest.family))
+ set_properties("vmAgentVersions", vmagent_pkgs, data)
+ # TODO: What etag should this return?
+ return vmagent_pkgs
+
+ def get_ext_handlers(self, last_etag=None):
+ headers = {
+ "x-ms-vmagent-public-x509-cert": self._get_trans_cert()
+ }
+ ext_list = ExtHandlerList()
+ data, etag = self._get_data(self.ext_uri, headers=headers)
+ if last_etag == None or last_etag < etag:
+ set_properties("extensionHandlers", ext_list.extHandlers, data)
+ return ext_list, etag
+
+ def get_ext_handler_pkgs(self, ext_handler):
+ ext_handler_pkgs = ExtHandlerPackageList()
+ data = None
+ for version_uri in ext_handler.versionUris:
+ try:
+ data, etag = self._get_data(version_uri.uri)
+ break
+ except ProtocolError as e:
+ logger.warn("Failed to get version uris: {0}", e)
+ logger.info("Retry getting version uris")
+ set_properties("extensionPackages", ext_handler_pkgs, data)
+ return ext_handler_pkgs
+
+ def report_provision_status(self, provision_status):
+ validate_param('provisionStatus', provision_status, ProvisionStatus)
+ data = get_properties(provision_status)
+ self._put_data(self.provision_status_uri, data)
+
+ def report_vm_status(self, vm_status):
+ validate_param('vmStatus', vm_status, VMStatus)
+ data = get_properties(vm_status)
+ #TODO code field is not implemented for metadata protocol yet. Remove it
+ handler_statuses = data['vmAgent']['extensionHandlers']
+ for handler_status in handler_statuses:
+ try:
+ handler_status.pop('code', None)
+ except KeyError:
+ pass
+
+ self._put_data(self.vm_status_uri, data)
+
+ def report_ext_status(self, ext_handler_name, ext_name, ext_status):
+ validate_param('extensionStatus', ext_status, ExtensionStatus)
+ data = get_properties(ext_status)
+ uri = self.ext_status_uri.format(ext_name)
+ self._put_data(uri, data)
+
+ def report_event(self, events):
+ #TODO disable telemetry for azure stack test
+ #validate_param('events', events, TelemetryEventList)
+ #data = get_properties(events)
+ #self._post_data(self.event_uri, data)
+ pass
+
diff --git a/azurelinuxagent/common/protocol/ovfenv.py b/azurelinuxagent/common/protocol/ovfenv.py
new file mode 100644
index 0000000..4901871
--- /dev/null
+++ b/azurelinuxagent/common/protocol/ovfenv.py
@@ -0,0 +1,113 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+"""
+Copy and parse ovf-env.xml from provisioning ISO and local cache
+"""
+import os
+import re
+import shutil
+import xml.dom.minidom as minidom
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext
+
+OVF_VERSION = "1.0"
+OVF_NAME_SPACE = "http://schemas.dmtf.org/ovf/environment/1"
+WA_NAME_SPACE = "http://schemas.microsoft.com/windowsazure"
+
+def _validate_ovf(val, msg):
+ if val is None:
+ raise ProtocolError("Failed to parse OVF XML: {0}".format(msg))
+
+class OvfEnv(object):
+ """
+ Read, and process provisioning info from provisioning file OvfEnv.xml
+ """
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("ovf-env is None")
+ logger.verbose("Load ovf-env.xml")
+ self.hostname = None
+ self.username = None
+ self.user_password = None
+ self.customdata = None
+ self.disable_ssh_password_auth = True
+ self.ssh_pubkeys = []
+ self.ssh_keypairs = []
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Parse xml tree, retreiving user and ssh key information.
+ Return self.
+ """
+ wans = WA_NAME_SPACE
+ ovfns = OVF_NAME_SPACE
+
+ xml_doc = parse_doc(xml_text)
+
+ environment = find(xml_doc, "Environment", namespace=ovfns)
+ _validate_ovf(environment, "Environment not found")
+
+ section = find(environment, "ProvisioningSection", namespace=wans)
+ _validate_ovf(section, "ProvisioningSection not found")
+
+ version = findtext(environment, "Version", namespace=wans)
+ _validate_ovf(version, "Version not found")
+
+ if version > OVF_VERSION:
+ logger.warn("Newer provisioning configuration detected. "
+ "Please consider updating waagent")
+
+ conf_set = find(section, "LinuxProvisioningConfigurationSet",
+ namespace=wans)
+ _validate_ovf(conf_set, "LinuxProvisioningConfigurationSet not found")
+
+ self.hostname = findtext(conf_set, "HostName", namespace=wans)
+ _validate_ovf(self.hostname, "HostName not found")
+
+ self.username = findtext(conf_set, "UserName", namespace=wans)
+ _validate_ovf(self.username, "UserName not found")
+
+ self.user_password = findtext(conf_set, "UserPassword", namespace=wans)
+
+ self.customdata = findtext(conf_set, "CustomData", namespace=wans)
+
+ auth_option = findtext(conf_set, "DisableSshPasswordAuthentication",
+ namespace=wans)
+ if auth_option is not None and auth_option.lower() == "true":
+ self.disable_ssh_password_auth = True
+ else:
+ self.disable_ssh_password_auth = False
+
+ public_keys = findall(conf_set, "PublicKey", namespace=wans)
+ for public_key in public_keys:
+ path = findtext(public_key, "Path", namespace=wans)
+ fingerprint = findtext(public_key, "Fingerprint", namespace=wans)
+ value = findtext(public_key, "Value", namespace=wans)
+ self.ssh_pubkeys.append((path, fingerprint, value))
+
+ keypairs = findall(conf_set, "KeyPair", namespace=wans)
+ for keypair in keypairs:
+ path = findtext(keypair, "Path", namespace=wans)
+ fingerprint = findtext(keypair, "Fingerprint", namespace=wans)
+ self.ssh_keypairs.append((path, fingerprint))
+
diff --git a/azurelinuxagent/common/protocol/restapi.py b/azurelinuxagent/common/protocol/restapi.py
new file mode 100644
index 0000000..7f00488
--- /dev/null
+++ b/azurelinuxagent/common/protocol/restapi.py
@@ -0,0 +1,272 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+import os
+import copy
+import re
+import json
+import xml.dom.minidom
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError, HttpError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.restutil as restutil
+
+def validate_param(name, val, expected_type):
+ if val is None:
+ raise ProtocolError("{0} is None".format(name))
+ if not isinstance(val, expected_type):
+ raise ProtocolError(("{0} type should be {1} not {2}"
+ "").format(name, expected_type, type(val)))
+
+def set_properties(name, obj, data):
+ if isinstance(obj, DataContract):
+ validate_param("Property '{0}'".format(name), data, dict)
+ for prob_name, prob_val in data.items():
+ prob_full_name = "{0}.{1}".format(name, prob_name)
+ try:
+ prob = getattr(obj, prob_name)
+ except AttributeError:
+ logger.warn("Unknown property: {0}", prob_full_name)
+ continue
+ prob = set_properties(prob_full_name, prob, prob_val)
+ setattr(obj, prob_name, prob)
+ return obj
+ elif isinstance(obj, DataContractList):
+ validate_param("List '{0}'".format(name), data, list)
+ for item_data in data:
+ item = obj.item_cls()
+ item = set_properties(name, item, item_data)
+ obj.append(item)
+ return obj
+ else:
+ return data
+
+def get_properties(obj):
+ if isinstance(obj, DataContract):
+ data = {}
+ props = vars(obj)
+ for prob_name, prob in list(props.items()):
+ data[prob_name] = get_properties(prob)
+ return data
+ elif isinstance(obj, DataContractList):
+ data = []
+ for item in obj:
+ item_data = get_properties(item)
+ data.append(item_data)
+ return data
+ else:
+ return obj
+
+class DataContract(object):
+ pass
+
+class DataContractList(list):
+ def __init__(self, item_cls):
+ self.item_cls = item_cls
+
+"""
+Data contract between guest and host
+"""
+class VMInfo(DataContract):
+ def __init__(self, subscriptionId=None, vmName=None, containerId=None,
+ roleName=None, roleInstanceName=None, tenantName=None):
+ self.subscriptionId = subscriptionId
+ self.vmName = vmName
+ self.containerId = containerId
+ self.roleName = roleName
+ self.roleInstanceName = roleInstanceName
+ self.tenantName = tenantName
+
+class Cert(DataContract):
+ def __init__(self, name=None, thumbprint=None, certificateDataUri=None):
+ self.name = name
+ self.thumbprint = thumbprint
+ self.certificateDataUri = certificateDataUri
+
+class CertList(DataContract):
+ def __init__(self):
+ self.certificates = DataContractList(Cert)
+
+#TODO: confirm vmagent manifest schema
+class VMAgentManifestUri(DataContract):
+ def __init__(self, uri=None):
+ self.uri = uri
+
+class VMAgentManifest(DataContract):
+ def __init__(self, family=None):
+ self.family = family
+ self.versionsManifestUris = DataContractList(VMAgentManifestUri)
+
+class VMAgentManifestList(DataContract):
+ def __init__(self):
+ self.vmAgentManifests = DataContractList(VMAgentManifest)
+
+class Extension(DataContract):
+ def __init__(self, name=None, sequenceNumber=None, publicSettings=None,
+ protectedSettings=None, certificateThumbprint=None):
+ self.name = name
+ self.sequenceNumber = sequenceNumber
+ self.publicSettings = publicSettings
+ self.protectedSettings = protectedSettings
+ self.certificateThumbprint = certificateThumbprint
+
+class ExtHandlerProperties(DataContract):
+ def __init__(self):
+ self.version = None
+ self.upgradePolicy = None
+ self.state = None
+ self.extensions = DataContractList(Extension)
+
+class ExtHandlerVersionUri(DataContract):
+ def __init__(self):
+ self.uri = None
+
+class ExtHandler(DataContract):
+ def __init__(self, name=None):
+ self.name = name
+ self.properties = ExtHandlerProperties()
+ self.versionUris = DataContractList(ExtHandlerVersionUri)
+
+class ExtHandlerList(DataContract):
+ def __init__(self):
+ self.extHandlers = DataContractList(ExtHandler)
+
+class ExtHandlerPackageUri(DataContract):
+ def __init__(self, uri=None):
+ self.uri = uri
+
+class ExtHandlerPackage(DataContract):
+ def __init__(self, version = None):
+ self.version = version
+ self.uris = DataContractList(ExtHandlerPackageUri)
+ # TODO update the naming to align with metadata protocol
+ self.isinternal = False
+
+class ExtHandlerPackageList(DataContract):
+ def __init__(self):
+ self.versions = DataContractList(ExtHandlerPackage)
+
+class VMProperties(DataContract):
+ def __init__(self, certificateThumbprint=None):
+ #TODO need to confirm the property name
+ self.certificateThumbprint = certificateThumbprint
+
+class ProvisionStatus(DataContract):
+ def __init__(self, status=None, subStatus=None, description=None):
+ self.status = status
+ self.subStatus = subStatus
+ self.description = description
+ self.properties = VMProperties()
+
+class ExtensionSubStatus(DataContract):
+ def __init__(self, name=None, status=None, code=None, message=None):
+ self.name = name
+ self.status = status
+ self.code = code
+ self.message = message
+
+class ExtensionStatus(DataContract):
+ def __init__(self, configurationAppliedTime=None, operation=None,
+ status=None, seq_no=None, code=None, message=None):
+ self.configurationAppliedTime = configurationAppliedTime
+ self.operation = operation
+ self.status = status
+ self.sequenceNumber = seq_no
+ self.code = code
+ self.message = message
+ self.substatusList = DataContractList(ExtensionSubStatus)
+
+class ExtHandlerStatus(DataContract):
+ def __init__(self, name=None, version=None, status=None, code=0,
+ message=None):
+ self.name = name
+ self.version = version
+ self.status = status
+ self.code = code
+ self.message = message
+ self.extensions = DataContractList(ustr)
+
+class VMAgentStatus(DataContract):
+ def __init__(self, version=None, status=None, message=None):
+ self.version = version
+ self.status = status
+ self.message = message
+ self.extensionHandlers = DataContractList(ExtHandlerStatus)
+
+class VMStatus(DataContract):
+ def __init__(self):
+ self.vmAgent = VMAgentStatus()
+
+class TelemetryEventParam(DataContract):
+ def __init__(self, name=None, value=None):
+ self.name = name
+ self.value = value
+
+class TelemetryEvent(DataContract):
+ def __init__(self, eventId=None, providerId=None):
+ self.eventId = eventId
+ self.providerId = providerId
+ self.parameters = DataContractList(TelemetryEventParam)
+
+class TelemetryEventList(DataContract):
+ def __init__(self):
+ self.events = DataContractList(TelemetryEvent)
+
+class Protocol(DataContract):
+
+ def detect(self):
+ raise NotImplementedError()
+
+ def get_vminfo(self):
+ raise NotImplementedError()
+
+ def get_certs(self):
+ raise NotImplementedError()
+
+ def get_vmagent_manifests(self):
+ raise NotImplementedError()
+
+ def get_vmagent_pkgs(self):
+ raise NotImplementedError()
+
+ def get_ext_handlers(self):
+ raise NotImplementedError()
+
+ def get_ext_handler_pkgs(self, extension):
+ raise NotImplementedError()
+
+ def download_ext_handler_pkg(self, uri):
+ try:
+ resp = restutil.http_get(uri, chk_proxy=True)
+ if resp.status == restutil.httpclient.OK:
+ return resp.read()
+ except HttpError as e:
+ raise ProtocolError("Failed to download from: {0}".format(uri), e)
+
+ def report_provision_status(self, provision_status):
+ raise NotImplementedError()
+
+ def report_vm_status(self, vm_status):
+ raise NotImplementedError()
+
+ def report_ext_status(self, ext_handler_name, ext_name, ext_status):
+ raise NotImplementedError()
+
+ def report_event(self, event):
+ raise NotImplementedError()
+
diff --git a/azurelinuxagent/common/protocol/util.py b/azurelinuxagent/common/protocol/util.py
new file mode 100644
index 0000000..7e7a74f
--- /dev/null
+++ b/azurelinuxagent/common/protocol/util.py
@@ -0,0 +1,285 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+import os
+import re
+import shutil
+import time
+import threading
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import ProtocolError, OSUtilError, \
+ ProtocolNotFoundError, DhcpError
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.osutil import get_osutil
+from azurelinuxagent.common.dhcp import get_dhcp_handler
+from azurelinuxagent.common.protocol.ovfenv import OvfEnv
+from azurelinuxagent.common.protocol.wire import WireProtocol
+from azurelinuxagent.common.protocol.metadata import MetadataProtocol, \
+ METADATA_ENDPOINT
+import azurelinuxagent.common.utils.shellutil as shellutil
+
+OVF_FILE_NAME = "ovf-env.xml"
+
+#Tag file to indicate usage of metadata protocol
+TAG_FILE_NAME = "useMetadataEndpoint.tag"
+
+PROTOCOL_FILE_NAME = "Protocol"
+
+#MAX retry times for protocol probing
+MAX_RETRY = 360
+
+PROBE_INTERVAL = 10
+
+ENDPOINT_FILE_NAME = "WireServerEndpoint"
+
+def get_protocol_util():
+ return ProtocolUtil()
+
+class ProtocolUtil(object):
+ """
+ ProtocolUtil handles initialization for protocol instance. 2 protocol types
+ are invoked, wire protocol and metadata protocols.
+ """
+ def __init__(self):
+ self.lock = threading.Lock()
+ self.protocol = None
+ self.osutil = get_osutil()
+ self.dhcp_handler = get_dhcp_handler()
+
+ def copy_ovf_env(self):
+ """
+ Copy ovf env file from dvd to hard disk.
+ Remove password before save it to the disk
+ """
+ dvd_mount_point = conf.get_dvd_mount_point()
+ ovf_file_path_on_dvd = os.path.join(dvd_mount_point, OVF_FILE_NAME)
+ tag_file_path_on_dvd = os.path.join(dvd_mount_point, TAG_FILE_NAME)
+ try:
+ self.osutil.mount_dvd()
+ ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
+ ovfenv = OvfEnv(ovfxml)
+ ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml)
+ ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
+ fileutil.write_file(ovf_file_path, ovfxml)
+
+ if os.path.isfile(tag_file_path_on_dvd):
+ logger.info("Found {0} in provisioning ISO", TAG_FILE_NAME)
+ tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME)
+ shutil.copyfile(tag_file_path_on_dvd, tag_file_path)
+
+ except (OSUtilError, IOError) as e:
+ raise ProtocolError(ustr(e))
+
+ try:
+ self.osutil.umount_dvd()
+ self.osutil.eject_dvd()
+ except OSUtilError as e:
+ logger.warn(ustr(e))
+
+ return ovfenv
+
+ def get_ovf_env(self):
+ """
+ Load saved ovf-env.xml
+ """
+ ovf_file_path = os.path.join(conf.get_lib_dir(), OVF_FILE_NAME)
+ if os.path.isfile(ovf_file_path):
+ xml_text = fileutil.read_file(ovf_file_path)
+ return OvfEnv(xml_text)
+ else:
+ raise ProtocolError("ovf-env.xml is missing.")
+
+ def _get_wireserver_endpoint(self):
+ try:
+ file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME)
+ return fileutil.read_file(file_path)
+ except IOError as e:
+ raise OSUtilError(ustr(e))
+
+ def _set_wireserver_endpoint(self, endpoint):
+ try:
+ file_path = os.path.join(conf.get_lib_dir(), ENDPOINT_FILE_NAME)
+ fileutil.write_file(file_path, endpoint)
+ except IOError as e:
+ raise OSUtilError(ustr(e))
+
+ def _detect_wire_protocol(self):
+ endpoint = self.dhcp_handler.endpoint
+ if endpoint is None:
+ logger.info("WireServer endpoint is not found. Rerun dhcp handler")
+ try:
+ self.dhcp_handler.run()
+ except DhcpError as e:
+ raise ProtocolError(ustr(e))
+ endpoint = self.dhcp_handler.endpoint
+
+ try:
+ protocol = WireProtocol(endpoint)
+ protocol.detect()
+ self._set_wireserver_endpoint(endpoint)
+ self.save_protocol("WireProtocol")
+ return protocol
+ except ProtocolError as e:
+ logger.info("WireServer is not responding. Reset endpoint")
+ self.dhcp_handler.endpoint = None
+ self.dhcp_handler.skip_cache = True
+ raise e
+
+ def _detect_metadata_protocol(self):
+ protocol = MetadataProtocol()
+ protocol.detect()
+
+ #Only allow root access METADATA_ENDPOINT
+ self.osutil.set_admin_access_to_ip(METADATA_ENDPOINT)
+
+ self.save_protocol("MetadataProtocol")
+
+ return protocol
+
+ def _detect_protocol(self, protocols):
+ """
+ Probe protocol endpoints in turn.
+ """
+ self.clear_protocol()
+
+ for retry in range(0, MAX_RETRY):
+ for protocol in protocols:
+ try:
+ if protocol == "WireProtocol":
+ return self._detect_wire_protocol()
+
+ if protocol == "MetadataProtocol":
+ return self._detect_metadata_protocol()
+
+ except ProtocolError as e:
+ logger.info("Protocol endpoint not found: {0}, {1}",
+ protocol, e)
+
+ if retry < MAX_RETRY -1:
+ logger.info("Retry detect protocols: retry={0}", retry)
+ time.sleep(PROBE_INTERVAL)
+ raise ProtocolNotFoundError("No protocol found.")
+
+ def _get_protocol(self):
+ """
+ Get protocol instance based on previous detecting result.
+ """
+ protocol_file_path = os.path.join(conf.get_lib_dir(),
+ PROTOCOL_FILE_NAME)
+ if not os.path.isfile(protocol_file_path):
+ raise ProtocolNotFoundError("No protocol found")
+
+ protocol_name = fileutil.read_file(protocol_file_path)
+ if protocol_name == "WireProtocol":
+ endpoint = self._get_wireserver_endpoint()
+ return WireProtocol(endpoint)
+ elif protocol_name == "MetadataProtocol":
+ return MetadataProtocol()
+ else:
+ raise ProtocolNotFoundError(("Unknown protocol: {0}"
+ "").format(protocol_name))
+
+ def save_protocol(self, protocol_name):
+ """
+ Save protocol endpoint
+ """
+ protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME)
+ try:
+ fileutil.write_file(protocol_file_path, protocol_name)
+ except IOError as e:
+ logger.error("Failed to save protocol endpoint: {0}", e)
+
+
+ def clear_protocol(self):
+ """
+ Cleanup previous saved endpoint.
+ """
+ logger.info("Clean protocol")
+ self.protocol = None
+ protocol_file_path = os.path.join(conf.get_lib_dir(), PROTOCOL_FILE_NAME)
+ if not os.path.isfile(protocol_file_path):
+ return
+
+ try:
+ os.remove(protocol_file_path)
+ except IOError as e:
+ logger.error("Failed to clear protocol endpoint: {0}", e)
+
+ def get_protocol(self):
+ """
+ Detect protocol by endpoints
+
+ :returns: protocol instance
+ """
+ self.lock.acquire()
+
+ try:
+ if self.protocol is not None:
+ return self.protocol
+
+ try:
+ self.protocol = self._get_protocol()
+ return self.protocol
+ except ProtocolNotFoundError:
+ pass
+
+ logger.info("Detect protocol endpoints")
+ protocols = ["WireProtocol", "MetadataProtocol"]
+ self.protocol = self._detect_protocol(protocols)
+
+ return self.protocol
+
+ finally:
+ self.lock.release()
+
+
+ def get_protocol_by_file(self):
+ """
+ Detect protocol by tag file.
+
+ If a file "useMetadataEndpoint.tag" is found on provision iso,
+ metedata protocol will be used. No need to probe for wire protocol
+
+ :returns: protocol instance
+ """
+ self.lock.acquire()
+
+ try:
+ if self.protocol is not None:
+ return self.protocol
+
+ try:
+ self.protocol = self._get_protocol()
+ return self.protocol
+ except ProtocolNotFoundError:
+ pass
+
+ logger.info("Detect protocol by file")
+ tag_file_path = os.path.join(conf.get_lib_dir(), TAG_FILE_NAME)
+ protocols = []
+ if os.path.isfile(tag_file_path):
+ protocols.append("MetadataProtocol")
+ else:
+ protocols.append("WireProtocol")
+ self.protocol = self._detect_protocol(protocols)
+ return self.protocol
+
+ finally:
+ self.lock.release()
diff --git a/azurelinuxagent/common/protocol/wire.py b/azurelinuxagent/common/protocol/wire.py
new file mode 100644
index 0000000..29a1663
--- /dev/null
+++ b/azurelinuxagent/common/protocol/wire.py
@@ -0,0 +1,1218 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import time
+import xml.sax.saxutils as saxutils
+import azurelinuxagent.common.conf as conf
+from azurelinuxagent.common.exception import ProtocolNotFoundError
+from azurelinuxagent.common.future import httpclient, bytebuffer
+from azurelinuxagent.common.utils.textutil import parse_doc, findall, find, findtext, \
+ getattrib, gettext, remove_bom, get_bytes_from_pem
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.cryptutil import CryptUtil
+from azurelinuxagent.common.protocol.restapi import *
+from azurelinuxagent.common.protocol.hostplugin import HostPluginProtocol
+
+VERSION_INFO_URI = "http://{0}/?comp=versions"
+GOAL_STATE_URI = "http://{0}/machine/?comp=goalstate"
+HEALTH_REPORT_URI = "http://{0}/machine?comp=health"
+ROLE_PROP_URI = "http://{0}/machine?comp=roleProperties"
+TELEMETRY_URI = "http://{0}/machine?comp=telemetrydata"
+
+WIRE_SERVER_ADDR_FILE_NAME = "WireServer"
+INCARNATION_FILE_NAME = "Incarnation"
+GOAL_STATE_FILE_NAME = "GoalState.{0}.xml"
+HOSTING_ENV_FILE_NAME = "HostingEnvironmentConfig.xml"
+SHARED_CONF_FILE_NAME = "SharedConfig.xml"
+CERTS_FILE_NAME = "Certificates.xml"
+P7M_FILE_NAME = "Certificates.p7m"
+PEM_FILE_NAME = "Certificates.pem"
+EXT_CONF_FILE_NAME = "ExtensionsConfig.{0}.xml"
+MANIFEST_FILE_NAME = "{0}.{1}.manifest.xml"
+TRANSPORT_CERT_FILE_NAME = "TransportCert.pem"
+TRANSPORT_PRV_FILE_NAME = "TransportPrivate.pem"
+
+PROTOCOL_VERSION = "2012-11-30"
+ENDPOINT_FINE_NAME = "WireServer"
+
+SHORT_WAITING_INTERVAL = 1 # 1 second
+LONG_WAITING_INTERVAL = 15 # 15 seconds
+
+
+class UploadError(HttpError):
+ pass
+
+
+class WireProtocolResourceGone(ProtocolError):
+ pass
+
+
+class WireProtocol(Protocol):
+ """Slim layer to adapt wire protocol data to metadata protocol interface"""
+
+ # TODO: Clean-up goal state processing
+ # At present, some methods magically update GoalState (e.g., get_vmagent_manifests), others (e.g., get_vmagent_pkgs)
+ # assume its presence. A better approach would make an explicit update call that returns the incarnation number and
+ # establishes that number the "context" for all other calls (either by updating the internal state of the protocol or
+ # by having callers pass the incarnation number to the method).
+
+ def __init__(self, endpoint):
+ if endpoint is None:
+ raise ProtocolError("WireProtocol endpoint is None")
+ self.endpoint = endpoint
+ self.client = WireClient(self.endpoint)
+
+ def detect(self):
+ self.client.check_wire_protocol_version()
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ cryptutil.gen_transport_cert(trans_prv_file, trans_cert_file)
+
+ self.client.update_goal_state(forced=True)
+
+ def get_vminfo(self):
+ goal_state = self.client.get_goal_state()
+ hosting_env = self.client.get_hosting_env()
+
+ vminfo = VMInfo()
+ vminfo.subscriptionId = None
+ vminfo.vmName = hosting_env.vm_name
+ vminfo.tenantName = hosting_env.deployment_name
+ vminfo.roleName = hosting_env.role_name
+ vminfo.roleInstanceName = goal_state.role_instance_id
+ vminfo.containerId = goal_state.container_id
+ return vminfo
+
+ def get_certs(self):
+ certificates = self.client.get_certs()
+ return certificates.cert_list
+
+ def get_vmagent_manifests(self):
+ # Update goal state to get latest extensions config
+ self.client.update_goal_state()
+ goal_state = self.client.get_goal_state()
+ ext_conf = self.client.get_ext_conf()
+ return ext_conf.vmagent_manifests, goal_state.incarnation
+
+ def get_vmagent_pkgs(self, vmagent_manifest):
+ goal_state = self.client.get_goal_state()
+ man = self.client.get_gafamily_manifest(vmagent_manifest, goal_state)
+ return man.pkg_list
+
+ def get_ext_handlers(self):
+ logger.verbose("Get extension handler config")
+ # Update goal state to get latest extensions config
+ self.client.update_goal_state()
+ goal_state = self.client.get_goal_state()
+ ext_conf = self.client.get_ext_conf()
+ # In wire protocol, incarnation is equivalent to ETag
+ return ext_conf.ext_handlers, goal_state.incarnation
+
+ def get_ext_handler_pkgs(self, ext_handler):
+ logger.verbose("Get extension handler package")
+ goal_state = self.client.get_goal_state()
+ man = self.client.get_ext_manifest(ext_handler, goal_state)
+ return man.pkg_list
+
+ def report_provision_status(self, provision_status):
+ validate_param("provision_status", provision_status, ProvisionStatus)
+
+ if provision_status.status is not None:
+ self.client.report_health(provision_status.status,
+ provision_status.subStatus,
+ provision_status.description)
+ if provision_status.properties.certificateThumbprint is not None:
+ thumbprint = provision_status.properties.certificateThumbprint
+ self.client.report_role_prop(thumbprint)
+
+ def report_vm_status(self, vm_status):
+ validate_param("vm_status", vm_status, VMStatus)
+ self.client.status_blob.set_vm_status(vm_status)
+ self.client.upload_status_blob()
+
+ def report_ext_status(self, ext_handler_name, ext_name, ext_status):
+ validate_param("ext_status", ext_status, ExtensionStatus)
+ self.client.status_blob.set_ext_status(ext_handler_name, ext_status)
+
+ def report_event(self, events):
+ validate_param("events", events, TelemetryEventList)
+ self.client.report_event(events)
+
+
+def _build_role_properties(container_id, role_instance_id, thumbprint):
+ xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>"
+ u"<RoleProperties>"
+ u"<Container>"
+ u"<ContainerId>{0}</ContainerId>"
+ u"<RoleInstances>"
+ u"<RoleInstance>"
+ u"<Id>{1}</Id>"
+ u"<Properties>"
+ u"<Property name=\"CertificateThumbprint\" value=\"{2}\" />"
+ u"</Properties>"
+ u"</RoleInstance>"
+ u"</RoleInstances>"
+ u"</Container>"
+ u"</RoleProperties>"
+ u"").format(container_id, role_instance_id, thumbprint)
+ return xml
+
+
+def _build_health_report(incarnation, container_id, role_instance_id,
+ status, substatus, description):
+ # Escape '&', '<' and '>'
+ description = saxutils.escape(ustr(description))
+ detail = u''
+ if substatus is not None:
+ substatus = saxutils.escape(ustr(substatus))
+ detail = (u"<Details>"
+ u"<SubStatus>{0}</SubStatus>"
+ u"<Description>{1}</Description>"
+ u"</Details>").format(substatus, description)
+ xml = (u"<?xml version=\"1.0\" encoding=\"utf-8\"?>"
+ u"<Health "
+ u"xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\""
+ u" xmlns:xsd=\"http://www.w3.org/2001/XMLSchema\">"
+ u"<GoalStateIncarnation>{0}</GoalStateIncarnation>"
+ u"<Container>"
+ u"<ContainerId>{1}</ContainerId>"
+ u"<RoleInstanceList>"
+ u"<Role>"
+ u"<InstanceId>{2}</InstanceId>"
+ u"<Health>"
+ u"<State>{3}</State>"
+ u"{4}"
+ u"</Health>"
+ u"</Role>"
+ u"</RoleInstanceList>"
+ u"</Container>"
+ u"</Health>"
+ u"").format(incarnation,
+ container_id,
+ role_instance_id,
+ status,
+ detail)
+ return xml
+
+
+"""
+Convert VMStatus object to status blob format
+"""
+
+
+def ga_status_to_v1(ga_status):
+ formatted_msg = {
+ 'lang': 'en-US',
+ 'message': ga_status.message
+ }
+ v1_ga_status = {
+ 'version': ga_status.version,
+ 'status': ga_status.status,
+ 'formattedMessage': formatted_msg
+ }
+ return v1_ga_status
+
+
+def ext_substatus_to_v1(sub_status_list):
+ status_list = []
+ for substatus in sub_status_list:
+ status = {
+ "name": substatus.name,
+ "status": substatus.status,
+ "code": substatus.code,
+ "formattedMessage": {
+ "lang": "en-US",
+ "message": substatus.message
+ }
+ }
+ status_list.append(status)
+ return status_list
+
+
+def ext_status_to_v1(ext_name, ext_status):
+ if ext_status is None:
+ return None
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ v1_sub_status = ext_substatus_to_v1(ext_status.substatusList)
+ v1_ext_status = {
+ "status": {
+ "name": ext_name,
+ "configurationAppliedTime": ext_status.configurationAppliedTime,
+ "operation": ext_status.operation,
+ "status": ext_status.status,
+ "code": ext_status.code,
+ "formattedMessage": {
+ "lang": "en-US",
+ "message": ext_status.message
+ }
+ },
+ "version": 1.0,
+ "timestampUTC": timestamp
+ }
+ if len(v1_sub_status) != 0:
+ v1_ext_status['substatus'] = v1_sub_status
+ return v1_ext_status
+
+
+def ext_handler_status_to_v1(handler_status, ext_statuses, timestamp):
+ v1_handler_status = {
+ 'handlerVersion': handler_status.version,
+ 'handlerName': handler_status.name,
+ 'status': handler_status.status,
+ 'code': handler_status.code
+ }
+ if handler_status.message is not None:
+ v1_handler_status["formattedMessage"] = {
+ "lang": "en-US",
+ "message": handler_status.message
+ }
+
+ if len(handler_status.extensions) > 0:
+ # Currently, no more than one extension per handler
+ ext_name = handler_status.extensions[0]
+ ext_status = ext_statuses.get(ext_name)
+ v1_ext_status = ext_status_to_v1(ext_name, ext_status)
+ if ext_status is not None and v1_ext_status is not None:
+ v1_handler_status["runtimeSettingsStatus"] = {
+ 'settingsStatus': v1_ext_status,
+ 'sequenceNumber': ext_status.sequenceNumber
+ }
+ return v1_handler_status
+
+
+def vm_status_to_v1(vm_status, ext_statuses):
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+
+ v1_ga_status = ga_status_to_v1(vm_status.vmAgent)
+ v1_handler_status_list = []
+ for handler_status in vm_status.vmAgent.extensionHandlers:
+ v1_handler_status = ext_handler_status_to_v1(handler_status,
+ ext_statuses, timestamp)
+ if v1_handler_status is not None:
+ v1_handler_status_list.append(v1_handler_status)
+
+ v1_agg_status = {
+ 'guestAgentStatus': v1_ga_status,
+ 'handlerAggregateStatus': v1_handler_status_list
+ }
+ v1_vm_status = {
+ 'version': '1.0',
+ 'timestampUTC': timestamp,
+ 'aggregateStatus': v1_agg_status
+ }
+ return v1_vm_status
+
+
+class StatusBlob(object):
+ def __init__(self, client):
+ self.vm_status = None
+ self.ext_statuses = {}
+ self.client = client
+ self.type = None
+ self.data = None
+
+ def set_vm_status(self, vm_status):
+ validate_param("vmAgent", vm_status, VMStatus)
+ self.vm_status = vm_status
+
+ def set_ext_status(self, ext_handler_name, ext_status):
+ validate_param("extensionStatus", ext_status, ExtensionStatus)
+ self.ext_statuses[ext_handler_name] = ext_status
+
+ def to_json(self):
+ report = vm_status_to_v1(self.vm_status, self.ext_statuses)
+ return json.dumps(report)
+
+ __storage_version__ = "2014-02-14"
+
+ def upload(self, url):
+ # TODO upload extension only if content has changed
+ logger.verbose("Upload status blob")
+ upload_successful = False
+ self.type = self.get_blob_type(url)
+ self.data = self.to_json()
+ try:
+ if self.type == "BlockBlob":
+ self.put_block_blob(url, self.data)
+ elif self.type == "PageBlob":
+ self.put_page_blob(url, self.data)
+ else:
+ raise ProtocolError("Unknown blob type: {0}".format(self.type))
+ except HttpError as e:
+ logger.warn("Initial upload failed [{0}]".format(e))
+ else:
+ logger.verbose("Uploading status blob succeeded")
+ upload_successful = True
+ return upload_successful
+
+ def get_blob_type(self, url):
+ # Check blob type
+ logger.verbose("Check blob type.")
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ try:
+ resp = self.client.call_storage_service(restutil.http_head, url, {
+ "x-ms-date": timestamp,
+ 'x-ms-version': self.__class__.__storage_version__
+ })
+ except HttpError as e:
+ raise ProtocolError((u"Failed to get status blob type: {0}"
+ u"").format(e))
+ if resp is None or resp.status != httpclient.OK:
+ raise ProtocolError(("Failed to get status blob type: {0}"
+ "").format(resp.status))
+
+ blob_type = resp.getheader("x-ms-blob-type")
+ logger.verbose("Blob type={0}".format(blob_type))
+ return blob_type
+
+ def put_block_blob(self, url, data):
+ logger.verbose("Upload block blob")
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+ resp = self.client.call_storage_service(restutil.http_put, url, data,
+ {
+ "x-ms-date": timestamp,
+ "x-ms-blob-type": "BlockBlob",
+ "Content-Length": ustr(len(data)),
+ "x-ms-version": self.__class__.__storage_version__
+ })
+ if resp.status != httpclient.CREATED:
+ raise UploadError(
+ "Failed to upload block blob: {0}".format(resp.status))
+
+ def put_page_blob(self, url, data):
+ logger.verbose("Replace old page blob")
+
+ # Convert string into bytes
+ data = bytearray(data, encoding='utf-8')
+ timestamp = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
+
+ # Align to 512 bytes
+ page_blob_size = int((len(data) + 511) / 512) * 512
+ resp = self.client.call_storage_service(restutil.http_put, url, "",
+ {
+ "x-ms-date": timestamp,
+ "x-ms-blob-type": "PageBlob",
+ "Content-Length": "0",
+ "x-ms-blob-content-length": ustr(page_blob_size),
+ "x-ms-version": self.__class__.__storage_version__
+ })
+ if resp.status != httpclient.CREATED:
+ raise UploadError(
+ "Failed to clean up page blob: {0}".format(resp.status))
+
+ if url.count("?") < 0:
+ url = "{0}?comp=page".format(url)
+ else:
+ url = "{0}&comp=page".format(url)
+
+ logger.verbose("Upload page blob")
+ page_max = 4 * 1024 * 1024 # Max page size: 4MB
+ start = 0
+ end = 0
+ while end < len(data):
+ end = min(len(data), start + page_max)
+ content_size = end - start
+ # Align to 512 bytes
+ page_end = int((end + 511) / 512) * 512
+ buf_size = page_end - start
+ buf = bytearray(buf_size)
+ buf[0: content_size] = data[start: end]
+ resp = self.client.call_storage_service(
+ restutil.http_put, url, bytebuffer(buf),
+ {
+ "x-ms-date": timestamp,
+ "x-ms-range": "bytes={0}-{1}".format(start, page_end - 1),
+ "x-ms-page-write": "update",
+ "x-ms-version": self.__class__.__storage_version__,
+ "Content-Length": ustr(page_end - start)
+ })
+ if resp is None or resp.status != httpclient.CREATED:
+ raise UploadError(
+ "Failed to upload page blob: {0}".format(resp.status))
+ start = end
+
+
+def event_param_to_v1(param):
+ param_format = '<Param Name="{0}" Value={1} T="{2}" />'
+ param_type = type(param.value)
+ attr_type = ""
+ if param_type is int:
+ attr_type = 'mt:uint64'
+ elif param_type is str:
+ attr_type = 'mt:wstr'
+ elif ustr(param_type).count("'unicode'") > 0:
+ attr_type = 'mt:wstr'
+ elif param_type is bool:
+ attr_type = 'mt:bool'
+ elif param_type is float:
+ attr_type = 'mt:float64'
+ return param_format.format(param.name, saxutils.quoteattr(ustr(param.value)),
+ attr_type)
+
+
+def event_to_v1(event):
+ params = ""
+ for param in event.parameters:
+ params += event_param_to_v1(param)
+ event_str = ('<Event id="{0}">'
+ '<![CDATA[{1}]]>'
+ '</Event>').format(event.eventId, params)
+ return event_str
+
+
+class WireClient(object):
+ def __init__(self, endpoint):
+ logger.info("Wire server endpoint:{0}", endpoint)
+ self.endpoint = endpoint
+ self.goal_state = None
+ self.updated = None
+ self.hosting_env = None
+ self.shared_conf = None
+ self.certs = None
+ self.ext_conf = None
+ self.last_request = 0
+ self.req_count = 0
+ self.status_blob = StatusBlob(self)
+ self.host_plugin = HostPluginProtocol(self.endpoint)
+
+ def prevent_throttling(self):
+ """
+ Try to avoid throttling of wire server
+ """
+ now = time.time()
+ if now - self.last_request < 1:
+ logger.verbose("Last request issued less than 1 second ago")
+ logger.verbose("Sleep {0} second to avoid throttling.",
+ SHORT_WAITING_INTERVAL)
+ time.sleep(SHORT_WAITING_INTERVAL)
+ self.last_request = now
+
+ self.req_count += 1
+ if self.req_count % 3 == 0:
+ logger.verbose("Sleep {0} second to avoid throttling.",
+ SHORT_WAITING_INTERVAL)
+ time.sleep(SHORT_WAITING_INTERVAL)
+ self.req_count = 0
+
+ def call_wireserver(self, http_req, *args, **kwargs):
+ """
+ Call wire server. Handle throttling(403) and Resource Gone(410)
+ """
+ self.prevent_throttling()
+ for retry in range(0, 3):
+ resp = http_req(*args, **kwargs)
+ if resp.status == httpclient.FORBIDDEN:
+ logger.warn("Sending too much request to wire server")
+ logger.info("Sleep {0} second to avoid throttling.",
+ LONG_WAITING_INTERVAL)
+ time.sleep(LONG_WAITING_INTERVAL)
+ elif resp.status == httpclient.GONE:
+ msg = args[0] if len(args) > 0 else ""
+ raise WireProtocolResourceGone(msg)
+ else:
+ return resp
+ raise ProtocolError(("Calling wire server failed: {0}"
+ "").format(resp.status))
+
+ def decode_config(self, data):
+ if data is None:
+ return None
+ data = remove_bom(data)
+ xml_text = ustr(data, encoding='utf-8')
+ return xml_text
+
+ def fetch_config(self, uri, headers):
+ try:
+ resp = self.call_wireserver(restutil.http_get, uri,
+ headers=headers)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+
+ if (resp.status != httpclient.OK):
+ raise ProtocolError("{0} - {1}".format(resp.status, uri))
+
+ return self.decode_config(resp.read())
+
+ def fetch_cache(self, local_file):
+ if not os.path.isfile(local_file):
+ raise ProtocolError("{0} is missing.".format(local_file))
+ try:
+ return fileutil.read_file(local_file)
+ except IOError as e:
+ raise ProtocolError("Failed to read cache: {0}".format(e))
+
+ def save_cache(self, local_file, data):
+ try:
+ fileutil.write_file(local_file, data)
+ except IOError as e:
+ raise ProtocolError("Failed to write cache: {0}".format(e))
+
+ def call_storage_service(self, http_req, *args, **kwargs):
+ """
+ Call storage service, handle SERVICE_UNAVAILABLE(503)
+ """
+ for retry in range(0, 3):
+ resp = http_req(*args, **kwargs)
+ if resp.status == httpclient.SERVICE_UNAVAILABLE:
+ logger.warn("Storage service is not avaible temporaryly")
+ logger.info("Will retry later, in {0} seconds",
+ LONG_WAITING_INTERVAL)
+ time.sleep(LONG_WAITING_INTERVAL)
+ else:
+ return resp
+ raise ProtocolError(("Calling storage endpoint failed: {0}"
+ "").format(resp.status))
+
+ def fetch_manifest(self, version_uris):
+ for version_uri in version_uris:
+ logger.verbose("Fetch ext handler manifest: {0}", version_uri.uri)
+ try:
+ resp = self.call_storage_service(restutil.http_get,
+ version_uri.uri, None,
+ chk_proxy=True)
+ except HttpError as e:
+ raise ProtocolError(ustr(e))
+
+ if resp.status == httpclient.OK:
+ return self.decode_config(resp.read())
+ logger.warn("Failed to fetch ExtensionManifest: {0}, {1}",
+ resp.status, version_uri.uri)
+ logger.info("Will retry later, in {0} seconds",
+ LONG_WAITING_INTERVAL)
+ time.sleep(LONG_WAITING_INTERVAL)
+ raise ProtocolError(("Failed to fetch ExtensionManifest from "
+ "all sources"))
+
+ def update_hosting_env(self, goal_state):
+ if goal_state.hosting_env_uri is None:
+ raise ProtocolError("HostingEnvironmentConfig uri is empty")
+ local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
+ xml_text = self.fetch_config(goal_state.hosting_env_uri,
+ self.get_header())
+ self.save_cache(local_file, xml_text)
+ self.hosting_env = HostingEnv(xml_text)
+
+ def update_shared_conf(self, goal_state):
+ if goal_state.shared_conf_uri is None:
+ raise ProtocolError("SharedConfig uri is empty")
+ local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
+ xml_text = self.fetch_config(goal_state.shared_conf_uri,
+ self.get_header())
+ self.save_cache(local_file, xml_text)
+ self.shared_conf = SharedConfig(xml_text)
+
+ def update_certs(self, goal_state):
+ if goal_state.certs_uri is None:
+ return
+ local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
+ xml_text = self.fetch_config(goal_state.certs_uri,
+ self.get_header_for_cert())
+ self.save_cache(local_file, xml_text)
+ self.certs = Certificates(self, xml_text)
+
+ def update_ext_conf(self, goal_state):
+ if goal_state.ext_uri is None:
+ logger.info("ExtensionsConfig.xml uri is empty")
+ self.ext_conf = ExtensionsConfig(None)
+ return
+ incarnation = goal_state.incarnation
+ local_file = os.path.join(conf.get_lib_dir(),
+ EXT_CONF_FILE_NAME.format(incarnation))
+ xml_text = self.fetch_config(goal_state.ext_uri, self.get_header())
+ self.save_cache(local_file, xml_text)
+ self.ext_conf = ExtensionsConfig(xml_text)
+
+ def update_goal_state(self, forced=False, max_retry=3):
+ uri = GOAL_STATE_URI.format(self.endpoint)
+ xml_text = self.fetch_config(uri, self.get_header())
+ goal_state = GoalState(xml_text)
+
+ incarnation_file = os.path.join(conf.get_lib_dir(),
+ INCARNATION_FILE_NAME)
+
+ if not forced:
+ last_incarnation = None
+ if (os.path.isfile(incarnation_file)):
+ last_incarnation = fileutil.read_file(incarnation_file)
+ new_incarnation = goal_state.incarnation
+ if last_incarnation is not None and \
+ last_incarnation == new_incarnation:
+ # Goalstate is not updated.
+ return
+
+ # Start updating goalstate, retry on 410
+ for retry in range(0, max_retry):
+ try:
+ self.goal_state = goal_state
+ file_name = GOAL_STATE_FILE_NAME.format(goal_state.incarnation)
+ goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
+ self.save_cache(goal_state_file, xml_text)
+ self.save_cache(incarnation_file, goal_state.incarnation)
+ self.update_hosting_env(goal_state)
+ self.update_shared_conf(goal_state)
+ self.update_certs(goal_state)
+ self.update_ext_conf(goal_state)
+ return
+ except WireProtocolResourceGone:
+ logger.info("Incarnation is out of date. Update goalstate.")
+ xml_text = self.fetch_config(uri, self.get_header())
+ goal_state = GoalState(xml_text)
+
+ raise ProtocolError("Exceeded max retry updating goal state")
+
+ def get_goal_state(self):
+ if (self.goal_state is None):
+ incarnation_file = os.path.join(conf.get_lib_dir(),
+ INCARNATION_FILE_NAME)
+ incarnation = self.fetch_cache(incarnation_file)
+
+ file_name = GOAL_STATE_FILE_NAME.format(incarnation)
+ goal_state_file = os.path.join(conf.get_lib_dir(), file_name)
+ xml_text = self.fetch_cache(goal_state_file)
+ self.goal_state = GoalState(xml_text)
+ return self.goal_state
+
+ def get_hosting_env(self):
+ if (self.hosting_env is None):
+ local_file = os.path.join(conf.get_lib_dir(), HOSTING_ENV_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
+ self.hosting_env = HostingEnv(xml_text)
+ return self.hosting_env
+
+ def get_shared_conf(self):
+ if (self.shared_conf is None):
+ local_file = os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
+ self.shared_conf = SharedConfig(xml_text)
+ return self.shared_conf
+
+ def get_certs(self):
+ if (self.certs is None):
+ local_file = os.path.join(conf.get_lib_dir(), CERTS_FILE_NAME)
+ xml_text = self.fetch_cache(local_file)
+ self.certs = Certificates(self, xml_text)
+ if self.certs is None:
+ return None
+ return self.certs
+
+ def get_ext_conf(self):
+ if (self.ext_conf is None):
+ goal_state = self.get_goal_state()
+ if goal_state.ext_uri is None:
+ self.ext_conf = ExtensionsConfig(None)
+ else:
+ local_file = EXT_CONF_FILE_NAME.format(goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_cache(local_file)
+ self.ext_conf = ExtensionsConfig(xml_text)
+ return self.ext_conf
+
+ def get_ext_manifest(self, ext_handler, goal_state):
+ local_file = MANIFEST_FILE_NAME.format(ext_handler.name,
+ goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_manifest(ext_handler.versionUris)
+ self.save_cache(local_file, xml_text)
+ return ExtensionManifest(xml_text)
+
+ def get_gafamily_manifest(self, vmagent_manifest, goal_state):
+ local_file = MANIFEST_FILE_NAME.format(vmagent_manifest.family,
+ goal_state.incarnation)
+ local_file = os.path.join(conf.get_lib_dir(), local_file)
+ xml_text = self.fetch_manifest(vmagent_manifest.versionsManifestUris)
+ fileutil.write_file(local_file, xml_text)
+ return ExtensionManifest(xml_text)
+
+ def check_wire_protocol_version(self):
+ uri = VERSION_INFO_URI.format(self.endpoint)
+ version_info_xml = self.fetch_config(uri, None)
+ version_info = VersionInfo(version_info_xml)
+
+ preferred = version_info.get_preferred()
+ if PROTOCOL_VERSION == preferred:
+ logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
+ elif PROTOCOL_VERSION in version_info.get_supported():
+ logger.info("Wire protocol version:{0}", PROTOCOL_VERSION)
+ logger.warn("Server prefered version:{0}", preferred)
+ else:
+ error = ("Agent supported wire protocol version: {0} was not "
+ "advised by Fabric.").format(PROTOCOL_VERSION)
+ raise ProtocolNotFoundError(error)
+
+ def upload_status_blob(self):
+ ext_conf = self.get_ext_conf()
+ if ext_conf.status_upload_blob is not None:
+ if not self.status_blob.upload(ext_conf.status_upload_blob):
+ self.host_plugin.put_vm_status(self.status_blob,
+ ext_conf.status_upload_blob)
+
+ def report_role_prop(self, thumbprint):
+ goal_state = self.get_goal_state()
+ role_prop = _build_role_properties(goal_state.container_id,
+ goal_state.role_instance_id,
+ thumbprint)
+ role_prop = role_prop.encode("utf-8")
+ role_prop_uri = ROLE_PROP_URI.format(self.endpoint)
+ headers = self.get_header_for_xml_content()
+ try:
+ resp = self.call_wireserver(restutil.http_post, role_prop_uri,
+ role_prop, headers=headers)
+ except HttpError as e:
+ raise ProtocolError((u"Failed to send role properties: {0}"
+ u"").format(e))
+ if resp.status != httpclient.ACCEPTED:
+ raise ProtocolError((u"Failed to send role properties: {0}"
+ u", {1}").format(resp.status, resp.read()))
+
+ def report_health(self, status, substatus, description):
+ goal_state = self.get_goal_state()
+ health_report = _build_health_report(goal_state.incarnation,
+ goal_state.container_id,
+ goal_state.role_instance_id,
+ status,
+ substatus,
+ description)
+ health_report = health_report.encode("utf-8")
+ health_report_uri = HEALTH_REPORT_URI.format(self.endpoint)
+ headers = self.get_header_for_xml_content()
+ try:
+ resp = self.call_wireserver(restutil.http_post, health_report_uri,
+ health_report, headers=headers, max_retry=8)
+ except HttpError as e:
+ raise ProtocolError((u"Failed to send provision status: {0}"
+ u"").format(e))
+ if resp.status != httpclient.OK:
+ raise ProtocolError((u"Failed to send provision status: {0}"
+ u", {1}").format(resp.status, resp.read()))
+
+ def send_event(self, provider_id, event_str):
+ uri = TELEMETRY_URI.format(self.endpoint)
+ data_format = ('<?xml version="1.0"?>'
+ '<TelemetryData version="1.0">'
+ '<Provider id="{0}">{1}'
+ '</Provider>'
+ '</TelemetryData>')
+ data = data_format.format(provider_id, event_str)
+ try:
+ header = self.get_header_for_xml_content()
+ resp = self.call_wireserver(restutil.http_post, uri, data, header)
+ except HttpError as e:
+ raise ProtocolError("Failed to send events:{0}".format(e))
+
+ if resp.status != httpclient.OK:
+ logger.verbose(resp.read())
+ raise ProtocolError("Failed to send events:{0}".format(resp.status))
+
+ def report_event(self, event_list):
+ buf = {}
+ # Group events by providerId
+ for event in event_list.events:
+ if event.providerId not in buf:
+ buf[event.providerId] = ""
+ event_str = event_to_v1(event)
+ if len(event_str) >= 63 * 1024:
+ logger.warn("Single event too large: {0}", event_str[300:])
+ continue
+ if len(buf[event.providerId] + event_str) >= 63 * 1024:
+ self.send_event(event.providerId, buf[event.providerId])
+ buf[event.providerId] = ""
+ buf[event.providerId] = buf[event.providerId] + event_str
+
+ # Send out all events left in buffer.
+ for provider_id in list(buf.keys()):
+ if len(buf[provider_id]) > 0:
+ self.send_event(provider_id, buf[provider_id])
+
+ def get_header(self):
+ return {
+ "x-ms-agent-name": "WALinuxAgent",
+ "x-ms-version": PROTOCOL_VERSION
+ }
+
+ def get_header_for_xml_content(self):
+ return {
+ "x-ms-agent-name": "WALinuxAgent",
+ "x-ms-version": PROTOCOL_VERSION,
+ "Content-Type": "text/xml;charset=utf-8"
+ }
+
+ def get_header_for_cert(self):
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ content = self.fetch_cache(trans_cert_file)
+ cert = get_bytes_from_pem(content)
+ return {
+ "x-ms-agent-name": "WALinuxAgent",
+ "x-ms-version": PROTOCOL_VERSION,
+ "x-ms-cipher-name": "DES_EDE3_CBC",
+ "x-ms-guest-agent-public-x509-cert": cert
+ }
+
+class VersionInfo(object):
+ def __init__(self, xml_text):
+ """
+ Query endpoint server for wire protocol version.
+ Fail if our desired protocol version is not seen.
+ """
+ logger.verbose("Load Version.xml")
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ xml_doc = parse_doc(xml_text)
+ preferred = find(xml_doc, "Preferred")
+ self.preferred = findtext(preferred, "Version")
+ logger.info("Fabric preferred wire protocol version:{0}", self.preferred)
+
+ self.supported = []
+ supported = find(xml_doc, "Supported")
+ supported_version = findall(supported, "Version")
+ for node in supported_version:
+ version = gettext(node)
+ logger.verbose("Fabric supported wire protocol version:{0}", version)
+ self.supported.append(version)
+
+ def get_preferred(self):
+ return self.preferred
+
+ def get_supported(self):
+ return self.supported
+
+
+class GoalState(object):
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("GoalState.xml is None")
+ logger.verbose("Load GoalState.xml")
+ self.incarnation = None
+ self.expected_state = None
+ self.hosting_env_uri = None
+ self.shared_conf_uri = None
+ self.certs_uri = None
+ self.ext_uri = None
+ self.role_instance_id = None
+ self.container_id = None
+ self.load_balancer_probe_port = None
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Request configuration data from endpoint server.
+ """
+ self.xml_text = xml_text
+ xml_doc = parse_doc(xml_text)
+ self.incarnation = findtext(xml_doc, "Incarnation")
+ self.expected_state = findtext(xml_doc, "ExpectedState")
+ self.hosting_env_uri = findtext(xml_doc, "HostingEnvironmentConfig")
+ self.shared_conf_uri = findtext(xml_doc, "SharedConfig")
+ self.certs_uri = findtext(xml_doc, "Certificates")
+ self.ext_uri = findtext(xml_doc, "ExtensionsConfig")
+ role_instance = find(xml_doc, "RoleInstance")
+ self.role_instance_id = findtext(role_instance, "InstanceId")
+ container = find(xml_doc, "Container")
+ self.container_id = findtext(container, "ContainerId")
+ lbprobe_ports = find(xml_doc, "LBProbePorts")
+ self.load_balancer_probe_port = findtext(lbprobe_ports, "Port")
+ return self
+
+
+class HostingEnv(object):
+ """
+ parse Hosting enviromnet config and store in
+ HostingEnvironmentConfig.xml
+ """
+
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("HostingEnvironmentConfig.xml is None")
+ logger.verbose("Load HostingEnvironmentConfig.xml")
+ self.vm_name = None
+ self.role_name = None
+ self.deployment_name = None
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ parse and create HostingEnvironmentConfig.xml.
+ """
+ self.xml_text = xml_text
+ xml_doc = parse_doc(xml_text)
+ incarnation = find(xml_doc, "Incarnation")
+ self.vm_name = getattrib(incarnation, "instance")
+ role = find(xml_doc, "Role")
+ self.role_name = getattrib(role, "name")
+ deployment = find(xml_doc, "Deployment")
+ self.deployment_name = getattrib(deployment, "name")
+ return self
+
+
+class SharedConfig(object):
+ """
+ parse role endpoint server and goal state config.
+ """
+
+ def __init__(self, xml_text):
+ logger.verbose("Load SharedConfig.xml")
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ parse and write configuration to file SharedConfig.xml.
+ """
+ # Not used currently
+ return self
+
+class Certificates(object):
+ """
+ Object containing certificates of host and provisioned user.
+ """
+
+ def __init__(self, client, xml_text):
+ logger.verbose("Load Certificates.xml")
+ self.client = client
+ self.cert_list = CertList()
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Parse multiple certificates into seperate files.
+ """
+ xml_doc = parse_doc(xml_text)
+ data = findtext(xml_doc, "Data")
+ if data is None:
+ return
+
+ cryptutil = CryptUtil(conf.get_openssl_cmd())
+ p7m_file = os.path.join(conf.get_lib_dir(), P7M_FILE_NAME)
+ p7m = ("MIME-Version:1.0\n"
+ "Content-Disposition: attachment; filename=\"{0}\"\n"
+ "Content-Type: application/x-pkcs7-mime; name=\"{1}\"\n"
+ "Content-Transfer-Encoding: base64\n"
+ "\n"
+ "{2}").format(p7m_file, p7m_file, data)
+
+ self.client.save_cache(p7m_file, p7m)
+
+ trans_prv_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_PRV_FILE_NAME)
+ trans_cert_file = os.path.join(conf.get_lib_dir(),
+ TRANSPORT_CERT_FILE_NAME)
+ pem_file = os.path.join(conf.get_lib_dir(), PEM_FILE_NAME)
+ # decrypt certificates
+ cryptutil.decrypt_p7m(p7m_file, trans_prv_file, trans_cert_file,
+ pem_file)
+
+ # The parsing process use public key to match prv and crt.
+ buf = []
+ begin_crt = False
+ begin_prv = False
+ prvs = {}
+ thumbprints = {}
+ index = 0
+ v1_cert_list = []
+ with open(pem_file) as pem:
+ for line in pem.readlines():
+ buf.append(line)
+ if re.match(r'[-]+BEGIN.*KEY[-]+', line):
+ begin_prv = True
+ elif re.match(r'[-]+BEGIN.*CERTIFICATE[-]+', line):
+ begin_crt = True
+ elif re.match(r'[-]+END.*KEY[-]+', line):
+ tmp_file = self.write_to_tmp_file(index, 'prv', buf)
+ pub = cryptutil.get_pubkey_from_prv(tmp_file)
+ prvs[pub] = tmp_file
+ buf = []
+ index += 1
+ begin_prv = False
+ elif re.match(r'[-]+END.*CERTIFICATE[-]+', line):
+ tmp_file = self.write_to_tmp_file(index, 'crt', buf)
+ pub = cryptutil.get_pubkey_from_crt(tmp_file)
+ thumbprint = cryptutil.get_thumbprint_from_crt(tmp_file)
+ thumbprints[pub] = thumbprint
+ # Rename crt with thumbprint as the file name
+ crt = "{0}.crt".format(thumbprint)
+ v1_cert_list.append({
+ "name": None,
+ "thumbprint": thumbprint
+ })
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), crt))
+ buf = []
+ index += 1
+ begin_crt = False
+
+ # Rename prv key with thumbprint as the file name
+ for pubkey in prvs:
+ thumbprint = thumbprints[pubkey]
+ if thumbprint:
+ tmp_file = prvs[pubkey]
+ prv = "{0}.prv".format(thumbprint)
+ os.rename(tmp_file, os.path.join(conf.get_lib_dir(), prv))
+
+ for v1_cert in v1_cert_list:
+ cert = Cert()
+ set_properties("certs", cert, v1_cert)
+ self.cert_list.certificates.append(cert)
+
+ def write_to_tmp_file(self, index, suffix, buf):
+ file_name = os.path.join(conf.get_lib_dir(),
+ "{0}.{1}".format(index, suffix))
+ self.client.save_cache(file_name, "".join(buf))
+ return file_name
+
+
+class ExtensionsConfig(object):
+ """
+ parse ExtensionsConfig, downloading and unpacking them to /var/lib/waagent.
+ Install if <enabled>true</enabled>, remove if it is set to false.
+ """
+
+ def __init__(self, xml_text):
+ logger.verbose("Load ExtensionsConfig.xml")
+ self.ext_handlers = ExtHandlerList()
+ self.vmagent_manifests = VMAgentManifestList()
+ self.status_upload_blob = None
+ if xml_text is not None:
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ """
+ Write configuration to file ExtensionsConfig.xml.
+ """
+ xml_doc = parse_doc(xml_text)
+
+ ga_families_list = find(xml_doc, "GAFamilies")
+ ga_families = findall(ga_families_list, "GAFamily")
+
+ for ga_family in ga_families:
+ family = findtext(ga_family, "Name")
+ uris_list = find(ga_family, "Uris")
+ uris = findall(uris_list, "Uri")
+ manifest = VMAgentManifest()
+ manifest.family = family
+ for uri in uris:
+ manifestUri = VMAgentManifestUri(uri=gettext(uri))
+ manifest.versionsManifestUris.append(manifestUri)
+ self.vmagent_manifests.vmAgentManifests.append(manifest)
+
+ plugins_list = find(xml_doc, "Plugins")
+ plugins = findall(plugins_list, "Plugin")
+ plugin_settings_list = find(xml_doc, "PluginSettings")
+ plugin_settings = findall(plugin_settings_list, "Plugin")
+
+ for plugin in plugins:
+ ext_handler = self.parse_plugin(plugin)
+ self.ext_handlers.extHandlers.append(ext_handler)
+ self.parse_plugin_settings(ext_handler, plugin_settings)
+
+ self.status_upload_blob = findtext(xml_doc, "StatusUploadBlob")
+
+ def parse_plugin(self, plugin):
+ ext_handler = ExtHandler()
+ ext_handler.name = getattrib(plugin, "name")
+ ext_handler.properties.version = getattrib(plugin, "version")
+ ext_handler.properties.state = getattrib(plugin, "state")
+
+ auto_upgrade = getattrib(plugin, "autoUpgrade")
+ if auto_upgrade is not None and auto_upgrade.lower() == "true":
+ ext_handler.properties.upgradePolicy = "auto"
+ else:
+ ext_handler.properties.upgradePolicy = "manual"
+
+ location = getattrib(plugin, "location")
+ failover_location = getattrib(plugin, "failoverlocation")
+ for uri in [location, failover_location]:
+ version_uri = ExtHandlerVersionUri()
+ version_uri.uri = uri
+ ext_handler.versionUris.append(version_uri)
+ return ext_handler
+
+ def parse_plugin_settings(self, ext_handler, plugin_settings):
+ if plugin_settings is None:
+ return
+
+ name = ext_handler.name
+ version = ext_handler.properties.version
+ settings = [x for x in plugin_settings \
+ if getattrib(x, "name") == name and \
+ getattrib(x, "version") == version]
+
+ if settings is None or len(settings) == 0:
+ return
+
+ runtime_settings = None
+ runtime_settings_node = find(settings[0], "RuntimeSettings")
+ seqNo = getattrib(runtime_settings_node, "seqNo")
+ runtime_settings_str = gettext(runtime_settings_node)
+ try:
+ runtime_settings = json.loads(runtime_settings_str)
+ except ValueError as e:
+ logger.error("Invalid extension settings")
+ return
+
+ for plugin_settings_list in runtime_settings["runtimeSettings"]:
+ handler_settings = plugin_settings_list["handlerSettings"]
+ ext = Extension()
+ # There is no "extension name" in wire protocol.
+ # Put
+ ext.name = ext_handler.name
+ ext.sequenceNumber = seqNo
+ ext.publicSettings = handler_settings.get("publicSettings")
+ ext.protectedSettings = handler_settings.get("protectedSettings")
+ thumbprint = handler_settings.get("protectedSettingsCertThumbprint")
+ ext.certificateThumbprint = thumbprint
+ ext_handler.properties.extensions.append(ext)
+
+
+class ExtensionManifest(object):
+ def __init__(self, xml_text):
+ if xml_text is None:
+ raise ValueError("ExtensionManifest is None")
+ logger.verbose("Load ExtensionManifest.xml")
+ self.pkg_list = ExtHandlerPackageList()
+ self.parse(xml_text)
+
+ def parse(self, xml_text):
+ xml_doc = parse_doc(xml_text)
+ self._handle_packages(findall(find(xml_doc, "Plugins"), "Plugin"), False)
+ self._handle_packages(findall(find(xml_doc, "InternalPlugins"), "Plugin"), True)
+
+ def _handle_packages(self, packages, isinternal):
+ for package in packages:
+ version = findtext(package, "Version")
+
+ disallow_major_upgrade = findtext(package, "DisallowMajorVersionUpgrade")
+ if disallow_major_upgrade is None:
+ disallow_major_upgrade = ''
+ disallow_major_upgrade = disallow_major_upgrade.lower() == "true"
+
+ uris = find(package, "Uris")
+ uri_list = findall(uris, "Uri")
+ uri_list = [gettext(x) for x in uri_list]
+ pkg = ExtHandlerPackage()
+ pkg.version = version
+ pkg.disallow_major_upgrade = disallow_major_upgrade
+ for uri in uri_list:
+ pkg_uri = ExtHandlerVersionUri()
+ pkg_uri.uri = uri
+ pkg.uris.append(pkg_uri)
+
+ pkg.isinternal = isinternal
+ self.pkg_list.versions.append(pkg)
diff --git a/azurelinuxagent/common/rdma.py b/azurelinuxagent/common/rdma.py
new file mode 100644
index 0000000..0c17e38
--- /dev/null
+++ b/azurelinuxagent/common/rdma.py
@@ -0,0 +1,280 @@
+# Windows Azure Linux Agent
+#
+# Copyright 2016 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Handle packages and modules to enable RDMA for IB networking
+"""
+
+import os
+import re
+import time
+import threading
+
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+import azurelinuxagent.common.utils.fileutil as fileutil
+import azurelinuxagent.common.utils.shellutil as shellutil
+from azurelinuxagent.common.utils.textutil import parse_doc, find, getattrib
+
+
+from azurelinuxagent.common.protocol.wire import SHARED_CONF_FILE_NAME
+
+dapl_config_paths = [
+ '/etc/dat.conf',
+ '/etc/rdma/dat.conf',
+ '/usr/local/etc/dat.conf'
+]
+
+def setup_rdma_device():
+ logger.verbose("Parsing SharedConfig XML contents for RDMA details")
+ xml_doc = parse_doc(
+ fileutil.read_file(os.path.join(conf.get_lib_dir(), SHARED_CONF_FILE_NAME)))
+ if xml_doc is None:
+ logger.error("Could not parse SharedConfig XML document")
+ return
+ instance_elem = find(xml_doc, "Instance")
+ if not instance_elem:
+ logger.error("Could not find <Instance> in SharedConfig document")
+ return
+
+ rdma_ipv4_addr = getattrib(instance_elem, "rdmaIPv4Address")
+ if not rdma_ipv4_addr:
+ logger.error(
+ "Could not find rdmaIPv4Address attribute on Instance element of SharedConfig.xml document")
+ return
+
+ rdma_mac_addr = getattrib(instance_elem, "rdmaMacAddress")
+ if not rdma_mac_addr:
+ logger.error(
+ "Could not find rdmaMacAddress attribute on Instance element of SharedConfig.xml document")
+ return
+
+ # add colons to the MAC address (e.g. 00155D33FF1D ->
+ # 00:15:5D:33:FF:1D)
+ rdma_mac_addr = ':'.join([rdma_mac_addr[i:i+2]
+ for i in range(0, len(rdma_mac_addr), 2)])
+ logger.info("Found RDMA details. IPv4={0} MAC={1}".format(
+ rdma_ipv4_addr, rdma_mac_addr))
+
+ # Set up the RDMA device with collected informatino
+ RDMADeviceHandler(rdma_ipv4_addr, rdma_mac_addr).start()
+ logger.info("RDMA: device is set up")
+ return
+
+class RDMAHandler(object):
+
+ driver_module_name = 'hv_network_direct'
+
+ @staticmethod
+ def get_rdma_version():
+ """Retrieve the firmware version information from the system.
+ This depends on information provided by the Linux kernel."""
+
+ driver_info_source = '/var/lib/hyperv/.kvp_pool_0'
+ base_kernel_err_msg = 'Kernel does not provide the necessary '
+ base_kernel_err_msg += 'information or the hv_kvp_daemon is not '
+ base_kernel_err_msg += 'running.'
+ if not os.path.isfile(driver_info_source):
+ error_msg = 'RDMA: Source file "%s" does not exist. '
+ error_msg += base_kernel_err_msg
+ logger.error(error_msg % driver_info_source)
+ return
+
+ lines = open(driver_info_source).read()
+ if not lines:
+ error_msg = 'RDMA: Source file "%s" is empty. '
+ error_msg += base_kernel_err_msg
+ logger.error(error_msg % driver_info_source)
+ return
+
+ r = re.search("NdDriverVersion\0+(\d\d\d\.\d)", lines)
+ if r:
+ NdDriverVersion = r.groups()[0]
+ return NdDriverVersion
+ else:
+ error_msg = 'RDMA: NdDriverVersion not found in "%s"'
+ logger.error(error_msg % driver_info_source)
+ return
+
+ def load_driver_module(self):
+ """Load the kernel driver, this depends on the proper driver
+ to be installed with the install_driver() method"""
+ logger.info("RDMA: probing module '%s'" % self.driver_module_name)
+ result = shellutil.run('modprobe %s' % self.driver_module_name)
+ if result != 0:
+ error_msg = 'Could not load "%s" kernel module. '
+ error_msg += 'Run "modprobe %s" as root for more details'
+ logger.error(
+ error_msg % (self.driver_module_name, self.driver_module_name)
+ )
+ return
+ logger.info('RDMA: Loaded the kernel driver successfully.')
+ return True
+
+ def install_driver(self):
+ """Install the driver. This is distribution specific and must
+ be overwritten in the child implementation."""
+ logger.error('RDMAHandler.install_driver not implemented')
+
+ def is_driver_loaded(self):
+ """Check if the network module is loaded in kernel space"""
+ cmd = 'lsmod | grep ^%s' % self.driver_module_name
+ status, loaded_modules = shellutil.run_get_output(cmd)
+ logger.info('RDMA: Checking if the module loaded.')
+ if loaded_modules:
+ logger.info('RDMA: module loaded.')
+ return True
+ logger.info('RDMA: module not loaded.')
+
+ def reboot_system(self):
+ """Reboot the system. This is required as the kernel module for
+ the rdma driver cannot be unloaded with rmmod"""
+ logger.info('RDMA: Rebooting system.')
+ ret = shellutil.run('shutdown -r now')
+ if ret != 0:
+ logger.error('RDMA: Failed to reboot the system')
+
+
+dapl_config_paths = [
+ '/etc/dat.conf', '/etc/rdma/dat.conf', '/usr/local/etc/dat.conf']
+
+class RDMADeviceHandler(object):
+
+ """
+ Responsible for writing RDMA IP and MAC address to the /dev/hvnd_rdma
+ interface.
+ """
+
+ rdma_dev = '/dev/hvnd_rdma'
+ device_check_timeout_sec = 120
+ device_check_interval_sec = 1
+
+ ipv4_addr = None
+ mac_adr = None
+
+ def __init__(self, ipv4_addr, mac_addr):
+ self.ipv4_addr = ipv4_addr
+ self.mac_addr = mac_addr
+
+ def start(self):
+ """
+ Start a thread in the background to process the RDMA tasks and returns.
+ """
+ logger.info("RDMA: starting device processing in the background.")
+ threading.Thread(target=self.process).start()
+
+ def process(self):
+ RDMADeviceHandler.wait_rdma_device(
+ self.rdma_dev, self.device_check_timeout_sec, self.device_check_interval_sec)
+ RDMADeviceHandler.update_dat_conf(dapl_config_paths, self.ipv4_addr)
+ RDMADeviceHandler.write_rdma_config_to_device(
+ self.rdma_dev, self.ipv4_addr, self.mac_addr)
+ RDMADeviceHandler.update_network_interface(self.mac_addr, self.ipv4_addr)
+
+ @staticmethod
+ def update_dat_conf(paths, ipv4_addr):
+ """
+ Looks at paths for dat.conf file and updates the ip address for the
+ infiniband interface.
+ """
+ logger.info("Updating DAPL configuration file")
+ for f in paths:
+ logger.info("RDMA: trying {0}".format(f))
+ if not os.path.isfile(f):
+ logger.info(
+ "RDMA: DAPL config not found at {0}".format(f))
+ continue
+ logger.info("RDMA: DAPL config is at: {0}".format(f))
+ cfg = fileutil.read_file(f)
+ new_cfg = RDMADeviceHandler.replace_dat_conf_contents(
+ cfg, ipv4_addr)
+ fileutil.write_file(f, new_cfg)
+ logger.info("RDMA: DAPL configuration is updated")
+ return
+
+ raise Exception("RDMA: DAPL configuration file not found at predefined paths")
+
+ @staticmethod
+ def replace_dat_conf_contents(cfg, ipv4_addr):
+ old = "ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 dapl.2.0 \"\S+ 0\""
+ new = "ofa-v2-ib0 u2.0 nonthreadsafe default libdaplofa.so.2 dapl.2.0 \"{0} 0\"".format(
+ ipv4_addr)
+ return re.sub(old, new, cfg)
+
+ @staticmethod
+ def write_rdma_config_to_device(path, ipv4_addr, mac_addr):
+ data = RDMADeviceHandler.generate_rdma_config(ipv4_addr, mac_addr)
+ logger.info(
+ "RDMA: Updating device with configuration: {0}".format(data))
+ with open(path, "w") as f:
+ f.write(data)
+ logger.info("RDMA: Updated device with IPv4/MAC addr successfully")
+
+ @staticmethod
+ def generate_rdma_config(ipv4_addr, mac_addr):
+ return 'rdmaMacAddress="{0}" rdmaIPv4Address="{1}"'.format(mac_addr, ipv4_addr)
+
+ @staticmethod
+ def wait_rdma_device(path, timeout_sec, check_interval_sec):
+ logger.info("RDMA: waiting for device={0} timeout={1}s".format(path, timeout_sec))
+ total_retries = timeout_sec/check_interval_sec
+ n = 0
+ while n < total_retries:
+ if os.path.exists(path):
+ logger.info("RDMA: device ready")
+ return
+ logger.verbose(
+ "RDMA: device not ready, sleep {0}s".format(check_interval_sec))
+ time.sleep(check_interval_sec)
+ n += 1
+ logger.error("RDMA device wait timed out")
+ raise Exception("The device did not show up in {0} seconds ({1} retries)".format(
+ timeout_sec, total_retries))
+
+ @staticmethod
+ def update_network_interface(mac_addr, ipv4_addr):
+ netmask=16
+
+ logger.info("RDMA: will update the network interface with IPv4/MAC")
+
+ if_name=RDMADeviceHandler.get_interface_by_mac(mac_addr)
+ logger.info("RDMA: network interface found: {0}", if_name)
+ logger.info("RDMA: bringing network interface up")
+ if shellutil.run("ifconfig {0} up".format(if_name)) != 0:
+ raise Exception("Could not bring up RMDA interface: {0}".format(if_name))
+
+ logger.info("RDMA: configuring IPv4 addr and netmask on interface")
+ addr = '{0}/{1}'.format(ipv4_addr, netmask)
+ if shellutil.run("ifconfig {0} {1}".format(if_name, addr)) != 0:
+ raise Exception("Could set addr to {1} on {0}".format(if_name, addr))
+ logger.info("RDMA: network address and netmask configured on interface")
+
+ @staticmethod
+ def get_interface_by_mac(mac):
+ ret, output = shellutil.run_get_output("ifconfig -a")
+ if ret != 0:
+ raise Exception("Failed to list network interfaces")
+ output = output.replace('\n', '')
+ match = re.search(r"(eth\d).*(HWaddr|ether) {0}".format(mac),
+ output, re.IGNORECASE)
+ if match is None:
+ raise Exception("Failed to get ifname with mac: {0}".format(mac))
+ output = match.group(0)
+ eths = re.findall(r"eth\d", output)
+ if eths is None or len(eths) == 0:
+ raise Exception("ifname with mac: {0} not found".format(mac))
+ return eths[-1]
diff --git a/azurelinuxagent/common/utils/__init__.py b/azurelinuxagent/common/utils/__init__.py
new file mode 100644
index 0000000..1ea2f38
--- /dev/null
+++ b/azurelinuxagent/common/utils/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
diff --git a/azurelinuxagent/common/utils/cryptutil.py b/azurelinuxagent/common/utils/cryptutil.py
new file mode 100644
index 0000000..b35bda0
--- /dev/null
+++ b/azurelinuxagent/common/utils/cryptutil.py
@@ -0,0 +1,121 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import base64
+import struct
+from azurelinuxagent.common.future import ustr, bytebuffer
+from azurelinuxagent.common.exception import CryptError
+import azurelinuxagent.common.utils.shellutil as shellutil
+
+class CryptUtil(object):
+ def __init__(self, openssl_cmd):
+ self.openssl_cmd = openssl_cmd
+
+ def gen_transport_cert(self, prv_file, crt_file):
+ """
+ Create ssl certificate for https communication with endpoint server.
+ """
+ cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 "
+ "-newkey rsa:2048 -keyout {1} "
+ "-out {2}").format(self.openssl_cmd, prv_file, crt_file)
+ shellutil.run(cmd)
+
+ def get_pubkey_from_prv(self, file_name):
+ cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd,
+ file_name)
+ pub = shellutil.run_get_output(cmd)[1]
+ return pub
+
+ def get_pubkey_from_crt(self, file_name):
+ cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd,
+ file_name)
+ pub = shellutil.run_get_output(cmd)[1]
+ return pub
+
+ def get_thumbprint_from_crt(self, file_name):
+ cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd,
+ file_name)
+ thumbprint = shellutil.run_get_output(cmd)[1]
+ thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper()
+ return thumbprint
+
+ def decrypt_p7m(self, p7m_file, trans_prv_file, trans_cert_file, pem_file):
+ cmd = ("{0} cms -decrypt -in {1} -inkey {2} -recip {3} "
+ "| {4} pkcs12 -nodes -password pass: -out {5}"
+ "").format(self.openssl_cmd, p7m_file, trans_prv_file,
+ trans_cert_file, self.openssl_cmd, pem_file)
+ shellutil.run(cmd)
+
+ def crt_to_ssh(self, input_file, output_file):
+ shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file,
+ output_file))
+
+ def asn1_to_ssh(self, pubkey):
+ lines = pubkey.split("\n")
+ lines = [x for x in lines if not x.startswith("----")]
+ base64_encoded = "".join(lines)
+ try:
+ #TODO remove pyasn1 dependency
+ from pyasn1.codec.der import decoder as der_decoder
+ der_encoded = base64.b64decode(base64_encoded)
+ der_encoded = der_decoder.decode(der_encoded)[0][1]
+ key = der_decoder.decode(self.bits_to_bytes(der_encoded))[0]
+ n=key[0]
+ e=key[1]
+ keydata = bytearray()
+ keydata.extend(struct.pack('>I', len("ssh-rsa")))
+ keydata.extend(b"ssh-rsa")
+ keydata.extend(struct.pack('>I', len(self.num_to_bytes(e))))
+ keydata.extend(self.num_to_bytes(e))
+ keydata.extend(struct.pack('>I', len(self.num_to_bytes(n)) + 1))
+ keydata.extend(b"\0")
+ keydata.extend(self.num_to_bytes(n))
+ keydata_base64 = base64.b64encode(bytebuffer(keydata))
+ return ustr(b"ssh-rsa " + keydata_base64 + b"\n",
+ encoding='utf-8')
+ except ImportError as e:
+ raise CryptError("Failed to load pyasn1.codec.der")
+
+ def num_to_bytes(self, num):
+ """
+ Pack number into bytes. Retun as string.
+ """
+ result = bytearray()
+ while num:
+ result.append(num & 0xFF)
+ num >>= 8
+ result.reverse()
+ return result
+
+ def bits_to_bytes(self, bits):
+ """
+ Convert an array contains bits, [0,1] to a byte array
+ """
+ index = 7
+ byte_array = bytearray()
+ curr = 0
+ for bit in bits:
+ curr = curr | (bit << index)
+ index = index - 1
+ if index == -1:
+ byte_array.append(curr)
+ curr = 0
+ index = 7
+ return bytes(byte_array)
+
diff --git a/azurelinuxagent/common/utils/fileutil.py b/azurelinuxagent/common/utils/fileutil.py
new file mode 100644
index 0000000..7ef4fef
--- /dev/null
+++ b/azurelinuxagent/common/utils/fileutil.py
@@ -0,0 +1,171 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+"""
+File operation util functions
+"""
+
+import os
+import re
+import shutil
+import pwd
+import tempfile
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.utils.textutil as textutil
+
+def copy_file(from_path, to_path=None, to_dir=None):
+ if to_path is None:
+ to_path = os.path.join(to_dir, os.path.basename(from_path))
+ shutil.copyfile(from_path, to_path)
+ return to_path
+
+
+def read_file(filepath, asbin=False, remove_bom=False, encoding='utf-8'):
+ """
+ Read and return contents of 'filepath'.
+ """
+ mode = 'rb'
+ with open(filepath, mode) as in_file:
+ data = in_file.read()
+ if data is None:
+ return None
+
+ if asbin:
+ return data
+
+ if remove_bom:
+ #Remove bom on bytes data before it is converted into string.
+ data = textutil.remove_bom(data)
+ data = ustr(data, encoding=encoding)
+ return data
+
+def write_file(filepath, contents, asbin=False, encoding='utf-8', append=False):
+ """
+ Write 'contents' to 'filepath'.
+ """
+ mode = "ab" if append else "wb"
+ data = contents
+ if not asbin:
+ data = contents.encode(encoding)
+ with open(filepath, mode) as out_file:
+ out_file.write(data)
+
+def append_file(filepath, contents, asbin=False, encoding='utf-8'):
+ """
+ Append 'contents' to 'filepath'.
+ """
+ write_file(filepath, contents, asbin=asbin, encoding=encoding, append=True)
+
+
+def base_name(path):
+ head, tail = os.path.split(path)
+ return tail
+
+def get_line_startingwith(prefix, filepath):
+ """
+ Return line from 'filepath' if the line startswith 'prefix'
+ """
+ for line in read_file(filepath).split('\n'):
+ if line.startswith(prefix):
+ return line
+ return None
+
+#End File operation util functions
+
+def mkdir(dirpath, mode=None, owner=None):
+ if not os.path.isdir(dirpath):
+ os.makedirs(dirpath)
+ if mode is not None:
+ chmod(dirpath, mode)
+ if owner is not None:
+ chowner(dirpath, owner)
+
+def chowner(path, owner):
+ if not os.path.exists(path):
+ logger.error("Path does not exist: {0}".format(path))
+ else:
+ owner_info = pwd.getpwnam(owner)
+ os.chown(path, owner_info[2], owner_info[3])
+
+def chmod(path, mode):
+ if not os.path.exists(path):
+ logger.error("Path does not exist: {0}".format(path))
+ else:
+ os.chmod(path, mode)
+
+def rm_files(*args):
+ for path in args:
+ if os.path.isfile(path):
+ os.remove(path)
+
+def rm_dirs(*args):
+ """
+ Remove all the contents under the directry
+ """
+ for dir_name in args:
+ if os.path.isdir(dir_name):
+ for item in os.listdir(dir_name):
+ path = os.path.join(dir_name, item)
+ if os.path.isfile(path):
+ os.remove(path)
+ elif os.path.isdir(path):
+ shutil.rmtree(path)
+
+def trim_ext(path, ext):
+ if not ext.startswith("."):
+ ext = "." + ext
+ return path.split(ext)[0] if path.endswith(ext) else path
+
+def update_conf_file(path, line_start, val, chk_err=False):
+ conf = []
+ if not os.path.isfile(path) and chk_err:
+ raise IOError("Can't find config file:{0}".format(path))
+ conf = read_file(path).split('\n')
+ conf = [x for x in conf if not x.startswith(line_start)]
+ conf.append(val)
+ write_file(path, '\n'.join(conf))
+
+def search_file(target_dir_name, target_file_name):
+ for root, dirs, files in os.walk(target_dir_name):
+ for file_name in files:
+ if file_name == target_file_name:
+ return os.path.join(root, file_name)
+ return None
+
+def chmod_tree(path, mode):
+ for root, dirs, files in os.walk(path):
+ for file_name in files:
+ os.chmod(os.path.join(root, file_name), mode)
+
+def findstr_in_file(file_path, pattern_str):
+ """
+ Return match object if found in file.
+ """
+ try:
+ pattern = re.compile(pattern_str)
+ for line in (open(file_path, 'r')).readlines():
+ match = re.search(pattern, line)
+ if match:
+ return match
+ except:
+ raise
+
+ return None
+
diff --git a/azurelinuxagent/common/utils/flexible_version.py b/azurelinuxagent/common/utils/flexible_version.py
new file mode 100644
index 0000000..2fce88d
--- /dev/null
+++ b/azurelinuxagent/common/utils/flexible_version.py
@@ -0,0 +1,199 @@
+from distutils import version
+import re
+
+class FlexibleVersion(version.Version):
+ """
+ A more flexible implementation of distutils.version.StrictVersion
+
+ The implementation allows to specify:
+ - an arbitrary number of version numbers:
+ not only '1.2.3' , but also '1.2.3.4.5'
+ - the separator between version numbers:
+ '1-2-3' is allowed when '-' is specified as separator
+ - a flexible pre-release separator:
+ '1.2.3.alpha1', '1.2.3-alpha1', and '1.2.3alpha1' are considered equivalent
+ - an arbitrary ordering of pre-release tags:
+ 1.1alpha3 < 1.1beta2 < 1.1rc1 < 1.1
+ when ["alpha", "beta", "rc"] is specified as pre-release tag list
+
+ Inspiration from this discussion at StackOverflow:
+ http://stackoverflow.com/questions/12255554/sort-versions-in-python
+ """
+
+ def __init__(self, vstring=None, sep='.', prerel_tags=('alpha', 'beta', 'rc')):
+ version.Version.__init__(self)
+
+ if sep is None:
+ sep = '.'
+ if prerel_tags is None:
+ prerel_tags = ()
+
+ self.sep = sep
+ self.prerel_sep = ''
+ self.prerel_tags = tuple(prerel_tags) if prerel_tags is not None else ()
+
+ self._compile_pattern()
+
+ self.prerelease = None
+ self.version = ()
+ if vstring:
+ self._parse(vstring)
+ return
+
+ _nn_version = 'version'
+ _nn_prerel_sep = 'prerel_sep'
+ _nn_prerel_tag = 'tag'
+ _nn_prerel_num = 'tag_num'
+
+ _re_prerel_sep = r'(?P<{pn}>{sep})?'.format(
+ pn=_nn_prerel_sep,
+ sep='|'.join(map(re.escape, ('.', '-'))))
+
+ @property
+ def major(self):
+ return self.version[0] if len(self.version) > 0 else 0
+
+ @property
+ def minor(self):
+ return self.version[1] if len(self.version) > 1 else 0
+
+ @property
+ def patch(self):
+ return self.version[2] if len(self.version) > 2 else 0
+
+ def _parse(self, vstring):
+ m = self.version_re.match(vstring)
+ if not m:
+ raise ValueError("Invalid version number '{0}'".format(vstring))
+
+ self.prerelease = None
+ self.version = ()
+
+ self.prerel_sep = m.group(self._nn_prerel_sep)
+ tag = m.group(self._nn_prerel_tag)
+ tag_num = m.group(self._nn_prerel_num)
+
+ if tag is not None and tag_num is not None:
+ self.prerelease = (tag, int(tag_num) if len(tag_num) else None)
+
+ self.version = tuple(map(int, self.sep_re.split(m.group(self._nn_version))))
+ return
+
+ def __add__(self, increment):
+ version = list(self.version)
+ version[-1] += increment
+ vstring = self._assemble(version, self.sep, self.prerel_sep, self.prerelease)
+ return FlexibleVersion(vstring=vstring, sep=self.sep, prerel_tags=self.prerel_tags)
+
+ def __sub__(self, decrement):
+ version = list(self.version)
+ if version[-1] <= 0:
+ raise ArithmeticError("Cannot decrement final numeric component of {0} below zero" \
+ .format(self))
+ version[-1] -= decrement
+ vstring = self._assemble(version, self.sep, self.prerel_sep, self.prerelease)
+ return FlexibleVersion(vstring=vstring, sep=self.sep, prerel_tags=self.prerel_tags)
+
+ def __repr__(self):
+ return "{cls} ('{vstring}', '{sep}', {prerel_tags})"\
+ .format(
+ cls=self.__class__.__name__,
+ vstring=str(self),
+ sep=self.sep,
+ prerel_tags=self.prerel_tags)
+
+ def __str__(self):
+ return self._assemble(self.version, self.sep, self.prerel_sep, self.prerelease)
+
+ def __ge__(self, that):
+ return not self.__lt__(that)
+
+ def __gt__(self, that):
+ return (not self.__lt__(that)) and (not self.__eq__(that))
+
+ def __le__(self, that):
+ return (self.__lt__(that)) or (self.__eq__(that))
+
+ def __lt__(self, that):
+ this_version, that_version = self._ensure_compatible(that)
+
+ if this_version != that_version \
+ or self.prerelease is None and that.prerelease is None:
+ return this_version < that_version
+
+ if self.prerelease is not None and that.prerelease is None:
+ return True
+ if self.prerelease is None and that.prerelease is not None:
+ return False
+
+ this_index = self.prerel_tags_set[self.prerelease[0]]
+ that_index = self.prerel_tags_set[that.prerelease[0]]
+ if this_index == that_index:
+ return self.prerelease[1] < that.prerelease[1]
+
+ return this_index < that_index
+
+ def __ne__(self, that):
+ return not self.__eq__(that)
+
+ def __eq__(self, that):
+ this_version, that_version = self._ensure_compatible(that)
+
+ if this_version != that_version:
+ return False
+
+ if self.prerelease != that.prerelease:
+ return False
+
+ return True
+
+ def _assemble(self, version, sep, prerel_sep, prerelease):
+ s = sep.join(map(str, version))
+ if prerelease is not None:
+ if prerel_sep is not None:
+ s += prerel_sep
+ s += prerelease[0]
+ if prerelease[1] is not None:
+ s += str(prerelease[1])
+ return s
+
+ def _compile_pattern(self):
+ sep, self.sep_re = self._compile_separator(self.sep)
+
+ if self.prerel_tags:
+ tags = '|'.join(re.escape(tag) for tag in self.prerel_tags)
+ self.prerel_tags_set = dict(zip(self.prerel_tags, range(len(self.prerel_tags))))
+ release_re = '(?:{prerel_sep}(?P<{tn}>{tags})(?P<{nn}>\d*))?'.format(
+ prerel_sep=self._re_prerel_sep,
+ tags=tags,
+ tn=self._nn_prerel_tag,
+ nn=self._nn_prerel_num)
+ else:
+ release_re = ''
+
+ version_re = r'^(?P<{vn}>\d+(?:(?:{sep}\d+)*)?){rel}$'.format(
+ vn=self._nn_version,
+ sep=sep,
+ rel=release_re)
+ self.version_re = re.compile(version_re)
+ return
+
+ def _compile_separator(self, sep):
+ if sep is None:
+ return '', re.compile('')
+ return re.escape(sep), re.compile(re.escape(sep))
+
+ def _ensure_compatible(self, that):
+ """
+ Ensures the instances have the same structure and, if so, returns length compatible
+ version lists (so that x.y.0.0 is equivalent to x.y).
+ """
+ if self.prerel_tags != that.prerel_tags or self.sep != that.sep:
+ raise ValueError("Unable to compare: versions have different structures")
+
+ this_version = list(self.version[:])
+ that_version = list(that.version[:])
+ while len(this_version) < len(that_version): this_version.append(0)
+ while len(that_version) < len(this_version): that_version.append(0)
+
+ return this_version, that_version
diff --git a/azurelinuxagent/common/utils/restutil.py b/azurelinuxagent/common/utils/restutil.py
new file mode 100644
index 0000000..a789650
--- /dev/null
+++ b/azurelinuxagent/common/utils/restutil.py
@@ -0,0 +1,156 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import time
+import platform
+import os
+import subprocess
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.logger as logger
+from azurelinuxagent.common.exception import HttpError
+from azurelinuxagent.common.future import httpclient, urlparse
+
+"""
+REST api util functions
+"""
+
+RETRY_WAITING_INTERVAL = 10
+
+def _parse_url(url):
+ o = urlparse(url)
+ rel_uri = o.path
+ if o.fragment:
+ rel_uri = "{0}#{1}".format(rel_uri, o.fragment)
+ if o.query:
+ rel_uri = "{0}?{1}".format(rel_uri, o.query)
+ secure = False
+ if o.scheme.lower() == "https":
+ secure = True
+ return o.hostname, o.port, secure, rel_uri
+
+def get_http_proxy():
+ """
+ Get http_proxy and https_proxy from environment variables.
+ Username and password is not supported now.
+ """
+ host = conf.get_httpproxy_host()
+ port = conf.get_httpproxy_port()
+ return (host, port)
+
+def _http_request(method, host, rel_uri, port=None, data=None, secure=False,
+ headers=None, proxy_host=None, proxy_port=None):
+ url, conn = None, None
+ if secure:
+ port = 443 if port is None else port
+ if proxy_host is not None and proxy_port is not None:
+ conn = httpclient.HTTPSConnection(proxy_host, proxy_port, timeout=10)
+ conn.set_tunnel(host, port)
+ #If proxy is used, full url is needed.
+ url = "https://{0}:{1}{2}".format(host, port, rel_uri)
+ else:
+ conn = httpclient.HTTPSConnection(host, port, timeout=10)
+ url = rel_uri
+ else:
+ port = 80 if port is None else port
+ if proxy_host is not None and proxy_port is not None:
+ conn = httpclient.HTTPConnection(proxy_host, proxy_port, timeout=10)
+ #If proxy is used, full url is needed.
+ url = "http://{0}:{1}{2}".format(host, port, rel_uri)
+ else:
+ conn = httpclient.HTTPConnection(host, port, timeout=10)
+ url = rel_uri
+ if headers == None:
+ conn.request(method, url, data)
+ else:
+ conn.request(method, url, data, headers)
+ resp = conn.getresponse()
+ return resp
+
+def http_request(method, url, data, headers=None, max_retry=3, chk_proxy=False):
+ """
+ Sending http request to server
+ On error, sleep 10 and retry max_retry times.
+ """
+ logger.verbose("HTTP Req: {0} {1}", method, url)
+ logger.verbose(" Data={0}", data)
+ logger.verbose(" Header={0}", headers)
+ host, port, secure, rel_uri = _parse_url(url)
+
+ #Check proxy
+ proxy_host, proxy_port = (None, None)
+ if chk_proxy:
+ proxy_host, proxy_port = get_http_proxy()
+
+ #If httplib module is not built with ssl support. Fallback to http
+ if secure and not hasattr(httpclient, "HTTPSConnection"):
+ logger.warn("httplib is not built with ssl support")
+ secure = False
+
+ #If httplib module doesn't support https tunnelling. Fallback to http
+ if secure and \
+ proxy_host is not None and \
+ proxy_port is not None and \
+ not hasattr(httpclient.HTTPSConnection, "set_tunnel"):
+ logger.warn("httplib doesn't support https tunnelling(new in python 2.7)")
+ secure = False
+
+ for retry in range(0, max_retry):
+ try:
+ resp = _http_request(method, host, rel_uri, port=port, data=data,
+ secure=secure, headers=headers,
+ proxy_host=proxy_host, proxy_port=proxy_port)
+ logger.verbose("HTTP Resp: Status={0}", resp.status)
+ logger.verbose(" Header={0}", resp.getheaders())
+ return resp
+ except httpclient.HTTPException as e:
+ logger.warn('HTTPException {0}, args:{1}', e, repr(e.args))
+ except IOError as e:
+ logger.warn('Socket IOError {0}, args:{1}', e, repr(e.args))
+
+ if retry < max_retry - 1:
+ logger.info("Retry={0}, {1} {2}", retry, method, url)
+ time.sleep(RETRY_WAITING_INTERVAL)
+
+ if url is not None and len(url) > 100:
+ url_log = url[0: 100] #In case the url is too long
+ else:
+ url_log = url
+ raise HttpError("HTTP Err: {0} {1}".format(method, url_log))
+
+def http_get(url, headers=None, max_retry=3, chk_proxy=False):
+ return http_request("GET", url, data=None, headers=headers,
+ max_retry=max_retry, chk_proxy=chk_proxy)
+
+def http_head(url, headers=None, max_retry=3, chk_proxy=False):
+ return http_request("HEAD", url, None, headers=headers,
+ max_retry=max_retry, chk_proxy=chk_proxy)
+
+def http_post(url, data, headers=None, max_retry=3, chk_proxy=False):
+ return http_request("POST", url, data, headers=headers,
+ max_retry=max_retry, chk_proxy=chk_proxy)
+
+def http_put(url, data, headers=None, max_retry=3, chk_proxy=False):
+ return http_request("PUT", url, data, headers=headers,
+ max_retry=max_retry, chk_proxy=chk_proxy)
+
+def http_delete(url, headers=None, max_retry=3, chk_proxy=False):
+ return http_request("DELETE", url, None, headers=headers,
+ max_retry=max_retry, chk_proxy=chk_proxy)
+
+#End REST api util functions
diff --git a/azurelinuxagent/common/utils/shellutil.py b/azurelinuxagent/common/utils/shellutil.py
new file mode 100644
index 0000000..d273c92
--- /dev/null
+++ b/azurelinuxagent/common/utils/shellutil.py
@@ -0,0 +1,107 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import platform
+import os
+import subprocess
+from azurelinuxagent.common.future import ustr
+import azurelinuxagent.common.logger as logger
+
+if not hasattr(subprocess,'check_output'):
+ def check_output(*popenargs, **kwargs):
+ r"""Backport from subprocess module from python 2.7"""
+ if 'stdout' in kwargs:
+ raise ValueError('stdout argument not allowed, '
+ 'it will be overridden.')
+ process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
+ output, unused_err = process.communicate()
+ retcode = process.poll()
+ if retcode:
+ cmd = kwargs.get("args")
+ if cmd is None:
+ cmd = popenargs[0]
+ raise subprocess.CalledProcessError(retcode, cmd, output=output)
+ return output
+
+ # Exception classes used by this module.
+ class CalledProcessError(Exception):
+ def __init__(self, returncode, cmd, output=None):
+ self.returncode = returncode
+ self.cmd = cmd
+ self.output = output
+ def __str__(self):
+ return ("Command '{0}' returned non-zero exit status {1}"
+ "").format(self.cmd, self.returncode)
+
+ subprocess.check_output=check_output
+ subprocess.CalledProcessError=CalledProcessError
+
+
+"""
+Shell command util functions
+"""
+def run(cmd, chk_err=True):
+ """
+ Calls run_get_output on 'cmd', returning only the return code.
+ If chk_err=True then errors will be reported in the log.
+ If chk_err=False then errors will be suppressed from the log.
+ """
+ retcode,out=run_get_output(cmd,chk_err)
+ return retcode
+
+def run_get_output(cmd, chk_err=True, log_cmd=True):
+ """
+ Wrapper for subprocess.check_output.
+ Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions.
+ Reports exceptions to Error if chk_err parameter is True
+ """
+ if log_cmd:
+ logger.verbose(u"run cmd '{0}'", cmd)
+ try:
+ output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True)
+ output = ustr(output, encoding='utf-8', errors="backslashreplace")
+ except subprocess.CalledProcessError as e :
+ output = ustr(e.output, encoding='utf-8', errors="backslashreplace")
+ if chk_err:
+ if log_cmd:
+ logger.error(u"run cmd '{0}' failed", e.cmd)
+ logger.error(u"Error Code:{0}", e.returncode)
+ logger.error(u"Result:{0}", output)
+ return e.returncode, output
+ return 0, output
+
+
+def quote(word_list):
+ """
+ Quote a list or tuple of strings for Unix Shell as words, using the
+ byte-literal single quote.
+
+ The resulting string is safe for use with ``shell=True`` in ``subprocess``,
+ and in ``os.system``. ``assert shlex.split(ShellQuote(wordList)) == wordList``.
+
+ See POSIX.1:2013 Vol 3, Chap 2, Sec 2.2.2:
+ http://pubs.opengroup.org/onlinepubs/9699919799/utilities/V3_chap02.html#tag_18_02_02
+ """
+ if not isinstance(word_list, (tuple, list)):
+ word_list = (word_list,)
+
+ return " ".join(list("'{0}'".format(s.replace("'", "'\\''")) for s in word_list))
+
+
+# End shell command util functions
diff --git a/azurelinuxagent/common/utils/textutil.py b/azurelinuxagent/common/utils/textutil.py
new file mode 100644
index 0000000..f03c7e6
--- /dev/null
+++ b/azurelinuxagent/common/utils/textutil.py
@@ -0,0 +1,279 @@
+# Microsoft Azure Linux Agent
+#
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+
+import base64
+import crypt
+import random
+import string
+import struct
+import sys
+import xml.dom.minidom as minidom
+
+from distutils.version import LooseVersion as Version
+
+
+def parse_doc(xml_text):
+ """
+ Parse xml document from string
+ """
+ # The minidom lib has some issue with unicode in python2.
+ # Encode the string into utf-8 first
+ xml_text = xml_text.encode('utf-8')
+ return minidom.parseString(xml_text)
+
+
+def findall(root, tag, namespace=None):
+ """
+ Get all nodes by tag and namespace under Node root.
+ """
+ if root is None:
+ return []
+
+ if namespace is None:
+ return root.getElementsByTagName(tag)
+ else:
+ return root.getElementsByTagNameNS(namespace, tag)
+
+
+def find(root, tag, namespace=None):
+ """
+ Get first node by tag and namespace under Node root.
+ """
+ nodes = findall(root, tag, namespace=namespace)
+ if nodes is not None and len(nodes) >= 1:
+ return nodes[0]
+ else:
+ return None
+
+
+def gettext(node):
+ """
+ Get node text
+ """
+ if node is None:
+ return None
+
+ for child in node.childNodes:
+ if child.nodeType == child.TEXT_NODE:
+ return child.data
+ return None
+
+
+def findtext(root, tag, namespace=None):
+ """
+ Get text of node by tag and namespace under Node root.
+ """
+ node = find(root, tag, namespace=namespace)
+ return gettext(node)
+
+
+def getattrib(node, attr_name):
+ """
+ Get attribute of xml node
+ """
+ if node is not None:
+ return node.getAttribute(attr_name)
+ else:
+ return None
+
+
+def unpack(buf, offset, range):
+ """
+ Unpack bytes into python values.
+ """
+ result = 0
+ for i in range:
+ result = (result << 8) | str_to_ord(buf[offset + i])
+ return result
+
+
+def unpack_little_endian(buf, offset, length):
+ """
+ Unpack little endian bytes into python values.
+ """
+ return unpack(buf, offset, list(range(length - 1, -1, -1)))
+
+
+def unpack_big_endian(buf, offset, length):
+ """
+ Unpack big endian bytes into python values.
+ """
+ return unpack(buf, offset, list(range(0, length)))
+
+
+def hex_dump3(buf, offset, length):
+ """
+ Dump range of buf in formatted hex.
+ """
+ return ''.join(['%02X' % str_to_ord(char) for char in buf[offset:offset + length]])
+
+
+def hex_dump2(buf):
+ """
+ Dump buf in formatted hex.
+ """
+ return hex_dump3(buf, 0, len(buf))
+
+
+def is_in_range(a, low, high):
+ """
+ Return True if 'a' in 'low' <= a >= 'high'
+ """
+ return (a >= low and a <= high)
+
+
+def is_printable(ch):
+ """
+ Return True if character is displayable.
+ """
+ return (is_in_range(ch, str_to_ord('A'), str_to_ord('Z'))
+ or is_in_range(ch, str_to_ord('a'), str_to_ord('z'))
+ or is_in_range(ch, str_to_ord('0'), str_to_ord('9')))
+
+
+def hex_dump(buffer, size):
+ """
+ Return Hex formated dump of a 'buffer' of 'size'.
+ """
+ if size < 0:
+ size = len(buffer)
+ result = ""
+ for i in range(0, size):
+ if (i % 16) == 0:
+ result += "%06X: " % i
+ byte = buffer[i]
+ if type(byte) == str:
+ byte = ord(byte.decode('latin1'))
+ result += "%02X " % byte
+ if (i & 15) == 7:
+ result += " "
+ if ((i + 1) % 16) == 0 or (i + 1) == size:
+ j = i
+ while ((j + 1) % 16) != 0:
+ result += " "
+ if (j & 7) == 7:
+ result += " "
+ j += 1
+ result += " "
+ for j in range(i - (i % 16), i + 1):
+ byte = buffer[j]
+ if type(byte) == str:
+ byte = str_to_ord(byte.decode('latin1'))
+ k = '.'
+ if is_printable(byte):
+ k = chr(byte)
+ result += k
+ if (i + 1) != size:
+ result += "\n"
+ return result
+
+
+def str_to_ord(a):
+ """
+ Allows indexing into a string or an array of integers transparently.
+ Generic utility function.
+ """
+ if type(a) == type(b'') or type(a) == type(u''):
+ a = ord(a)
+ return a
+
+
+def compare_bytes(a, b, start, length):
+ for offset in range(start, start + length):
+ if str_to_ord(a[offset]) != str_to_ord(b[offset]):
+ return False
+ return True
+
+
+def int_to_ip4_addr(a):
+ """
+ Build DHCP request string.
+ """
+ return "%u.%u.%u.%u" % ((a >> 24) & 0xFF,
+ (a >> 16) & 0xFF,
+ (a >> 8) & 0xFF,
+ (a) & 0xFF)
+
+
+def hexstr_to_bytearray(a):
+ """
+ Return hex string packed into a binary struct.
+ """
+ b = b""
+ for c in range(0, len(a) // 2):
+ b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16))
+ return b
+
+
+def set_ssh_config(config, name, val):
+ notfound = True
+ for i in range(0, len(config)):
+ if config[i].startswith(name):
+ config[i] = "{0} {1}".format(name, val)
+ notfound = False
+ elif config[i].startswith("Match"):
+ # Match block must be put in the end of sshd config
+ break
+ if notfound:
+ config.insert(i, "{0} {1}".format(name, val))
+ return config
+
+
+def set_ini_config(config, name, val):
+ notfound = True
+ nameEqual = name + '='
+ length = len(config)
+ text = "{0}=\"{1}\"".format(name, val)
+
+ for i in reversed(range(0, length)):
+ if config[i].startswith(nameEqual):
+ config[i] = text
+ notfound = False
+ break
+
+ if notfound:
+ config.insert(length - 1, text)
+
+
+def remove_bom(c):
+ if str_to_ord(c[0]) > 128 and str_to_ord(c[1]) > 128 and \
+ str_to_ord(c[2]) > 128:
+ c = c[3:]
+ return c
+
+
+def gen_password_hash(password, crypt_id, salt_len):
+ collection = string.ascii_letters + string.digits
+ salt = ''.join(random.choice(collection) for _ in range(salt_len))
+ salt = "${0}${1}".format(crypt_id, salt)
+ return crypt.crypt(password, salt)
+
+
+def get_bytes_from_pem(pem_str):
+ base64_bytes = ""
+ for line in pem_str.split('\n'):
+ if "----" not in line:
+ base64_bytes += line
+ return base64_bytes
+
+
+def b64encode(s):
+ from azurelinuxagent.common.version import PY_VERSION_MAJOR
+ if PY_VERSION_MAJOR > 2:
+ return base64.b64encode(bytes(s, 'utf-8')).decode('utf-8')
+ return base64.b64encode(s)
diff --git a/azurelinuxagent/common/version.py b/azurelinuxagent/common/version.py
new file mode 100644
index 0000000..6c4b475
--- /dev/null
+++ b/azurelinuxagent/common/version.py
@@ -0,0 +1,116 @@
+# Copyright 2014 Microsoft Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Requires Python 2.4+ and Openssl 1.0+
+#
+
+import os
+import re
+import platform
+import sys
+
+import azurelinuxagent.common.conf as conf
+import azurelinuxagent.common.utils.fileutil as fileutil
+from azurelinuxagent.common.utils.flexible_version import FlexibleVersion
+from azurelinuxagent.common.future import ustr
+
+
+def get_distro():
+ if 'FreeBSD' in platform.system():
+ release = re.sub('\-.*\Z', '', ustr(platform.release()))
+ osinfo = ['freebsd', release, '', 'freebsd']
+ elif 'linux_distribution' in dir(platform):
+ osinfo = list(platform.linux_distribution(full_distribution_name=0))
+ full_name = platform.linux_distribution()[0].strip()
+ osinfo.append(full_name)
+ else:
+ osinfo = platform.dist()
+
+ # The platform.py lib has issue with detecting oracle linux distribution.
+ # Merge the following patch provided by oracle as a temparory fix.
+ if os.path.exists("/etc/oracle-release"):
+ osinfo[2] = "oracle"
+ osinfo[3] = "Oracle Linux"
+
+ # Remove trailing whitespace and quote in distro name
+ osinfo[0] = osinfo[0].strip('"').strip(' ').lower()
+ return osinfo
+
+
+AGENT_NAME = "WALinuxAgent"
+AGENT_LONG_NAME = "Azure Linux Agent"
+AGENT_VERSION = '2.1.5'
+AGENT_LONG_VERSION = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION)
+AGENT_DESCRIPTION = """\
+The Azure Linux Agent supports the provisioning and running of Linux
+VMs in the Azure cloud. This package should be installed on Linux disk
+images that are built to run in the Azure environment.
+"""
+
+AGENT_DIR_GLOB = "{0}-*".format(AGENT_NAME)
+AGENT_PKG_GLOB = "{0}-*.zip".format(AGENT_NAME)
+
+AGENT_PATTERN = "{0}-(.*)".format(AGENT_NAME)
+AGENT_NAME_PATTERN = re.compile(AGENT_PATTERN)
+AGENT_DIR_PATTERN = re.compile(".*/{0}".format(AGENT_PATTERN))
+
+
+# Set the CURRENT_AGENT and CURRENT_VERSION to match the agent directory name
+# - This ensures the agent will "see itself" using the same name and version
+# as the code that downloads agents.
+def set_current_agent():
+ path = os.getcwd()
+ lib_dir = conf.get_lib_dir()
+ if lib_dir[-1] != os.path.sep:
+ lib_dir += os.path.sep
+ if path[:len(lib_dir)] != lib_dir:
+ agent = AGENT_LONG_VERSION
+ version = AGENT_VERSION
+ else:
+ agent = path[len(lib_dir):].split(os.path.sep)[0]
+ version = AGENT_NAME_PATTERN.match(agent).group(1)
+ return agent, FlexibleVersion(version)
+CURRENT_AGENT, CURRENT_VERSION = set_current_agent()
+
+def is_current_agent_installed():
+ return CURRENT_AGENT == AGENT_LONG_VERSION
+
+
+__distro__ = get_distro()
+DISTRO_NAME = __distro__[0]
+DISTRO_VERSION = __distro__[1]
+DISTRO_CODE_NAME = __distro__[2]
+DISTRO_FULL_NAME = __distro__[3]
+
+PY_VERSION = sys.version_info
+PY_VERSION_MAJOR = sys.version_info[0]
+PY_VERSION_MINOR = sys.version_info[1]
+PY_VERSION_MICRO = sys.version_info[2]
+
+"""
+Add this workaround for detecting Snappy Ubuntu Core temporarily, until ubuntu
+fixed this bug: https://bugs.launchpad.net/snappy/+bug/1481086
+"""
+
+
+def is_snappy():
+ if os.path.exists("/etc/motd"):
+ motd = fileutil.read_file("/etc/motd")
+ if "snappy" in motd:
+ return True
+ return False
+
+
+if is_snappy():
+ DISTRO_FULL_NAME = "Snappy Ubuntu Core"