File tree Expand file tree Collapse file tree 3 files changed +16
-8
lines changed Expand file tree Collapse file tree 3 files changed +16
-8
lines changed Original file line number Diff line number Diff line change 77from torch .ao .quantization .quantize_pt2e import prepare_pt2e , convert_pt2e
88from torch .ao .quantization .quantizer import QuantizationSpec , Quantizer
99
10- from torch ._export import capture_pre_autograd_graph
1110from torch .testing ._internal .common_quantization import (
1211 NodeSpec as ns ,
1312 QuantizationTestCase ,
2524 QuantizationAnnotation ,
2625)
2726import copy
27+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
2828
2929
3030def _apply_weight_only_uint4_quant (model ):
@@ -203,10 +203,16 @@ def forward(self, x):
203203
204204 # program capture
205205 m = copy .deepcopy (m_eager )
206- m = capture_pre_autograd_graph (
207- m ,
208- example_inputs ,
209- )
206+ if TORCH_VERSION_AT_LEAST_2_5 :
207+ m = torch .export .texport_for_training (
208+ m ,
209+ example_inputs ,
210+ ).module ()
211+ else :
212+ m = torch ._export .capture_pre_autograd_graph (
213+ m ,
214+ example_inputs ,
215+ ).module ()
210216
211217 m = prepare_pt2e (m , quantizer )
212218 # Calibrate
Original file line number Diff line number Diff line change @@ -1484,11 +1484,13 @@ def forward(self, x):
14841484
14851485 # make sure it compiles
14861486 example_inputs = (x ,)
1487- from torch ._export import capture_pre_autograd_graph
14881487 # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
14891488 # we can re-enable this after non-functional IR is enabled in export
14901489 # model = torch.export.export(model, example_inputs).module()
1491- model = capture_pre_autograd_graph (model , example_inputs )
1490+ if TORCH_VERSION_AT_LEAST_2_5 :
1491+ model = torch .export .export_for_training (model , example_inputs ).module ()
1492+ else :
1493+ model = torch ._export .capture_pre_autograd_graph (model , example_inputs )
14921494 after_export = model (x )
14931495 self .assertTrue (torch .equal (after_export , ref ))
14941496 if api is _int8da_int8w_api :
Original file line number Diff line number Diff line change @@ -180,7 +180,7 @@ def _the_op_that_needs_to_be_preserved(...)
180180
181181 # after this, `_the_op_that_needs_to_be_preserved` will be preserved as
182182 # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
183- # torch.export.export / torch._export.capture_pre_autograd_graph
183+ # torch.export.export / torch._export.export_for_training
184184
185185 """
186186 from torch ._inductor .decomposition import register_decomposition
You can’t perform that action at this time.
0 commit comments