diff --git a/src/OpCreateTensor.cpp b/src/OpCreateTensor.cpp index 7a1247272..4fc687528 100644 --- a/src/OpCreateTensor.cpp +++ b/src/OpCreateTensor.cpp @@ -27,7 +27,7 @@ OpCreateTensor::init(std::shared_ptr tensor, std::vector data) if (tensor->tensorType() == Tensor::TensorTypes::eDevice) { tensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer); - this->mStagingTensor = std::make_unique(tensor->shape(), Tensor::TensorTypes::eStaging); + this->mStagingTensor = std::make_shared(tensor->shape(), Tensor::TensorTypes::eStaging); this->mStagingTensor->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, data); diff --git a/src/OpMult.hpp b/src/OpMult.hpp deleted file mode 100644 index 87c6b55b9..000000000 --- a/src/OpMult.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "BaseOp.hpp" - -namespace kp { - -class OpMult : BaseOp -{ - private: - public: - OpMult(); - virtual ~OpMult(); -}; - -} // End namespace kp diff --git a/src/Tensor.cpp b/src/Tensor.cpp index 3b8fd070d..1c788f38c 100644 --- a/src/Tensor.cpp +++ b/src/Tensor.cpp @@ -7,15 +7,11 @@ Tensor::Tensor() { this->mTensorType = TensorTypes::eDevice; } -Tensor::Tensor(std::vector shape, TensorTypes tensorType) +Tensor::Tensor(std::array shape, TensorTypes tensorType) { SPDLOG_DEBUG("Kompute Tensor init with data"); - if (shape.size() > KP_MAX_DIM_SIZE) { - spdlog::warn("Kompute Tensor created with more dimensions than supported. Max: {}, Provided: {}.", KP_MAX_DIM_SIZE, shape.size()); - } - - std::copy_n(shape.begin(), KP_MAX_DIM_SIZE, this->mShape.begin()); + this->mShape = shape; this->mTensorType = tensorType; } diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 1e51a6019..1c7ca44b4 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -27,7 +27,7 @@ class Tensor Tensor(); - Tensor(std::vector shape, TensorTypes tensorType = TensorTypes::eDevice); + Tensor(std::array shape, TensorTypes tensorType = TensorTypes::eDevice); ~Tensor(); @@ -39,7 +39,7 @@ class Tensor // Getter functions std::vector data(); uint32_t size(); - std::array shape(); + std::vector shape(); TensorTypes tensorType(); bool isInit();