diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h index d7b2495dcfb..c08d5e8be17 100644 --- a/dlls/rpcrt4/rpc_binding.h +++ b/dlls/rpcrt4/rpc_binding.h @@ -63,23 +63,20 @@ typedef struct _RpcAssoc ULONG assoc_group_id; CRITICAL_SECTION cs; - struct list connection_pool; + /* connections available to be used */ + struct list free_connection_pool; } RpcAssoc; struct connection_ops; typedef struct _RpcConnection { - struct _RpcConnection* Next; BOOL server; LPSTR NetworkAddr; LPSTR Endpoint; LPWSTR NetworkOptions; const struct connection_ops *ops; USHORT MaxTransmissionSize; - /* The active interface bound to server. */ - RPC_SYNTAX_IDENTIFIER ActiveInterface; - USHORT NextCallId; /* authentication */ CtxtHandle ctx; @@ -93,6 +90,13 @@ typedef struct _RpcConnection /* client-only */ struct list conn_pool_entry; ULONG assoc_group_id; /* association group returned during binding */ + + /* server-only */ + /* The active interface bound to server. */ + RPC_SYNTAX_IDENTIFIER ActiveInterface; + USHORT NextCallId; + struct _RpcConnection* Next; + struct _RpcBinding *server_binding; } RpcConnection; struct connection_ops { @@ -150,6 +154,7 @@ RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endp RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo, RpcQualityOfService *QOS, RpcConnection **Connection); void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection); ULONG RpcAssoc_Release(RpcAssoc *assoc); +RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, unsigned long assoc_gid, RpcAssoc **assoc_out); RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAuthInfo* AuthInfo, RpcQualityOfService *QOS); RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection); diff --git a/dlls/rpcrt4/rpc_message.c b/dlls/rpcrt4/rpc_message.c index 5b4d30e58a4..1248e4605fa 100644 --- a/dlls/rpcrt4/rpc_message.c +++ b/dlls/rpcrt4/rpc_message.c @@ -241,6 +241,7 @@ RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation, RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, + unsigned long AssocGroupId, LPCSTR ServerAddress, unsigned long Result, unsigned long Reason, @@ -266,6 +267,7 @@ RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, header->common.frag_len = header_size; header->bind_ack.max_tsize = MaxTransmissionSize; header->bind_ack.max_rsize = MaxReceiveSize; + header->bind_ack.assoc_gid = AssocGroupId; server_address = (RpcAddressString*)(&header->bind_ack + 1); server_address->length = strlen(ServerAddress) + 1; strcpy(server_address->string, ServerAddress); diff --git a/dlls/rpcrt4/rpc_message.h b/dlls/rpcrt4/rpc_message.h index 8815d8d2dec..cbf83d30ce0 100644 --- a/dlls/rpcrt4/rpc_message.h +++ b/dlls/rpcrt4/rpc_message.h @@ -30,7 +30,7 @@ RpcPktHdr *RPCRT4_BuildFaultHeader(unsigned long DataRepresentation, RPC_STATUS RpcPktHdr *RPCRT4_BuildResponseHeader(unsigned long DataRepresentation, unsigned long BufferLength); RpcPktHdr *RPCRT4_BuildBindHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, unsigned long AssocGroupId, const RPC_SYNTAX_IDENTIFIER *AbstractId, const RPC_SYNTAX_IDENTIFIER *TransferId); RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation, unsigned char RpcVersion, unsigned char RpcVersionMinor); -RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, LPCSTR ServerAddress, unsigned long Result, unsigned long Reason, const RPC_SYNTAX_IDENTIFIER *TransferId); +RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, unsigned long AssocGroupId, LPCSTR ServerAddress, unsigned long Result, unsigned long Reason, const RPC_SYNTAX_IDENTIFIER *TransferId); VOID RPCRT4_FreeHeader(RpcPktHdr *Header); RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header, void *Buffer, unsigned int BufferLength); RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, PRPC_MESSAGE pMsg); diff --git a/dlls/rpcrt4/rpc_server.c b/dlls/rpcrt4/rpc_server.c index b8fc580b59f..8576f5c84a3 100644 --- a/dlls/rpcrt4/rpc_server.c +++ b/dlls/rpcrt4/rpc_server.c @@ -172,17 +172,29 @@ static void RPCRT4_process_packet(RpcConnection* conn, RpcPktHdr* hdr, RPC_MESSA RPC_STATUS status; BOOL exception; + msg->Handle = (RPC_BINDING_HANDLE)conn->server_binding; + switch (hdr->common.ptype) { case PKT_BIND: TRACE("got bind packet\n"); /* FIXME: do more checks! */ if (hdr->bind.max_tsize < RPC_MIN_PACKET_SIZE || - !UuidIsNil(&conn->ActiveInterface.SyntaxGUID, &status)) { + !UuidIsNil(&conn->ActiveInterface.SyntaxGUID, &status) || + conn->server_binding) { TRACE("packet size less than min size, or active interface syntax guid non-null\n"); sif = NULL; } else { - sif = RPCRT4_find_interface(NULL, &hdr->bind.abstract, FALSE); + /* create temporary binding */ + if (RPCRT4_MakeBinding(&conn->server_binding, conn) == RPC_S_OK && + RpcServerAssoc_GetAssociation(rpcrt4_conn_get_name(conn), + conn->NetworkAddr, conn->Endpoint, + conn->NetworkOptions, + hdr->bind.assoc_gid, + &conn->server_binding->Assoc) == RPC_S_OK) + sif = RPCRT4_find_interface(NULL, &hdr->bind.abstract, FALSE); + else + sif = NULL; } if (sif == NULL) { TRACE("rejecting bind request on connection %p\n", conn); @@ -197,6 +209,7 @@ static void RPCRT4_process_packet(RpcConnection* conn, RpcPktHdr* hdr, RPC_MESSA response = RPCRT4_BuildBindAckHeader(NDR_LOCAL_DATA_REPRESENTATION, RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE, + conn->server_binding->Assoc->assoc_group_id, conn->Endpoint, RESULT_ACCEPT, REASON_NONE, &sif->If->TransferSyntax); @@ -318,7 +331,6 @@ fail: if (msg->Buffer == buf) msg->Buffer = NULL; TRACE("freeing Buffer=%p\n", buf); HeapFree(GetProcessHeap(), 0, buf); - RPCRT4_DestroyBinding(msg->Handle); msg->Handle = 0; I_RpcFreeBuffer(msg); msg->Buffer = NULL; @@ -338,7 +350,6 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg) { RpcConnection* conn = (RpcConnection*)the_arg; RpcPktHdr *hdr; - RpcBinding *pbind; RPC_MESSAGE *msg; RPC_STATUS status; RpcPacket *packet; @@ -348,11 +359,6 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg) for (;;) { msg = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RPC_MESSAGE)); - /* create temporary binding for dispatch, it will be freed in - * RPCRT4_process_packet */ - RPCRT4_MakeBinding(&pbind, conn); - msg->Handle = (RPC_BINDING_HANDLE)pbind; - status = RPCRT4_Receive(conn, &hdr, msg); if (status != RPC_S_OK) { WARN("receive failed with error %lx\n", status); diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c index 1a78a9d6bbd..da2db6049fa 100644 --- a/dlls/rpcrt4/rpc_transport.c +++ b/dlls/rpcrt4/rpc_transport.c @@ -88,7 +88,10 @@ static CRITICAL_SECTION_DEBUG assoc_list_cs_debug = }; static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 }; -static struct list assoc_list = LIST_INIT(assoc_list); +static struct list client_assoc_list = LIST_INIT(client_assoc_list); +static struct list server_assoc_list = LIST_INIT(server_assoc_list); + +static LONG last_assoc_group_id; /**** ncacn_np support ****/ @@ -1454,6 +1457,7 @@ RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, NewConnection = ops->alloc(); NewConnection->Next = NULL; + NewConnection->server_binding = NULL; NewConnection->server = server; NewConnection->ops = ops; NewConnection->NetworkAddr = RPCRT4_strdupA(NetworkAddr); @@ -1481,14 +1485,36 @@ RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, return RPC_S_OK; } +static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr, + LPCSTR Endpoint, LPCWSTR NetworkOptions, + RpcAssoc **assoc_out) +{ + RpcAssoc *assoc; + assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc)); + if (!assoc) + return RPC_S_OUT_OF_RESOURCES; + assoc->refs = 1; + list_init(&assoc->free_connection_pool); + InitializeCriticalSection(&assoc->cs); + assoc->Protseq = RPCRT4_strdupA(Protseq); + assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr); + assoc->Endpoint = RPCRT4_strdupA(Endpoint); + assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL; + assoc->assoc_group_id = 0; + list_init(&assoc->entry); + *assoc_out = assoc; + return RPC_S_OK; +} + RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc_out) { RpcAssoc *assoc; + RPC_STATUS status; EnterCriticalSection(&assoc_list_cs); - LIST_FOR_EACH_ENTRY(assoc, &assoc_list, RpcAssoc, entry) + LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry) { if (!strcmp(Protseq, assoc->Protseq) && !strcmp(NetworkAddr, assoc->NetworkAddr) && @@ -1503,21 +1529,62 @@ RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, } } - assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc)); - if (!assoc) + status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc); + if (status != RPC_S_OK) { - LeaveCriticalSection(&assoc_list_cs); - return RPC_S_OUT_OF_RESOURCES; + LeaveCriticalSection(&assoc_list_cs); + return status; } - assoc->refs = 1; - list_init(&assoc->connection_pool); - InitializeCriticalSection(&assoc->cs); - assoc->Protseq = RPCRT4_strdupA(Protseq); - assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr); - assoc->Endpoint = RPCRT4_strdupA(Endpoint); - assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL; - assoc->assoc_group_id = 0; - list_add_head(&assoc_list, &assoc->entry); + list_add_head(&client_assoc_list, &assoc->entry); + *assoc_out = assoc; + + LeaveCriticalSection(&assoc_list_cs); + + TRACE("new assoc %p\n", assoc); + + return RPC_S_OK; +} + +RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, + LPCSTR Endpoint, LPCWSTR NetworkOptions, + unsigned long assoc_gid, + RpcAssoc **assoc_out) +{ + RpcAssoc *assoc; + RPC_STATUS status; + + EnterCriticalSection(&assoc_list_cs); + if (assoc_gid) + { + LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry) + { + /* FIXME: NetworkAddr shouldn't be NULL */ + if (assoc->assoc_group_id == assoc_gid && + !strcmp(Protseq, assoc->Protseq) && + (!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) && + !strcmp(Endpoint, assoc->Endpoint) && + ((!assoc->NetworkOptions == !NetworkOptions) && + (!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions)))) + { + assoc->refs++; + *assoc_out = assoc; + LeaveCriticalSection(&assoc_list_cs); + TRACE("using existing assoc %p\n", assoc); + return RPC_S_OK; + } + } + *assoc_out = NULL; + return RPC_S_NO_CONTEXT_AVAILABLE; + } + + status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc); + if (status != RPC_S_OK) + { + LeaveCriticalSection(&assoc_list_cs); + return status; + } + assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id); + list_add_head(&server_assoc_list, &assoc->entry); *assoc_out = assoc; LeaveCriticalSection(&assoc_list_cs); @@ -1543,7 +1610,7 @@ ULONG RpcAssoc_Release(RpcAssoc *assoc) TRACE("destroying assoc %p\n", assoc); - LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->connection_pool, RpcConnection, conn_pool_entry) + LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry) { list_remove(&Connection->conn_pool_entry); RPCRT4_DestroyConnection(Connection); @@ -1695,7 +1762,7 @@ static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc, RpcConnection *Connection; EnterCriticalSection(&assoc->cs); /* try to find a compatible connection from the connection pool */ - LIST_FOR_EACH_ENTRY(Connection, &assoc->connection_pool, RpcConnection, conn_pool_entry) + LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry) { if (!memcmp(&Connection->ActiveInterface, InterfaceId, sizeof(RPC_SYNTAX_IDENTIFIER)) && @@ -1757,7 +1824,7 @@ void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection) assert(!Connection->server); EnterCriticalSection(&assoc->cs); if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id; - list_add_head(&assoc->connection_pool, &Connection->conn_pool_entry); + list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry); LeaveCriticalSection(&assoc->cs); } @@ -1786,6 +1853,10 @@ RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection) HeapFree(GetProcessHeap(), 0, Connection->NetworkOptions); if (Connection->AuthInfo) RpcAuthInfo_Release(Connection->AuthInfo); if (Connection->QOS) RpcQualityOfService_Release(Connection->QOS); + + /* server-only */ + if (Connection->server_binding) RPCRT4_DestroyBinding(Connection->server_binding); + HeapFree(GetProcessHeap(), 0, Connection); return RPC_S_OK; }