Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b42a40e
Fix TPU test CI
carmocca Sep 28, 2022
ebf46cd
+x first
carmocca Sep 28, 2022
5627c8b
Lite first to uncovert errors faster
carmocca Sep 28, 2022
05627a2
Fixes
carmocca Sep 28, 2022
9cdb886
One more
carmocca Sep 28, 2022
684e257
Simplify XLALauncher wrapping to avoid pickle error
carmocca Sep 28, 2022
dc7dad8
debug
awaelchli Sep 28, 2022
fc0bd77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2022
cbcf684
Debug commit successful. Trying local definitions
carmocca Sep 28, 2022
2dd01bb
Require tpu for mock test
carmocca Sep 28, 2022
d37217d
ValueError: The number of devices must be either 1 or 8, got 4 instead
carmocca Sep 28, 2022
a6d4505
Fix mock test
carmocca Sep 28, 2022
cefbeb5
Simplify call, rely on defaults
carmocca Sep 29, 2022
aa2ac10
Skip OSError for now. Maybe upgrading will help
carmocca Sep 29, 2022
f61e5f2
Simplify launch tests, move some to lite
carmocca Sep 29, 2022
47a5d80
Stricter typing
carmocca Sep 29, 2022
f65107e
RuntimeError: Accessing the XLA device before processes have spawned …
carmocca Sep 29, 2022
bf6839c
Revert "RuntimeError: Accessing the XLA device before processes have …
carmocca Sep 29, 2022
cd7ab95
Alternative boring solution to the reverted commit
carmocca Sep 29, 2022
5c45b74
Fix failing test on CUDA machine
carmocca Sep 29, 2022
6eefee6
Workarounds
carmocca Sep 29, 2022
d06813a
Try latest mkl
akihironitta Sep 29, 2022
9551df6
Revert "Try latest mkl"
akihironitta Sep 29, 2022
495249a
Wrong exception
carmocca Sep 29, 2022
def27b4
xfail
carmocca Sep 29, 2022
ece552e
Mypy
carmocca Sep 29, 2022
652e448
Comment change
carmocca Sep 29, 2022
8bebce3
Spawn launch refactor
carmocca Sep 29, 2022
47b3939
Merge branch 'master' into ci/fix-circleci
carmocca Sep 29, 2022
f558553
Accept that we cannot lazy init now
carmocca Sep 29, 2022
67e5165
Merge branch 'master' into ci/fix-circleci
awaelchli Sep 29, 2022
6977f91
Merge branch 'master' into ci/fix-circleci
carmocca Sep 29, 2022
a7d870a
Fix mypy and launch test failures
carmocca Sep 29, 2022
5ba7307
The base dockerfile already includes mkl-2022.1.0 - what if we use it?
carmocca Sep 30, 2022
ca8748c
Merge branch 'master' into ci/fix-circleci
awaelchli Sep 30, 2022
f1ad0d6
Merge branch 'master' into ci/fix-circleci
carmocca Sep 30, 2022
4f5af35
try a different mkl version
carmocca Sep 30, 2022
d095bcd
Revert mkl version changes
carmocca Sep 30, 2022
21aab4b
Merge branch 'master' into ci/fix-circleci
carmocca Oct 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ parameters:
GHA_Event:
type: string
default: ""
GHA_Meta:
type: string
default: ""

references:

