vulkan/cuda: fix topk_moe with exp_probs_b (#18071)
I updated test_topk_moe to more closely match llm_graph_context::build_moe_ffn and added coverage for exp_probs_b and some other missing combinations. This exposed a bug in both CUDA and Vulkan backends where they were assuming the input to argsort and the input to get_rows are the same. I'd like to optimize this graph in another change, but for now just get it functional. CUDA also had a bug where it got n_experts from the wrong place, leading to GGML_ASSERT failures in some of the new tests.
This commit is contained in:
parent
cb64222b0c
commit
b365c3ff01
5 changed files with 117 additions and 35 deletions
|
|
@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
enum MoeGatingFunc {
|
||||
GATING_FUNC_SOFTMAX,
|
||||
GATING_FUNC_SIGMOID,
|
||||
GATING_FUNC_SOFTMAX_WEIGHT,
|
||||
};
|
||||
|
||||
struct test_topk_moe : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int n_expert_used;
|
||||
const bool with_norm;
|
||||
const bool delayed_softmax;
|
||||
const bool bias_probs;
|
||||
const MoeGatingFunc gating_func;
|
||||
const float scale_w;
|
||||
|
||||
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
|
||||
int n_expert_used = 1,
|
||||
bool with_norm = false,
|
||||
bool delayed_softmax = false) :
|
||||
bool bias_probs = false,
|
||||
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
|
||||
float scale_w = 0.0f) :
|
||||
ne(ne),
|
||||
n_expert_used(n_expert_used),
|
||||
with_norm(with_norm),
|
||||
delayed_softmax(delayed_softmax) {
|
||||
bias_probs(bias_probs),
|
||||
gating_func(gating_func),
|
||||
scale_w(scale_w) {
|
||||
GGML_ASSERT(n_expert_used <= ne[0]);
|
||||
GGML_ASSERT(!(with_norm && delayed_softmax));
|
||||
}
|
||||
|
||||
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
|
||||
std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
|
|
@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
|
|||
const int n_tokens = ne[1];
|
||||
|
||||
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_tensor * probs =
|
||||
(gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
|
||||
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
|
||||
ggml_set_name(probs, "probs");
|
||||
|
||||
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
ggml_tensor * selection_probs = probs;
|
||||
if (bias_probs) {
|
||||
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_set_name(exp_probs_b, "exp_probs_b");
|
||||
selection_probs = ggml_add(ctx, probs, exp_probs_b);
|
||||
ggml_set_name(selection_probs, "selection_probs");
|
||||
}
|
||||
|
||||
if (delayed_softmax) {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
ggml_set_name(selected_experts, "selected_experts");
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
ggml_set_name(weights, "weights");
|
||||
|
||||
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
|
||||
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
|
||||
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
|
||||
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
|
||||
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
|
||||
ggml_set_name(weights_sum, "weights_sum");
|
||||
|
||||
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
|
||||
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
|
||||
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
|
||||
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
if (scale_w) {
|
||||
weights = ggml_scale(ctx, weights, scale_w);
|
||||
}
|
||||
|
||||
ggml_set_name(weights, "weights");
|
||||
return weights;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -7991,19 +8021,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
for (bool with_norm : {false, true}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
|
||||
for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
|
||||
for (bool with_norm : {false, true}) {
|
||||
for (bool bias_probs : {false, true}) {
|
||||
for (float scale_w : {0.0f, 2.0f}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
|
||||
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
|
||||
|
||||
#if 0
|
||||
// these tests are disabled to save execution time, sbut they can be handy for debugging
|
||||
test_cases.emplace_back(new test_llama(2, true));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue