Merge pull request #219 from EthicalML/tensor_python_numpy_ownership

[PYTHON] Ensure numpy array increments refcount of tensor to keep valid
This commit is contained in:
Alejandro Saucedo 2021-05-15 17:47:13 +01:00 committed by GitHub
commit a3d8b78ff8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 6 deletions

View file

@ -95,18 +95,17 @@ PYBIND11_MODULE(kp, m) {
py::class_<kp::Tensor, std::shared_ptr<kp::Tensor>>(m, "Tensor", DOC(kp, Tensor))
.def("data", [](kp::Tensor& self) {
// Non-owning container exposing the underlying pointer
py::str dummyDataOwner; // Explicitly request data to not be owned by np
switch (self.dataType()) {
case kp::Tensor::TensorDataTypes::eFloat:
return py::array(self.size(), self.data<float>(), dummyDataOwner);
return py::array(self.size(), self.data<float>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eUnsignedInt:
return py::array(self.size(), self.data<uint32_t>(), dummyDataOwner);
return py::array(self.size(), self.data<uint32_t>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eInt:
return py::array(self.size(), self.data<int32_t>(), dummyDataOwner);
return py::array(self.size(), self.data<int32_t>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eDouble:
return py::array(self.size(), self.data<double>(), dummyDataOwner);
return py::array(self.size(), self.data<double>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eBool:
return py::array(self.size(), self.data<bool>(), dummyDataOwner);
return py::array(self.size(), self.data<bool>(), py::cast(&self));
default:
throw std::runtime_error("Kompute Python data type not supported");
}