11import  io 
22import  itertools 
33
4+ import  torch 
45from  torchaudio .backend  import  sox_io_backend 
56from  parameterized  import  parameterized 
67
@@ -24,7 +25,7 @@ def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
2425        """`sox_io_backend.save` can save wav format.""" 
2526        path  =  self .get_temp_path ('data.wav' )
2627        expected  =  get_wav_data (dtype , num_channels , num_frames = num_frames )
27-         sox_io_backend .save (path , expected , sample_rate )
28+         sox_io_backend .save (path , expected , sample_rate ,  dtype = None )
2829        found , sr  =  load_wav (path )
2930        assert  sample_rate  ==  sr 
3031        self .assertEqual (found , expected )
@@ -68,7 +69,7 @@ def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
6869        save_wav (src_path , data , sample_rate )
6970        # 2.1. Convert the original wav to mp3 with torchaudio 
7071        sox_io_backend .save (
71-             mp3_path , load_wav (src_path )[0 ], sample_rate , compression = bit_rate )
72+             mp3_path , load_wav (src_path )[0 ], sample_rate , compression = bit_rate ,  dtype = None )
7273        # 2.2. Convert the mp3 to wav with Sox 
7374        sox_utils .convert_audio_file (mp3_path , wav_path )
7475        # 2.3. Load 
@@ -99,7 +100,7 @@ def assert_flac(self, sample_rate, num_channels, compression_level, duration):
99100        save_wav (src_path , data , sample_rate )
100101        # 2.1. Convert the original wav to flac with torchaudio 
101102        sox_io_backend .save (
102-             flc_path , load_wav (src_path )[0 ], sample_rate , compression = compression_level )
103+             flc_path , load_wav (src_path )[0 ], sample_rate , compression = compression_level ,  dtype = None )
103104        # 2.2. Convert the flac to wav with Sox 
104105        # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. 
105106        sox_utils .convert_audio_file (flc_path , wav_path , bit_depth = 32 )
@@ -132,7 +133,7 @@ def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
132133        save_wav (src_path , data , sample_rate )
133134        # 2.1. Convert the original wav to vorbis with torchaudio 
134135        sox_io_backend .save (
135-             vbs_path , load_wav (src_path )[0 ], sample_rate , compression = quality_level )
136+             vbs_path , load_wav (src_path )[0 ], sample_rate , compression = quality_level ,  dtype = None )
136137        # 2.2. Convert the vorbis to wav with Sox 
137138        sox_utils .convert_audio_file (vbs_path , wav_path )
138139        # 2.3. Load 
@@ -184,7 +185,7 @@ def assert_sphere(self, sample_rate, num_channels, duration):
184185        data  =  get_wav_data ('float32' , num_channels , normalize = True , num_frames = duration  *  sample_rate )
185186        save_wav (src_path , data , sample_rate )
186187        # 2.1. Convert the original wav to sph with torchaudio 
187-         sox_io_backend .save (flc_path , load_wav (src_path )[0 ], sample_rate )
188+         sox_io_backend .save (flc_path , load_wav (src_path )[0 ], sample_rate ,  dtype = None )
188189        # 2.2. Convert the sph to wav with Sox 
189190        # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. 
190191        sox_utils .convert_audio_file (flc_path , wav_path , bit_depth = 32 )
@@ -216,7 +217,7 @@ def assert_amb(self, dtype, sample_rate, num_channels, duration):
216217        data  =  get_wav_data (dtype , num_channels , normalize = False , num_frames = duration  *  sample_rate )
217218        save_wav (src_path , data , sample_rate )
218219        # 2.1. Convert the original wav to amb with torchaudio 
219-         sox_io_backend .save (amb_path , load_wav (src_path , normalize = False )[0 ], sample_rate )
220+         sox_io_backend .save (amb_path , load_wav (src_path , normalize = False )[0 ], sample_rate ,  dtype = None )
220221        # 2.2. Convert the amb to wav with Sox 
221222        sox_utils .convert_audio_file (amb_path , wav_path )
222223        # 2.3. Load 
@@ -248,7 +249,7 @@ def assert_amr_nb(self, duration):
248249        data  =  get_wav_data ('int16' , num_channels , normalize = False , num_frames = duration  *  sample_rate )
249250        save_wav (src_path , data , sample_rate )
250251        # 2.1. Convert the original wav to amr_nb with torchaudio 
251-         sox_io_backend .save (amr_path , load_wav (src_path , normalize = False )[0 ], sample_rate )
252+         sox_io_backend .save (amr_path , load_wav (src_path , normalize = False )[0 ], sample_rate ,  dtype = None )
252253        # 2.2. Convert the amr_nb to wav with Sox 
253254        sox_utils .convert_audio_file (amr_path , wav_path )
254255        # 2.3. Load 
@@ -389,7 +390,7 @@ def test_channels_first(self, channels_first):
389390        path  =  self .get_temp_path ('data.wav' )
390391        data  =  get_wav_data ('int32' , 2 , channels_first = channels_first )
391392        sox_io_backend .save (
392-             path , data , 8000 , channels_first = channels_first )
393+             path , data , 8000 , channels_first = channels_first ,  dtype = None )
393394        found  =  load_wav (path )[0 ]
394395        expected  =  data  if  channels_first  else  data .transpose (1 , 0 )
395396        self .assertEqual (found , expected )
@@ -402,7 +403,7 @@ def test_noncontiguous(self, dtype):
402403        path  =  self .get_temp_path ('data.wav' )
403404        expected  =  get_wav_data (dtype , 4 )[::2 , ::2 ]
404405        assert  not  expected .is_contiguous ()
405-         sox_io_backend .save (path , expected , 8000 )
406+         sox_io_backend .save (path , expected , 8000 ,  dtype = None )
406407        found  =  load_wav (path )[0 ]
407408        self .assertEqual (found , expected )
408409
@@ -415,10 +416,24 @@ def test_tensor_preserve(self, dtype):
415416        expected  =  get_wav_data (dtype , 4 )[::2 , ::2 ]
416417
417418        data  =  expected .clone ()
418-         sox_io_backend .save (path , data , 8000 )
419+         sox_io_backend .save (path , data , 8000 ,  dtype = None )
419420
420421        self .assertEqual (data , expected )
421422
423+     @parameterized .expand ([ 
424+         ('float32' , torch .tensor ([- 1.0 , - 0.5 , 0 , 0.5 , 1.0 ]).to (torch .float32 )), 
425+         ('int32' , torch .tensor ([- 2147483648 , - 1073741824 , 0 , 1073741824 , 2147483647 ]).to (torch .int32 )), 
426+         ('int16' , torch .tensor ([- 32768 , - 16384 , 0 , 16384 , 32767 ]).to (torch .int16 )), 
427+         ('uint8' , torch .tensor ([0 , 64 , 128 , 192 , 255 ]).to (torch .uint8 )), 
428+     ]) 
429+     def  test_dtype_conversion (self , dtype , expected ):
430+         """`save` performs dtype conversion on float32 src tensors only.""" 
431+         path  =  self .get_temp_path ("data.wav" )
432+         data  =  torch .tensor ([- 1.0 , - 0.5 , 0 , 0.5 , 1.0 ]).to (torch .float32 ).view (- 1 , 1 )
433+         sox_io_backend .save (path , data , 8000 , dtype = dtype )
434+         found  =  load_wav (path , normalize = False )[0 ]
435+         self .assertEqual (found , expected .view (- 1 , 1 ))
436+ 
422437
423438@skipIfNoExtension  
424439@skipIfNoExec ('sox' ) 
@@ -452,11 +467,11 @@ def test_fileobj(self, ext, compression):
452467        res_path  =  self .get_temp_path (f'test.{ ext }  ' )
453468        sox_io_backend .save (
454469            ref_path , data , channels_first = channels_first ,
455-             sample_rate = sample_rate , compression = compression )
470+             sample_rate = sample_rate , compression = compression ,  dtype = None )
456471        with  open (res_path , 'wb' ) as  fileobj :
457472            sox_io_backend .save (
458473                fileobj , data , channels_first = channels_first ,
459-                 sample_rate = sample_rate , compression = compression , format = ext )
474+                 sample_rate = sample_rate , compression = compression , format = ext ,  dtype = None )
460475
461476        expected_data , _  =  sox_io_backend .load (ref_path )
462477        data , sr  =  sox_io_backend .load (res_path )
@@ -489,11 +504,11 @@ def test_bytesio(self, ext, compression):
489504        res_path  =  self .get_temp_path (f'test.{ ext }  ' )
490505        sox_io_backend .save (
491506            ref_path , data , channels_first = channels_first ,
492-             sample_rate = sample_rate , compression = compression )
507+             sample_rate = sample_rate , compression = compression ,  dtype = None )
493508        fileobj  =  io .BytesIO ()
494509        sox_io_backend .save (
495510            fileobj , data , channels_first = channels_first ,
496-             sample_rate = sample_rate , compression = compression , format = ext )
511+             sample_rate = sample_rate , compression = compression , format = ext ,  dtype = None )
497512        fileobj .seek (0 )
498513        with  open (res_path , 'wb' ) as  file_ :
499514            file_ .write (fileobj .read ())
0 commit comments