Skip to content

Commit 7711382

Browse files
committed
Revert "Fix quantization test import error"
This reverts commit 983cc1f.
1 parent 983cc1f commit 7711382

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/callbacks/test_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818
import torch
19-
from torchmetrics.functional.regression.mean_absolute_percentage_error import mean_absolute_percentage_error
19+
from torchmetrics.functional import mean_relative_error
2020

2121
from pytorch_lightning import seed_everything, Trainer
2222
from pytorch_lightning.callbacks import QuantizationAwareTraining
@@ -42,7 +42,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
4242
trainer = Trainer(**trainer_args)
4343
trainer.fit(model, datamodule=dm)
4444
org_size = get_model_size_mb(model)
45-
org_score = torch.mean(torch.tensor([mean_absolute_percentage_error(model(x), y) for x, y in dm.test_dataloader()]))
45+
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))
4646

4747
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
4848
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
@@ -51,7 +51,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
5151

5252
quant_calls = qcb._forward_calls
5353
assert quant_calls == qcb._forward_calls
54-
quant_score = torch.mean(torch.tensor([mean_absolute_percentage_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
54+
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
5555
# test that the test score is almost the same as with pure training
5656
assert torch.allclose(org_score, quant_score, atol=0.45)
5757
model_path = trainer.checkpoint_callback.best_model_path
@@ -70,7 +70,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
7070

7171
# todo: make it work also with strict loading
7272
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
73-
quant2_score = torch.mean(torch.tensor([mean_absolute_percentage_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
73+
quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
7474
assert torch.allclose(org_score, quant2_score, atol=0.45)
7575

7676

0 commit comments

Comments
 (0)