Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
523a317
Add gfx950 build support + fp16 fix + index type fix
avbokovoy Jul 29, 2025
aee3078
Change int64_t to index_t as template parameters in load_raw_per_warp
avbokovoy Jul 29, 2025
5a1ac2e
Implement llvm fp16 buffer load for gfx950
avbokovoy Jul 29, 2025
7856903
Fix c-style half to float cast
avbokovoy Aug 11, 2025
e1e246a
Patch 256 half stores
avbokovoy Aug 11, 2025
6a99fe0
cta_per_row workgroup optim
shbiswas834 Aug 8, 2025
349a7b5
Added mi350 guards
shbiswas834 Aug 11, 2025
1178cd1
Fix index overflow in row load
shbiswas834 Aug 12, 2025
606ad34
cta_per_row workgroup reduce by 4 optim
shbiswas834 Aug 12, 2025
a22ddeb
Fix mixed_D frontend to backend connection
avbokovoy Aug 13, 2025
6775452
changed max_segment_length_per_cta to 4096
kudomcho Aug 15, 2025
90e6ba7
added rocm guards and removed comment
shbiswas834 Aug 18, 2025
a9073ac
clean debug statements in Hip.cmake
liligwu Aug 20, 2025
9a16e12
Merge pull request #121
shbiswas834 Aug 28, 2025
68630da
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
bac0610
fix the bug in dimention 160 in ROCm optimization
liligwu Sep 18, 2025
a12112f
Cleanup optimized warp_per_raw kernel
avbokovoy Aug 19, 2025
3ef64f7
Add 320 embedding dim support for optimized warp_per_row kernel
avbokovoy Aug 20, 2025
f601e55
changed the max length per warp and cta per row WG size
Sep 8, 2025
04916da
added DPP and changed max length per warp to 16k
kudomcho Sep 9, 2025
1e09555
guard max segment warp based on emb dim
kudomcho Sep 10, 2025
b41192b
added guarding opt of max segment for the case batch size list=1
kudomcho Sep 10, 2025
2b08f96
opt for grad_indice_weights kernel
Sep 18, 2025
0c26470
added store row per warp on emb 192 and added accuracy test functiona…
kudomcho Sep 23, 2025
d6b491b
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
70ed5e2
specialize
Hardcode84 Sep 19, 2025
cf6a2b1
explicitly link to tbb
liligwu Sep 24, 2025
1be9bd8
added warpReduceAllSum with rocm guards
shbiswas834 Sep 25, 2025
9d3ee64
revert unroll and wg tuning
shbiswas834 Oct 13, 2025
a5a3b1e
Minor update embedding_forward_split_kernel_template.cu
liligwu Oct 13, 2025
28e93c0
add tbb-devel to the install_build_tools ()
liligwu Oct 17, 2025
842846c
fix lint issues
liligwu Oct 21, 2025
97aeb83
solve lint issues
liligwu Oct 21, 2025
00976c7
applied jinja is_rocm onto optimizations for backward and forward par…
kudomcho Oct 22, 2025
4c19030
Guard supported grad_t for optimized warp_per_row dispatch
avbokovoy Oct 23, 2025
9991cf1
Forward index_t to the optimizer
avbokovoy Oct 23, 2025
b61bd19
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
c38ff6f
Fix buffer offset for emb_dim == 160
avbokovoy Oct 23, 2025
e201e8b
Remove sanity check
avbokovoy Oct 27, 2025
aaaf80c
address the potential lint issues and revert the change in indices_ge…
liligwu Oct 27, 2025
b8aea67
addresss code style issue
liligwu Oct 27, 2025
b9a7759
removed guard rocm on mixed_D and refactored mixed_D var assignment
kudomcho Oct 28, 2025
a4b4431
Remove general load/store methods
avbokovoy Oct 24, 2025
5d4f2cd
Move weight type check to compile-time
avbokovoy Oct 24, 2025
d3b7d7a
Switch to 256B stores for float type
avbokovoy Oct 27, 2025
878d00f
removed jinj is_rocm on total_L as USE_ROCM is already applied
kudomcho Nov 3, 2025
d2596c7
Change mixed_D default value to false
avbokovoy Nov 6, 2025
e076556
Make const work_group_size for CUDA
avbokovoy Nov 6, 2025
585300d
Add jinja comments to grad_indice_weights kernel
avbokovoy Nov 6, 2025
e0db2f1
Remove redundand comment
avbokovoy Nov 6, 2025
bf143c7
Unify cuda and rocm loops
avbokovoy Nov 6, 2025
c6b0a88
Added BLOCK_SIZE_ROCM
shbiswas834 Nov 11, 2025
122a583
revert the link to tbb
liligwu Nov 14, 2025
0b05877
hack param
Bernard-Liu Nov 2, 2025
4f5c9ed
support opt code_gen
Bernard-Liu Oct 27, 2025
bec7db4
support subwarp
yadaish Aug 6, 2025
a530e5c
update subwarp kernel
Bernard-Liu Oct 28, 2025
f3054d9
grad sum kernel unroll improvement
XingerZhu Oct 27, 2025
ac9e798
fix performance issuse
yadaish Oct 29, 2025
97ef821
fix vbe opt not imply
Bernard-Liu Nov 2, 2025
f19cb5d
fix smybol bug & rm comment
Bernard-Liu Nov 3, 2025
bc73399
eliminate warning of process_block
liligwu Nov 13, 2025
d4bfd1b
add rocm for macro
Bernard-Liu Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/utils_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ install_build_tools () {
patchelf \
rhash \
scikit-build \
tbb-devel \
tbb \
wheel \
xz \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1506,4 +1506,4 @@ def context_factory(on_trace_ready: Callable[[profile], None]):


if __name__ == "__main__":
cli()
cli()
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
36 changes: 36 additions & 0 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]:

