Skip to content

Commit 76b963d

Browse files
committed
added cuda implementation and rocm guards
1 parent a474843 commit 76b963d

File tree

1 file changed

+98
-41
lines changed

1 file changed

+98
-41
lines changed

fbgemm_gpu/src/sparse_ops/sparse_group_index.cu

Lines changed: 98 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111
using Tensor = at::Tensor;
1212

1313
namespace fbgemm_gpu {
14+
namespace {
15+
16+
#ifdef USE_ROCM
17+
constexpr int kGroupIndexWarpSize = kWarpSize;
18+
#else
19+
constexpr int kGroupIndexWarpSize = kWarpSize;
20+
#endif
1421

15-
// TODO: Update UNROLL_FACTOR
1622
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
1723
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
18-
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
19-
20-
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
24+
GROUP_INDEX_SELECT_UNROLL_FACTOR * kGroupIndexWarpSize;
2125
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
2226
log2_calc<GROUP_INDEX_SELECT_COLS_PER_WARP>::value;
2327

24-
int get_group_index_select_cols_per_warp() {
25-
return GROUP_INDEX_SELECT_COLS_PER_WARP;
26-
}
28+
#ifdef USE_ROCM
2729

2830
template <
2931
typename index_t,
@@ -40,17 +42,16 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4042
const int64_t* indices_ptrs,
4143
const int64_t* warp_offsets_group,
4244
const int32_t* num_cols_group,
43-
const int64_t num_work_rows, // number of rows to work on per member
45+
const int64_t num_work_rows,
4446
const int64_t group_size) {
4547
const auto total_num_warps = warp_offsets_group[group_size];
46-
// USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
4748
if (USE_INDEX_SELECT) {
4849
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
4950
warp_id < total_num_warps;
5051
warp_id += gridDim.x * blockDim.y) {
5152
int32_t member_id, member_warp_id, num_cols, warps_per_row;
5253
if (USE_VAR_COLS) {
53-
__shared__ int member_ids[kMaxThreads / kWarpSize];
54+
__shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize];
5455
if (threadIdx.x == 0) {
5556
binary_search_range(
5657
&member_ids[threadIdx.y],
@@ -64,7 +65,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
6465
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
6566
member_warp_id = warp_id - warp_offsets_group[member_id];
6667
} else {
67-
// All columns are the same
6868
num_cols = num_cols_group[0];
6969
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
7070
member_id = warp_id / (warps_per_row * num_work_rows);
@@ -78,7 +78,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
7878
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
7979
scalar_t* output =
8080
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;
81-
8281
index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
8382
const index_t idx = indices[row];
8483
#pragma unroll
@@ -87,8 +86,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
8786
}
8887
}
8988
} else {
90-
// Cache a handful of scatter destinations per warp so we can merge
91-
// consecutive updates that hit the same index before touching global memory.
9289
constexpr int kCacheSlots = 2;
9390
index_t cached_idx[kCacheSlots];
9491
scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
@@ -135,7 +132,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
135132
warp_id += gridDim.x * blockDim.y) {
136133
int32_t member_id, member_warp_id, num_cols, warps_per_row;
137134
if (USE_VAR_COLS) {
138-
__shared__ int member_ids[kMaxThreads / kWarpSize];
135+
__shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize];
139136
if (threadIdx.x == 0) {
140137
binary_search_range(
141138
&member_ids[threadIdx.y],
@@ -149,7 +146,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
149146
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
150147
member_warp_id = warp_id - warp_offsets_group[member_id];
151148
} else {
152-
// All columns are the same
153149
num_cols = num_cols_group[0];
154150
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
155151
member_id = warp_id / (warps_per_row * num_work_rows);
@@ -258,6 +254,88 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
258254
}
259255
}
260256

257+
#else // !USE_ROCM
258+
259+
template <
260+
typename index_t,
261+
typename scalar_t,
262+
bool USE_INDEX_SELECT,
263+
bool USE_VAR_COLS,
264+
int UNROLL_FACTOR,
265+
int COLS_PER_WARP,
266+
int LOG_COLS_PER_WARP>
267+
__global__
268+
__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
269+
const int64_t* input_ptrs,
270+
const int64_t* output_ptrs,
271+
const int64_t* indices_ptrs,
272+
const int64_t* warp_offsets_group,
273+
const int32_t* num_cols_group,
274+
const int64_t num_work_rows,
275+
const int64_t group_size) {
276+
const auto total_num_warps = warp_offsets_group[group_size];
277+
int32_t num_cols = 0;
278+
int32_t warps_per_row = 0;
279+
280+
if constexpr (!USE_VAR_COLS) {
281+
num_cols = num_cols_group[0];
282+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
283+
}
284+
285+
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
286+
warp_id < total_num_warps;
287+
warp_id += gridDim.x * blockDim.y) {
288+
int32_t member_id = 0;
289+
int32_t member_warp_id = 0;
290+
if constexpr (USE_VAR_COLS) {
291+
__shared__ int member_ids[kMaxThreads / kGroupIndexWarpSize];
292+
if (threadIdx.x == 0) {
293+
binary_search_range(
294+
&member_ids[threadIdx.y],
295+
warp_offsets_group + 1,
296+
warp_id,
297+
group_size);
298+
}
299+
syncwarp();
300+
member_id = member_ids[threadIdx.y];
301+
num_cols = num_cols_group[member_id];
302+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
303+
member_warp_id = warp_id - warp_offsets_group[member_id];
304+
} else {
305+
member_id = warp_id / (warps_per_row * num_work_rows);
306+
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
307+
}
308+
const auto row = member_warp_id / warps_per_row;
309+
const auto col_offset =
310+
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
311+
(threadIdx.x * UNROLL_FACTOR);
312+
scalar_t* input =
313+
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
314+
scalar_t* output =
315+
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;
316+
317+
index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
318+
const index_t idx = indices[row];
319+
#pragma unroll
320+
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
321+
if constexpr (USE_INDEX_SELECT) {
322+
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
323+
} else {
324+
gpuAtomicAddNoReturn(
325+
&output[idx * num_cols + i], input[row * num_cols + i]);
326+
}
327+
}
328+
}
329+
}
330+
331+
#endif // USE_ROCM
332+
333+
} // namespace
334+
335+
int get_group_index_select_cols_per_warp() {
336+
return GROUP_INDEX_SELECT_COLS_PER_WARP;
337+
}
338+
261339
DLL_PUBLIC void group_index_select_or_add_cuda(
262340
const int64_t* input_ptrs,
263341
const int64_t* output_ptrs,
@@ -278,36 +356,15 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
278356

279357
at::cuda::OptionalCUDAGuard device_guard(device);
280358

281-
// Partition work based on num_work_rows
282-
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
359+
uint32_t num_warps_per_threadblock = kMaxThreads / kGroupIndexWarpSize;
283360
uint32_t max_grid_size =
284361
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
285362
uint32_t grid_size = std::min(
286363
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
287364
max_grid_size);
288-
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
289-
290-
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
291-
FBGEMM_LAUNCH_KERNEL( \
292-
(group_index_select_or_add_2d_kernel< \
293-
index_t, \
294-
scalar_t, \
295-
USE_INDEX_SELECT, \
296-
USE_VAR_COLS, \
297-
GROUP_INDEX_SELECT_UNROLL_FACTOR, \
298-
GROUP_INDEX_SELECT_COLS_PER_WARP, \
299-
GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), \
300-
grid_size, \
301-
block_size, \
302-
0, \
303-
at::cuda::getCurrentCUDAStream(), \
304-
input_ptrs, \
305-
output_ptrs, \
306-
indices_ptrs, \
307-
warp_offsets_group, \
308-
num_cols_group, \
309-
num_work_rows, \
310-
group_size)
365+
dim3 block_size(kGroupIndexWarpSize, num_warps_per_threadblock, 1);
366+
367+
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG) FBGEMM_LAUNCH_KERNEL( (group_index_select_or_add_2d_kernel< index_t, scalar_t, USE_INDEX_SELECT_FLAG, USE_VAR_COLS_FLAG, GROUP_INDEX_SELECT_UNROLL_FACTOR, GROUP_INDEX_SELECT_COLS_PER_WARP, GROUP_INDEX_SELECT_LOG_COLS_PER_WARP>), grid_size, block_size, 0, at::cuda::getCurrentCUDAStream(), input_ptrs, output_ptrs, indices_ptrs, warp_offsets_group, num_cols_group, num_work_rows, group_size)
311368

312369
AT_DISPATCH_INDEX_TYPES(
313370
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {

0 commit comments

Comments
 (0)