Updated init parameter to be tensor vector

This commit is contained in:
Alejandro Saucedo 2020-08-20 05:27:42 +01:00
parent 90ea083cba
commit 8aa7843f0e
7 changed files with 108 additions and 16 deletions

View file

@ -23,23 +23,29 @@ OpCreateTensor::~OpCreateTensor() {
}
void
OpCreateTensor::init(std::shared_ptr<Tensor> tensor, ...)
OpCreateTensor::init(std::vector<std::shared_ptr<Tensor>> tensors)
{
SPDLOG_DEBUG("Kompute OpCreateTensor init called");
this->mPrimaryTensor = tensor;
if (tensors.size() < 1) {
throw std::runtime_error("Kompute OpCreateTensor called with less than 1 tensor");
} else if (tensors.size() > 1) {
spdlog::warn("Kompute OpCreateTensor called with more than 1 tensor");
}
this->mPrimaryTensor = tensors[0];
std::vector<uint32_t> data = this->mPrimaryTensor->data();
if (tensor->tensorType() == Tensor::TensorTypes::eDevice) {
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
if (this->mPrimaryTensor->tensorType() == Tensor::TensorTypes::eDevice) {
this->mPrimaryTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer);
this->mStagingTensor = std::make_shared<Tensor>(tensor->data(), Tensor::TensorTypes::eStaging);
this->mStagingTensor = std::make_shared<Tensor>(this->mPrimaryTensor->data(), Tensor::TensorTypes::eStaging);
this->mStagingTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
}
else {
tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
this->mPrimaryTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data);
}
}