diff --git a/src/Examples/Program.cs b/src/Examples/Program.cs index 352f81e24..12ca4d3a3 100644 --- a/src/Examples/Program.cs +++ b/src/Examples/Program.cs @@ -12,6 +12,7 @@ public static void Main(string[] args) //SequenceToSequence.Main(args); //TextClassification.Main(args); //ImageTransforms.Main(args); + //SpeechCommands.Main(args); IOReadWrite.Main(args); } } diff --git a/src/Examples/SpeechCommands.cs b/src/Examples/SpeechCommands.cs new file mode 100644 index 000000000..6293a5983 --- /dev/null +++ b/src/Examples/SpeechCommands.cs @@ -0,0 +1,276 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using static TorchSharp.torch; + +using static TorchSharp.torch.nn; +using static TorchSharp.torch.nn.functional; +using static TorchSharp.torch.utils.data; +using static TorchSharp.torchaudio; +using static TorchSharp.torchaudio.datasets; + +namespace TorchSharp.Examples +{ + /// + /// SpeechCommands model with convolusions. + /// + /// + /// Translated from Python implementation + /// https://pytorch.org/tutorials/intermediate/speech_command_classification_with_torchaudio_tutorial.html + /// + public class SpeechCommands + { + private static readonly string[] Labels = new string[] { + "bed", "bird", "backward", "cat", "dog", "down", "eight", + "five", "follow", "forward", "four", "go", "happy", "house", + "learn", "left", "marvin", "nine", "no", "off", "on", "one", + "right", "seven", "six", "sheila", "stop", "three", "tree", + "two", "up", "visual", "wow", "yes", "zero" + }; + + private static int _epochs = 1; + private static int _trainBatchSize = 64; + private static int _testBatchSize = 128; + private static int _sample_rate = 16000; + private static int _new_sample_rate = 8000; + + private readonly static int _logInterval = 200; + + private static IDictionary _labelToIndex; + + internal static void Main(string[] args) + { + var dataset = args.Length > 0 ? args[0] : "speechcommands"; + var datasetPath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData); + + torchaudio.backend.utils.set_audio_backend(new WaveAudioBackend()); + torch.random.manual_seed(1); + + var cwd = Environment.CurrentDirectory; + Console.WriteLine(datasetPath); + + var device = torch.device(torch.cuda.is_available() ? "cuda" : "cpu"); + + Console.WriteLine($"Running SpeechCommands on {device.type.ToString()}"); + Console.WriteLine($"Dataset: {dataset}"); + + if (device.type == DeviceType.CUDA) { + _trainBatchSize *= 4; + _testBatchSize *= 4; + } + + var model = new M5("model"); + model.to(device); + + var transform = torchaudio.transforms.Resample(_sample_rate, _new_sample_rate, device: device); + + _labelToIndex = Labels.Select((label, index) => (label, index)).ToDictionary(t => t.label, t => t.index); + using (var train_data = SPEECHCOMMANDS(datasetPath, subset: "training", download: true)) + using (var test_data = SPEECHCOMMANDS(datasetPath, subset: "testing", download: true)) { + TrainingLoop("speechcommands", device, model, transform, train_data, test_data); + } + } + + private static BatchItem Collate(IEnumerable items, torch.Device device) + { + var audio_sequences = items.Select(item => item.waveform.t()); + var padded_audio = torch.nn.utils.rnn.pad_sequence(audio_sequences, batch_first: true, padding_value: 0.0); + padded_audio = padded_audio.permute(0, 2, 1); + var labels = items.Select(item => _labelToIndex[item.label]).ToArray(); + return new BatchItem { + audio = padded_audio.to(device), + label = torch.tensor(labels, dtype: torch.int64, device: device) + }; + } + + internal static void TrainingLoop(string dataset, Device device, M5 model, ITransform transform, Dataset train_data, Dataset test_data) + { + using (var train_loader = new DataLoader( + train_data, _trainBatchSize, Collate, shuffle: true, device: device)) + using (var test_loader = new DataLoader( + test_data, _testBatchSize, Collate, shuffle: false, device: device)) { + if (device.type == DeviceType.CUDA) { + _epochs *= 4; + } + + var optimizer = optim.Adam(model.parameters(), lr: 0.01, weight_decay: 0.0001); + var scheduler = optim.lr_scheduler.StepLR(optimizer, step_size: 20, gamma: 0.1); + + Stopwatch sw = new Stopwatch(); + sw.Start(); + + for (var epoch = 1; epoch <= _epochs; epoch++) { + Train(model, transform, optimizer, nll_loss(reduction: torch.nn.Reduction.Mean), train_loader, epoch, train_data.Count); + Test(model, transform, nll_loss(reduction: torch.nn.Reduction.Sum), test_loader, test_data.Count); + + Console.WriteLine($"End-of-epoch memory use: {GC.GetTotalMemory(false)}"); + scheduler.step(); + } + + sw.Stop(); + Console.WriteLine($"Elapsed time: {sw.Elapsed.TotalSeconds:F1} s."); + + Console.WriteLine("Saving model to '{0}'", dataset + ".model.bin"); + model.save(dataset + ".model.bin"); + } + } + + private static void Train( + M5 model, + ITransform transform, + torch.optim.Optimizer optimizer, + Loss criteria, + DataLoader dataLoader, + int epoch, + long size) + { + int batchId = 1; + long total = 0; + + Console.WriteLine($"Epoch: {epoch}..."); + + using (var d = torch.NewDisposeScope()) { + + model.train(); + foreach (var batch in dataLoader) { + var audio = transform.forward(batch.audio); + var target = batch.label; + var output = model.forward(batch.audio).squeeze(); + var loss = criteria(output, target); + optimizer.zero_grad(); + loss.backward(); + optimizer.step(); + total += target.shape[0]; + + if (batchId % _logInterval == 0 || total == size) { + Console.WriteLine($"\rTrain: epoch {epoch} [{total} / {size}] Loss: {loss.ToSingle():F4}"); + } + + batchId++; + + d.DisposeEverything(); + } + } + } + + private static void Test( + M5 model, + ITransform transform, + Loss criteria, + DataLoader dataLoader, + long size) + { + model.eval(); + + double testLoss = 0; + int correct = 0; + + using (var d = torch.NewDisposeScope()) { + + foreach (var batch in dataLoader) { + var audio = transform.forward(batch.audio); + var target = batch.label; + var output = model.forward(batch.audio).squeeze(); + var loss = criteria(output, target); + testLoss += loss.ToSingle(); + + var pred = output.argmax(1); + correct += pred.eq(batch.label).sum().ToInt32(); + + d.DisposeEverything(); + } + } + + Console.WriteLine($"Size: {size}, Total: {size}"); + + Console.WriteLine($"\rTest set: Average loss {(testLoss / size):F4} | Accuracy {((double)correct / size):P2}"); + } + + private class WaveAudioBackend : torchaudio.backend.AudioBackend + { + public override (torch.Tensor, int) load(string filepath, long frame_offset = 0, long num_frames = -1, bool normalize = true, bool channels_first = true, torchaudio.AudioFormat? format = null) + { + byte[] data = File.ReadAllBytes(filepath); + // In many cases, the first 44 bytes are for RIFF header. + short[] waveform = MemoryMarshal.Cast(data.AsSpan(11 * 4)).ToArray(); + return (torch.tensor(waveform).unsqueeze(0).to(torch.float32) / short.MaxValue, 16000); + } + + public override void save(string filepath, torch.Tensor src, int sample_rate, bool channels_first = true, float? compression = null, torchaudio.AudioFormat? format = null, torchaudio.AudioEncoding? encoding = null, int? bits_per_sample = null) + { + throw new NotImplementedException(); + } + + public override torchaudio.AudioMetaData info(string filepath, torchaudio.AudioFormat? format = null) + { + throw new NotImplementedException(); + } + } + + private class BatchItem + { + public torch.Tensor audio; + public torch.Tensor label; + } + + internal class M5 : Module + { + private readonly Module conv1; + private readonly Module bn1; + private readonly Module pool1; + private readonly Module conv2; + private readonly Module bn2; + private readonly Module pool2; + private readonly Module conv3; + private readonly Module bn3; + private readonly Module pool3; + private readonly Module conv4; + private readonly Module bn4; + private readonly Module pool4; + private readonly Module fc1; + + public M5(string name, int n_input = 1, int n_output = 35, int stride = 16, int n_channel = 32) : base(name) + { + conv1 = nn.Conv1d(n_input, n_channel, kernelSize: 80, stride: stride); + bn1 = nn.BatchNorm1d(n_channel); + pool1 = nn.MaxPool1d(4); + conv2 = nn.Conv1d(n_channel, n_channel, kernelSize: 3); + bn2 = nn.BatchNorm1d(n_channel); + pool2 = nn.MaxPool1d(4); + conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernelSize: 3); + bn3 = nn.BatchNorm1d(2 * n_channel); + pool3 = nn.MaxPool1d(4); + conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernelSize: 3); + bn4 = nn.BatchNorm1d(2 * n_channel); + pool4 = nn.MaxPool1d(4); + fc1 = nn.Linear(2 * n_channel, n_output); + RegisterComponents(); + } + + public override Tensor forward(Tensor input) + { + var x = input; + x = conv1.forward(x); + x = relu(bn1.forward(x)); + x = pool1.forward(x); + x = conv2.forward(x); + x = relu(bn2.forward(x)); + x = pool2.forward(x); + x = conv3.forward(x); + x = relu(bn3.forward(x)); + x = pool3.forward(x); + x = conv4.forward(x); + x = relu(bn4.forward(x)); + x = pool4.forward(x); + x = avg_pool1d(x, x.shape[x.dim() - 1]); + x = x.permute(0, 2, 1); + x = fc1.forward(x); + return log_softmax(x, dimension: 2); + } + } + } +}