All python tests pass

This commit is contained in:
Alejandro Saucedo 2021-02-28 07:57:36 +00:00
parent 4c4d073b90
commit 91d3b9a223
11 changed files with 158 additions and 169 deletions

View file

@ -7,6 +7,8 @@ import pyshader as ps
DIRNAME = os.path.dirname(os.path.abspath(__file__))
kp_log = logging.getLogger("kp")
# TODO: Add example with file
#def test_opalgobase_file():
# """
@ -62,9 +64,9 @@ void main()
algo = mgr.algorithm(params, spirv)
(mgr.sequence()
.record(kp.OpTensorSyncLocal(params))
.record(kp.OpAlgoDispatch(algo))
.record(kp.OpTensorSyncDevice(params))
.record(kp.OpAlgoDispatch(algo))
.record(kp.OpTensorSyncLocal(params))
.eval())
assert tensor_out.data() == [2.0, 4.0, 6.0]
@ -102,9 +104,9 @@ def test_sequence():
sq = mgr.sequence()
sq.record(kp.OpTensorSyncLocal(params))
sq.record(kp.OpAlgoDispatch(algo))
sq.record(kp.OpTensorSyncDevice(params))
sq.record(kp.OpAlgoDispatch(algo))
sq.record(kp.OpTensorSyncLocal(params))
sq.eval()
@ -141,16 +143,14 @@ def test_workgroup():
data1[i] = f32(gl_idx.x)
data2[i] = f32(gl_idx.y)
algo = mgr.algorithm([tensor_a, tensor_b], compute_shader_wg.to_spirv(), (16,8,1), [], [])
algo = mgr.algorithm([tensor_a, tensor_b], compute_shader_wg.to_spirv(), (16,8,1))
(mgr.sequence()
.record(kp.OpTensorSyncDevice([tensor_a, tensor_b]))
.record(kp.OpAlgoDispatch(algo))
.record(kp.OpAlgoTensorSyncLocal([tensor_a, tensor_b]))
.record(kp.OpTensorSyncLocal([tensor_a, tensor_b]))
.eval())
assert sq.is_init() == False
print(tensor_a.numpy())
print(tensor_b.numpy())