@@ -64,22 +64,30 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
6464 strict = False ,
6565 )
6666
67- # aoti_path = torch._inductor.aot_compile(
68- # exported_model.module(),
69- # model.get_example_inputs(),
70- # )
67+ aoti_path = torch ._inductor .aot_compile (
68+ exported_model .module (),
69+ model .get_example_inputs (),
70+ )
7171
7272 edge_program = to_edge (
7373 exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
7474 )
7575 executorch_model = edge_program .to_executorch ()
7676
77+ # Re-export as ExecuTorch edits the ExportedProgram.
78+ exported_model = torch .export .export (
79+ model .get_eager_model (),
80+ model .get_example_inputs (),
81+ dynamic_shapes = model .get_dynamic_shapes (),
82+ strict = False ,
83+ )
84+
7785 return {
7886 "config" : config ,
7987 "reference_model" : reference_model ,
8088 "model" : model ,
8189 "exported_model" : exported_model ,
82- # "aoti_path": aoti_path,
90+ "aoti_path" : aoti_path ,
8391 "executorch_model" : executorch_model ,
8492 }
8593
@@ -265,11 +273,13 @@ def run_preprocess(
265273 ), f"Executorch model: expected { reference_ar } but got { et_ar .tolist ()} "
266274
267275 # Run aoti model and check it matches reference model.
268- # aoti_path = models["aoti_path"]
269- # aoti_model = torch._export.aot_load(aoti_path, "cpu")
270- # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
271- # self.assertTrue(torch.allclose(reference_image, aoti_image))
272- # self.assertEqual(reference_ar, aoti_ar.tolist())
276+ aoti_path = models ["aoti_path" ]
277+ aoti_model = torch ._export .aot_load (aoti_path , "cpu" )
278+ aoti_image , aoti_ar = aoti_model (image_tensor , inscribed_size , best_resolution )
279+ assert torch .allclose (reference_image , aoti_image )
280+ assert (
281+ reference_ar == aoti_ar .tolist ()
282+ ), f"AOTI model: expected { reference_ar } but got { aoti_ar .tolist ()} "
273283
274284 # This test setup mirrors the one in torchtune:
275285 # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
0 commit comments