All python tests pass
This commit is contained in:
parent
4c4d073b90
commit
91d3b9a223
11 changed files with 158 additions and 169 deletions
|
|
@ -54,17 +54,20 @@ PYBIND11_MODULE(kp, m) {
|
|||
|
||||
py::class_<kp::OpBase, std::shared_ptr<kp::OpBase>>(m, "OpBase");
|
||||
|
||||
py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(m, "OpTensorSyncDevice")
|
||||
py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(m, "OpTensorSyncDevice", py::base<kp::OpBase>())
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>());
|
||||
|
||||
py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(m, "OpTensorSyncLocal")
|
||||
py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(m, "OpTensorSyncLocal", py::base<kp::OpBase>())
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>());
|
||||
|
||||
py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(m, "OpTensorCopy")
|
||||
py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(m, "OpTensorCopy", py::base<kp::OpBase>())
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>());
|
||||
|
||||
py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(m, "OpAlgoDispatch")
|
||||
.def(py::init<const std::shared_ptr<kp::Algorithm>&, bool>());
|
||||
py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(m, "OpAlgoDispatch", py::base<kp::OpBase>())
|
||||
.def(py::init<const std::shared_ptr<kp::Algorithm>&>());
|
||||
|
||||
py::class_<kp::OpMult, std::shared_ptr<kp::OpMult>>(m, "OpMult", py::base<kp::OpBase>())
|
||||
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&,const std::shared_ptr<kp::Algorithm>&>());
|
||||
|
||||
py::class_<kp::Algorithm, std::shared_ptr<kp::Algorithm>>(m, "Algorithm")
|
||||
.def("get_tensors", &kp::Algorithm::getTensors)
|
||||
|
|
@ -112,8 +115,7 @@ PYBIND11_MODULE(kp, m) {
|
|||
.def("__len__", &kp::Tensor::size, "Retrieves the size of the Tensor data as per the local Tensor memory.")
|
||||
.def("tensor_type", &kp::Tensor::tensorType, "Retreves the memory type of the tensor.")
|
||||
.def("is_init", &kp::Tensor::isInit, "Checks whether the tensor GPU memory has been initialised.")
|
||||
.def("map_data_from_host", &kp::Tensor::mapDataFromHostMemory, "Maps data into GPU memory from tensor local data.")
|
||||
.def("map_data_into_host", &kp::Tensor::mapDataIntoHostMemory, "Maps data from GPU memory into tensor local data.");
|
||||
.def("destroy", &kp::Tensor::destroy, "Destroy tensor GPU resources.");
|
||||
|
||||
py::class_<kp::Sequence, std::shared_ptr<kp::Sequence>>(m, "Sequence")
|
||||
.def("record", [](kp::Sequence& self, std::shared_ptr<kp::OpBase> op) { return self.record(op); })
|
||||
|
|
@ -147,15 +149,17 @@ PYBIND11_MODULE(kp, m) {
|
|||
.def("algorithm", [](kp::Manager& self,
|
||||
const std::vector<std::shared_ptr<kp::Tensor>>& tensors,
|
||||
const py::bytes& spirv,
|
||||
const kp::Workgroup& workgroup = {},
|
||||
const kp::Constants& spec_consts = {},
|
||||
const kp::Constants& push_consts = {}) {
|
||||
const kp::Workgroup& workgroup,
|
||||
const kp::Constants& spec_consts,
|
||||
const kp::Constants& push_consts) {
|
||||
py::buffer_info info(py::buffer(spirv).request());
|
||||
const char *data = reinterpret_cast<const char *>(info.ptr);
|
||||
size_t length = static_cast<size_t>(info.size);
|
||||
std::vector<uint32_t> spirvVec((uint32_t*)data, (uint32_t*)(data + length));
|
||||
return self.algorithm(tensors, spirvVec, workgroup, spec_consts, push_consts);
|
||||
});
|
||||
},
|
||||
"Algorithm initialisation function",
|
||||
py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), py::arg("spec_consts") = kp::Constants(), py::arg("push_consts") = kp::Constants());
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue