* jinja vm * lexer * add vm types * demo * clean up * parser ok * binary_expression::execute * shadow naming * bin ops works! * fix map object * add string builtins * add more builtins * wip * use mk_val * eval with is_user_input * render gemma tmpl ok * track input string even after transformations * support binded functions * keyword arguments and slicing array * use shared_ptr for values * add mk_stmt * allow print source on exception * fix negate test * testing more templates * mostly works * add filter_statement * allow func to access ctx * add jinja-value.cpp * impl global_from_json * a lot of fixes * more tests * more fix, more tests * more fixes * rm workarounds * demo: type inferrence * add placeholder for tojson * improve function args handling * rm type inference * no more std::regex * trailing spaces * make testing more flexible * make output a bit cleaner * (wip) redirect minja calls * test: add --output * fix crash on macro kwargs * add minimal caps system * add some workarounds * rm caps_apply_workarounds * get rid of preprocessing * more fixes * fix test-chat-template * move test-chat-jinja into test-chat-template * rm test-chat-jinja from cmake * test-chat-template: use common * fix build * fix build (2) * rename vm --> interpreter * improve error reporting * correct lstrip behavior * add tojson * more fixes * disable tests for COMMON_CHAT_FORMAT_GENERIC * make sure tojson output correct order * add object.length * fully functional selectattr / rejectattr * improve error reporting * more builtins added, more fixes * create jinja rendering tests * fix testing.h path * adjust whitespace rules * more fixes * temporary disable test for ibm-granite * r/lstrip behavior matched with hf.js * minimax, glm4.5 ok * add append and pop * kimi-k2 ok * test-chat passed * fix lstrip_block * add more jinja tests * cast to unsigned char * allow dict key to be numeric * nemotron: rm windows newline * tests ok * fix test * rename interpreter --> runtime * fix build * add more checks * bring back generic format support * fix Apertus * [json.exception.out_of_range.403] key 'content' not found * rm generic test * refactor input marking * add docs * fix windows build * clarify error message * improved tests * split/rsplit with maxsplit * non-inverse maxsplit forgot to change after simplifying * implement separators for tojson and fix indent * i like to move it move it * rename null -- > none * token::eof * some nits + comments * add exception classes for lexer and parser * null -> none * rename global -> env * rm minja * update docs * docs: add input marking caveats * imlement missing jinja-tests functions * oops * support trim filter with args, remove bogus to_json reference * numerous argument fixes * updated tests * implement optional strip chars parameter * use new chars parameter * float filter also has default * always leave at least one decimal in float string * jinja : static analysis + header cleanup + minor fixes * add fuzz test * add string.cpp * fix chat_template_kwargs * nits * fix build * revert * unrevert sorry :) * add fuzz func_args, refactor to be safer * fix array.map() * loosen ensure_vals max count condition, add not impl for map(int) * hopefully fix windows * check if empty first * normalize newlines --------- Co-authored-by: Alde Rojas <hello@alde.dev> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
237 lines
7.3 KiB
C++
237 lines
7.3 KiB
C++
#include "value.h"
|
|
#include "runtime.h"
|
|
#include "caps.h"
|
|
|
|
// note: the json dependency is only for defining input in a convenient way
|
|
// we can remove it in the future when we figure out a better way to define inputs using jinja::value
|
|
#include <nlohmann/json.hpp>
|
|
|
|
#include <functional>
|
|
#include <sstream>
|
|
|
|
#define FILENAME "jinja-caps"
|
|
|
|
using json = nlohmann::ordered_json;
|
|
|
|
namespace jinja {
|
|
|
|
using caps_json_fn = std::function<json()>;
|
|
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
|
|
|
|
static void caps_try_execute(jinja::program & prog,
|
|
const caps_json_fn & messages_fn,
|
|
const caps_json_fn & tools_fn,
|
|
const caps_analyze_fn & analyze_fn) {
|
|
context ctx;
|
|
ctx.is_get_stats = true;
|
|
jinja::global_from_json(ctx, json{
|
|
{"messages", messages_fn()},
|
|
{"tools", tools_fn()},
|
|
{"bos_token", ""},
|
|
{"eos_token", ""},
|
|
{"add_generation_prompt", true}
|
|
}, true);
|
|
|
|
auto messages = ctx.get_val("messages");
|
|
auto tools = ctx.get_val("tools");
|
|
|
|
bool success = false;
|
|
try {
|
|
jinja::runtime runtime(ctx);
|
|
runtime.execute(prog);
|
|
success = true;
|
|
} catch (const std::exception & e) {
|
|
JJ_DEBUG("Exception during execution: %s", e.what());
|
|
// ignore exceptions during capability analysis
|
|
}
|
|
|
|
analyze_fn(success, messages, tools);
|
|
}
|
|
|
|
// for debugging only
|
|
static void caps_print_stats(value & v, const std::string & path) {
|
|
std::string ops;
|
|
for (const auto & name : v->stats.ops) {
|
|
ops += name + " ";
|
|
}
|
|
JJ_DEBUG("Value %s, type: %s %s, ops: %s",
|
|
path.c_str(),
|
|
v->type().c_str(),
|
|
v->stats.used ? "(used)" : "",
|
|
ops.c_str());
|
|
}
|
|
|
|
std::string caps::to_string() const {
|
|
std::ostringstream ss;
|
|
ss << "Caps(\n";
|
|
ss << " requires_typed_content=" << requires_typed_content << "\n";
|
|
ss << " supports_tools=" << supports_tools << "\n";
|
|
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
|
|
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
|
|
ss << " supports_system_role=" << supports_system_role << "\n";
|
|
ss << ")";
|
|
return ss.str();
|
|
}
|
|
|
|
caps caps_get(jinja::program & prog) {
|
|
caps result;
|
|
|
|
static const auto has_op = [](value & v, const std::string & op_name) {
|
|
return v->stats.ops.find(op_name) != v->stats.ops.end();
|
|
};
|
|
|
|
// case: typed content requirement
|
|
caps_try_execute(
|
|
prog,
|
|
[&]() {
|
|
// messages
|
|
return json::array({
|
|
{
|
|
{"role", "user"},
|
|
{"content", "content"}
|
|
}
|
|
});
|
|
},
|
|
[&]() {
|
|
// tools
|
|
return json{nullptr};
|
|
},
|
|
[&](bool, value & messages, value &) {
|
|
auto & content = messages->at(0)->at("content");
|
|
caps_print_stats(content, "messages[0].content");
|
|
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
|
|
// accessed as an array
|
|
result.requires_typed_content = true;
|
|
}
|
|
}
|
|
);
|
|
|
|
|
|
// case: system prompt support
|
|
caps_try_execute(
|
|
prog,
|
|
[&]() {
|
|
// messages
|
|
return json::array({
|
|
{
|
|
{"role", "system"},
|
|
{"content", "System message"}
|
|
},
|
|
{
|
|
{"role", "user"},
|
|
{"content", "User message"}
|
|
},
|
|
});
|
|
},
|
|
[&]() {
|
|
// tools
|
|
return json::array();
|
|
},
|
|
[&](bool, value & messages, value &) {
|
|
auto & content = messages->at(0)->at("content");
|
|
caps_print_stats(content, "messages[0].content");
|
|
if (!content->stats.used) {
|
|
result.supports_system_role = false;
|
|
}
|
|
}
|
|
);
|
|
|
|
// case: tools support
|
|
caps_try_execute(
|
|
prog,
|
|
[&]() {
|
|
// messages
|
|
return json::array({
|
|
{
|
|
{"role", "user"},
|
|
{"content", "User message"},
|
|
},
|
|
{
|
|
{"role", "assistant"},
|
|
{"content", "Assistant message"},
|
|
{"tool_calls", json::array({
|
|
{
|
|
{"id", "call1"},
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", "tool1"},
|
|
{"arguments", {
|
|
{"arg", "value"}
|
|
}}
|
|
}}
|
|
},
|
|
{
|
|
{"id", "call2"},
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", "tool2"},
|
|
{"arguments", {
|
|
{"arg", "value"}
|
|
}}
|
|
}}
|
|
}
|
|
})}
|
|
},
|
|
{
|
|
{"role", "user"},
|
|
{"content", "User message"},
|
|
},
|
|
});
|
|
},
|
|
[&]() {
|
|
// tools
|
|
return json::array({
|
|
{
|
|
{"name", "tool"},
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", "tool"},
|
|
{"description", "Tool description"},
|
|
{"parameters", {
|
|
{"type", "object"},
|
|
{"properties", {
|
|
{"arg", {
|
|
{"type", "string"},
|
|
{"description", "Arg description"},
|
|
}},
|
|
}},
|
|
{"required", json::array({ "arg" })},
|
|
}},
|
|
}},
|
|
},
|
|
});
|
|
},
|
|
[&](bool success, value & messages, value & tools) {
|
|
if (!success) {
|
|
result.supports_tool_calls = false;
|
|
result.supports_tools = false;
|
|
return;
|
|
}
|
|
|
|
auto & tool_name = tools->at(0)->at("function")->at("name");
|
|
caps_print_stats(tool_name, "tools[0].function.name");
|
|
if (!tool_name->stats.used) {
|
|
result.supports_tools = false;
|
|
}
|
|
|
|
auto & tool_calls = messages->at(1)->at("tool_calls");;
|
|
caps_print_stats(tool_calls, "messages[1].tool_calls");
|
|
if (!tool_calls->stats.used) {
|
|
result.supports_tool_calls = false;
|
|
}
|
|
|
|
// check for second tool call usage
|
|
auto & tool_call_1 = tool_calls->at(1)->at("function");
|
|
caps_print_stats(tool_call_1, "messages[1].tool_calls[1].function");
|
|
if (!tool_call_1->stats.used) {
|
|
result.supports_parallel_tool_calls = false;
|
|
}
|
|
}
|
|
);
|
|
|
|
JJ_DEBUG("%s\n", result.to_string().c_str());
|
|
|
|
return result;
|
|
}
|
|
|
|
} // namespace jinja
|