Skip to content

Commit c97b349

Browse files
committed
Remove causal option and add batch mask
1 parent 9f65f8e commit c97b349

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

examples/source_separation/conv_tasnet/train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def _get_model(
101101
msk_num_hidden_feats=512,
102102
msk_num_layers=8,
103103
msk_num_stacks=3,
104-
causal=False,
105104
):
106105
model = conv_tasnet.model.ConvTasNet(
107106
num_sources=num_sources,
@@ -112,7 +111,6 @@ def _get_model(
112111
msk_num_hidden_feats=msk_num_hidden_feats,
113112
msk_num_layers=msk_num_layers,
114113
msk_num_stacks=msk_num_stacks,
115-
causal=causal,
116114
)
117115
_LG.info_on_master("Model Configuration:")
118116
_LG.info_on_master(" - N: %d", enc_num_feats)

examples/source_separation/conv_tasnet/trainer.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,14 @@ def train_one_epoch(self):
9898

9999
num_batches = len(self.train_loader)
100100
for i, batch in enumerate(self.train_loader, start=1):
101-
mixed = batch.mix.to(self.device)
102-
sources = batch.src.to(self.device)
101+
mix = batch.mix.to(self.device)
102+
src = batch.src.to(self.device)
103+
mask = batch.mask.to(self.device)
103104

104-
estimate = self.model(mixed)
105-
si_snri, sdri = si_sdr_improvement(estimate, sources, mixed)
105+
estimate = self.model(mix)
106+
estimate = estimate * mask
107+
108+
si_snri, sdri = si_sdr_improvement(estimate, src, mix)
106109
si_snri = si_snri.mean()
107110
sdri = sdri.mean()
108111

@@ -131,29 +134,28 @@ def validate(self):
131134
def _test(self, loader):
132135
self.model.eval()
133136

134-
total_si_snri = 0.0
135-
total_sdri = 0.0
137+
total_si_snri = torch.zeros(1, dtype=torch.float32, device=self.device)
138+
total_sdri = torch.zeros(1, dtype=torch.float32, device=self.device)
136139

137-
for samples in loader:
138-
# Due to the possible length difference, we run evaluation sample-wise
139-
for sample in samples:
140-
mixed = sample.mix.to(self.device)
141-
sources = sample.src.to(self.device)
140+
for batch in loader:
141+
mix = batch.mix.to(self.device)
142+
src = batch.src.to(self.device)
143+
mask = batch.mask.to(self.device)
142144

143-
estimate = self.model(mixed)
144-
si_snri, sdri = si_sdr_improvement(estimate, sources, mixed)
145-
si_snri = si_snri.sum()
146-
sdri = sdri.sum()
145+
estimate = self.model(mix)
146+
estimate = estimate * mask
147147

148-
dist.all_reduce(si_snri, dist.ReduceOp.SUM)
149-
dist.all_reduce(sdri, dist.ReduceOp.SUM)
148+
si_snri, sdri = si_sdr_improvement(estimate, src, mix)
150149

151-
total_si_snri += si_snri.item()
152-
total_sdri += sdri.item()
150+
total_si_snri += si_snri.sum()
151+
total_sdri += sdri.sum()
153152

154153
if self.debug:
155154
break
156155

156+
dist.all_reduce(total_si_snri, dist.ReduceOp.SUM)
157+
dist.all_reduce(total_sdri, dist.ReduceOp.SUM)
158+
157159
num_samples = len(loader.dataset)
158-
metric = Metric(total_si_snri / num_samples, total_sdri / num_samples)
160+
metric = Metric(total_si_snri.item() / num_samples, total_sdri.item() / num_samples)
159161
return metric

0 commit comments

Comments
 (0)