diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 740ba073f329..344ed84034db 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -40,6 +40,7 @@ ) from transformers.testing_utils import ( CaptureLogger, + is_flaky, require_accelerate, require_flash_attn, require_flash_attn_3, @@ -732,6 +733,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): self._check_generate_outputs(output, model.config, use_cache=True) @pytest.mark.generate + @is_flaky def test_prompt_lookup_decoding_matches_greedy_search(self): # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. # This test is mostly a copy of test_assisted_decoding_matches_greedy_search