Updated init parameter to be tensor vector

This commit is contained in:
Alejandro Saucedo 2020-08-20 05:27:42 +01:00
parent 90ea083cba
commit 8aa7843f0e
7 changed files with 108 additions and 16 deletions

View file

@ -28,7 +28,7 @@ class Manager
// Evaluate actions
template<typename T, typename... TArgs>
void evalOp(TArgs&&... args)
void evalOp(std::vector<std::shared_ptr<Tensor>> tensors)
{
SPDLOG_DEBUG("Kompute Manager eval triggered");
Sequence sq(
@ -36,7 +36,7 @@ class Manager
SPDLOG_DEBUG("Kompute Manager created sequence");
sq.begin();
SPDLOG_DEBUG("Kompute Manager sequence begin");
sq.record<T>(std::forward<TArgs>(args)...);
sq.record<T>(tensors);
SPDLOG_DEBUG("Kompute Manager sequence end");
sq.end();
SPDLOG_DEBUG("Kompute Manager sequence eval");

View file

@ -37,7 +37,7 @@ class OpBase
SPDLOG_DEBUG("Compute OpBase destructor started");
}
virtual void init(std::shared_ptr<Tensor> tensor, ...) {
virtual void init(std::vector<std::shared_ptr<Tensor>> tensors) {
SPDLOG_DEBUG("Kompute OpBase init called");
}

View file

@ -23,23 +23,29 @@ OpCreateTensor::~OpCreateTensor() {
}
void
OpCreateTensor::init(std::shared_ptr<Tensor> tensor, ...)
OpCreateTensor::init(std::vector<std::shared_ptr<Tensor>> tensors)
{
SPDLOG_DEBUG("Kompute OpCreateTensor init called");
this->mPrimaryTensor = tensor;
if (tensors.size() < 1) {
throw std::runtime_error("Kompute OpCreateTensor called with less than 1 tensor");
} else if (tensors.size() > 1) {
spdlog::warn("Kompute OpCreateTensor called with more than 1 tensor");
}
this->mPrimaryTensor = tensors[0];
std::vector<uint32_t> data = this->mPrimaryTensor->data();
if (tensor->tensorType() == Tensor::TensorTypes::eDevice) {
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
if (this->mPrimaryTensor->tensorType() == Tensor::TensorTypes::eDevice) {
this->mPrimaryTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
this->mStagingTensor = std::make_shared<Tensor>(tensor->data(), Tensor::TensorTypes::eStaging);
this->mStagingTensor = std::make_shared<Tensor>(this->mPrimaryTensor->data(), Tensor::TensorTypes::eStaging);
this->mStagingTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
}
else {
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
this->mPrimaryTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
}
}

View file

@ -27,7 +27,7 @@ class OpCreateTensor : public OpBase
~OpCreateTensor();
void init(std::shared_ptr<Tensor> tensor, ...) override;
void init(std::vector<std::shared_ptr<Tensor>> tensors) override;
void record() override;

46
src/OpMult.cpp Normal file
View file

@ -0,0 +1,46 @@
#include "Tensor.hpp"
#include "OpMult.hpp"
namespace kp {
OpMult::OpMult() {
SPDLOG_DEBUG("Kompute OpMult constructor base");
}
OpMult::OpMult(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer)
: OpBase(physicalDevice, device, commandBuffer)
{
SPDLOG_DEBUG("Kompute OpMult constructor with params");
}
OpMult::~OpMult() {
SPDLOG_DEBUG("Kompute OpMult destructor started");
}
void
OpMult::init(std::vector<std::shared_ptr<Tensor>> tensors)
{
SPDLOG_DEBUG("Kompute OpMult init called");
if (tensors.size() < 2) {
throw std::runtime_error("Kompute OpMult called with less than 1 tensor");
} else if (tensors.size() > 2) {
spdlog::warn("Kompute OpMult called with more than 2 tensor");
}
}
void
OpMult::record()
{
SPDLOG_DEBUG("Kompute OpMult record called");
}
}

40
src/OpMult.hpp Normal file
View file

@ -0,0 +1,40 @@
#pragma once
#include <vulkan/vulkan.h>
#include <vulkan/vulkan.hpp>
// SPDLOG_ACTIVE_LEVEL must be defined before spdlog.h import
#if DEBUG
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG
#endif
#include <spdlog/spdlog.h>
#include "Tensor.hpp"
#include "OpBase.hpp"
namespace kp {
class OpMult : public OpBase
{
public:
OpMult();
OpMult(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer);
~OpMult();
void init(std::vector<std::shared_ptr<Tensor>> tensors) override;
void record() override;
private:
std::shared_ptr<Tensor> mPrimaryTensor;
std::shared_ptr<Tensor> mStagingTensor;
};
} // End namespace kp

View file

@ -623,16 +623,16 @@ main()
kp::Manager mgr;
spdlog::info("Creating first tensor");
kp::Tensor tensorOne({0.0, 1.0, 2.0});
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>{&tensorOne});
std::shared_ptr<kp::Tensor> tensorOne{new kp::Tensor({0.0, 1.0, 2.0})};
mgr.evalOp<kp::OpCreateTensor>({tensorOne});
spdlog::info("Creating second tensor");
kp::Tensor tensorTwo({1.0, 2.0, 3.0});
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>{&tensorTwo});
std::shared_ptr<kp::Tensor> tensorTwo{new kp::Tensor({0.0, 1.0, 2.0})};
mgr.evalOp<kp::OpCreateTensor>({tensorTwo});
spdlog::info("Called manager eval success");
spdlog::info("Tensor one: {}", tensorOne.data());
spdlog::info("Tensor two: {}", tensorTwo.data());
spdlog::info("Tensor one: {}", tensorOne->data());
spdlog::info("Tensor two: {}", tensorTwo->data());
return 0;
} catch (const std::exception& exc) {