@@ -220,31 +220,134 @@ const std::string get_filetype(const std::string path) {
220220  return  ext;
221221}
222222
223- sox_encoding_t  get_encoding (
224-     const  std::string filetype,
225-     const  caffe2::TypeMeta dtype) {
226-   if  (filetype == " mp3" 
227-     return  SOX_ENCODING_MP3;
228-   if  (filetype == " flac" 
229-     return  SOX_ENCODING_FLAC;
230-   if  (filetype == " ogg" " vorbis" 
231-     return  SOX_ENCODING_VORBIS;
232-   if  (filetype == " wav" " amb" 
233-     if  (dtype == torch::kUInt8 )
234-       return  SOX_ENCODING_UNSIGNED;
235-     if  (dtype == torch::kInt16 )
236-       return  SOX_ENCODING_SIGN2;
237-     if  (dtype == torch::kInt32 )
238-       return  SOX_ENCODING_SIGN2;
239-     if  (dtype == torch::kFloat32 )
240-       return  SOX_ENCODING_FLOAT;
241-     throw  std::runtime_error (" Unsupported dtype." 
223+ namespace  {
224+ 
225+ std::tuple<sox_encoding_t , unsigned > get_save_encoding_for_wav (
226+     const  std::string format,
227+     const  c10::optional<std::string>& encoding,
228+     const  c10::optional<int64_t >& bits_per_sample) {
229+   if  (!encoding.has_value ()) {
230+     if  (!bits_per_sample.has_value ())
231+       return  std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
232+     auto  val = static_cast <unsigned >(bits_per_sample.value ());
233+     if  (val == 8 )
234+       return  std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
235+     return  std::make_tuple<>(SOX_ENCODING_SIGN2, val);
242236  }
243-   if  (filetype == " sph" 
244-     return  SOX_ENCODING_SIGN2;
245-   if  (filetype == " amr-nb" 
246-     return  SOX_ENCODING_AMR_NB;
247-   throw  std::runtime_error (" Unsupported file type: " 
237+   if  (encoding == ENCODING_PCM_SIGNED) {
238+     if  (!bits_per_sample.has_value ())
239+       return  std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
240+     auto  val = static_cast <unsigned >(bits_per_sample.value ());
241+     if  (val == 8 ) {
242+       TORCH_WARN_ONCE (" %s does not support 8-bit signed PCM encoding. Using 16-bit." 
243+       val = 16 ;
244+     }
245+     return  std::make_tuple<>(SOX_ENCODING_SIGN2, val);
246+   }
247+   if  (encoding == ENCODING_PCM_UNSIGNED) {
248+     if  (!bits_per_sample.has_value ())
249+       return  std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
250+     auto  val = static_cast <unsigned >(bits_per_sample.value ());
251+     if  (val != 8 )
252+       TORCH_WARN_ONCE (" %s only supports 8-bit for unsigned PCM encoding. Using 8-bit." 
253+     return  std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
254+   }
255+   if  (encoding == ENCODING_PCM_FLOAT) {
256+     auto  val = static_cast <unsigned >(bits_per_sample.value_or (32 ));
257+     if  (val != 32 )
258+       TORCH_WARN_ONCE (" %s only supports 32-bit for floating point PCM encoding. Using 32-bit." 
259+     return  std::make_tuple<>(SOX_ENCODING_FLOAT, 32 );
260+   }
261+   if  (encoding == ENCODING_ULAW) {
262+     auto  val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
263+     if  (val != 8 )
264+       TORCH_WARN_ONCE (" %s only supports 8-bit for mu-law encoding. Using 8-bit." 
265+     return  std::make_tuple<>(SOX_ENCODING_ULAW, 8 );
266+   }
267+   if  (encoding == ENCODING_ALAW) {
268+     auto  val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
269+     if  (val != 8 )
270+       TORCH_WARN_ONCE (" %s only supports 8-bit for a-law encoding. Using 8-bit." 
271+     return  std::make_tuple<>(SOX_ENCODING_ALAW, 8 );      
272+   }
273+   std::ostringstream message;
274+   message << format << "  format does not support encoding: " value ();
275+   throw  std::runtime_error (message.str ());
276+ }
277+ 
278+ std::tuple<sox_encoding_t , unsigned > get_save_encoding (
279+     const  std::string& format,
280+     const  c10::optional<std::string>& encoding,
281+     const  c10::optional<int64_t >& bits_per_sample) {
282+   if  (format == " mp3" 
283+     if  (encoding.has_value ()) {
284+       TORCH_WARN_ONCE (" mp3 does not support `encoding` option. Ignoring." 
285+     }
286+     if  (bits_per_sample.has_value ()) {
287+       TORCH_WARN_ONCE (" mp3 does not `bits_per_sample` option. Ignoring." 
288+     }
289+     return  std::make_tuple<>(SOX_ENCODING_MP3, 16 );
290+   }
291+   if  (format == " ogg" " vorbis" 
292+     if  (encoding.has_value ()) {
293+       TORCH_WARN_ONCE (" ogg/vorbis does not support `encoding` option. Ignoring." 
294+     }
295+     if  (bits_per_sample.has_value ()) {
296+       TORCH_WARN_ONCE (" ogg/vorbis does not `bits_per_sample` option. Ignoring." 
297+     }
298+     return  std::make_tuple<>(SOX_ENCODING_VORBIS, 16 );
299+   }
300+   if  (format == " amr-nb" 
301+     if  (encoding.has_value ()) {
302+       TORCH_WARN_ONCE (" amr-nb does not support `encoding` option. Ignoring." 
303+     }
304+     if  (bits_per_sample.has_value ()) {
305+       TORCH_WARN_ONCE (" amr-nb does not `bits_per_sample` option. Ignoring." 
306+     }
307+     return  std::make_tuple<>(SOX_ENCODING_AMR_NB, 16 );
308+   }
309+   if  (format == " wav" " amb" 
310+     return  get_save_encoding_for_wav (format, encoding, bits_per_sample);
311+   }
312+   if  (format == " flac" 
313+     if  (encoding.has_value ()) {
314+       TORCH_WARN_ONCE (" flac does not support `encoding` option. Ignoring." 
315+     }
316+     unsigned  bps = [&](){
317+       unsigned  val = static_cast <unsigned >(bits_per_sample.value_or (24 ));
318+       if  (val > 24 ) {
319+         TORCH_WARN_ONCE (" flac does not support bits_per_sample larger than 24. Using 24." 
320+         val = 24 ;
321+       }
322+       return  val;
323+     }();
324+     return  std::make_tuple<>(SOX_ENCODING_FLAC, bps);
325+   }
326+   if  (format == " sph" 
327+     if  (!encoding.has_value () || encoding == ENCODING_PCM_SIGNED) {
328+       if  (!bits_per_sample.has_value ())
329+         return  std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
330+       auto  val = static_cast <unsigned >(bits_per_sample.value ());
331+       return  std::make_tuple<>(SOX_ENCODING_SIGN2, val);
332+     }
333+     if  (encoding == ENCODING_PCM_UNSIGNED || encoding == ENCODING_PCM_FLOAT) {
334+       TORCH_WARN_ONCE (" sph does not support unsigned integer PCM or floating point PCM. Using signed interger PCM" 
335+       auto  val = static_cast <unsigned >(bits_per_sample.value_or (16 ));
336+       return  std::make_tuple<>(SOX_ENCODING_UNSIGNED, val);
337+     }
338+     if  (encoding == ENCODING_ULAW) {
339+       auto  val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
340+       if  (val != 8 )
341+         TORCH_WARN_ONCE (" sph only supports 8-bit for mu-law encoding. Using 8-bit." 
342+       return  std::make_tuple<>(SOX_ENCODING_ULAW, 8 );
343+     }
344+     if  (encoding == ENCODING_ALAW) {
345+       auto  val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
346+       return  std::make_tuple<>(SOX_ENCODING_ALAW, val);
347+     }
348+     throw  std::runtime_error (" sph format does not support encoding: " value ());
349+   }
350+   throw  std::runtime_error (" Unsupported format: " 
248351}
249352
250353unsigned  get_precision (
@@ -278,6 +381,8 @@ unsigned get_precision(
278381  throw  std::runtime_error (" Unsupported file type: " 
279382}
280383
384+ } //  namepsace
385+ 
281386sox_signalinfo_t  get_signalinfo (
282387    const  torch::Tensor* waveform,
283388    const  int64_t  sample_rate,
@@ -326,12 +431,14 @@ sox_encodinginfo_t get_tensor_encodinginfo(
326431}
327432
328433sox_encodinginfo_t  get_encodinginfo_for_save (
329-     const  std::string filetype,
330-     const  caffe2::TypeMeta dtype,
331-     c10::optional<double >& compression) {
434+     const  std::string& format,
435+     const  c10::optional<double >& compression,
436+     const  c10::optional<std::string>& encoding,
437+     const  c10::optional<int64_t >& bits_per_sample) {
438+   auto  enc = get_save_encoding (format, encoding, bits_per_sample);
332439  return  sox_encodinginfo_t {
333-       /* encoding=*/ get_encoding (filetype, dtype ),
334-       /* bits_per_sample=*/ get_precision (filetype, dtype ),
440+       /* encoding=*/ std::get< 0 >(enc ),
441+       /* bits_per_sample=*/ std::get< 1 >(enc ),
335442      /* compression=*/ value_or (HUGE_VAL),
336443      /* reverse_bytes=*/ 
337444      /* reverse_nibbles=*/ 
0 commit comments