diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index b086582dbe..9c34846b9e 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -6,17 +6,21 @@ from parameterized import parameterized from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin from torchaudio_unittest.common_utils import ( skipIfRocm, ) -class Functional(common_utils.TestBaseMixin): +class Functional(TempDirMixin, TestBaseMixin): """Implements test for `functinoal` modul that are performed for different devices""" def _assert_consistency(self, func, tensor, shape_only=False): tensor = tensor.to(device=self.device, dtype=self.dtype) - ts_func = torch.jit.script(func) + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + output = func(tensor) ts_output = ts_func(tensor) if shape_only: @@ -565,7 +569,7 @@ def func(tensor): self._assert_consistency(func, tensor) -class FunctionalComplex: +class FunctionalComplex(TempDirMixin, TestBaseMixin): complex_dtype = None real_dtype = None device = None @@ -573,7 +577,10 @@ class FunctionalComplex: def _assert_consistency(self, func, tensor, test_pseudo_complex=False): assert tensor.is_complex() tensor = tensor.to(device=self.device, dtype=self.complex_dtype) - ts_func = torch.jit.script(func) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) if test_pseudo_complex: tensor = torch.view_as_real(tensor) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 40e9a287dd..64ecbe53f7 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -7,16 +7,21 @@ from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import ( skipIfRocm, + TempDirMixin, + TestBaseMixin, ) -class Transforms(common_utils.TestBaseMixin): +class Transforms(TempDirMixin, TestBaseMixin): """Implements test for Transforms that are performed for different devices""" def _assert_consistency(self, transform, tensor): tensor = tensor.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype) - ts_transform = torch.jit.script(transform) + path = self.get_temp_path('transform.zip') + torch.jit.script(transform).save(path) + ts_transform = torch.jit.load(path) + output = transform(tensor) ts_output = ts_transform(tensor) self.assertEqual(ts_output, output) @@ -39,8 +44,8 @@ def test_AmplitudeToDB(self): self._assert_consistency(T.AmplitudeToDB(), spec) def test_MelScale(self): - spec_f = torch.rand((1, 6, 201)) - self._assert_consistency(T.MelScale(), spec_f) + spec_f = torch.rand((1, 201, 6)) + self._assert_consistency(T.MelScale(n_stft=201), spec_f) def test_MelSpectrogram(self): tensor = torch.rand((1, 1000)) @@ -100,7 +105,7 @@ def test_SpectralCentroid(self): self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) -class TransformsComplex: +class TransformsComplex(TempDirMixin, TestBaseMixin): complex_dtype = None real_dtype = None device = None @@ -109,7 +114,10 @@ def _assert_consistency(self, transform, tensor, test_pseudo_complex=False): assert tensor.is_complex() tensor = tensor.to(device=self.device, dtype=self.complex_dtype) transform = transform.to(device=self.device, dtype=self.real_dtype) - ts_transform = torch.jit.script(transform) + + path = self.get_temp_path('transform.zip') + torch.jit.script(transform).save(path) + ts_transform = torch.jit.load(path) if test_pseudo_complex: tensor = torch.view_as_real(tensor)