From 8b6c30ab4b48a5f4c9a32920dd389d218b459540 Mon Sep 17 00:00:00 2001
From: Rob Shearman <rob@codeweavers.com>
Date: Wed, 8 Nov 2006 20:45:19 +0000
Subject: [PATCH] rpcrt4: Open the endpoint from the caller of
 RpcServerUseProtseq* instead of the protseq server thread.

This allows errors to be returned to the caller and to create more than
one connection for an endpoint.
---
 dlls/rpcrt4/rpc_binding.c   |   2 +-
 dlls/rpcrt4/rpc_binding.h   |   4 +-
 dlls/rpcrt4/rpc_server.c    |   3 +-
 dlls/rpcrt4/rpc_server.h    |   2 +
 dlls/rpcrt4/rpc_transport.c | 304 +++++++++++++++++++++++++++---------
 5 files changed, 237 insertions(+), 78 deletions(-)

diff --git a/dlls/rpcrt4/rpc_binding.c b/dlls/rpcrt4/rpc_binding.c
index eda2322be38..d233f034cef 100644
--- a/dlls/rpcrt4/rpc_binding.c
+++ b/dlls/rpcrt4/rpc_binding.c
@@ -258,7 +258,7 @@ RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection,
   RPCRT4_CreateConnection(&NewConnection, Binding->server, Binding->Protseq,
                           Binding->NetworkAddr, Binding->Endpoint, NULL,
                           Binding->AuthInfo, Binding);
-  status = RPCRT4_OpenConnection(NewConnection);
+  status = RPCRT4_OpenClientConnection(NewConnection);
   if (status != RPC_S_OK)
   {
     RPCRT4_DestroyConnection(NewConnection);
diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h
index e6b097c3a33..4df3d7326dc 100644
--- a/dlls/rpcrt4/rpc_binding.h
+++ b/dlls/rpcrt4/rpc_binding.h
@@ -65,7 +65,7 @@ struct connection_ops {
   const char *name;
   unsigned char epm_protocols[2]; /* only floors 3 and 4. see http://www.opengroup.org/onlinepubs/9629399/apdxl.htm */
   RpcConnection *(*alloc)(void);
-  RPC_STATUS (*open_connection)(RpcConnection *conn);
+  RPC_STATUS (*open_connection_client)(RpcConnection *conn);
   RPC_STATUS (*handoff)(RpcConnection *old_conn, RpcConnection *new_conn);
   int (*read)(RpcConnection *conn, void *buffer, unsigned int len);
   int (*write)(RpcConnection *conn, const void *buffer, unsigned int len);
@@ -108,7 +108,7 @@ RpcConnection *RPCRT4_GetIdleConnection(const RPC_SYNTAX_IDENTIFIER *InterfaceId
 void RPCRT4_ReleaseIdleConnection(RpcConnection *Connection);
 RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCSTR NetworkOptions, RpcAuthInfo* AuthInfo, RpcBinding* Binding);
 RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection);
-RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection);
+RPC_STATUS RPCRT4_OpenClientConnection(RpcConnection* Connection);
 RPC_STATUS RPCRT4_CloseConnection(RpcConnection* Connection);
 RPC_STATUS RPCRT4_SpawnConnection(RpcConnection** Connection, RpcConnection* OldConnection);
 
