Added set and get functions

This commit is contained in:
Alejandro Saucedo 2020-11-10 09:07:22 +00:00
parent 752c7b74f3
commit 90bc86d0eb
2 changed files with 10 additions and 14 deletions

View file

@ -214,18 +214,10 @@ def test_logistic_regression_pyshader():
sq.eval()
# Calculate the parameters based on the respective derivatives calculated
w_in_i_val = tensor_w_in.data()[0]
w_in_j_val = tensor_w_in.data()[1]
b_in_val = tensor_b_in.data()[0]
for j_iter in range(tensor_b_out.size()):
w_in_i_val -= learning_rate * tensor_w_out_i.data()[j_iter]
w_in_j_val -= learning_rate * tensor_w_out_j.data()[j_iter]
b_in_val -= learning_rate * tensor_b_out.data()[j_iter]
# Update the parameters to process inference again
tensor_w_in.set_data([w_in_i_val, w_in_j_val])
tensor_b_in.set_data([b_in_val])
tensor_w_in.set(0, tensor_w_in.get(0) - learning_rate * tensor_w_out_i.data()[j_iter])
tensor_w_in.set(1, tensor_w_in.get(1) - learning_rate * tensor_w_out_j.data()[j_iter])
tensor_b_in.set(0, tensor_b_in.get(0) - learning_rate * tensor_b_out.data()[j_iter])
assert tensor_w_in.data()[0] < 0.01
assert tensor_w_in.data()[0] > 0.0