From 5040b641b6eaf2ced59dd7042c8d0991832c3113 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 6 Aug 2020 18:17:05 -0700 Subject: [PATCH 1/2] Add spectrogram normalization option --- examples/pipeline_wavernn/main.py | 5 ++++- examples/pipeline_wavernn/processing.py | 14 +++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/pipeline_wavernn/main.py b/examples/pipeline_wavernn/main.py index 9620dd5d45..28ecb76fb9 100644 --- a/examples/pipeline_wavernn/main.py +++ b/examples/pipeline_wavernn/main.py @@ -163,6 +163,9 @@ def parse_args(): parser.add_argument( "--file-path", default="", type=str, help="the path of audio files", ) + parser.add_argument( + "--normalization", default=True, action="store_true", help="if True, spectrogram is normalized", + ) args = parser.parse_args() return args @@ -273,7 +276,7 @@ def main(args): n_mels=args.n_freq, fmin=args.f_min, ), - NormalizeDB(min_level_db=args.min_level_db), + NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization), ) train_dataset, val_dataset = split_process_dataset(args, transforms) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py index b22d60dae4..9230d1f813 100644 --- a/examples/pipeline_wavernn/processing.py +++ b/examples/pipeline_wavernn/processing.py @@ -31,15 +31,19 @@ class NormalizeDB(nn.Module): r"""Normalize the spectrogram with a minimum db value """ - def __init__(self, min_level_db): + def __init__(self, min_level_db, normalization): super().__init__() self.min_level_db = min_level_db + self.normalization = normalization def forward(self, specgram): - specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) - return torch.clamp( - (self.min_level_db - specgram) / self.min_level_db, min=0, max=1 - ) + if self.normalization: + specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) + return torch.clamp( + (self.min_level_db - specgram) / self.min_level_db, min=0, max=1 + ) + else: + return torch.log10(torch.clamp(specgram, min=1e-5)) def normalized_waveform_to_bits(waveform, bits): From 4487565339570be523ac6dc177058bf7a3d6919b Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Fri, 7 Aug 2020 08:02:28 -0700 Subject: [PATCH 2/2] Update the processing format --- examples/pipeline_wavernn/processing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/pipeline_wavernn/processing.py b/examples/pipeline_wavernn/processing.py index 9230d1f813..dca45e97f4 100644 --- a/examples/pipeline_wavernn/processing.py +++ b/examples/pipeline_wavernn/processing.py @@ -37,13 +37,12 @@ def __init__(self, min_level_db, normalization): self.normalization = normalization def forward(self, specgram): + specgram = torch.log10(torch.clamp(specgram, min=1e-5)) if self.normalization: - specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) return torch.clamp( - (self.min_level_db - specgram) / self.min_level_db, min=0, max=1 + (self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1 ) - else: - return torch.log10(torch.clamp(specgram, min=1e-5)) + return specgram def normalized_waveform_to_bits(waveform, bits):