added tests
This commit is contained in:
parent
46278eb0a9
commit
695fb08c80
3 changed files with 7 additions and 0 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
|
||||
import kp
|
||||
import numpy as np
|
||||
|
||||
DIRNAME = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ def test_opmult():
|
|||
mgr.eval_tensor_sync_local_def([tensor_out])
|
||||
|
||||
assert tensor_out.data() == [2.0, 4.0, 6.0]
|
||||
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])
|
||||
|
||||
def test_opalgobase_data():
|
||||
"""
|
||||
|
|
@ -57,6 +59,7 @@ def test_opalgobase_data():
|
|||
mgr.eval_tensor_sync_local_def([tensor_out])
|
||||
|
||||
assert tensor_out.data() == [2.0, 4.0, 6.0]
|
||||
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])
|
||||
|
||||
|
||||
def test_opalgobase_file():
|
||||
|
|
@ -106,3 +109,4 @@ def test_sequence():
|
|||
seq.eval()
|
||||
|
||||
assert tensor_out.data() == [2.0, 4.0, 6.0]
|
||||
assert np.all(tensor_out.numpy() == [2.0, 4.0, 6.0])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue