Updated logistic regression model
This commit is contained in:
parent
5822850ef2
commit
8285f2f878
5 changed files with 16 additions and 26 deletions
|
|
@ -24,7 +24,7 @@ protected:
|
|||
|
||||
private:
|
||||
kp::Manager mManager;
|
||||
std::weak_ptr<kp::Sequence> mSequence;
|
||||
std::shared_ptr<kp::Sequence> mSequence;
|
||||
std::shared_ptr<kp::Tensor> mPrimaryTensor;
|
||||
std::shared_ptr<kp::Tensor> mSecondaryTensor;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -16,12 +16,7 @@ void KomputeSummator::add(float value) {
|
|||
// Set the new data in the local device
|
||||
this->mSecondaryTensor->setData({value});
|
||||
// Execute recorded sequence
|
||||
if (std::shared_ptr<kp::Sequence> sq = this->mSequence.lock()) {
|
||||
sq->eval();
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Sequence pointer no longer available");
|
||||
}
|
||||
this->mSequence->eval();
|
||||
}
|
||||
|
||||
void KomputeSummator::reset() {
|
||||
|
|
@ -38,9 +33,7 @@ void KomputeSummator::_init() {
|
|||
this->mSequence = this->mManager.getOrCreateManagedSequence("AdditionSeq");
|
||||
|
||||
// We now record the steps in the sequence
|
||||
if (std::shared_ptr<kp::Sequence> sq = this->mSequence.lock())
|
||||
{
|
||||
|
||||
std::string shader(R"(
|
||||
#version 450
|
||||
|
||||
|
|
@ -55,26 +48,23 @@ void KomputeSummator::_init() {
|
|||
}
|
||||
)");
|
||||
|
||||
sq->begin();
|
||||
this->mSequence->begin();
|
||||
|
||||
// First we ensure secondary tensor loads to GPU
|
||||
// No need to sync the primary tensor as it should not be changed
|
||||
sq->record<kp::OpTensorSyncDevice>(
|
||||
this->mSequence->record<kp::OpTensorSyncDevice>(
|
||||
{ this->mSecondaryTensor });
|
||||
|
||||
// Then we run the operation with both tensors
|
||||
sq->record<kp::OpAlgoBase<>>(
|
||||
this->mSequence->record<kp::OpAlgoBase>(
|
||||
{ this->mPrimaryTensor, this->mSecondaryTensor },
|
||||
std::vector<char>(shader.begin(), shader.end()));
|
||||
|
||||
// We map the result back to local
|
||||
sq->record<kp::OpTensorSyncLocal>(
|
||||
this->mSequence->record<kp::OpTensorSyncLocal>(
|
||||
{ this->mPrimaryTensor });
|
||||
|
||||
sq->end();
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Sequence pointer no longer available");
|
||||
this->mSequence->end();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ public:
|
|||
|
||||
private:
|
||||
kp::Manager mManager;
|
||||
std::weak_ptr<kp::Sequence> mSequence;
|
||||
std::shared_ptr<kp::Sequence> mSequence;
|
||||
std::shared_ptr<kp::Tensor> mPrimaryTensor;
|
||||
std::shared_ptr<kp::Tensor> mSecondaryTensor;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -51,14 +51,14 @@ void KomputeModelMLNode::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
kp::Manager mgr;
|
||||
|
||||
std::shared_ptr<kp::Sequence> sqTensor =
|
||||
mgr.createManagedSequence().lock();
|
||||
mgr.createManagedSequence();
|
||||
|
||||
sqTensor->begin();
|
||||
sqTensor->record<kp::OpTensorCreate>(params);
|
||||
sqTensor->end();
|
||||
sqTensor->eval();
|
||||
|
||||
std::shared_ptr<kp::Sequence> sq = mgr.createManagedSequence().lock();
|
||||
std::shared_ptr<kp::Sequence> sq = mgr.createManagedSequence();
|
||||
|
||||
// Record op algo base
|
||||
sq->begin();
|
||||
|
|
@ -67,11 +67,11 @@ void KomputeModelMLNode::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
|
||||
#ifdef KOMPUTE_ANDROID_SHADER_FROM_STRING
|
||||
// Newer versions of Android are able to use shaderc to read raw string
|
||||
sq->record<kp::OpAlgoBase<>>(
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, std::vector<char>(LR_SHADER.begin(), LR_SHADER.end()));
|
||||
#else
|
||||
// Older versions of Android require the SPIRV binary directly
|
||||
sq->record<kp::OpAlgoBase<>>(
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, std::vector<char>(
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv
|
||||
|
|
|
|||
|
|
@ -56,14 +56,14 @@ void KomputeModelML::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
|
||||
{
|
||||
std::shared_ptr<kp::Sequence> sqTensor =
|
||||
mgr.createManagedSequence().lock();
|
||||
mgr.createManagedSequence();
|
||||
|
||||
sqTensor->begin();
|
||||
sqTensor->record<kp::OpTensorCreate>(params);
|
||||
sqTensor->end();
|
||||
sqTensor->eval();
|
||||
|
||||
std::shared_ptr<kp::Sequence> sq = mgr.createManagedSequence().lock();
|
||||
std::shared_ptr<kp::Sequence> sq = mgr.createManagedSequence();
|
||||
|
||||
// Record op algo base
|
||||
sq->begin();
|
||||
|
|
@ -72,11 +72,11 @@ void KomputeModelML::train(Array yArr, Array xIArr, Array xJArr) {
|
|||
|
||||
#ifdef KOMPUTE_ANDROID_SHADER_FROM_STRING
|
||||
// Newer versions of Android are able to use shaderc to read raw string
|
||||
sq->record<kp::OpAlgoBase<>>(
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, std::vector<char>(LR_SHADER.begin(), LR_SHADER.end()));
|
||||
#else
|
||||
// Older versions of Android require the SPIRV binary directly
|
||||
sq->record<kp::OpAlgoBase<>>(
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, std::vector<char>(
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue