Skip to content

Commit f6d892a

Browse files
authored
[feat] Support custom filesystems in LightningModule.to_torchscript (#7617)
* [feat] Support custom filesystems in LightningModule.to_torchscript * Update CHANGELOG.md * Update test_torchscript.py * Update test_torchscript.py * Update CHANGELOG.md * Update test_torchscript.py
1 parent e8a46be commit f6d892a

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
13+
14+
1215
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
1316

1417

pytorch_lightning/core/lightning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pytorch_lightning.core.step_result import Result
4141
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
4242
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
43+
from pytorch_lightning.utilities.cloud_io import get_filesystem
4344
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
4445
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4546
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
@@ -1860,7 +1861,9 @@ def to_torchscript(
18601861
self.train(mode)
18611862

18621863
if file_path is not None:
1863-
torch.jit.save(torchscript_module, file_path)
1864+
fs = get_filesystem(file_path)
1865+
with fs.open(file_path, "wb") as f:
1866+
torch.jit.save(torchscript_module, f)
18641867

18651868
return torchscript_module
18661869

tests/models/test_torchscript.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
17+
import fsspec
1518
import pytest
1619
import torch
20+
from fsspec.implementations.local import LocalFileSystem
1721

22+
from pytorch_lightning.utilities.cloud_io import get_filesystem
1823
from tests.helpers import BoringModel
1924
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
2025
from tests.helpers.datamodules import MNISTDataModule
@@ -139,6 +144,34 @@ def test_torchscript_save_load(tmpdir, modelclass):
139144
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
140145

141146

147+
@pytest.mark.parametrize("modelclass", [
148+
BoringModel,
149+
ParityModuleRNN,
150+
BasicGAN,
151+
])
152+
@RunIf(min_torch="1.5.0")
153+
def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass):
154+
""" Test that scripted LightningModule is correctly saved and can be loaded with custom filesystems. """
155+
156+
_DUMMY_PRFEIX = "dummy"
157+
_PREFIX_SEPARATOR = "://"
158+
159+
class DummyFileSystem(LocalFileSystem):
160+
...
161+
162+
fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True)
163+
164+
model = modelclass()
165+
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmpdir, "model.pt")
166+
script = model.to_torchscript(file_path=output_file)
167+
168+
fs = get_filesystem(output_file)
169+
with fs.open(output_file, "rb") as f:
170+
loaded_script = torch.jit.load(f)
171+
172+
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
173+
174+
142175
def test_torchcript_invalid_method(tmpdir):
143176
"""Test that an error is thrown with invalid torchscript method"""
144177
model = BoringModel()

0 commit comments

Comments
 (0)