@@ -54,7 +54,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
5454 lfu_state,
5555 const int64_t row_alignment) {
5656 const int32_t C = lxu_cache_state.size (0 );
57- for (int32_t n = blockIdx .x * blockDim .y + threadIdx .y ; n < *N_unique;
57+ for (uint32_t n = blockIdx .x * blockDim .y + threadIdx .y ; n < *N_unique;
5858 n += gridDim .x * blockDim .y ) {
5959 // check if this warp is responsible for this whole segment.
6060 const bool segment_start =
@@ -81,17 +81,17 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
8181
8282 // now, we need to insert the (unique!) values in indices[n:n + SL] into
8383 // our slots.
84- const int32_t slot = threadIdx .x ;
84+ const uint32_t slot = threadIdx .x ;
8585 const int64_t current_idx = lxu_cache_state[cache_set][slot];
8686 const int64_t current_lfu_cost =
8787 (current_idx != static_cast <int64_t >(kCacheStateInvalid ))
8888 ? lfu_state[current_idx]
8989 : -1 ;
9090 int64_t costs[1 ] = {current_lfu_cost};
91- int32_t slots[1 ] = {slot};
91+ uint32_t slots[1 ] = {slot};
9292
93- BitonicSort<int64_t , int32_t , 1 , Comparator<int64_t >>::sort (costs, slots);
94- const int32_t sorted_slot = slots[0 ];
93+ BitonicSort<int64_t , uint32_t , 1 , Comparator<int64_t >>::sort (costs, slots);
94+ const uint32_t sorted_slot = slots[0 ];
9595 const int64_t sorted_lfu_cost = costs[0 ];
9696
9797 for (int32_t l = 0 ; l < min (SL, kWarpSize ); ++l) {
@@ -126,7 +126,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel(
126126 &weights[weights_offset_insert + idx_insert * D_insert_bytes + 0 ]);
127127 auto cache_row = reinterpret_cast <uint4 *>(
128128 &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0 ]);
129- for (int32_t d = threadIdx .x ; d * sizeof (uint4 ) < D_insert_bytes;
129+ for (uint32_t d = threadIdx .x ; d * sizeof (uint4 ) < D_insert_bytes;
130130 d += blockDim .x ) {
131131 cache_row[d] = row[d];
132132 }
@@ -173,33 +173,27 @@ void lfu_cache_insert_byte_cuda(
173173 cache_set_sorted_unique_indices.scalar_type (),
174174 " lfu_cache_insert_byte_cuda" ,
175175 [&] {
176- #ifdef FBGEMM_GPU_MEMCHECK
177- const char * func_name = " lfu_cache_insert_byte_kernel" ;
178- #endif
179- lfu_cache_insert_byte_kernel<<<
176+ FBGEMM_LAUNCH_KERNEL (
177+ (lfu_cache_insert_byte_kernel<index_t >),
180178 std::min (
181179 div_round_up (N, kCacheMaxThreads / kWarpSize ),
182180 get_max_thread_blocks_for_cache_kernels_ ()),
183181 dim3 (kWarpSize , kCacheMaxThreads / kWarpSize ),
184182 0 ,
185- at::cuda::getCurrentCUDAStream()>>>(
186- MAKE_PTA_WITH_NAME (func_name, weights, uint8_t , 1 , 64 ),
187- MAKE_PTA_WITH_NAME(
188- func_name, cache_hash_size_cumsum, int64_t , 1 , 32 ),
189- MAKE_PTA_WITH_NAME(
190- func_name, cache_index_table_map, int32_t , 1 , 64 ),
191- MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t , 1 , 32 ),
192- MAKE_PTA_WITH_NAME(func_name, weights_tys, uint8_t , 1 , 32 ),
193- MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t , 1 , 32 ),
183+ at::cuda::getCurrentCUDAStream (),
184+ PTA_B (weights, uint8_t , 1 , 64 ),
185+ PTA_B (cache_hash_size_cumsum, int64_t , 1 , 32 ),
186+ PTA_B (cache_index_table_map, int32_t , 1 , 64 ),
187+ PTA_B (weights_offsets, int64_t , 1 , 32 ),
188+ PTA_B (weights_tys, uint8_t , 1 , 32 ),
189+ PTA_B (D_offsets, int32_t , 1 , 32 ),
194190 (uint64_t *)sorted_cache_sets.data_ptr <int64_t >(),
195- MAKE_PTA_WITH_NAME(
196- func_name, cache_set_sorted_unique_indices, index_t , 1 , 32 ),
191+ PTA_B (cache_set_sorted_unique_indices, index_t , 1 , 32 ),
197192 unique_indices_length.data_ptr <int32_t >(),
198- MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t , 2 , 32 ),
199- MAKE_PTA_WITH_NAME(func_name, lxu_cache_weights, uint8_t , 2 , 64 ),
200- MAKE_PTA_WITH_NAME(func_name, lfu_state, int64_t , 1 , 64 ),
193+ PTA_B ( lxu_cache_state, int64_t , 2 , 32 ),
194+ PTA_B ( lxu_cache_weights, uint8_t , 2 , 64 ),
195+ PTA_B ( lfu_state, int64_t , 1 , 64 ),
201196 row_alignment);
202- C10_CUDA_KERNEL_LAUNCH_CHECK ();
203197 });
204198}
205199
0 commit comments