From d74a999e121dfbcf470f14aa8250ddd1b08a1adf Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Sat, 22 Aug 2020 13:33:21 +0100 Subject: [PATCH] Added template parameters to opmult class --- shaders/glsl/opmult.comp | 4 +++- shaders/glsl/opmult.comp.spv | Bin 1192 -> 1404 bytes src/OpMult.cpp | 44 ++++++++++++++++++++++++++++------- src/OpMult.hpp | 5 ++++ 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/shaders/glsl/opmult.comp b/shaders/glsl/opmult.comp index c60d67696..2fa804a99 100644 --- a/shaders/glsl/opmult.comp +++ b/shaders/glsl/opmult.comp @@ -21,7 +21,9 @@ void main() //valuesOutput[index] = valuesLhs[index] * valuesRhs[index]; // FOR TESTING - valuesOutput[index] = index; + valuesOutput[index] = 100 + index; + valuesRhs[index] = 100 + index; + valuesLhs[index] = 100 + index; } diff --git a/shaders/glsl/opmult.comp.spv b/shaders/glsl/opmult.comp.spv index e8e5b653a116e88357eb29ef19e12337a6b8b31b..c446f2d71b9ed6e4ac8727662af7dce9f14c63c2 100755 GIT binary patch literal 1404 zcmYk5+iFu$5QcY??zCF%!CLF7HEBGz1}{_uQ78zp5D@GGNHjqMF%gp#FMT$j${WG& z+iOqV6NZ`pUo-z&Yt~M^)?Uf7rEHa7Crh@K)nKx0Eo&69+dJ)bX7fSk+4HAXtY`J2 z(3%Zua^f=ntv{N$u!e8K%WdJ-h1cn=fqz5fR8;%KtE`PCgW*qN4pQsO_+xiG{nQ_y zOwOjC`}5Isa`LuVZxs6T;bb=b@qRJ?zL-O^^Gd;+pY_L!;q3of&i>mjU^%}%V(S0! zbw=LcC#R*}#M^f#(M8`q^3>Nm?~~I~U*he1km#(-{k)HU8y9a98@O|LhqCk1+9^;| zFE}S|y=Ol-UrV?JcBZbYTh)uZRQ2RS&pTE1#OnE8>fI$)_2Q0IJ-N{Hu2ns;dYXDq zh~EDOzRjsUcZ+WO&3CAAzb39V!F-^F%kP3O5Y6pz{UUJ*mw#1a_mDdTTYCxDtiMcj zO@0rRc!Pa%=P$PI3ehvksVjE=an7scOdILgx(~^#x{pd6x{tx?%7^YDxwBW-es$%n K{mToy#Qy=l@=UG( delta 457 zcmYk0%}T>i5QXO^ZEb}tB($~0HccXKv^y7tqToswcGEX-5fBtb}`D94(`tZenJyXKPyM8o8>1%RQ z?_4}c*M_~{KD?FHV%6WY&@x}T`@IAHwx9vm%p`v4;&bJPaD}hL9Q4&XTg#OmQEw-t zN|@-?Izipr3DUxW5)1zve1DFq_HhDG-jMlvyAic^l2+d#xBBGGHd=ei(_Foi=Cp~l z=EmIWk*7K74DgQa1Vaz!;NTS1;Vw3w{M<*4kF!2i8!qY4$^(d@3q82tS$XRFBWnu( D>ysfb diff --git a/src/OpMult.cpp b/src/OpMult.cpp index b32c9524f..873d72797 100644 --- a/src/OpMult.cpp +++ b/src/OpMult.cpp @@ -45,6 +45,21 @@ OpMult::init(std::vector> tensors) this->mTensorRHS = tensors[1]; this->mTensorOutput = tensors[2]; + // The dispatch size is set up based on either explicitly provided template parameters or by default it would take the shape and size of the tensors + if (tX > 0) { + // If at least the x value is provided we use mainly the parameters provided + this->mX = tX; + this->mY = tY > 0 ? tY : 1; + this->mZ = tZ > 0 ? tZ : 1; + } + else { + // TODO: Fully support the full size dispatch using size for the shape + this->mX = this->mTensorLHS->size(); + this->mY = 1; + this->mZ = 1; + } + spdlog::info("Kompute OpMult dispatch size X: {}, Y: {}, Z: {}", this->mX, this->mY, this->mZ); + // TODO: Explore adding a validate function if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() && this->mTensorOutput->isInit())) { @@ -83,21 +98,32 @@ OpMult::record() { SPDLOG_DEBUG("Kompute OpMult record called"); - this->mAlgorithm->recordDispatch(1, 1, 1); + this->mTensorLHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eHostWrite, + vk::AccessFlagBits::eShaderRead, + vk::PipelineStageFlagBits::eHost, + vk::PipelineStageFlagBits::eComputeShader); + this->mTensorRHS->recordBufferMemoryBarrier( + vk::AccessFlagBits::eHostWrite, + vk::AccessFlagBits::eShaderRead, + vk::PipelineStageFlagBits::eHost, + vk::PipelineStageFlagBits::eComputeShader); + + this->mAlgorithm->recordDispatch(this->mX, this->mY, this->mZ); this->mTensorOutput->recordBufferMemoryBarrier( - vk::AccessFlagBits::eShaderWrite, - vk::AccessFlagBits::eTransferRead, - vk::PipelineStageFlagBits::eComputeShader, - vk::PipelineStageFlagBits::eTransfer); + vk::AccessFlagBits::eShaderWrite, + vk::AccessFlagBits::eTransferRead, + vk::PipelineStageFlagBits::eComputeShader, + vk::PipelineStageFlagBits::eTransfer); this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput); this->mTensorOutput->recordBufferMemoryBarrier( - vk::AccessFlagBits::eTransferWrite, - vk::AccessFlagBits::eHostRead, - vk::PipelineStageFlagBits::eTransfer, - vk::PipelineStageFlagBits::eHost); + vk::AccessFlagBits::eTransferWrite, + vk::AccessFlagBits::eHostRead, + vk::PipelineStageFlagBits::eTransfer, + vk::PipelineStageFlagBits::eHost); } void diff --git a/src/OpMult.hpp b/src/OpMult.hpp index fd1635fdf..51012d937 100644 --- a/src/OpMult.hpp +++ b/src/OpMult.hpp @@ -17,6 +17,7 @@ namespace kp { +template class OpMult : public OpBase { public: @@ -40,6 +41,10 @@ class OpMult : public OpBase std::shared_ptr mTensorRHS; std::shared_ptr mTensorOutput; std::shared_ptr mTensorOutputStaging; + + uint32_t mX; + uint32_t mY; + uint32_t mZ; }; } // End namespace kp