diff options
-rw-r--r-- | osnet/Wire.hpp | 77 | ||||
-rw-r--r-- | selftest.cpp | 134 |
2 files changed, 182 insertions, 29 deletions
diff --git a/osnet/Wire.hpp b/osnet/Wire.hpp index 7b910656..d8a90826 100644 --- a/osnet/Wire.hpp +++ b/osnet/Wire.hpp @@ -178,7 +178,8 @@ public: ON_TCP_CLOSE_FUNCTION tcpCloseHandler, ON_TCP_DATA_FUNCTION tcpDataHandler, ON_TCP_WRITABLE_FUNCTION tcpWritableHandler, - bool noDelay) : + bool noDelay + ) : _datagramHandler(datagramHandler), _tcpConnectHandler(tcpConnectHandler), _tcpAcceptHandler(tcpAcceptHandler), @@ -263,11 +264,11 @@ public: * Bind a UDP socket * * @param localAddress Local endpoint address and port - * @param uptr Initial value of user pointer associated with this socket - * @param bufferSize Desired socket receive/send buffer size -- will set as close to this as possible (0 to accept default) + * @param uptr Initial value of user pointer associated with this socket (default: NULL) + * @param bufferSize Desired socket receive/send buffer size -- will set as close to this as possible (default: 0, leave alone) * @return Socket or NULL on failure to bind */ - inline WireSocket *udpBind(const struct sockaddr *localAddress,void *uptr,int bufferSize) + inline WireSocket *udpBind(const struct sockaddr *localAddress,void *uptr = (void *)0,int bufferSize = 0) { if (_socks.size() >= ZT_WIRE_MAX_SOCKETS) return (WireSocket *)0; @@ -358,26 +359,25 @@ public: * Send a UDP packet * * @param sock UDP socket - * @param addr Destination address (must be correct type for socket) - * @param addrlen Length of sockaddr_X structure + * @param remoteAddress Destination address (must be correct type for socket) * @param data Data to send * @param len Length of packet * @return True if packet appears to have been sent successfully */ - inline bool udpSend(WireSocket *sock,const struct sockaddr *addr,unsigned int addrlen,WireSocket *data,unsigned long len) + inline bool udpSend(WireSocket *sock,const struct sockaddr *remoteAddress,const void *data,unsigned long len) { WireSocketImpl &sws = *(const_cast <WireSocketImpl *>(reinterpret_cast<const WireSocketImpl *>(sock))); - return ((long)::sendto(sws.sock,data,len,0,addr,(socklen_t)addrlen) == (long)len); + return ((long)::sendto(sws.sock,data,len,0,remoteAddress,(remoteAddress->sa_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)) == (long)len); } /** * Bind a local listen socket to listen for new TCP connections * * @param localAddress Local address and port - * @param uptr Initial value of uptr for new socket + * @param uptr Initial value of uptr for new socket (default: NULL) * @return Socket or NULL on failure to bind */ - inline WireSocket *tcpListen(const struct sockaddr *localAddress,void *uptr) + inline WireSocket *tcpListen(const struct sockaddr *localAddress,void *uptr = (void *)0) { if (_socks.size() >= ZT_WIRE_MAX_SOCKETS) return (WireSocket *)0; @@ -438,30 +438,35 @@ public: /** * Start a non-blocking connect; CONNECT handler is called on success or failure * - * Note that if NULL is returned here, the handler is not called. Such - * a return would indicate failure to allocate the socket, too many - * open sockets, etc. + * A return value of NULL indicates a synchronous failure such as a + * failure to open a socket. The TCP connection handler is not called + * in this case. * - * Also note that an "instant connect" may occur for e.g. loopback - * connections. If this happens the 'connected' result paramter will - * be true. If callConnectHandlerOnInstantConnect is true, the - * TCP connect handler will be called before the function returns - * as well in this case. Otherwise it will not. + * It is possible on some platforms for an "instant connect" to occur, + * such as when connecting to a loopback address. In this case, the + * 'connected' result parameter will be set to 'true' and if the + * 'callConnectHandler' flag is true (the default) the TCP connect + * handler will be called before the function returns. + * + * These semantics can be a bit confusing, but they're less so than + * the underlying semantics of asynchronous TCP connect. * * @param remoteAddress Remote address - * @param uptr Initial value of uptr for new socket - * @param callConnectHandlerOnInstantConnect If true, call TCP connect handler now if an "instant connect" occurs - * @param connected Reference to result paramter set to true if "instant connect" occurs, false otherwise + * @param connected Result parameter: set to whether an "instant connect" has occurred (true if yes) + * @param uptr Initial value of uptr for new socket (default: NULL) + * @param callConnectHandler If true, call TCP connect handler even if result is known before function exit (default: true) * @return New socket or NULL on failure */ - inline WireSocket *tcpConnect(const struct sockaddr *remoteAddress,void *uptr,bool callConnectHandlerOnInstantConnect,bool &connected) + inline WireSocket *tcpConnect(const struct sockaddr *remoteAddress,bool &connected,void *uptr = (void *)0,bool callConnectHandler = true) { if (_socks.size() >= ZT_WIRE_MAX_SOCKETS) return (WireSocket *)0; ZT_WIRE_SOCKFD_TYPE s = ::socket(remoteAddress->sa_family,SOCK_STREAM,0); - if (!ZT_WIRE_SOCKFD_VALID(s)) + if (!ZT_WIRE_SOCKFD_VALID(s)) { + connected = false; return (WireSocket *)0; + } #if defined(_WIN32) || defined(_WIN64) { @@ -484,6 +489,7 @@ public: connected = true; if (::connect(s,remoteAddress,(remoteAddress->sa_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in))) { + connected = false; #if defined(_WIN32) || defined(_WIN64) if (WSAGetLastError() != WSAEWOULDBLOCK) { #else @@ -491,7 +497,7 @@ public: #endif ZT_WIRE_CLOSE_SOCKET(s); return (WireSocket *)0; - } else connected = false; + } // else connection is proceeding asynchronously... } try { @@ -519,9 +525,9 @@ public: memset(&(sws.saddr),0,sizeof(struct sockaddr_storage)); memcpy(&(sws.saddr),remoteAddress,(remoteAddress->sa_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)); - if ((callConnectHandlerOnInstantConnect)&&(connected)) { + if ((callConnectHandler)&&(connected)) { try { - _tcpConnectHandler((WireSocket *)&sws,uptr,true); + _tcpConnectHandler((WireSocket *)&sws,&(sws.uptr),true); } catch ( ... ) {} } @@ -541,7 +547,7 @@ public: * @param callCloseHandler If true, call close handler on socket closing failure condition * @return Number of bytes actually sent or -1 on fatal error (socket closure) */ - inline long tcpSend(WireSocket *sock,WireSocket *data,unsigned long len,bool callCloseHandler) + inline long tcpSend(WireSocket *sock,const void *data,unsigned long len,bool callCloseHandler) { WireSocketImpl &sws = *(const_cast <WireSocketImpl *>(reinterpret_cast<const WireSocketImpl *>(sock))); long n = (long)::send(sws.sock,data,len,0); @@ -706,7 +712,7 @@ public: if ((long)newSock > _nfds) _nfds = (long)newSock; sws.type = ZT_WIRE_SOCKET_TCP_IN; - sws.sock = s; + sws.sock = newSock; sws.uptr = (void *)0; memcpy(&(sws.saddr),&ss,sizeof(struct sockaddr_storage)); try { @@ -774,7 +780,7 @@ public: long oldSock = (long)sws.sock; for(typename std::list<WireSocketImpl>::iterator s(_socks.begin());s!=_socks.end();++s) { - if (&(*s) == sock) { + if (reinterpret_cast<WireSocket *>(&(*s)) == sock) { _socks.erase(s); break; } @@ -793,6 +799,19 @@ public: } }; +// Typedefs for using regular naked functions as template parameters to Wire<> +typedef void (*Wire_OnDatagramFunctionPtr)(WireSocket *sock,void **uptr,const struct sockaddr *from,void *data,unsigned long len); +typedef void (*Wire_OnTcpConnectFunction)(WireSocket *sock,void **uptr,bool success); +typedef void (*Wire_OnTcpAcceptFunction)(WireSocket *sockL,WireSocket *sockN,void **uptrL,void **uptrN,const struct sockaddr *from); +typedef void (*Wire_OnTcpCloseFunction)(WireSocket *sock,void **uptr); +typedef void (*Wire_OnTcpDataFunction)(WireSocket *sock,void **uptr,void *data,unsigned long len); +typedef void (*Wire_OnTcpWritableFunction)(WireSocket *sock,void **uptr); + +/** + * Wire<> typedef'd to use simple naked function pointers + */ +typedef Wire<Wire_OnDatagramFunctionPtr,Wire_OnTcpConnectFunction,Wire_OnTcpAcceptFunction,Wire_OnTcpCloseFunction,Wire_OnTcpDataFunction,Wire_OnTcpWritableFunction> SimpleFunctionWire; + } // namespace ZeroTier #endif diff --git a/selftest.cpp b/selftest.cpp index 8aef0c6f..d50b344b 100644 --- a/selftest.cpp +++ b/selftest.cpp @@ -646,6 +646,139 @@ static int testOther() return 0; } +#ifdef ZT_TEST_WIRE +#define ZT_TEST_WIRE_NUM_UDP_PACKETS 10000 +#define ZT_TEST_WIRE_UDP_PACKET_SIZE 1000 +#define ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS 10 +#define ZT_TEST_WIRE_NUM_INVALID_TCP_CONNECTS 2 +#define ZT_TEST_WIRE_TCP_MESSAGE_SIZE 1000000 +#define ZT_TEST_WIRE_TIMEOUT_MS 20000 +static unsigned long wireTestUdpPacketCount = 0; +static unsigned long wireTestTcpByteCount = 0; +static unsigned long wireTestTcpConnectSuccessCount = 0; +static unsigned long wireTestTcpConnectFailCount = 0; +static unsigned long wireTestTcpAcceptCount = 0; +static SimpleFunctionWire *testWireInstance = (SimpleFunctionWire *)0; +static void testWireOnDatagramFunction(WireSocket *sock,void **uptr,const struct sockaddr *from,void *data,unsigned long len) +{ + ++wireTestUdpPacketCount; +} +static void testWireOnTcpConnectFunction(WireSocket *sock,void **uptr,bool success) +{ + if (success) { + ++wireTestTcpConnectSuccessCount; + } else { + ++wireTestTcpConnectFailCount; + } +} +static void testWireOnTcpAcceptFunction(WireSocket *sockL,WireSocket *sockN,void **uptrL,void **uptrN,const struct sockaddr *from) +{ + ++wireTestTcpAcceptCount; + *uptrN = new std::string(ZT_TEST_WIRE_TCP_MESSAGE_SIZE,(char)0xff); + testWireInstance->tcpSetNotifyWritable(sockN,true); +} +static void testWireOnTcpCloseFunction(WireSocket *sock,void **uptr) +{ + delete (std::string *)*uptr; // delete testMessage if any +} +static void testWireOnTcpDataFunction(WireSocket *sock,void **uptr,void *data,unsigned long len) +{ + wireTestTcpByteCount += len; +} +static void testWireOnTcpWritableFunction(WireSocket *sock,void **uptr) +{ + std::string *testMessage = (std::string *)*uptr; + if ((testMessage)&&(testMessage->length() > 0)) { + long sent = testWireInstance->tcpSend(sock,(const void *)testMessage->data(),testMessage->length(),true); + if (sent > 0) + testMessage->erase(0,sent); + } + if ((!testMessage)||(!testMessage->length())) { + testWireInstance->close(sock,true); + } +} +#endif // ZT_TEST_WIRE + +static int testWire() +{ +#ifdef ZT_TEST_WIRE + char udpTestPayload[ZT_TEST_WIRE_UDP_PACKET_SIZE]; + memset(udpTestPayload,0xff,sizeof(udpTestPayload)); + + struct sockaddr_in bindaddr; + memset(&bindaddr,0,sizeof(bindaddr)); + bindaddr.sin_family = AF_INET; + bindaddr.sin_port = Utils::hton((uint16_t)60002); + bindaddr.sin_addr.s_addr = Utils::hton((uint32_t)0x7f000001); + struct sockaddr_in invalidAddr; + memset(&bindaddr,0,sizeof(bindaddr)); + bindaddr.sin_family = AF_INET; + bindaddr.sin_port = Utils::hton((uint16_t)60004); + bindaddr.sin_addr.s_addr = Utils::hton((uint32_t)0x7f000001); + + std::cout << "[wire] Creating wire endpoint..." << std::endl; + testWireInstance = new SimpleFunctionWire(testWireOnDatagramFunction,testWireOnTcpConnectFunction,testWireOnTcpAcceptFunction,testWireOnTcpCloseFunction,testWireOnTcpDataFunction,testWireOnTcpWritableFunction,false); + + std::cout << "[wire] Binding UDP listen socket to 127.0.0.1/60002... "; + WireSocket *udpListenSock = testWireInstance->udpBind((const struct sockaddr *)&bindaddr); + if (!udpListenSock) { + std::cout << "FAILED." << std::endl; + return -1; + } + std::cout << "OK" << std::endl; + + std::cout << "[wire] Binding TCP listen socket to 127.0.0.1/60002... "; + WireSocket *tcpListenSock = testWireInstance->tcpListen((const struct sockaddr *)&bindaddr); + if (!tcpListenSock) { + std::cout << "FAILED." << std::endl; + return -1; + } + std::cout << "OK" << std::endl; + + unsigned long wireTestUdpPacketsSent = 0; + unsigned long wireTestTcpValidConnectionsAttempted = 0; + unsigned long wireTestTcpInvalidConnectionsAttempted = 0; + + std::cout << "[wire] Testing UDP send/receive... "; std::cout.flush(); + uint64_t timeoutAt = Utils::now() + ZT_TEST_WIRE_TIMEOUT_MS; + while ((Utils::now() < timeoutAt)&&(wireTestUdpPacketCount < ZT_TEST_WIRE_NUM_UDP_PACKETS)) { + if (wireTestUdpPacketsSent < ZT_TEST_WIRE_NUM_UDP_PACKETS) { + if (!testWireInstance->udpSend(udpListenSock,(const struct sockaddr *)&bindaddr,udpTestPayload,sizeof(udpTestPayload))) { + std::cout << "FAILED." << std::endl; + return -1; + } else ++wireTestUdpPacketsSent; + } + testWireInstance->poll(100); + } + std::cout << "got " << wireTestUdpPacketCount << " packets, OK" << std::endl; + + std::cout << "[wire] Testing TCP... "; std::cout.flush(); + timeoutAt = Utils::now() + ZT_TEST_WIRE_TIMEOUT_MS; + while ((Utils::now() < timeoutAt)&&(wireTestTcpByteCount < (ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS * ZT_TEST_WIRE_TCP_MESSAGE_SIZE))) { + if (wireTestTcpValidConnectionsAttempted < ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS) { + ++wireTestTcpValidConnectionsAttempted; + bool connected = false; + if (!testWireInstance->tcpConnect((const struct sockaddr *)&bindaddr,connected,(void *)0,true)) + ++wireTestTcpConnectFailCount; + } + if (wireTestTcpInvalidConnectionsAttempted < ZT_TEST_WIRE_NUM_INVALID_TCP_CONNECTS) { + ++wireTestTcpInvalidConnectionsAttempted; + bool connected = false; + if (!testWireInstance->tcpConnect((const struct sockaddr *)&invalidAddr,connected,(void *)0,true)) + ++wireTestTcpConnectFailCount; + } + testWireInstance->poll(100); + } + if (wireTestTcpByteCount < (ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS * ZT_TEST_WIRE_TCP_MESSAGE_SIZE)) { + std::cout << "got " << wireTestTcpConnectSuccessCount << " connect successes, " << wireTestTcpConnectFailCount << " failures, and " << wireTestTcpByteCount << " bytes, FAILED." << std::endl; + return -1; + } else { + std::cout << "got " << wireTestTcpConnectSuccessCount << " connect successes, " << wireTestTcpConnectFailCount << " failures, and " << wireTestTcpByteCount << " bytes, OK" << std::endl; + } +#endif // ZT_TEST_WIRE + return 0; +} + static int testSqliteNetconfMaster() { #ifdef ZT_ENABLE_NETCONF_MASTER @@ -717,6 +850,7 @@ int main(int argc,char **argv) srand((unsigned int)time(0)); + r |= testWire(); r |= testSqliteNetconfMaster(); r |= testCrypto(); r |= testHttp(); |