grammar: Fix grammar root symbol check (#19761)
* grammar: fix bad check for root symbol, correct error logging * add tests to demonstrate root symbol check failure
This commit is contained in:
parent
deee23863b
commit
0a10c34dc1
2 changed files with 42 additions and 6 deletions
|
|
@ -15,8 +15,12 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static llama_grammar * build_grammar_with_root(const std::string & grammar_str, const char * grammar_root) {
|
||||
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), grammar_root, false, nullptr, 0, nullptr, 0);
|
||||
}
|
||||
|
||||
static llama_grammar * build_grammar(const std::string & grammar_str) {
|
||||
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
|
||||
return build_grammar_with_root(grammar_str, "root");
|
||||
}
|
||||
|
||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||
|
|
@ -860,6 +864,36 @@ static void test_failure_left_recursion() {
|
|||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
static void test_failure_missing_root_symbol() {
|
||||
fprintf(stderr, "⚫ Testing missing root symbol:\n");
|
||||
|
||||
const std::string grammar_str = R"""(
|
||||
root ::= "foobar"
|
||||
)""";
|
||||
|
||||
llama_grammar * failure_result = build_grammar_with_root(grammar_str, "nonexistent");
|
||||
assert(failure_result == nullptr);
|
||||
|
||||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
static void test_custom_root_symbol_check() {
|
||||
fprintf(stderr, "⚫ Testing custom root symbol check:\n");
|
||||
|
||||
const std::string custom_root_grammar_str = R"""(
|
||||
foobar ::= "foobar"
|
||||
)""";
|
||||
|
||||
llama_grammar * failure_result = build_grammar_with_root(custom_root_grammar_str, "root");
|
||||
assert(failure_result == nullptr);
|
||||
|
||||
llama_grammar * success_result = build_grammar_with_root(custom_root_grammar_str, "foobar");
|
||||
assert(success_result != nullptr);
|
||||
llama_grammar_free_impl(success_result);
|
||||
|
||||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
static void test_json_schema() {
|
||||
// Note that this is similar to the regular grammar tests,
|
||||
// but we convert each json schema to a grammar before parsing.
|
||||
|
|
@ -1433,6 +1467,8 @@ int main() {
|
|||
test_failure_missing_root();
|
||||
test_failure_missing_reference();
|
||||
test_failure_left_recursion();
|
||||
test_failure_missing_root_symbol();
|
||||
test_custom_root_symbol_check();
|
||||
test_json_schema();
|
||||
fprintf(stdout, "All tests passed.\n");
|
||||
return 0;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue