Fixed compiling tests with the new test layout

Signed-off-by: Fabian Sauter <sauter.fabian@mailbox.org>
This commit is contained in:
Fabian Sauter 2022-05-19 13:06:55 +02:00
parent b95df8d0a0
commit 34f9d58722
22 changed files with 216 additions and 69 deletions

View file

@ -4,7 +4,7 @@
#include "kompute/Kompute.hpp"
#include "test_logistic_regression.hpp"
#include "test_logistic_regression_shader.hpp"
TEST(TestLogisticRegression, TestMainLogisticRegression)
{
@ -39,12 +39,11 @@ TEST(TestLogisticRegression, TestMainLogisticRegression)
mgr.sequence()->eval<kp::OpTensorSyncDevice>(params);
std::vector<uint32_t> spirv = std::vector<uint32_t>(
(const uint32_t*)kp::TEST_LOGISTIC_REGRESSION_COMP_SPV.data(),
(const uint32_t*)(kp::shader_data::
test_shaders_glsl_test_logistic_regression_comp_spv +
kp::shader_data::
test_shaders_glsl_test_logistic_regression_comp_spv_len));
std::vector<uint32_t> spirv2{ 0x1, 0x2 };
std::vector<uint32_t> spirv(
kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.begin(),
kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.end());
std::shared_ptr<kp::Algorithm> algorithm = mgr.algorithm(
params, spirv, kp::Workgroup({ 5 }), std::vector<float>({ 5.0 }));
@ -57,7 +56,6 @@ TEST(TestLogisticRegression, TestMainLogisticRegression)
// Iterate across all expected iterations
for (size_t i = 0; i < ITERATIONS; i++) {
sq->eval();
for (size_t j = 0; j < bOut->size(); j++) {
@ -118,12 +116,9 @@ TEST(TestLogisticRegression, TestMainLogisticRegressionManualCopy)
mgr.sequence()->record<kp::OpTensorSyncDevice>(params)->eval();
std::vector<uint32_t> spirv = std::vector<uint32_t>(
(uint32_t*)kp::shader_data::shaders_glsl_logisticregression_comp_spv,
(uint32_t*)(kp::shader_data::
shaders_glsl_logisticregression_comp_spv +
kp::shader_data::
shaders_glsl_logisticregression_comp_spv_len));
std::vector<uint32_t> spirv(
kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.begin(),
kp::TEST_LOGISTIC_REGRESSION_SHADER_COMP_SPV.end());
std::shared_ptr<kp::Algorithm> algorithm = mgr.algorithm(
params, spirv, kp::Workgroup(), std::vector<float>({ 5.0 }));
@ -161,3 +156,16 @@ TEST(TestLogisticRegression, TestMainLogisticRegressionManualCopy)
bIn->data()[0]);
}
}
int
main(int argc, char* argv[])
{
testing::InitGoogleTest(&argc, argv);
#if KOMPUTE_ENABLE_SPDLOG
spdlog::set_level(
static_cast<spdlog::level::level_enum>(KOMPUTE_LOG_LEVEL));
#endif
return RUN_ALL_TESTS();
}