fix(nitro): Update metal file from llama.cpp f56e1ba (#293)
This commit is contained in:
parent
95940e7eb8
commit
780ec5c3fd
@ -24,12 +24,59 @@ typedef struct {
|
|||||||
int8_t qs[QK8_0]; // quants
|
int8_t qs[QK8_0]; // quants
|
||||||
} block_q8_0;
|
} block_q8_0;
|
||||||
|
|
||||||
|
// general-purpose kernel for addition of two tensors
|
||||||
|
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
||||||
|
// cons: not very efficient
|
||||||
kernel void kernel_add(
|
kernel void kernel_add(
|
||||||
device const float4 * src0,
|
device const char * src0,
|
||||||
device const float4 * src1,
|
device const char * src1,
|
||||||
device float4 * dst,
|
device char * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
constant int64_t & ne00,
|
||||||
dst[tpig] = src0[tpig] + src1[tpig];
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant int64_t & nb00,
|
||||||
|
constant int64_t & nb01,
|
||||||
|
constant int64_t & nb02,
|
||||||
|
constant int64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant int64_t & nb10,
|
||||||
|
constant int64_t & nb11,
|
||||||
|
constant int64_t & nb12,
|
||||||
|
constant int64_t & nb13,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant int64_t & nb0,
|
||||||
|
constant int64_t & nb1,
|
||||||
|
constant int64_t & nb2,
|
||||||
|
constant int64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig.z;
|
||||||
|
const int64_t i02 = tgpig.y;
|
||||||
|
const int64_t i01 = tgpig.x;
|
||||||
|
|
||||||
|
const int64_t i13 = i03 % ne13;
|
||||||
|
const int64_t i12 = i02 % ne12;
|
||||||
|
const int64_t i11 = i01 % ne11;
|
||||||
|
|
||||||
|
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
|
||||||
|
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
||||||
|
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
||||||
|
|
||||||
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||||
|
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
|
||||||
|
|
||||||
|
src0_ptr += ntg.x*nb00;
|
||||||
|
src1_ptr += ntg.x*nb10;
|
||||||
|
dst_ptr += ntg.x*nb0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assumption: src1 is a row
|
// assumption: src1 is a row
|
||||||
@ -38,7 +85,7 @@ kernel void kernel_add_row(
|
|||||||
device const float4 * src0,
|
device const float4 * src0,
|
||||||
device const float4 * src1,
|
device const float4 * src1,
|
||||||
device float4 * dst,
|
device float4 * dst,
|
||||||
constant int64_t & nb,
|
constant int64_t & nb [[buffer(27)]],
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||||
}
|
}
|
||||||
@ -783,7 +830,9 @@ kernel void kernel_alibi_f32(
|
|||||||
constant uint64_t & nb1,
|
constant uint64_t & nb1,
|
||||||
constant uint64_t & nb2,
|
constant uint64_t & nb2,
|
||||||
constant uint64_t & nb3,
|
constant uint64_t & nb3,
|
||||||
constant float & m0,
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant int & n_heads_log2_floor,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
@ -799,37 +848,73 @@ kernel void kernel_alibi_f32(
|
|||||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||||
|
|
||||||
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
float m_k = pow(m0, i2 + 1);
|
float m_k;
|
||||||
|
if (i2 < n_heads_log2_floor) {
|
||||||
|
m_k = pow(m0, i2 + 1);
|
||||||
|
} else {
|
||||||
|
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
|
||||||
|
}
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef void (rope_t)(
|
||||||
|
device const void * src0,
|
||||||
|
device const int32_t * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
constant int & n_past,
|
||||||
|
constant int & n_dims,
|
||||||
|
constant int & mode,
|
||||||
|
constant float & freq_base,
|
||||||
|
constant float & freq_scale,
|
||||||
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
uint3 tptg[[threads_per_threadgroup]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
kernel void kernel_rope(
|
kernel void kernel_rope(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device float * dst,
|
device const int32_t * src1,
|
||||||
constant int64_t & ne00,
|
device float * dst,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne03,
|
constant int64_t & ne02,
|
||||||
constant uint64_t & nb00,
|
constant int64_t & ne03,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb00,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb03,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne0,
|
constant uint64_t & nb03,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne2,
|
||||||
constant uint64_t & nb0,
|
constant int64_t & ne3,
|
||||||
constant uint64_t & nb1,
|
constant uint64_t & nb0,
|
||||||
constant uint64_t & nb2,
|
constant uint64_t & nb1,
|
||||||
constant uint64_t & nb3,
|
constant uint64_t & nb2,
|
||||||
constant int & n_past,
|
constant uint64_t & nb3,
|
||||||
constant int & n_dims,
|
constant int & n_past,
|
||||||
constant int & mode,
|
constant int & n_dims,
|
||||||
constant float & freq_base,
|
constant int & mode,
|
||||||
constant float & freq_scale,
|
constant float & freq_base,
|
||||||
|
constant float & freq_scale,
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint3 tptg[[threads_per_threadgroup]],
|
uint3 tptg[[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
@ -839,7 +924,9 @@ kernel void kernel_rope(
|
|||||||
|
|
||||||
const bool is_neox = mode & 2;
|
const bool is_neox = mode & 2;
|
||||||
|
|
||||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
device const int32_t * pos = src1;
|
||||||
|
|
||||||
|
const int64_t p = pos[i2];
|
||||||
|
|
||||||
const float theta_0 = freq_scale * (float)p;
|
const float theta_0 = freq_scale * (float)p;
|
||||||
const float inv_ndims = -1.f/n_dims;
|
const float inv_ndims = -1.f/n_dims;
|
||||||
@ -851,11 +938,11 @@ kernel void kernel_rope(
|
|||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
const float sin_theta = sin(theta);
|
const float sin_theta = sin(theta);
|
||||||
|
|
||||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
const float x0 = src[0];
|
const T x0 = src[0];
|
||||||
const float x1 = src[1];
|
const T x1 = src[1];
|
||||||
|
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||||
@ -870,8 +957,8 @@ kernel void kernel_rope(
|
|||||||
|
|
||||||
const int64_t i0 = ib*n_dims + ic/2;
|
const int64_t i0 = ib*n_dims + ic/2;
|
||||||
|
|
||||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
const float x0 = src[0];
|
const float x0 = src[0];
|
||||||
const float x1 = src[n_dims/2];
|
const float x1 = src[n_dims/2];
|
||||||
@ -883,6 +970,9 @@ kernel void kernel_rope(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||||
|
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
@ -1273,8 +1363,8 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||||||
|
|
||||||
float yl[32];
|
float yl[32];
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x3030;
|
//const uint16_t kmask1 = 0x3030;
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
//const uint16_t kmask2 = 0x0f0f;
|
||||||
|
|
||||||
const int tid = tiisg/4;
|
const int tid = tiisg/4;
|
||||||
const int ix = tiisg%4;
|
const int ix = tiisg%4;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user