Updated to add typedef on Constants and Workgroup
This commit is contained in:
parent
bf86daa3ef
commit
9adfa34fd3
8 changed files with 25 additions and 19 deletions
|
|
@ -28,7 +28,7 @@ public:
|
|||
*/
|
||||
Algorithm(std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
const std::vector<float>& specializationConstants = {});
|
||||
const Constants& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Initialiser for the shader data provided to the algorithm as well as
|
||||
|
|
@ -80,7 +80,7 @@ private:
|
|||
bool mFreePipeline = false;
|
||||
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::vector<float> mSpecializationConstants;
|
||||
Constants mSpecializationConstants;
|
||||
|
||||
// Create util functions
|
||||
void createShaderModule(const std::vector<char>& shaderFileData);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ static const char* KOMPUTE_LOG_TAG = "KomputeLog";
|
|||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
|
||||
// Typedefs to simplify interaction with core types
|
||||
namespace kp {
|
||||
typedef std::array<uint32_t, 3> Workgroup;
|
||||
typedef std::vector<float> Constants;
|
||||
}
|
||||
|
||||
// Must be after vulkan is included
|
||||
#ifndef KOMPUTE_VK_API_VERSION
|
||||
#ifndef KOMPUTE_VK_API_MAJOR_VERSION
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ class OpAlgoBase : public OpBase
|
|||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::array<uint32_t, 3>& komputeWorkgroup = {},
|
||||
const std::vector<float>& specializationConstants = {});
|
||||
const Workgroup& komputeWorkgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Constructor that enables a file to be passed to the operation with
|
||||
|
|
@ -64,8 +64,8 @@ class OpAlgoBase : public OpBase
|
|||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
std::string shaderFilePath,
|
||||
const std::array<uint32_t, 3>& komputeWorkgroup = {},
|
||||
const std::vector<float>& specializationConstants = {});
|
||||
const Workgroup& komputeWorkgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Constructor that enables raw shader data to be passed to the main operation
|
||||
|
|
@ -83,8 +83,8 @@ class OpAlgoBase : public OpBase
|
|||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<char>& shaderDataRaw,
|
||||
const std::array<uint32_t, 3>& komputeWorkgroup = {},
|
||||
const std::vector<float>& specializationConstants = {});
|
||||
const Workgroup& komputeWorkgroup = {},
|
||||
const Constants& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
|
|
@ -132,7 +132,7 @@ class OpAlgoBase : public OpBase
|
|||
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
|
||||
std::array<uint32_t, 3> mKomputeWorkgroup;
|
||||
Workgroup mKomputeWorkgroup;
|
||||
|
||||
std::string mShaderFilePath; ///< Optional member variable which can be provided for the OpAlgoBase to find the data automatically and load for processing
|
||||
std::vector<char> mShaderDataRaw; ///< Optional member variable which can be provided to contain either the raw shader content or the spirv binary content
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class OpAlgoLhsRhsOut : public OpAlgoBase
|
|||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>> tensors,
|
||||
const std::array<uint32_t, 3>& komputeWorkgroup = {});
|
||||
const Workgroup& komputeWorkgroup = {});
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class OpMult : public OpAlgoBase
|
|||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>> tensors,
|
||||
const std::array<uint32_t, 3>& komputeWorkgroup = {})
|
||||
const Workgroup& komputeWorkgroup = {})
|
||||
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, "", komputeWorkgroup)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpMult constructor with params");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue