diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 9620dd5d45..28ecb76fb9 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -163,6 +163,9 @@ def parse_args(): parser.add_argument( "--file-path", default="", type=str, help="the path of audio files", ) + parser.add_argument( + "--normalization", default=True, action="store_true", help="if True, spectrogram is normalized", + ) args = parser.parse_args() return args @@ -273,7 +276,7 @@ def main(args): n_mels=args.n_freq, fmin=args.f_min, ), - NormalizeDB(min_level_db=args.min_level_db), + NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization), ) train_dataset, val_dataset = split_process_dataset(args, transforms) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py index b22d60dae4..dca45e97f4 100644 --- a/examples/pipeline_wavernn/processing.py +++ b/examples/pipeline_wavernn/processing.py @@ -31,15 +31,18 @@ class NormalizeDB(nn.Module): r"""Normalize the spectrogram with a minimum db value """ - def __init__(self, min_level_db): + def __init__(self, min_level_db, normalization): super().__init__() self.min_level_db = min_level_db + self.normalization = normalization def forward(self, specgram): - specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) - return torch.clamp( - (self.min_level_db - specgram) / self.min_level_db, min=0, max=1 - ) + specgram = torch.log10(torch.clamp(specgram, min=1e-5)) + if self.normalization: + return torch.clamp( + (self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1 + ) + return specgram def normalized_waveform_to_bits(waveform, bits):