Updated to use all uint32_t to avoid ambiguity on passing strings
This commit is contained in:
parent
5bc3ac9c06
commit
56d9a3a933
18 changed files with 65 additions and 68 deletions
|
|
@ -43,10 +43,10 @@ TEST(TestLogisticRegression, TestMainLogisticRegression)
|
|||
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params,
|
||||
std::vector<char>(
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
|
||||
std::vector<uint32_t>(
|
||||
(uint32_t*)kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
(uint32_t*)(kp::shader_data::shaders_glsl_logisticregression_comp_spv +
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len)),
|
||||
kp::Workgroup(), kp::Constants({5.0}));
|
||||
|
||||
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
|
||||
|
|
@ -81,7 +81,7 @@ TEST(TestLogisticRegression, TestMainLogisticRegression)
|
|||
bIn->data()[0]);
|
||||
}
|
||||
|
||||
TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
|
||||
TEST(TestLogisticRegression, TestMainLogisticRegressionManualCopy)
|
||||
{
|
||||
|
||||
uint32_t ITERATIONS = 100;
|
||||
|
|
@ -120,19 +120,13 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
|
|||
// Record op algo base
|
||||
sq->begin();
|
||||
|
||||
#ifdef KOMPUTE_SHADER_FROM_STRING
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
|
||||
kp::Workgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
|
||||
#else
|
||||
sq->record<kp::OpAlgoBase>(
|
||||
params,
|
||||
std::vector<char>(
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
|
||||
kp::Workgroup(), kp::Constants({5.0}));
|
||||
#endif
|
||||
std::vector<uint32_t>(
|
||||
(uint32_t*)kp::shader_data::shaders_glsl_logisticregression_comp_spv,
|
||||
(uint32_t*)(kp::shader_data::shaders_glsl_logisticregression_comp_spv +
|
||||
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len)),
|
||||
kp::Workgroup(), kp::Constants({5.0}));
|
||||
|
||||
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue