Skip to content

Commit d2982d5

Browse files
authored
training ir torchao migration
Differential Revision: D63859678 Pull Request resolved: #1006
1 parent 5dd0132 commit d2982d5

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

test/dtypes/test_uint4.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
88
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
99

10-
from torch._export import capture_pre_autograd_graph
1110
from torch.testing._internal.common_quantization import (
1211
NodeSpec as ns,
1312
QuantizationTestCase,
@@ -25,6 +24,7 @@
2524
QuantizationAnnotation,
2625
)
2726
import copy
27+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2828

2929

3030
def _apply_weight_only_uint4_quant(model):
@@ -203,10 +203,16 @@ def forward(self, x):
203203

204204
# program capture
205205
m = copy.deepcopy(m_eager)
206-
m = capture_pre_autograd_graph(
207-
m,
208-
example_inputs,
209-
)
206+
if TORCH_VERSION_AT_LEAST_2_5:
207+
m = torch.export.texport_for_training(
208+
m,
209+
example_inputs,
210+
).module()
211+
else:
212+
m = torch._export.capture_pre_autograd_graph(
213+
m,
214+
example_inputs,
215+
).module()
210216

211217
m = prepare_pt2e(m, quantizer)
212218
# Calibrate

test/integration/test_integration.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,11 +1484,13 @@ def forward(self, x):
14841484

14851485
# make sure it compiles
14861486
example_inputs = (x,)
1487-
from torch._export import capture_pre_autograd_graph
14881487
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
14891488
# we can re-enable this after non-functional IR is enabled in export
14901489
# model = torch.export.export(model, example_inputs).module()
1491-
model = capture_pre_autograd_graph(model, example_inputs)
1490+
if TORCH_VERSION_AT_LEAST_2_5:
1491+
model = torch.export.export_for_training(model, example_inputs).module()
1492+
else:
1493+
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
14921494
after_export = model(x)
14931495
self.assertTrue(torch.equal(after_export, ref))
14941496
if api is _int8da_int8w_api:

torchao/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
180180
181181
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
182182
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
183-
# torch.export.export / torch._export.capture_pre_autograd_graph
183+
# torch.export.export / torch._export.export_for_training
184184
185185
"""
186186
from torch._inductor.decomposition import register_decomposition

0 commit comments

Comments
 (0)