common : refactor common_sampler + grammar logic changes (#17937)

* common : refactor common_sampler + grammar logic changes

* tests : increase max_tokens to get needed response

* batched : fix uninitialized samplers
This commit is contained in:
Georgi Gerganov 2025-12-14 10:11:13 +02:00 committed by GitHub
parent 3238b1400c
commit 254098a279
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 372 additions and 293 deletions

View file

@ -141,13 +141,15 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
common_init_result llama_init = common_init_from_params(params);
model = llama_init.model.get();
ctx = llama_init.context.get();
auto llama_init = common_init_from_params(params);
if (model == NULL) {
LOG_ERR("%s: error: unable to load model\n", __func__);
ctx = llama_init->context();
model = llama_init->model();
smpl = llama_init->sampler(0);
if (ctx == NULL) {
LOG_ERR("%s: error: unable to create context\n", __func__);
return 1;
}
@ -474,12 +476,6 @@ int main(int argc, char ** argv) {
}
}
smpl = common_sampler_init(model, sparams);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
return 1;
}
LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl));
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str());
@ -993,8 +989,6 @@ int main(int argc, char ** argv) {
LOG("\n\n");
common_perf_print(ctx, smpl);
common_sampler_free(smpl);
llama_backend_free();
ggml_threadpool_free_fn(threadpool);