diff --git a/single_include/kompute/Kompute.hpp b/single_include/kompute/Kompute.hpp index a4c6fa7d7..d1360f5ad 100755 --- a/single_include/kompute/Kompute.hpp +++ b/single_include/kompute/Kompute.hpp @@ -10,6 +10,12 @@ static const char* KOMPUTE_LOG_TAG = "KomputeLog"; #include +// Typedefs to simplify interaction with core types +namespace kp { +typedef std::array Workgroup; +typedef std::vector Constants; +} + // Must be after vulkan is included #ifndef KOMPUTE_VK_API_VERSION #ifndef KOMPUTE_VK_API_MAJOR_VERSION @@ -1739,7 +1745,7 @@ class OpAlgoBase : public OpBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, - const std::array& komputeWorkgroup = {}, + const Workgroup& komputeWorkgroup = {}, const std::vector& specializationConstants = {}); /** @@ -1759,7 +1765,7 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, std::string shaderFilePath, - const std::array& komputeWorkgroup = {}, + const Workgroup& komputeWorkgroup = {}, const std::vector& specializationConstants = {}); /** @@ -1778,7 +1784,7 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, const std::vector& shaderDataRaw, - const std::array& komputeWorkgroup = {}, + const Workgroup& komputeWorkgroup = {}, const std::vector& specializationConstants = {}); /** @@ -1826,7 +1832,7 @@ class OpAlgoBase : public OpBase // -------------- ALWAYS OWNED RESOURCES - std::array mKomputeWorkgroup; + Workgroup mKomputeWorkgroup; std::string mShaderFilePath; ///< Optional member variable which can be provided for the OpAlgoBase to find the data automatically and load for processing std::vector mShaderDataRaw; ///< Optional member variable which can be provided to contain either the raw shader content or the spirv binary content @@ -1869,7 +1875,7 @@ class OpAlgoLhsRhsOut : public OpAlgoBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - const std::array& komputeWorkgroup = {}); + const Workgroup& komputeWorkgroup = {}); /** * Default destructor, which is in charge of destroying the algorithm @@ -1948,7 +1954,7 @@ class OpMult : public OpAlgoBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - const std::array& komputeWorkgroup = {}) + const Workgroup& komputeWorkgroup = {}) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, "", komputeWorkgroup) { SPDLOG_DEBUG("Kompute OpMult constructor with params");