Added OpTensorSyncDevice by default on manager buildtensor functions with ability to disable with parameter

This commit is contained in:
Alejandro Saucedo 2021-02-08 21:41:48 +00:00
parent 65cb1b7582
commit aa25f980d6

View file

@ -7,6 +7,8 @@
#include "kompute/Sequence.hpp"
#include "kompute/operations/OpTensorSyncDevice.hpp"
#define KP_DEFAULT_SESSION "DEFAULT"
namespace kp {
@ -229,11 +231,13 @@ class Manager
*
* @param data The data to initialize the tensor with
* @param tensorType The type of tensor to initialize
* @param syncDataToGPU Whether to sync the data to GPU memory
* @returns Initialized Tensor with memory Syncd to GPU device
*/
std::shared_ptr<Tensor> buildTensor(
const std::vector<float>& data,
Tensor::TensorTypes tensorType = Tensor::TensorTypes::eDevice)
Tensor::TensorTypes tensorType = Tensor::TensorTypes::eDevice,
bool syncDataToGPU = true)
{
SPDLOG_DEBUG("Kompute Manager buildTensor triggered");
@ -242,11 +246,13 @@ class Manager
std::make_shared<Tensor>(kp::Tensor(data, tensorType));
tensor->init(this->mPhysicalDevice, this->mDevice);
if (tensor->tensorType() != Tensor::TensorTypes::eStorage) {
tensor->mapDataIntoHostMemory();
if (syncDataToGPU) {
this->evalOpDefault<OpTensorSyncDevice>({tensor});
}
this->mManagedTensors.insert(tensor);
return tensor;
}
@ -258,9 +264,10 @@ class Manager
*
* @param data The data to initialize the tensor with
* @param tensorType The type of tensor to initialize
* @param syncDataToGPU Whether to sync the data to GPU memory
* @returns Initialized Tensor with memory Syncd to GPU device
*/
void rebuildTensors(std::vector<std::shared_ptr<kp::Tensor>> tensors)
void rebuildTensors(std::vector<std::shared_ptr<kp::Tensor>> tensors, bool syncDataToGPU = true)
{
SPDLOG_DEBUG("Kompute Manager rebuildTensors triggered");
for (std::shared_ptr<Tensor> tensor : tensors) {
@ -270,9 +277,6 @@ class Manager
}
tensor->init(this->mPhysicalDevice, this->mDevice);
if (tensor->tensorType() != Tensor::TensorTypes::eStorage) {
tensor->mapDataIntoHostMemory();
}
std::set<std::shared_ptr<Tensor>>::iterator it =
this->mManagedTensors.find(tensor);
@ -280,6 +284,10 @@ class Manager
this->mManagedTensors.insert(tensor);
}
}
if (syncDataToGPU) {
this->evalOpDefault<OpTensorSyncDevice>(tensors);
}
}
private: