From a51486c657c4c2c4df9b94446822c234c34510cf Mon Sep 17 00:00:00 2001 From: Rob Shearman Date: Mon, 12 Nov 2007 20:10:19 +0000 Subject: [PATCH] rpcrt4: Implement RpcCancelThread for the ncacn_ip_tcp protocol sequence. --- dlls/rpcrt4/rpc_binding.h | 8 ++++ dlls/rpcrt4/rpc_message.c | 8 ++++ dlls/rpcrt4/rpc_transport.c | 86 ++++++++++++++++++++++++++++++++++--- dlls/rpcrt4/rpcrt4_main.c | 82 ++++++++++++++++++++++++++++++++++- 4 files changed, 176 insertions(+), 8 deletions(-) diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h index f01ae6e84a6..d7b2495dcfb 100644 --- a/dlls/rpcrt4/rpc_binding.h +++ b/dlls/rpcrt4/rpc_binding.h @@ -104,6 +104,7 @@ struct connection_ops { int (*read)(RpcConnection *conn, void *buffer, unsigned int len); int (*write)(RpcConnection *conn, const void *buffer, unsigned int len); int (*close)(RpcConnection *conn); + void (*cancel_call)(RpcConnection *conn); size_t (*get_top_of_tower)(unsigned char *tower_data, const char *networkaddr, const char *endpoint); RPC_STATUS (*parse_top_of_tower)(const unsigned char *tower_data, size_t tower_size, char **networkaddr, char **endpoint); }; @@ -190,6 +191,11 @@ static inline int rpcrt4_conn_close(RpcConnection *Connection) return Connection->ops->close(Connection); } +static inline void rpcrt4_conn_cancel_call(RpcConnection *Connection) +{ + Connection->ops->cancel_call(Connection); +} + static inline RPC_STATUS rpcrt4_conn_handoff(RpcConnection *old_conn, RpcConnection *new_conn) { return old_conn->ops->handoff(old_conn, new_conn); @@ -199,4 +205,6 @@ static inline RPC_STATUS rpcrt4_conn_handoff(RpcConnection *old_conn, RpcConnect RPC_STATUS RpcTransport_GetTopOfTower(unsigned char *tower_data, size_t *tower_size, const char *protseq, const char *networkaddr, const char *endpoint); RPC_STATUS RpcTransport_ParseTopOfTower(const unsigned char *tower_data, size_t tower_size, char **protseq, char **networkaddr, char **endpoint); +void RPCRT4_SetThreadCurrentConnection(RpcConnection *Connection); + #endif diff --git a/dlls/rpcrt4/rpc_message.c b/dlls/rpcrt4/rpc_message.c index ce8d1e7ac9f..91572b17326 100644 --- a/dlls/rpcrt4/rpc_message.c +++ b/dlls/rpcrt4/rpc_message.c @@ -449,6 +449,8 @@ static RPC_STATUS RPCRT4_SendAuth(RpcConnection *Connection, RpcPktHdr *Header, LONG alen; RPC_STATUS status; + RPCRT4_SetThreadCurrentConnection(Connection); + buffer_pos = Buffer; /* The packet building functions save the packet header size, so we can use it. */ hdr_size = Header->common.frag_len; @@ -518,6 +520,7 @@ static RPC_STATUS RPCRT4_SendAuth(RpcConnection *Connection, RpcPktHdr *Header, if (status != RPC_S_OK) { HeapFree(GetProcessHeap(), 0, pkt); + RPCRT4_SetThreadCurrentConnection(NULL); return status; } } @@ -528,6 +531,7 @@ write: HeapFree(GetProcessHeap(), 0, pkt); if (count<0) { WARN("rpcrt4_conn_write failed (auth)\n"); + RPCRT4_SetThreadCurrentConnection(NULL); return RPC_S_PROTOCOL_ERROR; } @@ -536,6 +540,7 @@ write: Header->common.flags &= ~RPC_FLG_FIRST; } + RPCRT4_SetThreadCurrentConnection(NULL); return RPC_S_OK; } @@ -697,6 +702,8 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, TRACE("(%p, %p, %p)\n", Connection, Header, pMsg); + RPCRT4_SetThreadCurrentConnection(Connection); + /* read packet common header */ dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr)); if (dwRead != sizeof(common_hdr)) { @@ -872,6 +879,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, status = RPC_S_OK; fail: + RPCRT4_SetThreadCurrentConnection(NULL); if (status != RPC_S_OK) { RPCRT4_FreeHeader(*Header); *Header = NULL; diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c index 53876a928ac..dfb7eceebe4 100644 --- a/dlls/rpcrt4/rpc_transport.c +++ b/dlls/rpcrt4/rpc_transport.c @@ -416,6 +416,11 @@ static int rpcrt4_conn_np_close(RpcConnection *Connection) return 0; } +static void rpcrt4_conn_np_cancel_call(RpcConnection *Connection) +{ + /* FIXME: implement when named pipe writes use overlapped I/O */ +} + static size_t rpcrt4_ncacn_np_get_top_of_tower(unsigned char *tower_data, const char *networkaddr, const char *endpoint) @@ -703,6 +708,7 @@ typedef struct _RpcConnection_tcp { RpcConnection common; int sock; + int cancel_fds[2]; } RpcConnection_tcp; static RpcConnection *rpcrt4_conn_tcp_alloc(void) @@ -712,6 +718,12 @@ static RpcConnection *rpcrt4_conn_tcp_alloc(void) if (tcpc == NULL) return NULL; tcpc->sock = -1; + if (socketpair(PF_UNIX, SOCK_STREAM, 0, tcpc->cancel_fds) < 0) + { + ERR("socketpair() failed: %s\n", strerror(errno)); + HeapFree(GetProcessHeap(), 0, tcpc); + return NULL; + } return &tcpc->common; } @@ -777,6 +789,7 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection) /* RPC depends on having minimal latency so disable the Nagle algorithm */ val = 1; setsockopt(sock, SOL_TCP, TCP_NODELAY, &val, sizeof(val)); + fcntl(sock, F_SETFL, O_NONBLOCK); /* make socket nonblocking */ tcpc->sock = sock; @@ -942,18 +955,64 @@ static int rpcrt4_conn_tcp_read(RpcConnection *Connection, void *buffer, unsigned int count) { RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection; - int r = recv(tcpc->sock, buffer, count, MSG_WAITALL); - TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r); - return r; + int bytes_read = 0; + do + { + int r = recv(tcpc->sock, (char *)buffer + bytes_read, count - bytes_read, 0); + if (r >= 0) + bytes_read += r; + else if (errno != EAGAIN) + return -1; + else + { + struct pollfd pfds[2]; + pfds[0].fd = tcpc->sock; + pfds[0].events = POLLIN; + pfds[1].fd = tcpc->cancel_fds[0]; + pfds[1].fd = POLLIN; + if (poll(pfds, 2, -1 /* infinite */) == -1 && errno != EINTR) + { + ERR("poll() failed: %s\n", strerror(errno)); + return -1; + } + if (pfds[1].revents & POLLIN) /* canceled */ + { + char dummy; + read(pfds[1].fd, &dummy, sizeof(dummy)); + return -1; + } + } + } while (bytes_read != count); + TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, bytes_read); + return bytes_read; } static int rpcrt4_conn_tcp_write(RpcConnection *Connection, const void *buffer, unsigned int count) { RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection; - int r = write(tcpc->sock, buffer, count); - TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r); - return r; + int bytes_written = 0; + do + { + int r = write(tcpc->sock, (const char *)buffer + bytes_written, count - bytes_written); + if (r >= 0) + bytes_written += r; + else if (errno != EAGAIN) + return -1; + else + { + struct pollfd pfd; + pfd.fd = tcpc->sock; + pfd.events = POLLOUT; + if (poll(&pfd, 1, -1 /* infinite */) == -1 && errno != EINTR) + { + ERR("poll() failed: %s\n", strerror(errno)); + return -1; + } + } + } while (bytes_written != count); + TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, bytes_written); + return bytes_written; } static int rpcrt4_conn_tcp_close(RpcConnection *Connection) @@ -965,9 +1024,21 @@ static int rpcrt4_conn_tcp_close(RpcConnection *Connection) if (tcpc->sock != -1) close(tcpc->sock); tcpc->sock = -1; + close(tcpc->cancel_fds[0]); + close(tcpc->cancel_fds[1]); return 0; } +static void rpcrt4_conn_tcp_cancel_call(RpcConnection *Connection) +{ + RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection; + char dummy = 1; + + TRACE("%p\n", Connection); + + write(tcpc->cancel_fds[1], &dummy, 1); +} + static size_t rpcrt4_ncacn_ip_tcp_get_top_of_tower(unsigned char *tower_data, const char *networkaddr, const char *endpoint) @@ -1250,6 +1321,7 @@ static const struct connection_ops conn_protseq_list[] = { rpcrt4_conn_np_read, rpcrt4_conn_np_write, rpcrt4_conn_np_close, + rpcrt4_conn_np_cancel_call, rpcrt4_ncacn_np_get_top_of_tower, rpcrt4_ncacn_np_parse_top_of_tower, }, @@ -1261,6 +1333,7 @@ static const struct connection_ops conn_protseq_list[] = { rpcrt4_conn_np_read, rpcrt4_conn_np_write, rpcrt4_conn_np_close, + rpcrt4_conn_np_cancel_call, rpcrt4_ncalrpc_get_top_of_tower, rpcrt4_ncalrpc_parse_top_of_tower, }, @@ -1272,6 +1345,7 @@ static const struct connection_ops conn_protseq_list[] = { rpcrt4_conn_tcp_read, rpcrt4_conn_tcp_write, rpcrt4_conn_tcp_close, + rpcrt4_conn_tcp_cancel_call, rpcrt4_ncacn_ip_tcp_get_top_of_tower, rpcrt4_ncacn_ip_tcp_parse_top_of_tower, } diff --git a/dlls/rpcrt4/rpcrt4_main.c b/dlls/rpcrt4/rpcrt4_main.c index e067d0e2ec5..a3caa177102 100644 --- a/dlls/rpcrt4/rpcrt4_main.c +++ b/dlls/rpcrt4/rpcrt4_main.c @@ -100,6 +100,8 @@ #include "winerror.h" #include "winbase.h" #include "winuser.h" +#include "winnt.h" +#include "winternl.h" #include "iptypes.h" #include "iphlpapi.h" #include "wine/unicode.h" @@ -133,6 +135,25 @@ static CRITICAL_SECTION_DEBUG critsect_debug = }; static CRITICAL_SECTION uuid_cs = { &critsect_debug, -1, 0, 0, 0, 0 }; +static CRITICAL_SECTION threaddata_cs; +static CRITICAL_SECTION_DEBUG threaddata_cs_debug = +{ + 0, 0, &uuid_cs, + { &threaddata_cs_debug.ProcessLocksList, &threaddata_cs_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": threaddata_cs") } +}; +static CRITICAL_SECTION threaddata_cs = { &threaddata_cs_debug, -1, 0, 0, 0, 0 }; + +struct list threaddata_list = LIST_INIT(threaddata_list); + +struct threaddata +{ + struct list entry; + CRITICAL_SECTION cs; + DWORD thread_id; + RpcConnection *connection; +}; + /*********************************************************************** * DllMain * @@ -148,14 +169,29 @@ static CRITICAL_SECTION uuid_cs = { &critsect_debug, -1, 0, 0, 0, 0 }; BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) { + struct threaddata *tdata; + switch (fdwReason) { case DLL_PROCESS_ATTACH: - DisableThreadLibraryCalls(hinstDLL); master_mutex = CreateMutexA( NULL, FALSE, RPCSS_MASTER_MUTEX_NAME); if (!master_mutex) ERR("Failed to create master mutex\n"); break; + case DLL_THREAD_DETACH: + tdata = NtCurrentTeb()->ReservedForNtRpc; + if (tdata) + { + EnterCriticalSection(&threaddata_cs); + list_remove(&tdata->entry); + LeaveCriticalSection(&threaddata_cs); + + DeleteCriticalSection(&tdata->cs); + if (tdata->connection) + ERR("tdata->connection should be NULL but is still set to %p\n", tdata); + HeapFree(GetProcessHeap(), 0, tdata); + } + case DLL_PROCESS_DETACH: CloseHandle(master_mutex); master_mutex = NULL; @@ -847,11 +883,53 @@ RPC_STATUS RPC_ENTRY RpcMgmtSetCancelTimeout(LONG Timeout) return RPC_S_OK; } +void RPCRT4_SetThreadCurrentConnection(RpcConnection *Connection) +{ + struct threaddata *tdata = NtCurrentTeb()->ReservedForNtRpc; + if (!tdata) + { + tdata = HeapAlloc(GetProcessHeap(), 0, sizeof(*tdata)); + if (!tdata) return; + + InitializeCriticalSection(&tdata->cs); + tdata->thread_id = GetCurrentThreadId(); + tdata->connection = Connection; + + EnterCriticalSection(&threaddata_cs); + list_add_tail(&threaddata_list, &tdata->entry); + LeaveCriticalSection(&threaddata_cs); + + NtCurrentTeb()->ReservedForNtRpc = tdata; + return; + } + + EnterCriticalSection(&tdata->cs); + tdata->connection = Connection; + LeaveCriticalSection(&tdata->cs); +} + /****************************************************************************** * RpcCancelThread (rpcrt4.@) */ RPC_STATUS RPC_ENTRY RpcCancelThread(HANDLE ThreadHandle) { - FIXME("(%p): stub\n", ThreadHandle); + DWORD target_tid; + struct threaddata *tdata; + + TRACE("(%p)\n", ThreadHandle); + + target_tid = GetThreadId(ThreadHandle); + if (!target_tid) + return RPC_S_INVALID_ARG; + + EnterCriticalSection(&threaddata_cs); + LIST_FOR_EACH_ENTRY(tdata, &threaddata_list, struct threaddata, entry) + if (tdata->thread_id == target_tid) + { + rpcrt4_conn_cancel_call(tdata->connection); + break; + } + LeaveCriticalSection(&threaddata_cs); + return RPC_S_OK; }