Updated python build
Signed-off-by: Alejandro Saucedo <axsauze@gmail.com>
This commit is contained in:
parent
6113d286a9
commit
1972f2c8f8
1 changed files with 96 additions and 4 deletions
|
|
@ -178,8 +178,8 @@ PYBIND11_MODULE(kp, m) {
|
|||
const std::vector<std::shared_ptr<kp::Tensor>>& tensors,
|
||||
const py::bytes& spirv,
|
||||
const kp::Workgroup& workgroup,
|
||||
const kp::Constants& spec_consts,
|
||||
const kp::Constants& push_consts) {
|
||||
const std::vector<float>& spec_consts,
|
||||
const std::vector<float>& push_consts) {
|
||||
py::buffer_info info(py::buffer(spirv).request());
|
||||
const char *data = reinterpret_cast<const char *>(info.ptr);
|
||||
size_t length = static_cast<size_t>(info.size);
|
||||
|
|
@ -190,8 +190,100 @@ PYBIND11_MODULE(kp, m) {
|
|||
py::arg("tensors"),
|
||||
py::arg("spirv"),
|
||||
py::arg("workgroup") = kp::Workgroup(),
|
||||
py::arg("spec_consts") = kp::Constants(),
|
||||
py::arg("push_consts") = kp::Constants())
|
||||
py::arg("spec_consts") = std::vector<float>(),
|
||||
py::arg("push_consts") = std::vector<float>())
|
||||
.def("algorithm_t", [np](kp::Manager& self,
|
||||
const std::vector<std::shared_ptr<kp::Tensor>>& tensors,
|
||||
const py::bytes& spirv,
|
||||
const kp::Workgroup& workgroup,
|
||||
const py::array& spec_consts,
|
||||
const py::array& push_consts) {
|
||||
|
||||
py::buffer_info info(py::buffer(spirv).request());
|
||||
const char *data = reinterpret_cast<const char *>(info.ptr);
|
||||
size_t length = static_cast<size_t>(info.size);
|
||||
std::vector<uint32_t> spirvVec((uint32_t*)data, (uint32_t*)(data + length));
|
||||
|
||||
const py::buffer_info pushInfo = push_consts.request();
|
||||
const py::buffer_info specInfo = spec_consts.request();
|
||||
|
||||
KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with "
|
||||
"push consts data size {} dtype {} and spec const data size {} dtype {}",
|
||||
push_consts.size(), std::string(py::str(push_consts.dtype())),
|
||||
spec_consts.size(), std::string(py::str(spec_consts.dtype())));
|
||||
|
||||
// We have to iterate across a combination of parameters due to the lack of support for templating
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
std::vector<float> specConstsVec((float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
std::vector<float> pushConstsVec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
std::vector<int32_t> pushConstsVec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
std::vector<uint32_t> pushConstsVec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
std::vector<double> pushConstsVec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
|
||||
}
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
std::vector<int32_t> specconstsvec((int32_t*)specInfo.ptr, ((int32_t*)specInfo.ptr) + specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
std::vector<float> pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
std::vector<int32_t> pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
std::vector<uint32_t> pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
std::vector<double> pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
}
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
std::vector<uint32_t> specconstsvec((uint32_t*)specInfo.ptr, ((uint32_t*)specInfo.ptr) + specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
std::vector<float> pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
std::vector<int32_t> pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
std::vector<uint32_t> pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
std::vector<double> pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
}
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
std::vector<double> specconstsvec((double*)specInfo.ptr, ((double*)specInfo.ptr) + specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
std::vector<float> pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
std::vector<float> pushconstsvec((int32_t*)pushInfo.ptr, ((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
std::vector<float> pushconstsvec((uint32_t*)pushInfo.ptr, ((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
std::vector<float> pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + pushInfo.size);
|
||||
return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec);
|
||||
}
|
||||
} else {
|
||||
// If reach then no valid dtype supported
|
||||
throw std::runtime_error("Kompute Python no valid dtype supported");
|
||||
}
|
||||
},
|
||||
DOC(kp, Manager, algorithm),
|
||||
py::arg("tensors"),
|
||||
py::arg("spirv"),
|
||||
py::arg("workgroup") = kp::Workgroup(),
|
||||
py::arg("spec_consts") = std::vector<float>(),
|
||||
py::arg("push_consts") = std::vector<float>())
|
||||
.def("list_devices", [](kp::Manager& self){
|
||||
const std::vector<vk::PhysicalDevice> devices = self.listDevices();
|
||||
py::list list;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue