From 6113d286a9177b41e012f3a291651ff06a0f2607 Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Sun, 12 Sep 2021 13:13:20 +0100 Subject: [PATCH] Updated python to build Signed-off-by: Alejandro Saucedo --- python/src/main.cpp | 1 - test/TestSpecializationConstant.cpp | 48 +++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/python/src/main.cpp b/python/src/main.cpp index 846576adb..43c369555 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -61,7 +61,6 @@ PYBIND11_MODULE(kp, m) { py::class_>(m, "Algorithm", DOC(kp, Algorithm, Algorithm)) .def("get_tensors", &kp::Algorithm::getTensors, DOC(kp, Algorithm, getTensors)) .def("destroy", &kp::Algorithm::destroy, DOC(kp, Algorithm, destroy)) - .def("get_spec_consts", &kp::Algorithm::getSpecializationConstants, DOC(kp, Algorithm, getSpecializationConstants)) .def("is_init", &kp::Algorithm::isInit, DOC(kp, Algorithm, isInit)); py::class_>(m, "Tensor", DOC(kp, Tensor)) diff --git a/test/TestSpecializationConstant.cpp b/test/TestSpecializationConstant.cpp index 15da143a0..f57c221ab 100644 --- a/test/TestSpecializationConstant.cpp +++ b/test/TestSpecializationConstant.cpp @@ -53,3 +53,51 @@ TEST(TestSpecializationConstants, TestTwoConstants) } } } + +TEST(TestSpecializationConstants, TestConstantsInt) +{ + { + std::string shader(R"( + #version 450 + layout (constant_id = 0) const float cOne = 1; + layout (constant_id = 1) const float cTwo = 1; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + layout(set = 0, binding = 1) buffer b { float pb[]; }; + void main() { + uint index = gl_GlobalInvocationID.x; + pa[index] = cOne; + pb[index] = cTwo; + })"); + + std::vector spirv = compileSource(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr> tensorA = + mgr.tensor({ 0, 0, 0 }); + std::shared_ptr> tensorB = + mgr.tensor({ 0, 0, 0 }); + + std::vector> params = { tensorA, + tensorB }; + + kp::Constants spec = kp::Constants({ 5.0, 0.3 }); + + std::shared_ptr algo = + mgr.algorithm(params, spirv, {}, spec); + + sq = mgr.sequence() + ->record(params) + ->record(algo) + ->record(params) + ->eval(); + + EXPECT_EQ(tensorA->vector(), std::vector({ 5, 5, 5 })); + EXPECT_EQ(tensorB->vector(), std::vector({ 0.3, 0.3, 0.3 })); + } + } +}