Updated and renamed classes for lr example
This commit is contained in:
parent
143baa4db3
commit
8959d90fa6
37 changed files with 402 additions and 60 deletions
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "KomputeSummatorNode.h"
|
||||
#include "KomputeModelMLNode.h"
|
||||
|
||||
KomputeSummatorNode::KomputeSummatorNode() {
|
||||
KomputeModelMLNode::KomputeModelMLNode() {
|
||||
this->_init();
|
||||
}
|
||||
|
||||
void KomputeSummatorNode::add(float value) {
|
||||
void KomputeModelMLNode::add(float value) {
|
||||
// Set the new data in the local device
|
||||
this->mSecondaryTensor->setData({value});
|
||||
// Execute recorded sequence
|
||||
|
|
@ -20,14 +20,14 @@ void KomputeSummatorNode::add(float value) {
|
|||
}
|
||||
}
|
||||
|
||||
void KomputeSummatorNode::reset() {
|
||||
void KomputeModelMLNode::reset() {
|
||||
}
|
||||
|
||||
float KomputeSummatorNode::get_total() const {
|
||||
float KomputeModelMLNode::get_total() const {
|
||||
return this->mPrimaryTensor->data()[0];
|
||||
}
|
||||
|
||||
void KomputeSummatorNode::_init() {
|
||||
void KomputeModelMLNode::_init() {
|
||||
std::cout << "CALLING INIT" << std::endl;
|
||||
this->mPrimaryTensor = this->mManager.buildTensor({ 0.0 });
|
||||
this->mSecondaryTensor = this->mManager.buildTensor({ 0.0 });
|
||||
|
|
@ -74,16 +74,16 @@ void KomputeSummatorNode::_init() {
|
|||
}
|
||||
}
|
||||
|
||||
void KomputeSummatorNode::_process(float delta) {
|
||||
void KomputeModelMLNode::_process(float delta) {
|
||||
|
||||
}
|
||||
|
||||
void KomputeSummatorNode::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("_process", "delta"), &KomputeSummatorNode::_process);
|
||||
ClassDB::bind_method(D_METHOD("_init"), &KomputeSummatorNode::_init);
|
||||
void KomputeModelMLNode::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("_process", "delta"), &KomputeModelMLNode::_process);
|
||||
ClassDB::bind_method(D_METHOD("_init"), &KomputeModelMLNode::_init);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("add", "value"), &KomputeSummatorNode::add);
|
||||
ClassDB::bind_method(D_METHOD("reset"), &KomputeSummatorNode::reset);
|
||||
ClassDB::bind_method(D_METHOD("get_total"), &KomputeSummatorNode::get_total);
|
||||
ClassDB::bind_method(D_METHOD("add", "value"), &KomputeModelMLNode::add);
|
||||
ClassDB::bind_method(D_METHOD("reset"), &KomputeModelMLNode::reset);
|
||||
ClassDB::bind_method(D_METHOD("get_total"), &KomputeModelMLNode::get_total);
|
||||
}
|
||||
|
||||
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
#include "scene/main/node.h"
|
||||
|
||||
class KomputeSummatorNode : public Node {
|
||||
GDCLASS(KomputeSummatorNode, Node);
|
||||
class KomputeModelMLNode : public Node {
|
||||
GDCLASS(KomputeModelMLNode, Node);
|
||||
|
||||
public:
|
||||
KomputeSummatorNode();
|
||||
KomputeModelMLNode();
|
||||
|
||||
void add(float value);
|
||||
void reset();
|
||||
|
|
@ -3,10 +3,10 @@
|
|||
#include "register_types.h"
|
||||
|
||||
#include "core/class_db.h"
|
||||
#include "KomputeSummatorNode.h"
|
||||
#include "KomputeModelMLNode.h"
|
||||
|
||||
void register_kompute_summator_types() {
|
||||
ClassDB::register_class<KomputeSummatorNode>();
|
||||
ClassDB::register_class<KomputeModelMLNode>();
|
||||
}
|
||||
|
||||
void unregister_kompute_summator_types() {
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
/* register_types.h */
|
||||
#pragma once
|
||||
|
||||
void register_kompute_summator_types();
|
||||
void unregister_kompute_summator_types();
|
||||
void register_kompute_model_ml_types();
|
||||
void unregister_kompute_model_ml_types();
|
||||
/* yes, the word in the middle must be the same as the module folder name */
|
||||
|
|
@ -16,7 +16,7 @@ find_package(Vulkan REQUIRED)
|
|||
|
||||
add_library(kompute_godot
|
||||
SHARED
|
||||
src/KomputeSummator.cpp
|
||||
src/KomputeModelML.cpp
|
||||
src/KomputeGdNative.cpp)
|
||||
|
||||
target_include_directories(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
#include "KomputeSummator.hpp"
|
||||
#include "KomputeModelML.hpp"
|
||||
|
||||
extern "C" void GDN_EXPORT godot_gdnative_init(godot_gdnative_init_options *o) {
|
||||
godot::Godot::gdnative_init(o);
|
||||
|
|
@ -11,5 +11,5 @@ extern "C" void GDN_EXPORT godot_gdnative_terminate(godot_gdnative_terminate_opt
|
|||
extern "C" void GDN_EXPORT godot_nativescript_init(void *handle) {
|
||||
godot::Godot::nativescript_init(handle);
|
||||
|
||||
godot::register_class<godot::KomputeSummator>();
|
||||
godot::register_class<godot::KomputeModelML>();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,15 +4,16 @@
|
|||
#include <string>
|
||||
#include <iostream>
|
||||
|
||||
#include "KomputeSummator.hpp"
|
||||
#include "KomputeModelML.hpp"
|
||||
|
||||
namespace godot {
|
||||
|
||||
KomputeSummator::KomputeSummator() {
|
||||
|
||||
KomputeModelML::KomputeModelML() {
|
||||
std::cout << "CALLING CONSTRUCTOR" << std::endl;
|
||||
this->_init();
|
||||
}
|
||||
|
||||
void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
|
||||
void KomputeModelML::train(Array yArr, Array xIArr, Array xJArr) {
|
||||
|
||||
assert(y.size() == xI.size());
|
||||
assert(xI.size() == xJ.size());
|
||||
|
|
@ -22,7 +23,7 @@ void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
std::vector<float> xJData;
|
||||
std::vector<float> zerosData;
|
||||
|
||||
for (int i = 0; i < yArr.size(); i++) {
|
||||
for (size_t i = 0; i < yArr.size(); i++) {
|
||||
yData.push_back(yArr[i]);
|
||||
xIData.push_back(xIArr[i]);
|
||||
xJData.push_back(xJArr[i]);
|
||||
|
|
@ -76,11 +77,11 @@ void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
sq->end();
|
||||
|
||||
// Iterate across all expected iterations
|
||||
for (int i = 0; i < ITERATIONS; i++) {
|
||||
for (size_t i = 0; i < ITERATIONS; i++) {
|
||||
|
||||
sq->eval();
|
||||
|
||||
for (int j = 0; j < bOut->size(); j++) {
|
||||
for (size_t j = 0; j < bOut->size(); j++) {
|
||||
wIn->data()[0] -= learningRate * wOutI->data()[j];
|
||||
wIn->data()[1] -= learningRate * wOutJ->data()[j];
|
||||
bIn->data()[0] -= learningRate * bOut->data()[j];
|
||||
|
|
@ -98,7 +99,7 @@ void KomputeSummator::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
this->mBias = kp::Tensor(bIn->data());
|
||||
}
|
||||
|
||||
Array KomputeSummator::predict(Array xI, Array xJ) {
|
||||
Array KomputeModelML::predict(Array xI, Array xJ) {
|
||||
assert(xI.size() == xJ.size());
|
||||
|
||||
Array retArray;
|
||||
|
|
@ -106,7 +107,7 @@ Array KomputeSummator::predict(Array xI, Array xJ) {
|
|||
// We run the inference in the CPU for simplicity
|
||||
// BUt you can also implement the inference on GPU
|
||||
// GPU implementation would speed up minibatching
|
||||
for (int i = 0; i < xI.size(); i++) {
|
||||
for (size_t i = 0; i < xI.size(); i++) {
|
||||
float xIVal = xI[i];
|
||||
float xJVal = xJ[i];
|
||||
float result = (xIVal * this->mWeights.data()[0]
|
||||
|
|
@ -121,9 +122,38 @@ Array KomputeSummator::predict(Array xI, Array xJ) {
|
|||
return retArray;
|
||||
}
|
||||
|
||||
void KomputeSummator::_register_methods() {
|
||||
register_method((char *)"train", &KomputeSummator::train);
|
||||
register_method((char *)"predict", &KomputeSummator::predict);
|
||||
Array KomputeModelML::get_params() {
|
||||
Array retArray;
|
||||
|
||||
SPDLOG_INFO(this->mWeights.size() + this->mBias.size());
|
||||
|
||||
if(this->mWeights.size() + this->mBias.size() == 0) {
|
||||
return retArray;
|
||||
}
|
||||
|
||||
retArray.push_back(this->mWeights.data()[0]);
|
||||
retArray.push_back(this->mWeights.data()[1]);
|
||||
retArray.push_back(this->mBias.data()[0]);
|
||||
retArray.push_back(99.0);
|
||||
|
||||
return retArray;
|
||||
}
|
||||
|
||||
void KomputeModelML::_init() {
|
||||
std::cout << "CALLING INIT" << std::endl;
|
||||
}
|
||||
|
||||
void KomputeModelML::_process(float delta) {
|
||||
|
||||
}
|
||||
|
||||
void KomputeModelML::_register_methods() {
|
||||
register_method((char *)"_process", &KomputeModelML::_process);
|
||||
register_method((char *)"_init", &KomputeModelML::_init);
|
||||
|
||||
register_method((char *)"train", &KomputeModelML::train);
|
||||
register_method((char *)"predict", &KomputeModelML::predict);
|
||||
register_method((char *)"get_params", &KomputeModelML::get_params);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -9,16 +9,22 @@
|
|||
#include "kompute/Kompute.hpp"
|
||||
|
||||
namespace godot {
|
||||
class KomputeSummator : public Node2D {
|
||||
class KomputeModelML : public Node2D {
|
||||
private:
|
||||
GODOT_CLASS(KomputeSummator, Node2D);
|
||||
GODOT_CLASS(KomputeModelML, Node2D);
|
||||
|
||||
public:
|
||||
KomputeSummator();
|
||||
KomputeModelML();
|
||||
|
||||
void train(Array y, Array xI, Array xJ);
|
||||
|
||||
Array predict(Array xI, Array xJ);
|
||||
|
||||
Array get_params();
|
||||
|
||||
void _process(float delta);
|
||||
void _init();
|
||||
|
||||
static void _register_methods();
|
||||
|
||||
private:
|
||||
|
|
@ -1,10 +1,309 @@
|
|||
[gd_scene load_steps=3 format=2]
|
||||
[gd_scene load_steps=10 format=2]
|
||||
|
||||
[ext_resource path="res://godot_resources/scripts/DynamicExampleScript.gd" type="Script" id=1]
|
||||
[ext_resource path="res://godot_resources/scripts/KomputeNativeClass.gdns" type="Script" id=2]
|
||||
[ext_resource path="res://godot_resources/assets/icon.png" type="Texture" id=3]
|
||||
[ext_resource path="res://godot_resources/assets/TextFormat.theme" type="Theme" id=4]
|
||||
|
||||
[sub_resource type="GradientTexture" id=1]
|
||||
|
||||
[sub_resource type="StyleBoxTexture" id=2]
|
||||
texture = SubResource( 1 )
|
||||
region_rect = Rect2( 0, 0, 2048, 1 )
|
||||
|
||||
[sub_resource type="DynamicFontData" id=3]
|
||||
font_path = "res://godot_resources/assets/roboto.ttf"
|
||||
|
||||
[sub_resource type="DynamicFont" id=4]
|
||||
size = 27
|
||||
font_data = SubResource( 3 )
|
||||
|
||||
[sub_resource type="Theme" id=5]
|
||||
default_font = SubResource( 4 )
|
||||
|
||||
[node name="Parent" type="Node2D"]
|
||||
script = ExtResource( 1 )
|
||||
|
||||
[node name="KomputeNode" type="Node2D" parent="."]
|
||||
script = ExtResource( 2 )
|
||||
|
||||
[node name="UI" type="Node" parent="."]
|
||||
|
||||
[node name="UIVBoxContainer" type="VBoxContainer" parent="UI"]
|
||||
anchor_right = 1.0
|
||||
anchor_bottom = 1.0
|
||||
theme = ExtResource( 4 )
|
||||
__meta__ = {
|
||||
"_edit_use_anchors_": false
|
||||
}
|
||||
|
||||
[node name="TitleLabel" type="Label" parent="UI/UIVBoxContainer"]
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 60.0
|
||||
text = "Godot ML Kompute "
|
||||
align = 1
|
||||
|
||||
[node name="LogoHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
|
||||
margin_top = 64.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 160.0
|
||||
alignment = 1
|
||||
|
||||
[node name="TextureRect" type="TextureRect" parent="UI/UIVBoxContainer/LogoHBoxContainer"]
|
||||
margin_left = 464.0
|
||||
margin_right = 560.0
|
||||
margin_bottom = 96.0
|
||||
texture = ExtResource( 3 )
|
||||
|
||||
[node name="XIHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
|
||||
margin_top = 164.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 234.0
|
||||
|
||||
[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"]
|
||||
margin_right = 20.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label" type="Label" parent="UI/UIVBoxContainer/XIHBoxContainer"]
|
||||
margin_left = 24.0
|
||||
margin_top = 5.0
|
||||
margin_right = 193.0
|
||||
margin_bottom = 65.0
|
||||
text = "Xi Input"
|
||||
|
||||
[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"]
|
||||
margin_left = 197.0
|
||||
margin_right = 217.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/XIHBoxContainer"]
|
||||
margin_left = 221.0
|
||||
margin_right = 1000.0
|
||||
margin_bottom = 70.0
|
||||
size_flags_horizontal = 3
|
||||
text = "[ 0, 0, 1, 1, 1, 1 ]"
|
||||
align = 1
|
||||
|
||||
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/XIHBoxContainer"]
|
||||
margin_left = 1004.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="XJHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
|
||||
margin_top = 238.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 308.0
|
||||
|
||||
[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"]
|
||||
margin_right = 20.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label" type="Label" parent="UI/UIVBoxContainer/XJHBoxContainer"]
|
||||
margin_left = 24.0
|
||||
margin_top = 5.0
|
||||
margin_right = 193.0
|
||||
margin_bottom = 65.0
|
||||
text = "Xj Input"
|
||||
|
||||
[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"]
|
||||
margin_left = 197.0
|
||||
margin_right = 217.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/XJHBoxContainer"]
|
||||
margin_left = 221.0
|
||||
margin_right = 1000.0
|
||||
margin_bottom = 70.0
|
||||
size_flags_horizontal = 3
|
||||
text = "[ 0, 0, 0, 0, 1, 1 ]"
|
||||
align = 1
|
||||
|
||||
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/XJHBoxContainer"]
|
||||
margin_left = 1004.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="YHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer"]
|
||||
margin_top = 312.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 382.0
|
||||
|
||||
[node name="VSeparator" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"]
|
||||
margin_right = 20.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label" type="Label" parent="UI/UIVBoxContainer/YHBoxContainer"]
|
||||
margin_left = 24.0
|
||||
margin_top = 5.0
|
||||
margin_right = 192.0
|
||||
margin_bottom = 65.0
|
||||
text = "Y Input "
|
||||
|
||||
[node name="VSeparator2" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"]
|
||||
margin_left = 196.0
|
||||
margin_right = 216.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="LineEdit" type="LineEdit" parent="UI/UIVBoxContainer/YHBoxContainer"]
|
||||
margin_left = 220.0
|
||||
margin_right = 1000.0
|
||||
margin_bottom = 70.0
|
||||
size_flags_horizontal = 3
|
||||
text = "[ 0, 0, 0, 0, 1, 1 ]"
|
||||
align = 1
|
||||
|
||||
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/YHBoxContainer"]
|
||||
margin_left = 1004.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 70.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Button" type="Button" parent="UI/UIVBoxContainer"]
|
||||
margin_top = 386.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 452.0
|
||||
text = "Kompute Train & Predict ML"
|
||||
|
||||
[node name="Panel" type="PanelContainer" parent="UI/UIVBoxContainer"]
|
||||
margin_top = 456.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 600.0
|
||||
size_flags_vertical = 3
|
||||
custom_styles/panel = SubResource( 2 )
|
||||
|
||||
[node name="VBoxContainer" type="VBoxContainer" parent="UI/UIVBoxContainer/Panel"]
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 144.0
|
||||
|
||||
[node name="VSplitContainer2" type="VSplitContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 10.0
|
||||
rect_min_size = Vector2( 0, 10 )
|
||||
|
||||
[node name="PredHBoxContainer" type="HBoxContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
|
||||
margin_top = 14.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 47.0
|
||||
theme = SubResource( 5 )
|
||||
__meta__ = {
|
||||
"_edit_use_anchors_": false
|
||||
}
|
||||
|
||||
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_right = 20.0
|
||||
margin_bottom = 33.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 24.0
|
||||
margin_right = 144.0
|
||||
margin_bottom = 33.0
|
||||
text = "Weight 1: "
|
||||
|
||||
[node name="Weight1Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 148.0
|
||||
margin_right = 332.0
|
||||
margin_bottom = 33.0
|
||||
size_flags_horizontal = 3
|
||||
text = "n/a"
|
||||
align = 1
|
||||
|
||||
[node name="VSeparator4" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 336.0
|
||||
margin_right = 356.0
|
||||
margin_bottom = 33.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="VSeparator5" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 360.0
|
||||
margin_right = 380.0
|
||||
margin_bottom = 33.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label2" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 384.0
|
||||
margin_right = 504.0
|
||||
margin_bottom = 33.0
|
||||
text = "Weight 2: "
|
||||
|
||||
[node name="Weight2Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 508.0
|
||||
margin_right = 692.0
|
||||
margin_bottom = 33.0
|
||||
size_flags_horizontal = 3
|
||||
text = "n/a"
|
||||
align = 1
|
||||
|
||||
[node name="VSeparator6" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 696.0
|
||||
margin_right = 716.0
|
||||
margin_bottom = 33.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="VSeparator7" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 720.0
|
||||
margin_right = 740.0
|
||||
margin_bottom = 33.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label3" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 744.0
|
||||
margin_right = 811.0
|
||||
margin_bottom = 33.0
|
||||
text = "Bias: "
|
||||
|
||||
[node name="BiasLabel" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 815.0
|
||||
margin_right = 999.0
|
||||
margin_bottom = 33.0
|
||||
size_flags_horizontal = 3
|
||||
text = "n/a"
|
||||
align = 1
|
||||
|
||||
[node name="VSeparator8" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer"]
|
||||
margin_left = 1003.0
|
||||
margin_right = 1023.0
|
||||
margin_bottom = 33.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="VSplitContainer" type="VSplitContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
|
||||
margin_top = 51.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 71.0
|
||||
rect_min_size = Vector2( 0, 20 )
|
||||
|
||||
[node name="PredHBoxContainer2" type="HBoxContainer" parent="UI/UIVBoxContainer/Panel/VBoxContainer"]
|
||||
margin_top = 75.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 135.0
|
||||
__meta__ = {
|
||||
"_edit_use_anchors_": false
|
||||
}
|
||||
|
||||
[node name="VSeparator3" type="VSeparator" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"]
|
||||
margin_right = 20.0
|
||||
margin_bottom = 60.0
|
||||
rect_min_size = Vector2( 20, 0 )
|
||||
|
||||
[node name="Label" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"]
|
||||
margin_left = 24.0
|
||||
margin_right = 399.0
|
||||
margin_bottom = 60.0
|
||||
text = "Prediction result:"
|
||||
|
||||
[node name="PredictionsLabel" type="Label" parent="UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2"]
|
||||
margin_left = 403.0
|
||||
margin_right = 1024.0
|
||||
margin_bottom = 60.0
|
||||
size_flags_horizontal = 3
|
||||
text = "n/a"
|
||||
align = 1
|
||||
[connection signal="pressed" from="UI/UIVBoxContainer/Button" to="." method="compute_ml"]
|
||||
|
|
|
|||
BIN
examples/godot_logistic_regression/godot_resources/assets/TextFormat.theme
Executable file
BIN
examples/godot_logistic_regression/godot_resources/assets/TextFormat.theme
Executable file
Binary file not shown.
BIN
examples/godot_logistic_regression/godot_resources/assets/roboto.ttf
Executable file
BIN
examples/godot_logistic_regression/godot_resources/assets/roboto.ttf
Executable file
Binary file not shown.
|
|
@ -1,29 +1,36 @@
|
|||
extends Node2D
|
||||
|
||||
onready var xi_node = $UI/UIVBoxContainer/XIHBoxContainer/LineEdit
|
||||
onready var xj_node = $UI/UIVBoxContainer/XJHBoxContainer/LineEdit
|
||||
onready var y_node = $UI/UIVBoxContainer/XJHBoxContainer/LineEdit
|
||||
onready var preds_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer2/PredictionsLabel
|
||||
onready var w1_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/Weight1Label
|
||||
onready var w2_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/Weight2Label
|
||||
onready var bias_node = $UI/UIVBoxContainer/Panel/VBoxContainer/PredHBoxContainer/BiasLabel
|
||||
|
||||
# Called when the node enters the scene tree for the first time.
|
||||
func _ready():
|
||||
pass
|
||||
|
||||
var xi = [0, 1, 1, 1, 1, 1]
|
||||
var xj = [0, 0, 0, 1, 1, 1]
|
||||
func compute_ml():
|
||||
|
||||
var y_train_1 = [0, 0, 0, 1, 1, 1]
|
||||
var xi = str2var(xi_node.text)
|
||||
var xj = str2var(xj_node.text)
|
||||
var y = str2var(y_node.text)
|
||||
|
||||
print("Training with " + str(y_train_1))
|
||||
$KomputeNode.train(y_train_1, xi, xj)
|
||||
var s = KomputeModelML.new()
|
||||
|
||||
print("Now running prediction with " + str(xi) + " and " + str(xj))
|
||||
print($KomputeNode.predict(xi, xj))
|
||||
s.train(y, xi, xj)
|
||||
|
||||
# We can also reference the class as named in editor
|
||||
# and create a new instance
|
||||
var s = KomputeSummator.new()
|
||||
var preds = s.predict(xi, xj)
|
||||
|
||||
# We can use a new prediciton value to see how weights change
|
||||
var y_train_2 = [0, 0, 1, 1, 1, 1]
|
||||
preds_node.text = str(preds)
|
||||
|
||||
var params = s.get_params()
|
||||
|
||||
w1_node.set_text(str(params[0]))
|
||||
w2_node.set_text(str(params[1]))
|
||||
bias_node.set_text(str(params[2]))
|
||||
|
||||
print("\nTraining with " + str(y_train_2))
|
||||
s.train(y_train_2, xi, xj)
|
||||
|
||||
print("Now running prediction with " + str(xi) + " and " + str(xj))
|
||||
print(s.predict(xi, xj))
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,6 @@
|
|||
[ext_resource path="res://godot_resources/scripts/KomputeNativeLibrary.gdnlib" type="GDNativeLibrary" id=1]
|
||||
|
||||
[resource]
|
||||
class_name = "KomputeSummator"
|
||||
class_name = "KomputeModelML"
|
||||
library = ExtResource( 1 )
|
||||
script_class_name = "KomputeSummator"
|
||||
script_class_name = "KomputeModelML"
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@
|
|||
config_version=4
|
||||
|
||||
_global_script_classes=[ {
|
||||
"base": "",
|
||||
"class": "KomputeSummator",
|
||||
"base": "Node2D",
|
||||
"class": "KomputeModelML",
|
||||
"language": "NativeScript",
|
||||
"path": "res://godot_resources/scripts/KomputeNativeClass.gdns"
|
||||
} ]
|
||||
_global_script_class_icons={
|
||||
"KomputeSummator": ""
|
||||
"KomputeModelML": ""
|
||||
}
|
||||
|
||||
[application]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue