llama : add adaptive-p sampler (#17927)

* initial commit for branch

* simplify constants

* add params to `struct common_params_sampling`, add reference to PR

* explicitly clamp `min_target` and `max_target` to `[0.0, 1.0]`

* add args, rename `queue_size` -> `window_size`

* improved comments

* minor

* remove old unused code from algorithm

* minor

* add power law case to `common_sampler_init`, add sampler name mappings

* clarify behaviour when `window_size = 0`

* add missing enums

* remove `target_range` param, make `target == 1` no-op, cleanup code

* oops, straggler

* add missing parameters in `server-task.cpp`

* copy from author

ref:
https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069

* remove old debug log, style nit

* fix compiler warning, add commented-out logging per token

* re-write + change parameters + simplify

* oops forgot args.cpp

* fix leftover `window_size`

* add missing values to `common_params_sampling::print()`

* with logging

* does this fix it?

* no, but does this?

* update default decay

* optimize

* fix bad merge

my git skills are lacking

* silence `missing initializer for member`

* update default decay to 0.9

* fix logging

* format (double)

* add power law to the new `samplers` vector

* log sampler init values

* improve logging messages in llama_sampler_power_law

* remove extraneous logging

* simplify target computation

last commit with debug logging!

* remove debug logging, explicitly clamp params at init

* add `use_power_law` flag + logic, minor cleanup

* update `power-law` -> `adaptive-p`

* fix cold start EMA

- `ctx->weighted_sum` is now initialized and reset to `target / (1.0f -
clamped_decay)`
- `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f -
clamped_decay)`

this fixes a "cold start" problem with the moving average

* update `SHARPNESS` constant to `10.0f`

* minor style fixes

no functional changes

* minor style fixes cont.

* update `llama_sampler_adaptive_p_i` for backend sampling (ref: #17004)

* separate into `apply` + `accept` functions

* `pending_token_idx`: switch from `llama_token` to `int32`

functionally identical (`llama.h` has `typedef int32_t llama_token;`),
but its more correct now

* don't transform logits <= -1e9f

* fix masking in backend top-p, min-p

* address review comments

* typo in comments `RND` -> `RNG`

* add docs

* add recommended values in completion docs

* address PR feedback

* remove trailing whitespace (for CI `editorconfig`)

* add to adaptive-p to `common_sampler_types_from_chars`
This commit is contained in:
ddh0 2026-01-15 11:16:29 -06:00 committed by GitHub
parent a04c2b06a3
commit 13f1e4a9ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 297 additions and 52 deletions

View file

@ -1729,6 +1729,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_sparam());
add_opt(common_arg(
{"--adaptive-target"}, "N",
string_format("adaptive-p: select tokens near this probability (valid range 0.0 "
"to 1.0; negative = disabled) (default: %.2f)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/17927)",
(double)params.sampling.adaptive_target),
[](common_params & params, const std::string & value) {
params.sampling.adaptive_target = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--adaptive-decay"}, "N",
string_format("adaptive-p: decay rate for target adaptation over time. lower values "
"are more reactive, higher values are more stable.\n"
"(valid range 0.0 to 0.99) (default: %.2f)",
(double)params.sampling.adaptive_decay),
[](common_params & params, const std::string & value) {
params.sampling.adaptive_decay = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),

View file

@ -119,6 +119,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12,
};
// dimensionality reduction methods, used by cvector-generator
@ -166,32 +167,34 @@ enum common_params_sampling_config : uint64_t {
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers

View file

@ -167,11 +167,11 @@ std::string common_params_sampling::print() const {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau);
mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay);
return std::string(result);
}
@ -255,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
if (params.mirostat == 0) {
bool use_adaptive_p = false; // see below
for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
@ -264,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
samplers.push_back(llama_sampler_init_top_k (params.top_k));
samplers.push_back(llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
samplers.push_back(llama_sampler_init_infill (vocab));
samplers.push_back(llama_sampler_init_infill(vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
case COMMON_SAMPLER_TYPE_ADAPTIVE_P:
// the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects
// a single token, so we will add `dist` at the end of the chain by default,
// unless the user specifically included `adaptive-p`. we set this flag here
// so we know to add the sampler at the very end.
use_adaptive_p = true;
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
samplers.push_back(llama_sampler_init_dist(params.seed));
if (use_adaptive_p) {
// only if user explicitly included adaptive-p sampler
samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed));
} else {
// default: sample from distribution
samplers.push_back(llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
@ -625,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a';
default : return '?';
}
}
@ -641,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p";
default : return "";
}
}
@ -657,6 +673,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
// since samplers names are written multiple ways
@ -672,6 +689,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> samplers;
@ -708,6 +726,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> samplers;