Skip to content

Commit 8677cba

Browse files
committed
update ddp test
1 parent df167ab commit 8677cba

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/plugins/test_ddp_plugin.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest import mock
15+
1416
import torch
1517
from torch.nn.parallel import DistributedDataParallel
1618

@@ -61,13 +63,16 @@ def on_train_start(self):
6163

6264

6365
@RunIf(min_gpus=4, special=True)
64-
def test_ddp_barrier_non_consecutive_device_ids(tmpdir):
65-
66+
@mock.patch("torch.distributed.barrier")
67+
def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir):
68+
""" Test correct usage of barriers when device ids do not start at 0 or are not consecutive. """
6669
model = BoringModel()
70+
gpus = [1, 3]
6771
trainer = Trainer(
6872
default_root_dir=tmpdir,
6973
max_steps=1,
70-
gpus=[1, 3],
74+
gpus=gpus,
7175
accelerator="ddp",
7276
)
7377
trainer.fit(model)
78+
barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]])

0 commit comments

Comments
 (0)