- 
                Notifications
    You must be signed in to change notification settings 
- Fork 735
Example pipeline with wav2letter #632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2da10fd    to
    728f6c8      
    Compare
  
            
          
                examples/pipeline/wav2letter.py
              
                Outdated
          
        
      | return len(self._iterable) | ||
|  | ||
|  | ||
| class MapMemoryCache(torch.utils.data.Dataset): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Seems like a pretty generic object that should also be useful for core. It's effectively a readthrough cache backed by RAM similar to how diskcache_iterator is a readthrough cache backed by disk.
        
          
                examples/pipeline/wav2letter.py
              
                Outdated
          
        
      | # return create(["train-clean-100", "train-clean-360", "train-other-500"]), create(["dev-clean", "dev-other"]), None | ||
|  | ||
|  | ||
| def which_set(filename, validation_percentage, testing_percentage): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary as opposed to a seeded shuffle + split?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the script recommended for splitting train/dev/test in SpeechCommands' readme adapted to this use case. I suggest we include it with SpeechCommands in torchaudio.
The advantage of this approach is that words and speakers are better distributed between the different splits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are they better distributed in comparison to a random shuffle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, the real advantage of this approach is listed in the docstring:
We want to keep files in the same training, validation, or testing sets even if new ones are added over time. This makes it less likely that testing samples will accidentally be reused in training when long runs are restarted for example.
        
          
                examples/pipeline/wav2letter.py
              
                Outdated
          
        
      | if c is None: | ||
| c = count | ||
| else: | ||
| c = c + count | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could also use update and initialize c as an empty Counter.
Creating a Counter and adding to one over each iteration might be quite slow in comparison.
        
          
                examples/pipeline/wav2letter.py
              
                Outdated
          
        
      | return output[:, 0, :] | ||
|  | ||
|  | ||
| def levenshtein_distance_list(r, h): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only one version is needed indeed. The list version is faster than the pytorch version though.
        
          
                examples/pipeline/wav2letter.py
              
                Outdated
          
        
      | data = LIBRISPEECH( | ||
| root, tag, folder_in_archive=folder_in_archive, download=False) | ||
| else: | ||
| data = torch.utils.data.ConcatDataset([LIBRISPEECH( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could also be done using sum since ConcatDataset can be created via add.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point
| Looks pretty good :) | 
| Codecov Report
 @@           Coverage Diff           @@
##           master     #632   +/-   ##
=======================================
  Coverage   89.99%   89.99%           
=======================================
  Files          35       35           
  Lines        2719     2719           
=======================================
  Hits         2447     2447           
  Misses        272      272           Continue to review full report at Codecov. 
 | 
| I'd add "split into multiple files" as another todo as well. | 
        
          
                examples/pipeline/metrics.py
              
                Outdated
          
        
      | import torch | ||
|  | ||
|  | ||
| def levenshtein_distance(r: str, h: str, device: Optional[str] = None): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems worthwhile a separate PR already. Do you agree? In particular with the C++ extension we can create a JIT-able, fast version of this already.
        
          
                examples/pipeline/wav2letter.py
              
                Outdated
          
        
      | return args | ||
|  | ||
|  | ||
| def signal_handler(a, b): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The need for functions like this worry me, because I'd imagine most users to not be aware of their necessity or purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not "needed" :) I'll remove them to avoid confusion.
c9b4d6c    to
    d347fa4      
    Compare
  
    | weight_decay=args.weight_decay, | ||
| ) | ||
| else: | ||
| raise ValueError("Selected optimizer not supported") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repeat the given option, i.e. "Selected optimizer %s not supported".format(args.optimizer) if you're going for this type of input sanitization to make it easier for the user to debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, this code is unreachable from CLI, so NotImplementedError makes more sense, because the only time you reach here is when you intend to add a new choice and changed CLI parser but forgot to add actual implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I would extract this into a helper function so that main logic becomes readable. _get_optimizer(...)
| device=tensors[0].device, | ||
| ) | ||
|  | ||
| tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A wrapped / generalized version of this could form a useful torchaudio function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pad_sequence requires transposes as it is, since in torchaudio it is the last dimension that we want to pad. I re-implemented pad_sequence for this use case.
| def encode(self, iterable): | ||
| if isinstance(iterable, list): | ||
| return [self.encode(i) for i in iterable] | ||
| else: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if I pass an iterable that yields lists? What's the basecase type here? Maybe that's an easier case to branch on. Also as a very minor nit, I actually like using returns to avoid "else". So you could write
if isinstance(iterable, list):
    return [self.encode(i) for i in iterable]
