diff --git a/dlls/d3dcompiler_43/tests/hlsl_d3d11.c b/dlls/d3dcompiler_43/tests/hlsl_d3d11.c index ed95457f1c7..8e37976d5ac 100644 --- a/dlls/d3dcompiler_43/tests/hlsl_d3d11.c +++ b/dlls/d3dcompiler_43/tests/hlsl_d3d11.c @@ -523,6 +523,168 @@ static void test_trig(void) release_test_context(&test_context); } +static void check_type_desc(const char *prefix, const D3D11_SHADER_TYPE_DESC *type, + const D3D11_SHADER_TYPE_DESC *expect) +{ + ok(type->Class == expect->Class, "%s: got class %#x.\n", prefix, type->Class); + ok(type->Type == expect->Type, "%s: got type %#x.\n", prefix, type->Type); + ok(type->Rows == expect->Rows, "%s: got %u rows.\n", prefix, type->Rows); + ok(type->Columns == expect->Columns, "%s: got %u columns.\n", prefix, type->Columns); + ok(type->Elements == expect->Elements, "%s: got %u elements.\n", prefix, type->Elements); + ok(type->Members == expect->Members, "%s: got %u members.\n", prefix, type->Members); + ok(type->Offset == expect->Offset, "%s: got %u members.\n", prefix, type->Members); + ok(!strcmp(type->Name, expect->Name), "%s: got name %s.\n", prefix, debugstr_a(type->Name)); +} + +static void test_reflection(void) +{ + ID3D11ShaderReflectionConstantBuffer *cbuffer; + ID3D11ShaderReflectionType *type, *field; + D3D11_SHADER_BUFFER_DESC buffer_desc; + ID3D11ShaderReflectionVariable *var; + D3D11_SHADER_VARIABLE_DESC var_desc; + ID3D11ShaderReflection *reflection; + D3D11_SHADER_TYPE_DESC type_desc; + ID3D10Blob *vs_code = NULL; + unsigned int i, j, k; + ULONG refcount; + HRESULT hr; + + static const char vs_source[] = + "typedef uint uint_t;\n" + "cbuffer b1\n" + "{\n" + " float a;\n" + " float2 b;\n" + " float4 c;\n" + " float d;\n" + " struct\n" + " {\n" + " float4 a;\n" + " float b;\n" + " float c;\n" + " } s;\n" + " float g;\n" + " float h[2];\n" + " int i;\n" + " uint_t j;\n" + " float3x1 k;\n" + " row_major float3x1 l;\n" + "};\n" + "\n" + "float m;\n" + "\n" + "float4 main(uniform float4 n) : SV_POSITION\n" + "{\n" + " return l._31 + m + n;\n" + "}"; + + struct shader_variable + { + D3D11_SHADER_VARIABLE_DESC var_desc; + D3D11_SHADER_TYPE_DESC type_desc; + }; + + static const D3D11_SHADER_TYPE_DESC field_types[] = + { + {D3D_SVC_VECTOR, D3D_SVT_FLOAT, 1, 4, 0, 0, 0, "float4"}, + {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 0, 0, 16, "float"}, + {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 0, 0, 20, "float"}, + }; + + static const struct shader_variable globals_vars = + {{"m", 0, 4, D3D_SVF_USED}, {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 0, 0, 0, "float"}}; + static const struct shader_variable params_vars = + {{"n", 0, 16, D3D_SVF_USED}, {D3D_SVC_VECTOR, D3D_SVT_FLOAT, 1, 4, 0, 0, 0, "float4"}}; + static const struct shader_variable buffer_vars[] = + { + {{"a", 0, 4}, {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 0, 0, 0, "float"}}, + {{"b", 4, 8}, {D3D_SVC_VECTOR, D3D_SVT_FLOAT, 1, 2, 0, 0, 0, "float2"}}, + {{"c", 16, 16}, {D3D_SVC_VECTOR, D3D_SVT_FLOAT, 1, 4, 0, 0, 0, "float4"}}, + {{"d", 32, 4}, {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 0, 0, 0, "float"}}, + {{"s", 48, 24}, {D3D_SVC_STRUCT, D3D_SVT_VOID, 1, 6, 0, 3, 0, ""}}, + {{"g", 72, 4}, {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 0, 0, 0, "float"}}, + {{"h", 80, 20}, {D3D_SVC_SCALAR, D3D_SVT_FLOAT, 1, 1, 2, 0, 0, "float"}}, + {{"i", 100, 4}, {D3D_SVC_SCALAR, D3D_SVT_INT, 1, 1, 0, 0, 0, "int"}}, + {{"j", 104, 4}, {D3D_SVC_SCALAR, D3D_SVT_UINT, 1, 1, 0, 0, 0, "uint_t"}}, + {{"k", 112, 12}, {D3D_SVC_MATRIX_COLUMNS, D3D_SVT_FLOAT, 3, 1, 0, 0, 0, "float3x1"}}, + {{"l", 128, 36, D3D_SVF_USED}, {D3D_SVC_MATRIX_ROWS, D3D_SVT_FLOAT, 3, 1, 0, 0, 0, "float3x1"}}, + }; + + static const struct + { + D3D11_SHADER_BUFFER_DESC desc; + const struct shader_variable *vars; + } + vs_buffers[] = + { + {{"$Globals", D3D_CT_CBUFFER, 1, 16}, &globals_vars}, + {{"$Params", D3D_CT_CBUFFER, 1, 16}, ¶ms_vars}, + {{"b1", D3D_CT_CBUFFER, ARRAY_SIZE(buffer_vars), 176}, buffer_vars}, + }; + + todo_wine vs_code = compile_shader(vs_source, "vs_5_0"); + if (!vs_code) + return; + + hr = pD3DReflect(ID3D10Blob_GetBufferPointer(vs_code), ID3D10Blob_GetBufferSize(vs_code), + &IID_ID3D11ShaderReflection, (void **)&reflection); + ok(hr == S_OK, "Got hr %#x.\n", hr); + + for (i = 0; i < ARRAY_SIZE(vs_buffers); ++i) + { + cbuffer = reflection->lpVtbl->GetConstantBufferByIndex(reflection, i); + hr = cbuffer->lpVtbl->GetDesc(cbuffer, &buffer_desc); + ok(hr == S_OK, "Test %u: got hr %#x.\n", i, hr); + ok(!strcmp(buffer_desc.Name, vs_buffers[i].desc.Name), + "Test %u: got name %s.\n", i, debugstr_a(buffer_desc.Name)); + ok(buffer_desc.Type == vs_buffers[i].desc.Type, "Test %u: got type %#x.\n", i, buffer_desc.Type); + ok(buffer_desc.Variables == vs_buffers[i].desc.Variables, + "Test %u: got %u variables.\n", i, buffer_desc.Variables); + ok(buffer_desc.Size == vs_buffers[i].desc.Size, "Test %u: got size %u.\n", i, buffer_desc.Size); + ok(buffer_desc.uFlags == vs_buffers[i].desc.uFlags, "Test %u: got flags %#x.\n", i, buffer_desc.uFlags); + + for (j = 0; j < buffer_desc.Variables; ++j) + { + const struct shader_variable *expect = &vs_buffers[i].vars[j]; + char prefix[40]; + + var = cbuffer->lpVtbl->GetVariableByIndex(cbuffer, j); + hr = var->lpVtbl->GetDesc(var, &var_desc); + ok(hr == S_OK, "Test %u, %u: got hr %#x.\n", i, j, hr); + ok(!strcmp(var_desc.Name, expect->var_desc.Name), + "Test %u, %u: got name %s.\n", i, j, debugstr_a(var_desc.Name)); + ok(var_desc.StartOffset == expect->var_desc.StartOffset, "Test %u, %u: got offset %u.\n", + i, j, var_desc.StartOffset); + ok(var_desc.Size == expect->var_desc.Size, "Test %u, %u: got size %u.\n", i, j, var_desc.Size); + ok(var_desc.uFlags == expect->var_desc.uFlags, "Test %u, %u: got flags %#x.\n", i, j, var_desc.uFlags); + ok(!var_desc.DefaultValue, "Test %u, %u: got default value %p.\n", i, j, var_desc.DefaultValue); + + type = var->lpVtbl->GetType(var); + hr = type->lpVtbl->GetDesc(type, &type_desc); + ok(hr == S_OK, "Test %u, %u: got hr %#x.\n", i, j, hr); + sprintf(prefix, "Test %u, %u", i, j); + check_type_desc(prefix, &type_desc, &expect->type_desc); + + if (!strcmp(type_desc.Name, "")) + { + for (k = 0; k < ARRAY_SIZE(field_types); ++k) + { + field = type->lpVtbl->GetMemberTypeByIndex(type, k); + hr = field->lpVtbl->GetDesc(field, &type_desc); + ok(hr == S_OK, "Test %u, %u, %u: got hr %#x.\n", i, j, k, hr); + sprintf(prefix, "Test %u, %u, %u", i, j, k); + check_type_desc(prefix, &type_desc, &field_types[k]); + } + } + } + } + + ID3D10Blob_Release(vs_code); + refcount = reflection->lpVtbl->Release(reflection); + ok(!refcount, "Got unexpected refcount %u.\n", refcount); +} + static BOOL load_d3dcompiler(void) { HMODULE module; @@ -548,6 +710,8 @@ START_TEST(hlsl_d3d11) return; } + test_reflection(); + if (!(mod = LoadLibraryA("d3d11.dll"))) { skip("Direct3D 11 is not available.\n");