diff --git a/dlls/rpcrt4/rpc_server.c b/dlls/rpcrt4/rpc_server.c
index 1c9dcc035e2..e716b53f973 100644
--- a/dlls/rpcrt4/rpc_server.c
+++ b/dlls/rpcrt4/rpc_server.c
@@ -496,8 +496,7 @@ static RPC_STATUS RPCRT4_use_protseq(RpcServerProtseq* ps)
 {
   RPC_STATUS status;
 
-  status = RPCRT4_CreateConnection(&ps->conn, TRUE, ps->Protseq, NULL,
-                                   ps->Endpoint, NULL, NULL, NULL);
+  status = ps->ops->open_endpoint(ps, ps->Endpoint);
   if (status != RPC_S_OK)
     return status;
 
diff --git a/dlls/rpcrt4/rpc_server.h b/dlls/rpcrt4/rpc_server.h
index 77dae4e924d..c385bed6b03 100644
--- a/dlls/rpcrt4/rpc_server.h
+++ b/dlls/rpcrt4/rpc_server.h
@@ -56,6 +56,8 @@ struct protseq_ops
     /* returns -1 for failure, 0 for server state changed and 1 to indicate a
      * new connection was established */
     int (*wait_for_new_connection)(RpcServerProtseq *protseq, unsigned int count, void *wait_array);
+    /* opens the endpoint and optionally begins listening */
+    RPC_STATUS (*open_endpoint)(RpcServerProtseq *protseq, LPSTR endpoint);
 };
 
 typedef struct _RpcServerInterface
diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c
index a264b9bbeee..4679e715d8d 100644
--- a/dlls/rpcrt4/rpc_transport.c
+++ b/dlls/rpcrt4/rpc_transport.c
@@ -89,8 +89,9 @@ static struct list connection_pool = LIST_INIT(connection_pool);
 typedef struct _RpcConnection_np
 {
   RpcConnection common;
-  HANDLE pipe, thread;
+  HANDLE pipe;
   OVERLAPPED ovl;
+  BOOL listening;
 } RpcConnection_np;
 
 static RpcConnection *rpcrt4_conn_np_alloc(void)
@@ -99,13 +100,35 @@ static RpcConnection *rpcrt4_conn_np_alloc(void)
   if (npc)
   {
     npc->pipe = NULL;
-    npc->thread = NULL;
     memset(&npc->ovl, 0, sizeof(npc->ovl));
+    npc->listening = FALSE;
   }
   return &npc->common;
 }
 
-static RPC_STATUS rpcrt4_connect_pipe(RpcConnection *Connection, LPCSTR pname)
+static RPC_STATUS rpcrt4_conn_listen_pipe(RpcConnection_np *npc)
+{
+  if (npc->listening)
+    return RPC_S_OK;
+
+  npc->listening = TRUE;
+  if (ConnectNamedPipe(npc->pipe, &npc->ovl))
+    return RPC_S_OK;
+
+  WARN("Couldn't ConnectNamedPipe (error was %ld)\n", GetLastError());
+  if (GetLastError() == ERROR_PIPE_CONNECTED) {
+    SetEvent(npc->ovl.hEvent);
+    return RPC_S_OK;
+  }
+  if (GetLastError() == ERROR_IO_PENDING) {
+    /* FIXME: looks like we need to GetOverlappedResult here? */
+    return RPC_S_OK;
+  }
+  npc->listening = FALSE;
+  return RPC_S_SERVER_UNAVAILABLE;
+}
+
+static RPC_STATUS rpcrt4_conn_create_pipe(RpcConnection *Connection, LPCSTR pname)
 {
   RpcConnection_np *npc = (RpcConnection_np *) Connection;
   TRACE("listening on %s\n", pname);
@@ -121,22 +144,13 @@ static RPC_STATUS rpcrt4_connect_pipe(RpcConnection *Connection, LPCSTR pname)
 
   memset(&npc->ovl, 0, sizeof(npc->ovl));
   npc->ovl.hEvent = CreateEventW(NULL, TRUE, FALSE, NULL);
-  if (ConnectNamedPipe(npc->pipe, &npc->ovl))
-     return RPC_S_OK;
 
-  WARN("Couldn't ConnectNamedPipe (error was %ld)\n", GetLastError());
-  if (GetLastError() == ERROR_PIPE_CONNECTED) {
-    SetEvent(npc->ovl.hEvent);
-    return RPC_S_OK;
-  }
-  if (GetLastError() == ERROR_IO_PENDING) {
-    /* FIXME: looks like we need to GetOverlappedResult here? */
-    return RPC_S_OK;
-  }
-  return RPC_S_SERVER_UNAVAILABLE;
+  /* Note: we don't call ConnectNamedPipe here because it must be done in the
+   * server thread as the thread must be alertable */
+  return RPC_S_OK;
 }
 
