From 83b8aebb19fe6e49e13a05d4e8f5ab9a06177642 Mon Sep 17 00:00:00 2001 From: Yves-Alexis Perez Date: Sat, 11 Apr 2015 22:03:59 +0200 Subject: Imported Upstream version 5.3.0 --- .../plugins/kernel_netlink/kernel_netlink_shared.c | 477 +++++++++++++++++---- 1 file changed, 405 insertions(+), 72 deletions(-) (limited to 'src/libhydra/plugins/kernel_netlink/kernel_netlink_shared.c') diff --git a/src/libhydra/plugins/kernel_netlink/kernel_netlink_shared.c b/src/libhydra/plugins/kernel_netlink/kernel_netlink_shared.c index b4cece720..a9adfe091 100644 --- a/src/libhydra/plugins/kernel_netlink/kernel_netlink_shared.c +++ b/src/libhydra/plugins/kernel_netlink/kernel_netlink_shared.c @@ -1,4 +1,6 @@ /* + * Copyright (C) 2014 Martin Willi + * Copyright (C) 2014 revosec AG * Copyright (C) 2008 Tobias Brunner * Hochschule fuer Technik Rapperswil * @@ -16,6 +18,7 @@ #include #include #include +#include #include #include @@ -23,6 +26,9 @@ #include #include +#include +#include +#include typedef struct private_netlink_socket_t private_netlink_socket_t; @@ -30,140 +36,447 @@ typedef struct private_netlink_socket_t private_netlink_socket_t; * Private variables and functions of netlink_socket_t class. */ struct private_netlink_socket_t { + /** * public part of the netlink_socket_t object. */ netlink_socket_t public; /** - * mutex to lock access to netlink socket + * mutex to lock access entries */ mutex_t *mutex; /** - * current sequence number for netlink request + * Netlink request entries currently active, uintptr_t seq => entry_t + */ + hashtable_t *entries; + + /** + * Current sequence number for Netlink requests */ - int seq; + refcount_t seq; /** * netlink socket */ int socket; + /** + * Netlink protocol + */ + int protocol; + /** * Enum names for Netlink messages */ enum_name_t *names; + + /** + * Timeout for Netlink replies, in ms + */ + u_int timeout; + + /** + * Number of times to repeat timed out queries + */ + u_int retries; + + /** + * Use parallel netlink queries + */ + bool parallel; + + /** + * Ignore errors potentially resulting from a retransmission + */ + bool ignore_retransmit_errors; }; /** - * Imported from kernel_netlink_ipsec.c + * #definable hook to simulate request message loss */ -extern enum_name_t *xfrm_msg_names; - -METHOD(netlink_socket_t, netlink_send, status_t, - private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out, - size_t *out_len) -{ - union { - struct nlmsghdr hdr; - u_char bytes[4096]; - } response; - struct sockaddr_nl addr; - chunk_t result = chunk_empty; - int len; +#ifdef NETLINK_MSG_LOSS_HOOK +bool NETLINK_MSG_LOSS_HOOK(struct nlmsghdr *msg); +#define msg_loss_hook(msg) NETLINK_MSG_LOSS_HOOK(msg) +#else +#define msg_loss_hook(msg) FALSE +#endif - this->mutex->lock(this->mutex); +/** + * Request entry the answer for a waiting thread is collected in + */ +typedef struct { + /** Condition variable thread is waiting */ + condvar_t *condvar; + /** Array of hdrs in a multi-message response, as struct nlmsghdr* */ + array_t *hdrs; + /** All response messages received? */ + bool complete; +} entry_t; - in->nlmsg_seq = ++this->seq; - in->nlmsg_pid = getpid(); +/** + * Clean up a thread waiting entry + */ +static void destroy_entry(entry_t *entry) +{ + entry->condvar->destroy(entry->condvar); + array_destroy_function(entry->hdrs, (void*)free, NULL); + free(entry); +} - memset(&addr, 0, sizeof(addr)); - addr.nl_family = AF_NETLINK; - addr.nl_pid = 0; - addr.nl_groups = 0; +/** + * Write a Netlink message to socket + */ +static bool write_msg(private_netlink_socket_t *this, struct nlmsghdr *msg) +{ + struct sockaddr_nl addr = { + .nl_family = AF_NETLINK, + }; + int len; - if (this->names) + if (msg_loss_hook(msg)) { - DBG3(DBG_KNL, "sending %N: %b", - this->names, in->nlmsg_type, in, in->nlmsg_len); + return TRUE; } + while (TRUE) { - len = sendto(this->socket, in, in->nlmsg_len, 0, + len = sendto(this->socket, msg, msg->nlmsg_len, 0, (struct sockaddr*)&addr, sizeof(addr)); - - if (len != in->nlmsg_len) + if (len != msg->nlmsg_len) { if (errno == EINTR) { - /* interrupted, try again */ continue; } - this->mutex->unlock(this->mutex); - DBG1(DBG_KNL, "error sending to netlink socket: %s", strerror(errno)); - return FAILED; + DBG1(DBG_KNL, "netlink write error: %s", strerror(errno)); + return FALSE; } - break; + return TRUE; } +} - while (TRUE) +/** + * Read a single Netlink message from socket, return 0 on error, -1 on timeout + */ +static ssize_t read_msg(private_netlink_socket_t *this, + char buf[4096], size_t buflen, bool block) +{ + ssize_t len; + + if (block) { - len = recv(this->socket, &response, sizeof(response), 0); - if (len < 0) + fd_set set; + timeval_t tv = {}; + + FD_ZERO(&set); + FD_SET(this->socket, &set); + timeval_add_ms(&tv, this->timeout); + + if (select(this->socket + 1, &set, NULL, NULL, + this->timeout ? &tv : NULL) <= 0) { - if (errno == EINTR) + return -1; + } + } + len = recv(this->socket, buf, buflen, block ? 0 : MSG_DONTWAIT); + if (len == buflen) + { + DBG1(DBG_KNL, "netlink response exceeds buffer size"); + return 0; + } + if (len < 0) + { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) + { + DBG1(DBG_KNL, "netlink read error: %s", strerror(errno)); + } + return 0; + } + return len; +} + +/** + * Queue received response message + */ +static bool queue(private_netlink_socket_t *this, struct nlmsghdr *buf) +{ + struct nlmsghdr *hdr; + entry_t *entry; + uintptr_t seq; + + seq = (uintptr_t)buf->nlmsg_seq; + + this->mutex->lock(this->mutex); + entry = this->entries->get(this->entries, (void*)seq); + if (entry) + { + hdr = malloc(buf->nlmsg_len); + memcpy(hdr, buf, buf->nlmsg_len); + array_insert(entry->hdrs, ARRAY_TAIL, hdr); + if (hdr->nlmsg_type == NLMSG_DONE || !(hdr->nlmsg_flags & NLM_F_MULTI)) + { + entry->complete = TRUE; + entry->condvar->signal(entry->condvar); + } + } + else + { + DBG1(DBG_KNL, "received unknown netlink seq %u, ignored", seq); + } + this->mutex->unlock(this->mutex); + + return entry != NULL; +} + +/** + * Read and queue response message, optionally blocking, returns TRUE on timeout + */ +static bool read_and_queue(private_netlink_socket_t *this, bool block) +{ + struct nlmsghdr *hdr; + union { + struct nlmsghdr hdr; + char bytes[4096]; + } buf; + ssize_t len; + + len = read_msg(this, buf.bytes, sizeof(buf.bytes), block); + if (len == -1) + { + return TRUE; + } + if (len) + { + hdr = &buf.hdr; + while (NLMSG_OK(hdr, len)) + { + if (!queue(this, hdr)) { - DBG1(DBG_KNL, "got interrupted"); - /* interrupted, try again */ - continue; + break; } - DBG1(DBG_KNL, "error reading from netlink socket: %s", strerror(errno)); - this->mutex->unlock(this->mutex); - free(result.ptr); - return FAILED; + hdr = NLMSG_NEXT(hdr, len); } - if (!NLMSG_OK(&response.hdr, len)) + } + return FALSE; +} + +CALLBACK(watch, bool, + private_netlink_socket_t *this, int fd, watcher_event_t event) +{ + if (event == WATCHER_READ) + { + read_and_queue(this, FALSE); + } + return TRUE; +} + +/** + * Send a netlink request, try once + */ +static status_t send_once(private_netlink_socket_t *this, struct nlmsghdr *in, + uintptr_t seq, struct nlmsghdr **out, size_t *out_len) +{ + struct nlmsghdr *hdr; + chunk_t result = {}; + entry_t *entry; + + in->nlmsg_seq = seq; + in->nlmsg_pid = getpid(); + + if (this->names) + { + DBG3(DBG_KNL, "sending %N %u: %b", this->names, in->nlmsg_type, + (u_int)seq, in, in->nlmsg_len); + } + + this->mutex->lock(this->mutex); + if (!write_msg(this, in)) + { + this->mutex->unlock(this->mutex); + return FAILED; + } + + INIT(entry, + .condvar = condvar_create(CONDVAR_TYPE_DEFAULT), + .hdrs = array_create(0, 0), + ); + this->entries->put(this->entries, (void*)seq, entry); + + while (!entry->complete) + { + if (this->parallel && + lib->watcher->get_state(lib->watcher) == WATCHER_RUNNING) { - DBG1(DBG_KNL, "received corrupted netlink message"); - this->mutex->unlock(this->mutex); - free(result.ptr); - return FAILED; + if (this->timeout) + { + if (entry->condvar->timed_wait(entry->condvar, this->mutex, + this->timeout)) + { + break; + } + } + else + { + entry->condvar->wait(entry->condvar, this->mutex); + } } - if (response.hdr.nlmsg_seq != this->seq) - { - DBG1(DBG_KNL, "received invalid netlink sequence number"); - if (response.hdr.nlmsg_seq < this->seq) + else + { /* During (de-)initialization, no watcher thread is active. + * collect responses ourselves. */ + if (read_and_queue(this, TRUE)) { - continue; + break; } - this->mutex->unlock(this->mutex); - free(result.ptr); - return FAILED; } + } + this->entries->remove(this->entries, (void*)seq); - result = chunk_cat("mc", result, chunk_create(response.bytes, len)); + this->mutex->unlock(this->mutex); - /* NLM_F_MULTI flag does not seem to be set correctly, we use sequence - * numbers to detect multi header messages */ - len = recv(this->socket, &response.hdr, sizeof(response.hdr), - MSG_PEEK | MSG_DONTWAIT); - if (len == sizeof(response.hdr) && response.hdr.nlmsg_seq == this->seq) + if (!entry->complete) + { /* timeout */ + destroy_entry(entry); + return OUT_OF_RES; + } + + while (array_remove(entry->hdrs, ARRAY_HEAD, &hdr)) + { + if (this->names) { - /* seems to be multipart */ - continue; + DBG3(DBG_KNL, "received %N %u: %b", this->names, hdr->nlmsg_type, + hdr->nlmsg_seq, hdr, hdr->nlmsg_len); } - break; + result = chunk_cat("mm", result, + chunk_create((char*)hdr, hdr->nlmsg_len)); } + destroy_entry(entry); *out_len = result.len; *out = (struct nlmsghdr*)result.ptr; - this->mutex->unlock(this->mutex); - return SUCCESS; } +/** + * Ignore errors for message types that might have completed previously + */ +static void ignore_retransmit_error(private_netlink_socket_t *this, + struct nlmsgerr *err, int type) +{ + switch (err->error) + { + case -EEXIST: + switch (this->protocol) + { + case NETLINK_XFRM: + switch (type) + { + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_NEWSA: + err->error = 0; + break; + } + break; + case NETLINK_ROUTE: + switch (type) + { + case RTM_NEWADDR: + case RTM_NEWLINK: + case RTM_NEWNEIGH: + case RTM_NEWROUTE: + case RTM_NEWRULE: + err->error = 0; + break; + } + break; + } + break; + case -ENOENT: + switch (this->protocol) + { + case NETLINK_XFRM: + switch (type) + { + case XFRM_MSG_DELPOLICY: + case XFRM_MSG_DELSA: + err->error = 0; + break; + } + break; + case NETLINK_ROUTE: + switch (type) + { + case RTM_DELADDR: + case RTM_DELLINK: + case RTM_DELNEIGH: + case RTM_DELROUTE: + case RTM_DELRULE: + err->error = 0; + break; + } + break; + } + break; + } +} + +METHOD(netlink_socket_t, netlink_send, status_t, + private_netlink_socket_t *this, struct nlmsghdr *in, struct nlmsghdr **out, + size_t *out_len) +{ + uintptr_t seq; + u_int try; + + seq = ref_get(&this->seq); + + for (try = 0; try <= this->retries; ++try) + { + struct nlmsghdr *hdr; + status_t status; + size_t len; + + if (try > 0) + { + DBG1(DBG_KNL, "retransmitting Netlink request (%u/%u)", + try, this->retries); + } + status = send_once(this, in, seq, &hdr, &len); + switch (status) + { + case SUCCESS: + break; + case OUT_OF_RES: + continue; + default: + return status; + } + if (hdr->nlmsg_type == NLMSG_ERROR) + { + struct nlmsgerr* err; + + err = NLMSG_DATA(hdr); + if (err->error == -EBUSY) + { + free(hdr); + try--; + continue; + } + if (this->ignore_retransmit_errors && try > 0) + { + ignore_retransmit_error(this, err, in->nlmsg_type); + } + } + *out = hdr; + *out_len = len; + return SUCCESS; + } + DBG1(DBG_KNL, "Netlink request timed out after %u retransmits", + this->retries); + return OUT_OF_RES; +} + METHOD(netlink_socket_t, netlink_send_ack, status_t, private_netlink_socket_t *this, struct nlmsghdr *in) { @@ -221,8 +534,13 @@ METHOD(netlink_socket_t, destroy, void, { if (this->socket != -1) { + if (this->parallel) + { + lib->watcher->remove(lib->watcher, this->socket); + } close(this->socket); } + this->entries->destroy(this->entries); this->mutex->destroy(this->mutex); free(this); } @@ -230,7 +548,8 @@ METHOD(netlink_socket_t, destroy, void, /** * Described in header. */ -netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names) +netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names, + bool parallel) { private_netlink_socket_t *this; struct sockaddr_nl addr = { @@ -244,9 +563,19 @@ netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names) .destroy = _destroy, }, .seq = 200, - .mutex = mutex_create(MUTEX_TYPE_DEFAULT), + .mutex = mutex_create(MUTEX_TYPE_RECURSIVE), .socket = socket(AF_NETLINK, SOCK_RAW, protocol), + .entries = hashtable_create(hashtable_hash_ptr, hashtable_equals_ptr, 4), + .protocol = protocol, .names = names, + .timeout = lib->settings->get_int(lib->settings, + "%s.plugins.kernel-netlink.timeout", 0, lib->ns), + .retries = lib->settings->get_int(lib->settings, + "%s.plugins.kernel-netlink.retries", 0, lib->ns), + .ignore_retransmit_errors = lib->settings->get_bool(lib->settings, + "%s.plugins.kernel-netlink.ignore_retransmit_errors", + FALSE, lib->ns), + .parallel = parallel, ); if (this->socket == -1) @@ -261,6 +590,10 @@ netlink_socket_t *netlink_socket_create(int protocol, enum_name_t *names) destroy(this); return NULL; } + if (this->parallel) + { + lib->watcher->add(lib->watcher, this->socket, WATCHER_READ, watch, this); + } return &this->public; } -- cgit v1.2.3