File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from unittest import mock
15+
1416import torch
1517from torch .nn .parallel import DistributedDataParallel
1618
@@ -61,13 +63,16 @@ def on_train_start(self):
6163
6264
6365@RunIf (min_gpus = 4 , special = True )
64- def test_ddp_barrier_non_consecutive_device_ids (tmpdir ):
65-
66+ @mock .patch ("torch.distributed.barrier" )
67+ def test_ddp_barrier_non_consecutive_device_ids (barrier_mock , tmpdir ):
68+ """ Test correct usage of barriers when device ids do not start at 0 or are not consecutive. """
6669 model = BoringModel ()
70+ gpus = [1 , 3 ]
6771 trainer = Trainer (
6872 default_root_dir = tmpdir ,
6973 max_steps = 1 ,
70- gpus = [ 1 , 3 ] ,
74+ gpus = gpus ,
7175 accelerator = "ddp" ,
7276 )
7377 trainer .fit (model )
78+ barrier_mock .assert_any_call (device_ids = [gpus [trainer .local_rank ]])
You can’t perform that action at this time.
0 commit comments