Updated OpAlgoBase to not copy data as optensorsync operations are introduced

This commit is contained in:
Alejandro Saucedo 2020-09-12 09:14:35 +01:00
parent 4171786b6f
commit 9f8508075a
10 changed files with 92 additions and 180 deletions

View file

@ -27,17 +27,16 @@ TEST(TestMultipleAlgoExecutions, SingleSequenceRecord) {
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
false, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
false, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
true, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->record<kp::OpTensorSyncLocal>({ tensorA });
sq->end();
sq->eval();
}
@ -70,7 +69,6 @@ TEST(TestMultipleAlgoExecutions, MultipleCmdBufRecords) {
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
false, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
@ -80,7 +78,6 @@ TEST(TestMultipleAlgoExecutions, MultipleCmdBufRecords) {
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
false, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
@ -90,11 +87,18 @@ TEST(TestMultipleAlgoExecutions, MultipleCmdBufRecords) {
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
true, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
sq->eval();
sq->begin();
sq->record<kp::OpTensorSyncLocal>(
{ tensorA });
sq->end();
sq->eval();
}
sqWeakPtr.reset();
@ -126,7 +130,6 @@ TEST(TestMultipleAlgoExecutions, MultipleSequences) {
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
true, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
@ -134,12 +137,11 @@ TEST(TestMultipleAlgoExecutions, MultipleSequences) {
}
std::weak_ptr<kp::Sequence> sqWeakPtr2 = mgr.getOrCreateManagedSequence("newSequence2");
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr.lock()) {
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr2.lock()) {
sq->begin();
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
true, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
@ -148,18 +150,28 @@ TEST(TestMultipleAlgoExecutions, MultipleSequences) {
std::weak_ptr<kp::Sequence> sqWeakPtr3 = mgr.getOrCreateManagedSequence("newSequence3");
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr.lock()) {
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr3.lock()) {
sq->begin();
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
true, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
sq->eval();
}
std::weak_ptr<kp::Sequence> sqWeakPtr4 = mgr.getOrCreateManagedSequence("newSequence5");
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr4.lock()) {
sq->begin();
sq->record<kp::OpTensorSyncLocal>(
{ tensorA });
sq->end();
sq->eval();
}
EXPECT_EQ(tensorA->data(), std::vector<float>({3, 3, 3}));
}
@ -190,12 +202,11 @@ TEST(TestMultipleAlgoExecutions, SingleRecordMultipleEval) {
}
std::weak_ptr<kp::Sequence> sqWeakPtr2 = mgr.getOrCreateManagedSequence("newSequence2");
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr.lock()) {
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr2.lock()) {
sq->begin();
sq->record<kp::OpAlgoBase<3, 1, 1>>(
{ tensorA },
true, // Whether to copy output from device
std::vector<char>(shader.begin(), shader.end()));
sq->end();
@ -205,6 +216,20 @@ TEST(TestMultipleAlgoExecutions, SingleRecordMultipleEval) {
sq->eval();
}
std::weak_ptr<kp::Sequence> sqWeakPtr3 = mgr.getOrCreateManagedSequence("newSequence3");
if (std::shared_ptr<kp::Sequence> sq = sqWeakPtr2.lock()) {
sq->begin();
sq->record<kp::OpTensorSyncLocal>(
{ tensorA });
sq->end();
sq->eval();
sq->eval();
sq->eval();
}
EXPECT_EQ(tensorA->data(), std::vector<float>({3, 3, 3}));
}