common : add parser for ministral/mistral large 3/devstral 2 (#17713)

This commit is contained in:
Aldehir Rojas 2025-12-09 17:31:04 -06:00 committed by GitHub
parent 63391852b0
commit 2fbe3b7bb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 415 additions and 0 deletions

View file

@ -1,5 +1,6 @@
#include "chat.h"
#include "chat-parser.h"
#include "chat-peg-parser.h"
#include "common.h"
#include "json-partial.h"
#include "json-schema-to-grammar.h"
@ -150,6 +151,7 @@ struct templates_params {
common_chat_tool_choice tool_choice;
json json_schema;
bool parallel_tool_calls;
common_reasoning_format reasoning_format;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
@ -589,6 +591,16 @@ common_chat_templates_ptr common_chat_templates_init(
"{%- if false %}");
}
// TODO @aldehir : this is a temporary fix, pending Minja changes
// Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664
if (default_template_src.find("[TOOL_CALLS]") != std::string::npos
// search for the error message and patch it
&& default_template_src.find("if (message['content'] is none or") != std::string::npos) {
string_replace_all(default_template_src,
"{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}",
"{%- if false %}");
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
bool add_bos = false;
@ -987,6 +999,118 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
return data;
}
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
auto role = msg.value("role", "");
if (role != "system" && role != "assistant") {
// Only adjust system and assistant messages. Interestingly, the system message may contain thinking.
adjusted_messages.push_back(msg);
continue;
}
auto content = json::array();
// If message contains `reasoning_content`, add it as a block of type `thinking`
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
content.push_back({
{"type", "thinking"},
{"thinking", msg.at("reasoning_content").get<std::string>()},
});
}
// If message contains `content`, add it as a block of type `text`
if (msg.contains("content")) {
if (msg.at("content").is_string()) {
content.push_back({
{"type", "text"},
{"text", msg.at("content").get<std::string>()},
});
} else if (msg.at("content").is_array()) {
auto blocks = msg.at("content");
content.insert(content.end(), blocks.begin(), blocks.end());
}
}
auto adjusted = msg;
adjusted["content"] = content;
adjusted.erase("reasoning_content");
adjusted_messages.push_back(adjusted);
}
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
"[THINK]",
"[/THINK]",
"[TOOL_CALLS]",
"[ARGS]",
};
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
// Ministral wants to emit json surrounded by code fences
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```";
}
// Tool call parser
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
const auto & schema = function.at("parameters");
tool_choice |= p.rule("tool-" + name,
p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]")
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))
);
});
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
}
// Content only parser
include_grammar = false;
return reasoning << p.content(p.rest());
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}
};
}
return data;
}
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@ -2341,6 +2465,7 @@ static common_chat_params common_chat_templates_apply_jinja(
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;
params.enable_thinking = inputs.enable_thinking;
params.grammar = inputs.grammar;
params.now = inputs.now;
@ -2504,6 +2629,13 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
// Ministral/Mistral Large 3
if (src.find("[SYSTEM_PROMPT]") != std::string::npos &&
src.find("[TOOL_CALLS]") != std::string::npos &&
src.find("[ARGS]") != std::string::npos) {
return common_chat_params_init_ministral_3(tmpl, params);
}
if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
return common_chat_params_init_magistral(tmpl, params);
}