Working end to end example

This commit is contained in:
Alejandro Saucedo 2020-08-19 21:10:53 +01:00
parent 7c3af1189f
commit b11a54911d
5 changed files with 31 additions and 10 deletions

View file

@ -37,8 +37,13 @@ class OpBase
SPDLOG_DEBUG("Compute OpBase destructor started");
}
virtual void init(std::shared_ptr<Tensor> tensor, ...) = 0;
virtual void record() = 0;
virtual void init(std::shared_ptr<Tensor> tensor, ...) {
SPDLOG_DEBUG("Kompute OpBase init called");
}
virtual void record() {
SPDLOG_DEBUG("Kompute OpBase record called");
}
protected:
std::shared_ptr<vk::PhysicalDevice> mPhysicalDevice;

View file

@ -16,7 +16,7 @@
namespace kp {
class OpCreateTensor : OpBase
class OpCreateTensor : public OpBase
{
public:
OpCreateTensor();

View file

@ -67,7 +67,7 @@ Sequence::begin()
}
if (!this->mRecording) {
spdlog::info("Kompute Sequence starting command recording");
spdlog::info("Kompute Sequence command recording BEGIN");
this->mCommandBuffer->begin(vk::CommandBufferBeginInfo());
this->mRecording = true;
} else {
@ -84,7 +84,7 @@ Sequence::end()
}
if (this->mRecording) {
spdlog::info("Kompute Sequence ending command recording");
spdlog::info("Kompute Sequence command recording END");
this->mCommandBuffer->end();
this->mRecording = false;
} else {
@ -96,6 +96,8 @@ Sequence::end()
void
Sequence::eval()
{
SPDLOG_DEBUG("Kompute sequence compute recording EVAL");
bool toggleSingleRecording = !this->mRecording;
if (toggleSingleRecording) {
this->begin();
@ -107,6 +109,9 @@ Sequence::eval()
0, nullptr, &waitStageMask, 1, this->mCommandBuffer.get());
vk::Fence fence = this->mDevice->createFence(vk::FenceCreateInfo());
SPDLOG_DEBUG("Kompute sequence submitting command buffer into compute queue");
this->mComputeQueue->submit(1, &submitInfo, fence);
this->mDevice->waitForFences(1, &fence, VK_TRUE, UINT64_MAX);
this->mDevice->destroy(fence);
@ -114,12 +119,15 @@ Sequence::eval()
if (toggleSingleRecording) {
this->end();
}
SPDLOG_DEBUG("Kompute sequence EVAL success");
}
void
Sequence::createCommandPool()
{
SPDLOG_DEBUG("Kompute Sequence creating command pool");
if (this->mDevice == nullptr) {
throw std::runtime_error("Kompute Sequence device is null");
}

View file

@ -36,9 +36,16 @@ class Sequence
static_assert(std::is_base_of<OpBase, T>::value, "Template only valid with OpBase derived classes");
SPDLOG_DEBUG("Kompute Sequence record");
T op(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
op.init(std::forward<TArgs>(args)...);
op.record();
T* op = new T(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
OpBase* baseOp = dynamic_cast<OpBase*>(op);
std::unique_ptr<OpBase> baseOpPtr{baseOp};
baseOpPtr->init(std::forward<TArgs>(args)...);
baseOpPtr->record();
operations.push_back(std::move(baseOpPtr));
}
private:
@ -52,7 +59,7 @@ class Sequence
bool mFreeCommandBuffer = false;
// Base op objects
std::vector<OpBase> operations;
std::vector<std::unique_ptr<OpBase>> operations;
// Record state
bool mRecording = false;

View file

@ -621,8 +621,9 @@ main()
spdlog::info("Creating manager");
kp::Manager mgr;
kp::Tensor tensor({0.0, 1.0, 2.0});
std::shared_ptr<kp::Tensor> tensorPtr{&tensor};
spdlog::info("Calling manager eval w opcreatetensor");
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>(&tensor));
mgr.evalOp<kp::OpCreateTensor>(tensorPtr);
spdlog::info("Called manager eval success");
std::vector<uint32_t> outData = tensor.data();
spdlog::info("Output data: {}", outData);