Merge signed commit 'b5332' into sisyphus
Diff-After-Merge: 2 files changed, 7 insertions(+) # gpg: Signature made Fri May 9 23:15:39 2025 MSK # gpg: using RSA key B5690EEEBB952194 # gpg: Good signature from "GitHub <noreply@github.com>" [unknown]
This commit is contained in:
commit
2969e4c10c
77 changed files with 2686 additions and 1156 deletions
|
|
@ -16,8 +16,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
|||
|
||||
## Hot topics
|
||||
|
||||
- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
|
||||
- **GGML developer experience survey (organized and reviewed by NVIDIA):** [link](https://forms.gle/Gasw3cRgyhNEnrwK9)
|
||||
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141]((https://github.com/ggml-org/llama.cpp/pull/13141))), `libllava` will be deprecated
|
||||
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated
|
||||
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
|
||||
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
|
|
|
|||
|
|
@ -119,8 +119,8 @@ if (LLAMA_LLGUIDANCE)
|
|||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
# v0.7.10:
|
||||
GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8
|
||||
# v0.7.19 (+ fancy-regex build fix):
|
||||
GIT_TAG b59f98f85269892a7de3d3641ad155366f13daa6
|
||||
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||
BUILD_IN_SOURCE TRUE
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ using json = nlohmann::ordered_json;
|
|||
|
||||
std::initializer_list<enum llama_example> mmproj_examples = {
|
||||
LLAMA_EXAMPLE_LLAVA,
|
||||
// TODO: add LLAMA_EXAMPLE_SERVER when it's ready
|
||||
LLAMA_EXAMPLE_SERVER,
|
||||
};
|
||||
|
||||
static std::string read_file(const std::string & fname) {
|
||||
|
|
@ -2627,6 +2627,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.i_chunk = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
||||
add_opt(common_arg(
|
||||
{"--parse-special"},
|
||||
string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.parse_special = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
|
||||
add_opt(common_arg(
|
||||
{"-pps"},
|
||||
string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"),
|
||||
|
|
|
|||
|
|
@ -409,6 +409,7 @@ struct common_params {
|
|||
|
||||
bool process_output = false; // collect data for the output tensor
|
||||
bool compute_ppl = true; // whether to compute perplexity
|
||||
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||
|
||||
// cvector-generator params
|
||||
int n_pca_batch = 100;
|
||||
|
|
|
|||
69
docs/multimodal.md
Normal file
69
docs/multimodal.md
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# Multimodal
|
||||
|
||||
llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature:
|
||||
- [llama-mtmd-cli](../tools/mtmd/README.md)
|
||||
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
|
||||
|
||||
To enable it, can use use one of the 2 methods below:
|
||||
|
||||
- Use `-hf` option with a [supported model](../../docs/multimodal.md)
|
||||
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
|
||||
- To load a model using `-hf` while using a custom mmproj file, use `--mmproj local_file.gguf`
|
||||
- Use `-m model.gguf` option with `--mmproj file.gguf` to specify text and multimodal projector respectively
|
||||
|
||||
By default, multimodal projector will be offloaded to GPU. To disable this, add `--no-mmproj-offload`
|
||||
|
||||
For example:
|
||||
|
||||
```sh
|
||||
# simple usage with CLI
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
|
||||
# simple usage with server
|
||||
llama-server -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
|
||||
# using local file
|
||||
llama-server -m gemma-3-4b-it-Q4_K_M.gguf --mmproj mmproj-gemma-3-4b-it-Q4_K_M.gguf
|
||||
|
||||
# no GPU offload
|
||||
llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
```
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default.
|
||||
|
||||
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
||||
|
||||
NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
```sh
|
||||
# Gemma 3
|
||||
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-3-12b-it-GGUF
|
||||
(tool_name) -hf ggml-org/gemma-3-27b-it-GGUF
|
||||
|
||||
# SmolVLM
|
||||
(tool_name) -hf ggml-org/SmolVLM-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/SmolVLM-256M-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/SmolVLM-500M-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
|
||||
|
||||
# Pixtral 12B
|
||||
(tool_name) -hf ggml-org/pixtral-12b-GGUF
|
||||
|
||||
# Qwen 2 VL
|
||||
(tool_name) -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF
|
||||
|
||||
# Qwen 2.5 VL
|
||||
(tool_name) -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF
|
||||
(tool_name) -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF
|
||||
|
||||
# Mistral Small 3.1 24B (IQ2_M quantization)
|
||||
(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF
|
||||
```
|
||||
|
|
@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
|
|||
|
||||
set(CUDA_CXX_FLAGS "")
|
||||
|
||||
set(CUDA_FLAGS -use_fast_math)
|
||||
set(CUDA_FLAGS -use_fast_math -extended-lambda)
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
|
||||
# Options are:
|
||||
|
|
|
|||
|
|
@ -296,6 +296,25 @@ static __device__ void no_device_code(
|
|||
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
|
||||
#endif // __CUDA_ARCH__
|
||||
|
||||
// The compiler is always able to unroll loops if they contain continue expressions.
|
||||
// In such cases loop unrolling can still be achieved via recursion:
|
||||
template <int n>
|
||||
struct ggml_cuda_unroll {
|
||||
template <typename Func, typename... Args>
|
||||
__device__ void operator()(const Func & f, Args... args) const {
|
||||
f(n - 1, args...);
|
||||
ggml_cuda_unroll<n - 1>{}(f, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ggml_cuda_unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
__device__ void operator()(const Func & f, Args... args) const {
|
||||
f(0, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
|
|
|||
|
|
@ -2,6 +2,17 @@
|
|||
|
||||
#include "common.cuh"
|
||||
|
||||
|
||||
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
return __cvta_generic_to_shared(generic_ptr);
|
||||
#else
|
||||
GGML_UNUSED(generic_ptr);
|
||||
NO_DEVICE_CODE;
|
||||
return 0;
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
}
|
||||
|
||||
// Copies data from global to shared memory, cg == cache global.
|
||||
// Both the src and dst pointers must be aligned to 16 bit.
|
||||
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
|
||||
|
|
|
|||
|
|
@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|||
nullptr;
|
||||
}
|
||||
|
||||
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
||||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
|
|
@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
|
|||
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
|
||||
fprintf(stderr, "Only f16 is supported.\n");
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int ncols1, int ncols2, int KQ_stride>
|
||||
template <int DV, int ncols1, int ncols2>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
|
||||
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
|
||||
|
|
@ -691,7 +691,7 @@ void launch_fattn(
|
|||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
|
|
@ -754,10 +754,13 @@ void launch_fattn(
|
|||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = 2*nsm;
|
||||
const int max_blocks = max_blocks_per_sm*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
|
||||
|
|
@ -769,14 +772,11 @@ void launch_fattn(
|
|||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
|
|
@ -853,19 +853,19 @@ void launch_fattn(
|
|||
|
||||
if (stream_k) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<D>
|
||||
flash_attn_combine_results<DV>
|
||||
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||
} break;
|
||||
case 128: {
|
||||
|
|
@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
|
||||
} break;
|
||||
default: {
|
||||
|
|
|
|||
|
|
@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||
} break;
|
||||
case 128: {
|
||||
|
|
@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
constexpr int nwarps = 8;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
|
||||
launch_fattn<D, cols_per_block, 1, -1>
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
|
||||
} break;
|
||||
default: {
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
|||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
|
|
|
|||
|
|
@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
|
|
|||
|
|
@ -8,58 +8,32 @@
|
|||
#include "fattn-wmma-f16.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
template <int D, int ncols2>
|
||||
template <int DKQ, int DV, int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
if constexpr (ncols2 <= 8) {
|
||||
if (Q->ne[1] <= 8/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 16/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32/ncols2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
|
||||
}
|
||||
|
||||
template <int ncols2>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
template <int DKQ, int DV>
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
|
|
@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const float use_gqa_opt = mask && max_bias == 0.0f;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 4) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio == 2) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
GGML_ASSERT(V->ne[0] == 64);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
|
||||
break;
|
||||
case 80:
|
||||
GGML_ASSERT(V->ne[0] == 80);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
|
||||
break;
|
||||
case 96:
|
||||
GGML_ASSERT(V->ne[0] == 96);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
|
||||
break;
|
||||
case 112:
|
||||
GGML_ASSERT(V->ne[0] == 112);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
|
||||
break;
|
||||
case 128:
|
||||
GGML_ASSERT(V->ne[0] == 128);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
|
||||
break;
|
||||
case 256:
|
||||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f;
|
||||
GGML_ASSERT(use_gqa_opt);
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
|
|
@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
|
||||
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ static __global__ void k_get_rows(
|
|||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
|
|
@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
|
|||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
|
|
@ -94,8 +96,8 @@ static void get_rows_cuda_q(
|
|||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
|
|
@ -127,8 +129,8 @@ static void get_rows_cuda_float(
|
|||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
|
|
|
|||
|
|
@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return false;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
// different head sizes of K and V are not supported yet
|
||||
return false;
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
|
||||
return false;
|
||||
}
|
||||
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
|
||||
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek MLA
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#include "../fattn-mma-f16.cuh"
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
|
|||
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
|
||||
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
|
||||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
|
|
@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
|
|||
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
||||
|
||||
for ncols in [8, 16, 32, 64, 128]:
|
||||
for ncols2 in [1, 2, 4, 8]:
|
||||
for ncols in [8, 16, 32, 64]:
|
||||
for ncols2 in [1, 2, 4, 8, 16]:
|
||||
if ncols2 > ncols:
|
||||
continue
|
||||
ncols1 = ncols // ncols2
|
||||
if ncols == 128:
|
||||
continue # Too much register pressure.
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
if ncols == 128 and head_size == 256:
|
||||
continue # Needs too much shared memory.
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
|
||||
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type in TYPES_MMQ:
|
||||
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
||||
|
|
|
|||
|
|
@ -299,21 +299,42 @@ typedef struct {
|
|||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne10;
|
||||
int32_t ne11; // n_expert_used (bcast)
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t neh11; // n_tokens
|
||||
uint64_t nbh11;
|
||||
int32_t ne20; // n_expert_used
|
||||
uint64_t nb21;
|
||||
} ggml_metal_kargs_mul_mm_id_map0;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne20; // n_expert_used
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
uint64_t nbh1;
|
||||
uint64_t nbh2;
|
||||
int32_t ne0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_mul_mm_id_map1;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
uint64_t nb03;
|
||||
int32_t neh12;
|
||||
uint64_t nbh10;
|
||||
uint64_t nbh11;
|
||||
uint64_t nbh12;
|
||||
uint64_t nbh13;
|
||||
int32_t neh0;
|
||||
int32_t neh1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
typedef struct {
|
||||
|
|
|
|||
|
|
@ -306,28 +306,30 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
||||
|
|
@ -650,7 +652,8 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
|||
}
|
||||
|
||||
if (mem_pool->heaps_to_remove.count > 0) {
|
||||
for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
|
||||
// remove in reverse order
|
||||
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
|
||||
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
||||
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
||||
|
||||
|
|
@ -659,6 +662,10 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
|||
|
||||
[mem_pool->heaps removeObjectAtIndex:index];
|
||||
[ptr release];
|
||||
|
||||
if (i == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
[mem_pool->heaps_to_remove removeAllObjects];
|
||||
|
|
@ -672,7 +679,7 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
|
|||
}
|
||||
|
||||
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
|
||||
const size_t alignment = 32;
|
||||
const size_t alignment = 256;
|
||||
|
||||
const size_t size_aligned = GGML_PAD(size, alignment);
|
||||
|
||||
|
|
@ -1242,28 +1249,30 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||
|
|
@ -2999,7 +3008,7 @@ static bool ggml_metal_encode_node(
|
|||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
|
|
@ -3219,8 +3228,6 @@ static bool ggml_metal_encode_node(
|
|||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const int n_as = src0->ne[2];
|
||||
|
||||
// src2 = ids
|
||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
||||
|
||||
|
|
@ -3234,24 +3241,21 @@ static bool ggml_metal_encode_node(
|
|||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
const uint32_t r2 = 1;
|
||||
const uint32_t r3 = 1;
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
// ne20 = n_used_experts
|
||||
// ne21 = n_rows
|
||||
const int dst_rows = ne20*ne21;
|
||||
const int dst_rows_min = n_as;
|
||||
const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
|
||||
|
||||
// max size of the rowids array in the kernel shared buffer
|
||||
//GGML_ASSERT(dst_rows <= dst_rows_max);
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
//ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
|
||||
dst_rows > dst_rows_min &&
|
||||
dst_rows <= dst_rows_max) {
|
||||
(ne21 >= ne21_mm_id_min)) {
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
|
|
@ -3262,62 +3266,169 @@ static bool ggml_metal_encode_node(
|
|||
default: break;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
const int64_t neh10 = ne10; // n_embd
|
||||
const int64_t neh11 = ne21; // n_tokens
|
||||
const int64_t neh12 = ne02; // n_expert
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
|
||||
const uint64_t nbh11 = nbh10*neh10;
|
||||
const uint64_t nbh12 = nbh11*neh11;
|
||||
const uint64_t nbh13 = nbh12*neh12;
|
||||
|
||||
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
|
||||
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
||||
if (!h_src1) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
/*.nei1 =*/ ne21,
|
||||
/*.nbi1 =*/ nb21,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
};
|
||||
const int64_t neh0 = ne0;
|
||||
const int64_t neh1 = ne21;
|
||||
const int64_t neh2 = ne02;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
|
||||
const uint64_t nbh1 = nbh0*neh0;
|
||||
const uint64_t nbh2 = nbh1*neh1;
|
||||
//const uint64_t nbh3 = nbh2*neh2;
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
||||
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
|
||||
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
||||
if (!h_dst) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
||||
return false;
|
||||
}
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
// tokens per expert
|
||||
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
||||
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||
if (!h_tpe) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
// id map
|
||||
// [n_expert_used, n_tokens]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
|
||||
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
||||
if (!h_ids) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
const int nth = MIN(1024, ne10/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map0 args = {
|
||||
ne10,
|
||||
ne11, // n_expert_used (bcast)
|
||||
nb11,
|
||||
nb12,
|
||||
neh11, // n_tokens
|
||||
nbh11,
|
||||
ne20, // n_expert_used
|
||||
nb21,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
|
||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
|
||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
||||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.neh12 =*/ neh12,
|
||||
/*.nbh10 =*/ nbh10,
|
||||
/*.nbh11 =*/ nbh11,
|
||||
/*.nbh12 =*/ nbh12,
|
||||
/*.nbh13 =*/ nbh13,
|
||||
/*.neh0 =*/ neh0,
|
||||
/*.neh1 =*/ neh1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(ne0 % 4 == 0);
|
||||
|
||||
const int nth = MIN(1024, ne0/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map1 args = {
|
||||
ne20, // n_expert_used
|
||||
neh0,
|
||||
neh1,
|
||||
nbh1,
|
||||
nbh2,
|
||||
ne0,
|
||||
nb1,
|
||||
nb2,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} else {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
|
|
@ -3511,7 +3622,7 @@ static bool ggml_metal_encode_node(
|
|||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||
|
||||
const int64_t _ne1 = 1;
|
||||
const int64_t ne123 = dst_rows;
|
||||
const int64_t ne123 = ne20*ne21;
|
||||
|
||||
if (smem > 0) {
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
|
|
|||
|
|
@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm(
|
|||
}
|
||||
}
|
||||
|
||||
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
||||
// TODO: this kernel needs to be reimplemented from scratch for better performance
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm_id_impl(
|
||||
int32_t ne00,
|
||||
int32_t ne02,
|
||||
uint64_t nb01,
|
||||
uint64_t nb02,
|
||||
int32_t ne11,
|
||||
int32_t ne12,
|
||||
uint64_t nb10,
|
||||
uint64_t nb11,
|
||||
uint64_t nb12,
|
||||
int32_t ne0,
|
||||
int32_t ne1,
|
||||
int64_t ne0ne1,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
threadgroup ushort2 * rowids,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
template<typename T4>
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * hsrc1,
|
||||
device char * htpe,
|
||||
device char * hids,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ide = tgpig[0]; // expert id
|
||||
|
||||
int n_all = 0;
|
||||
|
||||
device int32_t * ids_i32 = (device int32_t *) (hids);
|
||||
|
||||
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
||||
|
||||
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
||||
if (src2_i32[i20] != ide) {
|
||||
continue;
|
||||
}
|
||||
|
||||
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
||||
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
||||
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
||||
}
|
||||
|
||||
++n_all;
|
||||
}
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
||||
tpe_i32[ide] = n_all;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_mul_mm_id_map1(
|
||||
constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
||||
device const char * hdst,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i20 = tgpig[0]; // used expert
|
||||
const int i21 = tgpig[1]; // token
|
||||
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
||||
|
||||
const int id = ids_i32[i21*args.ne20 + i20];
|
||||
|
||||
const int ide = id / args.neh1;
|
||||
const int idt = id % args.neh1;
|
||||
|
||||
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
||||
dst_f32x4[i0] = hdst_f32x4[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * tpe,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = (threadgroup half *)(shmem);
|
||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
||||
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
const int im = tgpig.z;
|
||||
|
||||
if (r1*BLOCK_SIZE_N >= ne1) return;
|
||||
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
||||
|
||||
const int neh1 = tpe_i32[im];
|
||||
|
||||
if (r1*BLOCK_SIZE_N >= neh1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_half8x8 mb[2];
|
||||
simdgroup_float8x8 mc[8];
|
||||
for (int i = 0; i < 8; i++){
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
|
||||
ushort offset1 = il/nl;
|
||||
const int i12 = im%args.neh12;
|
||||
const int i13 = im/args.neh12;
|
||||
|
||||
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
||||
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const short offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
||||
device const float * y = (device const float *)(src1
|
||||
+ nb12 * id[1]
|
||||
+ nb11 * (id[0] % ne11)
|
||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
device const block_q * x = (device const block_q *)(src0
|
||||
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
device const half * y = (device const half *)(src1
|
||||
+ args.nbh13*i13
|
||||
+ args.nbh12*i12
|
||||
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
||||
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
half4x4 temp_a;
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
||||
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
||||
#pragma unroll(16)
|
||||
for (short i = 0; i < 16; i++) {
|
||||
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||
}
|
||||
|
||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
||||
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
y += BLOCK_SIZE_K;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||
|
||||
#pragma unroll(BLOCK_SIZE_K/8)
|
||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||
#pragma unroll(4)
|
||||
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
#pragma unroll(2)
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
||||
}
|
||||
|
||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (int i = 0; i < 8; i++){
|
||||
for (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
|
||||
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
||||
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
||||
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
|
|
@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
device const char * ids,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int32_t i02 = tgpig.z;
|
||||
|
||||
tgpig.z = 0;
|
||||
|
||||
device const char * src0 = src0s + i02*args.nb02;
|
||||
|
||||
// row indices
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
||||
|
||||
// TODO: parallelize this loop
|
||||
int32_t _ne1 = 0;
|
||||
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
||||
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
||||
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
||||
if (id == i02) {
|
||||
if (tiitg == 0) {
|
||||
rowids[_ne1] = ushort2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
||||
args.ne00,
|
||||
args.ne02,
|
||||
args.nb01,
|
||||
args.nb02,
|
||||
args.ne11,
|
||||
args.ne12,
|
||||
args.nb10,
|
||||
args.nb11,
|
||||
args.nb12,
|
||||
args.ne0,
|
||||
_ne1,
|
||||
(int64_t)args.ne0*args.ne1,
|
||||
src0,
|
||||
src1,
|
||||
rowids,
|
||||
dst,
|
||||
shmem,
|
||||
tgpig,
|
||||
tiitg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
#define QK_NL 16
|
||||
|
||||
//
|
||||
|
|
@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
|
|||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
|
||||
//
|
||||
// matrix-vector multiplication
|
||||
|
|
|
|||
|
|
@ -151,6 +151,12 @@ struct rpc_msg_buffer_clear_req {
|
|||
uint8_t value;
|
||||
};
|
||||
|
||||
struct rpc_msg_set_tensor_hash_req {
|
||||
rpc_tensor tensor;
|
||||
uint64_t offset;
|
||||
uint64_t hash;
|
||||
};
|
||||
|
||||
struct rpc_msg_set_tensor_hash_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
|
|
@ -548,15 +554,12 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
|
|||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
if (size > HASH_THRESHOLD) {
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
|
||||
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
|
||||
rpc_msg_set_tensor_hash_req request;
|
||||
request.tensor = rpc_tensor;
|
||||
request.offset = offset;
|
||||
request.hash = fnv_hash((const uint8_t*)data, size);
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
if (response.result) {
|
||||
// the server has the same data, no need to send it
|
||||
|
|
@ -864,7 +867,7 @@ public:
|
|||
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
||||
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
||||
bool set_tensor(const std::vector<uint8_t> & input);
|
||||
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
|
|
@ -1101,18 +1104,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
|
||||
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
|
||||
{
|
||||
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
|
||||
if (input.size() != sizeof(rpc_tensor) + 16) {
|
||||
return false;
|
||||
}
|
||||
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
||||
uint64_t offset;
|
||||
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
||||
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
|
||||
std::vector<uint8_t> cached_file;
|
||||
if (!get_cached_file(*hash, cached_file)) {
|
||||
if (!get_cached_file(request.hash, cached_file)) {
|
||||
response.result = 0;
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1125,25 +1120,28 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
|
|||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
GGML_ASSERT(ctx_ptr != nullptr);
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||
if (tensor == nullptr) {
|
||||
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
||||
return false;
|
||||
}
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
|
||||
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
|
||||
|
||||
// sanitize tensor->data
|
||||
{
|
||||
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
||||
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
||||
|
||||
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
||||
if (request.tensor.data + request.offset < p0
|
||||
|| request.tensor.data + request.offset >= p1
|
||||
|| size > (p1 - request.tensor.data - request.offset)) {
|
||||
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
|
||||
__func__, in_tensor->data, offset, size, *hash, p0, p1);
|
||||
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
|
||||
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
|
||||
response.result = 1;
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1503,12 +1501,12 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|||
break;
|
||||
}
|
||||
case RPC_CMD_SET_TENSOR_HASH: {
|
||||
std::vector<uint8_t> input;
|
||||
if (!recv_msg(sockfd, input)) {
|
||||
rpc_msg_set_tensor_hash_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
if (!server.set_tensor_hash(input, response)) {
|
||||
if (!server.set_tensor_hash(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
|
|
|
|||
|
|
@ -14,23 +14,24 @@
|
|||
#define GGML_SYCL_BACKEND_HPP
|
||||
|
||||
#include "binbcast.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "common.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "gla.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "mmq.hpp"
|
||||
#include "mmvq.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "gla.hpp"
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr);
|
|||
|
||||
extern int g_ggml_sycl_debug;
|
||||
extern int g_ggml_sycl_disable_optimize;
|
||||
extern int g_ggml_sycl_prioritize_dmmv;
|
||||
|
||||
#define GGML_SYCL_DEBUG(...) \
|
||||
do { \
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
|||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
|
|
@ -195,11 +196,13 @@ static void ggml_check_sycl() try {
|
|||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
|
|
@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|||
std::exit(1);
|
||||
}
|
||||
|
||||
enum class mul_mat_algo {
|
||||
DMMV = 0,
|
||||
MMVQ = 1,
|
||||
MUL_MAT_SYCL = 2,
|
||||
};
|
||||
|
||||
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
// TODO: accuracy issues in MMQ
|
||||
GGML_UNUSED(type);
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
|
|
@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
|||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
||||
|
||||
stream->parallel_for(
|
||||
|
|
@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
|||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
}
|
||||
|
||||
/*
|
||||
* This function could be called when the OP (mul_mat) function support reorder optimizition.
|
||||
*/
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
||||
src0->type == GGML_TYPE_Q4_0 &&
|
||||
src1->ne[2]==1 && src1->ne[3]==1) {
|
||||
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
||||
return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
||||
dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
|
||||
if (!extra) return; //only happen in CI/UT permute case.
|
||||
|
||||
if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
|
||||
|
||||
reorder_qw(src0, ctx->stream());
|
||||
extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
|
||||
ggml_tensor * dst, mul_mat_algo mm_algorithm) {
|
||||
if (!should_reorder_tensor(*ctx, dst)) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
||||
if (!extra || extra->optimized_feature.reorder) {
|
||||
return; // Skip permutations and already reordered tensors
|
||||
}
|
||||
|
||||
switch (mm_algorithm) {
|
||||
case mul_mat_algo::DMMV:
|
||||
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case mul_mat_algo::MMVQ:
|
||||
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case mul_mat_algo::MUL_MAT_SYCL:
|
||||
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
reorder_qw(src0, ctx->stream());
|
||||
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
int64_t min_compute_capability = INT_MAX;
|
||||
|
||||
if (split) {
|
||||
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
auto & tensor_split = buft_ctx->tensor_split;
|
||||
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
|
||||
// skip devices that are not going to do any work:
|
||||
|
|
@ -2924,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
}
|
||||
}
|
||||
} else {
|
||||
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
|
||||
min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
|
||||
}
|
||||
|
||||
// check data types and tensor shapes for custom matrix multiplication kernels:
|
||||
|
|
@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
||||
#endif // SYCL_USE_XMX
|
||||
|
||||
|
||||
// mmvq path is faster in the CUDA backend.
|
||||
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
|
||||
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
||||
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
||||
// is enabled takes precedence over DMMV, the current if-else implementation
|
||||
// requires disabling DMMV if both conditions are met
|
||||
|| (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// TODO: Refactor and cleanup of mul mat dispatching.
|
||||
|
|
@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
// KQ + KQV multi-batch
|
||||
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
||||
// save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
|
||||
} else if (use_mul_mat_vec_q) {
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
||||
constexpr bool convert_src1_to_q8_1 = true;
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
|
||||
} else if (use_mul_mat_q) {
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
|
||||
constexpr bool convert_src1_to_q8_1 = true;
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
||||
} else {
|
||||
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
// MUL_MAT_SYCL supports reorder
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
||||
}
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,60 @@
|
|||
#include "mmvq.hpp"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "common.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "vecdotq.hpp"
|
||||
#include <cassert>
|
||||
|
||||
template <typename reorder_vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
|
||||
using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
|
||||
using block_traits = typename block_type::traits;
|
||||
|
||||
const auto sg = nd_item.get_sub_group();
|
||||
const int sg_range = sg.get_group_linear_range();
|
||||
const int workgroup_id = nd_item.get_group_linear_id();
|
||||
const int sg_id = sg.get_group_linear_id();
|
||||
const int row = workgroup_id * sg_range + sg_id;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / block_traits::qk;
|
||||
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
||||
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
||||
|
||||
static_assert(blocks_per_subgroup > 0);
|
||||
static_assert(block_elements_per_subgroup > 0);
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
float partial_sum = 0.0f;
|
||||
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
||||
const int bx_offset = block_type::get_block_offset(ibx);
|
||||
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
|
||||
// Y block index that aligns with ibx
|
||||
const int iby = i * block_type::block_to_q8_1_ratio();
|
||||
|
||||
#pragma unroll
|
||||
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
|
||||
}
|
||||
}
|
||||
|
||||
auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
|
||||
|
||||
if (sg.leader()) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
|
|
@ -480,26 +534,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -916,93 +983,95 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_mul_mat_vec_q(
|
||||
ggml_backend_sycl_context & ctx,
|
||||
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
||||
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
||||
const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
const dpct::queue_ptr &stream) {
|
||||
|
||||
void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
|
||||
const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
const dpct::queue_ptr & stream) {
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
int id;
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
const size_t q8_1_ts = sizeof(block_q8_1);
|
||||
const size_t q8_1_bs = QK8_1;
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
|
||||
for (int i = 0; i < src1_ncols; i++)
|
||||
{
|
||||
for (int i = 0; i < src1_ncols; i++) {
|
||||
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
|
||||
const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
|
||||
float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
||||
const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
|
||||
float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
case GGML_TYPE_Q4_0:
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
|
||||
reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
|
||||
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
|
|
|||
61
ggml/src/ggml-sycl/quants.hpp
Normal file
61
ggml/src/ggml-sycl/quants.hpp
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Codeplay Software Ltd.
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_QUANTS_HPP
|
||||
#define GGML_SYCL_QUANTS_HPP
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
namespace ggml_sycl_reordered {
|
||||
|
||||
|
||||
// The reordered block moves quants (qs) and scales(d) to two
|
||||
// uniform regions of memory that is contiguous in the same tensor.
|
||||
// What this means is that instead of having:
|
||||
// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]
|
||||
// We have:
|
||||
// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN]
|
||||
//
|
||||
// Notes: out-of-bounds qs will run into d values
|
||||
// Aligment relies on the allocated size of qs
|
||||
|
||||
template <ggml_type type> struct block_q_t;
|
||||
|
||||
|
||||
// qk number of weights / quants in a block
|
||||
// qr number of weights in a byte (described as 'before dequantization')
|
||||
// for quantization types that has low and high bits split, qr is calculated with
|
||||
// using the lower bits, e.g for Q6 quants QR6 is 2
|
||||
// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)
|
||||
// See ggml-common.h to see how these are calculated
|
||||
template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK4_0;
|
||||
static constexpr uint32_t qi = QI4_0;
|
||||
static constexpr uint32_t qr = QR4_0;
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
|
|
@ -14,8 +14,11 @@
|
|||
#define GGML_SYCL_VECDOTQ_HPP
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
#include "quants.hpp"
|
||||
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs);
|
||||
|
||||
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
|
||||
const uint16_t* x16 =
|
||||
|
|
@ -252,13 +255,60 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
|
|||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
template <ggml_type T> struct reorder_vec_dot_q_sycl {
|
||||
static_assert(T != T, "ggml_type for reorder vecdot not implemented");
|
||||
};
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q4_0;
|
||||
|
||||
using q4_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_0>;
|
||||
using q4_0_traits = typename q4_0_block::traits;
|
||||
|
||||
__dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) {
|
||||
int sumi = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
||||
const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
|
||||
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
|
||||
|
||||
// SIMD dot product of quantized values
|
||||
sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
|
||||
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
|
||||
}
|
||||
|
||||
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
int u[2 * q4_0_traits::vdr_mmvq];
|
||||
|
||||
#pragma unroll
|
||||
|
||||
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
||||
v[i] = get_int_from_uint8(bq4_0, iqs + i);
|
||||
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
|
||||
}
|
||||
|
||||
return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
|
||||
};
|
||||
};
|
||||
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
template <int vdr>
|
||||
static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
|
||||
const float &d4,
|
||||
const sycl::half2 &ds8) {
|
||||
static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4,
|
||||
const sycl::half2 & ds8) {
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vdr; ++i) {
|
||||
|
|
@ -270,8 +320,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
|
|||
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
|
||||
}
|
||||
|
||||
const sycl::float2 ds8f =
|
||||
ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
|
||||
|
|
@ -456,13 +505,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
|
|||
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
|
||||
|
||||
int v[VDR_Q4_0_Q8_1_MMVQ];
|
||||
int u[2*VDR_Q4_0_Q8_1_MMVQ];
|
||||
int u[2 * VDR_Q4_0_Q8_1_MMVQ];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
|
||||
v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
|
||||
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
|
||||
v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
|
||||
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
|
||||
}
|
||||
|
||||
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
|
||||
|
|
|
|||
|
|
@ -1632,7 +1632,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|||
const uint32_t warps = warptile[0] / warptile[10];
|
||||
|
||||
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
||||
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
||||
|
||||
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
||||
|
|
@ -5260,7 +5260,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
|
||||
const uint64_t nei0 = ids->ne[0];
|
||||
const uint64_t nei1 = ids->ne[1];
|
||||
GGML_ASSERT(nei0 * nei1 <= 3072);
|
||||
GGML_ASSERT(nei0 * nei1 <= 4096);
|
||||
|
||||
const uint32_t nbi1 = ids->nb[1];
|
||||
const uint32_t nbi2 = ids->nb[2];
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
|||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[3072];
|
||||
shared u16vec4 row_ids[4096];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
|||
#define LOAD_VEC_B 4
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
shared u16vec2 row_ids[4096];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
|
|
|||
|
|
@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec(
|
|||
c = ggml_mul_mat_id(ctx, as, b, ids);
|
||||
|
||||
as -> [cols, rows, n_expert]
|
||||
ids -> [n_experts_used, n_tokens] (i32)
|
||||
b -> [cols, n_expert_used, n_tokens]
|
||||
ids -> [n_expert_used, n_tokens] (i32)
|
||||
c -> [rows, n_expert_used, n_tokens]
|
||||
|
||||
in b, n_experts_used can be broadcasted to match the n_expert_used of ids
|
||||
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
|
||||
|
||||
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
|||
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||
{
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
|
|
@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
|||
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
|
||||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
|
|
@ -202,19 +203,20 @@ int32_t llm_chat_apply_template(
|
|||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
|
||||
// Official mistral 'v7' template
|
||||
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
|
||||
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
|
||||
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
std::string content(message->content);
|
||||
if (role == "system") {
|
||||
ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
|
||||
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
|
||||
} else if (role == "user") {
|
||||
ss << "[INST] " << content << "[/INST]";
|
||||
}
|
||||
else {
|
||||
ss << " " << content << "</s>";
|
||||
ss << "[INST]" << trailing_space << content << "[/INST]";
|
||||
} else {
|
||||
ss << trailing_space << content << "</s>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ enum llm_chat_template {
|
|||
LLM_CHAT_TEMPLATE_MISTRAL_V3,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
|
|
|
|||
|
|
@ -1227,8 +1227,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
||||
if (v_mla) {
|
||||
#if 0
|
||||
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
||||
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
||||
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
#else
|
||||
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
||||
// The permutations are noops and only change how the tensor data is interpreted.
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
||||
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
||||
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
||||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
|
|
|
|||
|
|
@ -823,6 +823,10 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
|||
mmaps_used.reserve(files.size());
|
||||
for (const auto & file : files) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
||||
if (!reg) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
||||
mmaps_used.emplace_back(mapping->size(), 0);
|
||||
|
|
|
|||
|
|
@ -299,6 +299,10 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
|
|||
// add extra buffer types, only if no GPU device is present
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
||||
|
|
@ -1484,6 +1488,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (cpu_dev == nullptr) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
|
||||
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
||||
|
|
@ -1672,6 +1679,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
auto * buft_dev = ggml_backend_buft_get_device(buft);
|
||||
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
throw std::runtime_error("no CPU backend found");
|
||||
}
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
}
|
||||
|
||||
|
|
@ -4122,6 +4132,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
if (!dev) {
|
||||
// FIXME: workaround for CPU backend buft having a NULL device
|
||||
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!dev) {
|
||||
throw std::runtime_error(format("%s: no CPU backend found", __func__));
|
||||
}
|
||||
}
|
||||
ggml_backend_dev_props props;
|
||||
ggml_backend_dev_get_props(dev, &props);
|
||||
|
|
@ -13387,6 +13400,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
|
|||
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
||||
const auto & it = model->gguf_kv.find(key);
|
||||
if (it == model->gguf_kv.end()) {
|
||||
// one-off fix for very popular models (so we are not flooded with issues)
|
||||
// do not extend this list unless absolutely necessary
|
||||
// Mistral-Small-2503 does not have built-in chat template
|
||||
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
|
||||
if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
||||
return "mistral-v7-tekken";
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ static void print_usage(int, char ** argv) {
|
|||
LOG("\n %s \\\n"
|
||||
" -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] \\\n"
|
||||
" [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
|
||||
" [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]);
|
||||
" [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...] \\\n"
|
||||
" [--parse-special]\n" , argv[0]);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
|
|
@ -439,7 +440,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||
auto tim1 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: tokenizing the input ..\n", __func__);
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true, params.parse_special);
|
||||
|
||||
auto tim2 = std::chrono::high_resolution_clock::now();
|
||||
LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
|
|
|
|||
|
|
@ -152,7 +152,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
if (!cpu_dev) {
|
||||
LOG_ERR("%s: no CPU backend found\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
auto * reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_new");
|
||||
auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free");
|
||||
|
||||
|
|
|
|||
|
|
@ -16,38 +16,7 @@ The naming and structure related to multimodal support have evolved, which might
|
|||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default:
|
||||
|
||||
```sh
|
||||
# Gemma 3
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
|
||||
|
||||
# SmolVLM
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
|
||||
|
||||
# Pixtral 12B
|
||||
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
||||
|
||||
# Qwen 2 VL
|
||||
llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF
|
||||
|
||||
# Qwen 2.5 VL
|
||||
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF
|
||||
|
||||
# Mistral Small 3.1 24B (IQ2_M quantization)
|
||||
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
|
||||
```
|
||||
See the list of pre-quantized model [here](../../docs/multimodal.md)
|
||||
|
||||
## How it works and what is `mmproj`?
|
||||
|
||||
|
|
|
|||
|
|
@ -352,9 +352,12 @@ struct clip_ctx {
|
|||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
backend = ctx_params.use_gpu
|
||||
? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
|
||||
: nullptr;
|
||||
if (!backend_cpu) {
|
||||
throw std::runtime_error("failed to initialize CPU backend");
|
||||
}
|
||||
backend = ctx_params.use_gpu
|
||||
? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
|
||||
: nullptr;
|
||||
|
||||
if (backend) {
|
||||
LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
|
|
@ -2185,9 +2188,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity) {
|
|||
|
||||
struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) {
|
||||
g_logger_state.verbosity_thold = ctx_params.verbosity;
|
||||
clip_ctx * ctx_clip = new clip_ctx(ctx_params);
|
||||
clip_ctx * ctx_clip = nullptr;
|
||||
|
||||
try {
|
||||
ctx_clip = new clip_ctx(ctx_params);
|
||||
clip_model_loader loader(fname, *ctx_clip);
|
||||
loader.load_hparams();
|
||||
loader.load_tensors();
|
||||
|
|
|
|||
|
|
@ -212,6 +212,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
|||
ggml_build_forward_expand(gf, flatten);
|
||||
|
||||
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
|
||||
GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend");
|
||||
ggml_backend_graph_compute(backend.get(), gf);
|
||||
|
||||
struct ggml_tensor* result = ggml_graph_node(gf, -1);
|
||||
|
|
|
|||
|
|
@ -554,14 +554,19 @@ struct decode_embd_batch {
|
|||
llama_batch get_view(int offset, int n_tokens) {
|
||||
llama_pos * pos_ptr;
|
||||
pos_view.clear();
|
||||
pos_view.resize(n_tokens * n_pos_per_embd);
|
||||
pos_view.reserve(n_tokens * n_pos_per_embd);
|
||||
if (n_pos_per_embd > 1) {
|
||||
// mrope
|
||||
// for example, with layout of src: 1234...1234...1234...1234...
|
||||
// offset 2 will give us dst: 34...34...34...34...
|
||||
for (int i = 0; i < n_pos_per_embd; i++) {
|
||||
auto src = pos.begin() + i * batch.n_tokens + offset;
|
||||
pos_view.insert(pos_view.end(), src, src + n_tokens);
|
||||
// assume n_tokens is less than or equal to batch.n_tokens
|
||||
// batch.n_tokens is number of **total** tokens
|
||||
// n_tokens is number of viewed token
|
||||
size_t src_idx = i * batch.n_tokens + offset;
|
||||
pos_view.insert(pos_view.end(),
|
||||
pos.data() + src_idx,
|
||||
pos.data() + src_idx + n_tokens);
|
||||
}
|
||||
pos_ptr = pos_view.data();
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -237,15 +237,17 @@ static ggml_backend_t create_backend(const rpc_server_params & params) {
|
|||
backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
if (backend) {
|
||||
fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
|
||||
|
||||
// set the number of threads
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
ggml_backend_set_n_threads_fn(backend, params.n_threads);
|
||||
// set the number of threads
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
ggml_backend_set_n_threads_fn(backend, params.n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ Examples:
|
|||
llama-run ollama://smollm:135m
|
||||
llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf
|
||||
llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf
|
||||
llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf
|
||||
llama-run modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf
|
||||
llama-run https://example.com/some-file1.gguf
|
||||
llama-run some-file2.gguf
|
||||
llama-run file://some-file3.gguf
|
||||
|
|
|
|||
|
|
@ -267,7 +267,7 @@ class Opt {
|
|||
"Commands:\n"
|
||||
" model\n"
|
||||
" Model is a string with an optional prefix of \n"
|
||||
" huggingface:// (hf://), ollama://, https:// or file://.\n"
|
||||
" huggingface:// (hf://), modelscope:// (ms://), ollama://, https:// or file://.\n"
|
||||
" If no protocol is specified and a file exists in the specified\n"
|
||||
" path, file:// is assumed, otherwise if a file does not exist in\n"
|
||||
" the specified path, ollama:// is assumed. Models that are being\n"
|
||||
|
|
@ -282,6 +282,9 @@ class Opt {
|
|||
" llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
|
||||
" llama-run "
|
||||
"huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
|
||||
" llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
|
||||
" llama-run "
|
||||
"modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
|
||||
" llama-run https://example.com/some-file1.gguf\n"
|
||||
" llama-run some-file2.gguf\n"
|
||||
" llama-run file://some-file3.gguf\n"
|
||||
|
|
@ -689,7 +692,7 @@ class LlamaData {
|
|||
return 0;
|
||||
}
|
||||
|
||||
int huggingface_dl(std::string & model, const std::string & bn) {
|
||||
int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) {
|
||||
// Find the second occurrence of '/' after protocol string
|
||||
size_t pos = model.find('/');
|
||||
pos = model.find('/', pos + 1);
|
||||
|
|
@ -697,8 +700,6 @@ class LlamaData {
|
|||
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
|
||||
std::string url;
|
||||
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/");
|
||||
hfr = model_name;
|
||||
|
|
@ -720,6 +721,16 @@ class LlamaData {
|
|||
return download(url, bn, true, headers);
|
||||
}
|
||||
|
||||
int modelscope_dl(std::string & model, const std::string & bn) {
|
||||
std::string model_endpoint = "https://modelscope.cn/models/";
|
||||
return dl_from_endpoint(model_endpoint, model, bn);
|
||||
}
|
||||
|
||||
int huggingface_dl(std::string & model, const std::string & bn) {
|
||||
std::string model_endpoint = get_model_endpoint();
|
||||
return dl_from_endpoint(model_endpoint, model, bn);
|
||||
}
|
||||
|
||||
int ollama_dl(std::string & model, const std::string & bn) {
|
||||
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
|
||||
if (model.find('/') == std::string::npos) {
|
||||
|
|
@ -837,6 +848,9 @@ class LlamaData {
|
|||
rm_until_substring(model_, "hf.co/");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = huggingface_dl(model_, bn);
|
||||
} else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://")) {
|
||||
rm_until_substring(model_, "://");
|
||||
ret = modelscope_dl(model_, bn);
|
||||
} else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) &&
|
||||
!string_starts_with(model_, "https://ollama.com/library/")) {
|
||||
ret = download(model_, bn, true);
|
||||
|
|
|
|||
|
|
@ -34,8 +34,9 @@ endforeach()
|
|||
add_executable(${TARGET} ${TARGET_SRCS})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ../llava)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if (LLAMA_SERVER_SSL)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
|
|
|||
|
|
@ -193,6 +193,12 @@ services:
|
|||
LLAMA_ARG_PORT: 8080
|
||||
```
|
||||
|
||||
### Multimodal support
|
||||
|
||||
Multimodal support was added in [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) and is currently an experimental feature.
|
||||
|
||||
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
|
||||
|
||||
## Build
|
||||
|
||||
`llama-server` is built alongside everything else from the root of the project
|
||||
|
|
@ -749,6 +755,9 @@ This endpoint is public (no API key check). By default, it is read-only. To make
|
|||
"total_slots": 1,
|
||||
"model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
|
||||
"chat_template": "...",
|
||||
"modalities": {
|
||||
"vision": false
|
||||
},
|
||||
"build_info": "b(build number)-(build commit hash)"
|
||||
}
|
||||
```
|
||||
|
|
@ -757,6 +766,7 @@ This endpoint is public (no API key check). By default, it is read-only. To make
|
|||
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
|
||||
- `model_path` - the path to model file (same with `-m` argument)
|
||||
- `chat_template` - the model's original Jinja2 prompt template
|
||||
- `modalities` - the list of supported modalities
|
||||
|
||||
### POST `/props`: Change server global properties.
|
||||
|
||||
|
|
@ -1069,6 +1079,8 @@ print(completion.choices[0].text)
|
|||
|
||||
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
|
||||
|
||||
If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.
|
||||
|
||||
*Options:*
|
||||
|
||||
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -7,6 +7,7 @@
|
|||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "mtmd.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
|
|
@ -197,8 +198,8 @@ struct server_task {
|
|||
int id_target = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_INFERENCE
|
||||
slot_params params;
|
||||
llama_tokens prompt_tokens;
|
||||
slot_params params;
|
||||
server_tokens prompt_tokens;
|
||||
int id_selected_slot = -1;
|
||||
|
||||
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
|
||||
|
|
@ -1248,6 +1249,9 @@ struct server_slot {
|
|||
llama_context * ctx = nullptr;
|
||||
llama_context * ctx_dft = nullptr;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
|
||||
common_speculative * spec = nullptr;
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
|
|
@ -1275,14 +1279,14 @@ struct server_slot {
|
|||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
// input prompt tokens
|
||||
llama_tokens prompt_tokens;
|
||||
server_tokens prompt_tokens;
|
||||
|
||||
size_t last_nl_pos = 0;
|
||||
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
|
||||
llama_tokens cache_tokens;
|
||||
server_tokens cache_tokens;
|
||||
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
|
|
@ -1476,7 +1480,7 @@ struct server_slot {
|
|||
{"is_processing", is_processing()},
|
||||
{"non_causal", is_non_causal()},
|
||||
{"params", params.to_json()},
|
||||
{"prompt", common_detokenize(ctx, prompt_tokens)},
|
||||
{"prompt", prompt_tokens.detokenize(ctx, true)},
|
||||
{"next_token",
|
||||
{
|
||||
{"has_next_token", has_next_token},
|
||||
|
|
@ -1849,13 +1853,16 @@ struct server_context {
|
|||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
|
||||
// multimodal
|
||||
mtmd_context * mctx = nullptr;
|
||||
|
||||
const llama_vocab * vocab = nullptr;
|
||||
|
||||
llama_model * model_dft = nullptr;
|
||||
|
||||
llama_context_params cparams_dft;
|
||||
|
||||
llama_batch batch = {};
|
||||
llama_batch batch;
|
||||
|
||||
bool clean_kv_cache = true;
|
||||
bool add_bos_token = true;
|
||||
|
|
@ -1878,6 +1885,8 @@ struct server_context {
|
|||
common_chat_templates_ptr chat_templates;
|
||||
|
||||
~server_context() {
|
||||
mtmd_free(mctx);
|
||||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
common_sampler_free(slot.smpl);
|
||||
|
|
@ -1965,6 +1974,36 @@ struct server_context {
|
|||
chat_templates = common_chat_templates_init(model, "chatml");
|
||||
}
|
||||
|
||||
std::string & mmproj_path = params_base.mmproj.path;
|
||||
if (!mmproj_path.empty()) {
|
||||
mtmd_context_params mparams = mtmd_context_params_default();
|
||||
mparams.use_gpu = params_base.mmproj_use_gpu;
|
||||
mparams.print_timings = false;
|
||||
mparams.n_threads = params_base.cpuparams.n_threads;
|
||||
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
|
||||
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
|
||||
if (mctx == nullptr) {
|
||||
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
return false;
|
||||
}
|
||||
SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
|
||||
|
||||
if (params_base.ctx_shift) {
|
||||
params_base.ctx_shift = false;
|
||||
SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
|
||||
if (params_base.n_cache_reuse) {
|
||||
params_base.n_cache_reuse = 0;
|
||||
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
|
||||
}
|
||||
|
||||
if (!params_base.speculative.model.path.empty()) {
|
||||
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -1980,6 +2019,8 @@ struct server_context {
|
|||
slot.ctx = ctx;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
slot.n_predict = params_base.n_predict;
|
||||
slot.mctx = mctx;
|
||||
slot.cache_tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
if (model_dft) {
|
||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||
|
|
@ -2016,8 +2057,6 @@ struct server_context {
|
|||
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
||||
{
|
||||
const int32_t n_batch = llama_n_batch(ctx);
|
||||
|
||||
// only a single seq_id per token is needed
|
||||
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
||||
}
|
||||
|
||||
|
|
@ -2054,7 +2093,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
||||
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
|
||||
int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
|
||||
|
||||
// fraction of the common subsequence length compared to the current slot's prompt length
|
||||
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
||||
|
|
@ -2096,18 +2135,6 @@ struct server_context {
|
|||
return ret;
|
||||
}
|
||||
|
||||
bool can_be_detokenized(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
for (const auto & token : tokens) {
|
||||
if (token < 0 || token >= n_vocab) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||
slot.reset();
|
||||
slot.id_task = task.id;
|
||||
|
|
@ -2122,8 +2149,7 @@ struct server_context {
|
|||
slot.lora = slot.params.lora;
|
||||
}
|
||||
|
||||
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
|
||||
if (!can_detokenize) {
|
||||
if (!slot.prompt_tokens.validate(ctx)) {
|
||||
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2385,6 +2411,15 @@ struct server_context {
|
|||
queue_results.send(std::move(res));
|
||||
}
|
||||
|
||||
// if multimodal is enabled, send an error and return false
|
||||
bool ensure_no_mtmd(const int id_task) {
|
||||
if (mctx) {
|
||||
send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
||||
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
||||
|
||||
|
|
@ -2424,7 +2459,7 @@ struct server_context {
|
|||
res->content = std::move(slot.generated_text);
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
|
||||
res->response_fields = std::move(slot.params.response_fields);
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
|
|
@ -2734,6 +2769,10 @@ struct server_context {
|
|||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_SAVE:
|
||||
{
|
||||
if (!ensure_no_mtmd(task.id)) {
|
||||
break;
|
||||
}
|
||||
|
||||
int id_slot = task.slot_action.slot_id;
|
||||
server_slot * slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
|
|
@ -2753,7 +2792,8 @@ struct server_context {
|
|||
std::string filename = task.slot_action.filename;
|
||||
std::string filepath = task.slot_action.filepath;
|
||||
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
||||
const llama_tokens & tokens = slot->cache_tokens.get_text_tokens();
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_save_ms = (t_end - t_start) / 1000.0;
|
||||
|
|
@ -2770,6 +2810,7 @@ struct server_context {
|
|||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
||||
{
|
||||
if (!ensure_no_mtmd(task.id)) break;
|
||||
int id_slot = task.slot_action.slot_id;
|
||||
server_slot * slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
|
|
@ -2788,15 +2829,18 @@ struct server_context {
|
|||
std::string filename = task.slot_action.filename;
|
||||
std::string filepath = task.slot_action.filepath;
|
||||
|
||||
slot->cache_tokens.resize(slot->n_ctx);
|
||||
llama_tokens tokens;
|
||||
tokens.resize(slot->n_ctx);
|
||||
size_t token_count = 0;
|
||||
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
||||
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
|
||||
if (nread == 0) {
|
||||
slot->cache_tokens.resize(0);
|
||||
slot->cache_tokens.clear(); // KV may already been invalidated?
|
||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||
break;
|
||||
}
|
||||
slot->cache_tokens.resize(token_count);
|
||||
tokens.resize(token_count);
|
||||
slot->cache_tokens.clear();
|
||||
slot->cache_tokens.insert(tokens);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
|
@ -2813,6 +2857,7 @@ struct server_context {
|
|||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||
{
|
||||
if (!ensure_no_mtmd(task.id)) break;
|
||||
int id_slot = task.slot_action.slot_id;
|
||||
server_slot * slot = get_slot_by_id(id_slot);
|
||||
if (slot == nullptr) {
|
||||
|
|
@ -2844,6 +2889,7 @@ struct server_context {
|
|||
res->id = task.id;
|
||||
queue_results.send(std::move(res));
|
||||
} break;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2889,6 +2935,12 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
|
||||
// we don't support ctx_shift because an image chunk may contains multiple tokens
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
// Shift context
|
||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||
const int n_left = slot.n_past - n_keep;
|
||||
|
|
@ -2900,11 +2952,14 @@ struct server_context {
|
|||
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
||||
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
|
||||
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
|
||||
new_tokens[i - n_discard] = new_tokens[i];
|
||||
}
|
||||
|
||||
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||
new_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||
slot.cache_tokens.clear();
|
||||
slot.cache_tokens.insert(new_tokens);
|
||||
}
|
||||
|
||||
slot.n_past -= n_discard;
|
||||
|
|
@ -2982,7 +3037,7 @@ struct server_context {
|
|||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
|
|
@ -2992,7 +3047,7 @@ struct server_context {
|
|||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (prompt_tokens.empty()) {
|
||||
|
|
@ -3034,21 +3089,27 @@ struct server_context {
|
|||
|
||||
// if input prompt is too big, truncate it
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
if (mctx) {
|
||||
// we should never reach this
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens();
|
||||
llama_tokens new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
curr_tokens.begin(),
|
||||
curr_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
curr_tokens.end());
|
||||
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
prompt_tokens.clear();
|
||||
prompt_tokens.insert(new_tokens);
|
||||
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
|
@ -3060,13 +3121,18 @@ struct server_context {
|
|||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
|
||||
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params_base.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
|
||||
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
|
|
@ -3092,7 +3158,7 @@ struct server_context {
|
|||
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]);
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
|
|
@ -3140,21 +3206,52 @@ struct server_context {
|
|||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.n_past < slot.n_prompt_tokens
|
||||
&& slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
|
||||
// process the image
|
||||
int32_t new_n_past;
|
||||
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
|
||||
int32_t n_pos = new_n_past - slot.n_past;
|
||||
|
||||
if (res != 0) {
|
||||
SLT_ERR(slot, "failed to process image, res = %d\n", res);
|
||||
slot.release();
|
||||
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
|
||||
slot.cache_tokens.push_back(chunk.get()); // copy
|
||||
}
|
||||
|
||||
slot.n_past += n_pos;
|
||||
slot.n_prompt_tokens_processed += n_pos;
|
||||
}
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
llama_token cur_tok = slot.prompt_tokens[slot.n_past];
|
||||
if (cur_tok == LLAMA_TOKEN_NULL) {
|
||||
break; // end of text chunk
|
||||
}
|
||||
|
||||
// without pooling, we want to output the embeddings for all the tokens in the batch
|
||||
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
||||
|
||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||
|
||||
// entire prompt has been processed
|
||||
|
|
@ -3162,12 +3259,16 @@ struct server_context {
|
|||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size());
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||
llama_token id = slot.prompt_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(slot.smpl, id, false);
|
||||
}
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
|
|
@ -3320,6 +3421,11 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (mctx) {
|
||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
||||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = slot.params.speculative.n_max;
|
||||
|
||||
|
|
@ -3346,7 +3452,8 @@ struct server_context {
|
|||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
||||
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
|
||||
// keep track of total number of tokens generated in the draft
|
||||
slot.n_draft_total += draft.size();
|
||||
|
|
@ -3380,7 +3487,7 @@ struct server_context {
|
|||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
slot.cache_tokens.push_back(id);
|
||||
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
||||
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
|
||||
|
||||
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
|
||||
|
||||
|
|
@ -3903,6 +4010,7 @@ int main(int argc, char ** argv) {
|
|||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model.path },
|
||||
{ "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
|
||||
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
||||
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
|
||||
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
|
||||
|
|
@ -3950,9 +4058,10 @@ int main(int argc, char ** argv) {
|
|||
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||
server_task_type type,
|
||||
json & data,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const std::function<bool()> & is_connection_closed,
|
||||
httplib::Response & res,
|
||||
oaicompat_type oaicompat) {
|
||||
oaicompat_type oaicompat) -> void {
|
||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||
|
||||
if (ctx_server.params_base.embedding) {
|
||||
|
|
@ -3969,15 +4078,69 @@ int main(int argc, char ** argv) {
|
|||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
tasks.reserve(tokenized_prompts.size());
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
// process files
|
||||
mtmd::bitmaps bitmaps;
|
||||
const bool has_mtmd = ctx_server.mctx != nullptr;
|
||||
{
|
||||
if (!has_mtmd && !files.empty()) {
|
||||
throw std::runtime_error("This server does not support multimodal");
|
||||
}
|
||||
for (auto & file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
|
||||
if (!bmp.ptr) {
|
||||
throw std::runtime_error("Failed to load image");
|
||||
}
|
||||
// calculate bitmap hash (for KV caching)
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
|
||||
bmp.set_id(hash.c_str());
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
}
|
||||
}
|
||||
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
if (oaicompat && !prompt.is_string()) {
|
||||
throw std::runtime_error("prompt must be a string");
|
||||
}
|
||||
|
||||
if (oaicompat && has_mtmd) {
|
||||
// multimodal
|
||||
std::string prompt_str = prompt.get<std::string>();
|
||||
mtmd_input_text inp_txt = {
|
||||
prompt_str.c_str(),
|
||||
/* add_special */ true,
|
||||
/* parse_special */ true,
|
||||
};
|
||||
mtmd::input_chunks chunks(mtmd_input_chunks_init());
|
||||
auto bitmaps_c_ptr = bitmaps.c_ptr();
|
||||
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
|
||||
chunks.ptr.get(),
|
||||
&inp_txt,
|
||||
bitmaps_c_ptr.data(),
|
||||
bitmaps_c_ptr.size());
|
||||
if (tokenized != 0) {
|
||||
throw std::runtime_error("Failed to tokenize prompt");
|
||||
}
|
||||
|
||||
server_tokens tmp(chunks, true);
|
||||
inputs.push_back(std::move(tmp));
|
||||
} else {
|
||||
// non-multimodal version
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
for (auto & p : tokenized_prompts) {
|
||||
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
|
||||
inputs.push_back(std::move(tmp));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
server_task task = server_task(type);
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.prompt_tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
|
|
@ -4059,9 +4222,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = json::parse(req.body);
|
||||
return handle_completions_impl(
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE);
|
||||
|
|
@ -4069,9 +4234,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||
return handle_completions_impl(
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_COMPLETION);
|
||||
|
|
@ -4146,9 +4313,11 @@ int main(int argc, char ** argv) {
|
|||
tokenized_prompts[0]
|
||||
);
|
||||
|
||||
return handle_completions_impl(
|
||||
std::vector<raw_buffer> files; // dummy
|
||||
handle_completions_impl(
|
||||
SERVER_TASK_TYPE_INFILL,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
|
|
@ -4162,11 +4331,19 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
|
||||
std::vector<raw_buffer> files;
|
||||
json data = oaicompat_completion_params_parse(
|
||||
body,
|
||||
params.use_jinja,
|
||||
params.reasoning_format,
|
||||
ctx_server.chat_templates.get(),
|
||||
ctx_server.mctx,
|
||||
files);
|
||||
|
||||
return handle_completions_impl(
|
||||
handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
files,
|
||||
req.is_connection_closed,
|
||||
res,
|
||||
OAICOMPAT_TYPE_CHAT);
|
||||
|
|
@ -4175,7 +4352,14 @@ int main(int argc, char ** argv) {
|
|||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
|
||||
std::vector<raw_buffer> files; // dummy, unused
|
||||
json data = oaicompat_completion_params_parse(
|
||||
body,
|
||||
params.use_jinja,
|
||||
params.reasoning_format,
|
||||
ctx_server.chat_templates.get(),
|
||||
ctx_server.mctx,
|
||||
files);
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
|
|
@ -4280,7 +4464,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
|
|
@ -4300,7 +4484,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = oaicompat;
|
||||
|
|
@ -4394,13 +4578,14 @@ int main(int argc, char ** argv) {
|
|||
std::unordered_set<int> task_ids;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
|
|
|
|||
59
tools/server/tests/unit/test_vision_api.py
Normal file
59
tools/server/tests/unit/test_vision_api.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
import base64
|
||||
import requests
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
||||
|
||||
response = requests.get(IMG_URL_0)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinygemma3()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_url, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this:\n", IMG_URL_0, True, "(cat)+"),
|
||||
("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
||||
("What is this:\n", IMG_URL_1, True, "(frog)+"),
|
||||
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
|
||||
("What is this:\n", "malformed", False, None),
|
||||
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
||||
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
||||
]
|
||||
)
|
||||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
global server
|
||||
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
|
||||
if image_url == "IMG_BASE64_0":
|
||||
image_url = IMG_BASE64_0
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": image_url,
|
||||
}},
|
||||
]},
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
||||
|
|
@ -88,6 +88,7 @@ class ServerProcess:
|
|||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
server_path: str | None = None
|
||||
mmproj_url: str | None = None
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
|
|
@ -194,6 +195,8 @@ class ServerProcess:
|
|||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||
if self.mmproj_url:
|
||||
server_args.extend(["--mmproj-url", self.mmproj_url])
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"tests: starting server with: {' '.join(args)}")
|
||||
|
|
@ -379,6 +382,21 @@ class ServerPreset:
|
|||
server.server_reranking = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def tinygemma3() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
# mmproj is already provided by HF registry API
|
||||
server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
|
||||
server.model_hf_file = "tinygemma3-Q8_0.gguf"
|
||||
server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf"
|
||||
server.model_alias = "tinygemma3"
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 32
|
||||
server.n_slots = 2
|
||||
server.n_predict = 4
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
|
||||
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "arg.h" // common_remote_get_content
|
||||
#include "base64.hpp"
|
||||
#include "mtmd.h"
|
||||
|
||||
// increase max payload length to allow use of larger context size
|
||||
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
||||
|
|
@ -21,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <cinttypes>
|
||||
|
||||
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
|
||||
|
||||
|
|
@ -41,6 +44,8 @@ using json = nlohmann::ordered_json;
|
|||
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
||||
// Fallback null to default value
|
||||
|
|
@ -386,7 +391,7 @@ static inline bool is_base64(uint8_t c) {
|
|||
return (isalnum(c) || (c == '+') || (c == '/'));
|
||||
}
|
||||
|
||||
static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
|
||||
static inline raw_buffer base64_decode(const std::string & encoded_string) {
|
||||
int i = 0;
|
||||
int j = 0;
|
||||
int in_ = 0;
|
||||
|
|
@ -396,7 +401,7 @@ static inline std::vector<uint8_t> base64_decode(const std::string & encoded_str
|
|||
uint8_t char_array_4[4];
|
||||
uint8_t char_array_3[3];
|
||||
|
||||
std::vector<uint8_t> ret;
|
||||
raw_buffer ret;
|
||||
|
||||
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
|
||||
char_array_4[i++] = encoded_string[in_]; in_++;
|
||||
|
|
@ -579,7 +584,9 @@ static json oaicompat_completion_params_parse(
|
|||
const json & body, /* openai api json semantics */
|
||||
bool use_jinja,
|
||||
common_reasoning_format reasoning_format,
|
||||
const struct common_chat_templates * tmpls)
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool allow_non_text,
|
||||
std::vector<raw_buffer> & out_files)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
|
|
@ -627,8 +634,77 @@ static json oaicompat_completion_params_parse(
|
|||
}
|
||||
}
|
||||
|
||||
// get input files
|
||||
if (!body.contains("messages")) {
|
||||
throw std::runtime_error("'messages' is required");
|
||||
}
|
||||
json messages = body.at("messages");
|
||||
if (!messages.is_array()) {
|
||||
throw std::runtime_error("Expected 'messages' to be an array");
|
||||
}
|
||||
for (auto & msg : messages) {
|
||||
json & content = msg.at("content");
|
||||
if (content.is_string()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!content.is_array()) {
|
||||
throw std::runtime_error("Expected 'content' to be a string or an array");
|
||||
}
|
||||
|
||||
for (auto & p : content) {
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
if (type == "image_url") {
|
||||
if (!allow_non_text) {
|
||||
throw std::runtime_error("image input is not supported by this server");
|
||||
}
|
||||
|
||||
std::string url = json_value(image_url, "url", std::string());
|
||||
if (string_starts_with(url, "http")) {
|
||||
// download remote image
|
||||
// TODO @ngxson : maybe make these params configurable
|
||||
common_remote_params params;
|
||||
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
|
||||
params.max_size = 1024 * 1024 * 10; // 10MB
|
||||
params.timeout = 10; // seconds
|
||||
SRV_INF("downloading image from '%s'\n", url.c_str());
|
||||
auto res = common_remote_get_content(url, params);
|
||||
if (200 <= res.first && res.first < 300) {
|
||||
SRV_INF("downloaded %ld bytes\n", res.second.size());
|
||||
raw_buffer data;
|
||||
data.insert(data.end(), res.second.begin(), res.second.end());
|
||||
out_files.push_back(data);
|
||||
} else {
|
||||
throw std::runtime_error("Failed to download image");
|
||||
}
|
||||
|
||||
} else {
|
||||
// try to decode base64 image
|
||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||
if (parts.size() != 2) {
|
||||
throw std::runtime_error("Invalid image_url.url value");
|
||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
throw std::runtime_error("Invalid image_url.url format: " + parts[0]);
|
||||
} else if (!string_ends_with(parts[0], "base64")) {
|
||||
throw std::runtime_error("image_url.url must be base64 encoded");
|
||||
} else {
|
||||
auto base64_data = parts[1];
|
||||
auto decoded_data = base64_decode(base64_data);
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
|
||||
p.erase("image_url");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
|
|
@ -935,3 +1011,286 @@ static std::vector<common_adapter_lora_info> parse_lora_request(
|
|||
|
||||
return lora;
|
||||
}
|
||||
|
||||
//
|
||||
// utils for interacting with libmtmd
|
||||
// (may need to refactor in near future)
|
||||
//
|
||||
|
||||
/**
|
||||
* server_tokens is a helper to manage the input tokens and image for the server.
|
||||
* it is made this way to simplify the logic of KV cache management.
|
||||
*/
|
||||
struct server_tokens {
|
||||
bool has_mtmd = false;
|
||||
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** position in tokens to the image chunk
|
||||
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
|
||||
|
||||
// list of tokens
|
||||
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
|
||||
// a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position**
|
||||
// important: for models using mrope, an image can contain multiple tokens but will use only one **position**
|
||||
llama_tokens tokens;
|
||||
|
||||
// for ex. with input of 5 text tokens and 2 images:
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
|
||||
// pos 0 1 2 3 4 5 6 7 8 9
|
||||
// map_pos_to_image will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
~server_tokens() = default;
|
||||
|
||||
// Prevent copying
|
||||
server_tokens(const server_tokens&) = delete;
|
||||
server_tokens& operator=(const server_tokens&) = delete;
|
||||
|
||||
// Allow moving (usually implicitly generated if members are movable)
|
||||
server_tokens(server_tokens&&) = default;
|
||||
server_tokens& operator=(server_tokens&&) = default;
|
||||
|
||||
// Allow accessing elements using [] operator
|
||||
llama_token operator[](size_t index) { return tokens[index]; }
|
||||
const llama_token& operator[](size_t index) const { return tokens[index]; }
|
||||
|
||||
server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) {
|
||||
for (size_t i = 0; i < mtmd_chunks.size(); ++i) {
|
||||
push_back(mtmd_chunks[i]);
|
||||
}
|
||||
}
|
||||
|
||||
server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {}
|
||||
|
||||
// for debugging
|
||||
std::string str() const {
|
||||
std::ostringstream oss;
|
||||
oss << "tokens: ";
|
||||
for (const auto & t : tokens) {
|
||||
if (t == LLAMA_TOKEN_NULL) {
|
||||
oss << "<embd> ";
|
||||
} else {
|
||||
oss << t << " ";
|
||||
}
|
||||
}
|
||||
oss << "\n";
|
||||
oss << "image pos: ";
|
||||
for (const auto & it : map_pos_to_image) {
|
||||
oss << it.first << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
|
||||
auto it = map_pos_to_image.find(pos);
|
||||
if (it != map_pos_to_image.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
throw std::runtime_error("Chunk not found");
|
||||
}
|
||||
}
|
||||
|
||||
void push_back(llama_token tok) {
|
||||
if (tok == LLAMA_TOKEN_NULL) {
|
||||
throw std::runtime_error("Invalid token");
|
||||
}
|
||||
tokens.emplace_back(tok);
|
||||
}
|
||||
|
||||
// will create a copy of the chunk if it contains non-text data
|
||||
void push_back(const mtmd_input_chunk * chunk) {
|
||||
auto type = mtmd_input_chunk_get_type(chunk);
|
||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
GGML_ASSERT(has_mtmd);
|
||||
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
|
||||
llama_pos start_pos = tokens.size();
|
||||
for (int i = 0; i < n_pos; ++i) {
|
||||
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
||||
}
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_pos_to_image[start_pos] = std::move(new_chunk);
|
||||
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens;
|
||||
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||
for (size_t i = 0; i < n_tokens; ++i) {
|
||||
push_back(text_tokens[i]);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("Invalid chunk type");
|
||||
}
|
||||
}
|
||||
|
||||
// for compatibility with context shift and prompt truncation
|
||||
void insert(const llama_tokens & inp_tokens) {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end());
|
||||
}
|
||||
|
||||
// for compatibility with speculative decoding, ctx shift, slot save/load
|
||||
const llama_tokens & get_text_tokens() const {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// for compatibility with speculative decoding
|
||||
void set_token(llama_pos pos, llama_token id) {
|
||||
GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled
|
||||
tokens[pos] = id;
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return tokens.empty();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
GGML_ASSERT(n <= tokens.size());
|
||||
if (has_mtmd) {
|
||||
// we throw an error if we try to remove a token in the middle of an image
|
||||
// for ex. with input of 5 text tokens and 2 images:
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
|
||||
// n 1 2 3 4 5 6 7 8 9 10
|
||||
// allowed to resize ^ ^
|
||||
// disallowed to resize ^ ^ ^
|
||||
if (n > 0) {
|
||||
llama_token last_token = tokens[n - 1];
|
||||
// make sure we never remove tokens in the middle of an image
|
||||
if (last_token == LLAMA_TOKEN_NULL) {
|
||||
find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk
|
||||
}
|
||||
}
|
||||
// remove all image chunks that are not used anymore
|
||||
for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
|
||||
llama_pos pos = it->first;
|
||||
if (pos >= (llama_pos)n) {
|
||||
it = map_pos_to_image.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
tokens.resize(n);
|
||||
}
|
||||
|
||||
std::string detokenize(const llama_context * ctx, bool special) const {
|
||||
llama_tokens text_tokens;
|
||||
text_tokens.reserve(tokens.size());
|
||||
for (const auto & t : tokens) {
|
||||
if (t != LLAMA_TOKEN_NULL) {
|
||||
text_tokens.push_back(t);
|
||||
}
|
||||
}
|
||||
return common_detokenize(ctx, text_tokens, special);
|
||||
}
|
||||
|
||||
size_t get_common_prefix(const server_tokens & b) const {
|
||||
size_t max_idx = std::min(tokens.size(), b.tokens.size());
|
||||
for (size_t i = 0; i < max_idx; ++i) {
|
||||
auto & ai = tokens[i];
|
||||
auto & bi = b.tokens[i];
|
||||
|
||||
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
|
||||
GGML_ASSERT(has_mtmd);
|
||||
const auto & a_chunk = find_chunk(i);
|
||||
const auto & b_chunk = b.find_chunk(i);
|
||||
GGML_ASSERT(a_chunk && b_chunk);
|
||||
const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
|
||||
const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
|
||||
std::string ai_id = mtmd_image_tokens_get_id(a_img);
|
||||
std::string bi_id = mtmd_image_tokens_get_id(b_img);
|
||||
size_t a_pos = mtmd_image_tokens_get_n_pos(a_img);
|
||||
size_t b_pos = mtmd_image_tokens_get_n_pos(b_img);
|
||||
if (ai_id == bi_id && a_pos == b_pos) {
|
||||
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
|
||||
i += a_pos - 1; // will be +1 by the for loop
|
||||
continue;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
} else if (ai == bi) {
|
||||
continue;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return max_idx; // all tokens are equal
|
||||
}
|
||||
|
||||
// make sure all text tokens are within the vocab range
|
||||
bool validate(const struct llama_context * ctx) const {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); ++i) {
|
||||
auto & t = tokens[i];
|
||||
if (t == LLAMA_TOKEN_NULL) {
|
||||
try {
|
||||
const auto & chunk = find_chunk(i);
|
||||
const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
|
||||
size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
|
||||
i += n_pos - 1; // will be +1 by the for loop
|
||||
} catch (const std::exception & e) {
|
||||
return false;
|
||||
}
|
||||
} else if (t < 0 || t >= n_vocab) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// encode and decode the image chunk
|
||||
int32_t process_chunk(
|
||||
llama_context * ctx,
|
||||
mtmd_context * mctx,
|
||||
llama_pos n_past,
|
||||
int32_t seq_id,
|
||||
llama_pos & n_pos_out) {
|
||||
auto it = map_pos_to_image.find(n_past);
|
||||
if (it == map_pos_to_image.end()) {
|
||||
throw std::runtime_error("Chunk not found");
|
||||
}
|
||||
SRV_INF("%s\n", "processing image...");
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
llama_pos new_n_past = n_past;
|
||||
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
|
||||
it->second.get(), // chunk
|
||||
n_past,
|
||||
seq_id,
|
||||
n_batch,
|
||||
true, // logits last
|
||||
&new_n_past);
|
||||
SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
if (result != 0) {
|
||||
LOG_ERR("mtmd_helper_eval failed with status %d", result);
|
||||
n_pos_out = n_past;
|
||||
return result;
|
||||
}
|
||||
n_pos_out = new_n_past;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
static std::string fnv_hash(const uint8_t * data, size_t len) {
|
||||
const uint64_t fnv_prime = 0x100000001b3ULL;
|
||||
uint64_t hash = 0xcbf29ce484222325ULL;
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
hash ^= data[i];
|
||||
hash *= fnv_prime;
|
||||
}
|
||||
return std::to_string(hash);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,19 +37,26 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
|||
break;
|
||||
}
|
||||
|
||||
if (mimeType.startsWith('image/') && mimeType !== 'image/svg+xml') {
|
||||
if (!serverProps?.has_multimodal) {
|
||||
if (mimeType.startsWith('image/')) {
|
||||
if (!serverProps?.modalities?.vision) {
|
||||
toast.error('Multimodal is not supported by this server or model.');
|
||||
break;
|
||||
}
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
reader.onload = async (event) => {
|
||||
if (event.target?.result) {
|
||||
let base64Url = event.target.result as string;
|
||||
|
||||
if (mimeType === 'image/svg+xml') {
|
||||
// Convert SVG to PNG
|
||||
base64Url = await svgBase64UrlToPngDataURL(base64Url);
|
||||
}
|
||||
|
||||
addItems([
|
||||
{
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url: event.target.result as string,
|
||||
base64Url,
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
|
@ -172,3 +179,56 @@ export function isLikelyNotBinary(str: string): boolean {
|
|||
const ratio = suspiciousCharCount / sampleLength;
|
||||
return ratio <= options.suspiciousCharThresholdRatio;
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
// Converts a Base64URL encoded SVG string to a PNG Data URL using browser Canvas API.
|
||||
function svgBase64UrlToPngDataURL(base64UrlSvg: string): Promise<string> {
|
||||
const backgroundColor = 'white'; // Default background color for PNG
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
const img = new Image();
|
||||
|
||||
img.onload = () => {
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
|
||||
if (!ctx) {
|
||||
reject(new Error('Failed to get 2D canvas context.'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Use provided dimensions or SVG's natural dimensions, with fallbacks
|
||||
// Fallbacks (e.g., 300x300) are for SVGs without explicit width/height
|
||||
// or when naturalWidth/Height might be 0 before full processing.
|
||||
const targetWidth = img.naturalWidth || 300;
|
||||
const targetHeight = img.naturalHeight || 300;
|
||||
|
||||
canvas.width = targetWidth;
|
||||
canvas.height = targetHeight;
|
||||
|
||||
if (backgroundColor) {
|
||||
ctx.fillStyle = backgroundColor;
|
||||
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
||||
}
|
||||
|
||||
ctx.drawImage(img, 0, 0, targetWidth, targetHeight);
|
||||
resolve(canvas.toDataURL('image/png'));
|
||||
};
|
||||
|
||||
img.onerror = () => {
|
||||
reject(
|
||||
new Error('Failed to load SVG image. Ensure the SVG data is valid.')
|
||||
);
|
||||
};
|
||||
|
||||
// Load SVG string into an Image element
|
||||
img.src = base64UrlSvg;
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const errorMessage = `Error converting SVG to PNG: ${message}`;
|
||||
toast.error(errorMessage);
|
||||
reject(new Error(errorMessage));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -118,6 +118,8 @@ export interface LlamaCppServerProps {
|
|||
build_info: string;
|
||||
model_path: string;
|
||||
n_ctx: number;
|
||||
has_multimodal: boolean;
|
||||
modalities?: {
|
||||
vision: boolean;
|
||||
};
|
||||
// TODO: support params
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue