diff options
Diffstat (limited to 'ext/librethinkdbxx/src/connection.cc')
-rw-r--r-- | ext/librethinkdbxx/src/connection.cc | 434 |
1 files changed, 434 insertions, 0 deletions
diff --git a/ext/librethinkdbxx/src/connection.cc b/ext/librethinkdbxx/src/connection.cc new file mode 100644 index 00000000..53d106ec --- /dev/null +++ b/ext/librethinkdbxx/src/connection.cc @@ -0,0 +1,434 @@ +#include <sys/types.h> +#include <sys/socket.h> +#include <sys/select.h> + +#include <netdb.h> +#include <unistd.h> + +#include <algorithm> +#include <cstring> +#include <cinttypes> +#include <memory> + +#include "connection.h" +#include "connection_p.h" +#include "json_p.h" +#include "exceptions.h" +#include "term.h" +#include "cursor_p.h" + +#include "rapidjson-config.h" +#include "rapidjson/rapidjson.h" +#include "rapidjson/encodedstream.h" +#include "rapidjson/document.h" + +namespace RethinkDB { + +using QueryType = Protocol::Query::QueryType; + +// constants +const int debug_net = 0; +const uint32_t version_magic = + static_cast<uint32_t>(Protocol::VersionDummy::Version::V0_4); +const uint32_t json_magic = + static_cast<uint32_t>(Protocol::VersionDummy::Protocol::JSON); + +std::unique_ptr<Connection> connect(std::string host, int port, std::string auth_key) { + struct addrinfo hints; + memset(&hints, 0, sizeof hints); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + char port_str[16]; + snprintf(port_str, 16, "%d", port); + struct addrinfo *servinfo; + int ret = getaddrinfo(host.c_str(), port_str, &hints, &servinfo); + if (ret) throw Error("getaddrinfo: %s\n", gai_strerror(ret)); + + struct addrinfo *p; + Error error; + int sockfd; + for (p = servinfo; p != NULL; p = p->ai_next) { + sockfd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (sockfd == -1) { + error = Error::from_errno("socket"); + continue; + } + + if (connect(sockfd, p->ai_addr, p->ai_addrlen) == -1) { + ::close(sockfd); + error = Error::from_errno("connect"); + continue; + } + + break; + } + + if (p == NULL) { + throw error; + } + + freeaddrinfo(servinfo); + + std::unique_ptr<ConnectionPrivate> conn_private(new ConnectionPrivate(sockfd)); + WriteLock writer(conn_private.get()); + { + size_t size = auth_key.size(); + char buf[12 + size]; + memcpy(buf, &version_magic, 4); + uint32_t n = size; + memcpy(buf + 4, &n, 4); + memcpy(buf + 8, auth_key.data(), size); + memcpy(buf + 8 + size, &json_magic, 4); + writer.send(buf, sizeof buf); + } + + ReadLock reader(conn_private.get()); + { + const size_t max_response_length = 1024; + char buf[max_response_length + 1]; + size_t len = reader.recv_cstring(buf, max_response_length); + if (len == max_response_length || strcmp(buf, "SUCCESS")) { + buf[len] = 0; + ::close(sockfd); + throw Error("Server rejected connection with message: %s", buf); + } + } + + return std::unique_ptr<Connection>(new Connection(conn_private.release())); +} + +Connection::Connection(ConnectionPrivate *dd) : d(dd) { } +Connection::~Connection() { + // close(); + if (d->guarded_sockfd >= 0) + ::close(d->guarded_sockfd); +} + +size_t ReadLock::recv_some(char* buf, size_t size, double wait) { + if (wait != FOREVER) { + while (true) { + fd_set readfds; + struct timeval tv; + + FD_ZERO(&readfds); + FD_SET(conn->guarded_sockfd, &readfds); + + tv.tv_sec = (int)wait; + tv.tv_usec = (int)((wait - (int)wait) / MICROSECOND); + int rv = select(conn->guarded_sockfd + 1, &readfds, NULL, NULL, &tv); + if (rv == -1) { + throw Error::from_errno("select"); + } else if (rv == 0) { + throw TimeoutException(); + } + + if (FD_ISSET(conn->guarded_sockfd, &readfds)) { + break; + } + } + } + + ssize_t numbytes = ::recv(conn->guarded_sockfd, buf, size, 0); + if (numbytes <= 0) throw Error::from_errno("recv"); + if (debug_net > 1) { + fprintf(stderr, "<< %s\n", write_datum(std::string(buf, numbytes)).c_str()); + } + + return numbytes; +} + +void ReadLock::recv(char* buf, size_t size, double wait) { + while (size) { + size_t numbytes = recv_some(buf, size, wait); + + buf += numbytes; + size -= numbytes; + } +} + +size_t ReadLock::recv_cstring(char* buf, size_t max_size){ + size_t size = 0; + for (; size < max_size; size++) { + recv(buf, 1, FOREVER); + if (*buf == 0) { + break; + } + buf++; + } + return size; +} + +void WriteLock::send(const char* buf, size_t size) { + while (size) { + ssize_t numbytes = ::write(conn->guarded_sockfd, buf, size); + if (numbytes == -1) throw Error::from_errno("write"); + if (debug_net > 1) { + fprintf(stderr, ">> %s\n", write_datum(std::string(buf, numbytes)).c_str()); + } + + buf += numbytes; + size -= numbytes; + } +} + +void WriteLock::send(const std::string data) { + send(data.data(), data.size()); +} + +std::string ReadLock::recv(size_t size) { + char buf[size]; + recv(buf, size, FOREVER); + return buf; +} + +void Connection::close() { + CacheLock guard(d.get()); + for (auto& it : d->guarded_cache) { + stop_query(it.first); + } + + int ret = ::close(d->guarded_sockfd); + if (ret == -1) { + throw Error::from_errno("close"); + } + d->guarded_sockfd = -1; +} + +Response ConnectionPrivate::wait_for_response(uint64_t token_want, double wait) { + CacheLock guard(this); + ConnectionPrivate::TokenCache& cache = guarded_cache[token_want]; + + while (true) { + if (!cache.responses.empty()) { + Response response(std::move(cache.responses.front())); + cache.responses.pop(); + if (cache.closed && cache.responses.empty()) { + guarded_cache.erase(token_want); + } + + return response; + } + + if (cache.closed) { + throw Error("Trying to read from a closed token"); + } + + if (guarded_loop_active) { + cache.cond.wait(guard.inner_lock); + } else { + break; + } + } + + ReadLock reader(this); + return reader.read_loop(token_want, std::move(guard), wait); +} + +Response ReadLock::read_loop(uint64_t token_want, CacheLock&& guard, double wait) { + if (!guard.inner_lock) { + guard.lock(); + } + if (conn->guarded_loop_active) { + throw Error("Cannot run more than one read loop on the same connection"); + } + conn->guarded_loop_active = true; + guard.unlock(); + + try { + while (true) { + char buf[12]; + bzero(buf, sizeof(buf)); + recv(buf, 12, wait); + uint64_t token_got; + memcpy(&token_got, buf, 8); + uint32_t length; + memcpy(&length, buf + 8, 4); + + std::unique_ptr<char[]> bufmem(new char[length + 1]); + char *buffer = bufmem.get(); + bzero(buffer, length + 1); + recv(buffer, length, wait); + buffer[length] = '\0'; + + rapidjson::Document json; + json.ParseInsitu(buffer); + if (json.HasParseError()) { + fprintf(stderr, "json parse error, code: %d, position: %d\n", + (int)json.GetParseError(), (int)json.GetErrorOffset()); + } else if (json.IsNull()) { + fprintf(stderr, "null value, read: %s\n", buffer); + } + + Datum datum = read_datum(json); + if (debug_net > 0) { + fprintf(stderr, "[%" PRIu64 "] << %s\n", token_got, write_datum(datum).c_str()); + } + + Response response(std::move(datum)); + + if (token_got == token_want) { + guard.lock(); + if (response.type != Protocol::Response::ResponseType::SUCCESS_PARTIAL) { + auto it = conn->guarded_cache.find(token_got); + if (it != conn->guarded_cache.end()) { + it->second.closed = true; + it->second.cond.notify_all(); + } + conn->guarded_cache.erase(it); + } + conn->guarded_loop_active = false; + for (auto& it : conn->guarded_cache) { + it.second.cond.notify_all(); + } + return response; + } else { + guard.lock(); + auto it = conn->guarded_cache.find(token_got); + if (it == conn->guarded_cache.end()) { + // drop the response + } else if (!it->second.closed) { + it->second.responses.emplace(std::move(response)); + if (response.type != Protocol::Response::ResponseType::SUCCESS_PARTIAL) { + it->second.closed = true; + } + } + it->second.cond.notify_all(); + guard.unlock(); + } + } + } catch (const TimeoutException &e) { + if (!guard.inner_lock){ + guard.lock(); + } + conn->guarded_loop_active = false; + throw e; + } +} + +void ConnectionPrivate::run_query(Query query, bool no_reply) { + WriteLock writer(this); + writer.send(query.serialize()); +} + +Cursor Connection::start_query(Term *term, OptArgs&& opts) { + bool no_reply = false; + auto it = opts.find("noreply"); + if (it != opts.end()) { + no_reply = *(it->second.datum.get_boolean()); + } + + uint64_t token = d->new_token(); + { + CacheLock guard(d.get()); + d->guarded_cache[token]; + } + + d->run_query(Query{QueryType::START, token, term->datum, std::move(opts)}); + if (no_reply) { + return Cursor(new CursorPrivate(token, this, Nil())); + } + + Cursor cursor(new CursorPrivate(token, this)); + Response response = d->wait_for_response(token, FOREVER); + cursor.d->add_response(std::move(response)); + return cursor; +} + +void Connection::stop_query(uint64_t token) { + const auto& it = d->guarded_cache.find(token); + if (it != d->guarded_cache.end() && !it->second.closed) { + d->run_query(Query{QueryType::STOP, token}, true); + } +} + +void Connection::continue_query(uint64_t token) { + d->run_query(Query{QueryType::CONTINUE, token}, true); +} + +Error Response::as_error() { + std::string repr; + if (result.size() == 1) { + std::string* string = result[0].get_string(); + if (string) { + repr = *string; + } else { + repr = write_datum(result[0]); + } + } else { + repr = write_datum(Datum(result)); + } + std::string err; + using RT = Protocol::Response::ResponseType; + using ET = Protocol::Response::ErrorType; + switch (type) { + case RT::SUCCESS_SEQUENCE: err = "unexpected response: SUCCESS_SEQUENCE"; break; + case RT::SUCCESS_PARTIAL: err = "unexpected response: SUCCESS_PARTIAL"; break; + case RT::SUCCESS_ATOM: err = "unexpected response: SUCCESS_ATOM"; break; + case RT::WAIT_COMPLETE: err = "unexpected response: WAIT_COMPLETE"; break; + case RT::SERVER_INFO: err = "unexpected response: SERVER_INFO"; break; + case RT::CLIENT_ERROR: err = "ReqlDriverError"; break; + case RT::COMPILE_ERROR: err = "ReqlCompileError"; break; + case RT::RUNTIME_ERROR: + switch (error_type) { + case ET::INTERNAL: err = "ReqlInternalError"; break; + case ET::RESOURCE_LIMIT: err = "ReqlResourceLimitError"; break; + case ET::QUERY_LOGIC: err = "ReqlQueryLogicError"; break; + case ET::NON_EXISTENCE: err = "ReqlNonExistenceError"; break; + case ET::OP_FAILED: err = "ReqlOpFailedError"; break; + case ET::OP_INDETERMINATE: err = "ReqlOpIndeterminateError"; break; + case ET::USER: err = "ReqlUserError"; break; + case ET::PERMISSION_ERROR: err = "ReqlPermissionError"; break; + default: err = "ReqlRuntimeError"; break; + } + } + throw Error("%s: %s", err.c_str(), repr.c_str()); +} + +Protocol::Response::ResponseType response_type(double t) { + int n = static_cast<int>(t); + using RT = Protocol::Response::ResponseType; + switch (n) { + case static_cast<int>(RT::SUCCESS_ATOM): + return RT::SUCCESS_ATOM; + case static_cast<int>(RT::SUCCESS_SEQUENCE): + return RT::SUCCESS_SEQUENCE; + case static_cast<int>(RT::SUCCESS_PARTIAL): + return RT::SUCCESS_PARTIAL; + case static_cast<int>(RT::WAIT_COMPLETE): + return RT::WAIT_COMPLETE; + case static_cast<int>(RT::CLIENT_ERROR): + return RT::CLIENT_ERROR; + case static_cast<int>(RT::COMPILE_ERROR): + return RT::COMPILE_ERROR; + case static_cast<int>(RT::RUNTIME_ERROR): + return RT::RUNTIME_ERROR; + default: + throw Error("Unknown response type"); + } +} + +Protocol::Response::ErrorType runtime_error_type(double t) { + int n = static_cast<int>(t); + using ET = Protocol::Response::ErrorType; + switch (n) { + case static_cast<int>(ET::INTERNAL): + return ET::INTERNAL; + case static_cast<int>(ET::RESOURCE_LIMIT): + return ET::RESOURCE_LIMIT; + case static_cast<int>(ET::QUERY_LOGIC): + return ET::QUERY_LOGIC; + case static_cast<int>(ET::NON_EXISTENCE): + return ET::NON_EXISTENCE; + case static_cast<int>(ET::OP_FAILED): + return ET::OP_FAILED; + case static_cast<int>(ET::OP_INDETERMINATE): + return ET::OP_INDETERMINATE; + case static_cast<int>(ET::USER): + return ET::USER; + default: + throw Error("Unknown error type"); + } +} + +} |