Updated tests to match simplified specialisation constants

This commit is contained in:
Alejandro Saucedo 2021-02-14 06:55:24 +00:00
parent a7801cedd0
commit 119cdb2886
4 changed files with 348 additions and 358 deletions

View file

@ -43,8 +43,8 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
params, "test/shaders/glsl/test_logistic_regression.comp",
kp::OpAlgoBase::KomputeWorkgroup(), std::vector<float>({5.0}));
#else
sq->record<kp::OpAlgoBase>(
params,
@ -52,7 +52,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
kp::OpAlgoBase::KomputeWorkgroup(), std::vector<float>({5.0}));
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
@ -137,7 +137,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
kp::OpAlgoBase::KomputeWorkgroup(), std::vector<float>({5.0}));
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });