tests : add unit test coverage for llama_tensor_get_type (#20112)

* Add unit test coverage for llama_tensor_get_type

* Fix merge conflicts, add more schemas

* clang formatter changes

* Trailing whitespace

* Update name

* Start rebase

* Updating files with upstream changes prior to rebase

* Changes needed from rebase

* Update attn_qkv schema, change throw behaviour

* Fix merge conflicts

* White space

* Update with latest changes to state counters

* Revert accidental personal CLAUDE.md changes

* Change quotation mark

* Reuse metadata.name since we have it

* Move test-only stuff out of llama-quant.cpp

* Hide the regex functionality back in llama-quant.cpp, use a unique pointer to a new struct 'compiled_tensor_type_patterns' which contains the patterns

* cont : inital deslop guidelines

* Cleanup based on review comments

* Continue cleanup

* Small cleanup

* Manually set proper ordering of tensors, mostly applies to gemma

* Formatting

* Update tests/test-quant-type-selection.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Fix merge conflicts

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Bartowski 2026-04-02 16:53:58 -04:00 committed by GitHub
parent a1cfb64530
commit 7992aa7c8e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 35301 additions and 51 deletions

1
tests/.gitignore vendored
View file

@ -1,5 +1,6 @@
*
!*.*
!snapshots/
*.o
ggml-common.h
**/*.swp

View file

@ -274,6 +274,12 @@ if (TARGET cpp-httplib)
add_executable(test-gguf-model-data test-gguf-model-data.cpp)
target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common)
llama_test(test-gguf-model-data LABEL "model")
# test-quant-type-selection requires gguf-model-data for remote model metadata
llama_build_and_test(test-quant-type-selection.cpp LABEL "model")
target_link_libraries(test-quant-type-selection PRIVATE gguf-model-data)
target_compile_definitions(test-quant-type-selection PRIVATE
SNAPSHOT_DIR="${CMAKE_CURRENT_SOURCE_DIR}/snapshots")
endif()
endif()

View file

@ -125,6 +125,35 @@ static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
}
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
// Handle array-valued fields (e.g. per-layer head counts in hybrid models)
// by reading the first element as a representative value.
if (vtype == GGUF_TYPE_ARRAY) {
int32_t elem_type;
uint64_t count;
if (!r.read_val(elem_type)) {
return false;
}
if (!r.read_val(count)) {
return false;
}
if (count == 0) {
return false;
}
// Read first element, skip the rest
if (!gguf_read_uint32_val(r, elem_type, out)) {
return false;
}
for (uint64_t i = 1; i < count; i++) {
size_t sz = gguf_val_type_size(elem_type);
if (sz == 0) {
return false;
}
if (!r.skip(sz)) {
return false;
}
}
return true;
}
if (vtype == GGUF_TYPE_UINT8) {
uint8_t v;
if (!r.read_val(v)) {
@ -487,7 +516,8 @@ static std::string detect_gguf_filename(const std::string & repo, const std::str
static std::optional<gguf_remote_model> fetch_and_parse(
const std::string & repo,
const std::string & filename,
const std::string & cache_path) {
const std::string & cache_path,
bool verbose) {
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
// Progressive download inspired by RangeView.fetchChunk()
@ -496,7 +526,9 @@ static std::optional<gguf_remote_model> fetch_and_parse(
const size_t max_chunk = 64 * 1024 * 1024;
while (chunk_size <= max_chunk) {
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
if (verbose) {
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
}
char range_buf[64];
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
@ -542,7 +574,8 @@ static std::optional<gguf_remote_model> fetch_or_cached(
const std::string & repo,
const std::string & filename,
const std::string & cdir,
const std::string & repo_part) {
const std::string & repo_part,
bool verbose) {
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
{
@ -550,20 +583,23 @@ static std::optional<gguf_remote_model> fetch_or_cached(
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
auto result = gguf_parse_meta(cached);
if (result.has_value()) {
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
if (verbose) {
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
}
return result;
}
}
}
fs_create_directory_with_parents(cdir);
return fetch_and_parse(repo, filename, cache_path);
return fetch_and_parse(repo, filename, cache_path, verbose);
}
std::optional<gguf_remote_model> gguf_fetch_model_meta(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir) {
const std::string & cache_dir,
bool verbose) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
@ -573,7 +609,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return std::nullopt;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return std::nullopt;
@ -588,8 +624,10 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
return std::nullopt;
}
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
if (verbose) {
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
}
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
@ -597,7 +635,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
return std::nullopt;
@ -620,7 +658,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
gguf_context_ptr gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant,
const std::string & cache_dir) {
const std::string & cache_dir,
bool verbose) {
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
std::string repo_part = sanitize_for_path(repo);
@ -631,7 +670,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
return nullptr;
}
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
if (!model_opt.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
return nullptr;
@ -659,8 +698,10 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
return nullptr;
}
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
if (verbose) {
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
model.n_split, model.n_split - 1);
}
for (int i = 2; i <= model.n_split; i++) {
char num_buf[6], total_buf[6];
@ -668,7 +709,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
if (!shard.has_value()) {
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
return nullptr;

View file

@ -40,9 +40,11 @@ struct gguf_remote_model {
std::optional<gguf_remote_model> gguf_fetch_model_meta(
const std::string & repo,
const std::string & quant = "Q8_0",
const std::string & cache_dir = ""); // empty = default
const std::string & cache_dir = "", // empty = default
bool verbose = true);
gguf_context_ptr gguf_fetch_gguf_ctx(
const std::string & repo,
const std::string & quant = "Q8_0",
const std::string & cache_dir = "");
const std::string & cache_dir = "",
bool verbose = true);

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -116,6 +116,39 @@ int main() {
// Verify tensor count
TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780");
// Test a hybrid-attention model with array-valued head counts
auto result4 = gguf_fetch_model_meta("ggml-org/Step-3.5-Flash-GGUF", "Q4_K");
if (!result4.has_value()) {
fprintf(stderr, "FAIL: could not fetch Step-3.5-Flash metadata\n");
return 1;
}
const auto & model4 = result4.value();
fprintf(stderr, "Architecture: %s\n", model4.architecture.c_str());
fprintf(stderr, "n_embd: %u\n", model4.n_embd);
fprintf(stderr, "n_ff: %u\n", model4.n_ff);
fprintf(stderr, "n_vocab: %u\n", model4.n_vocab);
fprintf(stderr, "n_layer: %u\n", model4.n_layer);
fprintf(stderr, "n_head: %u\n", model4.n_head);
fprintf(stderr, "n_head_kv: %u\n", model4.n_head_kv);
fprintf(stderr, "n_expert: %u\n", model4.n_expert);
fprintf(stderr, "n_embd_head_k: %u\n", model4.n_embd_head_k);
fprintf(stderr, "n_embd_head_v: %u\n", model4.n_embd_head_v);
fprintf(stderr, "tensors: %zu\n", model4.tensors.size());
TEST_ASSERT(model4.architecture == "step35", "expected architecture 'step35'");
TEST_ASSERT(model4.n_layer == 45, "expected n_layer == 45");
TEST_ASSERT(model4.n_embd == 4096, "expected n_embd == 4096");
TEST_ASSERT(model4.n_ff == 11264, "expected n_ff == 11264");
TEST_ASSERT(model4.n_head == 64, "expected n_head == 64 (first element of per-layer array)");
TEST_ASSERT(model4.n_head_kv == 8, "expected n_head_kv == 8 (first element of per-layer array)");
TEST_ASSERT(model4.n_expert == 288, "expected n_expert == 288");
TEST_ASSERT(model4.n_embd_head_k == 128, "expected n_embd_head_k == 128");
TEST_ASSERT(model4.n_embd_head_v == 128, "expected n_embd_head_v == 128");
TEST_ASSERT(model4.n_vocab == 128896, "expected n_vocab == 128896");
TEST_ASSERT(model4.tensors.size() == 754, "expected tensor count == 754");
fprintf(stderr, "=== ALL TESTS PASSED ===\n");
return 0;
}

View file

@ -0,0 +1,520 @@
#include "../src/llama-ext.h"
#include "ggml-cpp.h"
#include "gguf-model-data.h"
#include "llama.h"
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
// ---------------------------------------------------------------------------
// ftype name <-> enum mapping
// ---------------------------------------------------------------------------
struct ftype_name_entry {
const char * name;
llama_ftype ftype;
};
static const ftype_name_entry ftype_name_table[] = {
{ "F32", LLAMA_FTYPE_ALL_F32 },
{ "F16", LLAMA_FTYPE_MOSTLY_F16 },
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16 },
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0 },
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1 },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0 },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1 },
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0 },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S },
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M },
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L },
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S },
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M },
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M },
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS },
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S },
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS },
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S },
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M },
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL },
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS },
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0 },
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0 },
{ "MXFP4_MOE", LLAMA_FTYPE_MOSTLY_MXFP4_MOE },
{ "NVFP4", LLAMA_FTYPE_MOSTLY_NVFP4 },
};
static llama_ftype llama_ftype_from_name(const char * name) {
for (const auto & e : ftype_name_table) {
if (strcmp(name, e.name) == 0) {
return e.ftype;
}
}
return (llama_ftype) -1;
}
static const char * llama_ftype_to_name(llama_ftype ftype) {
for (const auto & e : ftype_name_table) {
if (e.ftype == ftype) {
return e.name;
}
}
return nullptr;
}
// ---------------------------------------------------------------------------
// ggml_type name lookup
// ---------------------------------------------------------------------------
static ggml_type ggml_type_from_name(const std::string & name) {
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
const char * tname = ggml_type_name((ggml_type) i);
if (tname && name == tname) {
return (ggml_type) i;
}
}
return GGML_TYPE_COUNT;
}
// ---------------------------------------------------------------------------
// File parser for snapshot files (quant type schemas)
// ---------------------------------------------------------------------------
struct snapshot_section {
llama_ftype ftype;
ggml_type default_type;
std::vector<std::pair<std::string, ggml_type>> overrides;
};
// This function is pretty ugly, but it's a trade-off of readable snapshot files
// versus readable parsing code
static bool parse_snapshot_file(const std::string & path, std::vector<snapshot_section> & sections) {
std::ifstream f(path);
if (!f.good()) {
return false;
}
snapshot_section * cur = nullptr;
std::string line;
while (std::getline(f, line)) {
if (line.empty() || line[0] == '#') {
continue;
}
// section header: [FTYPE_NAME] default_type
if (line[0] == '[') {
auto close = line.find(']');
if (close == std::string::npos) {
fprintf(stderr, "parse error: missing ] in '%s'\n", line.c_str());
return false;
}
std::string ftype_str = line.substr(1, close - 1);
std::string default_str;
size_t pos = close + 1;
while (pos < line.size() && line[pos] == ' ') {
pos++;
}
default_str = line.substr(pos);
llama_ftype ftype = llama_ftype_from_name(ftype_str.c_str());
if ((int) ftype < 0) {
fprintf(stderr, "parse error: unknown ftype '%s'\n", ftype_str.c_str());
return false;
}
ggml_type dtype = ggml_type_from_name(default_str);
if (dtype == GGML_TYPE_COUNT) {
fprintf(stderr, "parse error: unknown default type '%s'\n", default_str.c_str());
return false;
}
sections.push_back({ ftype, dtype, {} });
cur = &sections.back();
continue;
}
if (!cur) {
fprintf(stderr, "parse error: tensor line before any section: '%s'\n", line.c_str());
return false;
}
auto sp = line.rfind(' ');
if (sp == std::string::npos) {
fprintf(stderr, "parse error: no space in tensor line: '%s'\n", line.c_str());
return false;
}
std::string tname = line.substr(0, sp);
std::string ttype = line.substr(sp + 1);
ggml_type gt = ggml_type_from_name(ttype);
if (gt == GGML_TYPE_COUNT) {
fprintf(stderr, "parse error: unknown type '%s' for tensor '%s'\n", ttype.c_str(), tname.c_str());
return false;
}
cur->overrides.push_back({ tname, gt });
}
return true;
}
// ---------------------------------------------------------------------------
// Remote model support using gguf-model-data.cpp
// ---------------------------------------------------------------------------
struct remote_model_spec {
const char * repo;
const char * quant;
};
// Get model name from repo: strip org prefix, strip -GGUF suffix,
// and strip anything up to and including first '_' (e.g. "deepseek-ai_DeepSeek-V3.1").
static std::string model_name_from_repo(const char * repo) {
std::string s(repo);
auto slash = s.find('/');
if (slash != std::string::npos) {
s = s.substr(slash + 1);
}
const std::string suffix = "-GGUF";
if (s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0) {
s = s.substr(0, s.size() - suffix.size());
}
auto underscore = s.find('_');
if (underscore != std::string::npos) {
s = s.substr(underscore + 1);
}
return s;
}
static std::string snapshot_file_from_name(const std::string & name) {
std::string lower = name;
for (auto & c : lower) {
c = std::tolower(c);
}
return lower;
}
static const remote_model_spec model_specs[] = {
{ "ggml-org/Qwen3-0.6B-GGUF", "Q8_0" },
{ "ggml-org/GLM-4.6V-GGUF", "Q8_0" },
{ "ggml-org/Step-3.5-Flash-GGUF", "Q4_K" },
{ "ggml-org/Qwen3-Coder-Next-GGUF", "Q8_0" },
{ "ggml-org/Qwen3-14B-GGUF", "Q8_0" },
{ "ggml-org/Nemotron-Nano-3-30B-A3B-GGUF", "Q8_0" },
{ "ggml-org/gpt-oss-120b-GGUF", "mxfp4" },
{ "ggml-org/gemma-3-4b-it-GGUF", "Q8_0" },
{ "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", "Q4_K_M" },
{ "bartowski/deepseek-ai_DeepSeek-V3.1-GGUF", "IQ1_M" },
{ "bartowski/Qwen_Qwen3.5-397B-A17B-GGUF", "IQ1_S" }, // TODO: swap with ggml-org if/when it's released
{ "bartowski/Qwen_Qwen3.5-27B-GGUF", "Q8_0" }, // TODO: swap with ggml-org if/when it's released
};
static const int n_model_specs = (int) (sizeof(model_specs) / sizeof(model_specs[0]));
static llama_model * build_mock_model_from_remote(const gguf_remote_model & remote) {
llama_quant_model_desc desc = {};
desc.architecture = remote.architecture.c_str();
desc.n_embd = remote.n_embd;
desc.n_ff = remote.n_ff;
desc.n_layer = remote.n_layer;
desc.n_head = remote.n_head;
desc.n_head_kv = remote.n_head_kv;
desc.n_expert = remote.n_expert;
desc.n_embd_head_k = remote.n_embd_head_k;
desc.n_embd_head_v = remote.n_embd_head_v;
return llama_quant_model_from_metadata(&desc);
}
// Single ggml context holding all quantizable tensors for a model.
struct mock_tensors {
ggml_context_ptr ctx;
std::vector<ggml_tensor *> tensors;
};
static mock_tensors build_mock_tensors(const quantize_state_impl * qs, const gguf_remote_model & remote) {
const size_t ctx_size = remote.tensors.size() * ggml_tensor_overhead();
struct ggml_init_params params = { ctx_size, nullptr, true };
ggml_context_ptr ctx(ggml_init(params));
std::vector<ggml_tensor *> result;
for (const auto & t : remote.tensors) {
ggml_tensor * gt = ggml_new_tensor_4d(ctx.get(), GGML_TYPE_F32, t.ne[0], t.ne[1], t.ne[2], t.ne[3]);
ggml_set_name(gt, t.name.c_str());
if (llama_quant_tensor_allows_quantization(qs, gt)) {
result.push_back(gt);
}
}
// sort by layer index then name, matching llama_model_loader::weight_name_comparer
std::sort(result.begin(), result.end(), [](const ggml_tensor * a, const ggml_tensor * b) {
int a_layer = -1, b_layer = -1;
sscanf(a->name, "blk.%d.", &a_layer);
sscanf(b->name, "blk.%d.", &b_layer);
if (a_layer != b_layer) {
return a_layer < b_layer;
}
return strcmp(a->name, b->name) < 0;
});
return { std::move(ctx), std::move(result) };
}
// ---------------------------------------------------------------------------
// Generate mode: regenerate all snapshot files
// Use this when either adding new models or modifying quants
// ---------------------------------------------------------------------------
static std::string generate_snapshot(const std::string & name,
const gguf_remote_model & remote,
quantize_state_impl * qs,
mock_tensors & mt) {
std::ostringstream out;
out << "# Model: " << name << "\n";
out << "# n_embd=" << remote.n_embd << ", n_ff=" << remote.n_ff << ", n_vocab=" << remote.n_vocab
<< ", n_layer=" << remote.n_layer << ", n_head=" << remote.n_head << ", n_head_kv=" << remote.n_head_kv;
if (remote.n_expert > 0) {
out << ", n_expert=" << remote.n_expert;
}
out << "\n";
for (int i = 0; i < LLAMA_FTYPE_GUESSED; i++) {
llama_ftype ft = (llama_ftype) i;
ggml_type default_type = llama_ftype_get_default_type(ft);
if (default_type == GGML_TYPE_COUNT) {
continue;
}
const char * fname = llama_ftype_to_name(ft);
if (!fname) {
continue;
}
std::vector<ggml_type> result_types(mt.tensors.size());
llama_quant_compute_types(qs, ft, mt.tensors.data(), result_types.data(), mt.tensors.size());
out << "\n[" << fname << "] " << ggml_type_name(default_type) << "\n";
for (size_t j = 0; j < mt.tensors.size(); j++) {
if (result_types[j] != default_type) {
out << ggml_get_name(mt.tensors[j]) << " " << ggml_type_name(result_types[j]) << "\n";
}
}
}
return out.str();
}
static int run_generate(const std::string & snapshot_dir) {
fprintf(stderr, "This will overwrite all snapshot files in:\n %s\n", snapshot_dir.c_str());
fprintf(stderr, "Continue? [y/N] ");
int ch = fgetc(stdin);
if (ch != 'y' && ch != 'Y') {
fprintf(stderr, "Aborted.\n");
return 1;
}
fprintf(stderr, "\n");
int n_written = 0;
for (int m = 0; m < n_model_specs; m++) {
const auto & spec = model_specs[m];
std::string name = model_name_from_repo(spec.repo);
fprintf(stderr, "Fetching model metadata for %s from %s...\n", name.c_str(), spec.repo);
auto result = gguf_fetch_model_meta(spec.repo, spec.quant);
if (!result.has_value()) {
fprintf(stderr, "ERROR: could not fetch model metadata for %s\n", name.c_str());
return 1;
}
const auto & remote = result.value();
llama_model * model = build_mock_model_from_remote(remote);
llama_model_quantize_params qparams = llama_model_quantize_default_params();
quantize_state_impl * qs = llama_quant_init(model, &qparams);
auto mt = build_mock_tensors(qs, remote);
std::string content = generate_snapshot(name, remote, qs, mt);
std::string path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
std::ofstream f(path);
if (!f.good()) {
fprintf(stderr, "ERROR: could not write %s\n", path.c_str());
llama_quant_free(qs);
llama_model_free(model);
return 1;
}
f << content;
n_written++;
fprintf(stderr, " wrote %s\n", path.c_str());
llama_quant_free(qs);
llama_model_free(model);
}
fprintf(stderr, "%d files written\n", n_written);
return 0;
}
// ---------------------------------------------------------------------------
// Test mode: compare against snapshot files
// ---------------------------------------------------------------------------
static bool run_test_section(quantize_state_impl * qs, mock_tensors & mt, const snapshot_section & section) {
// verify default_type matches what llama_ftype_get_default_type returns
ggml_type computed_default = llama_ftype_get_default_type(section.ftype);
if (computed_default != section.default_type) {
printf(" FAIL [%s] default type mismatch: file says %s, code says %s\n", llama_ftype_to_name(section.ftype),
ggml_type_name(section.default_type), ggml_type_name(computed_default));
return false;
}
std::vector<ggml_type> result_types(mt.tensors.size());
llama_quant_compute_types(qs, section.ftype, mt.tensors.data(), result_types.data(), mt.tensors.size());
std::map<std::string, ggml_type> override_map(section.overrides.begin(), section.overrides.end());
bool all_pass = true;
int n_override_found = 0;
for (size_t i = 0; i < mt.tensors.size(); i++) {
const char * name = ggml_get_name(mt.tensors[i]);
ggml_type got = result_types[i];
ggml_type expected = section.default_type;
auto it = override_map.find(name);
if (it != override_map.end()) {
expected = it->second;
n_override_found++;
}
if (got != expected) {
printf(" FAIL %-50s %-10s expected %s, got %s\n", name, llama_ftype_to_name(section.ftype),
ggml_type_name(expected), ggml_type_name(got));
all_pass = false;
}
}
if (n_override_found != (int) section.overrides.size()) {
printf(" FAIL [%s] override count mismatch: listed %d, matched %d\n", llama_ftype_to_name(section.ftype),
(int) section.overrides.size(), n_override_found);
all_pass = false;
}
return all_pass;
}
static int run_remote_tests(const std::string & snapshot_dir, const char * argv0) {
int total_pass = 0;
int total_fail = 0;
int total_skip = 0;
for (int m = 0; m < n_model_specs; m++) {
const auto & spec = model_specs[m];
std::string name = model_name_from_repo(spec.repo);
printf("=== %s ===\n", name.c_str());
auto result = gguf_fetch_model_meta(spec.repo, spec.quant, "", false);
if (!result.has_value()) {
printf(" SKIP (could not fetch model metadata)\n\n");
total_skip++;
continue;
}
const auto & remote = result.value();
llama_model * model = build_mock_model_from_remote(remote);
llama_model_quantize_params qparams = llama_model_quantize_default_params();
quantize_state_impl * qs = llama_quant_init(model, &qparams);
auto mt = build_mock_tensors(qs, remote);
std::string snapshot_path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
std::vector<snapshot_section> sections;
if (!parse_snapshot_file(snapshot_path, sections)) {
printf(" SKIP (could not read snapshot file: %s)\n\n", snapshot_path.c_str());
llama_quant_free(qs);
llama_model_free(model);
total_skip++;
continue;
}
int model_pass = 0;
int model_fail = 0;
for (const auto & section : sections) {
bool pass = run_test_section(qs, mt, section);
if (pass) {
model_pass++;
} else {
model_fail++;
}
}
printf(" %s %s: %d/%d ftype sections passed (%d tensors)\n", model_fail == 0 ? "PASS" : "FAIL", name.c_str(),
model_pass, model_pass + model_fail, (int) mt.tensors.size());
printf("\n");
if (model_fail == 0) {
total_pass++;
} else {
total_fail++;
}
llama_quant_free(qs);
llama_model_free(model);
}
printf("%d/%d models passed", total_pass, total_pass + total_fail);
if (total_skip > 0) {
printf(", %d skipped", total_skip);
}
printf("\n");
if (total_fail > 0) {
printf("\nIf these changes are intentional, regenerate snapshot files with:\n");
printf(" %s --generate\n", argv0);
}
return total_fail > 0 ? 1 : 0;
}
int main(int argc, char ** argv) {
std::string snapshot_dir = SNAPSHOT_DIR;
bool generate = false;
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "--generate") == 0) {
generate = true;
} else if (strcmp(argv[i], "--snapshot-dir") == 0 && i + 1 < argc) {
snapshot_dir = argv[++i];
}
}
if (generate) {
return run_generate(snapshot_dir);
}
// suppress llama log warnings during test (e.g. tensor type fallback messages)
llama_log_set([](enum ggml_log_level, const char *, void *) {}, nullptr);
return run_remote_tests(snapshot_dir, argv[0]);
}