common : refactor common_sampler + grammar logic changes (#17937)

* common : refactor common_sampler + grammar logic changes

* tests : increase max_tokens to get needed response

* batched : fix uninitialized samplers
This commit is contained in:
Georgi Gerganov 2025-12-14 10:11:13 +02:00 committed by GitHub
parent 3238b1400c
commit 254098a279
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 372 additions and 293 deletions

View file

@ -71,10 +71,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
common_init_result llama_init_tgt = common_init_from_params(params);
auto llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get();
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// load the draft model
params.devices = params.speculative.devices;
@ -87,10 +87,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
common_init_result llama_init_dft = common_init_from_params(params);
auto llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get();
model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
bool accept = false;
if (params.sampling.temp > 0) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
continue;
}
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);