diff options
Diffstat (limited to 'src/libtls/tls_socket.c')
-rw-r--r-- | src/libtls/tls_socket.c | 272 |
1 files changed, 199 insertions, 73 deletions
diff --git a/src/libtls/tls_socket.c b/src/libtls/tls_socket.c index 75b714e30..4ba964000 100644 --- a/src/libtls/tls_socket.c +++ b/src/libtls/tls_socket.c @@ -42,14 +42,39 @@ struct private_tls_application_t { tls_application_t application; /** - * Chunk of data to send + * Output buffer to write to */ chunk_t out; /** - * Chunk of data received + * Number of bytes written to out + */ + size_t out_done; + + /** + * Input buffer to read to */ chunk_t in; + + /** + * Number of bytes read to in + */ + size_t in_done; + + /** + * Cached input data + */ + chunk_t cache; + + /** + * Bytes consumed in cache + */ + size_t cache_done; + + /** + * Close TLS connection? + */ + bool close; }; /** @@ -82,22 +107,44 @@ METHOD(tls_application_t, process, status_t, private_tls_application_t *this, bio_reader_t *reader) { chunk_t data; + size_t len; - if (!reader->read_data(reader, reader->remaining(reader), &data)) + if (this->close) { - return FAILED; + return SUCCESS; + } + len = min(reader->remaining(reader), this->in.len - this->in_done); + if (len) + { /* copy to read buffer as much as fits in */ + if (!reader->read_data(reader, len, &data)) + { + return FAILED; + } + memcpy(this->in.ptr + this->in_done, data.ptr, data.len); + this->in_done += data.len; + } + else + { /* read buffer is full, cache for next read */ + if (!reader->read_data(reader, reader->remaining(reader), &data)) + { + return FAILED; + } + this->cache = chunk_cat("mc", this->cache, data); } - this->in = chunk_cat("mc", this->in, data); return NEED_MORE; } METHOD(tls_application_t, build, status_t, private_tls_application_t *this, bio_writer_t *writer) { - if (this->out.len) + if (this->close) + { + return SUCCESS; + } + if (this->out.len > this->out_done) { writer->write_data(writer, this->out); - this->out = chunk_empty; + this->out_done = this->out.len; return NEED_MORE; } return INVALID_STATE; @@ -106,11 +153,12 @@ METHOD(tls_application_t, build, status_t, /** * TLS data exchange loop */ -static bool exchange(private_tls_socket_t *this, bool wr) +static bool exchange(private_tls_socket_t *this, bool wr, bool block) { char buf[CRYPTO_BUF_SIZE], *pos; - ssize_t len, out; - int round = 0; + ssize_t in, out; + size_t len; + int round = 0, flags; for (round = 0; TRUE; round++) { @@ -137,6 +185,8 @@ static bool exchange(private_tls_socket_t *this, bool wr) continue; case INVALID_STATE: break; + case SUCCESS: + return TRUE; default: return FALSE; } @@ -144,55 +194,97 @@ static bool exchange(private_tls_socket_t *this, bool wr) } if (wr) { - if (this->app.out.len == 0) + if (this->app.out_done == this->app.out.len) { /* all data written */ return TRUE; } } else { - if (this->app.in.len) - { /* some data received */ + if (this->app.in_done == this->app.in.len) + { /* buffer fully received */ return TRUE; } - if (round > 0) - { /* did some handshaking, return empty chunk to not block */ - return TRUE; + } + + flags = 0; + if (this->app.out_done == this->app.out.len) + { + if (!block || this->app.in_done) + { + flags |= MSG_DONTWAIT; } } - len = read(this->fd, buf, sizeof(buf)); - if (len <= 0) + in = recv(this->fd, buf, sizeof(buf), flags); + if (in < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) + { + if (this->app.in_done == 0) + { + /* reading, nothing got yet, and call would block */ + errno = EWOULDBLOCK; + this->app.in_done = -1; + } + return TRUE; + } return FALSE; } - if (this->tls->process(this->tls, buf, len) != NEED_MORE) + if (in == 0) + { /* EOF */ + return TRUE; + } + switch (this->tls->process(this->tls, buf, in)) { - return FALSE; + case NEED_MORE: + break; + case SUCCESS: + return TRUE; + default: + return FALSE; } } } -METHOD(tls_socket_t, read_, bool, - private_tls_socket_t *this, chunk_t *buf) +METHOD(tls_socket_t, read_, ssize_t, + private_tls_socket_t *this, void *buf, size_t len, bool block) { - if (exchange(this, FALSE)) + if (this->app.cache.len) { - *buf = this->app.in; - this->app.in = chunk_empty; - return TRUE; + size_t cache; + + cache = min(len, this->app.cache.len - this->app.cache_done); + memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache); + + this->app.cache_done += cache; + if (this->app.cache_done == this->app.cache.len) + { + chunk_free(&this->app.cache); + this->app.cache_done = 0; + } + return cache; } - return FALSE; + this->app.in.ptr = buf; + this->app.in.len = len; + this->app.in_done = 0; + if (exchange(this, FALSE, block)) + { + return this->app.in_done; + } + return -1; } -METHOD(tls_socket_t, write_, bool, - private_tls_socket_t *this, chunk_t buf) +METHOD(tls_socket_t, write_, ssize_t, + private_tls_socket_t *this, void *buf, size_t len) { - this->app.out = buf; - if (exchange(this, TRUE)) + this->app.out.ptr = buf; + this->app.out.len = len; + this->app.out_done = 0; + if (exchange(this, TRUE, FALSE)) { - return TRUE; + return this->app.out_done; } - return FALSE; + return -1; } METHOD(tls_socket_t, splice, bool, @@ -200,68 +292,85 @@ METHOD(tls_socket_t, splice, bool, { char buf[PLAIN_BUF_SIZE], *pos; fd_set set; - chunk_t data; - ssize_t len; - bool old; + ssize_t in, out; + bool old, plain_eof = FALSE, crypto_eof = FALSE; - while (TRUE) + while (!plain_eof && !crypto_eof) { FD_ZERO(&set); FD_SET(rfd, &set); FD_SET(this->fd, &set); old = thread_cancelability(TRUE); - len = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL); + in = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL); thread_cancelability(old); - if (len == -1) + if (in == -1) { DBG1(DBG_TLS, "TLS select error: %s", strerror(errno)); return FALSE; } - if (FD_ISSET(this->fd, &set)) + while (!plain_eof && FD_ISSET(this->fd, &set)) { - if (!read_(this, &data)) - { - DBG2(DBG_TLS, "TLS read error/disconnect"); - return TRUE; - } - pos = data.ptr; - while (data.len) + in = read_(this, buf, sizeof(buf), FALSE); + switch (in) { - len = write(wfd, pos, data.len); - if (len == -1) - { - free(data.ptr); - DBG1(DBG_TLS, "TLS plain write error: %s", strerror(errno)); - return FALSE; - } - data.len -= len; - pos += len; + case 0: + plain_eof = TRUE; + break; + case -1: + if (errno != EWOULDBLOCK) + { + DBG1(DBG_TLS, "TLS read error: %s", strerror(errno)); + return FALSE; + } + break; + default: + pos = buf; + while (in) + { + out = write(wfd, pos, in); + if (out == -1) + { + DBG1(DBG_TLS, "TLS plain write error: %s", + strerror(errno)); + return FALSE; + } + in -= out; + pos += out; + } + continue; } - free(data.ptr); + break; } - if (FD_ISSET(rfd, &set)) + if (!crypto_eof && FD_ISSET(rfd, &set)) { - len = read(rfd, buf, sizeof(buf)); - if (len > 0) - { - if (!write_(this, chunk_create(buf, len))) - { - DBG1(DBG_TLS, "TLS write error"); - return FALSE; - } - } - else + in = read(rfd, buf, sizeof(buf)); + switch (in) { - if (len < 0) - { + case 0: + crypto_eof = TRUE; + break; + case -1: DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno)); return FALSE; - } - return TRUE; + default: + pos = buf; + while (in) + { + out = write_(this, pos, in); + if (out == -1) + { + DBG1(DBG_TLS, "TLS write error"); + return FALSE; + } + in -= out; + pos += out; + } + break; } } } + return TRUE; } METHOD(tls_socket_t, get_fd, int, @@ -270,11 +379,26 @@ METHOD(tls_socket_t, get_fd, int, return this->fd; } +METHOD(tls_socket_t, get_server_id, identification_t*, + private_tls_socket_t *this) +{ + return this->tls->get_server_id(this->tls); +} + +METHOD(tls_socket_t, get_peer_id, identification_t*, + private_tls_socket_t *this) +{ + return this->tls->get_peer_id(this->tls); +} + METHOD(tls_socket_t, destroy, void, private_tls_socket_t *this) { + /* send a TLS close notify if not done yet */ + this->app.close = TRUE; + write_(this, NULL, 0); + free(this->app.cache.ptr); this->tls->destroy(this->tls); - free(this->app.in.ptr); free(this); } @@ -292,6 +416,8 @@ tls_socket_t *tls_socket_create(bool is_server, identification_t *server, .write = _write_, .splice = _splice, .get_fd = _get_fd, + .get_server_id = _get_server_id, + .get_peer_id = _get_peer_id, .destroy = _destroy, }, .app = { |