diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c index 134ea4c582a..d3829c02156 100644 --- a/dlls/rpcrt4/rpc_transport.c +++ b/dlls/rpcrt4/rpc_transport.c @@ -564,42 +564,8 @@ typedef struct _RpcConnection_tcp { RpcConnection common; int sock; - HANDLE onEventAvailable; - HANDLE onEventHandled; - BOOL quit; } RpcConnection_tcp; -static DWORD WINAPI rpcrt4_tcp_poll_thread(LPVOID arg) -{ - RpcConnection_tcp *tcpc; - int ret; - struct pollfd pollInfo; - - tcpc = (RpcConnection_tcp*) arg; - pollInfo.fd = tcpc->sock; - pollInfo.events = POLLIN; - - while (!tcpc->quit) - { - ret = poll(&pollInfo, 1, 1000); - if (ret < 0) - ERR("poll failed with error %d\n", ret); - else - { - if (pollInfo.revents & POLLIN) - { - SignalObjectAndWait(tcpc->onEventAvailable, - tcpc->onEventHandled, INFINITE, FALSE); - } - } - } - - /* This avoids the tcpc being destroyed before we are done with it */ - SetEvent(tcpc->onEventAvailable); - - return 0; -} - static RpcConnection *rpcrt4_conn_tcp_alloc(void) { RpcConnection_tcp *tcpc; @@ -607,9 +573,6 @@ static RpcConnection *rpcrt4_conn_tcp_alloc(void) if (tcpc == NULL) return NULL; tcpc->sock = -1; - tcpc->onEventAvailable = NULL; - tcpc->onEventHandled = NULL; - tcpc->quit = FALSE; return &tcpc->common; } @@ -664,18 +627,19 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection) if (Connection->server) { - HANDLE thread = NULL; ret = bind(sock, ai_cur->ai_addr, ai_cur->ai_addrlen); if (ret < 0) { WARN("bind failed, error %d\n", ret); - goto done; + close(sock); + continue; } ret = listen(sock, 10); if (ret < 0) { WARN("listen failed, error %d\n", ret); - goto done; + close(sock); + continue; } /* need a non-blocking socket, otherwise accept() has a potential * race-condition (poll() says it is readable, connection drops, @@ -685,47 +649,10 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection) if (ret < 0) { WARN("couldn't make socket non-blocking, error %d\n", ret); - goto done; - } - tcpc->onEventAvailable = CreateEventW(NULL, FALSE, FALSE, NULL); - if (tcpc->onEventAvailable == NULL) - { - WARN("creating available event failed, error %lu\n", GetLastError()); - goto done; - } - tcpc->onEventHandled = CreateEventW(NULL, FALSE, FALSE, NULL); - if (tcpc->onEventHandled == NULL) - { - WARN("creating handled event failed, error %lu\n", GetLastError()); - goto done; - } - tcpc->sock = sock; - thread = CreateThread(NULL, 0, rpcrt4_tcp_poll_thread, tcpc, 0, NULL); - if (thread == NULL) - { - WARN("creating server polling thread failed, error %lu\n", - GetLastError()); - tcpc->sock = -1; - goto done; - } - CloseHandle(thread); - - done: - if (thread == NULL) /* ie. we failed somewhere */ - { close(sock); - if (tcpc->onEventAvailable != NULL) - { - CloseHandle(tcpc->onEventAvailable); - tcpc->onEventAvailable = NULL; - } - if (tcpc->onEventHandled != NULL) - { - CloseHandle(tcpc->onEventHandled); - tcpc->onEventHandled = NULL; - } continue; } + tcpc->sock = sock; } else /* it's a client */ { @@ -748,12 +675,6 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection) return RPC_S_SERVER_UNAVAILABLE; } -static HANDLE rpcrt4_conn_tcp_get_wait_handle(RpcConnection *Connection) -{ - RpcConnection_tcp *tcpc = (RpcConnection_tcp*) Connection; - return tcpc->onEventAvailable; -} - static RPC_STATUS rpcrt4_conn_tcp_handoff(RpcConnection *old_conn, RpcConnection *new_conn) { int ret; @@ -764,7 +685,6 @@ static RPC_STATUS rpcrt4_conn_tcp_handoff(RpcConnection *old_conn, RpcConnection addrsize = sizeof(address); ret = accept(server->sock, (struct sockaddr*) &address, &addrsize); - SetEvent(server->onEventHandled); if (ret < 0) { ERR("Failed to accept a TCP connection: error %d\n", ret); @@ -799,15 +719,6 @@ static int rpcrt4_conn_tcp_close(RpcConnection *Connection) TRACE("%d\n", tcpc->sock); - if (tcpc->onEventAvailable != NULL) - { - /* it's a server connection */ - tcpc->quit = TRUE; - WaitForSingleObject(tcpc->onEventAvailable, INFINITE); - CloseHandle(tcpc->onEventAvailable); - CloseHandle(tcpc->onEventHandled); - } - if (tcpc->sock != -1) close(tcpc->sock); tcpc->sock = -1; @@ -949,6 +860,145 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_parse_top_of_tower(const unsigned char *to return RPC_S_OK; } +typedef struct _RpcServerProtseq_sock +{ + RpcServerProtseq common; + int mgr_event_rcv; + int mgr_event_snd; +} RpcServerProtseq_sock; + +static RpcServerProtseq *rpcrt4_protseq_sock_alloc(void) +{ + RpcServerProtseq_sock *ps = HeapAlloc(GetProcessHeap(), 0, sizeof(*ps)); + if (ps) + { + int fds[2]; + if (!socketpair(PF_UNIX, SOCK_DGRAM, 0, fds)) + { + fcntl(fds[0], F_SETFL, O_NONBLOCK); + fcntl(fds[1], F_SETFL, O_NONBLOCK); + ps->mgr_event_rcv = fds[0]; + ps->mgr_event_snd = fds[1]; + } + else + { + ERR("socketpair failed with error %s\n", strerror(errno)); + HeapFree(GetProcessHeap(), 0, ps); + return NULL; + } + } + return &ps->common; +} + +static void rpcrt4_protseq_sock_signal_state_changed(RpcServerProtseq *protseq) +{ + RpcServerProtseq_sock *sockps = CONTAINING_RECORD(protseq, RpcServerProtseq_sock, common); + char dummy = 1; + write(sockps->mgr_event_snd, &dummy, sizeof(dummy)); +} + +static void *rpcrt4_protseq_sock_get_wait_array(RpcServerProtseq *protseq, void *prev_array, unsigned int *count) +{ + struct pollfd *poll_info = prev_array; + RpcConnection_tcp *conn; + RpcServerProtseq_sock *sockps = CONTAINING_RECORD(protseq, RpcServerProtseq_sock, common); + + EnterCriticalSection(&protseq->cs); + + /* open and count connections */ + *count = 1; + conn = (RpcConnection_tcp *)protseq->conn; + while (conn) { + RPCRT4_OpenConnection(&conn->common); + if (conn->sock != -1) + (*count)++; + conn = (RpcConnection_tcp *)conn->common.Next; + } + + /* make array of connections */ + if (poll_info) + poll_info = HeapReAlloc(GetProcessHeap(), 0, poll_info, *count*sizeof(*poll_info)); + else + poll_info = HeapAlloc(GetProcessHeap(), 0, *count*sizeof(*poll_info)); + if (!poll_info) + { + ERR("couldn't allocate poll_info\n"); + LeaveCriticalSection(&protseq->cs); + return NULL; + } + + poll_info[0].fd = sockps->mgr_event_rcv; + poll_info[0].events = POLLIN; + *count = 1; + conn = CONTAINING_RECORD(protseq->conn, RpcConnection_tcp, common); + while (conn) { + if (conn->sock != -1) + { + poll_info[*count].fd = conn->sock; + poll_info[*count].events = POLLIN; + (*count)++; + } + conn = CONTAINING_RECORD(conn->common.Next, RpcConnection_tcp, common); + } + LeaveCriticalSection(&protseq->cs); + return poll_info; +} + +static void rpcrt4_protseq_sock_free_wait_array(RpcServerProtseq *protseq, void *array) +{ + HeapFree(GetProcessHeap(), 0, array); +} + +static int rpcrt4_protseq_sock_wait_for_new_connection(RpcServerProtseq *protseq, unsigned int count, void *wait_array) +{ + struct pollfd *poll_info = wait_array; + int ret, i; + RpcConnection *cconn; + RpcConnection_tcp *conn; + + if (!poll_info) + return -1; + + ret = poll(poll_info, count, -1); + if (ret < 0) + { + ERR("poll failed with error %d\n", ret); + return -1; + } + + for (i = 0; i < count; i++) + if (poll_info[i].revents & POLLIN) + { + /* RPC server event */ + if (i == 0) + { + char dummy; + read(poll_info[0].fd, &dummy, sizeof(dummy)); + return 0; + } + + /* find which connection got a RPC */ + EnterCriticalSection(&protseq->cs); + conn = CONTAINING_RECORD(protseq->conn, RpcConnection_tcp, common); + while (conn) { + if (poll_info[i].fd == conn->sock) break; + conn = CONTAINING_RECORD(conn->common.Next, RpcConnection_tcp, common); + } + cconn = NULL; + if (conn) + RPCRT4_SpawnConnection(&cconn, &conn->common); + else + ERR("failed to locate connection for fd %d\n", poll_info[i].fd); + LeaveCriticalSection(&protseq->cs); + if (cconn) + RPCRT4_new_client(cconn); + else + return -1; + } + + return 1; +} + static const struct connection_ops conn_protseq_list[] = { { "ncacn_np", { EPM_PROTOCOL_NCACN, EPM_PROTOCOL_SMB }, @@ -978,7 +1028,7 @@ static const struct connection_ops conn_protseq_list[] = { { EPM_PROTOCOL_NCACN, EPM_PROTOCOL_TCP }, rpcrt4_conn_tcp_alloc, rpcrt4_ncacn_ip_tcp_open, - rpcrt4_conn_tcp_get_wait_handle, + NULL, rpcrt4_conn_tcp_handoff, rpcrt4_conn_tcp_read, rpcrt4_conn_tcp_write, @@ -1009,11 +1059,11 @@ static const struct protseq_ops protseq_list[] = }, { "ncacn_ip_tcp", - rpcrt4_protseq_np_alloc, - rpcrt4_protseq_np_signal_state_changed, - rpcrt4_protseq_np_get_wait_array, - rpcrt4_protseq_np_free_wait_array, - rpcrt4_protseq_np_wait_for_new_connection, + rpcrt4_protseq_sock_alloc, + rpcrt4_protseq_sock_signal_state_changed, + rpcrt4_protseq_sock_get_wait_array, + rpcrt4_protseq_sock_free_wait_array, + rpcrt4_protseq_sock_wait_for_new_connection, }, };