2424
2525namespace nteffectivetransformer {
2626
27- // gelu code from
27+ // gelu code from
2828// https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v1/fastertransformer/cuda/cuda_kernels.cu#L26-L45
2929template <typename T>
3030__inline__ __device__
3131T gelu (T x)
3232{
33- float cdf = 0 .5f *
33+ float cdf = 0 .5f *
3434 (1 .0f + tanhf ((0 .7978845608028654f * (x + 0 .044715f * x * x * x))));
3535 return x * cdf;
3636}
3737
38- // reduce code from
38+ // reduce code from
3939// https://github.com/NVIDIA/DeepLearningExamples/blob/master/FasterTransformer/v1/fastertransformer/cuda/cuda_kernels.cu#L47-L73
4040
4141#define FINAL_MASK 0xffffffff
@@ -53,9 +53,9 @@ template <typename T>
5353__inline__ __device__
5454T blockReduceSum (T val)
5555{
56- static __shared__ T shared[32 ];
57- int lane = threadIdx .x & 0x1f ;
58- int wid = threadIdx .x >> 5 ;
56+ static __shared__ T shared[32 ];
57+ int lane = threadIdx .x & 0x1f ;
58+ int wid = threadIdx .x >> 5 ;
5959
6060 val = warpReduceSum<T>(val);
6161
@@ -71,7 +71,7 @@ T blockReduceSum(T val)
7171// / ***************************** add_bias + gelu *****************************
7272
7373template <typename T>
74- __global__
74+ __global__
7575void add_bias_act (T* out, const T* bias, int m, int n)
7676{
7777 T val, reg_bias;
@@ -112,9 +112,9 @@ template void add_bias_act_kernelLauncher<float>(
112112// / ************************** add_bias + layer_norm **************************
113113
114114template <typename T>
115- __global__
115+ __global__
116116void add_bias_input_layernorm (
117- T* out, const T* input, const T* bias, const T* gamma,
117+ T* out, const T* input, const T* bias, const T* gamma,
118118 const T* beta, int m, int n)
119119{
120120 int tid = threadIdx .x ;
@@ -126,7 +126,7 @@ void add_bias_input_layernorm(
126126
127127 float local_out = 0 .0f ;
128128 for (int i = tid; i < n; i += blockDim .x )
129- local_out += (float )(out[blockIdx .x * n + i]
129+ local_out += (float )(out[blockIdx .x * n + i]
130130 + input[blockIdx .x * n + i] + __ldg (&bias[i]));
131131
132132 mean = blockReduceSum<float >(local_out);
@@ -141,14 +141,14 @@ void add_bias_input_layernorm(
141141 __syncthreads ();
142142
143143 for (int i = tid; i < n; i += blockDim .x )
144- out[blockIdx .x * n + i] =
145- (T)(((local_out - s_mean) * rsqrtf (s_variance))
144+ out[blockIdx .x * n + i] =
145+ (T)(((local_out - s_mean) * rsqrtf (s_variance))
146146 * (float )(__ldg (&gamma[i])) + (float )(__ldg (&beta[i])));
147147}
148148
149149template <typename T>
150150void add_bias_input_layernorm_kernelLauncher (
151- T* out, const T* input, const T* bias,
151+ T* out, const T* input, const T* bias,
152152 const T* gamma, const T* beta, int m, int n, cudaStream_t stream)
153153{
154154 assert (n < 1024 );
@@ -159,28 +159,28 @@ void add_bias_input_layernorm_kernelLauncher(
159159}
160160
161161template void add_bias_input_layernorm_kernelLauncher<float >(
162- float * out, const float * input,
163- const float * bias, const float * gamma, const float * beta,
162+ float * out, const float * input,
163+ const float * bias, const float * gamma, const float * beta,
164164 int m, int n, cudaStream_t stream);
165165
166166// / *********************************** fin ***********************************
167167
168168
169169// / *********************** compresse transformer input ***********************
170170
171- __global__
171+ __global__
172172void compress_bert_input (
173173 // const T* from_tensor,
174- const int * mask, const int * prefix_sum,
174+ const int * mask, const int * prefix_sum,
175175 // T* to_tensor,
176176 int * batch_idx, int * word_idx,
177- int batch_size , int seq_len, int hidden_dim)
177+ int batch_size , int seq_len, int hidden_dim)
178178{
179179 int bid = blockIdx .y ; // batch
180- int wid = blockIdx .x ; // word
181- int tid = threadIdx .x ; //
182-
183- // / 1. count pos for from tensor
180+ int wid = blockIdx .x ; // word
181+ int tid = threadIdx .x ; //
182+
183+ // / 1. count pos for from tensor
184184 int mask_idx = bid * seq_len + wid;
185185
186186 if (mask[mask_idx] > 0.5 ) {
@@ -191,7 +191,7 @@ void compress_bert_input(
191191 batch_idx[valid_idx] = bid;
192192 word_idx[valid_idx] = wid;
193193 }
194-
194+
195195 // /// 3. copy src data
196196 // float* src_ptr = (float*)from_tensor;
197197 // float* dst_ptr = (float*)to_tensor;
@@ -203,10 +203,10 @@ void compress_bert_input(
203203
204204void compressBertInput_kernelLauncher (
205205 // const T* from_tensor,
206- const int * mask, const int * prefix_sum,
206+ const int * mask, const int * prefix_sum,
207207 // T* to_tensor,
208208 int * batch_idx, int * word_idx,
209- int batch_size , int seq_len, int hidden_dim, cudaStream_t stream)
209+ int batch_size , int seq_len, int hidden_dim, cudaStream_t stream)
210210{
211211 // / TODO : fp32
212212 dim3 grid (seq_len, batch_size);
@@ -215,7 +215,7 @@ void compressBertInput_kernelLauncher(
215215 assert (hidden_dim <= 1024 );
216216 compress_bert_input<<<grid, block, 0 , stream>>> (
217217 // from_tensor,
218- mask, prefix_sum,
218+ mask, prefix_sum,
219219 // to_tensor,
220220 batch_idx, word_idx,
221221 batch_size , seq_len, hidden_dim);
@@ -229,11 +229,11 @@ template<typename T>
229229__global__
230230void restore_bert_output (
231231 T* to_tensor,
232- const T* from_tensor, const int * batch_idx, const int * word_idx,
233- int valid_word_num, int seq_len, int hidden_dim)
232+ const T* from_tensor, const int * batch_idx, const int * word_idx,
233+ int valid_word_num, int seq_len, int hidden_dim)
234234{
235235 int bid = batch_idx[blockIdx .x ];
236- int wid = word_idx[blockIdx .x ];
236+ int wid = word_idx[blockIdx .x ];
237237 int tid = threadIdx .x ;
238238 int vid = blockIdx .x ;
239239
@@ -248,24 +248,24 @@ void restore_bert_output(
248248template <typename T>
249249void restoreBertOutput_kernelLauncher (
250250 T* to_tensor,
251- const T* from_tensor, const int * batch_idx, const int * word_idx,
252- int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream)
251+ const T* from_tensor, const int * batch_idx, const int * word_idx,
252+ int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream)
253253{
254254 // TODO : fp32
255255 dim3 grid (valid_word_num);
256256 dim3 block (hidden_dim);
257257 assert (hidden_dim <= 1024 );
258258 restore_bert_output<<<grid, block, 0 , stream>>> (
259- to_tensor,
259+ to_tensor,
260260 from_tensor, batch_idx, word_idx,
261261 valid_word_num, seq_len, hidden_dim);
262262}
263263
264264template void restoreBertOutput_kernelLauncher<float >(
265265 float * to_tensor,
266- const float * from_tensor, const int * batch_idx, const int * word_idx,
266+ const float * from_tensor, const int * batch_idx, const int * word_idx,
267267 int valid_word_num, int seq_len, int hidden_dim, cudaStream_t stream);
268-
268+
269269// / *********************************** fin ***********************************
270270
271271// / ***************************** exclusive scan ******************************
@@ -279,14 +279,14 @@ int ELEMENTS_PER_BLOCK = THREADS_PER_BLOCK * 2;
279279#define LOG_MEM_BANKS 5
280280#define CONFLICT_FREE_OFFSET (n ) ((n) >> LOG_MEM_BANKS)
281281
282- __global__ void prescan_large (int *output, const int *input, int n, int *sums)
282+ __global__ void prescan_large (int *output, const int *input, int n, int *sums)
283283{
284284 extern __shared__ int temp[];
285285
286286 int blockID = blockIdx .x ;
287287 int threadID = threadIdx .x ;
288288 int blockOffset = blockID * n;
289-
289+
290290 int ai = threadID;
291291 int bi = threadID + (n / 2 );
292292 int bankOffsetA = CONFLICT_FREE_OFFSET (ai);
@@ -312,11 +312,11 @@ __global__ void prescan_large(int *output, const int *input, int n, int *sums)
312312 __syncthreads ();
313313
314314
315- if (threadID == 0 ) {
315+ if (threadID == 0 ) {
316316 sums[blockID] = temp[n - 1 + CONFLICT_FREE_OFFSET (n - 1 )];
317317 temp[n - 1 + CONFLICT_FREE_OFFSET (n - 1 )] = 0 ;
318- }
319-
318+ }
319+
320320 for (int d = 1 ; d < n; d *= 2 ) // traverse down tree & build scan
321321 {
322322 offset >>= 1 ;
@@ -350,7 +350,7 @@ __global__ void prescan_arbitrary(
350350 int bankOffsetA = CONFLICT_FREE_OFFSET (ai);
351351 int bankOffsetB = CONFLICT_FREE_OFFSET (bi);
352352
353-
353+
354354 if (threadID < n) {
355355 temp[ai + bankOffsetA] = input[ai];
356356 temp[bi + bankOffsetB] = input[bi];
@@ -359,11 +359,11 @@ __global__ void prescan_arbitrary(
359359 temp[ai + bankOffsetA] = 0 ;
360360 temp[bi + bankOffsetB] = 0 ;
361361 }
362-
362+
363363
364364 int offset = 1 ;
365365 // build sum in place up the tree
366- for (int d = powerOfTwo >> 1 ; d > 0 ; d >>= 1 )
366+ for (int d = powerOfTwo >> 1 ; d > 0 ; d >>= 1 )
367367 {
368368 __syncthreads ();
369369 if (threadID < d)
@@ -380,7 +380,7 @@ __global__ void prescan_arbitrary(
380380
381381 if (threadID == 0 ) {
382382 // clear the last element
383- temp[powerOfTwo - 1 + CONFLICT_FREE_OFFSET (powerOfTwo - 1 )] = 0 ;
383+ temp[powerOfTwo - 1 + CONFLICT_FREE_OFFSET (powerOfTwo - 1 )] = 0 ;
384384 }
385385
386386 for (int d = 1 ; d < powerOfTwo; d *= 2 ) // traverse down tree & build scan
@@ -435,15 +435,15 @@ int nextPowerOfTwo(int x) {
435435void scanSmallDeviceArray (
436436 int *d_out, const int * d_in, const int length, const cudaStream_t stream);
437437void scanLargeDeviceArray (
438- int *d_out, const int * d_in, const int length, int *d_buf,
438+ int *d_out, const int * d_in, const int length, int *d_buf,
439439 const cudaStream_t stream);
440440void scanLargeEvenDeviceArray (
441- int *d_out, const int * d_in, const int length, int *d_buf,
441+ int *d_out, const int * d_in, const int length, int *d_buf,
442442 const cudaStream_t stream);
443443
444444void scanLargeEvenDeviceArray (
445- int *d_out, const int * d_in, const int length, int *d_buf,
446- const cudaStream_t stream)
445+ int *d_out, const int * d_in, const int length, int *d_buf,
446+ const cudaStream_t stream)
447447{
448448 const int blocks = length / ELEMENTS_PER_BLOCK;
449449 const int sharedMemArraySize = ELEMENTS_PER_BLOCK * sizeof (int );
@@ -471,18 +471,18 @@ void scanLargeEvenDeviceArray(
471471}
472472
473473void scanSmallDeviceArray (
474- int *d_out, const int * d_in, const int length, const cudaStream_t stream)
474+ int *d_out, const int * d_in, const int length, const cudaStream_t stream)
475475{
476476 int powerOfTwo = nextPowerOfTwo (length);
477477 prescan_arbitrary
478478 <<<1 , (length + 1 ) / 2 , 2 * powerOfTwo * sizeof (int ), stream >>> (
479479 d_out, d_in, length, powerOfTwo);
480480}
481481
482- // /
482+ // /
483483void scanLargeDeviceArray (
484- int *d_out, const int * d_in, const int length, int *d_buf,
485- const cudaStream_t stream)
484+ int *d_out, const int * d_in, const int length, int *d_buf,
485+ const cudaStream_t stream)
486486{
487487 int remainder = length % (ELEMENTS_PER_BLOCK);
488488 if (remainder == 0 ) {
@@ -493,20 +493,20 @@ void scanLargeDeviceArray(
493493 int lengthMultiple = length - remainder;
494494 scanLargeEvenDeviceArray (d_out, d_in, lengthMultiple, d_buf, stream);
495495
496- // scan the remaining elements and add the (inclusive)
496+ // scan the remaining elements and add the (inclusive)
497497 // last element of the large scan to this
498498 int *startOfOutputArray = &(d_out[lengthMultiple]);
499499 scanSmallDeviceArray (
500500 startOfOutputArray, &(d_in[lengthMultiple]), remainder, stream);
501501
502502 add<<<1 , remainder, 0 , stream>>> (
503- startOfOutputArray, remainder, &(d_in[lengthMultiple - 1 ]),
503+ startOfOutputArray, remainder, &(d_in[lengthMultiple - 1 ]),
504504 &(d_out[lengthMultiple - 1 ]));
505505 }
506506}
507507
508508void exclusiveScan_kernelLauncher (
509- int * d_out, const int * d_in, const int length, const cudaStream_t stream)
509+ int * d_out, const int * d_in, const int length, const cudaStream_t stream)
510510{
511511 if (length > ELEMENTS_PER_BLOCK) {
512512 scanLargeDeviceArray (d_out, d_in, length, d_out + length, stream);
0 commit comments