Skip to content

TPU: Crashes using trainer.test() #6230

@zcain117

Description

@zcain117

🐛 Bug

trainer.test() does not work with TPUs.

There are a few different ways we've seen it crash.

1. Looks like a call to barrier() coming from __test_using_best_weights

RuntimeError                              Traceback (most recent call last)
<ipython-input-17-587e2a9e3858> in <module>
----> 1 trainer.test(datamodule=dm)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path, verbose, datamodule)
    922             results = self.__test_given_model(model, test_dataloaders)
    923         else:
--> 924             results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
    925 
    926         self.teardown('test')

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in __test_using_best_weights(self, ckpt_path, test_dataloaders)
    950                 return {}
    951             if not self._device_type == DeviceType.TPU:
--> 952                 self.accelerator.barrier()
    953 
    954             ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in barrier(self, name)
    375 
    376     def barrier(self, name: Optional[str] = None) -> None:
--> 377         self.training_type_plugin.barrier(name=name)
    378 
    379     def broadcast(self, obj: object, src: int = 0) -> object:

/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py in barrier(self, name)
    113 
    114     def barrier(self, name: Optional[str] = None) -> None:
--> 115         rendezvous(f"pl.Trainer.{name}")
    116 
    117     def transfer_distrib_spawn_state_on_fit_end(self, results):

/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py in rendezvous(tag, payload, replicas)
    859     ordinal `i` at position `i` in the returned tuple.
    860   """
--> 861   return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
    862 
    863 

RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:316 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds))  

So the barrier is coming from here. This is strange that barrier is being called - I think this means that if not self._device_type == DeviceType.TPU is mistakenly evaluating to True? I think pytorch lightning spins up 8 processes for 8 TPU cores, is it possible only some of them are evaluating to True?

Basically it seems like at least 1 process is not making it to this point, which means the other processes are waiting in the barrier and the meetup never happens so we get the RuntimeError shown.

2. Looks like a call to xm.save() is being misused:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py", line 103, in new_process
    self.transfer_distrib_spawn_state_on_fit_end(results)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py", line 129, in transfer_distrib_spawn_state_on_fit_end
    xm.save(self.lightning_module.state_dict(), last_path)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 817, in save
    rendezvous('torch_xla.core.xla_model.save')
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 861, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py", line 103, in new_process
    self.transfer_distrib_spawn_state_on_fit_end(results)
  File "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/tpu_spawn.py", line 129, in transfer_distrib_spawn_state_on_fit_end
    xm.save(self.lightning_module.state_dict(), last_path)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 817, in save
    rendezvous('torch_xla.core.xla_model.save')
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.py", line 861, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
Exception in device=TPU:6: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)Exception in device=TPU:3: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)

I think the problem is here with the usage of xm.save().

xm.save() already handles the multiprocess case by checking the ordinal and only writing to disk if the process is on the master ordinal. In general, if you surround xm.save() with if statements, it means some TPU cores enter the if statement and some will not, so the cores that entered the if statement will be waiting for those that didn't enter and eventually it will time out and crash.

Repro methods

1. (Colab) Make 3 modifications to the BoringModel

  1. Switch runtime version to TPU
  2. Add this cell as the first cell:
VERSION = "1.7"  #@param ["1.7" , "20200516", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION
  1. Add tpu_cores=8 to the trainer cell

2. (Google Cloud) Use the attached repro.py file in the following way:

  1. create a TPU
  2. create a Google Cloud VM (I used e2-standard-32 but size shouldn't matter too much)
  3. SSH into VM
  4. (VM) conda activate torch-xla-1.7
  5. (VM) pip install pytorch-lightning==1.2.1
  6. (VM) export TPU_IP_ADDRESS=my.tpu.ip.addr
  7. (VM) export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
  8. (VM) python3 repro.py

3. (Your CI setup) Modify TPU unit tests as follows:

  • Add a trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32)) after some call to trainer.fit
  • For example, I changed test_model_tpu_early_stop test to look like this:
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_tpu_early_stop(tmpdir):
    """Test if single TPU core training works"""

    # todo: Test on 8 cores - hanging.

    class CustomBoringModel(BoringModel):

        def validation_step(self, *args, **kwargs):
            out = super().validation_step(*args, **kwargs)
            self.log('val_loss', out['x'])
            return out

    tutils.reset_seed()
    model = CustomBoringModel()
    trainer = Trainer(
        callbacks=[EarlyStopping(monitor='val_loss')],
        default_root_dir=tmpdir,
        progress_bar_refresh_rate=0,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=2,
        tpu_cores=[1],
    )
    trainer.fit(model)
+    trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))

Ran tests with coverage run --source=pytorch_lightning -m pytest tests/models/test_tpu.py -v. This should allow testing on the CI framework

Environment

  • PyTorch Version (e.g., 1.0): 1.7
  • OS (e.g., Linux): Linux
  • Build command you used (if compiling from source): pip install pytorch-lightning==1.2.1 (note that earlier versions hang due to Hanging with TPUs on GCE VM #5841 )
  • Python version: 3.6
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

accelerator: tpuTensor Processing UnitbugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions