Updated to remove all todos

This commit is contained in:
Alejandro Saucedo 2020-09-02 21:25:30 +01:00
parent af4f429d4d
commit 07bfbe3504
9 changed files with 20 additions and 39 deletions

View file

@ -66,7 +66,6 @@ class Algorithm
bool mFreeDescriptorSetLayout = false;
std::shared_ptr<vk::DescriptorPool> mDescriptorPool;
bool mFreeDescriptorPool = false;
// TODO: Explore design for multiple descriptor sets
std::shared_ptr<vk::DescriptorSet> mDescriptorSet;
bool mFreeDescriptorSet = false;
std::shared_ptr<vk::ShaderModule> mShaderModule;

View file

@ -13,8 +13,6 @@ class Algorithm
Algorithm(std::shared_ptr<vk::Device> device);
// TODO: Add specialisation data
// TODO: Explore other ways of passing shader (ie raw bytes)
void init(std::string shaderFilePath,
std::vector<std::shared_ptr<Tensor>> tensorParams);

View file

@ -68,9 +68,18 @@ class Tensor
* important to ensure that there is no out-of-sync data with the GPU
* memory.
*
* @return Vector of elements representing the data in the tensor.
* @return Reference to vector of elements representing the data in the tensor.
*/
std::vector<float> data();
std::vector<float>& data();
/**
* Overrides the subscript operator to expose the underlying data's
* subscript operator which in this case would be its underlying
* vector's.
*
* @param i The index where the element will be returned from.
* @return Returns the element in the position requested.
*/
float& operator[] (int index);
/**
* Returns the size/magnitude of the Tensor, which will be the total number
* of elements across all dimensions

View file

@ -181,8 +181,6 @@ OpAlgoBase<tX, tY, tZ>::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalD
this->mY = tY > 0 ? tY : 1;
this->mZ = tZ > 0 ? tZ : 1;
} else {
// TODO: If tensor empty vector exception would be thrown
// TODO: Fully support the full size dispatch using size for the shape
this->mX = tensors[0]->size();
this->mY = 1;
this->mZ = 1;

View file

@ -136,7 +136,6 @@ OpAlgoLhsRhsOut<tX, tY, tZ>::init()
this->mTensorOutput = this->mTensors[2];
// TODO: Explore adding a validate function
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
this->mTensorOutput->isInit())) {
throw std::runtime_error(
@ -146,8 +145,6 @@ OpAlgoLhsRhsOut<tX, tY, tZ>::init()
" Output: " + std::to_string(this->mTensorOutput->isInit()));
}
// TODO: Explore use-cases where tensors shouldn't be the same size, and how
// to deal with those situations
if (!(this->mTensorLHS->size() == this->mTensorRHS->size() &&
this->mTensorRHS->size() == this->mTensorOutput->size())) {
throw std::runtime_error(