Simplified specialization constants by limiting to floats2

This commit is contained in:
Alejandro Saucedo 2021-02-14 06:55:09 +00:00
parent 7126cc47ff
commit a7801cedd0
5 changed files with 358 additions and 470 deletions

View file

@ -11,7 +11,7 @@ Algorithm::Algorithm()
Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
const SpecializationContainer& specializationConstants)
const std::vector<float>& specializationConstants)
{
SPDLOG_DEBUG("Kompute Algorithm Constructor with device");
@ -116,6 +116,10 @@ Algorithm::init(const std::vector<char>& shaderFileData,
this->createParameters(tensorParams);
this->createShaderModule(shaderFileData);
for (std::shared_ptr<Tensor> tensor : tensorParams) {
this->mSpecializationConstants.push_back(tensor->size());
}
this->createPipeline();
}
@ -242,8 +246,8 @@ Algorithm::createPipeline()
for (uint32_t i = 0; i < this->mSpecializationConstants.size(); i++) {
vk::SpecializationMapEntry specializationEntry(
static_cast<uint32_t>(i),
static_cast<uint32_t>(this->mSpecializationConstants.instanceMemorySize() * i),
this->mSpecializationConstants.instanceMemorySize());
static_cast<uint32_t>(sizeof(float) * i),
sizeof(float));
specializationEntries.push_back(specializationEntry);
}
@ -253,8 +257,8 @@ Algorithm::createPipeline()
vk::SpecializationInfo specializationInfo(
static_cast<uint32_t>(specializationEntries.size()),
specializationEntries.data(),
this->mSpecializationConstants.totalMemorySize(),
this->mSpecializationConstants.transferDataOwnership());
sizeof(float) * this->mSpecializationConstants.size(),
this->mSpecializationConstants.data());
vk::PipelineShaderStageCreateInfo shaderStage(
vk::PipelineShaderStageCreateFlags(),