Extended algorithm to add spec consts for int and float
This commit is contained in:
parent
9cb4c2f1e1
commit
0b84876c95
4 changed files with 145 additions and 26 deletions
|
|
@ -12,7 +12,115 @@ namespace kp {
|
|||
*/
|
||||
class Algorithm
|
||||
{
|
||||
public:
|
||||
public:
|
||||
// TODO: Move as internal struct of speccontainer
|
||||
class SpecializationConstant {
|
||||
public:
|
||||
SpecializationConstant(const SpecializationConstant& specializationConstant) {
|
||||
SPDLOG_DEBUG("Kompute SpecializationConstant copy constructor: {}", *((uint32_t*)specializationConstant.mInstanceData));
|
||||
this->mInstanceData = (char*)malloc(sizeof(uint32_t));
|
||||
memcpy(this->mInstanceData, specializationConstant.mInstanceData, sizeof(uint32_t));
|
||||
}
|
||||
// This class is required in absence of std::variant to ensure C++11 support
|
||||
SpecializationConstant(uint32_t val) {
|
||||
SPDLOG_DEBUG("Kompute SpecializationConstant uint32_t constructor: {}", val);
|
||||
this->mInstanceData = (char*)malloc(sizeof(uint32_t));
|
||||
memcpy(this->mInstanceData, &val, sizeof(uint32_t));
|
||||
}
|
||||
SpecializationConstant(float val) {
|
||||
SPDLOG_DEBUG("Kompute SpecializationConstant float constructor: {}", val);
|
||||
this->mInstanceData = (char*)malloc(sizeof(uint32_t));
|
||||
memcpy(this->mInstanceData, &val, sizeof(uint32_t));
|
||||
}
|
||||
~SpecializationConstant() {
|
||||
free(this->mInstanceData);
|
||||
}
|
||||
void *data() {
|
||||
return this->mInstanceData;
|
||||
}
|
||||
private:
|
||||
// We use char pointer to enable for pointer arithmetic
|
||||
char *mInstanceData = nullptr;
|
||||
};
|
||||
|
||||
class SpecializationContainer {
|
||||
public:
|
||||
SpecializationContainer() {
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer default initialiser");
|
||||
this->mFreeData = false;
|
||||
}
|
||||
|
||||
SpecializationContainer(const SpecializationContainer& specializationContainer)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer copy constructor, size: {}", specializationContainer.mSpecializationConstants.size());
|
||||
SpecializationContainer(specializationContainer.mSpecializationConstants);
|
||||
}
|
||||
|
||||
SpecializationContainer(std::vector<SpecializationConstant> instances) {
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer initialiser with instances size {}", instances.size());
|
||||
|
||||
static_assert(sizeof(uint32_t) == sizeof(float) && sizeof(uint32_t) == sizeof(char) * 4,
|
||||
"Kompute requires uint32_t and float to be of same size. Please report this to github.");
|
||||
|
||||
// totalMemorySize depends on instances being set so this needs to be set before
|
||||
this->mSpecializationConstants = instances;
|
||||
|
||||
// Data has then to be allocated in order to copy memory into it
|
||||
this->mData = (char*)malloc(this->totalMemorySize());
|
||||
|
||||
this->mFreeData = true;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
|
||||
memcpy(this->mData + (i * sizeof(uint32_t)), instances[i].data(), sizeof(uint32_t));
|
||||
}
|
||||
}
|
||||
|
||||
~SpecializationContainer() {
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer destructor started");
|
||||
|
||||
this->mSpecializationConstants.clear();
|
||||
|
||||
if (this->mFreeData) {
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer freeing data");
|
||||
this->mFreeData = false;
|
||||
free(this->mData);
|
||||
} else {
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer no data was freed");
|
||||
}
|
||||
|
||||
SPDLOG_DEBUG("kompute SpecializationContainer freed data");
|
||||
}
|
||||
|
||||
void *transferDataOwnership() {
|
||||
SPDLOG_DEBUG("Kompute SpecializationContainer data transfer ownership requested");
|
||||
this->mFreeData = false;
|
||||
return (void*)this->mData;
|
||||
}
|
||||
|
||||
uint32_t size() {
|
||||
return this->mSpecializationConstants.size();
|
||||
}
|
||||
|
||||
uint32_t totalMemorySize() {
|
||||
return this->instanceMemorySize() * this->size();
|
||||
}
|
||||
|
||||
uint32_t instanceMemorySize() {
|
||||
// At this point only variables accepted are uint32_t and float which are same size
|
||||
return sizeof(uint32_t);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
std::vector<SpecializationConstant> mSpecializationConstants;
|
||||
bool mFreeData = false;
|
||||
// We use char pointer to enable for pointer arithmetic
|
||||
char *mData = nullptr;
|
||||
};
|
||||
private:
|
||||
// Private struct template which is then
|
||||
public:
|
||||
/**
|
||||
Base constructor for Algorithm. Should not be used unless explicit
|
||||
intended.
|
||||
|
|
@ -27,7 +135,8 @@ class Algorithm
|
|||
* shaders
|
||||
*/
|
||||
Algorithm(std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer);
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
const SpecializationContainer& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Initialiser for the shader data provided to the algorithm as well as
|
||||
|
|
@ -35,6 +144,7 @@ class Algorithm
|
|||
*
|
||||
* @param shaderFileData The bytes in spir-v format of the shader
|
||||
* @tensorParams The Tensors to be used in the Algorithm / shader for
|
||||
* @specalizationInstalces The specialization parameters to pass to the function
|
||||
* processing
|
||||
*/
|
||||
void init(const std::vector<char>& shaderFileData,
|
||||
|
|
@ -56,7 +166,7 @@ class Algorithm
|
|||
*/
|
||||
void recordDispatch(uint32_t x = 1, uint32_t y = 1, uint32_t z = 1);
|
||||
|
||||
private:
|
||||
private:
|
||||
// -------------- NEVER OWNED RESOURCES
|
||||
std::shared_ptr<vk::Device> mDevice;
|
||||
std::shared_ptr<vk::CommandBuffer> mCommandBuffer;
|
||||
|
|
@ -77,9 +187,12 @@ class Algorithm
|
|||
std::shared_ptr<vk::Pipeline> mPipeline;
|
||||
bool mFreePipeline = false;
|
||||
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
SpecializationContainer mSpecializationConstants;
|
||||
|
||||
// Create util functions
|
||||
void createShaderModule(const std::vector<char>& shaderFileData);
|
||||
void createPipeline(std::vector<uint32_t> specializationData = {});
|
||||
void createPipeline();
|
||||
|
||||
// Parameters
|
||||
void createParameters(std::vector<std::shared_ptr<Tensor>>& tensorParams);
|
||||
|
|
|
|||
|
|
@ -49,7 +49,8 @@ class OpAlgoBase : public OpBase
|
|||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup());
|
||||
KomputeWorkgroup komputeWorkgroup = {},
|
||||
const Algorithm::SpecializationContainer& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Constructor that enables a file to be passed to the operation with
|
||||
|
|
@ -68,7 +69,8 @@ class OpAlgoBase : public OpBase
|
|||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
std::string shaderFilePath,
|
||||
KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup());
|
||||
KomputeWorkgroup komputeWorkgroup = {},
|
||||
const Algorithm::SpecializationContainer& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Constructor that enables raw shader data to be passed to the main operation
|
||||
|
|
@ -86,7 +88,8 @@ class OpAlgoBase : public OpBase
|
|||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<char>& shaderDataRaw,
|
||||
KomputeWorkgroup komputeWorkgroup = KomputeWorkgroup());
|
||||
KomputeWorkgroup komputeWorkgroup = {},
|
||||
const Algorithm::SpecializationContainer& specializationConstants = {});
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue