Reformatted
This commit is contained in:
parent
3036cbd95f
commit
1f614a87e4
11 changed files with 125 additions and 96 deletions
|
|
@ -9,13 +9,14 @@ OpAlgoLhsRhsOut::OpAlgoLhsRhsOut()
|
|||
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut constructor base");
|
||||
}
|
||||
|
||||
OpAlgoLhsRhsOut::OpAlgoLhsRhsOut(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>> tensors,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
OpAlgoLhsRhsOut::OpAlgoLhsRhsOut(
|
||||
std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>> tensors,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
// The inheritance is initialised with the copyOutputData to false given that
|
||||
// this depencendant class handles the transfer of data via staging buffers in
|
||||
// this depencendant class handles the transfer of data via staging buffers in
|
||||
// a granular way.
|
||||
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
|
||||
{
|
||||
|
|
@ -36,18 +37,19 @@ OpAlgoLhsRhsOut::init()
|
|||
throw std::runtime_error(
|
||||
"Kompute OpAlgoLhsRhsOut called with less than 1 tensor");
|
||||
} else if (this->mTensors.size() > 3) {
|
||||
SPDLOG_WARN("Kompute OpAlgoLhsRhsOut called with more than 3 this->mTensors");
|
||||
SPDLOG_WARN(
|
||||
"Kompute OpAlgoLhsRhsOut called with more than 3 this->mTensors");
|
||||
}
|
||||
|
||||
this->mTensorLHS = this->mTensors[0];
|
||||
this->mTensorRHS = this->mTensors[1];
|
||||
this->mTensorOutput = this->mTensors[2];
|
||||
|
||||
|
||||
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
|
||||
this->mTensorOutput->isInit())) {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpAlgoLhsRhsOut all tensor parameters must be initialised. LHS: " +
|
||||
"Kompute OpAlgoLhsRhsOut all tensor parameters must be initialised. "
|
||||
"LHS: " +
|
||||
std::to_string(this->mTensorLHS->isInit()) +
|
||||
" RHS: " + std::to_string(this->mTensorRHS->isInit()) +
|
||||
" Output: " + std::to_string(this->mTensorOutput->isInit()));
|
||||
|
|
@ -56,7 +58,8 @@ OpAlgoLhsRhsOut::init()
|
|||
if (!(this->mTensorLHS->size() == this->mTensorRHS->size() &&
|
||||
this->mTensorRHS->size() == this->mTensorOutput->size())) {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpAlgoLhsRhsOut all tensor parameters must be the same size LHS: " +
|
||||
"Kompute OpAlgoLhsRhsOut all tensor parameters must be the same size "
|
||||
"LHS: " +
|
||||
std::to_string(this->mTensorLHS->size()) +
|
||||
" RHS: " + std::to_string(this->mTensorRHS->size()) +
|
||||
" Output: " + std::to_string(this->mTensorOutput->size()));
|
||||
|
|
@ -65,8 +68,7 @@ OpAlgoLhsRhsOut::init()
|
|||
this->mTensorOutputStaging = std::make_shared<Tensor>(
|
||||
this->mTensorOutput->data(), Tensor::TensorTypes::eStaging);
|
||||
|
||||
this->mTensorOutputStaging->init(
|
||||
this->mPhysicalDevice, this->mDevice);
|
||||
this->mTensorOutputStaging->init(this->mPhysicalDevice, this->mDevice);
|
||||
|
||||
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut fetching spirv data");
|
||||
|
||||
|
|
@ -96,10 +98,9 @@ OpAlgoLhsRhsOut::record()
|
|||
vk::PipelineStageFlagBits::eHost,
|
||||
vk::PipelineStageFlagBits::eComputeShader);
|
||||
|
||||
this->mAlgorithm->recordDispatch(
|
||||
this->mKomputeWorkgroup.x,
|
||||
this->mKomputeWorkgroup.y,
|
||||
this->mKomputeWorkgroup.z);
|
||||
this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup.x,
|
||||
this->mKomputeWorkgroup.y,
|
||||
this->mKomputeWorkgroup.z);
|
||||
|
||||
// Barrier to ensure the shader code is executed before buffer read
|
||||
this->mTensorOutput->recordBufferMemoryBarrier(
|
||||
|
|
@ -110,9 +111,7 @@ OpAlgoLhsRhsOut::record()
|
|||
vk::PipelineStageFlagBits::eTransfer);
|
||||
|
||||
this->mTensorOutputStaging->recordCopyFrom(
|
||||
this->mCommandBuffer,
|
||||
this->mTensorOutput,
|
||||
true);
|
||||
this->mCommandBuffer, this->mTensorOutput, true);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -126,4 +125,3 @@ OpAlgoLhsRhsOut::postEval()
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue