From bf86daa3efe9c80bf97ca206893b8c7ca2d10e82 Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Sun, 14 Feb 2021 07:18:54 +0000 Subject: [PATCH] Updated single include --- single_include/kompute/Kompute.hpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/single_include/kompute/Kompute.hpp b/single_include/kompute/Kompute.hpp index 637931bd1..a4c6fa7d7 100755 --- a/single_include/kompute/Kompute.hpp +++ b/single_include/kompute/Kompute.hpp @@ -1717,11 +1717,6 @@ namespace kp { class OpAlgoBase : public OpBase { public: - struct KomputeWorkgroup { - uint32_t x; - uint32_t y; - uint32_t z; - }; /** * Base constructor, should not be used unless explicitly intended. @@ -1744,7 +1739,7 @@ class OpAlgoBase : public OpBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, - KomputeWorkgroup komputeWorkgroup = {}, + const std::array& komputeWorkgroup = {}, const std::vector& specializationConstants = {}); /** @@ -1764,7 +1759,7 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, std::string shaderFilePath, - KomputeWorkgroup komputeWorkgroup = {}, + const std::array& komputeWorkgroup = {}, const std::vector& specializationConstants = {}); /** @@ -1783,7 +1778,7 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, const std::vector& shaderDataRaw, - KomputeWorkgroup komputeWorkgroup = {}, + const std::array& komputeWorkgroup = {}, const std::vector& specializationConstants = {}); /** @@ -1831,7 +1826,7 @@ class OpAlgoBase : public OpBase // -------------- ALWAYS OWNED RESOURCES - KomputeWorkgroup mKomputeWorkgroup; + std::array mKomputeWorkgroup; std::string mShaderFilePath; ///< Optional member variable which can be provided for the OpAlgoBase to find the data automatically and load for processing std::vector mShaderDataRaw; ///< Optional member variable which can be provided to contain either the raw shader content or the spirv binary content @@ -1874,7 +1869,7 @@ class OpAlgoLhsRhsOut : public OpAlgoBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup()); + const std::array& komputeWorkgroup = {}); /** * Default destructor, which is in charge of destroying the algorithm @@ -1953,7 +1948,7 @@ class OpMult : public OpAlgoBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector> tensors, - KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup()) + const std::array& komputeWorkgroup = {}) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, "", komputeWorkgroup) { SPDLOG_DEBUG("Kompute OpMult constructor with params");