All python tests pass

This commit is contained in:
Alejandro Saucedo 2021-02-28 07:57:36 +00:00
parent 4c4d073b90
commit 91d3b9a223
11 changed files with 158 additions and 169 deletions

View file

@ -54,17 +54,20 @@ PYBIND11_MODULE(kp, m) {
py::class_<kp::OpBase, std::shared_ptr<kp::OpBase>>(m, "OpBase");
py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(m, "OpTensorSyncDevice")
py::class_<kp::OpTensorSyncDevice, std::shared_ptr<kp::OpTensorSyncDevice>>(m, "OpTensorSyncDevice", py::base<kp::OpBase>())
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>());
py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(m, "OpTensorSyncLocal")
py::class_<kp::OpTensorSyncLocal, std::shared_ptr<kp::OpTensorSyncLocal>>(m, "OpTensorSyncLocal", py::base<kp::OpBase>())
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>());
py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(m, "OpTensorCopy")
py::class_<kp::OpTensorCopy, std::shared_ptr<kp::OpTensorCopy>>(m, "OpTensorCopy", py::base<kp::OpBase>())
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&>());
py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(m, "OpAlgoDispatch")
.def(py::init<const std::shared_ptr<kp::Algorithm>&, bool>());
py::class_<kp::OpAlgoDispatch, std::shared_ptr<kp::OpAlgoDispatch>>(m, "OpAlgoDispatch", py::base<kp::OpBase>())
.def(py::init<const std::shared_ptr<kp::Algorithm>&>());
py::class_<kp::OpMult, std::shared_ptr<kp::OpMult>>(m, "OpMult", py::base<kp::OpBase>())
.def(py::init<const std::vector<std::shared_ptr<kp::Tensor>>&,const std::shared_ptr<kp::Algorithm>&>());
py::class_<kp::Algorithm, std::shared_ptr<kp::Algorithm>>(m, "Algorithm")
.def("get_tensors", &kp::Algorithm::getTensors)
@ -112,8 +115,7 @@ PYBIND11_MODULE(kp, m) {
.def("__len__", &kp::Tensor::size, "Retrieves the size of the Tensor data as per the local Tensor memory.")
.def("tensor_type", &kp::Tensor::tensorType, "Retreves the memory type of the tensor.")
.def("is_init", &kp::Tensor::isInit, "Checks whether the tensor GPU memory has been initialised.")
.def("map_data_from_host", &kp::Tensor::mapDataFromHostMemory, "Maps data into GPU memory from tensor local data.")
.def("map_data_into_host", &kp::Tensor::mapDataIntoHostMemory, "Maps data from GPU memory into tensor local data.");
.def("destroy", &kp::Tensor::destroy, "Destroy tensor GPU resources.");
py::class_<kp::Sequence, std::shared_ptr<kp::Sequence>>(m, "Sequence")
.def("record", [](kp::Sequence& self, std::shared_ptr<kp::OpBase> op) { return self.record(op); })
@ -147,15 +149,17 @@ PYBIND11_MODULE(kp, m) {
.def("algorithm", [](kp::Manager& self,
const std::vector<std::shared_ptr<kp::Tensor>>& tensors,
const py::bytes& spirv,
const kp::Workgroup& workgroup = {},
const kp::Constants& spec_consts = {},
const kp::Constants& push_consts = {}) {
const kp::Workgroup& workgroup,
const kp::Constants& spec_consts,
const kp::Constants& push_consts) {
py::buffer_info info(py::buffer(spirv).request());
const char *data = reinterpret_cast<const char *>(info.ptr);
size_t length = static_cast<size_t>(info.size);
std::vector<uint32_t> spirvVec((uint32_t*)data, (uint32_t*)(data + length));
return self.algorithm(tensors, spirvVec, workgroup, spec_consts, push_consts);
});
},
"Algorithm initialisation function",
py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), py::arg("spec_consts") = kp::Constants(), py::arg("push_consts") = kp::Constants());
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;