Updated constants to be set as part of descriptor layout
This commit is contained in:
parent
9edbac4b94
commit
c211c22a78
7 changed files with 106 additions and 24 deletions
|
|
@ -1135,7 +1135,8 @@ class Algorithm
|
|||
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
|
||||
const std::vector<uint32_t>& spirv = {},
|
||||
const Workgroup& workgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
const Constants& specializationConstants = {},
|
||||
const Constants& pushConstants = {});
|
||||
|
||||
/**
|
||||
* Initialiser for the shader data provided to the algorithm as well as
|
||||
|
|
@ -1149,7 +1150,8 @@ class Algorithm
|
|||
void rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
const Constants& specializationConstants = {},
|
||||
const Constants& pushConstants = {});
|
||||
|
||||
/**
|
||||
* Destructor for Algorithm which is responsible for freeing and desroying
|
||||
|
|
@ -1169,15 +1171,16 @@ class Algorithm
|
|||
|
||||
void bindCore(const vk::CommandBuffer& commandBuffer);
|
||||
|
||||
void bindPush(const vk::CommandBuffer& commandBuffer,
|
||||
const Constants& pushConstants);
|
||||
void bindPush(const vk::CommandBuffer& commandBuffer);
|
||||
|
||||
bool isInit();
|
||||
|
||||
void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1);
|
||||
void setPush(const Constants& pushConstants);
|
||||
|
||||
const Workgroup& getWorkgroup();
|
||||
const Constants& getSpecializationConstants();
|
||||
const Constants& getPush();
|
||||
const std::vector<std::shared_ptr<Tensor>>& getTensors();
|
||||
|
||||
void destroy();
|
||||
|
|
@ -1206,6 +1209,7 @@ class Algorithm
|
|||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::vector<uint32_t> mSpirv;
|
||||
Constants mSpecializationConstants;
|
||||
Constants mPushConstants;
|
||||
Workgroup mWorkgroup;
|
||||
|
||||
bool mIsInit;
|
||||
|
|
@ -1801,7 +1805,8 @@ class Manager
|
|||
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
|
||||
const std::vector<uint32_t>& spirv = {},
|
||||
const Workgroup& workgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
const Constants& specializationConstants = {},
|
||||
const Constants& pushConstants = {});
|
||||
|
||||
void destroy();
|
||||
void clear();
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
|
|||
const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup,
|
||||
const Constants& specializationConstants)
|
||||
const Constants& specializationConstants,
|
||||
const Constants& pushConstants)
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
|
||||
|
||||
|
|
@ -19,7 +20,7 @@ Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
|
|||
"spirv size: {}",
|
||||
tensors.size(),
|
||||
spirv.size());
|
||||
this->rebuild(tensors, spirv, workgroup, specializationConstants);
|
||||
this->rebuild(tensors, spirv, workgroup, specializationConstants, pushConstants);
|
||||
} else {
|
||||
KP_LOG_INFO("Kompute Algorithm constructor with empty tensors and or "
|
||||
"spirv so not rebuilding vulkan components");
|
||||
|
|
@ -37,13 +38,15 @@ void
|
|||
Algorithm::rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup,
|
||||
const Constants& specializationConstants)
|
||||
const Constants& specializationConstants,
|
||||
const Constants& pushConstants)
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute Algorithm rebuild started");
|
||||
|
||||
this->mTensors = tensors;
|
||||
this->mSpirv = spirv;
|
||||
this->mSpecializationConstants = specializationConstants;
|
||||
this->mPushConstants = pushConstants;
|
||||
this->setWorkgroup(workgroup,
|
||||
this->mTensors.size() ? this->mTensors[0]->size() : 1);
|
||||
|
||||
|
|
@ -273,6 +276,16 @@ Algorithm::createPipeline()
|
|||
1, // Set layout count
|
||||
this->mDescriptorSetLayout.get());
|
||||
|
||||
vk::PushConstantRange pushConstantRange;
|
||||
if (this->mPushConstants.size()) {
|
||||
pushConstantRange.setStageFlags(vk::ShaderStageFlagBits::eCompute);
|
||||
pushConstantRange.setOffset(0);
|
||||
pushConstantRange.setSize(sizeof(float) * this->mPushConstants.size());
|
||||
|
||||
pipelineLayoutInfo.setPushConstantRangeCount(1);
|
||||
pipelineLayoutInfo.setPPushConstantRanges(&pushConstantRange);
|
||||
}
|
||||
|
||||
this->mPipelineLayout = std::make_shared<vk::PipelineLayout>();
|
||||
this->mDevice->createPipelineLayout(
|
||||
&pipelineLayoutInfo, nullptr, this->mPipelineLayout.get());
|
||||
|
|
@ -364,18 +377,17 @@ Algorithm::bindCore(const vk::CommandBuffer& commandBuffer)
|
|||
}
|
||||
|
||||
void
|
||||
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer,
|
||||
const Constants& pushConstants)
|
||||
Algorithm::bindPush(const vk::CommandBuffer& commandBuffer)
|
||||
{
|
||||
if (pushConstants.size()) {
|
||||
if (this->mPushConstants.size()) {
|
||||
KP_LOG_DEBUG("Kompute Algorithm binding push constants size: {}",
|
||||
pushConstants.size());
|
||||
this->mPushConstants.size());
|
||||
|
||||
commandBuffer.pushConstants(*this->mPipelineLayout,
|
||||
vk::ShaderStageFlagBits::eCompute,
|
||||
0,
|
||||
pushConstants.size() * sizeof(float),
|
||||
pushConstants.data());
|
||||
this->mPushConstants.size() * sizeof(float),
|
||||
this->mPushConstants.data());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -412,6 +424,18 @@ Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
|
|||
this->mWorkgroup[2]);
|
||||
}
|
||||
|
||||
void
|
||||
Algorithm::setPush(const Constants& pushConstants) {
|
||||
|
||||
if (pushConstants.size() != this->mPushConstants.size()) {
|
||||
throw std::runtime_error(fmt::format("Kompute Algorithm push "
|
||||
"constant provided is size {} but expected size {}",
|
||||
pushConstants.size(), this->mPushConstants.size()));
|
||||
}
|
||||
|
||||
this->mPushConstants = pushConstants;
|
||||
}
|
||||
|
||||
const Workgroup&
|
||||
Algorithm::getWorkgroup()
|
||||
{
|
||||
|
|
@ -424,6 +448,11 @@ Algorithm::getSpecializationConstants()
|
|||
return this->mSpecializationConstants;
|
||||
}
|
||||
|
||||
const Constants&
|
||||
Algorithm::getPush() {
|
||||
return this->mPushConstants;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<Tensor>>&
|
||||
Algorithm::getTensors()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -361,13 +361,14 @@ std::shared_ptr<Algorithm>
|
|||
Manager::algorithm(const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup,
|
||||
const Constants& specializationConstants)
|
||||
const Constants& specializationConstants,
|
||||
const Constants& pushConstants)
|
||||
{
|
||||
|
||||
KP_LOG_DEBUG("Kompute Manager algorithm creation triggered");
|
||||
|
||||
std::shared_ptr<Algorithm> algorithm{ new kp::Algorithm(
|
||||
this->mDevice, tensors, spirv, workgroup, specializationConstants) };
|
||||
this->mDevice, tensors, spirv, workgroup, specializationConstants, pushConstants) };
|
||||
|
||||
if (this->mManageResources) {
|
||||
this->mManagedAlgorithms.push_back(algorithm);
|
||||
|
|
|
|||
|
|
@ -34,8 +34,12 @@ OpAlgoDispatch::record(const vk::CommandBuffer& commandBuffer)
|
|||
vk::PipelineStageFlagBits::eComputeShader);
|
||||
}
|
||||
|
||||
if (this->mPushConstants.size()) {
|
||||
this->mAlgorithm->setPush(this->mPushConstants);
|
||||
}
|
||||
|
||||
this->mAlgorithm->bindCore(commandBuffer);
|
||||
this->mAlgorithm->bindPush(commandBuffer, this->mPushConstants);
|
||||
this->mAlgorithm->bindPush(commandBuffer);
|
||||
this->mAlgorithm->recordDispatch(commandBuffer);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ class Algorithm
|
|||
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
|
||||
const std::vector<uint32_t>& spirv = {},
|
||||
const Workgroup& workgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
const Constants& specializationConstants = {},
|
||||
const Constants& pushConstants = {});
|
||||
|
||||
/**
|
||||
* Initialiser for the shader data provided to the algorithm as well as
|
||||
|
|
@ -38,7 +39,8 @@ class Algorithm
|
|||
void rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<uint32_t>& spirv,
|
||||
const Workgroup& workgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
const Constants& specializationConstants = {},
|
||||
const Constants& pushConstants = {});
|
||||
|
||||
/**
|
||||
* Destructor for Algorithm which is responsible for freeing and desroying
|
||||
|
|
@ -58,15 +60,16 @@ class Algorithm
|
|||
|
||||
void bindCore(const vk::CommandBuffer& commandBuffer);
|
||||
|
||||
void bindPush(const vk::CommandBuffer& commandBuffer,
|
||||
const Constants& pushConstants);
|
||||
void bindPush(const vk::CommandBuffer& commandBuffer);
|
||||
|
||||
bool isInit();
|
||||
|
||||
void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1);
|
||||
void setPush(const Constants& pushConstants);
|
||||
|
||||
const Workgroup& getWorkgroup();
|
||||
const Constants& getSpecializationConstants();
|
||||
const Constants& getPush();
|
||||
const std::vector<std::shared_ptr<Tensor>>& getTensors();
|
||||
|
||||
void destroy();
|
||||
|
|
@ -95,6 +98,7 @@ class Algorithm
|
|||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::vector<uint32_t> mSpirv;
|
||||
Constants mSpecializationConstants;
|
||||
Constants mPushConstants;
|
||||
Workgroup mWorkgroup;
|
||||
|
||||
bool mIsInit;
|
||||
|
|
|
|||
|
|
@ -87,7 +87,8 @@ class Manager
|
|||
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
|
||||
const std::vector<uint32_t>& spirv = {},
|
||||
const Workgroup& workgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
const Constants& specializationConstants = {},
|
||||
const Constants& pushConstants = {});
|
||||
|
||||
void destroy();
|
||||
void clear();
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
#include "fmt/ranges.h"
|
||||
|
||||
TEST(TestPushConstants, TestTwoConstants)
|
||||
TEST(TestPushConstants, TestConstants)
|
||||
{
|
||||
{
|
||||
std::string shader(R"(
|
||||
|
|
@ -32,7 +32,7 @@ TEST(TestPushConstants, TestTwoConstants)
|
|||
std::shared_ptr<kp::Tensor> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
|
||||
std::shared_ptr<kp::Algorithm> algo =
|
||||
mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }));
|
||||
mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0, 0.0, 0.0 });
|
||||
|
||||
sq = mgr.sequence()
|
||||
->record<kp::OpTensorSyncDevice>({ tensor })
|
||||
|
|
@ -47,3 +47,41 @@ TEST(TestPushConstants, TestTwoConstants)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestPushConstants, TestConstantsWrongSize)
|
||||
{
|
||||
{
|
||||
std::string shader(R"(
|
||||
#version 450
|
||||
layout(push_constant) uniform PushConstants {
|
||||
float x;
|
||||
float y;
|
||||
float z;
|
||||
} pcs;
|
||||
layout (local_size_x = 1) in;
|
||||
layout(set = 0, binding = 0) buffer a { float pa[]; };
|
||||
void main() {
|
||||
pa[0] += pcs.x;
|
||||
pa[1] += pcs.y;
|
||||
pa[2] += pcs.z;
|
||||
})");
|
||||
|
||||
std::vector<uint32_t> spirv = kp::Shader::compile_source(shader);
|
||||
|
||||
std::shared_ptr<kp::Sequence> sq = nullptr;
|
||||
|
||||
{
|
||||
kp::Manager mgr;
|
||||
|
||||
std::shared_ptr<kp::Tensor> tensor = mgr.tensor({ 0, 0, 0 });
|
||||
|
||||
std::shared_ptr<kp::Algorithm> algo =
|
||||
mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0 });
|
||||
|
||||
sq = mgr.sequence()
|
||||
->record<kp::OpTensorSyncDevice>({ tensor });
|
||||
|
||||
EXPECT_THROW(sq->record<kp::OpAlgoDispatch>(algo, kp::Constants{ 0.1, 0.2, 0.3 }), std::runtime_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue