llama : add adaptive-p sampler (#17927)

* initial commit for branch

* simplify constants

* add params to `struct common_params_sampling`, add reference to PR

* explicitly clamp `min_target` and `max_target` to `[0.0, 1.0]`

* add args, rename `queue_size` -> `window_size`

* improved comments

* minor

* remove old unused code from algorithm

* minor

* add power law case to `common_sampler_init`, add sampler name mappings

* clarify behaviour when `window_size = 0`

* add missing enums

* remove `target_range` param, make `target == 1` no-op, cleanup code

* oops, straggler

* add missing parameters in `server-task.cpp`

* copy from author

ref:
https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069

* remove old debug log, style nit

* fix compiler warning, add commented-out logging per token

* re-write + change parameters + simplify

* oops forgot args.cpp

* fix leftover `window_size`

* add missing values to `common_params_sampling::print()`

* with logging

* does this fix it?

* no, but does this?

* update default decay

* optimize

* fix bad merge

my git skills are lacking

* silence `missing initializer for member`

* update default decay to 0.9

* fix logging

* format (double)

* add power law to the new `samplers` vector

* log sampler init values

* improve logging messages in llama_sampler_power_law

* remove extraneous logging

* simplify target computation

last commit with debug logging!

* remove debug logging, explicitly clamp params at init

* add `use_power_law` flag + logic, minor cleanup

* update `power-law` -> `adaptive-p`

* fix cold start EMA

- `ctx->weighted_sum` is now initialized and reset to `target / (1.0f -
clamped_decay)`
- `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f -
clamped_decay)`

this fixes a "cold start" problem with the moving average

* update `SHARPNESS` constant to `10.0f`

* minor style fixes

no functional changes

* minor style fixes cont.

* update `llama_sampler_adaptive_p_i` for backend sampling (ref: #17004)

* separate into `apply` + `accept` functions

* `pending_token_idx`: switch from `llama_token` to `int32`

functionally identical (`llama.h` has `typedef int32_t llama_token;`),
but its more correct now

* don't transform logits <= -1e9f

* fix masking in backend top-p, min-p

* address review comments

* typo in comments `RND` -> `RNG`

* add docs

* add recommended values in completion docs

* address PR feedback

* remove trailing whitespace (for CI `editorconfig`)

* add to adaptive-p to `common_sampler_types_from_chars`
This commit is contained in:
ddh0 2026-01-15 11:16:29 -06:00 committed by GitHub
parent a04c2b06a3
commit 13f1e4a9ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 297 additions and 52 deletions

View file

@ -1513,12 +1513,9 @@ static void llama_sampler_top_p_backend_apply(
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// top_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
// Apply -INFINITY bias for masked-out tokens
// log(1) = 0 (keep), log(0) = -INF (discard)
struct ggml_tensor * top_p_bias = ggml_log(ctx, mask);
ggml_set_name(top_p_bias, "top_p_bias");
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
@ -1673,15 +1670,11 @@ static void llama_sampler_min_p_backend_apply(
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// min_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
// Apply -INFINITY bias for masked-out tokens
// log(1) = 0 (keep), log(0) = -INF (discard)
struct ggml_tensor * min_p_bias = ggml_log(ctx, mask);
ggml_set_name(min_p_bias, "min_p_bias");
// Add the min_p bias to the logits.
data->logits = ggml_add(ctx, data->logits, min_p_bias);
ggml_set_name(data->logits, "min_p_logits");
@ -3293,6 +3286,170 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
return result;
}
// adaptive-p sampler state
//
// maintains an exponential moving average of the *ORIGINAL* probabilities
// of selected tokens, used to compute an adapted target at each sampling step.
//
// see llama.h for a full description of the sampler
//
// ref: https://github.com/ggml-org/llama.cpp/pull/17927
//
struct llama_sampler_adaptive_p {
const float target; // target probability (0.0 - 1.0; negative = disabled)
const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99)
const uint32_t seed; // original RNG seed
uint32_t seed_cur; // actual RNG seed
std::mt19937 rng; // RNG state
float weighted_sum; // sum(p_i * decay^i)
float total_weight; // sum(decay^i), converges to 1/(1-decay)
std::vector<float> original_probs; // pre-transform probs, cached for EMA update
llama_token pending_token_id; // token ID of selected token
int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs
};
// adaptive probability transformation constants
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
static constexpr float SHARPNESS = 10.0f;
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;
static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) {
return "adaptive-p";
}
static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
llama_sampler_softmax_impl(cur_p, false);
if (ctx->target < 0.0f) {
// at negative target values, adaptive-p is no-op
// we simply sample from the existing distribution
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
return;
}
// store the original probabilities
ctx->original_probs.resize(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
ctx->original_probs[i] = cur_p->data[i].p;
}
// using the EMA, compute the adapted target probability for the current sampling step
auto target = std::clamp(ctx->target, 0.0f, 1.0f);
float adapted_target = std::clamp(
ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight),
0.0f, 1.0f
);
// adaptive probability transform
//
// quadratic near target for fine differentiation, transitioning to linear decay in the
// tails. unbounded negative logits ensure proper suppression of far-from-target tokens
// after the softmax.
//
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit == -INFINITY) {
// don't transform logits that are -INFINITY
// (as masked out by e.g. min-p and top-p when using backend sampling)
continue;
}
float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH);
cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist);
}
// softmax and sample from the transformed distribution
llama_sampler_softmax_impl(cur_p, false);
const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;
// store the selected token ID for acceptance later
ctx->pending_token_id = cur_p->data[idx].id;
ctx->pending_token_idx = idx;
}
static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
if (ctx->pending_token_id == token) {
GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL);
GGML_ASSERT(ctx->pending_token_idx != -1);
// update EMA with the original probability of the selected token
ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight;
}
ctx->pending_token_id = LLAMA_TOKEN_NULL;
ctx->pending_token_idx = -1;
}
static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx;
// ctx->target and ctx->decay never change after init, so it's safe to keep them as is.
// original_probs is completely overwritten on every call to _apply.
// so we only need to reset the EMA state and pending token.
ctx->weighted_sum = ctx->target / (1.0f - ctx->decay);
ctx->total_weight = 1.0f / (1.0f - ctx->decay);
ctx->pending_token_id = LLAMA_TOKEN_NULL;
ctx->pending_token_idx = -1;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx;
auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed);
auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx;
// copy everything (target, decay, seed, and RNG are already set)
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
result_ctx->pending_token_id = ctx->pending_token_id;
result_ctx->pending_token_idx = ctx->pending_token_idx;
return result;
}
static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_adaptive_p *) smpl->ctx;
}
static struct llama_sampler_i llama_sampler_adaptive_p_i = {
/* .name = */ llama_sampler_adaptive_p_name,
/* .accept = */ llama_sampler_adaptive_p_accept,
/* .apply = */ llama_sampler_adaptive_p_apply,
/* .reset = */ llama_sampler_adaptive_p_reset,
/* .clone = */ llama_sampler_adaptive_p_clone,
/* .free = */ llama_sampler_adaptive_p_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * llama_sampler_init_adaptive_p(
float target,
float decay,
uint32_t seed
) {
auto seed_cur = get_rng_seed(seed);
float clamped_decay = std::clamp(decay, 0.0f, 0.99f);
return llama_sampler_init(
/* .iface = */ &llama_sampler_adaptive_p_i,
/* .ctx = */ new llama_sampler_adaptive_p {
/* .target = */ target,
/* .decay = */ clamped_decay,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .weighted_sum = */ target / (1.0f - clamped_decay),
/* .total_weight = */ 1.0f / (1.0f - clamped_decay),
/* .original_probs = */ {},
/* .pending_token_id = */ LLAMA_TOKEN_NULL,
/* .pending_token_idx = */ -1
}
);
}
// logit-bias
struct llama_sampler_logit_bias : public llama_sampler_backend {