diff --git a/shaders/glsl/opmult.comp b/shaders/glsl/opmult.comp index c60d67696..2fa804a99 100644 --- a/shaders/glsl/opmult.comp +++ b/shaders/glsl/opmult.comp @@ -21,7 +21,9 @@ void main() //valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; // FOR TESTING - valuesOutput[index] = index; + valuesOutput[index] = 100 + index; + valuesRhs[index] = 100 + index; + valuesLhs[index] = 100 + index; } diff --git a/shaders/glsl/opmult.comp.spv b/shaders/glsl/opmult.comp.spv index e8e5b653a..c446f2d71 100755 Binary files a/shaders/glsl/opmult.comp.spv and b/shaders/glsl/opmult.comp.spv differ diff --git a/src/OpMult.cpp b/src/OpMult.cpp index b32c9524f..873d72797 100644 --- a/src/OpMult.cpp +++ b/src/OpMult.cpp @@ -45,6 +45,21 @@ OpMult::init(std::vector> tensors) this->mTensorRHS = tensors[1]; this->mTensorOutput = tensors[2]; + // The dispatch size is set up based on either explicitly provided template parameters or by default it would take the shape and size of the tensors + if (tX > 0) { + // If at least the x value is provided we use mainly the parameters provided + this->mX = tX; + this->mY = tY > 0 ? tY : 1; + this->mZ = tZ > 0 ? tZ : 1; + } + else { + // TODO: Fully support the full size dispatch using size for the shape + this->mX = this->mTensorLHS->size(); + this->mY = 1; + this->mZ = 1; + } + spdlog::info("Kompute OpMult dispatch size X: {}, Y: {}, Z: {}", this->mX, this->mY, this->mZ); + // TODO: Explore adding a validate function if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() && this->mTensorOutput->isInit())) { @@ -83,21 +98,32 @@ OpMult::record() { SPDLOG_DEBUG("Kompute OpMult record called"); - this->mAlgorithm->recordDispatch(1, 1, 1); + this->mTensorLHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eHostWrite, + vk::AccessFlagBits::eShaderRead, + vk::PipelineStageFlagBits::eHost, + vk::PipelineStageFlagBits::eComputeShader); + this->mTensorRHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eHostWrite, + vk::AccessFlagBits::eShaderRead, + vk::PipelineStageFlagBits::eHost, + vk::PipelineStageFlagBits::eComputeShader); + + this->mAlgorithm->recordDispatch(this->mX, this->mY, this->mZ); this->mTensorOutput->recordBufferMemoryBarrier( - vk::AccessFlagBits::eShaderWrite, - vk::AccessFlagBits::eTransferRead, - vk::PipelineStageFlagBits::eComputeShader, - vk::PipelineStageFlagBits::eTransfer); + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eTransferRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eTransfer); this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput); this->mTensorOutput->recordBufferMemoryBarrier( - vk::AccessFlagBits::eTransferWrite, - vk::AccessFlagBits::eHostRead, - vk::PipelineStageFlagBits::eTransfer, - vk::PipelineStageFlagBits::eHost); + vk::AccessFlagBits::eTransferWrite, + vk::AccessFlagBits::eHostRead, + vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eHost); } void diff --git a/src/OpMult.hpp b/src/OpMult.hpp index fd1635fdf..51012d937 100644 --- a/src/OpMult.hpp +++ b/src/OpMult.hpp @@ -17,6 +17,7 @@ namespace kp { +template class OpMult : public OpBase { public: @@ -40,6 +41,10 @@ class OpMult : public OpBase std::shared_ptr mTensorRHS; std::shared_ptr mTensorOutput; std::shared_ptr mTensorOutputStaging; + + uint32_t mX; + uint32_t mY; + uint32_t mZ; }; } // End namespace kp