From b11a54911dd57075f3d5a48843a824eec11b217c Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Wed, 19 Aug 2020 21:10:53 +0100 Subject: [PATCH] Working end to end example --- src/OpBase.hpp | 9 +++++++-- src/OpCreateTensor.hpp | 2 +- src/Sequence.cpp | 12 ++++++++++-- src/Sequence.hpp | 15 +++++++++++---- src/main.cpp | 3 ++- 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/OpBase.hpp b/src/OpBase.hpp index a433c214b..2cf6213da 100644 --- a/src/OpBase.hpp +++ b/src/OpBase.hpp @@ -37,8 +37,13 @@ class OpBase SPDLOG_DEBUG("Compute OpBase destructor started"); } - virtual void init(std::shared_ptr tensor, ...) = 0; - virtual void record() = 0; + virtual void init(std::shared_ptr tensor, ...) { + SPDLOG_DEBUG("Kompute OpBase init called"); + } + + virtual void record() { + SPDLOG_DEBUG("Kompute OpBase record called"); + } protected: std::shared_ptr mPhysicalDevice; diff --git a/src/OpCreateTensor.hpp b/src/OpCreateTensor.hpp index eb7085439..aed91e8b6 100644 --- a/src/OpCreateTensor.hpp +++ b/src/OpCreateTensor.hpp @@ -16,7 +16,7 @@ namespace kp { -class OpCreateTensor : OpBase +class OpCreateTensor : public OpBase { public: OpCreateTensor(); diff --git a/src/Sequence.cpp b/src/Sequence.cpp index 44cd7ed01..a1637473d 100644 --- a/src/Sequence.cpp +++ b/src/Sequence.cpp @@ -67,7 +67,7 @@ Sequence::begin() } if (!this->mRecording) { - spdlog::info("Kompute Sequence starting command recording"); + spdlog::info("Kompute Sequence command recording BEGIN"); this->mCommandBuffer->begin(vk::CommandBufferBeginInfo()); this->mRecording = true; } else { @@ -84,7 +84,7 @@ Sequence::end() } if (this->mRecording) { - spdlog::info("Kompute Sequence ending command recording"); + spdlog::info("Kompute Sequence command recording END"); this->mCommandBuffer->end(); this->mRecording = false; } else { @@ -96,6 +96,8 @@ Sequence::end() void Sequence::eval() { + SPDLOG_DEBUG("Kompute sequence compute recording EVAL"); + bool toggleSingleRecording = !this->mRecording; if (toggleSingleRecording) { this->begin(); @@ -107,6 +109,9 @@ Sequence::eval() 0, nullptr, &waitStageMask, 1, this->mCommandBuffer.get()); vk::Fence fence = this->mDevice->createFence(vk::FenceCreateInfo()); + + SPDLOG_DEBUG("Kompute sequence submitting command buffer into compute queue"); + this->mComputeQueue->submit(1, &submitInfo, fence); this->mDevice->waitForFences(1, &fence, VK_TRUE, UINT64_MAX); this->mDevice->destroy(fence); @@ -114,12 +119,15 @@ Sequence::eval() if (toggleSingleRecording) { this->end(); } + + SPDLOG_DEBUG("Kompute sequence EVAL success"); } void Sequence::createCommandPool() { SPDLOG_DEBUG("Kompute Sequence creating command pool"); + if (this->mDevice == nullptr) { throw std::runtime_error("Kompute Sequence device is null"); } diff --git a/src/Sequence.hpp b/src/Sequence.hpp index 442ce1ba8..b1412a74a 100644 --- a/src/Sequence.hpp +++ b/src/Sequence.hpp @@ -36,9 +36,16 @@ class Sequence static_assert(std::is_base_of::value, "Template only valid with OpBase derived classes"); SPDLOG_DEBUG("Kompute Sequence record"); - T op(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer); - op.init(std::forward(args)...); - op.record(); + + T* op = new T(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer); + OpBase* baseOp = dynamic_cast(op); + + std::unique_ptr baseOpPtr{baseOp}; + + baseOpPtr->init(std::forward(args)...); + baseOpPtr->record(); + + operations.push_back(std::move(baseOpPtr)); } private: @@ -52,7 +59,7 @@ class Sequence bool mFreeCommandBuffer = false; // Base op objects - std::vector operations; + std::vector> operations; // Record state bool mRecording = false; diff --git a/src/main.cpp b/src/main.cpp index 17cc1a533..dcfab1f7d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -621,8 +621,9 @@ main() spdlog::info("Creating manager"); kp::Manager mgr; kp::Tensor tensor({0.0, 1.0, 2.0}); + std::shared_ptr tensorPtr{&tensor}; spdlog::info("Calling manager eval w opcreatetensor"); - mgr.evalOp(std::shared_ptr(&tensor)); + mgr.evalOp(tensorPtr); spdlog::info("Called manager eval success"); std::vector outData = tensor.data(); spdlog::info("Output data: {}", outData);