updated tests on constants and workgroup typedefs
This commit is contained in:
parent
9adfa34fd3
commit
e481c3afac
2 changed files with 7 additions and 7 deletions
|
|
@ -44,7 +44,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
|
|||
#ifdef KOMPUTE_SHADER_FROM_STRING
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, "test/shaders/glsl/test_logistic_regression.comp",
|
||||
std::array<uint32_t, 3>(), std::vector<float>({5.0}));
|
||||
kp::Workgroup(), kp::Constants({5.0}));
|
||||
#else
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params,
|
||||
|
|
@ -52,7 +52,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression)
|
|||
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
|
||||
std::array<uint32_t, 3>(), std::vector<float>({5.0}));
|
||||
kp::Workgroup(), kp::Constants({5.0}));
|
||||
#endif
|
||||
|
||||
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
|
||||
|
|
@ -93,7 +93,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
|
|||
uint32_t ITERATIONS = 100;
|
||||
float learningRate = 0.1;
|
||||
|
||||
std::vector<float> wInVec = { 0.001, 0.001 };
|
||||
kp::Constants wInVec = { 0.001, 0.001 };
|
||||
std::vector<float> bInVec = { 0 };
|
||||
|
||||
std::shared_ptr<kp::Tensor> xI{ new kp::Tensor({ 0, 1, 1, 1, 1 }) };
|
||||
|
|
@ -129,7 +129,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
|
|||
#ifdef KOMPUTE_SHADER_FROM_STRING
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
|
||||
std::array<uint32_t, 3>(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
|
||||
kp::Workgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
|
||||
#else
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params,
|
||||
|
|
@ -137,7 +137,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
|
|||
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
|
||||
std::array<uint32_t, 3>(), std::vector<float>({5.0}));
|
||||
kp::Workgroup(), kp::Constants({5.0}));
|
||||
#endif
|
||||
|
||||
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue