Initial simpification of interface implementation
This commit is contained in:
parent
956883e0cd
commit
cf7d46cd23
6 changed files with 135 additions and 367 deletions
|
|
@ -15,11 +15,17 @@ OpTensorCopy::OpTensorCopy(const std::vector<std::shared_ptr<Tensor>>& tensors)
|
|||
}
|
||||
|
||||
kp::Tensor::TensorDataTypes dataType = this->mTensors[0]->dataType();
|
||||
uint32_t size = this->mTensors[0]->size();
|
||||
for (const std::shared_ptr<Tensor>& tensor : tensors) {
|
||||
if (tensor->dataType() != dataType) {
|
||||
throw std::runtime_error(fmt::format("Attempting to copy tensors of different types from {} to {}",
|
||||
dataType, tensor->dataType()));
|
||||
}
|
||||
if (tensor->size() != size) {
|
||||
throw std::runtime_error(fmt::format("Attempting to copy tensors of different sizes from {} to {}",
|
||||
size, tensor->size()));
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -55,12 +61,11 @@ OpTensorCopy::postEval(const vk::CommandBuffer& commandBuffer)
|
|||
uint32_t size = this->mTensors[0]->size();
|
||||
uint32_t dataTypeMemSize = this->mTensors[0]->dataTypeMemorySize();
|
||||
uint32_t memSize = size * dataTypeMemSize;
|
||||
void* data = operator new(memSize);
|
||||
this->mTensors[0]->getRawData(data);
|
||||
const void* data = this->mTensors[0]->getRawData();
|
||||
|
||||
// Copy the data from the first tensor into all the tensors
|
||||
for (size_t i = 1; i < this->mTensors.size(); i++) {
|
||||
this->mTensors[i]->setRawData(data, size, dataTypeMemSize);
|
||||
this->mTensors[i]->setRawData(data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue