* ggml-hexagon: add IQ4_NL and MXFP4 HMX matmul support - Add IQ4_NL quantization type support to Hexagon backend (buffer set/get tensor repack, mul_mat, mul_mat_id dispatch) - Implement HVX IQ4_NL vec_dot kernels (1x1, 2x1, 2x2) with LUT-based 4-bit index to int8 kvalue dequantization - Add MXFP4 HMX dequantization path with E8M0 scale conversion, including batch-4 fast path and single-tile fallback - Unify quantized row size / scale offset logic to handle Q4_0, Q8_0, IQ4_NL, and MXFP4 in the DMA fetch path * ggml-hexagon: fix SKIP_QUANTIZE src1 address mismatch in mixed-quant models * Fix the pragma indent
1705 lines
79 KiB
C
1705 lines
79 KiB
C
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
|
||
#pragma clang diagnostic ignored "-Wunused-function"
|
||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||
|
||
#include <assert.h>
|
||
#include <stdbool.h>
|
||
#include <stddef.h>
|
||
#include <stdint.h>
|
||
#include <string.h>
|
||
|
||
#include <HAP_farf.h>
|
||
#include <HAP_compute_res.h>
|
||
|
||
#define GGML_COMMON_DECL_C
|
||
#include "ggml-common.h"
|
||
|
||
#include "hex-dma.h"
|
||
#include "hvx-utils.h"
|
||
#include "hvx-dump.h"
|
||
#include "worker-pool.h"
|
||
#include "htp-ctx.h"
|
||
#include "htp-msg.h"
|
||
|
||
#include "hmx-utils.h"
|
||
#include "hmx-ops.h"
|
||
#include "hmx-profile.h"
|
||
|
||
static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
|
||
};
|
||
|
||
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
|
||
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
|
||
};
|
||
|
||
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
|
||
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
|
||
};
|
||
|
||
// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile.
|
||
// word[i] = i*128 maps K-row-pair i to byte offset i*128 in the tile.
|
||
// Column offset (n*4) is added at runtime. Only entries 0..15 are used (masked by predicate).
|
||
static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||
0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128,
|
||
8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128,
|
||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||
};
|
||
|
||
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
|
||
#define HMX_X4X2_SCALES_PER_BLK 8
|
||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
|
||
#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
|
||
|
||
static inline void swap_ptr(void **p1, void **p2) {
|
||
void *t = *p1;
|
||
*p1 = *p2;
|
||
*p2 = t;
|
||
}
|
||
|
||
typedef struct {
|
||
uint8_t *dst;
|
||
const uint8_t *src;
|
||
dma_queue *dma;
|
||
size_t n_rows;
|
||
size_t src_stride; // DDR row stride (full row_stride)
|
||
size_t dst_stride; // VTCM sub-block row stride
|
||
size_t quant_off; // quant byte offset in each DDR row
|
||
size_t quant_width; // quant bytes to copy per row
|
||
size_t scale_off; // scale byte offset in each DDR row
|
||
size_t scale_width; // scale bytes to copy per row
|
||
} qweight_fetch_task_state_t;
|
||
|
||
// Compute the byte stride of one row in x4x2 format.
|
||
// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because
|
||
// x4x2 packing has the same density as block_q4_0 / block_q8_0.
|
||
// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes]
|
||
// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8).
|
||
// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req).
|
||
static inline size_t get_x4x2_row_stride(int weight_type, int k) {
|
||
int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||
switch (weight_type) {
|
||
case HTP_TYPE_Q4_0:
|
||
case HTP_TYPE_IQ4_NL:
|
||
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||
case HTP_TYPE_Q8_0:
|
||
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||
case HTP_TYPE_MXFP4:
|
||
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb
|
||
default:
|
||
return 0;
|
||
}
|
||
}
|
||
|
||
// --- Overflow-safe arithmetic for VTCM budget calculation ---
|
||
|
||
static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) {
|
||
if (a != 0 && b > SIZE_MAX / a) return true;
|
||
*out = a * b;
|
||
return false;
|
||
}
|
||
|
||
static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) {
|
||
if (a > SIZE_MAX - b) return true;
|
||
*out = a + b;
|
||
return false;
|
||
}
|
||
|
||
// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget.
|
||
//
|
||
// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead
|
||
// per_n_cost: bytes per nc column (weight + scratch buffers)
|
||
// per_m_cost: bytes per mc row (activation)
|
||
// per_mn_cost: bytes per mc*nc element (output)
|
||
// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.)
|
||
//
|
||
// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max.
|
||
// Returns 0 on success, -1 if VTCM is insufficient.
|
||
static int hmx_compute_chunks(
|
||
size_t vtcm_total, size_t overhead,
|
||
size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost,
|
||
int m, int n,
|
||
size_t *m_chunk_out, size_t *n_chunk_out,
|
||
size_t *total_out)
|
||
{
|
||
if (m <= 0 || n <= 0) return -1;
|
||
if (vtcm_total <= overhead) return -1;
|
||
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
|
||
|
||
const size_t usable = vtcm_total - overhead;
|
||
size_t best_mn = 0, best_m = 0, best_n = 0;
|
||
|
||
const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS);
|
||
for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) {
|
||
// Early exit: if nc * m_max cannot beat best, smaller nc won't either
|
||
if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn)
|
||
break;
|
||
|
||
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
|
||
if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
|
||
if (n_fixed >= usable) goto next_nc;
|
||
|
||
if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
|
||
if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
|
||
|
||
{
|
||
size_t remain = usable - n_fixed;
|
||
size_t mc = remain / mc_denom;
|
||
mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS);
|
||
mc = hex_smin(mc, (size_t)m);
|
||
|
||
if (mc > 0 && mc * nc > best_mn) {
|
||
best_mn = mc * nc;
|
||
best_m = mc;
|
||
best_n = nc;
|
||
}
|
||
}
|
||
|
||
next_nc:
|
||
if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow
|
||
}
|
||
|
||
if (best_m == 0 || best_n == 0) return -1;
|
||
|
||
// Compute exact total (with overflow checks)
|
||
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
|
||
if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1;
|
||
if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1;
|
||
if (hmx_mul_overflow(best_m, best_n, &mn)) return -1;
|
||
if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1;
|
||
if (hmx_add_overflow(t0, t1, &total)) return -1;
|
||
if (hmx_add_overflow(total, t2, &total)) return -1;
|
||
if (hmx_add_overflow(total, overhead, &total)) return -1;
|
||
|
||
*m_chunk_out = best_m;
|
||
*n_chunk_out = best_n;
|
||
*total_out = total;
|
||
return 0;
|
||
}
|
||
|
||
// forward declaration – defined after transfer_activation_chunk_fp32_to_fp16
|
||
void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride);
|
||
|
||
// Scatter row-major FP16 weight (already in VTCM scratch) directly into transposed [K][N] tiles.
|
||
// vtcm_src: [n_cols][k] row-major fp16 in VTCM scratch buffer
|
||
// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16
|
||
static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst,
|
||
const __fp16 *restrict vtcm_src,
|
||
int n_cols, int k) {
|
||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||
|
||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||
const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets);
|
||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||
|
||
for (int r = 0; r < n_cols; r += 2) {
|
||
int ct = r / HMX_FP16_TILE_N_ROWS; // N-dimension tile index
|
||
int local_r = r % HMX_FP16_TILE_N_ROWS; // intra-tile row index
|
||
const bool next_row_valid = (r + 1) < n_cols;
|
||
|
||
// Offset vectors for N-columns local_r and local_r+1, reused across K-tiles.
|
||
HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||
HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||
|
||
for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) {
|
||
int kt = c / HMX_FP16_TILE_N_COLS;
|
||
int tile_idx = ct * n_k_tiles + kt;
|
||
__fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS;
|
||
|
||
HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c);
|
||
HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero();
|
||
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0);
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1);
|
||
}
|
||
}
|
||
}
|
||
|
||
// --- x4x2 format dequantizers ---
|
||
|
||
// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes.
|
||
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
|
||
// of the same 32 packed bytes.
|
||
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
|
||
const uint8_t *packed_32, bool upper_nibbles,
|
||
const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
|
||
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
|
||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||
// Shuffle before LUT
|
||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||
// Use standard vlut16 (not _nomatch) to avoid stale-register NaN.
|
||
// _nomatch retains the previous destination-register value for colliding
|
||
// indices, but the C intrinsic doesn't model the implicit read so the
|
||
// compiler may allocate a register containing garbage/NaN.
|
||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||
|
||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||
}
|
||
|
||
// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using
|
||
// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls.
|
||
// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes.
|
||
static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||
const uint8_t *packed_128, bool upper_nibbles,
|
||
const __fp16 *scales_4, const HVX_Vector vlut_cvt,
|
||
HVX_Vector out[4]) {
|
||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||
|
||
// Shuffle before LUT
|
||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||
|
||
// Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair
|
||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||
HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16]
|
||
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
|
||
|
||
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
|
||
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
|
||
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1]));
|
||
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3]));
|
||
|
||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||
|
||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||
out[0] = v_lo; // group0 already in [0:63]
|
||
out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63]
|
||
out[2] = v_hi; // group2 already in [0:63]
|
||
out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63]
|
||
}
|
||
|
||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
|
||
const int8_t *quants_32, const __fp16 *scale) {
|
||
HVX_Vector vq = hvx_vmemu(quants_32);
|
||
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
|
||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
|
||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||
}
|
||
|
||
// --- MXFP4 E8M0 scale conversion and dequantization ---
|
||
//
|
||
// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
|
||
// Scalar loads from the stack array execute on the scalar pipeline, in parallel
|
||
// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
|
||
// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
|
||
// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
|
||
|
||
typedef struct {
|
||
__fp16 v[8] __attribute__((aligned(16)));
|
||
} mxfp4_scales_t;
|
||
|
||
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
|
||
mxfp4_scales_t s;
|
||
HVX_Vector v = hvx_vmemu(e8m0_8);
|
||
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
|
||
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
|
||
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
|
||
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
|
||
vh = Q6_Vh_vasl_VhR(vh, 10);
|
||
hvx_vec_store_u(s.v, 16, vh);
|
||
return s;
|
||
}
|
||
|
||
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
|
||
return hvx_vec_splat_f16(scales.v[idx]);
|
||
}
|
||
|
||
// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
|
||
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
|
||
bool upper_nibbles,
|
||
int sub_blk,
|
||
const HVX_Vector vlut_cvt,
|
||
mxfp4_scales_t scales) {
|
||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||
|
||
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
|
||
|
||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||
|
||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
|
||
}
|
||
|
||
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
|
||
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||
bool upper_nibbles,
|
||
int sub_blk_base,
|
||
const HVX_Vector vlut_cvt,
|
||
mxfp4_scales_t scales,
|
||
HVX_Vector out[4]) {
|
||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||
|
||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||
|
||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||
|
||
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
|
||
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
|
||
mxfp4_extract_splat(scales, sub_blk_base + 1));
|
||
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
|
||
mxfp4_extract_splat(scales, sub_blk_base + 3));
|
||
|
||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||
|
||
out[0] = v_lo;
|
||
out[1] = Q6_V_vror_VR(v_lo, 64);
|
||
out[2] = v_hi;
|
||
out[3] = Q6_V_vror_VR(v_hi, 64);
|
||
}
|
||
|
||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
|
||
// Output: vtcm_dst in tile-major FP16 layout.
|
||
static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||
__fp16 *restrict vtcm_dst,
|
||
const uint8_t *restrict vtcm_src,
|
||
int n_cols, int k_block,
|
||
size_t row_stride, int weight_type,
|
||
int start_tile, int end_tile) {
|
||
|
||
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
|
||
|
||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||
hvx_vmem(q4_0_to_fp16_lut);
|
||
|
||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
|
||
// maps to K-rows 2i and 2i+1. Column offset (n*4) added per row.
|
||
const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets);
|
||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
|
||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
|
||
|
||
for (int t = start_tile; t < end_tile; ) {
|
||
int ct = t / n_k_tiles; // column tile index
|
||
int kt = t % n_k_tiles; // K tile index
|
||
|
||
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
|
||
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
|
||
((t + 3) / n_k_tiles == ct)) {
|
||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||
bool upper = (sub_blk_base >= 4);
|
||
int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
|
||
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
|
||
|
||
__fp16 *tile_bases[4];
|
||
for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||
|
||
HVX_Vector v_off = v_scat_base;
|
||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||
int row1 = row0 + 1;
|
||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||
|
||
HVX_Vector v0[4], v1[4];
|
||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||
if (row1 < n_cols) {
|
||
dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1);
|
||
} else {
|
||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||
}
|
||
|
||
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); }
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); }
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
}
|
||
|
||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
|
||
|
||
t += 4;
|
||
continue;
|
||
}
|
||
|
||
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
|
||
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
|
||
bool upper = (sub_blk_base >= 4);
|
||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
|
||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
|
||
|
||
__fp16 * tile_bases[4];
|
||
for (int g = 0; g < 4; g++) {
|
||
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||
}
|
||
|
||
HVX_Vector v_off = v_scat_base;
|
||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||
int row1 = row0 + 1;
|
||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||
|
||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||
|
||
HVX_Vector v0[4], v1[4];
|
||
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
|
||
if (row1 < n_cols) {
|
||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
|
||
} else {
|
||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||
}
|
||
|
||
for (int g = 0; g < 4; g++) {
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
|
||
}
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
for (int g = 0; g < 4; g++) {
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
|
||
}
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
}
|
||
|
||
for (int g = 0; g < 4; g++) {
|
||
(void) *(volatile HVX_Vector *) (tile_bases[g]);
|
||
}
|
||
|
||
t += 4;
|
||
continue;
|
||
}
|
||
|
||
// --- Single-tile fallback ---
|
||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||
|
||
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
|
||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||
bool upper = (sub_blk >= 4);
|
||
int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||
|
||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||
int row1 = row0 + 1;
|
||
|
||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||
|
||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
|
||
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||
HVX_Vector v1 = (row1 < n_cols)
|
||
? dequantize_x4x2_q4_0_group_hvx(
|
||
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||
: Q6_V_vzero();
|
||
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
}
|
||
(void) *(volatile HVX_Vector *)(tile_base);
|
||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||
bool upper = (sub_blk >= 4);
|
||
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
|
||
|
||
HVX_Vector v_off = v_scat_base;
|
||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||
int row1 = row0 + 1;
|
||
|
||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||
|
||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||
|
||
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
|
||
HVX_Vector v1;
|
||
if (row1 < n_cols) {
|
||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
|
||
} else {
|
||
v1 = Q6_V_vzero();
|
||
}
|
||
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
}
|
||
(void) *(volatile HVX_Vector *) (tile_base);
|
||
} else {
|
||
// Q8_0
|
||
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
|
||
int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32;
|
||
int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32;
|
||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||
|
||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||
int row1 = row0 + 1;
|
||
|
||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||
|
||
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx(
|
||
(const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
|
||
HVX_Vector v1 = (row1 < n_cols)
|
||
? dequantize_x4x2_q8_0_group_hvx(
|
||
(const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off))
|
||
: Q6_V_vzero();
|
||
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||
}
|
||
(void) *(volatile HVX_Vector *)(tile_base);
|
||
}
|
||
++t;
|
||
}
|
||
|
||
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
|
||
// all pending scatter entries to VTCM. Without this, the main thread's HMX
|
||
// reads may see stale data because atomic_fetch_sub (release) only orders
|
||
// regular stores, not the HVX scatter buffer.
|
||
if (start_tile < end_tile) {
|
||
(void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS);
|
||
}
|
||
}
|
||
|
||
typedef struct {
|
||
__fp16 *dst;
|
||
const uint8_t *src;
|
||
int n_cols;
|
||
int k_block;
|
||
size_t row_stride;
|
||
int weight_type;
|
||
int n_tot_tiles;
|
||
int n_tiles_per_task;
|
||
int n_tasks;
|
||
} x4x2_dequantize_state_t;
|
||
|
||
static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||
|
||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||
int start = task_id * state->n_tiles_per_task;
|
||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||
|
||
dequantize_x4x2_weight_to_fp16_tiles_task(
|
||
state->dst, state->src, state->n_cols, state->k_block,
|
||
state->row_stride, state->weight_type, start, end);
|
||
}
|
||
}
|
||
|
||
static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||
struct htp_context *ctx, __fp16 *vtcm_dst,
|
||
const void *vtcm_src, int n_cols, int k_block,
|
||
size_t row_stride, int weight_type) {
|
||
|
||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||
assert(k_block % HMX_FP16_TILE_N_COLS == 0);
|
||
|
||
int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
|
||
int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||
int n_tot_tiles = n_col_tiles * n_k_tiles;
|
||
|
||
size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads);
|
||
|
||
x4x2_dequantize_state_t state;
|
||
state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task;
|
||
state.n_tot_tiles = n_tot_tiles;
|
||
state.n_tiles_per_task = n_tiles_per_task;
|
||
state.dst = vtcm_dst;
|
||
state.src = (const uint8_t *)vtcm_src;
|
||
state.n_cols = n_cols;
|
||
state.k_block = k_block;
|
||
state.row_stride = row_stride;
|
||
state.weight_type = weight_type;
|
||
|
||
worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads);
|
||
}
|
||
|
||
// --- End x4x2 dequantizers ---
|
||
|
||
// requires external HMX lock
|
||
static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales,
|
||
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
|
||
hmx_set_output_scales(scales);
|
||
|
||
for (int r = 0; r < n_row_tiles; ++r) {
|
||
for (int c = 0; c < n_col_tiles; ++c) {
|
||
Q6_mxclracc_hf();
|
||
|
||
const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||
|
||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||
int offset = k * HMX_FP16_TILE_N_ELMS;
|
||
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
|
||
}
|
||
|
||
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
|
||
hmx_consume_accumulator_fp16(out_tile);
|
||
}
|
||
}
|
||
}
|
||
|
||
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) {
|
||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||
const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS;
|
||
|
||
const HVX_Vector one = hvx_vec_splat_f16(1.0);
|
||
|
||
for (int r = 0; r < n_rows; r += 2) {
|
||
int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||
int r1 = r % HMX_FP16_TILE_N_ROWS;
|
||
|
||
#pragma unroll(4)
|
||
for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) {
|
||
int c0 = c / HMX_FP16_TILE_N_COLS;
|
||
|
||
const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS;
|
||
|
||
HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2];
|
||
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
|
||
|
||
volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0));
|
||
volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory
|
||
|
||
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||
if (r + 1 < n_rows) {
|
||
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
typedef struct {
|
||
const __fp16 *vtcm_src;
|
||
float *dst;
|
||
int n_tasks;
|
||
int n_tot_chunks;
|
||
int n_chunks_per_task;
|
||
int n_cols;
|
||
int n; // DDR row stride (total output columns)
|
||
} output_transfer_task_state_t;
|
||
|
||
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
|
||
|
||
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
|
||
int chunk_idx = task_id * st->n_chunks_per_task;
|
||
size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task);
|
||
|
||
float *dst = st->dst + chunk_idx * st->n;
|
||
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
|
||
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
|
||
}
|
||
}
|
||
|
||
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
|
||
int n_rows, int n_cols, int n) {
|
||
assert(n_cols % HMX_FP16_TILE_N_COLS == 0);
|
||
|
||
size_t n_tot_chunks = n_rows;
|
||
size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32)
|
||
|
||
output_transfer_task_state_t state;
|
||
state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task;
|
||
state.n_tot_chunks = n_tot_chunks;
|
||
state.n_chunks_per_task = n_chunks_per_task;
|
||
state.dst = dst;
|
||
state.vtcm_src = vtcm_src;
|
||
state.n_cols = n_cols;
|
||
state.n = n;
|
||
|
||
worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads);
|
||
}
|
||
|
||
static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) {
|
||
return params->ne02 > 0 ? params->ne12 / params->ne02 : 1;
|
||
}
|
||
|
||
static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) {
|
||
return params->ne03 > 0 ? params->ne13 / params->ne03 : 1;
|
||
}
|
||
|
||
static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||
int dst_b2, int dst_b3) {
|
||
const int r2 = hmx_matmul_batch_r2(params);
|
||
const int r3 = hmx_matmul_batch_r3(params);
|
||
return (const __fp16 *) ((const uint8_t *) params->permuted_weight +
|
||
(size_t) (dst_b2 / r2) * params->src0_nb2 +
|
||
(size_t) (dst_b3 / r3) * params->src0_nb3);
|
||
}
|
||
|
||
static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||
int dst_b2, int dst_b3) {
|
||
return (const float *) ((const uint8_t *) params->activation +
|
||
(size_t) dst_b2 * params->src1_nb2 +
|
||
(size_t) dst_b3 * params->src1_nb3);
|
||
}
|
||
|
||
static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params,
|
||
int dst_b2, int dst_b3) {
|
||
return (float *) ((uint8_t *) params->dst +
|
||
(size_t) dst_b2 * params->dst_nb2 +
|
||
(size_t) dst_b3 * params->dst_nb3);
|
||
}
|
||
|
||
static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx,
|
||
const hmx_matmul_w16a32_batched_params_t *params) {
|
||
int ret = 0;
|
||
for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) {
|
||
for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) {
|
||
ret = hmx_mat_mul_permuted_w16a32(ctx,
|
||
hmx_matmul_dst_batch_ptr(params, b2, b3),
|
||
hmx_matmul_activation_batch_ptr(params, b2, b3),
|
||
hmx_matmul_weight_batch_ptr(params, b2, b3),
|
||
params->m, params->k, params->n,
|
||
params->act_stride, params->weight_stride);
|
||
}
|
||
}
|
||
return ret;
|
||
}
|
||
|
||
int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) {
|
||
if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; }
|
||
if (!params->m || !params->k || !params->n) { return -1; }
|
||
if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; }
|
||
if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; }
|
||
if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; }
|
||
if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; }
|
||
|
||
if (!hex_is_aligned(params->dst, VLEN) ||
|
||
!hex_is_aligned(params->activation, VLEN) ||
|
||
!hex_is_aligned(params->permuted_weight, VLEN)) {
|
||
return -1;
|
||
}
|
||
|
||
const int group_size = hmx_matmul_batch_r2(params);
|
||
|
||
if (group_size <= 1) {
|
||
FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size);
|
||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||
}
|
||
|
||
// Grouped path: reuse interleaved weight across all q_heads sharing a
|
||
// kv_head. Each q_head gets its own activation buffer in VTCM (so
|
||
// activation is loaded once per m_chunk and reused across all n_chunks),
|
||
// and each q_head is computed individually to avoid tile-major packing
|
||
// issues. m_chunk_n_rows is always a multiple of 32 (from
|
||
// hmx_compute_chunks), so per-head tile arrays don't overlap.
|
||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||
const size_t vec_dot_size = params->k * sizeof(__fp16);
|
||
|
||
// When the activation has a large stride (e.g. permuted Q tensor with
|
||
// act_stride >> k), HVX vector loads from strided DDR thrash L2 cache.
|
||
// Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather
|
||
// strided rows into a contiguous block before the F32->F16 conversion.
|
||
const bool use_dma_activation = (params->act_stride > params->k);
|
||
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0;
|
||
|
||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
|
||
/*per_n=*/3 * vec_dot_size,
|
||
/*per_m=*/group_size * vec_dot_size + f32_scratch_per_m,
|
||
/*per_mn=*/sizeof(__fp16),
|
||
params->m, params->n,
|
||
&m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||
FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__);
|
||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||
}
|
||
|
||
const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads
|
||
const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||
const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t f32_scratch_size = use_dma_activation
|
||
? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0;
|
||
|
||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
|
||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
|
||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
|
||
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
|
||
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
|
||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||
float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL;
|
||
|
||
if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) {
|
||
FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__);
|
||
return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params);
|
||
}
|
||
|
||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
|
||
|
||
FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu",
|
||
__func__, params->m, params->k, params->n, group_size, params->ne13,
|
||
m_chunk_n_rows, n_chunk_n_cols,
|
||
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||
|
||
TIMER_DEFINE(activation_load);
|
||
TIMER_DEFINE(weight_load);
|
||
TIMER_DEFINE(hmx_core);
|
||
TIMER_DEFINE(output_store);
|
||
TIMER_DEFINE(total);
|
||
|
||
TIMER_START(total);
|
||
|
||
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
|
||
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
|
||
|
||
for (int b3 = 0; b3 < params->ne13; ++b3) {
|
||
for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) {
|
||
const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3);
|
||
|
||
for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) {
|
||
const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows);
|
||
|
||
// Pre-load activations for all heads in the group (once per m_chunk).
|
||
// When the source is strided (permuted Q), use 2D DMA to gather
|
||
// contiguous rows into a VTCM scratch buffer first, then HVX
|
||
// converts from the contiguous VTCM buffer. This avoids L2 cache
|
||
// thrashing from HVX loads at large strides.
|
||
TIMER_START(activation_load);
|
||
for (int g = 0; g < group_size; ++g) {
|
||
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
|
||
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
|
||
if (use_dma_activation) {
|
||
const size_t row_bytes = (size_t) params->k * sizeof(float);
|
||
const size_t stride_bytes = (size_t) params->act_stride * sizeof(float);
|
||
dma_queue_push(ctx->dma[0],
|
||
dma_make_ptr(vtcm_f32_act, activation_chunk),
|
||
row_bytes, stride_bytes, row_bytes, n_rows);
|
||
dma_queue_pop(ctx->dma[0]);
|
||
transfer_activation_chunk_threaded(ctx, vtcm_act_g,
|
||
vtcm_f32_act, (int) n_rows,
|
||
params->k, params->k);
|
||
} else {
|
||
transfer_activation_chunk_threaded(ctx, vtcm_act_g,
|
||
activation_chunk, (int) n_rows,
|
||
params->k, params->act_stride);
|
||
}
|
||
}
|
||
TIMER_STOP(activation_load);
|
||
|
||
void *buf_curr = vtcm_scratch0;
|
||
void *buf_next = vtcm_scratch1;
|
||
|
||
{
|
||
const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols);
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group),
|
||
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
|
||
}
|
||
|
||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||
|
||
for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) {
|
||
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
|
||
|
||
TIMER_START(weight_load);
|
||
{
|
||
dma_queue_pop(ctx->dma[0]);
|
||
|
||
const size_t nc_next = nc + n_chunk_n_cols;
|
||
if (nc_next < (size_t) params->n) {
|
||
const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols);
|
||
const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride;
|
||
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
|
||
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
|
||
}
|
||
|
||
interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k);
|
||
swap_ptr(&buf_curr, &buf_next);
|
||
}
|
||
TIMER_STOP(weight_load);
|
||
|
||
// Reuse the interleaved weight for every q_head in this GQA group
|
||
for (int g = 0; g < group_size; ++g) {
|
||
TIMER_START(hmx_core);
|
||
{
|
||
const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
|
||
const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS);
|
||
const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
|
||
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales,
|
||
n_row_tiles, n_col_tiles, params->k / 32);
|
||
}
|
||
TIMER_STOP(hmx_core);
|
||
|
||
TIMER_START(output_store);
|
||
{
|
||
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
|
||
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride);
|
||
}
|
||
TIMER_STOP(output_store);
|
||
}
|
||
}
|
||
|
||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||
}
|
||
}
|
||
}
|
||
|
||
TIMER_STOP(total);
|
||
|
||
#if defined(ENABLE_PROFILE_TIMERS)
|
||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
|
||
params->m, params->k, params->n, group_size);
|
||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||
#endif
|
||
|
||
return 0;
|
||
}
|
||
|
||
int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||
const __fp16 *restrict permuted_weight, int m, int k, int n,
|
||
int act_stride, int weight_stride) {
|
||
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
|
||
if (act_stride < k || weight_stride < k) { return -1; }
|
||
if (k % 32 != 0 || n % 32 != 0) { return -1; }
|
||
|
||
if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) {
|
||
return -1;
|
||
}
|
||
|
||
// --- Dynamic VTCM layout ---
|
||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||
|
||
// DMA-based activation gather for strided tensors (see batched path comment).
|
||
const bool use_dma_activation = (act_stride > k);
|
||
const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0;
|
||
|
||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||
if (hmx_compute_chunks(vtcm_budget,
|
||
/*overhead=*/ 256,
|
||
/*per_n=*/ 3 * vec_dot_size, // W + S0 + S1
|
||
/*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch
|
||
/*per_mn=*/ sizeof(__fp16), // O
|
||
m, n,
|
||
&m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget);
|
||
return -1;
|
||
}
|
||
|
||
const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||
const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t f32_scratch_size = use_dma_activation
|
||
? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0;
|
||
|
||
// VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch]
|
||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
|
||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
|
||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
|
||
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
|
||
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size);
|
||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||
float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL;
|
||
if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) {
|
||
FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__,
|
||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||
return -1;
|
||
}
|
||
|
||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
|
||
|
||
FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu",
|
||
__func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols,
|
||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||
|
||
TIMER_DEFINE(activation_load);
|
||
TIMER_DEFINE(weight_load);
|
||
TIMER_DEFINE(hmx_core);
|
||
TIMER_DEFINE(output_store);
|
||
|
||
TIMER_DEFINE(total);
|
||
TIMER_START(total);
|
||
|
||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||
|
||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||
// transfer activation matrix chunk into VTCM
|
||
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||
|
||
TIMER_START(activation_load);
|
||
{
|
||
const float *activation_chunk = activation + mr * act_stride;
|
||
if (use_dma_activation) {
|
||
const size_t row_bytes = (size_t) k * sizeof(float);
|
||
const size_t stride_bytes = (size_t) act_stride * sizeof(float);
|
||
dma_queue_push(ctx->dma[0],
|
||
dma_make_ptr(vtcm_f32_act, activation_chunk),
|
||
row_bytes, stride_bytes, row_bytes, n_rows);
|
||
dma_queue_pop(ctx->dma[0]);
|
||
transfer_activation_chunk_threaded(ctx, vtcm_activation,
|
||
vtcm_f32_act, n_rows, k, k);
|
||
} else {
|
||
transfer_activation_chunk_threaded(ctx, vtcm_activation,
|
||
activation_chunk, n_rows, k, act_stride);
|
||
}
|
||
}
|
||
TIMER_STOP(activation_load);
|
||
|
||
const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16);
|
||
const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16);
|
||
|
||
void *buf_curr = vtcm_scratch0;
|
||
void *buf_next = vtcm_scratch1;
|
||
|
||
// issue async DMA for the first weight chunk
|
||
// NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow.
|
||
// The source rows can be strided (e.g. KV-cache K after ggml_permute).
|
||
{
|
||
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
|
||
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight),
|
||
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
|
||
}
|
||
|
||
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
|
||
size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||
|
||
TIMER_START(weight_load);
|
||
{
|
||
dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready
|
||
|
||
// issue async DMA for the next weight chunk (double buffering)
|
||
const size_t nc_next = nc + n_chunk_n_cols;
|
||
if (nc_next < n) {
|
||
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
|
||
const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride;
|
||
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
|
||
fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
|
||
}
|
||
|
||
// interleave row-major fp16 from scratch into tile-major in vtcm_weight
|
||
interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k);
|
||
|
||
swap_ptr(&buf_curr, &buf_next);
|
||
}
|
||
TIMER_STOP(weight_load);
|
||
|
||
TIMER_START(hmx_core);
|
||
{
|
||
const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
|
||
const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
|
||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
|
||
}
|
||
TIMER_STOP(hmx_core);
|
||
|
||
TIMER_START(output_store);
|
||
{
|
||
float *output = dst + (mr * n + nc);
|
||
transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n);
|
||
}
|
||
TIMER_STOP(output_store);
|
||
}
|
||
|
||
}
|
||
|
||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||
|
||
TIMER_STOP(total);
|
||
|
||
#if defined(ENABLE_PROFILE_TIMERS)
|
||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n);
|
||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||
{
|
||
size_t weight_size = (size_t)k * n * sizeof(__fp16);
|
||
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
|
||
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
|
||
}
|
||
#endif
|
||
|
||
return 0;
|
||
}
|
||
|
||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
|
||
int k, int n, int w_type);
|
||
|
||
int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation,
|
||
const uint8_t *restrict permuted_weight, int m, int k, int n,
|
||
int weight_type) {
|
||
if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; }
|
||
if (k % 32 != 0 || n % 32 != 0) { return -1; }
|
||
|
||
if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) {
|
||
return -1;
|
||
}
|
||
|
||
// for large m, k (e.g. prefill FFN Down), use out-stationary version
|
||
if (m >= 128 && k > n && n > 1024) {
|
||
FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)",
|
||
m, k, n, weight_type, (k + 511) / 512);
|
||
return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type);
|
||
}
|
||
|
||
size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||
if (row_stride == 0) {
|
||
return -1;
|
||
}
|
||
|
||
FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type);
|
||
|
||
// --- Dynamic VTCM layout ---
|
||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||
const size_t vec_dot_size = k * sizeof(__fp16);
|
||
const bool use_pipeline = (m >= 128) && (k <= n);
|
||
|
||
// Select cost parameters based on execution path
|
||
size_t per_n_cost, per_mn_cost;
|
||
if (use_pipeline) {
|
||
per_n_cost = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs)
|
||
per_mn_cost = 2 * sizeof(__fp16); // O x 2 (output double buffer)
|
||
} else {
|
||
per_n_cost = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs)
|
||
per_mn_cost = sizeof(__fp16); // O x 1
|
||
}
|
||
|
||
size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0;
|
||
if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256,
|
||
per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost,
|
||
m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) {
|
||
FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)",
|
||
__func__, m, k, n, use_pipeline, vtcm_budget);
|
||
return -1;
|
||
}
|
||
|
||
// Compute precise buffer sizes per execution path
|
||
const size_t weight_area_size = hex_align_up(
|
||
n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE);
|
||
const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE);
|
||
const size_t output_area_size = hex_align_up(
|
||
m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||
|
||
size_t scratch0_size, scratch1_size, scratch2_size;
|
||
if (use_pipeline) {
|
||
scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0
|
||
scratch1_size = scratch0_size; // dequant buf 1
|
||
scratch2_size = output_area_size; // output buf 1
|
||
} else {
|
||
scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0
|
||
scratch1_size = scratch0_size; // x4x2 DMA buf 1
|
||
scratch2_size = 0; // unused
|
||
}
|
||
|
||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size);
|
||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size);
|
||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size);
|
||
void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size);
|
||
void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size);
|
||
void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL;
|
||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||
if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) {
|
||
FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__,
|
||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||
return -1;
|
||
}
|
||
|
||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
|
||
|
||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu",
|
||
__func__, m, k, n, weight_type, use_pipeline,
|
||
m_chunk_n_rows, n_chunk_n_cols,
|
||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||
|
||
TIMER_DEFINE(activation_load);
|
||
TIMER_DEFINE(weight_load);
|
||
TIMER_DEFINE(hmx_core);
|
||
TIMER_DEFINE(output_store);
|
||
|
||
TIMER_DEFINE(total);
|
||
TIMER_START(total);
|
||
|
||
FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu",
|
||
use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols,
|
||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||
|
||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||
|
||
if (!use_pipeline) {
|
||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||
// transfer activation matrix chunk into VTCM
|
||
size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||
|
||
TIMER_START(activation_load);
|
||
{
|
||
const float *activation_chunk = activation + mr * k;
|
||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||
}
|
||
TIMER_STOP(activation_load);
|
||
|
||
void *buf_curr = vtcm_scratch0;
|
||
void *buf_next = vtcm_scratch1;
|
||
|
||
// issue async DDR data transfer for the first weight chunk
|
||
// NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D
|
||
// because UDMA roiwidth is 16-bit and total size can exceed 65535.
|
||
{
|
||
const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
|
||
}
|
||
|
||
for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
|
||
size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||
|
||
TIMER_START(weight_load);
|
||
{
|
||
dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready
|
||
|
||
const size_t nc_next = nc + n_chunk_n_cols;
|
||
if (nc_next < n) {
|
||
const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols);
|
||
|
||
const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride;
|
||
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
|
||
}
|
||
|
||
// Dequant + vscatter writes directly to [K, N] transposed tiles.
|
||
// HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight.
|
||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type);
|
||
|
||
swap_ptr(&buf_curr, &buf_next);
|
||
}
|
||
TIMER_STOP(weight_load);
|
||
|
||
TIMER_START(hmx_core);
|
||
{
|
||
const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS);
|
||
const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS);
|
||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32);
|
||
}
|
||
TIMER_STOP(hmx_core);
|
||
|
||
TIMER_START(output_store);
|
||
{
|
||
float *output = dst + (mr * n + nc);
|
||
transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n);
|
||
}
|
||
TIMER_STOP(output_store);
|
||
}
|
||
}
|
||
} else {
|
||
// 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D)
|
||
// stage B and D (dequantize and store) are expected to be on the critical path
|
||
|
||
// A --> B: vtcm_qweight, 1 buffer
|
||
// B --> C: vtcm_weight0/vtcm_weight1, 2 buffers
|
||
// C --> D: vtcm_output0/vtcm_output1, 2 buffers
|
||
|
||
//
|
||
// LD ||A3| | B3 ||
|
||
// MM || C2 ||
|
||
// ST || D1 | ||
|
||
|
||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||
for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) {
|
||
const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows);
|
||
|
||
void *vtcm_qweight = vtcm_weight;
|
||
void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 };
|
||
void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 };
|
||
|
||
// prologue: A0
|
||
const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols);
|
||
{
|
||
// Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow.
|
||
const uint8_t *qweight_chunk_A0 = permuted_weight;
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
|
||
}
|
||
|
||
{
|
||
const float *activation_chunk = activation + mr * k;
|
||
transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k);
|
||
}
|
||
|
||
// prologue: B0, A1, C0, B1
|
||
{
|
||
// B0
|
||
dma_queue_pop(ctx->dma[0]);
|
||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type);
|
||
|
||
// A1
|
||
const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
|
||
if (1 < n_chunk_cnt) {
|
||
const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
|
||
}
|
||
|
||
// C0
|
||
core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales,
|
||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||
|
||
// B1
|
||
if (1 < n_chunk_cnt) {
|
||
dma_queue_pop(ctx->dma[0]);
|
||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type);
|
||
}
|
||
}
|
||
|
||
// main loop
|
||
for (int i = 0; i < n_chunk_cnt; ++i) {
|
||
const size_t nc = i * n_chunk_n_cols;
|
||
const size_t nc_p1 = nc + 1 * n_chunk_n_cols;
|
||
const size_t nc_p2 = nc + 2 * n_chunk_n_cols;
|
||
|
||
const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols);
|
||
const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols);
|
||
const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols);
|
||
|
||
// issue A_{i+2}
|
||
if (i + 2 < n_chunk_cnt) {
|
||
const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
|
||
}
|
||
|
||
// wait for HMX (C_{i}) -- C_{i} is done
|
||
|
||
// result of B_{i+1} (input of C_{i+1}) should be ready now
|
||
|
||
// issue C_{i+1}
|
||
if (i + 1 < n_chunk_cnt) {
|
||
core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales,
|
||
hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS);
|
||
}
|
||
|
||
// compute D_{i}
|
||
float *output_chunk = dst + (mr * n + nc);
|
||
transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n);
|
||
|
||
// wait for DMA (A_{i+2}), compute B_{i+2}
|
||
if (i + 2 < n_chunk_cnt) {
|
||
dma_queue_pop(ctx->dma[0]);
|
||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||
|
||
TIMER_STOP(total);
|
||
|
||
#if defined(ENABLE_PROFILE_TIMERS)
|
||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline);
|
||
if (!use_pipeline) {
|
||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||
size_t weight_size = (size_t)n * row_stride;
|
||
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
|
||
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
|
||
}
|
||
#endif
|
||
|
||
return 0;
|
||
}
|
||
|
||
// C += AB
|
||
void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile,
|
||
int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) {
|
||
|
||
hmx_set_output_scales(col_scales);
|
||
|
||
for (int i = 0; i < n_row_tiles; ++i) {
|
||
for (int j = 0; j < n_col_tiles; ++j) {
|
||
Q6_mxclracc_hf();
|
||
|
||
const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||
const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||
|
||
__fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
|
||
if (!zero_init) {
|
||
hmx_load_tile_pair_fp16(accum_tile, eye_tile);
|
||
}
|
||
|
||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||
int offset = k * HMX_FP16_TILE_N_ELMS;
|
||
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
|
||
}
|
||
|
||
hmx_consume_accumulator_fp16(accum_tile);
|
||
}
|
||
}
|
||
}
|
||
|
||
static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows,
|
||
int k_block, int k_stride) {
|
||
for (int r = 0; r < n_rows; r += 2) {
|
||
int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index
|
||
int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx
|
||
|
||
const bool next_row_valid = (r + 1) < n_rows;
|
||
|
||
const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride);
|
||
const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride);
|
||
for (int c = 0; c < k_block; c += 32) {
|
||
HVX_Vector v0 = *pv_in0++;
|
||
HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero();
|
||
|
||
HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1);
|
||
|
||
// compute output position
|
||
int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index
|
||
int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0;
|
||
|
||
HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS);
|
||
tile[r1 / 2] = v_out;
|
||
}
|
||
}
|
||
}
|
||
|
||
typedef struct {
|
||
__fp16 *dst;
|
||
const float *src;
|
||
int n_tasks;
|
||
int n_tot_chunks;
|
||
int n_chunks_per_task;
|
||
int k_block;
|
||
int k_stride;
|
||
} activation_transfer_task_state_t;
|
||
|
||
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
|
||
|
||
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
|
||
// one chunk: one row
|
||
int chunk_idx = task_id * st->n_chunks_per_task;
|
||
size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task);
|
||
|
||
__fp16 *dst = st->dst + chunk_idx * st->k_block;
|
||
const float *src = st->src + chunk_idx * st->k_stride;
|
||
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
|
||
}
|
||
}
|
||
|
||
void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) {
|
||
assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0);
|
||
assert(VLEN == 32 * sizeof(float));
|
||
|
||
size_t n_tot_chunks = n_rows;
|
||
size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address
|
||
|
||
activation_transfer_task_state_t state;
|
||
state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task;
|
||
state.n_tot_chunks = n_tot_chunks;
|
||
state.n_chunks_per_task = n_chunks_per_task;
|
||
state.dst = dst;
|
||
state.src = src;
|
||
state.k_block = k_block;
|
||
state.k_stride = k_stride;
|
||
|
||
worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads);
|
||
}
|
||
|
||
int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m,
|
||
int k, int n, int weight_type) {
|
||
// Runtime check -- k >= 16384 exceeds 2D DMA limit
|
||
if (k >= 16384) {
|
||
FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k);
|
||
return -1;
|
||
}
|
||
// assume k % 32 == 0 && n % 32 == 0
|
||
const size_t row_stride = get_x4x2_row_stride(weight_type, k);
|
||
if (row_stride == 0) {
|
||
return -1;
|
||
}
|
||
|
||
const size_t vtcm_budget = ctx->vtcm_scratch_size;
|
||
|
||
const size_t M_BLOCK_SIZE = 512;
|
||
const size_t N_BLOCK_SIZE = 512;
|
||
const size_t K_BLOCK_SIZE = 512;
|
||
|
||
// Compute precise buffer sizes
|
||
const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE);
|
||
const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||
const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||
const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE);
|
||
const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE);
|
||
const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE);
|
||
|
||
const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256;
|
||
if (total_vtcm > vtcm_budget) {
|
||
FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n);
|
||
return -1;
|
||
}
|
||
|
||
uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base;
|
||
__fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size);
|
||
__fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size);
|
||
__fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size);
|
||
uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz);
|
||
uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz);
|
||
__fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE);
|
||
__fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256);
|
||
assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget);
|
||
|
||
FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu",
|
||
__func__, m, k, n, weight_type,
|
||
(size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget);
|
||
|
||
// initialize eye tile (32x32 identity matrix)
|
||
{
|
||
HVX_Vector v;
|
||
v = Q6_V_vzero();
|
||
v = Q6_Vw_vinsert_VwR(v, 0x3c000000);
|
||
v = Q6_V_vror_VR(v, VLEN - 4);
|
||
v = Q6_Vw_vinsert_VwR(v, 0x00003c00);
|
||
for (int i = 0; i < 16; ++i) {
|
||
((HVX_Vector *) vtcm_eye_tile)[i] = v;
|
||
v = Q6_V_vror_VR(v, VLEN - 8);
|
||
}
|
||
}
|
||
hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0
|
||
|
||
TIMER_DEFINE(fetch);
|
||
TIMER_DEFINE(act_load);
|
||
TIMER_DEFINE(wt_dequant);
|
||
TIMER_DEFINE(core);
|
||
|
||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||
|
||
for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) {
|
||
size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE);
|
||
for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) {
|
||
size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE);
|
||
|
||
const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS);
|
||
const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS);
|
||
|
||
for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) {
|
||
size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE);
|
||
|
||
TIMER_START(fetch);
|
||
// fetch activation block into VTCM
|
||
{
|
||
const float *activation_block = x + mr * k + kk;
|
||
|
||
dma_queue_push(ctx->dma[0],
|
||
dma_make_ptr(vtcm_scratch1, activation_block),
|
||
k_blk_sz * sizeof(float),
|
||
k * sizeof(float),
|
||
k_blk_sz * sizeof(float),
|
||
m_blk_sz);
|
||
}
|
||
|
||
// fetch weight block into VTCM (x4x2 sub-block: quants + scales)
|
||
{
|
||
qweight_fetch_task_state_t s;
|
||
|
||
const int blk_start = kk / QK_Q4_0x4x2;
|
||
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
|
||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||
const int scale_blk_size =
|
||
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
|
||
|
||
s.dst = vtcm_scratch0;
|
||
s.src = w + nc * row_stride;
|
||
s.n_rows = n_blk_sz;
|
||
s.src_stride = row_stride;
|
||
s.dst_stride = sub_row_stride;
|
||
s.quant_off =
|
||
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
|
||
s.quant_width =
|
||
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
|
||
s.scale_off = full_qrow + blk_start * scale_blk_size;
|
||
s.scale_width = nb_sub * scale_blk_size;
|
||
|
||
// 2D DMA: quants sub-range
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
|
||
s.dst_stride, s.src_stride, s.quant_width, s.n_rows);
|
||
// 2D DMA: scales sub-range
|
||
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off),
|
||
s.dst_stride, s.src_stride, s.scale_width, s.n_rows);
|
||
}
|
||
TIMER_STOP(fetch);
|
||
|
||
TIMER_START(act_load);
|
||
// load activation block
|
||
{
|
||
dma_queue_pop(ctx->dma[0]); // wait for act DNA
|
||
transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz);
|
||
}
|
||
TIMER_STOP(act_load);
|
||
|
||
TIMER_START(wt_dequant);
|
||
// dequantize weight block
|
||
{
|
||
dma_queue_pop(ctx->dma[0]);
|
||
dma_queue_pop(ctx->dma[0]);
|
||
// vtcm_scratch0 is used to store the qweight chunk
|
||
// worker_pool_run_func already returned, so fetch is done
|
||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0,
|
||
n_blk_sz, k_blk_sz, sub_row_stride, weight_type);
|
||
}
|
||
TIMER_STOP(wt_dequant);
|
||
|
||
// core mma
|
||
TIMER_START(core);
|
||
{
|
||
core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles,
|
||
n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0);
|
||
}
|
||
TIMER_STOP(core);
|
||
}
|
||
|
||
// store output block
|
||
{
|
||
float *output_block = out + (mr * n + nc);
|
||
transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n);
|
||
}
|
||
}
|
||
}
|
||
|
||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||
|
||
#if defined(ENABLE_PROFILE_TIMERS)
|
||
FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us",
|
||
TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core));
|
||
#endif
|
||
return 0;
|
||
}
|