From 2ec35acba518dd49b020706b68a5149efb58b53b Mon Sep 17 00:00:00 2001 From: Fabian Sauter Date: Mon, 13 Jun 2022 14:57:04 +0200 Subject: [PATCH] Updated pybind11 bindings Signed-off-by: Fabian Sauter --- python/src/main.cpp | 120 ++++++++++++++++++++++---------------------- src/CMakeLists.txt | 14 ++++-- 2 files changed, 69 insertions(+), 65 deletions(-) diff --git a/python/src/main.cpp b/python/src/main.cpp index 0ea2793cb..3f3f19d10 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -22,22 +22,22 @@ opAlgoDispatchPyInit(std::shared_ptr& algorithm, push_consts.size(), std::string(py::str(push_consts.dtype()))); - if (push_consts.dtype() == py::dtype::of()) { + if (push_consts.dtype().is(py::dtype::of())) { std::vector dataVec((float*)info.ptr, ((float*)info.ptr) + info.size); return std::unique_ptr{ new kp::OpAlgoDispatch( algorithm, dataVec) }; - } else if (push_consts.dtype() == py::dtype::of()) { + } else if (push_consts.dtype().is(py::dtype::of())) { std::vector dataVec((uint32_t*)info.ptr, ((uint32_t*)info.ptr) + info.size); return std::unique_ptr{ new kp::OpAlgoDispatch( algorithm, dataVec) }; - } else if (push_consts.dtype() == py::dtype::of()) { + } else if (push_consts.dtype().is(py::dtype::of())) { std::vector dataVec((int32_t*)info.ptr, ((int32_t*)info.ptr) + info.size); return std::unique_ptr{ new kp::OpAlgoDispatch( algorithm, dataVec) }; - } else if (push_consts.dtype() == py::dtype::of()) { + } else if (push_consts.dtype().is(py::dtype::of())) { std::vector dataVec((double*)info.ptr, ((double*)info.ptr) + info.size); return std::unique_ptr{ new kp::OpAlgoDispatch( @@ -76,29 +76,29 @@ PYBIND11_MODULE(kp, m) py::class_>( m, "OpBase", DOC(kp, OpBase)); - py::class_>( - m, - "OpTensorSyncDevice", - py::base(), - DOC(kp, OpTensorSyncDevice)) + py::class_>( + m, "OpTensorSyncDevice", DOC(kp, OpTensorSyncDevice)) .def(py::init>&>(), DOC(kp, OpTensorSyncDevice, OpTensorSyncDevice)); - py::class_>( - m, - "OpTensorSyncLocal", - py::base(), - DOC(kp, OpTensorSyncLocal)) + py::class_>( + m, "OpTensorSyncLocal", DOC(kp, OpTensorSyncLocal)) .def(py::init>&>(), DOC(kp, OpTensorSyncLocal, OpTensorSyncLocal)); - py::class_>( - m, "OpTensorCopy", py::base(), DOC(kp, OpTensorCopy)) + py::class_>( + m, "OpTensorCopy", DOC(kp, OpTensorCopy)) .def(py::init>&>(), DOC(kp, OpTensorCopy, OpTensorCopy)); - py::class_>( - m, "OpAlgoDispatch", py::base(), DOC(kp, OpAlgoDispatch)) + py::class_>( + m, "OpAlgoDispatch", DOC(kp, OpAlgoDispatch)) .def(py::init&, const std::vector&>(), DOC(kp, OpAlgoDispatch, OpAlgoDispatch), @@ -109,8 +109,8 @@ PYBIND11_MODULE(kp, m) py::arg("algorithm"), py::arg("push_consts")); - py::class_>( - m, "OpMult", py::base(), DOC(kp, OpMult)) + py::class_>( + m, "OpMult", DOC(kp, OpMult)) .def(py::init>&, const std::shared_ptr&>(), DOC(kp, OpMult, OpMult)); @@ -253,31 +253,31 @@ PYBIND11_MODULE(kp, m) "size {} dtype {}", flatdata.size(), std::string(py::str(flatdata.dtype()))); - if (flatdata.dtype() == py::dtype::of()) { + if (flatdata.dtype().is(py::dtype::of())) { return self.tensor(info.ptr, flatdata.size(), sizeof(float), kp::Tensor::TensorDataTypes::eFloat, tensor_type); - } else if (flatdata.dtype() == py::dtype::of()) { + } else if (flatdata.dtype().is(py::dtype::of())) { return self.tensor(info.ptr, flatdata.size(), sizeof(uint32_t), kp::Tensor::TensorDataTypes::eUnsignedInt, tensor_type); - } else if (flatdata.dtype() == py::dtype::of()) { + } else if (flatdata.dtype().is(py::dtype::of())) { return self.tensor(info.ptr, flatdata.size(), sizeof(int32_t), kp::Tensor::TensorDataTypes::eInt, tensor_type); - } else if (flatdata.dtype() == py::dtype::of()) { + } else if (flatdata.dtype().is(py::dtype::of())) { return self.tensor(info.ptr, flatdata.size(), sizeof(double), kp::Tensor::TensorDataTypes::eDouble, tensor_type); - } else if (flatdata.dtype() == py::dtype::of()) { + } else if (flatdata.dtype().is(py::dtype::of())) { return self.tensor(info.ptr, flatdata.size(), sizeof(bool), @@ -340,10 +340,10 @@ PYBIND11_MODULE(kp, m) // We have to iterate across a combination of parameters due to the // lack of support for templating - if (spec_consts.dtype() == py::dtype::of()) { + if (spec_consts.dtype().is(py::dtype::of())) { std::vector specConstsVec( (float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype() == py::dtype::of()) { + if (spec_consts.dtype().is(py::dtype::of())) { std::vector pushConstsVec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -352,8 +352,8 @@ PYBIND11_MODULE(kp, m) workgroup, specConstsVec, pushConstsVec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushConstsVec( (int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); @@ -362,8 +362,8 @@ PYBIND11_MODULE(kp, m) workgroup, specConstsVec, pushConstsVec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushConstsVec( (uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); @@ -372,8 +372,8 @@ PYBIND11_MODULE(kp, m) workgroup, specConstsVec, pushConstsVec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushConstsVec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); @@ -383,11 +383,11 @@ PYBIND11_MODULE(kp, m) specConstsVec, pushConstsVec); } - } else if (spec_consts.dtype() == py::dtype::of()) { + } else if (spec_consts.dtype().is(py::dtype::of())) { std::vector specconstsvec((int32_t*)specInfo.ptr, ((int32_t*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype() == py::dtype::of()) { + if (spec_consts.dtype().is(py::dtype::of())) { std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -396,8 +396,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec( (int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); @@ -406,8 +406,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec( (uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); @@ -416,8 +416,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); @@ -427,11 +427,11 @@ PYBIND11_MODULE(kp, m) specconstsvec, pushconstsvec); } - } else if (spec_consts.dtype() == py::dtype::of()) { + } else if (spec_consts.dtype().is(py::dtype::of())) { std::vector specconstsvec((uint32_t*)specInfo.ptr, ((uint32_t*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype() == py::dtype::of()) { + if (spec_consts.dtype().is(py::dtype::of())) { std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -440,8 +440,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec( (int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); @@ -450,8 +450,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec( (uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); @@ -460,8 +460,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); @@ -471,11 +471,11 @@ PYBIND11_MODULE(kp, m) specconstsvec, pushconstsvec); } - } else if (spec_consts.dtype() == py::dtype::of()) { + } else if (spec_consts.dtype().is(py::dtype::of())) { std::vector specconstsvec((double*)specInfo.ptr, ((double*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype() == py::dtype::of()) { + if (spec_consts.dtype().is(py::dtype::of())) { std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -484,8 +484,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); @@ -494,8 +494,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); @@ -504,8 +504,8 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype() == - py::dtype::of()) { + } else if (spec_consts.dtype().is( + py::dtype::of())) { std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); @@ -515,11 +515,9 @@ PYBIND11_MODULE(kp, m) specconstsvec, pushconstsvec); } - } else { - // If reach then no valid dtype supported - throw std::runtime_error( - "Kompute Python no valid dtype supported"); } + // If reach then no valid dtype supported + throw std::runtime_error("Kompute Python no valid dtype supported"); }, DOC(kp, Manager, algorithm), py::arg("tensors"), diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 68dd739b1..cd1787dd3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -82,12 +82,18 @@ endif() if(KOMPUTE_OPT_ANDROID_BUILD) target_link_libraries(kompute PUBLIC kompute_vk_ndk_wrapper android - PRIVATE kp_logger - fmt::fmt) + kp_logger + PRIVATE fmt::fmt) else() target_link_libraries(kompute PUBLIC Vulkan::Vulkan - PRIVATE fmt::fmt - kp_logger) + kp_logger + PRIVATE fmt::fmt) +endif() + +if(KOMPUTE_OPT_BUILD_PYTHON) + include_directories(${PYTHON_INCLUDE_DIRS}) + + target_link_libraries(kompute PRIVATE pybind11::headers ${PYTHON_LIBRARIES}) endif() if(KOMPUTE_OPT_USE_BUILD_IN_VULKAN_HEADER)