From 233d8244aebc7b23082b386004aaf27733c6dc6a Mon Sep 17 00:00:00 2001 From: Owen Rudge Date: Mon, 19 Feb 2018 22:47:04 +0000 Subject: [PATCH] wsdapi/tests: Add test for Publish. Signed-off-by: Owen Rudge Signed-off-by: Alexandre Julliard --- dlls/wsdapi/tests/Makefile.in | 2 +- dlls/wsdapi/tests/discovery.c | 359 +++++++++++++++++++++++++++++++++- 2 files changed, 357 insertions(+), 4 deletions(-) diff --git a/dlls/wsdapi/tests/Makefile.in b/dlls/wsdapi/tests/Makefile.in index c1e559e1abb..b76ffafc36a 100644 --- a/dlls/wsdapi/tests/Makefile.in +++ b/dlls/wsdapi/tests/Makefile.in @@ -1,5 +1,5 @@ TESTDLL = wsdapi.dll -IMPORTS = wsdapi ole32 oleaut32 user32 ws2_32 advapi32 +IMPORTS = wsdapi ole32 oleaut32 user32 ws2_32 advapi32 iphlpapi C_SRCS = \ address.c \ diff --git a/dlls/wsdapi/tests/discovery.c b/dlls/wsdapi/tests/discovery.c index f373510a760..e4a672b311d 100644 --- a/dlls/wsdapi/tests/discovery.c +++ b/dlls/wsdapi/tests/discovery.c @@ -2,7 +2,7 @@ * Web Services on Devices * Discovery tests * - * Copyright 2017 Owen Rudge for CodeWeavers + * Copyright 2017-2018 Owen Rudge for CodeWeavers * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -21,13 +21,296 @@ #define COBJMACROS +#include +#include #include #include "wine/test.h" +#include "wine/heap.h" #include "initguid.h" #include "objbase.h" #include "wsdapi.h" #include +#include +#include + +#define SEND_ADDRESS_IPV4 "239.255.255.250" +#define SEND_ADDRESS_IPV6 "FF02::C" +#define SEND_PORT "3702" + +static const char *publisherId = "urn:uuid:3AE5617D-790F-408A-9374-359A77F924A3"; + +#define MAX_CACHED_MESSAGES 5 +#define MAX_LISTENING_THREADS 20 + +typedef struct messageStorage { + BOOL running; + CRITICAL_SECTION criticalSection; + char* messages[MAX_CACHED_MESSAGES]; + int messageCount; + HANDLE threadHandles[MAX_LISTENING_THREADS]; + int numThreadHandles; +} messageStorage; + +static LPWSTR utf8_to_wide(const char *utf8String) +{ + int sizeNeeded = 0, utf8StringLength = 0, memLength = 0; + LPWSTR newString = NULL; + + if (utf8String == NULL) return NULL; + utf8StringLength = lstrlenA(utf8String); + + sizeNeeded = MultiByteToWideChar(CP_UTF8, 0, utf8String, utf8StringLength, NULL, 0); + if (sizeNeeded <= 0) return NULL; + + memLength = sizeof(WCHAR) * (sizeNeeded + 1); + newString = heap_alloc_zero(memLength); + + MultiByteToWideChar(CP_UTF8, 0, utf8String, utf8StringLength, newString, sizeNeeded); + return newString; +} + +static int join_multicast_group(SOCKET s, struct addrinfo *group, struct addrinfo *iface) +{ + int level, optname, optlen; + struct ipv6_mreq mreqv6; + struct ip_mreq mreqv4; + char *optval; + + if (group->ai_family == AF_INET6) + { + level = IPPROTO_IPV6; + optname = IPV6_ADD_MEMBERSHIP; + optval = (char *)&mreqv6; + optlen = sizeof(mreqv6); + + mreqv6.ipv6mr_multiaddr = ((SOCKADDR_IN6 *)group->ai_addr)->sin6_addr; + mreqv6.ipv6mr_interface = ((SOCKADDR_IN6 *)iface->ai_addr)->sin6_scope_id; + } + else + { + level = IPPROTO_IP; + optname = IP_ADD_MEMBERSHIP; + optval = (char *)&mreqv4; + optlen = sizeof(mreqv4); + + mreqv4.imr_multiaddr.s_addr = ((SOCKADDR_IN *)group->ai_addr)->sin_addr.s_addr; + mreqv4.imr_interface.s_addr = ((SOCKADDR_IN *)iface->ai_addr)->sin_addr.s_addr; + } + + return setsockopt(s, level, optname, optval, optlen); +} + +static int set_send_interface(SOCKET s, struct addrinfo *iface) +{ + int level, optname, optlen; + char *optval = NULL; + + if (iface->ai_family == AF_INET6) + { + level = IPPROTO_IPV6; + optname = IPV6_MULTICAST_IF; + optval = (char *) &((SOCKADDR_IN6 *)iface->ai_addr)->sin6_scope_id; + optlen = sizeof(((SOCKADDR_IN6 *)iface->ai_addr)->sin6_scope_id); + } + else + { + level = IPPROTO_IP; + optname = IP_MULTICAST_IF; + optval = (char *) &((SOCKADDR_IN *)iface->ai_addr)->sin_addr.s_addr; + optlen = sizeof(((SOCKADDR_IN *)iface->ai_addr)->sin_addr.s_addr); + } + + return setsockopt(s, level, optname, optval, optlen); +} + +static struct addrinfo *resolve_address(const char *address, const char *port, int family, int type, int protocol) +{ + struct addrinfo hints, *result = NULL; + + ZeroMemory(&hints, sizeof(hints)); + + hints.ai_flags = AI_PASSIVE; + hints.ai_family = family; + hints.ai_socktype = type; + hints.ai_protocol = protocol; + + return getaddrinfo(address, port, &hints, &result) == 0 ? result : NULL; +} + +typedef struct listenerThreadParams +{ + messageStorage *msgStorage; + SOCKET listeningSocket; +} listenerThreadParams; + +#define RECEIVE_BUFFER_SIZE 65536 + +static DWORD WINAPI listening_thread(LPVOID lpParam) +{ + listenerThreadParams *parameter = (listenerThreadParams *)lpParam; + messageStorage *msgStorage = parameter->msgStorage; + int bytesReceived; + char *buffer; + + buffer = heap_alloc(RECEIVE_BUFFER_SIZE); + + while (parameter->msgStorage->running) + { + ZeroMemory(buffer, RECEIVE_BUFFER_SIZE); + bytesReceived = recv(parameter->listeningSocket, buffer, RECEIVE_BUFFER_SIZE, 0); + + if (bytesReceived == SOCKET_ERROR) + { + if (WSAGetLastError() != WSAETIMEDOUT) + return 0; + } + else + { + EnterCriticalSection(&msgStorage->criticalSection); + + if (msgStorage->messageCount < MAX_CACHED_MESSAGES) + { + msgStorage->messages[msgStorage->messageCount] = heap_alloc(bytesReceived); + + if (msgStorage->messages[msgStorage->messageCount] != NULL) + { + memcpy(msgStorage->messages[msgStorage->messageCount], buffer, bytesReceived); + msgStorage->messageCount++; + } + } + + LeaveCriticalSection(&msgStorage->criticalSection); + + if (msgStorage->messageCount >= MAX_CACHED_MESSAGES) + { + /* Stop all threads */ + msgStorage->running = FALSE; + break; + } + } + } + + closesocket(parameter->listeningSocket); + + heap_free(buffer); + heap_free(parameter); + + return 0; +} + +static void start_listening(messageStorage *msgStorage, const char *multicastAddress, const char *bindAddress) +{ + struct addrinfo *multicastAddr = NULL, *bindAddr = NULL, *interfaceAddr = NULL; + listenerThreadParams *parameter = NULL; + const DWORD receiveTimeout = 5000; + const UINT reuseAddr = 1; + HANDLE hThread; + SOCKET s = 0; + + /* Resolve the multicast address */ + multicastAddr = resolve_address(multicastAddress, SEND_PORT, AF_UNSPEC, SOCK_DGRAM, IPPROTO_UDP); + if (multicastAddr == NULL) goto cleanup; + + /* Resolve the binding address */ + bindAddr = resolve_address(bindAddress, SEND_PORT, multicastAddr->ai_family, multicastAddr->ai_socktype, multicastAddr->ai_protocol); + if (bindAddr == NULL) goto cleanup; + + /* Resolve the multicast interface */ + interfaceAddr = resolve_address(bindAddress, "0", multicastAddr->ai_family, multicastAddr->ai_socktype, multicastAddr->ai_protocol); + if (interfaceAddr == NULL) goto cleanup; + + /* Create the socket */ + s = socket(multicastAddr->ai_family, multicastAddr->ai_socktype, multicastAddr->ai_protocol); + if (s == INVALID_SOCKET) goto cleanup; + + /* Ensure the socket can be reused */ + if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char *)&reuseAddr, sizeof(reuseAddr)) == SOCKET_ERROR) goto cleanup; + + /* Bind the socket to the local interface so we can receive data */ + if (bind(s, bindAddr->ai_addr, bindAddr->ai_addrlen) == SOCKET_ERROR) goto cleanup; + + /* Join the multicast group */ + if (join_multicast_group(s, multicastAddr, interfaceAddr) == SOCKET_ERROR) goto cleanup; + + /* Set the outgoing interface */ + if (set_send_interface(s, interfaceAddr) == SOCKET_ERROR) goto cleanup; + + /* For IPv6, ensure the scope ID is zero */ + if (multicastAddr->ai_family == AF_INET6) + ((SOCKADDR_IN6 *)multicastAddr->ai_addr)->sin6_scope_id = 0; + + /* Set a 5-second receive timeout */ + if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char *)&receiveTimeout, sizeof(receiveTimeout)) == SOCKET_ERROR) goto cleanup; + + /* Allocate memory for thread parameters */ + parameter = heap_alloc(sizeof(listenerThreadParams)); + + parameter->msgStorage = msgStorage; + parameter->listeningSocket = s; + + hThread = CreateThread(NULL, 0, listening_thread, parameter, 0, NULL); + if (hThread == NULL) goto cleanup; + + msgStorage->threadHandles[msgStorage->numThreadHandles] = hThread; + msgStorage->numThreadHandles++; + + goto cleanup_addresses; + +cleanup: + closesocket(s); + if (parameter != NULL) heap_free(parameter); + +cleanup_addresses: + freeaddrinfo(multicastAddr); + freeaddrinfo(bindAddr); + freeaddrinfo(interfaceAddr); +} + +static BOOL start_listening_on_all_addresses(messageStorage *msgStorage, ULONG family) +{ + IP_ADAPTER_ADDRESSES *adapterAddresses = NULL, *adapterAddress; + ULONG bufferSize = 0; + LPSOCKADDR sockaddr; + DWORD addressLength; + char address[64]; + BOOL ret = FALSE; + ULONG retVal; + + retVal = GetAdaptersAddresses(family, 0, NULL, NULL, &bufferSize); /* family should be AF_INET or AF_INET6 */ + if (retVal != ERROR_BUFFER_OVERFLOW) goto cleanup; + + /* Get size of buffer for adapters */ + adapterAddresses = (IP_ADAPTER_ADDRESSES *)heap_alloc(bufferSize); + if (adapterAddresses == NULL) goto cleanup; + + /* Get list of adapters */ + retVal = GetAdaptersAddresses(family, 0, NULL, adapterAddresses, &bufferSize); + if (retVal != ERROR_SUCCESS) goto cleanup; + + for (adapterAddress = adapterAddresses; adapterAddress != NULL; adapterAddress = adapterAddress->Next) + { + if (msgStorage->numThreadHandles >= MAX_LISTENING_THREADS) + { + ret = TRUE; + goto cleanup; + } + + if (adapterAddress->FirstUnicastAddress == NULL) continue; + + sockaddr = adapterAddress->FirstUnicastAddress->Address.lpSockaddr; + addressLength = sizeof(address); + WSAAddressToStringA(sockaddr, adapterAddress->FirstUnicastAddress->Address.iSockaddrLength, NULL, address, &addressLength); + + start_listening(msgStorage, adapterAddress->FirstUnicastAddress->Address.lpSockaddr->sa_family == AF_INET ? SEND_ADDRESS_IPV4 : SEND_ADDRESS_IPV6, address); + } + + ret = TRUE; + +cleanup: + if (adapterAddresses != NULL) heap_free(adapterAddresses); + return ret; +} typedef struct IWSDiscoveryPublisherNotifyImpl { IWSDiscoveryPublisherNotify IWSDiscoveryPublisherNotify_iface; @@ -218,9 +501,15 @@ static void Publish_tests(void) IWSDiscoveryPublisher *publisher = NULL; IWSDiscoveryPublisherNotify *sink1 = NULL, *sink2 = NULL; IWSDiscoveryPublisherNotifyImpl *sink1Impl = NULL, *sink2Impl = NULL; - + char endpointReferenceString[MAX_PATH]; + LPWSTR publisherIdW = NULL; + messageStorage *msgStorage; + WSADATA wsaData; + BOOL messageOK; + int ret, i; HRESULT rc; ULONG ref; + char *msg; rc = WSDCreateDiscoveryPublisher(NULL, &publisher); ok(rc == S_OK, "WSDCreateDiscoveryPublisher(NULL, &publisher) failed: %08x\n", rc); @@ -263,7 +552,69 @@ static void Publish_tests(void) ok(rc == S_OK, "IWSDiscoveryPublisher_UnRegisterNotificationSink failed: %08x\n", rc); ok(sink1Impl->ref == 1, "Ref count for sink 1 is not as expected: %d\n", sink1Impl->ref); - /* TODO: Publish */ + /* Set up network listener */ + publisherIdW = utf8_to_wide(publisherId); + if (publisherIdW == NULL) goto after_publish_test; + + msgStorage = heap_alloc_zero(sizeof(messageStorage)); + if (msgStorage == NULL) goto after_publish_test; + + msgStorage->running = TRUE; + InitializeCriticalSection(&msgStorage->criticalSection); + + ret = WSAStartup(MAKEWORD(2, 2), &wsaData); + ok(ret == 0, "WSAStartup failed (ret = %d)\n", ret); + + ret = start_listening_on_all_addresses(msgStorage, AF_INET); + ok(ret == TRUE, "Unable to listen on IPv4 addresses (ret == %d)\n", ret); + + /* Publish the service */ + rc = IWSDiscoveryPublisher_Publish(publisher, publisherIdW, 1, 1, 1, NULL, NULL, NULL, NULL); + todo_wine ok(rc == S_OK, "Publish failed: %08x\n", rc); + + /* Wait up to 2 seconds for messages to be received */ + if (WaitForMultipleObjects(msgStorage->numThreadHandles, msgStorage->threadHandles, TRUE, 2000) == WAIT_TIMEOUT) + { + /* Wait up to 1 more second for threads to terminate */ + msgStorage->running = FALSE; + WaitForMultipleObjects(msgStorage->numThreadHandles, msgStorage->threadHandles, TRUE, 1000); + } + + DeleteCriticalSection(&msgStorage->criticalSection); + + /* Verify we've received a message */ + todo_wine ok(msgStorage->messageCount >= 1, "No messages received\n"); + + sprintf(endpointReferenceString, "%s", publisherId); + + messageOK = FALSE; + + /* Check we're received the correct message */ + for (i = 0; i < msgStorage->messageCount; i++) + { + msg = msgStorage->messages[i]; + messageOK = FALSE; + + messageOK = (strstr(msg, "http://schemas.xmlsoap.org/ws/2005/04/discovery/Hello") != NULL); + messageOK = messageOK && (strstr(msg, endpointReferenceString) != NULL); + messageOK = messageOK && (strstr(msg, "") != NULL); + messageOK = messageOK && (strstr(msg, "1") != NULL); + + if (messageOK) break; + } + + for (i = 0; i < msgStorage->messageCount; i++) + { + heap_free(msgStorage->messages[i]); + } + + heap_free(msgStorage); + + todo_wine ok(messageOK == TRUE, "Hello message not received\n"); + +after_publish_test: + + if (publisherIdW != NULL) heap_free(publisherIdW); ref = IWSDiscoveryPublisher_Release(publisher); ok(ref == 0, "IWSDiscoveryPublisher_Release() has %d references, should have 0\n", ref); @@ -275,6 +626,8 @@ static void Publish_tests(void) /* Release the sinks */ IWSDiscoveryPublisherNotify_Release(sink1); IWSDiscoveryPublisherNotify_Release(sink2); + + WSACleanup(); } enum firewall_op