diff --git a/test/TestMultipleAlgoExecutions.cpp b/test/TestMultipleAlgoExecutions.cpp index 45466d4a1..7f63c208f 100644 --- a/test/TestMultipleAlgoExecutions.cpp +++ b/test/TestMultipleAlgoExecutions.cpp @@ -219,3 +219,60 @@ TEST(TestMultipleAlgoExecutions, SingleRecordMultipleEval) EXPECT_EQ(tensorA->vector(), std::vector({ 3, 3, 3 })); } + +TEST(TestAlgoUtils, TestAlgorithmUtilFunctions) +{ + + kp::Manager mgr; + + // Default tensor constructor simplifies creation of float values + auto tensorInA = mgr.tensor({ 2., 2., 2. }); + auto tensorInB = mgr.tensor({ 1., 2., 3. }); + // Explicit type constructor supports int, in32, double, float and int + auto tensorOutA = mgr.tensorT({ 0, 0, 0 }); + auto tensorOutB = mgr.tensorT({ 0, 0, 0 }); + + std::string shader = (R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer buf_in_a { float in_a[]; }; + layout(set = 0, binding = 1) buffer buf_in_b { float in_b[]; }; + layout(set = 0, binding = 2) buffer buf_out_a { uint out_a[]; }; + layout(set = 0, binding = 3) buffer buf_out_b { uint out_b[]; }; + + // Kompute supports push constants updated on dispatch + layout(push_constant) uniform PushConstants { + float val; + } push_const; + + // Kompute also supports spec constants on initalization + layout(constant_id = 0) const float const_one = 0; + + void main() { + uint index = gl_GlobalInvocationID.x; + out_a[index] += uint( in_a[index] * in_b[index] ); + out_b[index] += uint( const_one * push_const.val ); + } + )"); + + std::vector> params = { + tensorInA, tensorInB, tensorOutA, tensorOutB + }; + + kp::Workgroup workgroup({ 3, 1, 1 }); + kp::Constants specConsts({ 2 }); + kp::Constants pushConsts({ 2.0 }); + + auto algorithm = mgr.algorithm(params, + compileSource(shader), + workgroup, + specConsts, + pushConsts); + + EXPECT_EQ(algorithm->getWorkgroup(), workgroup); + EXPECT_EQ(algorithm->getPush(), pushConsts); + EXPECT_EQ(algorithm->getSpecializationConstants(), specConsts); +}