Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,37 +865,89 @@ def model_calibration(self,

def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration, conf=None):
results = []
for idx, (input, label) in enumerate(dataloader):
if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
if self.device != "cpu": # pragma: no cover
output = output.to("cpu")
label = label.to("cpu")
if measurer is not None:
measurer.end()
if postprocess is not None:
output, label = postprocess((output, label))
if metrics:
for metric in metrics:
if not hasattr(metric, "compare_label") or \
(hasattr(metric, "compare_label") and metric.compare_label):
metric.update(output, label)

# If distributed dataloader, gather all outputs to update metric
if getattr(dataloader, 'distributed', False) or \
isinstance(dataloader.sampler, \
torch.utils.data.distributed.DistributedSampler):
hvd.init()
try:
for idx, (input, label) in enumerate(dataloader):
if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
if self.device != "cpu": # pragma: no cover
output = output.to("cpu")
label = label.to("cpu")
if measurer is not None:
measurer.end()
if postprocess is not None:
output, label = postprocess((output, label))
if metrics:
for metric in metrics:
metric.hvd = hvd
if not hasattr(metric, "compare_label") or \
(hasattr(metric, "compare_label") and metric.compare_label):
metric.update(output, label)

# If distributed dataloader, gather all outputs to update metric
if getattr(dataloader, 'distributed', False) or \
isinstance(dataloader.sampler, \
torch.utils.data.distributed.DistributedSampler):
hvd.init()
for metric in metrics:
metric.hvd = hvd

if self.fp32_preds_as_label:
self.fp32_results.append(output) if self.is_baseline else \
results.append(output)
if idx + 1 == iteration:
break
except Exception as e:
logger.warning("The dataloader didn't include label, will try input without label!")
for idx, input in enumerate(dataloader):
if (isinstance(input, dict) or isinstance(input, UserDict)):
if not self.benchmark:
assert "label" in input, \
"The dataloader must include label to measure the metric!"
label = input["label"].to("cpu")
elif not self.benchmark:
assert False, "The dataloader must include label to measure the metric!"

if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)

if measurer is not None:
measurer.end()

if self.device != "cpu" and not self.benchmark: # pragma: no cover
if isinstance(output, dict) or isinstance(input, UserDict):
for key in output:
output[key] = output[key].to("cpu")
elif isinstance(output, list) or isinstance(output, tuple):
for tensor in output:
tensor = tensor.to("cpu")
else:
output = output.to("cpu")

if self.fp32_preds_as_label:
self.fp32_results.append(output) if self.is_baseline else \
results.append(output)
if idx + 1 == iteration:
break
if postprocess is not None and not self.benchmark:
output, label = postprocess((output, label))

if metrics and not self.benchmark:
for metric in metrics:
if not hasattr(metric, "compare_label") or \
(hasattr(metric, "compare_label") and metric.compare_label):
metric.update(output, label)

# If distributed dataloader, gather all outputs to update metric
if getattr(dataloader, 'distributed', False) or \
isinstance(dataloader.sampler, \
torch.utils.data.distributed.DistributedSampler):
hvd.init()
for metric in metrics:
metric.hvd = hvd

if self.fp32_preds_as_label:
self.fp32_results.append(output) if self.is_baseline else \
results.append(output)
if idx + 1 == iteration:
break
return results

def model_eval(self,
Expand Down