Added initial base for opcreatetensor

This commit is contained in:
Alejandro Saucedo 2020-08-18 21:13:33 +01:00
parent 93041b4519
commit 014f15d552
14 changed files with 422 additions and 61 deletions

View file

@ -1,15 +1,48 @@
#include "Tensor.hpp"
#include "OpCreateTensor.hpp"
namespace kp {
OpCreateTensor::OpCreateTensor() {}
OpCreateTensor::OpCreateTensor(std::shared_ptr<vk::CommandBuffer> commandBuffer)
OpCreateTensor::OpCreateTensor(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer)
: BaseOp(physicalDevice, device, commandBuffer)
{
this->mCommandBuffer = commandBuffer;
}
OpCreateTensor::~OpCreateTensor() {}
}
OpCreateTensor::~OpCreateTensor() {
}
void
OpCreateTensor::init(std::shared_ptr<Tensor> tensor, std::vector<uint32_t> data)
{
this->mPrimaryTensor = tensor;
if (tensor->tensorType() == Tensor::TensorTypes::eDevice) {
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
this->mStagingTensor = std::make_unique<Tensor>(tensor->shape(), Tensor::TensorTypes::eStaging);
this->mStagingTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
}
else {
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
}
}
void
OpCreateTensor::record()
{
if (this->mPrimaryTensor->tensorType() == Tensor::TensorTypes::eDevice) {
this->mPrimaryTensor->recordCopyFrom(this->mStagingTensor);
}
}
}