Added set and get functions
This commit is contained in:
parent
752c7b74f3
commit
90bc86d0eb
2 changed files with 10 additions and 14 deletions
|
|
@ -214,18 +214,10 @@ def test_logistic_regression_pyshader():
|
|||
sq.eval()
|
||||
|
||||
# Calculate the parameters based on the respective derivatives calculated
|
||||
w_in_i_val = tensor_w_in.data()[0]
|
||||
w_in_j_val = tensor_w_in.data()[1]
|
||||
b_in_val = tensor_b_in.data()[0]
|
||||
|
||||
for j_iter in range(tensor_b_out.size()):
|
||||
w_in_i_val -= learning_rate * tensor_w_out_i.data()[j_iter]
|
||||
w_in_j_val -= learning_rate * tensor_w_out_j.data()[j_iter]
|
||||
b_in_val -= learning_rate * tensor_b_out.data()[j_iter]
|
||||
|
||||
# Update the parameters to process inference again
|
||||
tensor_w_in.set_data([w_in_i_val, w_in_j_val])
|
||||
tensor_b_in.set_data([b_in_val])
|
||||
tensor_w_in.set(0, tensor_w_in.get(0) - learning_rate * tensor_w_out_i.data()[j_iter])
|
||||
tensor_w_in.set(1, tensor_w_in.get(1) - learning_rate * tensor_w_out_j.data()[j_iter])
|
||||
tensor_b_in.set(0, tensor_b_in.get(0) - learning_rate * tensor_b_out.data()[j_iter])
|
||||
|
||||
assert tensor_w_in.data()[0] < 0.01
|
||||
assert tensor_w_in.data()[0] > 0.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue