Updated pybind11 bindings

Signed-off-by: Fabian Sauter <sauter.fabian@mailbox.org>
This commit is contained in:
Fabian Sauter 2022-06-13 14:57:04 +02:00
parent baa84827d6
commit 2ec35acba5
2 changed files with 69 additions and 65 deletions

View file

@ -22,22 +22,22 @@ opAlgoDispatchPyInit(std::shared_ptr<kp::Algorithm>& algorithm,
push_consts.size(),
std::string(py::str(push_consts.dtype())));
if (push_consts.dtype() == py::dtype::of<std::float_t>()) {
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> dataVec((float*)info.ptr,
((float*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
algorithm, dataVec) };
} else if (push_consts.dtype() == py::dtype::of<std::uint32_t>()) {
} else if (push_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> dataVec((uint32_t*)info.ptr,
((uint32_t*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
algorithm, dataVec) };
} else if (push_consts.dtype() == py::dtype::of<std::int32_t>()) {
} else if (push_consts.dtype().is(py::dtype::of<std::int32_t>())) {
std::vector<int32_t> dataVec((int32_t*)info.ptr,
((int32_t*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
algorithm, dataVec) };
} else if (push_consts.dtype() == py::dtype::of<std::double_t>()) {
} else if (push_consts.dtype().is(py::dtype::of<std::double_t>())) {
std::vector<double> dataVec((double*)info.ptr,
((double*)info.ptr) + info.size);
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
@ -76,29 +76,29 @@ PYBIND11_MODULE(kp, m)
py::class_<kp::OpBase, std::shared_ptr<kp::OpBase>>(
m, "OpBase", DOC(kp, OpBase));
py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(
m,
"OpTensorSyncDevice",
py::base<kp::OpBase>(),
DOC(kp, OpTensorSyncDevice))
py::class_<kp::OpTensorSyncDevice,
kp::OpBase,
std::shared_ptr<kp::OpTensorSyncDevice>>(
m, "OpTensorSyncDevice", DOC(kp, OpTensorSyncDevice))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
DOC(kp, OpTensorSyncDevice, OpTensorSyncDevice));
py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(
m,
"OpTensorSyncLocal",
py::base<kp::OpBase>(),
DOC(kp, OpTensorSyncLocal))
py::class_<kp::OpTensorSyncLocal,
kp::OpBase,
std::shared_ptr<kp::OpTensorSyncLocal>>(
m, "OpTensorSyncLocal", DOC(kp, OpTensorSyncLocal))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
DOC(kp, OpTensorSyncLocal, OpTensorSyncLocal));
py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(
m, "OpTensorCopy", py::base<kp::OpBase>(), DOC(kp, OpTensorCopy))
py::class_<kp::OpTensorCopy, kp::OpBase, std::shared_ptr<kp::OpTensorCopy>>(
m, "OpTensorCopy", DOC(kp, OpTensorCopy))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
DOC(kp, OpTensorCopy, OpTensorCopy));
py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(
m, "OpAlgoDispatch", py::base<kp::OpBase>(), DOC(kp, OpAlgoDispatch))
py::class_<kp::OpAlgoDispatch,
kp::OpBase,
std::shared_ptr<kp::OpAlgoDispatch>>(
m, "OpAlgoDispatch", DOC(kp, OpAlgoDispatch))
.def(py::init<const std::shared_ptr<kp::Algorithm>&,
const std::vector<float>&>(),
DOC(kp, OpAlgoDispatch, OpAlgoDispatch),
@ -109,8 +109,8 @@ PYBIND11_MODULE(kp, m)
py::arg("algorithm"),
py::arg("push_consts"));
py::class_<kp::OpMult, std::shared_ptr<kp::OpMult>>(
m, "OpMult", py::base<kp::OpBase>(), DOC(kp, OpMult))
py::class_<kp::OpMult, kp::OpBase, std::shared_ptr<kp::OpMult>>(
m, "OpMult", DOC(kp, OpMult))
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&,
const std::shared_ptr<kp::Algorithm>&>(),
DOC(kp, OpMult, OpMult));
@ -253,31 +253,31 @@ PYBIND11_MODULE(kp, m)
"size {} dtype {}",
flatdata.size(),
std::string(py::str(flatdata.dtype())));
if (flatdata.dtype() == py::dtype::of<std::float_t>()) {
if (flatdata.dtype().is(py::dtype::of<std::float_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(float),
kp::Tensor::TensorDataTypes::eFloat,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<std::uint32_t>()) {
} else if (flatdata.dtype().is(py::dtype::of<std::uint32_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(uint32_t),
kp::Tensor::TensorDataTypes::eUnsignedInt,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<std::int32_t>()) {
} else if (flatdata.dtype().is(py::dtype::of<std::int32_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(int32_t),
kp::Tensor::TensorDataTypes::eInt,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<std::double_t>()) {
} else if (flatdata.dtype().is(py::dtype::of<std::double_t>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(double),
kp::Tensor::TensorDataTypes::eDouble,
tensor_type);
} else if (flatdata.dtype() == py::dtype::of<bool>()) {
} else if (flatdata.dtype().is(py::dtype::of<bool>())) {
return self.tensor(info.ptr,
flatdata.size(),
sizeof(bool),
@ -340,10 +340,10 @@ PYBIND11_MODULE(kp, m)
// We have to iterate across a combination of parameters due to the
// lack of support for templating
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> specConstsVec(
(float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushConstsVec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
@ -352,8 +352,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specConstsVec,
pushConstsVec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<int32_t> pushConstsVec(
(int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) + pushInfo.size);
@ -362,8 +362,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specConstsVec,
pushConstsVec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> pushConstsVec(
(uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) + pushInfo.size);
@ -372,8 +372,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specConstsVec,
pushConstsVec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<double> pushConstsVec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
@ -383,11 +383,11 @@ PYBIND11_MODULE(kp, m)
specConstsVec,
pushConstsVec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(py::dtype::of<std::int32_t>())) {
std::vector<int32_t> specconstsvec((int32_t*)specInfo.ptr,
((int32_t*)specInfo.ptr) +
specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
@ -396,8 +396,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<int32_t> pushconstsvec(
(int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) + pushInfo.size);
@ -406,8 +406,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> pushconstsvec(
(uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) + pushInfo.size);
@ -416,8 +416,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
@ -427,11 +427,11 @@ PYBIND11_MODULE(kp, m)
specconstsvec,
pushconstsvec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> specconstsvec((uint32_t*)specInfo.ptr,
((uint32_t*)specInfo.ptr) +
specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
@ -440,8 +440,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<int32_t> pushconstsvec(
(int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) + pushInfo.size);
@ -450,8 +450,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<uint32_t> pushconstsvec(
(uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) + pushInfo.size);
@ -460,8 +460,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
@ -471,11 +471,11 @@ PYBIND11_MODULE(kp, m)
specconstsvec,
pushconstsvec);
}
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(py::dtype::of<std::double_t>())) {
std::vector<double> specconstsvec((double*)specInfo.ptr,
((double*)specInfo.ptr) +
specInfo.size);
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
((float*)pushInfo.ptr) +
pushInfo.size);
@ -484,8 +484,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::int32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::int32_t>())) {
std::vector<float> pushconstsvec((int32_t*)pushInfo.ptr,
((int32_t*)pushInfo.ptr) +
pushInfo.size);
@ -494,8 +494,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::uint32_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::uint32_t>())) {
std::vector<float> pushconstsvec((uint32_t*)pushInfo.ptr,
((uint32_t*)pushInfo.ptr) +
pushInfo.size);
@ -504,8 +504,8 @@ PYBIND11_MODULE(kp, m)
workgroup,
specconstsvec,
pushconstsvec);
} else if (spec_consts.dtype() ==
py::dtype::of<std::double_t>()) {
} else if (spec_consts.dtype().is(
py::dtype::of<std::double_t>())) {
std::vector<float> pushconstsvec((double*)pushInfo.ptr,
((double*)pushInfo.ptr) +
pushInfo.size);
@ -515,11 +515,9 @@ PYBIND11_MODULE(kp, m)
specconstsvec,
pushconstsvec);
}
} else {
// If reach then no valid dtype supported
throw std::runtime_error(
"Kompute Python no valid dtype supported");
}
// If reach then no valid dtype supported
throw std::runtime_error("Kompute Python no valid dtype supported");
},
DOC(kp, Manager, algorithm),
py::arg("tensors"),