Updated pybind11 bindings
Signed-off-by: Fabian Sauter <sauter.fabian@mailbox.org>
This commit is contained in:
parent
baa84827d6
commit
2ec35acba5
2 changed files with 69 additions and 65 deletions
|
|
@ -22,22 +22,22 @@ opAlgoDispatchPyInit(std::shared_ptr<kp::Algorithm>& algorithm,
|
|||
push_consts.size(),
|
||||
std::string(py::str(push_consts.dtype())));
|
||||
|
||||
if (push_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
std::vector<float> dataVec((float*)info.ptr,
|
||||
((float*)info.ptr) + info.size);
|
||||
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
|
||||
algorithm, dataVec) };
|
||||
} else if (push_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
} else if (push_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
|
||||
std::vector<uint32_t> dataVec((uint32_t*)info.ptr,
|
||||
((uint32_t*)info.ptr) + info.size);
|
||||
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
|
||||
algorithm, dataVec) };
|
||||
} else if (push_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
} else if (push_consts.dtype().is(py::dtype::of<std::int32_t>())) {
|
||||
std::vector<int32_t> dataVec((int32_t*)info.ptr,
|
||||
((int32_t*)info.ptr) + info.size);
|
||||
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
|
||||
algorithm, dataVec) };
|
||||
} else if (push_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
} else if (push_consts.dtype().is(py::dtype::of<std::double_t>())) {
|
||||
std::vector<double> dataVec((double*)info.ptr,
|
||||
((double*)info.ptr) + info.size);
|
||||
return std::unique_ptr<kp::OpAlgoDispatch>{ new kp::OpAlgoDispatch(
|
||||
|
|
@ -76,29 +76,29 @@ PYBIND11_MODULE(kp, m)
|
|||
py::class_<kp::OpBase, std::shared_ptr<kp::OpBase>>(
|
||||
m, "OpBase", DOC(kp, OpBase));
|
||||
|
||||
py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(
|
||||
m,
|
||||
"OpTensorSyncDevice",
|
||||
py::base<kp::OpBase>(),
|
||||
DOC(kp, OpTensorSyncDevice))
|
||||
py::class_<kp::OpTensorSyncDevice,
|
||||
kp::OpBase,
|
||||
std::shared_ptr<kp::OpTensorSyncDevice>>(
|
||||
m, "OpTensorSyncDevice", DOC(kp, OpTensorSyncDevice))
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
|
||||
DOC(kp, OpTensorSyncDevice, OpTensorSyncDevice));
|
||||
|
||||
py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(
|
||||
m,
|
||||
"OpTensorSyncLocal",
|
||||
py::base<kp::OpBase>(),
|
||||
DOC(kp, OpTensorSyncLocal))
|
||||
py::class_<kp::OpTensorSyncLocal,
|
||||
kp::OpBase,
|
||||
std::shared_ptr<kp::OpTensorSyncLocal>>(
|
||||
m, "OpTensorSyncLocal", DOC(kp, OpTensorSyncLocal))
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
|
||||
DOC(kp, OpTensorSyncLocal, OpTensorSyncLocal));
|
||||
|
||||
py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(
|
||||
m, "OpTensorCopy", py::base<kp::OpBase>(), DOC(kp, OpTensorCopy))
|
||||
py::class_<kp::OpTensorCopy, kp::OpBase, std::shared_ptr<kp::OpTensorCopy>>(
|
||||
m, "OpTensorCopy", DOC(kp, OpTensorCopy))
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>(),
|
||||
DOC(kp, OpTensorCopy, OpTensorCopy));
|
||||
|
||||
py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(
|
||||
m, "OpAlgoDispatch", py::base<kp::OpBase>(), DOC(kp, OpAlgoDispatch))
|
||||
py::class_<kp::OpAlgoDispatch,
|
||||
kp::OpBase,
|
||||
std::shared_ptr<kp::OpAlgoDispatch>>(
|
||||
m, "OpAlgoDispatch", DOC(kp, OpAlgoDispatch))
|
||||
.def(py::init<const std::shared_ptr<kp::Algorithm>&,
|
||||
const std::vector<float>&>(),
|
||||
DOC(kp, OpAlgoDispatch, OpAlgoDispatch),
|
||||
|
|
@ -109,8 +109,8 @@ PYBIND11_MODULE(kp, m)
|
|||
py::arg("algorithm"),
|
||||
py::arg("push_consts"));
|
||||
|
||||
py::class_<kp::OpMult, std::shared_ptr<kp::OpMult>>(
|
||||
m, "OpMult", py::base<kp::OpBase>(), DOC(kp, OpMult))
|
||||
py::class_<kp::OpMult, kp::OpBase, std::shared_ptr<kp::OpMult>>(
|
||||
m, "OpMult", DOC(kp, OpMult))
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&,
|
||||
const std::shared_ptr<kp::Algorithm>&>(),
|
||||
DOC(kp, OpMult, OpMult));
|
||||
|
|
@ -253,31 +253,31 @@ PYBIND11_MODULE(kp, m)
|
|||
"size {} dtype {}",
|
||||
flatdata.size(),
|
||||
std::string(py::str(flatdata.dtype())));
|
||||
if (flatdata.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (flatdata.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
return self.tensor(info.ptr,
|
||||
flatdata.size(),
|
||||
sizeof(float),
|
||||
kp::Tensor::TensorDataTypes::eFloat,
|
||||
tensor_type);
|
||||
} else if (flatdata.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
} else if (flatdata.dtype().is(py::dtype::of<std::uint32_t>())) {
|
||||
return self.tensor(info.ptr,
|
||||
flatdata.size(),
|
||||
sizeof(uint32_t),
|
||||
kp::Tensor::TensorDataTypes::eUnsignedInt,
|
||||
tensor_type);
|
||||
} else if (flatdata.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
} else if (flatdata.dtype().is(py::dtype::of<std::int32_t>())) {
|
||||
return self.tensor(info.ptr,
|
||||
flatdata.size(),
|
||||
sizeof(int32_t),
|
||||
kp::Tensor::TensorDataTypes::eInt,
|
||||
tensor_type);
|
||||
} else if (flatdata.dtype() == py::dtype::of<std::double_t>()) {
|
||||
} else if (flatdata.dtype().is(py::dtype::of<std::double_t>())) {
|
||||
return self.tensor(info.ptr,
|
||||
flatdata.size(),
|
||||
sizeof(double),
|
||||
kp::Tensor::TensorDataTypes::eDouble,
|
||||
tensor_type);
|
||||
} else if (flatdata.dtype() == py::dtype::of<bool>()) {
|
||||
} else if (flatdata.dtype().is(py::dtype::of<bool>())) {
|
||||
return self.tensor(info.ptr,
|
||||
flatdata.size(),
|
||||
sizeof(bool),
|
||||
|
|
@ -340,10 +340,10 @@ PYBIND11_MODULE(kp, m)
|
|||
|
||||
// We have to iterate across a combination of parameters due to the
|
||||
// lack of support for templating
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
std::vector<float> specConstsVec(
|
||||
(float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
std::vector<float> pushConstsVec((float*)pushInfo.ptr,
|
||||
((float*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -352,8 +352,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specConstsVec,
|
||||
pushConstsVec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::int32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::int32_t>())) {
|
||||
std::vector<int32_t> pushConstsVec(
|
||||
(int32_t*)pushInfo.ptr,
|
||||
((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
|
|
@ -362,8 +362,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specConstsVec,
|
||||
pushConstsVec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::uint32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::uint32_t>())) {
|
||||
std::vector<uint32_t> pushConstsVec(
|
||||
(uint32_t*)pushInfo.ptr,
|
||||
((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
|
|
@ -372,8 +372,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specConstsVec,
|
||||
pushConstsVec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::double_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::double_t>())) {
|
||||
std::vector<double> pushConstsVec((double*)pushInfo.ptr,
|
||||
((double*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -383,11 +383,11 @@ PYBIND11_MODULE(kp, m)
|
|||
specConstsVec,
|
||||
pushConstsVec);
|
||||
}
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::int32_t>()) {
|
||||
} else if (spec_consts.dtype().is(py::dtype::of<std::int32_t>())) {
|
||||
std::vector<int32_t> specconstsvec((int32_t*)specInfo.ptr,
|
||||
((int32_t*)specInfo.ptr) +
|
||||
specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
|
||||
((float*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -396,8 +396,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::int32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::int32_t>())) {
|
||||
std::vector<int32_t> pushconstsvec(
|
||||
(int32_t*)pushInfo.ptr,
|
||||
((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
|
|
@ -406,8 +406,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::uint32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::uint32_t>())) {
|
||||
std::vector<uint32_t> pushconstsvec(
|
||||
(uint32_t*)pushInfo.ptr,
|
||||
((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
|
|
@ -416,8 +416,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::double_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::double_t>())) {
|
||||
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
|
||||
((double*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -427,11 +427,11 @@ PYBIND11_MODULE(kp, m)
|
|||
specconstsvec,
|
||||
pushconstsvec);
|
||||
}
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::uint32_t>()) {
|
||||
} else if (spec_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
|
||||
std::vector<uint32_t> specconstsvec((uint32_t*)specInfo.ptr,
|
||||
((uint32_t*)specInfo.ptr) +
|
||||
specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
|
||||
((float*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -440,8 +440,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::int32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::int32_t>())) {
|
||||
std::vector<int32_t> pushconstsvec(
|
||||
(int32_t*)pushInfo.ptr,
|
||||
((int32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
|
|
@ -450,8 +450,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::uint32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::uint32_t>())) {
|
||||
std::vector<uint32_t> pushconstsvec(
|
||||
(uint32_t*)pushInfo.ptr,
|
||||
((uint32_t*)pushInfo.ptr) + pushInfo.size);
|
||||
|
|
@ -460,8 +460,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::double_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::double_t>())) {
|
||||
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
|
||||
((double*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -471,11 +471,11 @@ PYBIND11_MODULE(kp, m)
|
|||
specconstsvec,
|
||||
pushconstsvec);
|
||||
}
|
||||
} else if (spec_consts.dtype() == py::dtype::of<std::double_t>()) {
|
||||
} else if (spec_consts.dtype().is(py::dtype::of<std::double_t>())) {
|
||||
std::vector<double> specconstsvec((double*)specInfo.ptr,
|
||||
((double*)specInfo.ptr) +
|
||||
specInfo.size);
|
||||
if (spec_consts.dtype() == py::dtype::of<std::float_t>()) {
|
||||
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
|
||||
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
|
||||
((float*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -484,8 +484,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::int32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::int32_t>())) {
|
||||
std::vector<float> pushconstsvec((int32_t*)pushInfo.ptr,
|
||||
((int32_t*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -494,8 +494,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::uint32_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::uint32_t>())) {
|
||||
std::vector<float> pushconstsvec((uint32_t*)pushInfo.ptr,
|
||||
((uint32_t*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -504,8 +504,8 @@ PYBIND11_MODULE(kp, m)
|
|||
workgroup,
|
||||
specconstsvec,
|
||||
pushconstsvec);
|
||||
} else if (spec_consts.dtype() ==
|
||||
py::dtype::of<std::double_t>()) {
|
||||
} else if (spec_consts.dtype().is(
|
||||
py::dtype::of<std::double_t>())) {
|
||||
std::vector<float> pushconstsvec((double*)pushInfo.ptr,
|
||||
((double*)pushInfo.ptr) +
|
||||
pushInfo.size);
|
||||
|
|
@ -515,11 +515,9 @@ PYBIND11_MODULE(kp, m)
|
|||
specconstsvec,
|
||||
pushconstsvec);
|
||||
}
|
||||
} else {
|
||||
// If reach then no valid dtype supported
|
||||
throw std::runtime_error(
|
||||
"Kompute Python no valid dtype supported");
|
||||
}
|
||||
// If reach then no valid dtype supported
|
||||
throw std::runtime_error("Kompute Python no valid dtype supported");
|
||||
},
|
||||
DOC(kp, Manager, algorithm),
|
||||
py::arg("tensors"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue