Updated to add typedef on Constants and Workgroup

This commit is contained in:
Alejandro Saucedo 2021-02-14 07:29:50 +00:00
parent bf86daa3ef
commit 9adfa34fd3
8 changed files with 25 additions and 19 deletions

View file

@ -11,7 +11,7 @@ Algorithm::Algorithm()
Algorithm::Algorithm(std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
const std::vector<float>& specializationConstants)
const Constants& specializationConstants)
{
SPDLOG_DEBUG("Kompute Algorithm Constructor with device");

View file

@ -13,8 +13,8 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
const std::array<uint32_t, 3>& komputeWorkgroup,
const std::vector<float>& specializationConstants)
const Workgroup& komputeWorkgroup,
const Constants& specializationConstants)
: OpBase(physicalDevice, device, commandBuffer, tensors)
{
SPDLOG_DEBUG("Kompute OpAlgoBase constructor with params numTensors: {}",
@ -46,8 +46,8 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::string shaderFilePath,
const std::array<uint32_t, 3>& komputeWorkgroup,
const std::vector<float>& specializationConstants)
const Workgroup& komputeWorkgroup,
const Constants& specializationConstants)
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup, specializationConstants)
{
SPDLOG_DEBUG(
@ -62,8 +62,8 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<char>& shaderDataRaw,
const std::array<uint32_t, 3>& komputeWorkgroup,
const std::vector<float>& specializationConstants)
const Workgroup& komputeWorkgroup,
const Constants& specializationConstants)
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup, specializationConstants)
{
SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shader raw "

View file

@ -14,7 +14,7 @@ OpAlgoLhsRhsOut::OpAlgoLhsRhsOut(
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>> tensors,
const std::array<uint32_t, 3>& komputeWorkgroup)
const Workgroup& komputeWorkgroup)
// The inheritance is initialised with the copyOutputData to false given that
// this depencendant class handles the transfer of data via staging buffers in
// a granular way.

View file

@ -28,7 +28,7 @@ public:
*/
Algorithm(std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
const std::vector<float>& specializationConstants = {});
const Constants& specializationConstants = {});
/**
* Initialiser for the shader data provided to the algorithm as well as
@ -80,7 +80,7 @@ private:
bool mFreePipeline = false;
// -------------- ALWAYS OWNED RESOURCES
std::vector<float> mSpecializationConstants;
Constants mSpecializationConstants;
// Create util functions
void createShaderModule(const std::vector<char>& shaderFileData);

View file

@ -10,6 +10,12 @@ static const char* KOMPUTE_LOG_TAG = "KomputeLog";
#include <vulkan/vulkan.hpp>
// Typedefs to simplify interaction with core types
namespace kp {
typedef std::array<uint32_t, 3> Workgroup;
typedef std::vector<float> Constants;
}
// Must be after vulkan is included
#ifndef KOMPUTE_VK_API_VERSION
#ifndef KOMPUTE_VK_API_MAJOR_VERSION

View file

@ -44,8 +44,8 @@ class OpAlgoBase : public OpBase
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
const std::array<uint32_t, 3>& komputeWorkgroup = {},
const std::vector<float>& specializationConstants = {});
const Workgroup& komputeWorkgroup = {},
const Constants& specializationConstants = {});
/**
* Constructor that enables a file to be passed to the operation with
@ -64,8 +64,8 @@ class OpAlgoBase : public OpBase
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::string shaderFilePath,
const std::array<uint32_t, 3>& komputeWorkgroup = {},
const std::vector<float>& specializationConstants = {});
const Workgroup& komputeWorkgroup = {},
const Constants& specializationConstants = {});
/**
* Constructor that enables raw shader data to be passed to the main operation
@ -83,8 +83,8 @@ class OpAlgoBase : public OpBase
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<char>& shaderDataRaw,
const std::array<uint32_t, 3>& komputeWorkgroup = {},
const std::vector<float>& specializationConstants = {});
const Workgroup& komputeWorkgroup = {},
const Constants& specializationConstants = {});
/**
* Default destructor, which is in charge of destroying the algorithm
@ -132,7 +132,7 @@ class OpAlgoBase : public OpBase
// -------------- ALWAYS OWNED RESOURCES
std::array<uint32_t, 3> mKomputeWorkgroup;
Workgroup mKomputeWorkgroup;
std::string mShaderFilePath; ///< Optional member variable which can be provided for the OpAlgoBase to find the data automatically and load for processing
std::vector<char> mShaderDataRaw; ///< Optional member variable which can be provided to contain either the raw shader content or the spirv binary content

View file

@ -40,7 +40,7 @@ class OpAlgoLhsRhsOut : public OpAlgoBase
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>> tensors,
const std::array<uint32_t, 3>& komputeWorkgroup = {});
const Workgroup& komputeWorkgroup = {});
/**
* Default destructor, which is in charge of destroying the algorithm

View file

@ -44,7 +44,7 @@ class OpMult : public OpAlgoBase
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>> tensors,
const std::array<uint32_t, 3>& komputeWorkgroup = {})
const Workgroup& komputeWorkgroup = {})
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, "", komputeWorkgroup)
{
SPDLOG_DEBUG("Kompute OpMult constructor with params");