@@ -220,31 +220,134 @@ const std::string get_filetype(const std::string path) {
220220 return ext;
221221}
222222
223- sox_encoding_t get_encoding (
224- const std::string filetype,
225- const caffe2::TypeMeta dtype) {
226- if (filetype == " mp3" )
227- return SOX_ENCODING_MP3;
228- if (filetype == " flac" )
229- return SOX_ENCODING_FLAC;
230- if (filetype == " ogg" || filetype == " vorbis" )
231- return SOX_ENCODING_VORBIS;
232- if (filetype == " wav" || filetype == " amb" ) {
233- if (dtype == torch::kUInt8 )
234- return SOX_ENCODING_UNSIGNED;
235- if (dtype == torch::kInt16 )
236- return SOX_ENCODING_SIGN2;
237- if (dtype == torch::kInt32 )
238- return SOX_ENCODING_SIGN2;
239- if (dtype == torch::kFloat32 )
240- return SOX_ENCODING_FLOAT;
241- throw std::runtime_error (" Unsupported dtype." );
223+ namespace {
224+
225+ std::tuple<sox_encoding_t , unsigned > get_save_encoding_for_wav (
226+ const std::string format,
227+ const c10::optional<std::string>& encoding,
228+ const c10::optional<int64_t >& bits_per_sample) {
229+ if (!encoding.has_value ()) {
230+ if (!bits_per_sample.has_value ())
231+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
232+ auto val = static_cast <unsigned >(bits_per_sample.value ());
233+ if (val == 8 )
234+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
235+ return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
242236 }
243- if (filetype == " sph" )
244- return SOX_ENCODING_SIGN2;
245- if (filetype == " amr-nb" )
246- return SOX_ENCODING_AMR_NB;
247- throw std::runtime_error (" Unsupported file type: " + filetype);
237+ if (encoding == ENCODING_PCM_SIGNED) {
238+ if (!bits_per_sample.has_value ())
239+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
240+ auto val = static_cast <unsigned >(bits_per_sample.value ());
241+ if (val == 8 ) {
242+ TORCH_WARN_ONCE (" %s does not support 8-bit signed PCM encoding. Using 16-bit." , format);
243+ val = 16 ;
244+ }
245+ return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
246+ }
247+ if (encoding == ENCODING_PCM_UNSIGNED) {
248+ if (!bits_per_sample.has_value ())
249+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
250+ auto val = static_cast <unsigned >(bits_per_sample.value ());
251+ if (val != 8 )
252+ TORCH_WARN_ONCE (" %s only supports 8-bit for unsigned PCM encoding. Using 8-bit." , format);
253+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
254+ }
255+ if (encoding == ENCODING_PCM_FLOAT) {
256+ auto val = static_cast <unsigned >(bits_per_sample.value_or (32 ));
257+ if (val != 32 )
258+ TORCH_WARN_ONCE (" %s only supports 32-bit for floating point PCM encoding. Using 32-bit." , format);
259+ return std::make_tuple<>(SOX_ENCODING_FLOAT, 32 );
260+ }
261+ if (encoding == ENCODING_ULAW) {
262+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
263+ if (val != 8 )
264+ TORCH_WARN_ONCE (" %s only supports 8-bit for mu-law encoding. Using 8-bit." , format);
265+ return std::make_tuple<>(SOX_ENCODING_ULAW, 8 );
266+ }
267+ if (encoding == ENCODING_ALAW) {
268+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
269+ if (val != 8 )
270+ TORCH_WARN_ONCE (" %s only supports 8-bit for a-law encoding. Using 8-bit." , format);
271+ return std::make_tuple<>(SOX_ENCODING_ALAW, 8 );
272+ }
273+ std::ostringstream message;
274+ message << format << " format does not support encoding: " << encoding.value ();
275+ throw std::runtime_error (message.str ());
276+ }
277+
278+ std::tuple<sox_encoding_t , unsigned > get_save_encoding (
279+ const std::string& format,
280+ const c10::optional<std::string>& encoding,
281+ const c10::optional<int64_t >& bits_per_sample) {
282+ if (format == " mp3" ) {
283+ if (encoding.has_value ()) {
284+ TORCH_WARN_ONCE (" mp3 does not support `encoding` option. Ignoring." );
285+ }
286+ if (bits_per_sample.has_value ()) {
287+ TORCH_WARN_ONCE (" mp3 does not `bits_per_sample` option. Ignoring." );
288+ }
289+ return std::make_tuple<>(SOX_ENCODING_MP3, 16 );
290+ }
291+ if (format == " ogg" || format == " vorbis" ) {
292+ if (encoding.has_value ()) {
293+ TORCH_WARN_ONCE (" ogg/vorbis does not support `encoding` option. Ignoring." );
294+ }
295+ if (bits_per_sample.has_value ()) {
296+ TORCH_WARN_ONCE (" ogg/vorbis does not `bits_per_sample` option. Ignoring." );
297+ }
298+ return std::make_tuple<>(SOX_ENCODING_VORBIS, 16 );
299+ }
300+ if (format == " amr-nb" ) {
301+ if (encoding.has_value ()) {
302+ TORCH_WARN_ONCE (" amr-nb does not support `encoding` option. Ignoring." );
303+ }
304+ if (bits_per_sample.has_value ()) {
305+ TORCH_WARN_ONCE (" amr-nb does not `bits_per_sample` option. Ignoring." );
306+ }
307+ return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16 );
308+ }
309+ if (format == " wav" || format == " amb" ) {
310+ return get_save_encoding_for_wav (format, encoding, bits_per_sample);
311+ }
312+ if (format == " flac" ) {
313+ if (encoding.has_value ()) {
314+ TORCH_WARN_ONCE (" flac does not support `encoding` option. Ignoring." );
315+ }
316+ unsigned bps = [&](){
317+ unsigned val = static_cast <unsigned >(bits_per_sample.value_or (24 ));
318+ if (val > 24 ) {
319+ TORCH_WARN_ONCE (" flac does not support bits_per_sample larger than 24. Using 24." );
320+ val = 24 ;
321+ }
322+ return val;
323+ }();
324+ return std::make_tuple<>(SOX_ENCODING_FLAC, bps);
325+ }
326+ if (format == " sph" ) {
327+ if (!encoding.has_value () || encoding == ENCODING_PCM_SIGNED) {
328+ if (!bits_per_sample.has_value ())
329+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
330+ auto val = static_cast <unsigned >(bits_per_sample.value ());
331+ return std::make_tuple<>(SOX_ENCODING_SIGN2, val);
332+ }
333+ if (encoding == ENCODING_PCM_UNSIGNED || encoding == ENCODING_PCM_FLOAT) {
334+ TORCH_WARN_ONCE (" sph does not support unsigned integer PCM or floating point PCM. Using signed interger PCM" );
335+ auto val = static_cast <unsigned >(bits_per_sample.value_or (16 ));
336+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, val);
337+ }
338+ if (encoding == ENCODING_ULAW) {
339+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
340+ if (val != 8 )
341+ TORCH_WARN_ONCE (" sph only supports 8-bit for mu-law encoding. Using 8-bit." );
342+ return std::make_tuple<>(SOX_ENCODING_ULAW, 8 );
343+ }
344+ if (encoding == ENCODING_ALAW) {
345+ auto val = static_cast <unsigned >(bits_per_sample.value_or (8 ));
346+ return std::make_tuple<>(SOX_ENCODING_ALAW, val);
347+ }
348+ throw std::runtime_error (" sph format does not support encoding: " + encoding.value ());
349+ }
350+ throw std::runtime_error (" Unsupported format: " + format);
248351}
249352
250353unsigned get_precision (
@@ -278,6 +381,8 @@ unsigned get_precision(
278381 throw std::runtime_error (" Unsupported file type: " + filetype);
279382}
280383
384+ } // namepsace
385+
281386sox_signalinfo_t get_signalinfo (
282387 const torch::Tensor* waveform,
283388 const int64_t sample_rate,
@@ -326,12 +431,14 @@ sox_encodinginfo_t get_tensor_encodinginfo(
326431}
327432
328433sox_encodinginfo_t get_encodinginfo_for_save (
329- const std::string filetype,
330- const caffe2::TypeMeta dtype,
331- c10::optional<double >& compression) {
434+ const std::string& format,
435+ const c10::optional<double >& compression,
436+ const c10::optional<std::string>& encoding,
437+ const c10::optional<int64_t >& bits_per_sample) {
438+ auto enc = get_save_encoding (format, encoding, bits_per_sample);
332439 return sox_encodinginfo_t {
333- /* encoding=*/ get_encoding (filetype, dtype ),
334- /* bits_per_sample=*/ get_precision (filetype, dtype ),
440+ /* encoding=*/ std::get< 0 >(enc ),
441+ /* bits_per_sample=*/ std::get< 1 >(enc ),
335442 /* compression=*/ compression.value_or (HUGE_VAL),
336443 /* reverse_bytes=*/ sox_option_default,
337444 /* reverse_nibbles=*/ sox_option_default,
0 commit comments