Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 35db038

Browse files
authored
Move export_aoti into export.py (#983)
* Move export_aoti into export + minor tidyness * Lint * Remove mismatched arg
1 parent e2721d2 commit 35db038

File tree

2 files changed

+40
-37
lines changed

2 files changed

+40
-37
lines changed

export.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import argparse
88
import os
9+
from typing import Optional
910

1011
import torch
12+
import torch.nn as nn
1113

1214
from build.builder import (
1315
_initialize_model,
@@ -20,7 +22,8 @@
2022

2123
from build.utils import set_backend, set_precision
2224
from cli import add_arguments_for_verb, arg_init, check_args
23-
from export_util.export_aoti import export_model as export_model_aoti
25+
26+
from torch.export import Dim
2427

2528
try:
2629
executorch_export_available = True
@@ -33,6 +36,41 @@
3336
default_device = "cpu"
3437

3538

39+
def export_for_server(
40+
model: nn.Module, device: Optional[str] = "cpu", output_path: str = "model.dso"
41+
) -> str:
42+
"""
43+
Export the model using AOT Compile to get a .dso for server use cases.
44+
45+
Args:
46+
model: The model to be exported.
47+
device: The device to run the model on.
48+
output_path: The path to save the exported model.
49+
Returns:
50+
The path to the exported model.
51+
"""
52+
max_seq_length = 350
53+
54+
input = (
55+
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
56+
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
57+
)
58+
59+
seq = Dim("seq", min=1, max=max_seq_length)
60+
# Specify that the first dimension of each input is that batch size
61+
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
62+
63+
model.to(device)
64+
so = torch._export.aot_compile(
65+
model,
66+
args=input,
67+
options={"aot_inductor.output_path": output_path},
68+
dynamic_shapes=dynamic_shapes,
69+
)
70+
print(f"The generated DSO model can be found at: {so}")
71+
return so
72+
73+
3674
def main(args):
3775
builder_args = BuilderArgs.from_args(args)
3876
quantize = args.quantize
@@ -107,7 +145,7 @@ def main(args):
107145
if output_dso_path:
108146
output_dso_path = str(os.path.abspath(output_dso_path))
109147
print(f"Exporting model using AOT Inductor to {output_dso_path}")
110-
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
148+
export_for_server(model_to_dso, builder_args.device, output_dso_path)
111149

112150

113151
if __name__ == "__main__":

export_util/export_aoti.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)