diff options
Diffstat (limited to 'src/libhydra/kernel/kernel_interface.c')
-rw-r--r-- | src/libhydra/kernel/kernel_interface.c | 276 |
1 files changed, 263 insertions, 13 deletions
diff --git a/src/libhydra/kernel/kernel_interface.c b/src/libhydra/kernel/kernel_interface.c index 3fa28e054..ce31bd410 100644 --- a/src/libhydra/kernel/kernel_interface.c +++ b/src/libhydra/kernel/kernel_interface.c @@ -43,6 +43,8 @@ #include <utils/debug.h> #include <threading/mutex.h> #include <collections/linked_list.h> +#include <collections/hashtable.h> +#include <collections/array.h> typedef struct private_kernel_interface_t private_kernel_interface_t; @@ -115,6 +117,16 @@ struct private_kernel_interface_t { linked_list_t *listeners; /** + * Reqid entries indexed by reqids + */ + hashtable_t *reqids; + + /** + * Reqid entries indexed by traffic selectors + */ + hashtable_t *reqids_by_ts; + + /** * mutex for algorithm mappings */ mutex_t *mutex_algs; @@ -155,24 +167,252 @@ METHOD(kernel_interface_t, get_features, kernel_feature_t, METHOD(kernel_interface_t, get_spi, status_t, private_kernel_interface_t *this, host_t *src, host_t *dst, - u_int8_t protocol, u_int32_t reqid, u_int32_t *spi) + u_int8_t protocol, u_int32_t *spi) { if (!this->ipsec) { return NOT_SUPPORTED; } - return this->ipsec->get_spi(this->ipsec, src, dst, protocol, reqid, spi); + return this->ipsec->get_spi(this->ipsec, src, dst, protocol, spi); } METHOD(kernel_interface_t, get_cpi, status_t, private_kernel_interface_t *this, host_t *src, host_t *dst, - u_int32_t reqid, u_int16_t *cpi) + u_int16_t *cpi) { if (!this->ipsec) { return NOT_SUPPORTED; } - return this->ipsec->get_cpi(this->ipsec, src, dst, reqid, cpi); + return this->ipsec->get_cpi(this->ipsec, src, dst, cpi); +} + +/** + * Reqid mapping entry + */ +typedef struct { + /** allocated reqid */ + u_int32_t reqid; + /** references to this entry */ + u_int refs; + /** inbound mark used for SA */ + mark_t mark_in; + /** outbound mark used for SA */ + mark_t mark_out; + /** local traffic selectors */ + array_t *local; + /** remote traffic selectors */ + array_t *remote; +} reqid_entry_t; + +/** + * Destroy a reqid mapping entry + */ +static void reqid_entry_destroy(reqid_entry_t *entry) +{ + array_destroy_offset(entry->local, offsetof(traffic_selector_t, destroy)); + array_destroy_offset(entry->remote, offsetof(traffic_selector_t, destroy)); + free(entry); +} + +/** + * Hashtable hash function for reqid entries using reqid as key + */ +static u_int hash_reqid(reqid_entry_t *entry) +{ + return chunk_hash_inc(chunk_from_thing(entry->reqid), + chunk_hash_inc(chunk_from_thing(entry->mark_in), + chunk_hash(chunk_from_thing(entry->mark_out)))); +} + +/** + * Hashtable equals function for reqid entries using reqid as key + */ +static bool equals_reqid(reqid_entry_t *a, reqid_entry_t *b) +{ + return a->reqid == b->reqid && + a->mark_in.value == b->mark_in.value && + a->mark_in.mask == b->mark_in.mask && + a->mark_out.value == b->mark_out.value && + a->mark_out.mask == b->mark_out.mask; +} + +/** + * Hash an array of traffic selectors + */ +static u_int hash_ts_array(array_t *array, u_int hash) +{ + enumerator_t *enumerator; + traffic_selector_t *ts; + + enumerator = array_create_enumerator(array); + while (enumerator->enumerate(enumerator, &ts)) + { + hash = ts->hash(ts, hash); + } + enumerator->destroy(enumerator); + + return hash; +} + +/** + * Hashtable hash function for reqid entries using traffic selectors as key + */ +static u_int hash_reqid_by_ts(reqid_entry_t *entry) +{ + return hash_ts_array(entry->local, hash_ts_array(entry->remote, + chunk_hash_inc(chunk_from_thing(entry->mark_in), + chunk_hash(chunk_from_thing(entry->mark_out))))); +} + +/** + * Compare two array with traffic selectors for equality + */ +static bool ts_array_equals(array_t *a, array_t *b) +{ + traffic_selector_t *tsa, *tsb; + enumerator_t *ae, *be; + bool equal = TRUE; + + if (array_count(a) != array_count(b)) + { + return FALSE; + } + + ae = array_create_enumerator(a); + be = array_create_enumerator(b); + while (equal && ae->enumerate(ae, &tsa) && be->enumerate(be, &tsb)) + { + equal = tsa->equals(tsa, tsb); + } + ae->destroy(ae); + be->destroy(be); + + return equal; +} + +/** + * Hashtable equals function for reqid entries using traffic selectors as key + */ +static bool equals_reqid_by_ts(reqid_entry_t *a, reqid_entry_t *b) +{ + return ts_array_equals(a->local, b->local) && + ts_array_equals(a->remote, b->remote) && + a->mark_in.value == b->mark_in.value && + a->mark_in.mask == b->mark_in.mask && + a->mark_out.value == b->mark_out.value && + a->mark_out.mask == b->mark_out.mask; +} + +/** + * Create an array from copied traffic selector list items + */ +static array_t *array_from_ts_list(linked_list_t *list) +{ + enumerator_t *enumerator; + traffic_selector_t *ts; + array_t *array; + + array = array_create(0, 0); + + enumerator = list->create_enumerator(list); + while (enumerator->enumerate(enumerator, &ts)) + { + array_insert(array, ARRAY_TAIL, ts->clone(ts)); + } + enumerator->destroy(enumerator); + + return array; +} + +METHOD(kernel_interface_t, alloc_reqid, status_t, + private_kernel_interface_t *this, + linked_list_t *local_ts, linked_list_t *remote_ts, + mark_t mark_in, mark_t mark_out, u_int32_t *reqid) +{ + static u_int32_t counter = 0; + reqid_entry_t *entry = NULL, *tmpl; + status_t status = SUCCESS; + + INIT(tmpl, + .local = array_from_ts_list(local_ts), + .remote = array_from_ts_list(remote_ts), + .mark_in = mark_in, + .mark_out = mark_out, + .reqid = *reqid, + ); + + this->mutex->lock(this->mutex); + if (tmpl->reqid) + { + /* search by reqid if given */ + entry = this->reqids->get(this->reqids, tmpl); + } + if (entry) + { + /* we don't require a traffic selector match for explicit reqids, + * as we wan't to reuse a reqid for trap-triggered policies that + * got narrowed during negotiation. */ + reqid_entry_destroy(tmpl); + } + else + { + /* search by traffic selectors */ + entry = this->reqids_by_ts->get(this->reqids_by_ts, tmpl); + if (entry) + { + reqid_entry_destroy(tmpl); + } + else + { + /* none found, create a new entry, allocating a reqid */ + entry = tmpl; + entry->reqid = ++counter; + this->reqids_by_ts->put(this->reqids_by_ts, entry, entry); + this->reqids->put(this->reqids, entry, entry); + } + *reqid = entry->reqid; + } + entry->refs++; + this->mutex->unlock(this->mutex); + + return status; +} + +METHOD(kernel_interface_t, release_reqid, status_t, + private_kernel_interface_t *this, u_int32_t reqid, + mark_t mark_in, mark_t mark_out) +{ + reqid_entry_t *entry, tmpl = { + .reqid = reqid, + .mark_in = mark_in, + .mark_out = mark_out, + }; + + this->mutex->lock(this->mutex); + entry = this->reqids->remove(this->reqids, &tmpl); + if (entry) + { + if (--entry->refs == 0) + { + entry = this->reqids_by_ts->remove(this->reqids_by_ts, entry); + if (entry) + { + reqid_entry_destroy(entry); + } + } + else + { + this->reqids->put(this->reqids, entry, entry); + } + } + this->mutex->unlock(this->mutex); + + if (entry) + { + return SUCCESS; + } + return NOT_FOUND; } METHOD(kernel_interface_t, add_sa, status_t, @@ -181,8 +421,8 @@ METHOD(kernel_interface_t, add_sa, status_t, u_int32_t tfc, lifetime_cfg_t *lifetime, u_int16_t enc_alg, chunk_t enc_key, u_int16_t int_alg, chunk_t int_key, ipsec_mode_t mode, u_int16_t ipcomp, u_int16_t cpi, u_int32_t replay_window, - bool initiator, bool encap, bool esn, bool inbound, - traffic_selector_t *src_ts, traffic_selector_t *dst_ts) + bool initiator, bool encap, bool esn, bool inbound, bool update, + linked_list_t *src_ts, linked_list_t *dst_ts) { if (!this->ipsec) { @@ -191,7 +431,7 @@ METHOD(kernel_interface_t, add_sa, status_t, return this->ipsec->add_sa(this->ipsec, src, dst, spi, protocol, reqid, mark, tfc, lifetime, enc_alg, enc_key, int_alg, int_key, mode, ipcomp, cpi, replay_window, initiator, encap, esn, inbound, - src_ts, dst_ts); + update, src_ts, dst_ts); } METHOD(kernel_interface_t, update_sa, status_t, @@ -575,17 +815,18 @@ METHOD(kernel_interface_t, acquire, void, } METHOD(kernel_interface_t, expire, void, - private_kernel_interface_t *this, u_int32_t reqid, u_int8_t protocol, - u_int32_t spi, bool hard) + private_kernel_interface_t *this, u_int8_t protocol, u_int32_t spi, + host_t *dst, bool hard) { kernel_listener_t *listener; enumerator_t *enumerator; + this->mutex->lock(this->mutex); enumerator = this->listeners->create_enumerator(this->listeners); while (enumerator->enumerate(enumerator, &listener)) { if (listener->expire && - !listener->expire(listener, reqid, protocol, spi, hard)) + !listener->expire(listener, protocol, spi, dst, hard)) { this->listeners->remove_at(this->listeners, enumerator); } @@ -595,17 +836,18 @@ METHOD(kernel_interface_t, expire, void, } METHOD(kernel_interface_t, mapping, void, - private_kernel_interface_t *this, u_int32_t reqid, u_int32_t spi, - host_t *remote) + private_kernel_interface_t *this, u_int8_t protocol, u_int32_t spi, + host_t *dst, host_t *remote) { kernel_listener_t *listener; enumerator_t *enumerator; + this->mutex->lock(this->mutex); enumerator = this->listeners->create_enumerator(this->listeners); while (enumerator->enumerate(enumerator, &listener)) { if (listener->mapping && - !listener->mapping(listener, reqid, spi, remote)) + !listener->mapping(listener, protocol, spi, dst, remote)) { this->listeners->remove_at(this->listeners, enumerator); } @@ -733,6 +975,8 @@ METHOD(kernel_interface_t, destroy, void, DESTROY_IF(this->ipsec); DESTROY_IF(this->net); DESTROY_FUNCTION_IF(this->ifaces_filter, (void*)free); + this->reqids->destroy(this->reqids); + this->reqids_by_ts->destroy(this->reqids_by_ts); this->listeners->destroy(this->listeners); this->mutex->destroy(this->mutex); free(this); @@ -751,6 +995,8 @@ kernel_interface_t *kernel_interface_create() .get_features = _get_features, .get_spi = _get_spi, .get_cpi = _get_cpi, + .alloc_reqid = _alloc_reqid, + .release_reqid = _release_reqid, .add_sa = _add_sa, .update_sa = _update_sa, .query_sa = _query_sa, @@ -795,6 +1041,10 @@ kernel_interface_t *kernel_interface_create() .listeners = linked_list_create(), .mutex_algs = mutex_create(MUTEX_TYPE_DEFAULT), .algorithms = linked_list_create(), + .reqids = hashtable_create((hashtable_hash_t)hash_reqid, + (hashtable_equals_t)equals_reqid, 8), + .reqids_by_ts = hashtable_create((hashtable_hash_t)hash_reqid_by_ts, + (hashtable_equals_t)equals_reqid_by_ts, 8), ); ifaces = lib->settings->get_str(lib->settings, |