Extended algorithm to add spec consts for int and float
This commit is contained in:
parent
9cb4c2f1e1
commit
0b84876c95
4 changed files with 145 additions and 26 deletions
|
|
@ -10,12 +10,14 @@ Algorithm::Algorithm()
|
|||
}
|
||||
|
||||
Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer)
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
const SpecializationContainer& specializationConstants)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute Algorithm Constructor with device");
|
||||
|
||||
this->mDevice = device;
|
||||
this->mCommandBuffer = commandBuffer;
|
||||
this->mSpecializationConstants = specializationConstants;
|
||||
}
|
||||
|
||||
Algorithm::~Algorithm()
|
||||
|
|
@ -114,11 +116,7 @@ Algorithm::init(const std::vector<char>& shaderFileData,
|
|||
this->createParameters(tensorParams);
|
||||
this->createShaderModule(shaderFileData);
|
||||
|
||||
std::vector<uint32_t> sizes;
|
||||
for (std::shared_ptr<Tensor> tensor : tensorParams) {
|
||||
sizes.push_back(tensor->size());
|
||||
}
|
||||
this->createPipeline(sizes);
|
||||
this->createPipeline();
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -225,7 +223,7 @@ Algorithm::createShaderModule(const std::vector<char>& shaderFileData)
|
|||
}
|
||||
|
||||
void
|
||||
Algorithm::createPipeline(std::vector<uint32_t> specializationData)
|
||||
Algorithm::createPipeline()
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute Algorithm calling create Pipeline");
|
||||
|
||||
|
|
@ -241,20 +239,22 @@ Algorithm::createPipeline(std::vector<uint32_t> specializationData)
|
|||
|
||||
std::vector<vk::SpecializationMapEntry> specializationEntries;
|
||||
|
||||
for (size_t i = 0; i < specializationData.size(); i++) {
|
||||
for (uint32_t i = 0; i < this->mSpecializationConstants.size(); i++) {
|
||||
vk::SpecializationMapEntry specializationEntry(
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(sizeof(uint32_t) * i),
|
||||
sizeof(uint32_t));
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(this->mSpecializationConstants.instanceMemorySize() * i),
|
||||
this->mSpecializationConstants.instanceMemorySize());
|
||||
|
||||
specializationEntries.push_back(specializationEntry);
|
||||
}
|
||||
|
||||
// This passes ownership of the memory so we remove ownership from
|
||||
// specialization container by using "transferDataOwnership"
|
||||
vk::SpecializationInfo specializationInfo(
|
||||
static_cast<uint32_t>(specializationEntries.size()),
|
||||
specializationEntries.data(),
|
||||
sizeof(uint32_t) * specializationEntries.size(),
|
||||
specializationData.data());
|
||||
this->mSpecializationConstants.totalMemorySize(),
|
||||
this->mSpecializationConstants.transferDataOwnership());
|
||||
|
||||
vk::PipelineShaderStageCreateInfo shaderStage(
|
||||
vk::PipelineShaderStageCreateFlags(),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue