-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
Let's say we are using ddp and there is single dataloader, the number of data points in a process is 140, and the batch size is 64.
When the PredictionWriter's write_on_epoch_end is called on that process, the sizes of predictions and batch_indices parameters are as follows:
len(predictions) == 1
len(predictions[0]) == 3
len(predictions[0][0]) == 64
len(predictions[0][1]) == 64
len(predictions[0][2]) == 12
len(batch_indices) == 1
len(batch_indices[0]) == 3
len(batch_indices[0][0]) == 12
len(batch_indices[0][1]) == 12
len(batch_indices[0][2]) == 12
Also the contents of batch_indices[0][0], batch_indices[0][1], and batch_indices[0][2] are the same. predictions are correctly different.
The source of batch_indices is trainer.predict_loop.epoch_batch_indices which is populated here:
dl_batch_indices is coming from PredictionEpochLoop._all_batch_indices
Which is filled in the _store_batch_indices method:
During debugging I see batch_sampler.batch_indices is in fact the last batch indices every time this method is called.
I couldn't dig deeper but I guess the sampler yields all batches before the program reaches to _store_batch_indices.
Environment
- PyTorch Lightning Version (e.g., 1.5.0): 1.5.3
- PyTorch Version (e.g., 1.10): 1.9.0
- Python version (e.g., 3.9): 3.8
- OS (e.g., Linux): Linux (Ubuntu)
- CUDA/cuDNN version: 11.3
- GPU models and configuration: 3x RTX 3080
- How you installed PyTorch (
conda,pip, source): Using NVIDIA's Docker Image - If compiling from source, the output of
torch.__config__.show(): - Any other relevant information: