Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b714395
add xla environment class
awaelchli Jan 5, 2022
078a01f
add api reference
awaelchli Jan 5, 2022
64c57c4
integrate
awaelchli Jan 5, 2022
cee674b
use xenv
awaelchli Jan 5, 2022
f509dc9
remove properties
awaelchli Jan 5, 2022
7d192cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2022
ce427e5
test environment selection
awaelchli Jan 5, 2022
460349f
Merge branch 'master' into feature/xla_environment
awaelchli Feb 6, 2022
6df74bc
update
awaelchli Feb 6, 2022
35a7474
Merge branch 'master' into feature/xla_environment
awaelchli Feb 7, 2022
46cd7a3
notebooks
awaelchli Feb 7, 2022
ad7acc0
notebooks
awaelchli Feb 7, 2022
e5fae8f
update
awaelchli Feb 7, 2022
d084d51
update
awaelchli Feb 7, 2022
5dca2f8
test tests
awaelchli Feb 7, 2022
472e200
include test case
awaelchli Feb 7, 2022
1833b62
fix test
awaelchli Feb 7, 2022
dcf3ccb
fix
awaelchli Feb 13, 2022
970c1b0
Merge branch 'master' into feature/xla_environment
awaelchli Feb 13, 2022
1131bf7
reset
awaelchli Feb 14, 2022
32390a9
temp fix
kaushikb11 Feb 16, 2022
2fcdb64
Merge branch 'master' into feature/xla_environment
kaushikb11 Feb 18, 2022
7dcf6c4
Update
kaushikb11 Feb 18, 2022
d2700b8
Update
kaushikb11 Feb 18, 2022
ba54586
Update
kaushikb11 Feb 18, 2022
983a9e7
Update tests
kaushikb11 Feb 23, 2022
4c6f73a
Update tests
kaushikb11 Feb 24, 2022
1d16728
Update tests
kaushikb11 Feb 24, 2022
f1d9cd9
Merge branch 'master' into feature/xla_environment
awaelchli Apr 5, 2022
d3fac36
Merge branch 'master' into feature/xla_environment
awaelchli May 14, 2022
ac90b96
debug
awaelchli May 14, 2022
51c9239
select env
awaelchli May 14, 2022
e3bfbac
debug
awaelchli May 14, 2022
478a705
debug
awaelchli May 14, 2022
1f34ba7
debug
awaelchli May 14, 2022
cbbd80e
debug
awaelchli May 14, 2022
bf51487
remove
awaelchli May 14, 2022
976ee6c
format
awaelchli May 14, 2022
39e9aa8
add changelog
awaelchli May 15, 2022
8d1b7c9
fix test entry
awaelchli May 15, 2022
b8224b0
remove unused import
awaelchli May 15, 2022
f61133f
simplify
awaelchli May 15, 2022
c1943eb
update
awaelchli May 15, 2022
d76a493
Merge branch 'master' into refactor/tpu-rank
awaelchli Jun 22, 2022
b66ec28
update
awaelchli Jun 22, 2022
5948bbd
update
awaelchli Jun 22, 2022
96ae47a
update
awaelchli Jun 22, 2022
1ec44c9
Merge branch 'master' into refactor/tpu-rank
awaelchli Jul 11, 2022
2e3bf6d
update
awaelchli Jul 11, 2022
57a887c
update changelog
awaelchli Jul 11, 2022
c7c0ea9
Merge branch 'master' into refactor/tpu-rank
awaelchli Jul 14, 2022
8f54a67
Merge branch 'master' into refactor/tpu-rank
awaelchli Jul 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ local tputests = base.BaseTest {
profilers/test_xla_profiler.py \
accelerators/test_tpu.py \
models/test_tpu.py \
plugins/environments/test_xla_environment.py
plugins/environments/test_xla_environment.py \
utilities/test_xla_device_utils.py
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
coverage xml
Expand Down
8 changes: 8 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deprecated `get_progress_bar_dict` property from `LightningModule` ([#12839](https://github.com/PyTorchLightning/pytorch-lightning/pull/12839))


- Removed sanity check for multi-optimizer support with habana backends ([#13217](https://github.com/PyTorchLightning/pytorch-lightning/pull/13217))


Expand Down Expand Up @@ -302,6 +303,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `LightningModule.on_post_move_to_device` ([#13548](https://github.com/Lightning-AI/lightning/pull/13548))


- Removed `TPUSpawnStrategy.{tpu_local_core_rank,tpu_global_core_rank}` attributes in favor of `TPUSpawnStrategy.{local_rank,global_rank}` ([#11163](https://github.com/PyTorchLightning/pytorch-lightning/pull/11163))


- Removed `SingleTPUStrategy.{tpu_local_core_rank,tpu_global_core_rank}` attributes in favor of `SingleTPUStrategy.{local_rank,global_rank}`([#11163](https://github.com/PyTorchLightning/pytorch-lightning/pull/11163))



### Fixed


Expand Down
6 changes: 0 additions & 6 deletions src/pytorch_lightning/strategies/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ def __init__(
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)

self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0

@property
def is_distributed(self) -> bool:
Expand All @@ -63,9 +60,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
Expand Down
22 changes: 6 additions & 16 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def __init__(
precision_plugin=precision_plugin,
)
self.debug = debug
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0
self.start_method = "fork"

@property
Expand Down Expand Up @@ -152,12 +150,6 @@ def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
def configure_ddp(self) -> None:
pass

def init_dist_connection(self, global_rank: int, world_size: int) -> None:
pass

def set_world_ranks(self, process_idx: int = 0) -> None:
pass

def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

Expand Down Expand Up @@ -203,9 +195,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

def _worker_setup(self, process_idx: int):
reset_seed()
self._local_rank = xm.get_local_ordinal()
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
Expand Down Expand Up @@ -237,7 +227,7 @@ def _pod_progress_bar_force_stdout(self) -> None:
# from different vms to the main worker doesn't work well with tqdm
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
# The print statement seems to force tqdm to flush stdout.
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
if self.global_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print()

def save_checkpoint(
Expand Down Expand Up @@ -276,6 +266,10 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)

def teardown(self) -> None:
super().teardown()
os.environ.pop("PT_XLA_DEBUG", None)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
Expand All @@ -287,7 +281,3 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
cls,
description=f"{cls.__class__.__name__}",
)

def teardown(self) -> None:
super().teardown()
os.environ.pop("PT_XLA_DEBUG", None)