Skip to content

Commit 702e642

Browse files
Delete FlexAttention + NJT composition from tutorial (#3561)
Fix broken CI https://github.com/pytorch/tutorials/actions/runs/17684562516/job/50266336429?pr=3553#step:9:11267 We should delete this because NJT support will be deleted from FlexAttention in pytorch/pytorch#161734 Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent d81e832 commit 702e642

File tree

1 file changed

+0
-61
lines changed

1 file changed

+0
-61
lines changed

intermediate_source/transformer_building_blocks.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,6 @@ def benchmark(func, *args, **kwargs):
564564
#
565565
# * Cross Attention
566566
# * Fully masked rows no longer cause NaNs
567-
# * Modifying attention score: ALiBi with FlexAttention and NJT
568567
# * Packed Projection
569568

570569
###############################################################################
@@ -668,66 +667,6 @@ def benchmark(func, *args, **kwargs):
668667
# appropriately makes it possible to properly express empty sequences.
669668

670669

671-
################################################################################
672-
# FlexAttention + NJT
673-
# ---------------------------------------------------------------------
674-
# NJT also composes with the ``FlexAttention`` module. This is a generalization
675-
# of the ``MultiheadAttention`` layer that allows for arbitrary modifications
676-
# to the attention score. The example below takes the ``alibi_mod``
677-
# that implements `ALiBi <https://arxiv.org/abs/2108.12409>`_ from
678-
# `attention gym <https://github.com/meta-pytorch/attention-gym>`_ and uses it
679-
# with nested input tensors.
680-
681-
from torch.nn.attention.flex_attention import flex_attention
682-
683-
684-
def generate_alibi_bias(H: int):
685-
"""Returns an alibi bias score_mod given the number of heads H
686-
Args:
687-
H: number of heads
688-
Returns:
689-
alibi_bias: alibi bias score_mod
690-
"""
691-
692-
def alibi_mod(score, b, h, q_idx, kv_idx):
693-
scale = torch.exp2(-((h + 1) * 8.0 / H))
694-
bias = (q_idx - kv_idx) * scale
695-
return score + bias
696-
697-
return alibi_mod
698-
699-
700-
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
701-
n_heads, D = 8, E_q // 8
702-
alibi_score_mod = generate_alibi_bias(n_heads)
703-
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
704-
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
705-
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
706-
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
707-
708-
###############################################################################
709-
# In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
710-
# with NJTs via the ``create_nested_block_mask`` function. This is useful for
711-
# taking advantage of the sparsity of the mask to speed up the attention computation.
712-
# In particular, the function creates a sparse block mask for a "stacked sequence" of all
713-
# the variable length sequences in the NJT combined into one, while properly masking out
714-
# inter-sequence attention. In the following example, we show how to create a
715-
# causal block mask using this utility.
716-
717-
from torch.nn.attention.flex_attention import create_nested_block_mask
718-
719-
720-
def causal_mask(b, h, q_idx, kv_idx):
721-
return q_idx >= kv_idx
722-
723-
724-
query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)
725-
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
726-
query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
727-
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
728-
value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
729-
out_flex = flex_attention(query, key, value, block_mask=block_mask)
730-
731670
###############################################################################
732671
# Packed Projection
733672
# -----------------

0 commit comments

Comments
 (0)