metal : add FA specialization for HSK = 320, HSV = 256 (#20549)

This commit is contained in:
Georgi Gerganov 2026-03-14 23:15:47 +02:00 committed by GitHub
parent b4768955c4
commit b30a5fdf37
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 4 deletions

View file

@ -8576,11 +8576,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 576 }) {
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 576 }) {
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue;
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
if (hsk == 320 && hsv != 256) continue; // MLA
for (bool mask : { true, false } ) {
for (bool sinks : { true, false } ) {
@ -8589,12 +8590,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (float logit_softcap : {0.0f, 10.0f}) {
if (hsk != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 1, 4 }) {
if (nh == 1 && hsk != 576) continue; // GLM 4.7 Flash
if (nh == 1 && hsk != 320 && hsk != 576) continue; // GLM 4.7 Flash
for (int nr3 : { 1, 3, }) {
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
for (int nr2 : { 1, 4, 12, 20 }) {
for (int nr2 : { 1, 4, 12, 20, 32 }) {
if (nr2 == 12 && hsk != 128) continue;
if (nr2 == 20 && (nh != 1 || hsk != 576)) continue;
if (nr2 == 32 && (nh != 1 || hsk != 320)) continue;
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
for (int kv : { 113, 512, 1024, }) {
if (nr2 != 1 && kv != 512) continue;