2828from pytorch_lightning .accelerators import CPUAccelerator
2929from pytorch_lightning .metrics .classification .accuracy import Accuracy
3030from pytorch_lightning .trainer .states import TrainerState
31- from pytorch_lightning .utilities import _APEX_AVAILABLE , _HOROVOD_AVAILABLE , _NATIVE_AMP_AVAILABLE
31+ from pytorch_lightning .utilities import _HOROVOD_AVAILABLE
3232from tests .helpers import BoringModel
3333from tests .helpers .advanced_models import BasicGAN
3434from tests .helpers .runif import RunIf
@@ -120,8 +120,7 @@ def test_horovod_multi_gpu(tmpdir):
120120
121121@pytest .mark .skip (reason = "Horovod has a problem with broadcast when using apex?" )
122122@pytest .mark .skipif (not _HOROVOD_NCCL_AVAILABLE , reason = "test requires Horovod with NCCL support" )
123- @RunIf (min_gpus = 2 , skip_windows = True )
124- @pytest .mark .skipif (not _APEX_AVAILABLE , reason = "test requires apex" )
123+ @RunIf (min_gpus = 2 , skip_windows = True , amp_apex = True )
125124def test_horovod_apex (tmpdir ):
126125 """Test Horovod with multi-GPU support using apex amp."""
127126 trainer_options = dict (
@@ -143,8 +142,7 @@ def test_horovod_apex(tmpdir):
143142
144143@pytest .mark .skip (reason = "Skip till Horovod fixes integration with Native torch.cuda.amp" )
145144@pytest .mark .skipif (not _HOROVOD_NCCL_AVAILABLE , reason = "test requires Horovod with NCCL support" )
146- @RunIf (min_gpus = 2 , skip_windows = True )
147- @pytest .mark .skipif (not _NATIVE_AMP_AVAILABLE , reason = "test requires torch.cuda.amp" )
145+ @RunIf (min_gpus = 2 , skip_windows = True , amp_native = True )
148146def test_horovod_amp (tmpdir ):
149147 """Test Horovod with multi-GPU support using native amp."""
150148 trainer_options = dict (
0 commit comments