rpcrt4: Make a server association when a bind packet is received in the server.

Support handing out association group IDs.
This commit is contained in:
Rob Shearman 2007-12-16 12:07:30 +00:00 committed by Alexandre Julliard
parent 2bda19c6b0
commit 22f530c835
5 changed files with 117 additions and 33 deletions

View File

@ -63,23 +63,20 @@ typedef struct _RpcAssoc
ULONG assoc_group_id; ULONG assoc_group_id;
CRITICAL_SECTION cs; CRITICAL_SECTION cs;
struct list connection_pool; /* connections available to be used */
struct list free_connection_pool;
} RpcAssoc; } RpcAssoc;
struct connection_ops; struct connection_ops;
typedef struct _RpcConnection typedef struct _RpcConnection
{ {
struct _RpcConnection* Next;
BOOL server; BOOL server;
LPSTR NetworkAddr; LPSTR NetworkAddr;
LPSTR Endpoint; LPSTR Endpoint;
LPWSTR NetworkOptions; LPWSTR NetworkOptions;
const struct connection_ops *ops; const struct connection_ops *ops;
USHORT MaxTransmissionSize; USHORT MaxTransmissionSize;
/* The active interface bound to server. */
RPC_SYNTAX_IDENTIFIER ActiveInterface;
USHORT NextCallId;
/* authentication */ /* authentication */
CtxtHandle ctx; CtxtHandle ctx;
@ -93,6 +90,13 @@ typedef struct _RpcConnection
/* client-only */ /* client-only */
struct list conn_pool_entry; struct list conn_pool_entry;
ULONG assoc_group_id; /* association group returned during binding */ ULONG assoc_group_id; /* association group returned during binding */
/* server-only */
/* The active interface bound to server. */
RPC_SYNTAX_IDENTIFIER ActiveInterface;
USHORT NextCallId;
struct _RpcConnection* Next;
struct _RpcBinding *server_binding;
} RpcConnection; } RpcConnection;
struct connection_ops { struct connection_ops {
@ -150,6 +154,7 @@ RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endp
RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo, RpcQualityOfService *QOS, RpcConnection **Connection); RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, RpcAuthInfo *AuthInfo, RpcQualityOfService *QOS, RpcConnection **Connection);
void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection); void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection);
ULONG RpcAssoc_Release(RpcAssoc *assoc); ULONG RpcAssoc_Release(RpcAssoc *assoc);
RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, unsigned long assoc_gid, RpcAssoc **assoc_out);
RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAuthInfo* AuthInfo, RpcQualityOfService *QOS); RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAuthInfo* AuthInfo, RpcQualityOfService *QOS);
RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection); RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection);

View File

@ -241,6 +241,7 @@ RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation,
RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation,
unsigned short MaxTransmissionSize, unsigned short MaxTransmissionSize,
unsigned short MaxReceiveSize, unsigned short MaxReceiveSize,
unsigned long AssocGroupId,
LPCSTR ServerAddress, LPCSTR ServerAddress,
unsigned long Result, unsigned long Result,
unsigned long Reason, unsigned long Reason,
@ -266,6 +267,7 @@ RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation,
header->common.frag_len = header_size; header->common.frag_len = header_size;
header->bind_ack.max_tsize = MaxTransmissionSize; header->bind_ack.max_tsize = MaxTransmissionSize;
header->bind_ack.max_rsize = MaxReceiveSize; header->bind_ack.max_rsize = MaxReceiveSize;
header->bind_ack.assoc_gid = AssocGroupId;
server_address = (RpcAddressString*)(&header->bind_ack + 1); server_address = (RpcAddressString*)(&header->bind_ack + 1);
server_address->length = strlen(ServerAddress) + 1; server_address->length = strlen(ServerAddress) + 1;
strcpy(server_address->string, ServerAddress); strcpy(server_address->string, ServerAddress);

View File

