Updated tests to match kompute workgroup

This commit is contained in:
Alejandro Saucedo 2021-02-14 07:18:46 +00:00
parent c8370e0a3a
commit f507439eb7
2 changed files with 5 additions and 5 deletions

View file

@ -44,7 +44,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp",
kp::OpAlgoBase::KomputeWorkgroup(), std::vector<float>({5.0}));
std::array<uint32_t, 3>(), std::vector<float>({5.0}));
#else
sq->record<kp::OpAlgoBase>(
params,
@ -52,7 +52,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::OpAlgoBase::KomputeWorkgroup(), std::vector<float>({5.0}));
std::array<uint32_t, 3>(), std::vector<float>({5.0}));
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
@ -129,7 +129,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
std::array<uint32_t, 3>(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
#else
sq->record<kp::OpAlgoBase>(
params,
@ -137,7 +137,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::OpAlgoBase::KomputeWorkgroup(), std::vector<float>({5.0}));
std::array<uint32_t, 3>(), std::vector<float>({5.0}));
#endif
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });

View file

@ -385,7 +385,7 @@ TEST(TestMultipleAlgoExecutions, TestAlgorithmSpecialized)
sq->record<kp::OpAlgoBase>(
{ tensorA, tensorB },
std::vector<char>(shader.begin(), shader.end()),
kp::OpAlgoBase::KomputeWorkgroup(), spec);
std::array<uint32_t, 3>(), spec);
sq->end();
sq->eval();