Added further tests including lr

This commit is contained in:
Alejandro Saucedo 2021-01-21 21:51:37 +00:00
parent fbfef16105
commit d620093a27
2 changed files with 24 additions and 3 deletions

View file

@ -3,6 +3,8 @@
#include "kompute/Kompute.hpp"
#include "kompute_test/shaders/shadertest_logistic_regression.hpp"
TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
{
@ -44,8 +46,18 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
sq->record<kp::OpTensorSyncDevice>({ wIn, bIn });
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv");
#else
sq->record<kp::OpAlgoBase>(
params,
std::vector<char>(
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::
shaders_glsl_logisticregression_comp_spv_len));
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
@ -123,8 +135,18 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
// Record op algo base
sq->begin();
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv");
#else
sq->record<kp::OpAlgoBase>(
params,
std::vector<char>(
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::
shaders_glsl_logisticregression_comp_spv_len));
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });