Added tests for push constants of all and mixed types

Signed-off-by: Alejandro Saucedo <axsauze@gmail.com>
This commit is contained in:
Alejandro Saucedo 2021-09-12 11:31:32 +01:00
parent c23573eb47
commit 858a70d9b8
7 changed files with 304 additions and 35 deletions

View file

@ -183,23 +183,29 @@ class Algorithm
template<typename T>
void setPushConstants(const std::vector<T>& pushConstants)
{
uint32_t memorySize = sizeof(decltype(pushConstants.back()));
uint32_t size = pushConstants.size();
uint32_t totalSize = memorySize * size;
uint32_t previousTotalSize = this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize;
if (pushConstants.size() != this->mPushConstantsSize) {
if (totalSize != previousTotalSize) {
throw std::runtime_error(
fmt::format("Kompute Algorithm push "
"constant provided is size {} but expected size {}",
pushConstants.size(),
this->mPushConstantsSize));
"constant total memory size provided is {} but expected {} bytes",
totalSize,
previousTotalSize));
}
if (this->mPushConstantsData) {
free(this->mPushConstantsData);
}
uint32_t memorySize = sizeof(decltype(pushConstants.back()));
uint32_t size = pushConstants.size();
this->setPushConstants(pushConstants.data(), size, memorySize);
}
void setPushConstants(void* data, uint32_t size, uint32_t memorySize) {
uint32_t totalSize = size * memorySize;
this->mPushConstantsData = malloc(totalSize);
memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
memcpy(this->mPushConstantsData, data, totalSize);
this->mPushConstantsDataTypeMemorySize = memorySize;
this->mPushConstantsSize = size;
}

View file

@ -25,8 +25,24 @@ class OpAlgoDispatch : public OpBase
* @param algorithm The algorithm object to use for dispatch
* @param pushConstants The push constants to use for override
*/
template<typename T = float>
OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm,
const kp::Constants& pushConstants = {});
const std::vector<T>& pushConstants = {})
{
KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor");
this->mAlgorithm = algorithm;
if (pushConstants.size()) {
uint32_t memorySize = sizeof(decltype(pushConstants.back()));
uint32_t size = pushConstants.size();
uint32_t totalSize = size * memorySize;
this->mPushConstantsData = malloc(totalSize);
memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
this->mPushConstantsDataTypeMemorySize = memorySize;
this->mPushConstantsSize = size;
}
}
/**
* Default destructor, which is in charge of destroying the algorithm
@ -63,7 +79,9 @@ class OpAlgoDispatch : public OpBase
private:
// -------------- ALWAYS OWNED RESOURCES
std::shared_ptr<Algorithm> mAlgorithm;
Constants mPushConstants;
void* mPushConstantsData = nullptr;
uint32_t mPushConstantsDataTypeMemorySize = 0;
uint32_t mPushConstantsSize = 0;
};
} // End namespace kp