From 0e35ea1dcf4c742284cf35aa242ae7873a7de293 Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Sat, 11 Sep 2021 16:51:33 +0100 Subject: [PATCH] Added test for errors in sequence Signed-off-by: Alejandro Saucedo --- test/TestSequence.cpp | 57 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/test/TestSequence.cpp b/test/TestSequence.cpp index 631632ad1..6e9687562 100644 --- a/test/TestSequence.cpp +++ b/test/TestSequence.cpp @@ -185,3 +185,60 @@ TEST(TestSequence, UtilsClearRecordingRunning) EXPECT_EQ(tensorOut->vector(), std::vector({ 2, 4, 6 })); } + +TEST(TestSequence, CorrectSequenceRunningError) +{ + kp::Manager mgr; + + std::shared_ptr sq = mgr.sequence(); + + std::shared_ptr> tensorA = mgr.tensor({ 1, 2, 3 }); + std::shared_ptr> tensorB = mgr.tensor({ 2, 2, 2 }); + std::shared_ptr> tensorOut = mgr.tensor({ 0, 0, 0 }); + + sq->eval({ tensorA, tensorB, tensorOut }); + + std::vector spirv = compileSource(R"( + #version 450 + + layout (local_size_x = 1) in; + + // The input tensors bind index is relative to index in parameter passed + layout(set = 0, binding = 0) buffer bina { float tina[]; }; + layout(set = 0, binding = 1) buffer binb { float tinb[]; }; + layout(set = 0, binding = 2) buffer bout { float tout[]; }; + + void main() { + uint index = gl_GlobalInvocationID.x; + tout[index] = tina[index] * tinb[index]; + } + )"); + + std::shared_ptr algo = + mgr.algorithm({ tensorA, tensorB, tensorOut }, spirv); + + sq->record(algo)->record( + { tensorA, tensorB, tensorOut }); + + EXPECT_TRUE(sq->isRecording()); + + sq->evalAsync(); + + EXPECT_TRUE(sq->isRunning()); + + // Sequence should throw when running + EXPECT_ANY_THROW(sq->begin()); + EXPECT_ANY_THROW(sq->end()); + EXPECT_ANY_THROW(sq->evalAsync()); + + // Errors should still not get into inconsystent state + sq->evalAwait(); + + // Sequence should not throw when finished + EXPECT_NO_THROW(sq->evalAwait()); + EXPECT_NO_THROW(sq->evalAwait(10)); + + EXPECT_FALSE(sq->isRunning()); + + EXPECT_EQ(tensorOut->vector(), std::vector({ 2, 4, 6 })); +}