Initil implementation

Signed-off-by: Alejandro Saucedo <axsauze@gmail.com>
This commit is contained in:
Alejandro Saucedo 2021-09-12 09:37:50 +01:00
parent 44e4ff6978
commit 860fda9fb5
6 changed files with 247 additions and 110 deletions

View file

@ -5,30 +5,6 @@
namespace kp {
Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
const std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<uint32_t>& spirv,
const Workgroup& workgroup,
const Constants& specializationConstants,
const Constants& pushConstants)
{
KP_LOG_DEBUG("Kompute Algorithm Constructor with device");
this->mDevice = device;
if (tensors.size() && spirv.size()) {
KP_LOG_INFO("Kompute Algorithm initialising with tensor size: {} and "
"spirv size: {}",
tensors.size(),
spirv.size());
this->rebuild(
tensors, spirv, workgroup, specializationConstants, pushConstants);
} else {
KP_LOG_INFO("Kompute Algorithm constructor with empty tensors and or "
"spirv so not rebuilding vulkan components");
}
}
Algorithm::~Algorithm()
{
KP_LOG_DEBUG("Kompute Algorithm Destructor started");
@ -36,33 +12,6 @@ Algorithm::~Algorithm()
this->destroy();
}
void
Algorithm::rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<uint32_t>& spirv,
const Workgroup& workgroup,
const Constants& specializationConstants,
const Constants& pushConstants)
{
KP_LOG_DEBUG("Kompute Algorithm rebuild started");
this->mTensors = tensors;
this->mSpirv = spirv;
this->mSpecializationConstants = specializationConstants;
this->mPushConstants = pushConstants;
this->setWorkgroup(workgroup,
this->mTensors.size() ? this->mTensors[0]->size() : 1);
// Descriptor pool is created first so if available then destroy all before
// rebuild
if (this->isInit()) {
this->destroy();
}
this->createParameters();
this->createShaderModule();
this->createPipeline();
}
bool
Algorithm::isInit()
{
@ -74,6 +23,13 @@ Algorithm::isInit()
void
Algorithm::destroy()
{
if (this->mPushConstantsData) {
free(this->mPushConstantsData);
}
if (this->mSpecializationConstantsData) {
free(this->mSpecializationConstantsData);
}
if (!this->mDevice) {
KP_LOG_WARN("Kompute Algorithm destroy function reached with null "
@ -279,10 +235,10 @@ Algorithm::createPipeline()
this->mDescriptorSetLayout.get());
vk::PushConstantRange pushConstantRange;
if (this->mPushConstants.size()) {
if (this->mPushConstantsSize) {
pushConstantRange.setStageFlags(vk::ShaderStageFlagBits::eCompute);
pushConstantRange.setOffset(0);
pushConstantRange.setSize(sizeof(float) * this->mPushConstants.size());
pushConstantRange.setSize(this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize);
pipelineLayoutInfo.setPushConstantRangeCount(1);
pipelineLayoutInfo.setPPushConstantRanges(&pushConstantRange);
@ -295,11 +251,11 @@ Algorithm::createPipeline()
std::vector<vk::SpecializationMapEntry> specializationEntries;
for (uint32_t i = 0; i < this->mSpecializationConstants.size(); i++) {
for (uint32_t i = 0; i < this->mSpecializationConstantsSize; i++) {
vk::SpecializationMapEntry specializationEntry(
static_cast<uint32_t>(i),
static_cast<uint32_t>(sizeof(float) * i),
sizeof(float));
static_cast<uint32_t>(this->mSpecializationConstantsDataTypeMemorySize * i),
this->mSpecializationConstantsDataTypeMemorySize);
specializationEntries.push_back(specializationEntry);
}
@ -309,8 +265,8 @@ Algorithm::createPipeline()
vk::SpecializationInfo specializationInfo(
static_cast<uint32_t>(specializationEntries.size()),
specializationEntries.data(),
sizeof(float) * this->mSpecializationConstants.size(),
this->mSpecializationConstants.data());
this->mSpecializationConstantsDataTypeMemorySize * this->mSpecializationConstantsSize,
this->mSpecializationConstantsData);
vk::PipelineShaderStageCreateInfo shaderStage(
vk::PipelineShaderStageCreateFlags(),
@ -381,15 +337,22 @@ Algorithm::recordBindCore(const vk::CommandBuffer& commandBuffer)
void
Algorithm::recordBindPush(const vk::CommandBuffer& commandBuffer)
{
if (this->mPushConstants.size()) {
if (this->mPushConstantsSize) {
KP_LOG_DEBUG("Kompute Algorithm binding push constants size: {}",
this->mPushConstants.size());
this->mPushConstantsSize);
KP_LOG_DEBUG("{} {}",
this->mPushConstantsDataTypeMemorySize,
this->mPushConstantsData == nullptr);
KP_LOG_DEBUG("{}",
((float*)this->mPushConstantsData)[0]);
commandBuffer.pushConstants(*this->mPipelineLayout,
vk::ShaderStageFlagBits::eCompute,
0,
this->mPushConstants.size() * sizeof(float),
this->mPushConstants.data());
this->mPushConstantsSize * this->mPushConstantsDataTypeMemorySize,
this->mPushConstantsData);
KP_LOG_DEBUG("Constants bound: {}",
this->mPushConstantsSize);
}
}
@ -426,39 +389,12 @@ Algorithm::setWorkgroup(const Workgroup& workgroup, uint32_t minSize)
this->mWorkgroup[2]);
}
void
Algorithm::setPush(const Constants& pushConstants)
{
if (pushConstants.size() != this->mPushConstants.size()) {
throw std::runtime_error(
fmt::format("Kompute Algorithm push "
"constant provided is size {} but expected size {}",
pushConstants.size(),
this->mPushConstants.size()));
}
this->mPushConstants = pushConstants;
}
const Workgroup&
Algorithm::getWorkgroup()
{
return this->mWorkgroup;
}
const Constants&
Algorithm::getSpecializationConstants()
{
return this->mSpecializationConstants;
}
const Constants&
Algorithm::getPush()
{
return this->mPushConstants;
}
const std::vector<std::shared_ptr<Tensor>>&
Algorithm::getTensors()
{