diff options
| author | Tamilmani Manoharan <tamanoha@microsoft.com> | 2018-11-29 21:53:18 +0000 | 
|---|---|---|
| committer | Server Team CI Bot <josh.powers+server-team-bot@canonical.com> | 2018-11-29 21:53:18 +0000 | 
| commit | bf7917159dbb292c9fcdef82b004e0f5ecb32c16 (patch) | |
| tree | df2d64a2856949cb43b6fdab1adf52880a883a13 /cloudinit/sources/helpers | |
| parent | c7c395ce0f3d024243192947fee32d7fc6c063f5 (diff) | |
| download | vyos-cloud-init-bf7917159dbb292c9fcdef82b004e0f5ecb32c16.tar.gz vyos-cloud-init-bf7917159dbb292c9fcdef82b004e0f5ecb32c16.zip | |
azure: detect vnet migration via netlink media change event
Replace Azure pre-provision polling on IMDS with a blocking call
which watches for netlink link state change messages.  The media
change event happens when a pre-provisioned VM has been activated
and is connected to the users virtual network and cloud-init can
then resume operation to complete image instantiation.
Diffstat (limited to 'cloudinit/sources/helpers')
| -rw-r--r-- | cloudinit/sources/helpers/netlink.py | 250 | ||||
| -rw-r--r-- | cloudinit/sources/helpers/tests/test_netlink.py | 373 | 
2 files changed, 623 insertions, 0 deletions
| diff --git a/cloudinit/sources/helpers/netlink.py b/cloudinit/sources/helpers/netlink.py new file mode 100644 index 00000000..d377ae3d --- /dev/null +++ b/cloudinit/sources/helpers/netlink.py @@ -0,0 +1,250 @@ +# Author: Tamilmani Manoharan <tamanoha@microsoft.com> +# +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit import log as logging +from cloudinit import util +from collections import namedtuple + +import os +import select +import socket +import struct + +LOG = logging.getLogger(__name__) + +# http://man7.org/linux/man-pages/man7/netlink.7.html +RTMGRP_LINK = 1 +NLMSG_NOOP = 1 +NLMSG_ERROR = 2 +NLMSG_DONE = 3 +RTM_NEWLINK = 16 +RTM_DELLINK = 17 +RTM_GETLINK = 18 +RTM_SETLINK = 19 +MAX_SIZE = 65535 +RTA_DATA_OFFSET = 32 +MSG_TYPE_OFFSET = 16 +SELECT_TIMEOUT = 60 + +NLMSGHDR_FMT = "IHHII" +IFINFOMSG_FMT = "BHiII" +NLMSGHDR_SIZE = struct.calcsize(NLMSGHDR_FMT) +IFINFOMSG_SIZE = struct.calcsize(IFINFOMSG_FMT) +RTATTR_START_OFFSET = NLMSGHDR_SIZE + IFINFOMSG_SIZE +RTA_DATA_START_OFFSET = 4 +PAD_ALIGNMENT = 4 + +IFLA_IFNAME = 3 +IFLA_OPERSTATE = 16 + +# https://www.kernel.org/doc/Documentation/networking/operstates.txt +OPER_UNKNOWN = 0 +OPER_NOTPRESENT = 1 +OPER_DOWN = 2 +OPER_LOWERLAYERDOWN = 3 +OPER_TESTING = 4 +OPER_DORMANT = 5 +OPER_UP = 6 + +RTAAttr = namedtuple('RTAAttr', ['length', 'rta_type', 'data']) +InterfaceOperstate = namedtuple('InterfaceOperstate', ['ifname', 'operstate']) +NetlinkHeader = namedtuple('NetlinkHeader', ['length', 'type', 'flags', 'seq', +                                             'pid']) + + +class NetlinkCreateSocketError(RuntimeError): +    '''Raised if netlink socket fails during create or bind.''' +    pass + + +def create_bound_netlink_socket(): +    '''Creates netlink socket and bind on netlink group to catch interface +    down/up events. The socket will bound only on RTMGRP_LINK (which only +    includes RTM_NEWLINK/RTM_DELLINK/RTM_GETLINK events). The socket is set to +    non-blocking mode since we're only receiving messages. + +    :returns: netlink socket in non-blocking mode +    :raises: NetlinkCreateSocketError +    ''' +    try: +        netlink_socket = socket.socket(socket.AF_NETLINK, +                                       socket.SOCK_RAW, +                                       socket.NETLINK_ROUTE) +        netlink_socket.bind((os.getpid(), RTMGRP_LINK)) +        netlink_socket.setblocking(0) +    except socket.error as e: +        msg = "Exception during netlink socket create: %s" % e +        raise NetlinkCreateSocketError(msg) +    LOG.debug("Created netlink socket") +    return netlink_socket + + +def get_netlink_msg_header(data): +    '''Gets netlink message type and length + +    :param: data read from netlink socket +    :returns: netlink message type +    :raises: AssertionError if data is None or data is not >= NLMSGHDR_SIZE +    struct nlmsghdr { +               __u32 nlmsg_len;    /* Length of message including header */ +               __u16 nlmsg_type;   /* Type of message content */ +               __u16 nlmsg_flags;  /* Additional flags */ +               __u32 nlmsg_seq;    /* Sequence number */ +               __u32 nlmsg_pid;    /* Sender port ID */ +    }; +    ''' +    assert (data is not None), ("data is none") +    assert (len(data) >= NLMSGHDR_SIZE), ( +        "data is smaller than netlink message header") +    msg_len, msg_type, flags, seq, pid = struct.unpack(NLMSGHDR_FMT, +                                                       data[:MSG_TYPE_OFFSET]) +    LOG.debug("Got netlink msg of type %d", msg_type) +    return NetlinkHeader(msg_len, msg_type, flags, seq, pid) + + +def read_netlink_socket(netlink_socket, timeout=None): +    '''Select and read from the netlink socket if ready. + +    :param: netlink_socket: specify which socket object to read from +    :param: timeout: specify a timeout value (integer) to wait while reading, +            if none, it will block indefinitely until socket ready for read +    :returns: string of data read (max length = <MAX_SIZE>) from socket, +              if no data read, returns None +    :raises: AssertionError if netlink_socket is None +    ''' +    assert (netlink_socket is not None), ("netlink socket is none") +    read_set, _, _ = select.select([netlink_socket], [], [], timeout) +    # Incase of timeout,read_set doesn't contain netlink socket. +    # just return from this function +    if netlink_socket not in read_set: +        return None +    LOG.debug("netlink socket ready for read") +    data = netlink_socket.recv(MAX_SIZE) +    if data is None: +        LOG.error("Reading from Netlink socket returned no data") +    return data + + +def unpack_rta_attr(data, offset): +    '''Unpack a single rta attribute. + +    :param: data: string of data read from netlink socket +    :param: offset: starting offset of RTA Attribute +    :return: RTAAttr object with length, type and data. On error, return None. +    :raises: AssertionError if data is None or offset is not integer. +    ''' +    assert (data is not None), ("data is none") +    assert (type(offset) == int), ("offset is not integer") +    assert (offset >= RTATTR_START_OFFSET), ( +        "rta offset is less than expected length") +    length = rta_type = 0 +    attr_data = None +    try: +        length = struct.unpack_from("H", data, offset=offset)[0] +        rta_type = struct.unpack_from("H", data, offset=offset+2)[0] +    except struct.error: +        return None  # Should mean our offset is >= remaining data + +    # Unpack just the attribute's data. Offset by 4 to skip length/type header +    attr_data = data[offset+RTA_DATA_START_OFFSET:offset+length] +    return RTAAttr(length, rta_type, attr_data) + + +def read_rta_oper_state(data): +    '''Reads Interface name and operational state from RTA Data. + +    :param: data: string of data read from netlink socket +    :returns: InterfaceOperstate object containing if_name and oper_state. +              None if data does not contain valid IFLA_OPERSTATE and +              IFLA_IFNAME messages. +    :raises: AssertionError if data is None or length of data is +             smaller than RTATTR_START_OFFSET. +    ''' +    assert (data is not None), ("data is none") +    assert (len(data) > RTATTR_START_OFFSET), ( +        "length of data is smaller than RTATTR_START_OFFSET") +    ifname = operstate = None +    offset = RTATTR_START_OFFSET +    while offset <= len(data): +        attr = unpack_rta_attr(data, offset) +        if not attr or attr.length == 0: +            break +        # Each attribute is 4-byte aligned. Determine pad length. +        padlen = (PAD_ALIGNMENT - +                  (attr.length % PAD_ALIGNMENT)) % PAD_ALIGNMENT +        offset += attr.length + padlen + +        if attr.rta_type == IFLA_OPERSTATE: +            operstate = ord(attr.data) +        elif attr.rta_type == IFLA_IFNAME: +            interface_name = util.decode_binary(attr.data, 'utf-8') +            ifname = interface_name.strip('\0') +    if not ifname or operstate is None: +        return None +    LOG.debug("rta attrs: ifname %s operstate %d", ifname, operstate) +    return InterfaceOperstate(ifname, operstate) + + +def wait_for_media_disconnect_connect(netlink_socket, ifname): +    '''Block until media disconnect and connect has happened on an interface. +    Listens on netlink socket to receive netlink events and when the carrier +    changes from 0 to 1, it considers event has happened and +    return from this function + +    :param: netlink_socket: netlink_socket to receive events +    :param: ifname: Interface name to lookout for netlink events +    :raises: AssertionError if netlink_socket is None or ifname is None. +    ''' +    assert (netlink_socket is not None), ("netlink socket is none") +    assert (ifname is not None), ("interface name is none") +    assert (len(ifname) > 0), ("interface name cannot be empty") +    carrier = OPER_UP +    prevCarrier = OPER_UP +    data = bytes() +    LOG.debug("Wait for media disconnect and reconnect to happen") +    while True: +        recv_data = read_netlink_socket(netlink_socket, SELECT_TIMEOUT) +        if recv_data is None: +            continue +        LOG.debug('read %d bytes from socket', len(recv_data)) +        data += recv_data +        LOG.debug('Length of data after concat %d', len(data)) +        offset = 0 +        datalen = len(data) +        while offset < datalen: +            nl_msg = data[offset:] +            if len(nl_msg) < NLMSGHDR_SIZE: +                LOG.debug("Data is smaller than netlink header") +                break +            nlheader = get_netlink_msg_header(nl_msg) +            if len(nl_msg) < nlheader.length: +                LOG.debug("Partial data. Smaller than netlink message") +                break +            padlen = (nlheader.length+PAD_ALIGNMENT-1) & ~(PAD_ALIGNMENT-1) +            offset = offset + padlen +            LOG.debug('offset to next netlink message: %d', offset) +            # Ignore any messages not new link or del link +            if nlheader.type not in [RTM_NEWLINK, RTM_DELLINK]: +                continue +            interface_state = read_rta_oper_state(nl_msg) +            if interface_state is None: +                LOG.debug('Failed to read rta attributes: %s', interface_state) +                continue +            if interface_state.ifname != ifname: +                LOG.debug( +                    "Ignored netlink event on interface %s. Waiting for %s.", +                    interface_state.ifname, ifname) +                continue +            if interface_state.operstate not in [OPER_UP, OPER_DOWN]: +                continue +            prevCarrier = carrier +            carrier = interface_state.operstate +            # check for carrier down, up sequence +            isVnetSwitch = (prevCarrier == OPER_DOWN) and (carrier == OPER_UP) +            if isVnetSwitch: +                LOG.debug("Media switch happened on %s.", ifname) +                return +        data = data[offset:] + +# vi: ts=4 expandtab diff --git a/cloudinit/sources/helpers/tests/test_netlink.py b/cloudinit/sources/helpers/tests/test_netlink.py new file mode 100644 index 00000000..c2898a16 --- /dev/null +++ b/cloudinit/sources/helpers/tests/test_netlink.py @@ -0,0 +1,373 @@ +# Author: Tamilmani Manoharan <tamanoha@microsoft.com> +# +# This file is part of cloud-init. See LICENSE file for license information. + +from cloudinit.tests.helpers import CiTestCase, mock +import socket +import struct +import codecs +from cloudinit.sources.helpers.netlink import ( +    NetlinkCreateSocketError, create_bound_netlink_socket, read_netlink_socket, +    read_rta_oper_state, unpack_rta_attr, wait_for_media_disconnect_connect, +    OPER_DOWN, OPER_UP, OPER_DORMANT, OPER_LOWERLAYERDOWN, OPER_NOTPRESENT, +    OPER_TESTING, OPER_UNKNOWN, RTATTR_START_OFFSET, RTM_NEWLINK, RTM_SETLINK, +    RTM_GETLINK, MAX_SIZE) + + +def int_to_bytes(i): +    '''convert integer to binary: eg: 1 to \x01''' +    hex_value = '{0:x}'.format(i) +    hex_value = '0' * (len(hex_value) % 2) + hex_value +    return codecs.decode(hex_value, 'hex_codec') + + +class TestCreateBoundNetlinkSocket(CiTestCase): + +    @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +    def test_socket_error_on_create(self, m_socket): +        '''create_bound_netlink_socket catches socket creation exception''' + +        """NetlinkCreateSocketError is raised when socket creation errors.""" +        m_socket.side_effect = socket.error("Fake socket failure") +        with self.assertRaises(NetlinkCreateSocketError) as ctx_mgr: +            create_bound_netlink_socket() +        self.assertEqual( +            'Exception during netlink socket create: Fake socket failure', +            str(ctx_mgr.exception)) + + +class TestReadNetlinkSocket(CiTestCase): + +    @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +    @mock.patch('cloudinit.sources.helpers.netlink.select.select') +    def test_read_netlink_socket(self, m_select, m_socket): +        '''read_netlink_socket able to receive data''' +        data = 'netlinktest' +        m_select.return_value = [m_socket], None, None +        m_socket.recv.return_value = data +        recv_data = read_netlink_socket(m_socket, 2) +        m_select.assert_called_with([m_socket], [], [], 2) +        m_socket.recv.assert_called_with(MAX_SIZE) +        self.assertIsNotNone(recv_data) +        self.assertEqual(recv_data, data) + +    @mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +    @mock.patch('cloudinit.sources.helpers.netlink.select.select') +    def test_netlink_read_timeout(self, m_select, m_socket): +        '''read_netlink_socket should timeout if nothing to read''' +        m_select.return_value = [], None, None +        data = read_netlink_socket(m_socket, 1) +        m_select.assert_called_with([m_socket], [], [], 1) +        self.assertEqual(m_socket.recv.call_count, 0) +        self.assertIsNone(data) + +    def test_read_invalid_socket(self): +        '''read_netlink_socket raises assert error if socket is invalid''' +        socket = None +        with self.assertRaises(AssertionError) as context: +            read_netlink_socket(socket, 1) +        self.assertTrue('netlink socket is none' in str(context.exception)) + + +class TestParseNetlinkMessage(CiTestCase): + +    def test_read_rta_oper_state(self): +        '''read_rta_oper_state could parse netlink message and extract data''' +        ifname = "eth0" +        bytes = ifname.encode("utf-8") +        buf = bytearray(48) +        struct.pack_into("HH4sHHc", buf, RTATTR_START_OFFSET, 8, 3, bytes, 5, +                         16, int_to_bytes(OPER_DOWN)) +        interface_state = read_rta_oper_state(buf) +        self.assertEqual(interface_state.ifname, ifname) +        self.assertEqual(interface_state.operstate, OPER_DOWN) + +    def test_read_none_data(self): +        '''read_rta_oper_state raises assert error if data is none''' +        data = None +        with self.assertRaises(AssertionError) as context: +            read_rta_oper_state(data) +        self.assertTrue('data is none', str(context.exception)) + +    def test_read_invalid_rta_operstate_none(self): +        '''read_rta_oper_state returns none if operstate is none''' +        ifname = "eth0" +        buf = bytearray(40) +        bytes = ifname.encode("utf-8") +        struct.pack_into("HH4s", buf, RTATTR_START_OFFSET, 8, 3, bytes) +        interface_state = read_rta_oper_state(buf) +        self.assertIsNone(interface_state) + +    def test_read_invalid_rta_ifname_none(self): +        '''read_rta_oper_state returns none if ifname is none''' +        buf = bytearray(40) +        struct.pack_into("HHc", buf, RTATTR_START_OFFSET, 5, 16, +                         int_to_bytes(OPER_DOWN)) +        interface_state = read_rta_oper_state(buf) +        self.assertIsNone(interface_state) + +    def test_read_invalid_data_len(self): +        '''raise assert error if data size is smaller than required size''' +        buf = bytearray(32) +        with self.assertRaises(AssertionError) as context: +            read_rta_oper_state(buf) +        self.assertTrue('length of data is smaller than RTATTR_START_OFFSET' in +                        str(context.exception)) + +    def test_unpack_rta_attr_none_data(self): +        '''unpack_rta_attr raises assert error if data is none''' +        data = None +        with self.assertRaises(AssertionError) as context: +            unpack_rta_attr(data, RTATTR_START_OFFSET) +        self.assertTrue('data is none' in str(context.exception)) + +    def test_unpack_rta_attr_invalid_offset(self): +        '''unpack_rta_attr raises assert error if offset is invalid''' +        data = bytearray(48) +        with self.assertRaises(AssertionError) as context: +            unpack_rta_attr(data, "offset") +        self.assertTrue('offset is not integer' in str(context.exception)) +        with self.assertRaises(AssertionError) as context: +            unpack_rta_attr(data, 31) +        self.assertTrue('rta offset is less than expected length' in +                        str(context.exception)) + + +@mock.patch('cloudinit.sources.helpers.netlink.socket.socket') +@mock.patch('cloudinit.sources.helpers.netlink.read_netlink_socket') +class TestWaitForMediaDisconnectConnect(CiTestCase): +    with_logs = True + +    def _media_switch_data(self, ifname, msg_type, operstate): +        '''construct netlink data with specified fields''' +        if ifname and operstate is not None: +            data = bytearray(48) +            bytes = ifname.encode("utf-8") +            struct.pack_into("HH4sHHc", data, RTATTR_START_OFFSET, 8, 3, +                             bytes, 5, 16, int_to_bytes(operstate)) +        elif ifname: +            data = bytearray(40) +            bytes = ifname.encode("utf-8") +            struct.pack_into("HH4s", data, RTATTR_START_OFFSET, 8, 3, bytes) +        elif operstate: +            data = bytearray(40) +            struct.pack_into("HHc", data, RTATTR_START_OFFSET, 5, 16, +                             int_to_bytes(operstate)) +        struct.pack_into("=LHHLL", data, 0, len(data), msg_type, 0, 0, 0) +        return data + +    def test_media_down_up_scenario(self, m_read_netlink_socket, +                                    m_socket): +        '''Test for media down up sequence for required interface name''' +        ifname = "eth0" +        # construct data for Oper State down +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        # construct data for Oper State up +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_op_down, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 2) + +    def test_wait_for_media_switch_diff_interface(self, m_read_netlink_socket, +                                                  m_socket): +        '''wait_for_media_disconnect_connect ignores unexpected interfaces. + +        The first two messages are for other interfaces and last two are for +        expected interface. So the function exit only after receiving last +        2 messages and therefore the call count for m_read_netlink_socket +        has to be 4 +        ''' +        other_ifname = "eth1" +        expected_ifname = "eth0" +        data_op_down_eth1 = self._media_switch_data( +                                other_ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up_eth1 = self._media_switch_data( +                                other_ifname, RTM_NEWLINK, OPER_UP) +        data_op_down_eth0 = self._media_switch_data( +                                expected_ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up_eth0 = self._media_switch_data( +                                expected_ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_op_down_eth1, +                                             data_op_up_eth1, +                                             data_op_down_eth0, +                                             data_op_up_eth0] +        wait_for_media_disconnect_connect(m_socket, expected_ifname) +        self.assertIn('Ignored netlink event on interface %s' % other_ifname, +                      self.logs.getvalue()) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_invalid_msgtype_getlink(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect ignores GETLINK events. + +        The first two messages are for oper down and up for RTM_GETLINK type +        which netlink module will ignore. The last 2 messages are RTM_NEWLINK +        with oper state down and up messages. Therefore the call count for +        m_read_netlink_socket has to be 4 ignoring first 2 messages +        of RTM_GETLINK +        ''' +        ifname = "eth0" +        data_getlink_down = self._media_switch_data( +                                    ifname, RTM_GETLINK, OPER_DOWN) +        data_getlink_up = self._media_switch_data( +                                    ifname, RTM_GETLINK, OPER_UP) +        data_newlink_down = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_DOWN) +        data_newlink_up = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_getlink_down, +                                             data_getlink_up, +                                             data_newlink_down, +                                             data_newlink_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_invalid_msgtype_setlink(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect ignores SETLINK events. + +        The first two messages are for oper down and up for RTM_GETLINK type +        which it will ignore. 3rd and 4th messages are RTM_NEWLINK with down +        and up messages. This function should exit after 4th messages since it +        sees down->up scenario. So the call count for m_read_netlink_socket +        has to be 4 ignoring first 2 messages of RTM_GETLINK and +        last 2 messages of RTM_NEWLINK +        ''' +        ifname = "eth0" +        data_setlink_down = self._media_switch_data( +                                    ifname, RTM_SETLINK, OPER_DOWN) +        data_setlink_up = self._media_switch_data( +                                    ifname, RTM_SETLINK, OPER_UP) +        data_newlink_down = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_DOWN) +        data_newlink_up = self._media_switch_data( +                                    ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_setlink_down, +                                             data_setlink_up, +                                             data_newlink_down, +                                             data_newlink_up, +                                             data_newlink_down, +                                             data_newlink_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_netlink_invalid_switch_scenario(self, m_read_netlink_socket, +                                             m_socket): +        '''returns only if it receives UP event after a DOWN event''' +        ifname = "eth0" +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        data_op_dormant = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_DORMANT) +        data_op_notpresent = self._media_switch_data(ifname, RTM_NEWLINK, +                                                     OPER_NOTPRESENT) +        data_op_lowerdown = self._media_switch_data(ifname, RTM_NEWLINK, +                                                    OPER_LOWERLAYERDOWN) +        data_op_testing = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_TESTING) +        data_op_unknown = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_UNKNOWN) +        m_read_netlink_socket.side_effect = [data_op_up, data_op_up, +                                             data_op_dormant, data_op_up, +                                             data_op_notpresent, data_op_up, +                                             data_op_lowerdown, data_op_up, +                                             data_op_testing, data_op_up, +                                             data_op_unknown, data_op_up, +                                             data_op_down, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 14) + +    def test_netlink_valid_inbetween_transitions(self, m_read_netlink_socket, +                                                 m_socket): +        '''wait_for_media_disconnect_connect handles in between transitions''' +        ifname = "eth0" +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        data_op_dormant = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_DORMANT) +        data_op_unknown = self._media_switch_data(ifname, RTM_NEWLINK, +                                                  OPER_UNKNOWN) +        m_read_netlink_socket.side_effect = [data_op_down, data_op_dormant, +                                             data_op_unknown, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_netlink_invalid_operstate(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect should handle invalid operstates. + +        The function should not fail and return even if it receives invalid +        operstates. It always should wait for down up sequence. +        ''' +        ifname = "eth0" +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        data_op_invalid = self._media_switch_data(ifname, RTM_NEWLINK, 7) +        m_read_netlink_socket.side_effect = [data_op_invalid, data_op_up, +                                             data_op_down, data_op_invalid, +                                             data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 5) + +    def test_wait_invalid_socket(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect handle none netlink socket.''' +        socket = None +        ifname = "eth0" +        with self.assertRaises(AssertionError) as context: +            wait_for_media_disconnect_connect(socket, ifname) +        self.assertTrue('netlink socket is none' in str(context.exception)) + +    def test_wait_invalid_ifname(self, m_read_netlink_socket, m_socket): +        '''wait_for_media_disconnect_connect handle none interface name''' +        ifname = None +        with self.assertRaises(AssertionError) as context: +            wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertTrue('interface name is none' in str(context.exception)) +        ifname = "" +        with self.assertRaises(AssertionError) as context: +            wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertTrue('interface name cannot be empty' in +                        str(context.exception)) + +    def test_wait_invalid_rta_attr(self, m_read_netlink_socket, m_socket): +        ''' wait_for_media_disconnect_connect handles invalid rta data''' +        ifname = "eth0" +        data_invalid1 = self._media_switch_data(None, RTM_NEWLINK, OPER_DOWN) +        data_invalid2 = self._media_switch_data(ifname, RTM_NEWLINK, None) +        data_op_down = self._media_switch_data(ifname, RTM_NEWLINK, OPER_DOWN) +        data_op_up = self._media_switch_data(ifname, RTM_NEWLINK, OPER_UP) +        m_read_netlink_socket.side_effect = [data_invalid1, data_invalid2, +                                             data_op_down, data_op_up] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 4) + +    def test_read_multiple_netlink_msgs(self, m_read_netlink_socket, m_socket): +        '''Read multiple messages in single receive call''' +        ifname = "eth0" +        bytes = ifname.encode("utf-8") +        data = bytearray(96) +        struct.pack_into("=LHHLL", data, 0, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data, RTATTR_START_OFFSET, 8, 3, +                         bytes, 5, 16, int_to_bytes(OPER_DOWN)) +        struct.pack_into("=LHHLL", data, 48, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data, 48 + RTATTR_START_OFFSET, 8, +                         3, bytes, 5, 16, int_to_bytes(OPER_UP)) +        m_read_netlink_socket.return_value = data +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 1) + +    def test_read_partial_netlink_msgs(self, m_read_netlink_socket, m_socket): +        '''Read partial messages in receive call''' +        ifname = "eth0" +        bytes = ifname.encode("utf-8") +        data1 = bytearray(112) +        data2 = bytearray(32) +        struct.pack_into("=LHHLL", data1, 0, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data1, RTATTR_START_OFFSET, 8, 3, +                         bytes, 5, 16, int_to_bytes(OPER_DOWN)) +        struct.pack_into("=LHHLL", data1, 48, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data1, 80, 8, 3, bytes, 5, 16, +                         int_to_bytes(OPER_DOWN)) +        struct.pack_into("=LHHLL", data1, 96, 48, RTM_NEWLINK, 0, 0, 0) +        struct.pack_into("HH4sHHc", data2, 16, 8, 3, bytes, 5, 16, +                         int_to_bytes(OPER_UP)) +        m_read_netlink_socket.side_effect = [data1, data2] +        wait_for_media_disconnect_connect(m_socket, ifname) +        self.assertEqual(m_read_netlink_socket.call_count, 2) | 
