summaryrefslogtreecommitdiff
path: root/src/libhydra/kernel/kernel_interface.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libhydra/kernel/kernel_interface.c')
-rw-r--r--src/libhydra/kernel/kernel_interface.c276
1 files changed, 263 insertions, 13 deletions
diff --git a/src/libhydra/kernel/kernel_interface.c b/src/libhydra/kernel/kernel_interface.c
index 3fa28e054..ce31bd410 100644
--- a/src/libhydra/kernel/kernel_interface.c
+++ b/src/libhydra/kernel/kernel_interface.c
@@ -43,6 +43,8 @@
#include <utils/debug.h>
#include <threading/mutex.h>
#include <collections/linked_list.h>
+#include <collections/hashtable.h>
+#include <collections/array.h>
typedef struct private_kernel_interface_t private_kernel_interface_t;
@@ -115,6 +117,16 @@ struct private_kernel_interface_t {
linked_list_t *listeners;
/**
+ * Reqid entries indexed by reqids
+ */
+ hashtable_t *reqids;
+
+ /**
+ * Reqid entries indexed by traffic selectors
+ */
+ hashtable_t *reqids_by_ts;
+
+ /**
* mutex for algorithm mappings
*/
mutex_t *mutex_algs;
@@ -155,24 +167,252 @@ METHOD(kernel_interface_t, get_features, kernel_feature_t,
METHOD(kernel_interface_t, get_spi, status_t,
private_kernel_interface_t *this, host_t *src, host_t *dst,
- u_int8_t protocol, u_int32_t reqid, u_int32_t *spi)
+ u_int8_t protocol, u_int32_t *spi)
{
if (!this->ipsec)
{
return NOT_SUPPORTED;
}
- return this->ipsec->get_spi(this->ipsec, src, dst, protocol, reqid, spi);
+ return this->ipsec->get_spi(this->ipsec, src, dst, protocol, spi);
}
METHOD(kernel_interface_t, get_cpi, status_t,
private_kernel_interface_t *this, host_t *src, host_t *dst,
- u_int32_t reqid, u_int16_t *cpi)
+ u_int16_t *cpi)
{
if (!this->ipsec)
{
return NOT_SUPPORTED;
}
- return this->ipsec->get_cpi(this->ipsec, src, dst, reqid, cpi);
+ return this->ipsec->get_cpi(this->ipsec, src, dst, cpi);
+}
+
+/**
+ * Reqid mapping entry
+ */
+typedef struct {
+ /** allocated reqid */
+ u_int32_t reqid;
+ /** references to this entry */
+ u_int refs;
+ /** inbound mark used for SA */
+ mark_t mark_in;
+ /** outbound mark used for SA */
+ mark_t mark_out;
+ /** local traffic selectors */
+ array_t *local;
+ /** remote traffic selectors */
+ array_t *remote;
+} reqid_entry_t;
+
+/**
+ * Destroy a reqid mapping entry
+ */
+static void reqid_entry_destroy(reqid_entry_t *entry)
+{
+ array_destroy_offset(entry->local, offsetof(traffic_selector_t, destroy));
+ array_destroy_offset(entry->remote, offsetof(traffic_selector_t, destroy));
+ free(entry);
+}
+
+/**
+ * Hashtable hash function for reqid entries using reqid as key
+ */
+static u_int hash_reqid(reqid_entry_t *entry)
+{
+ return chunk_hash_inc(chunk_from_thing(entry->reqid),
+ chunk_hash_inc(chunk_from_thing(entry->mark_in),
+ chunk_hash(chunk_from_thing(entry->mark_out))));
+}
+
+/**
+ * Hashtable equals function for reqid entries using reqid as key
+ */
+static bool equals_reqid(reqid_entry_t *a, reqid_entry_t *b)
+{
+ return a->reqid == b->reqid &&
+ a->mark_in.value == b->mark_in.value &&
+ a->mark_in.mask == b->mark_in.mask &&
+ a->mark_out.value == b->mark_out.value &&
+ a->mark_out.mask == b->mark_out.mask;
+}
+
+/**
+ * Hash an array of traffic selectors
+ */
+static u_int hash_ts_array(array_t *array, u_int hash)
+{
+ enumerator_t *enumerator;
+ traffic_selector_t *ts;
+
+ enumerator = array_create_enumerator(array);
+ while (enumerator->enumerate(enumerator, &ts))
+ {
+ hash = ts->hash(ts, hash);
+ }
+ enumerator->destroy(enumerator);
+
+ return hash;
+}
+
+/**
+ * Hashtable hash function for reqid entries using traffic selectors as key
+ */
+static u_int hash_reqid_by_ts(reqid_entry_t *entry)
+{
+ return hash_ts_array(entry->local, hash_ts_array(entry->remote,
+ chunk_hash_inc(chunk_from_thing(entry->mark_in),
+ chunk_hash(chunk_from_thing(entry->mark_out)))));
+}
+
+/**
+ * Compare two array with traffic selectors for equality
+ */
+static bool ts_array_equals(array_t *a, array_t *b)
+{
+ traffic_selector_t *tsa, *tsb;
+ enumerator_t *ae, *be;
+ bool equal = TRUE;
+
+ if (array_count(a) != array_count(b))
+ {
+ return FALSE;
+ }
+
+ ae = array_create_enumerator(a);
+ be = array_create_enumerator(b);
+ while (equal && ae->enumerate(ae, &tsa) && be->enumerate(be, &tsb))
+ {
+ equal = tsa->equals(tsa, tsb);
+ }
+ ae->destroy(ae);
+ be->destroy(be);
+
+ return equal;
+}
+
+/**
+ * Hashtable equals function for reqid entries using traffic selectors as key
+ */
+static bool equals_reqid_by_ts(reqid_entry_t *a, reqid_entry_t *b)
+{
+ return ts_array_equals(a->local, b->local) &&
+ ts_array_equals(a->remote, b->remote) &&
+ a->mark_in.value == b->mark_in.value &&
+ a->mark_in.mask == b->mark_in.mask &&
+ a->mark_out.value == b->mark_out.value &&
+ a->mark_out.mask == b->mark_out.mask;
+}
+
+/**
+ * Create an array from copied traffic selector list items
+ */
+static array_t *array_from_ts_list(linked_list_t *list)
+{
+ enumerator_t *enumerator;
+ traffic_selector_t *ts;
+ array_t *array;
+
+ array = array_create(0, 0);
+
+ enumerator = list->create_enumerator(list);
+ while (enumerator->enumerate(enumerator, &ts))
+ {
+ array_insert(array, ARRAY_TAIL, ts->clone(ts));
+ }
+ enumerator->destroy(enumerator);
+
+ return array;
+}
+
+METHOD(kernel_interface_t, alloc_reqid, status_t,
+ private_kernel_interface_t *this,
+ linked_list_t *local_ts, linked_list_t *remote_ts,
+ mark_t mark_in, mark_t mark_out, u_int32_t *reqid)
+{
+ static u_int32_t counter = 0;
+ reqid_entry_t *entry = NULL, *tmpl;
+ status_t status = SUCCESS;
+
+ INIT(tmpl,
+ .local = array_from_ts_list(local_ts),
+ .remote = array_from_ts_list(remote_ts),
+ .mark_in = mark_in,
+ .mark_out = mark_out,
+ .reqid = *reqid,
+ );
+
+ this->mutex->lock(this->mutex);
+ if (tmpl->reqid)
+ {
+ /* search by reqid if given */
+ entry = this->reqids->get(this->reqids, tmpl);
+ }
+ if (entry)
+ {
+ /* we don't require a traffic selector match for explicit reqids,
+ * as we wan't to reuse a reqid for trap-triggered policies that
+ * got narrowed during negotiation. */
+ reqid_entry_destroy(tmpl);
+ }
+ else
+ {
+ /* search by traffic selectors */
+ entry = this->reqids_by_ts->get(this->reqids_by_ts, tmpl);
+ if (entry)
+ {
+ reqid_entry_destroy(tmpl);
+ }
+ else
+ {
+ /* none found, create a new entry, allocating a reqid */
+ entry = tmpl;
+ entry->reqid = ++counter;
+ this->reqids_by_ts->put(this->reqids_by_ts, entry, entry);
+ this->reqids->put(this->reqids, entry, entry);
+ }
+ *reqid = entry->reqid;
+ }
+ entry->refs++;
+ this->mutex->unlock(this->mutex);
+
+ return status;
+}
+
+METHOD(kernel_interface_t, release_reqid, status_t,
+ private_kernel_interface_t *this, u_int32_t reqid,
+ mark_t mark_in, mark_t mark_out)
+{
+ reqid_entry_t *entry, tmpl = {
+ .reqid = reqid,
+ .mark_in = mark_in,
+ .mark_out = mark_out,
+ };
+
+ this->mutex->lock(this->mutex);
+ entry = this->reqids->remove(this->reqids, &tmpl);
+ if (entry)
+ {
+ if (--entry->refs == 0)
+ {
+ entry = this->reqids_by_ts->remove(this->reqids_by_ts, entry);
+ if (entry)
+ {
+ reqid_entry_destroy(entry);
+ }
+ }
+ else
+ {
+ this->reqids->put(this->reqids, entry, entry);
+ }
+ }
+ this->mutex->unlock(this->mutex);
+
+ if (entry)
+ {
+ return SUCCESS;
+ }
+ return NOT_FOUND;
}
METHOD(kernel_interface_t, add_sa, status_t,
@@ -181,8 +421,8 @@ METHOD(kernel_interface_t, add_sa, status_t,
u_int32_t tfc, lifetime_cfg_t *lifetime, u_int16_t enc_alg, chunk_t enc_key,
u_int16_t int_alg, chunk_t int_key, ipsec_mode_t mode,
u_int16_t ipcomp, u_int16_t cpi, u_int32_t replay_window,
- bool initiator, bool encap, bool esn, bool inbound,
- traffic_selector_t *src_ts, traffic_selector_t *dst_ts)
+ bool initiator, bool encap, bool esn, bool inbound, bool update,
+ linked_list_t *src_ts, linked_list_t *dst_ts)
{
if (!this->ipsec)
{
@@ -191,7 +431,7 @@ METHOD(kernel_interface_t, add_sa, status_t,
return this->ipsec->add_sa(this->ipsec, src, dst, spi, protocol, reqid,
mark, tfc, lifetime, enc_alg, enc_key, int_alg, int_key, mode,
ipcomp, cpi, replay_window, initiator, encap, esn, inbound,
- src_ts, dst_ts);
+ update, src_ts, dst_ts);
}
METHOD(kernel_interface_t, update_sa, status_t,
@@ -575,17 +815,18 @@ METHOD(kernel_interface_t, acquire, void,
}
METHOD(kernel_interface_t, expire, void,
- private_kernel_interface_t *this, u_int32_t reqid, u_int8_t protocol,
- u_int32_t spi, bool hard)
+ private_kernel_interface_t *this, u_int8_t protocol, u_int32_t spi,
+ host_t *dst, bool hard)
{
kernel_listener_t *listener;
enumerator_t *enumerator;
+
this->mutex->lock(this->mutex);
enumerator = this->listeners->create_enumerator(this->listeners);
while (enumerator->enumerate(enumerator, &listener))
{
if (listener->expire &&
- !listener->expire(listener, reqid, protocol, spi, hard))
+ !listener->expire(listener, protocol, spi, dst, hard))
{
this->listeners->remove_at(this->listeners, enumerator);
}
@@ -595,17 +836,18 @@ METHOD(kernel_interface_t, expire, void,
}
METHOD(kernel_interface_t, mapping, void,
- private_kernel_interface_t *this, u_int32_t reqid, u_int32_t spi,
- host_t *remote)
+ private_kernel_interface_t *this, u_int8_t protocol, u_int32_t spi,
+ host_t *dst, host_t *remote)
{
kernel_listener_t *listener;
enumerator_t *enumerator;
+
this->mutex->lock(this->mutex);
enumerator = this->listeners->create_enumerator(this->listeners);
while (enumerator->enumerate(enumerator, &listener))
{
if (listener->mapping &&
- !listener->mapping(listener, reqid, spi, remote))
+ !listener->mapping(listener, protocol, spi, dst, remote))
{
this->listeners->remove_at(this->listeners, enumerator);
}
@@ -733,6 +975,8 @@ METHOD(kernel_interface_t, destroy, void,
DESTROY_IF(this->ipsec);
DESTROY_IF(this->net);
DESTROY_FUNCTION_IF(this->ifaces_filter, (void*)free);
+ this->reqids->destroy(this->reqids);
+ this->reqids_by_ts->destroy(this->reqids_by_ts);
this->listeners->destroy(this->listeners);
this->mutex->destroy(this->mutex);
free(this);
@@ -751,6 +995,8 @@ kernel_interface_t *kernel_interface_create()
.get_features = _get_features,
.get_spi = _get_spi,
.get_cpi = _get_cpi,
+ .alloc_reqid = _alloc_reqid,
+ .release_reqid = _release_reqid,
.add_sa = _add_sa,
.update_sa = _update_sa,
.query_sa = _query_sa,
@@ -795,6 +1041,10 @@ kernel_interface_t *kernel_interface_create()
.listeners = linked_list_create(),
.mutex_algs = mutex_create(MUTEX_TYPE_DEFAULT),
.algorithms = linked_list_create(),
+ .reqids = hashtable_create((hashtable_hash_t)hash_reqid,
+ (hashtable_equals_t)equals_reqid, 8),
+ .reqids_by_ts = hashtable_create((hashtable_hash_t)hash_reqid_by_ts,
+ (hashtable_equals_t)equals_reqid_by_ts, 8),
);
ifaces = lib->settings->get_str(lib->settings,