diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 136e117538..69a5f6c1ba 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -993,6 +993,111 @@ def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) +@cli.command() +@click.option("--num-segments", default=100) +@click.option("--max-segment-length", default=10000) +@click.option( + "--index-dtype", type=click.Choice(["int", "int64", "float"]), default="float" +) +@click.option("--has-weight", is_flag=True, default=False) +@click.option("--device", type=click.Choice(["cpu", "cuda"]), default="cuda") +def permute_1d_sparse_data_bench( + num_segments: int, + max_segment_length: int, + index_dtype: str, + has_weight: bool, + device: str, +) -> None: + """Benchmark permute_1D_sparse_data operator. + + This operator permutes sparse features (indices and optional weights) according + to a given permutation. Commonly used in recommendation systems to reorder + embedding tables. + """ + if index_dtype == "int": + index_dtype = torch.int32 + elif index_dtype == "int64": + index_dtype = torch.int64 + elif index_dtype == "float": + index_dtype = torch.float32 + else: + raise RuntimeError(f"Does not support data type {index_dtype}") + + # Generate variable-length segments to test vectorization + emb_dim = 256 + lengths = ( + torch.randint( + low=max_segment_length // 2, + high=max_segment_length, + size=(num_segments,), + dtype=torch.int32, + device=device, + ) + * emb_dim + ) + total_indices = int(lengths.sum().item()) + # Generate indices + if index_dtype == torch.float32: + indices = torch.rand(total_indices, dtype=index_dtype, device=device) + else: + indices = torch.randint( + low=0, + high=2**31 - 1, + size=(total_indices,), + dtype=index_dtype, + device=device, + ) + + # Generate optional weights + weights = ( + torch.rand(total_indices, dtype=torch.float32, device=device) + if has_weight + else None + ) + # Generate random permutation + permute_list = list(range(num_segments)) + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list).to(device) + # Benchmark the operation + time, (permuted_lengths, permuted_indices, permuted_weights) = ( + benchmark_torch_function( + torch.ops.fbgemm.permute_1D_sparse_data, + (permute, lengths, indices, weights, None), + num_warmups=100, + iters=1000, + ) + ) + + # Calculate memory bandwidth + num_bytes = ( + permute.numel() * permute.element_size() + + lengths.numel() * lengths.element_size() + + indices.numel() * indices.element_size() + + permuted_lengths.numel() * permuted_lengths.element_size() + + permuted_indices.numel() * permuted_indices.element_size() + ) + if has_weight: + assert weights is not None + assert permuted_weights is not None + num_bytes += ( + weights.numel() * weights.element_size() # pyre-ignore [16] + + permuted_weights.numel() * permuted_weights.element_size() + ) + + logging.info( + f"permute_1D_sparse_data_bench (" + f"num_segments={num_segments}, " + f"max_segment_length={max_segment_length}, " + f"total_indices={total_indices}, " + f"dtype={index_dtype}, " + f"with_weights={has_weight}, " + f"device={device})" + ) + logging.info( + f"fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)" + ) + + @cli.command() @click.option("--row-size", default=2560000) @click.option("--batch-size", default=2048) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu index 6e7ca51614..76a8f4d659 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu @@ -61,6 +61,115 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel( } } +// Vectorized kernel for permuting the indices and weights. Used for permutation +// of sparse data. Uses vec4 loads for improved memory bandwidth. +template < + bool has_weight, + typename offsets_t, + typename indices_t, + typename weights_t> +__global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel_vec( + int32_t permuted_indices_size, + int32_t permuted_lengths_size, + const indices_t* __restrict__ indices, + const weights_t* __restrict__ weights, + const int32_t* __restrict__ permute, + const offsets_t* __restrict__ input_offsets, + const offsets_t* __restrict__ output_offsets, + indices_t* __restrict__ permuted_indices, + weights_t* __restrict__ permuted_weights) { + // Select vector types based on element size (vec4 for 4× bandwidth) + using indices_vec4_t = + typename std::conditional::type; + using weights_vec4_t = + typename std::conditional::type; + + const auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y; + const auto stride = gridDim.x * blockDim.y; + + for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) { + // Read offsets once - use int32_t for segment_length as it fits in 32 bits + const offsets_t output_start = output_offsets[b_t]; + const offsets_t output_end = (b_t == permuted_lengths_size - 1) + ? permuted_indices_size + : output_offsets[b_t + 1]; + const int32_t segment_length = + static_cast(output_end - output_start); + const offsets_t input_start = input_offsets[permute[b_t]]; + + // Compute pointers + indices_t* __restrict__ indices_dst_ptr = permuted_indices + output_start; + const indices_t* __restrict__ indices_src_ptr = indices + input_start; + weights_t* __restrict__ weights_dst_ptr = + has_weight ? permuted_weights + output_start : nullptr; + const weights_t* __restrict__ weights_src_ptr = + has_weight ? weights + input_start : nullptr; + + // Check alignment once per segment + const bool indices_vec4_aligned = + (sizeof(indices_t) == 4 || sizeof(indices_t) == 8) && + (reinterpret_cast(indices_dst_ptr) & + (alignof(indices_vec4_t) - 1)) == 0 && + (reinterpret_cast(indices_src_ptr) & + (alignof(indices_vec4_t) - 1)) == 0; + + const bool weights_vec4_aligned = !has_weight || + ((reinterpret_cast(weights_dst_ptr) & + (alignof(weights_vec4_t) - 1)) == 0 && + (reinterpret_cast(weights_src_ptr) & + (alignof(weights_vec4_t) - 1)) == 0); + + if (indices_vec4_aligned && weights_vec4_aligned) { + // Vectorized path - process both indices and weights together + const int32_t vec4_count = segment_length / 4; + const int32_t remainder = segment_length & 3; // segment_length % 4 + + auto indices_dst = reinterpret_cast(indices_dst_ptr); + auto indices_src = + reinterpret_cast(indices_src_ptr); + + if (has_weight) { + auto weights_dst = reinterpret_cast(weights_dst_ptr); + auto weights_src = + reinterpret_cast(weights_src_ptr); + +// copy both indices and weights +#pragma unroll + for (auto i = threadIdx.x; i < vec4_count; i += blockDim.x) { + indices_dst[i] = indices_src[i]; + weights_dst[i] = weights_src[i]; + } + // Handle remainder elements (0-3 elements) + if (threadIdx.x < remainder) { + const auto offset = vec4_count * 4 + threadIdx.x; + indices_dst_ptr[offset] = indices_src_ptr[offset]; + weights_dst_ptr[offset] = weights_src_ptr[offset]; + } + } else { +// copy only indices +#pragma unroll + for (auto i = threadIdx.x; i < vec4_count; i += blockDim.x) { + indices_dst[i] = indices_src[i]; + } + + // Handle remainder elements (0-3 elements) + if (threadIdx.x < remainder) { + const auto offset = vec4_count * 4 + threadIdx.x; + indices_dst_ptr[offset] = indices_src_ptr[offset]; + } + } + } else { + // Scalar fallback path + for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) { + indices_dst_ptr[i] = indices_src_ptr[i]; + if (has_weight) { + weights_dst_ptr[i] = weights_src_ptr[i]; + } + } + } + } +} + DLL_PUBLIC std::tuple> permute_1D_sparse_data_cuda( const Tensor& permute, @@ -124,17 +233,17 @@ permute_1D_sparse_data_cuda( permuted_indices_size = output_offsets[-1].item(); } - constexpr int32_t BT_blocks = 32; - dim3 threads_2(32, BT_blocks); + constexpr int32_t BT_blocks = 16; + dim3 threads_2(64, BT_blocks); const auto blocks_2 = cuda_calc_xblock_count(permuted_lengths_size, BT_blocks); permuted_indices = at::empty(permuted_indices_size, indices.options()); AT_DISPATCH_INDEX_TYPES( - input_offsets.scalar_type(), "permute_1D_data_kernel_1", [&] { + input_offsets.scalar_type(), "permute_1D_data_kernel_vec_1", [&] { using offsets_t = index_t; FBGEMM_DISPATCH_ALL_TYPES( - indices.scalar_type(), "permute_1D_data_kernel_2", [&] { + indices.scalar_type(), "permute_1D_data_kernel_vec_2", [&] { using indices_t = scalar_t; if (weights.has_value()) { const Tensor weights_value = weights.value(); @@ -143,11 +252,11 @@ permute_1D_sparse_data_cuda( at::empty(permuted_indices_size, weights_value.options()); FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE( weights_value.scalar_type(), - "permute_1D_data_kernel_3", + "permute_1D_data_kernel_vec_3", [&] { using weights_t = scalar_t; FBGEMM_LAUNCH_KERNEL( - (permute_1D_data_kernel< + (permute_1D_data_kernel_vec< true, offsets_t, indices_t, @@ -168,7 +277,7 @@ permute_1D_sparse_data_cuda( }); // for each weights_t } else { FBGEMM_LAUNCH_KERNEL( - (permute_1D_data_kernel< + (permute_1D_data_kernel_vec< false, offsets_t, indices_t, diff --git a/fbgemm_gpu/test/sparse/permute_indices_test.py b/fbgemm_gpu/test/sparse/permute_indices_test.py index a48f43c80e..b575735a8f 100644 --- a/fbgemm_gpu/test/sparse/permute_indices_test.py +++ b/fbgemm_gpu/test/sparse/permute_indices_test.py @@ -68,7 +68,6 @@ def test_permute_indices( lengths = torch.cat(length_splits, dim=1) else: lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) - # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, # typing.Tuple[int, ...]]` but got `Union[bool, float, int]`. weights = torch.rand(lengths.sum().item()).float() if has_weight else None @@ -128,6 +127,9 @@ def test_permute_indices( assert permuted_weights_cpu is None and permuted_weights_ref is None if gpu_available: + weights_cuda = ( + weights.cuda() if (has_weight and weights is not None) else None + ) if is_1D: ( permuted_lengths_gpu, @@ -137,8 +139,7 @@ def test_permute_indices( permute.cuda(), lengths.cuda(), indices.cuda(), - # pyre-fixme[16]: `Optional` has no attribute `cuda`. - weights.cuda() if has_weight else None, + weights_cuda, None, ) else: @@ -150,7 +151,7 @@ def test_permute_indices( permute.cuda(), lengths.cuda(), indices.cuda(), - weights.cuda() if has_weight else None, + weights_cuda, None, ) torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) @@ -320,6 +321,9 @@ def test_permute_indices_with_repeats( assert permuted_weights_cpu is None and permuted_weights_ref is None if gpu_available: + weights_cuda = ( + weights.cuda() if (has_weight and weights is not None) else None + ) ( permuted_lengths_gpu, permuted_indices_gpu, @@ -328,8 +332,7 @@ def test_permute_indices_with_repeats( permute.cuda(), lengths.cuda(), indices.cuda(), - # pyre-fixme[16]: `Optional` has no attribute `cuda`. - weights.cuda() if has_weight else None, + weights_cuda, ) torch.testing.assert_close(permuted_indices_gpu.cpu(), permuted_indices_cpu) torch.testing.assert_close(permuted_lengths_gpu.cpu(), permuted_lengths_cpu) @@ -340,6 +343,155 @@ def test_permute_indices_with_repeats( else: assert permuted_weights_cpu is None + @given( + num_segments=st.integers(min_value=20, max_value=100), + max_segment_length=st.integers(min_value=100, max_value=1000), + index_dtype=st.sampled_from([torch.int32, torch.int64, torch.float32]), + has_weight=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=100, deadline=None) + @unittest.skipIf(*gpu_unavailable) + def test_permute_1D_sparse_data_vec( + self, + num_segments: int, + max_segment_length: int, + index_dtype: torch.dtype, + has_weight: bool, + ) -> None: + """ + Test vectorized permute_1D_sparse_data kernel with vec4 optimization. + + Validates: + - Correctness for various segment lengths (tests vec4 path and remainder handling) + - Alignment-based vectorization (vec4 when aligned, scalar fallback when misaligned) + - With and without weights (tests weights_vec4_aligned short-circuit logic) + - Different index types (float vs int64) + - Edge cases: segment lengths at vec4 boundaries (1, 3, 4, 5, 8, 15, 16, etc.) + """ + + # Generate variable-length segments to test vectorization + lengths = torch.randint( + low=max_segment_length // 2, + high=max_segment_length, + size=(num_segments,), + dtype=torch.int32, + ) + total_indices = int(lengths.sum().item()) + # Generate indices + if index_dtype == torch.float32: + indices = torch.rand(total_indices, dtype=index_dtype) + else: + indices = torch.randint( + low=0, + high=2**31 - 1, + size=(total_indices,), + dtype=index_dtype, + ) + + # Generate optional weights + weights = torch.rand(total_indices, dtype=torch.float32) if has_weight else None + + # Generate random permutation + permute_list = list(range(num_segments)) + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list) + + # CPU reference (uses scalar kernel) + ( + permuted_lengths_cpu, + permuted_indices_cpu, + permuted_weights_cpu, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + permute, lengths, indices, weights, None + ) + + # GPU vectorized kernel (uses vec4 when aligned) + ( + permuted_lengths_gpu, + permuted_indices_gpu, + permuted_weights_gpu, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + permute.cuda(), + lengths.cuda(), + indices.cuda(), + weights.cuda() if has_weight and weights is not None else None, + None, + ) + + # Validate correctness + torch.testing.assert_close( + permuted_lengths_gpu.cpu(), + permuted_lengths_cpu, + ) + torch.testing.assert_close( + permuted_indices_gpu.cpu(), + permuted_indices_cpu, + ) + + if has_weight: + torch.testing.assert_close( + permuted_weights_gpu.cpu(), + permuted_weights_cpu, + ) + else: + self.assertIsNone(permuted_weights_gpu) + self.assertIsNone(permuted_weights_cpu) + + # Test edge cases with specific segment lengths at vec4 boundaries + # This validates remainder handling (segment_length % 4 = 0, 1, 2, 3) + edge_case_lengths = [1, 3, 4, 5, 15, 16, 17, 63, 64, 127, 128] + for segment_length in edge_case_lengths: + lengths_edge = torch.tensor([segment_length], dtype=torch.int32) + if index_dtype == torch.float32: + indices_edge = torch.rand(segment_length, dtype=index_dtype) + else: + indices_edge = torch.randint( + 0, 2**31 - 1, size=(segment_length,), dtype=index_dtype + ) + weights_edge = ( + torch.rand(segment_length, dtype=torch.float32) if has_weight else None + ) + permute_edge = torch.IntTensor([0]) + + ( + permuted_lengths_cpu_edge, + permuted_indices_cpu_edge, + permuted_weights_cpu_edge, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + permute_edge, lengths_edge, indices_edge, weights_edge, None + ) + + weights_edge_cuda = ( + weights_edge.cuda() + if (has_weight and weights_edge is not None) + else None + ) + ( + permuted_lengths_gpu_edge, + permuted_indices_gpu_edge, + permuted_weights_gpu_edge, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + permute_edge.cuda(), + lengths_edge.cuda(), + indices_edge.cuda(), + weights_edge_cuda, + None, + ) + torch.testing.assert_close( + permuted_lengths_gpu_edge.cpu(), + permuted_lengths_cpu_edge, + ) + torch.testing.assert_close( + permuted_indices_gpu_edge.cpu(), + permuted_indices_cpu_edge, + ) + + if has_weight: + torch.testing.assert_close( + permuted_weights_gpu_edge.cpu(), + permuted_weights_cpu_edge, + ) + extend_test_class(PermuteIndicesTest)