diff --git a/src/Algorithm.cpp b/src/Algorithm.cpp index 972c50f74..227a9b6bc 100644 --- a/src/Algorithm.cpp +++ b/src/Algorithm.cpp @@ -11,7 +11,7 @@ Algorithm::Algorithm() Algorithm::Algorithm(std::shared_ptr device, std::shared_ptr commandBuffer, - const std::vector& specializationConstants) + const Constants& specializationConstants) { SPDLOG_DEBUG("Kompute Algorithm Constructor with device"); diff --git a/src/OpAlgoBase.cpp b/src/OpAlgoBase.cpp index 55b7dfec5..04c4cdebe 100644 --- a/src/OpAlgoBase.cpp +++ b/src/OpAlgoBase.cpp @@ -13,8 +13,8 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, - const std::array& komputeWorkgroup, - const std::vector& specializationConstants) + const Workgroup& komputeWorkgroup, + const Constants& specializationConstants) : OpBase(physicalDevice, device, commandBuffer, tensors) { SPDLOG_DEBUG("Kompute OpAlgoBase constructor with params numTensors: {}", @@ -46,8 +46,8 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr physicalDevice, std::shared_ptr commandBuffer, std::vector>& tensors, std::string shaderFilePath, - const std::array& komputeWorkgroup, - const std::vector& specializationConstants) + const Workgroup& komputeWorkgroup, + const Constants& specializationConstants) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup, specializationConstants) { SPDLOG_DEBUG( @@ -62,8 +62,8 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr physicalDevice, std::shared_ptr commandBuffer, std::vector>& tensors, const std::vector& shaderDataRaw, - const std::array& komputeWorkgroup, - const std::vector& specializationConstants) + const Workgroup& komputeWorkgroup, + const Constants& specializationConstants) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup, specializationConstants) { SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shader raw " diff --git a/src/OpAlgoLhsRhsOut.cpp b/src/OpAlgoLhsRhsOut.cpp index 6798a009f..51a1d0fb9 100644 --- a/src/OpAlgoLhsRhsOut.cpp +++ b/src/OpAlgoLhsRhsOut.cpp @@ -14,7 +14,7 @@ OpAlgoLhsRhsOut::OpAlgoLhsRhsOut( std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - const std::array& komputeWorkgroup) + const Workgroup& komputeWorkgroup) // The inheritance is initialised with the copyOutputData to false given that // this depencendant class handles the transfer of data via staging buffers in // a granular way. diff --git a/src/include/kompute/Algorithm.hpp b/src/include/kompute/Algorithm.hpp index ac631a8cd..361ebe4e5 100644 --- a/src/include/kompute/Algorithm.hpp +++ b/src/include/kompute/Algorithm.hpp @@ -28,7 +28,7 @@ public: */ Algorithm(std::shared_ptr device, std::shared_ptr commandBuffer, - const std::vector& specializationConstants = {}); + const Constants& specializationConstants = {}); /** * Initialiser for the shader data provided to the algorithm as well as @@ -80,7 +80,7 @@ private: bool mFreePipeline = false; // -------------- ALWAYS OWNED RESOURCES - std::vector mSpecializationConstants; + Constants mSpecializationConstants; // Create util functions void createShaderModule(const std::vector& shaderFileData); diff --git a/src/include/kompute/Core.hpp b/src/include/kompute/Core.hpp index e1940e756..809cf5322 100644 --- a/src/include/kompute/Core.hpp +++ b/src/include/kompute/Core.hpp @@ -10,6 +10,12 @@ static const char* KOMPUTE_LOG_TAG = "KomputeLog"; #include +// Typedefs to simplify interaction with core types +namespace kp { +typedef std::array Workgroup; +typedef std::vector Constants; +} + // Must be after vulkan is included #ifndef KOMPUTE_VK_API_VERSION #ifndef KOMPUTE_VK_API_MAJOR_VERSION diff --git a/src/include/kompute/operations/OpAlgoBase.hpp b/src/include/kompute/operations/OpAlgoBase.hpp index 9ac298c34..95c5853f0 100644 --- a/src/include/kompute/operations/OpAlgoBase.hpp +++ b/src/include/kompute/operations/OpAlgoBase.hpp @@ -44,8 +44,8 @@ class OpAlgoBase : public OpBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, - const std::array& komputeWorkgroup = {}, - const std::vector& specializationConstants = {}); + const Workgroup& komputeWorkgroup = {}, + const Constants& specializationConstants = {}); /** * Constructor that enables a file to be passed to the operation with @@ -64,8 +64,8 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, std::string shaderFilePath, - const std::array& komputeWorkgroup = {}, - const std::vector& specializationConstants = {}); + const Workgroup& komputeWorkgroup = {}, + const Constants& specializationConstants = {}); /** * Constructor that enables raw shader data to be passed to the main operation @@ -83,8 +83,8 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, const std::vector& shaderDataRaw, - const std::array& komputeWorkgroup = {}, - const std::vector& specializationConstants = {}); + const Workgroup& komputeWorkgroup = {}, + const Constants& specializationConstants = {}); /** * Default destructor, which is in charge of destroying the algorithm @@ -132,7 +132,7 @@ class OpAlgoBase : public OpBase // -------------- ALWAYS OWNED RESOURCES - std::array mKomputeWorkgroup; + Workgroup mKomputeWorkgroup; std::string mShaderFilePath; ///< Optional member variable which can be provided for the OpAlgoBase to find the data automatically and load for processing std::vector mShaderDataRaw; ///< Optional member variable which can be provided to contain either the raw shader content or the spirv binary content diff --git a/src/include/kompute/operations/OpAlgoLhsRhsOut.hpp b/src/include/kompute/operations/OpAlgoLhsRhsOut.hpp index 70c01d929..c1223e738 100644 --- a/src/include/kompute/operations/OpAlgoLhsRhsOut.hpp +++ b/src/include/kompute/operations/OpAlgoLhsRhsOut.hpp @@ -40,7 +40,7 @@ class OpAlgoLhsRhsOut : public OpAlgoBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - const std::array& komputeWorkgroup = {}); + const Workgroup& komputeWorkgroup = {}); /** * Default destructor, which is in charge of destroying the algorithm diff --git a/src/include/kompute/operations/OpMult.hpp b/src/include/kompute/operations/OpMult.hpp index 69953afe1..8d3263fef 100644 --- a/src/include/kompute/operations/OpMult.hpp +++ b/src/include/kompute/operations/OpMult.hpp @@ -44,7 +44,7 @@ class OpMult : public OpAlgoBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - const std::array& komputeWorkgroup = {}) + const Workgroup& komputeWorkgroup = {}) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, "", komputeWorkgroup) { SPDLOG_DEBUG("Kompute OpMult constructor with params");