From 325984ded50b354686e5a454aa5aca3aafa81432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Bernon?= Date: Tue, 10 Aug 2021 11:31:17 +0200 Subject: [PATCH] hidclass.sys: Use a simpler ring buffer with ref-counted reports. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: RĂ©mi Bernon Signed-off-by: Alexandre Julliard --- dlls/hidclass.sys/Makefile.in | 1 - dlls/hidclass.sys/buffer.c | 204 ---------------------------------- dlls/hidclass.sys/device.c | 162 ++++++++++++++++++++------- dlls/hidclass.sys/hid.h | 25 +++-- 4 files changed, 136 insertions(+), 256 deletions(-) delete mode 100644 dlls/hidclass.sys/buffer.c diff --git a/dlls/hidclass.sys/Makefile.in b/dlls/hidclass.sys/Makefile.in index 09281c118b4..344fd10bc86 100644 --- a/dlls/hidclass.sys/Makefile.in +++ b/dlls/hidclass.sys/Makefile.in @@ -5,7 +5,6 @@ IMPORTS = hal ntoskrnl user32 EXTRADLLFLAGS = -mno-cygwin C_SRCS = \ - buffer.c \ descriptor.c \ device.c \ pnp.c diff --git a/dlls/hidclass.sys/buffer.c b/dlls/hidclass.sys/buffer.c deleted file mode 100644 index 59c0edf29f9..00000000000 --- a/dlls/hidclass.sys/buffer.c +++ /dev/null @@ -1,204 +0,0 @@ -/* Implementation of a ring buffer for reports - * - * Copyright 2015 CodeWeavers, Aric Stewart - * - * This library is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 2.1 of the License, or (at your option) any later version. - * - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this library; if not, write to the Free Software - * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA - */ - -#include -#include -#include "hid.h" - -#include "wine/debug.h" - -WINE_DEFAULT_DEBUG_CHANNEL(hid); - -#define POINTER_UNUSED 0xffffffff -#define BASE_BUFFER_SIZE 32 -#define MIN_BUFFER_SIZE 2 -#define MAX_BUFFER_SIZE 512 - -struct ReportRingBuffer -{ - UINT start, end, size; - - UINT *pointers; - UINT pointer_alloc; - UINT buffer_size; - - CRITICAL_SECTION lock; - - BYTE *buffer; -}; - -struct ReportRingBuffer* RingBuffer_Create(UINT buffer_size) -{ - struct ReportRingBuffer *ring; - int i; - - TRACE("Create Ring Buffer with buffer size %i\n",buffer_size); - - ring = malloc(sizeof(*ring)); - if (!ring) - return NULL; - ring->start = ring->end = 0; - ring->size = BASE_BUFFER_SIZE; - ring->buffer_size = buffer_size; - ring->pointer_alloc = 2; - ring->pointers = malloc(sizeof(UINT) * ring->pointer_alloc); - if (!ring->pointers) - { - free(ring); - return NULL; - } - for (i = 0; i < ring->pointer_alloc; i++) - ring->pointers[i] = POINTER_UNUSED; - ring->buffer = malloc(buffer_size * ring->size); - if (!ring->buffer) - { - free(ring->pointers); - free(ring); - return NULL; - } - InitializeCriticalSection(&ring->lock); - ring->lock.DebugInfo->Spare[0] = (DWORD_PTR)(__FILE__ ": RingBuffer.lock"); - return ring; -} - -void RingBuffer_Destroy(struct ReportRingBuffer *ring) -{ - free(ring->buffer); - free(ring->pointers); - ring->lock.DebugInfo->Spare[0] = 0; - DeleteCriticalSection(&ring->lock); - free(ring); -} - -UINT RingBuffer_GetBufferSize(struct ReportRingBuffer *ring) -{ - return ring->buffer_size; -} - -UINT RingBuffer_GetSize(struct ReportRingBuffer *ring) -{ - return ring->size; -} - -NTSTATUS RingBuffer_SetSize(struct ReportRingBuffer *ring, UINT size) -{ - BYTE* new_buffer; - int i; - - if (size < MIN_BUFFER_SIZE || size > MAX_BUFFER_SIZE) - return STATUS_INVALID_PARAMETER; - if (size == ring->size) - return STATUS_SUCCESS; - - EnterCriticalSection(&ring->lock); - ring->start = ring->end = 0; - for (i = 0; i < ring->pointer_alloc; i++) - { - if (ring->pointers[i] != POINTER_UNUSED) - ring->pointers[i] = 0; - } - new_buffer = malloc(ring->buffer_size * size); - if (!new_buffer) - { - LeaveCriticalSection(&ring->lock); - return STATUS_NO_MEMORY; - } - free(ring->buffer); - ring->buffer = new_buffer; - ring->size = size; - LeaveCriticalSection(&ring->lock); - return STATUS_SUCCESS; -} - -void RingBuffer_ReadNew(struct ReportRingBuffer *ring, UINT index, void *output, UINT *size) -{ - void *ret = NULL; - - EnterCriticalSection(&ring->lock); - if (index >= ring->pointer_alloc || ring->pointers[index] == POINTER_UNUSED) - { - LeaveCriticalSection(&ring->lock); - *size = 0; - return; - } - if (ring->pointers[index] == ring->end) - { - LeaveCriticalSection(&ring->lock); - *size = 0; - } - else - { - ret = &ring->buffer[ring->pointers[index] * ring->buffer_size]; - memcpy(output, ret, ring->buffer_size); - ring->pointers[index]++; - if (ring->pointers[index] == ring->size) - ring->pointers[index] = 0; - LeaveCriticalSection(&ring->lock); - *size = ring->buffer_size; - } -} - -UINT RingBuffer_AddPointer(struct ReportRingBuffer *ring) -{ - UINT idx; - EnterCriticalSection(&ring->lock); - for (idx = 0; idx < ring->pointer_alloc; idx++) - if (ring->pointers[idx] == POINTER_UNUSED) - break; - if (idx >= ring->pointer_alloc) - { - int count = idx = ring->pointer_alloc; - ring->pointer_alloc *= 2; - ring->pointers = realloc(ring->pointers, sizeof(UINT) * ring->pointer_alloc); - for( ;count < ring->pointer_alloc; count++) - ring->pointers[count] = POINTER_UNUSED; - } - ring->pointers[idx] = ring->end; - LeaveCriticalSection(&ring->lock); - return idx; -} - -void RingBuffer_RemovePointer(struct ReportRingBuffer *ring, UINT index) -{ - EnterCriticalSection(&ring->lock); - if (index < ring->pointer_alloc) - ring->pointers[index] = POINTER_UNUSED; - LeaveCriticalSection(&ring->lock); -} - -void RingBuffer_Write(struct ReportRingBuffer *ring, void *data) -{ - UINT i; - - EnterCriticalSection(&ring->lock); - memcpy(&ring->buffer[ring->end * ring->buffer_size], data, ring->buffer_size); - ring->end++; - if (ring->end == ring->size) - ring->end = 0; - if (ring->start == ring->end) - { - ring->start++; - if (ring->start == ring->size) - ring->start = 0; - } - for (i = 0; i < ring->pointer_alloc; i++) - if (ring->pointers[i] == ring->end) - ring->pointers[i] = ring->start; - LeaveCriticalSection(&ring->lock); -} diff --git a/dlls/hidclass.sys/device.c b/dlls/hidclass.sys/device.c index 8510f25ae41..6e82608aa2e 100644 --- a/dlls/hidclass.sys/device.c +++ b/dlls/hidclass.sys/device.c @@ -79,11 +79,111 @@ static void WINAPI read_cancel_routine(DEVICE_OBJECT *device, IRP *irp) IoCompleteRequest(irp, IO_NO_INCREMENT); } +static struct hid_report *hid_report_create( HID_XFER_PACKET *packet ) +{ + struct hid_report *report; + + if (!(report = malloc( offsetof( struct hid_report, buffer[packet->reportBufferLen] ) ))) + return NULL; + report->ref = 1; + report->length = packet->reportBufferLen; + memcpy( report->buffer, packet->reportBuffer, report->length ); + + return report; +} + +static void hid_report_incref( struct hid_report *report ) +{ + InterlockedIncrement( &report->ref ); +} + +static void hid_report_decref( struct hid_report *report ) +{ + if (!report) return; + if (InterlockedDecrement( &report->ref ) == 0) free( report ); +} + +static struct hid_report_queue *hid_report_queue_create( void ) +{ + struct hid_report_queue *queue; + + if (!(queue = calloc( 1, sizeof(struct hid_report_queue) ))) return NULL; + KeInitializeSpinLock( &queue->lock ); + list_init( &queue->entry ); + queue->length = 32; + queue->read_idx = 0; + queue->write_idx = 0; + + return queue; +} + +static void hid_report_queue_destroy( struct hid_report_queue *queue ) +{ + while (queue->length--) hid_report_decref( queue->reports[queue->length] ); + free( queue ); +} + +static NTSTATUS hid_report_queue_resize( struct hid_report_queue *queue, ULONG length ) +{ + struct hid_report *old_reports[512]; + LONG old_length = queue->length; + KIRQL irql; + + if (length < 2 || length > 512) return STATUS_INVALID_PARAMETER; + if (length == queue->length) return STATUS_SUCCESS; + + KeAcquireSpinLock( &queue->lock, &irql ); + memcpy( old_reports, queue->reports, old_length * sizeof(void *) ); + memset( queue->reports, 0, old_length * sizeof(void *) ); + queue->length = length; + queue->write_idx = 0; + queue->read_idx = 0; + KeReleaseSpinLock( &queue->lock, irql ); + + while (old_length--) hid_report_decref( old_reports[old_length] ); + return STATUS_SUCCESS; +} + +static void hid_report_queue_push( struct hid_report_queue *queue, struct hid_report *report ) +{ + ULONG i = queue->write_idx, next = i + 1; + struct hid_report *prev; + KIRQL irql; + + if (next >= queue->length) next = 0; + hid_report_incref( report ); + + KeAcquireSpinLock( &queue->lock, &irql ); + prev = queue->reports[i]; + queue->reports[i] = report; + if (next != queue->read_idx) queue->write_idx = next; + KeReleaseSpinLock( &queue->lock, irql ); + + hid_report_decref( prev ); +} + +static struct hid_report *hid_report_queue_pop( struct hid_report_queue *queue ) +{ + ULONG i = queue->read_idx, next = i + 1; + struct hid_report *report; + KIRQL irql; + + if (next >= queue->length) next = 0; + + KeAcquireSpinLock( &queue->lock, &irql ); + report = queue->reports[i]; + queue->reports[i] = NULL; + if (i != queue->write_idx) queue->read_idx = next; + KeReleaseSpinLock( &queue->lock, irql ); + + return report; +} + static void hid_device_queue_input( DEVICE_OBJECT *device, HID_XFER_PACKET *packet ) { BASE_DEVICE_EXTENSION *ext = device->DeviceExtension; struct hid_preparsed_data *preparsed = ext->u.pdo.preparsed_data; - HID_XFER_PACKET *read_packet, *last_packet = packet; + struct hid_report *last_report, *report; struct hid_report_queue *queue; RAWINPUT *rawinput; ULONG size; @@ -113,31 +213,31 @@ static void hid_device_queue_input( DEVICE_OBJECT *device, HID_XFER_PACKET *pack free( rawinput ); } - KeAcquireSpinLock( &ext->u.pdo.report_queues_lock, &irql ); - LIST_FOR_EACH_ENTRY( queue, &ext->u.pdo.report_queues, struct hid_report_queue, entry ) - RingBuffer_Write( queue->buffer, packet ); - KeReleaseSpinLock( &ext->u.pdo.report_queues_lock, irql ); - - if (!(read_packet = malloc( sizeof(*packet) + preparsed->caps.InputReportByteLength ))) + if (!(last_report = hid_report_create( packet ))) { - ERR( "Failed to allocate read_packet!\n" ); + ERR( "Failed to allocate hid_report!\n" ); return; } + KeAcquireSpinLock( &ext->u.pdo.report_queues_lock, &irql ); + LIST_FOR_EACH_ENTRY( queue, &ext->u.pdo.report_queues, struct hid_report_queue, entry ) + hid_report_queue_push( queue, last_report ); + KeReleaseSpinLock( &ext->u.pdo.report_queues_lock, irql ); + while ((irp = pop_irp_from_queue( ext ))) { queue = irp->Tail.Overlay.OriginalFileObject->FsContext; - RingBuffer_ReadNew( queue->buffer, 0, read_packet, &size ); - if (!size) packet = last_packet; - else packet = read_packet; - memcpy( irp->AssociatedIrp.SystemBuffer, packet + 1, preparsed->caps.InputReportByteLength ); - irp->IoStatus.Information = packet->reportBufferLen; + if (!(report = hid_report_queue_pop( queue ))) hid_report_incref( (report = last_report) ); + memcpy( irp->AssociatedIrp.SystemBuffer, report->buffer, preparsed->caps.InputReportByteLength ); + irp->IoStatus.Information = report->length; irp->IoStatus.Status = STATUS_SUCCESS; + hid_report_decref( report ); + IoCompleteRequest( irp, IO_NO_INCREMENT ); } - free( read_packet ); + hid_report_decref( last_report ); } static DWORD CALLBACK hid_device_thread(void *args) @@ -476,7 +576,7 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp) if (irpsp->Parameters.DeviceIoControl.InputBufferLength != sizeof(ULONG)) irp->IoStatus.Status = STATUS_BUFFER_OVERFLOW; else - irp->IoStatus.Status = RingBuffer_SetSize( queue->buffer, *(ULONG *)irp->AssociatedIrp.SystemBuffer ); + irp->IoStatus.Status = hid_report_queue_resize( queue, *(ULONG *)irp->AssociatedIrp.SystemBuffer ); break; } case IOCTL_GET_NUM_DEVICE_INPUT_BUFFERS: @@ -488,7 +588,7 @@ NTSTATUS WINAPI pdo_ioctl(DEVICE_OBJECT *device, IRP *irp) } else { - *(ULONG *)irp->AssociatedIrp.SystemBuffer = RingBuffer_GetSize( queue->buffer ); + *(ULONG *)irp->AssociatedIrp.SystemBuffer = queue->length; irp->IoStatus.Information = sizeof(ULONG); irp->IoStatus.Status = STATUS_SUCCESS; } @@ -522,10 +622,8 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp) struct hid_preparsed_data *preparsed = ext->u.pdo.preparsed_data; IO_STACK_LOCATION *irpsp = IoGetCurrentIrpStackLocation(irp); BYTE report_id = HID_INPUT_VALUE_CAPS( preparsed )->report_id; - HID_XFER_PACKET *packet; - UINT buffer_size; + struct hid_report *report; NTSTATUS status; - int ptr = -1; BOOL removed; KIRQL irql; @@ -547,17 +645,13 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp) return STATUS_INVALID_BUFFER_SIZE; } - packet = malloc( sizeof(*packet) + preparsed->caps.InputReportByteLength ); - ptr = PtrToUlong( irp->Tail.Overlay.OriginalFileObject->FsContext ); - irp->IoStatus.Information = 0; - RingBuffer_ReadNew( queue->buffer, ptr, packet, &buffer_size ); - - if (buffer_size) + if ((report = hid_report_queue_pop( queue ))) { - memcpy( irp->AssociatedIrp.SystemBuffer, packet + 1, preparsed->caps.InputReportByteLength ); - irp->IoStatus.Information = packet->reportBufferLen; + memcpy( irp->AssociatedIrp.SystemBuffer, report->buffer, preparsed->caps.InputReportByteLength ); + irp->IoStatus.Information = report->length; irp->IoStatus.Status = STATUS_SUCCESS; + hid_report_decref( report ); } else { @@ -606,7 +700,6 @@ NTSTATUS WINAPI pdo_read(DEVICE_OBJECT *device, IRP *irp) sizeof(packet), &irp->IoStatus ); } } - free(packet); status = irp->IoStatus.Status; if (status != STATUS_PENDING) IoCompleteRequest( irp, IO_NO_INCREMENT ); @@ -628,21 +721,14 @@ NTSTATUS WINAPI pdo_write(DEVICE_OBJECT *device, IRP *irp) NTSTATUS WINAPI pdo_create(DEVICE_OBJECT *device, IRP *irp) { BASE_DEVICE_EXTENSION *ext = device->DeviceExtension; - struct hid_preparsed_data *preparsed = ext->u.pdo.preparsed_data; struct hid_report_queue *queue; KIRQL irql; TRACE("Open handle on device %p\n", device); - if (!(queue = malloc( sizeof(*queue) )) || - !(queue->buffer = RingBuffer_Create( sizeof(HID_XFER_PACKET) + preparsed->caps.InputReportByteLength ))) - { - free( queue ); - irp->IoStatus.Status = STATUS_NO_MEMORY; - } + if (!(queue = hid_report_queue_create())) irp->IoStatus.Status = STATUS_NO_MEMORY; else { - RingBuffer_AddPointer( queue->buffer ); KeAcquireSpinLock( &ext->u.pdo.report_queues_lock, &irql ); list_add_tail( &ext->u.pdo.report_queues, &queue->entry ); KeReleaseSpinLock( &ext->u.pdo.report_queues_lock, irql ); @@ -668,9 +754,7 @@ NTSTATUS WINAPI pdo_close(DEVICE_OBJECT *device, IRP *irp) KeAcquireSpinLock( &ext->u.pdo.report_queues_lock, &irql ); list_remove( &queue->entry ); KeReleaseSpinLock( &ext->u.pdo.report_queues_lock, irql ); - - RingBuffer_Destroy( queue->buffer ); - free( queue ); + hid_report_queue_destroy( queue ); } irp->IoStatus.Status = STATUS_SUCCESS; diff --git a/dlls/hidclass.sys/hid.h b/dlls/hidclass.sys/hid.h index 6bfba100007..2dfdfe4d0f7 100644 --- a/dlls/hidclass.sys/hid.h +++ b/dlls/hidclass.sys/hid.h @@ -87,21 +87,22 @@ typedef struct _BASE_DEVICE_EXTENSION BOOL is_fdo; } BASE_DEVICE_EXTENSION; -struct hid_report_queue +struct hid_report { - struct list entry; - struct ReportRingBuffer *buffer; + LONG ref; + ULONG length; + BYTE buffer[1]; }; -void RingBuffer_Write(struct ReportRingBuffer *buffer, void *data) DECLSPEC_HIDDEN; -UINT RingBuffer_AddPointer(struct ReportRingBuffer *buffer) DECLSPEC_HIDDEN; -void RingBuffer_RemovePointer(struct ReportRingBuffer *ring, UINT index) DECLSPEC_HIDDEN; -void RingBuffer_ReadNew(struct ReportRingBuffer *buffer, UINT index, void *output, UINT *size) DECLSPEC_HIDDEN; -UINT RingBuffer_GetBufferSize(struct ReportRingBuffer *buffer) DECLSPEC_HIDDEN; -UINT RingBuffer_GetSize(struct ReportRingBuffer *buffer) DECLSPEC_HIDDEN; -void RingBuffer_Destroy(struct ReportRingBuffer *buffer) DECLSPEC_HIDDEN; -struct ReportRingBuffer* RingBuffer_Create(UINT buffer_size) DECLSPEC_HIDDEN; -NTSTATUS RingBuffer_SetSize(struct ReportRingBuffer *buffer, UINT size) DECLSPEC_HIDDEN; +struct hid_report_queue +{ + struct list entry; + KSPIN_LOCK lock; + ULONG length; + ULONG read_idx; + ULONG write_idx; + struct hid_report *reports[512]; +}; typedef struct _minidriver {