From 893fd4fc7cec9674d4bdf92ec9d709d049d78569 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Sun, 10 Jan 2021 15:29:18 +0100 Subject: [PATCH] faster set_data() --- python/src/main.cpp | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/python/src/main.cpp b/python/src/main.cpp index 77924581e..83807b53a 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -10,6 +10,8 @@ namespace py = pybind11; PYBIND11_MODULE(kp, m) { + py::module_ np = py::module_::import("numpy"); + #if KOMPUTE_ENABLE_SPDLOG spdlog::set_level( static_cast(SPDLOG_ACTIVE_LEVEL)); @@ -40,25 +42,19 @@ PYBIND11_MODULE(kp, m) { return std::unique_ptr(new kp::Tensor(data, tensorTypes)); }), "Initialiser with list of data components and tensor GPU memory type.") .def("data", &kp::Tensor::data, DOC(kp, Tensor, data)) - .def("numpy", [](kp::Tensor& self){ - ssize_t ndim = 1; - std::vector shape = { self.size() }; - std::vector strides = { sizeof(float) }; - - return py::array(py::buffer_info( - self.data().data(), - sizeof(float), - py::format_descriptor::format(), - ndim, - shape, - strides - )); + .def("numpy", [](kp::Tensor& self) { + return py::array(self.data().size(), self.data().data()); }, "Returns stored data as a new numpy array.") .def("__getitem__", [](kp::Tensor &self, size_t index) -> float { return self.data()[index]; }, "When only an index is necessary") .def("__setitem__", [](kp::Tensor &self, size_t index, float value) { self.data()[index] = value; }) - .def("set_data", &kp::Tensor::setData, "Overrides the data in the local Tensor memory.") + .def("set_data", [np](kp::Tensor &self, const py::array_t data){ + const py::array_t flatdata = np.attr("ravel")(data); + const py::buffer_info info = flatdata.request(); + const float* ptr = (float*) info.ptr; + self.setData(std::vector(ptr, ptr+flatdata.size())); + }, "Overrides the data in the local Tensor memory.") .def("__iter__", [](kp::Tensor &self) { return py::make_iterator(self.data().begin(), self.data().end()); }, py::keep_alive<0, 1>(), // Required to keep alive iterator while exists