Skip to content

Commit f21e4a3

Browse files
authored
Enhancement benchmark with dataloader (#269)
Signed-off-by: Cheng, Penghui <[email protected]>
1 parent ae3cf56 commit f21e4a3

File tree

1 file changed

+81
-29
lines changed

1 file changed

+81
-29
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -865,37 +865,89 @@ def model_calibration(self,
865865

866866
def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration, conf=None):
867867
results = []
868-
for idx, (input, label) in enumerate(dataloader):
869-
if measurer is not None:
870-
measurer.start()
871-
872-
output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
873-
if self.device != "cpu": # pragma: no cover
874-
output = output.to("cpu")
875-
label = label.to("cpu")
876-
if measurer is not None:
877-
measurer.end()
878-
if postprocess is not None:
879-
output, label = postprocess((output, label))
880-
if metrics:
881-
for metric in metrics:
882-
if not hasattr(metric, "compare_label") or \
883-
(hasattr(metric, "compare_label") and metric.compare_label):
884-
metric.update(output, label)
885-
886-
# If distributed dataloader, gather all outputs to update metric
887-
if getattr(dataloader, 'distributed', False) or \
888-
isinstance(dataloader.sampler, \
889-
torch.utils.data.distributed.DistributedSampler):
890-
hvd.init()
868+
try:
869+
for idx, (input, label) in enumerate(dataloader):
870+
if measurer is not None:
871+
measurer.start()
872+
873+
output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
874+
if self.device != "cpu": # pragma: no cover
875+
output = output.to("cpu")
876+
label = label.to("cpu")
877+
if measurer is not None:
878+
measurer.end()
879+
if postprocess is not None:
880+
output, label = postprocess((output, label))
881+
if metrics:
891882
for metric in metrics:
892-
metric.hvd = hvd
883+
if not hasattr(metric, "compare_label") or \
884+
(hasattr(metric, "compare_label") and metric.compare_label):
885+
metric.update(output, label)
886+
887+
# If distributed dataloader, gather all outputs to update metric
888+
if getattr(dataloader, 'distributed', False) or \
889+
isinstance(dataloader.sampler, \
890+
torch.utils.data.distributed.DistributedSampler):
891+
hvd.init()
892+
for metric in metrics:
893+
metric.hvd = hvd
894+
895+
if self.fp32_preds_as_label:
896+
self.fp32_results.append(output) if self.is_baseline else \
897+
results.append(output)
898+
if idx + 1 == iteration:
899+
break
900+
except Exception as e:
901+
logger.warning("The dataloader didn't include label, will try input without label!")
902+
for idx, input in enumerate(dataloader):
903+
if (isinstance(input, dict) or isinstance(input, UserDict)):
904+
if not self.benchmark:
905+
assert "label" in input, \
906+
"The dataloader must include label to measure the metric!"
907+
label = input["label"].to("cpu")
908+
elif not self.benchmark:
909+
assert False, "The dataloader must include label to measure the metric!"
910+
911+
if measurer is not None:
912+
measurer.start()
913+
914+
output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
915+
916+
if measurer is not None:
917+
measurer.end()
918+
919+
if self.device != "cpu" and not self.benchmark: # pragma: no cover
920+
if isinstance(output, dict) or isinstance(input, UserDict):
921+
for key in output:
922+
output[key] = output[key].to("cpu")
923+
elif isinstance(output, list) or isinstance(output, tuple):
924+
for tensor in output:
925+
tensor = tensor.to("cpu")
926+
else:
927+
output = output.to("cpu")
893928

894-
if self.fp32_preds_as_label:
895-
self.fp32_results.append(output) if self.is_baseline else \
896-
results.append(output)
897-
if idx + 1 == iteration:
898-
break
929+
if postprocess is not None and not self.benchmark:
930+
output, label = postprocess((output, label))
931+
932+
if metrics and not self.benchmark:
933+
for metric in metrics:
934+
if not hasattr(metric, "compare_label") or \
935+
(hasattr(metric, "compare_label") and metric.compare_label):
936+
metric.update(output, label)
937+
938+
# If distributed dataloader, gather all outputs to update metric
939+
if getattr(dataloader, 'distributed', False) or \
940+
isinstance(dataloader.sampler, \
941+
torch.utils.data.distributed.DistributedSampler):
942+
hvd.init()
943+
for metric in metrics:
944+
metric.hvd = hvd
945+
946+
if self.fp32_preds_as_label:
947+
self.fp32_results.append(output) if self.is_baseline else \
948+
results.append(output)
949+
if idx + 1 == iteration:
950+
break
899951
return results
900952

901953
def model_eval(self,

0 commit comments

Comments
 (0)