1111using Tensor = at::Tensor;
1212
1313namespace fbgemm_gpu {
14+ namespace {
15+
16+ #ifdef USE_ROCM
17+ constexpr int kGroupIndexWarpSize = kWarpSize ;
18+ #else
19+ constexpr int kGroupIndexWarpSize = kWarpSize ;
20+ #endif
1421
15- // TODO: Update UNROLL_FACTOR
1622constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1 ;
1723constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
18- GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize ;
19-
20- // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
24+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kGroupIndexWarpSize ;
2125constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
2226 log2_calc<GROUP_INDEX_SELECT_COLS_PER_WARP>::value;
2327
24- int get_group_index_select_cols_per_warp () {
25- return GROUP_INDEX_SELECT_COLS_PER_WARP;
26- }
28+ #ifdef USE_ROCM
2729
2830template <
2931 typename index_t ,
@@ -40,17 +42,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4042 const int64_t * indices_ptrs,
4143 const int64_t * warp_offsets_group,
4244 const int32_t * num_cols_group,
43- const int64_t num_work_rows, // number of rows to work on per member
45+ const int64_t num_work_rows,
4446 const int64_t group_size) {
4547 const auto total_num_warps = warp_offsets_group[group_size];
46- // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
4748 if (USE_INDEX_SELECT) {
4849 for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
4950 warp_id < total_num_warps;
5051 warp_id += gridDim .x * blockDim .y ) {
5152 int32_t member_id, member_warp_id, num_cols, warps_per_row;
5253 if (USE_VAR_COLS) {
53- __shared__ int member_ids[kMaxThreads / kWarpSize ];
54+ __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize ];
5455 if (threadIdx .x == 0 ) {
5556 binary_search_range (
5657 &member_ids[threadIdx .y ],
@@ -64,7 +65,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
6465 warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
6566 member_warp_id = warp_id - warp_offsets_group[member_id];
6667 } else {
67- // All columns are the same
6868 num_cols = num_cols_group[0 ];
6969 warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
7070 member_id = warp_id / (warps_per_row * num_work_rows);
@@ -78,7 +78,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
7878 reinterpret_cast <scalar_t *>(input_ptrs[member_id]) + col_offset;
7979 scalar_t * output =
8080 reinterpret_cast <scalar_t *>(output_ptrs[member_id]) + col_offset;
81-
8281 index_t * indices = reinterpret_cast <index_t *>(indices_ptrs[member_id]);
8382 const index_t idx = indices[row];
8483#pragma unroll
@@ -87,8 +86,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
8786 }
8887 }
8988 } else {
90- // Cache a handful of scatter destinations per warp so we can merge
91- // consecutive updates that hit the same index before touching global memory.
9289 constexpr int kCacheSlots = 2 ;
9390 index_t cached_idx[kCacheSlots ];
9491 scalar_t cached_vals[kCacheSlots ][UNROLL_FACTOR];
@@ -135,7 +132,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
135132 warp_id += gridDim .x * blockDim .y ) {
136133 int32_t member_id, member_warp_id, num_cols, warps_per_row;
137134 if (USE_VAR_COLS) {
138- __shared__ int member_ids[kMaxThreads / kWarpSize ];
135+ __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize ];
139136 if (threadIdx .x == 0 ) {
140137 binary_search_range (
141138 &member_ids[threadIdx .y ],
@@ -149,7 +146,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
149146 warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
150147 member_warp_id = warp_id - warp_offsets_group[member_id];
151148 } else {
152- // All columns are the same
153149 num_cols = num_cols_group[0 ];
154150 warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
155151 member_id = warp_id / (warps_per_row * num_work_rows);
@@ -258,6 +254,88 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
258254 }
259255}
260256
257+ #else // !USE_ROCM
258+
259+ template <
260+ typename index_t ,
261+ typename scalar_t ,
262+ bool USE_INDEX_SELECT,
263+ bool USE_VAR_COLS,
264+ int UNROLL_FACTOR,
265+ int COLS_PER_WARP,
266+ int LOG_COLS_PER_WARP>
267+ __global__
268+ __launch_bounds__ (kMaxThreads ) void group_index_select_or_add_2d_kernel(
269+ const int64_t * input_ptrs,
270+ const int64_t * output_ptrs,
271+ const int64_t * indices_ptrs,
272+ const int64_t * warp_offsets_group,
273+ const int32_t * num_cols_group,
274+ const int64_t num_work_rows,
275+ const int64_t group_size) {
276+ const auto total_num_warps = warp_offsets_group[group_size];
277+ int32_t num_cols = 0 ;
278+ int32_t warps_per_row = 0 ;
279+
280+ if constexpr (!USE_VAR_COLS) {
281+ num_cols = num_cols_group[0 ];
282+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
283+ }
284+
285+ for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
286+ warp_id < total_num_warps;
287+ warp_id += gridDim .x * blockDim .y ) {
288+ int32_t member_id = 0 ;
289+ int32_t member_warp_id = 0 ;
290+ if constexpr (USE_VAR_COLS) {
291+ __shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize ];
292+ if (threadIdx .x == 0 ) {
293+ binary_search_range (
294+ &member_ids[threadIdx .y ],
295+ warp_offsets_group + 1 ,
296+ warp_id,
297+ group_size);
298+ }
299+ syncwarp ();
300+ member_id = member_ids[threadIdx .y ];
301+ num_cols = num_cols_group[member_id];
302+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
303+ member_warp_id = warp_id - warp_offsets_group[member_id];
304+ } else {
305+ member_id = warp_id / (warps_per_row * num_work_rows);
306+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
307+ }
308+ const auto row = member_warp_id / warps_per_row;
309+ const auto col_offset =
310+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
311+ (threadIdx .x * UNROLL_FACTOR);
312+ scalar_t * input =
313+ reinterpret_cast <scalar_t *>(input_ptrs[member_id]) + col_offset;
314+ scalar_t * output =
315+ reinterpret_cast <scalar_t *>(output_ptrs[member_id]) + col_offset;
316+
317+ index_t * indices = reinterpret_cast <index_t *>(indices_ptrs[member_id]);
318+ const index_t idx = indices[row];
319+ #pragma unroll
320+ for (int i = 0 ; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
321+ if constexpr (USE_INDEX_SELECT) {
322+ output[row * num_cols + i] = LDG (&input[idx * num_cols + i]);
323+ } else {
324+ gpuAtomicAddNoReturn (
325+ &output[idx * num_cols + i], input[row * num_cols + i]);
326+ }
327+ }
328+ }
329+ }
330+
331+ #endif // USE_ROCM
332+
333+ } // namespace
334+
335+ int get_group_index_select_cols_per_warp () {
336+ return GROUP_INDEX_SELECT_COLS_PER_WARP;
337+ }
338+
261339DLL_PUBLIC void group_index_select_or_add_cuda (
262340 const int64_t * input_ptrs,
263341 const int64_t * output_ptrs,
@@ -278,36 +356,15 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
278356
279357 at::cuda::OptionalCUDAGuard device_guard (device);
280358
281- // Partition work based on num_work_rows
282- uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize ;
359+ uint32_t num_warps_per_threadblock = kMaxThreads / kGroupIndexWarpSize ;
283360 uint32_t max_grid_size =
284361 at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
285362 uint32_t grid_size = std::min (
286363 cuda_calc_xblock_count (total_num_warps, num_warps_per_threadblock),
287364 max_grid_size);
288- dim3 block_size (kWarpSize , num_warps_per_threadblock, 1 );
289-
290- #define INVOKE_GROUP_INDEX_SELECT_OR_ADD (USE_INDEX_SELECT, USE_VAR_COLS ) \
291- FBGEMM_LAUNCH_KERNEL ( \
292- (group_index_select_or_add_2d_kernel< \
293- index_t , \
294- scalar_t , \
295- USE_INDEX_SELECT, \
296- USE_VAR_COLS, \
297- GROUP_INDEX_SELECT_UNROLL_FACTOR, \
298- GROUP_INDEX_SELECT_COLS_PER_WARP, \
299- GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \
300- grid_size, \
301- block_size, \
302- 0 , \
303- at::cuda::getCurrentCUDAStream (), \
304- input_ptrs, \
305- output_ptrs, \
306- indices_ptrs, \
307- warp_offsets_group, \
308- num_cols_group, \
309- num_work_rows, \
310- group_size)
365+ dim3 block_size (kGroupIndexWarpSize , num_warps_per_threadblock, 1 );
366+
367+ #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG) FBGEMM_LAUNCH_KERNEL( (group_index_select_or_add_2d_kernel< index_t, scalar_t, USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG, GROUP_INDEX_SELECT_UNROLL_FACTOR, GROUP_INDEX_SELECT_COLS_PER_WARP, GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), grid_size, block_size, 0, at::cuda::getCurrentCUDAStream(), input_ptrs, output_ptrs, indices_ptrs, warp_offsets_group, num_cols_group, num_work_rows, group_size)
311368
312369 AT_DISPATCH_INDEX_TYPES (
313370 indices_scalar_type, " group_index_select_2d_wrapper_1" , [&] {
0 commit comments