Migrated default base type of tensor to float
This commit is contained in:
parent
00c66f347b
commit
08a1300450
13 changed files with 255 additions and 249 deletions
|
|
@ -42,7 +42,7 @@ class Tensor
|
|||
* @param data Vector of data that will be used by the tensor
|
||||
* @param tensorType Type for the tensor which is of type TensorTypes
|
||||
*/
|
||||
Tensor(std::vector<uint32_t> data,
|
||||
Tensor(std::vector<float> data,
|
||||
TensorTypes tensorType = TensorTypes::eDevice);
|
||||
|
||||
/**
|
||||
|
|
@ -70,7 +70,7 @@ class Tensor
|
|||
*
|
||||
* @return Vector of elements representing the data in the tensor.
|
||||
*/
|
||||
std::vector<uint32_t> data();
|
||||
std::vector<float> data();
|
||||
/**
|
||||
* Returns the size/magnitude of the Tensor, which will be the total number
|
||||
* of elements across all dimensions
|
||||
|
|
@ -103,7 +103,7 @@ class Tensor
|
|||
* Sets / resets the vector data of the tensor. This function does not
|
||||
* perform any copies into GPU memory and is only performed on the host.
|
||||
*/
|
||||
void setData(const std::vector<uint32_t>& data);
|
||||
void setData(const std::vector<float>& data);
|
||||
|
||||
/**
|
||||
* Records a copy from the memory of the tensor provided to the current
|
||||
|
|
@ -163,7 +163,7 @@ class Tensor
|
|||
bool mFreeMemory;
|
||||
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::vector<uint32_t> mData;
|
||||
std::vector<float> mData;
|
||||
|
||||
TensorTypes mTensorType = TensorTypes::eDevice;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue