From 1c4990518269a570476b76c02d11c98808f61b93 Mon Sep 17 00:00:00 2001 From: Alexandre Julliard Date: Mon, 24 Jul 2017 10:33:53 +0200 Subject: [PATCH] ntdll: Add an assembly wrapper to return correct values for the current thread in NtGetContextThread. Signed-off-by: Alexandre Julliard --- dlls/ntdll/ntdll.spec | 4 +- dlls/ntdll/signal_i386.c | 128 ++++++++++++++++++----------------- dlls/ntdll/tests/exception.c | 91 +++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 63 deletions(-) diff --git a/dlls/ntdll/ntdll.spec b/dlls/ntdll/ntdll.spec index d7fdced48ce..141f50d6d02 100644 --- a/dlls/ntdll/ntdll.spec +++ b/dlls/ntdll/ntdll.spec @@ -185,7 +185,7 @@ # @ stub NtFreeUserPhysicalPages @ stdcall NtFreeVirtualMemory(long ptr ptr long) @ stdcall NtFsControlFile(long long ptr ptr ptr long ptr long ptr long) -@ stdcall NtGetContextThread(long ptr) +@ stdcall -norelay NtGetContextThread(long ptr) @ stdcall NtGetCurrentProcessorNumber() # @ stub NtGetDevicePowerState @ stub NtGetPlugPlayEvent @@ -1111,7 +1111,7 @@ # @ stub ZwFreeUserPhysicalPages @ stdcall -private ZwFreeVirtualMemory(long ptr ptr long) NtFreeVirtualMemory @ stdcall -private ZwFsControlFile(long long ptr ptr ptr long ptr long ptr long) NtFsControlFile -@ stdcall -private ZwGetContextThread(long ptr) NtGetContextThread +@ stdcall -private -norelay ZwGetContextThread(long ptr) NtGetContextThread @ stdcall -private ZwGetCurrentProcessorNumber() NtGetCurrentProcessorNumber # @ stub ZwGetDevicePowerState @ stub ZwGetPlugPlayEvent diff --git a/dlls/ntdll/signal_i386.c b/dlls/ntdll/signal_i386.c index 184e6dafb1a..9275ad9cd25 100644 --- a/dlls/ntdll/signal_i386.c +++ b/dlls/ntdll/signal_i386.c @@ -1292,59 +1292,6 @@ static void set_cpu_context( const CONTEXT *context ) } -/*********************************************************************** - * copy_context - * - * Copy a register context according to the flags. - */ -static void copy_context( CONTEXT *to, const CONTEXT *from, DWORD flags ) -{ - flags &= ~CONTEXT_i386; /* get rid of CPU id */ - if (flags & CONTEXT_INTEGER) - { - to->Eax = from->Eax; - to->Ebx = from->Ebx; - to->Ecx = from->Ecx; - to->Edx = from->Edx; - to->Esi = from->Esi; - to->Edi = from->Edi; - } - if (flags & CONTEXT_CONTROL) - { - to->Ebp = from->Ebp; - to->Esp = from->Esp; - to->Eip = from->Eip; - to->SegCs = from->SegCs; - to->SegSs = from->SegSs; - to->EFlags = from->EFlags; - } - if (flags & CONTEXT_SEGMENTS) - { - to->SegDs = from->SegDs; - to->SegEs = from->SegEs; - to->SegFs = from->SegFs; - to->SegGs = from->SegGs; - } - if (flags & CONTEXT_DEBUG_REGISTERS) - { - to->Dr0 = from->Dr0; - to->Dr1 = from->Dr1; - to->Dr2 = from->Dr2; - to->Dr3 = from->Dr3; - to->Dr6 = from->Dr6; - to->Dr7 = from->Dr7; - } - if (flags & CONTEXT_FLOATING_POINT) - { - to->FloatSave = from->FloatSave; - } - if (flags & CONTEXT_EXTENDED_REGISTERS) - { - memcpy( to->ExtendedRegisters, from->ExtendedRegisters, sizeof(to->ExtendedRegisters) ); - } -} - - /*********************************************************************** * context_to_server * @@ -1515,15 +1462,19 @@ NTSTATUS WINAPI NtSetContextThread( HANDLE handle, const CONTEXT *context ) /*********************************************************************** * NtGetContextThread (NTDLL.@) * ZwGetContextThread (NTDLL.@) + * + * Note: we use a small assembly wrapper to save the necessary registers + * in case we are fetching the context of the current thread. */ -NTSTATUS WINAPI NtGetContextThread( HANDLE handle, CONTEXT *context ) +NTSTATUS CDECL __regs_NtGetContextThread( DWORD edi, DWORD esi, DWORD ebx, DWORD eflags, + DWORD ebp, DWORD retaddr, HANDLE handle, CONTEXT *context ) { NTSTATUS ret; - DWORD needed_flags = context->ContextFlags; + DWORD needed_flags = context->ContextFlags & ~CONTEXT_i386; BOOL self = (handle == GetCurrentThread()); /* debug registers require a server call */ - if (context->ContextFlags & (CONTEXT_DEBUG_REGISTERS & ~CONTEXT_i386)) self = FALSE; + if (needed_flags & CONTEXT_DEBUG_REGISTERS) self = FALSE; if (!self) { @@ -1533,13 +1484,36 @@ NTSTATUS WINAPI NtGetContextThread( HANDLE handle, CONTEXT *context ) if (self) { - if (needed_flags) + if (needed_flags & CONTEXT_INTEGER) { - CONTEXT ctx; - RtlCaptureContext( &ctx ); - copy_context( context, &ctx, ctx.ContextFlags & needed_flags ); - context->ContextFlags |= ctx.ContextFlags & needed_flags; + context->Eax = 0; + context->Ebx = ebx; + context->Ecx = 0; + context->Edx = 0; + context->Esi = esi; + context->Edi = edi; + context->ContextFlags |= CONTEXT_INTEGER; } + if (needed_flags & CONTEXT_CONTROL) + { + context->Ebp = ebp; + context->Esp = (DWORD)&retaddr; + context->Eip = *(&edi - 1); + context->SegCs = wine_get_cs(); + context->SegSs = wine_get_ss(); + context->EFlags = eflags; + context->ContextFlags |= CONTEXT_CONTROL; + } + if (needed_flags & CONTEXT_SEGMENTS) + { + context->SegDs = wine_get_ds(); + context->SegEs = wine_get_es(); + context->SegFs = wine_get_fs(); + context->SegGs = wine_get_gs(); + context->ContextFlags |= CONTEXT_SEGMENTS; + } + if (needed_flags & CONTEXT_FLOATING_POINT) save_fpu( context ); + /* FIXME: extended floating point */ /* update the cached version of the debug registers */ if (context->ContextFlags & (CONTEXT_DEBUG_REGISTERS & ~CONTEXT_i386)) { @@ -1551,8 +1525,40 @@ NTSTATUS WINAPI NtGetContextThread( HANDLE handle, CONTEXT *context ) x86_thread_data()->dr7 = context->Dr7; } } + + if (context->ContextFlags & (CONTEXT_INTEGER & ~CONTEXT_i386)) + TRACE( "%p: eax=%08x ebx=%08x ecx=%08x edx=%08x esi=%08x edi=%08x\n", handle, + context->Eax, context->Ebx, context->Ecx, context->Edx, context->Esi, context->Edi ); + if (context->ContextFlags & (CONTEXT_CONTROL & ~CONTEXT_i386)) + TRACE( "%p: ebp=%08x esp=%08x eip=%08x cs=%04x ss=%04x flags=%08x\n", handle, + context->Ebp, context->Esp, context->Eip, context->SegCs, context->SegSs, context->EFlags ); + if (context->ContextFlags & (CONTEXT_SEGMENTS & ~CONTEXT_i386)) + TRACE( "%p: ds=%04x es=%04x fs=%04x gs=%04x\n", handle, + context->SegCs, context->SegDs, context->SegEs, context->SegFs ); + if (context->ContextFlags & (CONTEXT_DEBUG_REGISTERS & ~CONTEXT_i386)) + TRACE( "%p: dr0=%08x dr1=%08x dr2=%08x dr3=%08x dr6=%08x dr7=%08x\n", handle, + context->Dr0, context->Dr1, context->Dr2, context->Dr3, context->Dr6, context->Dr7 ); + return STATUS_SUCCESS; } +__ASM_STDCALL_FUNC( NtGetContextThread, 8, + "pushl %ebp\n\t" + __ASM_CFI(".cfi_adjust_cfa_offset 4\n\t") + __ASM_CFI(".cfi_rel_offset %ebp,0\n\t") + "movl %esp,%ebp\n\t" + __ASM_CFI(".cfi_def_cfa_register %ebp\n\t") + "pushfl\n\t" + "pushl %ebx\n\t" + __ASM_CFI(".cfi_rel_offset %ebx,-8\n\t") + "pushl %esi\n\t" + __ASM_CFI(".cfi_rel_offset %esi,-12\n\t") + "pushl %edi\n\t" + __ASM_CFI(".cfi_rel_offset %edi,-16\n\t") + "call " __ASM_NAME("__regs_NtGetContextThread") "\n\t" + "leave\n\t" + __ASM_CFI(".cfi_def_cfa %esp,4\n\t") + __ASM_CFI(".cfi_same_value %ebp\n\t") + "ret $8" ) /*********************************************************************** diff --git a/dlls/ntdll/tests/exception.c b/dlls/ntdll/tests/exception.c index 27eac2748cb..14e6da32a4b 100644 --- a/dlls/ntdll/tests/exception.c +++ b/dlls/ntdll/tests/exception.c @@ -1416,6 +1416,96 @@ static void test_dpe_exceptions(void) ok(stat == STATUS_ACCESS_DENIED, "enabling DEP while permanent: status %08x\n", stat); } +static void test_thread_context(void) +{ + CONTEXT context; + NTSTATUS status; + struct expected + { + DWORD Eax, Ebx, Ecx, Edx, Esi, Edi, Ebp, Esp, Eip, + SegCs, SegDs, SegEs, SegFs, SegGs, SegSs, EFlags, prev_frame; + } expect; + NTSTATUS (*func_ptr)( struct expected *res, void *func, void *arg1, void *arg2 ) = (void *)code_mem; + + static const BYTE call_func[] = + { + 0x55, /* pushl %ebp */ + 0x89, 0xe5, /* mov %esp,%ebp */ + 0x50, /* pushl %eax ; add a bit of offset to the stack */ + 0x50, /* pushl %eax */ + 0x50, /* pushl %eax */ + 0x50, /* pushl %eax */ + 0x8b, 0x45, 0x08, /* mov 0x8(%ebp),%eax */ + 0x8f, 0x00, /* popl (%eax) */ + 0x89, 0x58, 0x04, /* mov %ebx,0x4(%eax) */ + 0x89, 0x48, 0x08, /* mov %ecx,0x8(%eax) */ + 0x89, 0x50, 0x0c, /* mov %edx,0xc(%eax) */ + 0x89, 0x70, 0x10, /* mov %esi,0x10(%eax) */ + 0x89, 0x78, 0x14, /* mov %edi,0x14(%eax) */ + 0x89, 0x68, 0x18, /* mov %ebp,0x18(%eax) */ + 0x89, 0x60, 0x1c, /* mov %esp,0x1c(%eax) */ + 0xff, 0x75, 0x04, /* pushl 0x4(%ebp) */ + 0x8f, 0x40, 0x20, /* popl 0x20(%eax) */ + 0x8c, 0x48, 0x24, /* mov %cs,0x24(%eax) */ + 0x8c, 0x58, 0x28, /* mov %ds,0x28(%eax) */ + 0x8c, 0x40, 0x2c, /* mov %es,0x2c(%eax) */ + 0x8c, 0x60, 0x30, /* mov %fs,0x30(%eax) */ + 0x8c, 0x68, 0x34, /* mov %gs,0x34(%eax) */ + 0x8c, 0x50, 0x38, /* mov %ss,0x38(%eax) */ + 0x9c, /* pushf */ + 0x8f, 0x40, 0x3c, /* popl 0x3c(%eax) */ + 0xff, 0x75, 0x00, /* pushl 0x0(%ebp) ; previous stack frame */ + 0x8f, 0x40, 0x40, /* popl 0x40(%eax) */ + 0x8b, 0x00, /* mov (%eax),%eax */ + 0xff, 0x75, 0x14, /* pushl 0x14(%ebp) */ + 0xff, 0x75, 0x10, /* pushl 0x10(%ebp) */ + 0xff, 0x55, 0x0c, /* call *0xc(%ebp) */ + 0xc9, /* leave */ + 0xc3, /* ret */ + }; + + memcpy( func_ptr, call_func, sizeof(call_func) ); + +#define COMPARE(reg) \ + ok( context.reg == expect.reg, "wrong " #reg " %08x/%08x\n", context.reg, expect.reg ) + + memset( &context, 0xcc, sizeof(context) ); + memset( &expect, 0xcc, sizeof(expect) ); + context.ContextFlags = CONTEXT_CONTROL | CONTEXT_INTEGER | CONTEXT_SEGMENTS; + status = func_ptr( &expect, pNtGetContextThread, (void *)GetCurrentThread(), &context ); + ok( status == STATUS_SUCCESS, "NtGetContextThread failed %08x\n", status ); + trace( "expect: eax=%08x ebx=%08x ecx=%08x edx=%08x esi=%08x edi=%08x ebp=%08x esp=%08x " + "eip=%08x cs=%04x ds=%04x es=%04x fs=%04x gs=%04x ss=%04x flags=%08x prev=%08x\n", + expect.Eax, expect.Ebx, expect.Ecx, expect.Edx, expect.Esi, expect.Edi, + expect.Ebp, expect.Esp, expect.Eip, expect.SegCs, expect.SegDs, expect.SegEs, + expect.SegFs, expect.SegGs, expect.SegSs, expect.EFlags, expect.prev_frame ); + trace( "actual: eax=%08x ebx=%08x ecx=%08x edx=%08x esi=%08x edi=%08x ebp=%08x esp=%08x " + "eip=%08x cs=%04x ds=%04x es=%04x fs=%04x gs=%04x ss=%04x flags=%08x\n", + context.Eax, context.Ebx, context.Ecx, context.Edx, context.Esi, context.Edi, + context.Ebp, context.Esp, context.Eip, context.SegCs, context.SegDs, context.SegEs, + context.SegFs, context.SegGs, context.SegSs, context.EFlags ); + /* Eax, Ecx, Edx, EFlags are not preserved */ + COMPARE( Ebx ); + COMPARE( Esi ); + COMPARE( Edi ); + COMPARE( Ebp ); + /* Esp is the stack upon entry to NtGetContextThread */ + ok( context.Esp == expect.Esp - 12 || context.Esp == expect.Esp - 16, + "wrong Esp %08x/%08x\n", context.Esp, expect.Esp ); + /* Eip is somewhere close to the NtGetContextThread implementation */ + ok( (char *)context.Eip >= (char *)pNtGetContextThread - 0x10000 && + (char *)context.Eip <= (char *)pNtGetContextThread + 0x10000, + "wrong Eip %08x/%08x\n", context.Eip, (DWORD)pNtGetContextThread ); + /* segment registers clear the high word */ + ok( context.SegCs == LOWORD(expect.SegCs), "wrong SegCs %08x/%08x\n", context.SegCs, expect.SegCs ); + ok( context.SegDs == LOWORD(expect.SegDs), "wrong SegDs %08x/%08x\n", context.SegDs, expect.SegDs ); + ok( context.SegEs == LOWORD(expect.SegEs), "wrong SegEs %08x/%08x\n", context.SegEs, expect.SegEs ); + ok( context.SegFs == LOWORD(expect.SegFs), "wrong SegFs %08x/%08x\n", context.SegFs, expect.SegFs ); + ok( context.SegGs == LOWORD(expect.SegGs), "wrong SegGs %08x/%08x\n", context.SegGs, expect.SegGs ); + ok( context.SegSs == LOWORD(expect.SegSs), "wrong SegSs %08x/%08x\n", context.SegSs, expect.SegGs ); +#undef COMPARE +} + #elif defined(__x86_64__) #define is_wow64 0 @@ -2530,6 +2620,7 @@ START_TEST(exception) test_fpu_exceptions(); test_dpe_exceptions(); test_prot_fault(); + test_thread_context(); #elif defined(__x86_64__) pRtlAddFunctionTable = (void *)GetProcAddress( hntdll,