model: support GLM-OCR (#19677)
* model: support GLM-OCR * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
e99f1083a0
commit
eeef3cfced
8 changed files with 122 additions and 43 deletions
|
|
@ -1784,7 +1784,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
|
||||
|
||||
// NextN/MTP parameters (GLM-OCR)
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 17: type = LLM_TYPE_1B; break; // GLM-OCR
|
||||
case 40: type = LLM_TYPE_9B; break;
|
||||
case 61: type = LLM_TYPE_32B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
|
|
@ -5410,30 +5418,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (layer.wqkv == nullptr) {
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
int flags = 0;
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= TENSOR_SKIP;
|
||||
}
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0);
|
||||
if (layer.wqkv == nullptr) {
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
|
||||
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags);
|
||||
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
|
||||
|
||||
// Optional tensors
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue