Added optensorcopy operation

This commit is contained in:
Alejandro Saucedo 2020-09-06 11:07:32 +01:00
parent 93c1ba126e
commit 236c349aa0
13 changed files with 238 additions and 45 deletions

View file

@ -41,8 +41,7 @@ Tensor::~Tensor()
void
Tensor::init(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer)
std::shared_ptr<vk::Device> device)
{
SPDLOG_DEBUG("Kompute Tensor running init with Vulkan params and num data "
"elementS: {}",
@ -50,7 +49,6 @@ Tensor::init(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
this->mPhysicalDevice = physicalDevice;
this->mDevice = device;
this->mCommandBuffer = commandBuffer;
this->mIsInit = true;
@ -106,8 +104,10 @@ Tensor::setData(const std::vector<float>& data)
}
void
Tensor::recordCopyFrom(std::shared_ptr<Tensor> copyFromTensor,
bool createBarrier)
Tensor::recordCopyFrom(
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::shared_ptr<Tensor> copyFromTensor,
bool createBarrier)
{
SPDLOG_DEBUG("Kompute Tensor recordCopyFrom called");
@ -121,12 +121,13 @@ Tensor::recordCopyFrom(std::shared_ptr<Tensor> copyFromTensor,
SPDLOG_DEBUG("Kompute Tensor copying data size {}.", bufferSize);
this->mCommandBuffer->copyBuffer(
commandBuffer->copyBuffer(
*copyFromTensor->mBuffer, *this->mBuffer, copyRegion);
if (createBarrier) {
// Buffer to ensure wait until data is copied to staging buffer
this->recordBufferMemoryBarrier(vk::AccessFlagBits::eTransferWrite,
this->recordBufferMemoryBarrier(commandBuffer,
vk::AccessFlagBits::eTransferWrite,
vk::AccessFlagBits::eHostRead,
vk::PipelineStageFlagBits::eTransfer,
vk::PipelineStageFlagBits::eHost);
@ -134,7 +135,8 @@ Tensor::recordCopyFrom(std::shared_ptr<Tensor> copyFromTensor,
}
void
Tensor::recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask,
Tensor::recordBufferMemoryBarrier(std::shared_ptr<vk::CommandBuffer> commandBuffer,
vk::AccessFlagBits srcAccessMask,
vk::AccessFlagBits dstAccessMask,
vk::PipelineStageFlagBits srcStageMask,
vk::PipelineStageFlagBits dstStageMask)
@ -151,7 +153,7 @@ Tensor::recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask,
bufferMemoryBarrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
bufferMemoryBarrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
this->mCommandBuffer->pipelineBarrier(srcStageMask,
commandBuffer->pipelineBarrier(srcStageMask,
dstStageMask,
vk::DependencyFlags(),
nullptr,