Skip to content

Commit 08e69c5

Browse files
committed
cuda : adapt soft_max to F16 mask and pos
1 parent 3e318e7 commit 08e69c5

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ggml-cuda/softmax.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "softmax.cuh"
22

33
template <bool vals_smem, int ncols_template, int block_size_template>
4-
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
4+
static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
55
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
66

77
const int tid = threadIdx.x;
@@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
4343
const int ix = rowx*ncols + col;
4444
const int iy = rowy*ncols + col;
4545

46-
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
46+
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
4747

4848
vals[col] = val;
4949
max_val = max(max_val, val);
@@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
114114
}
115115
}
116116

117-
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
117+
static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
118118
int nth = WARP_SIZE;
119119
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
120120
const dim3 block_dims(nth, 1, 1);
@@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
168168
const ggml_tensor * src0 = dst->src[0];
169169
const ggml_tensor * src1 = dst->src[1];
170170
const float * src0_d = (const float *)src0->data;
171-
const float * src1_d = src1 ? (const float *)src1->data : nullptr;
171+
const half * src1_d = src1 ? (const half *)src1->data : nullptr;
172172
float * dst_d = (float *)dst->data;
173173
cudaStream_t stream = ctx.stream();
174174

175175
GGML_ASSERT(src0->type == GGML_TYPE_F32);
176176
GGML_ASSERT( dst->type == GGML_TYPE_F32);
177177

178-
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
178+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
179179

180180
const int64_t ne00 = src0->ne[0];
181181
const int64_t nrows_x = ggml_nrows(src0);
@@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188188
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
189189

190190
// positions tensor
191-
float * src2_dd = nullptr;
191+
half * src2_dd = nullptr;
192192

193193
ggml_tensor * src2 = dst->src[2];
194194
const bool use_src2 = src2 != nullptr;
195195

196196
if (use_src2) {
197-
src2_dd = (float *)src2->data;
197+
src2_dd = (half *)src2->data;
198198
}
199199

200200
soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);

0 commit comments

Comments
 (0)