Updated examples in readme

This commit is contained in:
Alejandro Saucedo 2021-02-28 15:53:09 +00:00
parent 38f356fdae
commit ddb77702ee
3 changed files with 221 additions and 36 deletions

View file

@ -3,6 +3,66 @@
#include "kompute/Kompute.hpp"
TEST(TestMultipleAlgoExecutions, TestEndToEndFunctionality) {
kp::Manager mgr;
auto tensorInA = mgr.tensor({ 2., 2., 2. });
auto tensorInB = mgr.tensor({ 1., 2., 3. });
auto tensorOutA = mgr.tensor({ 0., 0., 0. });
auto tensorOutB = mgr.tensor({ 0., 0., 0. });
std::string shader = (R"(
#version 450
layout (local_size_x = 1) in;
// The input tensors bind index is relative to index in parameter passed
layout(set = 0, binding = 0) buffer buf_in_a { float in_a[]; };
layout(set = 0, binding = 1) buffer buf_in_b { float in_b[]; };
layout(set = 0, binding = 2) buffer buf_out_a { float out_a[]; };
layout(set = 0, binding = 3) buffer buf_out_b { float out_b[]; };
// Kompute supports push constants updated on dispatch
layout(push_constant) uniform PushConstants {
float val;
} push_const;
// Kompute also supports spec constants on initalization
layout(constant_id = 0) const float const_one = 0;
void main() {
uint index = gl_GlobalInvocationID.x;
out_a[index] += in_a[index] * in_b[index];
out_b[index] += const_one * push_const.val;
}
)");
std::vector<std::shared_ptr<kp::Tensor>> params = {tensorInA, tensorInB, tensorOutA, tensorOutB};
kp::Workgroup workgroup({3, 1, 1});
kp::Constants specConsts({ 2 });
kp::Constants pushConstsA({ 2.0 });
kp::Constants pushConstsB({ 3.0 });
auto algorithm = mgr.algorithm(params, kp::Shader::compile_source(shader), workgroup, specConsts);
// 3. Run operation with string shader synchronously
mgr.sequence()
->record<kp::OpTensorSyncDevice>(params)
->record<kp::OpAlgoDispatch>(algorithm, pushConstsA)
->record<kp::OpAlgoDispatch>(algorithm, pushConstsB)
->eval();
auto sq = mgr.sequence();
sq->evalAsync<kp::OpTensorSyncLocal>(params);
sq->evalAwait();
EXPECT_EQ(tensorOutA->data(), std::vector<float>({ 4, 8, 12 }));
EXPECT_EQ(tensorOutB->data(), std::vector<float>({ 10, 10, 10 }));
}
TEST(TestMultipleAlgoExecutions, SingleSequenceRecord)
{