Skip to content

Commit 5cef977

Browse files
authored
Add tests for GCS filesystem (#7946)
1 parent ced2c94 commit 5cef977

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

requirements/extra.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ matplotlib>3.1
44
horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed
55
omegaconf>=2.0.1
66
torchtext>=0.5
7-
# onnx>=1.7.0
7+
onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
1010
jsonargparse[signatures]>=3.15.0
11+
gcsfs>=2021.5.0
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import fsspec
17+
import pytest
18+
19+
from pytorch_lightning import Trainer
20+
from pytorch_lightning.callbacks import ModelCheckpoint
21+
from pytorch_lightning.loggers import TensorBoardLogger
22+
from tests.helpers import BoringModel
23+
24+
GCS_BUCKET_PATH = os.getenv("GCS_BUCKET_PATH", None)
25+
_GCS_BUCKET_PATH_AVAILABLE = GCS_BUCKET_PATH is not None
26+
27+
gcs_fs = fsspec.filesystem("gs") if _GCS_BUCKET_PATH_AVAILABLE else None
28+
29+
30+
def gcs_path_join(dir_path):
31+
return GCS_BUCKET_PATH + str(dir_path)
32+
33+
34+
def gcs_rm_dir(dir_path):
35+
gcs_fs.rm(dir_path, recursive=True)
36+
return True
37+
38+
39+
@pytest.mark.skipif(not _GCS_BUCKET_PATH_AVAILABLE, reason="Test requires GCS bucket path")
40+
def test_gcs_model_checkpoint_contents(tmpdir):
41+
dir_path = gcs_path_join(tmpdir)
42+
43+
model = BoringModel()
44+
checkpoint_callback = ModelCheckpoint(dirpath=dir_path, save_top_k=-1, save_last=True)
45+
epochs = 2
46+
47+
trainer = Trainer(
48+
default_root_dir=dir_path,
49+
callbacks=[checkpoint_callback],
50+
limit_train_batches=10,
51+
limit_val_batches=10,
52+
max_epochs=2,
53+
logger=False,
54+
)
55+
56+
trainer.fit(model)
57+
58+
assert checkpoint_callback.best_model_path == os.path.join(dir_path, 'epoch=1-step=19.ckpt')
59+
assert checkpoint_callback.last_model_path == os.path.join(dir_path, 'last.ckpt')
60+
61+
expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])]
62+
expected.append('last.ckpt')
63+
64+
gcs_ckpt_paths = [os.path.basename(path) for path in gcs_fs.listdir(dir_path, detail=False)]
65+
assert gcs_ckpt_paths == expected
66+
67+
assert gcs_rm_dir(dir_path)
68+
69+
70+
@pytest.mark.skipif(not _GCS_BUCKET_PATH_AVAILABLE, reason="Test requires GCS bucket path")
71+
def test_gcs_logging(tmpdir):
72+
dir_path = gcs_path_join(tmpdir)
73+
74+
name = "tb_versioning"
75+
log_dir = os.path.join(dir_path, name)
76+
gcs_fs.mkdir(log_dir)
77+
expected_version = "101"
78+
79+
logger = TensorBoardLogger(save_dir=dir_path, name=name, version=expected_version)
80+
logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5})
81+
82+
assert logger.version == expected_version
83+
84+
gcs_paths = [os.path.basename(path) for path in gcs_fs.listdir(log_dir, detail=False)]
85+
gcs_paths = list(filter(lambda x: len(x) > 0, gcs_paths))
86+
87+
assert gcs_paths == [expected_version]
88+
assert gcs_fs.listdir(os.path.join(log_dir, expected_version), detail=False)
89+
90+
assert gcs_rm_dir(dir_path)
91+
92+
93+
@pytest.mark.skipif(not _GCS_BUCKET_PATH_AVAILABLE, reason="Test requires GCS bucket path")
94+
def test_gcs_save_hparams_to_yaml_file(tmpdir):
95+
dir_path = gcs_path_join(tmpdir)
96+
97+
model = BoringModel()
98+
logger = TensorBoardLogger(save_dir=dir_path, default_hp_metric=False)
99+
trainer = Trainer(max_steps=1, default_root_dir=dir_path, logger=logger)
100+
assert trainer.log_dir == trainer.logger.log_dir
101+
trainer.fit(model)
102+
103+
hparams_file = "hparams.yaml"
104+
assert gcs_fs.isfile(os.path.join(trainer.log_dir, hparams_file))
105+
106+
assert gcs_rm_dir(dir_path)

0 commit comments

Comments
 (0)