diff --git a/single_include/kompute/Kompute.hpp b/single_include/kompute/Kompute.hpp index 21ebae5b8..c1162b52f 100755 --- a/single_include/kompute/Kompute.hpp +++ b/single_include/kompute/Kompute.hpp @@ -1633,7 +1633,115 @@ namespace kp { */ class Algorithm { - public: +public: + // TODO: Move as internal struct of speccontainer + class SpecializationConstant { + public: + SpecializationConstant(const SpecializationConstant& specializationConstant) { + SPDLOG_DEBUG("Kompute SpecializationConstant copy constructor: {}", *((uint32_t*)specializationConstant.mInstanceData)); + this->mInstanceData = (char*)malloc(sizeof(uint32_t)); + memcpy(this->mInstanceData, specializationConstant.mInstanceData, sizeof(uint32_t)); + } + // This class is required in absence of std::variant to ensure C++11 support + SpecializationConstant(uint32_t val) { + SPDLOG_DEBUG("Kompute SpecializationConstant uint32_t constructor: {}", val); + this->mInstanceData = (char*)malloc(sizeof(uint32_t)); + memcpy(this->mInstanceData, &val, sizeof(uint32_t)); + } + SpecializationConstant(float val) { + SPDLOG_DEBUG("Kompute SpecializationConstant float constructor: {}", val); + this->mInstanceData = (char*)malloc(sizeof(uint32_t)); + memcpy(this->mInstanceData, &val, sizeof(uint32_t)); + } + ~SpecializationConstant() { + free(this->mInstanceData); + } + void *data() { + return this->mInstanceData; + } + private: + // We use char pointer to enable for pointer arithmetic + char *mInstanceData = nullptr; + }; + + class SpecializationContainer { + public: + SpecializationContainer() { + SPDLOG_DEBUG("Kompute SpecializationContainer default initialiser"); + this->mFreeData = false; + } + + SpecializationContainer(const SpecializationContainer& specializationContainer) + { + SPDLOG_DEBUG("Kompute SpecializationContainer copy constructor, size: {}", specializationContainer.mSpecializationConstants.size()); + SpecializationContainer(specializationContainer.mSpecializationConstants); + } + + SpecializationContainer(std::vector instances) { + SPDLOG_DEBUG("Kompute SpecializationContainer initialiser with instances size {}", instances.size()); + + static_assert(sizeof(uint32_t) == sizeof(float) && sizeof(uint32_t) == sizeof(char) * 4, + "Kompute requires uint32_t and float to be of same size. Please report this to github."); + + // totalMemorySize depends on instances being set so this needs to be set before + this->mSpecializationConstants = instances; + + // Data has then to be allocated in order to copy memory into it + this->mData = (char*)malloc(this->totalMemorySize()); + + this->mFreeData = true; + + for (size_t i = 0; i < this->size(); i++) { + + memcpy(this->mData + (i * sizeof(uint32_t)), instances[i].data(), sizeof(uint32_t)); + } + } + + ~SpecializationContainer() { + SPDLOG_DEBUG("Kompute SpecializationContainer destructor started"); + + this->mSpecializationConstants.clear(); + + if (this->mFreeData) { + SPDLOG_DEBUG("Kompute SpecializationContainer freeing data"); + this->mFreeData = false; + free(this->mData); + } else { + SPDLOG_DEBUG("Kompute SpecializationContainer no data was freed"); + } + + SPDLOG_DEBUG("kompute SpecializationContainer freed data"); + } + + void *transferDataOwnership() { + SPDLOG_DEBUG("Kompute SpecializationContainer data transfer ownership requested"); + this->mFreeData = false; + return (void*)this->mData; + } + + uint32_t size() { + return this->mSpecializationConstants.size(); + } + + uint32_t totalMemorySize() { + return this->instanceMemorySize() * this->size(); + } + + uint32_t instanceMemorySize() { + // At this point only variables accepted are uint32_t and float which are same size + return sizeof(uint32_t); + } + + private: + + std::vector mSpecializationConstants; + bool mFreeData = false; + // We use char pointer to enable for pointer arithmetic + char *mData = nullptr; + }; +private: + // Private struct template which is then +public: /** Base constructor for Algorithm. Should not be used unless explicit intended. @@ -1648,7 +1756,8 @@ class Algorithm * shaders */ Algorithm(std::shared_ptr device, - std::shared_ptr commandBuffer); + std::shared_ptr commandBuffer, + const SpecializationContainer& specializationConstants = {}); /** * Initialiser for the shader data provided to the algorithm as well as @@ -1656,6 +1765,7 @@ class Algorithm * * @param shaderFileData The bytes in spir-v format of the shader * @tensorParams The Tensors to be used in the Algorithm / shader for + * @specalizationInstalces The specialization parameters to pass to the function * processing */ void init(const std::vector& shaderFileData, @@ -1677,7 +1787,7 @@ class Algorithm */ void recordDispatch(uint32_t x = 1, uint32_t y = 1, uint32_t z = 1); - private: +private: // -------------- NEVER OWNED RESOURCES std::shared_ptr mDevice; std::shared_ptr mCommandBuffer; @@ -1698,9 +1808,12 @@ class Algorithm std::shared_ptr mPipeline; bool mFreePipeline = false; + // -------------- ALWAYS OWNED RESOURCES + SpecializationContainer mSpecializationConstants; + // Create util functions void createShaderModule(const std::vector& shaderFileData); - void createPipeline(std::vector specializationData = {}); + void createPipeline(); // Parameters void createParameters(std::vector>& tensorParams); @@ -1747,7 +1860,8 @@ class OpAlgoBase : public OpBase std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, - KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup()); + KomputeWorkgroup komputeWorkgroup = {}, + const Algorithm::SpecializationContainer& specializationConstants = {}); /** * Constructor that enables a file to be passed to the operation with @@ -1766,7 +1880,8 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, std::string shaderFilePath, - KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup()); + KomputeWorkgroup komputeWorkgroup = {}, + const Algorithm::SpecializationContainer& specializationConstants = {}); /** * Constructor that enables raw shader data to be passed to the main operation @@ -1784,7 +1899,8 @@ class OpAlgoBase : public OpBase std::shared_ptr commandBuffer, std::vector>& tensors, const std::vector& shaderDataRaw, - KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup()); + KomputeWorkgroup komputeWorkgroup = {}, + const Algorithm::SpecializationContainer& specializationConstants = {}); /** * Default destructor, which is in charge of destroying the algorithm