#ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage // however, be careful from int overflows when using those in the kernel implementation // // - strides (e.g. nb00) use uint64_t typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t dim; } ggml_metal_kargs_concat; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; uint64_t offs; } ggml_metal_kargs_bin; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_repeat; typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; int64_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int64_t ne0; int64_t ne1; int64_t ne2; int64_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; } ggml_metal_kargs_cpy; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne0; int32_t ne1; int32_t ne2; int32_t ne3; uint64_t nb0; uint64_t nb1; uint64_t nb2; uint64_t nb3; int32_t n_past; int32_t n_dims; int32_t n_ctx_orig; float freq_base; float freq_scale; float ext_factor; float attn_factor; float beta_fast; float beta_slow; } ggml_metal_kargs_rope; typedef struct { int32_t ne01; int32_t ne02; int32_t ne03; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; uint64_t nb_12_1; uint64_t nb_12_2; uint64_t nb_12_3; uint64_t nb31; int32_t ne1; int32_t ne2; float scale; float max_bias; float m0; float m1; uint16_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; typedef struct { int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; int32_t ne10; int32_t ne11; int32_t ne12; uint64_t nb10; uint64_t nb11; uint64_t nb12; uint64_t nb13; int32_t ne0; int32_t ne1; int16_t r2; int16_t r3; int16_t nsg; int16_t nxpsg; int16_t r1ptg; } ggml_metal_kargs_mul_mv_ext; typedef struct { int32_t nei0; int32_t nei1; uint64_t nbi1; int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; int32_t ne0; int32_t ne1; } ggml_metal_kargs_mul_mm_id; typedef struct { int32_t nei0; int32_t nei1; uint64_t nbi1; int32_t ne00; int32_t ne01; int32_t ne02; uint64_t nb00; uint64_t nb01; uint64_t nb02; int32_t ne10; int32_t ne11; int32_t ne12; int32_t ne13; uint64_t nb10; uint64_t nb11; uint64_t nb12; int32_t ne0; int32_t ne1; uint64_t nb1; } ggml_metal_kargs_mul_mv_id; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb01; float eps; } ggml_metal_kargs_norm; typedef struct { int32_t ne00; int32_t ne00_4; uint64_t nb01; float eps; } ggml_metal_kargs_rms_norm; #endif // GGML_METAL_IMPL