common/parser: fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers (#21230)
* Fix call ID detection (Mistral parser mostly) + atomicity for tag-json parsers * Rename * Update common/chat-auto-parser-generator.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
af5c13841f
commit
f1f793ad06
7 changed files with 242 additions and 153 deletions
|
|
@ -6,6 +6,7 @@
|
|||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
|
@ -317,6 +318,44 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
|||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const {
|
||||
auto open = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix);
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(open + call_id_section) + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (atomic_peek.has_value()) {
|
||||
func_parser = p.atomic(open + call_id_section + p.space() + *atomic_peek) + args;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser = func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
}
|
||||
return func_parser;
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
|
|
@ -330,17 +369,27 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
|||
const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
bool have_call_id = false;
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
have_call_id = true;
|
||||
}
|
||||
auto args_parser = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!arguments.start.empty()) {
|
||||
args_parser = p.literal(arguments.start) + args_parser;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_parser = args_parser + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + function.close;
|
||||
}
|
||||
auto atomic_peek = !arguments.start.empty() ? std::optional(p.peek(p.literal(arguments.start))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_parser, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
|
|
@ -470,52 +519,31 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size());
|
||||
}
|
||||
|
||||
if (!arguments.start.empty()) {
|
||||
args_seq = p.literal(arguments.start) + args_seq;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_seq = args_seq + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
bool have_call_id = false;
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
have_call_id = true;
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
}
|
||||
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser =
|
||||
func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
}
|
||||
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
auto atomic_peek = (!arguments.name_prefix.empty() && !required_parsers.empty()) ?
|
||||
std::optional(p.peek(p.literal(arguments.name_prefix))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_seq, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -356,6 +356,13 @@ struct analyze_tools : analyze_base {
|
|||
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
|
||||
|
||||
// Shared helper: builds func_parser from open+call_id+args, handling atomic wrapping and close.
|
||||
// atomic_peek: if present, used as the peek expression in the third atomicity branch.
|
||||
common_peg_parser build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const;
|
||||
common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ static const std::string ARG_SECOND = "BB_ARG_SND_BB";
|
|||
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
|
||||
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
|
||||
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
|
||||
static const std::string CALL_ID_001 = "call00001";
|
||||
static const std::string CALL_ID_002 = "call00002";
|
||||
static const std::string CALL_ID_999 = "call99999";
|
||||
|
||||
static std::vector<std::function<void(const common_chat_template & tmpl, autoparser &)>> workarounds(
|
||||
{ // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to
|
||||
|
|
@ -131,6 +134,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
|||
analysis.tools.function.name_prefix = "<|tool▁sep|>";
|
||||
analysis.tools.format.per_call_end = "<|tool▁call▁end|>";
|
||||
analysis.tools.function.close = "```";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -158,7 +162,7 @@ static json user_msg = json{
|
|||
{ "content", USER_MSG }
|
||||
};
|
||||
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") {
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = CALL_ID_001) {
|
||||
return json{
|
||||
{ "id", id },
|
||||
{ "type", "function" },
|
||||
|
|
@ -166,17 +170,17 @@ static json build_tool_call(const std::string & name, const json & args, const s
|
|||
};
|
||||
}
|
||||
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001");
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001");
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), CALL_ID_001);
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, CALL_ID_001);
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, CALL_ID_001);
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
|
||||
static json first_tool_call =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
static json second_tool_call =
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002");
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_002);
|
||||
static json first_tool_call_alt_id =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_999);
|
||||
|
||||
template <typename T>
|
||||
static std::string mode_to_str(T mode) {
|
||||
|
|
@ -215,6 +219,11 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
|||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("call_id_prefix: '%s'\n", tools.call_id.prefix.c_str());
|
||||
LOG_DBG("call_id_suffix: '%s'\n", tools.call_id.suffix.c_str());
|
||||
LOG_DBG("call_id_pos: '%s'\n", mode_to_str(tools.call_id.pos).c_str());
|
||||
LOG_DBG("args_start: '%s'\n", tools.arguments.start.c_str());
|
||||
LOG_DBG("args_end: '%s'\n", tools.arguments.end.c_str());
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
|
|
@ -583,12 +592,15 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
|
|||
if (caps.supports_parallel_tool_calls) {
|
||||
check_per_call_markers();
|
||||
}
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3a: Function call analysis\n" ANSI_RESET);
|
||||
extract_function_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3b: Argument analysis\n" ANSI_RESET);
|
||||
if (format.mode == tool_format::TAG_WITH_TAGGED) {
|
||||
analyze_arguments();
|
||||
}
|
||||
extract_argument_separator();
|
||||
extract_args_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3c: Call id analysis\n" ANSI_RESET);
|
||||
extract_call_id_markers();
|
||||
}
|
||||
}
|
||||
|
|
@ -979,8 +991,6 @@ void analyze_tools::extract_function_markers() {
|
|||
}
|
||||
|
||||
void analyze_tools::analyze_arguments() {
|
||||
LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET);
|
||||
|
||||
extract_argument_name_markers();
|
||||
extract_argument_value_markers();
|
||||
}
|
||||
|
|
@ -1189,7 +1199,7 @@ void analyze_tools::extract_args_markers() {
|
|||
|
||||
const auto & diff = comparison->diff;
|
||||
|
||||
if (format.mode != tool_format::JSON_NATIVE) {
|
||||
if (format.mode == tool_format::JSON_NATIVE) {
|
||||
std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end;
|
||||
// these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones
|
||||
|
|
@ -1211,6 +1221,10 @@ void analyze_tools::extract_args_markers() {
|
|||
if (find_fun != std::string::npos) {
|
||||
args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size());
|
||||
}
|
||||
size_t find_call_id = args_start.find(CALL_ID_001);
|
||||
if (find_call_id != std::string::npos) {
|
||||
args_start = args_start.substr(find_call_id + CALL_ID_001.size(), args_start.size() - find_call_id - CALL_ID_001.size());
|
||||
}
|
||||
arguments.start = args_start;
|
||||
arguments.end = args_end;
|
||||
}
|
||||
|
|
@ -1250,8 +1264,8 @@ void analyze_tools::extract_call_id_markers() {
|
|||
return;
|
||||
}
|
||||
|
||||
std::string id_value_1 = "call00001";
|
||||
std::string id_value_2 = "call99999";
|
||||
std::string id_value_1 = CALL_ID_001;
|
||||
std::string id_value_2 = CALL_ID_999;
|
||||
|
||||
size_t common_id_prefix_len = 0;
|
||||
for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) {
|
||||
|
|
@ -1350,6 +1364,14 @@ void analyze_tools::extract_call_id_markers() {
|
|||
call_id.suffix = find_first_marker(before_func);
|
||||
}
|
||||
|
||||
if (call_id.prefix == arguments.end) {
|
||||
call_id.prefix = "";
|
||||
}
|
||||
|
||||
if (call_id.suffix == arguments.start) {
|
||||
call_id.suffix = "";
|
||||
}
|
||||
|
||||
// When call_id is detected, per_call_end may have been incorrectly set to include
|
||||
// the call_id_suffix and sample args. Clear it if it starts with call_id_suffix.
|
||||
if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() &&
|
||||
|
|
|
|||
|
|
@ -1631,7 +1631,7 @@ static json common_chat_extra_context() {
|
|||
return ctx;
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
|
|
@ -1778,7 +1778,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
if (auto result = common_chat_try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -270,3 +270,8 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
|||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue