diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 31e69c443e..ef5632ea57 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -24,11 +24,15 @@ def info(filepath: str) -> AudioMetaData: """Get signal information of an audio file. Args: - filepath (str): Path to audio file + filepath (str or pathlib.Path): + Path to audio file. This function also handles ``pathlib.Path`` objects, but is annotated as + ``str`` for TorchScript compiler compatibility. Returns: AudioMetaData: meta data of the given audio. """ + # Cast to str in case type is `pathlib.Path` + filepath = str(filepath) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath) return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels()) @@ -80,8 +84,9 @@ def load( ``[-1.0, 1.0]``. Args: - filepath (str): - Path to audio file + filepath (str or pathlib.Path): + Path to audio file. This function also handles ``pathlib.Path`` objects, but is + annotated as ``str`` for TorchScript compiler compatibility. frame_offset (int): Number of frames to skip before start reading data. num_frames (int): @@ -105,6 +110,8 @@ def load( integer type, else ``float32`` type. If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``. """ + # Cast to str in case type is `pathlib.Path` + filepath = str(filepath) signal = torch.ops.torchaudio.sox_io_load_audio_file( filepath, frame_offset, num_frames, normalize, channels_first) return signal.get_tensor(), signal.get_sample_rate() @@ -140,7 +147,9 @@ def save( and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc. Args: - filepath (str): Path to save file. + filepath (str or pathlib.Path): + Path to save file. This function also handles ``pathlib.Path`` objects, but is annotated + as ``str`` for TorchScript compiler compatibility. tensor (torch.Tensor): Audio data to save. must be 2D tensor. sample_rate (int): sampling rate channels_first (bool): @@ -158,6 +167,8 @@ def save( See the detail at http://sox.sourceforge.net/soxformat.html. """ + # Cast to str in case type is `pathlib.Path` + filepath = str(filepath) if compression is None: ext = str(filepath).split('.')[-1].lower() if ext in ['wav', 'sph']: