common/parser: handle reasoning budget (#20297)

* v1

* Finished!

* Handlie cli

* Reasoning sampler

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Less explosive terminology :)

* Add utf-8 case and tests

* common : migrate reasoning budget sampler to common

* cont : clean up

* cont : expose state and allow passing as initial state

* cont : remove unused imports

* cont : update state machine doc string

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Alde Rojas <hello@alde.dev>
This commit is contained in:
Piotr Wilkin (ilintar) 2026-03-11 10:26:12 +01:00 committed by GitHub
parent 5f91b1d5d5
commit acb7c79069
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 670 additions and 10 deletions

View file

@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h

View file

@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
auto parsed = json::parse(value);
for (const auto & item : parsed.items()) {
if (item.key() == "enable_thinking") {
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
"Use --reasoning on / --reasoning off instead.\n");
}
params.default_template_kwargs[item.key()] = item.value().dump();
}
}
@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.reasoning_format = common_reasoning_format_from_name(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg(
{"-rea", "--reasoning"}, "[on|off|auto]",
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
[](common_params & params, const std::string & value) {
if (is_truthy(value)) {
params.enable_reasoning = 1;
params.default_template_kwargs["enable_thinking"] = "true";
} else if (is_falsey(value)) {
params.enable_reasoning = 0;
params.default_template_kwargs["enable_thinking"] = "false";
} else if (is_autoy(value)) {
params.enable_reasoning = -1;
} else {
throw std::invalid_argument(
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
add_opt(common_arg(
{"--reasoning-budget"}, "N",
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
[](common_params & params, int value) {
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
if (value < -1) { throw std::invalid_argument("invalid value"); }
params.reasoning_budget = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
add_opt(common_arg(
{"--reasoning-budget-message"}, "MESSAGE",
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
[](common_params & params, const std::string & value) {
params.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(

View file

@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template)
return p.reasoning(p.until(end)) + end;
// However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
}
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)

View file

@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
data.supports_thinking = true;
data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = {
@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
const autoparser::templates_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = {
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
}
return auto_params;
} catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());

View file

@ -213,6 +213,8 @@ struct common_chat_params {
bool grammar_lazy = false;
bool thinking_forced_open = false;
bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>"
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;

View file

@ -235,6 +235,14 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
bool backend_sampling = false;
bool has_logit_bias() const {
@ -536,7 +544,9 @@ struct common_params {
bool use_jinja = true; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1;
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time

219
common/reasoning-budget.cpp Normal file
View file

@ -0,0 +1,219 @@
#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"
#include "log.h"
#include <cmath>
#include <cstdint>
#include <string>
#include <vector>
struct token_matcher {
std::vector<llama_token> tokens;
size_t pos = 0;
bool advance(llama_token token) {
if (tokens.empty()) {
return false;
}
if (token == tokens[pos]) {
pos++;
if (pos >= tokens.size()) {
pos = 0;
return true;
}
} else {
pos = 0;
if (token == tokens[0]) {
pos = 1;
}
}
return false;
}
void reset() { pos = 0; }
};
struct common_reasoning_budget_ctx {
const llama_vocab * vocab;
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
common_reasoning_budget_state state;
// for forcing
size_t force_pos; // next position in forced_tokens to force
};
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
switch (ctx->state) {
case REASONING_BUDGET_IDLE:
{
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
break;
}
bool utf8_complete = true;
if (ctx->vocab != nullptr) {
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
utf8_complete = common_utf8_is_complete(piece);
}
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
}
}
}
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
break;
case REASONING_BUDGET_DONE:
break;
}
}
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
}
if (ctx->force_pos >= ctx->forced_tokens.size()) {
return;
}
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
// set all logits to -inf except the forced token
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
ctx->state = REASONING_BUDGET_IDLE;
ctx->remaining = ctx->budget;
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
}
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
}
static struct llama_sampler_i common_reasoning_budget_i = {
/* .name = */ common_reasoning_budget_name,
/* .accept = */ common_reasoning_budget_accept,
/* .apply = */ common_reasoning_budget_apply,
/* .reset = */ common_reasoning_budget_reset,
/* .clone = */ common_reasoning_budget_clone,
/* .free = */ common_reasoning_budget_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
}
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
}
);
}

41
common/reasoning-budget.h Normal file
View file

@ -0,0 +1,41 @@
#pragma once
#include "llama.h"
#include <cstdint>
#include <vector>
enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};
// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);

View file

@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "reasoning-budget.h"
#include <algorithm>
#include <cmath>
@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
}
if (params.has_logit_bias()) {
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}

View file

@ -1,8 +1,10 @@
#include "unicode.h"
#include <algorithm>
#include <cassert>
#include <stdexcept>
#include <vector>
#include <string>
#include <vector>
// implementation adopted from src/unicode.cpp
@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
return utf8_parse_result(utf8_parse_result::INVALID);
}
bool common_utf8_is_complete(const std::string & s) {
if (s.empty()) {
return true;
}
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
unsigned char c = s[s.size() - i];
if ((c & 0xC0) != 0x80) {
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
return i >= expected;
}
}
return false;
}
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
for (size_t i = 0; i < cps.size(); ++i) {

View file

@ -20,6 +20,9 @@ struct utf8_parse_result {
// Returns 0 for invalid first bytes
size_t common_utf8_sequence_length(unsigned char first_byte);
// Check if a string ends with a complete UTF-8 sequence.
bool common_utf8_is_complete(const std::string & s);
// Parse a single UTF-8 codepoint from input
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);