diff --git a/python/src/main.cpp b/python/src/main.cpp index 36be7ac7a..8aac68c98 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -4,6 +4,8 @@ #include +#include "fmt/ranges.h" + #include "docstrings.hpp" namespace py = pybind11; @@ -64,7 +66,8 @@ PYBIND11_MODULE(kp, m) { .def(py::init>&>()); py::class_>(m, "OpAlgoDispatch", py::base()) - .def(py::init&>()); + .def(py::init&,const kp::Constants&>(), + py::arg("algorithm"), py::arg("push_consts") = kp::Constants()); py::class_>(m, "OpMult", py::base()) .def(py::init>&,const std::shared_ptr&>()); @@ -73,12 +76,10 @@ PYBIND11_MODULE(kp, m) { .def("get_tensors", &kp::Algorithm::getTensors) .def("destroy", &kp::Algorithm::destroy) .def("get_spec_consts", &kp::Algorithm::getSpecializationConstants) - .def("get_push_consts", &kp::Algorithm::getPushConstants) .def("is_init", &kp::Algorithm::isInit); py::class_>(m, "Tensor", DOC(kp, Tensor)) - .def("data", &kp::Tensor::data, DOC(kp, Tensor, data)) - .def("numpy", [](kp::Tensor& self) { + .def("data", [](kp::Tensor& self) { return py::array(self.data().size(), self.data().data()); }, "Returns stored data as a new numpy array.") .def("__getitem__", [](kp::Tensor &self, size_t index) -> float { return self.data()[index]; }, @@ -150,16 +151,15 @@ PYBIND11_MODULE(kp, m) { const std::vector>& tensors, const py::bytes& spirv, const kp::Workgroup& workgroup, - const kp::Constants& spec_consts, - const kp::Constants& push_consts) { + const kp::Constants& spec_consts) { py::buffer_info info(py::buffer(spirv).request()); const char *data = reinterpret_cast(info.ptr); size_t length = static_cast(info.size); std::vector spirvVec((uint32_t*)data, (uint32_t*)(data + length)); - return self.algorithm(tensors, spirvVec, workgroup, spec_consts, push_consts); + return self.algorithm(tensors, spirvVec, workgroup, spec_consts); }, "Algorithm initialisation function", - py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), py::arg("spec_consts") = kp::Constants(), py::arg("push_consts") = kp::Constants()); + py::arg("tensors"), py::arg("spirv"), py::arg("workgroup") = kp::Workgroup(), py::arg("spec_consts") = kp::Constants()); #ifdef VERSION_INFO m.attr("__version__") = VERSION_INFO; diff --git a/python/test/test_array_multiplication.py b/python/test/test_array_multiplication.py index 55d764805..0dab581c6 100644 --- a/python/test/test_array_multiplication.py +++ b/python/test/test_array_multiplication.py @@ -30,5 +30,5 @@ def test_array_multiplication(): .record(kp.OpTensorSyncLocal([tensor_out])) .eval()) - assert tensor_out.data() == [2.0, 4.0, 6.0] - assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0]) + assert tensor_out.data().tolist() == [2.0, 4.0, 6.0] + assert np.all(tensor_out.data() == [2.0, 4.0, 6.0]) diff --git a/python/test/test_kompute.py b/python/test/test_kompute.py index ad4b77391..4514e2dd2 100644 --- a/python/test/test_kompute.py +++ b/python/test/test_kompute.py @@ -69,7 +69,7 @@ void main() .record(kp.OpTensorSyncLocal(params)) .eval()) - assert tensor_out.data() == [2.0, 4.0, 6.0] + assert tensor_out.data().tolist() == [2.0, 4.0, 6.0] def test_sequence(): """ @@ -116,8 +116,8 @@ def test_sequence(): assert sq.is_init() == False - assert tensor_out.data() == [2.0, 4.0, 6.0] - assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0]) + assert tensor_out.data().tolist() == [2.0, 4.0, 6.0] + assert np.all(tensor_out.data() == [2.0, 4.0, 6.0]) tensor_in_a.destroy() tensor_in_b.destroy() @@ -127,6 +127,39 @@ def test_sequence(): assert tensor_in_b.is_init() == False assert tensor_out.is_init() == False +def test_pushconsts(): + + spirv = kp.Shader.compile_source(""" + #version 450 + layout(push_constant) uniform PushConstants { + float x; + float y; + float z; + } pcs; + layout (local_size_x = 1) in; + layout(set = 0, binding = 0) buffer a { float pa[]; }; + void main() { + pa[0] += pcs.x; + pa[1] += pcs.y; + pa[2] += pcs.z; + } + """) + + mgr = kp.Manager() + + tensor = mgr.tensor([0, 0, 0]) + + algo = mgr.algorithm([tensor], spirv, (1, 1, 1)) + + (mgr.sequence() + .record(kp.OpTensorSyncDevice([tensor])) + .record(kp.OpAlgoDispatch(algo, [0.1, 0.2, 0.3])) + .record(kp.OpAlgoDispatch(algo, [0.3, 0.2, 0.1])) + .record(kp.OpTensorSyncLocal([tensor])) + .eval()) + + assert np.all(tensor.data() == np.array([0.4, 0.4, 0.4], dtype=np.float32)) + def test_workgroup(): mgr = kp.Manager(0) @@ -151,9 +184,9 @@ def test_workgroup(): .record(kp.OpTensorSyncLocal([tensor_a, tensor_b])) .eval()) - print(tensor_a.numpy()) - print(tensor_b.numpy()) + print(tensor_a.data()) + print(tensor_b.data()) - assert np.all(tensor_a.numpy() == np.stack([np.arange(16)]*8, axis=1).ravel()) - assert np.all(tensor_b.numpy() == np.stack([np.arange(8)]*16, axis=0).ravel()) + assert np.all(tensor_a.data() == np.stack([np.arange(16)]*8, axis=1).ravel()) + assert np.all(tensor_b.data() == np.stack([np.arange(8)]*16, axis=0).ravel()) diff --git a/src/Algorithm.cpp b/src/Algorithm.cpp index 174f78d9b..cfae65643 100644 --- a/src/Algorithm.cpp +++ b/src/Algorithm.cpp @@ -2,8 +2,6 @@ #include "kompute/Algorithm.hpp" -#include "fmt/ranges.h" - namespace kp { Algorithm::Algorithm(