From c5580b0355513ca5566d36e58401d6c1c4695312 Mon Sep 17 00:00:00 2001 From: Filip Navara Date: Mon, 26 Apr 2004 23:33:39 +0000 Subject: [PATCH] Make RPCRT4 use Windows compatible protocol (DCE v5.0) for communication. --- dlls/rpcrt4/rpc_binding.c | 124 +++++++- dlls/rpcrt4/rpc_binding.h | 6 +- dlls/rpcrt4/rpc_defs.h | 140 +++++++-- dlls/rpcrt4/rpc_message.c | 593 +++++++++++++++++++++++++++++--------- dlls/rpcrt4/rpc_message.h | 38 +++ dlls/rpcrt4/rpc_server.c | 217 ++++++++------ 6 files changed, 853 insertions(+), 265 deletions(-) create mode 100644 dlls/rpcrt4/rpc_message.h diff --git a/dlls/rpcrt4/rpc_binding.c b/dlls/rpcrt4/rpc_binding.c index 857c921ec16..9764f204c4c 100644 --- a/dlls/rpcrt4/rpc_binding.c +++ b/dlls/rpcrt4/rpc_binding.c @@ -3,6 +3,7 @@ * * Copyright 2001 Ove Kåven, TransGaming Technologies * Copyright 2003 Mike Hearn + * Copyright 2004 Filip Navara * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -36,10 +37,12 @@ #include "wine/unicode.h" #include "rpc.h" +#include "rpcndr.h" #include "wine/debug.h" #include "rpc_binding.h" +#include "rpc_message.h" WINE_DEFAULT_DEBUG_CHANNEL(ole); @@ -117,6 +120,7 @@ RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPST NewConnection->NetworkAddr = RPCRT4_strdupA(NetworkAddr); NewConnection->Endpoint = RPCRT4_strdupA(Endpoint); NewConnection->Used = Binding; + NewConnection->MaxTransmissionSize = RPC_MAX_PACKET_SIZE; EnterCriticalSection(&conn_cache_cs); NewConnection->Next = conn_cache; @@ -206,8 +210,9 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) pname = HeapAlloc(GetProcessHeap(), 0, strlen(prefix) + strlen(Connection->Endpoint) + 1); strcat(strcpy(pname, prefix), Connection->Endpoint); TRACE("listening on %s\n", pname); - Connection->conn = CreateNamedPipeA(pname, PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, - 0, PIPE_UNLIMITED_INSTANCES, 0, 0, 5000, NULL); + Connection->conn = CreateNamedPipeA(pname, PROFILE_SERVER | PIPE_ACCESS_DUPLEX, + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE, PIPE_UNLIMITED_INSTANCES, + RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE, 5000, NULL); HeapFree(GetProcessHeap(), 0, pname); memset(&Connection->ovl, 0, sizeof(Connection->ovl)); Connection->ovl.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); @@ -216,6 +221,8 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) if (GetLastError() == ERROR_PIPE_CONNECTED) { SetEvent(Connection->ovl.hEvent); return RPC_S_OK; + } else if (GetLastError() == ERROR_IO_PENDING) { + return RPC_S_OK; } return RPC_S_SERVER_UNAVAILABLE; } @@ -227,8 +234,9 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) pname = HeapAlloc(GetProcessHeap(), 0, strlen(prefix) + strlen(Connection->Endpoint) + 1); strcat(strcpy(pname, prefix), Connection->Endpoint); TRACE("listening on %s\n", pname); - Connection->conn = CreateNamedPipeA(pname, PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, - 0, PIPE_UNLIMITED_INSTANCES, 0, 0, 5000, NULL); + Connection->conn = CreateNamedPipeA(pname, PROFILE_SERVER | PIPE_ACCESS_DUPLEX, + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, + RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE, 5000, NULL); HeapFree(GetProcessHeap(), 0, pname); memset(&Connection->ovl, 0, sizeof(Connection->ovl)); Connection->ovl.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); @@ -254,6 +262,7 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) LPSTR pname; HANDLE conn; DWORD err; + DWORD dwMode; pname = HeapAlloc(GetProcessHeap(), 0, strlen(prefix) + strlen(Connection->Endpoint) + 1); strcat(strcpy(pname, prefix), Connection->Endpoint); @@ -261,7 +270,7 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) while (TRUE) { if (WaitNamedPipeA(pname, NMPWAIT_WAIT_FOREVER)) { conn = CreateFileA(pname, GENERIC_READ|GENERIC_WRITE, 0, NULL, - OPEN_EXISTING, FILE_FLAG_OVERLAPPED, 0); + OPEN_EXISTING, 0, 0); if (conn != INVALID_HANDLE_VALUE) break; err = GetLastError(); if (err == ERROR_PIPE_BUSY) continue; @@ -279,6 +288,9 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) /* success */ HeapFree(GetProcessHeap(), 0, pname); memset(&Connection->ovl, 0, sizeof(Connection->ovl)); + /* pipe is connected; change to message-read mode. */ + dwMode = PIPE_READMODE_MESSAGE; + SetNamedPipeHandleState(conn, &dwMode, NULL, NULL); Connection->ovl.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); Connection->conn = conn; } @@ -288,12 +300,13 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) LPSTR pname; HANDLE conn; DWORD err; + DWORD dwMode; pname = HeapAlloc(GetProcessHeap(), 0, strlen(prefix) + strlen(Connection->Endpoint) + 1); strcat(strcpy(pname, prefix), Connection->Endpoint); TRACE("connecting to %s\n", pname); conn = CreateFileA(pname, GENERIC_READ|GENERIC_WRITE, 0, NULL, - OPEN_EXISTING, FILE_FLAG_OVERLAPPED, 0); + OPEN_EXISTING, 0, 0); if (conn == INVALID_HANDLE_VALUE) { err = GetLastError(); /* we don't need to handle ERROR_PIPE_BUSY here, @@ -309,6 +322,9 @@ RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection) /* success */ HeapFree(GetProcessHeap(), 0, pname); memset(&Connection->ovl, 0, sizeof(Connection->ovl)); + /* pipe is connected; change to message-read mode. */ + dwMode = PIPE_READMODE_MESSAGE; + SetNamedPipeHandleState(conn, &dwMode, NULL, NULL); Connection->ovl.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); Connection->conn = conn; } else { @@ -484,18 +500,100 @@ RPC_STATUS RPCRT4_DestroyBinding(RpcBinding* Binding) return RPC_S_OK; } -RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection) +RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection, + PRPC_SYNTAX_IDENTIFIER TransferSyntax, + PRPC_SYNTAX_IDENTIFIER InterfaceId) { RpcConnection* NewConnection; - TRACE("(Binding == ^%p)\n", Binding); - if (Binding->FromConn) { - *Connection = Binding->FromConn; - return RPC_S_OK; - } + RPC_STATUS status; + TRACE("(Binding == ^%p)\n", Binding); + + /* if we try to bind a new interface and the connection is already opened, + * close the current connection and create a new with the new binding. */ + if (!Binding->server && Binding->FromConn && + memcmp(&Binding->FromConn->ActiveInterface, InterfaceId, + sizeof(RPC_SYNTAX_IDENTIFIER))) { + RPCRT4_ReleaseConnection(Binding->FromConn); + Binding->FromConn = NULL; + } else { + /* we already have an connection with acceptable binding, so use it */ + if (Binding->FromConn) { + *Connection = Binding->FromConn; + return RPC_S_OK; + } + } + + /* create a new connection */ RPCRT4_GetConnection(&NewConnection, Binding->server, Binding->Protseq, Binding->NetworkAddr, Binding->Endpoint, NULL, Binding); *Connection = NewConnection; - return RPCRT4_OpenConnection(NewConnection); + status = RPCRT4_OpenConnection(NewConnection); + if (status != RPC_S_OK) { + return status; + } + + /* we need to send a binding packet if we are client. */ + if (!(*Connection)->server) { + RpcPktHdr *hdr; + DWORD count; + BYTE *response; + RpcPktHdr *response_hdr; + + TRACE("sending bind request to server\n"); + + hdr = RPCRT4_BuildBindHeader(NDR_LOCAL_DATA_REPRESENTATION, + RPC_MAX_PACKET_SIZE, RPC_MAX_PACKET_SIZE, + InterfaceId, TransferSyntax); + + status = RPCRT4_Send(*Connection, hdr, NULL, 0); + if (status != RPC_S_OK) { + RPCRT4_ReleaseConnection(*Connection); + return status; + } + + response = HeapAlloc(GetProcessHeap(), 0, RPC_MAX_PACKET_SIZE); + if (response == NULL) { + WARN("Can't allocate memory for binding response\n"); + RPCRT4_ReleaseConnection(*Connection); + return E_OUTOFMEMORY; + } + + /* get a reply */ + if (!ReadFile(NewConnection->conn, response, RPC_MAX_PACKET_SIZE, &count, NULL)) { + WARN("ReadFile failed with error %ld\n", GetLastError()); + RPCRT4_ReleaseConnection(*Connection); + return RPC_S_PROTOCOL_ERROR; + } + + if (count < sizeof(response_hdr->common)) { + WARN("received invalid header\n"); + RPCRT4_ReleaseConnection(*Connection); + return RPC_S_PROTOCOL_ERROR; + } + + response_hdr = (RpcPktHdr*)response; + + if (response_hdr->common.rpc_ver != RPC_VER_MAJOR || + response_hdr->common.rpc_ver_minor != RPC_VER_MINOR || + response_hdr->common.ptype != PKT_BIND_ACK) { + WARN("invalid protocol version or rejection packet\n"); + RPCRT4_ReleaseConnection(Binding->FromConn); + return RPC_S_PROTOCOL_ERROR; + } + + if (response_hdr->bind_ack.max_tsize < RPC_MIN_PACKET_SIZE) { + WARN("server doesn't allow large enough packets\n"); + RPCRT4_ReleaseConnection(Binding->FromConn); + return RPC_S_PROTOCOL_ERROR; + } + + /* FIXME: do more checks? */ + + (*Connection)->MaxTransmissionSize = response_hdr->bind_ack.max_tsize; + (*Connection)->ActiveInterface = *InterfaceId; + } + + return RPC_S_OK; } RPC_STATUS RPCRT4_CloseBinding(RpcBinding* Binding, RpcConnection* Connection) diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h index 54d514b1baf..c358356dd5f 100644 --- a/dlls/rpcrt4/rpc_binding.h +++ b/dlls/rpcrt4/rpc_binding.h @@ -33,6 +33,9 @@ typedef struct _RpcConnection LPSTR Endpoint; HANDLE conn, thread; OVERLAPPED ovl; + USHORT MaxTransmissionSize; + /* The active interface bound to server. */ + RPC_SYNTAX_IDENTIFIER ActiveInterface; } RpcConnection; /* don't know what MS's structure looks like */ @@ -42,7 +45,6 @@ typedef struct _RpcBinding struct _RpcBinding* Next; BOOL server; UUID ObjectUuid; - UUID ActiveUuid; LPSTR Protseq; LPSTR NetworkAddr; LPSTR Endpoint; @@ -75,7 +77,7 @@ RPC_STATUS RPCRT4_SetBindingObject(RpcBinding* Binding, UUID* ObjectUuid); RPC_STATUS RPCRT4_MakeBinding(RpcBinding** Binding, RpcConnection* Connection); RPC_STATUS RPCRT4_ExportBinding(RpcBinding** Binding, RpcBinding* OldBinding); RPC_STATUS RPCRT4_DestroyBinding(RpcBinding* Binding); -RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection); +RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection, PRPC_SYNTAX_IDENTIFIER TransferSyntax, PRPC_SYNTAX_IDENTIFIER InterfaceId); RPC_STATUS RPCRT4_CloseBinding(RpcBinding* Binding, RpcConnection* Connection); BOOL RPCRT4_RPCSSOnDemandCall(PRPCSS_NP_MESSAGE msg, char *vardata_payload, PRPCSS_NP_REPLY reply); HANDLE RPCRT4_GetMasterMutex(void); diff --git a/dlls/rpcrt4/rpc_defs.h b/dlls/rpcrt4/rpc_defs.h index 3cb688f3bc9..9e9a7359b10 100644 --- a/dlls/rpcrt4/rpc_defs.h +++ b/dlls/rpcrt4/rpc_defs.h @@ -2,6 +2,7 @@ * RPC definitions * * Copyright 2001-2002 Ove Kåven, TransGaming Technologies + * Copyright 2004 Filip Navara * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -25,27 +26,124 @@ typedef struct { - unsigned char rpc_ver; - unsigned char ptype; - unsigned char flags1; - unsigned char flags2; - unsigned char drep[3]; - unsigned char serial_hi; - GUID object; - GUID if_id; - GUID act_id; - unsigned long server_boot; - unsigned long if_vers; - unsigned long seqnum; + unsigned char rpc_ver; /* RPC major version (5) */ + unsigned char rpc_ver_minor; /* RPC minor version (0) */ + unsigned char ptype; /* Packet type (PKT_*) */ + unsigned char flags; + unsigned char drep[4]; /* Data representation */ + unsigned short frag_len; /* Data size in bytes including header and tail. */ + unsigned short auth_len; /* Authentication length */ + unsigned long call_id; /* Call identifier. */ +} RpcPktCommonHdr; + +typedef struct +{ + RpcPktCommonHdr common; + unsigned long alloc_hint; /* Data size in bytes excluding header and tail. */ + unsigned short context_id; /* Presentation context identifier */ unsigned short opnum; - unsigned short ihint; - unsigned short ahint; - unsigned short len; - unsigned short fragnum; - unsigned char auth_proto; - unsigned char serial_lo; +} RpcPktRequestHdr; + +typedef struct +{ + RpcPktCommonHdr common; + unsigned long alloc_hint; /* Data size in bytes excluding header and tail. */ + unsigned short context_id; /* Presentation context identifier */ + unsigned char cancel_count; + unsigned char reserved; +} RpcPktResponseHdr; + +typedef struct +{ + RpcPktCommonHdr common; + unsigned long alloc_hint; /* Data size in bytes excluding header and tail. */ + unsigned short context_id; /* Presentation context identifier */ + unsigned char alert_count; /* Pending alert count */ + unsigned char padding[3]; /* Force alignment! */ + unsigned long status; /* Runtime fault code (RPC_STATUS) */ + unsigned long reserved; +} RpcPktFaultHdr; + +typedef struct +{ + RpcPktCommonHdr common; + unsigned short max_tsize; /* Maximum transmission fragment size */ + unsigned short max_rsize; /* Maximum receive fragment size */ + unsigned long assoc_gid; /* Associated group id */ + unsigned char num_elements; /* Number of elements */ + unsigned char padding[3]; /* Force alignment! */ + unsigned short context_id; /* Presentation context identifier */ + unsigned char num_syntaxes; /* Number of syntaxes */ + RPC_SYNTAX_IDENTIFIER abstract; + RPC_SYNTAX_IDENTIFIER transfer; +} RpcPktBindHdr; + +#include "pshpack1.h" +typedef struct +{ + unsigned short length; /* Length of the string including null terminator */ + char string[1]; /* String data in single byte, null terminated form */ +} RpcAddressString; +#include "poppack.h" + +typedef struct +{ + unsigned char padding1[2]; /* Force alignment! */ + unsigned char num_results; /* Number of results */ + unsigned char padding2[3]; /* Force alignment! */ + struct { + unsigned short result; + unsigned short reason; + } results[1]; +} RpcResults; + +typedef struct +{ + RpcPktCommonHdr common; + unsigned short max_tsize; /* Maximum transmission fragment size */ + unsigned short max_rsize; /* Maximum receive fragment size */ + unsigned long assoc_gid; /* Associated group id */ + /* + * Following this header are these fields: + * RpcAddressString server_address; + * RpcResults results; + * RPC_SYNTAX_IDENTIFIER transfer; + */ +} RpcPktBindAckHdr; + +typedef struct +{ + RpcPktCommonHdr common; + unsigned short reject_reason; + unsigned char protocols_count; + struct { + unsigned char rpc_ver; + unsigned char rpc_ver_minor; + } protocols[1]; +} RpcPktBindNAckHdr; + +/* Union representing all possible packet headers */ +typedef union +{ + RpcPktCommonHdr common; + RpcPktRequestHdr request; + RpcPktResponseHdr response; + RpcPktFaultHdr fault; + RpcPktBindHdr bind; + RpcPktBindAckHdr bind_ack; + RpcPktBindNAckHdr bind_nack; } RpcPktHdr; +#define RPC_VER_MAJOR 5 +#define RPC_VER_MINOR 0 + +#define RPC_FLG_FIRST 1 +#define RPC_FLG_LAST 2 +#define RPC_FLG_OBJECT_UUID 0x80 + +#define RPC_MIN_PACKET_SIZE 0x1000 +#define RPC_MAX_PACKET_SIZE 0x16D0 + #define PKT_REQUEST 0 #define PKT_PING 1 #define PKT_RESPONSE 2 @@ -59,13 +157,17 @@ typedef struct #define PKT_CANCEL_ACK 10 #define PKT_BIND 11 #define PKT_BIND_ACK 12 -#define PKT_BIND_NAK 13 +#define PKT_BIND_NACK 13 #define PKT_ALTER_CONTEXT 14 #define PKT_ALTER_CONTEXT_RESP 15 #define PKT_SHUTDOWN 17 #define PKT_CO_CANCEL 18 #define PKT_ORPHANED 19 +#define RESULT_ACCEPT 0 + +#define NO_REASON 0 + #define NCADG_IP_UDP 0x08 #define NCACN_IP_TCP 0x07 #define NCADG_IPX 0x0E diff --git a/dlls/rpcrt4/rpc_message.c b/dlls/rpcrt4/rpc_message.c index 798ee2d204f..3a87fcd4e38 100644 --- a/dlls/rpcrt4/rpc_message.c +++ b/dlls/rpcrt4/rpc_message.c @@ -2,6 +2,7 @@ * RPC messages * * Copyright 2001-2002 Ove Kåven, TransGaming Technologies + * Copyright 2004 Filip Navara * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -20,7 +21,6 @@ * TODO: * - figure out whether we *really* got this right * - check for errors and throw exceptions - * - decide if OVERLAPPED_WORKS */ #include @@ -33,6 +33,7 @@ #include "winreg.h" #include "rpc.h" +#include "rpcndr.h" #include "rpcdcep.h" #include "wine/debug.h" @@ -43,6 +44,394 @@ WINE_DEFAULT_DEBUG_CHANNEL(ole); +DWORD RPCRT4_GetHeaderSize(RpcPktHdr *Header) +{ + static const DWORD header_sizes[] = { + sizeof(Header->request), 0, sizeof(Header->response), + sizeof(Header->fault), 0, 0, 0, 0, 0, 0, 0, sizeof(Header->bind), + sizeof(Header->bind_ack), sizeof(Header->bind_nack), + 0, 0, 0, 0, 0 + }; + ULONG ret = 0; + + if (Header->common.ptype < sizeof(header_sizes) / sizeof(header_sizes[0])) { + ret = header_sizes[Header->common.ptype]; + if (ret == 0) + FIXME("unhandled packet type\n"); + if (Header->common.flags & RPC_FLG_OBJECT_UUID) + ret += sizeof(UUID); + } else { + TRACE("invalid packet type\n"); + } + + return ret; +} + +VOID RPCRT4_BuildCommonHeader(RpcPktHdr *Header, unsigned char PacketType, + unsigned long DataRepresentation) +{ + Header->common.rpc_ver = RPC_VER_MAJOR; + Header->common.rpc_ver_minor = RPC_VER_MINOR; + Header->common.ptype = PacketType; + Header->common.drep[0] = LOBYTE(LOWORD(DataRepresentation)); + Header->common.drep[1] = HIBYTE(LOWORD(DataRepresentation)); + Header->common.drep[2] = LOBYTE(HIWORD(DataRepresentation)); + Header->common.drep[3] = HIBYTE(HIWORD(DataRepresentation)); + Header->common.auth_len = 0; + Header->common.call_id = 1; + Header->common.flags = 0; + /* Flags and fragment length are computed in RPCRT4_Send. */ +} + +RpcPktHdr *RPCRT4_BuildRequestHeader(unsigned long DataRepresentation, + unsigned long BufferLength, + unsigned short ProcNum, + UUID *ObjectUuid) +{ + RpcPktHdr *header; + BOOL has_object; + RPC_STATUS status; + + has_object = (ObjectUuid != NULL && !UuidIsNil(ObjectUuid, &status)); + header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, + sizeof(header->request) + (has_object ? sizeof(UUID) : 0)); + if (header == NULL) { + return NULL; + } + + RPCRT4_BuildCommonHeader(header, PKT_REQUEST, DataRepresentation); + header->common.frag_len = sizeof(header->request); + header->request.alloc_hint = BufferLength; + header->request.context_id = 0; + header->request.opnum = ProcNum; + if (has_object) { + header->common.flags |= RPC_FLG_OBJECT_UUID; + header->common.frag_len += sizeof(UUID); + memcpy(&header->request + 1, ObjectUuid, sizeof(UUID)); + } + + return header; +} + +RpcPktHdr *RPCRT4_BuildResponseHeader(unsigned long DataRepresentation, + unsigned long BufferLength) +{ + RpcPktHdr *header; + + header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->response)); + if (header == NULL) { + return NULL; + } + + RPCRT4_BuildCommonHeader(header, PKT_RESPONSE, DataRepresentation); + header->common.frag_len = sizeof(header->response); + header->response.alloc_hint = BufferLength; + + return header; +} + +RpcPktHdr *RPCRT4_BuildFaultHeader(unsigned long DataRepresentation, + RPC_STATUS Status) +{ + RpcPktHdr *header; + + header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->fault)); + if (header == NULL) { + return NULL; + } + + RPCRT4_BuildCommonHeader(header, PKT_FAULT, DataRepresentation); + header->common.frag_len = sizeof(header->fault); + header->fault.status = Status; + + return header; +} + +RpcPktHdr *RPCRT4_BuildBindHeader(unsigned long DataRepresentation, + unsigned short MaxTransmissionSize, + unsigned short MaxReceiveSize, + RPC_SYNTAX_IDENTIFIER *AbstractId, + RPC_SYNTAX_IDENTIFIER *TransferId) +{ + RpcPktHdr *header; + + header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->bind)); + if (header == NULL) { + return NULL; + } + + RPCRT4_BuildCommonHeader(header, PKT_BIND, DataRepresentation); + header->common.frag_len = sizeof(header->bind); + header->bind.max_tsize = MaxTransmissionSize; + header->bind.max_rsize = MaxReceiveSize; + header->bind.num_elements = 1; + header->bind.num_syntaxes = 1; + memcpy(&header->bind.abstract, AbstractId, sizeof(RPC_SYNTAX_IDENTIFIER)); + memcpy(&header->bind.transfer, TransferId, sizeof(RPC_SYNTAX_IDENTIFIER)); + + return header; +} + +RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation, + unsigned char RpcVersion, + unsigned char RpcVersionMinor) +{ + RpcPktHdr *header; + + header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(header->bind_nack)); + if (header == NULL) { + return NULL; + } + + RPCRT4_BuildCommonHeader(header, PKT_BIND_NACK, DataRepresentation); + header->common.frag_len = sizeof(header->bind_nack); + header->bind_nack.protocols_count = 1; + header->bind_nack.protocols[0].rpc_ver = RpcVersion; + header->bind_nack.protocols[0].rpc_ver_minor = RpcVersionMinor; + + return header; +} + +RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, + unsigned short MaxTransmissionSize, + unsigned short MaxReceiveSize, + LPSTR ServerAddress, + unsigned long Result, + unsigned long Reason, + RPC_SYNTAX_IDENTIFIER *TransferId) +{ + RpcPktHdr *header; + unsigned long header_size; + RpcAddressString *server_address; + RpcResults *results; + RPC_SYNTAX_IDENTIFIER *transfer_id; + + header_size = sizeof(header->bind_ack) + sizeof(RpcResults) + + sizeof(RPC_SYNTAX_IDENTIFIER) + sizeof(RpcAddressString) + + strlen(ServerAddress); + + header = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, header_size); + if (header == NULL) { + return NULL; + } + + RPCRT4_BuildCommonHeader(header, PKT_BIND_ACK, DataRepresentation); + header->common.frag_len = header_size; + header->bind_ack.max_tsize = MaxTransmissionSize; + header->bind_ack.max_rsize = MaxReceiveSize; + server_address = (RpcAddressString*)(&header->bind_ack + 1); + server_address->length = strlen(ServerAddress) + 1; + strcpy(server_address->string, ServerAddress); + results = (RpcResults*)((ULONG_PTR)server_address + sizeof(RpcAddressString) + server_address->length - 1); + results->num_results = 1; + results->results[0].result = Result; + results->results[0].reason = Reason; + transfer_id = (RPC_SYNTAX_IDENTIFIER*)(results + 1); + memcpy(transfer_id, TransferId, sizeof(RPC_SYNTAX_IDENTIFIER)); + + return header; +} + +VOID RPCRT4_FreeHeader(RpcPktHdr *Header) +{ + HeapFree(GetProcessHeap(), 0, Header); +} + +/*********************************************************************** + * RPCRT4_Send (internal) + * + * Transmit a packet over connection in acceptable fragments. + */ +RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header, + void *Buffer, unsigned int BufferLength) +{ + PUCHAR buffer_pos; + DWORD hdr_size, count; + + buffer_pos = Buffer; + /* The packet building functions save the packet header size, so we can use it. */ + hdr_size = Header->common.frag_len; + Header->common.flags |= RPC_FLG_FIRST; + Header->common.flags &= ~RPC_FLG_LAST; + while (!(Header->common.flags & RPC_FLG_LAST)) { + /* decide if we need to split the packet into fragments */ + if ((BufferLength + hdr_size) <= Connection->MaxTransmissionSize) { + Header->common.flags |= RPC_FLG_LAST; + Header->common.frag_len = BufferLength + hdr_size; + } else { + Header->common.frag_len = Connection->MaxTransmissionSize; + buffer_pos += Header->common.frag_len - hdr_size; + BufferLength -= Header->common.frag_len - hdr_size; + } + + /* transmit packet header */ + if (!WriteFile(Connection->conn, Header, hdr_size, &count, NULL)) { + WARN("WriteFile failed with error %ld\n", GetLastError()); + return GetLastError(); + } + + /* fragment consisted of header only and is the last one */ + if (hdr_size == Header->common.frag_len && + Header->common.flags & RPC_FLG_LAST) { + return RPC_S_OK; + } + + /* send the fragment data */ + if (!WriteFile(Connection->conn, buffer_pos, Header->common.frag_len - hdr_size, &count, NULL)) { + WARN("WriteFile failed with error %ld\n", GetLastError()); + return GetLastError(); + } + + Header->common.flags &= ~RPC_FLG_FIRST; + } + + return RPC_S_OK; +} + +/*********************************************************************** + * RPCRT4_Receive (internal) + * + * Receive a packet from connection and merge the fragments. + */ +RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, + PRPC_MESSAGE pMsg) +{ + RPC_STATUS status; + DWORD dwRead, hdr_length; + unsigned short first_flag; + unsigned long data_length; + unsigned long buffer_length; + unsigned char *buffer_ptr; + RpcPktCommonHdr common_hdr; + + *Header = NULL; + + TRACE("(%p, %p, %p)\n", Connection, Header, pMsg); + + /* read packet common header */ + if (!ReadFile(Connection->conn, &common_hdr, sizeof(common_hdr), &dwRead, NULL)) { + if (GetLastError() != ERROR_MORE_DATA) { + WARN("ReadFile failed with error %ld\n", GetLastError()); + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + } + if (dwRead != sizeof(common_hdr)) { + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + /* verify if the header really makes sense */ + if (common_hdr.rpc_ver != RPC_VER_MAJOR || + common_hdr.rpc_ver_minor != RPC_VER_MINOR) { + WARN("unhandled packet version\n"); + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + hdr_length = RPCRT4_GetHeaderSize((RpcPktHdr*)&common_hdr); + if (hdr_length == 0) { + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + *Header = HeapAlloc(GetProcessHeap(), 0, hdr_length); + memcpy(*Header, &common_hdr, sizeof(common_hdr)); + + /* read the rest of packet header */ + if (!ReadFile(Connection->conn, &(*Header)->common + 1, + hdr_length - sizeof(common_hdr), &dwRead, NULL)) { + if (GetLastError() != ERROR_MORE_DATA) { + WARN("ReadFile failed with error %ld\n", GetLastError()); + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + } + if (dwRead != hdr_length - sizeof(common_hdr)) { + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + /* read packet body */ + switch (common_hdr.ptype) { + case PKT_RESPONSE: + pMsg->BufferLength = (*Header)->response.alloc_hint; + break; + case PKT_REQUEST: + pMsg->BufferLength = (*Header)->request.alloc_hint; + break; + default: + pMsg->BufferLength = common_hdr.frag_len - hdr_length; + } + status = I_RpcGetBuffer(pMsg); + if (status != RPC_S_OK) goto fail; + + first_flag = RPC_FLG_FIRST; + buffer_length = 0; + buffer_ptr = pMsg->Buffer; + while (buffer_length < pMsg->BufferLength) + { + data_length = (*Header)->common.frag_len - hdr_length; + if (((*Header)->common.flags & RPC_FLG_FIRST) != first_flag || + data_length + buffer_length > pMsg->BufferLength) { + TRACE("invalid packet flags or buffer length\n"); + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + if (data_length == 0) dwRead = 0; else + if (!ReadFile(Connection->conn, buffer_ptr, data_length, &dwRead, NULL)) { + if (GetLastError() != ERROR_MORE_DATA) { + WARN("ReadFile failed with error %ld\n", GetLastError()); + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + } + if (dwRead != data_length) { + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + if (buffer_length == pMsg->BufferLength && + ((*Header)->common.flags & RPC_FLG_LAST) == 0) { + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + buffer_length += data_length; + if (buffer_length < pMsg->BufferLength) { + TRACE("next header\n"); + + /* read the header of next packet */ + if (!ReadFile(Connection->conn, *Header, hdr_length, &dwRead, NULL)) { + if (GetLastError() != ERROR_MORE_DATA) { + WARN("ReadFile failed with error %ld\n", GetLastError()); + status = GetLastError(); + goto fail; + } + } + if (dwRead != hdr_length) { + WARN("invalid packet header size (%ld)\n", dwRead); + status = RPC_S_PROTOCOL_ERROR; + goto fail; + } + + buffer_ptr += data_length; + first_flag = 0; + } + } + + /* success */ + status = RPC_S_OK; + +fail: + if (status != RPC_S_OK && *Header) { + RPCRT4_FreeHeader(*Header); + *Header = NULL; + } + return status; +} + /*********************************************************************** * I_RpcGetBuffer [RPCRT4.@] */ @@ -74,7 +463,9 @@ RPC_STATUS WINAPI I_RpcFreeBuffer(PRPC_MESSAGE pMsg) { TRACE("(%p) Buffer=%p\n", pMsg, pMsg->Buffer); /* FIXME: pfnFree? */ - HeapFree(GetProcessHeap(), 0, pMsg->Buffer); + if (pMsg->Buffer != NULL) { + HeapFree(GetProcessHeap(), 0, pMsg->Buffer); + } pMsg->Buffer = NULL; return S_OK; } @@ -88,68 +479,44 @@ RPC_STATUS WINAPI I_RpcSend(PRPC_MESSAGE pMsg) RpcConnection* conn; RPC_CLIENT_INTERFACE* cif = NULL; RPC_SERVER_INTERFACE* sif = NULL; - UUID* obj; - UUID* act; RPC_STATUS status; - RpcPktHdr hdr; - DWORD count; + RpcPktHdr *hdr; TRACE("(%p)\n", pMsg); if (!bind) return RPC_S_INVALID_BINDING; - status = RPCRT4_OpenBinding(bind, &conn); - if (status != RPC_S_OK) return status; - - obj = &bind->ObjectUuid; - act = &bind->ActiveUuid; - if (bind->server) { sif = pMsg->RpcInterfaceInformation; if (!sif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */ + status = RPCRT4_OpenBinding(bind, &conn, &sif->TransferSyntax, + &sif->InterfaceId); } else { cif = pMsg->RpcInterfaceInformation; if (!cif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */ + status = RPCRT4_OpenBinding(bind, &conn, &cif->TransferSyntax, + &cif->InterfaceId); } - /* initialize packet header */ - memset(&hdr, 0, sizeof(hdr)); - hdr.rpc_ver = 4; - hdr.ptype = bind->server - ? ((pMsg->RpcFlags & WINE_RPCFLAG_EXCEPTION) ? PKT_FAULT : PKT_RESPONSE) - : PKT_REQUEST; - hdr.object = *obj; /* FIXME: IIRC iff no object, the header structure excludes this elt */ - hdr.if_id = (bind->server) ? sif->InterfaceId.SyntaxGUID : cif->InterfaceId.SyntaxGUID; - hdr.if_vers = - (bind->server) ? - MAKELONG(sif->InterfaceId.SyntaxVersion.MinorVersion, sif->InterfaceId.SyntaxVersion.MajorVersion) : - MAKELONG(cif->InterfaceId.SyntaxVersion.MinorVersion, cif->InterfaceId.SyntaxVersion.MajorVersion); - hdr.act_id = *act; - hdr.opnum = pMsg->ProcNum; - /* only the low-order 3 octets of the DataRepresentation go in the header */ - hdr.drep[0] = LOBYTE(LOWORD(pMsg->DataRepresentation)); - hdr.drep[1] = HIBYTE(LOWORD(pMsg->DataRepresentation)); - hdr.drep[2] = LOBYTE(HIWORD(pMsg->DataRepresentation)); - hdr.len = pMsg->BufferLength; + if (status != RPC_S_OK) return status; - /* transmit packet */ - if (!WriteFile(conn->conn, &hdr, sizeof(hdr), &count, NULL)) { - WARN("WriteFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - - if (!pMsg->BufferLength) - { - status = RPC_S_OK; - goto fail; - } - - if (!WriteFile(conn->conn, pMsg->Buffer, pMsg->BufferLength, &count, NULL)) { - WARN("WriteFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; + if (bind->server) { + if (pMsg->RpcFlags & WINE_RPCFLAG_EXCEPTION) { + hdr = RPCRT4_BuildFaultHeader(pMsg->DataRepresentation, + RPC_S_CALL_FAILED); + } else { + hdr = RPCRT4_BuildResponseHeader(pMsg->DataRepresentation, + pMsg->BufferLength); + } + } else { + hdr = RPCRT4_BuildRequestHeader(pMsg->DataRepresentation, + pMsg->BufferLength, pMsg->ProcNum, + &bind->ObjectUuid); } + status = RPCRT4_Send(conn, hdr, pMsg->Buffer, pMsg->BufferLength); + + RPCRT4_FreeHeader(hdr); + /* success */ if (!bind->server) { /* save the connection, so the response can be read from it */ @@ -158,7 +525,6 @@ RPC_STATUS WINAPI I_RpcSend(PRPC_MESSAGE pMsg) } RPCRT4_CloseBinding(bind, conn); status = RPC_S_OK; -fail: return status; } @@ -170,10 +536,10 @@ RPC_STATUS WINAPI I_RpcReceive(PRPC_MESSAGE pMsg) { RpcBinding* bind = (RpcBinding*)pMsg->Handle; RpcConnection* conn; - UUID* act; + RPC_CLIENT_INTERFACE* cif = NULL; + RPC_SERVER_INTERFACE* sif = NULL; RPC_STATUS status; - RpcPktHdr hdr; - DWORD dwRead; + RpcPktHdr *hdr = NULL; TRACE("(%p)\n", pMsg); if (!bind) return RPC_S_INVALID_BINDING; @@ -182,94 +548,51 @@ RPC_STATUS WINAPI I_RpcReceive(PRPC_MESSAGE pMsg) conn = pMsg->ReservedForRuntime; pMsg->ReservedForRuntime = NULL; } else { - status = RPCRT4_OpenBinding(bind, &conn); + if (bind->server) { + sif = pMsg->RpcInterfaceInformation; + if (!sif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */ + status = RPCRT4_OpenBinding(bind, &conn, &sif->TransferSyntax, + &sif->InterfaceId); + } else { + cif = pMsg->RpcInterfaceInformation; + if (!cif) return RPC_S_INTERFACE_NOT_FOUND; /* ? */ + status = RPCRT4_OpenBinding(bind, &conn, &cif->TransferSyntax, + &cif->InterfaceId); + } if (status != RPC_S_OK) return status; } - act = &bind->ActiveUuid; - - for (;;) { - /* read packet header */ -#ifdef OVERLAPPED_WORKS - if (!ReadFile(conn->conn, &hdr, sizeof(hdr), &dwRead, &conn->ovl)) { - DWORD err = GetLastError(); - if (err != ERROR_IO_PENDING) { - WARN("ReadFile failed with error %ld\n", err); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - if (!GetOverlappedResult(conn->conn, &conn->ovl, &dwRead, TRUE)) { - WARN("ReadFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - } -#else - if (!ReadFile(conn->conn, &hdr, sizeof(hdr), &dwRead, NULL)) { - WARN("ReadFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } -#endif - if (dwRead != sizeof(hdr)) { - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - - /* read packet body */ - pMsg->BufferLength = hdr.len; - status = I_RpcGetBuffer(pMsg); - if (status != RPC_S_OK) goto fail; - if (!pMsg->BufferLength) dwRead = 0; else -#ifdef OVERLAPPED_WORKS - if (!ReadFile(conn->conn, pMsg->Buffer, hdr.len, &dwRead, &conn->ovl)) { - if (GetLastError() != ERROR_IO_PENDING) { - WARN("ReadFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - if (!GetOverlappedResult(conn->conn, &conn->ovl, &dwRead, TRUE)) { - WARN("ReadFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - } -#else - if (!ReadFile(conn->conn, pMsg->Buffer, hdr.len, &dwRead, NULL)) { - WARN("ReadFile failed with error %ld\n", GetLastError()); - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } -#endif - if (dwRead != hdr.len) { - status = RPC_S_PROTOCOL_ERROR; - goto fail; - } - - status = RPC_S_PROTOCOL_ERROR; - - switch (hdr.ptype) { - case PKT_RESPONSE: - if (bind->server) goto fail; - break; - case PKT_REQUEST: - if (!bind->server) goto fail; - break; - case PKT_FAULT: - pMsg->RpcFlags |= WINE_RPCFLAG_EXCEPTION; - status = RPC_S_CALL_FAILED; /* ? */ - goto fail; - default: - goto fail; - } - - /* success */ - status = RPC_S_OK; - - /* FIXME: check destination, etc? */ - break; + status = RPCRT4_Receive(conn, &hdr, pMsg); + if (status != RPC_S_OK) { + WARN("receive failed with error %lx\n", status); + goto fail; } + + status = RPC_S_PROTOCOL_ERROR; + + switch (hdr->common.ptype) { + case PKT_RESPONSE: + if (bind->server) goto fail; + break; + case PKT_REQUEST: + if (!bind->server) goto fail; + break; + case PKT_FAULT: + pMsg->RpcFlags |= WINE_RPCFLAG_EXCEPTION; + ERR ("we got fault packet with status %lx\n", hdr->fault.status); + status = RPC_S_CALL_FAILED; /* ? */ + goto fail; + default: + goto fail; + } + + /* success */ + status = RPC_S_OK; + fail: + if (hdr) { + RPCRT4_FreeHeader(hdr); + } RPCRT4_CloseBinding(bind, conn); return status; } diff --git a/dlls/rpcrt4/rpc_message.h b/dlls/rpcrt4/rpc_message.h new file mode 100644 index 00000000000..18a73a9fb0a --- /dev/null +++ b/dlls/rpcrt4/rpc_message.h @@ -0,0 +1,38 @@ +/* + * RPC message API + * + * Copyright 2004 Filip Navara + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#ifndef __WINE_RPC_MESSAGE_H +#define __WINE_RPC_MESSAGE_H + +#include "wine/rpcss_shared.h" +#include "rpc_defs.h" + +VOID RPCRT4_BuildCommonHeader(RpcPktHdr *Header, unsigned char PacketType, unsigned long DataRepresentation); +RpcPktHdr *RPCRT4_BuildRequestHeader(unsigned long DataRepresentation, unsigned long BufferLength, unsigned short ProcNum, UUID *ObjectUuid); +RpcPktHdr *RPCRT4_BuildResponseHeader(unsigned long DataRepresentation, unsigned long BufferLength); +RpcPktHdr *RPCRT4_BuildFaultHeader(unsigned long DataRepresentation, RPC_STATUS Status); +RpcPktHdr *RPCRT4_BuildBindHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, RPC_SYNTAX_IDENTIFIER *AbstractId, RPC_SYNTAX_IDENTIFIER *TransferId); +RpcPktHdr *RPCRT4_BuildBindNackHeader(unsigned long DataRepresentation, unsigned char RpcVersion, unsigned char RpcVersionMinor); +RpcPktHdr *RPCRT4_BuildBindAckHeader(unsigned long DataRepresentation, unsigned short MaxTransmissionSize, unsigned short MaxReceiveSize, LPSTR ServerAddress, unsigned long Result, unsigned long Reason, RPC_SYNTAX_IDENTIFIER *TransferId); +VOID RPCRT4_FreeHeader(RpcPktHdr *Header); +RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header, void *Buffer, unsigned int BufferLength); +RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, PRPC_MESSAGE pMsg); + +#endif diff --git a/dlls/rpcrt4/rpc_server.c b/dlls/rpcrt4/rpc_server.c index 4b45123e9d7..11956e7155e 100644 --- a/dlls/rpcrt4/rpc_server.c +++ b/dlls/rpcrt4/rpc_server.c @@ -2,6 +2,7 @@ * RPC server API * * Copyright 2001 Ove Kåven, TransGaming Technologies + * Copyright 2004 Filip Navara * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -36,6 +37,7 @@ #include "ntstatus.h" #include "rpc.h" +#include "rpcndr.h" #include "excpt.h" #include "wine/debug.h" @@ -43,6 +45,7 @@ #include "rpc_server.h" #include "rpc_misc.h" +#include "rpc_message.h" #include "rpc_defs.h" #define MAX_THREADS 128 @@ -53,8 +56,8 @@ typedef struct _RpcPacket { struct _RpcPacket* next; struct _RpcConnection* conn; - RpcPktHdr hdr; - void* buf; + RpcPktHdr* hdr; + RPC_MESSAGE* msg; } RpcPacket; typedef struct _RpcObjTypeMap @@ -131,19 +134,22 @@ inline static UUID *LookupObjType(UUID *ObjUuid) return &uuid_nil; } -static RpcServerInterface* RPCRT4_find_interface(UUID* object, UUID* if_id) +static RpcServerInterface* RPCRT4_find_interface(UUID* object, + RPC_SYNTAX_IDENTIFIER* if_id, + BOOL check_object) { UUID* MgrType = NULL; RpcServerInterface* cif = NULL; RPC_STATUS status; - MgrType = LookupObjType(object); + if (check_object) + MgrType = LookupObjType(object); EnterCriticalSection(&server_cs); cif = ifs; while (cif) { - if (UuidEqual(if_id, &cif->If->InterfaceId.SyntaxGUID, &status) && - UuidEqual(MgrType, &cif->MgrTypeUuid, &status) && - (std_listen || (cif->Flags & RPC_IF_AUTOLISTEN))) break; + if (!memcmp(if_id, &cif->If->InterfaceId, sizeof(RPC_SYNTAX_IDENTIFIER)) && + (check_object == FALSE || UuidEqual(MgrType, &cif->MgrTypeUuid, &status)) && + (std_listen || (cif->Flags & RPC_IF_AUTOLISTEN))) break; cif = cif->Next; } LeaveCriticalSection(&server_cs); @@ -199,82 +205,132 @@ static WINE_EXCEPTION_FILTER(rpc_filter) return EXCEPTION_EXECUTE_HANDLER; } -static void RPCRT4_process_packet(RpcConnection* conn, RpcPktHdr* hdr, void* buf) +static void RPCRT4_process_packet(RpcConnection* conn, RpcPktHdr* hdr, RPC_MESSAGE* msg) { - RpcBinding* pbind; - RPC_MESSAGE msg; RpcServerInterface* sif; RPC_DISPATCH_FUNCTION func; packet_state state; + UUID *object_uuid; + RpcPktHdr *response; + void *buf = msg->Buffer; + RPC_STATUS status; - state.msg = &msg; + state.msg = msg; state.buf = buf; TlsSetValue(worker_tls, &state); - memset(&msg, 0, sizeof(msg)); - msg.BufferLength = hdr->len; - msg.Buffer = buf; - sif = RPCRT4_find_interface(&hdr->object, &hdr->if_id); - if (sif) { - TRACE("packet received for interface %s\n", debugstr_guid(&hdr->if_id)); - msg.RpcInterfaceInformation = sif->If; - /* copy the endpoint vector from sif to msg so that midl-generated code will use it */ - msg.ManagerEpv = sif->MgrEpv; - /* create temporary binding for dispatch */ - RPCRT4_MakeBinding(&pbind, conn); - RPCRT4_SetBindingObject(pbind, &hdr->object); - msg.Handle = (RPC_BINDING_HANDLE)pbind; - /* process packet */ - switch (hdr->ptype) { + + switch (hdr->common.ptype) { + case PKT_BIND: + TRACE("got bind packet\n"); + + /* FIXME: do more checks! */ + if (hdr->bind.max_tsize < RPC_MIN_PACKET_SIZE || + !UuidIsNil(&conn->ActiveInterface.SyntaxGUID, &status)) { + sif = NULL; + } else { + sif = RPCRT4_find_interface(NULL, &hdr->bind.abstract, FALSE); + } + if (sif == NULL) { + TRACE("rejecting bind request\n"); + /* Report failure to client. */ + response = RPCRT4_BuildBindNackHeader(NDR_LOCAL_DATA_REPRESENTATION, + RPC_VER_MAJOR, RPC_VER_MINOR); + } else { + TRACE("accepting bind request\n"); + + /* accept. */ + response = RPCRT4_BuildBindAckHeader(NDR_LOCAL_DATA_REPRESENTATION, + RPC_MAX_PACKET_SIZE, + RPC_MAX_PACKET_SIZE, + conn->Endpoint, + RESULT_ACCEPT, NO_REASON, + &sif->If->TransferSyntax); + + /* save the interface for later use */ + conn->ActiveInterface = hdr->bind.abstract; + conn->MaxTransmissionSize = hdr->bind.max_tsize; + } + + if (RPCRT4_Send(conn, response, NULL, 0) != RPC_S_OK) + goto fail; + + break; + case PKT_REQUEST: + TRACE("got request packet\n"); + + /* fail if the connection isn't bound with an interface */ + if (UuidIsNil(&conn->ActiveInterface.SyntaxGUID, &status)) { + response = RPCRT4_BuildFaultHeader(NDR_LOCAL_DATA_REPRESENTATION, + status); + + RPCRT4_Send(conn, response, NULL, 0); + break; + } + + if (hdr->common.flags & RPC_FLG_OBJECT_UUID) { + object_uuid = (UUID*)(&hdr->request + 1); + } else { + object_uuid = NULL; + } + + sif = RPCRT4_find_interface(object_uuid, &conn->ActiveInterface, TRUE); + msg->RpcInterfaceInformation = sif->If; + /* copy the endpoint vector from sif to msg so that midl-generated code will use it */ + msg->ManagerEpv = sif->MgrEpv; + if (object_uuid != NULL) { + RPCRT4_SetBindingObject(msg->Handle, object_uuid); + } + /* find dispatch function */ - msg.ProcNum = hdr->opnum; + msg->ProcNum = hdr->request.opnum; if (sif->Flags & RPC_IF_OLE) { /* native ole32 always gives us a dispatch table with a single entry * (I assume that's a wrapper for IRpcStubBuffer::Invoke) */ func = *sif->If->DispatchTable->DispatchTable; } else { - if (msg.ProcNum >= sif->If->DispatchTable->DispatchTableCount) { + if (msg->ProcNum >= sif->If->DispatchTable->DispatchTableCount) { ERR("invalid procnum\n"); func = NULL; } - func = sif->If->DispatchTable->DispatchTable[msg.ProcNum]; + func = sif->If->DispatchTable->DispatchTable[msg->ProcNum]; } /* put in the drep. FIXME: is this more universally applicable? perhaps we should move this outward... */ - msg.DataRepresentation = - MAKELONG( MAKEWORD(hdr->drep[0], hdr->drep[1]), - MAKEWORD(hdr->drep[2], 0)); + msg->DataRepresentation = + MAKELONG( MAKEWORD(hdr->common.drep[0], hdr->common.drep[1]), + MAKEWORD(hdr->common.drep[2], hdr->common.drep[3])); /* dispatch */ __TRY { - if (func) func(&msg); + if (func) func(msg); } __EXCEPT(rpc_filter) { /* failure packet was created in rpc_filter */ } __ENDTRY /* send response packet */ - I_RpcSend(&msg); + I_RpcSend(msg); + + msg->RpcInterfaceInformation = NULL; + break; + default: - ERR("unknown packet type\n"); + FIXME("unhandled packet type\n"); break; - } - - RPCRT4_DestroyBinding(pbind); - msg.Handle = 0; - msg.RpcInterfaceInformation = NULL; - } - else { - ERR("got RPC packet to unregistered interface %s\n", debugstr_guid(&hdr->if_id)); } +fail: /* clean up */ - if (msg.Buffer == buf) msg.Buffer = NULL; + if (msg->Buffer == buf) msg->Buffer = NULL; TRACE("freeing Buffer=%p\n", buf); HeapFree(GetProcessHeap(), 0, buf); - I_RpcFreeBuffer(&msg); - msg.Buffer = NULL; + RPCRT4_DestroyBinding(msg->Handle); + msg->Handle = 0; + I_RpcFreeBuffer(msg); + msg->Buffer = NULL; + RPCRT4_FreeHeader(hdr); TlsSetValue(worker_tls, NULL); } @@ -295,7 +351,7 @@ static DWORD CALLBACK RPCRT4_worker_thread(LPVOID the_arg) if (!pkt) continue; InterlockedDecrement(&worker_free); for (;;) { - RPCRT4_process_packet(pkt->conn, &pkt->hdr, pkt->buf); + RPCRT4_process_packet(pkt->conn, pkt->hdr, pkt->msg); HeapFree(GetProcessHeap(), 0, pkt); /* try to grab another packet here without waiting * on the semaphore, in case it hits max */ @@ -329,73 +385,42 @@ static void RPCRT4_create_worker_if_needed(void) static DWORD CALLBACK RPCRT4_io_thread(LPVOID the_arg) { RpcConnection* conn = (RpcConnection*)the_arg; - RpcPktHdr hdr; - DWORD dwRead; - void* buf = NULL; - RpcPacket* packet; + RpcPktHdr *hdr; + RpcBinding *pbind; + RPC_MESSAGE *msg; + RPC_STATUS status; + RpcPacket *packet; TRACE("(%p)\n", conn); for (;;) { - /* read packet header */ -#ifdef OVERLAPPED_WORKS - if (!ReadFile(conn->conn, &hdr, sizeof(hdr), &dwRead, &conn->ovl)) { - DWORD err = GetLastError(); - if (err != ERROR_IO_PENDING) { - TRACE("connection lost, error=%08lx\n", err); - break; - } - if (!GetOverlappedResult(conn->conn, &conn->ovl, &dwRead, TRUE)) break; - } -#else - if (!ReadFile(conn->conn, &hdr, sizeof(hdr), &dwRead, NULL)) { - TRACE("connection lost, error=%08lx\n", GetLastError()); - break; - } -#endif - if (dwRead != sizeof(hdr)) { - if (dwRead) TRACE("protocol error: \n", sizeof(hdr), dwRead); - break; - } + msg = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(RPC_MESSAGE)); - /* read packet body */ - buf = HeapAlloc(GetProcessHeap(), 0, hdr.len); - TRACE("receiving payload=%d\n", hdr.len); - if (!hdr.len) dwRead = 0; else -#ifdef OVERLAPPED_WORKS - if (!ReadFile(conn->conn, buf, hdr.len, &dwRead, &conn->ovl)) { - DWORD err = GetLastError(); - if (err != ERROR_IO_PENDING) { - TRACE("connection lost, error=%08lx\n", err); - break; - } - if (!GetOverlappedResult(conn->conn, &conn->ovl, &dwRead, TRUE)) break; - } -#else - if (!ReadFile(conn->conn, buf, hdr.len, &dwRead, NULL)) { - TRACE("connection lost, error=%08lx\n", GetLastError()); - break; - } -#endif - if (dwRead != hdr.len) { - TRACE("protocol error: \n", hdr.len, dwRead); + /* create temporary binding for dispatch, it will be freed in + * RPCRT4_process_packet */ + RPCRT4_MakeBinding(&pbind, conn); + msg->Handle = (RPC_BINDING_HANDLE)pbind; + + status = RPCRT4_Receive(conn, &hdr, msg); + if (status != RPC_S_OK) { + WARN("receive failed with error %lx\n", status); break; } #if 0 - RPCRT4_process_packet(conn, &hdr, buf); + RPCRT4_process_packet(conn, hdr, msg); #else packet = HeapAlloc(GetProcessHeap(), 0, sizeof(RpcPacket)); packet->conn = conn; packet->hdr = hdr; - packet->buf = buf; + packet->msg = msg; RPCRT4_create_worker_if_needed(); RPCRT4_push_packet(packet); ReleaseSemaphore(server_sem, 1, NULL); #endif - buf = NULL; + msg = NULL; } - if (buf) HeapFree(GetProcessHeap(), 0, buf); + if (msg) HeapFree(GetProcessHeap(), 0, msg); RPCRT4_DestroyConnection(conn); return 0; }