llama-cpp-turboquant/common/jinja/caps.cpp
Xuan-Son Nguyen c15395f73c
common : implement new jinja template engine (#18462)
* jinja vm

* lexer

* add vm types

* demo

* clean up

* parser ok

* binary_expression::execute

* shadow naming

* bin ops works!

* fix map object

* add string builtins

* add more builtins

* wip

* use mk_val

* eval with is_user_input

* render gemma tmpl ok

* track input string even after transformations

* support binded functions

* keyword arguments and slicing array

* use shared_ptr for values

* add mk_stmt

* allow print source on exception

* fix negate test

* testing more templates

* mostly works

* add filter_statement

* allow func to access ctx

* add jinja-value.cpp

* impl global_from_json

* a lot of fixes

* more tests

* more fix, more tests

* more fixes

* rm workarounds

* demo: type inferrence

* add placeholder for tojson

* improve function args handling

* rm type inference

* no more std::regex

* trailing spaces

* make testing more flexible

* make output a bit cleaner

* (wip) redirect minja calls

* test: add --output

* fix crash on macro kwargs

* add minimal caps system

* add some workarounds

* rm caps_apply_workarounds

* get rid of preprocessing

* more fixes

* fix test-chat-template

* move test-chat-jinja into test-chat-template

* rm test-chat-jinja from cmake

* test-chat-template: use common

* fix build

* fix build (2)

* rename vm --> interpreter

* improve error reporting

* correct lstrip behavior

* add tojson

* more fixes

* disable tests for COMMON_CHAT_FORMAT_GENERIC

* make sure tojson output correct order

* add object.length

* fully functional selectattr / rejectattr

* improve error reporting

* more builtins added, more fixes

* create jinja rendering tests

* fix testing.h path

* adjust whitespace rules

* more fixes

* temporary disable test for ibm-granite

* r/lstrip behavior matched with hf.js

* minimax, glm4.5 ok

* add append and pop

* kimi-k2 ok

* test-chat passed

* fix lstrip_block

* add more jinja tests

* cast to unsigned char

* allow dict key to be numeric

* nemotron: rm windows newline

* tests ok

* fix test

* rename interpreter --> runtime

* fix build

* add more checks

* bring back generic format support

* fix Apertus

* [json.exception.out_of_range.403] key 'content' not found

* rm generic test

* refactor input marking

* add docs

* fix windows build

* clarify error message

* improved tests

* split/rsplit with maxsplit

* non-inverse maxsplit

forgot to change after simplifying

* implement separators for tojson and fix indent

* i like to move it move it

* rename null -- > none

* token::eof

* some nits + comments

* add exception classes for lexer and parser

* null -> none

* rename global -> env

* rm minja

* update docs

* docs: add input marking caveats

* imlement missing jinja-tests functions

* oops

* support trim filter with args, remove bogus to_json reference

* numerous argument fixes

* updated tests

* implement optional strip chars parameter

* use new chars parameter

* float filter also has default

* always leave at least one decimal in float string

* jinja : static analysis + header cleanup + minor fixes

* add fuzz test

* add string.cpp

* fix chat_template_kwargs

* nits

* fix build

* revert

* unrevert

sorry :)

* add fuzz func_args, refactor to be safer

* fix array.map()

* loosen ensure_vals max count condition, add not impl for map(int)

* hopefully fix windows

* check if empty first

* normalize newlines

---------

Co-authored-by: Alde Rojas <hello@alde.dev>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-01-16 11:22:06 +01:00

237 lines
7.3 KiB
C++

#include "value.h"
#include "runtime.h"
#include "caps.h"
// note: the json dependency is only for defining input in a convenient way
// we can remove it in the future when we figure out a better way to define inputs using jinja::value
#include <nlohmann/json.hpp>
#include <functional>
#include <sstream>
#define FILENAME "jinja-caps"
using json = nlohmann::ordered_json;
namespace jinja {
using caps_json_fn = std::function<json()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
static void caps_try_execute(jinja::program & prog,
const caps_json_fn & messages_fn,
const caps_json_fn & tools_fn,
const caps_analyze_fn & analyze_fn) {
context ctx;
ctx.is_get_stats = true;
jinja::global_from_json(ctx, json{
{"messages", messages_fn()},
{"tools", tools_fn()},
{"bos_token", ""},
{"eos_token", ""},
{"add_generation_prompt", true}
}, true);
auto messages = ctx.get_val("messages");
auto tools = ctx.get_val("tools");
bool success = false;
try {
jinja::runtime runtime(ctx);
runtime.execute(prog);
success = true;
} catch (const std::exception & e) {
JJ_DEBUG("Exception during execution: %s", e.what());
// ignore exceptions during capability analysis
}
analyze_fn(success, messages, tools);
}
// for debugging only
static void caps_print_stats(value & v, const std::string & path) {
std::string ops;
for (const auto & name : v->stats.ops) {
ops += name + " ";
}
JJ_DEBUG("Value %s, type: %s %s, ops: %s",
path.c_str(),
v->type().c_str(),
v->stats.used ? "(used)" : "",
ops.c_str());
}
std::string caps::to_string() const {
std::ostringstream ss;
ss << "Caps(\n";
ss << " requires_typed_content=" << requires_typed_content << "\n";
ss << " supports_tools=" << supports_tools << "\n";
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
ss << " supports_system_role=" << supports_system_role << "\n";
ss << ")";
return ss.str();
}
caps caps_get(jinja::program & prog) {
caps result;
static const auto has_op = [](value & v, const std::string & op_name) {
return v->stats.ops.find(op_name) != v->stats.ops.end();
};
// case: typed content requirement
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "content"}
}
});
},
[&]() {
// tools
return json{nullptr};
},
[&](bool, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
// accessed as an array
result.requires_typed_content = true;
}
}
);
// case: system prompt support
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "system"},
{"content", "System message"}
},
{
{"role", "user"},
{"content", "User message"}
},
});
},
[&]() {
// tools
return json::array();
},
[&](bool, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!content->stats.used) {
result.supports_system_role = false;
}
}
);
// case: tools support
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "User message"},
},
{
{"role", "assistant"},
{"content", "Assistant message"},
{"tool_calls", json::array({
{
{"id", "call1"},
{"type", "function"},
{"function", {
{"name", "tool1"},
{"arguments", {
{"arg", "value"}
}}
}}
},
{
{"id", "call2"},
{"type", "function"},
{"function", {
{"name", "tool2"},
{"arguments", {
{"arg", "value"}
}}
}}
}
})}
},
{
{"role", "user"},
{"content", "User message"},
},
});
},
[&]() {
// tools
return json::array({
{
{"name", "tool"},
{"type", "function"},
{"function", {
{"name", "tool"},
{"description", "Tool description"},
{"parameters", {
{"type", "object"},
{"properties", {
{"arg", {
{"type", "string"},
{"description", "Arg description"},
}},
}},
{"required", json::array({ "arg" })},
}},
}},
},
});
},
[&](bool success, value & messages, value & tools) {
if (!success) {
result.supports_tool_calls = false;
result.supports_tools = false;
return;
}
auto & tool_name = tools->at(0)->at("function")->at("name");
caps_print_stats(tool_name, "tools[0].function.name");
if (!tool_name->stats.used) {
result.supports_tools = false;
}
auto & tool_calls = messages->at(1)->at("tool_calls");;
caps_print_stats(tool_calls, "messages[1].tool_calls");
if (!tool_calls->stats.used) {
result.supports_tool_calls = false;
}
// check for second tool call usage
auto & tool_call_1 = tool_calls->at(1)->at("function");
caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
if (!tool_call_1->stats.used) {
result.supports_parallel_tool_calls = false;
}
}
);
JJ_DEBUG("%s\n", result.to_string().c_str());
return result;
}
} // namespace jinja