Updated Tensor Memory to hold staging within class

This commit is contained in:
Alejandro Saucedo 2021-02-08 07:17:54 +00:00
parent b61f3f2297
commit 04853df469
11 changed files with 97 additions and 115 deletions

View file

@ -65,11 +65,6 @@ OpAlgoLhsRhsOut::init()
" Output: " + std::to_string(this->mTensorOutput->size()));
}
this->mTensorOutputStaging = std::make_shared<Tensor>(
this->mTensorOutput->data(), Tensor::TensorTypes::eStaging);
this->mTensorOutputStaging->init(this->mPhysicalDevice, this->mDevice);
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut fetching spirv data");
std::vector<char> shaderFileData = this->fetchSpirvBinaryData();
@ -110,8 +105,10 @@ OpAlgoLhsRhsOut::record()
vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eTransfer);
this->mTensorOutputStaging->recordCopyFrom(
this->mCommandBuffer, this->mTensorOutput, true);
if (this->mTensorOutput->tensorType() == Tensor::TensorTypes::eDevice) {
this->mTensorOutput->recordCopyFromDeviceToStaging(
this->mCommandBuffer, true);
}
}
void
@ -119,9 +116,7 @@ OpAlgoLhsRhsOut::postEval()
{
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut postSubmit called");
this->mTensorOutputStaging->mapDataFromHostMemory();
this->mTensorOutput->setData(this->mTensorOutputStaging->data());
this->mTensorOutput->mapDataFromHostMemory();
}
}