|
5 | 5 | # credit: |
6 | 6 | # Amin Rezaei (original author) |
7 | 7 | # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) |
| 8 | +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) |
8 | 9 | # implementation of: |
9 | 10 | # Self-attention Does Not Need O(n2) Memory": |
10 | 11 | # https://arxiv.org/abs/2112.05682v2 |
|
16 | 17 | import math |
17 | 18 | from typing import Optional, NamedTuple, Protocol, List |
18 | 19 |
|
19 | | -def dynamic_slice( |
20 | | - x: Tensor, |
21 | | - starts: List[int], |
22 | | - sizes: List[int], |
| 20 | +def narrow_trunc( |
| 21 | + input: Tensor, |
| 22 | + dim: int, |
| 23 | + start: int, |
| 24 | + length: int |
23 | 25 | ) -> Tensor: |
24 | | - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] |
25 | | - return x[slicing] |
| 26 | + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) |
26 | 27 |
|
27 | 28 | class AttnChunk(NamedTuple): |
28 | 29 | exp_values: Tensor |
@@ -76,15 +77,17 @@ def _query_chunk_attention( |
76 | 77 | _, _, v_channels_per_head = value.shape |
77 | 78 |
|
78 | 79 | def chunk_scanner(chunk_idx: int) -> AttnChunk: |
79 | | - key_chunk = dynamic_slice( |
| 80 | + key_chunk = narrow_trunc( |
80 | 81 | key, |
81 | | - (0, chunk_idx, 0), |
82 | | - (batch_x_heads, kv_chunk_size, k_channels_per_head) |
| 82 | + 1, |
| 83 | + chunk_idx, |
| 84 | + kv_chunk_size |
83 | 85 | ) |
84 | | - value_chunk = dynamic_slice( |
| 86 | + value_chunk = narrow_trunc( |
85 | 87 | value, |
86 | | - (0, chunk_idx, 0), |
87 | | - (batch_x_heads, kv_chunk_size, v_channels_per_head) |
| 88 | + 1, |
| 89 | + chunk_idx, |
| 90 | + kv_chunk_size |
88 | 91 | ) |
89 | 92 | return summarize_chunk(query, key_chunk, value_chunk) |
90 | 93 |
|
@@ -161,10 +164,11 @@ def efficient_dot_product_attention( |
161 | 164 | kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) |
162 | 165 |
|
163 | 166 | def get_query_chunk(chunk_idx: int) -> Tensor: |
164 | | - return dynamic_slice( |
| 167 | + return narrow_trunc( |
165 | 168 | query, |
166 | | - (0, chunk_idx, 0), |
167 | | - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) |
| 169 | + 1, |
| 170 | + chunk_idx, |
| 171 | + min(query_chunk_size, q_tokens) |
168 | 172 | ) |
169 | 173 |
|
170 | 174 | summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) |
|
0 commit comments