Skip to content

Commit b7f169f

Browse files
committed
[llama-mm] Onboard Llama3.2 mm vision encoder
Summary: Add llama3.2 mm vision encoder to examples/models. We need to do a module swapping for TilePositionEmbedding to make sure vision encoder is exportable. Test Plan: Unit tests. Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f996a61 Pull Request resolved: #6653
1 parent 5436d8a commit b7f169f

File tree

7 files changed

+148
-1
lines changed

7 files changed

+148
-1
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"ic4": "linux.12xlarge",
2525
"resnet50": "linux.12xlarge",
2626
"llava": "linux.12xlarge",
27+
"llama3_2_vision_encoder": "linux.12xlarge",
2728
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
2829
"dl3": "linux.12xlarge",
2930
"emformer_join": "linux.12xlarge",

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
1919
"llama2": ("llama", "Llama2Model"),
2020
"llama": ("llama", "Llama2Model"),
21+
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
2122
"lstm": ("lstm", "LSTMModel"),
2223
"mobilebert": ("mobilebert", "MobileBertModelExample"),
2324
"mv2": ("mobilenet_v2", "MV2Model"),
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import FlamingoVisionEncoderModel, VisionEncoderConfig
8+
9+
__all__ = [
10+
"FlamingoVisionEncoderModel",
11+
"VisionEncoderConfig",
12+
]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass, field
8+
from typing import Optional
9+
10+
import torch
11+
12+
from executorch.examples.models.model_base import EagerModelBase
13+
from executorch.extension.llm.modules._position_embeddings import (
14+
replace_tile_positional_embedding,
15+
)
16+
from torchtune.models.flamingo._component_builders import flamingo_vision_encoder
17+
18+
max_seq_len = 8192
19+
in_channels = 3
20+
tile_size = 560
21+
max_num_tiles = 4
22+
# how many tokens per image generated by the vision encoder
23+
tokens_per_image = 6404
24+
# how many images to cache in the kv cache in cross attention
25+
kv_cache_image_num = 1
26+
# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention
27+
encoder_max_seq_len = tokens_per_image * kv_cache_image_num
28+
29+
30+
@dataclass
31+
class VisionEncoderConfig:
32+
patch_size: int = 14
33+
num_heads: int = 16
34+
clip_embed_dim: int = 1280
35+
clip_num_layers: int = 32
36+
clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30])
37+
decoder_embed_dim: int = 4096
38+
num_layers_projection: int = 8
39+
tile_size: int = 560
40+
max_num_tiles: int = 4
41+
in_channels: int = 3
42+
43+
44+
class FlamingoVisionEncoderModel(EagerModelBase):
45+
def __init__(self, config: Optional[VisionEncoderConfig] = None):
46+
super().__init__()
47+
if config is None:
48+
config = VisionEncoderConfig()
49+
self.config = config
50+
self.model = flamingo_vision_encoder(
51+
patch_size=config.patch_size,
52+
num_heads=config.num_heads,
53+
clip_embed_dim=config.clip_embed_dim,
54+
clip_num_layers=config.clip_num_layers,
55+
clip_hidden_states=config.clip_hidden_states,
56+
decoder_embed_dim=config.decoder_embed_dim,
57+
num_layers_projection=config.num_layers_projection,
58+
tile_size=config.tile_size,
59+
max_num_tiles=config.max_num_tiles,
60+
in_channels=config.in_channels,
61+
)
62+
self.model = replace_tile_positional_embedding(self.model)
63+
self.image = torch.randn(
64+
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
65+
)
66+
self.aspect_ratio = torch.tensor([[[1, 2]]])
67+
self.sample_inputs = (
68+
self.image,
69+
self.aspect_ratio,
70+
)
71+
72+
def get_eager_model(self, **kwargs):
73+
return self.model
74+
75+
def get_example_inputs(self):
76+
return self.sample_inputs
77+
78+
def get_dynamic_shapes(self):
79+
dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles)
80+
image_dynamic_dim = {
81+
0: 1,
82+
1: 1,
83+
2: dim,
84+
3: 3,
85+
4: self.config.tile_size,
86+
5: self.config.tile_size,
87+
}
88+
return (image_dynamic_dim, None)

examples/models/llama3_2_vision/vision_encoder/test/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
8+
# Only test AOTI in this file
9+
import os
10+
import tempfile
11+
import unittest
12+
13+
import torch
14+
15+
from executorch.examples.models.llama3_2_vision.vision_encoder import (
16+
FlamingoVisionEncoderModel,
17+
)
18+
from torch._inductor.package import package_aoti
19+
20+
21+
class FlamingoVisionEncoderTest(unittest.TestCase):
22+
def setUp(self) -> None:
23+
super().setUp()
24+
25+
def test_flamingo_vision_encoder(self) -> None:
26+
model = FlamingoVisionEncoderModel()
27+
encoder = model.model
28+
eager_res = encoder.forward(*model.get_example_inputs())
29+
30+
# AOTI
31+
so = torch._export.aot_compile(
32+
encoder,
33+
model.get_example_inputs(),
34+
options={"aot_inductor.package": True},
35+
dynamic_shapes=model.get_dynamic_shapes(),
36+
)
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
path = package_aoti(os.path.join(tmpdir, "vision_encoder.pt2"), so)
39+
print(path)
40+
encoder_aoti = torch._inductor.aoti_load_package(path)
41+
42+
y = encoder_aoti(*model.get_example_inputs())
43+
44+
self.assertTrue(torch.allclose(y, eager_res))

pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ addopts =
1616
devtools/
1717
# examples
1818
examples/models/llama/tests
19-
examples/models/llama3_2_vision/preprocess
19+
examples/models/llama3_2_vision/preprocess/test
20+
examples/models/llama3_2_vision/vision_encoder/test
2021
# examples/models/llava/test TODO: enable this
2122
# exir
2223
exir/_serialize/test

0 commit comments

Comments
 (0)