model: support Ministral3 (#17644)

* conversion script

* support ministral 3

* maybe this is better?

* add TODO for rope_yarn_log_mul

* better ppl (tested on 14B-Instruct)

* Add Ministral3 support to Mistral format

* improve arch handling

* add sizes

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* nits

---------

Co-authored-by: Julien Denize <julien.denize@mistral.ai>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Xuan-Son Nguyen 2025-12-01 12:26:52 +01:00 committed by GitHub
parent 649495c9d9
commit cd3c118908
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 342 additions and 10 deletions

View file

@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
switch (arch) {
case LLM_ARCH_LLAMA:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
if (hparams.n_expert == 8) {
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_8x7B; break;
@ -663,8 +661,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
} else {
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192;
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
hparams.n_swa = 8192;
hparams.n_attn_temp_floor_scale = 8192;
hparams.f_attn_temp_scale = 0.1f;
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
}
@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MISTRAL3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
if (hparams.f_attn_temp_scale != 0.0f) {
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
if (hparams.n_attn_temp_floor_scale == 0) {
throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
}
}
// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
// but may need further verification with other values
if (hparams.rope_yarn_log_mul != 0.0f) {
float factor = 1.0f / hparams.rope_freq_scale_train;
float mscale = 1.0f;
float mscale_all_dims = hparams.rope_yarn_log_mul;
static auto get_mscale = [](float scale, float mscale) {
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
}
switch (hparams.n_layer) {
case 26: type = LLM_TYPE_3B; break;
case 34: type = LLM_TYPE_8B; break;
case 40: type = LLM_TYPE_14B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_MISTRAL3:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_qwen3next>(*this, params);
} break;
case LLM_ARCH_MISTRAL3:
{
llm = std::make_unique<llm_build_mistral3>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
}
@ -7690,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCEE:
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2