Skip to content

Commit b0d9e3f

Browse files
committed
Update on "[llama-mm] Add export-friendly tile position embedding"
Summary: Before we make a decision on whether torchtune takes this export-friendly version of `TilePositionEmbedding`, we put it under `extension/llm` so that users can start to use it. Added unit tests to make sure the behavior is the same as the reference implementation in torchtune and export/AOTI/ET all working properly. Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent df66f00 commit b0d9e3f

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

extension/llm/modules/_position_embeddings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,13 @@ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
188188
torch._check(n_tiles_w >= 1)
189189
torch._check(n_tiles_h <= self.max_num_tiles)
190190
torch._check(n_tiles_w <= self.max_num_tiles)
191+
# TODO: Remove this once pytorch/pytorch#120288 is fixed
191192
padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
192193
pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
193194

194-
# Add pos encoding to the non padded tiles.
195+
# We need to do a clone here in order to make this model export
196+
# friendly as the reshape is collapsing dim 0 and dim 1 into a
197+
# single dim.
195198
pos_embed = pos_embed.clone()
196199
pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)
197200

0 commit comments

Comments
 (0)