llama : remove write/read of output ids/logits/embeddings (#18862)

* llama : remove write/read of output ids/logits/embeddings

This commit removes the write/read of output ids, logits and
embeddings from the llama context state.

Refs: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941

* completion : add replying of session state

This commit updates the session handing in the completion tool to handle
the that logits are no longer stored in the session file. Instead, we
need to replay the last token to get the logits for sampling.

* common : add common_prompt_batch_decode function

This commit adds a new function which is responsible for decoding prompt
and optionally handle the saving for session data.

* update save-state.cpp to use llama_state_load_file

This commit updates the save-load-state example to utilize the new
llama_state_load_file function for loading the model state from a file.
And it also replays the last token after loading since this state is now
stored before the last token is processed.

* examples : set n_seq_max = 2 for ctx3

This commit updates the save-load-state example to set the n_seq_max
parameter to 2 when initializing the ctx3 context.

The motivation for this change is that using 1 as n_parallel/n_seq_max
the context only supports one sequence, but the test laster tries to
use a second sequence which results in the following error:
```console
main : loaded state with 4 tokens
main : seq 0 copied, 225760 bytes
main : kv cache cleared
find_slot: seq_id=1 >= n_seq_max=1 Try using a bigger --parallel value
state_read_meta: failed to find available cells in kv cache
```
This seems to only happen for recurrent/hybrid models.
This commit is contained in:
Daniel Bevenius 2026-02-23 07:04:30 +01:00 committed by GitHub
parent e8e261699a
commit 2b6dfe824d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 132 additions and 200 deletions

View file

@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
// write output ids
{
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
std::vector<int32_t> w_output_pos;
w_output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
}
}
io.write(&n_outputs, sizeof(n_outputs));
if (n_outputs) {
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
}
}
// [TAG_CONTEXT_STATE_LOGITS]
// write logits
{
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
io.write(&logits_size, sizeof(logits_size));
if (logits_size) {
io.write(logits.data, logits_size * sizeof(float));
}
}
// write embeddings
{
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
io.write(&embd_size, sizeof(embd_size));
if (embd_size) {
io.write(embd.data, embd_size * sizeof(float));
}
}
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
// TODO: add more info which needs to be identical but which is not verified otherwise
}
// read output ids
{
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}
std::vector<int32_t> output_pos;
if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
if ((uint32_t) id >= n_batch()) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
}
this->output_ids[id] = i;
}
this->n_outputs = n_outputs;
}
}
// read logits
{
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));
if (this->logits.size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) {
io.read_to(this->logits.data, logits_size * sizeof(float));
}
}
// read embeddings
{
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
uint64_t embd_size;
io.read_to(&embd_size, sizeof(embd_size));
if (this->embd.size < embd_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embd_size) {
io.read_to(this->embd.data, embd_size * sizeof(float));
}
}
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);