Merge pull request #219 from EthicalML/tensor_python_numpy_ownership

[PYTHON] Ensure numpy array increments refcount of tensor to keep valid
This commit is contained in:
Alejandro Saucedo 2021-05-15 17:47:13 +01:00 committed by GitHub
commit a3d8b78ff8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 6 deletions

View file

@ -95,18 +95,17 @@ PYBIND11_MODULE(kp, m) {
py::class_<kp::Tensor, std::shared_ptr<kp::Tensor>>(m, "Tensor", DOC(kp, Tensor))
.def("data", [](kp::Tensor& self) {
// Non-owning container exposing the underlying pointer
py::str dummyDataOwner; // Explicitly request data to not be owned by np
switch (self.dataType()) {
case kp::Tensor::TensorDataTypes::eFloat:
return py::array(self.size(), self.data<float>(), dummyDataOwner);
return py::array(self.size(), self.data<float>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eUnsignedInt:
return py::array(self.size(), self.data<uint32_t>(), dummyDataOwner);
return py::array(self.size(), self.data<uint32_t>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eInt:
return py::array(self.size(), self.data<int32_t>(), dummyDataOwner);
return py::array(self.size(), self.data<int32_t>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eDouble:
return py::array(self.size(), self.data<double>(), dummyDataOwner);
return py::array(self.size(), self.data<double>(), py::cast(&self));
case kp::Tensor::TensorDataTypes::eBool:
return py::array(self.size(), self.data<bool>(), dummyDataOwner);
return py::array(self.size(), self.data<bool>(), py::cast(&self));
default:
throw std::runtime_error("Kompute Python data type not supported");
}

View file

@ -207,3 +207,26 @@ def test_type_unsigned_int():
assert np.all(tensor_out.data() == arr_in_a * arr_in_b)
def test_tensor_numpy_ownership():
arr_in = np.array([1, 2, 3])
m = kp.Manager()
t = m.tensor(arr_in)
# This should increment refcount for tensor sharedptr
td = t.data()
assert td.base.is_init() == True
assert np.all(td == arr_in)
del t
assert td.base.is_init() == True
assert np.all(td == arr_in)
m.destroy()
assert td.base.is_init() == False