Skip to content

batch_indices passed to PredictionWriter write_on_epoch_end is wrong when strategy=ddp #10782

@ozen

Description

@ozen

🐛 Bug

Let's say we are using ddp and there is single dataloader, the number of data points in a process is 140, and the batch size is 64.

When the PredictionWriter's write_on_epoch_end is called on that process, the sizes of predictions and batch_indices parameters are as follows:

len(predictions) == 1
len(predictions[0]) == 3
len(predictions[0][0]) == 64
len(predictions[0][1]) == 64
len(predictions[0][2]) == 12

len(batch_indices) == 1
len(batch_indices[0]) == 3
len(batch_indices[0][0]) == 12
len(batch_indices[0][1]) == 12
len(batch_indices[0][2]) == 12

Also the contents of batch_indices[0][0], batch_indices[0][1], and batch_indices[0][2] are the same. predictions are correctly different.

The source of batch_indices is trainer.predict_loop.epoch_batch_indices which is populated here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/3089dc3829c6456d74ab95aef06891927519eab9/pytorch_lightning/loops/dataloader/prediction_loop.py#L91-L95

dl_batch_indices is coming from PredictionEpochLoop._all_batch_indices

https://github.com/PyTorchLightning/pytorch-lightning/blob/a002f872eaad5a0f8f2069c49295cdda3568a406/pytorch_lightning/loops/epoch/prediction_epoch_loop.py#L102-L109

Which is filled in the _store_batch_indices method:

https://github.com/PyTorchLightning/pytorch-lightning/blob/a002f872eaad5a0f8f2069c49295cdda3568a406/pytorch_lightning/loops/epoch/prediction_epoch_loop.py#L163-L171

During debugging I see batch_sampler.batch_indices is in fact the last batch indices every time this method is called.

I couldn't dig deeper but I guess the sampler yields all batches before the program reaches to _store_batch_indices.

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): 1.5.3
  • PyTorch Version (e.g., 1.10): 1.9.0
  • Python version (e.g., 3.9): 3.8
  • OS (e.g., Linux): Linux (Ubuntu)
  • CUDA/cuDNN version: 11.3
  • GPU models and configuration: 3x RTX 3080
  • How you installed PyTorch (conda, pip, source): Using NVIDIA's Docker Image
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions