Skip to content

Commit c0225ef

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Make test precision stricter for Classification (#6380)
Summary: * Make test precision stricter for Classification * Update classification threshold. * Update quantized classification threshold. Reviewed By: datumbox Differential Revision: D38824223 fbshipit-source-id: aa5adbf9fa7d55c0343c97cbe162c40a7ca0f984
1 parent a48bd0c commit c0225ef

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def test_classification_model(model_fn, dev):
614614
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
615615
x = torch.rand(input_shape).to(device=dev)
616616
out = model(x)
617-
_assert_expected(out.cpu(), model_name, prec=0.1)
617+
_assert_expected(out.cpu(), model_name, prec=1e-3)
618618
assert out.shape[-1] == num_classes
619619
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
620620
_check_fx_compatible(model, x, eager_out=out)
@@ -841,7 +841,7 @@ def test_video_model(model_fn, dev):
841841
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
842842
x = torch.rand(input_shape).to(device=dev)
843843
out = model(x)
844-
_assert_expected(out.cpu(), model_name, prec=0.1)
844+
_assert_expected(out.cpu(), model_name, prec=1e-5)
845845
assert out.shape[-1] == num_classes
846846
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
847847
_check_fx_compatible(model, x, eager_out=out)
@@ -884,7 +884,7 @@ def test_quantized_classification_model(model_fn):
884884
out = model(x)
885885

886886
if model_name not in quantized_flaky_models:
887-
_assert_expected(out, model_name + "_quantized", prec=0.1)
887+
_assert_expected(out, model_name + "_quantized", prec=2e-2)
888888
assert out.shape[-1] == 5
889889
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
890890
_check_fx_compatible(model, x, eager_out=out)

0 commit comments

Comments
 (0)