diff --git a/shaders/glsl/opmult.comp b/shaders/glsl/opmult.comp index 2fa804a99..109b48cdd 100644 --- a/shaders/glsl/opmult.comp +++ b/shaders/glsl/opmult.comp @@ -12,6 +12,10 @@ layout(binding = 2) buffer tensorOutput { uint valuesOutput[ ]; }; +layout(binding = 3) buffer tensorInvalid { + uint valuesInvalid[ ]; +}; + // TODO: Explore how to make layout inside shader dynamic layout (local_size_x = 1, local_size_y = 1, local_size_z = 1) in; @@ -24,6 +28,7 @@ void main() valuesOutput[index] = 100 + index; valuesRhs[index] = 100 + index; valuesLhs[index] = 100 + index; + valuesInvalid[index] = 100 + index; } diff --git a/shaders/glsl/opmult.comp.spv b/shaders/glsl/opmult.comp.spv index c446f2d71..69ffb1c8b 100755 Binary files a/shaders/glsl/opmult.comp.spv and b/shaders/glsl/opmult.comp.spv differ diff --git a/src/OpMult.tpp b/src/OpMult.cpp similarity index 83% rename from src/OpMult.tpp rename to src/OpMult.cpp index 0980ed809..43fcb7cf5 100644 --- a/src/OpMult.tpp +++ b/src/OpMult.cpp @@ -127,14 +127,34 @@ OpMult::record() vk::AccessFlagBits::eTransferRead, vk::PipelineStageFlagBits::eComputeShader, vk::PipelineStageFlagBits::eTransfer); + this->mTensorLHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eTransferRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eTransfer); + this->mTensorRHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eTransferRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eTransfer); - this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput); + this->mTensorOutputStaging->recordCopyFrom(this->mTensorLHS); this->mTensorOutput->recordBufferMemoryBarrier( vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eHostRead, vk::PipelineStageFlagBits::eTransfer, vk::PipelineStageFlagBits::eHost); + this->mTensorLHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eTransferWrite, + vk::AccessFlagBits::eHostRead, + vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eHost); + this->mTensorRHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eTransferWrite, + vk::AccessFlagBits::eHostRead, + vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eHost); } template