|
26 | 26 | from executorch.exir._serialize._program import deserialize_pte_binary |
27 | 27 | from executorch.exir.backend.backend_api import to_backend |
28 | 28 | from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult |
| 29 | +from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import ( |
| 30 | + ExecutorBackendPartitioner, |
| 31 | +) |
29 | 32 | from executorch.exir.dialects._ops import ops as exir_ops |
30 | 33 | from executorch.exir.emit import emit_program # noqa |
31 | 34 | from executorch.exir.error import InternalError |
|
60 | 63 | from functorch.experimental import control_flow |
61 | 64 | from torch import nn |
62 | 65 |
|
63 | | -from torch.export import Dim, export |
| 66 | +from torch.export import Dim, export, export_for_training |
64 | 67 |
|
65 | 68 |
|
66 | 69 | class WrapperModule(torch.nn.Module): |
@@ -1660,3 +1663,52 @@ def forward(self, x): |
1660 | 1663 | ] |
1661 | 1664 | self.assertEqual(external_map["linear.weight"], 0) |
1662 | 1665 | self.assertEqual(external_map["linear.bias"], 1) |
| 1666 | + |
| 1667 | + def test_delegate_deduplicate(self) -> None: |
| 1668 | + class SharedModule(torch.nn.Module): |
| 1669 | + def __init__(self): |
| 1670 | + super().__init__() |
| 1671 | + self.linear = torch.nn.Linear(2, 2) |
| 1672 | + |
| 1673 | + def forward(self, x): |
| 1674 | + return self.linear(x) |
| 1675 | + |
| 1676 | + |
| 1677 | + class Module1(torch.nn.Module): |
| 1678 | + def __init__(self, shared_module): |
| 1679 | + super().__init__() |
| 1680 | + self.shared_module = shared_module |
| 1681 | + |
| 1682 | + def forward(self, x): |
| 1683 | + return self.shared_module(x) |
| 1684 | + |
| 1685 | + |
| 1686 | + class Module2(torch.nn.Module): |
| 1687 | + def __init__(self, shared_module): |
| 1688 | + super().__init__() |
| 1689 | + self.shared_module = shared_module |
| 1690 | + |
| 1691 | + def forward(self, x): |
| 1692 | + return self.shared_module(x) |
| 1693 | + |
| 1694 | + shared_module = SharedModule() |
| 1695 | + module_1 = Module1(shared_module) |
| 1696 | + module_2 = Module2(shared_module) |
| 1697 | + example_inputs = (torch.randn(2, 2),) |
| 1698 | + module_1(*example_inputs) |
| 1699 | + module_2(*example_inputs) |
| 1700 | + |
| 1701 | + ep1 = export_for_training(module_1, example_inputs) |
| 1702 | + ep2 = export_for_training(module_2, example_inputs) |
| 1703 | + |
| 1704 | + edge_program_manager = exir.to_edge( |
| 1705 | + {"forward1": ep1, "forward2": ep2}, |
| 1706 | + compile_config=exir.EdgeCompileConfig( |
| 1707 | + _check_ir_validity=False, _use_edge_ops=True |
| 1708 | + ), |
| 1709 | + ) |
| 1710 | + |
| 1711 | + edge_program_manager = edge_program_manager.to_backend(ExecutorBackendPartitioner()).to_executorch() |
| 1712 | + |
| 1713 | + # Check that there is only one delegate because two methods are exactly the same |
| 1714 | + self.assertEqual(len(edge_program_manager.executorch_program.backend_delegate_data), 1) |
0 commit comments