return [self.mapping[i] + self.mapping[self.char_blank] for i in iterable]
| from torch import topk | ||
|  | ||
|  | ||
| class GreedyDecoder: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could generalize this file to be called "decoders.py" and also fold in things such as compute_error_rates
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is stateless. Can it be a function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be functional corresponding to a transform, but really it's a step towards our beamsearch work
| metric["dataset length"] += metric["batch size"] | ||
| metric["iteration"] += 1 | ||
| metric["loss"] = loss.item() | ||
| metric["cumulative loss"] += metric["loss"] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd abstract this accumulation for both training evaluation and merge it into a single function. That way you'll always be sure that both training and evaluation are using the exact same calculations, since that's the last place you'd want to be buggy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second that, all those logging logic should belong logger side as a method. That will make the train loop more readable, and achieve better decoupling.
|  | ||
| logging.info("Start time: %s", datetime.now()) | ||
|  | ||
| # Explicitly set seed to make sure models created in separate processes | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is for distributed training I'd worry that this isn't already happening. Did you have a case where this became necessary in order to avoid a bug?
| collate_fn=collate_fn_train, | ||
| **loader_training_params, | ||
| ) | ||
| loader_validation = DataLoader( | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For validation "drop_last" is usually undesired because you can end up not running on the entire dataset.
| "Checkpoint: loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"] | ||
| ) | ||
| else: | ||
| logging.info("Checkpoint: not found") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a case I'd error on. If the user intents to resume from this checkpoint and it wasn't found that's probably a mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the logic here is strange. If user does not give the checkpoint option (training from scratch), there is no need to say not found.
| not args.reduce_lr_valid, | ||
| ) | ||
|  | ||
| if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I sometimes like to save on the indent and write something like:
if epoch < args.epoch - 1 and (epoch + 1) % args.print_freq:
    continue
|  | ||
| class UnsqueezeFirst(torch.nn.Module): | ||
| def forward(self, tensor): | ||
| return tensor.unsqueeze(0) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is, in my opinion, the sort of issue that makes me dislike using nn.Sequential over a function. You end up wrapping simple, small commands into modules.
However, if you write one (or two) collate_functions you'll probably end up writing function factories that essentially do the same.
| def save_checkpoint(state, is_best, filename, disable): | ||
| """ | ||
| Save the model to a temporary file first, | ||
| then copy it to filename, in case the signal interrupts | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has this happened? I think the scheduler is supposed to signal you and then you get a bunch of time to catch the signal and shutdown gracefully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in my opinion this logic does not corresponds to the name of the function. save_checkpoint should do saving and only saving. The logic for Handling temporary file for the sake of interruption solved different concern and should live in a different function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of this happening. I'll remove this logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns on accuracy metrics. I strongly believe that WER computation is incorrect and we should not re-invent the wheel and use SCTK or something.
| self.dataset = dataset | ||
| self._cache = [None] * len(dataset) | ||
|  | ||
| def __getitem__(self, n): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be simplified.
if self._cache[n] is None:
    self._cache[n] = self.dataset[n]
return self._cache[n]
| def __len__(self): | ||
| return len(self.dataset) | ||
|  | ||
| def process_datapoint(self, item): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operation is not generic and requires specific item type, and since it uses index slicing it is very difficult to understand what it does. Please add a docstring.
| if isinstance(transforms, list): | ||
| transform_list = transforms | ||
| else: | ||
| transform_list = [transforms] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is an example code and all the helper functions are for making the example code main code simpler, so making helper functions more specific helps better with maintainability. Instead of allowing multiple types, it's simpler to allow only one type and do the equivalent type conversion in client code.
|  | ||
| def collate_fn(batch): | ||
|  | ||
| tensors = [transforms(b[0]) for b in batch if b] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very difficult to understand what are being transformed, here.
- for b in batch if b
Why is there a case that one item in a batch (denoted as b) can be invalid sample?
- what does b[0]represent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- if bis no longer needed, removed :)
- b[0]is the waveform from the processed data point tuple. added a comment
| self.char_space = char_space | ||
| self.char_blank = char_blank | ||
|  | ||
| labels = [l for l in labels] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot be labels = list(labels)? What is the expected type of the input labels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Yes, it's just a string.
| from typing import List, Union | ||
|  | ||
|  | ||
| def levenshtein_distance(r: Union[str, List[str]], h: Union[str, List[str]]): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If moving this into the library, docstring needs to be improved with the equation.
|  | ||
| class MetricLogger(defaultdict): | ||
| def __init__(self, name, print_freq=1, disable=False): | ||
| super().__init__(lambda: 0.0) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think super().__init__(float) is better.
| """ | ||
|  | ||
| if disable: | ||
| return | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think this should be the logic of save_checkpoinit function. This goes against single-responsibility principle. It's caller's responsibility to check when to save.
| return | ||
|  | ||
| if filename == "": | ||
| return | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, giving empty string as file location should be error.
| def save_checkpoint(state, is_best, filename, disable): | ||
| """ | ||
| Save the model to a temporary file first, | ||
| then copy it to filename, in case the signal interrupts | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in my opinion this logic does not corresponds to the name of the function. save_checkpoint should do saving and only saving. The logic for Handling temporary file for the sake of interruption solved different concern and should live in a different function.
d8c6de6    to
    c378c48      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is getting very, very large and has been up for a long time. Let's merge it, since it's already working, and revisit some of these suggested improvements on a PR-by-PR basis. It'll also help us share code across examples etc.
c378c48    to
    a2b6ad2      
    Compare
  
    | Following comment, reverted to Aug 19, merging, and moved follow-up to vincentqb#3. | 
| As mentioned in the README, we can get less than 13.8% "cer over target length" after 30 epochs. See sample output grepped for validation:  | 
…torchscripttutorial Delete Intro_to_TorchScript.py adding redirect to get user to new file.
We implement a reference pipeline using wav2letter model to train on librispeech. The structure will be inspired by torchvision's reference implementation.
As discussed here, this code was initially implemented in this python script which was converted from this notebook to be ran using SLURM using bash script and
sbatch.There are at least a few more things to do:
Add an option to activatetorch.autograd.set_detect_anomaly(True)Bring backviterbi decoder.Add 10 ms shift data augmentationPublish pre-trained weightsNote:
See also post by assemblyai, and internal.
cc @zhangguanheng66 for pytorch/text#767