Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,111 @@ def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)


@cli.command()
@click.option("--num-segments", default=100)
@click.option("--max-segment-length", default=10000)
@click.option(
"--index-dtype", type=click.Choice(["int", "int64", "float"]), default="float"
)
@click.option("--has-weight", is_flag=True, default=False)
@click.option("--device", type=click.Choice(["cpu", "cuda"]), default="cuda")
def permute_1d_sparse_data_bench(
num_segments: int,
max_segment_length: int,
index_dtype: str,
has_weight: bool,
device: str,
) -> None:
"""Benchmark permute_1D_sparse_data operator.

This operator permutes sparse features (indices and optional weights) according
to a given permutation. Commonly used in recommendation systems to reorder
embedding tables.
"""
if index_dtype == "int":
index_dtype = torch.int32
elif index_dtype == "int64":
index_dtype = torch.int64
elif index_dtype == "float":
index_dtype = torch.float32
else:
raise RuntimeError(f"Does not support data type {index_dtype}")

# Generate variable-length segments to test vectorization
emb_dim = 256
lengths = (
torch.randint(
low=max_segment_length // 2,
high=max_segment_length,
size=(num_segments,),
dtype=torch.int32,
device=device,
)
* emb_dim
)
total_indices = int(lengths.sum().item())
# Generate indices
if index_dtype == torch.float32:
indices = torch.rand(total_indices, dtype=index_dtype, device=device)
else:
indices = torch.randint(
low=0,
high=2**31 - 1,
size=(total_indices,),
dtype=index_dtype,
device=device,
)

# Generate optional weights
weights = (
torch.rand(total_indices, dtype=torch.float32, device=device)
if has_weight
else None
)
# Generate random permutation
permute_list = list(range(num_segments))
random.shuffle(permute_list)
permute = torch.IntTensor(permute_list).to(device)
# Benchmark the operation
time, (permuted_lengths, permuted_indices, permuted_weights) = (
benchmark_torch_function(
torch.ops.fbgemm.permute_1D_sparse_data,
(permute, lengths, indices, weights, None),
num_warmups=100,
iters=1000,
)
)

# Calculate memory bandwidth
num_bytes = (
permute.numel() * permute.element_size()
+ lengths.numel() * lengths.element_size()
+ indices.numel() * indices.element_size()
+ permuted_lengths.numel() * permuted_lengths.element_size()
+ permuted_indices.numel() * permuted_indices.element_size()
)
if has_weight:
assert weights is not None
assert permuted_weights is not None
num_bytes += (
weights.numel() * weights.element_size() # pyre-ignore [16]
+ permuted_weights.numel() * permuted_weights.element_size()
)

logging.info(
f"permute_1D_sparse_data_bench ("
f"num_segments={num_segments}, "
f"max_segment_length={max_segment_length}, "
f"total_indices={total_indices}, "
f"dtype={index_dtype}, "
f"with_weights={has_weight}, "
f"device={device})"
)
logging.info(
f"fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)"
)


@cli.command()
@click.option("--row-size", default=2560000)
@click.option("--batch-size", default=2048)
Expand Down
123 changes: 116 additions & 7 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,115 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
}
}

