Updated to use all uint32_t to avoid ambiguity on passing strings

This commit is contained in:
Alejandro Saucedo 2021-02-20 18:09:02 +00:00
parent 5bc3ac9c06
commit 56d9a3a933
18 changed files with 65 additions and 68 deletions

View file

@ -43,10 +43,10 @@ TEST(TestLogisticRegression, TestMainLogisticRegression)
sq->record<kp::OpAlgoBase>(
params,
std::vector<char>(
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
std::vector<uint32_t>(
(uint32_t*)kp::shader_data::shaders_glsl_logisticregression_comp_spv,
(uint32_t*)(kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len)),
kp::Workgroup(), kp::Constants({5.0}));
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });
@ -81,7 +81,7 @@ TEST(TestLogisticRegression, TestMainLogisticRegression)
bIn->data()[0]);
}
TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
TEST(TestLogisticRegression, TestMainLogisticRegressionManualCopy)
{
uint32_t ITERATIONS = 100;
@ -120,19 +120,13 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy)
// Record op algo base
sq->begin();
#ifdef KOMPUTE_SHADER_FROM_STRING
sq->record<kp::OpAlgoBase>(
params, "test/shaders/glsl/test_logistic_regression.comp.spv",
kp::Workgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}});
#else
sq->record<kp::OpAlgoBase>(
params,
std::vector<char>(
kp::shader_data::shaders_glsl_logisticregression_comp_spv,
kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len),
kp::Workgroup(), kp::Constants({5.0}));
#endif
std::vector<uint32_t>(
(uint32_t*)kp::shader_data::shaders_glsl_logisticregression_comp_spv,
(uint32_t*)(kp::shader_data::shaders_glsl_logisticregression_comp_spv +
kp::shader_data::shaders_glsl_logisticregression_comp_spv_len)),
kp::Workgroup(), kp::Constants({5.0}));
sq->record<kp::OpTensorSyncLocal>({ wOutI, wOutJ, bOut, lOut });