11#include " norm.cuh"
2+ #include < cstdint>
23
34template <int block_size>
4- static __global__ void norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
5- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
6- const int tid = threadIdx .x ;
5+ static __global__ void norm_f32 (
6+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
7+ const int64_t stride_sample, const float eps) {
8+ const int nrows = gridDim .x ;
9+ const int nchannels = gridDim .y ;
710
8- x += int64_t (row)*ncols;
9- dst += int64_t (row)*ncols;
11+ const int row = blockIdx .x ;
12+ const int channel = blockIdx .y ;
13+ const int sample = blockIdx .z ;
14+ const int tid = threadIdx .x ;
15+
16+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
17+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
1018
1119 float2 mean_var = make_float2 (0 .0f , 0 .0f );
1220
@@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
97105}
98106
99107template <int block_size>
100- static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
101- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
102- const int tid = threadIdx .x ;
108+ static __global__ void rms_norm_f32 (
109+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110+ const int64_t stride_sample, const float eps) {
111+ const int nrows = gridDim .x ;
112+ const int nchannels = gridDim .y ;
113+
114+ const int row = blockIdx .x ;
115+ const int channel = blockIdx .y ;
116+ const int sample = blockIdx .z ;
117+ const int tid = threadIdx .x ;
103118
104- x += int64_t ( row)*ncols ;
105- dst += int64_t ( row)*ncols;
119+ x += sample*stride_sample + channel*stride_channel + row*stride_row ;
120+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
106121
107122 float tmp = 0 .0f ; // partial sum for thread in warp
108123
@@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
186201 }
187202}
188203
189- static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
204+ static void norm_f32_cuda (
205+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
206+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
207+ const dim3 blocks_num (nrows, nchannels, nsamples);
190208 if (ncols < 1024 ) {
191209 const dim3 block_dims (WARP_SIZE, 1 , 1 );
192- norm_f32<WARP_SIZE><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
210+ norm_f32<WARP_SIZE><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
193211 } else {
194212 const dim3 block_dims (1024 , 1 , 1 );
195- norm_f32<1024 ><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
213+ norm_f32<1024 ><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
196214 }
197215}
198216
@@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
207225 }
208226}
209227
210- static void rms_norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
228+ static void rms_norm_f32_cuda (
229+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
230+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
231+ const dim3 blocks_num (nrows, nchannels, nsamples);
211232 if (ncols < 1024 ) {
212233 const dim3 block_dims (WARP_SIZE, 1 , 1 );
213- rms_norm_f32<WARP_SIZE><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
234+ rms_norm_f32<WARP_SIZE><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
214235 } else {
215236 const dim3 block_dims (1024 , 1 , 1 );
216- rms_norm_f32<1024 ><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
237+ rms_norm_f32<1024 ><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
217238 }
218239}
219240
@@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
229250
230251void ggml_cuda_op_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231252 const ggml_tensor * src0 = dst->src [0 ];
232- const float * src0_d = (const float *)src0->data ;
233- float * dst_d = (float *)dst->data ;
253+ const float * src0_d = (const float *) src0->data ;
254+ float * dst_d = (float *) dst->data ;
234255 cudaStream_t stream = ctx.stream ();
235256
236- GGML_ASSERT (ggml_is_contiguous (src0));
237-
238257 GGML_ASSERT (src0->type == GGML_TYPE_F32);
239258 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
240259
241- const int64_t ne00 = src0->ne [0 ];
242- const int64_t nrows = ggml_nrows (src0);
260+ GGML_TENSOR_UNARY_OP_LOCALS;
243261
244262 float eps;
245263 memcpy (&eps, dst->op_params , sizeof (float ));
246264 GGML_ASSERT (eps >= 0 .0f );
247265
248- norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
266+ const size_t ts0 = ggml_type_size (src0->type );
267+ GGML_ASSERT (nb00 == ts0);
268+ const int64_t s01 = nb01 / ts0;
269+ const int64_t s02 = nb02 / ts0;
270+ const int64_t s03 = nb03 / ts0;
271+
272+ norm_f32_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
249273}
250274
251275void ggml_cuda_op_group_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
254278 float * dst_d = (float *)dst->data ;
255279 cudaStream_t stream = ctx.stream ();
256280
257- GGML_ASSERT (ggml_is_contiguous (src0));
258-
259281 GGML_ASSERT (src0->type == GGML_TYPE_F32);
260282 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
261283
@@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
271293
272294void ggml_cuda_op_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273295 const ggml_tensor * src0 = dst->src [0 ];
274- const float * src0_d = (const float *)src0->data ;
275- float * dst_d = (float *)dst->data ;
296+ const float * src0_d = (const float *) src0->data ;
297+ float * dst_d = (float *) dst->data ;
276298 cudaStream_t stream = ctx.stream ();
277299
278- GGML_ASSERT (ggml_is_contiguous (src0));
279-
280300 GGML_ASSERT (src0->type == GGML_TYPE_F32);
281301 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
282302
283- const int64_t ne00 = src0->ne [0 ];
284- const int64_t nrows = ggml_nrows (src0);
303+ GGML_TENSOR_UNARY_OP_LOCALS;
285304
286305 float eps;
287306 memcpy (&eps, dst->op_params , sizeof (float ));
288307 GGML_ASSERT (eps >= 0 .0f );
289308
290- rms_norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
309+ const size_t ts0 = ggml_type_size (src0->type );
310+ GGML_ASSERT (nb00 == ts0);
311+ const int64_t s01 = nb01 / ts0;
312+ const int64_t s02 = nb02 / ts0;
313+ const int64_t s03 = nb03 / ts0;
314+
315+ rms_norm_f32_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
291316}
292317
293318void ggml_cuda_op_rms_norm_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments