From f1332fef45f572a6b5968ee464aaaf7b12ffc149 Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Wed, 11 Nov 2020 07:17:33 +0000 Subject: [PATCH] Added individual arary mult file --- python/test/test_array_multiplication.py | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 python/test/test_array_multiplication.py diff --git a/python/test/test_array_multiplication.py b/python/test/test_array_multiplication.py new file mode 100644 index 000000000..4698a573c --- /dev/null +++ b/python/test/test_array_multiplication.py @@ -0,0 +1,25 @@ +import pyshader as ps +import kp + + +def test_array_multiplication(): + + @ps.python2shader + def compute_shader_multiply(index=("input", "GlobalInvocationId", ps.ivec3), + data1=("buffer", 0, ps.Array(ps.f32)), + data2=("buffer", 1, ps.Array(ps.f32)), + data3=("buffer", 2, ps.Array(ps.f32))): + i = index.x + data3[i] = data1[i] * data2[i] + + tensor_in_a = kp.Tensor([2, 2, 2]) + tensor_in_b = kp.Tensor([1, 2, 3]) + tensor_out = kp.Tensor([0, 0, 0]) + + mgr = kp.Manager() + + mgr.eval_tensor_create_def([tensor_in_a, tensor_in_b, tensor_out]) + mgr.eval_algo_data_def([tensor_in_a, tensor_in_b, tensor_out], compute_shader_multiply.to_spirv()) + mgr.eval_tensor_sync_local_def([tensor_out]) + + assert tensor_out.data() == [2.0, 4.0, 6.0]