llama : add adaptive-p sampler (#17927)

* initial commit for branch

* simplify constants

* add params to `struct common_params_sampling`, add reference to PR

* explicitly clamp `min_target` and `max_target` to `[0.0, 1.0]`

* add args, rename `queue_size` -> `window_size`

* improved comments

* minor

* remove old unused code from algorithm

* minor

* add power law case to `common_sampler_init`, add sampler name mappings

* clarify behaviour when `window_size = 0`

* add missing enums

* remove `target_range` param, make `target == 1` no-op, cleanup code

* oops, straggler

* add missing parameters in `server-task.cpp`

* copy from author

ref:
https://gist.github.com/MrJackSpade/9be99c7efbba7b95a41377e123b7b069

* remove old debug log, style nit

* fix compiler warning, add commented-out logging per token

* re-write + change parameters + simplify

* oops forgot args.cpp

* fix leftover `window_size`

* add missing values to `common_params_sampling::print()`

* with logging

* does this fix it?

* no, but does this?

* update default decay

* optimize

* fix bad merge

my git skills are lacking

* silence `missing initializer for member`

* update default decay to 0.9

* fix logging

* format (double)

* add power law to the new `samplers` vector

* log sampler init values

* improve logging messages in llama_sampler_power_law

* remove extraneous logging

* simplify target computation

last commit with debug logging!

* remove debug logging, explicitly clamp params at init

* add `use_power_law` flag + logic, minor cleanup

* update `power-law` -> `adaptive-p`

* fix cold start EMA

- `ctx->weighted_sum` is now initialized and reset to `target / (1.0f -
clamped_decay)`
- `ctx->total_weight` is now initialized and reset to `1.0f / (1.0f -
clamped_decay)`

this fixes a "cold start" problem with the moving average

* update `SHARPNESS` constant to `10.0f`

* minor style fixes

no functional changes

* minor style fixes cont.

* update `llama_sampler_adaptive_p_i` for backend sampling (ref: #17004)

* separate into `apply` + `accept` functions

* `pending_token_idx`: switch from `llama_token` to `int32`

functionally identical (`llama.h` has `typedef int32_t llama_token;`),
but its more correct now

* don't transform logits <= -1e9f

* fix masking in backend top-p, min-p

* address review comments

* typo in comments `RND` -> `RNG`

* add docs

* add recommended values in completion docs

* address PR feedback

* remove trailing whitespace (for CI `editorconfig`)

* add to adaptive-p to `common_sampler_types_from_chars`
This commit is contained in:
ddh0 2026-01-15 11:16:29 -06:00 committed by GitHub
parent a04c2b06a3
commit 13f1e4a9ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 297 additions and 52 deletions

View file

@ -436,6 +436,19 @@ The Min-P sampling method was designed as an alternative to Top-P, and aims to e
Example usage: `--min-p 0.05`
### Adaptive-P Sampling
- `--adaptive-target N`: select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
- `--adaptive-decay N`: EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
Adaptive-P: Select tokens near a configurable target probability over time.
The adaptive-p sampler transforms the token probability distribution to favor tokens that fall near a user-configurable probability target. Internally, the sampler maintains an exponential moving average of the *ORIGINAL* probabilities of selected tokens at each sampling step. It uses this EMA to compute an adapted target probability at each sampling step, thus maintaining the desired target probability over time. Only mild truncation before this sampler is recommended. It is suggested to apply min-p before adaptive-p as the only other active sampler.
Recommended starting values: `--adaptive-target 0.55 --adaptive-decay 0.9`
For more info, refer to: [llama.cpp#17927](https://github.com/ggml-org/llama.cpp/pull/17927)
### Locally Typical Sampling
- `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled).