|
#ifndef GGML_METAL_IMPL |
|
#define GGML_METAL_IMPL |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
int32_t dim; |
|
} ggml_metal_kargs_concat; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
uint64_t offs; |
|
} ggml_metal_kargs_bin; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
} ggml_metal_kargs_repeat; |
|
|
|
typedef struct { |
|
int64_t ne00; |
|
int64_t ne01; |
|
int64_t ne02; |
|
int64_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int64_t ne0; |
|
int64_t ne1; |
|
int64_t ne2; |
|
int64_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
} ggml_metal_kargs_cpy; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int32_t ne2; |
|
int32_t ne3; |
|
uint64_t nb0; |
|
uint64_t nb1; |
|
uint64_t nb2; |
|
uint64_t nb3; |
|
int32_t n_past; |
|
int32_t n_dims; |
|
int32_t n_ctx_orig; |
|
float freq_base; |
|
float freq_scale; |
|
float ext_factor; |
|
float attn_factor; |
|
float beta_fast; |
|
float beta_slow; |
|
} ggml_metal_kargs_rope; |
|
|
|
typedef struct { |
|
int32_t ne01; |
|
int32_t ne02; |
|
int32_t ne03; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne11; |
|
int32_t ne_12_2; |
|
int32_t ne_12_3; |
|
uint64_t nb_12_1; |
|
uint64_t nb_12_2; |
|
uint64_t nb_12_3; |
|
uint64_t nb31; |
|
int32_t ne1; |
|
int32_t ne2; |
|
float scale; |
|
float max_bias; |
|
float m0; |
|
float m1; |
|
uint16_t n_head_log2; |
|
float logit_softcap; |
|
} ggml_metal_kargs_flash_attn_ext; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne02; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int16_t r2; |
|
int16_t r3; |
|
} ggml_metal_kargs_mul_mm; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int16_t r2; |
|
int16_t r3; |
|
} ggml_metal_kargs_mul_mv; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
uint64_t nb03; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
uint64_t nb13; |
|
int32_t ne0; |
|
int32_t ne1; |
|
int16_t r2; |
|
int16_t r3; |
|
int16_t nsg; |
|
int16_t nxpsg; |
|
int16_t r1ptg; |
|
} ggml_metal_kargs_mul_mv_ext; |
|
|
|
typedef struct { |
|
int32_t nei0; |
|
int32_t nei1; |
|
uint64_t nbi1; |
|
int32_t ne00; |
|
int32_t ne02; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
int32_t ne0; |
|
int32_t ne1; |
|
} ggml_metal_kargs_mul_mm_id; |
|
|
|
typedef struct { |
|
int32_t nei0; |
|
int32_t nei1; |
|
uint64_t nbi1; |
|
int32_t ne00; |
|
int32_t ne01; |
|
int32_t ne02; |
|
uint64_t nb00; |
|
uint64_t nb01; |
|
uint64_t nb02; |
|
int32_t ne10; |
|
int32_t ne11; |
|
int32_t ne12; |
|
int32_t ne13; |
|
uint64_t nb10; |
|
uint64_t nb11; |
|
uint64_t nb12; |
|
int32_t ne0; |
|
int32_t ne1; |
|
uint64_t nb1; |
|
} ggml_metal_kargs_mul_mv_id; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne00_4; |
|
uint64_t nb01; |
|
float eps; |
|
} ggml_metal_kargs_norm; |
|
|
|
typedef struct { |
|
int32_t ne00; |
|
int32_t ne00_4; |
|
uint64_t nb01; |
|
float eps; |
|
} ggml_metal_kargs_rms_norm; |
|
|
|
#endif |
|
|