-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Labels
staleHas not had recent activityHas not had recent activity
Description
We have built a terraform script that spins up 4 VMs and uses a v3-32 TPU for Resnet50 training. We store the Imagenet training and validation data in a GCS bucket. Full code repo can be found here
- we use the
torch_xla.distributed.xla_dist - as well as the
test_train_mp_imagenet.py(only altering it to use our GCS data loader)
For the questions below, I've attached a log file (with metrics_debug), and used the following configuration:
- VM machine types =
n2-custom (72 vCPUs, 512 GB memory) NUM_EPOCHS=20BATCH_SIZE=512TEST_BATCH_SIZE=64NUM_WORKERS=8log_steps=200--conda-env=torch-xla-1.7--env XLA_USE_BF16=1-
- default learning rate and schedule
Questions
- Not sure what baseline to compare with, but epoch training time seems to be around 5-6 minutes.
- This is true for 8 workers on batch sizes of 128, 256, and 512. (batch size of 128 with 32 workers seems to be low 4 minutes per epoch).
- Is there anything from a code or configuration perspective we could do to improve this? 32 workers seems like overkill, but we've seen better results with this?
- Sometime we will get
BrokenPipeError: [Errno 32] Broken pipeorunhealthy mesherrors and training will automatically restart (see line 20689 in log file for Broken Pipe Error during Epoch 13).- Is there anything we can do to overcome this?
Metadata
Metadata
Assignees
Labels
staleHas not had recent activityHas not had recent activity