vulkan: Implement top-k (#17418)

* vulkan: Implement top-k

Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10)
and discards all but the top K. Repeat until only K are left. And there's a fast
path when K==1 to just find the max value rather than sorting.

* fix pipeline selection

* vulkan: Add N-ary search algorithm for topk

* microoptimizations
This commit is contained in:
Jeff Bolz 2025-11-26 09:45:43 -06:00 committed by GitHub
parent 6ab4e50d9c
commit 879d673759
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 480 additions and 1 deletions

View file

@ -7635,6 +7635,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
}
for (int i = 0; i < 20; ++i) {
for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
if (k <= 1<<i) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
}
}
}
for (int k : {1, 2, 3, 7, 15}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
@ -8032,7 +8040,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
for (auto k : {1, 10, 40}) {
for (auto nrows : {1, 16}) {
for (auto cols : {k, 1000, 65000, 200000}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
}
}
}
return test_cases;
}