Updated constants to be set as part of descriptor layout
This commit is contained in:
parent
9edbac4b94
commit
c211c22a78
7 changed files with 106 additions and 24 deletions
|
|
@ -8,7 +8,8 @@ Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
|
|||
const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup,
|
||||
const Constants& specializationConstants)
|
||||
const Constants& specializationConstants,
|
||||
const Constants& pushConstants)
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
|
|||
"spirv size: {}",
|
||||
tensors.size(),
|
||||
spirv.size());
|
||||
this->rebuild(tensors, spirv, workgroup, specializationConstants);
|
||||
this->rebuild(tensors, spirv, workgroup, specializationConstants, pushConstants);
|
||||
} else {
|
||||
KP_LOG_INFO("Kompute Algorithm constructor with empty tensors and or "
|
||||
"spirv so not rebuilding vulkan components");
|
||||
|
|
@ -37,13 +38,15 @@ void
|
|||
Algorithm::rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup,
|
||||
const Constants& specializationConstants)
|
||||
const Constants& specializationConstants,
|
||||
const Constants& pushConstants)
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute Algorithm rebuild started");
|
||||
|
||||
this->mTensors = tensors;
|
||||
this->mSpirv = spirv;
|
||||
this->mSpecializationConstants = specializationConstants;
|
||||
this->mPushConstants = pushConstants;
|
||||
this->setWorkgroup(workgroup,
|
||||
this->mTensors.size() ? this->mTensors[0]->size() : 1);
|
||||
|
||||
|
|
@ -273,6 +276,16 @@ Algorithm::createPipeline()
|
|||
1, // Set layout count
|
||||
this->mDescriptorSetLayout.get());
|
||||
|
||||
vk::PushConstantRange pushConstantRange;
|
||||
if (this->mPushConstants.size()) {
|
||||
pushConstantRange.setStageFlags(vk::ShaderStageFlagBits::eCompute);
|
||||
pushConstantRange.setOffset(0);
|
||||
pushConstantRange.setSize(sizeof(float) * this->mPushConstants.size());
|
||||
|
||||
pipelineLayoutInfo.setPushConstantRangeCount(1);
|
||||
pipelineLayoutInfo.setPPushConstantRanges(&pushConstantRange);
|
||||
}
|
||||
|
||||
this->mPipelineLayout = std::make_shared<vk::PipelineLayout>();
|
||||
this->mDevice->createPipelineLayout(
|
||||
&pipelineLayoutInfo, nullptr, this->mPipelineLayout.get());
|
||||
|
|
@ -364,18 +377,17 @@ Algorithm::bindCore(const vk::CommandBuffer& commandBuffer)
|
|||
}
|
||||
|
||||
void
|
||||
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer,
|
||||
const Constants& pushConstants)
|
||||
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer)
|
||||
{
|
||||
if (pushConstants.size()) {
|
||||
if (this->mPushConstants.size()) {
|
||||
KP_LOG_DEBUG("Kompute Algorithm binding push constants size: {}",
|
||||
pushConstants.size());
|
||||
this->mPushConstants.size());
|
||||
|
||||
commandBuffer.pushConstants(*this->mPipelineLayout,
|
||||
vk::ShaderStageFlagBits::eCompute,
|
||||
0,
|
||||
pushConstants.size() * sizeof(float),
|
||||
pushConstants.data());
|
||||
this->mPushConstants.size() * sizeof(float),
|
||||
this->mPushConstants.data());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -412,6 +424,18 @@ Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
|
|||
this->mWorkgroup[2]);
|
||||
}
|
||||
|
||||
void
|
||||
Algorithm::setPush(const Constants& pushConstants) {
|
||||
|
||||
if (pushConstants.size() != this->mPushConstants.size()) {
|
||||
throw std::runtime_error(fmt::format("Kompute Algorithm push "
|
||||
"constant provided is size {} but expected size {}",
|
||||
pushConstants.size(), this->mPushConstants.size()));
|
||||
}
|
||||
|
||||
this->mPushConstants = pushConstants;
|
||||
}
|
||||
|
||||
const Workgroup&
|
||||
Algorithm::getWorkgroup()
|
||||
{
|
||||
|
|
@ -424,6 +448,11 @@ Algorithm::getSpecializationConstants()
|
|||
return this->mSpecializationConstants;
|
||||
}
|
||||
|
||||
const Constants&
|
||||
Algorithm::getPush() {
|
||||
return this->mPushConstants;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>>&
|
||||
Algorithm::getTensors()
|
||||
{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue