33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6-
76import unittest
87
9- from dataclasses import dataclass
10- from typing import List , Optional , Tuple
8+ from typing import Any , Dict , List , Tuple
119
1210import numpy as np
1311import PIL
1412import torch
1513
14+ # Import these first. Otherwise, the custom ops are not registered.
1615from executorch .extension .pybindings import portable_lib # noqa # usort: skip
17- from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa # usort: skip
18- from executorch . examples . models . llama3_2_vision . preprocess . export_preprocess_lib import (
19- export_preprocess ,
20- get_example_inputs ,
21- lower_to_executorch_preprocess ,
16+ from executorch .extension .llm .custom_ops import op_tile_crop_aot # noqa # usort: skip
17+
18+ from executorch . examples . models . llama3_2_vision . preprocess . model import (
19+ CLIPImageTransformModel ,
20+ PreprocessConfig ,
2221)
22+
23+ from executorch .exir import EdgeCompileConfig , to_edge
24+
2325from executorch .extension .pybindings .portable_lib import (
2426 _load_for_executorch_from_buffer ,
2527)
2628
27- from parameterized import parameterized
2829from PIL import Image
2930
30- from torchtune .models .clip .inference ._transform import (
31- _CLIPImageTransform ,
32- CLIPImageTransform ,
33- )
31+ from torchtune .models .clip .inference ._transform import CLIPImageTransform
3432
3533from torchtune .modules .transforms .vision_utils .get_canvas_best_fit import (
3634 find_supported_resolutions ,
4341from torchvision .transforms .v2 import functional as F
4442
4543
46- @dataclass
47- class PreprocessConfig :
48- image_mean : Optional [List [float ]] = None
49- image_std : Optional [List [float ]] = None
50- resize_to_max_canvas : bool = True
51- resample : str = "bilinear"
52- antialias : bool = False
53- tile_size : int = 224
54- max_num_tiles : int = 4
55- possible_resolutions = None
56-
57-
5844class TestImageTransform (unittest .TestCase ):
5945 """
6046 This unittest checks that the exported image transform model produces the
@@ -66,6 +52,53 @@ class TestImageTransform(unittest.TestCase):
6652 https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26
6753 """
6854
55+ def initialize_models (self , resize_to_max_canvas : bool ) -> Dict [str , Any ]:
56+ config = PreprocessConfig (resize_to_max_canvas = resize_to_max_canvas )
57+
58+ reference_model = CLIPImageTransform (
59+ image_mean = config .image_mean ,
60+ image_std = config .image_std ,
61+ resize_to_max_canvas = config .resize_to_max_canvas ,
62+ resample = config .resample ,
63+ antialias = config .antialias ,
64+ tile_size = config .tile_size ,
65+ max_num_tiles = config .max_num_tiles ,
66+ possible_resolutions = None ,
67+ )
68+
69+ model = CLIPImageTransformModel (config )
70+
71+ exported_model = torch .export .export (
72+ model .get_eager_model (),
73+ model .get_example_inputs (),
74+ dynamic_shapes = model .get_dynamic_shapes (),
75+ strict = False ,
76+ )
77+
78+ # aoti_path = torch._inductor.aot_compile(
79+ # exported_model.module(),
80+ # model.get_example_inputs(),
81+ # )
82+
83+ edge_program = to_edge (
84+ exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
85+ )
86+ executorch_model = edge_program .to_executorch ()
87+
88+ return {
89+ "config" : config ,
90+ "reference_model" : reference_model ,
91+ "model" : model ,
92+ "exported_model" : exported_model ,
93+ # "aoti_path": aoti_path,
94+ "executorch_model" : executorch_model ,
95+ }
96+
97+ @classmethod
98+ def setUpClass (cls ):
99+ cls .models_no_resize = cls .initialize_models (resize_to_max_canvas = False )
100+ cls .models_resize = cls .initialize_models (resize_to_max_canvas = True )
101+
69102 def setUp (self ):
70103 np .random .seed (0 )
71104
@@ -121,51 +154,7 @@ def prepare_inputs(
121154
122155 return image_tensor , inscribed_size , best_resolution
123156
124- # This test setup mirrors the one in torchtune:
125- # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
126- # The values are slightly different, as torchtune uses antialias=True,
127- # and this test uses antialias=False, which is exportable (has a portable kernel).
128- @parameterized .expand (
129- [
130- (
131- (100 , 400 , 3 ), # image_size
132- torch .Size ([2 , 3 , 224 , 224 ]), # expected shape
133- False , # resize_to_max_canvas
134- [0.2230 , 0.1763 ], # expected_tile_means
135- [1.0 , 1.0 ], # expected_tile_max
136- [0.0 , 0.0 ], # expected_tile_min
137- [1 , 2 ], # expected_aspect_ratio
138- ),
139- (
140- (1000 , 300 , 3 ), # image_size
141- torch .Size ([4 , 3 , 224 , 224 ]), # expected shape
142- True , # resize_to_max_canvas
143- [0.5005 , 0.4992 , 0.5004 , 0.1651 ], # expected_tile_means
144- [0.9976 , 0.9940 , 0.9936 , 0.9906 ], # expected_tile_max
145- [0.0037 , 0.0047 , 0.0039 , 0.0 ], # expected_tile_min
146- [4 , 1 ], # expected_aspect_ratio
147- ),
148- (
149- (200 , 200 , 3 ), # image_size
150- torch .Size ([4 , 3 , 224 , 224 ]), # expected shape
151- True , # resize_to_max_canvas
152- [0.5012 , 0.5020 , 0.5010 , 0.4991 ], # expected_tile_means
153- [0.9921 , 0.9925 , 0.9969 , 0.9908 ], # expected_tile_max
154- [0.0056 , 0.0069 , 0.0059 , 0.0032 ], # expected_tile_min
155- [2 , 2 ], # expected_aspect_ratio
156- ),
157- (
158- (600 , 200 , 3 ), # image_size
159- torch .Size ([3 , 3 , 224 , 224 ]), # expected shape
160- False , # resize_to_max_canvas
161- [0.4472 , 0.4468 , 0.3031 ], # expected_tile_means
162- [1.0 , 1.0 , 1.0 ], # expected_tile_max
163- [0.0 , 0.0 , 0.0 ], # expected_tile_min
164- [3 , 1 ], # expected_aspect_ratio
165- ),
166- ]
167- )
168- def test_preprocess (
157+ def run_preprocess (
169158 self ,
170159 image_size : Tuple [int ],
171160 expected_shape : torch .Size ,
@@ -175,45 +164,7 @@ def test_preprocess(
175164 expected_tile_min : List [float ],
176165 expected_ar : List [int ],
177166 ) -> None :
178- config = PreprocessConfig (resize_to_max_canvas = resize_to_max_canvas )
179-
180- reference_model = CLIPImageTransform (
181- image_mean = config .image_mean ,
182- image_std = config .image_std ,
183- resize_to_max_canvas = config .resize_to_max_canvas ,
184- resample = config .resample ,
185- antialias = config .antialias ,
186- tile_size = config .tile_size ,
187- max_num_tiles = config .max_num_tiles ,
188- possible_resolutions = None ,
189- )
190-
191- eager_model = _CLIPImageTransform (
192- image_mean = config .image_mean ,
193- image_std = config .image_std ,
194- resample = config .resample ,
195- antialias = config .antialias ,
196- tile_size = config .tile_size ,
197- max_num_tiles = config .max_num_tiles ,
198- )
199-
200- exported_model = export_preprocess (
201- image_mean = config .image_mean ,
202- image_std = config .image_std ,
203- resample = config .resample ,
204- antialias = config .antialias ,
205- tile_size = config .tile_size ,
206- max_num_tiles = config .max_num_tiles ,
207- )
208-
209- executorch_model = lower_to_executorch_preprocess (exported_model )
210- executorch_module = _load_for_executorch_from_buffer (executorch_model .buffer )
211-
212- aoti_path = torch ._inductor .aot_compile (
213- exported_model .module (),
214- get_example_inputs (),
215- )
216-
167+ models = self .models_resize if resize_to_max_canvas else self .models_no_resize
217168 # Prepare image input.
218169 image = (
219170 np .random .randint (0 , 256 , np .prod (image_size ))
@@ -223,6 +174,7 @@ def test_preprocess(
223174 image = PIL .Image .fromarray (image )
224175
225176 # Run reference model.
177+ reference_model = models ["reference_model" ]
226178 reference_output = reference_model (image = image )
227179 reference_image = reference_output ["image" ]
228180 reference_ar = reference_output ["aspect_ratio" ].tolist ()
@@ -249,10 +201,11 @@ def test_preprocess(
249201 # Pre-work for eager and exported models. The reference model performs these
250202 # calculations and passes the result to _CLIPImageTransform, the exportable model.
251203 image_tensor , inscribed_size , best_resolution = self .prepare_inputs (
252- image = image , config = config
204+ image = image , config = models [ " config" ]
253205 )
254206
255207 # Run eager model and check it matches reference model.
208+ eager_model = models ["model" ].get_eager_model ()
256209 eager_image , eager_ar = eager_model (
257210 image_tensor , inscribed_size , best_resolution
258211 )
@@ -261,6 +214,7 @@ def test_preprocess(
261214 self .assertEqual (reference_ar , eager_ar )
262215
263216 # Run exported model and check it matches reference model.
217+ exported_model = models ["exported_model" ]
264218 exported_image , exported_ar = exported_model .module ()(
265219 image_tensor , inscribed_size , best_resolution
266220 )
@@ -269,14 +223,65 @@ def test_preprocess(
269223 self .assertEqual (reference_ar , exported_ar )
270224
271225 # Run executorch model and check it matches reference model.
226+ executorch_model = models ["executorch_model" ]
227+ executorch_module = _load_for_executorch_from_buffer (executorch_model .buffer )
272228 et_image , et_ar = executorch_module .forward (
273229 (image_tensor , inscribed_size , best_resolution )
274230 )
275231 self .assertTrue (torch .allclose (reference_image , et_image ))
276232 self .assertEqual (reference_ar , et_ar .tolist ())
277233
278234 # Run aoti model and check it matches reference model.
279- aoti_model = torch ._export .aot_load (aoti_path , "cpu" )
280- aoti_image , aoti_ar = aoti_model (image_tensor , inscribed_size , best_resolution )
281- self .assertTrue (torch .allclose (reference_image , aoti_image ))
282- self .assertEqual (reference_ar , aoti_ar .tolist ())
235+ # aoti_path = models["aoti_path"]
236+ # aoti_model = torch._export.aot_load(aoti_path, "cpu")
237+ # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
238+ # self.assertTrue(torch.allclose(reference_image, aoti_image))
239+ # self.assertEqual(reference_ar, aoti_ar.tolist())
240+
241+ # This test setup mirrors the one in torchtune:
242+ # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
243+ # The values are slightly different, as torchtune uses antialias=True,
244+ # and this test uses antialias=False, which is exportable (has a portable kernel).
245+ def test_preprocess1 (self ):
246+ self .run_preprocess (
247+ (100 , 400 , 3 ), # image_size
248+ torch .Size ([2 , 3 , 224 , 224 ]), # expected shape
249+ False , # resize_to_max_canvas
250+ [0.2230 , 0.1763 ], # expected_tile_means
251+ [1.0 , 1.0 ], # expected_tile_max
252+ [0.0 , 0.0 ], # expected_tile_min
253+ [1 , 2 ], # expected_aspect_ratio
254+ )
255+
256+ def test_preprocess2 (self ):
257+ self .run_preprocess (
258+ (1000 , 300 , 3 ), # image_size
259+ torch .Size ([4 , 3 , 224 , 224 ]), # expected shape
260+ True , # resize_to_max_canvas
261+ [0.5005 , 0.4992 , 0.5004 , 0.1651 ], # expected_tile_means
262+ [0.9976 , 0.9940 , 0.9936 , 0.9906 ], # expected_tile_max
263+ [0.0037 , 0.0047 , 0.0039 , 0.0 ], # expected_tile_min
264+ [4 , 1 ], # expected_aspect_ratio
265+ )
266+
267+ def test_preprocess3 (self ):
268+ self .run_preprocess (
269+ (200 , 200 , 3 ), # image_size
270+ torch .Size ([4 , 3 , 224 , 224 ]), # expected shape
271+ True , # resize_to_max_canvas
272+ [0.5012 , 0.5020 , 0.5010 , 0.4991 ], # expected_tile_means
273+ [0.9921 , 0.9925 , 0.9969 , 0.9908 ], # expected_tile_max
274+ [0.0056 , 0.0069 , 0.0059 , 0.0032 ], # expected_tile_min
275+ [2 , 2 ], # expected_aspect_ratio
276+ )
277+
278+ def test_preprocess4 (self ):
279+ self .run_preprocess (
280+ (600 , 200 , 3 ), # image_size
281+ torch .Size ([3 , 3 , 224 , 224 ]), # expected shape
282+ False , # resize_to_max_canvas
283+ [0.4472 , 0.4468 , 0.3031 ], # expected_tile_means
284+ [1.0 , 1.0 , 1.0 ], # expected_tile_max
285+ [0.0 , 0.0 , 0.0 ], # expected_tile_min
286+ [3 , 1 ], # expected_aspect_ratio
287+ )
0 commit comments