Working implementation with tests
This commit is contained in:
parent
cf7d46cd23
commit
f02b9d6915
21 changed files with 297 additions and 216 deletions
|
|
@ -29,7 +29,7 @@ TEST(TestPushConstants, TestConstantsAlgoDispatchOverride)
|
|||
{
|
||||
kp::Manager mgr;
|
||||
|
||||
std::shared_ptr<kp::Tensor> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
std::shared_ptr<kp::TensorT<float>> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
|
||||
std::shared_ptr<kp::Algorithm> algo =
|
||||
mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0, 0.0, 0.0 });
|
||||
|
|
@ -42,7 +42,7 @@ TEST(TestPushConstants, TestConstantsAlgoDispatchOverride)
|
|||
sq->eval<kp::OpAlgoDispatch>(algo, kp::Constants{ 0.3, 0.2, 0.1 });
|
||||
sq->eval<kp::OpTensorSyncLocal>({ tensor });
|
||||
|
||||
EXPECT_EQ(tensor->data(), kp::Constants({ 0.4, 0.4, 0.4 }));
|
||||
EXPECT_EQ(tensor->vector(), kp::Constants({ 0.4, 0.4, 0.4 }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -72,7 +72,7 @@ TEST(TestPushConstants, TestConstantsAlgoDispatchNoOverride)
|
|||
{
|
||||
kp::Manager mgr;
|
||||
|
||||
std::shared_ptr<kp::Tensor> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
std::shared_ptr<kp::TensorT<float>> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
|
||||
std::shared_ptr<kp::Algorithm> algo =
|
||||
mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.1, 0.2, 0.3 });
|
||||
|
|
@ -85,7 +85,7 @@ TEST(TestPushConstants, TestConstantsAlgoDispatchNoOverride)
|
|||
sq->eval<kp::OpAlgoDispatch>(algo, kp::Constants{ 0.3, 0.2, 0.1 });
|
||||
sq->eval<kp::OpTensorSyncLocal>({ tensor });
|
||||
|
||||
EXPECT_EQ(tensor->data(), kp::Constants({ 0.4, 0.4, 0.4 }));
|
||||
EXPECT_EQ(tensor->vector(), kp::Constants({ 0.4, 0.4, 0.4 }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -115,7 +115,7 @@ TEST(TestPushConstants, TestConstantsWrongSize)
|
|||
{
|
||||
kp::Manager mgr;
|
||||
|
||||
std::shared_ptr<kp::Tensor> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
std::shared_ptr<kp::TensorT<float>> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
|
||||
std::shared_ptr<kp::Algorithm> algo =
|
||||
mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0 });
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue