ntdll: Add an assembly wrapper to return correct values for the current thread in NtGetContextThread.

Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Alexandre Julliard 2017-07-24 10:33:53 +02:00
parent 998fe046b5
commit 1c49905182
3 changed files with 160 additions and 63 deletions

View File

@ -185,7 +185,7 @@
# @ stub NtFreeUserPhysicalPages
@ stdcall NtFreeVirtualMemory(long ptr ptr long)
@ stdcall NtFsControlFile(long long ptr ptr ptr long ptr long ptr long)
@ stdcall NtGetContextThread(long ptr)
@ stdcall -norelay NtGetContextThread(long ptr)
@ stdcall NtGetCurrentProcessorNumber()
# @ stub NtGetDevicePowerState
@ stub NtGetPlugPlayEvent
@ -1111,7 +1111,7 @@
# @ stub ZwFreeUserPhysicalPages
@ stdcall -private ZwFreeVirtualMemory(long ptr ptr long) NtFreeVirtualMemory
@ stdcall -private ZwFsControlFile(long long ptr ptr ptr long ptr long ptr long) NtFsControlFile
@ stdcall -private ZwGetContextThread(long ptr) NtGetContextThread
@ stdcall -private -norelay ZwGetContextThread(long ptr) NtGetContextThread
@ stdcall -private ZwGetCurrentProcessorNumber() NtGetCurrentProcessorNumber
# @ stub ZwGetDevicePowerState
@ stub ZwGetPlugPlayEvent

View File

@ -1292,59 +1292,6 @@ static void set_cpu_context( const CONTEXT *context )
}
/***********************************************************************
* copy_context
*
* Copy a register context according to the flags.
*/
static void copy_context( CONTEXT *to, const CONTEXT *from, DWORD flags )
{
flags &= ~CONTEXT_i386; /* get rid of CPU id */
if (flags & CONTEXT_INTEGER)
{
to->Eax = from->Eax;
to->Ebx = from->Ebx;
to->Ecx = from->Ecx;
to->Edx = from->Edx;
to->Esi = from->Esi;
to->Edi = from->Edi;
}
if (flags & CONTEXT_CONTROL)
{
to->Ebp = from->Ebp;
to->Esp = from->Esp;
to->Eip = from->Eip;
to->SegCs = from->SegCs;
to->SegSs = from->SegSs;
to->EFlags = from->EFlags;
}
if (flags & CONTEXT_SEGMENTS)
{
to->SegDs = from->SegDs;
to->SegEs = from->SegEs;
to->SegFs = from->SegFs;
to->SegGs = from->SegGs;
}
if (flags & CONTEXT_DEBUG_REGISTERS)
{
to->Dr0 = from->Dr0;
to->Dr1 = from->Dr1;
to->Dr2 = from->Dr2;
to->Dr3 = from->Dr3;
to->Dr6 = from->Dr6;
to->Dr7 = from->Dr7;
}
if (flags & CONTEXT_FLOATING_POINT)
{
to->FloatSave = from->FloatSave;
}
if (flags & CONTEXT_EXTENDED_REGISTERS)
{
memcpy( to->ExtendedRegisters, from->ExtendedRegisters, sizeof(to->ExtendedRegisters) );
}
}
/***********************************************************************
* context_to_server
*
@ -1515,15 +1462,19 @@ NTSTATUS WINAPI NtSetContextThread( HANDLE handle, const CONTEXT *context )
/***********************************************************************
* NtGetContextThread (NTDLL.@)
* ZwGetContextThread (NTDLL.@)
*
* Note: we use a small assembly wrapper to save the necessary registers
* in case we are fetching the context of the current thread.
*/
NTSTATUS WINAPI NtGetContextThread( HANDLE handle, CONTEXT *context )
NTSTATUS CDECL __regs_NtGetContextThread( DWORD edi, DWORD esi, DWORD ebx, DWORD eflags,
DWORD ebp, DWORD retaddr, HANDLE handle, CONTEXT *context )
{
NTSTATUS ret;
DWORD needed_flags = context->ContextFlags;
DWORD needed_flags = context->ContextFlags & ~CONTEXT_i386;
BOOL self = (handle == GetCurrentThread());
/* debug registers require a server call */
if (context->ContextFlags & (CONTEXT_DEBUG_REGISTERS & ~CONTEXT_i386)) self = FALSE;
if (needed_flags & CONTEXT_DEBUG_REGISTERS) self = FALSE;
if (!self)
{
@ -1533,13 +1484,36 @@ NTSTATUS WINAPI NtGetContextThread( HANDLE handle, CONTEXT *context )
if (self)
{
if (needed_flags)
if (needed_flags & CONTEXT_INTEGER)
{
CONTEXT ctx;
RtlCaptureContext( &ctx );
copy_context( context, &ctx, ctx.ContextFlags & needed_flags );
context->ContextFlags |= ctx.ContextFlags & needed_flags;
context->Eax = 0;
context->Ebx = ebx;
context->Ecx = 0;
context->Edx = 0;
context->Esi = esi;
context->Edi = edi;
context->ContextFlags |= CONTEXT_INTEGER;
}
if (needed_flags & CONTEXT_CONTROL)
{
context->Ebp = ebp;
context->Esp = (DWORD)&retaddr;
context->Eip = *(&edi - 1);
context->SegCs = wine_get_cs();
context->SegSs = wine_get_ss();
context->EFlags = eflags;
context->ContextFlags |= CONTEXT_CONTROL;
}
if (needed_flags & CONTEXT_SEGMENTS)
{
context->SegDs = wine_get_ds();
context->SegEs = wine_get_es();
context->SegFs = wine_get_fs();
context->SegGs = wine_get_gs();
context->ContextFlags |= CONTEXT_SEGMENTS;
}
if (needed_flags & CONTEXT_FLOATING_POINT) save_fpu( context );
/* FIXME: extended floating point */
/* update the cached version of the debug registers */
if (context->ContextFlags & (CONTEXT_DEBUG_REGISTERS & ~CONTEXT_i386))
{
@ -1551,8 +1525,40 @@ NTSTATUS WINAPI NtGetContextThread( HANDLE handle, CONTEXT *context )
x86_thread_data()->dr7 = context->Dr7;
}
}
if (context->ContextFlags & (CONTEXT_INTEGER & ~CONTEXT_i386))
TRACE( "%p: eax=%08x ebx=%08x ecx=%08x edx=%08x esi=%08x edi=%08x\n", handle,
context->Eax, context->Ebx, context->Ecx, context->Edx, context->Esi, context->Edi );
if (context->ContextFlags & (CONTEXT_CONTROL & ~CONTEXT_i386))
TRACE( "%p: ebp=%08x esp=%08x eip=%08x cs=%04x ss=%04x flags=%08x\n", handle,
context->Ebp, context->Esp, context->Eip, context->SegCs, context->SegSs, context->EFlags );
if (context->ContextFlags & (CONTEXT_SEGMENTS & ~CONTEXT_i386))
TRACE( "%p: ds=%04x es=%04x fs=%04x gs=%04x\n", handle,
context->SegCs, context->SegDs, context->SegEs, context->SegFs );
if (context->ContextFlags & (CONTEXT_DEBUG_REGISTERS & ~CONTEXT_i386))
TRACE( "%p: dr0=%08x dr1=%08x dr2=%08x dr3=%08x dr6=%08x dr7=%08x\n", handle,
context->Dr0, context->Dr1, context->Dr2, context->Dr3, context->Dr6, context->Dr7 );
return STATUS_SUCCESS;
}
__ASM_STDCALL_FUNC( NtGetContextThread, 8,
"pushl %ebp\n\t"
__ASM_CFI(".cfi_adjust_cfa_offset 4\n\t")
__ASM_CFI(".cfi_rel_offset %ebp,0\n\t")
"movl %esp,%ebp\n\t"
__ASM_CFI(".cfi_def_cfa_register %ebp\n\t")
"pushfl\n\t"
"pushl %ebx\n\t"
__ASM_CFI(".cfi_rel_offset %ebx,-8\n\t")
"pushl %esi\n\t"
__ASM_CFI(".cfi_rel_offset %esi,-12\n\t")
"pushl %edi\n\t"
__ASM_CFI(".cfi_rel_offset %edi,-16\n\t")
"call " __ASM_NAME("__regs_NtGetContextThread") "\n\t"
"leave\n\t"
__ASM_CFI(".cfi_def_cfa %esp,4\n\t")
__ASM_CFI(".cfi_same_value %ebp\n\t")
"ret $8" )
/***********************************************************************

View File

@ -1416,6 +1416,96 @@ static void test_dpe_exceptions(void)
ok(stat == STATUS_ACCESS_DENIED, "enabling DEP while permanent: status %08x\n", stat);
}
static void test_thread_context(void)
{
CONTEXT context;
NTSTATUS status;
struct expected
{
DWORD Eax, Ebx, Ecx, Edx, Esi, Edi, Ebp, Esp, Eip,
SegCs, SegDs, SegEs, SegFs, SegGs, SegSs, EFlags, prev_frame;
} expect;
NTSTATUS (*func_ptr)( struct expected *res, void *func, void *arg1, void *arg2 ) = (void *)code_mem;
static const BYTE call_func[] =
{
0x55, /* pushl %ebp */
0x89, 0xe5, /* mov %esp,%ebp */
0x50, /* pushl %eax ; add a bit of offset to the stack */
0x50, /* pushl %eax */
0x50, /* pushl %eax */
0x50, /* pushl %eax */
0x8b, 0x45, 0x08, /* mov 0x8(%ebp),%eax */
0x8f, 0x00, /* popl (%eax) */
0x89, 0x58, 0x04, /* mov %ebx,0x4(%eax) */
0x89, 0x48, 0x08, /* mov %ecx,0x8(%eax) */
0x89, 0x50, 0x0c, /* mov %edx,0xc(%eax) */
0x89, 0x70, 0x10, /* mov %esi,0x10(%eax) */
0x89, 0x78, 0x14, /* mov %edi,0x14(%eax) */
0x89, 0x68, 0x18, /* mov %ebp,0x18(%eax) */
0x89, 0x60, 0x1c, /* mov %esp,0x1c(%eax) */
0xff, 0x75, 0x04, /* pushl 0x4(%ebp) */
0x8f, 0x40, 0x20, /* popl 0x20(%eax) */
0x8c, 0x48, 0x24, /* mov %cs,0x24(%eax) */
0x8c, 0x58, 0x28, /* mov %ds,0x28(%eax) */
0x8c, 0x40, 0x2c, /* mov %es,0x2c(%eax) */
0x8c, 0x60, 0x30, /* mov %fs,0x30(%eax) */
0x8c, 0x68, 0x34, /* mov %gs,0x34(%eax) */
0x8c, 0x50, 0x38, /* mov %ss,0x38(%eax) */
0x9c, /* pushf */
0x8f, 0x40, 0x3c, /* popl 0x3c(%eax) */
0xff, 0x75, 0x00, /* pushl 0x0(%ebp) ; previous stack frame */
0x8f, 0x40, 0x40, /* popl 0x40(%eax) */
0x8b, 0x00, /* mov (%eax),%eax */
0xff, 0x75, 0x14, /* pushl 0x14(%ebp) */
0xff, 0x75, 0x10, /* pushl 0x10(%ebp) */
0xff, 0x55, 0x0c, /* call *0xc(%ebp) */
0xc9, /* leave */
0xc3, /* ret */
};
memcpy( func_ptr, call_func, sizeof(call_func) );
#define COMPARE(reg) \
ok( context.reg == expect.reg, "wrong " #reg " %08x/%08x\n", context.reg, expect.reg )
memset( &context, 0xcc, sizeof(context) );
memset( &expect, 0xcc, sizeof(expect) );
context.ContextFlags = CONTEXT_CONTROL | CONTEXT_INTEGER | CONTEXT_SEGMENTS;
status = func_ptr( &expect, pNtGetContextThread, (void *)GetCurrentThread(), &context );
ok( status == STATUS_SUCCESS, "NtGetContextThread failed %08x\n", status );
trace( "expect: eax=%08x ebx=%08x ecx=%08x edx=%08x esi=%08x edi=%08x ebp=%08x esp=%08x "
"eip=%08x cs=%04x ds=%04x es=%04x fs=%04x gs=%04x ss=%04x flags=%08x prev=%08x\n",
expect.Eax, expect.Ebx, expect.Ecx, expect.Edx, expect.Esi, expect.Edi,
expect.Ebp, expect.Esp, expect.Eip, expect.SegCs, expect.SegDs, expect.SegEs,
expect.SegFs, expect.SegGs, expect.SegSs, expect.EFlags, expect.prev_frame );
trace( "actual: eax=%08x ebx=%08x ecx=%08x edx=%08x esi=%08x edi=%08x ebp=%08x esp=%08x "
"eip=%08x cs=%04x ds=%04x es=%04x fs=%04x gs=%04x ss=%04x flags=%08x\n",
context.Eax, context.Ebx, context.Ecx, context.Edx, context.Esi, context.Edi,
context.Ebp, context.Esp, context.Eip, context.SegCs, context.SegDs, context.SegEs,
context.SegFs, context.SegGs, context.SegSs, context.EFlags );
/* Eax, Ecx, Edx, EFlags are not preserved */
COMPARE( Ebx );
COMPARE( Esi );
COMPARE( Edi );
COMPARE( Ebp );
/* Esp is the stack upon entry to NtGetContextThread */
ok( context.Esp == expect.Esp - 12 || context.Esp == expect.Esp - 16,
"wrong Esp %08x/%08x\n", context.Esp, expect.Esp );
/* Eip is somewhere close to the NtGetContextThread implementation */
ok( (char *)context.Eip >= (char *)pNtGetContextThread - 0x10000 &&
(char *)context.Eip <= (char *)pNtGetContextThread + 0x10000,
"wrong Eip %08x/%08x\n", context.Eip, (DWORD)pNtGetContextThread );
/* segment registers clear the high word */
ok( context.SegCs == LOWORD(expect.SegCs), "wrong SegCs %08x/%08x\n", context.SegCs, expect.SegCs );
ok( context.SegDs == LOWORD(expect.SegDs), "wrong SegDs %08x/%08x\n", context.SegDs, expect.SegDs );
ok( context.SegEs == LOWORD(expect.SegEs), "wrong SegEs %08x/%08x\n", context.SegEs, expect.SegEs );
ok( context.SegFs == LOWORD(expect.SegFs), "wrong SegFs %08x/%08x\n", context.SegFs, expect.SegFs );
ok( context.SegGs == LOWORD(expect.SegGs), "wrong SegGs %08x/%08x\n", context.SegGs, expect.SegGs );
ok( context.SegSs == LOWORD(expect.SegSs), "wrong SegSs %08x/%08x\n", context.SegSs, expect.SegGs );
#undef COMPARE
}
#elif defined(__x86_64__)
#define is_wow64 0
@ -2530,6 +2620,7 @@ START_TEST(exception)
test_fpu_exceptions();
test_dpe_exceptions();
test_prot_fault();
test_thread_context();
#elif defined(__x86_64__)
pRtlAddFunctionTable = (void *)GetProcAddress( hntdll,