Updated init parameter to be tensor vector
This commit is contained in:
parent
90ea083cba
commit
8aa7843f0e
7 changed files with 108 additions and 16 deletions
|
|
@ -28,7 +28,7 @@ class Manager
|
|||
|
||||
// Evaluate actions
|
||||
template<typename T, typename... TArgs>
|
||||
void evalOp(TArgs&&... args)
|
||||
void evalOp(std::vector<std::shared_ptr<Tensor>> tensors)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute Manager eval triggered");
|
||||
Sequence sq(
|
||||
|
|
@ -36,7 +36,7 @@ class Manager
|
|||
SPDLOG_DEBUG("Kompute Manager created sequence");
|
||||
sq.begin();
|
||||
SPDLOG_DEBUG("Kompute Manager sequence begin");
|
||||
sq.record<T>(std::forward<TArgs>(args)...);
|
||||
sq.record<T>(tensors);
|
||||
SPDLOG_DEBUG("Kompute Manager sequence end");
|
||||
sq.end();
|
||||
SPDLOG_DEBUG("Kompute Manager sequence eval");
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class OpBase
|
|||
SPDLOG_DEBUG("Compute OpBase destructor started");
|
||||
}
|
||||
|
||||
virtual void init(std::shared_ptr<Tensor> tensor, ...) {
|
||||
virtual void init(std::vector<std::shared_ptr<Tensor>> tensors) {
|
||||
SPDLOG_DEBUG("Kompute OpBase init called");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,23 +23,29 @@ OpCreateTensor::~OpCreateTensor() {
|
|||
}
|
||||
|
||||
void
|
||||
OpCreateTensor::init(std::shared_ptr<Tensor> tensor, ...)
|
||||
OpCreateTensor::init(std::vector<std::shared_ptr<Tensor>> tensors)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpCreateTensor init called");
|
||||
|
||||
this->mPrimaryTensor = tensor;
|
||||
if (tensors.size() < 1) {
|
||||
throw std::runtime_error("Kompute OpCreateTensor called with less than 1 tensor");
|
||||
} else if (tensors.size() > 1) {
|
||||
spdlog::warn("Kompute OpCreateTensor called with more than 1 tensor");
|
||||
}
|
||||
|
||||
this->mPrimaryTensor = tensors[0];
|
||||
std::vector<uint32_t> data = this->mPrimaryTensor->data();
|
||||
|
||||
if (tensor->tensorType() == Tensor::TensorTypes::eDevice) {
|
||||
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
|
||||
if (this->mPrimaryTensor->tensorType() == Tensor::TensorTypes::eDevice) {
|
||||
this->mPrimaryTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
|
||||
|
||||
this->mStagingTensor = std::make_shared<Tensor>(tensor->data(), Tensor::TensorTypes::eStaging);
|
||||
this->mStagingTensor = std::make_shared<Tensor>(this->mPrimaryTensor->data(), Tensor::TensorTypes::eStaging);
|
||||
|
||||
this->mStagingTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
|
||||
|
||||
}
|
||||
else {
|
||||
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
|
||||
this->mPrimaryTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class OpCreateTensor : public OpBase
|
|||
|
||||
~OpCreateTensor();
|
||||
|
||||
void init(std::shared_ptr<Tensor> tensor, ...) override;
|
||||
void init(std::vector<std::shared_ptr<Tensor>> tensors) override;
|
||||
|
||||
void record() override;
|
||||
|
||||
|
|
|
|||
46
src/OpMult.cpp
Normal file
46
src/OpMult.cpp
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
|
||||
|
||||
#include "Tensor.hpp"
|
||||
|
||||
#include "OpMult.hpp"
|
||||
|
||||
namespace kp {
|
||||
|
||||
OpMult::OpMult() {
|
||||
SPDLOG_DEBUG("Kompute OpMult constructor base");
|
||||
|
||||
}
|
||||
|
||||
OpMult::OpMult(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer)
|
||||
: OpBase(physicalDevice, device, commandBuffer)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpMult constructor with params");
|
||||
}
|
||||
|
||||
OpMult::~OpMult() {
|
||||
SPDLOG_DEBUG("Kompute OpMult destructor started");
|
||||
}
|
||||
|
||||
void
|
||||
OpMult::init(std::vector<std::shared_ptr<Tensor>> tensors)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpMult init called");
|
||||
|
||||
if (tensors.size() < 2) {
|
||||
throw std::runtime_error("Kompute OpMult called with less than 1 tensor");
|
||||
} else if (tensors.size() > 2) {
|
||||
spdlog::warn("Kompute OpMult called with more than 2 tensor");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
OpMult::record()
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpMult record called");
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
40
src/OpMult.hpp
Normal file
40
src/OpMult.hpp
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#pragma once
|
||||
|
||||
#include <vulkan/vulkan.h>
|
||||
#include <vulkan/vulkan.hpp>
|
||||
|
||||
// SPDLOG_ACTIVE_LEVEL must be defined before spdlog.h import
|
||||
#if DEBUG
|
||||
#define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG
|
||||
#endif
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "Tensor.hpp"
|
||||
|
||||
#include "OpBase.hpp"
|
||||
|
||||
namespace kp {
|
||||
|
||||
class OpMult : public OpBase
|
||||
{
|
||||
public:
|
||||
OpMult();
|
||||
|
||||
OpMult(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer);
|
||||
|
||||
~OpMult();
|
||||
|
||||
void init(std::vector<std::shared_ptr<Tensor>> tensors) override;
|
||||
|
||||
void record() override;
|
||||
|
||||
private:
|
||||
|
||||
std::shared_ptr<Tensor> mPrimaryTensor;
|
||||
std::shared_ptr<Tensor> mStagingTensor;
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
12
src/main.cpp
12
src/main.cpp
|
|
@ -623,16 +623,16 @@ main()
|
|||
kp::Manager mgr;
|
||||
|
||||
spdlog::info("Creating first tensor");
|
||||
kp::Tensor tensorOne({0.0, 1.0, 2.0});
|
||||
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>{&tensorOne});
|
||||
std::shared_ptr<kp::Tensor> tensorOne{new kp::Tensor({0.0, 1.0, 2.0})};
|
||||
mgr.evalOp<kp::OpCreateTensor>({tensorOne});
|
||||
|
||||
spdlog::info("Creating second tensor");
|
||||
kp::Tensor tensorTwo({1.0, 2.0, 3.0});
|
||||
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>{&tensorTwo});
|
||||
std::shared_ptr<kp::Tensor> tensorTwo{new kp::Tensor({0.0, 1.0, 2.0})};
|
||||
mgr.evalOp<kp::OpCreateTensor>({tensorTwo});
|
||||
|
||||
spdlog::info("Called manager eval success");
|
||||
spdlog::info("Tensor one: {}", tensorOne.data());
|
||||
spdlog::info("Tensor two: {}", tensorTwo.data());
|
||||
spdlog::info("Tensor one: {}", tensorOne->data());
|
||||
spdlog::info("Tensor two: {}", tensorTwo->data());
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception& exc) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue