diff --git a/src/OpMult.cpp b/src/OpMult.cpp index 3f6ec93f1..b32c9524f 100644 --- a/src/OpMult.cpp +++ b/src/OpMult.cpp @@ -86,18 +86,18 @@ OpMult::record() this->mAlgorithm->recordDispatch(1, 1, 1); this->mTensorOutput->recordBufferMemoryBarrier( - vk::AccessFlagBits::eShaderWrite, - vk::AccessFlagBits::eTransferRead, - vk::PipelineStageFlagBits::eComputeShader, - vk::PipelineStageFlagBits::eTransfer); + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eTransferRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eTransfer); this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput); this->mTensorOutput->recordBufferMemoryBarrier( - vk::AccessFlagBits::eTransferWrite, - vk::AccessFlagBits::eHostRead, - vk::PipelineStageFlagBits::eTransfer, - vk::PipelineStageFlagBits::eHost); + vk::AccessFlagBits::eTransferWrite, + vk::AccessFlagBits::eHostRead, + vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eHost); } void diff --git a/src/Tensor.cpp b/src/Tensor.cpp index 841434391..d64f7241a 100644 --- a/src/Tensor.cpp +++ b/src/Tensor.cpp @@ -136,7 +136,12 @@ Tensor::recordCopyFrom(std::shared_ptr copyFromTensor) this->mData = copyFromTensor->mData; } -void Tensor::recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask, vk::AccessFlagBits dstAccessMask, vk::PipelineStageFlagBits srcStageMask, vk::PipelineStageFlagBits dstStageMask) { +void +Tensor::recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask) +{ SPDLOG_DEBUG("Kompute Tensor recording buffer memory barrier"); vk::DeviceSize bufferSize = this->memorySize(); @@ -146,19 +151,15 @@ void Tensor::recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask, vk::Acc bufferMemoryBarrier.size = bufferSize; bufferMemoryBarrier.srcAccessMask = srcAccessMask; bufferMemoryBarrier.dstAccessMask = dstAccessMask; - bufferMemoryBarrier.srcQueueFamilyIndex = - VK_QUEUE_FAMILY_IGNORED; - bufferMemoryBarrier.dstQueueFamilyIndex = - VK_QUEUE_FAMILY_IGNORED; - - this->mCommandBuffer->pipelineBarrier( - srcStageMask, - dstStageMask, - vk::DependencyFlags(), - nullptr, - bufferMemoryBarrier, - nullptr); + bufferMemoryBarrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + bufferMemoryBarrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + this->mCommandBuffer->pipelineBarrier(srcStageMask, + dstStageMask, + vk::DependencyFlags(), + nullptr, + bufferMemoryBarrier, + nullptr); } // TODO: Explore if this function should be here or expose buffer diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 13b122b9e..8c5c2e95b 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -51,8 +51,12 @@ class Tensor // Record functions void recordCopyFrom(std::shared_ptr copyFromTensor); - // TODO: Explore simplifying by infering pipeline stage flag bits from access flag bits (as seems to be superset) - void recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask, vk::AccessFlagBits dstAccessMask, vk::PipelineStageFlagBits srcStageMask, vk::PipelineStageFlagBits dstStageMask); + // TODO: Explore simplifying by infering pipeline stage flag bits from + // access flag bits (as seems to be superset) + void recordBufferMemoryBarrier(vk::AccessFlagBits srcAccessMask, + vk::AccessFlagBits dstAccessMask, + vk::PipelineStageFlagBits srcStageMask, + vk::PipelineStageFlagBits dstStageMask); // Util functions vk::DescriptorBufferInfo constructDescriptorBufferInfo();