@ -30,7 +30,7 @@ RpcPktHdr *RPCRT4_BuildFaultHeader(unsigned long DataRepresentation, RPC_STATUS
RpcPktHdr *RPCRT4_BuildResponseHeader(unsigned long DataRepresentation, unsigned long BufferLength); RpcPktHdr *RPCRT4_BuildResponseHeader(unsigned long DataRepresentation, unsigned long BufferLength);
RpcPktHdr *RPCRT4_BuildBindHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, unsigned long AssocGroupId, const RPC_SYNTAX_IDENTIFIER *AbstractId, const RPC_SYNTAX_IDENTIFIER *TransferId); RpcPktHdr *RPCRT4_BuildBindHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, unsigned long AssocGroupId, const RPC_SYNTAX_IDENTIFIER *AbstractId, const RPC_SYNTAX_IDENTIFIER *TransferId);
RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation, unsigned char RpcVersion, unsigned char RpcVersionMinor); RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation, unsigned char RpcVersion, unsigned char RpcVersionMinor);
RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, LPCSTR ServerAddress, unsigned long Result, unsigned long Reason, const RPC_SYNTAX_IDENTIFIER *TransferId); RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, unsigned long AssocGroupId, LPCSTR ServerAddress, unsigned long Result, unsigned long Reason, const RPC_SYNTAX_IDENTIFIER *TransferId);
VOID RPCRT4_FreeHeader(RpcPktHdr *Header); VOID RPCRT4_FreeHeader(RpcPktHdr *Header);
RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header, void *Buffer, unsigned int BufferLength); RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header, void *Buffer, unsigned int BufferLength);
RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, PRPC_MESSAGE pMsg); RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, PRPC_MESSAGE pMsg);

View File

