ggml: add GATED_DELTA_NET op (#19504)
* ggml: add GATED_DELTA_NET op * remove the transpose * add KDA * add qwen35 dense * llama : check for fused gated delta net backend support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
6fce5c6a7d
commit
c5a778891b
15 changed files with 627 additions and 10 deletions
|
|
@ -150,6 +150,9 @@ llama_context::llama_context(
|
|||
cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
|
||||
cparams.fused_gdn_ar = true;
|
||||
cparams.fused_gdn_ch = false; // TODO: implement
|
||||
|
||||
// with causal attention, the batch size is limited by the context size
|
||||
cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
|
||||
|
||||
|
|
@ -422,7 +425,7 @@ void llama_context::sched_reserve() {
|
|||
if (cparams.auto_fa) {
|
||||
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to split graph for Flash Attention check");
|
||||
throw std::runtime_error("failed to reserve graph for Flash Attention check");
|
||||
}
|
||||
|
||||
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
||||
|
|
@ -432,8 +435,7 @@ void llama_context::sched_reserve() {
|
|||
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_dev_t device_fa = ggml_backend_get_device(
|
||||
ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||
ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||
|
||||
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
||||
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
||||
|
|
@ -448,6 +450,7 @@ void llama_context::sched_reserve() {
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fa_device_mismatch) {
|
||||
cparams.flash_attn = false;
|
||||
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
||||
|
|
@ -459,6 +462,39 @@ void llama_context::sched_reserve() {
|
|||
cparams.auto_fa = false;
|
||||
}
|
||||
|
||||
if (cparams.fused_gdn_ar) {
|
||||
auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
|
||||
}
|
||||
|
||||
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
|
||||
bool gdn_device_mismatch = false;
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
ggml_tensor * n = ggml_graph_node(gf, i);
|
||||
if (n->op != GGML_OP_GATED_DELTA_NET) {
|
||||
continue;
|
||||
}
|
||||
ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||
|
||||
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
|
||||
const int il = std::stoi(n->name + prefix_len);
|
||||
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
||||
if (device_gdn != device_kv) {
|
||||
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
|
||||
"is assigned to device %s (usually due to missing support)\n",
|
||||
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
|
||||
gdn_device_mismatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (gdn_device_mismatch) {
|
||||
cparams.fused_gdn_ar = false;
|
||||
LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// reserve worst-case graph
|
||||
int n_splits_pp = -1;
|
||||
int n_nodes_pp = -1;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue