Added tests for push constants of all and mixed types
Signed-off-by: Alejandro Saucedo <axsauze@gmail.com>
This commit is contained in:
parent
c23573eb47
commit
858a70d9b8
7 changed files with 304 additions and 35 deletions
|
|
@ -183,23 +183,29 @@ class Algorithm
|
|||
template<typename T>
|
||||
void setPushConstants(const std::vector<T>& pushConstants)
|
||||
{
|
||||
uint32_t memorySize = sizeof(decltype(pushConstants.back()));
|
||||
uint32_t size = pushConstants.size();
|
||||
uint32_t totalSize = memorySize * size;
|
||||
uint32_t previousTotalSize = this->mPushConstantsDataTypeMemorySize * this->mPushConstantsSize;
|
||||
|
||||
if (pushConstants.size() != this->mPushConstantsSize) {
|
||||
if (totalSize != previousTotalSize) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("Kompute Algorithm push "
|
||||
"constant provided is size {} but expected size {}",
|
||||
pushConstants.size(),
|
||||
this->mPushConstantsSize));
|
||||
"constant total memory size provided is {} but expected {} bytes",
|
||||
totalSize,
|
||||
previousTotalSize));
|
||||
}
|
||||
if (this->mPushConstantsData) {
|
||||
free(this->mPushConstantsData);
|
||||
}
|
||||
|
||||
uint32_t memorySize = sizeof(decltype(pushConstants.back()));
|
||||
uint32_t size = pushConstants.size();
|
||||
this->setPushConstants(pushConstants.data(), size, memorySize);
|
||||
}
|
||||
|
||||
void setPushConstants(void* data, uint32_t size, uint32_t memorySize) {
|
||||
uint32_t totalSize = size * memorySize;
|
||||
this->mPushConstantsData = malloc(totalSize);
|
||||
memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
|
||||
memcpy(this->mPushConstantsData, data, totalSize);
|
||||
this->mPushConstantsDataTypeMemorySize = memorySize;
|
||||
this->mPushConstantsSize = size;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,8 +25,24 @@ class OpAlgoDispatch : public OpBase
|
|||
* @param algorithm The algorithm object to use for dispatch
|
||||
* @param pushConstants The push constants to use for override
|
||||
*/
|
||||
template<typename T = float>
|
||||
OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm,
|
||||
const kp::Constants& pushConstants = {});
|
||||
const std::vector<T>& pushConstants = {})
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute OpAlgoDispatch constructor");
|
||||
|
||||
this->mAlgorithm = algorithm;
|
||||
|
||||
if (pushConstants.size()) {
|
||||
uint32_t memorySize = sizeof(decltype(pushConstants.back()));
|
||||
uint32_t size = pushConstants.size();
|
||||
uint32_t totalSize = size * memorySize;
|
||||
this->mPushConstantsData = malloc(totalSize);
|
||||
memcpy(this->mPushConstantsData, pushConstants.data(), totalSize);
|
||||
this->mPushConstantsDataTypeMemorySize = memorySize;
|
||||
this->mPushConstantsSize = size;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
|
|
@ -63,7 +79,9 @@ class OpAlgoDispatch : public OpBase
|
|||
private:
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::shared_ptr<Algorithm> mAlgorithm;
|
||||
Constants mPushConstants;
|
||||
void* mPushConstantsData = nullptr;
|
||||
uint32_t mPushConstantsDataTypeMemorySize = 0;
|
||||
uint32_t mPushConstantsSize = 0;
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue