Skip to content

Commit c533944

Browse files
committed
Update on "[llama-mm] Onboard Llama3.2 mm vision encoder"
Summary: Add llama3.2 mm vision encoder to examples/models. We need to do a module swapping for TilePositionEmbedding to make sure vision encoder is exportable. Test Plan: Unit tests. Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 2e77105 + 64ff85e commit c533944

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"ic4": "linux.12xlarge",
2525
"resnet50": "linux.12xlarge",
2626
"llava": "linux.12xlarge",
27+
"llama3_2_vision_encoder": "linux.12xlarge",
2728
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2829
"dl3": "linux.12xlarge",
2930
"emformer_join": "linux.12xlarge",

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass, field
8+
from typing import Optional
89

910
import torch
1011

@@ -41,8 +42,10 @@ class VisionEncoderConfig:
4142

4243

4344
class FlamingoVisionEncoderModel(EagerModelBase):
44-
def __init__(self, config: VisionEncoderConfig = VisionEncoderConfig()):
45+
def __init__(self, config: Optional[VisionEncoderConfig] = None):
4546
super().__init__()
47+
if config is None:
48+
config = VisionEncoderConfig()
4649
self.config = config
4750
self.model = flamingo_vision_encoder(
4851
patch_size=config.patch_size,
@@ -56,6 +59,7 @@ def __init__(self, config: VisionEncoderConfig = VisionEncoderConfig()):
5659
max_num_tiles=config.max_num_tiles,
5760
in_channels=config.in_channels,
5861
)
62+
self.model = replace_tile_positional_embedding(self.model)
5963
self.image = torch.randn(
6064
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
6165
)
@@ -66,7 +70,6 @@ def __init__(self, config: VisionEncoderConfig = VisionEncoderConfig()):
6670
)
6771

6872
def get_eager_model(self, **kwargs):
69-
self.model = replace_tile_positional_embedding(self.model)
7073
return self.model
7174

7275
def get_example_inputs(self):

examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414

1515
from executorch.examples.models.llama3_2_vision.vision_encoder import (
1616
FlamingoVisionEncoderModel,
17-
VisionEncoderConfig,
1817
)
19-
from torch._inductor.package import load_package, package_aoti
18+
from torch._inductor.package import package_aoti
2019

2120

2221
class FlamingoVisionEncoderTest(unittest.TestCase):
2322
def setUp(self) -> None:
2423
super().setUp()
2524

2625
def test_flamingo_vision_encoder(self) -> None:
27-
model = FlamingoVisionEncoderModel(VisionEncoderConfig())
26+
model = FlamingoVisionEncoderModel()
2827
encoder = model.model
2928
eager_res = encoder.forward(*model.get_example_inputs())
3029

@@ -38,7 +37,7 @@ def test_flamingo_vision_encoder(self) -> None:
3837
with tempfile.TemporaryDirectory() as tmpdir:
3938
path = package_aoti(os.path.join(tmpdir, "vision_encoder.pt2"), so)
4039
print(path)
41-
encoder_aoti = load_package(path)
40+
encoder_aoti = torch._inductor.aoti_load_package(path)
4241

4342
y = encoder_aoti(*model.get_example_inputs())
4443

extension/llm/modules/_position_embeddings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,13 @@ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
188188
torch._check(n_tiles_w >= 1)
189189
torch._check(n_tiles_h <= self.max_num_tiles)
190190
torch._check(n_tiles_w <= self.max_num_tiles)
191+
# TODO: Remove this once pytorch/pytorch#120288 is fixed
191192
padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
192193
pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
193194

194-
# Add pos encoding to the non padded tiles.
195+
# We need to do a clone here in order to make this model export
196+
# friendly as the reshape is collapsing dim 0 and dim 1 into a
197+
# single dim.
195198
pos_embed = pos_embed.clone()
196199
pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)
197200

0 commit comments

Comments
 (0)