Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 19 additions & 25 deletions fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate_byte.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
lfu_state,
const int64_t row_alignment) {
const int32_t C = lxu_cache_state.size(0);
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
for (uint32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
n += gridDim.x * blockDim.y) {
// check if this warp is responsible for this whole segment.
const bool segment_start =
Expand All @@ -81,17 +81,17 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(

// now, we need to insert the (unique!) values in indices[n:n + SL] into
// our slots.
const int32_t slot = threadIdx.x;
const uint32_t slot = threadIdx.x;
const int64_t current_idx = lxu_cache_state[cache_set][slot];
const int64_t current_lfu_cost =
(current_idx != static_cast<int64_t>(kCacheStateInvalid))
? lfu_state[current_idx]
: -1;
int64_t costs[1] = {current_lfu_cost};
int32_t slots[1] = {slot};
uint32_t slots[1] = {slot};

BitonicSort<int64_t, int32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const int32_t sorted_slot = slots[0];
BitonicSort<int64_t, uint32_t, 1, Comparator<int64_t>>::sort(costs, slots);
const uint32_t sorted_slot = slots[0];
const int64_t sorted_lfu_cost = costs[0];

for (int32_t l = 0; l < min(SL, kWarpSize); ++l) {
Expand Down Expand Up @@ -126,7 +126,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
&weights[weights_offset_insert + idx_insert * D_insert_bytes + 0]);
auto cache_row = reinterpret_cast<uint4*>(
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0]);
for (int32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes;
for (uint32_t d = threadIdx.x; d * sizeof(uint4) < D_insert_bytes;
d += blockDim.x) {
cache_row[d] = row[d];
}
Expand Down Expand Up @@ -173,33 +173,27 @@ void lfu_cache_insert_byte_cuda(
cache_set_sorted_unique_indices.scalar_type(),
"lfu_cache_insert_byte_cuda",
[&] {
#ifdef FBGEMM_GPU_MEMCHECK
const char* func_name = "lfu_cache_insert_byte_kernel";
#endif
lfu_cache_insert_byte_kernel<<<
FBGEMM_LAUNCH_KERNEL(
(lfu_cache_insert_byte_kernel<index_t>),
std::min(
div_round_up(N, kCacheMaxThreads / kWarpSize),
get_max_thread_blocks_for_cache_kernels_()),
dim3(kWarpSize, kCacheMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, weights, uint8_t, 1, 64),
MAKE_PTA_WITH_NAME(
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, cache_index_table_map, int32_t, 1, 64),
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
at::cuda::getCurrentCUDAStream(),
PTA_B(weights, uint8_t, 1, 64),
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
PTA_B(cache_index_table_map, int32_t, 1, 64),
PTA_B(weights_offsets, int64_t, 1, 32),
PTA_B(weights_tys, uint8_t, 1, 32),
PTA_B(D_offsets, int32_t, 1, 32),
(uint64_t*)sorted_cache_sets.data_ptr<int64_t>(),
MAKE_PTA_WITH_NAME(
func_name, cache_set_sorted_unique_indices, index_t, 1, 32),
PTA_B(cache_set_sorted_unique_indices, index_t, 1, 32),
unique_indices_length.data_ptr<int32_t>(),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, uint8_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, lfu_state, int64_t, 1, 64),
PTA_B(lxu_cache_state, int64_t, 2, 32),
PTA_B(lxu_cache_weights, uint8_t, 2, 64),
PTA_B(lfu_state, int64_t, 1, 64),
row_alignment);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}

Expand Down
Loading