Added option for creating barrier on copyfrom tensor
This commit is contained in:
parent
a2efc441db
commit
2298159586
4 changed files with 20 additions and 38 deletions
|
|
@ -96,8 +96,12 @@ class Tensor
|
|||
* Records a copy from the memory of the tensor provided to the current
|
||||
* thensor. This is intended to pass memory into a processing, to perform
|
||||
* a staging buffer transfer, or to gather output (between others).
|
||||
*
|
||||
* @param copyFromTensor Tensor to copy the data from
|
||||
* @param createBarrier Whether to create a barrier that ensures the data is copied before further operations. Default is true.
|
||||
*/
|
||||
void recordCopyFrom(std::shared_ptr<Tensor> copyFromTensor);
|
||||
void recordCopyFrom(std::shared_ptr<Tensor> copyFromTensor,
|
||||
bool createBarrier = true);
|
||||
|
||||
/**
|
||||
* Records the buffer memory barrier into the command buffer which
|
||||
|
|
|
|||
|
|
@ -198,11 +198,6 @@ OpMult<tX, tY, tZ>::record()
|
|||
this->mAlgorithm->recordDispatch(this->mX, this->mY, this->mZ);
|
||||
|
||||
// Barrier to ensure the shader code is executed before buffer read
|
||||
this->mTensorLHS->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eShaderWrite,
|
||||
vk::AccessFlagBits::eTransferRead,
|
||||
vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eTransfer);
|
||||
this->mTensorOutput->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eShaderWrite,
|
||||
vk::AccessFlagBits::eTransferRead,
|
||||
|
|
@ -210,18 +205,6 @@ OpMult<tX, tY, tZ>::record()
|
|||
vk::PipelineStageFlagBits::eTransfer);
|
||||
|
||||
this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput);
|
||||
|
||||
// Buffer to ensure wait until data is copied to staging buffer
|
||||
this->mTensorLHS->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eTransferWrite,
|
||||
vk::AccessFlagBits::eHostRead,
|
||||
vk::PipelineStageFlagBits::eTransfer,
|
||||
vk::PipelineStageFlagBits::eHost);
|
||||
this->mTensorOutput->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eTransferWrite,
|
||||
vk::AccessFlagBits::eHostRead,
|
||||
vk::PipelineStageFlagBits::eTransfer,
|
||||
vk::PipelineStageFlagBits::eHost);
|
||||
}
|
||||
|
||||
template<uint32_t tX, uint32_t tY, uint32_t tZ>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue