@@ -98,11 +98,14 @@ def train_one_epoch(self):
9898
9999 num_batches = len (self .train_loader )
100100 for i , batch in enumerate (self .train_loader , start = 1 ):
101- mixed = batch .mix .to (self .device )
102- sources = batch .src .to (self .device )
101+ mix = batch .mix .to (self .device )
102+ src = batch .src .to (self .device )
103+ mask = batch .mask .to (self .device )
103104
104- estimate = self .model (mixed )
105- si_snri , sdri = si_sdr_improvement (estimate , sources , mixed )
105+ estimate = self .model (mix )
106+ estimate = estimate * mask
107+
108+ si_snri , sdri = si_sdr_improvement (estimate , src , mix )
106109 si_snri = si_snri .mean ()
107110 sdri = sdri .mean ()
108111
@@ -131,29 +134,28 @@ def validate(self):
131134 def _test (self , loader ):
132135 self .model .eval ()
133136
134- total_si_snri = 0.0
135- total_sdri = 0.0
137+ total_si_snri = torch . zeros ( 1 , dtype = torch . float32 , device = self . device )
138+ total_sdri = torch . zeros ( 1 , dtype = torch . float32 , device = self . device )
136139
137- for samples in loader :
138- # Due to the possible length difference, we run evaluation sample-wise
139- for sample in samples :
140- mixed = sample .mix .to (self .device )
141- sources = sample .src .to (self .device )
140+ for batch in loader :
141+ mix = batch .mix .to (self .device )
142+ src = batch .src .to (self .device )
143+ mask = batch .mask .to (self .device )
142144
143- estimate = self .model (mixed )
144- si_snri , sdri = si_sdr_improvement (estimate , sources , mixed )
145- si_snri = si_snri .sum ()
146- sdri = sdri .sum ()
145+ estimate = self .model (mix )
146+ estimate = estimate * mask
147147
148- dist .all_reduce (si_snri , dist .ReduceOp .SUM )
149- dist .all_reduce (sdri , dist .ReduceOp .SUM )
148+ si_snri , sdri = si_sdr_improvement (estimate , src , mix )
150149
151- total_si_snri += si_snri .item ()
152- total_sdri += sdri .item ()
150+ total_si_snri += si_snri .sum ()
151+ total_sdri += sdri .sum ()
153152
154153 if self .debug :
155154 break
156155
156+ dist .all_reduce (total_si_snri , dist .ReduceOp .SUM )
157+ dist .all_reduce (total_sdri , dist .ReduceOp .SUM )
158+
157159 num_samples = len (loader .dataset )
158- metric = Metric (total_si_snri / num_samples , total_sdri / num_samples )
160+ metric = Metric (total_si_snri . item () / num_samples , total_sdri . item () / num_samples )
159161 return metric
0 commit comments