Expand Down Expand Up @@ -49,9 +52,10 @@ references:
update_jsonnet: &update_jsonnet
run:
name: Update jsonnet
environment:
PR_NUMBER: << pipeline.parameters.GHA_Meta >>
command: |
export SHA=$(git rev-parse --short HEAD)
export PR_NUMBER=$(git ls-remote origin "pull/*/head" | grep -F -f $SHA | awk -F'/' '{print $3}')
python -c "fname = 'dockers/tpu-tests/tpu_test_cases.jsonnet' ; data = open(fname).read().replace('{PYTORCH_VERSION}', '$XLA_VER')
data = data.replace('{PYTHON_VERSION}', '$PYTHON_VER').replace('{PR_NUMBER}', '$PR_NUMBER').replace('{SHA}', '$SHA') ; open(fname, 'w').write(data)"
cat dockers/tpu-tests/tpu_test_cases.jsonnet
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci-circleci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ jobs:
- uses: CircleCI-Public/[email protected]
env:
CCI_TOKEN: ${{ secrets.CCI_TOKEN }}
with:
GHA_Meta: ${{ github.event.pull_request.number }}
49 changes: 30 additions & 19 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,40 @@ local tputests = base.BaseTest {

command: utils.scriptCommand(
|||
set +x # turn off tracing, spammy
set -e # exit on error

source ~/.bashrc
set -e
conda activate lightning
mkdir -p /home/runner/work/lightning && cd /home/runner/work/lightning
git clone https://github.com/Lightning-AI/lightning.git
cd lightning
echo $PWD
git ls-remote --refs origin
git fetch origin "refs/pull/{PR_NUMBER}/head"
git checkout {SHA}
export PACKAGE_NAME=pytorch
export FREEZE_REQUIREMENTS=1
pip install -e .[test]

echo "--- Fetch the SHA's changes ---"
git clone --single-branch --depth 1 https://github.com/Lightning-AI/lightning.git /home/runner/work/lightning
cd home/runner/work/lightning
git fetch origin --depth 1 pull/{PR_NUMBER}/head:test/{PR_NUMBER}
git -c advice.detachedHead=false checkout {SHA}

echo "--- Install PL ---"
PACKAGE_NAME=pytorch FREEZE_REQUIREMENTS=1 pip install -e .[test]
pip list

echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
export PL_RUN_TPU_TESTS=1
cd tests/tests_pytorch
coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
echo "\n||| Running standalone tests |||\n"
export PL_STANDALONE_TESTS_SOURCE=pytorch_lightning
export PL_STANDALONE_TESTS_BATCH_SIZE=1
bash run_standalone_tests.sh
echo "\n||| END PYTEST LOGS |||\n"

echo "--- Running Lite tests ---"
cd tests/tests_lite
PL_RUN_TPU_TESTS=1 coverage run --source=lightning_lite -m pytest -vv --durations=0 ./

echo "--- Running standalone Lite tests ---"
PL_STANDALONE_TESTS_SOURCE=lightning_lite PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh

echo "--- Running PL tests ---"
cd ../tests_pytorch
PL_RUN_TPU_TESTS=1 coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./

echo "--- Running standalone PL tests ---"
PL_STANDALONE_TESTS_SOURCE=pytorch_lightning PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh

echo "--- Generating coverage ---"
coverage xml
cat coverage.xml | tr -d '\t'
|||
Expand Down
45 changes: 15 additions & 30 deletions src/lightning_lite/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from functools import wraps
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING

from torch.multiprocessing import get_context, ProcessContext
from torch.multiprocessing import get_context

from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from lightning_lite.utilities import _TPU_AVAILABLE
Expand Down Expand Up @@ -67,7 +66,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""
context = get_context(self._start_method)
return_queue = context.SimpleQueue()
_save_spawn(
xmp.spawn(
self._wrapping_function,
args=(function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
Expand All @@ -90,30 +89,16 @@ def _wrapping_function(
if process_idx == 0:
return_queue.put(move_data_to_device(results, "cpu"))

_rank_teardown(process_idx)

def _save_spawn(
fn: Callable,
args: Tuple = (),
nprocs: Optional[int] = None,
join: bool = True,
daemon: bool = False,
start_method: str = "spawn",
) -> Optional[ProcessContext]:
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
processes."""

@wraps(fn)
def wrapped(rank: int, *_args: Any) -> None:
fn(rank, *_args)

import torch_xla.core.xla_model as xm

# Make all processes wait for each other before joining
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
xm.rendezvous("end-process")
# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if rank == 0:
time.sleep(1)

return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)

def _rank_teardown(rank: int) -> None:
import torch_xla.core.xla_model as xm

# Make all processes wait for each other before joining
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
xm.rendezvous("end-process")
# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if rank == 0:
time.sleep(1)
6 changes: 4 additions & 2 deletions src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.multiprocessing as mp

import pytorch_lightning as pl
from lightning_lite.strategies.launchers.xla import _save_spawn
from lightning_lite.strategies.launchers.xla import _rank_teardown
from lightning_lite.utilities import move_data_to_device
from pytorch_lightning.strategies.launchers.multiprocessing import (
_FakeQueue,
Expand Down Expand Up @@ -74,7 +74,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
"""
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
_save_spawn(
xmp.spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
Expand Down Expand Up @@ -106,6 +106,8 @@ def _wrapping_function(
if process_idx == 0:
return_queue.put(move_data_to_device(results, "cpu"))

_rank_teardown(process_idx)

def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
checkpoint_callback = trainer.checkpoint_callback
Expand Down
6 changes: 2 additions & 4 deletions tests/tests_lite/strategies/launchers/test_xla.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import mock
from unittest.mock import ANY, Mock
from unittest.mock import Mock

from tests_lite.helpers.runif import RunIf

Expand Down Expand Up @@ -29,11 +29,9 @@ def test_xla_launcher_xmp_spawn(get_context_mock, xmp_mock):
queue = get_context_mock.return_value.SimpleQueue.return_value
get_context_mock.assert_called_with("fork")
xmp_mock.spawn.assert_called_with(
ANY,
launcher._wrapping_function,
args=(function, ("positional-arg",), {"keyword_arg": 0}, queue),
nprocs=strategy.num_processes,
join=True,
daemon=False,
start_method="fork",
)
queue.get.assert_called_once_with()
2 changes: 2 additions & 0 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
import torch
from torch import nn
from torch.multiprocessing import ProcessExitedException
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
Expand Down Expand Up @@ -69,6 +70,7 @@ def test_resume_training_on_cpu(tmpdir):

@RunIf(tpu=True)
@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.xfail(raises=ProcessExitedException, reason="https://github.com/pytorch/xla/issues/1666")
def test_if_test_works_after_train(tmpdir):
"""Ensure that .test() works after .fit()"""
model = BoringModel()
Expand Down