-static RPC_STATUS rpcrt4_open_pipe(RpcConnection *Connection, LPCSTR pname, BOOL wait)
+static RPC_STATUS rpcrt4_conn_open_pipe(RpcConnection *Connection, LPCSTR pname, BOOL wait)
 {
   RpcConnection_np *npc = (RpcConnection_np *) Connection;
   HANDLE pipe;
@@ -188,16 +202,39 @@ static RPC_STATUS rpcrt4_ncalrpc_open(RpcConnection* Connection)
    * but we'll implement it with named pipes for now */
   pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
   strcat(strcpy(pname, prefix), Connection->Endpoint);
-
-  if (Connection->server)
-    r = rpcrt4_connect_pipe(Connection, pname);
-  else
-    r = rpcrt4_open_pipe(Connection, pname, TRUE);
+  r = rpcrt4_conn_open_pipe(Connection, pname, TRUE);
   I_RpcFree(pname);
 
   return r;
 }
 
+static RPC_STATUS rpcrt4_protseq_ncalrpc_open_endpoint(RpcServerProtseq* protseq, LPSTR endpoint)
+{
+  static LPCSTR prefix = "\\\\.\\pipe\\lrpc\\";
+  RPC_STATUS r;
+  LPSTR pname;
+  RpcConnection *Connection;
+
+  r = RPCRT4_CreateConnection(&Connection, TRUE, protseq->Protseq, NULL,
+                              endpoint, NULL, NULL, NULL);
+  if (r != RPC_S_OK)
+      return r;
+
+  /* protseq=ncalrpc: supposed to use NT LPC ports,
+   * but we'll implement it with named pipes for now */
+  pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
+  strcat(strcpy(pname, prefix), Connection->Endpoint);
+  r = rpcrt4_conn_create_pipe(Connection, pname);
+  I_RpcFree(pname);
+
+  EnterCriticalSection(&protseq->cs);
+  Connection->Next = protseq->conn;
+  protseq->conn = Connection;
+  LeaveCriticalSection(&protseq->cs);
+
+  return r;
+}
+
 static RPC_STATUS rpcrt4_ncacn_np_open(RpcConnection* Connection)
 {
   RpcConnection_np *npc = (RpcConnection_np *) Connection;
@@ -212,19 +249,35 @@ static RPC_STATUS rpcrt4_ncacn_np_open(RpcConnection* Connection)
   /* protseq=ncacn_np: named pipes */
   pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
   strcat(strcpy(pname, prefix), Connection->Endpoint);
-  if (Connection->server)
-    r = rpcrt4_connect_pipe(Connection, pname);
-  else
-    r = rpcrt4_open_pipe(Connection, pname, FALSE);
+  r = rpcrt4_conn_open_pipe(Connection, pname, FALSE);
   I_RpcFree(pname);
 
   return r;
 }
 
