1
1
#include " softmax.cuh"
2
2
3
3
template <bool vals_smem, int ncols_template, int block_size_template>
4
- static __global__ void soft_max_f32 (const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
4
+ static __global__ void soft_max_f32 (const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
5
5
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
6
6
7
7
const int tid = threadIdx .x ;
@@ -43,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
43
43
const int ix = rowx*ncols + col;
44
44
const int iy = rowy*ncols + col;
45
45
46
- const float val = x[ix]*scale + (mask ? mask[iy] : 0 .0f ) + (pos ? slope*pos[col] : 0 .0f );
46
+ const float val = x[ix]*scale + (mask ? __half2float ( mask[iy]) : 0 .0f ) + (pos ? slope*__half2float ( pos[col]) : 0 .0f );
47
47
48
48
vals[col] = val;
49
49
max_val = max (max_val, val);
@@ -114,7 +114,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
114
114
}
115
115
}
116
116
117
- static void soft_max_f32_cuda (const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
117
+ static void soft_max_f32_cuda (const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
118
118
int nth = WARP_SIZE;
119
119
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
120
120
const dim3 block_dims (nth, 1 , 1 );
@@ -168,14 +168,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
168
168
const ggml_tensor * src0 = dst->src [0 ];
169
169
const ggml_tensor * src1 = dst->src [1 ];
170
170
const float * src0_d = (const float *)src0->data ;
171
- const float * src1_d = src1 ? (const float *)src1->data : nullptr ;
171
+ const half * src1_d = src1 ? (const half *)src1->data : nullptr ;
172
172
float * dst_d = (float *)dst->data ;
173
173
cudaStream_t stream = ctx.stream ();
174
174
175
175
GGML_ASSERT (src0->type == GGML_TYPE_F32);
176
176
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
177
177
178
- GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32 ); // src1 contains mask and it is optional
178
+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 ); // src1 contains mask and it is optional
179
179
180
180
const int64_t ne00 = src0->ne [0 ];
181
181
const int64_t nrows_x = ggml_nrows (src0);
@@ -188,13 +188,13 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188
188
memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
189
189
190
190
// positions tensor
191
- float * src2_dd = nullptr ;
191
+ half * src2_dd = nullptr ;
192
192
193
193
ggml_tensor * src2 = dst->src [2 ];
194
194
const bool use_src2 = src2 != nullptr ;
195
195
196
196
if (use_src2) {
197
- src2_dd = (float *)src2->data ;
197
+ src2_dd = (half *)src2->data ;
198
198
}
199
199
200
200
soft_max_f32_cuda (src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
0 commit comments