diff --git a/single_include/kompute/Kompute.hpp b/single_include/kompute/Kompute.hpp index 7b67e2024..661824804 100755 --- a/single_include/kompute/Kompute.hpp +++ b/single_include/kompute/Kompute.hpp @@ -1135,7 +1135,8 @@ class Algorithm const std::vector>& tensors = {}, const std::vector& spirv = {}, const Workgroup& workgroup = {}, - const Constants& specializationConstants = {}); + const Constants& specializationConstants = {}, + const Constants& pushConstants = {}); /** * Initialiser for the shader data provided to the algorithm as well as @@ -1149,7 +1150,8 @@ class Algorithm void rebuild(const std::vector>& tensors, const std::vector& spirv, const Workgroup& workgroup = {}, - const Constants& specializationConstants = {}); + const Constants& specializationConstants = {}, + const Constants& pushConstants = {}); /** * Destructor for Algorithm which is responsible for freeing and desroying @@ -1169,15 +1171,16 @@ class Algorithm void bindCore(const vk::CommandBuffer& commandBuffer); - void bindPush(const vk::CommandBuffer& commandBuffer, - const Constants& pushConstants); + void bindPush(const vk::CommandBuffer& commandBuffer); bool isInit(); void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1); + void setPush(const Constants& pushConstants); const Workgroup& getWorkgroup(); const Constants& getSpecializationConstants(); + const Constants& getPush(); const std::vector>& getTensors(); void destroy(); @@ -1206,6 +1209,7 @@ class Algorithm // -------------- ALWAYS OWNED RESOURCES std::vector mSpirv; Constants mSpecializationConstants; + Constants mPushConstants; Workgroup mWorkgroup; bool mIsInit; @@ -1801,7 +1805,8 @@ class Manager const std::vector>& tensors = {}, const std::vector& spirv = {}, const Workgroup& workgroup = {}, - const Constants& specializationConstants = {}); + const Constants& specializationConstants = {}, + const Constants& pushConstants = {}); void destroy(); void clear(); diff --git a/src/Algorithm.cpp b/src/Algorithm.cpp index c58c5a228..50ec9ad28 100644 --- a/src/Algorithm.cpp +++ b/src/Algorithm.cpp @@ -8,7 +8,8 @@ Algorithm::Algorithm(std::shared_ptr device, const std::vector>& tensors, const std::vector& spirv, const Workgroup& workgroup, - const Constants& specializationConstants) + const Constants& specializationConstants, + const Constants& pushConstants) { KP_LOG_DEBUG("Kompute Algorithm Constructor with device"); @@ -19,7 +20,7 @@ Algorithm::Algorithm(std::shared_ptr device, "spirv size: {}", tensors.size(), spirv.size()); - this->rebuild(tensors, spirv, workgroup, specializationConstants); + this->rebuild(tensors, spirv, workgroup, specializationConstants, pushConstants); } else { KP_LOG_INFO("Kompute Algorithm constructor with empty tensors and or " "spirv so not rebuilding vulkan components"); @@ -37,13 +38,15 @@ void Algorithm::rebuild(const std::vector>& tensors, const std::vector& spirv, const Workgroup& workgroup, - const Constants& specializationConstants) + const Constants& specializationConstants, + const Constants& pushConstants) { KP_LOG_DEBUG("Kompute Algorithm rebuild started"); this->mTensors = tensors; this->mSpirv = spirv; this->mSpecializationConstants = specializationConstants; + this->mPushConstants = pushConstants; this->setWorkgroup(workgroup, this->mTensors.size() ? this->mTensors[0]->size() : 1); @@ -273,6 +276,16 @@ Algorithm::createPipeline() 1, // Set layout count this->mDescriptorSetLayout.get()); + vk::PushConstantRange pushConstantRange; + if (this->mPushConstants.size()) { + pushConstantRange.setStageFlags(vk::ShaderStageFlagBits::eCompute); + pushConstantRange.setOffset(0); + pushConstantRange.setSize(sizeof(float) * this->mPushConstants.size()); + + pipelineLayoutInfo.setPushConstantRangeCount(1); + pipelineLayoutInfo.setPPushConstantRanges(&pushConstantRange); + } + this->mPipelineLayout = std::make_shared(); this->mDevice->createPipelineLayout( &pipelineLayoutInfo, nullptr, this->mPipelineLayout.get()); @@ -364,18 +377,17 @@ Algorithm::bindCore(const vk::CommandBuffer& commandBuffer) } void -Algorithm::bindPush(const vk::CommandBuffer& commandBuffer, - const Constants& pushConstants) +Algorithm::bindPush(const vk::CommandBuffer& commandBuffer) { - if (pushConstants.size()) { + if (this->mPushConstants.size()) { KP_LOG_DEBUG("Kompute Algorithm binding push constants size: {}", - pushConstants.size()); + this->mPushConstants.size()); commandBuffer.pushConstants(*this->mPipelineLayout, vk::ShaderStageFlagBits::eCompute, 0, - pushConstants.size() * sizeof(float), - pushConstants.data()); + this->mPushConstants.size() * sizeof(float), + this->mPushConstants.data()); } } @@ -412,6 +424,18 @@ Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize) this->mWorkgroup[2]); } +void +Algorithm::setPush(const Constants& pushConstants) { + + if (pushConstants.size() != this->mPushConstants.size()) { + throw std::runtime_error(fmt::format("Kompute Algorithm push " + "constant provided is size {} but expected size {}", + pushConstants.size(), this->mPushConstants.size())); + } + + this->mPushConstants = pushConstants; +} + const Workgroup& Algorithm::getWorkgroup() { @@ -424,6 +448,11 @@ Algorithm::getSpecializationConstants() return this->mSpecializationConstants; } +const Constants& +Algorithm::getPush() { + return this->mPushConstants; +} + const std::vector>& Algorithm::getTensors() { diff --git a/src/Manager.cpp b/src/Manager.cpp index 38f67de0d..f32cda43e 100644 --- a/src/Manager.cpp +++ b/src/Manager.cpp @@ -361,13 +361,14 @@ std::shared_ptr Manager::algorithm(const std::vector>& tensors, const std::vector& spirv, const Workgroup& workgroup, - const Constants& specializationConstants) + const Constants& specializationConstants, + const Constants& pushConstants) { KP_LOG_DEBUG("Kompute Manager algorithm creation triggered"); std::shared_ptr algorithm{ new kp::Algorithm( - this->mDevice, tensors, spirv, workgroup, specializationConstants) }; + this->mDevice, tensors, spirv, workgroup, specializationConstants, pushConstants) }; if (this->mManageResources) { this->mManagedAlgorithms.push_back(algorithm); diff --git a/src/OpAlgoDispatch.cpp b/src/OpAlgoDispatch.cpp index 4a30751fb..3aef85e4f 100644 --- a/src/OpAlgoDispatch.cpp +++ b/src/OpAlgoDispatch.cpp @@ -34,8 +34,12 @@ OpAlgoDispatch::record(const vk::CommandBuffer& commandBuffer) vk::PipelineStageFlagBits::eComputeShader); } + if (this->mPushConstants.size()) { + this->mAlgorithm->setPush(this->mPushConstants); + } + this->mAlgorithm->bindCore(commandBuffer); - this->mAlgorithm->bindPush(commandBuffer, this->mPushConstants); + this->mAlgorithm->bindPush(commandBuffer); this->mAlgorithm->recordDispatch(commandBuffer); } diff --git a/src/include/kompute/Algorithm.hpp b/src/include/kompute/Algorithm.hpp index 32e5d9bdf..cabc673b6 100644 --- a/src/include/kompute/Algorithm.hpp +++ b/src/include/kompute/Algorithm.hpp @@ -24,7 +24,8 @@ class Algorithm const std::vector>& tensors = {}, const std::vector& spirv = {}, const Workgroup& workgroup = {}, - const Constants& specializationConstants = {}); + const Constants& specializationConstants = {}, + const Constants& pushConstants = {}); /** * Initialiser for the shader data provided to the algorithm as well as @@ -38,7 +39,8 @@ class Algorithm void rebuild(const std::vector>& tensors, const std::vector& spirv, const Workgroup& workgroup = {}, - const Constants& specializationConstants = {}); + const Constants& specializationConstants = {}, + const Constants& pushConstants = {}); /** * Destructor for Algorithm which is responsible for freeing and desroying @@ -58,15 +60,16 @@ class Algorithm void bindCore(const vk::CommandBuffer& commandBuffer); - void bindPush(const vk::CommandBuffer& commandBuffer, - const Constants& pushConstants); + void bindPush(const vk::CommandBuffer& commandBuffer); bool isInit(); void setWorkgroup(const Workgroup& workgroup, uint32_t minSize = 1); + void setPush(const Constants& pushConstants); const Workgroup& getWorkgroup(); const Constants& getSpecializationConstants(); + const Constants& getPush(); const std::vector>& getTensors(); void destroy(); @@ -95,6 +98,7 @@ class Algorithm // -------------- ALWAYS OWNED RESOURCES std::vector mSpirv; Constants mSpecializationConstants; + Constants mPushConstants; Workgroup mWorkgroup; bool mIsInit; diff --git a/src/include/kompute/Manager.hpp b/src/include/kompute/Manager.hpp index 61212abf2..0dbde246d 100644 --- a/src/include/kompute/Manager.hpp +++ b/src/include/kompute/Manager.hpp @@ -87,7 +87,8 @@ class Manager const std::vector>& tensors = {}, const std::vector& spirv = {}, const Workgroup& workgroup = {}, - const Constants& specializationConstants = {}); + const Constants& specializationConstants = {}, + const Constants& pushConstants = {}); void destroy(); void clear(); diff --git a/test/TestPushConstant.cpp b/test/TestPushConstant.cpp index ae8cf4a32..8c3357b33 100644 --- a/test/TestPushConstant.cpp +++ b/test/TestPushConstant.cpp @@ -4,7 +4,7 @@ #include "fmt/ranges.h" -TEST(TestPushConstants, TestTwoConstants) +TEST(TestPushConstants, TestConstants) { { std::string shader(R"( @@ -32,7 +32,7 @@ TEST(TestPushConstants, TestTwoConstants) std::shared_ptr tensor = mgr.tensor({ 0, 0, 0 }); std::shared_ptr algo = - mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 })); + mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0, 0.0, 0.0 }); sq = mgr.sequence() ->record({ tensor }) @@ -47,3 +47,41 @@ TEST(TestPushConstants, TestTwoConstants) } } } + +TEST(TestPushConstants, TestConstantsWrongSize) +{ + { + std::string shader(R"( + #version 450 + layout(push_constant) uniform PushConstants { + float x; + float y; + float z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + })"); + + std::vector spirv = kp::Shader::compile_source(shader); + + std::shared_ptr sq = nullptr; + + { + kp::Manager mgr; + + std::shared_ptr tensor = mgr.tensor({ 0, 0, 0 }); + + std::shared_ptr algo = + mgr.algorithm({ tensor }, spirv, kp::Workgroup({ 1 }), {}, { 0.0 }); + + sq = mgr.sequence() + ->record({ tensor }); + + EXPECT_THROW(sq->record(algo, kp::Constants{ 0.1, 0.2, 0.3 }), std::runtime_error); + } + } +}