Updated python build

Signed-off-by: Alejandro Saucedo <axsauze@gmail.com>
This commit is contained in:
Alejandro Saucedo 2021-09-12 13:57:54 +01:00
parent 6113d286a9
commit 1972f2c8f8

View file

@ -178,8 +178,8 @@ PYBIND11_MODULE(kp, m) {
const std::vector<std::shared_ptr<kp::Tensor>>& tensors,
const py::bytes& spirv,
const kp::Workgroup& workgroup,
const kp::Constants& spec_consts,
const kp::Constants& push_consts) {
const std::vector<float>& spec_consts,
const std::vector<float>& push_consts) {
py::buffer_info info(py::buffer(spirv).request());
const char *data = reinterpret_cast<const char *>(info.ptr);
size_t length = static_cast<size_t>(info.size);
@ -190,8 +190,100 @@ PYBIND11_MODULE(kp, m) {
py::arg("tensors"),
py::arg("spirv"),
py::arg("workgroup") = kp::Workgroup(),
py::arg("spec_consts") = kp::Constants(),
py::arg("push_consts") = kp::Constants())
py::arg("spec_consts") = std::vector<float>(),
py::arg("push_consts") = std::vector<float>())
.def("algorithm_t", [np](kp::Manager& self,
const std::vector<std::shared_ptr<kp::Tensor>>& tensors,
const py::bytes& spirv,
const kp::Workgroup& workgroup,
const py::array& spec_consts,
const py::array& push_consts) {
py::buffer_info info(py::buffer(spirv).request());
const char *data = reinterpret_cast<const char *>(info.ptr);
size_t length = static_cast<size_t>(info.size);
std::vector<uint32_t> spirvVec((uint32_t*)data, (uint32_t*)(data + length));
const py::buffer_info pushInfo = push_consts.request();
const py::buffer_info specInfo = spec_consts.request();
KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with "
"push consts data size {} dtype {} and spec const data size {} dtype {}",
push_consts.size(), std::string(py::str(push_consts.dtype())),
spec_consts.size(), std::string(py::str(spec_consts.dtype())));
// We have to iterate across a combination of parameters due to the lack of support for templating
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
std::vector<float> specConstsVec((float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
std::vector<float> pushConstsVec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
std::vector<int32_t> pushConstsVec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
std::vector<uint32_t> pushConstsVec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
std::vector<double> pushConstsVec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
std::vector<int32_t> specconstsvec((int32_t*)specInfo.ptr, ((int32_t*)specInfo.ptr) + specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
std::vector<int32_t> pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
std::vector<uint32_t> pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
std::vector<double> pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
std::vector<uint32_t> specconstsvec((uint32_t*)specInfo.ptr, ((uint32_t*)specInfo.ptr) + specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
std::vector<int32_t> pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
std::vector<uint32_t> pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
std::vector<double> pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
std::vector<double> specconstsvec((double*)specInfo.ptr, ((double*)specInfo.ptr) + specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
std::vector<float> pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
std::vector<float> pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
std::vector<float> pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
}
} else {
// If reach then no valid dtype supported
throw std::runtime_error("Kompute Python no valid dtype supported");
}
},
DOC(kp, Manager, algorithm),
py::arg("tensors"),
py::arg("spirv"),
py::arg("workgroup") = kp::Workgroup(),
py::arg("spec_consts") = std::vector<float>(),
py::arg("push_consts") = std::vector<float>())
.def("list_devices", [](kp::Manager& self){
const std::vector<vk::PhysicalDevice> devices = self.listDevices();
py::list list;