|
5 | 5 | namespace torchaudio { |
6 | 6 | namespace sox_utils { |
7 | 7 |
|
| 8 | +void set_seed(const int64_t seed) { |
| 9 | + sox_get_globals()->ranqd1 = static_cast<sox_int32_t>(seed); |
| 10 | +} |
| 11 | + |
| 12 | +void set_verbosity(const int64_t verbosity) { |
| 13 | + sox_get_globals()->verbosity = static_cast<unsigned>(verbosity); |
| 14 | +} |
| 15 | + |
| 16 | +void set_use_threads(const bool use_threads) { |
| 17 | + sox_get_globals()->use_threads = static_cast<sox_bool>(use_threads); |
| 18 | +} |
| 19 | + |
| 20 | +void set_buffer_size(const int64_t buffer_size) { |
| 21 | + sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size); |
| 22 | +} |
| 23 | + |
| 24 | +std::vector<std::vector<std::string>> list_effects() { |
| 25 | + std::vector<std::vector<std::string>> effects; |
| 26 | + for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) { |
| 27 | + const sox_effect_handler_t* handler = (*fns)(); |
| 28 | + if (handler && handler->name) { |
| 29 | + if (UNSUPPORTED_EFFECTS.find(handler->name) == |
| 30 | + UNSUPPORTED_EFFECTS.end()) { |
| 31 | + effects.emplace_back(std::vector<std::string>{ |
| 32 | + handler->name, |
| 33 | + handler->usage ? std::string(handler->usage) : std::string("")}); |
| 34 | + } |
| 35 | + } |
| 36 | + } |
| 37 | + return effects; |
| 38 | +} |
| 39 | + |
| 40 | +std::vector<std::string> list_formats() { |
| 41 | + std::vector<std::string> formats; |
| 42 | + for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { |
| 43 | + for (const char* const* names = fns->fn()->names; *names; ++names) { |
| 44 | + if (!strchr(*names, '/')) |
| 45 | + formats.emplace_back(*names); |
| 46 | + } |
| 47 | + } |
| 48 | + return formats; |
| 49 | +} |
| 50 | + |
8 | 51 | TensorSignal::TensorSignal( |
9 | 52 | torch::Tensor tensor_, |
10 | 53 | int64_t sample_rate_, |
@@ -205,13 +248,13 @@ unsigned get_precision( |
205 | 248 | } |
206 | 249 |
|
207 | 250 | sox_signalinfo_t get_signalinfo( |
208 | | - const torch::Tensor& tensor, |
209 | | - const int64_t sample_rate, |
210 | | - const bool channels_first, |
| 251 | + const TensorSignal* signal, |
211 | 252 | const std::string filetype) { |
| 253 | + auto tensor = signal->getTensor(); |
212 | 254 | return sox_signalinfo_t{ |
213 | | - /*rate=*/static_cast<sox_rate_t>(sample_rate), |
214 | | - /*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)), |
| 255 | + /*rate=*/static_cast<sox_rate_t>(signal->getSampleRate()), |
| 256 | + /*channels=*/ |
| 257 | + static_cast<unsigned>(tensor.size(signal->getChannelsFirst() ? 0 : 1)), |
215 | 258 | /*precision=*/get_precision(filetype, tensor.dtype()), |
216 | 259 | /*length=*/static_cast<uint64_t>(tensor.numel())}; |
217 | 260 | } |
|
0 commit comments