@ -172,17 +172,29 @@ static void RPCRT4_process_packet(RpcConnection* conn, RpcPktHdr* hdr, RPC_MESSA
RPC_STATUS status; RPC_STATUS status;
BOOL exception; BOOL exception;
msg->Handle = (RPC_BINDING_HANDLE)conn->server_binding;
switch (hdr->common.ptype) { switch (hdr->common.ptype) {
case PKT_BIND: case PKT_BIND:
TRACE("got bind packet\n"); TRACE("got bind packet\n");
/* FIXME: do more checks! */ /* FIXME: do more checks! */
if (hdr->bind.max_tsize < RPC_MIN_PACKET_SIZE || if (hdr->bind.max_tsize < RPC_MIN_PACKET_SIZE ||
!UuidIsNil(&conn->ActiveInterface.SyntaxGUID, &status)) { !UuidIsNil(&conn->ActiveInterface.SyntaxGUID, &status) ||
conn->server_binding) {
TRACE("packet size less than min size, or active interface syntax guid non-null\n"); TRACE("packet size less than min size, or active interface syntax guid non-null\n");
sif = NULL; sif = NULL;
} else { } else {
sif = RPCRT4_find_interface(NULL, &hdr->bind.abstract, FALSE); /* create temporary binding */
if (RPCRT4_MakeBinding(&conn->server_binding, conn) == RPC_S_OK &&
RpcServerAssoc_GetAssociation(rpcrt4_conn_get_name(conn),
conn->NetworkAddr, conn->Endpoint,
conn->NetworkOptions,
hdr->bind.assoc_gid,
&conn->server_binding->Assoc) == RPC_S_OK)
sif = RPCRT4_find_interface(NULL, &hdr->bind.abstract, FALSE);
else
sif = NULL;
} }
if (sif == NULL) { if (sif == NULL) {
TRACE("rejecting bind request on connection %p\n", conn); TRACE("rejecting bind request on connection %p\n", conn);
@ -197,6 +209,7 @@ static void RPCRT4_process_packet(RpcConnection* conn, RpcPktHdr* hdr, RPC_MESSA
response = RPCRT4_BuildBindAckHeader(NDR_LOCAL_DATA_REPRESENTATION, response = RPCRT4_BuildBindAckHeader(NDR_LOCAL_DATA_REPRESENTATION,
RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE,
conn->server_binding->Assoc->assoc_group_id,
conn->Endpoint, conn->Endpoint,
RESULT_ACCEPT, REASON_NONE, RESULT_ACCEPT, REASON_NONE,
&sif->If->TransferSyntax); &sif->If->TransferSyntax);
@ -318,7 +331,6 @@ fail:
if (msg->Buffer == buf) msg->Buffer = NULL; if (msg->Buffer == buf) msg->Buffer = NULL;
TRACE("freeing Buffer=%p\n", buf); TRACE("freeing Buffer=%p\n", buf);
HeapFree(GetProcessHeap(), 0, buf); HeapFree(GetProcessHeap(), 0, buf);
RPCRT4_DestroyBinding(msg->Handle);
msg->Handle = 0; msg->Handle = 0;
I_RpcFreeBuffer(msg); I_RpcFreeBuffer(msg);
msg->Buffer = NULL; msg->Buffer = NULL;
@ -338,7 +350,6 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg)
{ {
RpcConnection* conn = (RpcConnection*)the_arg; RpcConnection* conn = (RpcConnection*)the_arg;
RpcPktHdr *hdr; RpcPktHdr *hdr;
RpcBinding *pbind;
RPC_MESSAGE *msg; RPC_MESSAGE *msg;
RPC_STATUS status; RPC_STATUS status;
RpcPacket *packet; RpcPacket *packet;
@ -348,11 +359,6 @@ static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg)
for (;;) { for (;;) {
msg = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RPC_MESSAGE)); msg = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RPC_MESSAGE));
/* create temporary binding for dispatch, it will be freed in
* RPCRT4_process_packet */
RPCRT4_MakeBinding(&pbind, conn);
msg->Handle = (RPC_BINDING_HANDLE)pbind;
status = RPCRT4_Receive(conn, &hdr, msg); status = RPCRT4_Receive(conn, &hdr, msg);
if (status != RPC_S_OK) { if (status != RPC_S_OK) {
WARN("receive failed with error %lx\n", status); WARN("receive failed with error %lx\n", status);

View File

@ -88,7 +88,10 @@ static CRITICAL_SECTION_DEBUG assoc_list_cs_debug =
}; };
static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 }; static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 };
static struct list assoc_list = LIST_INIT(assoc_list); static struct list client_assoc_list = LIST_INIT(client_assoc_list);
static struct list server_assoc_list = LIST_INIT(server_assoc_list);
static LONG last_assoc_group_id;
/**** ncacn_np support ****/ /**** ncacn_np support ****/
@ -1454,6 +1457,7 @@ RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server,
NewConnection = ops->alloc(); NewConnection = ops->alloc();
NewConnection->Next = NULL; NewConnection->Next = NULL;
NewConnection->server_binding = NULL;
NewConnection->server = server; NewConnection->server = server;
NewConnection->ops = ops; NewConnection->ops = ops;
NewConnection->NetworkAddr = RPCRT4_strdupA(NetworkAddr); NewConnection->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
@ -1481,14 +1485,36 @@ RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server,
return RPC_S_OK; return RPC_S_OK;
} }
static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr,
LPCSTR Endpoint, LPCWSTR NetworkOptions,
RpcAssoc **assoc_out)
{
RpcAssoc *assoc;
assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc));
if (!assoc)
return RPC_S_OUT_OF_RESOURCES;
assoc->refs = 1;
list_init(&assoc->free_connection_pool);
InitializeCriticalSection(&assoc->cs);
assoc->Protseq = RPCRT4_strdupA(Protseq);
assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
assoc->Endpoint = RPCRT4_strdupA(Endpoint);
assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
assoc->assoc_group_id = 0;
list_init(&assoc->entry);
*assoc_out = assoc;
return RPC_S_OK;
}
RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
LPCSTR Endpoint, LPCWSTR NetworkOptions, LPCSTR Endpoint, LPCWSTR NetworkOptions,
RpcAssoc **assoc_out) RpcAssoc **assoc_out)
{ {
RpcAssoc *assoc; RpcAssoc *assoc;
RPC_STATUS status;
EnterCriticalSection(&assoc_list_cs); EnterCriticalSection(&assoc_list_cs);
LIST_FOR_EACH_ENTRY(assoc, &assoc_list, RpcAssoc, entry) LIST_FOR_EACH_ENTRY(assoc, &client_assoc_list, RpcAssoc, entry)
{ {
if (!strcmp(Protseq, assoc->Protseq) && if (!strcmp(Protseq, assoc->Protseq) &&
!strcmp(NetworkAddr, assoc->NetworkAddr) && !strcmp(NetworkAddr, assoc->NetworkAddr) &&
@ -1503,21 +1529,62 @@ RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
} }
} }
assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc)); status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
if (!assoc) if (status != RPC_S_OK)
{ {
LeaveCriticalSection(&assoc_list_cs); LeaveCriticalSection(&assoc_list_cs);
return RPC_S_OUT_OF_RESOURCES; return status;
} }
assoc->refs = 1; list_add_head(&client_assoc_list, &assoc->entry);
list_init(&assoc->connection_pool); *assoc_out = assoc;
InitializeCriticalSection(&assoc->cs);
assoc->Protseq = RPCRT4_strdupA(Protseq); LeaveCriticalSection(&assoc_list_cs);
assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr);
assoc->Endpoint = RPCRT4_strdupA(Endpoint); TRACE("new assoc %p\n", assoc);
assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL;
assoc->assoc_group_id = 0; return RPC_S_OK;
list_add_head(&assoc_list, &assoc->entry); }
RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr,
LPCSTR Endpoint, LPCWSTR NetworkOptions,
unsigned long assoc_gid,
RpcAssoc **assoc_out)
{
RpcAssoc *assoc;
RPC_STATUS status;
EnterCriticalSection(&assoc_list_cs);
if (assoc_gid)
{
LIST_FOR_EACH_ENTRY(assoc, &server_assoc_list, RpcAssoc, entry)
{
/* FIXME: NetworkAddr shouldn't be NULL */
if (assoc->assoc_group_id == assoc_gid &&
!strcmp(Protseq, assoc->Protseq) &&
(!NetworkAddr || !assoc->NetworkAddr || !strcmp(NetworkAddr, assoc->NetworkAddr)) &&
!strcmp(Endpoint, assoc->Endpoint) &&
((!assoc->NetworkOptions == !NetworkOptions) &&
(!NetworkOptions || !strcmpW(NetworkOptions, assoc->NetworkOptions))))
{
assoc->refs++;
*assoc_out = assoc;
LeaveCriticalSection(&assoc_list_cs);
TRACE("using existing assoc %p\n", assoc);
return RPC_S_OK;
}
}
*assoc_out = NULL;
return RPC_S_NO_CONTEXT_AVAILABLE;
}
status = RpcAssoc_Alloc(Protseq, NetworkAddr, Endpoint, NetworkOptions, &assoc);
if (status != RPC_S_OK)
{
LeaveCriticalSection(&assoc_list_cs);
return status;
}
assoc->assoc_group_id = InterlockedIncrement(&last_assoc_group_id);
list_add_head(&server_assoc_list, &assoc->entry);
*assoc_out = assoc; *assoc_out = assoc;
LeaveCriticalSection(&assoc_list_cs); LeaveCriticalSection(&assoc_list_cs);
@ -1543,7 +1610,7 @@ ULONG RpcAssoc_Release(RpcAssoc *assoc)
TRACE("destroying assoc %p\n", assoc); TRACE("destroying assoc %p\n", assoc);
LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->connection_pool, RpcConnection, conn_pool_entry) LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
{ {
list_remove(&Connection->conn_pool_entry); list_remove(&Connection->conn_pool_entry);
RPCRT4_DestroyConnection(Connection); RPCRT4_DestroyConnection(Connection);
@ -1695,7 +1762,7 @@ static RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
RpcConnection *Connection; RpcConnection *Connection;
EnterCriticalSection(&assoc->cs); EnterCriticalSection(&assoc->cs);
/* try to find a compatible connection from the connection pool */ /* try to find a compatible connection from the connection pool */
LIST_FOR_EACH_ENTRY(Connection, &assoc->connection_pool, RpcConnection, conn_pool_entry) LIST_FOR_EACH_ENTRY(Connection, &assoc->free_connection_pool, RpcConnection, conn_pool_entry)
{ {
if (!memcmp(&Connection->ActiveInterface, InterfaceId, if (!memcmp(&Connection->ActiveInterface, InterfaceId,
sizeof(RPC_SYNTAX_IDENTIFIER)) && sizeof(RPC_SYNTAX_IDENTIFIER)) &&
@ -1757,7 +1824,7 @@ void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection)
assert(!Connection->server); assert(!Connection->server);
EnterCriticalSection(&assoc->cs); EnterCriticalSection(&assoc->cs);
if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id; if (!assoc->assoc_group_id) assoc->assoc_group_id = Connection->assoc_group_id;
list_add_head(&assoc->connection_pool, &Connection->conn_pool_entry); list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry);
LeaveCriticalSection(&assoc->cs); LeaveCriticalSection(&assoc->cs);
} }
@ -1786,6 +1853,10 @@ RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection)
HeapFree(GetProcessHeap(), 0, Connection->NetworkOptions); HeapFree(GetProcessHeap(), 0, Connection->NetworkOptions);
if (Connection->AuthInfo) RpcAuthInfo_Release(Connection->AuthInfo); if (Connection->AuthInfo) RpcAuthInfo_Release(Connection->AuthInfo);
if (Connection->QOS) RpcQualityOfService_Release(Connection->QOS); if (Connection->QOS) RpcQualityOfService_Release(Connection->QOS);
/* server-only */
if (Connection->server_binding) RPCRT4_DestroyBinding(Connection->server_binding);
HeapFree(GetProcessHeap(), 0, Connection); HeapFree(GetProcessHeap(), 0, Connection);
return RPC_S_OK; return RPC_S_OK;
} }