diff --git a/python/src/main.cpp b/python/src/main.cpp index 43c369555..b0ef31191 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -178,8 +178,8 @@ PYBIND11_MODULE(kp, m) { const std::vector>& tensors, const py::bytes& spirv, const kp::Workgroup& workgroup, - const kp::Constants& spec_consts, - const kp::Constants& push_consts) { + const std::vector& spec_consts, + const std::vector& push_consts) { py::buffer_info info(py::buffer(spirv).request()); const char *data = reinterpret_cast(info.ptr); size_t length = static_cast(info.size); @@ -190,8 +190,100 @@ PYBIND11_MODULE(kp, m) { py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), - py::arg("spec_consts") = kp::Constants(), - py::arg("push_consts") = kp::Constants()) + py::arg("spec_consts") = std::vector(), + py::arg("push_consts") = std::vector()) + .def("algorithm_t", [np](kp::Manager& self, + const std::vector>& tensors, + const py::bytes& spirv, + const kp::Workgroup& workgroup, + const py::array& spec_consts, + const py::array& push_consts) { + + py::buffer_info info(py::buffer(spirv).request()); + const char *data = reinterpret_cast(info.ptr); + size_t length = static_cast(info.size); + std::vector spirvVec((uint32_t*)data, (uint32_t*)(data + length)); + + const py::buffer_info pushInfo = push_consts.request(); + const py::buffer_info specInfo = spec_consts.request(); + + KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with " + "push consts data size {} dtype {} and spec const data size {} dtype {}", + push_consts.size(), std::string(py::str(push_consts.dtype())), + spec_consts.size(), std::string(py::str(spec_consts.dtype()))); + + // We have to iterate across a combination of parameters due to the lack of support for templating + if (spec_consts.dtype() == py::dtype::of()) { + std::vector specConstsVec((float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size); + if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushConstsVec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushConstsVec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushConstsVec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushConstsVec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec); + } + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector specconstsvec((int32_t*)specInfo.ptr, ((int32_t*)specInfo.ptr) + specInfo.size); + if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector specconstsvec((uint32_t*)specInfo.ptr, ((uint32_t*)specInfo.ptr) + specInfo.size); + if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector specconstsvec((double*)specInfo.ptr, ((double*)specInfo.ptr) + specInfo.size); + if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } else if (spec_consts.dtype() == py::dtype::of()) { + std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size); + return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); + } + } else { + // If reach then no valid dtype supported + throw std::runtime_error("Kompute Python no valid dtype supported"); + } + }, + DOC(kp, Manager, algorithm), + py::arg("tensors"), + py::arg("spirv"), + py::arg("workgroup") = kp::Workgroup(), + py::arg("spec_consts") = std::vector(), + py::arg("push_consts") = std::vector()) .def("list_devices", [](kp::Manager& self){ const std::vector devices = self.listDevices(); py::list list;