sampling : delegate input allocation to the scheduler (#19266)

* sampling : delegate input allocation to the scheduler

* graph : compute backend samplers only if needed
This commit is contained in:
Georgi Gerganov 2026-02-03 22:16:16 +02:00 committed by GitHub
parent 32b17abdb0
commit faa1bc26ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 33 additions and 73 deletions

View file

@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend {
std::mt19937 rng;
// backend input
struct ggml_tensor * inp_uniform;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
ggml_tensor * inp_uniform;
};
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init(
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
// allocate inputs
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
}
const bool res = llama_sampler_backend_support(smpl, buft);
sctx->init(res);
if (!res) {
sctx->inp_ctx.reset(nullptr);
sctx->inp_buf.reset(nullptr);
}
return res;
}
@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply(
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_dist *) smpl->ctx;
sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "dist_probs");
@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply(
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
// We sample in double precision and cast to float to match rnd numbers of
@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .inp_uniform = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}
@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {
struct ggml_tensor * inp_logit_bias;
struct ggml_tensor * inp_logit_idxs;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
};
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply(
return;
}
const size_t n = sctx->logit_bias.size();
sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);
sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
static bool llama_sampler_logit_bias_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
sctx->init(true);
@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init(
return true;
}
ggml_init_params params = {
/*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
const size_t n = sctx->logit_bias.size();
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
return true;
}
@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias(
/* .to_search = */ {},
/* .inp_logit_bias = */ nullptr,
/* .inp_logit_idxs = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}