From f507439eb7b01adbbaf5221257761577aac8cc4b Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Sun, 14 Feb 2021 07:18:46 +0000 Subject: [PATCH] Updated tests to match kompute workgroup --- test/TestLogisticRegression.cpp | 8 ++++---- test/TestMultipleAlgoExecutions.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/TestLogisticRegression.cpp b/test/TestLogisticRegression.cpp index 4dfb6334f..aa728641f 100644 --- a/test/TestLogisticRegression.cpp +++ b/test/TestLogisticRegression.cpp @@ -44,7 +44,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression) #ifdef KOMPUTE_SHADER_FROM_STRING sq->record( params, "test/shaders/glsl/test_logistic_regression.comp", - kp::OpAlgoBase::KomputeWorkgroup(), std::vector({5.0})); + std::array(), std::vector({5.0})); #else sq->record( params, @@ -52,7 +52,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegression) kp::shader_data::shaders_glsl_logisticregression_comp_spv, kp::shader_data::shaders_glsl_logisticregression_comp_spv + kp::shader_data::shaders_glsl_logisticregression_comp_spv_len), - kp::OpAlgoBase::KomputeWorkgroup(), std::vector({5.0})); + std::array(), std::vector({5.0})); #endif sq->record({ wOutI, wOutJ, bOut, lOut }); @@ -129,7 +129,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy) #ifdef KOMPUTE_SHADER_FROM_STRING sq->record( params, "test/shaders/glsl/test_logistic_regression.comp.spv", - kp::OpAlgoBase::KomputeWorkgroup(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}}); + std::array(), kp::Algorithm::SpecializationContainer{{(uint32_t)5}}); #else sq->record( params, @@ -137,7 +137,7 @@ TEST(TestLogisticRegressionAlgorithm, TestMainLogisticRegressionManualCopy) kp::shader_data::shaders_glsl_logisticregression_comp_spv, kp::shader_data::shaders_glsl_logisticregression_comp_spv + kp::shader_data::shaders_glsl_logisticregression_comp_spv_len), - kp::OpAlgoBase::KomputeWorkgroup(), std::vector({5.0})); + std::array(), std::vector({5.0})); #endif sq->record({ wOutI, wOutJ, bOut, lOut }); diff --git a/test/TestMultipleAlgoExecutions.cpp b/test/TestMultipleAlgoExecutions.cpp index 859a742e5..127550d50 100644 --- a/test/TestMultipleAlgoExecutions.cpp +++ b/test/TestMultipleAlgoExecutions.cpp @@ -385,7 +385,7 @@ TEST(TestMultipleAlgoExecutions, TestAlgorithmSpecialized) sq->record( { tensorA, tensorB }, std::vector(shader.begin(), shader.end()), - kp::OpAlgoBase::KomputeWorkgroup(), spec); + std::array(), spec); sq->end(); sq->eval();