llama-fit-params: free memory target per device (#18679)
This commit is contained in:
parent
9a5724dee2
commit
64848deb18
6 changed files with 83 additions and 39 deletions
|
|
@ -2255,7 +2255,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
std::vector<std::string> split_arg{ it, {} };
|
||||
if (split_arg.size() >= llama_max_devices()) {
|
||||
throw std::invalid_argument(
|
||||
string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices())
|
||||
string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices())
|
||||
);
|
||||
}
|
||||
for (size_t i = 0; i < llama_max_devices(); ++i) {
|
||||
|
|
@ -2295,10 +2295,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_env("LLAMA_ARG_FIT"));
|
||||
add_opt(common_arg(
|
||||
{ "-fitt", "--fit-target" }, "MiB",
|
||||
string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)),
|
||||
[](common_params & params, int value) {
|
||||
params.fit_params_target = value * size_t(1024*1024);
|
||||
{ "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...",
|
||||
string_format("target margin per device for --fit, comma-separated list of values, "
|
||||
"single value is broadcast across all devices, default: %zu", params.fit_params_target[0]/(1024*1024)),
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::string arg_next = value;
|
||||
|
||||
// split string by , and /
|
||||
const std::regex regex{ R"([,/]+)" };
|
||||
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
|
||||
std::vector<std::string> split_arg{ it, {} };
|
||||
if (split_arg.size() >= llama_max_devices()) {
|
||||
throw std::invalid_argument(
|
||||
string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices())
|
||||
);
|
||||
}
|
||||
if (split_arg.size() == 1) {
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < split_arg.size(); i++) {
|
||||
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_FIT_TARGET"));
|
||||
add_opt(common_arg(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue