Migrated default base type of tensor to float
This commit is contained in:
parent
00c66f347b
commit
08a1300450
13 changed files with 255 additions and 249 deletions
|
|
@ -14,7 +14,7 @@ TEST_CASE("test_multiple_algo_exec_single_cmd_buf_record") {
|
|||
std::string shader(
|
||||
"#version 450\n"
|
||||
"layout (local_size_x = 1) in;\n"
|
||||
"layout(set = 0, binding = 0) buffer a { uint pa[]; };\n"
|
||||
"layout(set = 0, binding = 0) buffer a { float pa[]; };\n"
|
||||
"void main() {\n"
|
||||
" uint index = gl_GlobalInvocationID.x;\n"
|
||||
" pa[index] = pa[index] + 1;\n"
|
||||
|
|
@ -45,7 +45,7 @@ TEST_CASE("test_multiple_algo_exec_single_cmd_buf_record") {
|
|||
}
|
||||
sqWeakPtr.reset();
|
||||
|
||||
REQUIRE(tensorA->data() == std::vector<uint32_t>{3, 3, 3});
|
||||
REQUIRE(tensorA->data() == std::vector<float>{3, 3, 3});
|
||||
}
|
||||
|
||||
TEST_CASE("test_multiple_algo_exec_multiple_record") {
|
||||
|
|
@ -57,7 +57,7 @@ TEST_CASE("test_multiple_algo_exec_multiple_record") {
|
|||
std::string shader(
|
||||
"#version 450\n"
|
||||
"layout (local_size_x = 1) in;\n"
|
||||
"layout(set = 0, binding = 0) buffer a { uint pa[]; };\n"
|
||||
"layout(set = 0, binding = 0) buffer a { float pa[]; };\n"
|
||||
"void main() {\n"
|
||||
" uint index = gl_GlobalInvocationID.x;\n"
|
||||
" pa[index] = pa[index] + 1;\n"
|
||||
|
|
@ -100,7 +100,7 @@ TEST_CASE("test_multiple_algo_exec_multiple_record") {
|
|||
}
|
||||
sqWeakPtr.reset();
|
||||
|
||||
REQUIRE(tensorA->data() == std::vector<uint32_t>{3, 3, 3});
|
||||
REQUIRE(tensorA->data() == std::vector<float>{3, 3, 3});
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ TEST_CASE("test_multiple_algo_exec_multiple_sequence") {
|
|||
std::string shader(
|
||||
"#version 450\n"
|
||||
"layout (local_size_x = 1) in;\n"
|
||||
"layout(set = 0, binding = 0) buffer a { uint pa[]; };\n"
|
||||
"layout(set = 0, binding = 0) buffer a { float pa[]; };\n"
|
||||
"void main() {\n"
|
||||
" uint index = gl_GlobalInvocationID.x;\n"
|
||||
" pa[index] = pa[index] + 1;\n"
|
||||
|
|
@ -162,6 +162,6 @@ TEST_CASE("test_multiple_algo_exec_multiple_sequence") {
|
|||
sq->eval();
|
||||
}
|
||||
|
||||
REQUIRE(tensorA->data() == std::vector<uint32_t>{3, 3, 3});
|
||||
REQUIRE(tensorA->data() == std::vector<float>{3, 3, 3});
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue