Base working compilation
This commit is contained in:
parent
5596b6f029
commit
7c3af1189f
8 changed files with 40 additions and 37 deletions
|
|
@ -12,43 +12,33 @@
|
|||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "Tensor.hpp"
|
||||
|
||||
namespace kp {
|
||||
|
||||
template<class T>
|
||||
class BaseOp
|
||||
class OpBase
|
||||
{
|
||||
private:
|
||||
public:
|
||||
BaseOp() {}
|
||||
OpBase() {}
|
||||
|
||||
BaseOp(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
OpBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer) {
|
||||
SPDLOG_DEBUG("Compute BaseOp constructor started");
|
||||
SPDLOG_DEBUG("Compute OpBase constructor started");
|
||||
|
||||
this->mPhysicalDevice = physicalDevice;
|
||||
this->mDevice = device;
|
||||
this->mCommandBuffer = commandBuffer;
|
||||
}
|
||||
|
||||
~BaseOp()
|
||||
~OpBase()
|
||||
{
|
||||
SPDLOG_DEBUG("Compute BaseOp destructor started");
|
||||
SPDLOG_DEBUG("Compute OpBase destructor started");
|
||||
}
|
||||
|
||||
template<typename... TArgs>
|
||||
void init(TArgs&&... args)
|
||||
{
|
||||
SPDLOG_DEBUG("Compute BaseOp init started");
|
||||
static_cast<T*>(this)->init(std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
template<typename... TArgs>
|
||||
void record(TArgs&&... args)
|
||||
{
|
||||
SPDLOG_DEBUG("Compute BaseOp record started");
|
||||
static_cast<T*>(this)->record(std::forward<TArgs>(args)...);
|
||||
}
|
||||
virtual void init(std::shared_ptr<Tensor> tensor, ...) = 0;
|
||||
virtual void record() = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<vk::PhysicalDevice> mPhysicalDevice;
|
||||
|
|
@ -13,7 +13,7 @@ OpCreateTensor::OpCreateTensor() {
|
|||
OpCreateTensor::OpCreateTensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer)
|
||||
: BaseOp(physicalDevice, device, commandBuffer)
|
||||
: OpBase(physicalDevice, device, commandBuffer)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpCreateTensor constructor with params");
|
||||
}
|
||||
|
|
@ -23,15 +23,17 @@ OpCreateTensor::~OpCreateTensor() {
|
|||
}
|
||||
|
||||
void
|
||||
OpCreateTensor::init(std::shared_ptr<Tensor> tensor, std::vector<uint32_t> data)
|
||||
OpCreateTensor::init(std::shared_ptr<Tensor> tensor, ...)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpCreateTensor init called");
|
||||
|
||||
this->mPrimaryTensor = tensor;
|
||||
std::vector<uint32_t> data = this->mPrimaryTensor->data();
|
||||
|
||||
if (tensor->tensorType() == Tensor::TensorTypes::eDevice) {
|
||||
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
|
||||
|
||||
this->mStagingTensor = std::make_shared<Tensor>(tensor->shape(), Tensor::TensorTypes::eStaging);
|
||||
this->mStagingTensor = std::make_shared<Tensor>(tensor->data(), Tensor::TensorTypes::eStaging);
|
||||
|
||||
this->mStagingTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@
|
|||
|
||||
#include "Tensor.hpp"
|
||||
|
||||
#include "BaseOp.hpp"
|
||||
#include "OpBase.hpp"
|
||||
|
||||
namespace kp {
|
||||
|
||||
class OpCreateTensor : BaseOp<OpCreateTensor>
|
||||
class OpCreateTensor : OpBase
|
||||
{
|
||||
public:
|
||||
OpCreateTensor();
|
||||
|
|
@ -27,9 +27,9 @@ class OpCreateTensor : BaseOp<OpCreateTensor>
|
|||
|
||||
~OpCreateTensor();
|
||||
|
||||
void init(std::shared_ptr<Tensor> tensor, std::vector<uint32_t> data);
|
||||
void init(std::shared_ptr<Tensor> tensor, ...) override;
|
||||
|
||||
void record();
|
||||
void record() override;
|
||||
|
||||
private:
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include "OpBase.hpp"
|
||||
|
||||
namespace kp {
|
||||
|
||||
class Sequence
|
||||
|
|
@ -31,6 +33,8 @@ class Sequence
|
|||
template<typename T, typename... TArgs>
|
||||
void record(TArgs&&... args)
|
||||
{
|
||||
static_assert(std::is_base_of<OpBase, T>::value, "Template only valid with OpBase derived classes");
|
||||
|
||||
SPDLOG_DEBUG("Kompute Sequence record");
|
||||
T op(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
|
||||
op.init(std::forward<TArgs>(args)...);
|
||||
|
|
@ -47,6 +51,9 @@ class Sequence
|
|||
std::shared_ptr<vk::CommandBuffer> mCommandBuffer = nullptr;
|
||||
bool mFreeCommandBuffer = false;
|
||||
|
||||
// Base op objects
|
||||
std::vector<OpBase> operations;
|
||||
|
||||
// Record state
|
||||
bool mRecording = false;
|
||||
|
||||
|
|
|
|||
|
|
@ -8,11 +8,12 @@ Tensor::Tensor() {
|
|||
this->mTensorType = TensorTypes::eDevice;
|
||||
}
|
||||
|
||||
Tensor::Tensor(std::array<uint32_t, KP_MAX_DIM_SIZE> shape, TensorTypes tensorType)
|
||||
Tensor::Tensor(std::vector<uint32_t> data, TensorTypes tensorType)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute Tensor constructor shape and type");
|
||||
|
||||
this->mShape = shape;
|
||||
this->mData = data;
|
||||
this->mShape = {data.size()};
|
||||
this->mTensorType = tensorType;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class Tensor
|
|||
|
||||
Tensor();
|
||||
|
||||
Tensor(std::array<uint32_t, KP_MAX_DIM_SIZE> shape, TensorTypes tensorType = TensorTypes::eDevice);
|
||||
Tensor(std::vector<uint32_t> data, TensorTypes tensorType = TensorTypes::eDevice);
|
||||
|
||||
~Tensor();
|
||||
|
||||
|
|
|
|||
|
|
@ -620,10 +620,9 @@ main()
|
|||
// Run Kompute
|
||||
spdlog::info("Creating manager");
|
||||
kp::Manager mgr;
|
||||
std::vector<uint32_t> data = {0.0, 1.0, 2.0};
|
||||
kp::Tensor tensor({data.size()});
|
||||
kp::Tensor tensor({0.0, 1.0, 2.0});
|
||||
spdlog::info("Calling manager eval w opcreatetensor");
|
||||
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>(&tensor), data);
|
||||
mgr.evalOp<kp::OpCreateTensor>(std::shared_ptr<kp::Tensor>(&tensor));
|
||||
spdlog::info("Called manager eval success");
|
||||
std::vector<uint32_t> outData = tensor.data();
|
||||
spdlog::info("Output data: {}", outData);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue