diff --git a/python/test/test_kompute.py b/python/test/test_kompute.py index fd6611550..559600eba 100644 --- a/python/test/test_kompute.py +++ b/python/test/test_kompute.py @@ -189,14 +189,20 @@ def test_logistic_regression_pyshader(): mgr.eval_tensor_create_def(params) + # Record commands for efficient evaluation + sq = mgr.create_sequence() + sq.begin() + sq.record_tensor_sync_device([tensor_w_in, tensor_b_in]) + sq.record_algo_data(params, compute_shader.to_spirv()) + sq.record_tensor_sync_local([tensor_w_out_i, tensor_w_out_j, tensor_b_out, tensor_l_out]) + sq.end() + ITERATIONS = 100 learning_rate = 0.1 # Perform machine learning training and inference across all input X and Y for i_iter in range(ITERATIONS): - mgr.eval_tensor_sync_device_def([tensor_w_in, tensor_b_in]) - mgr.eval_algo_data_def(params, compute_shader.to_spirv()) - mgr.eval_tensor_sync_local_def([tensor_w_out_i, tensor_w_out_j, tensor_b_out, tensor_l_out]) + sq.eval() # Calculate the parameters based on the respective derivatives calculated w_in_i_val = tensor_w_in.data()[0]