1414
1515using namespace fbgemm_gpu ;
1616
17+ {%- if is_rocm %}
1718// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1)
1819#define GRAD_OFFSET (i, j ) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i);
1920#define L (i, j ) int32_t l_j_##i = SHFL_SYNC(l, j + i);
@@ -105,6 +106,7 @@ using namespace fbgemm_gpu;
105106 {%- endif %}
106107 } \
107108 }
109+ {%- endif %}
108110
109111{%- if gen_once %}
110112{#- /*
@@ -235,6 +237,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
235237 {%- endif %}
236238 int32_t j = 0 ;
237239
240+ {%- if is_rocm %}
238241 // Process blocks of different sizes with loop unrolling
239242 if constexpr (sizeof (grad_t ) <= 2 ) {
240243 PROCESS_BLOCK (8 , kFixedMaxVecsPerThread , grad_sum, grad_output, grad_offset, \
@@ -246,6 +249,50 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
246249 vec_start, kThreadGroupSize , threadIdx .x , VEC_WIDTH, D, j, sl, sl_end)
247250 PROCESS_BLOCK (1 , kFixedMaxVecsPerThread , grad_sum, grad_output, grad_offset, \
248251 vec_start, kThreadGroupSize , threadIdx .x , VEC_WIDTH, D, j, sl, sl_end)
252+
253+ #undef PROCESS_BLOCK
254+
255+ {%- else %}
256+ for (; j < kThreadGroupSize && sl + j < sl_end; ++j) {
257+ {%- if nobag %}
258+ int32_t l_j = SHFL_SYNC (l, j);
259+ {%- elif vbe %}
260+ const auto grad_offset_j = SHFL_SYNC (grad_offset, j);
261+ {%- else %}
262+ int32_t b_j = SHFL_SYNC (b, j);
263+ int32_t D_start_j = SHFL_SYNC (D_start, j);
264+ {%- endif %}
265+
266+ {%- if weighted %}
267+ at::acc_type<cache_t , true > idx_weight_j = SHFL_SYNC (idx_weight, j);
268+ {%- endif %}
269+
270+ {%- set d = " (((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}
271+
272+ #pragma unroll kFixedMaxVecsPerThread
273+ for (int32_t vec = 0 ; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) {
274+ const int32_t d = {{ d }};
275+ Vec4TAcc<grad_t > grad_out_vec (
276+ {%- if nobag and is_index_select %}
277+ // grad_output is 1d
278+ &grad_output[grad_offset + l_j * grad_stride + d]
279+ {%- elif nobag %}
280+ &grad_output[l_j][d]
281+ {%- elif vbe %}
282+ &grad_output[0 ][grad_offset_j + d]
283+ {%- else %}
284+ &grad_output[b_j][0 ] + D_start_j + d
285+ {%- endif %} // if nobag
286+ );
287+
288+ {%- if weighted %}
289+ grad_sum[vec].fma_ (grad_out_vec, idx_weight_j);
290+ {%- else %}
291+ grad_sum[vec].add_ (grad_out_vec);
292+ {%- endif %}
293+ }
294+ }
295+ {%- endif %}
249296 }
250297 {%- set d_vec = " ((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}
251298
@@ -262,7 +309,6 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
262309 }
263310}
264311
265- #undef PROCESS_BLOCK
266312{%- endif %}
267313
268- // clang-format on
314+ // clang-format on
0 commit comments