Updated memory ownership of sharedptr of Tensor to also be refcounted by numpy array returned in data

This commit is contained in:
Alejandro Saucedo 2021-05-07 16:54:08 +01:00
parent 11597af272
commit ad9c857427

View file

@ -95,18 +95,17 @@ PYBIND11_MODULE(kp, m) {
py::class_<kp::Tensor, std::shared_ptr<kp::Tensor>>(m, "Tensor", DOC(kp, Tensor))
.def("data", [](kp::Tensor& self) {
// Non-owning container exposing the underlying pointer
py::str dummyDataOwner; // Explicitly request data to not be owned by np
switch (self.dataType()) {
case kp::Tensor::TensorDataTypes::eFloat:
return py::array(self.size(), self.data<float>(), dummyDataOwner);
return py::array(self.size(), self.data<float>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eUnsignedInt:
return py::array(self.size(), self.data<uint32_t>(), dummyDataOwner);
return py::array(self.size(), self.data<uint32_t>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eInt:
return py::array(self.size(), self.data<int32_t>(), dummyDataOwner);
return py::array(self.size(), self.data<int32_t>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eDouble:
return py::array(self.size(), self.data<double>(), dummyDataOwner);
return py::array(self.size(), self.data<double>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eBool:
return py::array(self.size(), self.data<bool>(), dummyDataOwner);
return py::array(self.size(), self.data<bool>(), py::cast(&self));
default:
throw std::runtime_error("Kompute Python data type not supported");
}