From 6192dda520cf8ea22e09650fc8daca1c5ffce7e9 Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Mon, 1 Mar 2021 22:08:05 +0000 Subject: [PATCH] Added rerecord functionality and tests --- single_include/kompute/Kompute.hpp | 61 ++-------------------- src/Sequence.cpp | 18 ++++++- src/include/kompute/Algorithm.hpp | 4 +- src/include/kompute/Sequence.hpp | 57 ++------------------ test/TestSequence.cpp | 83 ++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 113 deletions(-) diff --git a/single_include/kompute/Kompute.hpp b/single_include/kompute/Kompute.hpp index 52d574ad3..7b67e2024 100755 --- a/single_include/kompute/Kompute.hpp +++ b/single_include/kompute/Kompute.hpp @@ -1146,8 +1146,8 @@ class Algorithm * @specalizationInstalces The specialization parameters to pass to the * function processing */ - void rebuild(const std::vector>& tensors = {}, - const std::vector& spirv = {}, + void rebuild(const std::vector>& tensors, + const std::vector& spirv, const Workgroup& workgroup = {}, const Constants& specializationConstants = {}); @@ -1554,34 +1554,17 @@ class Sequence : public std::enable_shared_from_this */ template std::shared_ptr record( - std::vector> tensors, - TArgs&&... params) + std::vector> tensors, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->record(op); } template std::shared_ptr record(std::shared_ptr algorithm, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(algorithm, std::forward(params)...) }; - return this->record(op); } @@ -1606,34 +1589,15 @@ class Sequence : public std::enable_shared_from_this std::shared_ptr eval(std::vector> tensors, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - - // TODO: Aim to be able to handle errors when returning without throw - // except return this->eval(op); } - // Needded as otherise can't use initialiser list template std::shared_ptr eval(std::shared_ptr algorithm, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(algorithm, std::forward(params)...) }; - return this->eval(op); } @@ -1658,32 +1622,15 @@ class Sequence : public std::enable_shared_from_this std::vector> tensors, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->evalAsync(op); } - // Needed as otherwise it's not possible to use initializer lists template std::shared_ptr evalAsync(std::shared_ptr algorithm, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(algorithm, std::forward(params)...) }; - return this->evalAsync(op); } @@ -1727,6 +1674,8 @@ class Sequence : public std::enable_shared_from_this bool isInit(); + void rerecord(); + /** * Returns true if the sequence is currently running - mostly used for async * workloads. diff --git a/src/Sequence.cpp b/src/Sequence.cpp index 68ff082ce..fa715cefc 100644 --- a/src/Sequence.cpp +++ b/src/Sequence.cpp @@ -51,6 +51,11 @@ Sequence::end() { KP_LOG_DEBUG("Kompute Sequence calling END"); + if (this->isRunning()) { + throw std::runtime_error( + "Kompute Sequence begin called when sequence still running"); + } + if (!this->isRecording()) { KP_LOG_WARN("Kompute Sequence end called when not recording"); return; @@ -64,7 +69,7 @@ Sequence::end() void Sequence::clear() { - KP_LOG_DEBUG("Kompute Sequence calling clear"); + KP_LOG_DEBUG("Kompute Sequence calling clear"); this->end(); } @@ -171,6 +176,17 @@ Sequence::isInit() this->mComputeQueue; } +void +Sequence::rerecord() +{ + this->end(); + std::vector> ops = this->mOperations; + this->mOperations.clear(); + for (const std::shared_ptr& op : ops) { + this->record(op); + } +} + void Sequence::destroy() { diff --git a/src/include/kompute/Algorithm.hpp b/src/include/kompute/Algorithm.hpp index e5fd1287e..32e5d9bdf 100644 --- a/src/include/kompute/Algorithm.hpp +++ b/src/include/kompute/Algorithm.hpp @@ -35,8 +35,8 @@ class Algorithm * @specalizationInstalces The specialization parameters to pass to the * function processing */ - void rebuild(const std::vector>& tensors = {}, - const std::vector& spirv = {}, + void rebuild(const std::vector>& tensors, + const std::vector& spirv, const Workgroup& workgroup = {}, const Constants& specializationConstants = {}); diff --git a/src/include/kompute/Sequence.hpp b/src/include/kompute/Sequence.hpp index 29c6a0c3b..5741fb4e6 100644 --- a/src/include/kompute/Sequence.hpp +++ b/src/include/kompute/Sequence.hpp @@ -47,34 +47,17 @@ class Sequence : public std::enable_shared_from_this */ template std::shared_ptr record( - std::vector> tensors, - TArgs&&... params) + std::vector> tensors, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->record(op); } template std::shared_ptr record(std::shared_ptr algorithm, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(algorithm, std::forward(params)...) }; - return this->record(op); } @@ -99,34 +82,15 @@ class Sequence : public std::enable_shared_from_this std::shared_ptr eval(std::vector> tensors, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - - // TODO: Aim to be able to handle errors when returning without throw - // except return this->eval(op); } - // Needded as otherise can't use initialiser list template std::shared_ptr eval(std::shared_ptr algorithm, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(algorithm, std::forward(params)...) }; - return this->eval(op); } @@ -151,32 +115,15 @@ class Sequence : public std::enable_shared_from_this std::vector> tensors, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->evalAsync(op); } - // Needed as otherwise it's not possible to use initializer lists template std::shared_ptr evalAsync(std::shared_ptr algorithm, TArgs&&... params) { - KP_LOG_DEBUG("Kompute Sequence record function started"); - - static_assert(std::is_base_of::value, - "Kompute Sequence record(...) template only valid with " - "OpBase derived classes"); - - KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); std::shared_ptr op{ new T(algorithm, std::forward(params)...) }; - return this->evalAsync(op); } @@ -220,6 +167,8 @@ class Sequence : public std::enable_shared_from_this bool isInit(); + void rerecord(); + /** * Returns true if the sequence is currently running - mostly used for async * workloads. diff --git a/test/TestSequence.cpp b/test/TestSequence.cpp index 4d0233694..482868a88 100644 --- a/test/TestSequence.cpp +++ b/test/TestSequence.cpp @@ -17,3 +17,86 @@ TEST(TestSequence, SequenceDestructorViaManager) EXPECT_FALSE(sq->isInit()); } + +TEST(TestSequence, SequenceDestructorOutsideManagerExplicit) +{ + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + sq = mgr.sequence(); + + EXPECT_TRUE(sq->isInit()); + + sq->destroy(); + + EXPECT_FALSE(sq->isInit()); + } + + EXPECT_FALSE(sq->isInit()); +} + +TEST(TestSequence, SequenceDestructorOutsideManagerImplicit) +{ + kp::Manager mgr; + + std::weak_ptr sqWeak; + + { + std::shared_ptr sq = mgr.sequence(); + + sqWeak = sq; + + EXPECT_TRUE(sq->isInit()); + } + + EXPECT_FALSE(sqWeak.lock()); +} + +TEST(TestSequence, RerecordSequence) +{ + kp::Manager mgr; + + std::shared_ptr sq = mgr.sequence(); + + std::shared_ptr tensorA = mgr.tensor({1, 2, 3}); + std::shared_ptr tensorB = mgr.tensor({2, 2, 2}); + std::shared_ptr tensorOut = mgr.tensor({0, 0, 0}); + + sq->eval({ tensorA, tensorB, tensorOut }); + + std::vector spirv = kp::Shader::compile_source(R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer bina { float tina[]; }; + layout(set = 0, binding = 1) buffer binb { float tinb[]; }; + layout(set = 0, binding = 2) buffer bout { float tout[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + tout[index] = tina[index] * tinb[index]; + } + )"); + + std::shared_ptr algo = + mgr.algorithm({tensorA, tensorB, tensorOut}, spirv); + + sq->record(algo) + ->record({tensorA, tensorB, tensorOut}); + + sq->eval(); + + EXPECT_EQ(tensorOut->data(), std::vector({2, 4, 6})); + + algo->rebuild({tensorOut, tensorA, tensorB}, spirv); + + // Refresh and trigger a rerecord + sq->rerecord(); + sq->eval(); + + EXPECT_EQ(tensorB->data(), std::vector({2, 8, 18})); +}