// Vectorized kernel for permuting the indices and weights. Used for permutation
// of sparse data. Uses vec4 loads for improved memory bandwidth.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
__global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel_vec(
int32_t permuted_indices_size,
int32_t permuted_lengths_size,
const indices_t* __restrict__ indices,
const weights_t* __restrict__ weights,
const int32_t* __restrict__ permute,
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
indices_t* __restrict__ permuted_indices,
weights_t* __restrict__ permuted_weights) {
// Select vector types based on element size (vec4 for 4× bandwidth)
using indices_vec4_t =
typename std::conditional<sizeof(indices_t) == 8, long4, float4>::type;
using weights_vec4_t =
typename std::conditional<sizeof(weights_t) == 8, long4, float4>::type;

const auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
const auto stride = gridDim.x * blockDim.y;

for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
// Read offsets once - use int32_t for segment_length as it fits in 32 bits
const offsets_t output_start = output_offsets[b_t];
const offsets_t output_end = (b_t == permuted_lengths_size - 1)
? permuted_indices_size
: output_offsets[b_t + 1];
const int32_t segment_length =
static_cast<int32_t>(output_end - output_start);
const offsets_t input_start = input_offsets[permute[b_t]];

// Compute pointers
indices_t* __restrict__ indices_dst_ptr = permuted_indices + output_start;
const indices_t* __restrict__ indices_src_ptr = indices + input_start;
weights_t* __restrict__ weights_dst_ptr =
has_weight ? permuted_weights + output_start : nullptr;
const weights_t* __restrict__ weights_src_ptr =
has_weight ? weights + input_start : nullptr;

// Check alignment once per segment
const bool indices_vec4_aligned =
(sizeof(indices_t) == 4 || sizeof(indices_t) == 8) &&
(reinterpret_cast<uintptr_t>(indices_dst_ptr) &
(alignof(indices_vec4_t) - 1)) == 0 &&
(reinterpret_cast<uintptr_t>(indices_src_ptr) &
(alignof(indices_vec4_t) - 1)) == 0;

const bool weights_vec4_aligned = !has_weight ||
((reinterpret_cast<uintptr_t>(weights_dst_ptr) &
(alignof(weights_vec4_t) - 1)) == 0 &&
(reinterpret_cast<uintptr_t>(weights_src_ptr) &
(alignof(weights_vec4_t) - 1)) == 0);

if (indices_vec4_aligned && weights_vec4_aligned) {
// Vectorized path - process both indices and weights together
const int32_t vec4_count = segment_length / 4;
const int32_t remainder = segment_length & 3; // segment_length % 4

auto indices_dst = reinterpret_cast<indices_vec4_t*>(indices_dst_ptr);
auto indices_src =
reinterpret_cast<const indices_vec4_t*>(indices_src_ptr);

if (has_weight) {
auto weights_dst = reinterpret_cast<weights_vec4_t*>(weights_dst_ptr);
auto weights_src =
reinterpret_cast<const weights_vec4_t*>(weights_src_ptr);

// copy both indices and weights
#pragma unroll
for (auto i = threadIdx.x; i < vec4_count; i += blockDim.x) {
indices_dst[i] = indices_src[i];
weights_dst[i] = weights_src[i];
}
// Handle remainder elements (0-3 elements)
if (threadIdx.x < remainder) {
const auto offset = vec4_count * 4 + threadIdx.x;
indices_dst_ptr[offset] = indices_src_ptr[offset];
weights_dst_ptr[offset] = weights_src_ptr[offset];
}
} else {
// copy only indices
#pragma unroll
for (auto i = threadIdx.x; i < vec4_count; i += blockDim.x) {
indices_dst[i] = indices_src[i];
}

// Handle remainder elements (0-3 elements)
if (threadIdx.x < remainder) {
const auto offset = vec4_count * 4 + threadIdx.x;
indices_dst_ptr[offset] = indices_src_ptr[offset];
}
}
} else {
// Scalar fallback path
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
indices_dst_ptr[i] = indices_src_ptr[i];
if (has_weight) {
weights_dst_ptr[i] = weights_src_ptr[i];
}
}
}
}
}

DLL_PUBLIC std::tuple<Tensor, Tensor, std::optional<Tensor>>
permute_1D_sparse_data_cuda(
const Tensor& permute,
Expand Down Expand Up @@ -124,17 +233,17 @@ permute_1D_sparse_data_cuda(
permuted_indices_size = output_offsets[-1].item<int64_t>();
}

constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
constexpr int32_t BT_blocks = 16;
dim3 threads_2(64, BT_blocks);
const auto blocks_2 =
cuda_calc_xblock_count(permuted_lengths_size, BT_blocks);
permuted_indices = at::empty(permuted_indices_size, indices.options());

AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_1D_data_kernel_1", [&] {
input_offsets.scalar_type(), "permute_1D_data_kernel_vec_1", [&] {
using offsets_t = index_t;
FBGEMM_DISPATCH_ALL_TYPES(
indices.scalar_type(), "permute_1D_data_kernel_2", [&] {
indices.scalar_type(), "permute_1D_data_kernel_vec_2", [&] {
using indices_t = scalar_t;
if (weights.has_value()) {
const Tensor weights_value = weights.value();
Expand All @@ -143,11 +252,11 @@ permute_1D_sparse_data_cuda(
at::empty(permuted_indices_size, weights_value.options());
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
weights_value.scalar_type(),
"permute_1D_data_kernel_3",
"permute_1D_data_kernel_vec_3",
[&] {
using weights_t = scalar_t;
FBGEMM_LAUNCH_KERNEL(
(permute_1D_data_kernel<
(permute_1D_data_kernel_vec<
true,
offsets_t,
indices_t,
Expand All @@ -168,7 +277,7 @@ permute_1D_sparse_data_cuda(
}); // for each weights_t
} else {
FBGEMM_LAUNCH_KERNEL(
(permute_1D_data_kernel<
(permute_1D_data_kernel_vec<
false,
offsets_t,
indices_t,
Expand Down
Loading
Loading