model: support Ministral3 (#17644)
* conversion script * support ministral 3 * maybe this is better? * add TODO for rope_yarn_log_mul * better ppl (tested on 14B-Instruct) * Add Ministral3 support to Mistral format * improve arch handling * add sizes * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * nits --------- Co-authored-by: Julien Denize <julien.denize@mistral.ai> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
649495c9d9
commit
cd3c118908
11 changed files with 342 additions and 10 deletions
|
|
@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
if (hparams.n_expert == 8) {
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_8x7B; break;
|
||||
|
|
@ -663,8 +661,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192;
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192;
|
||||
hparams.n_attn_temp_floor_scale = 8192;
|
||||
hparams.f_attn_temp_scale = 0.1f;
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
}
|
||||
|
||||
|
|
@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||
|
||||
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
|
||||
if (hparams.n_attn_temp_floor_scale == 0) {
|
||||
throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
|
||||
// but may need further verification with other values
|
||||
if (hparams.rope_yarn_log_mul != 0.0f) {
|
||||
float factor = 1.0f / hparams.rope_freq_scale_train;
|
||||
float mscale = 1.0f;
|
||||
float mscale_all_dims = hparams.rope_yarn_log_mul;
|
||||
static auto get_mscale = [](float scale, float mscale) {
|
||||
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
||||
};
|
||||
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 26: type = LLM_TYPE_3B; break;
|
||||
case 34: type = LLM_TYPE_8B; break;
|
||||
case 40: type = LLM_TYPE_14B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
|
@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
llm = std::make_unique<llm_build_mistral3>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -7690,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_ARCEE:
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue