From 523a31759bdc0012bfe4c2eb44f1bf53fe4a1158 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 11:57:27 +0000 Subject: [PATCH 01/53] Add gfx950 build support + fp16 fix + index type fix --- fbgemm_gpu/cmake/Hip.cmake | 8 ++++++++ .../embedding_backward_split_template.cu | 2 +- ..._backward_split_device_kernel_template.hip | 2 +- .../include/fbgemm_gpu/rocm/cdna_guard.h | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 20 ++++++++++++++++++- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 2 +- 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76eba64c99..76a2b347d8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1193,7 +1193,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..5acc61382e 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..c96da01063 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -215,6 +215,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -471,7 +489,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index dfea2dce8a..361059020e 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From aee3078a77b5aad122ceac022a16ce9dfad0e9c0 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 13:16:41 +0000 Subject: [PATCH 02/53] Change int64_t to index_t as template parameters in load_raw_per_warp --- .../rocm/embedding_backward_split_device_kernel_template.hip | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 5acc61382e..d5841d6e00 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -452,7 +452,7 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); From 5a1ac2e85de86b772ed609d0c5662b80b2cc0e3d Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 14:39:22 +0000 Subject: [PATCH 03/53] Implement llvm fp16 buffer load for gfx950 --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c96da01063..4b33fd1422 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -72,7 +77,12 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, From 78569031c9133f7175241617fdd6e81eca6c2b5c Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:23:47 +0000 Subject: [PATCH 04/53] Fix c-style half to float cast --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 4b33fd1422..238a83440a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -261,7 +261,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } From e1e246a825386da2c6ccc206ba522e988aa6ba0f Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:24:29 +0000 Subject: [PATCH 05/53] Patch 256 half stores --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 238a83440a..974eae2594 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -294,6 +294,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { From 6a99fe08a80f133d74502167bad5ff1ce3143019 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Fri, 8 Aug 2025 05:02:58 +0000 Subject: [PATCH 06/53] cta_per_row workgroup optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76a2b347d8..9412edc1a5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1042,7 +1042,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,7 +1053,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, (kMaxThreads/2)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From 349a7b5fc90935325e58eef100fcf308d6df17d7 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 11 Aug 2025 21:06:48 +0000 Subject: [PATCH 07/53] Added mi350 guards --- ...ding_backward_split_indice_weights_template.cu | 15 ++++++++++++++- .../backward/embedding_backward_split_template.cu | 10 ++++++++++ .../forward/embedding_forward_split_template.cu | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9e1f71ef4e --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); - + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 9412edc1a5..9e9e7aac68 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 37e774bb49..2861b631a0 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -459,6 +463,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} From 1178cd101308894e5463d63bccba652bd9784d23 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 15:09:39 +0000 Subject: [PATCH 08/53] Fix index overflow in row load --- ..._backward_split_device_kernel_template.hip | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d5841d6e00..d1a874805a 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -420,7 +420,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -441,7 +441,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); From 606ad34c72469df96096853603ea56abe976949e Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 20:13:09 +0000 Subject: [PATCH 09/53] cta_per_row workgroup reduce by 4 optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 9e9e7aac68..c59f6fe9aa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1052,7 +1052,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1063,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/2)), + div_round_up(total_unique_indices, (kMaxThreads/4)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From a22ddebacc0aef64c2fb0569835eaed9786b5f0e Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 13 Aug 2025 13:21:38 +0000 Subject: [PATCH 10/53] Fix mixed_D frontend to backend connection --- .../training/backward/embedding_backward_split_template.cu | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 1 + .../split_table_batched_embeddings_ops_training.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c59f6fe9aa..c8a846a552 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1203,7 +1203,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 0f4814fb41..da0c69ad21 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,6 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 0a2a7ab0a1..4f1741b3dc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -820,7 +820,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2517,6 +2517,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2535,6 +2536,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2544,6 +2546,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) From 677545246e4091eac87236831092dad39b3e3fe1 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Fri, 15 Aug 2025 15:32:19 +0000 Subject: [PATCH 11/53] changed max_segment_length_per_cta to 4096 --- .../training/backward/embedding_backward_split_template.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c8a846a552..1ddcea55b2 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { From 90e6ba7a5f3793e52a1f283456e1b3bbc7a4634b Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Mon, 18 Aug 2025 22:32:58 +0000 Subject: [PATCH 12/53] added rocm guards and removed comment --- .../embedding_backward_split_template.cu | 19 ++++++++++++++++--- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) mode change 100644 => 100755 fbgemm_gpu/src/tbe/eeg/indices_generator.cpp diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 1ddcea55b2..099c7e5685 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #ifdef USE_ROCM + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #else + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + #endif Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { @@ -1052,7 +1056,11 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #ifdef USE_ROCM + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #else + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1071,12 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/4)), + #ifdef USE_ROCM + div_round_up(total_unique_indices, (kMaxThreads/4)), + #else + div_round_up(total_unique_indices, kMaxThreads), + #endif + get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755 index 361059020e..715acd8c0c --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,6 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From a9073ac78c2b6689e317d92bbc278014699ff7a7 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 20 Aug 2025 03:00:56 +0000 Subject: [PATCH 13/53] clean debug statements in Hip.cmake --- fbgemm_gpu/cmake/Hip.cmake | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 2011a34c33..17640b7254 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,14 +78,6 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) - list(APPEND HIP_CXX_FLAGS -g) - list(APPEND HIP_CXX_FLAGS -ggdb) - - # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) - #list(APPEND HIP_CXX_FLAGS -adhln) - list(APPEND HIP_CXX_FLAGS -save-temps) - list(APPEND HIP_CXX_FLAGS -fverbose-asm) - set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use From 9a16e1265308a13dec7170e08f1b6736fc595e8a Mon Sep 17 00:00:00 2001 From: Shreya Date: Thu, 28 Aug 2025 11:43:32 -0500 Subject: [PATCH 14/53] Merge pull request #121 warp per row wg change --- .../embedding_backward_split_template.cu | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 099c7e5685..2425322948 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1056,10 +1056,21 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); + int32_t total_L = indices.numel(); #ifdef USE_ROCM - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1){ + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else{ + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } #else int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t work_group_size = kMaxThreads; #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, @@ -1071,17 +1082,13 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - #ifdef USE_ROCM - div_round_up(total_unique_indices, (kMaxThreads/4)), - #else - div_round_up(total_unique_indices, kMaxThreads), - #endif - + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, + // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1185,7 +1192,18 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #ifdef USE_ROCM + int32_t num_warp_per_row_groups; + + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + #else + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #endif int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { From 68630daf7e1c8b9742eac9492d220b7a55e28cf1 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 15/53] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 974eae2594..46c4603381 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From bac0610aeb2d7842e8a4cd0efac729d196d35cee Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 18 Sep 2025 16:28:31 +0000 Subject: [PATCH 16/53] fix the bug in dimention 160 in ROCm optimization --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 46c4603381..8a97579d6a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From a12112f10f485a6f0263c71379ed683c99103be6 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 19 Aug 2025 13:41:17 +0000 Subject: [PATCH 17/53] Cleanup optimized warp_per_raw kernel --- fbgemm_gpu/cmake/tbe_sources.py | 2 - .../genscript/generate_backward_split.py | 10 +- ...ing_backward_split_kernel_warp_template.cu | 40 +++----- .../embedding_backward_split_template.cu | 18 ++-- ..._backward_split_device_kernel_template.hip | 94 +++++-------------- 5 files changed, 54 insertions(+), 110 deletions(-) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a5277a906a..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..1158721526 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - + auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 2425322948..fb125101e7 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -862,8 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1226,8 +1233,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d1a874805a..951cff4399 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -221,21 +208,15 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ @@ -244,23 +225,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,23 +261,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,21 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} + table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); @@ -346,23 +314,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,23 +341,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,12 +374,9 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -435,12 +392,9 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( From 3ef64f7c3bcca2904a74a40e794eccfa1ad3043b Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 20 Aug 2025 12:15:37 +0000 Subject: [PATCH 18/53] Add 320 embedding dim support for optimized warp_per_row kernel --- ...ing_backward_split_kernel_warp_template.cu | 2 +- .../embedding_backward_split_template.cu | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1158721526..e61b3fc0aa 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -766,7 +766,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fb125101e7..7eb2b6880f 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1243,7 +1243,7 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 8a97579d6a..5b9d69d910 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -205,6 +205,22 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + } +}; + template struct load_row_per_warp { static __device__ void @@ -304,6 +320,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + p_output[lane_id + 256] = acc[4]; + } +}; + template <> struct store_row_per_warp { From f601e553f0ce84af2ac728ec3266ac49dce8b7eb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Sep 2025 19:34:16 +0000 Subject: [PATCH 19/53] changed the max length per warp and cta per row WG size --- .../backward/embedding_backward_split_host_template.cpp | 2 +- .../training/backward/embedding_backward_split_template.cu | 6 +----- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 134a03b983..8a0cd9daca 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7eb2b6880f..86d4ce8b8b 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,11 +987,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - #ifdef USE_ROCM - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - #else - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - #endif + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 18378b6106..00673abc8b 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 4096; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index da0c69ad21..bf5b56c079 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1060,7 +1060,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From 04916da836e07175e3ecc4b27daa2960a994a360 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 9 Sep 2025 20:25:30 +0000 Subject: [PATCH 20/53] added DPP and changed max length per warp to 16k --- .../embedding_backward_split_host_template.cpp | 2 +- .../index_select/batch_index_select_dim0_host.cpp | 4 ++-- .../embedding_split_host_pt2_autograd_template.cpp | 2 +- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 14 ++++++++------ 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 8a0cd9daca..2ea96a107e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 00673abc8b..02529f2d89 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -342,7 +342,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index bf5b56c079..bcecc7c91c 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1060,7 +1060,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 0d65c4798a..a1d9819017 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,11 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); } DEVICE_INLINE void syncwarp() { From 1e09555b32d622cf1e903db762639e1dac267a11 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 19:33:44 +0000 Subject: [PATCH 21/53] guard max segment warp based on emb dim --- ...dding_split_host_pt2_autograd_template.cpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index bcecc7c91c..cbd7aceda9 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1060,7 +1060,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 16384; + int32_t max_segment_length_per_warp = 64; + // Workaround. Should not be upstreamed in any way. + // Redistribute all cta_per_row work to warp_per_row. + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + const auto B = (offsets.size(0) - 1) / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From b41192be39dc00545379016f3688c6d16ba0641d Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 22:00:20 +0000 Subject: [PATCH 22/53] added guarding opt of max segment for the case batch size list=1 --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index cbd7aceda9..787a9b6d2f 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1063,6 +1063,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and @@ -1071,9 +1072,10 @@ static torch::autograd::variable_list backward( (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); - const auto B = (offsets.size(0) - 1) / T; + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } From 2b08f965ec9e4cbd4d98ef7d4e20ccbff812341d Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 09:26:57 +0000 Subject: [PATCH 23/53] opt for grad_indice_weights kernel --- ..._backward_split_indice_weights_template.cu | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 9e1f71ef4e..b30e3e5c77 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -214,7 +214,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ) {%- endif %} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + int32_t j = 0; + {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + if (placement == PlacementType::MANAGED_CACHING) { + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + } else { + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + } + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + {%- endif %} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From 0c264705036d86ffa84ef1215bed646d90c3dc3e Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 23 Sep 2025 02:09:26 +0000 Subject: [PATCH 24/53] added store row per warp on emb 192 and added accuracy test functionality --- ...plit_table_batched_embeddings_benchmark.py | 223 +++++++++++++----- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 125 ++++++++-- .../fbgemm_gpu/rocm/split_embeddings_common.h | 18 +- 3 files changed, 277 insertions(+), 89 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..3fad8f53fe 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,7 +7,8 @@ # pyre-strict - +import gzip +import yaml import logging import os import tempfile @@ -1011,7 +1012,15 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1040,39 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1081,11 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1164,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,53 +1192,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests - + + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests assert len(requests) == iters - + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1203,34 +1280,44 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) + if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1331,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 00b51bbbe0..6d20a42c04 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,7 +19,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device - +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen logging.basicConfig(level=logging.DEBUG) @@ -248,36 +249,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -285,7 +293,86 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) + print("PASS") for it in range(iters): req = requests[it % num_reqs] diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 5b9d69d910..745499ac08 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,7 +24,6 @@ #include #include #include -#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -62,7 +61,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -79,7 +78,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); @@ -165,7 +164,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { @@ -320,6 +319,15 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); + } +}; + template <> struct store_row_per_warp { static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { @@ -619,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm +} // namespace fbgemm_gpu::rocm \ No newline at end of file From d6b491bb9c95af0142f663818c556c82960002a7 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 25/53] workgroup tuning and loop unrolled --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index aada1cdad5..69ad8cf8ca --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 70ed5e261f3eb6561b3b6625e5f4be1a5b36a50e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Sep 2025 22:38:17 +0200 Subject: [PATCH 26/53] specialize --- ..._backward_split_indice_weights_template.cu | 145 ++++++++++++------ 1 file changed, 95 insertions(+), 50 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index b30e3e5c77..0052d96406 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -217,33 +217,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only - for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { - const auto offset_idx_j0 = shfl_sync(offset_idx, j); - const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); - const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); - const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); - - const auto cache_idx_j0 = shfl_sync(cache_idx, j); - const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); - const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); - const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); - - at::acc_type grad_indice_weight0 = 0.0; - at::acc_type grad_indice_weight1 = 0.0; - at::acc_type grad_indice_weight2 = 0.0; - at::acc_type grad_indice_weight3 = 0.0; - - [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); - [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); - [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); - [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { - const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - Vec4T> weight0, weight1, weight2, weight3; - if (placement == PlacementType::MANAGED_CACHING) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; weight0 = (cache_idx_j0 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : weight_row0.load(d); @@ -259,33 +308,29 @@ __global__ __launch_bounds__(kForwardMaxThreads) void weight3 = (cache_idx_j3 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : weight_row3.load(d); - } else { - weight0 = weight_row0.load(d); - weight1 = weight_row1.load(d); - weight2 = weight_row2.load(d); - weight3 = weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; } - grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + - weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; - grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + - weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; - grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + - weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; - grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + - weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; - } - - grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); - grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); - grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); - grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - if (threadIdx.x == 0) { - grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; - grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; - grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; - grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } } } {%- endif %} @@ -447,7 +492,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] From cf6a2b1a965925e9fe9896d0ae96341d0564e582 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 24 Sep 2025 00:48:35 +0000 Subject: [PATCH 27/53] explicitly link to tbb --- cmake/modules/CppLibrary.cmake | 12 ++++++++++++ cmake/modules/GpuCppLibrary.cmake | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..388d3ac779 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PUBLIC TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) From 1be9bd8e5c693cfb55fca99b437fece631721b47 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Thu, 25 Sep 2025 19:00:23 +0000 Subject: [PATCH 28/53] added warpReduceAllSum with rocm guards --- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) mode change 100644 => 100755 fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index a1d9819017..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -140,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { - return rocm::wave_reduce< - rocm::reduce_op::sum, // Sum reduction - T, // Data type - ReduceWidth // Wave/Warp size - >(val); + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() { From 9d3ee64987ebd8296f897897cbd1310ff6f1d040 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 13 Oct 2025 20:34:59 +0000 Subject: [PATCH 29/53] revert unroll and wg tuning --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 69ad8cf8ca..acbf4563f3 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 8; + constexpr int kManualUnrollLength = 4; {%- endif %} // Determine the linearized warp ID, and exit early if needed From a5a3b1ec1a9dcba556fb38b179093838bac94e5e Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 13 Oct 2025 15:46:07 -0500 Subject: [PATCH 30/53] Minor update embedding_forward_split_kernel_template.cu --- .../forward/embedding_forward_split_kernel_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index acbf4563f3..aada1cdad5 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { + for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group From 28e93c020f2a72d20a75168771e73f9ffeccb37c Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 17 Oct 2025 21:17:37 +0000 Subject: [PATCH 31/53] add tbb-devel to the install_build_tools () --- .github/scripts/utils_build.bash | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ From 842846c7b37e37ee3df6534e5102bce25bb9d11c Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 21 Oct 2025 18:54:56 +0000 Subject: [PATCH 32/53] fix lint issues --- ...plit_table_batched_embeddings_benchmark.py | 35 +++++++++---------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 4 +-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 3fad8f53fe..02d3820b07 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -1192,7 +1192,6 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: requests = [] for i in range(iters): @@ -1253,7 +1252,6 @@ def device_with_spec( # noqa C901 torch.save(req.offsets, f"{save}/{i}_offsets.pt") torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") - sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1280,23 +1278,22 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) - + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative @@ -1315,7 +1312,7 @@ def device_with_spec( # noqa C901 # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - + if save: torch.save(grad_output, f"{save}/grad_output.pt") # backward diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 6d20a42c04..f435370e36 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -305,11 +305,11 @@ def benchmark_requests( # noqa: C901 torch.save(out, f) else: torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - + out.backward(grad) torch.cuda.synchronize() torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - + if sliced: for id, t in enumerate(emb.split_embedding_weights()): if compressed: From 97aeb8329e99f2ca44902a39f02806ea5581d9c8 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 21 Oct 2025 21:23:38 +0000 Subject: [PATCH 33/53] solve lint issues --- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 18 +++++++----------- .../fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index f435370e36..3278252b5f 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -314,14 +314,13 @@ def benchmark_requests( # noqa: C901 for id, t in enumerate(emb.split_embedding_weights()): if compressed: with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) + torch.save(t[slice_min:slice_max, :].clone(), f) else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - else: if compressed: with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: @@ -332,11 +331,9 @@ def benchmark_requests( # noqa: C901 if load and emb: for it in range(iters): req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() out = emb(indices, offsets, weights) torch.cuda.synchronize() - out.backward(grad) torch.cuda.synchronize() emb_ref = copy.deepcopy(emb) @@ -346,8 +343,8 @@ def benchmark_requests( # noqa: C901 emb_ref.load_state_dict(torch.load(f)) else: emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: for id, t in enumerate(emb.split_embedding_weights()): if compressed: @@ -355,15 +352,15 @@ def benchmark_requests( # noqa: C901 w_ref = torch.load(f) else: w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + torch.testing.assert_close(t[slice_min:slice_max, :], w_ref, msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) else: for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) print("PASS") - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") @@ -375,7 +372,6 @@ def benchmark_requests( # noqa: C901 print("PASS") for it in range(iters): req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() if bwd_only: # Run forward before profiling if does backward only diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 745499ac08..aa869fe2b5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -627,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm \ No newline at end of file +} // namespace fbgemm_gpu::rocm From 00976c79cfb5bc9cf62a4cf3946dde20d19688a9 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 22 Oct 2025 18:45:41 +0000 Subject: [PATCH 34/53] applied jinja is_rocm onto optimizations for backward and forward parameters --- ..._backward_split_indice_weights_template.cu | 5 ++++- .../embedding_backward_split_template.cu | 22 ++++++++++--------- .../embedding_forward_split_template.cu | 4 ++-- .../batch_index_select_dim0_host.cpp | 4 ++-- ...dding_split_host_pt2_autograd_template.cpp | 4 ++++ 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 0052d96406..c58ba89f78 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -213,7 +213,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} - + {%- if is_rocm %} int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only @@ -335,6 +335,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } {%- endif %} for (; j < kWarpSize && l_start + j < L; ++j) { + {%- else %} // if is_rocm + for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 86d4ce8b8b..759bbfd9bb 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,8 +987,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1059,8 +1062,8 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t total_L = indices.numel(); - #ifdef USE_ROCM + {% if is_rocm %} + int32_t total_L = indices.numel(); int32_t num_cta_per_row_groups; int32_t work_group_size; if (total_L/total_B > 1){ @@ -1071,10 +1074,10 @@ Tensor {{ embedding_cuda_op }}( num_cta_per_row_groups = kMaxThreads / kWarpSize; work_group_size = kMaxThreads; } - #else + {%- else %} int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; int32_t work_group_size = kMaxThreads; - #endif + {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1091,7 +1094,6 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1195,7 +1197,7 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - #ifdef USE_ROCM + {%- if is_rocm %} int32_t num_warp_per_row_groups; if (total_L/total_B > 1){ @@ -1204,9 +1206,9 @@ Tensor {{ embedding_cuda_op }}( else{ num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; } - #else + {%- else %} int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - #endif + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 2861b631a0..dac49631cf 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -463,7 +463,7 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); - #ifdef USE_ROCM + {% if is_rocm %} if (!rocm::is_supported_cdna()) { TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); } @@ -471,7 +471,7 @@ batch_index_select_dim0_codegen_forward_cuda( // Ensure we're running on a supported CDNA architecture (including MI350) TORCH_WARN_ONCE("Running on CDNA architecture"); } - #endif + {%- endif %} {%- if not nobag %} int32_t T = D_offsets.numel() - 1; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 02529f2d89..608f6017ec 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -342,7 +342,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 16384; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 16384; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 787a9b6d2f..21cb348c21 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,8 +743,10 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + {% if is_rocm %} const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} + {%- endif %} // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -1063,7 +1065,9 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + {% if is_rocm %} int32_t total_L = indices.numel(); + {%- endif %} {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and From 4c19030121bc36e2ea2c5a2da959bb23f9df8f50 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:54:56 +0000 Subject: [PATCH 35/53] Guard supported grad_t for optimized warp_per_row dispatch --- .../training/backward/embedding_backward_split_template.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 759bbfd9bb..18beeae1ff 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1238,7 +1238,9 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; + + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} From 9991cf1ccec41bc3e6daa7ed5f412839ed9fad89 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:56:05 +0000 Subject: [PATCH 36/53] Forward index_t to the optimizer --- .../backward/embedding_backward_split_kernel_warp_template.cu | 2 +- .../rocm/embedding_backward_split_device_kernel_template.hip | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index e61b3fc0aa..b757f64d36 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -650,7 +650,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd opt_karg.weight_decay = weight_decay; rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 951cff4399..87d259ebee 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { From b61bd196fa96c748c4c5b9fd9cbcf6c9829edd60 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 37/53] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index aa869fe2b5..08e1efa3e9 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From c38ff6f72b2734ec955de4415f9068e22c683040 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:59:56 +0000 Subject: [PATCH 38/53] Fix buffer offset for emb_dim == 160 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 08e1efa3e9..c1d98d3e9f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From e201e8b7ee5b1633acdd63dd3b8fac957dff3cea Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 27 Oct 2025 14:36:52 +0000 Subject: [PATCH 39/53] Remove sanity check --- ...plit_table_batched_embeddings_benchmark.py | 190 +++++------------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 123 ++---------- 2 files changed, 70 insertions(+), 243 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 02d3820b07..4ffb7341a5 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,8 +7,7 @@ # pyre-strict -import gzip -import yaml + import logging import os import tempfile @@ -1012,15 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) -@click.pass_context def device_with_spec( # noqa C901 - ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1040,39 +1031,7 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, ) -> None: - if load: - with open(f"{load}/params.yaml", "r") as f: - ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) - alpha = ctx.params["alpha"] - bag_size_list = ctx.params["bag_size_list"] - bag_size_sigma_list = ctx.params["bag_size_sigma_list"] - batch_size = ctx.params["batch_size"] - embedding_dim_list = ctx.params["embedding_dim_list"] - weights_precision = ctx.params["weights_precision"] - cache_precision = ctx.params["cache_precision"] - stoc = ctx.params["stoc"] - iters = ctx.params["iters"] - warmup_runs = ctx.params["warmup_runs"] - managed = ctx.params["managed"] - num_embeddings_list = ctx.params["num_embeddings_list"] - reuse = ctx.params["reuse"] - row_wise = ctx.params["row_wise"] - weighted = ctx.params["weighted"] - pooling = ctx.params["pooling"] - bounds_check_mode = ctx.params["bounds_check_mode"] - flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] - output_dtype = ctx.params["output_dtype"] - random_weights = ctx.params["random_weights"] - compressed = ctx.params["compressed"] - slice_min = ctx.params["slice_min"] - slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1081,11 +1040,6 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" - params = ctx.params - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1164,22 +1118,6 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) - elif random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1192,66 +1130,53 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: - requests = [] - for i in range(iters): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") - requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) - else: - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - del all_requests + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + + del all_requests + assert len(requests) == iters - if save: - for i in range(iters): - req = requests[i] - torch.save(req.indices, f"{save}/{i}_indices.pt") - torch.save(req.offsets, f"{save}/{i}_offsets.pt") - torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") - torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1299,22 +1224,13 @@ def device_with_spec( # noqa C901 # backward bench not representative return - if load: - grad_output = torch.load(f"{load}/grad_output.pt") + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) - else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - - if save: - torch.save(grad_output, f"{save}/grad_output.pt") + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -1328,12 +1244,6 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 3278252b5f..00b51bbbe0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,7 +11,6 @@ import statistics import threading import time -import gzip from subprocess import Popen from typing import Callable, Optional @@ -19,7 +18,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device -from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen + logging.basicConfig(level=logging.DEBUG) @@ -249,43 +248,36 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - import copy + if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - if not (load or save): - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - sliced = slice_min is not None and slice_max is not None + if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -293,85 +285,10 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] - if save and emb: - for it in range(iters): - req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - if compressed: - with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: - torch.save(out, f) - else: - torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - - out.backward(grad) - torch.cuda.synchronize() - torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max, :].clone(), f) - else: - torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it in range(iters): - req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max, :], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) - print("PASS") for it in range(iters): req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() if bwd_only: # Run forward before profiling if does backward only From aaaf80c07c9c52e023d6b43fbce263527b33c156 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 27 Oct 2025 20:14:58 +0000 Subject: [PATCH 40/53] address the potential lint issues and revert the change in indices_generator.cpp --- ...dding_backward_split_kernel_warp_template.cu | 11 +++++------ .../embedding_backward_split_template.cu | 13 ++++++------- ...ng_backward_split_device_kernel_template.hip | 17 ++++++++--------- .../forward/embedding_forward_split_template.cu | 2 +- ...bedding_split_host_pt2_autograd_template.cpp | 14 +++++++------- .../fbgemm_gpu/rocm/split_embeddings_common.h | 4 ++-- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 + 7 files changed, 30 insertions(+), 32 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index b757f64d36..7b3b5b653a 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,13 +32,13 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and not dense and - not nobag and + not nobag and not is_index_select and - not is_gwd_kernel and - not vbe and + not is_gwd_kernel and + not vbe and not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" @@ -621,7 +621,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ) { int32_t T = D_offsets.size(0) - 1; - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 18beeae1ff..72cf189ccc 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,13 +48,13 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and not dense and - not nobag and + not nobag and not is_index_select and - not is_gwd_kernel and - not vbe and + not is_gwd_kernel and + not vbe and not ssd %} template < @@ -669,7 +669,7 @@ Tensor {{ embedding_cuda_op }}( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -1199,7 +1199,6 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; - if (total_L/total_B > 1){ num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; } diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 87d259ebee..2a747731cc 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -225,7 +225,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -234,7 +234,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -270,7 +270,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -301,7 +301,6 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; @@ -314,7 +313,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -323,7 +322,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -341,7 +340,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -350,7 +349,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index dac49631cf..a3edb6b965 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -472,7 +472,7 @@ batch_index_select_dim0_codegen_forward_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } {%- endif %} - + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 21cb348c21..ce068b54d3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1068,18 +1068,18 @@ static torch::autograd::variable_list backward( {% if is_rocm %} int32_t total_L = indices.numel(); {%- endif %} - {%- if (not nobag) and - (optimizer == "rowwise_adagrad") and - (not vbe) and - (not is_gwd) and - (not ssd) and - (not is_index_select) and + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); auto total_B = (offsets.size(0) - 1); const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c1d98d3e9f..b5aa74c1ab 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -217,7 +217,7 @@ struct load_row_per_warp { *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; } }; @@ -335,7 +335,7 @@ struct store_row_per_warp { auto out = reinterpret_cast(p_output); out[lane_id] = *reinterpret_cast(acc); out[lane_id + 64] = *reinterpret_cast(&acc[2]); - p_output[lane_id + 256] = acc[4]; + p_output[lane_id + 256] = acc[4]; } }; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 715acd8c0c..dfea2dce8a 100755 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,6 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( + std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From b8aea67fb3f53276a66fb3c6365a17593ddeab13 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 27 Oct 2025 20:37:50 +0000 Subject: [PATCH 41/53] addresss code style issue --- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 608f6017ec..18378b6106 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; From b9a7759e8a40eeac3cd58a64b5f6336b1164026b Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 28 Oct 2025 19:16:27 +0000 Subject: [PATCH 42/53] removed guard rocm on mixed_D and refactored mixed_D var assignment --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index ce068b54d3..661f7b9b45 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,9 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; - {% if is_rocm %} - const auto mixed_D = aux_bool[IDX_MIXED_D]; - {%- endif %} + const auto mixed_D = static_cast(aux_bool[IDX_MIXED_D]); {%- endif %} // Default values for Dynamo tracing @@ -860,7 +858,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; From a4b44316fe9ab1f790f37cf2215950b212b01461 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 24 Oct 2025 14:44:16 +0000 Subject: [PATCH 43/53] Remove general load/store methods --- ..._backward_split_device_kernel_template.hip | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 397 ++++++++++++------ 2 files changed, 259 insertions(+), 140 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2a747731cc..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -410,6 +410,6 @@ L_tail_grad_acc: optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b5aa74c1ab..e6e575e6e5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,7 +21,11 @@ * ******************************************************************************/ #pragma once + +#include #include +#include + #include #include #include @@ -47,10 +51,10 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr, const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -60,8 +64,8 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) + int32_t soffset = 0, + int32_t glc_slc = 0) #if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else @@ -71,33 +75,59 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) + int32_t soffset = 0, + int32_t glc_slc = 0) #if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); #endif +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif + __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -107,35 +137,15 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host side + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { - emb_data[i] = 0.f; + static_assert(false, "HIP: Optimized load operation is not supported yet"); } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); } - } - } }; template @@ -145,7 +155,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -156,7 +166,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); } }; @@ -165,15 +175,11 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(half) * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -184,9 +190,9 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -198,10 +204,10 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; @@ -210,35 +216,15 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 320); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -256,9 +242,97 @@ struct load_row_per_warp { lane_id ); } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + } }; +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 256) * sizeof(float)); + } +}; template < typename emb_t, @@ -291,116 +365,161 @@ struct accumulate_row_per_warp { } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host function + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert(false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - out[lane_id + 64] = *reinterpret_cast(&acc[2]); +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - out[lane_id + 64] = *reinterpret_cast(&acc[2]); - p_output[lane_id + 256] = acc[4]; +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); } }; +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32( + acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(&acc[2]), + out_res, + (lane_id + 64) * sizeof(floatx2_t)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), + *reinterpret_cast(&acc[2]), out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + (lane_id + 64) * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 256) * sizeof(float)); } }; From 5d4f2cd349cbb2f9f1d027851ba342e156ae4f01 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 24 Oct 2025 14:57:41 +0000 Subject: [PATCH 44/53] Move weight type check to compile-time --- .../training/backward/embedding_backward_split_template.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 72cf189ccc..82acd61baa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1234,9 +1234,7 @@ Tensor {{ embedding_cuda_op }}( const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; - + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) From d3b7d7a893d0d69f968e5ce15c4d1b54836f1bcb Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 27 Oct 2025 11:51:43 +0000 Subject: [PATCH 45/53] Switch to 256B stores for float type --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 54 +++++++------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index e6e575e6e5..5475f74ddd 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -449,8 +449,7 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32( - acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); } }; @@ -458,10 +457,8 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); } }; @@ -469,12 +466,9 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; @@ -482,12 +476,9 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; @@ -495,14 +486,10 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[3], out_res, (lane_id + 192) * sizeof(float)); } }; @@ -510,16 +497,11 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[4], out_res, (lane_id + 256) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; From 878d00f7a9483d121f38428101f8fce87d1b64d0 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Mon, 3 Nov 2025 20:19:56 +0000 Subject: [PATCH 46/53] removed jinj is_rocm on total_L as USE_ROCM is already applied --- .../training/pt2/embedding_split_host_pt2_autograd_template.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 661f7b9b45..2b359ad06e 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1063,9 +1063,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. - {% if is_rocm %} int32_t total_L = indices.numel(); - {%- endif %} {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and From d2596c712588f23589a5f960a715d1afc604d949 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:26:50 +0000 Subject: [PATCH 47/53] Change mixed_D default value to false --- .../training/backward/embedding_backward_dense_host_cpu.cpp | 2 +- .../backward/embedding_backward_split_host_template.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 2ea96a107e..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} From e076556674df4879f96b64a433dd7f1812fafee1 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:30:55 +0000 Subject: [PATCH 48/53] Make const work_group_size for CUDA --- .../embedding_backward_split_template.cu | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 82acd61baa..f07ef5830e 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1063,20 +1063,20 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); {% if is_rocm %} - int32_t total_L = indices.numel(); - int32_t num_cta_per_row_groups; - int32_t work_group_size; - if (total_L/total_B > 1){ - num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; - work_group_size = (kMaxThreads/4); - } - else{ - num_cta_per_row_groups = kMaxThreads / kWarpSize; - work_group_size = kMaxThreads; - } + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } {%- else %} - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; - int32_t work_group_size = kMaxThreads; + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, From 585300ded18ac828bd183a51a57dba6ff788836d Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:33:04 +0000 Subject: [PATCH 49/53] Add jinja comments to grad_indice_weights kernel --- .../embedding_backward_split_indice_weights_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index c58ba89f78..57c6804e66 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -333,9 +333,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } } } - {%- endif %} + {%- endif %}{#-/* if not ssd and not dense and not use_vec_blocking and not vbe */#} for (; j < kWarpSize && l_start + j < L; ++j) { - {%- else %} // if is_rocm + {%- else %}{#-/* if is_rocm*/#} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); From e0db2f1b4a582dcbea9f6b63a75a16a4814890fa Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:48:15 +0000 Subject: [PATCH 50/53] Remove redundand comment --- .../training/pt2/embedding_split_host_pt2_autograd_template.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 2b359ad06e..a2304b3fb3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1061,8 +1061,6 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; int32_t max_segment_length_per_warp = 64; - // Workaround. Should not be upstreamed in any way. - // Redistribute all cta_per_row work to warp_per_row. int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and From bf143c73992f8f0242fae89a17c783977cf13694 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 11:27:11 +0000 Subject: [PATCH 51/53] Unify cuda and rocm loops --- .../embedding_backward_split_indice_weights_template.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 57c6804e66..9ffaea3a67 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -213,9 +213,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} - {%- if is_rocm %} int32_t j = 0; - {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only if (placement != PlacementType::MANAGED_CACHING) { for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { @@ -333,11 +332,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } } } - {%- endif %}{#-/* if not ssd and not dense and not use_vec_blocking and not vbe */#} + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} for (; j < kWarpSize && l_start + j < L; ++j) { - {%- else %}{#-/* if is_rocm*/#} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From c6b0a8827fafc860f4c4b59db951749e1b0b71df Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Tue, 11 Nov 2025 17:21:13 +0000 Subject: [PATCH 52/53] Added BLOCK_SIZE_ROCM --- .../backward/embedding_backward_split_kernel_warp_template.cu | 2 +- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 7b3b5b653a..deffa8bfab 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -655,7 +655,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd cache_t, grad_t, index_t, - BLOCK_SIZE, + BLOCK_SIZE_ROCM, embedding_dim, segment_prefetch, segment_unroll, diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 5475f74ddd..38c1ac1ea4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -36,7 +36,7 @@ typedef float floatx2_t __attribute__((ext_vector_type(2))); #define AMDGCN_BUFFER_RES_3 0x00027000 #define AMDGCN_WAVE_SIZE 64 #define THREADS_PER_ROW 64 -#define BLOCK_SIZE 256 +#define BLOCK_SIZE_ROCM 256 namespace fbgemm_gpu::rocm { template From 122a583418a4d9377a13710fd2c9c53a0bb36889 Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 14 Nov 2025 20:36:46 +0000 Subject: [PATCH 53/53] revert the link to tbb --- cmake/modules/CppLibrary.cmake | 12 ------------ cmake/modules/GpuCppLibrary.cmake | 12 ------------ 2 files changed, 24 deletions(-) diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 388d3ac779..92a93a60b6 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,18 +168,6 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() - if(NOT TARGET TBB::tbb) - find_package(TBB QUIET) - endif() - if(TBB_FOUND) - target_link_libraries(${lib_name} PUBLIC TBB::tbb) - else() - find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) - if(TBB_LIB) - target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) - endif() - endif() - # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index e662848348..51c30df750 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,18 +302,6 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() - if(NOT TARGET TBB::tbb) - find_package(TBB QUIET) - endif() - if(TBB_FOUND) - list(APPEND library_dependencies TBB::tbb) - else() - find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) - if(TBB_LIB) - list(APPEND library_dependencies ${TBB_LIB}) - endif() - endif() - # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies})