* mtmd: llama.cpp DeepSeekOCR support init commit * loading sam tensors * mtmd: fix vision model processing * deepseek-ocr clip-vit model impl * mtmd: add DeepSeek-OCR LM support with standard attention * mtmd: successfully runs DeepSeek-OCR LM in llama-cli * mtmd: Fix RoPE type for DeepSeek-OCR LM. * loading LM testing Vision model loading * sam warmup working * sam erroneous return corrected * clip-vit: corrected cls_embd concat * clip-vit: model convert qkv_proj split * corrected combining of image encoders' results * fix: update callback for ffn_moe_weighted and add callback for attn_out in deepseek2 model * concat image_newline and image_seperator tokens * visual_model warmup (technically) works * window partitioning using standard ggml ops * sam implementation without using CPU only ops * clip: fixed warnings * Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr * mtmd: fix get_rel_pos * mtmd: fixed the wrong scaler for get_rel_pos * image encoding technically works but the output can't be checked singe image decoding fails * mtmd: minor changed * mtmd: add native resolution support * - image encoding debugged - issues fixed mainly related wrong config like n_patches etc. - configs need to be corrected in the converter * mtmd: correct token order * - dynamic resizing - changes are concerning PR https://github.com/sfallah/llama.cpp/pull/4 * mtmd: quick fix token order * mtmd: fix danling pointer * mtmd: SAM numerically works * mtmd: debug CLIP-L (vit_pre_ln) * mtmd: debug CLIP-L & first working DeepSeek-OCR model * mtmd : add --dsocr-mode CLI argument for DeepSeek-OCR resolution control & all native resolution modes work * mtmd: simplify SAM patch embedding * mtmd: adapt Pillow image resizing function * mtmd: simplify DeepSeek-OCR dynamic resolution preprocessing * mtmd: remove --dsocr-mode argument * mtmd: refactor code & remove unused helper functions * mtmd: fix tensor names for image newlines and view separator * clean up * reverting automatically removed spaces * reverting automatically removed spaces * mtmd: fixed bad ocr check in Deepseek2 (LM) * mtmd: support combined QKV projection in buid_vit * using common build_attn in sam * corrected code-branch when flash-attn disabled enabling usage of --flash-attn option * mtmd: minor fix * minor formatting and style * fixed flake8 lint issues * minor editorconfig-check fixes * minor editorconfig-check fixes * mtmd: simplify get_rel_pos * mtmd: make sam hparams configurable * mtmd: add detailed comments for resize_bicubic_pillow * mtmd: fixed wrong input setting * mtmd: convert model in FP16 * mtmd: minor fix * mtmd: remove tweak to llama-mtmd-cli & deepseek-ocr template * fix: test-1.jpg ORC issue with small (640) resolution setting min-resolution base (1024) max large (1280) for dynamic-resolution * minor: editconfig-check fix * merge with changes from https://github.com/ggml-org/llama.cpp/pull/17909 added new opt to tests.sh to disable flash-attn * minor: editconfig-check fix * testing deepseek-ocr quick and dirty test script comparing results of Qwen2.5-VL vs DeepSeek-OCR * quick and (potential) dirty merge with https://github.com/ggml-org/llama.cpp/pull/17909 * refactoring, one single builder function and static helpers * added deepseek-ocr test to tests.sh * minor formatting fixes * check with fixed expected resutls * minor formatting * editorconfig-check fix * merge with changes from https://github.com/ggml-org/llama.cpp/pull/18042 * minor - added GLM-4.6V to big tests - added missing deps for python test * convert: minor fix * mtmd: format code * convert: quick fix * convert: quick fix * minor python formatting * fixed merge build issue * merge resolved - fixed issues in convert - tested several deepseek models * minor fix * minor * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * - removed clip_is_deepseekocr - removed redundant RESIZE_ALGO_BICUBIC_PILLOW resize-algo - simplified image-preprocessing - removed/simplified debug functions * - cleaning commented out code * fixing instabilities issues reintroducing resize_bicubic_pillow * - use f16 model for deepseek-ocr test - ignore llama-arch test for deepseek-ocr * rename fc_w --> mm_fc_w * add links to OCR discussion * cleaner loading code * add missing .weight to some tensors * add default jinja template (to be used by server) * move test model to ggml-org * rolling back upscale change * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: bluebread <hotbread70127@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
122 lines
4.5 KiB
C++
122 lines
4.5 KiB
C++
#include "models.h"
|
|
|
|
ggml_cgraph * clip_graph_glm4v::build() {
|
|
GGML_ASSERT(model.patch_bias != nullptr);
|
|
GGML_ASSERT(model.class_embedding == nullptr);
|
|
|
|
const int batch_size = 1;
|
|
|
|
norm_type norm_t = NORM_TYPE_RMS;
|
|
|
|
ggml_tensor * inp_raw = build_inp_raw();
|
|
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
|
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
|
ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches * 4);
|
|
ggml_set_name(positions, "positions");
|
|
ggml_set_input(positions);
|
|
|
|
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
|
|
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
|
|
|
|
// second conv dimension
|
|
{
|
|
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
|
inp = ggml_add(ctx0, inp, inp_1);
|
|
|
|
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w, h, c, b] -> [c, w, h, b]
|
|
inp = ggml_cont_4d(
|
|
ctx0, inp,
|
|
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
|
inp = ggml_reshape_4d(
|
|
ctx0, inp,
|
|
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
|
inp = ggml_permute(ctx0, inp, 0, 2, 1, 3);
|
|
inp = ggml_cont_3d(
|
|
ctx0, inp,
|
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
|
}
|
|
|
|
// add patch bias
|
|
inp = ggml_add(ctx0, inp, model.patch_bias);
|
|
cb(inp, "patch_bias", -1);
|
|
|
|
// pos-conv norm
|
|
inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1);
|
|
|
|
ggml_tensor * learned_pos_embd = nullptr;
|
|
// Note: GLM-OCR does not have learned position embeddings
|
|
if (model.position_embeddings != nullptr) {
|
|
learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC);
|
|
learned_pos_embd = ggml_cont_4d(
|
|
ctx0, learned_pos_embd,
|
|
n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
|
|
learned_pos_embd = ggml_reshape_4d(
|
|
ctx0, learned_pos_embd,
|
|
n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
|
|
learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3);
|
|
learned_pos_embd = ggml_cont_3d(
|
|
ctx0, learned_pos_embd,
|
|
n_embd, n_patches_x * n_patches_y, batch_size);
|
|
cb(learned_pos_embd, "learned_pos_embd", -1);
|
|
}
|
|
|
|
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
|
return ggml_rope_multi(
|
|
ctx0, cur, positions, nullptr,
|
|
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION,
|
|
32768, hparams.rope_theta, 1, 0, 1, 32, 1);
|
|
};
|
|
|
|
ggml_tensor * cur = build_vit(
|
|
inp, n_patches,
|
|
norm_t,
|
|
hparams.ffn_op,
|
|
learned_pos_embd,
|
|
add_pos);
|
|
|
|
cb(cur, "vit_out", -1);
|
|
// cb(ggml_sum(ctx0, cur), "vit_out_sum", -1);
|
|
|
|
// GLM4V projector
|
|
// ref: https://github.com/huggingface/transformers/blob/40dc11cd3eb4126652aa41ef8272525affd4a636/src/transformers/models/glm4v/modeling_glm4v.py#L116-L130
|
|
|
|
// patch merger (downsample)
|
|
{
|
|
int n_merge = hparams.n_merge;
|
|
GGML_ASSERT(n_merge > 0);
|
|
|
|
int n_token_out = n_patches / n_merge / n_merge;
|
|
cur = ggml_reshape_4d(ctx0, cur, n_embd, n_merge, n_merge, n_token_out);
|
|
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); // [n_merge, n_merge, n_embd, n_token_out]
|
|
cur = ggml_conv_2d(ctx0, model.mm_patch_merger_w, cur, n_merge, n_merge, 0, 0, 1, 1);
|
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[2], n_token_out); // [n_embd_out, n_token_out]
|
|
|
|
cur = ggml_add(ctx0, cur, model.mm_patch_merger_b);
|
|
}
|
|
|
|
// FC projector
|
|
{
|
|
cur = build_mm(model.mm_fc_w, cur);
|
|
// default LayerNorm (post_projection_norm)
|
|
cur = build_norm(cur, model.mm_post_norm_w, model.mm_post_norm_b, NORM_TYPE_NORMAL, 1e-5, -1);
|
|
cur = ggml_gelu_erf(ctx0, cur);
|
|
cb(cur, "after_fc_proj", -1);
|
|
}
|
|
|
|
// FFN projector
|
|
{
|
|
cur = build_ffn(cur,
|
|
model.mm_ffn_up_w, model.mm_ffn_up_b,
|
|
model.mm_ffn_gate_w, model.mm_ffn_gate_b,
|
|
model.mm_ffn_down_w, model.mm_ffn_down_b,
|
|
hparams.ffn_op, -1);
|
|
cb(cur, "after_ffn_proj", -1);
|
|
// cb(ggml_sum(ctx0, cur), "merged_sum", -1);
|
|
}
|
|
|
|
// build the graph
|
|
ggml_build_forward_expand(gf, cur);
|
|
|
|
return gf;
|
|
}
|