Initial simpification of interface implementation

This commit is contained in:
Alejandro Saucedo 2021-03-06 19:42:41 +00:00
parent 956883e0cd
commit cf7d46cd23
6 changed files with 135 additions and 367 deletions

View file

@ -15,11 +15,17 @@ OpTensorCopy::OpTensorCopy(const std::vector<std::shared_ptr<Tensor>>& tensors)
}
kp::Tensor::TensorDataTypes dataType = this->mTensors[0]->dataType();
uint32_t size = this->mTensors[0]->size();
for (const std::shared_ptr<Tensor>& tensor : tensors) {
if (tensor->dataType() != dataType) {
throw std::runtime_error(fmt::format("Attempting to copy tensors of different types from {} to {}",
dataType, tensor->dataType()));
}
if (tensor->size() != size) {
throw std::runtime_error(fmt::format("Attempting to copy tensors of different sizes from {} to {}",
size, tensor->size()));
}
}
}
@ -55,12 +61,11 @@ OpTensorCopy::postEval(const vk::CommandBuffer& commandBuffer)
uint32_t size = this->mTensors[0]->size();
uint32_t dataTypeMemSize = this->mTensors[0]->dataTypeMemorySize();
uint32_t memSize = size * dataTypeMemSize;
void* data = operator new(memSize);
this->mTensors[0]->getRawData(data);
const void* data = this->mTensors[0]->getRawData();
// Copy the data from the first tensor into all the tensors
for (size_t i = 1; i < this->mTensors.size(); i++) {
this->mTensors[i]->setRawData(data, size, dataTypeMemSize);
this->mTensors[i]->setRawData(data);
}
}