From 523a31759bdc0012bfe4c2eb44f1bf53fe4a1158 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 11:57:27 +0000 Subject: [PATCH 01/63] Add gfx950 build support + fp16 fix + index type fix --- fbgemm_gpu/cmake/Hip.cmake | 8 ++++++++ .../embedding_backward_split_template.cu | 2 +- ..._backward_split_device_kernel_template.hip | 2 +- .../include/fbgemm_gpu/rocm/cdna_guard.h | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 20 ++++++++++++++++++- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 2 +- 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76eba64c99..76a2b347d8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1193,7 +1193,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..5acc61382e 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..c96da01063 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -215,6 +215,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -471,7 +489,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index dfea2dce8a..361059020e 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From aee3078a77b5aad122ceac022a16ce9dfad0e9c0 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 13:16:41 +0000 Subject: [PATCH 02/63] Change int64_t to index_t as template parameters in load_raw_per_warp --- .../rocm/embedding_backward_split_device_kernel_template.hip | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 5acc61382e..d5841d6e00 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -452,7 +452,7 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); From 5a1ac2e85de86b772ed609d0c5662b80b2cc0e3d Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 14:39:22 +0000 Subject: [PATCH 03/63] Implement llvm fp16 buffer load for gfx950 --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c96da01063..4b33fd1422 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -72,7 +77,12 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, From 78569031c9133f7175241617fdd6e81eca6c2b5c Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:23:47 +0000 Subject: [PATCH 04/63] Fix c-style half to float cast --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 4b33fd1422..238a83440a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -261,7 +261,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } From e1e246a825386da2c6ccc206ba522e988aa6ba0f Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:24:29 +0000 Subject: [PATCH 05/63] Patch 256 half stores --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 238a83440a..974eae2594 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -294,6 +294,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { From 6a99fe08a80f133d74502167bad5ff1ce3143019 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Fri, 8 Aug 2025 05:02:58 +0000 Subject: [PATCH 06/63] cta_per_row workgroup optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76a2b347d8..9412edc1a5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1042,7 +1042,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,7 +1053,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, (kMaxThreads/2)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From 349a7b5fc90935325e58eef100fcf308d6df17d7 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 11 Aug 2025 21:06:48 +0000 Subject: [PATCH 07/63] Added mi350 guards --- ...ding_backward_split_indice_weights_template.cu | 15 ++++++++++++++- .../backward/embedding_backward_split_template.cu | 10 ++++++++++ .../forward/embedding_forward_split_template.cu | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9e1f71ef4e --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); - + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 9412edc1a5..9e9e7aac68 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 37e774bb49..2861b631a0 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -459,6 +463,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} From 1178cd101308894e5463d63bccba652bd9784d23 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 15:09:39 +0000 Subject: [PATCH 08/63] Fix index overflow in row load --- ..._backward_split_device_kernel_template.hip | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d5841d6e00..d1a874805a 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -420,7 +420,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -441,7 +441,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); From 606ad34c72469df96096853603ea56abe976949e Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 20:13:09 +0000 Subject: [PATCH 09/63] cta_per_row workgroup reduce by 4 optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 9e9e7aac68..c59f6fe9aa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1052,7 +1052,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1063,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/2)), + div_round_up(total_unique_indices, (kMaxThreads/4)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From a22ddebacc0aef64c2fb0569835eaed9786b5f0e Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 13 Aug 2025 13:21:38 +0000 Subject: [PATCH 10/63] Fix mixed_D frontend to backend connection --- .../training/backward/embedding_backward_split_template.cu | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 1 + .../split_table_batched_embeddings_ops_training.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c59f6fe9aa..c8a846a552 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1203,7 +1203,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 0f4814fb41..da0c69ad21 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,6 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 0a2a7ab0a1..4f1741b3dc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -820,7 +820,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2517,6 +2517,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2535,6 +2536,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2544,6 +2546,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) From 677545246e4091eac87236831092dad39b3e3fe1 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Fri, 15 Aug 2025 15:32:19 +0000 Subject: [PATCH 11/63] changed max_segment_length_per_cta to 4096 --- .../training/backward/embedding_backward_split_template.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c8a846a552..1ddcea55b2 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { From 90e6ba7a5f3793e52a1f283456e1b3bbc7a4634b Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Mon, 18 Aug 2025 22:32:58 +0000 Subject: [PATCH 12/63] added rocm guards and removed comment --- .../embedding_backward_split_template.cu | 19 ++++++++++++++++--- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) mode change 100644 => 100755 fbgemm_gpu/src/tbe/eeg/indices_generator.cpp diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 1ddcea55b2..099c7e5685 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #ifdef USE_ROCM + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #else + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + #endif Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { @@ -1052,7 +1056,11 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #ifdef USE_ROCM + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #else + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1071,12 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/4)), + #ifdef USE_ROCM + div_round_up(total_unique_indices, (kMaxThreads/4)), + #else + div_round_up(total_unique_indices, kMaxThreads), + #endif + get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755 index 361059020e..715acd8c0c --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,6 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From a9073ac78c2b6689e317d92bbc278014699ff7a7 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 20 Aug 2025 03:00:56 +0000 Subject: [PATCH 13/63] clean debug statements in Hip.cmake --- fbgemm_gpu/cmake/Hip.cmake | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 2011a34c33..17640b7254 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,14 +78,6 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) - list(APPEND HIP_CXX_FLAGS -g) - list(APPEND HIP_CXX_FLAGS -ggdb) - - # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) - #list(APPEND HIP_CXX_FLAGS -adhln) - list(APPEND HIP_CXX_FLAGS -save-temps) - list(APPEND HIP_CXX_FLAGS -fverbose-asm) - set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use From 9a16e1265308a13dec7170e08f1b6736fc595e8a Mon Sep 17 00:00:00 2001 From: Shreya Date: Thu, 28 Aug 2025 11:43:32 -0500 Subject: [PATCH 14/63] Merge pull request #121 warp per row wg change --- .../embedding_backward_split_template.cu | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 099c7e5685..2425322948 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1056,10 +1056,21 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); + int32_t total_L = indices.numel(); #ifdef USE_ROCM - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1){ + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else{ + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } #else int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t work_group_size = kMaxThreads; #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, @@ -1071,17 +1082,13 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - #ifdef USE_ROCM - div_round_up(total_unique_indices, (kMaxThreads/4)), - #else - div_round_up(total_unique_indices, kMaxThreads), - #endif - + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, + // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1185,7 +1192,18 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #ifdef USE_ROCM + int32_t num_warp_per_row_groups; + + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + #else + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #endif int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { From 68630daf7e1c8b9742eac9492d220b7a55e28cf1 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 15/63] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 974eae2594..46c4603381 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From bac0610aeb2d7842e8a4cd0efac729d196d35cee Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 18 Sep 2025 16:28:31 +0000 Subject: [PATCH 16/63] fix the bug in dimention 160 in ROCm optimization --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 46c4603381..8a97579d6a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From a12112f10f485a6f0263c71379ed683c99103be6 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 19 Aug 2025 13:41:17 +0000 Subject: [PATCH 17/63] Cleanup optimized warp_per_raw kernel --- fbgemm_gpu/cmake/tbe_sources.py | 2 - .../genscript/generate_backward_split.py | 10 +- ...ing_backward_split_kernel_warp_template.cu | 40 +++----- .../embedding_backward_split_template.cu | 18 ++-- ..._backward_split_device_kernel_template.hip | 94 +++++-------------- 5 files changed, 54 insertions(+), 110 deletions(-) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a5277a906a..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..1158721526 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - + auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 2425322948..fb125101e7 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -862,8 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1226,8 +1233,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d1a874805a..951cff4399 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -221,21 +208,15 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ @@ -244,23 +225,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,23 +261,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,21 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} + table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); @@ -346,23 +314,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,23 +341,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,12 +374,9 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -435,12 +392,9 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( From 3ef64f7c3bcca2904a74a40e794eccfa1ad3043b Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 20 Aug 2025 12:15:37 +0000 Subject: [PATCH 18/63] Add 320 embedding dim support for optimized warp_per_row kernel --- ...ing_backward_split_kernel_warp_template.cu | 2 +- .../embedding_backward_split_template.cu | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1158721526..e61b3fc0aa 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -766,7 +766,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fb125101e7..7eb2b6880f 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1243,7 +1243,7 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 8a97579d6a..5b9d69d910 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -205,6 +205,22 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + } +}; + template struct load_row_per_warp { static __device__ void @@ -304,6 +320,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + p_output[lane_id + 256] = acc[4]; + } +}; + template <> struct store_row_per_warp { From f601e553f0ce84af2ac728ec3266ac49dce8b7eb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Sep 2025 19:34:16 +0000 Subject: [PATCH 19/63] changed the max length per warp and cta per row WG size --- .../backward/embedding_backward_split_host_template.cpp | 2 +- .../training/backward/embedding_backward_split_template.cu | 6 +----- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 134a03b983..8a0cd9daca 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7eb2b6880f..86d4ce8b8b 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,11 +987,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - #ifdef USE_ROCM - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - #else - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - #endif + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 18378b6106..00673abc8b 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 4096; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index da0c69ad21..bf5b56c079 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1060,7 +1060,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From 04916da836e07175e3ecc4b27daa2960a994a360 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 9 Sep 2025 20:25:30 +0000 Subject: [PATCH 20/63] added DPP and changed max length per warp to 16k --- .../embedding_backward_split_host_template.cpp | 2 +- .../index_select/batch_index_select_dim0_host.cpp | 4 ++-- .../embedding_split_host_pt2_autograd_template.cpp | 2 +- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 14 ++++++++------ 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 8a0cd9daca..2ea96a107e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 00673abc8b..02529f2d89 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -342,7 +342,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index bf5b56c079..bcecc7c91c 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1060,7 +1060,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 0d65c4798a..a1d9819017 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,11 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); } DEVICE_INLINE void syncwarp() { From 1e09555b32d622cf1e903db762639e1dac267a11 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 19:33:44 +0000 Subject: [PATCH 21/63] guard max segment warp based on emb dim --- ...dding_split_host_pt2_autograd_template.cpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index bcecc7c91c..cbd7aceda9 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1060,7 +1060,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 16384; + int32_t max_segment_length_per_warp = 64; + // Workaround. Should not be upstreamed in any way. + // Redistribute all cta_per_row work to warp_per_row. + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + const auto B = (offsets.size(0) - 1) / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From b41192be39dc00545379016f3688c6d16ba0641d Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 22:00:20 +0000 Subject: [PATCH 22/63] added guarding opt of max segment for the case batch size list=1 --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index cbd7aceda9..787a9b6d2f 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1063,6 +1063,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and @@ -1071,9 +1072,10 @@ static torch::autograd::variable_list backward( (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); - const auto B = (offsets.size(0) - 1) / T; + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } From 2b08f965ec9e4cbd4d98ef7d4e20ccbff812341d Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 09:26:57 +0000 Subject: [PATCH 23/63] opt for grad_indice_weights kernel --- ..._backward_split_indice_weights_template.cu | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 9e1f71ef4e..b30e3e5c77 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -214,7 +214,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ) {%- endif %} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + int32_t j = 0; + {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + if (placement == PlacementType::MANAGED_CACHING) { + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + } else { + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + } + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + {%- endif %} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From 0c264705036d86ffa84ef1215bed646d90c3dc3e Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 23 Sep 2025 02:09:26 +0000 Subject: [PATCH 24/63] added store row per warp on emb 192 and added accuracy test functionality --- ...plit_table_batched_embeddings_benchmark.py | 223 +++++++++++++----- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 125 ++++++++-- .../fbgemm_gpu/rocm/split_embeddings_common.h | 18 +- 3 files changed, 277 insertions(+), 89 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..3fad8f53fe 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,7 +7,8 @@ # pyre-strict - +import gzip +import yaml import logging import os import tempfile @@ -1011,7 +1012,15 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1040,39 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1081,11 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1164,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,53 +1192,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests - + + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests assert len(requests) == iters - + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1203,34 +1280,44 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) + if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1331,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 00b51bbbe0..6d20a42c04 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,7 +19,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device - +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen logging.basicConfig(level=logging.DEBUG) @@ -248,36 +249,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -285,7 +293,86 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) + print("PASS") for it in range(iters): req = requests[it % num_reqs] diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 5b9d69d910..745499ac08 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,7 +24,6 @@ #include #include #include -#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -62,7 +61,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -79,7 +78,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); @@ -165,7 +164,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { @@ -320,6 +319,15 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); + } +}; + template <> struct store_row_per_warp { static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { @@ -619,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm +} // namespace fbgemm_gpu::rocm \ No newline at end of file From d6b491bb9c95af0142f663818c556c82960002a7 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 25/63] workgroup tuning and loop unrolled --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index aada1cdad5..69ad8cf8ca --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 70ed5e261f3eb6561b3b6625e5f4be1a5b36a50e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Sep 2025 22:38:17 +0200 Subject: [PATCH 26/63] specialize --- ..._backward_split_indice_weights_template.cu | 145 ++++++++++++------ 1 file changed, 95 insertions(+), 50 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index b30e3e5c77..0052d96406 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -217,33 +217,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only - for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { - const auto offset_idx_j0 = shfl_sync(offset_idx, j); - const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); - const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); - const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); - - const auto cache_idx_j0 = shfl_sync(cache_idx, j); - const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); - const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); - const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); - - at::acc_type grad_indice_weight0 = 0.0; - at::acc_type grad_indice_weight1 = 0.0; - at::acc_type grad_indice_weight2 = 0.0; - at::acc_type grad_indice_weight3 = 0.0; - - [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); - [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); - [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); - [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { - const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - Vec4T> weight0, weight1, weight2, weight3; - if (placement == PlacementType::MANAGED_CACHING) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; weight0 = (cache_idx_j0 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : weight_row0.load(d); @@ -259,33 +308,29 @@ __global__ __launch_bounds__(kForwardMaxThreads) void weight3 = (cache_idx_j3 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : weight_row3.load(d); - } else { - weight0 = weight_row0.load(d); - weight1 = weight_row1.load(d); - weight2 = weight_row2.load(d); - weight3 = weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; } - grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + - weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; - grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + - weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; - grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + - weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; - grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + - weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; - } - - grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); - grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); - grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); - grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - if (threadIdx.x == 0) { - grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; - grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; - grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; - grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } } } {%- endif %} @@ -447,7 +492,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] From cf6a2b1a965925e9fe9896d0ae96341d0564e582 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 24 Sep 2025 00:48:35 +0000 Subject: [PATCH 27/63] explicitly link to tbb --- cmake/modules/CppLibrary.cmake | 12 ++++++++++++ cmake/modules/GpuCppLibrary.cmake | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..388d3ac779 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PUBLIC TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) From 1be9bd8e5c693cfb55fca99b437fece631721b47 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Thu, 25 Sep 2025 19:00:23 +0000 Subject: [PATCH 28/63] added warpReduceAllSum with rocm guards --- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) mode change 100644 => 100755 fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index a1d9819017..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -140,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { - return rocm::wave_reduce< - rocm::reduce_op::sum, // Sum reduction - T, // Data type - ReduceWidth // Wave/Warp size - >(val); + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() { From 9d3ee64987ebd8296f897897cbd1310ff6f1d040 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 13 Oct 2025 20:34:59 +0000 Subject: [PATCH 29/63] revert unroll and wg tuning --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 69ad8cf8ca..acbf4563f3 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -628,7 +628,7 @@ batch_index_select_dim0_codegen_forward_kernel( constexpr int VEC_WIDTH = 4; {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 8; + constexpr int kManualUnrollLength = 4; {%- endif %} // Determine the linearized warp ID, and exit early if needed From a5a3b1ec1a9dcba556fb38b179093838bac94e5e Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 13 Oct 2025 15:46:07 -0500 Subject: [PATCH 30/63] Minor update embedding_forward_split_kernel_template.cu --- .../forward/embedding_forward_split_kernel_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index acbf4563f3..aada1cdad5 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -461,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { + for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group From 28e93c020f2a72d20a75168771e73f9ffeccb37c Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 17 Oct 2025 21:17:37 +0000 Subject: [PATCH 31/63] add tbb-devel to the install_build_tools () --- .github/scripts/utils_build.bash | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ From 842846c7b37e37ee3df6534e5102bce25bb9d11c Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 21 Oct 2025 18:54:56 +0000 Subject: [PATCH 32/63] fix lint issues --- ...plit_table_batched_embeddings_benchmark.py | 35 +++++++++---------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 4 +-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 3fad8f53fe..02d3820b07 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -1192,7 +1192,6 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: requests = [] for i in range(iters): @@ -1253,7 +1252,6 @@ def device_with_spec( # noqa C901 torch.save(req.offsets, f"{save}/{i}_offsets.pt") torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") - sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1280,23 +1278,22 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) - + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative @@ -1315,7 +1312,7 @@ def device_with_spec( # noqa C901 # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - + if save: torch.save(grad_output, f"{save}/grad_output.pt") # backward diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 6d20a42c04..f435370e36 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -305,11 +305,11 @@ def benchmark_requests( # noqa: C901 torch.save(out, f) else: torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - + out.backward(grad) torch.cuda.synchronize() torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - + if sliced: for id, t in enumerate(emb.split_embedding_weights()): if compressed: From 97aeb8329e99f2ca44902a39f02806ea5581d9c8 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 21 Oct 2025 21:23:38 +0000 Subject: [PATCH 33/63] solve lint issues --- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 18 +++++++----------- .../fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index f435370e36..3278252b5f 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -314,14 +314,13 @@ def benchmark_requests( # noqa: C901 for id, t in enumerate(emb.split_embedding_weights()): if compressed: with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) + torch.save(t[slice_min:slice_max, :].clone(), f) else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - else: if compressed: with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: @@ -332,11 +331,9 @@ def benchmark_requests( # noqa: C901 if load and emb: for it in range(iters): req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() out = emb(indices, offsets, weights) torch.cuda.synchronize() - out.backward(grad) torch.cuda.synchronize() emb_ref = copy.deepcopy(emb) @@ -346,8 +343,8 @@ def benchmark_requests( # noqa: C901 emb_ref.load_state_dict(torch.load(f)) else: emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: for id, t in enumerate(emb.split_embedding_weights()): if compressed: @@ -355,15 +352,15 @@ def benchmark_requests( # noqa: C901 w_ref = torch.load(f) else: w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + torch.testing.assert_close(t[slice_min:slice_max, :], w_ref, msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) else: for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) print("PASS") - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") @@ -375,7 +372,6 @@ def benchmark_requests( # noqa: C901 print("PASS") for it in range(iters): req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() if bwd_only: # Run forward before profiling if does backward only diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 745499ac08..aa869fe2b5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -627,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm \ No newline at end of file +} // namespace fbgemm_gpu::rocm From 00976c79cfb5bc9cf62a4cf3946dde20d19688a9 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 22 Oct 2025 18:45:41 +0000 Subject: [PATCH 34/63] applied jinja is_rocm onto optimizations for backward and forward parameters --- ..._backward_split_indice_weights_template.cu | 5 ++++- .../embedding_backward_split_template.cu | 22 ++++++++++--------- .../embedding_forward_split_template.cu | 4 ++-- .../batch_index_select_dim0_host.cpp | 4 ++-- ...dding_split_host_pt2_autograd_template.cpp | 4 ++++ 5 files changed, 24 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 0052d96406..c58ba89f78 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -213,7 +213,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} - + {%- if is_rocm %} int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only @@ -335,6 +335,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } {%- endif %} for (; j < kWarpSize && l_start + j < L; ++j) { + {%- else %} // if is_rocm + for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 86d4ce8b8b..759bbfd9bb 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,8 +987,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1059,8 +1062,8 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t total_L = indices.numel(); - #ifdef USE_ROCM + {% if is_rocm %} + int32_t total_L = indices.numel(); int32_t num_cta_per_row_groups; int32_t work_group_size; if (total_L/total_B > 1){ @@ -1071,10 +1074,10 @@ Tensor {{ embedding_cuda_op }}( num_cta_per_row_groups = kMaxThreads / kWarpSize; work_group_size = kMaxThreads; } - #else + {%- else %} int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; int32_t work_group_size = kMaxThreads; - #endif + {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1091,7 +1094,6 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1195,7 +1197,7 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - #ifdef USE_ROCM + {%- if is_rocm %} int32_t num_warp_per_row_groups; if (total_L/total_B > 1){ @@ -1204,9 +1206,9 @@ Tensor {{ embedding_cuda_op }}( else{ num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; } - #else + {%- else %} int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - #endif + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 2861b631a0..dac49631cf 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -463,7 +463,7 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); - #ifdef USE_ROCM + {% if is_rocm %} if (!rocm::is_supported_cdna()) { TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); } @@ -471,7 +471,7 @@ batch_index_select_dim0_codegen_forward_cuda( // Ensure we're running on a supported CDNA architecture (including MI350) TORCH_WARN_ONCE("Running on CDNA architecture"); } - #endif + {%- endif %} {%- if not nobag %} int32_t T = D_offsets.numel() - 1; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 02529f2d89..608f6017ec 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -342,7 +342,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 16384; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 16384; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 787a9b6d2f..21cb348c21 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,8 +743,10 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + {% if is_rocm %} const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} + {%- endif %} // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -1063,7 +1065,9 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + {% if is_rocm %} int32_t total_L = indices.numel(); + {%- endif %} {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and From 4c19030121bc36e2ea2c5a2da959bb23f9df8f50 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:54:56 +0000 Subject: [PATCH 35/63] Guard supported grad_t for optimized warp_per_row dispatch --- .../training/backward/embedding_backward_split_template.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 759bbfd9bb..18beeae1ff 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1238,7 +1238,9 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; + + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} From 9991cf1ccec41bc3e6daa7ed5f412839ed9fad89 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:56:05 +0000 Subject: [PATCH 36/63] Forward index_t to the optimizer --- .../backward/embedding_backward_split_kernel_warp_template.cu | 2 +- .../rocm/embedding_backward_split_device_kernel_template.hip | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index e61b3fc0aa..b757f64d36 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -650,7 +650,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd opt_karg.weight_decay = weight_decay; rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 951cff4399..87d259ebee 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { From b61bd196fa96c748c4c5b9fd9cbcf6c9829edd60 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 37/63] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index aa869fe2b5..08e1efa3e9 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From c38ff6f72b2734ec955de4415f9068e22c683040 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 23 Oct 2025 13:59:56 +0000 Subject: [PATCH 38/63] Fix buffer offset for emb_dim == 160 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 08e1efa3e9..c1d98d3e9f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From e201e8b7ee5b1633acdd63dd3b8fac957dff3cea Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 27 Oct 2025 14:36:52 +0000 Subject: [PATCH 39/63] Remove sanity check --- ...plit_table_batched_embeddings_benchmark.py | 190 +++++------------- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 123 ++---------- 2 files changed, 70 insertions(+), 243 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 02d3820b07..4ffb7341a5 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,8 +7,7 @@ # pyre-strict -import gzip -import yaml + import logging import os import tempfile @@ -1012,15 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) -@click.pass_context def device_with_spec( # noqa C901 - ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1040,39 +1031,7 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, ) -> None: - if load: - with open(f"{load}/params.yaml", "r") as f: - ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) - alpha = ctx.params["alpha"] - bag_size_list = ctx.params["bag_size_list"] - bag_size_sigma_list = ctx.params["bag_size_sigma_list"] - batch_size = ctx.params["batch_size"] - embedding_dim_list = ctx.params["embedding_dim_list"] - weights_precision = ctx.params["weights_precision"] - cache_precision = ctx.params["cache_precision"] - stoc = ctx.params["stoc"] - iters = ctx.params["iters"] - warmup_runs = ctx.params["warmup_runs"] - managed = ctx.params["managed"] - num_embeddings_list = ctx.params["num_embeddings_list"] - reuse = ctx.params["reuse"] - row_wise = ctx.params["row_wise"] - weighted = ctx.params["weighted"] - pooling = ctx.params["pooling"] - bounds_check_mode = ctx.params["bounds_check_mode"] - flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] - output_dtype = ctx.params["output_dtype"] - random_weights = ctx.params["random_weights"] - compressed = ctx.params["compressed"] - slice_min = ctx.params["slice_min"] - slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1081,11 +1040,6 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" - params = ctx.params - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1164,22 +1118,6 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) - elif random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1192,66 +1130,53 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: - requests = [] - for i in range(iters): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") - requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) - else: - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - del all_requests + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + + del all_requests + assert len(requests) == iters - if save: - for i in range(iters): - req = requests[i] - torch.save(req.indices, f"{save}/{i}_indices.pt") - torch.save(req.offsets, f"{save}/{i}_offsets.pt") - torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") - torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1299,22 +1224,13 @@ def device_with_spec( # noqa C901 # backward bench not representative return - if load: - grad_output = torch.load(f"{load}/grad_output.pt") + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) - else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - - if save: - torch.save(grad_output, f"{save}/grad_output.pt") + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -1328,12 +1244,6 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 3278252b5f..00b51bbbe0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,7 +11,6 @@ import statistics import threading import time -import gzip from subprocess import Popen from typing import Callable, Optional @@ -19,7 +18,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device -from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen + logging.basicConfig(level=logging.DEBUG) @@ -249,43 +248,36 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - import copy + if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - if not (load or save): - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - sliced = slice_min is not None and slice_max is not None + if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -293,85 +285,10 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] - if save and emb: - for it in range(iters): - req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - if compressed: - with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: - torch.save(out, f) - else: - torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - - out.backward(grad) - torch.cuda.synchronize() - torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max, :].clone(), f) - else: - torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max, :].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it in range(iters): - req = requests[it % num_reqs] - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max, :], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) - print("PASS") for it in range(iters): req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() if bwd_only: # Run forward before profiling if does backward only From aaaf80c07c9c52e023d6b43fbce263527b33c156 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 27 Oct 2025 20:14:58 +0000 Subject: [PATCH 40/63] address the potential lint issues and revert the change in indices_generator.cpp --- ...dding_backward_split_kernel_warp_template.cu | 11 +++++------ .../embedding_backward_split_template.cu | 13 ++++++------- ...ng_backward_split_device_kernel_template.hip | 17 ++++++++--------- .../forward/embedding_forward_split_template.cu | 2 +- ...bedding_split_host_pt2_autograd_template.cpp | 14 +++++++------- .../fbgemm_gpu/rocm/split_embeddings_common.h | 4 ++-- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 + 7 files changed, 30 insertions(+), 32 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index b757f64d36..7b3b5b653a 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,13 +32,13 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and not dense and - not nobag and + not nobag and not is_index_select and - not is_gwd_kernel and - not vbe and + not is_gwd_kernel and + not vbe and not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" @@ -621,7 +621,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ) { int32_t T = D_offsets.size(0) - 1; - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 18beeae1ff..72cf189ccc 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,13 +48,13 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and not dense and - not nobag and + not nobag and not is_index_select and - not is_gwd_kernel and - not vbe and + not is_gwd_kernel and + not vbe and not ssd %} template < @@ -669,7 +669,7 @@ Tensor {{ embedding_cuda_op }}( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -1199,7 +1199,6 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; - if (total_L/total_B > 1){ num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; } diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 87d259ebee..2a747731cc 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -225,7 +225,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -234,7 +234,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -270,7 +270,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -301,7 +301,6 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; @@ -314,7 +313,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -323,7 +322,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; @@ -341,7 +340,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; @@ -350,7 +349,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index dac49631cf..a3edb6b965 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -472,7 +472,7 @@ batch_index_select_dim0_codegen_forward_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } {%- endif %} - + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 21cb348c21..ce068b54d3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1068,18 +1068,18 @@ static torch::autograd::variable_list backward( {% if is_rocm %} int32_t total_L = indices.numel(); {%- endif %} - {%- if (not nobag) and - (optimizer == "rowwise_adagrad") and - (not vbe) and - (not is_gwd) and - (not ssd) and - (not is_index_select) and + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); auto total_B = (offsets.size(0) - 1); const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c1d98d3e9f..b5aa74c1ab 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -217,7 +217,7 @@ struct load_row_per_warp { *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; } }; @@ -335,7 +335,7 @@ struct store_row_per_warp { auto out = reinterpret_cast(p_output); out[lane_id] = *reinterpret_cast(acc); out[lane_id + 64] = *reinterpret_cast(&acc[2]); - p_output[lane_id + 256] = acc[4]; + p_output[lane_id + 256] = acc[4]; } }; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 715acd8c0c..dfea2dce8a 100755 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,6 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( + std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From b8aea67fb3f53276a66fb3c6365a17593ddeab13 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 27 Oct 2025 20:37:50 +0000 Subject: [PATCH 41/63] addresss code style issue --- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 608f6017ec..18378b6106 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -658,7 +658,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 32; auto grad_output = grad_outputs[0]; From b9a7759e8a40eeac3cd58a64b5f6336b1164026b Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 28 Oct 2025 19:16:27 +0000 Subject: [PATCH 42/63] removed guard rocm on mixed_D and refactored mixed_D var assignment --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index ce068b54d3..661f7b9b45 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,9 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; - {% if is_rocm %} - const auto mixed_D = aux_bool[IDX_MIXED_D]; - {%- endif %} + const auto mixed_D = static_cast(aux_bool[IDX_MIXED_D]); {%- endif %} // Default values for Dynamo tracing @@ -860,7 +858,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; From a4b44316fe9ab1f790f37cf2215950b212b01461 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 24 Oct 2025 14:44:16 +0000 Subject: [PATCH 43/63] Remove general load/store methods --- ..._backward_split_device_kernel_template.hip | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 397 ++++++++++++------ 2 files changed, 259 insertions(+), 140 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2a747731cc..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -410,6 +410,6 @@ L_tail_grad_acc: optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b5aa74c1ab..e6e575e6e5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,7 +21,11 @@ * ******************************************************************************/ #pragma once + +#include #include +#include + #include #include #include @@ -47,10 +51,10 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr, const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -60,8 +64,8 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) + int32_t soffset = 0, + int32_t glc_slc = 0) #if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else @@ -71,33 +75,59 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) + int32_t soffset = 0, + int32_t glc_slc = 0) #if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); #endif +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif + __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -107,35 +137,15 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host side + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { - emb_data[i] = 0.f; + static_assert(false, "HIP: Optimized load operation is not supported yet"); } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); } - } - } }; template @@ -145,7 +155,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -156,7 +166,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); } }; @@ -165,15 +175,11 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(half) * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -184,9 +190,9 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -198,10 +204,10 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; @@ -210,35 +216,15 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 320); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -256,9 +242,97 @@ struct load_row_per_warp { lane_id ); } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + } }; +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 256) * sizeof(float)); + } +}; template < typename emb_t, @@ -291,116 +365,161 @@ struct accumulate_row_per_warp { } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host function + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert(false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - out[lane_id + 64] = *reinterpret_cast(&acc[2]); +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { - auto out = reinterpret_cast(p_output); - out[lane_id] = *reinterpret_cast(acc); - out[lane_id + 64] = *reinterpret_cast(&acc[2]); - p_output[lane_id + 256] = acc[4]; +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); } }; +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32( + acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + *reinterpret_cast(acc), + out_res, + lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(&acc[2]), + out_res, + (lane_id + 64) * sizeof(floatx2_t)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32x2( + *reinterpret_cast(acc), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + lane_id * sizeof(floatx2_t)); llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), + *reinterpret_cast(&acc[2]), out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + (lane_id + 64) * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 256) * sizeof(float)); } }; From 5d4f2cd349cbb2f9f1d027851ba342e156ae4f01 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Fri, 24 Oct 2025 14:57:41 +0000 Subject: [PATCH 44/63] Move weight type check to compile-time --- .../training/backward/embedding_backward_split_template.cu | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 72cf189ccc..82acd61baa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1234,9 +1234,7 @@ Tensor {{ embedding_cuda_op }}( const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; - + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) From d3b7d7a893d0d69f968e5ce15c4d1b54836f1bcb Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 27 Oct 2025 11:51:43 +0000 Subject: [PATCH 45/63] Switch to 256B stores for float type --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 54 +++++++------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index e6e575e6e5..5475f74ddd 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -449,8 +449,7 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32( - acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); } }; @@ -458,10 +457,8 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); } }; @@ -469,12 +466,9 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; @@ -482,12 +476,9 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); } }; @@ -495,14 +486,10 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[3], out_res, (lane_id + 192) * sizeof(float)); } }; @@ -510,16 +497,11 @@ template <> struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[4], out_res, (lane_id + 256) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; From 878d00f7a9483d121f38428101f8fce87d1b64d0 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Mon, 3 Nov 2025 20:19:56 +0000 Subject: [PATCH 46/63] removed jinj is_rocm on total_L as USE_ROCM is already applied --- .../training/pt2/embedding_split_host_pt2_autograd_template.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 661f7b9b45..2b359ad06e 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1063,9 +1063,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. - {% if is_rocm %} int32_t total_L = indices.numel(); - {%- endif %} {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and From d2596c712588f23589a5f960a715d1afc604d949 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:26:50 +0000 Subject: [PATCH 47/63] Change mixed_D default value to false --- .../training/backward/embedding_backward_dense_host_cpu.cpp | 2 +- .../backward/embedding_backward_split_host_template.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 2ea96a107e..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} From e076556674df4879f96b64a433dd7f1812fafee1 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:30:55 +0000 Subject: [PATCH 48/63] Make const work_group_size for CUDA --- .../embedding_backward_split_template.cu | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 82acd61baa..f07ef5830e 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1063,20 +1063,20 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); {% if is_rocm %} - int32_t total_L = indices.numel(); - int32_t num_cta_per_row_groups; - int32_t work_group_size; - if (total_L/total_B > 1){ - num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; - work_group_size = (kMaxThreads/4); - } - else{ - num_cta_per_row_groups = kMaxThreads / kWarpSize; - work_group_size = kMaxThreads; - } + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } {%- else %} - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; - int32_t work_group_size = kMaxThreads; + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, From 585300ded18ac828bd183a51a57dba6ff788836d Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:33:04 +0000 Subject: [PATCH 49/63] Add jinja comments to grad_indice_weights kernel --- .../embedding_backward_split_indice_weights_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index c58ba89f78..57c6804e66 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -333,9 +333,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } } } - {%- endif %} + {%- endif %}{#-/* if not ssd and not dense and not use_vec_blocking and not vbe */#} for (; j < kWarpSize && l_start + j < L; ++j) { - {%- else %} // if is_rocm + {%- else %}{#-/* if is_rocm*/#} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); From e0db2f1b4a582dcbea9f6b63a75a16a4814890fa Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 09:48:15 +0000 Subject: [PATCH 50/63] Remove redundand comment --- .../training/pt2/embedding_split_host_pt2_autograd_template.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 2b359ad06e..a2304b3fb3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1061,8 +1061,6 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; int32_t max_segment_length_per_warp = 64; - // Workaround. Should not be upstreamed in any way. - // Redistribute all cta_per_row work to warp_per_row. int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and From bf143c73992f8f0242fae89a17c783977cf13694 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Thu, 6 Nov 2025 11:27:11 +0000 Subject: [PATCH 51/63] Unify cuda and rocm loops --- .../embedding_backward_split_indice_weights_template.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 57c6804e66..9ffaea3a67 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -213,9 +213,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} - {%- if is_rocm %} int32_t j = 0; - {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only if (placement != PlacementType::MANAGED_CACHING) { for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { @@ -333,11 +332,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void } } } - {%- endif %}{#-/* if not ssd and not dense and not use_vec_blocking and not vbe */#} + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} for (; j < kWarpSize && l_start + j < L; ++j) { - {%- else %}{#-/* if is_rocm*/#} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - {%- endif %} // if is_rocm const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From c6b0a8827fafc860f4c4b59db951749e1b0b71df Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Tue, 11 Nov 2025 17:21:13 +0000 Subject: [PATCH 52/63] Added BLOCK_SIZE_ROCM --- .../backward/embedding_backward_split_kernel_warp_template.cu | 2 +- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 7b3b5b653a..deffa8bfab 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -655,7 +655,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd cache_t, grad_t, index_t, - BLOCK_SIZE, + BLOCK_SIZE_ROCM, embedding_dim, segment_prefetch, segment_unroll, diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 5475f74ddd..38c1ac1ea4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -36,7 +36,7 @@ typedef float floatx2_t __attribute__((ext_vector_type(2))); #define AMDGCN_BUFFER_RES_3 0x00027000 #define AMDGCN_WAVE_SIZE 64 #define THREADS_PER_ROW 64 -#define BLOCK_SIZE 256 +#define BLOCK_SIZE_ROCM 256 namespace fbgemm_gpu::rocm { template From 122a583418a4d9377a13710fd2c9c53a0bb36889 Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 14 Nov 2025 20:36:46 +0000 Subject: [PATCH 53/63] revert the link to tbb --- cmake/modules/CppLibrary.cmake | 12 ------------ cmake/modules/GpuCppLibrary.cmake | 12 ------------ 2 files changed, 24 deletions(-) diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 388d3ac779..92a93a60b6 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,18 +168,6 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() - if(NOT TARGET TBB::tbb) - find_package(TBB QUIET) - endif() - if(TBB_FOUND) - target_link_libraries(${lib_name} PUBLIC TBB::tbb) - else() - find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) - if(TBB_LIB) - target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) - endif() - endif() - # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index e662848348..51c30df750 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,18 +302,6 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() - if(NOT TARGET TBB::tbb) - find_package(TBB QUIET) - endif() - if(TBB_FOUND) - list(APPEND library_dependencies TBB::tbb) - else() - find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) - if(TBB_LIB) - list(APPEND library_dependencies ${TBB_LIB}) - endif() - endif() - # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) From 0b05877828c0264e561eef20ff2a24671b6c8fb4 Mon Sep 17 00:00:00 2001 From: Wulley Date: Sun, 2 Nov 2025 03:11:09 +0000 Subject: [PATCH 54/63] hack param --- .../embedding_backward_split_kernel_warp_template.cu | 4 ++-- .../backward/embedding_backward_split_template.cu | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index deffa8bfab..426c4581a4 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,7 +32,7 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -546,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if is_optimized_hip_kernel_supported_mode_ori %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index f07ef5830e..9f81ac7100 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,7 +48,7 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -236,7 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if is_optimized_hip_kernel_supported_mode_ori %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -870,7 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if is_optimized_hip_kernel_supported_mode_ori %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1230,7 +1230,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if is_optimized_hip_kernel_supported_mode_ori %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); From 4f5c9ed82c5560d05640ae435faf5ae3023d36aa Mon Sep 17 00:00:00 2001 From: Wulley Date: Mon, 27 Oct 2025 06:36:55 +0000 Subject: [PATCH 55/63] support opt code_gen --- ...ing_backward_split_kernel_warp_template.cu | 339 ++++++++++++++++++ .../embedding_backward_split_template.cu | 209 ++++++++++- 2 files changed, 546 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 426c4581a4..043b1eccc7 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -41,6 +41,14 @@ not vbe and not ssd %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -341,6 +349,258 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( } } +{%- if is_optimized_hip_kernel_supported_mode %} +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t"}}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { + {%- if not nobag %} + int32_t T = D_offsets.size(0) - 1; + {%- else %} + int32_t T = weights_offsets.size(0); + {%- endif %} + const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; + +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE + const unsigned int shfl_sync_mask = + ((1L << kThreadGroupSize) - 1) << + (threadIdx.y % (kWarpSize / kThreadGroupSize) * kThreadGroupSize); +#else + const unsigned int shfl_sync_mask = 0xffffffffu; +#endif + +#define BROADCAST(val, srcLane) __builtin_amdgcn_readlane(val,srcLane) + + constexpr int VEC_WIDTH = 4; + constexpr auto kIsInt8 = std::is_same::value; + + struct SharedMemory> smem; + const int32_t grad_sum_stride = max_D / VEC_WIDTH; + auto* smem_grad_sum = (kUseVecBlocking || kIsInt8) + ? smem.getPointer() + threadIdx.y * grad_sum_stride + : nullptr; + + constexpr int num_unroll = 32; + + auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); + + for (uint32_t out_run_id = start_run_id * num_unroll; out_run_id < num_run_id; out_run_id += gridDim.x * blockDim.y * num_unroll) { + auto stride = gridDim.x * blockDim.y; + auto num_valid_id = min(num_unroll, num_run_id - out_run_id); + auto is_valid = threadIdx.x < num_valid_id; + + int32_t s_segment_start = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x)] : -1; + int32_t s_segment_end = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x + 1)] : -1; + int64_t s_idx = is_valid? sorted_linear_indices_run[out_run_id + threadIdx.x] : -1; + + {%- if not nobag %} + uint32_t s_t_0 = is_valid? reinterpret_cast(&sorted_infos[0])[s_segment_start] : -1; + s_t_0 = s_t_0 >> info_B_num_bits; + {%- else %} + auto s_t_0 = is_valid? sorted_infos[s_segment_start] : -1; + s_t_0 = s_t_0 % T; + {%- endif %} + + int64_t s_hash_size = is_valid? hash_size_cumsum[s_t_0] : -1; + s_idx -= s_hash_size; + {%- if not nobag %} + int32_t s_D_offsets_0 = is_valid? D_offsets[s_t_0] : 0; + int32_t s_D_offsets_1 = is_valid? D_offsets[s_t_0 + 1] : 0; + auto s_D = s_D_offsets_1 - s_D_offsets_0; + {%- endif %} + + int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; + int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; + int64_t s_momentum1_offset = is_valid? momentum1_offsets[s_t_0] : 0; + int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; + int32_t s_momentum1_placement = is_valid? momentum1_placements[s_t_0] : 0; + + at::acc_type* __restrict__ s_momentum1; + if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { + s_momentum1 = &momentum1_dev[s_momentum1_offset]; + } else { + s_momentum1 = &momentum1_uvm[s_momentum1_offset]; + } + + for (auto i = 0; i < num_valid_id; ++i) { + auto run_id = out_run_id + i; + auto t_0 = BROADCAST(s_t_0, i); + auto idx = BROADCAST(s_idx, i); + auto segment_start = BROADCAST(s_segment_start, i); + auto segment_end = BROADCAST(s_segment_end, i); + auto D = BROADCAST(s_D, i); + int32_t table_unique_indice_offset = BROADCAST(s_table_unique_indice_offset, i); + const int32_t SL = segment_end - segment_start; + + const int64_t weights_offset = SHFL_SYNC(s_weights_offset, i); + const auto weights_placement = static_cast(SHFL_SYNC(s_weights_placement, i)); + + const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + auto momentum1_val = momentum1[idx]; + + if (SL >= max_segment_length_per_warp) { + continue; + } + + // now, each segment corresponds to exactly one table `t` and row in + // that table (`idx`). Thus, we can hoist out some of the book-keeping. + + const int32_t SL_per_warp = div_round_up(SL, blockDim.y); + const int32_t sl_start = 0; + const int32_t sl_end = SL; + + Vec4TAcc grad_sum[kFixedMaxVecsPerThread]; + constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH; + const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth; + + compute_grad_sum_{{ kdesc }}< + grad_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_sum, + smem_grad_sum, + grad_output, + {%- if not nobag or is_index_select %} + D_offsets, + {%- endif %} + D, + T, + sorted_infos, + {%- if weighted %} + sorted_indice_weights, + {%- endif %} + {%- if not nobag and vbe %} + B_offsets, + row_output_offsets, + {%- endif %} + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + segment_start, + sl_start, + sl_end, + shfl_sync_mask, + num_vecs + ); + + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + split_rowwise_adagrad_table_update_kernel< + emb_t, + cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_name + "_ph_t" }}, + {%- endfor %} + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + dev_weights, + uvm_weights, + lxu_cache_weights, + weights_placements, + weights_offsets, + sorted_{{ locs_or_addrs_tensor }}, + grad_sum, + smem_grad_sum, + smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) + stochastic_rounding, + stochastic_rounding_philox_args, + run_id, + use_uniq_cache_locations + ? (run_id - table_unique_indices_offsets[t_0]) + : segment_start, + D, + t_0, + idx, + {%- if is_gwd_kernel %} + global_weight_decay, + {%- elif has_global_weight_decay_support %} + {# /* cases where gwd is not enabled/supported */ #} + 1, // global_weight_decay + {%- endif %} + shfl_sync_mask, + max_vecs, + momentum1, momentum1_val, learning_rate, eps, weight_decay, weight_decay_mode, max_norm + ); + } + } +} +{%- endif %} + //////////////////////////////////////////////////////////////////////////////// // Explicit Template Instantiations @@ -455,6 +715,85 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row }} {%- endif %} ); + +{%- if is_optimized_hip_kernel_supported_mode %} + +template __global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 +< {{ emb_type }}, + {{ grad_type }}, + {{ cache_type }}, + {{ index_type }}, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_type_combo[ph_name].primitive_type }}, + {%- endfor %} + {{ kFixedMaxVecsPerThread }}, + {{ kThreadGroupSize }}, + {{ kUseVecBlocking }} +> ( + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | + replace_pta_namespace() | + replace_placeholder_types(ph_type_combo) | + join(",\n ") | + replace("cache_t", cache_type) + }} + {%- endif %} +); + +{%- endif %} {%- endmacro %} {%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 9f81ac7100..f25ac1b656 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -56,6 +56,14 @@ using namespace fbgemm_gpu; not is_gwd_kernel and not vbe and not ssd %} + +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} template < typename emb_t, @@ -307,6 +315,147 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ); {%- endif %} + +{%- if is_optimized_hip_kernel_supported_mode %} + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} // if optimizer != "none" + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32 long_run_ids, + const pta::PackedTensorAccessor32 num_long_run_ids, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- if optimizer == "none" %} + const int32_t max_D, + {%- endif %} + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const pta::PackedTensorAccessor32 long_run_id_to_really_long_run_ids, + pta::PackedTensorAccessor32, 2, at::RestrictPtrTraits> temp_grad_accum, + pta::PackedTensorAccessor32 grad_accum_counter, + const int32_t max_segment_length_per_cta, + const bool use_deterministic_algorithms, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); +{%- endif %} + {% if is_index_select %} namespace index_select { {% else %} @@ -877,7 +1026,25 @@ Tensor {{ embedding_cuda_op }}( wdesc, vdesc, ) - %} + %} + {%- endif %} + + {%- if is_optimized_hip_kernel_supported_mode %} + {%- set hip_mixed_d_warp_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} + + {%- set hip_mixed_d_cta_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_cta_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} {%- endif %} AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { @@ -1029,6 +1196,10 @@ Tensor {{ embedding_cuda_op }}( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); + {%- if is_optimized_hip_kernel_supported_mode %} + const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); + {%- endif %} + DISPATCH_PLACEHOLDER_TYPES( {%- for ph_name in args.placeholder_tensor_names %} {{ ph_name + "_dev" }}.scalar_type(), @@ -1047,7 +1218,7 @@ Tensor {{ embedding_cuda_op }}( ) %} - const auto backward_cta_per_row_kernel = + auto backward_cta_per_row_kernel = {{ cta_kernel }} ; + + {%- if is_optimized_hip_kernel_supported_mode %} + if (use_hip_kernel && mixed_D) { + backward_cta_per_row_kernel = + {{ hip_mixed_d_cta_kernel }} + ; + } + {%- endif %} // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); @@ -1196,6 +1384,23 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + {%- if is_optimized_hip_kernel_supported_mode %} + if (use_hip_kernel && mixed_D) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + } + {%- endif %} + // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; From bec7db478f6c3cbfd6c9a7cc6793077c574995cf Mon Sep 17 00:00:00 2001 From: yadai Date: Wed, 6 Aug 2025 11:29:38 +0000 Subject: [PATCH 56/63] support subwarp --- ...plit_table_batched_embeddings_benchmark.py | 525 +++++++++++------- fbgemm_gpu/codegen/genscript/optimizers.py | 36 ++ ...ding_backward_split_kernel_cta_template.cu | 2 +- ...ing_backward_split_kernel_warp_template.cu | 114 ++-- .../embedding_backward_split_template.cu | 23 +- ...optimizer_split_device_kernel_template.cuh | 198 ++++++- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 251 ++++++++- 7 files changed, 882 insertions(+), 267 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..2d3755fe06 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -8,11 +8,13 @@ # pyre-strict +import gzip import logging import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional +import yaml import click import numpy as np @@ -1011,7 +1013,31 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--batch-size", default=512) +@click.option("--embedding-dim-list", type=str, default="128") +@click.option("--weights-precision", type=SparseType, default=SparseType.FP32) +@click.option("--cache-precision", type=SparseType, default=None) +@click.option("--stoc", is_flag=True, default=False) +@click.option("--iters", default=100) +@click.option("--warmup-runs", default=0) +@click.option("--managed", default="device") +@click.option("--num-embeddings-list", type=str, default="100000") +@click.option("--reuse", default=0.0) +@click.option("--row-wise/--no-row-wise", default=True) +@click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") +@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) +@click.option("--flush-gpu-cache-size-mb", default=0) +@click.option("--output-dtype", type=SparseType, default=SparseType.FP32) +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1057,40 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] + np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1099,12 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1183,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,52 +1211,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests + assert len(requests) == iters + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: @@ -1201,36 +1298,44 @@ def device_with_spec( # noqa C901 f"Accessed weights per batch: {B * sum_DLs * param_size_multiplier / 1.0e9: .2f} GB" ) + if load is None and save is None: # forward - time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) - logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + time_per_iter = benchmark_requests( + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) + logging.info( + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") + # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1349,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " @@ -1256,19 +1367,19 @@ def device_with_spec( # noqa C901 @click.option( "--batch-size-list", type=str, - required=True, + required=False, help="A comma separated list of batch sizes (B) for each table.", ) @click.option( "--embedding-dim-list", type=str, - required=True, + required=False, help="A comma separated list of embedding dimensions (D) for each table.", ) @click.option( "--bag-size-list", type=str, - required=True, + required=False, help="A comma separated list of bag sizes (L) for each table.", ) @click.option( @@ -1281,7 +1392,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-embeddings-list", type=str, - required=True, + required=False, help="A comma separated list of number of embeddings (E) for each table.", ) @click.option( @@ -1294,7 +1405,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-tables", type=int, - required=True, + required=False, help="The number of tables.", ) @click.option( @@ -1303,16 +1414,12 @@ def device_with_spec( # noqa C901 default=False, help="Whether the table is weighted or not", ) -@click.option( - "--print-kernel-summary", - is_flag=True, - default=False, - help="Whether the table is weighted or not", -) -@click.option("--ssd", is_flag=True, default=False) -@click.option( - "--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix" -) +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) @TBEBenchmarkingConfigLoader.options @EmbeddingOpsCommonConfigLoader.options @click.pass_context @@ -1326,9 +1433,12 @@ def vbe( alpha_list: str, num_tables: int, weighted: bool, - print_kernel_summary: bool, - ssd: bool, - ssd_prefix: str, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, # pyre-ignore[2] **kwargs, ) -> None: @@ -1340,6 +1450,28 @@ def vbe( np.random.seed(42) torch.manual_seed(42) + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(context.params, f, sort_keys=False) + + if load: + with open(f"{load}/params.yaml", "r") as f: + context.params = yaml.load(f, Loader=yaml.UnsafeLoader) + params = context.params + batch_size_list = params["batch_size_list"] + embedding_dim_list = params["embedding_dim_list"] + bag_size_list = params["bag_size_list"] + bag_size_sigma_list = params["bag_size_sigma_list"] + num_embeddings_list = params["num_embeddings_list"] + alpha_list = params["alpha_list"] + num_tables = params["num_tables"] + weighted = params["weighted"] + random_weights = params["random_weights"] + compressed = params["compressed"] + slice_min = params["slice_min"] + slice_max = params["slice_max"] + # Load general TBE benchmarking configuration from cli arguments benchconfig = TBEBenchmarkingConfigLoader.load(context) if benchconfig.num_requests != benchconfig.iterations: @@ -1348,6 +1480,9 @@ def vbe( if benchconfig.flush_gpu_cache_size_mb != 0: raise ValueError("--bench-flush-gpu-cache-size is not supported.") + if benchconfig.export_trace: + raise ValueError("--bench-export-trace is not supported.") + # Load common embedding op configuration from cli arguments embconfig = EmbeddingOpsCommonConfigLoader.load(context) if embconfig.uvm_host_mapped: @@ -1384,126 +1519,122 @@ def vbe( else EmbeddingLocation.HOST ) - common_split_args: dict[str, Any] = { - "weights_precision": embconfig.weights_dtype, - "stochastic_rounding": embconfig.stochastic_rounding, - "output_dtype": embconfig.output_dtype, - "pooling_mode": embconfig.pooling_mode, - "bounds_check_mode": embconfig.bounds_check_mode, - "optimizer": optimizer, - "learning_rate": 0.1, - "eps": 0.1, - "feature_table_map": list(range(T)), - } - - if ssd: - cache_set = max(T * max(Bs), 1) - tempdir = tempfile.mkdtemp(prefix=ssd_prefix) - emb = SSDTableBatchedEmbeddingBags( - [(E, D) for E, D in zip(Es, Ds)], - cache_sets=cache_set, - ssd_storage_directory=tempdir, - ssd_cache_location=EmbeddingLocation.DEVICE, - ssd_rocksdb_shards=8, - **common_split_args, - ) - else: - emb = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - E, - D, - managed_option, - get_available_compute_device(), - ) - for E, D in zip(Es, Ds) - ], - cache_precision=embconfig.cache_dtype, - **common_split_args, - ) - emb = emb.to(get_device()) - all_requests = { - "indices": [[] for _ in range(benchconfig.iterations)], - "offsets": [[] for _ in range(benchconfig.iterations)], - "weights": [[] for _ in range(benchconfig.iterations)], - } - for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): - # Generate a request for a single table. - local_requests = generate_requests( - benchconfig.iterations, - B, - 1, - L, - E, - alpha=alpha, - weighted=weighted, - sigma_L=sigma_L, - zipf_oversample_ratio=3 if L > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - - # Store requests for each table in all_requests. - for i, req in enumerate(local_requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - # pyre-ignore[53] - def _kineto_trace_handler( - p: profile, emb_op_type: str = "vbe", print_summary: bool = False - ) -> None: - p.export_chrome_trace( - benchconfig.trace_url.format(emb_op_type=emb_op_type, ospid=os.getpid()) - ) - if print_summary: - print(p.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + E, + D, + managed_option, + get_available_compute_device(), + ) + for E, D in zip(Es, Ds) + ], + optimizer=optimizer, + learning_rate=0.1, + eps=0.1, + cache_precision=embconfig.cache_dtype, + weights_precision=embconfig.weights_dtype, + stochastic_rounding=embconfig.stochastic_rounding, + output_dtype=embconfig.output_dtype, + pooling_mode=embconfig.pooling_mode, + bounds_check_mode=embconfig.bounds_check_mode, + ).to(get_device()) + + if random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") - emb_op_type = "vbe" + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) - # pyre-ignore[3, 53] - def context_factory(on_trace_ready: Callable[[profile], None]): - return ( - profile(on_trace_ready=on_trace_ready) - if benchconfig.export_trace - else nullcontext() - ) + if load: + requests = [] + for i in range(benchconfig.iterations): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + requests.append((indices, offsets, per_sample_weights)) + else: + all_requests = { + "indices": [[] for _ in range(benchconfig.iterations)], + "offsets": [[] for _ in range(benchconfig.iterations)], + "weights": [[] for _ in range(benchconfig.iterations)], + } + for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): + # Generate a request for a single table. + local_requests = generate_requests( + benchconfig.iterations, + B, + 1, + L, + E, + alpha=alpha, + weighted=weighted, + sigma_L=sigma_L, + zipf_oversample_ratio=3 if L > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) - # Combine the requests for all tables by - requests = [ - ( - torch.concat(all_requests["indices"][i]), - torch.concat(all_requests["offsets"][i]), - torch.concat(all_requests["weights"][i]) if weighted else None, - ) - for i in range(benchconfig.iterations) - ] + # Store requests for each table in all_requests. + for i, req in enumerate(local_requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + # Combine the requests for all tables by + requests = [ + ( + torch.concat(all_requests["indices"][i]), + torch.concat(all_requests["offsets"][i]), + torch.concat(all_requests["weights"][i]) if weighted else None, + ) + for i in range(benchconfig.iterations) + ] + + del all_requests - del all_requests + if save: + for i, (indices, offsets, weights) in enumerate(requests): + torch.save(indices, f"{save}/{i}_indices.pt") + torch.save(offsets, f"{save}/{i}_offsets.pt") + torch.save(weights, f"{save}/{i}_per_sample_weights.pt") - with context_factory( - lambda p: _kineto_trace_handler(p, emb_op_type, print_kernel_summary) - ): - fwd_time_sec, bwd_time_sec = benchmark_vbe( - requests, - func=lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - batch_size_per_feature_per_rank=[[B] for B in Bs], - ), - num_warmups=benchconfig.warmup_iterations, - ) + fwd_time_sec, bwd_time_sec = benchmark_vbe( + requests, + func=lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=[[B] for B in Bs], + ), + num_warmups=benchconfig.warmup_iterations, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, + ) logging.info( f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n" f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us" ) - if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index c61e6843f9..8c25dc0d8f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]: at::acc_type multiplier = 0.0; at::acc_type correction = 0.0; + """ + split_precomputation_preload = split_precomputation + split_precomputation += """ if (threadIdx.x == 0) { auto new_sum_square_grads = g_avg_square; @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]: multiplier = SHFL_SYNC(multiplier, 0); correction = SHFL_SYNC(correction, 0); """ + split_precomputation_preload += """ + if (threadIdx.x == 0) { + auto new_sum_square_grads = g_avg_square; + + // Update the optimizer state. Use optimizer state offloading only if + // SSD and if enabled by the user + if (enable_optimizer_offloading) { + // Fetch the pointer to the optimizer state along the cache row + auto* optimizer = weight_row_template.template optimizer_state_ptr(); + new_sum_square_grads += optimizer->momentum; + optimizer->momentum = new_sum_square_grads; + + } else { + new_sum_square_grads += momentum1_val; + momentum1[idx] = new_sum_square_grads; + } + + multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); + if (weight_decay_mode == 1) { + // L2 regularization + correction = 1.0 - multiplier * weight_decay; + } else if (weight_decay_mode == 2 || weight_decay_mode == 5) { + // Decoupled weight decay + correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; + } + } + multiplier = SHFL_SYNC(multiplier, 0); + correction = SHFL_SYNC(correction, 0); + """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]: }, ), "split_precomputation": split_precomputation, + "split_precomputation_preload": split_precomputation_preload, "split_weight_update": split_weight_update, "split_post_update": split_post_update, "split_weight_update_cpu": split_weight_update_cpu, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 25f7119a7a..b10eb1312e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -625,7 +625,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 043b1eccc7..1e1976fb3d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -426,6 +426,8 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; +#define SUBWARP_SHFL_SYNC(val, srcLane) __shfl_sync(UINT64_MAX, val, srcLane, kThreadGroupSize) + #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = ((1L << kThreadGroupSize) - 1) << @@ -445,7 +447,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc ? smem.getPointer() + threadIdx.y * grad_sum_stride : nullptr; - constexpr int num_unroll = 32; + constexpr int num_unroll = kThreadGroupSize; auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); @@ -476,39 +478,49 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; - int64_t s_momentum1_offset = is_valid? momentum1_offsets[s_t_0] : 0; int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; - int32_t s_momentum1_placement = is_valid? momentum1_placements[s_t_0] : 0; - at::acc_type* __restrict__ s_momentum1; - if (static_cast(s_momentum1_placement) == PlacementType::DEVICE) { - s_momentum1 = &momentum1_dev[s_momentum1_offset]; + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ s_{{ tensor }}; + const auto s_{{ tensor }}_placement = {{ tensor }}_placements[s_t_0]; + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[s_t_0]; + if (static_cast(s_{{ tensor }}_placement) == PlacementType::DEVICE) { + s_{{ tensor }} = &{{ tensor }}_dev[s_{{ tensor }}_offset]; } else { - s_momentum1 = &momentum1_uvm[s_momentum1_offset]; + s_{{ tensor }} = &{{ tensor }}_uvm[s_{{ tensor }}_offset]; } + {{ args.split_tensor_types[tensor] }} s_{{tensor}}_val = is_valid? s_{{tensor}}[s_idx] : 0; + + {%- endfor %} for (auto i = 0; i < num_valid_id; ++i) { - auto run_id = out_run_id + i; - auto t_0 = BROADCAST(s_t_0, i); - auto idx = BROADCAST(s_idx, i); - auto segment_start = BROADCAST(s_segment_start, i); - auto segment_end = BROADCAST(s_segment_end, i); - auto D = BROADCAST(s_D, i); - int32_t table_unique_indice_offset = BROADCAST(s_table_unique_indice_offset, i); + auto segment_start = SUBWARP_SHFL_SYNC(s_segment_start, i); + auto segment_end = SUBWARP_SHFL_SYNC(s_segment_end, i); const int32_t SL = segment_end - segment_start; - - const int64_t weights_offset = SHFL_SYNC(s_weights_offset, i); - const auto weights_placement = static_cast(SHFL_SYNC(s_weights_placement, i)); - - const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); - const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); - auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); - auto momentum1_val = momentum1[idx]; - if (SL >= max_segment_length_per_warp) { continue; } + auto run_id = out_run_id + i; + auto t_0 = SUBWARP_SHFL_SYNC(s_t_0, i); + auto idx = SUBWARP_SHFL_SYNC(s_idx, i); + + {%- if not nobag %} + auto D = SUBWARP_SHFL_SYNC(s_D, i); + {%- endif %} + int32_t table_unique_indice_offset = SUBWARP_SHFL_SYNC(s_table_unique_indice_offset, i); + + {%- for tensor in args.split_tensors %} + const auto {{ tensor }}_placement = SUBWARP_SHFL_SYNC(s_{{ tensor }}_placement, i); + const int64_t {{ tensor }}_offset = SUBWARP_SHFL_SYNC(s_{{ tensor }}_offset, i); + {{ args.split_tensor_types[tensor] }} {{tensor}}_val = SUBWARP_SHFL_SYNC(s_{{ tensor }}_val, i); + {%- endfor %} + + // const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + // const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + // auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + // auto momentum1_val = momentum1[idx]; + // now, each segment corresponds to exactly one table `t` and row in // that table (`idx`). Thus, we can hoist out some of the book-keeping. @@ -558,7 +570,11 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc // when kUseVecBlocking == false const int32_t max_vecs = kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; - split_rowwise_adagrad_table_update_kernel< + + {%- if not dense and optimizer != "none" %} + const int64_t weights_offset = SUBWARP_SHFL_SYNC(s_weights_offset, i); + const int32_t weights_placement = SUBWARP_SHFL_SYNC(s_weights_placement, i); + {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, {%- for ph_name in args.placeholder_tensor_names %} @@ -571,8 +587,8 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc dev_weights, uvm_weights, lxu_cache_weights, - weights_placements, - weights_offsets, + weights_placement, + weights_offset, sorted_{{ locs_or_addrs_tensor }}, grad_sum, smem_grad_sum, @@ -594,8 +610,42 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} shfl_sync_mask, max_vecs, - momentum1, momentum1_val, learning_rate, eps, weight_decay, weight_decay_mode, max_norm + {%- if ssd %} + enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + {{ tensor }}_placement, + {{ tensor }}_offset, + {{ tensor }}_val, + {%- endfor %} + {{ args.split_kernel_arg_names | join(", ") }} + ); + {%- else %} + // Write deduplicated gradient to grad_dev_weights gradient is sparse + // for split_embedding and dense for dense_embedding + {%- if dense %} + const int64_t weights_offset = weights_offsets[t_0]; + {%- else %} + // Compute offset of sparse gradient + const int64_t weights_offset = run_id * max_D; + idx = 0; + {%- endif %} + store_grad_sum< + emb_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_dev_weights, + grad_sum, + kUseVecBlocking ? smem_grad_sum : nullptr, + D, + weights_offset, + idx, + max_vecs ); + {%- endif %} // if not dense and optimizer != "none" } } } @@ -877,7 +927,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif @@ -1101,10 +1151,10 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} - {%- for cache_type in ['float', 'at::Half'] %} - {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} + {%- for emb_type in (['float', 'at::Half', 'at::BFloat16'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} + {%- for cache_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for index_type in ['int32_t', 'int64_t', 'at::BFloat16'] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index f25ac1b656..42aabae479 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1232,9 +1232,9 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (use_hip_kernel && mixed_D) { + if (!kUseVecBlocking) { backward_cta_per_row_kernel = - {{ hip_mixed_d_cta_kernel }} + {{ cta_kernel }} ; + 1, + 32, + false>; } {%- endif %} @@ -1282,7 +1282,7 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(kThreadGroupSize, num_cta_per_row_groups), + dim3(32, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1385,7 +1385,8 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; {%- if is_optimized_hip_kernel_supported_mode %} - if (use_hip_kernel && mixed_D) { + if (!kUseVecBlocking) { + printf("%s:%d call here\n", __FILE__, __LINE__); backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} ; + 1, + 32, + false>; } {%- endif %} @@ -1429,6 +1430,7 @@ Tensor {{ embedding_cuda_op }}( } auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + // auto blockSize = dim3(32, num_warp_per_row_groups); int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), @@ -1470,7 +1472,6 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} #endif - FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index e4fb6c548c..ef1a011e1d 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -11,8 +11,42 @@ #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" -#define GROUP_REDUCE_ALL_SUM(val, ...) \ - warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + +template +DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { + static_assert(kThreadGroupSize == 8 || kThreadGroupSize == 16 || kThreadGroupSize == 32 || kThreadGroupSize == 64, "Wavefront size must be 16/32/64"); + if (kThreadGroupSize == 16) { + // Reduce across 4 groups of 16 threads + value += __shfl_xor(value, 1, 16); + value += __shfl_xor(value, 2, 16); + value += __shfl_xor(value, 4, 16); + value += __shfl_xor(value, 8, 16); + } else if (kThreadGroupSize == 32) { + // Reduce across 2 groups of 32 threads + value += __shfl_xor(value, 1, 32); + value += __shfl_xor(value, 2, 32); + value += __shfl_xor(value, 4, 32); + value += __shfl_xor(value, 8, 32); + value += __shfl_xor(value, 16, 32); + } else if (kThreadGroupSize == 64) { + value += __shfl_xor(value, 1, 64); + value += __shfl_xor(value, 2, 64); + value += __shfl_xor(value, 4, 64); + value += __shfl_xor(value, 8, 64); + value += __shfl_xor(value, 16, 64); + value += __shfl_xor(value, 32, 64); + } + return value; +} + +#define GROUP_REDUCE_ALL_SUM(val, ...) subwarp_reduce_add(val) {%- set mdesc = "ssd" if ssd else "split" %} {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} @@ -176,4 +210,164 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } +{%- if is_optimized_hip_kernel_supported_mode %} +template < + typename emb_t, + typename cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {%- set ph_type = "{}_ph_t".format(ph_name) %} + typename {{ ph_type }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( + pta::PackedTensorAccessor64& dev_weights, + pta::PackedTensorAccessor64& uvm_weights, + pta::PackedTensorAccessor64& lxu_cache_weights, + const int32_t weights_placement, + const int64_t weights_offset, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + Vec4TAcc* shared_weight_update_row, + const bool stochastic_rounding, + const at::PhiloxCudaState& stochastic_rounding_philox_args, + const uint32_t run_id, + const uint32_t cache_loc_run_id, + const int32_t D, + const int32_t t, + const int64_t idx, + {%- if has_global_weight_decay_support %} + const float global_weight_decay, + {%- endif %} + const uint32_t shfl_sync_mask, + const int32_t max_vecs_per_thread, + {%- if ssd %} + const bool enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + const int32_t {{ tensor }}_placement, + const int64_t {{ tensor }}_offset, + const int64_t {{ tensor }}_val, + {%- endfor %} + {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} +) { + constexpr auto kIsInt8 = std::is_same_v; + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + emb_t* __restrict__ weights {nullptr}; + cache_t* __restrict__ cache_weights {nullptr}; + int32_t D_emb = D; + if constexpr (kIsInt8) { + D_emb += kINT8QparamsBytes; + } + if (static_cast(weights_placement) == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset + idx * D_emb]; + } else { + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset + idx * D_emb]" }}; + } + if (static_cast(weights_placement) == PlacementType::MANAGED_CACHING) { + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; + } + {%- endif %} + } + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + // const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + // const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if (static_cast({{ tensor }}_placement) == PlacementType::DEVICE) { + {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; + } else { + {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; + } + {%- endfor %} + + auto weight_row_template = + WeightRow>( + weights, + cache_weights, + D, + stochastic_rounding, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); + + float2 qparams_template; + if constexpr (kIsInt8) { + if (!cache_weights) { + qparams_template = weight_row_template.load_qparams(); + } + } + + {%- if not ssd %} + [[maybe_unused]] constexpr auto enable_optimizer_offloading = false; + {%- endif %} + + {{ split_precomputation_preload }} + + {# /* Note: technically, global weight decay (gwd) compensation should be done before + `split_precomputation`). But since decouple mode in `rowwise_adagrad` only computes correction, + the order of applying gwd does not matter. We perform gwd update before `split_weight_update` + below to minimize number of times to load weights. + So, note that the behavior may be different if you want to enable gwd for other optimizers + such as `lamb` or `partial_rowwise_lamb`. + */#} + float2 qparams_new; + {{ + generate_optimized_grad_sum_loop_access( + """ + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + Vec4TAcc& grad = {grad_vec}; + {global_weight_decay_update} + {split_weight_update} + if (kIsInt8 && !cache_weights) { + shared_weight_update_row[d_vec] = weight_new; + } else { + // qparams_new not used if type is not int8 + weight_row_template.store(weight_new, d, qparams_new); + } + """, + other_formats={ + "split_weight_update": split_weight_update, + "global_weight_decay_update": "weight_new.mul_(global_weight_decay);" if has_global_weight_decay_support else "" + }, + ) + }} + + if constexpr (kIsInt8) { + if (!cache_weights) { + // Calculate new qparams after row update + qparams_new = thrust_find_qparams>( + shared_weight_update_row, D); + weight_row_template.store_qparams(qparams_new); + + // Fetch cached updated row from shared mem and quantize on-the-fly + // when saving to lowp embedding + for (int32_t vec = 0; + (vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const auto d_vec = vec * kThreadGroupSize + threadIdx.x; + const int32_t d = d_vec * VEC_WIDTH; + weight_row_template.store( + shared_weight_update_row[d_vec], + d, + qparams_new); + } + } + } + + {{ split_post_update }} +} +{%- endif %} + // clang-format on diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 00b51bbbe0..fb6e0d97d7 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,6 +19,9 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen + +import copy logging.basicConfig(level=logging.DEBUG) @@ -248,35 +252,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() @@ -286,6 +298,94 @@ def benchmark_requests( # noqa: C901 start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out_ref = torch.load(f"{load}/{it}_fwd_grad_out.pt") + torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) + + print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) + print("FWD PASS") + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) + print("PASS") + + for it in range(iters): req = requests[it % num_reqs] @@ -609,6 +709,12 @@ def benchmark_vbe( requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> tuple[float, float]: """ A benchmark function to return the average execution time in seconds of @@ -633,14 +739,16 @@ def benchmark_vbe( """ use_cuda = torch.cuda.is_available() + sliced = slice_min is not None and slice_max is not None + if not (load or save): # Warm-ups. - for _ in range(num_warmups): - # Warm-up using the first request as done in benchmark_requests - indices, offsets, weights = requests[0] - out = func(indices, offsets, weights) - grad = torch.rand_like(out) - out.backward(grad) + for _ in range(num_warmups): + # Warm-up using the first request as done in benchmark_requests + indices, offsets, weights = requests[0] + out = func(indices, offsets, weights) + grad = torch.rand_like(out) + out.backward(grad) iters = len(requests) if use_cuda: @@ -654,6 +762,101 @@ def benchmark_vbe( fwd_times_sec = [] bwd_times_sec = [] + if save and emb: + for it, req in enumerate(requests): + + indices, offsets, weights = req + out = func(indices, offsets, weights) + torch.cuda.synchronize() + + torch.save(out, f"{save}/{it}_fwd_out.pt") + + grad = torch.rand_like(out) + if compressed: + with gzip.open(f"{save}/{it}_grad.pt.gz", "wb") as f: + torch.save(grad, f) + else: + torch.save(grad, f"{save}/{it}_grad.pt") + + out.backward(grad) + torch.cuda.synchronize() + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it, req in enumerate(requests): + + indices, offsets, weights = req + out = func(indices, offsets, weights) + torch.cuda.synchronize() + + out_ref = torch.load(f"{load}/{it}_fwd_out.pt") + torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) + + print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) + print("FWD PASS") + + if compressed: + with gzip.open(f"{load}/{it}_grad.pt.gz", "rb") as f: + grad = torch.load(f) + else: + grad = torch.load(f"{load}/{it}_grad.pt") + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) + print("PASS") + + for i, (indices, offsets, weights) in enumerate(requests): # forward if use_cuda: @@ -706,4 +909,4 @@ def benchmark_vbe( # pyre-ignore[61] bwd_time_sec = statistics.median(bwd_times_sec) - return fwd_time_sec, bwd_time_sec + return fwd_time_sec, bwd_time_sec \ No newline at end of file From a530e5c7baabed2e1848286749f1d0e7726683b0 Mon Sep 17 00:00:00 2001 From: Wulley Date: Tue, 28 Oct 2025 08:59:14 +0000 Subject: [PATCH 57/63] update subwarp kernel --- ...ing_backward_split_kernel_warp_template.cu | 1 + .../embedding_backward_split_template.cu | 49 ++++++++++++++----- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1e1976fb3d..b45202b244 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -46,6 +46,7 @@ not dense and not is_index_select and not is_gwd_kernel and + not nobag and not vbe and not ssd %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 42aabae479..107fdb085d 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -62,7 +62,8 @@ using namespace fbgemm_gpu; not dense and not is_index_select and not is_gwd_kernel and - not vbe and + not vbe and + not nobag and not ssd %} template < @@ -1231,8 +1232,10 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; {%- if is_optimized_hip_kernel_supported_mode %} - if (!kUseVecBlocking) { + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + if (max_D <= 128) { backward_cta_per_row_kernel = {{ cta_kernel }} ; + + auto cta_blockSize = dim3(32, num_cta_per_row_groups); } + {%- else %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} // Compute shared memory size for cta_per_row @@ -1282,7 +1289,7 @@ Tensor {{ embedding_cuda_op }}( FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(32, num_cta_per_row_groups), + cta_blockSize, cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1384,9 +1391,10 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- if is_optimized_hip_kernel_supported_mode %} - if (!kUseVecBlocking) { - printf("%s:%d call here\n", __FILE__, __LINE__); + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + if (use_hip_kernel && mixed_D) { backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} ; + kFixedMaxVecsPerThread, + kThreadGroupSize, + kUseVecBlocking>; + if (max_D <= 128) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + + blockSize = dim3(32, num_warp_per_row_groups); + } } + {%- else %} + // Compute shared memory size for warp_per_row + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); {%- endif %} - // Compute shared memory size for warp_per_row {%- if is_rocm %} int32_t num_warp_per_row_groups; if (total_L/total_B > 1){ @@ -1414,6 +1440,7 @@ Tensor {{ embedding_cuda_op }}( {%- else %} int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- endif %} + int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1428,10 +1455,6 @@ Tensor {{ embedding_cuda_op }}( backward_warp_per_row_kernel, used_shared_bytes); } - - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - // auto blockSize = dim3(32, num_warp_per_row_groups); - int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); From f3054d92e28ee6fef3922fc8aec855a6c21fe6fa Mon Sep 17 00:00:00 2001 From: xzhu Date: Mon, 27 Oct 2025 03:02:34 +0000 Subject: [PATCH 58/63] grad sum kernel unroll improvement --- ..._backward_split_device_kernel_template.cuh | 144 +++++++++++++----- 1 file changed, 106 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..d58f67bcb0 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,98 @@ using namespace fbgemm_gpu; +// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) +#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); +#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); +#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i); +#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i); +#define IDX_WEIGHT(i, j) at::acc_type idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i); + +#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j); +#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j); +#define REPEAT_2(X, j) X(1, j); +#define REPEAT_1(X, j) // No additional variables needed for block size 1 + +#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n); +#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); +#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n); +#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1 + +// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1) +// if nobag and is_index_select +#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]); +// elif nobag +#define GRAD_VEC_N(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[l_j_##i][d]); +// elif vbe +#define GRAD_VEC_V(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]); +// else +#define GRAD_VEC(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d); + +// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1) +#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i); +// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1) +#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i); + +// Core macro: Process blocks of specified size (block_size = 8/4/2/1) +// Parameters: +// - block_size: Size of each block to process +// - unroll_count: Number of unroll iterations for the inner loop +#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \ + for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \ + {%- if nobag %} + int32_t l_j_0 = SHFL_SYNC(l, j); \ + REPEAT_##block_size(L, j) \ + {%- elif vbe %} + /* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \ + const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \ + /* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \ + REPEAT_##block_size(GRAD_OFFSET, j) \ + {%- else %} + int32_t b_j_0 = SHFL_SYNC(b, j); \ + REPEAT_##block_size(B, j) \ + int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \ + REPEAT_##block_size(D_START, j) \ + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \ + REPEAT_##block_size(IDX_WEIGHT, j) \ + {%- endif %} + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + \ + for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \ + const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \ + /* Generate block_size Vec4TAcc objects and accumulate them */ \ + Vec4TAcc grad_out_vec_0( \ + {%- if nobag and is_index_select %} + &grad_output[grad_offset + l_j_0 * grad_stride + d] \ + {%- elif nobag %} + &grad_output[l_j_0][d] \ + {%- elif vbe %} + &grad_output[0][grad_offset_j_0 + d] \ + {%- else %} + &grad_output[b_j_0][0] + D_start_j_0 + d \ + {%- endif %} + ); \ + {%- if nobag and is_index_select %} + REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \ + {%- elif nobag %} + REPEAT_##block_size(GRAD_VEC_N, d) \ + {%- elif vbe %} + REPEAT_##block_size(GRAD_VEC_V, d) \ + {%- else %} + REPEAT_##block_size(GRAD_VEC, d) \ + {%- endif %} + \ + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \ + REPEAT_##block_size(FMA_GRAD, vec) \ + {%- else %} + grad_sum[vec].add_(grad_out_vec_0); \ + REPEAT_##block_size(ADD_GRAD, vec) \ + {%- endif %} + } \ + } + {%- if gen_once %} {#- /* The kernels in this section will be generated only once for all TBE configs @@ -141,45 +233,21 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} - for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { - {%- if nobag %} - int32_t l_j = SHFL_SYNC(l, j); - {%- elif vbe %} - const auto grad_offset_j = SHFL_SYNC(grad_offset, j); - {%- else %} - int32_t b_j = SHFL_SYNC(b, j); - int32_t D_start_j = SHFL_SYNC(D_start, j); - {%- endif %} - - {%- if weighted %} - at::acc_type idx_weight_j = SHFL_SYNC(idx_weight, j); - {%- endif %} + int32_t j = 0; - {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} - - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { - const int32_t d = {{ d }}; - Vec4TAcc grad_out_vec( - {%- if nobag and is_index_select %} - // grad_output is 1d - &grad_output[grad_offset + l_j * grad_stride + d] - {%- elif nobag %} - &grad_output[l_j][d] - {%- elif vbe %} - &grad_output[0][grad_offset_j + d] - {%- else %} - &grad_output[b_j][0] + D_start_j + d - {%- endif %} // if nobag - ); - - {%- if weighted %} - grad_sum[vec].fma_(grad_out_vec, idx_weight_j); - {%- else %} - grad_sum[vec].add_(grad_out_vec); - {%- endif %} - } - } + // Process blocks of different sizes with loop unrolling + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} From ac9e798c609f4309771ccf8201ba5e0fe83d8195 Mon Sep 17 00:00:00 2001 From: yadai Date: Wed, 29 Oct 2025 08:29:02 +0000 Subject: [PATCH 59/63] fix performance issuse --- .../embedding_backward_split_template.cu | 67 +++++++++---------- ...optimizer_split_device_kernel_template.cuh | 2 +- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 107fdb085d..c6e2ff2083 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1232,7 +1232,22 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; + {% if is_rocm %} + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } + {%- else %} int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; + {%- endif %} {%- if is_optimized_hip_kernel_supported_mode %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); if (max_D <= 128) { @@ -1249,30 +1264,15 @@ Tensor {{ embedding_cuda_op }}( 32, false>; - auto cta_blockSize = dim3(32, num_cta_per_row_groups); + cta_blockSize = dim3(32, num_cta_per_row_groups); } {%- else %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} + // printf("%s:%d %d\n", __FILE__, __LINE__, num_cta_per_row_groups); // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - {% if is_rocm %} - int32_t total_L = indices.numel(); - int32_t num_cta_per_row_groups; - int32_t work_group_size; - if (total_L/total_B > 1) { - num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; - work_group_size = (kMaxThreads/4); - } - else { - num_cta_per_row_groups = kMaxThreads / kWarpSize; - work_group_size = kMaxThreads; - } - {%- else %} - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; - const int32_t work_group_size = kMaxThreads; - {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1391,9 +1391,20 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if is_rocm %} + int32_t num_warp_per_row_groups; + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + {%- else %} + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- endif %} auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + {%- if is_optimized_hip_kernel_supported_mode %} + // printf("%s:%d warp kernel %d %d %d\n", __FILE__, __LINE__, num_warp_per_row_groups, use_hip_kernel, mixed_D); if (use_hip_kernel && mixed_D) { backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} @@ -1420,27 +1431,11 @@ Tensor {{ embedding_cuda_op }}( 1, 32, false>; - blockSize = dim3(32, num_warp_per_row_groups); + // printf("%s:%d warp kernel %d\n", __FILE__, __LINE__, num_warp_per_row_groups); } } - {%- else %} - // Compute shared memory size for warp_per_row - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); {%- endif %} - - {%- if is_rocm %} - int32_t num_warp_per_row_groups; - if (total_L/total_B > 1){ - num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; - } - else{ - num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - } - {%- else %} - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; - {%- endif %} - int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index ef1a011e1d..514d8428b9 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -251,7 +251,7 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {%- for tensor in args.split_tensors %} const int32_t {{ tensor }}_placement, const int64_t {{ tensor }}_offset, - const int64_t {{ tensor }}_val, + const {{ args.split_tensor_types[tensor] }} {{ tensor }}_val, {%- endfor %} {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} ) { From 97ef821f14569ea012d060b1165d1f4275787d2b Mon Sep 17 00:00:00 2001 From: Wulley Date: Sun, 2 Nov 2025 08:03:11 +0000 Subject: [PATCH 60/63] fix vbe opt not imply --- ...plit_table_batched_embeddings_benchmark.py | 527 +++++++----------- ..._backward_split_device_kernel_template.cuh | 8 +- ...ing_backward_split_kernel_warp_template.cu | 20 +- .../embedding_backward_split_template.cu | 38 +- ...optimizer_split_device_kernel_template.cuh | 16 +- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 249 +-------- 6 files changed, 259 insertions(+), 599 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 2d3755fe06..4dd8b3dbb3 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -8,13 +8,11 @@ # pyre-strict -import gzip import logging import os import tempfile from contextlib import nullcontext -from typing import Any, Callable, Dict, Optional -import yaml +from typing import Any, Callable, Optional import click import numpy as np @@ -1013,31 +1011,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options -@click.option("--batch-size", default=512) -@click.option("--embedding-dim-list", type=str, default="128") -@click.option("--weights-precision", type=SparseType, default=SparseType.FP32) -@click.option("--cache-precision", type=SparseType, default=None) -@click.option("--stoc", is_flag=True, default=False) -@click.option("--iters", default=100) -@click.option("--warmup-runs", default=0) -@click.option("--managed", default="device") -@click.option("--num-embeddings-list", type=str, default="100000") -@click.option("--reuse", default=0.0) -@click.option("--row-wise/--no-row-wise", default=True) -@click.option("--weighted", is_flag=True, default=False) -@click.option("--pooling", type=str, default="sum") -@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value) -@click.option("--flush-gpu-cache-size-mb", default=0) -@click.option("--output-dtype", type=SparseType, default=SparseType.FP32) -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) -@click.pass_context def device_with_spec( # noqa C901 - ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1057,40 +1031,7 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, ) -> None: - if load: - with open(f"{load}/params.yaml", "r") as f: - ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) - alpha = ctx.params["alpha"] - bag_size_list = ctx.params["bag_size_list"] - bag_size_sigma_list = ctx.params["bag_size_sigma_list"] - batch_size = ctx.params["batch_size"] - embedding_dim_list = ctx.params["embedding_dim_list"] - weights_precision = ctx.params["weights_precision"] - cache_precision = ctx.params["cache_precision"] - stoc = ctx.params["stoc"] - iters = ctx.params["iters"] - warmup_runs = ctx.params["warmup_runs"] - managed = ctx.params["managed"] - num_embeddings_list = ctx.params["num_embeddings_list"] - reuse = ctx.params["reuse"] - row_wise = ctx.params["row_wise"] - weighted = ctx.params["weighted"] - pooling = ctx.params["pooling"] - bounds_check_mode = ctx.params["bounds_check_mode"] - flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] - output_dtype = ctx.params["output_dtype"] - random_weights = ctx.params["random_weights"] - compressed = ctx.params["compressed"] - slice_min = ctx.params["slice_min"] - slice_max = ctx.params["slice_max"] - np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1099,12 +1040,6 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" - - params = ctx.params - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1183,22 +1118,6 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) - elif random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1211,68 +1130,52 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - if load: - requests = [] - for i in range(iters): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") - requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) - else: - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + + del all_requests - del all_requests - assert len(requests) == iters - if save: - for i in range(iters): - req = requests[i] - torch.save(req.indices, f"{save}/{i}_indices.pt") - torch.save(req.offsets, f"{save}/{i}_offsets.pt") - torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") - torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: @@ -1298,44 +1201,36 @@ def device_with_spec( # noqa C901 f"Accessed weights per batch: {B * sum_DLs * param_size_multiplier / 1.0e9: .2f} GB" ) - if load is None and save is None: # forward - time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) - logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + time_per_iter = benchmark_requests( + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) + logging.info( + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) if output_dtype == SparseType.INT8: # backward bench not representative return - if load: - grad_output = torch.load(f"{load}/grad_output.pt") + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) else: - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) - else: - # Obtain B * L from indices len - # pyre-ignore[19] - # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) - - if save: - torch.save(grad_output, f"{save}/grad_output.pt") - + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) # backward time_per_iter = benchmark_requests( requests, @@ -1349,12 +1244,6 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " @@ -1367,19 +1256,19 @@ def device_with_spec( # noqa C901 @click.option( "--batch-size-list", type=str, - required=False, + required=True, help="A comma separated list of batch sizes (B) for each table.", ) @click.option( "--embedding-dim-list", type=str, - required=False, + required=True, help="A comma separated list of embedding dimensions (D) for each table.", ) @click.option( "--bag-size-list", type=str, - required=False, + required=True, help="A comma separated list of bag sizes (L) for each table.", ) @click.option( @@ -1392,7 +1281,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-embeddings-list", type=str, - required=False, + required=True, help="A comma separated list of number of embeddings (E) for each table.", ) @click.option( @@ -1405,7 +1294,7 @@ def device_with_spec( # noqa C901 @click.option( "--num-tables", type=int, - required=False, + required=True, help="The number of tables.", ) @click.option( @@ -1414,12 +1303,16 @@ def device_with_spec( # noqa C901 default=False, help="Whether the table is weighted or not", ) -@click.option("--save", type=str, default=None) -@click.option("--load", type=str, default=None) -@click.option("--random-weights", is_flag=True, default=False) -@click.option("--compressed", is_flag=True, default=False) -@click.option("--slice-min", type=int, default=None) -@click.option("--slice-max", type=int, default=None) +@click.option( + "--print-kernel-summary", + is_flag=True, + default=False, + help="Whether the table is weighted or not", +) +@click.option("--ssd", is_flag=True, default=False) +@click.option( + "--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix" +) @TBEBenchmarkingConfigLoader.options @EmbeddingOpsCommonConfigLoader.options @click.pass_context @@ -1433,12 +1326,9 @@ def vbe( alpha_list: str, num_tables: int, weighted: bool, - save: str, - load: str, - random_weights: bool, - compressed: bool, - slice_min: int, - slice_max: int, + print_kernel_summary: bool, + ssd: bool, + ssd_prefix: str, # pyre-ignore[2] **kwargs, ) -> None: @@ -1450,28 +1340,6 @@ def vbe( np.random.seed(42) torch.manual_seed(42) - if save: - os.makedirs(f"{save}", exist_ok=True) - with open(f"{save}/params.yaml", "w") as f: - yaml.dump(context.params, f, sort_keys=False) - - if load: - with open(f"{load}/params.yaml", "r") as f: - context.params = yaml.load(f, Loader=yaml.UnsafeLoader) - params = context.params - batch_size_list = params["batch_size_list"] - embedding_dim_list = params["embedding_dim_list"] - bag_size_list = params["bag_size_list"] - bag_size_sigma_list = params["bag_size_sigma_list"] - num_embeddings_list = params["num_embeddings_list"] - alpha_list = params["alpha_list"] - num_tables = params["num_tables"] - weighted = params["weighted"] - random_weights = params["random_weights"] - compressed = params["compressed"] - slice_min = params["slice_min"] - slice_max = params["slice_max"] - # Load general TBE benchmarking configuration from cli arguments benchconfig = TBEBenchmarkingConfigLoader.load(context) if benchconfig.num_requests != benchconfig.iterations: @@ -1480,9 +1348,6 @@ def vbe( if benchconfig.flush_gpu_cache_size_mb != 0: raise ValueError("--bench-flush-gpu-cache-size is not supported.") - if benchconfig.export_trace: - raise ValueError("--bench-export-trace is not supported.") - # Load common embedding op configuration from cli arguments embconfig = EmbeddingOpsCommonConfigLoader.load(context) if embconfig.uvm_host_mapped: @@ -1519,122 +1384,126 @@ def vbe( else EmbeddingLocation.HOST ) - emb = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - E, - D, - managed_option, - get_available_compute_device(), - ) - for E, D in zip(Es, Ds) - ], - optimizer=optimizer, - learning_rate=0.1, - eps=0.1, - cache_precision=embconfig.cache_dtype, - weights_precision=embconfig.weights_dtype, - stochastic_rounding=embconfig.stochastic_rounding, - output_dtype=embconfig.output_dtype, - pooling_mode=embconfig.pooling_mode, - bounds_check_mode=embconfig.bounds_check_mode, - ).to(get_device()) - - if random_weights: - emb.init_embedding_weights_uniform(-1.0, 1.0) - - if save: - if compressed: - with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/model_state.pth") - - if load: - if compressed: - with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: - emb.load_state_dict(torch.load(f)) - else: - emb.load_state_dict(torch.load(f"{load}/model_state.pth")) + common_split_args: dict[str, Any] = { + "weights_precision": embconfig.weights_dtype, + "stochastic_rounding": embconfig.stochastic_rounding, + "output_dtype": embconfig.output_dtype, + "pooling_mode": embconfig.pooling_mode, + "bounds_check_mode": embconfig.bounds_check_mode, + "optimizer": optimizer, + "learning_rate": 0.1, + "eps": 0.1, + "feature_table_map": list(range(T)), + } - if load: - requests = [] - for i in range(benchconfig.iterations): - indices = torch.load(f"{load}/{i}_indices.pt") - offsets = torch.load(f"{load}/{i}_offsets.pt") - per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") - requests.append((indices, offsets, per_sample_weights)) + if ssd: + cache_set = max(T * max(Bs), 1) + tempdir = tempfile.mkdtemp(prefix=ssd_prefix) + emb = SSDTableBatchedEmbeddingBags( + [(E, D) for E, D in zip(Es, Ds)], + cache_sets=cache_set, + ssd_storage_directory=tempdir, + ssd_cache_location=EmbeddingLocation.DEVICE, + ssd_rocksdb_shards=8, + **common_split_args, + ) else: - all_requests = { - "indices": [[] for _ in range(benchconfig.iterations)], - "offsets": [[] for _ in range(benchconfig.iterations)], - "weights": [[] for _ in range(benchconfig.iterations)], - } - for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): - # Generate a request for a single table. - local_requests = generate_requests( - benchconfig.iterations, - B, - 1, - L, - E, - alpha=alpha, - weighted=weighted, - sigma_L=sigma_L, - zipf_oversample_ratio=3 if L > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + E, + D, + managed_option, + get_available_compute_device(), + ) + for E, D in zip(Es, Ds) + ], + cache_precision=embconfig.cache_dtype, + **common_split_args, + ) + emb = emb.to(get_device()) + all_requests = { + "indices": [[] for _ in range(benchconfig.iterations)], + "offsets": [[] for _ in range(benchconfig.iterations)], + "weights": [[] for _ in range(benchconfig.iterations)], + } + for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)): + # Generate a request for a single table. + local_requests = generate_requests( + benchconfig.iterations, + B, + 1, + L, + E, + alpha=alpha, + weighted=weighted, + sigma_L=sigma_L, + zipf_oversample_ratio=3 if L > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) - # Store requests for each table in all_requests. - for i, req in enumerate(local_requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - # Combine the requests for all tables by - requests = [ - ( - torch.concat(all_requests["indices"][i]), - torch.concat(all_requests["offsets"][i]), - torch.concat(all_requests["weights"][i]) if weighted else None, - ) - for i in range(benchconfig.iterations) - ] - - del all_requests + # Store requests for each table in all_requests. + for i, req in enumerate(local_requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) - if save: - for i, (indices, offsets, weights) in enumerate(requests): - torch.save(indices, f"{save}/{i}_indices.pt") - torch.save(offsets, f"{save}/{i}_offsets.pt") - torch.save(weights, f"{save}/{i}_per_sample_weights.pt") + # pyre-ignore[53] + def _kineto_trace_handler( + p: profile, emb_op_type: str = "vbe", print_summary: bool = False + ) -> None: + p.export_chrome_trace( + benchconfig.trace_url.format(emb_op_type=emb_op_type, ospid=os.getpid()) + ) + if print_summary: + print(p.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - fwd_time_sec, bwd_time_sec = benchmark_vbe( - requests, - func=lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - batch_size_per_feature_per_rank=[[B] for B in Bs], - ), - num_warmups=benchconfig.warmup_iterations, - emb=emb, - save=save, - load=load, - compressed=compressed, - slice_min=slice_min, - slice_max=slice_max, - ) + emb_op_type = "vbe" + + # pyre-ignore[3, 53] + def context_factory(on_trace_ready: Callable[[profile], None]): + return ( + profile(on_trace_ready=on_trace_ready) + if benchconfig.export_trace + else nullcontext() + ) + + # Combine the requests for all tables by + requests = [ + ( + torch.concat(all_requests["indices"][i]), + torch.concat(all_requests["offsets"][i]), + torch.concat(all_requests["weights"][i]) if weighted else None, + ) + for i in range(benchconfig.iterations) + ] + + del all_requests + + with context_factory( + lambda p: _kineto_trace_handler(p, emb_op_type, print_kernel_summary) + ): + fwd_time_sec, bwd_time_sec = benchmark_vbe( + requests, + func=lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=[[B] for B in Bs], + ), + num_warmups=benchconfig.warmup_iterations, + ) logging.info( f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n" f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us" ) + if __name__ == "__main__": - cli() + cli() \ No newline at end of file diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index d58f67bcb0..6e25c40f10 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -236,9 +236,11 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( int32_t j = 0; // Process blocks of different sizes with loop unrolling - #pragma unroll kFixedMaxVecsPerThread - PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ - vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + if constexpr (sizeof(grad_t) <= 2) { + #pragma unroll kFixedMaxVecsPerThread + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + } #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index b45202b244..50e2477a1c 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -41,14 +41,13 @@ not vbe and not ssd %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and - not dense and - not is_index_select and - not is_gwd_kernel and - not nobag and - not vbe and - not ssd %} +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -350,7 +349,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( } } -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template < typename emb_t, typename grad_t, @@ -453,7 +452,6 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); for (uint32_t out_run_id = start_run_id * num_unroll; out_run_id < num_run_id; out_run_id += gridDim.x * blockDim.y * num_unroll) { - auto stride = gridDim.x * blockDim.y; auto num_valid_id = min(num_unroll, num_run_id - out_run_id); auto is_valid = threadIdx.x < num_valid_id; @@ -767,7 +765,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} ); -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template __global__ __launch_bounds__(kBackwardMaxThreads) void hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c6e2ff2083..f8b1a24cf1 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -57,14 +57,13 @@ using namespace fbgemm_gpu; not vbe and not ssd %} -{%- set is_optimized_hip_kernel_supported_mode = is_rocm and - optimizer == "rowwise_adagrad" and - not dense and - not is_index_select and - not is_gwd_kernel and - not vbe and - not nobag and - not ssd %} +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} template < typename emb_t, @@ -317,7 +316,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd ); {%- endif %} -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template < typename emb_t, @@ -1030,7 +1029,7 @@ Tensor {{ embedding_cuda_op }}( %} {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if enable_optimized_hip_mixed_D_kernel %} {%- set hip_mixed_d_warp_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1038,14 +1037,6 @@ Tensor {{ embedding_cuda_op }}( vdesc, ) %} - - {%- set hip_mixed_d_cta_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_cta_per_row_1".format( - ndesc, - optimizer, - wdesc, - vdesc, - ) - %} {%- endif %} AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { @@ -1197,7 +1188,7 @@ Tensor {{ embedding_cuda_op }}( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if enable_optimized_hip_mixed_D_kernel %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); {%- endif %} @@ -1248,7 +1239,7 @@ Tensor {{ embedding_cuda_op }}( int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; const int32_t work_group_size = kMaxThreads; {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode %} + {%- if enable_optimized_hip_mixed_D_kernel %} auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); if (max_D <= 128) { backward_cta_per_row_kernel = @@ -1403,9 +1394,12 @@ Tensor {{ embedding_cuda_op }}( int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; {%- endif %} auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - {%- if is_optimized_hip_kernel_supported_mode %} - // printf("%s:%d warp kernel %d %d %d\n", __FILE__, __LINE__, num_warp_per_row_groups, use_hip_kernel, mixed_D); + {%- if enable_optimized_hip_mixed_D_kernel %} + {%- if vbe %} + if (use_hip_kernel) { + {%- else %} if (use_hip_kernel && mixed_D) { + {%- endif %} backward_warp_per_row_kernel = {{ hip_mixed_d_warp_kernel }} DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { @@ -210,7 +210,7 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } -{%- if is_optimized_hip_kernel_supported_mode %} +{%- if enable_optimized_hip_mixed_D_kernel %} template < typename emb_t, typename cache_t, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index fb6e0d97d7..f0ac6f1a70 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,7 +11,6 @@ import statistics import threading import time -import gzip from subprocess import Popen from typing import Callable, Optional @@ -19,9 +18,6 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device -from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen - -import copy logging.basicConfig(level=logging.DEBUG) @@ -252,43 +248,35 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - import copy + if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - if not (load or save): - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() @@ -298,94 +286,6 @@ def benchmark_requests( # noqa: C901 start_events = [] end_events = [] - if save and emb: - for it in range(iters): - req = requests[it % num_reqs] - - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - if compressed: - with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: - torch.save(out, f) - else: - torch.save(out, f"{save}/{it}_fwd_grad_out.pt") - - out.backward(grad) - torch.cuda.synchronize() - torch.save(out, f"{save}/{it}_bwd_grad_out.pt") - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it in range(iters): - req = requests[it % num_reqs] - - indices, offsets, weights = req.unpack_3() - out = emb(indices, offsets, weights) - torch.cuda.synchronize() - - out_ref = torch.load(f"{load}/{it}_fwd_grad_out.pt") - torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) - - print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) - print("FWD PASS") - - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) - print("PASS") - - for it in range(iters): req = requests[it % num_reqs] @@ -709,12 +609,6 @@ def benchmark_vbe( requests: list[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor], num_warmups: int = 0, - emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, - save: Optional[str] = None, - load: Optional[str] = None, - compressed: bool = False, - slice_min: Optional[int] = None, - slice_max: Optional[int] = None, ) -> tuple[float, float]: """ A benchmark function to return the average execution time in seconds of @@ -739,16 +633,14 @@ def benchmark_vbe( """ use_cuda = torch.cuda.is_available() - sliced = slice_min is not None and slice_max is not None - if not (load or save): # Warm-ups. - for _ in range(num_warmups): - # Warm-up using the first request as done in benchmark_requests - indices, offsets, weights = requests[0] - out = func(indices, offsets, weights) - grad = torch.rand_like(out) - out.backward(grad) + for _ in range(num_warmups): + # Warm-up using the first request as done in benchmark_requests + indices, offsets, weights = requests[0] + out = func(indices, offsets, weights) + grad = torch.rand_like(out) + out.backward(grad) iters = len(requests) if use_cuda: @@ -762,101 +654,6 @@ def benchmark_vbe( fwd_times_sec = [] bwd_times_sec = [] - if save and emb: - for it, req in enumerate(requests): - - indices, offsets, weights = req - out = func(indices, offsets, weights) - torch.cuda.synchronize() - - torch.save(out, f"{save}/{it}_fwd_out.pt") - - grad = torch.rand_like(out) - if compressed: - with gzip.open(f"{save}/{it}_grad.pt.gz", "wb") as f: - torch.save(grad, f) - else: - torch.save(grad, f"{save}/{it}_grad.pt") - - out.backward(grad) - torch.cuda.synchronize() - - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: - torch.save(t[slice_min:slice_max,:].clone(), f) - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - else: - torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") - torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") - torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") - - else: - if compressed: - with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: - torch.save(emb.state_dict(), f) - else: - torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") - - if load and emb: - for it, req in enumerate(requests): - - indices, offsets, weights = req - out = func(indices, offsets, weights) - torch.cuda.synchronize() - - out_ref = torch.load(f"{load}/{it}_fwd_out.pt") - torch.testing.assert_close(out, out_ref, atol=1.0e-3, rtol=1.0e-3) - - print(f"[{it + 1}/{iters}] Forward output check... ", end="", flush=True) - print("FWD PASS") - - if compressed: - with gzip.open(f"{load}/{it}_grad.pt.gz", "rb") as f: - grad = torch.load(f) - else: - grad = torch.load(f"{load}/{it}_grad.pt") - - out.backward(grad) - torch.cuda.synchronize() - emb_ref = copy.deepcopy(emb) - if not sliced: - if compressed: - with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: - emb_ref.load_state_dict(torch.load(f)) - else: - emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) - - print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) - if sliced: - for id, t in enumerate(emb.split_embedding_weights()): - if compressed: - with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: - w_ref = torch.load(f) - else: - w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") - torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - else: - for id, t in enumerate(emb.split_embedding_weights()): - torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], - msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) - print("PASS") - - print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) - if sliced: - m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") - m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") - else: - m_dev_ref = emb_ref.momentum1_dev - m_uvm_ref = emb_ref.momentum1_uvm - torch.testing.assert_close(emb.momentum1_dev, m_dev_ref) - torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref) - print("PASS") - - for i, (indices, offsets, weights) in enumerate(requests): # forward if use_cuda: From f19cb5d0223dedb84158c833b8b7893025b40e1d Mon Sep 17 00:00:00 2001 From: Wulley Date: Mon, 3 Nov 2025 03:08:43 +0000 Subject: [PATCH 61/63] fix smybol bug & rm comment --- .../embedding_backward_split_kernel_warp_template.cu | 12 ++++++------ .../backward/embedding_backward_split_template.cu | 10 ++++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 50e2477a1c..50bddcfeb1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,7 +32,7 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} -{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -934,7 +934,7 @@ hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc {%- endif %} -{%- if is_optimized_hip_kernel_supported_mode_ori %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -1150,10 +1150,10 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for emb_type in (['float', 'at::Half', 'at::BFloat16'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} - {%- for cache_type in ['float', 'at::Half', 'at::BFloat16'] %} - {%- for index_type in ['int32_t', 'int64_t', 'at::BFloat16'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} + {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index f8b1a24cf1..f29e32024c 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,7 +48,7 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} -{%- set is_optimized_hip_kernel_supported_mode_ori = is_rocm and +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and optimizer == "rowwise_adagrad" and not dense and not nobag and @@ -244,7 +244,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_optimized_hip_kernel_supported_mode_ori %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -1019,7 +1019,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_optimized_hip_kernel_supported_mode_ori %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1261,7 +1261,6 @@ Tensor {{ embedding_cuda_op }}( auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); {%- endif %} - // printf("%s:%d %d\n", __FILE__, __LINE__, num_cta_per_row_groups); // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( @@ -1426,7 +1425,6 @@ Tensor {{ embedding_cuda_op }}( 32, false>; blockSize = dim3(32, num_warp_per_row_groups); - // printf("%s:%d warp kernel %d\n", __FILE__, __LINE__, num_warp_per_row_groups); } } {%- endif %} @@ -1449,7 +1447,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_optimized_hip_kernel_supported_mode_ori %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); From bc733994242cad6545b005c0e5b346e4f747718f Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 13 Nov 2025 07:08:58 +0000 Subject: [PATCH 62/63] eliminate warning of process_block --- .../embedding_backward_split_device_kernel_template.cuh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 6e25c40f10..32d61bc1c8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -237,17 +237,13 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( // Process blocks of different sizes with loop unrolling if constexpr (sizeof(grad_t) <= 2) { - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) - #pragma unroll kFixedMaxVecsPerThread PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) } @@ -266,6 +262,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( } } +#undef PROCESS_BLOCK {%- endif %} // clang-format on From d4bfd1bf7d18fe53edd6d646b51a10d78a615ba4 Mon Sep 17 00:00:00 2001 From: Wulley Date: Thu, 13 Nov 2025 14:18:19 +0000 Subject: [PATCH 63/63] add rocm for macro --- ..._backward_split_device_kernel_template.cuh | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 32d61bc1c8..bb15b24f15 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,7 @@ using namespace fbgemm_gpu; +{%- if is_rocm %} // Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) #define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); #define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); @@ -105,6 +106,7 @@ using namespace fbgemm_gpu; {%- endif %} } \ } +{%- endif %} {%- if gen_once %} {#- /* @@ -235,6 +237,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} int32_t j = 0; + {%- if is_rocm %} // Process blocks of different sizes with loop unrolling if constexpr (sizeof(grad_t) <= 2) { PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ @@ -246,6 +249,50 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + +#undef PROCESS_BLOCK + + {%- else %} + for (; j < kThreadGroupSize && sl + j < sl_end; ++j) { + {%- if nobag %} + int32_t l_j = SHFL_SYNC(l, j); + {%- elif vbe %} + const auto grad_offset_j = SHFL_SYNC(grad_offset, j); + {%- else %} + int32_t b_j = SHFL_SYNC(b, j); + int32_t D_start_j = SHFL_SYNC(D_start, j); + {%- endif %} + + {%- if weighted %} + at::acc_type idx_weight_j = SHFL_SYNC(idx_weight, j); + {%- endif %} + + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { + const int32_t d = {{ d }}; + Vec4TAcc grad_out_vec( + {%- if nobag and is_index_select %} + // grad_output is 1d + &grad_output[grad_offset + l_j * grad_stride + d] + {%- elif nobag %} + &grad_output[l_j][d] + {%- elif vbe %} + &grad_output[0][grad_offset_j + d] + {%- else %} + &grad_output[b_j][0] + D_start_j + d + {%- endif %} // if nobag + ); + + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec, idx_weight_j); + {%- else %} + grad_sum[vec].add_(grad_out_vec); + {%- endif %} + } + } + {%- endif %} } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} @@ -262,7 +309,6 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( } } -#undef PROCESS_BLOCK {%- endif %} - // clang-format on + // clang-format on \ No newline at end of file