@@ -865,37 +865,89 @@ def model_calibration(self,
865865
866866 def eval_func (self , model , dataloader , postprocess , metrics , measurer , iteration , conf = None ):
867867 results = []
868- for idx , (input , label ) in enumerate (dataloader ):
869- if measurer is not None :
870- measurer .start ()
871-
872- output = pytorch_forward_wrapper (model , input , device = self .device , conf = conf )
873- if self .device != "cpu" : # pragma: no cover
874- output = output .to ("cpu" )
875- label = label .to ("cpu" )
876- if measurer is not None :
877- measurer .end ()
878- if postprocess is not None :
879- output , label = postprocess ((output , label ))
880- if metrics :
881- for metric in metrics :
882- if not hasattr (metric , "compare_label" ) or \
883- (hasattr (metric , "compare_label" ) and metric .compare_label ):
884- metric .update (output , label )
885-
886- # If distributed dataloader, gather all outputs to update metric
887- if getattr (dataloader , 'distributed' , False ) or \
888- isinstance (dataloader .sampler , \
889- torch .utils .data .distributed .DistributedSampler ):
890- hvd .init ()
868+ try :
869+ for idx , (input , label ) in enumerate (dataloader ):
870+ if measurer is not None :
871+ measurer .start ()
872+
873+ output = pytorch_forward_wrapper (model , input , device = self .device , conf = conf )
874+ if self .device != "cpu" : # pragma: no cover
875+ output = output .to ("cpu" )
876+ label = label .to ("cpu" )
877+ if measurer is not None :
878+ measurer .end ()
879+ if postprocess is not None :
880+ output , label = postprocess ((output , label ))
881+ if metrics :
891882 for metric in metrics :
892- metric .hvd = hvd
883+ if not hasattr (metric , "compare_label" ) or \
884+ (hasattr (metric , "compare_label" ) and metric .compare_label ):
885+ metric .update (output , label )
886+
887+ # If distributed dataloader, gather all outputs to update metric
888+ if getattr (dataloader , 'distributed' , False ) or \
889+ isinstance (dataloader .sampler , \
890+ torch .utils .data .distributed .DistributedSampler ):
891+ hvd .init ()
892+ for metric in metrics :
893+ metric .hvd = hvd
894+
895+ if self .fp32_preds_as_label :
896+ self .fp32_results .append (output ) if self .is_baseline else \
897+ results .append (output )
898+ if idx + 1 == iteration :
899+ break
900+ except Exception as e :
901+ logger .warning ("The dataloader didn't include label, will try input without label!" )
902+ for idx , input in enumerate (dataloader ):
903+ if (isinstance (input , dict ) or isinstance (input , UserDict )):
904+ if not self .benchmark :
905+ assert "label" in input , \
906+ "The dataloader must include label to measure the metric!"
907+ label = input ["label" ].to ("cpu" )
908+ elif not self .benchmark :
909+ assert False , "The dataloader must include label to measure the metric!"
910+
911+ if measurer is not None :
912+ measurer .start ()
913+
914+ output = pytorch_forward_wrapper (model , input , device = self .device , conf = conf )
915+
916+ if measurer is not None :
917+ measurer .end ()
918+
919+ if self .device != "cpu" and not self .benchmark : # pragma: no cover
920+ if isinstance (output , dict ) or isinstance (input , UserDict ):
921+ for key in output :
922+ output [key ] = output [key ].to ("cpu" )
923+ elif isinstance (output , list ) or isinstance (output , tuple ):
924+ for tensor in output :
925+ tensor = tensor .to ("cpu" )
926+ else :
927+ output = output .to ("cpu" )
893928
894- if self .fp32_preds_as_label :
895- self .fp32_results .append (output ) if self .is_baseline else \
896- results .append (output )
897- if idx + 1 == iteration :
898- break
929+ if postprocess is not None and not self .benchmark :
930+ output , label = postprocess ((output , label ))
931+
932+ if metrics and not self .benchmark :
933+ for metric in metrics :
934+ if not hasattr (metric , "compare_label" ) or \
935+ (hasattr (metric , "compare_label" ) and metric .compare_label ):
936+ metric .update (output , label )
937+
938+ # If distributed dataloader, gather all outputs to update metric
939+ if getattr (dataloader , 'distributed' , False ) or \
940+ isinstance (dataloader .sampler , \
941+ torch .utils .data .distributed .DistributedSampler ):
942+ hvd .init ()
943+ for metric in metrics :
944+ metric .hvd = hvd
945+
946+ if self .fp32_preds_as_label :
947+ self .fp32_results .append (output ) if self .is_baseline else \
948+ results .append (output )
949+ if idx + 1 == iteration :
950+ break
899951 return results
900952
901953 def model_eval (self ,
0 commit comments