rpcrt4: Implement RpcCancelThread for the ncacn_ip_tcp protocol sequence.
This commit is contained in:
parent
4dda7c6371
commit
a51486c657
|
@ -104,6 +104,7 @@ struct connection_ops {
|
|||
int (*read)(RpcConnection *conn, void *buffer, unsigned int len);
|
||||
int (*write)(RpcConnection *conn, const void *buffer, unsigned int len);
|
||||
int (*close)(RpcConnection *conn);
|
||||
void (*cancel_call)(RpcConnection *conn);
|
||||
size_t (*get_top_of_tower)(unsigned char *tower_data, const char *networkaddr, const char *endpoint);
|
||||
RPC_STATUS (*parse_top_of_tower)(const unsigned char *tower_data, size_t tower_size, char **networkaddr, char **endpoint);
|
||||
};
|
||||
|
@ -190,6 +191,11 @@ static inline int rpcrt4_conn_close(RpcConnection *Connection)
|
|||
return Connection->ops->close(Connection);
|
||||
}
|
||||
|
||||
static inline void rpcrt4_conn_cancel_call(RpcConnection *Connection)
|
||||
{
|
||||
Connection->ops->cancel_call(Connection);
|
||||
}
|
||||
|
||||
static inline RPC_STATUS rpcrt4_conn_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
|
||||
{
|
||||
return old_conn->ops->handoff(old_conn, new_conn);
|
||||
|
@ -199,4 +205,6 @@ static inline RPC_STATUS rpcrt4_conn_handoff(RpcConnection *old_conn, RpcConnect
|
|||
RPC_STATUS RpcTransport_GetTopOfTower(unsigned char *tower_data, size_t *tower_size, const char *protseq, const char *networkaddr, const char *endpoint);
|
||||
RPC_STATUS RpcTransport_ParseTopOfTower(const unsigned char *tower_data, size_t tower_size, char **protseq, char **networkaddr, char **endpoint);
|
||||
|
||||
void RPCRT4_SetThreadCurrentConnection(RpcConnection *Connection);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -449,6 +449,8 @@ static RPC_STATUS RPCRT4_SendAuth(RpcConnection *Connection, RpcPktHdr *Header,
|
|||
LONG alen;
|
||||
RPC_STATUS status;
|
||||
|
||||
RPCRT4_SetThreadCurrentConnection(Connection);
|
||||
|
||||
buffer_pos = Buffer;
|
||||
/* The packet building functions save the packet header size, so we can use it. */
|
||||
hdr_size = Header->common.frag_len;
|
||||
|
@ -518,6 +520,7 @@ static RPC_STATUS RPCRT4_SendAuth(RpcConnection *Connection, RpcPktHdr *Header,
|
|||
if (status != RPC_S_OK)
|
||||
{
|
||||
HeapFree(GetProcessHeap(), 0, pkt);
|
||||
RPCRT4_SetThreadCurrentConnection(NULL);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
@ -528,6 +531,7 @@ write:
|
|||
HeapFree(GetProcessHeap(), 0, pkt);
|
||||
if (count<0) {
|
||||
WARN("rpcrt4_conn_write failed (auth)\n");
|
||||
RPCRT4_SetThreadCurrentConnection(NULL);
|
||||
return RPC_S_PROTOCOL_ERROR;
|
||||
}
|
||||
|
||||
|
@ -536,6 +540,7 @@ write:
|
|||
Header->common.flags &= ~RPC_FLG_FIRST;
|
||||
}
|
||||
|
||||
RPCRT4_SetThreadCurrentConnection(NULL);
|
||||
return RPC_S_OK;
|
||||
}
|
||||
|
||||
|
@ -697,6 +702,8 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
|
|||
|
||||
TRACE("(%p, %p, %p)\n", Connection, Header, pMsg);
|
||||
|
||||
RPCRT4_SetThreadCurrentConnection(Connection);
|
||||
|
||||
/* read packet common header */
|
||||
dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr));
|
||||
if (dwRead != sizeof(common_hdr)) {
|
||||
|
@ -872,6 +879,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
|
|||
status = RPC_S_OK;
|
||||
|
||||
fail:
|
||||
RPCRT4_SetThreadCurrentConnection(NULL);
|
||||
if (status != RPC_S_OK) {
|
||||
RPCRT4_FreeHeader(*Header);
|
||||
*Header = NULL;
|
||||
|
|
|
@ -416,6 +416,11 @@ static int rpcrt4_conn_np_close(RpcConnection *Connection)
|
|||
return 0;
|
||||
}
|
||||
|
||||
static void rpcrt4_conn_np_cancel_call(RpcConnection *Connection)
|
||||
{
|
||||
/* FIXME: implement when named pipe writes use overlapped I/O */
|
||||
}
|
||||
|
||||
static size_t rpcrt4_ncacn_np_get_top_of_tower(unsigned char *tower_data,
|
||||
const char *networkaddr,
|
||||
const char *endpoint)
|
||||
|
@ -703,6 +708,7 @@ typedef struct _RpcConnection_tcp
|
|||
{
|
||||
RpcConnection common;
|
||||
int sock;
|
||||
int cancel_fds[2];
|
||||
} RpcConnection_tcp;
|
||||
|
||||
static RpcConnection *rpcrt4_conn_tcp_alloc(void)
|
||||
|
@ -712,6 +718,12 @@ static RpcConnection *rpcrt4_conn_tcp_alloc(void)
|
|||
if (tcpc == NULL)
|
||||
return NULL;
|
||||
tcpc->sock = -1;
|
||||
if (socketpair(PF_UNIX, SOCK_STREAM, 0, tcpc->cancel_fds) < 0)
|
||||
{
|
||||
ERR("socketpair() failed: %s\n", strerror(errno));
|
||||
HeapFree(GetProcessHeap(), 0, tcpc);
|
||||
return NULL;
|
||||
}
|
||||
return &tcpc->common;
|
||||
}
|
||||
|
||||
|
@ -777,6 +789,7 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
|
|||
/* RPC depends on having minimal latency so disable the Nagle algorithm */
|
||||
val = 1;
|
||||
setsockopt(sock, SOL_TCP, TCP_NODELAY, &val, sizeof(val));
|
||||
fcntl(sock, F_SETFL, O_NONBLOCK); /* make socket nonblocking */
|
||||
|
||||
tcpc->sock = sock;
|
||||
|
||||
|
@ -942,18 +955,64 @@ static int rpcrt4_conn_tcp_read(RpcConnection *Connection,
|
|||
void *buffer, unsigned int count)
|
||||
{
|
||||
RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
|
||||
int r = recv(tcpc->sock, buffer, count, MSG_WAITALL);
|
||||
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r);
|
||||
return r;
|
||||
int bytes_read = 0;
|
||||
do
|
||||
{
|
||||
int r = recv(tcpc->sock, (char *)buffer + bytes_read, count - bytes_read, 0);
|
||||
if (r >= 0)
|
||||
bytes_read += r;
|
||||
else if (errno != EAGAIN)
|
||||
return -1;
|
||||
else
|
||||
{
|
||||
struct pollfd pfds[2];
|
||||
pfds[0].fd = tcpc->sock;
|
||||
pfds[0].events = POLLIN;
|
||||
pfds[1].fd = tcpc->cancel_fds[0];
|
||||
pfds[1].fd = POLLIN;
|
||||
if (poll(pfds, 2, -1 /* infinite */) == -1 && errno != EINTR)
|
||||
{
|
||||
ERR("poll() failed: %s\n", strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
if (pfds[1].revents & POLLIN) /* canceled */
|
||||
{
|
||||
char dummy;
|
||||
read(pfds[1].fd, &dummy, sizeof(dummy));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} while (bytes_read != count);
|
||||
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, bytes_read);
|
||||
return bytes_read;
|
||||
}
|
||||
|
||||
static int rpcrt4_conn_tcp_write(RpcConnection *Connection,
|
||||
const void *buffer, unsigned int count)
|
||||
{
|
||||
RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
|
||||
int r = write(tcpc->sock, buffer, count);
|
||||
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, r);
|
||||
return r;
|
||||
int bytes_written = 0;
|
||||
do
|
||||
{
|
||||
int r = write(tcpc->sock, (const char *)buffer + bytes_written, count - bytes_written);
|
||||
if (r >= 0)
|
||||
bytes_written += r;
|
||||
else if (errno != EAGAIN)
|
||||
return -1;
|
||||
else
|
||||
{
|
||||
struct pollfd pfd;
|
||||
pfd.fd = tcpc->sock;
|
||||
pfd.events = POLLOUT;
|
||||
if (poll(&pfd, 1, -1 /* infinite */) == -1 && errno != EINTR)
|
||||
{
|
||||
ERR("poll() failed: %s\n", strerror(errno));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} while (bytes_written != count);
|
||||
TRACE("%d %p %u -> %d\n", tcpc->sock, buffer, count, bytes_written);
|
||||
return bytes_written;
|
||||
}
|
||||
|
||||
static int rpcrt4_conn_tcp_close(RpcConnection *Connection)
|
||||
|
@ -965,9 +1024,21 @@ static int rpcrt4_conn_tcp_close(RpcConnection *Connection)
|
|||
if (tcpc->sock != -1)
|
||||
close(tcpc->sock);
|
||||
tcpc->sock = -1;
|
||||
close(tcpc->cancel_fds[0]);
|
||||
close(tcpc->cancel_fds[1]);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void rpcrt4_conn_tcp_cancel_call(RpcConnection *Connection)
|
||||
{
|
||||
RpcConnection_tcp *tcpc = (RpcConnection_tcp *) Connection;
|
||||
char dummy = 1;
|
||||
|
||||
TRACE("%p\n", Connection);
|
||||
|
||||
write(tcpc->cancel_fds[1], &dummy, 1);
|
||||
}
|
||||
|
||||
static size_t rpcrt4_ncacn_ip_tcp_get_top_of_tower(unsigned char *tower_data,
|
||||
const char *networkaddr,
|
||||
const char *endpoint)
|
||||
|
@ -1250,6 +1321,7 @@ static const struct connection_ops conn_protseq_list[] = {
|
|||
rpcrt4_conn_np_read,
|
||||
rpcrt4_conn_np_write,
|
||||
rpcrt4_conn_np_close,
|
||||
rpcrt4_conn_np_cancel_call,
|
||||
rpcrt4_ncacn_np_get_top_of_tower,
|
||||
rpcrt4_ncacn_np_parse_top_of_tower,
|
||||
},
|
||||
|
@ -1261,6 +1333,7 @@ static const struct connection_ops conn_protseq_list[] = {
|
|||
rpcrt4_conn_np_read,
|
||||
rpcrt4_conn_np_write,
|
||||
rpcrt4_conn_np_close,
|
||||
rpcrt4_conn_np_cancel_call,
|
||||
rpcrt4_ncalrpc_get_top_of_tower,
|
||||
rpcrt4_ncalrpc_parse_top_of_tower,
|
||||
},
|
||||
|
@ -1272,6 +1345,7 @@ static const struct connection_ops conn_protseq_list[] = {
|
|||
rpcrt4_conn_tcp_read,
|
||||
rpcrt4_conn_tcp_write,
|
||||
rpcrt4_conn_tcp_close,
|
||||
rpcrt4_conn_tcp_cancel_call,
|
||||
rpcrt4_ncacn_ip_tcp_get_top_of_tower,
|
||||
rpcrt4_ncacn_ip_tcp_parse_top_of_tower,
|
||||
}
|
||||
|
|
|
@ -100,6 +100,8 @@
|
|||
#include "winerror.h"
|
||||
#include "winbase.h"
|
||||
#include "winuser.h"
|
||||
#include "winnt.h"
|
||||
#include "winternl.h"
|
||||
#include "iptypes.h"
|
||||
#include "iphlpapi.h"
|
||||
#include "wine/unicode.h"
|
||||
|
@ -133,6 +135,25 @@ static CRITICAL_SECTION_DEBUG critsect_debug =
|
|||
};
|
||||
static CRITICAL_SECTION uuid_cs = { &critsect_debug, -1, 0, 0, 0, 0 };
|
||||
|
||||
static CRITICAL_SECTION threaddata_cs;
|
||||
static CRITICAL_SECTION_DEBUG threaddata_cs_debug =
|
||||
{
|
||||
0, 0, &uuid_cs,
|
||||
{ &threaddata_cs_debug.ProcessLocksList, &threaddata_cs_debug.ProcessLocksList },
|
||||
0, 0, { (DWORD_PTR)(__FILE__ ": threaddata_cs") }
|
||||
};
|
||||
static CRITICAL_SECTION threaddata_cs = { &threaddata_cs_debug, -1, 0, 0, 0, 0 };
|
||||
|
||||
struct list threaddata_list = LIST_INIT(threaddata_list);
|
||||
|
||||
struct threaddata
|
||||
{
|
||||
struct list entry;
|
||||
CRITICAL_SECTION cs;
|
||||
DWORD thread_id;
|
||||
RpcConnection *connection;
|
||||
};
|
||||
|
||||
/***********************************************************************
|
||||
* DllMain
|
||||
*
|
||||
|
@ -148,14 +169,29 @@ static CRITICAL_SECTION uuid_cs = { &critsect_debug, -1, 0, 0, 0, 0 };
|
|||
|
||||
BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved)
|
||||
{
|
||||
struct threaddata *tdata;
|
||||
|
||||
switch (fdwReason) {
|
||||
case DLL_PROCESS_ATTACH:
|
||||
DisableThreadLibraryCalls(hinstDLL);
|
||||
master_mutex = CreateMutexA( NULL, FALSE, RPCSS_MASTER_MUTEX_NAME);
|
||||
if (!master_mutex)
|
||||
ERR("Failed to create master mutex\n");
|
||||
break;
|
||||
|
||||
case DLL_THREAD_DETACH:
|
||||
tdata = NtCurrentTeb()->ReservedForNtRpc;
|
||||
if (tdata)
|
||||
{
|
||||
EnterCriticalSection(&threaddata_cs);
|
||||
list_remove(&tdata->entry);
|
||||
LeaveCriticalSection(&threaddata_cs);
|
||||
|
||||
DeleteCriticalSection(&tdata->cs);
|
||||
if (tdata->connection)
|
||||
ERR("tdata->connection should be NULL but is still set to %p\n", tdata);
|
||||
HeapFree(GetProcessHeap(), 0, tdata);
|
||||
}
|
||||
|
||||
case DLL_PROCESS_DETACH:
|
||||
CloseHandle(master_mutex);
|
||||
master_mutex = NULL;
|
||||
|
@ -847,11 +883,53 @@ RPC_STATUS RPC_ENTRY RpcMgmtSetCancelTimeout(LONG Timeout)
|
|||
return RPC_S_OK;
|
||||
}
|
||||
|
||||
void RPCRT4_SetThreadCurrentConnection(RpcConnection *Connection)
|
||||
{
|
||||
struct threaddata *tdata = NtCurrentTeb()->ReservedForNtRpc;
|
||||
if (!tdata)
|
||||
{
|
||||
tdata = HeapAlloc(GetProcessHeap(), 0, sizeof(*tdata));
|
||||
if (!tdata) return;
|
||||
|
||||
InitializeCriticalSection(&tdata->cs);
|
||||
tdata->thread_id = GetCurrentThreadId();
|
||||
tdata->connection = Connection;
|
||||
|
||||
EnterCriticalSection(&threaddata_cs);
|
||||
list_add_tail(&threaddata_list, &tdata->entry);
|
||||
LeaveCriticalSection(&threaddata_cs);
|
||||
|
||||
NtCurrentTeb()->ReservedForNtRpc = tdata;
|
||||
return;
|
||||
}
|
||||
|
||||
EnterCriticalSection(&tdata->cs);
|
||||
tdata->connection = Connection;
|
||||
LeaveCriticalSection(&tdata->cs);
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* RpcCancelThread (rpcrt4.@)
|
||||
*/
|
||||
RPC_STATUS RPC_ENTRY RpcCancelThread(HANDLE ThreadHandle)
|
||||
{
|
||||
FIXME("(%p): stub\n", ThreadHandle);
|
||||
DWORD target_tid;
|
||||
struct threaddata *tdata;
|
||||
|
||||
TRACE("(%p)\n", ThreadHandle);
|
||||
|
||||
target_tid = GetThreadId(ThreadHandle);
|
||||
if (!target_tid)
|
||||
return RPC_S_INVALID_ARG;
|
||||
|
||||
EnterCriticalSection(&threaddata_cs);
|
||||
LIST_FOR_EACH_ENTRY(tdata, &threaddata_list, struct threaddata, entry)
|
||||
if (tdata->thread_id == target_tid)
|
||||
{
|
||||
rpcrt4_conn_cancel_call(tdata->connection);
|
||||
break;
|
||||
}
|
||||
LeaveCriticalSection(&threaddata_cs);
|
||||
|
||||
return RPC_S_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue