/*
 * HTTP server driver
 *
 * Copyright 2019 Zebediah Figura
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
 */

#include <assert.h>
#include "ntstatus.h"
#define WIN32_NO_STATUS
#include "wine/http.h"
#include "winternl.h"
#include "ddk/wdm.h"
#include "wine/debug.h"
#include "wine/heap.h"
#include "wine/list.h"

static HANDLE directory_obj;
static DEVICE_OBJECT *device_obj;

WINE_DEFAULT_DEBUG_CHANNEL(http);

#define DECLARE_CRITICAL_SECTION(cs) \
    static CRITICAL_SECTION cs; \
    static CRITICAL_SECTION_DEBUG cs##_debug = \
    { 0, 0, &cs, { &cs##_debug.ProcessLocksList, &cs##_debug.ProcessLocksList }, \
      0, 0, { (DWORD_PTR)(__FILE__ ": " # cs) }}; \
    static CRITICAL_SECTION cs = { &cs##_debug, -1, 0, 0, 0, 0 };

DECLARE_CRITICAL_SECTION(http_cs);

static HANDLE request_thread, request_event;
static BOOL thread_stop;

static HTTP_REQUEST_ID req_id_counter;

struct connection
{
    struct list entry; /* in "connections" below */

    SOCKET socket;

    char *buffer;
    unsigned int len, size;

    /* If there is a request fully received and waiting to be read, the
     * "available" parameter will be TRUE. Either there is no queue matching
     * the URL of this request yet ("queue" is NULL), there is a queue but no
     * IRPs have arrived for this request yet ("queue" is non-NULL and "req_id"
     * is HTTP_NULL_ID), or an IRP has arrived but did not provide a large
     * enough buffer to read the whole request ("queue" is non-NULL and
     * "req_id" is not HTTP_NULL_ID).
     *
     * If "available" is FALSE, either we are waiting for a new request
     * ("req_id" is HTTP_NULL_ID), or we are waiting for the user to send a
     * response ("req_id" is not HTTP_NULL_ID). */
    BOOL available;
    struct request_queue *queue;
    HTTP_REQUEST_ID req_id;

    /* Things we already parsed out of the request header in parse_request().
     * These are valid only if "available" is TRUE. */
    unsigned int req_len;
    HTTP_VERB verb;
    HTTP_VERSION version;
    const char *url, *host;
    ULONG unk_verb_len, url_len, content_len;
};

static struct list connections = LIST_INIT(connections);

struct request_queue
{
    struct list entry;
    LIST_ENTRY irp_queue;
    HTTP_URL_CONTEXT context;
    char *url;
    SOCKET socket;
};

static struct list request_queues = LIST_INIT(request_queues);

static void accept_connection(SOCKET socket)
{
    struct connection *conn;
    ULONG true = 1;
    SOCKET peer;

    if ((peer = accept(socket, NULL, NULL)) == INVALID_SOCKET)
        return;

    if (!(conn = heap_alloc_zero(sizeof(*conn))))
    {
        ERR("Failed to allocate memory.\n");
        shutdown(peer, SD_BOTH);
        closesocket(peer);
        return;
    }
    if (!(conn->buffer = heap_alloc(8192)))
    {
        ERR("Failed to allocate buffer memory.\n");
        heap_free(conn);
        shutdown(peer, SD_BOTH);
        closesocket(peer);
        return;
    }
    conn->size = 8192;
    WSAEventSelect(peer, request_event, FD_READ | FD_CLOSE);
    ioctlsocket(peer, FIONBIO, &true);
    conn->socket = peer;
    list_add_head(&connections, &conn->entry);
}

static void close_connection(struct connection *conn)
{
    heap_free(conn->buffer);
    shutdown(conn->socket, SD_BOTH);
    closesocket(conn->socket);
    list_remove(&conn->entry);
    heap_free(conn);
}

static HTTP_VERB parse_verb(const char *verb, int len)
{
    static const char *const verbs[] =
    {
        "OPTIONS",
        "GET",
        "HEAD",
        "POST",
        "PUT",
        "DELETE",
        "TRACE",
        "CONNECT",
        "TRACK",
        "MOVE",
        "COPY",
        "PROPFIND",
        "PROPPATCH",
        "MKCOL",
        "LOCK",
        "UNLOCK",
        "SEARCH",
    };
    unsigned int i;

    for (i = 0; i < ARRAY_SIZE(verbs); ++i)
    {
        if (!strncmp(verb, verbs[i], len))
            return HttpVerbOPTIONS + i;
    }
    return HttpVerbUnknown;
}

/* Return the length of a token, as defined in RFC 2616 section 2.2. */
static int parse_token(const char *str, const char *end)
{
    const char *p;
    for (p = str; !end || p < end; ++p)
    {
        if (!isgraph(*p) || strchr("()<>@,;:\\\"/[]?={}", *p))
            break;
    }
    return p - str;
}

static HTTP_HEADER_ID parse_header_name(const char *header, int len)
{
    static const char *const headers[] =
    {
        "Cache-Control",
        "Connection",
        "Date",
        "Keep-Alive",
        "Pragma",
        "Trailer",
        "Transfer-Encoding",
        "Upgrade",
        "Via",
        "Warning",
        "Allow",
        "Content-Length",
        "Content-Type",
        "Content-Encoding",
        "Content-Language",
        "Content-Location",
        "Content-MD5",
        "Content-Range",
        "Expires",
        "Last-Modified",
        "Accept",
        "Accept-Charset",
        "Accept-Encoding",
        "Accept-Language",
        "Authorization",
        "Cookie",
        "Expect",
        "From",
        "Host",
        "If-Match",
        "If-Modified-Since",
        "If-None-Match",
        "If-Range",
        "If-Unmodified-Since",
        "Max-Forwards",
        "Proxy-Authorization",
        "Referer",
        "Range",
        "TE",
        "Translate",
        "User-Agent",
    };
    unsigned int i;

    for (i = 0; i < ARRAY_SIZE(headers); ++i)
    {
        if (!strncmp(header, headers[i], len))
            return i;
    }
    return HttpHeaderRequestMaximum;
}

static void parse_header(const char *name, int *name_len, const char **value, int *value_len)
{
    const char *p = name;
    *name_len = parse_token(name, NULL);
    p += *name_len;
    while (*p == ' ' || *p == '\t') ++p;
    ++p; /* skip colon */
    while (*p == ' ' || *p == '\t') ++p;
    *value = p;
    while (isprint(*p) || *p == '\t') ++p;
    while (isspace(*p)) --p; /* strip trailing LWS */
    *value_len = p - *value + 1;
}

#define http_unknown_header http_unknown_header_64
#define http_data_chunk http_data_chunk_64
#define http_request http_request_64
#define complete_irp complete_irp_64
#define POINTER ULONGLONG
#include "request.h"
#undef http_unknown_header
#undef http_data_chunk
#undef http_request
#undef complete_irp
#undef POINTER

#define http_unknown_header http_unknown_header_32
#define http_data_chunk http_data_chunk_32
#define http_request http_request_32
#define complete_irp complete_irp_32
#define POINTER ULONG
#include "request.h"
#undef http_unknown_header
#undef http_data_chunk
#undef http_request
#undef complete_irp
#undef POINTER

static NTSTATUS complete_irp(struct connection *conn, IRP *irp)
{
    const struct http_receive_request_params params
            = *(struct http_receive_request_params *)irp->AssociatedIrp.SystemBuffer;

    TRACE("Completing IRP %p.\n", irp);

    if (!conn->req_id)
        conn->req_id = ++req_id_counter;

    if (params.bits == 32)
        return complete_irp_32(conn, irp);
    else
        return complete_irp_64(conn, irp);
}

/* Complete an IOCTL_HTTP_RECEIVE_REQUEST IRP if there is one to complete. */
static void try_complete_irp(struct connection *conn)
{
    LIST_ENTRY *entry;
    if (conn->queue && (entry = RemoveHeadList(&conn->queue->irp_queue)) != &conn->queue->irp_queue)
    {
        IRP *irp = CONTAINING_RECORD(entry, IRP, Tail.Overlay.ListEntry);
        irp->IoStatus.Status = complete_irp(conn, irp);
        IoCompleteRequest(irp, IO_NO_INCREMENT);
    }
}

/* Return 1 if str matches expect, 0 if str is incomplete, -1 if they don't match. */
static int compare_exact(const char *str, const char *expect, const char *end)
{
    while (*expect)
    {
        if (str >= end) return 0;
        if (*str++ != *expect++) return -1;
    }
    return 1;
}

static int parse_number(const char *str, const char **endptr, const char *end)
{
    int n = 0;
    while (str < end && isdigit(*str))
        n = n * 10 + (*str++ - '0');
    *endptr = str;
    return n;
}

static BOOL host_matches(const struct connection *conn, const struct request_queue *queue)
{
    const char *conn_host = (conn->url[0] == '/') ? conn->host : conn->url + 7;

    if (queue->url[7] == '+')
    {
        const char *queue_port = strchr(queue->url + 7, ':');
        return !strncmp(queue_port, strchr(conn_host, ':'), strlen(queue_port) - 1 /* strip final slash */);
    }

    return !memicmp(queue->url + 7, conn_host, strlen(queue->url) - 8 /* strip final slash */);
}

/* Upon receiving a request, parse it to ensure that it is a valid HTTP request,
 * and mark down some information that we will use later. Returns 1 if we parsed
 * a complete request, 0 if incomplete, -1 if invalid. */
static int parse_request(struct connection *conn)
{
    const char *const req = conn->buffer, *const end = conn->buffer + conn->len;
    struct request_queue *queue;
    const char *p = req, *q;
    int len, ret;

    if (!conn->len) return 0;

    TRACE("%s\n", wine_dbgstr_an(conn->buffer, conn->len));

    len = parse_token(p, end);
    if (p + len >= end) return 0;
    if (!len || p[len] != ' ') return -1;

    /* verb */
    if ((conn->verb = parse_verb(p, len)) == HttpVerbUnknown)
        conn->unk_verb_len = len;
    p += len + 1;

    TRACE("Got verb %u (%s).\n", conn->verb, debugstr_an(req, len));

    /* URL */
    conn->url = p;
    while (p < end && isgraph(*p)) ++p;
    conn->url_len = p - conn->url;
    if (p >= end) return 0;
    if (!conn->url_len) return -1;

    TRACE("Got URI %s.\n", debugstr_an(conn->url, conn->url_len));

    /* version */
    if ((ret = compare_exact(p, " HTTP/", end)) <= 0) return ret;
    p += 6;
    conn->version.MajorVersion = parse_number(p, &q, end);
    if (q >= end) return 0;
    if (q == p || *q != '.') return -1;
    p = q + 1;
    if (p >= end) return 0;
    conn->version.MinorVersion = parse_number(p, &q, end);
    if (q >= end) return 0;
    if (q == p) return -1;
    p = q;
    if ((ret = compare_exact(p, "\r\n", end)) <= 0) return ret;
    p += 2;

    TRACE("Got version %hu.%hu.\n", conn->version.MajorVersion, conn->version.MinorVersion);

    /* headers */
    conn->host = NULL;
    conn->content_len = 0;
    for (;;)
    {
        const char *name = p;

        if (!(ret = compare_exact(p, "\r\n", end))) return 0;
        else if (ret > 0) break;

        len = parse_token(p, end);
        if (p + len >= end) return 0;
        if (!len) return -1;
        p += len;
        while (p < end && (*p == ' ' || *p == '\t')) ++p;
        if (p >= end) return 0;
        if (*p != ':') return -1;
        ++p;
        while (p < end && (*p == ' ' || *p == '\t')) ++p;

        TRACE("Got %s header.\n", debugstr_an(name, len));

        if (!strncmp(name, "Host", len))
            conn->host = p;
        else if (!strncmp(name, "Content-Length", len))
        {
            conn->content_len = parse_number(p, &q, end);
            if (q >= end) return 0;
            if (q == p) return -1;
        }
        else if (!strncmp(name, "Transfer-Encoding", len))
            FIXME("Unhandled Transfer-Encoding header.\n");
        while (p < end && (isprint(*p) || *p == '\t')) ++p;
        if ((ret = compare_exact(p, "\r\n", end)) <= 0) return ret;
        p += 2;
    }
    p += 2;
    if (conn->url[0] == '/' && !conn->host) return -1;

    if (end - p < conn->content_len) return 0;

    conn->req_len = (p - req) + conn->content_len;

    TRACE("Received a full request, length %u bytes.\n", conn->req_len);

    conn->queue = NULL;
    /* Find a queue which can receive this request. */
    LIST_FOR_EACH_ENTRY(queue, &request_queues, struct request_queue, entry)
    {
        if (host_matches(conn, queue))
        {
            TRACE("Assigning request to queue %p.\n", queue);
            conn->queue = queue;
            break;
        }
    }

    /* Stop selecting on incoming data until a response is queued. */
    WSAEventSelect(conn->socket, request_event, FD_CLOSE);

    conn->available = TRUE;
    try_complete_irp(conn);

    return 1;
}

static void format_date(char *buffer)
{
    static const char day_names[7][4] = {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"};
    static const char month_names[12][4] =
            {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
    SYSTEMTIME date;
    GetSystemTime(&date);
    sprintf(buffer + strlen(buffer), "Date: %s, %02u %s %u %02u:%02u:%02u GMT\r\n",
            day_names[date.wDayOfWeek], date.wDay, month_names[date.wMonth - 1],
            date.wYear, date.wHour, date.wMinute, date.wSecond);
}

/* Send a 400 Bad Request response. */
static void send_400(struct connection *conn)
{
    static const char response_header[] = "HTTP/1.1 400 Bad Request\r\n";
    static const char response_body[] =
        "Content-Type: text/html; charset=utf-8\r\n"
        "Content-Language: en\r\n"
        "Connection: close\r\n";
    char buffer[sizeof(response_header) + sizeof(response_body) + 37];

    strcpy(buffer, response_header);
    format_date(buffer + strlen(buffer));
    strcat(buffer, response_body);
    if (send(conn->socket, buffer, strlen(buffer), 0) < 0)
        ERR("Failed to send 400 response, error %u.\n", WSAGetLastError());
}

static void receive_data(struct connection *conn)
{
    int len, ret;

    /* We might be waiting for an IRP, but always call recv() anyway, since we
     * might have been woken up by the socket closing. */
    if ((len = recv(conn->socket, conn->buffer + conn->len, conn->size - conn->len, 0)) <= 0)
    {
        if (WSAGetLastError() == WSAEWOULDBLOCK)
            return; /* nothing to receive */
        else if (!len)
            TRACE("Connection was shut down by peer.\n");
        else
            ERR("Got error %u; shutting down connection.\n", WSAGetLastError());
        close_connection(conn);
        return;
    }
    conn->len += len;

    if (conn->available)
        return; /* waiting for an HttpReceiveHttpRequest() call */
    if (conn->req_id != HTTP_NULL_ID)
        return; /* waiting for an HttpSendHttpResponse() call */

    TRACE("Received %u bytes of data.\n", len);

    if (!(ret = parse_request(conn)))
    {
        ULONG available;
        ioctlsocket(conn->socket, FIONREAD, &available);
        if (available)
        {
            TRACE("%u more bytes of data available, trying with larger buffer.\n", available);
            if (!(conn->buffer = heap_realloc(conn->buffer, conn->len + available)))
            {
                ERR("Failed to allocate %u bytes of memory.\n", conn->len + available);
                close_connection(conn);
                return;
            }
            conn->size = conn->len + available;

            if ((len = recv(conn->socket, conn->buffer + conn->len, conn->size - conn->len, 0)) < 0)
            {
                ERR("Got error %u; shutting down connection.\n", WSAGetLastError());
                close_connection(conn);
                return;
            }
            TRACE("Received %u bytes of data.\n", len);
            conn->len += len;
            ret = parse_request(conn);
        }
    }
    if (!ret)
        TRACE("Request is incomplete, waiting for more data.\n");
    else if (ret < 0)
    {
        WARN("Failed to parse request; shutting down connection.\n");
        send_400(conn);
        close_connection(conn);
    }
}

static DWORD WINAPI request_thread_proc(void *arg)
{
    struct connection *conn, *cursor;
    struct request_queue *queue;

    TRACE("Starting request thread.\n");

    while (!WaitForSingleObject(request_event, INFINITE))
    {
        EnterCriticalSection(&http_cs);

        LIST_FOR_EACH_ENTRY(queue, &request_queues, struct request_queue, entry)
        {
            if (queue->socket != -1)
                accept_connection(queue->socket);
        }

        LIST_FOR_EACH_ENTRY_SAFE(conn, cursor, &connections, struct connection, entry)
        {
            receive_data(conn);
        }

        LeaveCriticalSection(&http_cs);
    }

    TRACE("Stopping request thread.\n");

    return 0;
}

static NTSTATUS http_add_url(struct request_queue *queue, IRP *irp)
{
    const struct http_add_url_params *params = irp->AssociatedIrp.SystemBuffer;
    struct sockaddr_in addr;
    struct connection *conn;
    unsigned int count = 0;
    char *url, *endptr;
    ULONG true = 1;
    const char *p;
    SOCKET s;

    TRACE("host %s, context %s.\n", debugstr_a(params->url), wine_dbgstr_longlong(params->context));

    if (!strncmp(params->url, "https://", 8))
    {
        FIXME("HTTPS is not implemented.\n");
        return STATUS_NOT_IMPLEMENTED;
    }
    else if (strncmp(params->url, "http://", 7) || !strchr(params->url + 7, ':')
            || params->url[strlen(params->url) - 1] != '/')
        return STATUS_INVALID_PARAMETER;
    if (!(addr.sin_port = htons(strtol(strchr(params->url + 7, ':') + 1, &endptr, 10))) || *endptr != '/')
        return STATUS_INVALID_PARAMETER;

    if (!(url = heap_alloc(strlen(params->url)+1)))
        return STATUS_NO_MEMORY;
    strcpy(url, params->url);

    for (p = url; *p; ++p)
        if (*p == '/') ++count;
    if (count > 3)
        FIXME("Binding to relative URIs is not implemented; binding to all URIs instead.\n");

    EnterCriticalSection(&http_cs);

    if (queue->url && !strcmp(queue->url, url))
    {
        LeaveCriticalSection(&http_cs);
        heap_free(url);
        return STATUS_OBJECT_NAME_COLLISION;
    }
    else if (queue->url)
    {
        FIXME("Binding to multiple URLs is not implemented.\n");
        LeaveCriticalSection(&http_cs);
        heap_free(url);
        return STATUS_NOT_IMPLEMENTED;
    }

    if ((s = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET)
    {
        ERR("Failed to create socket, error %u.\n", WSAGetLastError());
        LeaveCriticalSection(&http_cs);
        heap_free(url);
        return STATUS_UNSUCCESSFUL;
    }

    addr.sin_family = AF_INET;
    addr.sin_addr.S_un.S_addr = INADDR_ANY;
    if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1)
    {
        LeaveCriticalSection(&http_cs);
        closesocket(s);
        heap_free(url);
        if (WSAGetLastError() == WSAEADDRINUSE)
        {
            WARN("Address %s is already in use.\n", debugstr_a(params->url));
            return STATUS_SHARING_VIOLATION;
        }
        else if (WSAGetLastError() == WSAEACCES)
        {
            WARN("Not enough permissions to bind to address %s.\n", debugstr_a(params->url));
            return STATUS_ACCESS_DENIED;
        }
        ERR("Failed to bind socket, error %u.\n", WSAGetLastError());
        return STATUS_UNSUCCESSFUL;
    }

    if (listen(s, SOMAXCONN) == -1)
    {
        ERR("Failed to listen to port %u, error %u.\n", addr.sin_port, WSAGetLastError());
        LeaveCriticalSection(&http_cs);
        closesocket(s);
        heap_free(url);
        return STATUS_OBJECT_NAME_COLLISION;
    }

    ioctlsocket(s, FIONBIO, &true);
    WSAEventSelect(s, request_event, FD_ACCEPT);
    queue->socket = s;
    queue->url = url;
    queue->context = params->context;

    /* See if any pending requests now match this queue. */
    LIST_FOR_EACH_ENTRY(conn, &connections, struct connection, entry)
    {
        if (conn->available && !conn->queue && host_matches(conn, queue))
        {
            conn->queue = queue;
            try_complete_irp(conn);
        }
    }

    LeaveCriticalSection(&http_cs);

    return STATUS_SUCCESS;
}

static NTSTATUS http_remove_url(struct request_queue *queue, IRP *irp)
{
    const char *url = irp->AssociatedIrp.SystemBuffer;

    TRACE("host %s.\n", debugstr_a(url));

    EnterCriticalSection(&http_cs);

    if (!queue->url || strcmp(url, queue->url))
    {
        LeaveCriticalSection(&http_cs);
        return STATUS_OBJECT_NAME_NOT_FOUND;
    }
    heap_free(queue->url);
    queue->url = NULL;

    LeaveCriticalSection(&http_cs);
    return STATUS_SUCCESS;
}

static struct connection *get_connection(HTTP_REQUEST_ID req_id)
{
    struct connection *conn;

    LIST_FOR_EACH_ENTRY(conn, &connections, struct connection, entry)
    {
        if (conn->req_id == req_id)
            return conn;
    }
    return NULL;
}

static void WINAPI http_receive_request_cancel(DEVICE_OBJECT *device, IRP *irp)
{
    TRACE("device %p, irp %p.\n", device, irp);

    IoReleaseCancelSpinLock(irp->CancelIrql);

    EnterCriticalSection(&http_cs);
    RemoveEntryList(&irp->Tail.Overlay.ListEntry);
    LeaveCriticalSection(&http_cs);

    irp->IoStatus.Status = STATUS_CANCELLED;
    IoCompleteRequest(irp, IO_NO_INCREMENT);
}

static NTSTATUS http_receive_request(struct request_queue *queue, IRP *irp)
{
    const struct http_receive_request_params *params = irp->AssociatedIrp.SystemBuffer;
    struct connection *conn;
    NTSTATUS ret;

    TRACE("addr %s, id %s, flags %#x, bits %u.\n", wine_dbgstr_longlong(params->addr),
            wine_dbgstr_longlong(params->id), params->flags, params->bits);

    EnterCriticalSection(&http_cs);

    if ((conn = get_connection(params->id)) && conn->available && conn->queue == queue)
    {
        ret = complete_irp(conn, irp);
        LeaveCriticalSection(&http_cs);
        return ret;
    }

    if (params->id == HTTP_NULL_ID)
    {
        TRACE("Queuing IRP %p.\n", irp);

        IoSetCancelRoutine(irp, http_receive_request_cancel);
        if (irp->Cancel && !IoSetCancelRoutine(irp, NULL))
        {
            /* The IRP was canceled before we set the cancel routine. */
            ret = STATUS_CANCELLED;
        }
        else
        {
            IoMarkIrpPending(irp);
            InsertTailList(&queue->irp_queue, &irp->Tail.Overlay.ListEntry);
            ret = STATUS_PENDING;
        }
    }
    else
        ret = STATUS_CONNECTION_INVALID;

    LeaveCriticalSection(&http_cs);

    return ret;
}

static NTSTATUS http_send_response(struct request_queue *queue, IRP *irp)
{
    const struct http_response *response = irp->AssociatedIrp.SystemBuffer;
    struct connection *conn;

    TRACE("id %s, len %d.\n", wine_dbgstr_longlong(response->id), response->len);

    EnterCriticalSection(&http_cs);

    if ((conn = get_connection(response->id)))
    {
        if (send(conn->socket, response->buffer, response->len, 0) >= 0)
        {
            if (conn->content_len)
            {
                /* Discard whatever entity body is left. */
                memmove(conn->buffer, conn->buffer + conn->content_len, conn->len - conn->content_len);
                conn->len -= conn->content_len;
            }

            conn->queue = NULL;
            conn->req_id = HTTP_NULL_ID;
            WSAEventSelect(conn->socket, request_event, FD_READ | FD_CLOSE);
            irp->IoStatus.Information = response->len;
            /* We might have another request already in the buffer. */
            if (parse_request(conn) < 0)
            {
                WARN("Failed to parse request; shutting down connection.\n");
                send_400(conn);
                close_connection(conn);
            }
        }
        else
        {
            ERR("Got error %u; shutting down connection.\n", WSAGetLastError());
            close_connection(conn);
        }

        LeaveCriticalSection(&http_cs);
        return STATUS_SUCCESS;
    }

    LeaveCriticalSection(&http_cs);
    return STATUS_CONNECTION_INVALID;
}

static NTSTATUS http_receive_body(struct request_queue *queue, IRP *irp)
{
    const struct http_receive_body_params *params = irp->AssociatedIrp.SystemBuffer;
    IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp);
    const DWORD output_len = stack->Parameters.DeviceIoControl.OutputBufferLength;
    struct connection *conn;
    NTSTATUS ret;

    TRACE("id %s, bits %u.\n", wine_dbgstr_longlong(params->id), params->bits);

    EnterCriticalSection(&http_cs);

    if ((conn = get_connection(params->id)))
    {
        TRACE("%u bits remaining.\n", conn->content_len);

        if (conn->content_len)
        {
            ULONG len = min(conn->content_len, output_len);
            memcpy(irp->AssociatedIrp.SystemBuffer, conn->buffer, len);
            memmove(conn->buffer, conn->buffer + len, conn->len - len);
            conn->content_len -= len;
            conn->len -= len;

            irp->IoStatus.Information = len;
            ret = STATUS_SUCCESS;
        }
        else
            ret = STATUS_END_OF_FILE;
    }
    else
        ret = STATUS_CONNECTION_INVALID;

    LeaveCriticalSection(&http_cs);

    return ret;
}

static NTSTATUS WINAPI dispatch_ioctl(DEVICE_OBJECT *device, IRP *irp)
{
    IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp);
    struct request_queue *queue = stack->FileObject->FsContext;
    NTSTATUS ret;

    switch (stack->Parameters.DeviceIoControl.IoControlCode)
    {
    case IOCTL_HTTP_ADD_URL:
        ret = http_add_url(queue, irp);
        break;
    case IOCTL_HTTP_REMOVE_URL:
        ret = http_remove_url(queue, irp);
        break;
    case IOCTL_HTTP_RECEIVE_REQUEST:
        ret = http_receive_request(queue, irp);
        break;
    case IOCTL_HTTP_SEND_RESPONSE:
        ret = http_send_response(queue, irp);
        break;
    case IOCTL_HTTP_RECEIVE_BODY:
        ret = http_receive_body(queue, irp);
        break;
    default:
        FIXME("Unhandled ioctl %#x.\n", stack->Parameters.DeviceIoControl.IoControlCode);
        ret = STATUS_NOT_IMPLEMENTED;
    }

    if (ret != STATUS_PENDING)
    {
        irp->IoStatus.Status = ret;
        IoCompleteRequest(irp, IO_NO_INCREMENT);
    }
    return ret;
}

static NTSTATUS WINAPI dispatch_create(DEVICE_OBJECT *device, IRP *irp)
{
    IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp);
    struct request_queue *queue;

    if (!(queue = heap_alloc_zero(sizeof(*queue))))
        return STATUS_NO_MEMORY;
    stack->FileObject->FsContext = queue;
    InitializeListHead(&queue->irp_queue);

    EnterCriticalSection(&http_cs);
    list_add_head(&request_queues, &queue->entry);
    LeaveCriticalSection(&http_cs);

    TRACE("Created queue %p.\n", queue);

    irp->IoStatus.Status = STATUS_SUCCESS;
    IoCompleteRequest(irp, IO_NO_INCREMENT);
    return STATUS_SUCCESS;
}

static void close_queue(struct request_queue *queue)
{
    EnterCriticalSection(&http_cs);
    list_remove(&queue->entry);
    if (queue->socket != -1)
    {
        shutdown(queue->socket, SD_BOTH);
        closesocket(queue->socket);
    }
    LeaveCriticalSection(&http_cs);

    heap_free(queue->url);
    heap_free(queue);
}

static NTSTATUS WINAPI dispatch_close(DEVICE_OBJECT *device, IRP *irp)
{
    IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp);
    struct request_queue *queue = stack->FileObject->FsContext;
    LIST_ENTRY *entry;

    TRACE("Closing queue %p.\n", queue);

    EnterCriticalSection(&http_cs);

    while ((entry = queue->irp_queue.Flink) != &queue->irp_queue)
    {
        IRP *queued_irp = CONTAINING_RECORD(entry, IRP, Tail.Overlay.ListEntry);
        IoCancelIrp(queued_irp);
    }

    LeaveCriticalSection(&http_cs);

    close_queue(queue);

    irp->IoStatus.Status = STATUS_SUCCESS;
    IoCompleteRequest(irp, IO_NO_INCREMENT);
    return STATUS_SUCCESS;
}

static void WINAPI unload(DRIVER_OBJECT *driver)
{
    struct request_queue *queue, *queue_next;
    struct connection *conn, *conn_next;

    thread_stop = TRUE;
    SetEvent(request_event);
    WaitForSingleObject(request_thread, INFINITE);
    CloseHandle(request_thread);
    CloseHandle(request_event);

    LIST_FOR_EACH_ENTRY_SAFE(conn, conn_next, &connections, struct connection, entry)
    {
        close_connection(conn);
    }

    LIST_FOR_EACH_ENTRY_SAFE(queue, queue_next, &request_queues, struct request_queue, entry)
    {
        close_queue(queue);
    }

    WSACleanup();

    IoDeleteDevice(device_obj);
    NtClose(directory_obj);
}

NTSTATUS WINAPI DriverEntry(DRIVER_OBJECT *driver, UNICODE_STRING *path)
{
    OBJECT_ATTRIBUTES attr = {sizeof(attr)};
    UNICODE_STRING string;
    WSADATA wsadata;
    NTSTATUS ret;

    TRACE("driver %p, path %s.\n", driver, debugstr_w(path->Buffer));

    RtlInitUnicodeString(&string, L"\\Device\\Http");
    attr.ObjectName = &string;
    if ((ret = NtCreateDirectoryObject(&directory_obj, 0, &attr)) && ret != STATUS_OBJECT_NAME_COLLISION)
        ERR("Failed to create \\Device\\Http directory, status %#x.\n", ret);

    RtlInitUnicodeString(&string, L"\\Device\\Http\\ReqQueue");
    if ((ret = IoCreateDevice(driver, 0, &string, FILE_DEVICE_UNKNOWN, 0, FALSE, &device_obj)))
    {
        ERR("Failed to create request queue device, status %#x.\n", ret);
        NtClose(directory_obj);
        return ret;
    }

    driver->MajorFunction[IRP_MJ_CREATE] = dispatch_create;
    driver->MajorFunction[IRP_MJ_CLOSE] = dispatch_close;
    driver->MajorFunction[IRP_MJ_DEVICE_CONTROL] = dispatch_ioctl;
    driver->DriverUnload = unload;

    WSAStartup(MAKEWORD(1,1), &wsadata);

    request_event = CreateEventW(NULL, FALSE, FALSE, NULL);
    request_thread = CreateThread(NULL, 0, request_thread_proc, NULL, 0, NULL);

    return STATUS_SUCCESS;
}