@@ -91,11 +91,13 @@ def check_dataloader_idx(self, result: Result) -> bool:
9191 random_key = list (result .keys ())[- 1 ]
9292 return result ["meta" ][random_key ]["dataloader_idx" ] is not None
9393
94- def get_latest_from_func_name (self , latest_result , func_name : str , * args , ** kwargs ) -> Dict :
94+ def get_latest_from_func_name (self , latest_result_opt , func_name : str , * args , ** kwargs ) -> Dict :
9595 results = {}
96- add_dataloader_idx = self .check_dataloader_idx (latest_result )
97- func = getattr (latest_result , func_name )
98- results .update (func (* args , add_dataloader_idx = add_dataloader_idx , ** kwargs ))
96+ for opt_idx in latest_result_opt :
97+ latest_result = latest_result_opt [opt_idx ]
98+ add_dataloader_idx = self .check_dataloader_idx (latest_result )
99+ func = getattr (latest_result , func_name )
100+ results .update (func (* args , add_dataloader_idx = add_dataloader_idx , ** kwargs ))
99101 return results
100102
101103 def run_latest_batch_metrics_with_func_name (self , func_name , * args , ** kwargs ) -> List [Dict ]:
@@ -156,6 +158,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
156158 assert isinstance (result , Result )
157159 if dataloader_idx is None :
158160 dataloader_idx = 0
161+
159162 if extra_info is None :
160163 extra_info = {}
161164
@@ -166,22 +169,27 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
166169 if dataloader_idx not in self ._internals :
167170 self ._internals [dataloader_idx ] = {}
168171 self ._internals_reduced [dataloader_idx ] = defaultdict (dict )
172+ self ._latest_ref [dataloader_idx ] = {}
169173
170174 # extract infos
171175 opt_idx = extra_info ["opt_idx" ]
172176 batch_idx = extra_info ["batch_idx" ]
173177
174178 self ._append_to_structure (self ._internals [dataloader_idx ], opt_idx , batch_idx , result )
175179
176- self ._latest_ref [dataloader_idx ] = result
180+ self ._latest_ref [dataloader_idx ][ opt_idx ] = result
177181
178182 # [dataloader_idx] is a list
179183 else :
180184 self ._internal_type = ResultStoreType .OUTSIDE_BATCH_TRAIN_LOOP
181185 self ._internals .setdefault (dataloader_idx , [])
182186 self ._internals [dataloader_idx ].append (result )
183187
184- self ._latest_ref [dataloader_idx ] = result
188+ if dataloader_idx not in self ._latest_ref :
189+ self ._latest_ref [dataloader_idx ] = {}
190+ self ._latest_ref [dataloader_idx ][0 ] = {}
191+
192+ self ._latest_ref [dataloader_idx ][0 ] = result
185193
186194 def auto_reduce_results_on_epoch_end (self ) -> None :
187195 """
@@ -206,13 +214,9 @@ def auto_reduce_results_on_epoch_end(self) -> None:
206214 # TODO: How to start training in middle of epoch
207215 opt_outputs = epoch_metrics [opt_idx ]
208216
209- num_batch_idx = len (self ._internals [dl_idx ][num_opt_idx ]) - 1
210- assert num_batch_idx >= 0
211- batch_indexes = self ._internals [dl_idx ][num_opt_idx ].keys ()
212-
213217 # reduce across time first
214218 time_reduced_outputs = []
215- for batch_idx in batch_indexes :
219+ for batch_idx in opt_outputs . keys () :
216220 tbptt_outs = opt_outputs [batch_idx ]
217221 tbptt_outs = tbptt_outs [0 ].__class__ .reduce_across_time (tbptt_outs )
218222 if len (tbptt_outs ) > 1 :
0 commit comments