Added test and updated LR tests to confirm works

This commit is contained in:
Alejandro Saucedo 2021-02-13 19:38:15 +00:00
parent 0b84876c95
commit 6b62990dbc
2 changed files with 55 additions and 4 deletions

View file

@ -43,14 +43,16 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv");
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
#else
sq->record<kp::OpAlgoBase>(
params,
std::vector<char>(
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len));
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
@ -126,14 +128,16 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv");
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
#else
sq->record<kp::OpAlgoBase>(
params,
std::vector<char>(
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len));
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });