|
6 | 6 |
|
7 | 7 | import argparse |
8 | 8 | import os |
| 9 | +from typing import Optional |
9 | 10 |
|
10 | 11 | import torch |
| 12 | +import torch.nn as nn |
11 | 13 |
|
12 | 14 | from build.builder import ( |
13 | 15 | _initialize_model, |
|
20 | 22 |
|
21 | 23 | from build.utils import set_backend, set_precision |
22 | 24 | from cli import add_arguments_for_verb, arg_init, check_args |
23 | | -from export_util.export_aoti import export_model as export_model_aoti |
| 25 | + |
| 26 | +from torch.export import Dim |
24 | 27 |
|
25 | 28 | try: |
26 | 29 | executorch_export_available = True |
|
33 | 36 | default_device = "cpu" |
34 | 37 |
|
35 | 38 |
|
| 39 | +def export_for_server( |
| 40 | + model: nn.Module, device: Optional[str] = "cpu", output_path: str = "model.dso" |
| 41 | +) -> str: |
| 42 | + """ |
| 43 | + Export the model using AOT Compile to get a .dso for server use cases. |
| 44 | +
|
| 45 | + Args: |
| 46 | + model: The model to be exported. |
| 47 | + device: The device to run the model on. |
| 48 | + output_path: The path to save the exported model. |
| 49 | + Returns: |
| 50 | + The path to the exported model. |
| 51 | + """ |
| 52 | + max_seq_length = 350 |
| 53 | + |
| 54 | + input = ( |
| 55 | + torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), |
| 56 | + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), |
| 57 | + ) |
| 58 | + |
| 59 | + seq = Dim("seq", min=1, max=max_seq_length) |
| 60 | + # Specify that the first dimension of each input is that batch size |
| 61 | + dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} |
| 62 | + |
| 63 | + model.to(device) |
| 64 | + so = torch._export.aot_compile( |
| 65 | + model, |
| 66 | + args=input, |
| 67 | + options={"aot_inductor.output_path": output_path}, |
| 68 | + dynamic_shapes=dynamic_shapes, |
| 69 | + ) |
| 70 | + print(f"The generated DSO model can be found at: {so}") |
| 71 | + return so |
| 72 | + |
| 73 | + |
36 | 74 | def main(args): |
37 | 75 | builder_args = BuilderArgs.from_args(args) |
38 | 76 | quantize = args.quantize |
@@ -107,7 +145,7 @@ def main(args): |
107 | 145 | if output_dso_path: |
108 | 146 | output_dso_path = str(os.path.abspath(output_dso_path)) |
109 | 147 | print(f"Exporting model using AOT Inductor to {output_dso_path}") |
110 | | - export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args) |
| 148 | + export_for_server(model_to_dso, builder_args.device, output_dso_path) |
111 | 149 |
|
112 | 150 |
|
113 | 151 | if __name__ == "__main__": |
|
0 commit comments