Added template parameters to opmult class

This commit is contained in:
Alejandro Saucedo 2020-08-22 13:33:21 +01:00
parent c92425dd87
commit d74a999e12
4 changed files with 43 additions and 10 deletions

View file

@ -45,6 +45,21 @@ OpMult::init(std::vector<std::shared_ptr<Tensor>> tensors)
this->mTensorRHS = tensors[1];
this->mTensorOutput = tensors[2];
// The dispatch size is set up based on either explicitly provided template parameters or by default it would take the shape and size of the tensors
if (tX > 0) {
// If at least the x value is provided we use mainly the parameters provided
this->mX = tX;
this->mY = tY > 0 ? tY : 1;
this->mZ = tZ > 0 ? tZ : 1;
}
else {
// TODO: Fully support the full size dispatch using size for the shape
this->mX = this->mTensorLHS->size();
this->mY = 1;
this->mZ = 1;
}
spdlog::info("Kompute OpMult dispatch size X: {}, Y: {}, Z: {}", this->mX, this->mY, this->mZ);
// TODO: Explore adding a validate function
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
this->mTensorOutput->isInit())) {
@ -83,21 +98,32 @@ OpMult::record()
{
SPDLOG_DEBUG("Kompute OpMult record called");
this->mAlgorithm->recordDispatch(1, 1, 1);
this->mTensorLHS->recordBufferMemoryBarrier(
vk::AccessFlagBits::eHostWrite,
vk::AccessFlagBits::eShaderRead,
vk::PipelineStageFlagBits::eHost,
vk::PipelineStageFlagBits::eComputeShader);
this->mTensorRHS->recordBufferMemoryBarrier(
vk::AccessFlagBits::eHostWrite,
vk::AccessFlagBits::eShaderRead,
vk::PipelineStageFlagBits::eHost,
vk::PipelineStageFlagBits::eComputeShader);
this->mAlgorithm->recordDispatch(this->mX, this->mY, this->mZ);
this->mTensorOutput->recordBufferMemoryBarrier(
vk::AccessFlagBits::eShaderWrite,
vk::AccessFlagBits::eTransferRead,
vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eTransfer);
vk::AccessFlagBits::eShaderWrite,
vk::AccessFlagBits::eTransferRead,
vk::PipelineStageFlagBits::eComputeShader,
vk::PipelineStageFlagBits::eTransfer);
this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput);
this->mTensorOutput->recordBufferMemoryBarrier(
vk::AccessFlagBits::eTransferWrite,
vk::AccessFlagBits::eHostRead,
vk::PipelineStageFlagBits::eTransfer,
vk::PipelineStageFlagBits::eHost);
vk::AccessFlagBits::eTransferWrite,
vk::AccessFlagBits::eHostRead,
vk::PipelineStageFlagBits::eTransfer,
vk::PipelineStageFlagBits::eHost);
}
void