common : inhibit lazy grammar sampler while reasoning is active (#20970)

* common : inhibit grammar while reasoning budget is active

* cont : update force_pos in accept

* cont : fix tests

* cont : tweak should apply logic

* cont : return early not using grammar sampler

* Add tests

* cont : prevent backend sampling when reasoning budget enabled

* cont : fix typo

---------

Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
This commit is contained in:
Aldehir Rojas 2026-03-27 12:30:40 -05:00 committed by GitHub
parent ff934e29bc
commit 59d840209a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 295 additions and 106 deletions

View file

@ -115,9 +115,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
break;
case REASONING_BUDGET_DONE:
break;
@ -144,14 +146,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
@ -261,3 +255,10 @@ struct llama_sampler * common_reasoning_budget_init(
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
if (!smpl) {
return REASONING_BUDGET_IDLE;
}
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
}

View file

@ -51,3 +51,5 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);

View file

@ -7,6 +7,7 @@
#include <algorithm>
#include <cctype>
#include <climits>
#include <cmath>
#include <cstring>
#include <unordered_map>
@ -109,6 +110,7 @@ struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * rbudget;
struct llama_sampler * chain;
ring_buffer<llama_token> prev;
@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
lparams.no_perf = params.no_perf;
llama_sampler * grmr = nullptr;
llama_sampler * rbudget = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
std::vector<llama_sampler *> samplers;
@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
if (grmr) {
if (grmr && !params.grammar_lazy) {
try {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
// reasoning budget sampler
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
rbudget = common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
prefill_tokens));
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
prefill_tokens);
}
if (params.has_logit_bias()) {
@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .rbudget = */ rbudget,
/* .chain = */ chain,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
@ -398,11 +402,27 @@ void common_sampler_free(struct common_sampler * gsmpl) {
}
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->rbudget);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
}
static bool grammar_should_apply(struct common_sampler * gsmpl) {
if (!gsmpl->grmr) {
return false;
}
if (!gsmpl->rbudget) {
return true;
}
if (gsmpl->params.grammar_lazy) {
// if grammar is lazy, only apply when reasoning budget is not active
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
}
return true;
}
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
@ -410,6 +430,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
const auto tm = gsmpl->tm();
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
llama_sampler_accept(gsmpl->rbudget, token);
if (gsmpl->grmr && accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & rbudget = gsmpl->rbudget;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
// TODO: simplify
gsmpl->cur.resize(1);
@ -524,7 +552,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
gsmpl->set_logits(ctx, idx);
if (grammar_first) {
// apply reasoning budget first
llama_sampler_apply(rbudget, &cur_p);
if (grammar_first && grammar_should_apply(gsmpl)) {
llama_sampler_apply(grmr, &cur_p);
}
@ -532,7 +563,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
if (grammar_first || !grammar_should_apply(gsmpl)) {
return id;
}
@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(rbudget, &cur_p);
if (grammar_should_apply(gsmpl)) {
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");