Skip to content

Distributed TPU Training, training data stored in GCS #2690

@tottenjordan

Description

@tottenjordan

We have built a terraform script that spins up 4 VMs and uses a v3-32 TPU for Resnet50 training. We store the Imagenet training and validation data in a GCS bucket. Full code repo can be found here

  • we use the torch_xla.distributed.xla_dist
  • as well as the test_train_mp_imagenet.py (only altering it to use our GCS data loader)

For the questions below, I've attached a log file (with metrics_debug), and used the following configuration:

  • VM machine types = n2-custom (72 vCPUs, 512 GB memory)
  • NUM_EPOCHS=20
  • BATCH_SIZE=512
  • TEST_BATCH_SIZE=64
  • NUM_WORKERS=8
  • log_steps=200
  • --conda-env=torch-xla-1.7
  • --env XLA_USE_BF16=1
    • default learning rate and schedule

Questions

  • Not sure what baseline to compare with, but epoch training time seems to be around 5-6 minutes.
    • This is true for 8 workers on batch sizes of 128, 256, and 512. (batch size of 128 with 32 workers seems to be low 4 minutes per epoch).
    • Is there anything from a code or configuration perspective we could do to improve this? 32 workers seems like overkill, but we've seen better results with this?
  • Sometime we will get BrokenPipeError: [Errno 32] Broken pipe or unhealthy mesh errors and training will automatically restart (see line 20689 in log file for Broken Pipe Error during Epoch 13).
    • Is there anything we can do to overcome this?

imagenetraw_logfiles4-v3-32-512batch-8workers.txt

@zcain117
@shanemhansen

Metadata

Metadata

Assignees

Labels

staleHas not had recent activity

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions