All python tests pass
This commit is contained in:
parent
4c4d073b90
commit
91d3b9a223
11 changed files with 158 additions and 169 deletions
|
|
@ -7,6 +7,8 @@ import pyshader as ps
|
|||
|
||||
DIRNAME = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
kp_log = logging.getLogger("kp")
|
||||
|
||||
# TODO: Add example with file
|
||||
#def test_opalgobase_file():
|
||||
# """
|
||||
|
|
@ -62,9 +64,9 @@ void main()
|
|||
algo = mgr.algorithm(params, spirv)
|
||||
|
||||
(mgr.sequence()
|
||||
.record(kp.OpTensorSyncLocal(params))
|
||||
.record(kp.OpAlgoDispatch(algo))
|
||||
.record(kp.OpTensorSyncDevice(params))
|
||||
.record(kp.OpAlgoDispatch(algo))
|
||||
.record(kp.OpTensorSyncLocal(params))
|
||||
.eval())
|
||||
|
||||
assert tensor_out.data() == [2.0, 4.0, 6.0]
|
||||
|
|
@ -102,9 +104,9 @@ def test_sequence():
|
|||
|
||||
sq = mgr.sequence()
|
||||
|
||||
sq.record(kp.OpTensorSyncLocal(params))
|
||||
sq.record(kp.OpAlgoDispatch(algo))
|
||||
sq.record(kp.OpTensorSyncDevice(params))
|
||||
sq.record(kp.OpAlgoDispatch(algo))
|
||||
sq.record(kp.OpTensorSyncLocal(params))
|
||||
|
||||
sq.eval()
|
||||
|
||||
|
|
@ -141,16 +143,14 @@ def test_workgroup():
|
|||
data1[i] = f32(gl_idx.x)
|
||||
data2[i] = f32(gl_idx.y)
|
||||
|
||||
algo = mgr.algorithm([tensor_a, tensor_b], compute_shader_wg.to_spirv(), (16,8,1), [], [])
|
||||
algo = mgr.algorithm([tensor_a, tensor_b], compute_shader_wg.to_spirv(), (16,8,1))
|
||||
|
||||
(mgr.sequence()
|
||||
.record(kp.OpTensorSyncDevice([tensor_a, tensor_b]))
|
||||
.record(kp.OpAlgoDispatch(algo))
|
||||
.record(kp.OpAlgoTensorSyncLocal([tensor_a, tensor_b]))
|
||||
.record(kp.OpTensorSyncLocal([tensor_a, tensor_b]))
|
||||
.eval())
|
||||
|
||||
assert sq.is_init() == False
|
||||
|
||||
print(tensor_a.numpy())
|
||||
print(tensor_b.numpy())
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue