1616
1717import pytest
1818import torch
19- from torchmetrics .functional . regression . mean_absolute_percentage_error import mean_absolute_percentage_error
19+ from torchmetrics .functional import mean_relative_error
2020
2121from pytorch_lightning import seed_everything , Trainer
2222from pytorch_lightning .callbacks import QuantizationAwareTraining
@@ -42,7 +42,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
4242 trainer = Trainer (** trainer_args )
4343 trainer .fit (model , datamodule = dm )
4444 org_size = get_model_size_mb (model )
45- org_score = torch .mean (torch .tensor ([mean_absolute_percentage_error (model (x ), y ) for x , y in dm .test_dataloader ()]))
45+ org_score = torch .mean (torch .tensor ([mean_relative_error (model (x ), y ) for x , y in dm .test_dataloader ()]))
4646
4747 fusing_layers = [(f"layer_{ i } " , f"layer_{ i } a" ) for i in range (3 )] if fuse else None
4848 qcb = QuantizationAwareTraining (observer_type = observe , modules_to_fuse = fusing_layers , quantize_on_fit_end = convert )
@@ -51,7 +51,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
5151
5252 quant_calls = qcb ._forward_calls
5353 assert quant_calls == qcb ._forward_calls
54- quant_score = torch .mean (torch .tensor ([mean_absolute_percentage_error (qmodel (x ), y ) for x , y in dm .test_dataloader ()]))
54+ quant_score = torch .mean (torch .tensor ([mean_relative_error (qmodel (x ), y ) for x , y in dm .test_dataloader ()]))
5555 # test that the test score is almost the same as with pure training
5656 assert torch .allclose (org_score , quant_score , atol = 0.45 )
5757 model_path = trainer .checkpoint_callback .best_model_path
@@ -70,7 +70,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
7070
7171 # todo: make it work also with strict loading
7272 qmodel2 = RegressionModel .load_from_checkpoint (model_path , strict = False )
73- quant2_score = torch .mean (torch .tensor ([mean_absolute_percentage_error (qmodel2 (x ), y ) for x , y in dm .test_dataloader ()]))
73+ quant2_score = torch .mean (torch .tensor ([mean_relative_error (qmodel2 (x ), y ) for x , y in dm .test_dataloader ()]))
7474 assert torch .allclose (org_score , quant2_score , atol = 0.45 )
7575
7676
0 commit comments