Skip to content

Commit d4bfd1b

Browse files
committed
add rocm for macro
1 parent bc73399 commit d4bfd1b

File tree

1 file changed

+48
-2
lines changed

1 file changed

+48
-2
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
using namespace fbgemm_gpu;
1616

17+
{%- if is_rocm %}
1718
// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1)
1819
#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i);
1920
#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i);
@@ -105,6 +106,7 @@ using namespace fbgemm_gpu;
105106
{%- endif %}
106107
} \
107108
}
109+
{%- endif %}
108110

109111
{%- if gen_once %}
110112
{#- /*
@@ -235,6 +237,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
235237
{%- endif %}
236238
int32_t j = 0;
237239

240+
{%- if is_rocm %}
238241
// Process blocks of different sizes with loop unrolling
239242
if constexpr (sizeof(grad_t) <= 2) {
240243
PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
@@ -246,6 +249,50 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
246249
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
247250
PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
248251
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
252+
253+
#undef PROCESS_BLOCK
254+
255+
{%- else %}
256+
for (; j < kThreadGroupSize && sl + j < sl_end; ++j) {
257+
{%- if nobag %}
258+
int32_t l_j = SHFL_SYNC(l, j);
259+
{%- elif vbe %}
260+
const auto grad_offset_j = SHFL_SYNC(grad_offset, j);
261+
{%- else %}
262+
int32_t b_j = SHFL_SYNC(b, j);
263+
int32_t D_start_j = SHFL_SYNC(D_start, j);
264+
{%- endif %}
265+
266+
{%- if weighted %}
267+
at::acc_type<cache_t, true> idx_weight_j = SHFL_SYNC(idx_weight, j);
268+
{%- endif %}
269+
270+
{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}
271+
272+
#pragma unroll kFixedMaxVecsPerThread
273+
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) {
274+
const int32_t d = {{ d }};
275+
Vec4TAcc<grad_t> grad_out_vec(
276+
{%- if nobag and is_index_select %}
277+
// grad_output is 1d
278+
&grad_output[grad_offset + l_j * grad_stride + d]
279+
{%- elif nobag %}
280+
&grad_output[l_j][d]
281+
{%- elif vbe %}
282+
&grad_output[0][grad_offset_j + d]
283+
{%- else %}
284+
&grad_output[b_j][0] + D_start_j + d
285+
{%- endif %} // if nobag
286+
);
287+
288+
{%- if weighted %}
289+
grad_sum[vec].fma_(grad_out_vec, idx_weight_j);
290+
{%- else %}
291+
grad_sum[vec].add_(grad_out_vec);
292+
{%- endif %}
293+
}
294+
}
295+
{%- endif %}
249296
}
250297
{%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}
251298

@@ -262,7 +309,6 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
262309
}
263310
}
264311

265-
#undef PROCESS_BLOCK
266312
{%- endif %}
267313

268-
// clang-format on
314+
// clang-format on

0 commit comments

Comments
 (0)