Added template parameters to opmult class
This commit is contained in:
parent
c92425dd87
commit
d74a999e12
4 changed files with 43 additions and 10 deletions
|
|
@ -21,7 +21,9 @@ void main()
|
|||
|
||||
//valuesOutput[index] = valuesLhs[index] * valuesRhs[index];
|
||||
// FOR TESTING
|
||||
valuesOutput[index] = index;
|
||||
valuesOutput[index] = 100 + index;
|
||||
valuesRhs[index] = 100 + index;
|
||||
valuesLhs[index] = 100 + index;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -45,6 +45,21 @@ OpMult::init(std::vector<std::shared_ptr<Tensor>> tensors)
|
|||
this->mTensorRHS = tensors[1];
|
||||
this->mTensorOutput = tensors[2];
|
||||
|
||||
// The dispatch size is set up based on either explicitly provided template parameters or by default it would take the shape and size of the tensors
|
||||
if (tX > 0) {
|
||||
// If at least the x value is provided we use mainly the parameters provided
|
||||
this->mX = tX;
|
||||
this->mY = tY > 0 ? tY : 1;
|
||||
this->mZ = tZ > 0 ? tZ : 1;
|
||||
}
|
||||
else {
|
||||
// TODO: Fully support the full size dispatch using size for the shape
|
||||
this->mX = this->mTensorLHS->size();
|
||||
this->mY = 1;
|
||||
this->mZ = 1;
|
||||
}
|
||||
spdlog::info("Kompute OpMult dispatch size X: {}, Y: {}, Z: {}", this->mX, this->mY, this->mZ);
|
||||
|
||||
// TODO: Explore adding a validate function
|
||||
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
|
||||
this->mTensorOutput->isInit())) {
|
||||
|
|
@ -83,21 +98,32 @@ OpMult::record()
|
|||
{
|
||||
SPDLOG_DEBUG("Kompute OpMult record called");
|
||||
|
||||
this->mAlgorithm->recordDispatch(1, 1, 1);
|
||||
this->mTensorLHS->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eHostWrite,
|
||||
vk::AccessFlagBits::eShaderRead,
|
||||
vk::PipelineStageFlagBits::eHost,
|
||||
vk::PipelineStageFlagBits::eComputeShader);
|
||||
this->mTensorRHS->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eHostWrite,
|
||||
vk::AccessFlagBits::eShaderRead,
|
||||
vk::PipelineStageFlagBits::eHost,
|
||||
vk::PipelineStageFlagBits::eComputeShader);
|
||||
|
||||
this->mAlgorithm->recordDispatch(this->mX, this->mY, this->mZ);
|
||||
|
||||
this->mTensorOutput->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eShaderWrite,
|
||||
vk::AccessFlagBits::eTransferRead,
|
||||
vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eTransfer);
|
||||
vk::AccessFlagBits::eShaderWrite,
|
||||
vk::AccessFlagBits::eTransferRead,
|
||||
vk::PipelineStageFlagBits::eComputeShader,
|
||||
vk::PipelineStageFlagBits::eTransfer);
|
||||
|
||||
this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput);
|
||||
|
||||
this->mTensorOutput->recordBufferMemoryBarrier(
|
||||
vk::AccessFlagBits::eTransferWrite,
|
||||
vk::AccessFlagBits::eHostRead,
|
||||
vk::PipelineStageFlagBits::eTransfer,
|
||||
vk::PipelineStageFlagBits::eHost);
|
||||
vk::AccessFlagBits::eTransferWrite,
|
||||
vk::AccessFlagBits::eHostRead,
|
||||
vk::PipelineStageFlagBits::eTransfer,
|
||||
vk::PipelineStageFlagBits::eHost);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
namespace kp {
|
||||
|
||||
template<uint32_t tX = 0, uint32_t tY = 0, uint32_t tZ = 0>
|
||||
class OpMult : public OpBase
|
||||
{
|
||||
public:
|
||||
|
|
@ -40,6 +41,10 @@ class OpMult : public OpBase
|
|||
std::shared_ptr<Tensor> mTensorRHS;
|
||||
std::shared_ptr<Tensor> mTensorOutput;
|
||||
std::shared_ptr<Tensor> mTensorOutputStaging;
|
||||
|
||||
uint32_t mX;
|
||||
uint32_t mY;
|
||||
uint32_t mZ;
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue