server: add auto-sleep after N seconds of idle (#18228)

* implement sleeping at queue level

* implement server-context suspend

* add test

* add docs

* optimization: add fast path

* make sure to free llama_init

* nits

* fix use-after-free

* allow /models to be accessed during sleeping, fix use-after-free

* don't allow accessing /models during sleep, it is not thread-safe

* fix data race on accessing props and model_meta

* small clean up

* trailing whitespace

* rm outdated comments
This commit is contained in:
Xuan-Son Nguyen 2025-12-21 02:24:42 +01:00 committed by GitHub
parent 52ab19df63
commit ddcb75dd8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 355 additions and 122 deletions

View file

@ -544,7 +544,9 @@ struct server_context_impl {
server_metrics metrics;
json webui_settings = json::object();
// cached responses for HTTP API (read-only from HTTP threads)
json json_server_props = json::object();
json json_server_model_meta = json::object();
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
@ -554,8 +556,23 @@ struct server_context_impl {
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
bool sleeping = false;
~server_context_impl() {
if (!sleeping) {
// destroy() is already called when entering sleeping state
// we don't call it again here to avoid double free
destroy();
}
}
void destroy() {
llama_init.reset();
ctx = nullptr;
model = nullptr;
mtmd_free(mctx);
mctx = nullptr;
// Clear any sampling context
for (server_slot & slot : slots) {
@ -571,22 +588,29 @@ struct server_context_impl {
llama_batch_free(batch);
}
void handle_sleeping_state(bool new_state) {
GGML_ASSERT(sleeping != new_state);
if (new_state) {
SRV_INF("%s", "server is entering sleeping state\n");
destroy();
} else {
SRV_INF("%s", "server is exiting sleeping state\n");
if (!load_model(params_base)) {
GGML_ABORT("failed to reload model after sleeping");
}
}
sleeping = new_state;
}
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(const common_params & params) {
bool is_resume = sleeping;
SRV_INF("loading model '%s'\n", params.model.path.c_str());
params_base = params;
webui_settings = json::object();
if (!params_base.webui_config_json.empty()) {
try {
webui_settings = json::parse(params_base.webui_config_json);
} catch (const std::exception & e) {
SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
return false;
}
}
llama_init = common_init_from_params(params_base);
model = llama_init->model();
@ -654,7 +678,9 @@ struct server_context_impl {
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
if (!is_resume) {
mtmd_helper_log_set(common_log_default_callback, nullptr);
}
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
@ -699,19 +725,6 @@ struct server_context_impl {
}
}
return true;
}
// initialize slots and server-related data
void init() {
// wiring up server queues
queue_tasks.on_new_task([this](server_task && task) {
process_single_task(std::move(task));
});
queue_tasks.on_update_slots([this]() {
update_slots();
});
// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
@ -726,6 +739,7 @@ struct server_context_impl {
n_ctx_slot = n_ctx_train;
}
slots.clear();
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -742,13 +756,13 @@ struct server_context_impl {
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n");
return;
return false;
}
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
if (slot.spec == nullptr) {
SRV_ERR("%s", "failed to create speculator\n");
return;
return false;
}
for (auto & pair : params_base.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
@ -782,8 +796,6 @@ struct server_context_impl {
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
metrics.init();
if (params_base.cache_ram_mib != 0) {
if (params_base.cache_ram_mib < 0) {
SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
@ -832,6 +844,103 @@ struct server_context_impl {
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
common_chat_templates_source(chat_templates.get()),
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
if (!is_resume) {
return init();
}
return true;
}
// unlike load_model(), this is only called once during initialization
bool init() {
GGML_ASSERT(ctx != nullptr);
GGML_ASSERT(model != nullptr);
GGML_ASSERT(!sleeping);
// wiring up server queues
queue_tasks.on_new_task([this](server_task && task) {
process_single_task(std::move(task));
});
queue_tasks.on_update_slots([this]() {
update_slots();
});
queue_tasks.on_sleeping_state([this](bool sleeping) {
handle_sleeping_state(sleeping);
});
metrics.init();
if (!populate_json_responses()) {
SRV_ERR("%s", "failed to populate JSON responses\n");
return false;
}
return true;
}
bool populate_json_responses() {
// populate webui settings
json json_webui_settings = json::object();
{
if (!params_base.webui_config_json.empty()) {
try {
json_webui_settings = json::parse(params_base.webui_config_json);
} catch (const std::exception & e) {
SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
return false;
}
}
}
// populate server properties
{
task_params params;
params.sampling = params_base.sampling;
json default_generation_settings_for_props = json {
{"params", params.to_json(true)},
{"n_ctx", get_slot_n_ctx()},
};
json_server_props = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", params_base.n_parallel },
{ "model_alias", model_name },
{ "model_path", params_base.model.path },
{ "modalities", json {
{"vision", oai_parser_opt.allow_image},
{"audio", oai_parser_opt.allow_audio},
} },
{ "endpoint_slots", params_base.endpoint_slots },
{ "endpoint_props", params_base.endpoint_props },
{ "endpoint_metrics", params_base.endpoint_metrics },
{ "webui", params_base.webui },
{ "webui_settings", json_webui_settings },
{ "chat_template", common_chat_templates_source(chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx, llama_vocab_bos(vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx, llama_vocab_eos(vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(chat_templates.get(), "tool_use")) {
json_server_props["chat_template_tool_use"] = tool_use_src;
}
}
}
// populate model metadata
{
json_server_model_meta = {
{"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_vocab_n_tokens (vocab)},
{"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)},
};
}
return true;
}
server_slot * get_slot_by_id(int id) {
@ -2635,17 +2744,6 @@ struct server_context_impl {
SRV_DBG("%s", "run slots completed\n");
}
json model_meta() const {
return json {
{"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_vocab_n_tokens (vocab)},
{"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)},
};
}
int get_slot_n_ctx() {
return slots.back().n_ctx;
}
@ -2662,16 +2760,13 @@ struct server_context_impl {
server_context::server_context() : impl(new server_context_impl()) {}
server_context::~server_context() = default;
void server_context::init() {
impl->init();
}
bool server_context::load_model(const common_params & params) {
return impl->load_model(params);
}
void server_context::start_loop() {
impl->queue_tasks.start_loop();
auto & params = impl->params_base;
impl->queue_tasks.start_loop(params.sleep_idle_seconds * 1000);
}
void server_context::terminate() {
@ -2698,10 +2793,17 @@ server_context_info server_context::get_info() const {
// generator-like API for HTTP response generation
// may have bypass_sleep = true if the task does not use ctx_server
struct server_res_generator : server_http_res {
server_response_reader rd;
server_res_generator(server_context_impl & ctx_server)
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
server_res_generator(server_context_impl & ctx_server, bool bypass_sleep = false)
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {
// fast path in case sleeping is disabled
bypass_sleep |= ctx_server.params_base.sleep_idle_seconds < 0;
if (!bypass_sleep) {
ctx_server.queue_tasks.wait_until_no_sleep();
}
}
void ok(const json & response_data) {
status = 200;
data = safe_json_to_str(response_data);
@ -2719,6 +2821,7 @@ struct server_res_generator : server_http_res {
//
static std::unique_ptr<server_res_generator> handle_completions_impl(
std::unique_ptr<server_res_generator> && res_ptr,
server_context_impl & ctx_server,
server_task_type type,
const json & data,
@ -2727,7 +2830,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
task_response_type res_type) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
auto res = std::make_unique<server_res_generator>(ctx_server);
auto res = std::move(res_ptr);
auto completion_id = gen_chatcmplid();
auto & rd = res->rd;
@ -2931,9 +3034,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
}
void server_routes::init_routes() {
// IMPORTANT: all lambda functions must start with std::make_unique<server_res_generator>
// this is to ensure that the server_res_generator can handle sleeping case correctly
this->get_health = [this](const server_http_req &) {
// error and loading states are handled by middleware
auto res = std::make_unique<server_res_generator>(ctx_server);
auto res = std::make_unique<server_res_generator>(ctx_server, true);
res->ok({{"status", "ok"}});
return res;
};
@ -3115,46 +3221,10 @@ void server_routes::init_routes() {
};
this->get_props = [this](const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json default_generation_settings_for_props;
{
task_params params;
params.sampling = ctx_server.params_base.sampling;
default_generation_settings_for_props = json {
{"params", params.to_json(true)},
{"n_ctx", ctx_server.get_slot_n_ctx()},
};
}
json data = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_alias", ctx_server.model_name },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json {
{"vision", ctx_server.oai_parser_opt.allow_image},
{"audio", ctx_server.oai_parser_opt.allow_audio},
} },
{ "endpoint_slots", params.endpoint_slots },
{ "endpoint_props", params.endpoint_props },
{ "endpoint_metrics", params.endpoint_metrics },
{ "webui", params.webui },
{ "webui_settings", ctx_server.webui_settings },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
res->ok(data);
auto res = std::make_unique<server_res_generator>(ctx_server, true);
auto props = ctx_server.json_server_props;
props["is_sleeping"] = ctx_server.queue_tasks.is_sleeping();
res->ok(props);
return res;
};
@ -3272,6 +3342,7 @@ void server_routes::init_routes() {
std::vector<raw_buffer> files; // dummy
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_INFILL,
data,
@ -3281,9 +3352,11 @@ void server_routes::init_routes() {
};
this->post_completions = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files; // dummy
const json body = json::parse(req.body);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body,
@ -3293,9 +3366,11 @@ void server_routes::init_routes() {
};
this->post_completions_oai = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files; // dummy
const json body = json::parse(req.body);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body,
@ -3305,6 +3380,7 @@ void server_routes::init_routes() {
};
this->post_chat_completions = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files;
json body = json::parse(req.body);
json body_parsed = oaicompat_chat_params_parse(
@ -3312,6 +3388,7 @@ void server_routes::init_routes() {
ctx_server.oai_parser_opt,
files);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body_parsed,
@ -3321,6 +3398,7 @@ void server_routes::init_routes() {
};
this->post_anthropic_messages = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files;
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
@ -3328,6 +3406,7 @@ void server_routes::init_routes() {
ctx_server.oai_parser_opt,
files);
return handle_completions_impl(
std::move(res),
ctx_server,
SERVER_TASK_TYPE_COMPLETION,
body_parsed,
@ -3365,11 +3444,13 @@ void server_routes::init_routes() {
return res;
};
// TODO: this endpoint is unsafe to access during model reloading (i.e. wake up from sleeping)
// how to make it work even during load_model()?
this->get_models = [this](const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json model_meta = nullptr;
if (is_ready()) {
model_meta = ctx_server.model_meta();
model_meta = ctx_server.json_server_model_meta;
}
bool has_mtmd = ctx_server.mctx != nullptr;
json models = {