server: add auto-sleep after N seconds of idle (#18228)
* implement sleeping at queue level * implement server-context suspend * add test * add docs * optimization: add fast path * make sure to free llama_init * nits * fix use-after-free * allow /models to be accessed during sleeping, fix use-after-free * don't allow accessing /models during sleep, it is not thread-safe * fix data race on accessing props and model_meta * small clean up * trailing whitespace * rm outdated comments
This commit is contained in:
parent
52ab19df63
commit
ddcb75dd8a
12 changed files with 355 additions and 122 deletions
|
|
@ -544,7 +544,9 @@ struct server_context_impl {
|
|||
|
||||
server_metrics metrics;
|
||||
|
||||
json webui_settings = json::object();
|
||||
// cached responses for HTTP API (read-only from HTTP threads)
|
||||
json json_server_props = json::object();
|
||||
json json_server_model_meta = json::object();
|
||||
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
|
@ -554,8 +556,23 @@ struct server_context_impl {
|
|||
common_chat_templates_ptr chat_templates;
|
||||
oaicompat_parser_options oai_parser_opt;
|
||||
|
||||
bool sleeping = false;
|
||||
|
||||
~server_context_impl() {
|
||||
if (!sleeping) {
|
||||
// destroy() is already called when entering sleeping state
|
||||
// we don't call it again here to avoid double free
|
||||
destroy();
|
||||
}
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
llama_init.reset();
|
||||
ctx = nullptr;
|
||||
model = nullptr;
|
||||
|
||||
mtmd_free(mctx);
|
||||
mctx = nullptr;
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
|
|
@ -571,22 +588,29 @@ struct server_context_impl {
|
|||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
void handle_sleeping_state(bool new_state) {
|
||||
GGML_ASSERT(sleeping != new_state);
|
||||
if (new_state) {
|
||||
SRV_INF("%s", "server is entering sleeping state\n");
|
||||
destroy();
|
||||
} else {
|
||||
SRV_INF("%s", "server is exiting sleeping state\n");
|
||||
if (!load_model(params_base)) {
|
||||
GGML_ABORT("failed to reload model after sleeping");
|
||||
}
|
||||
}
|
||||
sleeping = new_state;
|
||||
}
|
||||
|
||||
// load the model and initialize llama_context
|
||||
// this may also be called to resume from sleeping state
|
||||
bool load_model(const common_params & params) {
|
||||
bool is_resume = sleeping;
|
||||
|
||||
SRV_INF("loading model '%s'\n", params.model.path.c_str());
|
||||
|
||||
params_base = params;
|
||||
|
||||
webui_settings = json::object();
|
||||
if (!params_base.webui_config_json.empty()) {
|
||||
try {
|
||||
webui_settings = json::parse(params_base.webui_config_json);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
llama_init = common_init_from_params(params_base);
|
||||
|
||||
model = llama_init->model();
|
||||
|
|
@ -654,7 +678,9 @@ struct server_context_impl {
|
|||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
if (!is_resume) {
|
||||
mtmd_helper_log_set(common_log_default_callback, nullptr);
|
||||
}
|
||||
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
|
|
@ -699,19 +725,6 @@ struct server_context_impl {
|
|||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// initialize slots and server-related data
|
||||
void init() {
|
||||
// wiring up server queues
|
||||
queue_tasks.on_new_task([this](server_task && task) {
|
||||
process_single_task(std::move(task));
|
||||
});
|
||||
queue_tasks.on_update_slots([this]() {
|
||||
update_slots();
|
||||
});
|
||||
|
||||
// Necessary similarity of prompt for slot selection
|
||||
slot_prompt_similarity = params_base.slot_prompt_similarity;
|
||||
|
||||
|
|
@ -726,6 +739,7 @@ struct server_context_impl {
|
|||
n_ctx_slot = n_ctx_train;
|
||||
}
|
||||
|
||||
slots.clear();
|
||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||
server_slot slot;
|
||||
|
||||
|
|
@ -742,13 +756,13 @@ struct server_context_impl {
|
|||
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
|
||||
if (slot.ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create draft context\n");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft);
|
||||
if (slot.spec == nullptr) {
|
||||
SRV_ERR("%s", "failed to create speculator\n");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
for (auto & pair : params_base.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
||||
|
|
@ -782,8 +796,6 @@ struct server_context_impl {
|
|||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
|
||||
if (params_base.cache_ram_mib != 0) {
|
||||
if (params_base.cache_ram_mib < 0) {
|
||||
SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit");
|
||||
|
|
@ -832,6 +844,103 @@ struct server_context_impl {
|
|||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||
common_chat_templates_source(chat_templates.get()),
|
||||
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
|
||||
|
||||
if (!is_resume) {
|
||||
return init();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// unlike load_model(), this is only called once during initialization
|
||||
bool init() {
|
||||
GGML_ASSERT(ctx != nullptr);
|
||||
GGML_ASSERT(model != nullptr);
|
||||
GGML_ASSERT(!sleeping);
|
||||
|
||||
// wiring up server queues
|
||||
queue_tasks.on_new_task([this](server_task && task) {
|
||||
process_single_task(std::move(task));
|
||||
});
|
||||
queue_tasks.on_update_slots([this]() {
|
||||
update_slots();
|
||||
});
|
||||
queue_tasks.on_sleeping_state([this](bool sleeping) {
|
||||
handle_sleeping_state(sleeping);
|
||||
});
|
||||
|
||||
metrics.init();
|
||||
|
||||
if (!populate_json_responses()) {
|
||||
SRV_ERR("%s", "failed to populate JSON responses\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool populate_json_responses() {
|
||||
// populate webui settings
|
||||
json json_webui_settings = json::object();
|
||||
{
|
||||
if (!params_base.webui_config_json.empty()) {
|
||||
try {
|
||||
json_webui_settings = json::parse(params_base.webui_config_json);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// populate server properties
|
||||
{
|
||||
task_params params;
|
||||
params.sampling = params_base.sampling;
|
||||
json default_generation_settings_for_props = json {
|
||||
{"params", params.to_json(true)},
|
||||
{"n_ctx", get_slot_n_ctx()},
|
||||
};
|
||||
|
||||
json_server_props = {
|
||||
{ "default_generation_settings", default_generation_settings_for_props },
|
||||
{ "total_slots", params_base.n_parallel },
|
||||
{ "model_alias", model_name },
|
||||
{ "model_path", params_base.model.path },
|
||||
{ "modalities", json {
|
||||
{"vision", oai_parser_opt.allow_image},
|
||||
{"audio", oai_parser_opt.allow_audio},
|
||||
} },
|
||||
{ "endpoint_slots", params_base.endpoint_slots },
|
||||
{ "endpoint_props", params_base.endpoint_props },
|
||||
{ "endpoint_metrics", params_base.endpoint_metrics },
|
||||
{ "webui", params_base.webui },
|
||||
{ "webui_settings", json_webui_settings },
|
||||
{ "chat_template", common_chat_templates_source(chat_templates.get()) },
|
||||
{ "bos_token", common_token_to_piece(ctx, llama_vocab_bos(vocab), /* special= */ true)},
|
||||
{ "eos_token", common_token_to_piece(ctx, llama_vocab_eos(vocab), /* special= */ true)},
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
if (params_base.use_jinja) {
|
||||
if (auto tool_use_src = common_chat_templates_source(chat_templates.get(), "tool_use")) {
|
||||
json_server_props["chat_template_tool_use"] = tool_use_src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// populate model metadata
|
||||
{
|
||||
json_server_model_meta = {
|
||||
{"vocab_type", llama_vocab_type (vocab)},
|
||||
{"n_vocab", llama_vocab_n_tokens (vocab)},
|
||||
{"n_ctx_train", llama_model_n_ctx_train(model)},
|
||||
{"n_embd", llama_model_n_embd (model)},
|
||||
{"n_params", llama_model_n_params (model)},
|
||||
{"size", llama_model_size (model)},
|
||||
};
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
server_slot * get_slot_by_id(int id) {
|
||||
|
|
@ -2635,17 +2744,6 @@ struct server_context_impl {
|
|||
SRV_DBG("%s", "run slots completed\n");
|
||||
}
|
||||
|
||||
json model_meta() const {
|
||||
return json {
|
||||
{"vocab_type", llama_vocab_type (vocab)},
|
||||
{"n_vocab", llama_vocab_n_tokens (vocab)},
|
||||
{"n_ctx_train", llama_model_n_ctx_train(model)},
|
||||
{"n_embd", llama_model_n_embd (model)},
|
||||
{"n_params", llama_model_n_params (model)},
|
||||
{"size", llama_model_size (model)},
|
||||
};
|
||||
}
|
||||
|
||||
int get_slot_n_ctx() {
|
||||
return slots.back().n_ctx;
|
||||
}
|
||||
|
|
@ -2662,16 +2760,13 @@ struct server_context_impl {
|
|||
server_context::server_context() : impl(new server_context_impl()) {}
|
||||
server_context::~server_context() = default;
|
||||
|
||||
void server_context::init() {
|
||||
impl->init();
|
||||
}
|
||||
|
||||
bool server_context::load_model(const common_params & params) {
|
||||
return impl->load_model(params);
|
||||
}
|
||||
|
||||
void server_context::start_loop() {
|
||||
impl->queue_tasks.start_loop();
|
||||
auto & params = impl->params_base;
|
||||
impl->queue_tasks.start_loop(params.sleep_idle_seconds * 1000);
|
||||
}
|
||||
|
||||
void server_context::terminate() {
|
||||
|
|
@ -2698,10 +2793,17 @@ server_context_info server_context::get_info() const {
|
|||
|
||||
|
||||
// generator-like API for HTTP response generation
|
||||
// may have bypass_sleep = true if the task does not use ctx_server
|
||||
struct server_res_generator : server_http_res {
|
||||
server_response_reader rd;
|
||||
server_res_generator(server_context_impl & ctx_server)
|
||||
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
|
||||
server_res_generator(server_context_impl & ctx_server, bool bypass_sleep = false)
|
||||
: rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {
|
||||
// fast path in case sleeping is disabled
|
||||
bypass_sleep |= ctx_server.params_base.sleep_idle_seconds < 0;
|
||||
if (!bypass_sleep) {
|
||||
ctx_server.queue_tasks.wait_until_no_sleep();
|
||||
}
|
||||
}
|
||||
void ok(const json & response_data) {
|
||||
status = 200;
|
||||
data = safe_json_to_str(response_data);
|
||||
|
|
@ -2719,6 +2821,7 @@ struct server_res_generator : server_http_res {
|
|||
//
|
||||
|
||||
static std::unique_ptr<server_res_generator> handle_completions_impl(
|
||||
std::unique_ptr<server_res_generator> && res_ptr,
|
||||
server_context_impl & ctx_server,
|
||||
server_task_type type,
|
||||
const json & data,
|
||||
|
|
@ -2727,7 +2830,7 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
task_response_type res_type) {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
auto res = std::move(res_ptr);
|
||||
auto completion_id = gen_chatcmplid();
|
||||
auto & rd = res->rd;
|
||||
|
||||
|
|
@ -2931,9 +3034,12 @@ static std::unique_ptr<server_res_generator> handle_completions_impl(
|
|||
}
|
||||
|
||||
void server_routes::init_routes() {
|
||||
// IMPORTANT: all lambda functions must start with std::make_unique<server_res_generator>
|
||||
// this is to ensure that the server_res_generator can handle sleeping case correctly
|
||||
|
||||
this->get_health = [this](const server_http_req &) {
|
||||
// error and loading states are handled by middleware
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server, true);
|
||||
res->ok({{"status", "ok"}});
|
||||
return res;
|
||||
};
|
||||
|
|
@ -3115,46 +3221,10 @@ void server_routes::init_routes() {
|
|||
};
|
||||
|
||||
this->get_props = [this](const server_http_req &) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
json default_generation_settings_for_props;
|
||||
|
||||
{
|
||||
task_params params;
|
||||
|
||||
params.sampling = ctx_server.params_base.sampling;
|
||||
|
||||
default_generation_settings_for_props = json {
|
||||
{"params", params.to_json(true)},
|
||||
{"n_ctx", ctx_server.get_slot_n_ctx()},
|
||||
};
|
||||
}
|
||||
|
||||
json data = {
|
||||
{ "default_generation_settings", default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_alias", ctx_server.model_name },
|
||||
{ "model_path", ctx_server.params_base.model.path },
|
||||
{ "modalities", json {
|
||||
{"vision", ctx_server.oai_parser_opt.allow_image},
|
||||
{"audio", ctx_server.oai_parser_opt.allow_audio},
|
||||
} },
|
||||
{ "endpoint_slots", params.endpoint_slots },
|
||||
{ "endpoint_props", params.endpoint_props },
|
||||
{ "endpoint_metrics", params.endpoint_metrics },
|
||||
{ "webui", params.webui },
|
||||
{ "webui_settings", ctx_server.webui_settings },
|
||||
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
||||
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
|
||||
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
if (ctx_server.params_base.use_jinja) {
|
||||
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
|
||||
data["chat_template_tool_use"] = tool_use_src;
|
||||
}
|
||||
}
|
||||
|
||||
res->ok(data);
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server, true);
|
||||
auto props = ctx_server.json_server_props;
|
||||
props["is_sleeping"] = ctx_server.queue_tasks.is_sleeping();
|
||||
res->ok(props);
|
||||
return res;
|
||||
};
|
||||
|
||||
|
|
@ -3272,6 +3342,7 @@ void server_routes::init_routes() {
|
|||
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
return handle_completions_impl(
|
||||
std::move(res),
|
||||
ctx_server,
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
data,
|
||||
|
|
@ -3281,9 +3352,11 @@ void server_routes::init_routes() {
|
|||
};
|
||||
|
||||
this->post_completions = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
const json body = json::parse(req.body);
|
||||
return handle_completions_impl(
|
||||
std::move(res),
|
||||
ctx_server,
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body,
|
||||
|
|
@ -3293,9 +3366,11 @@ void server_routes::init_routes() {
|
|||
};
|
||||
|
||||
this->post_completions_oai = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
const json body = json::parse(req.body);
|
||||
return handle_completions_impl(
|
||||
std::move(res),
|
||||
ctx_server,
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body,
|
||||
|
|
@ -3305,6 +3380,7 @@ void server_routes::init_routes() {
|
|||
};
|
||||
|
||||
this->post_chat_completions = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
std::vector<raw_buffer> files;
|
||||
json body = json::parse(req.body);
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
|
|
@ -3312,6 +3388,7 @@ void server_routes::init_routes() {
|
|||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
return handle_completions_impl(
|
||||
std::move(res),
|
||||
ctx_server,
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
|
|
@ -3321,6 +3398,7 @@ void server_routes::init_routes() {
|
|||
};
|
||||
|
||||
this->post_anthropic_messages = [this](const server_http_req & req) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
std::vector<raw_buffer> files;
|
||||
json body = convert_anthropic_to_oai(json::parse(req.body));
|
||||
json body_parsed = oaicompat_chat_params_parse(
|
||||
|
|
@ -3328,6 +3406,7 @@ void server_routes::init_routes() {
|
|||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
return handle_completions_impl(
|
||||
std::move(res),
|
||||
ctx_server,
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
body_parsed,
|
||||
|
|
@ -3365,11 +3444,13 @@ void server_routes::init_routes() {
|
|||
return res;
|
||||
};
|
||||
|
||||
// TODO: this endpoint is unsafe to access during model reloading (i.e. wake up from sleeping)
|
||||
// how to make it work even during load_model()?
|
||||
this->get_models = [this](const server_http_req &) {
|
||||
auto res = std::make_unique<server_res_generator>(ctx_server);
|
||||
json model_meta = nullptr;
|
||||
if (is_ready()) {
|
||||
model_meta = ctx_server.model_meta();
|
||||
model_meta = ctx_server.json_server_model_meta;
|
||||
}
|
||||
bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
json models = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue