Skip to content

Commit c43677d

Browse files
q10facebook-github-bot
authored andcommitted
Migrate LFU cache populate kernels to FBGEMM_LAUNCH_KERNEL (#4805)
Summary: X-link: facebookresearch/FBGEMM#1831 Pull Request resolved: #4805 - Migrate LFU cache populate kernels to `FBGEMM_LAUNCH_KERNEL` Reviewed By: ionuthristodorescu Differential Revision: D79974662 fbshipit-source-id: cc46d389e19ae796523621393d8021df3ca2b538
1 parent d51741c commit c43677d

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
5454
lfu_state,
5555
const int64_t row_alignment) {
5656
const int32_t C = lxu_cache_state.size(0);
57-
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
57+
for (uint32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
5858
n += gridDim.x * blockDim.y) {
5959
// check if this warp is responsible for this whole segment.
6060
const bool segment_start =
@@ -81,17 +81,17 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
8181

8282
// now, we need to insert the (unique!) values in indices[n:n + SL] into
8383
// our slots.
84-
const int32_t slot = threadIdx.x;
84+
const uint32_t slot = threadIdx.x;
8585
const int64_t current_idx = lxu_cache_state[cache_set][slot];
8686
const int64_t current_lfu_cost =
8787
(current_idx != static_cast<int64_t>(kCacheStateInvalid))
8888
? lfu_state[current_idx]
8989
: -1;
9090
int64_t costs[1] = {current_lfu_cost};
91-
int32_t slots[1] = {slot};
91+
uint32_t slots[1] = {slot};
9292

93-
BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
94-
const int32_t sorted_slot = slots[0];
93+
BitonicSort<int64_t, uint32_t, 1, Comparator<int64_t>>::sort(costs, slots);
94+
const uint32_t sorted_slot = slots[0];
9595
const int64_t sorted_lfu_cost = costs[0];
9696

9797
for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
@@ -126,7 +126,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
126126
&weights[weights_offset_insert + idx_insert * D_insert_bytes + 0]);
127127
auto cache_row = reinterpret_cast<uint4*>(
128128
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0]);
129-
for (int32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes;
129+
for (uint32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes;
130130
d += blockDim.x) {
131131
cache_row[d] = row[d];
132132
}
@@ -173,33 +173,27 @@ void lfu_cache_insert_byte_cuda(
173173
cache_set_sorted_unique_indices.scalar_type(),
174174
"lfu_cache_insert_byte_cuda",
175175
[&] {
176-
#ifdef FBGEMM_GPU_MEMCHECK
177-
const char* func_name = "lfu_cache_insert_byte_kernel";
178-
#endif
179-
lfu_cache_insert_byte_kernel<<<
176+
FBGEMM_LAUNCH_KERNEL(
177+
(lfu_cache_insert_byte_kernel<index_t>),
180178
std::min(
181179
div_round_up(N, kCacheMaxThreads / kWarpSize),
182180
get_max_thread_blocks_for_cache_kernels_()),
183181
dim3(kWarpSize, kCacheMaxThreads / kWarpSize),
184182
0,
185-
at::cuda::getCurrentCUDAStream()>>>(
186-
MAKE_PTA_WITH_NAME(func_name, weights, uint8_t, 1, 64),
187-
MAKE_PTA_WITH_NAME(
188-
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
189-
MAKE_PTA_WITH_NAME(
190-
func_name, cache_index_table_map, int32_t, 1, 64),
191-
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
192-
MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32),
193-
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
183+
at::cuda::getCurrentCUDAStream(),
184+
PTA_B(weights, uint8_t, 1, 64),
185+
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
186+
PTA_B(cache_index_table_map, int32_t, 1, 64),
187+
PTA_B(weights_offsets, int64_t, 1, 32),
188+
PTA_B(weights_tys, uint8_t, 1, 32),
189+
PTA_B(D_offsets, int32_t, 1, 32),
194190
(uint64_t*)sorted_cache_sets.data_ptr<int64_t>(),
195-
MAKE_PTA_WITH_NAME(
196-
func_name, cache_set_sorted_unique_indices, index_t, 1, 32),
191+
PTA_B(cache_set_sorted_unique_indices, index_t, 1, 32),
197192
unique_indices_length.data_ptr<int32_t>(),
198-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
199-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, uint8_t, 2, 64),
200-
MAKE_PTA_WITH_NAME(func_name, lfu_state, int64_t, 1, 64),
193+
PTA_B(lxu_cache_state, int64_t, 2, 32),
194+
PTA_B(lxu_cache_weights, uint8_t, 2, 64),
195+
PTA_B(lfu_state, int64_t, 1, 64),
201196
row_alignment);
202-
C10_CUDA_KERNEL_LAUNCH_CHECK();
203197
});
204198
}
205199

0 commit comments

Comments
 (0)