Skip to content

Commit c2ab7f8

Browse files
committed
CI: parameterize TPU tests (#15876)
* update * param * Apply suggestions from code review (cherry picked from commit 77006a2)
1 parent 6ed0e00 commit c2ab7f8

File tree

5 files changed

+100
-29
lines changed

5 files changed

+100
-29
lines changed

.github/workflows/tpu-tests.yml

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,22 @@ env:
3030
GKE_CLUSTER: lightning-cluster
3131
GKE_ZONE: us-central1-a
3232

33+
defaults:
34+
run:
35+
shell: bash
36+
3337
jobs:
34-
# TODO: package parametrization
3538
test-on-tpus:
3639
runs-on: ubuntu-22.04
3740
if: github.event.pull_request.draft == false
3841
env:
3942
PYTHON_VER: 3.7
43+
strategy:
44+
fail-fast: false
45+
max-parallel: 1 # run sequential
46+
matrix:
47+
# TODO: add also lightning
48+
pkg-name: ["lite", "pytorch"]
4049
timeout-minutes: 100 # should match the timeout in `tpu_workflow.jsonnet`
4150

4251
steps:
@@ -64,14 +73,24 @@ jobs:
6473

6574
- name: Update jsonnet
6675
env:
76+
SCOPE: ${{ matrix.pkg-name }}
6777
XLA_VER: 1.12
6878
PR_NUMBER: ${{ github.event.pull_request.number }}
6979
SHA: ${{ github.event.pull_request.head.sha }}
7080
run: |
71-
python -c "fname = 'dockers/base-xla/tpu_workflow.jsonnet' ; data = open(fname).read().replace('{PYTORCH_VERSION}', '$XLA_VER')
72-
data = data.replace('{PYTHON_VERSION}', '$PYTHON_VER').replace('{PR_NUMBER}', '$PR_NUMBER').replace('{SHA}', '$SHA') ; open(fname, 'w').write(data)"
73-
cat dockers/base-xla/tpu_workflow.jsonnet
74-
shell: bash
81+
import os
82+
fname = f'dockers/base-xla/tpu_workflow_{os.getenv("SCOPE")}.jsonnet'
83+
with open(fname) as fo:
84+
data = fo.read()
85+
data = data.replace('{PYTORCH_VERSION}', os.getenv("XLA_VER"))
86+
data = data.replace('{PYTHON_VERSION}', os.getenv("PYTHON_VER"))
87+
data = data.replace('{PR_NUMBER}', os.getenv("PR_NUMBER"))
88+
data = data.replace('{SHA}', os.getenv("SHA"))
89+
with open(fname, "w") as fw:
90+
fw.write(data)
91+
shell: python
92+
- name: Show jsonnet
93+
run: cat dockers/base-xla/tpu_workflow_${{ matrix.pkg-name }}.jsonnet
7594

7695
- uses: google-github-actions/auth@v1
7796
with:
@@ -86,7 +105,7 @@ jobs:
86105
- name: Deploy cluster
87106
run: |
88107
export PATH=$PATH:$HOME/go/bin
89-
job_name=$(jsonnet -J ml-testing-accelerators/ dockers/base-xla/tpu_workflow.jsonnet | kubectl create -f -)
108+
job_name=$(jsonnet -J ml-testing-accelerators/ dockers/base-xla/tpu_workflow_${{ matrix.pkg-name }}.jsonnet | kubectl create -f -)
90109
job_name=${job_name#job.batch/}
91110
job_name=${job_name% created}
92111
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}')

dockers/base-xla/Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ RUN pip --version && \
8686
rm *.whl
8787

8888
# Get package
89-
COPY ./ ./pytorch-lightning/
89+
COPY ./ ./lightning/
9090

9191
RUN \
9292
python --version && \
93-
cd pytorch-lightning && \
93+
cd lightning && \
9494
pip install -q -r .actions/requirements.txt && \
9595
# Pin mkl version to avoid OSError on torch import
9696
# OSError: libmkl_intel_lp64.so.1: cannot open shared object file: No such file or directory
@@ -103,7 +103,7 @@ RUN \
103103
# install PL dependencies
104104
pip install --requirement ./requirements/pytorch/devel.txt --no-cache-dir && \
105105
cd .. && \
106-
rm -rf pytorch-lightning && \
106+
rm -rf lightning && \
107107
rm -rf /root/.cache
108108

109109
RUN \

dockers/base-xla/tpu_workflow.jsonnet renamed to dockers/base-xla/tpu_workflow_lite.jsonnet

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,26 @@ local tputests = base.BaseTest {
2626
source ~/.bashrc
2727
conda activate lightning
2828
29-
echo "--- Fetch the SHA's changes ---"
29+
echo "--- Cloning lightning repo ---"
3030
git clone --single-branch --depth 1 https://github.com/Lightning-AI/lightning.git
3131
cd lightning
32+
# PR triggered it, check it out
3233
if [ -n "{PR_NUMBER}" ]; then # if PR number is not empty
33-
# PR triggered it, check it out
34+
echo "--- Fetch the PR changes ---"
3435
git fetch origin --depth 1 pull/{PR_NUMBER}/head:test/{PR_NUMBER}
36+
echo "--- Checkout PR changes ---"
3537
git -c advice.detachedHead=false checkout {SHA}
3638
fi
3739
3840
echo "--- Install packages ---"
39-
PACKAGE_NAME=lite pip install -e .[dev]
40-
PACKAGE_NAME=pytorch pip install -e .[dev]
41+
PACKAGE_NAME=lite pip install .[dev]
4142
pip list
4243
4344
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
4445
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
4546
4647
echo "--- Sanity check TPU availability ---"
4748
python -c "from lightning_lite.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
48-
python -c "from pytorch_lightning.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
4949
echo "Sanity check passed!"
5050
5151
echo "--- Running Lite tests ---"
@@ -55,13 +55,6 @@ local tputests = base.BaseTest {
5555
echo "--- Running standalone Lite tests ---"
5656
PL_STANDALONE_TESTS_SOURCE=lightning_lite PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh
5757
58-
echo "--- Running PL tests ---"
59-
cd ../tests_pytorch
60-
PL_RUN_TPU_TESTS=1 coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
61-
62-
echo "--- Running standalone PL tests ---"
63-
PL_STANDALONE_TESTS_SOURCE=pytorch_lightning PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh
64-
6558
echo "--- Generating coverage ---"
6659
coverage xml
6760
cat coverage.xml | tr -d '\t'
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
local base = import 'templates/base.libsonnet';
2+
local tpus = import 'templates/tpus.libsonnet';
3+
local utils = import "templates/utils.libsonnet";
4+
5+
local tputests = base.BaseTest {
6+
frameworkPrefix: 'pl',
7+
modelName: 'tpu-tests',
8+
mode: 'postsubmit',
9+
configMaps: [],
10+
11+
timeout: 6000, # 100 minutes, in seconds.
12+
13+
image: 'pytorchlightning/pytorch_lightning',
14+
imageTag: 'base-xla-py{PYTHON_VERSION}-torch{PYTORCH_VERSION}',
15+
16+
tpuSettings+: {
17+
softwareVersion: 'pytorch-{PYTORCH_VERSION}',
18+
},
19+
accelerator: tpus.v3_8,
20+
21+
command: utils.scriptCommand(
22+
|||
23+
set +x # turn off tracing, spammy
24+
set -e # exit on error
25+
26+
source ~/.bashrc
27+
conda activate lightning
28+
29+
echo "--- Cloning lightning repo ---"
30+
git clone --single-branch --depth 1 https://github.com/Lightning-AI/lightning.git
31+
cd lightning
32+
# PR triggered it, check it out
33+
if [ -n "{PR_NUMBER}" ]; then # if PR number is not empty
34+
echo "--- Fetch the PR changes ---"
35+
git fetch origin --depth 1 pull/{PR_NUMBER}/head:test/{PR_NUMBER}
36+
echo "--- Checkout PR changes ---"
37+
git -c advice.detachedHead=false checkout {SHA}
38+
fi
39+
40+
echo "--- Install packages ---"
41+
PACKAGE_NAME=pytorch pip install .[dev]
42+
pip list
43+
44+
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
45+
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
46+
47+
echo "--- Sanity check TPU availability ---"
48+
python -c "from pytorch_lightning.accelerators import TPUAccelerator; assert TPUAccelerator.is_available()"
49+
echo "Sanity check passed!"
50+
51+
echo "--- Running PL tests ---"
52+
cd tests/tests_pytorch
53+
PL_RUN_TPU_TESTS=1 coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
54+
55+
echo "--- Running standalone PL tests ---"
56+
PL_STANDALONE_TESTS_SOURCE=pytorch_lightning PL_STANDALONE_TESTS_BATCH_SIZE=1 bash run_standalone_tests.sh
57+
58+
echo "--- Generating coverage ---"
59+
coverage xml
60+
cat coverage.xml | tr -d '\t'
61+
|||
62+
),
63+
};
64+
65+
tputests.oneshotJob

environment.yml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,15 @@ dependencies:
3232
- pytorch>=1.9.*
3333
- future>=0.17.1
3434
- PyYAML>=5.1
35-
- tqdm>=4.41.0
35+
- tqdm>=4.57.0
3636
- fsspec[http]>=2021.06.1
3737
#- tensorboard>=2.2.0 # not needed, already included in pytorch
3838

3939
# Optional
4040
#- nvidia-apex # missing for py3.8
41-
- scikit-learn>=0.20.0
41+
- scikit-learn >0.22.1
4242
- matplotlib>=3.1.1
4343
- omegaconf>=2.0.5
4444

4545
# Examples
4646
- torchvision>=0.10.*
47-
48-
- pip:
49-
- mlflow>=1.0.0
50-
- comet_ml>=3.1.12
51-
- wandb>=0.10.22
52-
- neptune-client>=0.10.0

0 commit comments

Comments
 (0)