ggml-webgpu: Add supports for DIAG and TRI (#20664)
* Add supports for DIAG and TRI. * Remove extra ttype and add a comment for TRI op.
This commit is contained in:
parent
07ba6d275b
commit
ea01d196d7
5 changed files with 77 additions and 21 deletions
|
|
@ -244,13 +244,15 @@ struct ggml_webgpu_binary_pipeline_key_hash {
|
|||
/** Unary **/
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
ggml_tri_type ttype; // only used for GGML_OP_TRI
|
||||
|
||||
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
|
||||
ttype == other.ttype;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -261,6 +263,7 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
|||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.is_unary);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.ttype);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
|
@ -1058,6 +1061,7 @@ class ggml_webgpu_shader_lib {
|
|||
.op = op,
|
||||
.is_unary = is_unary,
|
||||
.inplace = context.inplace,
|
||||
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
|
||||
};
|
||||
|
||||
auto it = unary_pipelines.find(key);
|
||||
|
|
@ -1088,6 +1092,29 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_inplace";
|
||||
}
|
||||
|
||||
if (op == GGML_OP_TRI) {
|
||||
switch (key.ttype) {
|
||||
case GGML_TRI_TYPE_LOWER:
|
||||
defines.push_back("TRI_TYPE_LOWER");
|
||||
variant += "_tri_type_lower";
|
||||
break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG:
|
||||
defines.push_back("TRI_TYPE_LOWER_DIAG");
|
||||
variant += "_tri_type_lower_diag";
|
||||
break;
|
||||
case GGML_TRI_TYPE_UPPER:
|
||||
defines.push_back("TRI_TYPE_UPPER");
|
||||
variant += "_tri_type_upper";
|
||||
break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG:
|
||||
defines.push_back("TRI_TYPE_UPPER_DIAG");
|
||||
variant += "_tri_upper_diag";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
||||
|
|
|
|||
|
|
@ -2209,6 +2209,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_TRI:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_PAD:
|
||||
return ggml_webgpu_pad(ctx, src0, node);
|
||||
|
|
@ -3201,6 +3203,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_OP_COS:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_PAD:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ enable f16;
|
|||
#define TYPE f32
|
||||
#endif
|
||||
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<TYPE>;
|
||||
|
||||
|
|
@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
let ne2 = params.ne2;
|
||||
#ifdef DIAG
|
||||
let ne1 = params.ne0;
|
||||
#else
|
||||
let ne1 = params.ne1;
|
||||
#endif
|
||||
let ne0 = params.ne0;
|
||||
|
||||
let i3 = i / (ne2 * ne1 * ne0);
|
||||
i = i % (ne2 * ne1 * ne0);
|
||||
let i2 = i / (ne1 * ne0);
|
||||
i = i % (ne1 * ne0);
|
||||
let i1 = i / ne0;
|
||||
let i0 = i % ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
|
@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
|
||||
let res = TYPE(res_f32);
|
||||
#endif
|
||||
#ifdef DIAG
|
||||
let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
|
||||
#endif
|
||||
#ifdef TRI
|
||||
#ifdef TRI_TYPE_LOWER
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
|
||||
#elif TRI_TYPE_LOWER_DIAG
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
|
||||
#elif TRI_TYPE_UPPER
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
|
||||
#elif TRI_TYPE_UPPER_DIAG
|
||||
let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef INPLACE
|
||||
src[params.offset_src + src_idx] = res;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue