diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu index 1fa91c519e..2f2b1e9eae 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu @@ -54,7 +54,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( lfu_state, const int64_t row_alignment) { const int32_t C = lxu_cache_state.size(0); - for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; + for (uint32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique; n += gridDim.x * blockDim.y) { // check if this warp is responsible for this whole segment. const bool segment_start = @@ -81,17 +81,17 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( // now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. - const int32_t slot = threadIdx.x; + const uint32_t slot = threadIdx.x; const int64_t current_idx = lxu_cache_state[cache_set][slot]; const int64_t current_lfu_cost = (current_idx != static_cast(kCacheStateInvalid)) ? lfu_state[current_idx] : -1; int64_t costs[1] = {current_lfu_cost}; - int32_t slots[1] = {slot}; + uint32_t slots[1] = {slot}; - BitonicSort>::sort(costs, slots); - const int32_t sorted_slot = slots[0]; + BitonicSort>::sort(costs, slots); + const uint32_t sorted_slot = slots[0]; const int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { @@ -126,7 +126,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( &weights[weights_offset_insert + idx_insert * D_insert_bytes + 0]); auto cache_row = reinterpret_cast( &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0]); - for (int32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes; + for (uint32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes; d += blockDim.x) { cache_row[d] = row[d]; } @@ -173,33 +173,27 @@ void lfu_cache_insert_byte_cuda( cache_set_sorted_unique_indices.scalar_type(), "lfu_cache_insert_byte_cuda", [&] { -#ifdef FBGEMM_GPU_MEMCHECK - const char* func_name = "lfu_cache_insert_byte_kernel"; -#endif - lfu_cache_insert_byte_kernel<<< + FBGEMM_LAUNCH_KERNEL( + (lfu_cache_insert_byte_kernel), std::min( div_round_up(N, kCacheMaxThreads / kWarpSize), get_max_thread_blocks_for_cache_kernels_()), dim3(kWarpSize, kCacheMaxThreads / kWarpSize), 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, weights, uint8_t, 1, 64), - MAKE_PTA_WITH_NAME( - func_name, cache_hash_size_cumsum, int64_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, cache_index_table_map, int32_t, 1, 64), - MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), + at::cuda::getCurrentCUDAStream(), + PTA_B(weights, uint8_t, 1, 64), + PTA_B(cache_hash_size_cumsum, int64_t, 1, 32), + PTA_B(cache_index_table_map, int32_t, 1, 64), + PTA_B(weights_offsets, int64_t, 1, 32), + PTA_B(weights_tys, uint8_t, 1, 32), + PTA_B(D_offsets, int32_t, 1, 32), (uint64_t*)sorted_cache_sets.data_ptr(), - MAKE_PTA_WITH_NAME( - func_name, cache_set_sorted_unique_indices, index_t, 1, 32), + PTA_B(cache_set_sorted_unique_indices, index_t, 1, 32), unique_indices_length.data_ptr(), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32), - MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, uint8_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, lfu_state, int64_t, 1, 64), + PTA_B(lxu_cache_state, int64_t, 2, 32), + PTA_B(lxu_cache_weights, uint8_t, 2, 64), + PTA_B(lfu_state, int64_t, 1, 64), row_alignment); - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }