|
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | import torch._export |
16 | | -from executorch.exir._serialize import _serialize_pte_binary |
17 | 16 | from executorch.exir._serialize._cord import Cord |
| 17 | +from executorch.exir._serialize._serialize import serialize |
| 18 | +from executorch.exir._serialize.data_serializer import DataSerializer |
18 | 19 | from executorch.exir._warnings import experimental |
19 | 20 | from executorch.exir.backend.backend_api import to_backend |
20 | 21 | from executorch.exir.backend.partitioner import Partitioner |
|
56 | 57 | EXIREdgeDialectVerifier, |
57 | 58 | get_aten_verifier, |
58 | 59 | ) |
| 60 | +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer |
59 | 61 | from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass |
60 | 62 | from torch.export import ExportedProgram |
61 | 63 | from torch.export._remove_auto_functionalized_pass import ( |
@@ -494,23 +496,23 @@ def __init__( |
494 | 496 | ) |
495 | 497 | self.exported_program = exir_exported_program.exported_program |
496 | 498 | self._pte_data: Optional[Cord] = None |
| 499 | + self._data_files: Optional[Dict[str, Cord]] = None |
497 | 500 | self._buffer: Optional[bytes] = None |
498 | 501 | self._emitter_output: Optional[EmitterOutput] = None |
499 | 502 | self._emit_stacktrace: bool = emit_stacktrace |
500 | 503 | self._extract_delegate_segments: bool = extract_delegate_segments |
501 | 504 | self._segment_alignment: int = segment_alignment |
502 | 505 | self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment |
503 | 506 | self._delegate_alignment: Optional[int] = delegate_alignment |
| 507 | + self._data_serializer: DataSerializer = FlatTensorSerializer() |
504 | 508 |
|
505 | 509 | def _get_pte_data(self) -> Cord: |
506 | 510 | if self._pte_data is None: |
507 | | - self._pte_data = _serialize_pte_binary( |
508 | | - program=self.program, |
509 | | - extract_delegate_segments=self._extract_delegate_segments, |
510 | | - segment_alignment=self._segment_alignment, |
511 | | - constant_tensor_alignment=self._constant_tensor_alignment, |
512 | | - delegate_alignment=self._delegate_alignment, |
| 511 | + assert self._emitter_output is not None |
| 512 | + self._pte_data, self._data_files = serialize( |
| 513 | + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer |
513 | 514 | ) |
| 515 | + assert self._pte_data is not None |
514 | 516 | return self._pte_data |
515 | 517 |
|
516 | 518 | @property |
@@ -1443,14 +1445,11 @@ def __init__( |
1443 | 1445 | self._config_methods, |
1444 | 1446 | ) |
1445 | 1447 |
|
| 1448 | + self._data_serializer = FlatTensorSerializer() |
| 1449 | + |
1446 | 1450 | # Serialize emitter output, ready to be written to a file. |
1447 | | - self._pte_data: Cord = _serialize_pte_binary( |
1448 | | - program=self._emitter_output.program, |
1449 | | - mutable_data=self._emitter_output.mutable_data, |
1450 | | - extract_delegate_segments=backend_config.extract_delegate_segments, |
1451 | | - segment_alignment=backend_config.segment_alignment, |
1452 | | - constant_tensor_alignment=backend_config.constant_tensor_alignment, |
1453 | | - delegate_alignment=backend_config.delegate_alignment, |
| 1451 | + self._pte_data, self._data_files = serialize( |
| 1452 | + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer |
1454 | 1453 | ) |
1455 | 1454 | self._buffer: Optional[bytes] = None |
1456 | 1455 |
|
@@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: |
1532 | 1531 | reducing the peak memory usage. |
1533 | 1532 | """ |
1534 | 1533 | self._pte_data.write_to_file(open_file) |
| 1534 | + |
| 1535 | + for filename, cord in self._data_files.items(): |
| 1536 | + filename = filename + ".ptd" |
| 1537 | + with open(filename, "wb") as file: |
| 1538 | + cord.write_to_file(file) |
0 commit comments