-static RPC_STATUS rpcrt4_conn_np_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
+static RPC_STATUS rpcrt4_protseq_ncacn_np_open_endpoint(RpcServerProtseq *protseq, LPSTR endpoint)
 {
-  RpcConnection_np *old_npc = (RpcConnection_np *) old_conn;
-  RpcConnection_np *new_npc = (RpcConnection_np *) new_conn;
+  static LPCSTR prefix = "\\\\.";
+  RPC_STATUS r;
+  LPSTR pname;
+  RpcConnection *Connection;
+
+  r = RPCRT4_CreateConnection(&Connection, TRUE, protseq->Protseq, NULL,
+                              endpoint, NULL, NULL, NULL);
+  if (r != RPC_S_OK)
+    return r;
+
+  /* protseq=ncacn_np: named pipes */
+  pname = I_RpcAllocate(strlen(prefix) + strlen(Connection->Endpoint) + 1);
+  strcat(strcpy(pname, prefix), Connection->Endpoint);
+  r = rpcrt4_conn_create_pipe(Connection, pname);
+  I_RpcFree(pname);
+    
+  return r;
+}
+
+static void rpcrt4_conn_np_handoff(RpcConnection_np *old_npc, RpcConnection_np *new_npc)
+{    
   /* because of the way named pipes work, we'll transfer the connected pipe
    * to the child, then reopen the server binding to continue listening */
 
@@ -232,7 +285,41 @@ static RPC_STATUS rpcrt4_conn_np_handoff(RpcConnection *old_conn, RpcConnection
   new_npc->ovl = old_npc->ovl;
   old_npc->pipe = 0;
   memset(&old_npc->ovl, 0, sizeof(old_npc->ovl));
-  return RPCRT4_OpenConnection(old_conn);
+  old_npc->listening = FALSE;
+}
+
+static RPC_STATUS rpcrt4_ncacn_np_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
+{
+  RPC_STATUS status;
+  LPSTR pname;
+  static LPCSTR prefix = "\\\\.";
+
+  rpcrt4_conn_np_handoff((RpcConnection_np *)old_conn, (RpcConnection_np *)new_conn);
+
+  pname = I_RpcAllocate(strlen(prefix) + strlen(old_conn->Endpoint) + 1);
+  strcat(strcpy(pname, prefix), old_conn->Endpoint);
+  status = rpcrt4_conn_create_pipe(old_conn, pname);
+  I_RpcFree(pname);
+
+  return status;
+}
+
+static RPC_STATUS rpcrt4_ncalrpc_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
+{
+  RPC_STATUS status;
+  LPSTR pname;
+  static LPCSTR prefix = "\\\\.\\pipe\\lrpc\\";
+
+  TRACE("%s\n", old_conn->Endpoint);
+
+  rpcrt4_conn_np_handoff((RpcConnection_np *)old_conn, (RpcConnection_np *)new_conn);
+
+  pname = I_RpcAllocate(strlen(prefix) + strlen(old_conn->Endpoint) + 1);
+  strcat(strcpy(pname, prefix), old_conn->Endpoint);
+  status = rpcrt4_conn_create_pipe(old_conn, pname);
+  I_RpcFree(pname);
+    
+  return status;
 }
 
 static int rpcrt4_conn_np_read(RpcConnection *Connection,
@@ -409,7 +496,7 @@ static void *rpcrt4_protseq_np_get_wait_array(RpcServerProtseq *protseq, void *p
     *count = 1;
     conn = CONTAINING_RECORD(protseq->conn, RpcConnection_np, common);
     while (conn) {
-        RPCRT4_OpenConnection(&conn->common);
+        rpcrt4_conn_listen_pipe(conn);
         if (conn->ovl.hEvent)
             (*count)++;
         conn = CONTAINING_RECORD(conn->common.Next, RpcConnection_np, common);
@@ -584,7 +671,7 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
   if (tcpc->sock != -1)
     return RPC_S_OK;
 
-  hints.ai_flags          = Connection->server ? AI_PASSIVE : 0;
+  hints.ai_flags          = 0;
   hints.ai_family         = PF_UNSPEC;
   hints.ai_socktype       = SOCK_STREAM;
   hints.ai_protocol       = IPPROTO_TCP;
@@ -620,45 +707,13 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
       continue;
     }
 
-    if (Connection->server)
+    if (0>connect(sock, ai_cur->ai_addr, ai_cur->ai_addrlen))
     {
-      ret = bind(sock, ai_cur->ai_addr, ai_cur->ai_addrlen);
-      if (ret < 0)
-      {
-        WARN("bind failed: %s\n", strerror(errno));
-        close(sock);
-        continue;
-      }
-      ret = listen(sock, 10);
-      if (ret < 0)
-      {
-        WARN("listen failed: %s\n", strerror(errno));
-        close(sock);
-        continue;
-      }
-      /* need a non-blocking socket, otherwise accept() has a potential
-       * race-condition (poll() says it is readable, connection drops,
-       * and accept() blocks until the next connection comes...)
-       */
-      ret = fcntl(sock, F_SETFL, O_NONBLOCK);
-      if (ret < 0)
-      {
-        WARN("couldn't make socket non-blocking, error %d\n", ret);
-        close(sock);
-        continue;
-      }
-      tcpc->sock = sock;
-    }
-    else /* it's a client */
-    {
-      if (0>connect(sock, ai_cur->ai_addr, ai_cur->ai_addrlen))
-      {
-        WARN("connect() failed: %s\n", strerror(errno));
-        close(sock);
-        continue;
-      }
-      tcpc->sock = sock;
+      WARN("connect() failed: %s\n", strerror(errno));
+      close(sock);
+      continue;
     }
+    tcpc->sock = sock;
 
     freeaddrinfo(ai);
     TRACE("connected\n");
@@ -670,6 +725,106 @@ static RPC_STATUS rpcrt4_ncacn_ip_tcp_open(RpcConnection* Connection)
   return RPC_S_SERVER_UNAVAILABLE;
 }
 
+static RPC_STATUS rpcrt4_protseq_ncacn_ip_tcp_open_endpoint(RpcServerProtseq *protseq, LPSTR endpoint)
+{
+    RPC_STATUS status;
+    int sock;
+    int ret;
+    struct addrinfo *ai;
+    struct addrinfo *ai_cur;
+    struct addrinfo hints;
+
+    TRACE("(%p, %s)\n", protseq, endpoint);
+
+    hints.ai_flags          = AI_PASSIVE /* for non-localhost addresses */;
+    hints.ai_family         = PF_UNSPEC;
+    hints.ai_socktype       = SOCK_STREAM;
+    hints.ai_protocol       = IPPROTO_TCP;
+    hints.ai_addrlen        = 0;
+    hints.ai_addr           = NULL;
+    hints.ai_canonname      = NULL;
+    hints.ai_next           = NULL;
+
+    ret = getaddrinfo(NULL, endpoint, &hints, &ai);
+    if (ret)
+    {
+        ERR("getaddrinfo for port %s failed: %s\n", endpoint,
+            gai_strerror(ret));
+        return RPC_S_SERVER_UNAVAILABLE;
+    }
+
+    for (ai_cur = ai; ai_cur; ai_cur = ai_cur->ai_next)
+    {
+        RpcConnection_tcp *tcpc;
+        if (TRACE_ON(rpc))
+        {
+            char host[256];
+            char service[256];
+            getnameinfo(ai_cur->ai_addr, ai_cur->ai_addrlen,
+                        host, sizeof(host), service, sizeof(service),
+                        NI_NUMERICHOST | NI_NUMERICSERV);
+            TRACE("trying %s:%s\n", host, service);
+        }
+
+        sock = socket(ai_cur->ai_family, ai_cur->ai_socktype, ai_cur->ai_protocol);
+        if (sock < 0)
+        {
+            WARN("socket() failed: %s\n", strerror(errno));
+            continue;
+        }
+
+        ret = bind(sock, ai_cur->ai_addr, ai_cur->ai_addrlen);
+        if (ret < 0)
+        {
+            WARN("bind failed: %s\n", strerror(errno));
+            close(sock);
+            continue;
+        }
+        status = RPCRT4_CreateConnection((RpcConnection **)&tcpc, TRUE,
+                                         protseq->Protseq, NULL, endpoint, NULL,
+                                         NULL, NULL);
+        if (status != RPC_S_OK)
+        {
+            close(sock);
+            continue;
+        }
+
+        ret = listen(sock, 10);
+        if (ret < 0)
+        {
+            WARN("listen failed: %s\n", strerror(errno));
+            close(sock);
+            continue;
+        }
+        /* need a non-blocking socket, otherwise accept() has a potential
+         * race-condition (poll() says it is readable, connection drops,
+         * and accept() blocks until the next connection comes...)
+         */
+        ret = fcntl(sock, F_SETFL, O_NONBLOCK);
+        if (ret < 0)
+        {
+            WARN("couldn't make socket non-blocking, error %d\n", ret);
+            close(sock);
+            continue;
+        }
+        tcpc->sock = sock;
+
+        freeaddrinfo(ai);
+
+        EnterCriticalSection(&protseq->cs);
+        tcpc->common.Next = protseq->conn;
+        protseq->conn = &tcpc->common;
+        LeaveCriticalSection(&protseq->cs);
+
+        TRACE("listening on %s\n", endpoint);
+        return RPC_S_OK;
+    }
+
+    freeaddrinfo(ai);
+    ERR("couldn't listen on port %s\n", endpoint);
+    return RPC_S_SERVER_UNAVAILABLE;
+}
+
 static RPC_STATUS rpcrt4_conn_tcp_handoff(RpcConnection *old_conn, RpcConnection *new_conn)
 {
   int ret;
@@ -906,7 +1061,6 @@ static void *rpcrt4_protseq_sock_get_wait_array(RpcServerProtseq *protseq, void
     *count = 1;
     conn = (RpcConnection_tcp *)protseq->conn;
     while (conn) {
-        RPCRT4_OpenConnection(&conn->common);
         if (conn->sock != -1)
             (*count)++;
         conn = (RpcConnection_tcp *)conn->common.Next;
@@ -1001,7 +1155,7 @@ static const struct connection_ops conn_protseq_list[] = {
     { EPM_PROTOCOL_NCACN, EPM_PROTOCOL_SMB },
     rpcrt4_conn_np_alloc,
     rpcrt4_ncacn_np_open,
-    rpcrt4_conn_np_handoff,
+    rpcrt4_ncacn_np_handoff,
     rpcrt4_conn_np_read,
     rpcrt4_conn_np_write,
     rpcrt4_conn_np_close,
@@ -1012,7 +1166,7 @@ static const struct connection_ops conn_protseq_list[] = {
     { EPM_PROTOCOL_NCALRPC, EPM_PROTOCOL_PIPE },
     rpcrt4_conn_np_alloc,
     rpcrt4_ncalrpc_open,
-    rpcrt4_conn_np_handoff,
+    rpcrt4_ncalrpc_handoff,
     rpcrt4_conn_np_read,
     rpcrt4_conn_np_write,
     rpcrt4_conn_np_close,
@@ -1042,6 +1196,7 @@ static const struct protseq_ops protseq_list[] =
         rpcrt4_protseq_np_get_wait_array,
         rpcrt4_protseq_np_free_wait_array,
         rpcrt4_protseq_np_wait_for_new_connection,
+        rpcrt4_protseq_ncacn_np_open_endpoint,
     },
     {
         "ncalrpc",
@@ -1050,6 +1205,7 @@ static const struct protseq_ops protseq_list[] =
         rpcrt4_protseq_np_get_wait_array,
         rpcrt4_protseq_np_free_wait_array,
         rpcrt4_protseq_np_wait_for_new_connection,
+        rpcrt4_protseq_ncalrpc_open_endpoint,
     },
     {
         "ncacn_ip_tcp",
@@ -1058,6 +1214,7 @@ static const struct protseq_ops protseq_list[] =
         rpcrt4_protseq_sock_get_wait_array,
         rpcrt4_protseq_sock_free_wait_array,
         rpcrt4_protseq_sock_wait_for_new_connection,
+        rpcrt4_protseq_ncacn_ip_tcp_open_endpoint,
     },
 };
 
@@ -1083,11 +1240,12 @@ static const struct connection_ops *rpcrt4_get_conn_protseq_ops(const char *prot
 
 /**** interface to rest of code ****/
 
-RPC_STATUS RPCRT4_OpenConnection(RpcConnection* Connection)
+RPC_STATUS RPCRT4_OpenClientConnection(RpcConnection* Connection)
 {
   TRACE("(Connection == ^%p)\n", Connection);
 
-  return Connection->ops->open_connection(Connection);
+  assert(!Connection->server);
+  return Connection->ops->open_connection_client(Connection);
 }
 
 RPC_STATUS RPCRT4_CloseConnection(RpcConnection* Connection)