@@ -237,17 +237,13 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
237237
238238 // Process blocks of different sizes with loop unrolling
239239 if constexpr (sizeof (grad_t ) <= 2 ) {
240- #pragma unroll kFixedMaxVecsPerThread
241240 PROCESS_BLOCK (8 , kFixedMaxVecsPerThread , grad_sum, grad_output, grad_offset, \
242241 vec_start, kThreadGroupSize , threadIdx .x , VEC_WIDTH, D, j, sl, sl_end)
243242 }
244- #pragma unroll kFixedMaxVecsPerThread
245243 PROCESS_BLOCK (4 , kFixedMaxVecsPerThread , grad_sum, grad_output, grad_offset, \
246244 vec_start, kThreadGroupSize , threadIdx .x , VEC_WIDTH, D, j, sl, sl_end)
247- #pragma unroll kFixedMaxVecsPerThread
248245 PROCESS_BLOCK (2 , kFixedMaxVecsPerThread , grad_sum, grad_output, grad_offset, \
249246 vec_start, kThreadGroupSize , threadIdx .x , VEC_WIDTH, D, j, sl, sl_end)
250- #pragma unroll kFixedMaxVecsPerThread
251247 PROCESS_BLOCK (1 , kFixedMaxVecsPerThread , grad_sum, grad_output, grad_offset, \
252248 vec_start, kThreadGroupSize , threadIdx .x , VEC_WIDTH, D, j, sl, sl_end)
253249 }
@@ -266,6 +262,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
266262 }
267263}
268264
265+ #undef PROCESS_BLOCK
269266{%- endif %}
270267
271268 // clang-format on
0 commit comments