Updated and renamed classes for lr example

This commit is contained in:
Alejandro Saucedo 2020-09-27 14:09:16 +01:00
parent 143baa4db3
commit 8959d90fa6
37 changed files with 402 additions and 60 deletions

View file

@ -2,13 +2,13 @@
#include <vector>
#include "KomputeSummatorNode.h"
#include "KomputeModelMLNode.h"
KomputeSummatorNode::KomputeSummatorNode() {
KomputeModelMLNode::KomputeModelMLNode() {
this->_init();
}
void KomputeSummatorNode::add(float value) {
void KomputeModelMLNode::add(float value) {
// Set the new data in the local device
this->mSecondaryTensor->setData({value});
// Execute recorded sequence
@ -20,14 +20,14 @@ void KomputeSummatorNode::add(float value) {
}
}
void KomputeSummatorNode::reset() {
void KomputeModelMLNode::reset() {
}
float KomputeSummatorNode::get_total() const {
float KomputeModelMLNode::get_total() const {
return this->mPrimaryTensor->data()[0];
}
void KomputeSummatorNode::_init() {
void KomputeModelMLNode::_init() {
std::cout << "CALLING INIT" << std::endl;
this->mPrimaryTensor = this->mManager.buildTensor({ 0.0 });
this->mSecondaryTensor = this->mManager.buildTensor({ 0.0 });
@ -74,16 +74,16 @@ void KomputeSummatorNode::_init() {
}
}
void KomputeSummatorNode::_process(float delta) {
void KomputeModelMLNode::_process(float delta) {
}
void KomputeSummatorNode::_bind_methods() {
ClassDB::bind_method(D_METHOD("_process", "delta"), &KomputeSummatorNode::_process);
ClassDB::bind_method(D_METHOD("_init"), &KomputeSummatorNode::_init);
void KomputeModelMLNode::_bind_methods() {
ClassDB::bind_method(D_METHOD("_process", "delta"), &KomputeModelMLNode::_process);
ClassDB::bind_method(D_METHOD("_init"), &KomputeModelMLNode::_init);
ClassDB::bind_method(D_METHOD("add", "value"), &KomputeSummatorNode::add);
ClassDB::bind_method(D_METHOD("reset"), &KomputeSummatorNode::reset);
ClassDB::bind_method(D_METHOD("get_total"), &KomputeSummatorNode::get_total);
ClassDB::bind_method(D_METHOD("add", "value"), &KomputeModelMLNode::add);
ClassDB::bind_method(D_METHOD("reset"), &KomputeModelMLNode::reset);
ClassDB::bind_method(D_METHOD("get_total"), &KomputeModelMLNode::get_total);
}

View file

@ -6,11 +6,11 @@
#include "scene/main/node.h"
class KomputeSummatorNode : public Node {
GDCLASS(KomputeSummatorNode, Node);
class KomputeModelMLNode : public Node {
GDCLASS(KomputeModelMLNode, Node);
public:
KomputeSummatorNode();
KomputeModelMLNode();
void add(float value);
void reset();

View file

@ -3,10 +3,10 @@
#include "register_types.h"
#include "core/class_db.h"
#include "KomputeSummatorNode.h"
#include "KomputeModelMLNode.h"
void register_kompute_summator_types() {
ClassDB::register_class<KomputeSummatorNode>();
ClassDB::register_class<KomputeModelMLNode>();
}
void unregister_kompute_summator_types() {

View file

@ -1,6 +1,6 @@
/* register_types.h */
#pragma once
void register_kompute_summator_types();
void unregister_kompute_summator_types();
void register_kompute_model_ml_types();
void unregister_kompute_model_ml_types();
/* yes, the word in the middle must be the same as the module folder name */

View file

@ -16,7 +16,7 @@ find_package(Vulkan REQUIRED)
add_library(kompute_godot
SHARED
src/KomputeSummator.cpp
src/KomputeModelML.cpp
src/KomputeGdNative.cpp)
target_include_directories(

View file

@ -1,4 +1,4 @@
#include "KomputeSummator.hpp"
#include "KomputeModelML.hpp"
extern "C" void GDN_EXPORT godot_gdnative_init(godot_gdnative_init_options *o) {
godot::Godot::gdnative_init(o);
@ -11,5 +11,5 @@ extern "C" void GDN_EXPORT godot_gdnative_terminate(godot_gdnative_terminate_opt
extern "C" void GDN_EXPORT godot_nativescript_init(void *handle) {
godot::Godot::nativescript_init(handle);
godot::register_class<godot::KomputeSummator>();
godot::register_class<godot::KomputeModelML>();
}

View file

@ -4,15 +4,16 @@
#include <string>
#include <iostream>
#include "KomputeSummator.hpp"
#include "KomputeModelML.hpp"
namespace godot {
KomputeSummator::KomputeSummator() {
KomputeModelML::KomputeModelML() {
std::cout << "CALLING CONSTRUCTOR" << std::endl;
this->_init();
}
void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
void KomputeModelML::train(Array yArr, Array xIArr, Array xJArr) {
assert(y.size() == xI.size());
assert(xI.size() == xJ.size());
@ -22,7 +23,7 @@ void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
std::vector<float> xJData;
std::vector<float> zerosData;
for (int i = 0; i < yArr.size(); i++) {
for (size_t i = 0; i < yArr.size(); i++) {
yData.push_back(yArr[i]);
xIData.push_back(xIArr[i]);
xJData.push_back(xJArr[i]);
@ -76,11 +77,11 @@ void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
sq->end();
// Iterate across all expected iterations
for (int i = 0; i < ITERATIONS; i++) {
for (size_t i = 0; i < ITERATIONS; i++) {
sq->eval();
for (int j = 0; j < bOut->size(); j++) {
for (size_t j = 0; j < bOut->size(); j++) {
wIn->data()[0] -= learningRate * wOutI->data()[j];
wIn->data()[1] -= learningRate * wOutJ->data()[j];
bIn->data()[0] -= learningRate * bOut->data()[j];
@ -98,7 +99,7 @@ void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
this->mBias = kp::Tensor(bIn->data());
}
Array KomputeSummator::predict(Array xI, Array xJ) {
Array KomputeModelML::predict(Array xI, Array xJ) {
assert(xI.size() == xJ.size());
Array retArray;
@ -106,7 +107,7 @@ Array KomputeSummator::predict(Array xI, Array xJ) {
// We run the inference in the CPU for simplicity
// BUt you can also implement the inference on GPU
// GPU implementation would speed up minibatching
for (int i = 0; i < xI.size(); i++) {
for (size_t i = 0; i < xI.size(); i++) {
float xIVal = xI[i];
float xJVal = xJ[i];
float result = (xIVal * this->mWeights.data()[0]
@ -121,9 +122,38 @@ Array KomputeSummator::predict(Array xI, Array xJ) {
return retArray;
}
void KomputeSummator::_register_methods() {
register_method((char *)"train", &KomputeSummator::train);
register_method((char *)"predict", &KomputeSummator::predict);
Array KomputeModelML::get_params() {
Array retArray;
SPDLOG_INFO(this->mWeights.size() + this->mBias.size());
if(this->mWeights.size() + this->mBias.size() == 0) {
return retArray;
}
retArray.push_back(this->mWeights.data()[0]);
retArray.push_back(this->mWeights.data()[1]);
retArray.push_back(this->mBias.data()[0]);
retArray.push_back(99.0);
return retArray;
}
void KomputeModelML::_init() {
std::cout << "CALLING INIT" << std::endl;
}
void KomputeModelML::_process(float delta) {
}
void KomputeModelML::_register_methods() {
register_method((char *)"_process", &KomputeModelML::_process);
register_method((char *)"_init", &KomputeModelML::_init);
register_method((char *)"train", &KomputeModelML::train);
register_method((char *)"predict", &KomputeModelML::predict);
register_method((char *)"get_params", &KomputeModelML::get_params);
}
}

View file

@ -9,16 +9,22 @@
#include "kompute/Kompute.hpp"
namespace godot {
class KomputeSummator : public Node2D {
class KomputeModelML : public Node2D {
private:
GODOT_CLASS(KomputeSummator, Node2D);
GODOT_CLASS(KomputeModelML, Node2D);
public:
KomputeSummator();
KomputeModelML();
void train(Array y, Array xI, Array xJ);
Array predict(Array xI, Array xJ);
Array get_params();
void _process(float delta);
void _init();
static void _register_methods();
private:

View file

@ -1,10 +1,309 @@
[gd_scene load_steps=3 format=2]
[gd_scene load_steps=10 format=2]
[ext_resource path="res://godot_resources/scripts/DynamicExampleScript.gd" type="Script" id=1]
[ext_resource path="res://godot_resources/scripts/KomputeNativeClass.gdns" type="Script" id=2]
[ext_resource path="res://godot_resources/assets/icon.png" type="Texture" id=3]
[ext_resource path="res://godot_resources/assets/TextFormat.theme" type="Theme" id=4]
[sub_resource type="GradientTexture" id=1]
[sub_resource type="StyleBoxTexture" id=2]
texture = SubResource( 1 )
region_rect = Rect2( 0, 0, 2048, 1 )
[sub_resource type="DynamicFontData" id=3]
font_path = "res://godot_resources/assets/roboto.ttf"
[sub_resource type="DynamicFont" id=4]
size = 27
font_data = SubResource( 3 )
[sub_resource type="Theme" id=5]
default_font = SubResource( 4 )
[node name="Parent" type="Node2D"]
script = ExtResource( 1 )
[node name="KomputeNode" type="Node2D" parent="."]
script = ExtResource( 2 )
[node name="UI" type="Node" parent="."]
[node name="UIVBoxContainer" type="VBoxContainer" parent="UI"]
anchor_right = 1.0
anchor_bottom = 1.0
theme = ExtResource( 4 )
__meta__ = {
"_edit_use_anchors_": false
}
[node name="TitleLabel" type="Label" parent="UI/UIVBoxContainer"]
margin_right = 1024.0
margin_bottom = 60.0
text = "Godot ML Kompute "
align = 1
[node name="LogoHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
margin_top = 64.0
margin_right = 1024.0
margin_bottom = 160.0
alignment = 1
[node name="TextureRect" type="TextureRect" parent="UI/UIVBoxContainer/LogoHBoxContainer"]
margin_left = 464.0
margin_right = 560.0
margin_bottom = 96.0
texture = ExtResource( 3 )
[node name="XIHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
margin_top = 164.0
margin_right = 1024.0
margin_bottom = 234.0
[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"]
margin_right = 20.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="Label" type="Label" parent="UI/UIVBoxContainer/XIHBoxContainer"]
margin_left = 24.0
margin_top = 5.0
margin_right = 193.0
margin_bottom = 65.0
text = "Xi Input"
[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"]
margin_left = 197.0
margin_right = 217.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/XIHBoxContainer"]
margin_left = 221.0
margin_right = 1000.0
margin_bottom = 70.0
size_flags_horizontal = 3
text = "[ 0, 0, 1, 1, 1, 1 ]"
align = 1
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"]
margin_left = 1004.0
margin_right = 1024.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="XJHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
margin_top = 238.0
margin_right = 1024.0
margin_bottom = 308.0
[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"]
margin_right = 20.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="Label" type="Label" parent="UI/UIVBoxContainer/XJHBoxContainer"]
margin_left = 24.0
margin_top = 5.0
margin_right = 193.0
margin_bottom = 65.0
text = "Xj Input"
[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"]
margin_left = 197.0
margin_right = 217.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/XJHBoxContainer"]
margin_left = 221.0
margin_right = 1000.0
margin_bottom = 70.0
size_flags_horizontal = 3
text = "[ 0, 0, 0, 0, 1, 1 ]"
align = 1
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"]
margin_left = 1004.0
margin_right = 1024.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="YHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
margin_top = 312.0
margin_right = 1024.0
margin_bottom = 382.0
[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"]
margin_right = 20.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="Label" type="Label" parent="UI/UIVBoxContainer/YHBoxContainer"]
margin_left = 24.0
margin_top = 5.0
margin_right = 192.0
margin_bottom = 65.0
text = "Y Input "
[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"]
margin_left = 196.0
margin_right = 216.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/YHBoxContainer"]
margin_left = 220.0
margin_right = 1000.0
margin_bottom = 70.0
size_flags_horizontal = 3
text = "[ 0, 0, 0, 0, 1, 1 ]"
align = 1
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"]
margin_left = 1004.0
margin_right = 1024.0
margin_bottom = 70.0
rect_min_size = Vector2( 20, 0 )
[node name="Button" type="Button" parent="UI/UIVBoxContainer"]
margin_top = 386.0
margin_right = 1024.0
margin_bottom = 452.0
text = "Kompute Train & Predict ML"
[node name="Panel" type="PanelContainer" parent="UI/UIVBoxContainer"]
margin_top = 456.0
margin_right = 1024.0
margin_bottom = 600.0
size_flags_vertical = 3
custom_styles/panel = SubResource( 2 )
[node name="VBoxContainer" type="VBoxContainer" parent="UI/UIVBoxContainer/Panel"]
margin_right = 1024.0
margin_bottom = 144.0
[node name="VSplitContainer2" type="VSplitContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
margin_right = 1024.0
margin_bottom = 10.0
rect_min_size = Vector2( 0, 10 )
[node name="PredHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
margin_top = 14.0
margin_right = 1024.0
margin_bottom = 47.0
theme = SubResource( 5 )
__meta__ = {
"_edit_use_anchors_": false
}
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_right = 20.0
margin_bottom = 33.0
rect_min_size = Vector2( 20, 0 )
[node name="Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 24.0
margin_right = 144.0
margin_bottom = 33.0
text = "Weight 1: "
[node name="Weight1Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 148.0
margin_right = 332.0
margin_bottom = 33.0
size_flags_horizontal = 3
text = "n/a"
align = 1
[node name="VSeparator4" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 336.0
margin_right = 356.0
margin_bottom = 33.0
rect_min_size = Vector2( 20, 0 )
[node name="VSeparator5" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 360.0
margin_right = 380.0
margin_bottom = 33.0
rect_min_size = Vector2( 20, 0 )
[node name="Label2" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 384.0
margin_right = 504.0
margin_bottom = 33.0
text = "Weight 2: "
[node name="Weight2Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 508.0
margin_right = 692.0
margin_bottom = 33.0
size_flags_horizontal = 3
text = "n/a"
align = 1
[node name="VSeparator6" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 696.0
margin_right = 716.0
margin_bottom = 33.0
rect_min_size = Vector2( 20, 0 )
[node name="VSeparator7" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 720.0
margin_right = 740.0
margin_bottom = 33.0
rect_min_size = Vector2( 20, 0 )
[node name="Label3" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 744.0
margin_right = 811.0
margin_bottom = 33.0
text = "Bias: "
[node name="BiasLabel" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 815.0
margin_right = 999.0
margin_bottom = 33.0
size_flags_horizontal = 3
text = "n/a"
align = 1
[node name="VSeparator8" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
margin_left = 1003.0
margin_right = 1023.0
margin_bottom = 33.0
rect_min_size = Vector2( 20, 0 )
[node name="VSplitContainer" type="VSplitContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
margin_top = 51.0
margin_right = 1024.0
margin_bottom = 71.0
rect_min_size = Vector2( 0, 20 )
[node name="PredHBoxContainer2" type="HBoxContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
margin_top = 75.0
margin_right = 1024.0
margin_bottom = 135.0
__meta__ = {
"_edit_use_anchors_": false
}
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"]
margin_right = 20.0
margin_bottom = 60.0
rect_min_size = Vector2( 20, 0 )
[node name="Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"]
margin_left = 24.0
margin_right = 399.0
margin_bottom = 60.0
text = "Prediction result:"
[node name="PredictionsLabel" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"]
margin_left = 403.0
margin_right = 1024.0
margin_bottom = 60.0
size_flags_horizontal = 3
text = "n/a"
align = 1
[connection signal="pressed" from="UI/UIVBoxContainer/Button" to="." method="compute_ml"]

View file

@ -1,29 +1,36 @@
extends Node2D
onready var xi_node = $UI/UIVBoxContainer/XIHBoxContainer/LineEdit
onready var xj_node = $UI/UIVBoxContainer/XJHBoxContainer/LineEdit
onready var y_node = $UI/UIVBoxContainer/XJHBoxContainer/LineEdit
onready var preds_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2/PredictionsLabel
onready var w1_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/Weight1Label
onready var w2_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/Weight2Label
onready var bias_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/BiasLabel
# Called when the node enters the scene tree for the first time.
func _ready():
pass
var xi = [0, 1, 1, 1, 1, 1]
var xj = [0, 0, 0, 1, 1, 1]
func compute_ml():
var y_train_1 = [0, 0, 0, 1, 1, 1]
var xi = str2var(xi_node.text)
var xj = str2var(xj_node.text)
var y = str2var(y_node.text)
print("Training with " + str(y_train_1))
$KomputeNode.train(y_train_1, xi, xj)
var s = KomputeModelML.new()
print("Now running prediction with " + str(xi) + " and " + str(xj))
print($KomputeNode.predict(xi, xj))
s.train(y, xi, xj)
# We can also reference the class as named in editor
# and create a new instance
var s = KomputeSummator.new()
var preds = s.predict(xi, xj)
# We can use a new prediciton value to see how weights change
var y_train_2 = [0, 0, 1, 1, 1, 1]
preds_node.text = str(preds)
var params = s.get_params()
w1_node.set_text(str(params[0]))
w2_node.set_text(str(params[1]))
bias_node.set_text(str(params[2]))
print("\nTraining with " + str(y_train_2))
s.train(y_train_2, xi, xj)
print("Now running prediction with " + str(xi) + " and " + str(xj))
print(s.predict(xi, xj))

View file

@ -3,6 +3,6 @@
[ext_resource path="res://godot_resources/scripts/KomputeNativeLibrary.gdnlib" type="GDNativeLibrary" id=1]
[resource]
class_name = "KomputeSummator"
class_name = "KomputeModelML"
library = ExtResource( 1 )
script_class_name = "KomputeSummator"
script_class_name = "KomputeModelML"

View file

@ -9,13 +9,13 @@
config_version=4
_global_script_classes=[ {
"base": "",
"class": "KomputeSummator",
"base": "Node2D",
"class": "KomputeModelML",
"language": "NativeScript",
"path": "res://godot_resources/scripts/KomputeNativeClass.gdns"
} ]
_global_script_class_icons={
"KomputeSummator": ""
"KomputeModelML": ""
}
[application]