From e790647a577fb1402b5a25a479a03d4045f5aa44 Mon Sep 17 00:00:00 2001 From: Cameron Gutman Date: Sun, 12 Feb 2023 03:40:00 -0600 Subject: [PATCH] Correctly support hosts with multiple possible source addresses --- host.c | 3 +- include/enet/enet.h | 8 ++- protocol.c | 13 ++-- unix.c | 143 +++++++++++++++++++++++++++++++++---- win32.c | 167 ++++++++++++++++++++++++++++++++++++-------- 5 files changed, 283 insertions(+), 51 deletions(-) diff --git a/host.c b/host.c index 6b6827a..8487e92 100644 --- a/host.c +++ b/host.c @@ -88,7 +88,8 @@ enet_host_create (int addressFamily, const ENetAddress * address, size_t peerCou host -> commandCount = 0; host -> bufferCount = 0; host -> checksum = NULL; - memset(& host -> receivedAddress, 0, sizeof (host -> receivedAddress)); + memset(& host -> receivedPeerAddress, 0, sizeof (host -> receivedPeerAddress)); + memset(& host -> receivedLocalAddress, 0, sizeof (host -> receivedLocalAddress)); host -> receivedData = NULL; host -> receivedDataLength = 0; diff --git a/include/enet/enet.h b/include/enet/enet.h index dfe5475..4da0b7e 100644 --- a/include/enet/enet.h +++ b/include/enet/enet.h @@ -259,6 +259,7 @@ typedef struct _ENetPeer enet_uint8 outgoingSessionID; enet_uint8 incomingSessionID; ENetAddress address; /**< Internet address of the peer */ + ENetAddress localAddress; void * data; /**< Application private data, may be freely modified */ ENetPeerState state; ENetChannel * channels; @@ -373,7 +374,8 @@ typedef struct _ENetHost ENetChecksumCallback checksum; /**< callback the user can set to enable packet checksums for this host */ ENetCompressor compressor; enet_uint8 packetData [2][ENET_PROTOCOL_MAXIMUM_MTU]; - ENetAddress receivedAddress; + ENetAddress receivedPeerAddress; + ENetAddress receivedLocalAddress; enet_uint8 * receivedData; size_t receivedDataLength; enet_uint32 totalSentData; /**< total data sent, user should reset to 0 as needed to prevent overflow */ @@ -488,8 +490,8 @@ ENET_API int enet_socket_get_address (ENetSocket, ENetAddress *); ENET_API int enet_socket_listen (ENetSocket, int); ENET_API ENetSocket enet_socket_accept (ENetSocket, ENetAddress *); ENET_API int enet_socket_connect (ENetSocket, const ENetAddress *); -ENET_API int enet_socket_send (ENetSocket, const ENetAddress *, const ENetBuffer *, size_t); -ENET_API int enet_socket_receive (ENetSocket, ENetAddress *, ENetBuffer *, size_t); +ENET_API int enet_socket_send (ENetSocket, const ENetAddress *, const ENetAddress *, const ENetBuffer *, size_t); +ENET_API int enet_socket_receive (ENetSocket, ENetAddress *, ENetAddress *, ENetBuffer *, size_t); ENET_API int enet_socket_wait (ENetSocket, enet_uint32 *, enet_uint32); ENET_API int enet_socket_set_option (ENetSocket, ENetSocketOption, int); ENET_API int enet_socket_get_option (ENetSocket, ENetSocketOption, int *); diff --git a/protocol.c b/protocol.c index 26ffc9d..f937cc6 100644 --- a/protocol.c +++ b/protocol.c @@ -309,7 +309,7 @@ enet_protocol_handle_connect (ENetHost * host, ENetProtocolHeader * header, ENet } else if (currentPeer -> state != ENET_PEER_STATE_CONNECTING && - enet_address_equal (& currentPeer -> address, & host -> receivedAddress)) + enet_address_equal (& currentPeer -> address, & host -> receivedPeerAddress)) { if (currentPeer -> connectID == command -> connect.connectID) return NULL; @@ -329,7 +329,8 @@ enet_protocol_handle_connect (ENetHost * host, ENetProtocolHeader * header, ENet peer -> channelCount = channelCount; peer -> state = ENET_PEER_STATE_ACKNOWLEDGING_CONNECT; peer -> connectID = command -> connect.connectID; - peer -> address = host -> receivedAddress; + peer -> address = host -> receivedPeerAddress; + peer -> localAddress = host -> receivedLocalAddress; peer -> outgoingPeerID = ENET_NET_TO_HOST_16 (command -> connect.outgoingPeerID); peer -> incomingBandwidth = ENET_NET_TO_HOST_32 (command -> connect.incomingBandwidth); peer -> outgoingBandwidth = ENET_NET_TO_HOST_32 (command -> connect.outgoingBandwidth); @@ -1072,7 +1073,8 @@ enet_protocol_handle_incoming_commands (ENetHost * host, ENetEvent * event) if (peer != NULL) { - memcpy(& peer -> address, & host -> receivedAddress, sizeof (host -> receivedAddress)); + memcpy(& peer -> address, & host -> receivedPeerAddress, sizeof (host -> receivedPeerAddress)); + memcpy(& peer -> localAddress, & host -> receivedLocalAddress, sizeof (host -> receivedLocalAddress)); peer -> incomingDataTotal += host -> receivedDataLength; } @@ -1223,7 +1225,8 @@ enet_protocol_receive_incoming_commands (ENetHost * host, ENetEvent * event) buffer.dataLength = sizeof (host -> packetData [0]); receivedLength = enet_socket_receive (host -> socket, - & host -> receivedAddress, + & host -> receivedPeerAddress, + & host -> receivedLocalAddress, & buffer, 1); @@ -1690,7 +1693,7 @@ enet_protocol_send_outgoing_commands (ENetHost * host, ENetEvent * event, int ch enet_socket_set_option (host -> socket, ENET_SOCKOPT_QOS, 0); } - sentLength = enet_socket_send (host -> socket, & currentPeer -> address, host -> buffers, host -> bufferCount); + sentLength = enet_socket_send (host -> socket, & currentPeer -> address, & currentPeer -> localAddress, host -> buffers, host -> bufferCount); enet_protocol_remove_sent_unreliable_commands (currentPeer); diff --git a/unix.c b/unix.c index 73f8ad9..7a3fa4d 100644 --- a/unix.c +++ b/unix.c @@ -4,6 +4,16 @@ */ #ifndef _WIN32 +// Required for IPV6_PKTINFO with Darwin headers +#ifndef __APPLE_USE_RFC_3542 +#define __APPLE_USE_RFC_3542 1 +#endif + +// Required for in6_pktinfo with glibc headers +#ifndef _GNU_SOURCE +#define _GNU_SOURCE 1 +#endif + #include #include #include @@ -298,7 +308,37 @@ enet_socket_listen (ENetSocket socket, int backlog) ENetSocket enet_socket_create (int af, ENetSocketType type) { - return socket (af, type == ENET_SOCKET_TYPE_DATAGRAM ? SOCK_DGRAM : SOCK_STREAM, 0); + ENetSocket sock = socket (af, type == ENET_SOCKET_TYPE_DATAGRAM ? SOCK_DGRAM : SOCK_STREAM, 0); + if (sock < 0) { + return sock; + } + +#ifdef IPV6_V6ONLY + if (af == AF_INET6) { + int off = 0; + + // Some OSes don't support dual-stack sockets, so ignore failures + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (char *)&off, sizeof(off)); + } +#endif + +#ifdef IP_PKTINFO + { + // We turn this on for all sockets because it may be required for IPv4 + // traffic on dual-stack sockets on some OSes. + int on = 1; + setsockopt(sock, IPPROTO_IP, IP_PKTINFO, (char *)&on, sizeof(on)); + } +#endif + +#ifdef IPV6_RECVPKTINFO + if (af == AF_INET6) { + int on = 1; + setsockopt(sock, IPPROTO_IPV6, IPV6_RECVPKTINFO, (char *)&on, sizeof(on)); + } +#endif + + return sock; } int @@ -448,7 +488,8 @@ enet_socket_destroy (ENetSocket socket) int enet_socket_send (ENetSocket socket, - const ENetAddress * address, + const ENetAddress * peerAddress, + const ENetAddress * localAddress, const ENetBuffer * buffers, size_t bufferCount) { @@ -486,24 +527,65 @@ enet_socket_send (ENetSocket socket, } sentLength = sendto (socket, sendBuffer, sendLength, MSG_NOSIGNAL, - (struct sockaddr *) & address -> address, address -> addressLength); + (struct sockaddr *) & peerAddress -> address, peerAddress -> addressLength); if (bufferCount > 1) free(sendBuffer); #else struct msghdr msgHdr; + char controlBufData[1024]; memset (& msgHdr, 0, sizeof (struct msghdr)); - if (address != NULL) + if (peerAddress != NULL) { - msgHdr.msg_name = (void*) & address -> address; - msgHdr.msg_namelen = address -> addressLength; + msgHdr.msg_name = (void*) & peerAddress -> address; + msgHdr.msg_namelen = peerAddress -> addressLength; } msgHdr.msg_iov = (struct iovec *) buffers; msgHdr.msg_iovlen = bufferCount; + // We always send traffic from the same local address as we last received + // from this peer to ensure it correctly recognizes our responses as + // coming from the expected host. + if (localAddress != NULL) { +#ifdef IP_PKTINFO + if (localAddress->address.ss_family == AF_INET) { + struct in_pktinfo pktInfo; + + pktInfo.ipi_spec_dst = ((struct sockaddr_in*)&localAddress->address)->sin_addr; + pktInfo.ipi_ifindex = 0; // Unspecified + + msgHdr.msg_control = controlBufData; + msgHdr.msg_controllen = CMSG_SPACE(sizeof(pktInfo)); + + struct cmsghdr *chdr = CMSG_FIRSTHDR(&msgHdr); + chdr->cmsg_level = IPPROTO_IP; + chdr->cmsg_type = IP_PKTINFO; + chdr->cmsg_len = CMSG_LEN(sizeof(pktInfo)); + memcpy(CMSG_DATA(chdr), &pktInfo, sizeof(pktInfo)); + } +#endif +#ifdef IPV6_PKTINFO + if (localAddress->address.ss_family == AF_INET6) { + struct in6_pktinfo pktInfo; + + pktInfo.ipi6_addr = ((struct sockaddr_in6*)&localAddress->address)->sin6_addr; + pktInfo.ipi6_ifindex = 0; // Unspecified + + msgHdr.msg_control = controlBufData; + msgHdr.msg_controllen = CMSG_SPACE(sizeof(pktInfo)); + + struct cmsghdr *chdr = CMSG_FIRSTHDR(&msgHdr); + chdr->cmsg_level = IPPROTO_IPV6; + chdr->cmsg_type = IPV6_PKTINFO; + chdr->cmsg_len = CMSG_LEN(sizeof(pktInfo)); + memcpy(CMSG_DATA(chdr), &pktInfo, sizeof(pktInfo)); + } + #endif + } + sentLength = sendmsg (socket, & msgHdr, MSG_NOSIGNAL); #endif @@ -520,7 +602,8 @@ enet_socket_send (ENetSocket socket, int enet_socket_receive (ENetSocket socket, - ENetAddress * address, + ENetAddress * peerAddress, + ENetAddress * localAddress, ENetBuffer * buffers, size_t bufferCount) { @@ -531,7 +614,7 @@ enet_socket_receive (ENetSocket socket, address -> addressLength = sizeof (address -> address); recvLength = recvfrom (socket, buffers[0].data, buffers[0].dataLength, MSG_NOSIGNAL, - (struct sockaddr *) & address -> address, & address -> addressLength); + (struct sockaddr *) & peerAddress -> address, & peerAddress -> addressLength); if (recvLength == -1) { @@ -544,17 +627,20 @@ enet_socket_receive (ENetSocket socket, return recvLength; #else struct msghdr msgHdr; + char controlBufData[1024]; memset (& msgHdr, 0, sizeof (struct msghdr)); - if (address != NULL) + if (peerAddress != NULL) { - msgHdr.msg_name = & address -> address; - msgHdr.msg_namelen = sizeof (address -> address); + msgHdr.msg_name = & peerAddress -> address; + msgHdr.msg_namelen = sizeof (peerAddress -> address); } msgHdr.msg_iov = (struct iovec *) buffers; msgHdr.msg_iovlen = bufferCount; + msgHdr.msg_control = controlBufData; + msgHdr.msg_controllen = sizeof(controlBufData); recvLength = recvmsg (socket, & msgHdr, MSG_NOSIGNAL); @@ -565,15 +651,44 @@ enet_socket_receive (ENetSocket socket, return -1; } - - if (address != NULL) - address -> addressLength = msgHdr.msg_namelen; #ifdef HAS_MSGHDR_FLAGS if (msgHdr.msg_flags & MSG_TRUNC) return -1; #endif + // Retrieve the local address that this traffic was received on + // to ensure we respond from the correct address/interface. + if (localAddress != NULL) { + for (struct cmsghdr *chdr = CMSG_FIRSTHDR(&msgHdr); chdr != NULL; chdr = CMSG_NXTHDR(&msgHdr, chdr)) { +#ifdef IP_PKTINFO + if (chdr->cmsg_level == IPPROTO_IP && chdr->cmsg_type == IP_PKTINFO) { + struct sockaddr_in *localAddr = (struct sockaddr_in*)&localAddress->address; + + localAddr->sin_family = AF_INET; + localAddr->sin_addr = ((struct in_pktinfo*)CMSG_DATA(chdr))->ipi_addr; + + localAddress->addressLength = sizeof(*localAddr); + break; + } +#endif +#ifdef IPV6_PKTINFO + if (chdr->cmsg_level == IPPROTO_IPV6 && chdr->cmsg_type == IPV6_PKTINFO) { + struct sockaddr_in6 *localAddr = (struct sockaddr_in6*)&localAddress->address; + + localAddr->sin6_family = AF_INET6; + localAddr->sin6_addr = ((struct in6_pktinfo*)CMSG_DATA(chdr))->ipi6_addr; + + localAddress->addressLength = sizeof(*localAddr); + break; + } + #endif + } + } + + if (peerAddress != NULL) + peerAddress -> addressLength = msgHdr.msg_namelen; + return recvLength; #endif } diff --git a/win32.c b/win32.c index 0793454..83c7483 100644 --- a/win32.c +++ b/win32.c @@ -7,6 +7,7 @@ #define ENET_BUILDING_LIB 1 #include "enet/enet.h" #include +#include #ifndef HAS_QOS_FLOWID typedef UINT32 QOS_FLOWID; #endif @@ -30,6 +31,8 @@ BOOL (WINAPI *pfnQOSCreateHandle)(PQOS_VERSION Version, PHANDLE QOSHandle); BOOL (WINAPI *pfnQOSCloseHandle)(HANDLE QOSHandle); BOOL (WINAPI *pfnQOSAddSocketToFlow)(HANDLE QOSHandle, SOCKET Socket, PSOCKADDR DestAddr, QOS_TRAFFIC_TYPE TrafficType, DWORD Flags, PQOS_FLOWID FlowId); +LPFN_WSARECVMSG pfnWSARecvMsg; + int enet_initialize (void) { @@ -231,7 +234,46 @@ enet_socket_listen (ENetSocket socket, int backlog) ENetSocket enet_socket_create (int af, ENetSocketType type) { - return socket (af, type == ENET_SOCKET_TYPE_DATAGRAM ? SOCK_DGRAM : SOCK_STREAM, 0); + SOCKET sock = socket (af, type == ENET_SOCKET_TYPE_DATAGRAM ? SOCK_DGRAM : SOCK_STREAM, 0); + if (sock == INVALID_SOCKET) + return INVALID_SOCKET; + + DWORD bytesReturned; + GUID wsaRecvMsgGuid = WSAID_WSARECVMSG; + if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &wsaRecvMsgGuid, sizeof(wsaRecvMsgGuid), + &pfnWSARecvMsg, sizeof(pfnWSARecvMsg), &bytesReturned, NULL, NULL) == SOCKET_ERROR) { + closesocket(sock); + return INVALID_SOCKET; + } + + BOOL val; + + // Enable dual-stack operation for IPv6 sockets + if (af == AF_INET6) { + val = FALSE; + if (setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (char*)&val, sizeof(val)) == SOCKET_ERROR) { + closesocket(sock); + return INVALID_SOCKET; + } + } + + // Enable returning local address info for IPv4 and dual-stack sockets + val = TRUE; + if (setsockopt(sock, IPPROTO_IP, IP_PKTINFO, (char*)&val, sizeof(val)) == SOCKET_ERROR) { + closesocket(sock); + return INVALID_SOCKET; + } + + // Enable returning local address info for IPv6 and dual-stack sockets + if (af == AF_INET6) { + val = TRUE; + if (setsockopt(sock, IPPROTO_IPV6, IPV6_PKTINFO, (char*)&val, sizeof(val)) == SOCKET_ERROR) { + closesocket(sock); + return INVALID_SOCKET; + } + } + + return sock; } int @@ -365,18 +407,21 @@ enet_socket_destroy (ENetSocket socket) int enet_socket_send (ENetSocket socket, - const ENetAddress * address, + const ENetAddress * peerAddress, + const ENetAddress * localAddress, const ENetBuffer * buffers, size_t bufferCount) { DWORD sentLength; + WSAMSG msg = { 0 }; + char controlBufData[1024]; if (!qosAddedFlow && qosHandle != INVALID_HANDLE_VALUE) { qosFlowId = 0; // Must be initialized to 0 pfnQOSAddSocketToFlow(qosHandle, socket, - (struct sockaddr *)&address->address, + (struct sockaddr *)&peerAddress->address, QOSTrafficTypeControl, QOS_NON_ADAPTIVE_FLOW, &qosFlowId); @@ -385,15 +430,53 @@ enet_socket_send (ENetSocket socket, qosAddedFlow = TRUE; } - if (WSASendTo (socket, - (LPWSABUF) buffers, - (DWORD) bufferCount, - & sentLength, - 0, - address != NULL ? (struct sockaddr *) & address -> address : NULL, - address != NULL ? address -> addressLength : 0, - NULL, - NULL) == SOCKET_ERROR) + msg.name = peerAddress != NULL ? (struct sockaddr *) & peerAddress -> address : NULL; + msg.namelen = peerAddress != NULL ? peerAddress -> addressLength : 0; + msg.lpBuffers = (LPWSABUF) buffers; + msg.dwBufferCount = (DWORD) bufferCount; + + // We always send traffic from the same local address as we last received + // from this peer to ensure it correctly recognizes our responses as + // coming from the expected host. + if (localAddress != NULL) { + if (localAddress->address.ss_family == AF_INET) { + IN_PKTINFO pktInfo; + + pktInfo.ipi_addr = ((PSOCKADDR_IN)&localAddress->address)->sin_addr; + pktInfo.ipi_ifindex = 0; // Unspecified + + msg.Control.buf = controlBufData; + msg.Control.len = WSA_CMSG_SPACE(sizeof(pktInfo)); + + PWSACMSGHDR chdr = WSA_CMSG_FIRSTHDR(&msg); + chdr->cmsg_level = IPPROTO_IP; + chdr->cmsg_type = IP_PKTINFO; + chdr->cmsg_len = WSA_CMSG_LEN(sizeof(pktInfo)); + memcpy(WSA_CMSG_DATA(chdr), &pktInfo, sizeof(pktInfo)); + } + else if (localAddress->address.ss_family == AF_INET6) { + IN6_PKTINFO pktInfo; + + pktInfo.ipi6_addr = ((PSOCKADDR_IN6)&localAddress->address)->sin6_addr; + pktInfo.ipi6_ifindex = 0; // Unspecified + + msg.Control.buf = controlBufData; + msg.Control.len = WSA_CMSG_SPACE(sizeof(pktInfo)); + + PWSACMSGHDR chdr = WSA_CMSG_FIRSTHDR(&msg); + chdr->cmsg_level = IPPROTO_IPV6; + chdr->cmsg_type = IPV6_PKTINFO; + chdr->cmsg_len = WSA_CMSG_LEN(sizeof(pktInfo)); + memcpy(WSA_CMSG_DATA(chdr), &pktInfo, sizeof(pktInfo)); + } + } + + if (WSASendMsg (socket, + & msg, + 0, + & sentLength, + NULL, + NULL) == SOCKET_ERROR) { if (WSAGetLastError () == WSAEWOULDBLOCK) return 0; @@ -406,25 +489,27 @@ enet_socket_send (ENetSocket socket, int enet_socket_receive (ENetSocket socket, - ENetAddress * address, + ENetAddress * peerAddress, + ENetAddress * localAddress, ENetBuffer * buffers, size_t bufferCount) { - DWORD flags = 0, - recvLength; - - if (address != NULL) - address -> addressLength = sizeof (address -> address); + DWORD recvLength; + WSAMSG msg = { 0 }; + char controlBufData[1024]; - if (WSARecvFrom (socket, - (LPWSABUF) buffers, - (DWORD) bufferCount, - & recvLength, - & flags, - address != NULL ? (struct sockaddr *) & address -> address : NULL, - address != NULL ? & address -> addressLength : NULL, - NULL, - NULL) == SOCKET_ERROR) + msg.name = peerAddress != NULL ? (struct sockaddr *) & peerAddress -> address : NULL; + msg.namelen = peerAddress != NULL ? sizeof (peerAddress -> address) : 0; + msg.lpBuffers = (LPWSABUF) buffers; + msg.dwBufferCount = (DWORD) bufferCount; + msg.Control.buf = controlBufData; + msg.Control.len = sizeof(controlBufData); + + if (pfnWSARecvMsg (socket, + & msg, + & recvLength, + NULL, + NULL) == SOCKET_ERROR) { switch (WSAGetLastError ()) { @@ -436,9 +521,35 @@ enet_socket_receive (ENetSocket socket, return -1; } - if (flags & MSG_PARTIAL) + if (msg.dwFlags & MSG_PARTIAL) return -1; + // Retrieve the local address that this traffic was received on + // to ensure we respond from the correct address/interface. + if (localAddress != NULL) { + for (PWSACMSGHDR chdr = WSA_CMSG_FIRSTHDR(&msg); chdr != NULL; chdr = WSA_CMSG_NXTHDR(&msg, chdr)) { + if (chdr->cmsg_level == IPPROTO_IP && chdr->cmsg_type == IP_PKTINFO) { + PSOCKADDR_IN localAddr = (PSOCKADDR_IN)&localAddress->address; + + localAddr->sin_family = AF_INET; + localAddr->sin_addr = ((IN_PKTINFO*)WSA_CMSG_DATA(chdr))->ipi_addr; + + localAddress->addressLength = sizeof(*localAddr); + break; + } + else if (chdr->cmsg_level == IPPROTO_IPV6 && chdr->cmsg_type == IPV6_PKTINFO) { + PSOCKADDR_IN6 localAddr = (PSOCKADDR_IN6)&localAddress->address; + + localAddr->sin6_family = AF_INET6; + localAddr->sin6_addr = ((IN6_PKTINFO*)WSA_CMSG_DATA(chdr))->ipi6_addr; + + localAddress->addressLength = sizeof(*localAddr); + break; + } + } + } + + peerAddress->addressLength = msg.namelen; return (int) recvLength; }