Updated constants to be set as part of descriptor layout

This commit is contained in:
Alejandro Saucedo 2021-03-04 08:10:28 +00:00
parent 9edbac4b94
commit c211c22a78
7 changed files with 106 additions and 24 deletions

View file

@ -8,7 +8,8 @@ Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
const std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<uint32_t>& spirv,
const Workgroup& workgroup,
const Constants& specializationConstants)
const Constants& specializationConstants,
const Constants& pushConstants)
{
KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
@ -19,7 +20,7 @@ Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
"spirv size: {}",
tensors.size(),
spirv.size());
this->rebuild(tensors, spirv, workgroup, specializationConstants);
this->rebuild(tensors, spirv, workgroup, specializationConstants, pushConstants);
} else {
KP_LOG_INFO("Kompute Algorithm constructor with empty tensors and or "
"spirv so not rebuilding vulkan components");
@ -37,13 +38,15 @@ void
Algorithm::rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<uint32_t>& spirv,
const Workgroup& workgroup,
const Constants& specializationConstants)
const Constants& specializationConstants,
const Constants& pushConstants)
{
KP_LOG_DEBUG("Kompute Algorithm rebuild started");
this->mTensors = tensors;
this->mSpirv = spirv;
this->mSpecializationConstants = specializationConstants;
this->mPushConstants = pushConstants;
this->setWorkgroup(workgroup,
this->mTensors.size() ? this->mTensors[0]->size() : 1);
@ -273,6 +276,16 @@ Algorithm::createPipeline()
1, // Set layout count
this->mDescriptorSetLayout.get());
vk::PushConstantRange pushConstantRange;
if (this->mPushConstants.size()) {
pushConstantRange.setStageFlags(vk::ShaderStageFlagBits::eCompute);
pushConstantRange.setOffset(0);
pushConstantRange.setSize(sizeof(float) * this->mPushConstants.size());
pipelineLayoutInfo.setPushConstantRangeCount(1);
pipelineLayoutInfo.setPPushConstantRanges(&pushConstantRange);
}
this->mPipelineLayout = std::make_shared<vk::PipelineLayout>();
this->mDevice->createPipelineLayout(
&pipelineLayoutInfo, nullptr, this->mPipelineLayout.get());
@ -364,18 +377,17 @@ Algorithm::bindCore(const vk::CommandBuffer& commandBuffer)
}
void
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer,
const Constants& pushConstants)
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer)
{
if (pushConstants.size()) {
if (this->mPushConstants.size()) {
KP_LOG_DEBUG("Kompute Algorithm binding push constants size: {}",
pushConstants.size());
this->mPushConstants.size());
commandBuffer.pushConstants(*this->mPipelineLayout,
vk::ShaderStageFlagBits::eCompute,
0,
pushConstants.size() * sizeof(float),
pushConstants.data());
this->mPushConstants.size() * sizeof(float),
this->mPushConstants.data());
}
}
@ -412,6 +424,18 @@ Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
this->mWorkgroup[2]);
}
void
Algorithm::setPush(const Constants& pushConstants) {
if (pushConstants.size() != this->mPushConstants.size()) {
throw std::runtime_error(fmt::format("Kompute Algorithm push "
"constant provided is size {} but expected size {}",
pushConstants.size(), this->mPushConstants.size()));
}
this->mPushConstants = pushConstants;
}
const Workgroup&
Algorithm::getWorkgroup()
{
@ -424,6 +448,11 @@ Algorithm::getSpecializationConstants()
return this->mSpecializationConstants;
}
const Constants&
Algorithm::getPush() {
return this->mPushConstants;
}
const std::vector<std::shared_ptr<Tensor>>&
Algorithm::getTensors()
{