diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..4dd8b3dbb3 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -1506,4 +1506,4 @@ def context_factory(on_trace_ready: Callable[[profile], None]): if __name__ == "__main__": - cli() + cli() \ No newline at end of file diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a5277a906a..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index c61e6843f9..8c25dc0d8f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]: at::acc_type multiplier = 0.0; at::acc_type correction = 0.0; + """ + split_precomputation_preload = split_precomputation + split_precomputation += """ if (threadIdx.x == 0) { auto new_sum_square_grads = g_avg_square; @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]: multiplier = SHFL_SYNC(multiplier, 0); correction = SHFL_SYNC(correction, 0); """ + split_precomputation_preload += """ + if (threadIdx.x == 0) { + auto new_sum_square_grads = g_avg_square; + + // Update the optimizer state. Use optimizer state offloading only if + // SSD and if enabled by the user + if (enable_optimizer_offloading) { + // Fetch the pointer to the optimizer state along the cache row + auto* optimizer = weight_row_template.template optimizer_state_ptr(); + new_sum_square_grads += optimizer->momentum; + optimizer->momentum = new_sum_square_grads; + + } else { + new_sum_square_grads += momentum1_val; + momentum1[idx] = new_sum_square_grads; + } + + multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); + if (weight_decay_mode == 1) { + // L2 regularization + correction = 1.0 - multiplier * weight_decay; + } else if (weight_decay_mode == 2 || weight_decay_mode == 5) { + // Decoupled weight decay + correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; + } + } + multiplier = SHFL_SYNC(multiplier, 0); + correction = SHFL_SYNC(correction, 0); + """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]: }, ), "split_precomputation": split_precomputation, + "split_precomputation_preload": split_precomputation_preload, "split_weight_update": split_weight_update, "split_post_update": split_post_update, "split_weight_update_cpu": split_weight_update_cpu, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..bb15b24f15 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,100 @@ using namespace fbgemm_gpu; +{%- if is_rocm %} +// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) +#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); +#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); +#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i); +#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i); +#define IDX_WEIGHT(i, j) at::acc_type idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i); + +#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j); +#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j); +#define REPEAT_2(X, j) X(1, j); +#define REPEAT_1(X, j) // No additional variables needed for block size 1 + +#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n); +#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); +#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n); +#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1 + +// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1) +// if nobag and is_index_select +#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]); +// elif nobag +#define GRAD_VEC_N(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[l_j_##i][d]); +// elif vbe +#define GRAD_VEC_V(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]); +// else +#define GRAD_VEC(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d); + +// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1) +#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i); +// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1) +#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i); + +// Core macro: Process blocks of specified size (block_size = 8/4/2/1) +// Parameters: +// - block_size: Size of each block to process +// - unroll_count: Number of unroll iterations for the inner loop +#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \ + for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \ + {%- if nobag %} + int32_t l_j_0 = SHFL_SYNC(l, j); \ + REPEAT_##block_size(L, j) \ + {%- elif vbe %} + /* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \ + const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \ + /* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \ + REPEAT_##block_size(GRAD_OFFSET, j) \ + {%- else %} + int32_t b_j_0 = SHFL_SYNC(b, j); \ + REPEAT_##block_size(B, j) \ + int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \ + REPEAT_##block_size(D_START, j) \ + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \ + REPEAT_##block_size(IDX_WEIGHT, j) \ + {%- endif %} + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + \ + for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \ + const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \ + /* Generate block_size Vec4TAcc objects and accumulate them */ \ + Vec4TAcc grad_out_vec_0( \ + {%- if nobag and is_index_select %} + &grad_output[grad_offset + l_j_0 * grad_stride + d] \ + {%- elif nobag %} + &grad_output[l_j_0][d] \ + {%- elif vbe %} + &grad_output[0][grad_offset_j_0 + d] \ + {%- else %} + &grad_output[b_j_0][0] + D_start_j_0 + d \ + {%- endif %} + ); \ + {%- if nobag and is_index_select %} + REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \ + {%- elif nobag %} + REPEAT_##block_size(GRAD_VEC_N, d) \ + {%- elif vbe %} + REPEAT_##block_size(GRAD_VEC_V, d) \ + {%- else %} + REPEAT_##block_size(GRAD_VEC, d) \ + {%- endif %} + \ + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \ + REPEAT_##block_size(FMA_GRAD, vec) \ + {%- else %} + grad_sum[vec].add_(grad_out_vec_0); \ + REPEAT_##block_size(ADD_GRAD, vec) \ + {%- endif %} + } \ + } +{%- endif %} + {%- if gen_once %} {#- /* The kernels in this section will be generated only once for all TBE configs @@ -141,7 +235,25 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} - for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { + int32_t j = 0; + + {%- if is_rocm %} + // Process blocks of different sizes with loop unrolling + if constexpr (sizeof(grad_t) <= 2) { + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + } + PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + +#undef PROCESS_BLOCK + + {%- else %} + for (; j < kThreadGroupSize && sl + j < sl_end; ++j) { {%- if nobag %} int32_t l_j = SHFL_SYNC(l, j); {%- elif vbe %} @@ -180,6 +292,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} } } + {%- endif %} } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} @@ -198,4 +311,4 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} - // clang-format on + // clang-format on \ No newline at end of file diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 134a03b983..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9ffaea3a67 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} + int32_t j = 0; + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); @@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 25f7119a7a..b10eb1312e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -625,7 +625,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..50bddcfeb1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,22 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -333,6 +349,307 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( } } +{%- if enable_optimized_hip_mixed_D_kernel %} +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t"}}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { + {%- if not nobag %} + int32_t T = D_offsets.size(0) - 1; + {%- else %} + int32_t T = weights_offsets.size(0); + {%- endif %} + const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; + +#define SUBWARP_SHFL_SYNC(val, srcLane) __shfl_sync(UINT64_MAX, val, srcLane, kThreadGroupSize) + +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE + const unsigned int shfl_sync_mask = + ((1L << kThreadGroupSize) - 1) << + (threadIdx.y % (kWarpSize / kThreadGroupSize) * kThreadGroupSize); +#else + const unsigned int shfl_sync_mask = 0xffffffffu; +#endif + +#define BROADCAST(val, srcLane) __builtin_amdgcn_readlane(val,srcLane) + + constexpr int VEC_WIDTH = 4; + constexpr auto kIsInt8 = std::is_same::value; + + struct SharedMemory> smem; + const int32_t grad_sum_stride = max_D / VEC_WIDTH; + auto* smem_grad_sum = (kUseVecBlocking || kIsInt8) + ? smem.getPointer() + threadIdx.y * grad_sum_stride + : nullptr; + + constexpr int num_unroll = kThreadGroupSize; + + auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); + + for (uint32_t out_run_id = start_run_id * num_unroll; out_run_id < num_run_id; out_run_id += gridDim.x * blockDim.y * num_unroll) { + auto num_valid_id = min(num_unroll, num_run_id - out_run_id); + auto is_valid = threadIdx.x < num_valid_id; + + int32_t s_segment_start = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x)] : -1; + int32_t s_segment_end = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x + 1)] : -1; + int64_t s_idx = is_valid? sorted_linear_indices_run[out_run_id + threadIdx.x] : -1; + + {%- if not nobag %} + uint32_t s_t_0 = is_valid? reinterpret_cast(&sorted_infos[0])[s_segment_start] : -1; + s_t_0 = s_t_0 >> info_B_num_bits; + {%- else %} + auto s_t_0 = is_valid? sorted_infos[s_segment_start] : -1; + s_t_0 = s_t_0 % T; + {%- endif %} + + int64_t s_hash_size = is_valid? hash_size_cumsum[s_t_0] : -1; + s_idx -= s_hash_size; + {%- if not nobag %} + int32_t s_D_offsets_0 = is_valid? D_offsets[s_t_0] : 0; + int32_t s_D_offsets_1 = is_valid? D_offsets[s_t_0 + 1] : 0; + auto s_D = s_D_offsets_1 - s_D_offsets_0; + {%- endif %} + + int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; + int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; + int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; + + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ s_{{ tensor }}; + const auto s_{{ tensor }}_placement = {{ tensor }}_placements[s_t_0]; + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[s_t_0]; + if (static_cast(s_{{ tensor }}_placement) == PlacementType::DEVICE) { + s_{{ tensor }} = &{{ tensor }}_dev[s_{{ tensor }}_offset]; + } else { + s_{{ tensor }} = &{{ tensor }}_uvm[s_{{ tensor }}_offset]; + } + {{ args.split_tensor_types[tensor] }} s_{{tensor}}_val = is_valid? s_{{tensor}}[s_idx] : 0; + + {%- endfor %} + + for (auto i = 0; i < num_valid_id; ++i) { + auto segment_start = SUBWARP_SHFL_SYNC(s_segment_start, i); + auto segment_end = SUBWARP_SHFL_SYNC(s_segment_end, i); + const int32_t SL = segment_end - segment_start; + if (SL >= max_segment_length_per_warp) { + continue; + } + + auto run_id = out_run_id + i; + auto t_0 = SUBWARP_SHFL_SYNC(s_t_0, i); + auto idx = SUBWARP_SHFL_SYNC(s_idx, i); + + {%- if not nobag %} + auto D = SUBWARP_SHFL_SYNC(s_D, i); + {%- endif %} + int32_t table_unique_indice_offset = SUBWARP_SHFL_SYNC(s_table_unique_indice_offset, i); + + {%- for tensor in args.split_tensors %} + const auto {{ tensor }}_placement = SUBWARP_SHFL_SYNC(s_{{ tensor }}_placement, i); + const int64_t {{ tensor }}_offset = SUBWARP_SHFL_SYNC(s_{{ tensor }}_offset, i); + {{ args.split_tensor_types[tensor] }} {{tensor}}_val = SUBWARP_SHFL_SYNC(s_{{ tensor }}_val, i); + {%- endfor %} + + // const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + // const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + // auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + // auto momentum1_val = momentum1[idx]; + + // now, each segment corresponds to exactly one table `t` and row in + // that table (`idx`). Thus, we can hoist out some of the book-keeping. + + const int32_t SL_per_warp = div_round_up(SL, blockDim.y); + const int32_t sl_start = 0; + const int32_t sl_end = SL; + + Vec4TAcc grad_sum[kFixedMaxVecsPerThread]; + constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH; + const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth; + + compute_grad_sum_{{ kdesc }}< + grad_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_sum, + smem_grad_sum, + grad_output, + {%- if not nobag or is_index_select %} + D_offsets, + {%- endif %} + D, + T, + sorted_infos, + {%- if weighted %} + sorted_indice_weights, + {%- endif %} + {%- if not nobag and vbe %} + B_offsets, + row_output_offsets, + {%- endif %} + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + segment_start, + sl_start, + sl_end, + shfl_sync_mask, + num_vecs + ); + + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + + {%- if not dense and optimizer != "none" %} + const int64_t weights_offset = SUBWARP_SHFL_SYNC(s_weights_offset, i); + const int32_t weights_placement = SUBWARP_SHFL_SYNC(s_weights_placement, i); + {{ mdesc }}_{{ optimizer }}_table_update_kernel< + emb_t, + cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_name + "_ph_t" }}, + {%- endfor %} + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + dev_weights, + uvm_weights, + lxu_cache_weights, + weights_placement, + weights_offset, + sorted_{{ locs_or_addrs_tensor }}, + grad_sum, + smem_grad_sum, + smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) + stochastic_rounding, + stochastic_rounding_philox_args, + run_id, + use_uniq_cache_locations + ? (run_id - table_unique_indices_offsets[t_0]) + : segment_start, + D, + t_0, + idx, + {%- if is_gwd_kernel %} + global_weight_decay, + {%- elif has_global_weight_decay_support %} + {# /* cases where gwd is not enabled/supported */ #} + 1, // global_weight_decay + {%- endif %} + shfl_sync_mask, + max_vecs, + {%- if ssd %} + enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + {{ tensor }}_placement, + {{ tensor }}_offset, + {{ tensor }}_val, + {%- endfor %} + {{ args.split_kernel_arg_names | join(", ") }} + ); + {%- else %} + // Write deduplicated gradient to grad_dev_weights gradient is sparse + // for split_embedding and dense for dense_embedding + {%- if dense %} + const int64_t weights_offset = weights_offsets[t_0]; + {%- else %} + // Compute offset of sparse gradient + const int64_t weights_offset = run_id * max_D; + idx = 0; + {%- endif %} + store_grad_sum< + emb_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_dev_weights, + grad_sum, + kUseVecBlocking ? smem_grad_sum : nullptr, + D, + weights_offset, + idx, + max_vecs + ); + {%- endif %} // if not dense and optimizer != "none" + } + } +} +{%- endif %} + //////////////////////////////////////////////////////////////////////////////// // Explicit Template Instantiations @@ -447,6 +764,85 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row }} {%- endif %} ); + +{%- if enable_optimized_hip_mixed_D_kernel %} + +template __global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 +< {{ emb_type }}, + {{ grad_type }}, + {{ cache_type }}, + {{ index_type }}, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_type_combo[ph_name].primitive_type }}, + {%- endfor %} + {{ kFixedMaxVecsPerThread }}, + {{ kThreadGroupSize }}, + {{ kUseVecBlocking }} +> ( + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | + replace_pta_namespace() | + replace_placeholder_types(ph_type_combo) | + join(",\n ") | + replace("cache_t", cache_type) + }} + {%- endif %} +); + +{%- endif %} {%- endmacro %} {%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} @@ -530,7 +926,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif @@ -538,7 +934,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +1008,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +1023,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,30 +1035,15 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, grad_t, index_t, - BLOCK_SIZE, + BLOCK_SIZE_ROCM, embedding_dim, segment_prefetch, segment_unroll, @@ -680,16 +1054,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} @@ -784,7 +1153,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 76eba64c99..f29e32024c --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,23 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +244,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -299,6 +315,147 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ); {%- endif %} + +{%- if enable_optimized_hip_mixed_D_kernel %} + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} // if optimizer != "none" + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32 long_run_ids, + const pta::PackedTensorAccessor32 num_long_run_ids, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- if optimizer == "none" %} + const int32_t max_D, + {%- endif %} + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const pta::PackedTensorAccessor32 long_run_id_to_really_long_run_ids, + pta::PackedTensorAccessor32, 2, at::RestrictPtrTraits> temp_grad_accum, + pta::PackedTensorAccessor32 grad_accum_counter, + const int32_t max_segment_length_per_cta, + const bool use_deterministic_algorithms, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); +{%- endif %} + {% if is_index_select %} namespace index_select { {% else %} @@ -652,6 +809,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -852,15 +1019,24 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, wdesc, vdesc, ) - %} + %} + {%- endif %} + + {%- if enable_optimized_hip_mixed_D_kernel %} + {%- set hip_mixed_d_warp_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} {%- endif %} AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { @@ -970,8 +1146,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1009,6 +1188,10 @@ Tensor {{ embedding_cuda_op }}( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); + {%- if enable_optimized_hip_mixed_D_kernel %} + const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); + {%- endif %} + DISPATCH_PLACEHOLDER_TYPES( {%- for ph_name in args.placeholder_tensor_names %} {{ ph_name + "_dev" }}.scalar_type(), @@ -1027,7 +1210,7 @@ Tensor {{ embedding_cuda_op }}( ) %} - const auto backward_cta_per_row_kernel = + auto backward_cta_per_row_kernel = {{ cta_kernel }} ; + + {% if is_rocm %} + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } + {%- else %} + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; + {%- endif %} + {%- if enable_optimized_hip_mixed_D_kernel %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + if (max_D <= 128) { + backward_cta_per_row_kernel = + {{ cta_kernel }} + ; + + cta_blockSize = dim3(32, num_cta_per_row_groups); + } + {%- else %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + {%- endif %} // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,13 +1273,13 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(kThreadGroupSize, num_cta_per_row_groups), + cta_blockSize, cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1161,8 +1381,53 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; - // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- if is_rocm %} + int32_t num_warp_per_row_groups; + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + {%- else %} + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- endif %} + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + {%- if enable_optimized_hip_mixed_D_kernel %} + {%- if vbe %} + if (use_hip_kernel) { + {%- else %} + if (use_hip_kernel && mixed_D) { + {%- endif %} + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + if (max_D <= 128) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + blockSize = dim3(32, num_warp_per_row_groups); + } + } + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1177,26 +1442,22 @@ Tensor {{ embedding_cuda_op }}( backward_warp_per_row_kernel, used_shared_bytes); } - - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { @@ -1221,7 +1482,6 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} #endif - FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -179,7 +166,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; @@ -221,22 +208,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -244,24 +225,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,24 +261,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,22 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -346,24 +313,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,24 +340,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,13 +373,10 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -435,13 +391,10 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -452,11 +405,11 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 37e774bb49..a3edb6b965 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -459,6 +463,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + {% if is_rocm %} + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + {%- endif %} + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index e4fb6c548c..b4c943f769 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -11,8 +11,42 @@ #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" -#define GROUP_REDUCE_ALL_SUM(val, ...) \ - warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} + +template +DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { + static_assert(kThreadGroupSize == 8 || kThreadGroupSize == 16 || kThreadGroupSize == 32 || kThreadGroupSize == 64, "Wavefront size must be 16/32/64"); + if (kThreadGroupSize == 16) { + // Reduce across 4 groups of 16 threads + value += __shfl_xor(value, 1, 16); + value += __shfl_xor(value, 2, 16); + value += __shfl_xor(value, 4, 16); + value += __shfl_xor(value, 8, 16); + } else if (kThreadGroupSize == 32) { + // Reduce across 2 groups of 32 threads + value += __shfl_xor(value, 1, 32); + value += __shfl_xor(value, 2, 32); + value += __shfl_xor(value, 4, 32); + value += __shfl_xor(value, 8, 32); + value += __shfl_xor(value, 16, 32); + } else if (kThreadGroupSize == 64) { + value += __shfl_xor(value, 1, 64); + value += __shfl_xor(value, 2, 64); + value += __shfl_xor(value, 4, 64); + value += __shfl_xor(value, 8, 64); + value += __shfl_xor(value, 16, 64); + value += __shfl_xor(value, 32, 64); + } + return value; +} + +#define GROUP_REDUCE_ALL_SUM(val, ...) subwarp_reduce_add(val) {%- set mdesc = "ssd" if ssd else "split" %} {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} @@ -176,4 +210,164 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } +{%- if enable_optimized_hip_mixed_D_kernel %} +template < + typename emb_t, + typename cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {%- set ph_type = "{}_ph_t".format(ph_name) %} + typename {{ ph_type }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( + pta::PackedTensorAccessor64& dev_weights, + pta::PackedTensorAccessor64& uvm_weights, + pta::PackedTensorAccessor64& lxu_cache_weights, + const int32_t weights_placement, + const int64_t weights_offset, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + Vec4TAcc* shared_weight_update_row, + const bool stochastic_rounding, + const at::PhiloxCudaState& stochastic_rounding_philox_args, + const uint32_t run_id, + const uint32_t cache_loc_run_id, + const int32_t D, + const int32_t t, + const int64_t idx, + {%- if has_global_weight_decay_support %} + const float global_weight_decay, + {%- endif %} + const uint32_t shfl_sync_mask, + const int32_t max_vecs_per_thread, + {%- if ssd %} + const bool enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + const int32_t {{ tensor }}_placement, + const int64_t {{ tensor }}_offset, + const {{ args.split_tensor_types[tensor] }} {{ tensor }}_val, + {%- endfor %} + {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} +) { + constexpr auto kIsInt8 = std::is_same_v; + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + emb_t* __restrict__ weights {nullptr}; + cache_t* __restrict__ cache_weights {nullptr}; + int32_t D_emb = D; + if constexpr (kIsInt8) { + D_emb += kINT8QparamsBytes; + } + if (static_cast(weights_placement) == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset + idx * D_emb]; + } else { + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset + idx * D_emb]" }}; + } + if (static_cast(weights_placement) == PlacementType::MANAGED_CACHING) { + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; + } + {%- endif %} + } + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + // const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + // const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if (static_cast({{ tensor }}_placement) == PlacementType::DEVICE) { + {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; + } else { + {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; + } + {%- endfor %} + + auto weight_row_template = + WeightRow>( + weights, + cache_weights, + D, + stochastic_rounding, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); + + float2 qparams_template; + if constexpr (kIsInt8) { + if (!cache_weights) { + qparams_template = weight_row_template.load_qparams(); + } + } + + {%- if not ssd %} + [[maybe_unused]] constexpr auto enable_optimizer_offloading = false; + {%- endif %} + + {{ split_precomputation_preload }} + + {# /* Note: technically, global weight decay (gwd) compensation should be done before + `split_precomputation`). But since decouple mode in `rowwise_adagrad` only computes correction, + the order of applying gwd does not matter. We perform gwd update before `split_weight_update` + below to minimize number of times to load weights. + So, note that the behavior may be different if you want to enable gwd for other optimizers + such as `lamb` or `partial_rowwise_lamb`. + */#} + float2 qparams_new; + {{ + generate_optimized_grad_sum_loop_access( + """ + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + Vec4TAcc& grad = {grad_vec}; + {global_weight_decay_update} + {split_weight_update} + if (kIsInt8 && !cache_weights) { + shared_weight_update_row[d_vec] = weight_new; + } else { + // qparams_new not used if type is not int8 + weight_row_template.store(weight_new, d, qparams_new); + } + """, + other_formats={ + "split_weight_update": split_weight_update, + "global_weight_decay_update": "weight_new.mul_(global_weight_decay);" if has_global_weight_decay_support else "" + }, + ) + }} + + if constexpr (kIsInt8) { + if (!cache_weights) { + // Calculate new qparams after row update + qparams_new = thrust_find_qparams>( + shared_weight_update_row, D); + weight_row_template.store_qparams(qparams_new); + + // Fetch cached updated row from shared mem and quantize on-the-fly + // when saving to lowp embedding + for (int32_t vec = 0; + (vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const auto d_vec = vec * kThreadGroupSize + threadIdx.x; + const int32_t d = d_vec * VEC_WIDTH; + weight_row_template.store( + shared_weight_update_row[d_vec], + d, + qparams_new); + } + } + } + + {{ split_post_update }} +} +{%- endif %} + // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 0f4814fb41..a2304b3fb3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,6 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = static_cast(aux_bool[IDX_MIXED_D]); {%- endif %} // Default values for Dynamo tracing @@ -857,7 +858,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; @@ -1059,7 +1060,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + int32_t max_segment_length_per_warp = 64; + int32_t total_L = indices.numel(); + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 0a2a7ab0a1..4f1741b3dc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -820,7 +820,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2517,6 +2517,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2535,6 +2536,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2544,6 +2546,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 00b51bbbe0..f0ac6f1a70 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -706,4 +706,4 @@ def benchmark_vbe( # pyre-ignore[61] bwd_time_sec = statistics.median(bwd_times_sec) - return fwd_time_sec, bwd_time_sec + return fwd_time_sec, bwd_time_sec \ No newline at end of file diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..38c1ac1ea4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,9 +21,14 @@ * ******************************************************************************/ #pragma once + +#include #include +#include + #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -31,7 +36,7 @@ typedef float floatx2_t __attribute__((ext_vector_type(2))); #define AMDGCN_BUFFER_RES_3 0x00027000 #define AMDGCN_WAVE_SIZE 64 #define THREADS_PER_ROW 64 -#define BLOCK_SIZE 256 +#define BLOCK_SIZE_ROCM 256 namespace fbgemm_gpu::rocm { template @@ -46,10 +51,10 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr, const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -59,34 +64,70 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -96,35 +137,15 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host side + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { - emb_data[i] = 0.f; + static_assert(false, "HIP: Optimized load operation is not supported yet"); } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); } - } - } }; template @@ -134,7 +155,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -145,7 +166,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); } }; @@ -154,15 +175,11 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(half) * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -173,9 +190,9 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -187,31 +204,133 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; template -struct load_row_per_warp { +struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 256) * sizeof(float)); } }; @@ -233,93 +352,156 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host function + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert(false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[3], out_res, (lane_id + 192) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; @@ -471,7 +653,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index 0d65c4798a..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() { diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755