vulkan: Multi-pass softmax for large number of cols (#17892)
When the number of cols is large, split each row across multiple workgroups. There are three phases that communicate partial results through temp buffers: (1) compute max partials (2) take max of partials, compute sum(exp(x-max)) partials (3) sum partials, compute scaled result
This commit is contained in:
parent
3c6391e748
commit
303f8615e9
7 changed files with 331 additions and 2 deletions
|
|
@ -7652,6 +7652,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
|
||||
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue