Skip to content

Conversation

@sunjiweiswift
Copy link

@sunjiweiswift sunjiweiswift commented Nov 3, 2025

NHD: the last 3 dimensions are organized as (seq_len, num_heads, head_dim).
HND: the last 3 dimensions are organized as (num_heads, seq_len, head_dim).

In VLLM/sglang, NHD is a more commonly used format. Support for NHD has been added in the release pr.

@sunjiweiswift sunjiweiswift marked this pull request as ready for review November 4, 2025 04:29
@sunjiweiswift
Copy link
Author

@petercad pls review

@sunjiweiswift sunjiweiswift changed the title NHD v1.0 NHD layout Nov 4, 2025
@sunjiweiswift
Copy link
Author

@jiyang1011 @taozha2 @tdeng5 pls review

@tdeng5 tdeng5 requested review from jiyang1011 and tdeng5 November 5, 2025 03:39
@Antonyvance Antonyvance requested a review from Copilot November 5, 2025 07:29
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for NHD (seq_len, num_heads, head_dim) layout in addition to the existing HND (num_heads, seq_len, head_dim) layout for the BMG flash attention example. The NHD layout is commonly used in VLLM/sglang frameworks and is set as the new default.

Key changes:

  • Added --layout command-line option with validation for "NHD" and "HND" values
  • Updated stride calculations to support both layout formats
  • Modified the verification function to handle layout-specific tensor indexing and data reordering

@sunjiweiswift
Copy link
Author

@copilot open a new pull request to apply changes based on the comments in this thread

@sunjiweiswift
Copy link
Author

@rolandschulz pls review and merge

@hshen14
Copy link

hshen14 commented Nov 6, 2025

@sunjiweiswift did you observe some perf gains using HND than NHD when low precision e.g., FP8 is enabled? FlashInfer says it's more friendly for GPU implementation.

@sunjiweiswift
Copy link
Author

@sunjiweiswift did you observe some perf gains using HND than NHD when low precision e.g., FP8 is enabled? FlashInfer says it's more friendly for GPU implementation.

image @hshen14 There is not much difference in the case of BF16. FP8 and other low-precision types are not currently supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants