Extended algorithm to add spec consts for int and float

This commit is contained in:
Alejandro Saucedo 2021-02-13 19:38:02 +00:00
parent 9cb4c2f1e1
commit 0b84876c95
4 changed files with 145 additions and 26 deletions

View file

@ -10,12 +10,14 @@ Algorithm::Algorithm()
}
Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer)
std::shared_ptr<vk::CommandBuffer> commandBuffer,
const SpecializationContainer& specializationConstants)
{
SPDLOG_DEBUG("Kompute Algorithm Constructor with device");
this->mDevice = device;
this->mCommandBuffer = commandBuffer;
this->mSpecializationConstants = specializationConstants;
}
Algorithm::~Algorithm()
@ -114,11 +116,7 @@ Algorithm::init(const std::vector<char>& shaderFileData,
this->createParameters(tensorParams);
this->createShaderModule(shaderFileData);
std::vector<uint32_t> sizes;
for (std::shared_ptr<Tensor> tensor : tensorParams) {
sizes.push_back(tensor->size());
}
this->createPipeline(sizes);
this->createPipeline();
}
void
@ -225,7 +223,7 @@ Algorithm::createShaderModule(const std::vector<char>& shaderFileData)
}
void
Algorithm::createPipeline(std::vector<uint32_t> specializationData)
Algorithm::createPipeline()
{
SPDLOG_DEBUG("Kompute Algorithm calling create Pipeline");
@ -241,20 +239,22 @@ Algorithm::createPipeline(std::vector<uint32_t> specializationData)
std::vector<vk::SpecializationMapEntry> specializationEntries;
for (size_t i = 0; i < specializationData.size(); i++) {
for (uint32_t i = 0; i < this->mSpecializationConstants.size(); i++) {
vk::SpecializationMapEntry specializationEntry(
static_cast<uint32_t>(i),
static_cast<uint32_t>(sizeof(uint32_t) * i),
sizeof(uint32_t));
static_cast<uint32_t>(i),
static_cast<uint32_t>(this->mSpecializationConstants.instanceMemorySize() * i),
this->mSpecializationConstants.instanceMemorySize());
specializationEntries.push_back(specializationEntry);
}
// This passes ownership of the memory so we remove ownership from
// specialization container by using "transferDataOwnership"
vk::SpecializationInfo specializationInfo(
static_cast<uint32_t>(specializationEntries.size()),
specializationEntries.data(),
sizeof(uint32_t) * specializationEntries.size(),
specializationData.data());
this->mSpecializationConstants.totalMemorySize(),
this->mSpecializationConstants.transferDataOwnership());
vk::PipelineShaderStageCreateInfo shaderStage(
vk::PipelineShaderStageCreateFlags(),