Reformatted

This commit is contained in:
Alejandro Saucedo 2020-11-01 20:56:03 +00:00
parent 3036cbd95f
commit 1f614a87e4
11 changed files with 125 additions and 96 deletions

View file

@ -9,13 +9,14 @@ OpAlgoLhsRhsOut::OpAlgoLhsRhsOut()
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut constructor base");
}
OpAlgoLhsRhsOut::OpAlgoLhsRhsOut(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>> tensors,
KomputeWorkgroup komputeWorkgroup)
OpAlgoLhsRhsOut::OpAlgoLhsRhsOut(
std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>> tensors,
KomputeWorkgroup komputeWorkgroup)
// The inheritance is initialised with the copyOutputData to false given that
// this depencendant class handles the transfer of data via staging buffers in
// this depencendant class handles the transfer of data via staging buffers in
// a granular way.
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
{
@ -36,18 +37,19 @@ OpAlgoLhsRhsOut::init()
throw std::runtime_error(
"Kompute OpAlgoLhsRhsOut called with less than 1 tensor");
} else if (this->mTensors.size() > 3) {
SPDLOG_WARN("Kompute OpAlgoLhsRhsOut called with more than 3 this->mTensors");
SPDLOG_WARN(
"Kompute OpAlgoLhsRhsOut called with more than 3 this->mTensors");
}
this->mTensorLHS = this->mTensors[0];
this->mTensorRHS = this->mTensors[1];
this->mTensorOutput = this->mTensors[2];
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
this->mTensorOutput->isInit())) {
throw std::runtime_error(
"Kompute OpAlgoLhsRhsOut all tensor parameters must be initialised. LHS: " +
"Kompute OpAlgoLhsRhsOut all tensor parameters must be initialised. "
"LHS: " +
std::to_string(this->mTensorLHS->isInit()) +
" RHS: " + std::to_string(this->mTensorRHS->isInit()) +
" Output: " + std::to_string(this->mTensorOutput->isInit()));
@ -56,7 +58,8 @@ OpAlgoLhsRhsOut::init()
if (!(this->mTensorLHS->size() == this->mTensorRHS->size() &&
this->mTensorRHS->size() == this->mTensorOutput->size())) {
throw std::runtime_error(
"Kompute OpAlgoLhsRhsOut all tensor parameters must be the same size LHS: " +
"Kompute OpAlgoLhsRhsOut all tensor parameters must be the same size "
"LHS: " +
std::to_string(this->mTensorLHS->size()) +
" RHS: " + std::to_string(this->mTensorRHS->size()) +
" Output: " + std::to_string(this->mTensorOutput->size()));
@ -65,8 +68,7 @@ OpAlgoLhsRhsOut::init()
this->mTensorOutputStaging = std::make_shared<Tensor>(
this->mTensorOutput->data(), Tensor::TensorTypes::eStaging);
this->mTensorOutputStaging->init(
this->mPhysicalDevice, this->mDevice);
this->mTensorOutputStaging->init(this->mPhysicalDevice, this->mDevice);
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut fetching spirv data");
@ -96,10 +98,9 @@ OpAlgoLhsRhsOut::record()
vk::PipelineStageFlagBits::eHost,
vk::PipelineStageFlagBits::eComputeShader);
this->mAlgorithm->recordDispatch(
this->mKomputeWorkgroup.x,
this->mKomputeWorkgroup.y,
this->mKomputeWorkgroup.z);
this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup.x,
this->mKomputeWorkgroup.y,
this->mKomputeWorkgroup.z);
// Barrier to ensure the shader code is executed before buffer read
this->mTensorOutput->recordBufferMemoryBarrier(
@ -110,9 +111,7 @@ OpAlgoLhsRhsOut::record()
vk::PipelineStageFlagBits::eTransfer);
this->mTensorOutputStaging->recordCopyFrom(
this->mCommandBuffer,
this->mTensorOutput,
true);
this->mCommandBuffer, this->mTensorOutput, true);
}
void
@ -126,4 +125,3 @@ OpAlgoLhsRhsOut::postEval()
}
}