Initial implementation of tensor working compiling

This commit is contained in:
Alejandro Saucedo 2021-03-06 17:25:35 +00:00
parent b81896a780
commit ad18c2e546
12 changed files with 417 additions and 210 deletions

View file

@ -13,6 +13,14 @@ OpTensorCopy::OpTensorCopy(const std::vector<std::shared_ptr<Tensor>>& tensors)
throw std::runtime_error(
"Kompute OpTensorCopy called with less than 2 tensor");
}
kp::Tensor::TensorDataTypes dataType = this->mTensors[0]->dataType();
for (const std::shared_ptr<Tensor>& tensor : tensors) {
if (tensor->dataType() != dataType) {
throw std::runtime_error(fmt::format("Attempting to copy tensors of different types from {} to {}",
dataType, tensor->dataType()));
}
}
}
OpTensorCopy::~OpTensorCopy()
@ -43,9 +51,16 @@ OpTensorCopy::postEval(const vk::CommandBuffer& commandBuffer)
{
KP_LOG_DEBUG("Kompute OpTensorCopy postEval called");
// TODO: Simplify with a copyRawData
uint32_t size = this->mTensors[0]->size();
uint32_t dataTypeMemSize = this->mTensors[0]->dataTypeMemorySize();
uint32_t memSize = size * dataTypeMemSize;
void* data = operator new(memSize);
this->mTensors[0]->getRawData(data);
// Copy the data from the first tensor into all the tensors
for (size_t i = 1; i < this->mTensors.size(); i++) {
this->mTensors[i]->setData(this->mTensors[0]->data());
this->mTensors[i]->setRawData(data, size, dataTypeMemSize);
}
}