at::acc_type<cache_t, true> multiplier = 0.0;
at::acc_type<cache_t, true> correction = 0.0;
"""
split_precomputation_preload = split_precomputation
split_precomputation += """
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;

Expand Down Expand Up @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]:
multiplier = SHFL_SYNC(multiplier, 0);
correction = SHFL_SYNC(correction, 0);
"""
split_precomputation_preload += """
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;

// Update the optimizer state. Use optimizer state offloading only if
// SSD and if enabled by the user
if (enable_optimizer_offloading) {
// Fetch the pointer to the optimizer state along the cache row
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
new_sum_square_grads += optimizer->momentum;
optimizer->momentum = new_sum_square_grads;

} else {
new_sum_square_grads += momentum1_val;
momentum1[idx] = new_sum_square_grads;
}

multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
if (weight_decay_mode == 1) {
// L2 regularization
correction = 1.0 - multiplier * weight_decay;
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
// Decoupled weight decay
correction = 1.0 - learning_rate * weight_decay;
} else {
// default value
correction = 1.0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
correction = SHFL_SYNC(correction, 0);
"""
split_weight_update_cpu = """
at::acc_type<grad_t, true> g_local_sum_square = 0.0;
for (int64_t d = 0; d < D; ++d) {
Expand Down Expand Up @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
},
),
"split_precomputation": split_precomputation,
"split_precomputation_preload": split_precomputation_preload,
"split_weight_update": split_weight_update,
"split_post_update": split_post_update,
"split_weight_update_cpu": split_weight_update_cpu,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = true */) {
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,100 @@

using namespace fbgemm_gpu;

{%- if is_rocm %}
// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1)
#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i);
#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i);
#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i);
#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i);
#define IDX_WEIGHT(i, j) at::acc_type<cache_t, true> idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i);

#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j);
#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j);
#define REPEAT_2(X, j) X(1, j);
#define REPEAT_1(X, j) // No additional variables needed for block size 1

#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n);
#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n);
#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n);
#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1

// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1)
// if nobag and is_index_select
#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]);
// elif nobag
#define GRAD_VEC_N(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[l_j_##i][d]);
// elif vbe
#define GRAD_VEC_V(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]);
// else
#define GRAD_VEC(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d);

// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1)
#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i);
// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1)
#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i);

// Core macro: Process blocks of specified size (block_size = 8/4/2/1)
// Parameters:
// - block_size: Size of each block to process
// - unroll_count: Number of unroll iterations for the inner loop
#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \
for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \
{%- if nobag %}
int32_t l_j_0 = SHFL_SYNC(l, j); \
REPEAT_##block_size(L, j) \
{%- elif vbe %}
/* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \
const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \
/* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \
REPEAT_##block_size(GRAD_OFFSET, j) \
{%- else %}
int32_t b_j_0 = SHFL_SYNC(b, j); \
REPEAT_##block_size(B, j) \
int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \
REPEAT_##block_size(D_START, j) \
{%- endif %}
{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \
REPEAT_##block_size(IDX_WEIGHT, j) \
{%- endif %}
{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}
\
for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \
const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \
/* Generate block_size Vec4TAcc objects and accumulate them */ \
Vec4TAcc<grad_t> grad_out_vec_0( \
{%- if nobag and is_index_select %}
&grad_output[grad_offset + l_j_0 * grad_stride + d] \
{%- elif nobag %}
&grad_output[l_j_0][d] \
{%- elif vbe %}
&grad_output[0][grad_offset_j_0 + d] \
{%- else %}
&grad_output[b_j_0][0] + D_start_j_0 + d \
{%- endif %}
); \
{%- if nobag and is_index_select %}
REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \
{%- elif nobag %}
REPEAT_##block_size(GRAD_VEC_N, d) \
{%- elif vbe %}
REPEAT_##block_size(GRAD_VEC_V, d) \
{%- else %}
REPEAT_##block_size(GRAD_VEC, d) \
{%- endif %}
\
{%- if weighted %}
grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \
REPEAT_##block_size(FMA_GRAD, vec) \
{%- else %}
grad_sum[vec].add_(grad_out_vec_0); \
REPEAT_##block_size(ADD_GRAD, vec) \
{%- endif %}
} \
}
{%- endif %}

{%- if gen_once %}
{#- /*
The kernels in this section will be generated only once for all TBE configs
Expand Down Expand Up @@ -141,7 +235,25 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
? sorted_indice_weights[segment_start + sl_j]
: 0.0;
{%- endif %}
for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) {
int32_t j = 0;

{%- if is_rocm %}
// Process blocks of different sizes with loop unrolling
if constexpr (sizeof(grad_t) <= 2) {
PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
}
PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)

#undef PROCESS_BLOCK

{%- else %}
for (; j < kThreadGroupSize && sl + j < sl_end; ++j) {
{%- if nobag %}
int32_t l_j = SHFL_SYNC(l, j);
{%- elif vbe %}
Expand Down Expand Up @@ -180,6 +292,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
{%- endif %}
}
}
{%- endif %}
}
{%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}

Expand All @@ -198,4 +311,4 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(

{%- endif %}

// clang-format on
// clang-format on
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ class {{ autograd_func }} :

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
Expand Down Expand Up @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
{%- else %}
const c10::SymInt vbe_output_size = -1,
{%- endif %}
const bool mixed_D = true
const bool mixed_D = false
) {
// TODO: refactor into macro
{%- if has_gpu_support %}
Expand Down
Loading