context : reserve new scheduler when graph topology changes (#18547)
* context : reserve new scheduler when graph topology changes * cont : fix * cont : fix reserve * cont : reserve only when changes occur + timing * context : add comments * llama : reserve on sampler changes * common : allow null common_sampler * server : task declares needs (embd, logits, sampling) * server : do not init sampler if not needed * llama : fix need_reserve when unsetting a sampler * server : consolidate slot reset/clear logic
This commit is contained in:
parent
5c662d21a3
commit
39173bcacb
9 changed files with 328 additions and 216 deletions
|
|
@ -45,26 +45,6 @@ enum server_state {
|
|||
SERVER_STATE_READY, // Server is ready and model is loaded
|
||||
};
|
||||
|
||||
static bool server_task_type_need_embd(server_task_type task_type) {
|
||||
switch (task_type) {
|
||||
case SERVER_TASK_TYPE_EMBEDDING:
|
||||
case SERVER_TASK_TYPE_RERANK:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool server_task_type_need_logits(server_task_type task_type) {
|
||||
switch (task_type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
|
||||
|
|
@ -147,6 +127,17 @@ struct server_slot {
|
|||
return res;
|
||||
}
|
||||
|
||||
void prompt_clear(bool allow_processing) {
|
||||
if (!allow_processing) {
|
||||
GGML_ASSERT(!is_processing());
|
||||
}
|
||||
|
||||
SLT_INF(*this, "clearing prompt with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
int32_t alora_invocation_start = -1;
|
||||
|
||||
|
|
@ -196,30 +187,24 @@ struct server_slot {
|
|||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
|
||||
task_prev = std::move(task);
|
||||
task.reset();
|
||||
task_prev.reset();
|
||||
|
||||
llama_set_sampler(ctx, id, nullptr);
|
||||
|
||||
// clear alora start
|
||||
alora_invocation_start = -1;
|
||||
}
|
||||
|
||||
// remove cached prompt + tokens
|
||||
void clear(bool allow_processing) {
|
||||
if (!allow_processing) {
|
||||
GGML_ASSERT(!is_processing());
|
||||
void init_sampler() const {
|
||||
common_sampler_reset(smpl.get());
|
||||
|
||||
if (!task->need_sampling()) {
|
||||
return;
|
||||
}
|
||||
|
||||
SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size());
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
|
||||
prompt.tokens.clear();
|
||||
}
|
||||
|
||||
void init_sampler() const {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
common_sampler_reset(smpl.get());
|
||||
|
||||
int n_text = 0;
|
||||
|
||||
for (int i = 0; i < (int) prompt.tokens.size(); i++) {
|
||||
|
|
@ -235,25 +220,13 @@ struct server_slot {
|
|||
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_embd() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_embd(task->type);
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_logits() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_logits(task->type);
|
||||
}
|
||||
|
||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
// also we cannot split if the pooling would require any past tokens
|
||||
bool can_split() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return
|
||||
!need_embd() ||
|
||||
!task->need_embd() ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||
}
|
||||
|
||||
|
|
@ -349,11 +322,10 @@ struct server_slot {
|
|||
|
||||
// do not keep context of the child slots - the parent's context is enough
|
||||
if (is_child()) {
|
||||
clear(false);
|
||||
prompt_clear(false);
|
||||
}
|
||||
|
||||
task_prev = std::move(task);
|
||||
task.reset();
|
||||
reset();
|
||||
|
||||
callback_on_release(id);
|
||||
}
|
||||
|
|
@ -801,6 +773,7 @@ private:
|
|||
|
||||
slots.clear();
|
||||
|
||||
// initialize slots
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
|
|
@ -1049,7 +1022,7 @@ private:
|
|||
ret->prompt_save(*prompt_cache);
|
||||
|
||||
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
|
||||
ret->clear(false);
|
||||
ret->prompt_clear(false);
|
||||
}
|
||||
|
||||
prompt_cache->update();
|
||||
|
|
@ -1081,7 +1054,7 @@ private:
|
|||
if (slot.prompt.n_tokens() > 0) {
|
||||
SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size());
|
||||
|
||||
slot.clear(false);
|
||||
slot.prompt_clear(false);
|
||||
|
||||
res = true;
|
||||
|
||||
|
|
@ -1107,8 +1080,6 @@ private:
|
|||
}
|
||||
|
||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||
slot.reset();
|
||||
|
||||
// process per-request lora adapters
|
||||
if (!task.params.lora.empty()) {
|
||||
auto task_loras = construct_lora_list(task.params.lora);
|
||||
|
|
@ -1182,7 +1153,7 @@ private:
|
|||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||
|
||||
// initialize samplers
|
||||
{
|
||||
if (task.need_sampling()) {
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
|
||||
if (slot.smpl == nullptr) {
|
||||
|
|
@ -1211,6 +1182,8 @@ private:
|
|||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
} else {
|
||||
slot.smpl.reset();
|
||||
}
|
||||
|
||||
// initialize draft batch
|
||||
|
|
@ -1864,7 +1837,7 @@ private:
|
|||
// Erase token cache
|
||||
const size_t n_erased = slot->prompt.tokens.size();
|
||||
|
||||
slot->clear(false);
|
||||
slot->prompt_clear(false);
|
||||
|
||||
auto res = std::make_unique<server_task_result_slot_erase>();
|
||||
res->id = task.id;
|
||||
|
|
@ -2161,7 +2134,7 @@ private:
|
|||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
||||
if (slot.task->need_logits() && !llama_get_memory(ctx)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
|
|
@ -2421,7 +2394,7 @@ private:
|
|||
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
|
||||
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
|
||||
|
||||
slot.clear(true);
|
||||
slot.prompt_clear(true);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_prompt_tokens_cache = 0;
|
||||
|
|
@ -2500,7 +2473,7 @@ private:
|
|||
cur_tok,
|
||||
slot.prompt.tokens.pos_next(),
|
||||
{ slot.id },
|
||||
slot.need_embd());
|
||||
slot.task->need_embd());
|
||||
slot.prompt.tokens.push_back(cur_tok);
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
|
@ -2590,7 +2563,7 @@ private:
|
|||
slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
||||
}
|
||||
|
||||
llama_set_embeddings(ctx, slot_batched->need_embd());
|
||||
llama_set_embeddings(ctx, slot_batched->task->need_embd());
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
|
|
@ -2648,7 +2621,7 @@ private:
|
|||
|
||||
// note: it's complicated to keep track of how much of the current batch has been
|
||||
// processed before the error occurred, so we simply clear the entire context
|
||||
slot.clear(false);
|
||||
slot.prompt_clear(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2727,6 +2700,8 @@ private:
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.task->need_sampling());
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue