Added support for push constants

This commit is contained in:
Alejandro Saucedo 2021-02-28 13:59:01 +00:00
parent 91d3b9a223
commit 7dc1f35206
28 changed files with 3151 additions and 3090 deletions

View file

@ -2,6 +2,8 @@
#include "kompute/Algorithm.hpp"
#include "fmt/ranges.h"
namespace kp {
Algorithm::Algorithm(
@ -9,8 +11,7 @@ Algorithm::Algorithm(
const std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<uint32_t>& spirv,
const Workgroup& workgroup,
const Constants& specializationConstants,
const Constants& pushConstants)
const Constants& specializationConstants)
{
KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
@ -18,7 +19,7 @@ Algorithm::Algorithm(
if (tensors.size() && spirv.size()) {
KP_LOG_INFO("Kompute Algorithm initialising with tensor size: {} and spirv size: {}", tensors.size(), spirv.size());
this->rebuild(tensors, spirv, workgroup, specializationConstants, pushConstants);
this->rebuild(tensors, spirv, workgroup, specializationConstants);
}
else {
KP_LOG_INFO("Kompute Algorithm constructor with empty tensors and or spirv so not rebuilding vulkan components");
@ -37,15 +38,13 @@ Algorithm::rebuild(
const std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<uint32_t>& spirv,
const Workgroup& workgroup,
const Constants& specializationConstants,
const Constants& pushConstants)
const Constants& specializationConstants)
{
KP_LOG_DEBUG("Kompute Algorithm rebuild started");
this->mTensors = tensors;
this->mSpirv = spirv;
this->mSpecializationConstants = specializationConstants;
this->mPushConstants = pushConstants;
this->setWorkgroup(workgroup, this->mTensors.size() ? this->mTensors[0]->size() : 1);
// Descriptor pool is created first so if available then destroy all before rebuild
@ -347,27 +346,43 @@ Algorithm::createPipeline()
}
void
Algorithm::recordDispatch(std::shared_ptr<vk::CommandBuffer> commandBuffer)
Algorithm::bindCore(const vk::CommandBuffer& commandBuffer)
{
KP_LOG_DEBUG("Kompute Algorithm calling record dispatch");
KP_LOG_DEBUG("Kompute Algorithm binding pipeline");
commandBuffer->bindPipeline(vk::PipelineBindPoint::eCompute,
commandBuffer.bindPipeline(vk::PipelineBindPoint::eCompute,
*this->mPipeline);
KP_LOG_DEBUG("Kompute Algorithm pipeline bound");
KP_LOG_DEBUG("Kompute Algorithm binding descriptor sets");
commandBuffer->bindDescriptorSets(vk::PipelineBindPoint::eCompute,
commandBuffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
*this->mPipelineLayout,
0, // First set
*this->mDescriptorSet,
nullptr // Dispatcher
);
}
KP_LOG_DEBUG("Kompute Algorithm descriptor sets bound");
void
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer, const Constants& pushConstants)
{
if (pushConstants.size()) {
KP_LOG_DEBUG("Kompute Algorithm binding push constants size: {}", pushConstants.size());
commandBuffer->dispatch(this->mWorkgroup[0], this->mWorkgroup[1], this->mWorkgroup[2]);
commandBuffer.pushConstants(*this->mPipelineLayout,
vk::ShaderStageFlagBits::eCompute,
0,
pushConstants.size() * sizeof(float),
pushConstants.data());
}
}
KP_LOG_DEBUG("Kompute Algorithm dispatch success");
void
Algorithm::recordDispatch(const vk::CommandBuffer& commandBuffer)
{
KP_LOG_DEBUG("Kompute Algorithm recording dispatch");
commandBuffer.dispatch(this->mWorkgroup[0], this->mWorkgroup[1], this->mWorkgroup[2]);
}
void
@ -405,11 +420,6 @@ Algorithm::getSpecializationConstants() {
return this->mSpecializationConstants;
}
const Constants&
Algorithm::getPushConstants() {
return this->mPushConstants;
}
const std::vector<std::shared_ptr<Tensor>>&
Algorithm::getTensors() {
return this->mTensors;