diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index b0b51040510..23e12ab40f6 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -865,37 +865,89 @@ def model_calibration(self, def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration, conf=None): results = [] - for idx, (input, label) in enumerate(dataloader): - if measurer is not None: - measurer.start() - - output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf) - if self.device != "cpu": # pragma: no cover - output = output.to("cpu") - label = label.to("cpu") - if measurer is not None: - measurer.end() - if postprocess is not None: - output, label = postprocess((output, label)) - if metrics: - for metric in metrics: - if not hasattr(metric, "compare_label") or \ - (hasattr(metric, "compare_label") and metric.compare_label): - metric.update(output, label) - - # If distributed dataloader, gather all outputs to update metric - if getattr(dataloader, 'distributed', False) or \ - isinstance(dataloader.sampler, \ - torch.utils.data.distributed.DistributedSampler): - hvd.init() + try: + for idx, (input, label) in enumerate(dataloader): + if measurer is not None: + measurer.start() + + output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf) + if self.device != "cpu": # pragma: no cover + output = output.to("cpu") + label = label.to("cpu") + if measurer is not None: + measurer.end() + if postprocess is not None: + output, label = postprocess((output, label)) + if metrics: for metric in metrics: - metric.hvd = hvd + if not hasattr(metric, "compare_label") or \ + (hasattr(metric, "compare_label") and metric.compare_label): + metric.update(output, label) + + # If distributed dataloader, gather all outputs to update metric + if getattr(dataloader, 'distributed', False) or \ + isinstance(dataloader.sampler, \ + torch.utils.data.distributed.DistributedSampler): + hvd.init() + for metric in metrics: + metric.hvd = hvd + + if self.fp32_preds_as_label: + self.fp32_results.append(output) if self.is_baseline else \ + results.append(output) + if idx + 1 == iteration: + break + except Exception as e: + logger.warning("The dataloader didn't include label, will try input without label!") + for idx, input in enumerate(dataloader): + if (isinstance(input, dict) or isinstance(input, UserDict)): + if not self.benchmark: + assert "label" in input, \ + "The dataloader must include label to measure the metric!" + label = input["label"].to("cpu") + elif not self.benchmark: + assert False, "The dataloader must include label to measure the metric!" + + if measurer is not None: + measurer.start() + + output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf) + + if measurer is not None: + measurer.end() + + if self.device != "cpu" and not self.benchmark: # pragma: no cover + if isinstance(output, dict) or isinstance(input, UserDict): + for key in output: + output[key] = output[key].to("cpu") + elif isinstance(output, list) or isinstance(output, tuple): + for tensor in output: + tensor = tensor.to("cpu") + else: + output = output.to("cpu") - if self.fp32_preds_as_label: - self.fp32_results.append(output) if self.is_baseline else \ - results.append(output) - if idx + 1 == iteration: - break + if postprocess is not None and not self.benchmark: + output, label = postprocess((output, label)) + + if metrics and not self.benchmark: + for metric in metrics: + if not hasattr(metric, "compare_label") or \ + (hasattr(metric, "compare_label") and metric.compare_label): + metric.update(output, label) + + # If distributed dataloader, gather all outputs to update metric + if getattr(dataloader, 'distributed', False) or \ + isinstance(dataloader.sampler, \ + torch.utils.data.distributed.DistributedSampler): + hvd.init() + for metric in metrics: + metric.hvd = hvd + + if self.fp32_preds_as_label: + self.fp32_results.append(output) if self.is_baseline else \ + results.append(output) + if idx + 1 == iteration: + break return results def model_eval(self,