Skip to content

Commit b720976

Browse files
authored
fix ut bug for inc keras api change (#259)
1 parent 3264df3 commit b720976

File tree

4 files changed

+144
-115
lines changed

4 files changed

+144
-115
lines changed

nlp_toolkit/optimization/optimizer_tf.py

Lines changed: 141 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from neural_compressor import __version__
2323
from neural_compressor.experimental import common
2424
from neural_compressor.model.model import saved_model_session
25+
from neural_compressor.model.model import get_model_type
26+
from neural_compressor.experimental.data.transforms.imagenet_transform import LabelShift
27+
from neural_compressor.experimental.metric.metric import TensorflowTopK
2528
from nlp_toolkit import (DistillationConfig, QuantizationConfig, PruningConfig)
2629
from nlp_toolkit.optimization.quantization import QuantizationMode
2730
from nlp_toolkit.optimization.utils.metrics import Metric
@@ -146,135 +149,161 @@ def eval_dataset(self, eval_dataset):
146149
self._eval_dataset = eval_dataset
147150

148151
def builtin_eval_func(self, model):
149-
"""Evaluate the model for specified metric on validation dataset.
152+
"""
153+
Custom Evaluate function to inference the model for specified metric on validation dataset.
150154
151155
Args:
152-
model ([Graph, GraphDef or Path String]): The model could be the graph,
153-
graph_def object, the frozen pb or ckpt/savedmodel folder path.
156+
model ([tf.saved_model.load]): The model will be the class of tf.saved_model.load(quantized_model_path).
154157
155158
Returns:
156159
[float]: evaluation result, the larger is better.
157160
"""
158-
num_examples = sum(1 for _ in (self._eval_dataset.unbatch(
159-
) if hasattr(self._eval_dataset, "unbatch") else self._eval_dataset))
160-
161+
model_type = None
162+
try:
163+
model_type = get_model_type(model)
164+
except ValueError:
165+
logger.info("use keras savedModel")
166+
167+
num_examples = sum(1 for _ in (
168+
self._eval_dataset.unbatch() if hasattr(self._eval_dataset, "unbatch") else self._eval_dataset))
161169
logger.info(f"***** Running Evaluation *****")
162170
logger.info(f" Num examples in dataset = {num_examples}")
163171
logger.info(f" Batch size = {self.args.per_device_eval_batch_size}")
164172

165-
from neural_compressor.adaptor.tf_utils.util import get_tensor_by_name
166-
input_tensor = [get_tensor_by_name(\
167-
model, x) for x in self.input_names]
168-
output_tensor = [get_tensor_by_name(\
169-
model, x) for x in self.output_names]
170-
171-
logger.info("Start to evaluate the TensorFlow model.")
172-
173-
total_time = 0
174-
config = tf.compat.v1.ConfigProto()
175-
config.use_per_session_threads = 1
176-
config.inter_op_parallelism_threads = 1
177-
sess = tf.compat.v1.Session(graph=model, config=config)
178-
feed_dict = {}
179-
label_ids: np.ndarray = None
180-
preds: np.ndarray = None
181-
for idx, (inputs, labels) in enumerate(self._eval_dataset):
182-
assert len(input_tensor) == len(inputs), \
183-
'inputs len must equal with input_tensor'
173+
if model_type is None:
174+
infer = model.signatures["serving_default"]
175+
output_dict_keys = infer.structured_outputs.keys()
176+
output_name = list(output_dict_keys)[0]
177+
178+
postprocess = LabelShift(label_shift=1)
179+
metric = TensorflowTopK(k=1)
180+
181+
def eval_func(dataloader, metric):
182+
for idx, (inputs, labels) in enumerate(dataloader):
183+
for name in inputs:
184+
inputs[name] = tf.constant(inputs[name].numpy(), dtype=tf.int32)
185+
186+
predictions = infer(**inputs)[output_name]
187+
predictions = predictions.numpy()
188+
predictions, labels = postprocess((predictions, labels))
189+
metric.update(predictions, labels)
190+
191+
eval_func(self._eval_dataset, metric)
192+
acc = metric.result()
193+
return acc
194+
else: # pragma: no cover
195+
from neural_compressor.adaptor.tf_utils.util import get_tensor_by_name
196+
input_tensor = [get_tensor_by_name(\
197+
model, x) for x in self.input_names]
198+
output_tensor = [get_tensor_by_name(\
199+
model, x) for x in self.output_names]
200+
201+
logger.info("Start to evaluate the TensorFlow model.")
202+
203+
total_time = 0
204+
config = tf.compat.v1.ConfigProto()
205+
config.use_per_session_threads = 1
206+
config.inter_op_parallelism_threads = 1
207+
sess = tf.compat.v1.Session(graph=model, config=config)
184208
feed_dict = {}
185-
for name in inputs:
186-
for tensor in input_tensor:
187-
pos = tensor.name.rfind(":")
188-
t_name = tensor.name if pos < 0 else tensor.name[:pos]
189-
if name == t_name:
190-
feed_dict[tensor] = inputs[name].numpy()
191-
break
192-
193-
start = time.time()
194-
logits = sess.run(output_tensor, feed_dict)
195-
total_time += time.time() - start
196-
if not self.args.prediction_loss_only:
197-
if isinstance(logits, tuple):
198-
logits = logits[0]
199-
200-
if isinstance(labels, tuple):
201-
labels = labels[0].numpy()
202-
203-
if isinstance(logits,
204-
list) and len(logits) > 1: # pragma: no cover
205-
for val in logits:
209+
label_ids: np.ndarray = None
210+
preds: np.ndarray = None
211+
for idx, (inputs, labels) in enumerate(self._eval_dataset):
212+
assert len(input_tensor) == len(inputs), \
213+
'inputs len must equal with input_tensor'
214+
feed_dict = {}
215+
for name in inputs:
216+
for tensor in input_tensor:
217+
pos = tensor.name.rfind(":")
218+
t_name = tensor.name if pos < 0 else tensor.name[:pos]
219+
if name == t_name:
220+
feed_dict[tensor] = inputs[name].numpy()
221+
break
222+
223+
start = time.time()
224+
logits = sess.run(output_tensor, feed_dict)
225+
total_time += time.time() - start
226+
if not self.args.prediction_loss_only:
227+
if isinstance(logits, tuple):
228+
logits = logits[0]
229+
230+
if isinstance(labels, tuple):
231+
labels = labels[0].numpy()
232+
233+
if isinstance(logits,
234+
list) and len(logits) > 1: # pragma: no cover
235+
for val in logits:
236+
if preds is None:
237+
preds = val
238+
else:
239+
preds = np.append(preds, val, axis=0)
240+
241+
for val in labels:
242+
if label_ids is None:
243+
label_ids = val.numpy()
244+
else:
245+
label_ids = np.append(label_ids,
246+
val.numpy(),
247+
axis=0)
248+
else:
206249
if preds is None:
207-
preds = val
250+
preds = logits[0] if isinstance(logits,
251+
list) else logits
208252
else:
209-
preds = np.append(preds, val, axis=0)
253+
preds = np.append(
254+
preds,
255+
logits[0] if isinstance(logits, list) else logits,
256+
axis=0)
210257

211-
for val in labels:
212258
if label_ids is None:
213-
label_ids = val.numpy()
259+
label_ids = labels[0].numpy() if isinstance(
260+
labels, list) else labels.numpy()
214261
else:
215-
label_ids = np.append(label_ids,
216-
val.numpy(),
217-
axis=0)
218-
else:
219-
if preds is None:
220-
preds = logits[0] if isinstance(logits,
221-
list) else logits
222-
else:
223-
preds = np.append(
224-
preds,
225-
logits[0] if isinstance(logits, list) else logits,
226-
axis=0)
227-
228-
if label_ids is None:
229-
label_ids = labels[0].numpy() if isinstance(
230-
labels, list) else labels.numpy()
231-
else:
232-
label_ids = np.append(
233-
label_ids,
234-
labels[0].numpy()
235-
if isinstance(labels, list) else labels.numpy(),
236-
axis=0)
237-
238-
if self.compute_metrics is not None and preds is not None and label_ids is not None:
239-
try:
240-
loss = self.criterion(
241-
label_ids, preds) if self.criterion is not None else None
242-
except Exception as e: # pragma: no cover
243-
logger.info(e)
244-
logger.info("There is no loss function or loss compute error, \
245-
Please compute loss in compute_metrics function"
246-
)
247-
loss = None
248-
results = self.compute_metrics({"logits": preds}, label_ids)
249-
if loss is not None:
250-
results["loss"] = loss.numpy()
251-
252-
if isinstance(self.metrics, list):
253-
nums = len(self.metrics)
254-
for metric in self.metrics:
255-
assert metric.name in results.keys(), \
256-
"Please set metric from {}".format(results.keys())
257-
if nums == 1:
258-
result = results.get(self.metrics[0].name)
259-
else: # pragma: no cover
260-
result = 0
262+
label_ids = np.append(
263+
label_ids,
264+
labels[0].numpy()
265+
if isinstance(labels, list) else labels.numpy(),
266+
axis=0)
267+
268+
if self.compute_metrics is not None and preds is not None and label_ids is not None:
269+
try:
270+
loss = self.criterion(
271+
label_ids, preds) if self.criterion is not None else None
272+
except Exception as e: # pragma: no cover
273+
logger.info(e)
274+
logger.info("There is no loss function or loss compute error, \
275+
Please compute loss in compute_metrics function"
276+
)
277+
loss = None
278+
results = self.compute_metrics({"logits": preds}, label_ids)
279+
if loss is not None:
280+
results["loss"] = loss.numpy()
281+
282+
if isinstance(self.metrics, list):
283+
nums = len(self.metrics)
261284
for metric in self.metrics:
262-
assert metric.weight_ratio is not None, \
263-
"Please set weights for metric if you want to use more than one metric"
264-
result += results[metric.name] * metric.weighted
265-
logger.info("metric Accuracy: {}".format(result))
266-
elif isinstance(self.metrics, Metric):
267-
assert self.metrics.name in results.keys(), \
268-
"Please set metric from {}".format(results.keys())
269-
result = results.get(self.metrics.name)
270-
logger.info("metric Accuracy: {}".format(result))
271-
else: # pragma: no cover
272-
assert False, "Please set the correct metrics format from the README"
273-
else:
274-
result = 0
275-
276-
logger.info("Throughput: {} samples/sec".format(num_examples / total_time))
277-
return result
285+
assert metric.name in results.keys(), \
286+
"Please set metric from {}".format(results.keys())
287+
if nums == 1:
288+
result = results.get(self.metrics[0].name)
289+
else: # pragma: no cover
290+
result = 0
291+
for metric in self.metrics:
292+
assert metric.weight_ratio is not None, \
293+
"Please set weights for metric if you want to use more than one metric"
294+
result += results[metric.name] * metric.weighted
295+
logger.info("metric Accuracy: {}".format(result))
296+
elif isinstance(self.metrics, Metric):
297+
assert self.metrics.name in results.keys(), \
298+
"Please set metric from {}".format(results.keys())
299+
result = results.get(self.metrics.name)
300+
logger.info("metric Accuracy: {}".format(result))
301+
else: # pragma: no cover
302+
assert False, "Please set the correct metrics format from the README"
303+
else:
304+
result = 0
305+
logger.info("Throughput: {} samples/sec".format(num_examples / total_time))
306+
return result
278307

279308
def init_quantizer(
280309
self,

tests/test_tf_distillation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def preprocess_function(examples):
5050
drop_remainder=False,
5151
# `label_cols` is needed for user-defined losses, such as in this example
5252
# datasets v2.3.x need "labels", not "label"
53-
label_cols=["label", "labels"]
53+
label_cols=["labels"]
5454
if "label" in dataset.column_names else None,
5555
)
5656
parser = HfArgumentParser(TFTrainingArguments)

tests/test_tf_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def preprocess_function(examples):
5151
drop_remainder=False,
5252
# `label_cols` is needed for user-defined losses, such as in this example
5353
# datasets v2.3.x need "labels", not "label"
54-
label_cols=["label", "labels"] if "label" in dataset.column_names else None,
54+
label_cols=["labels"] if "label" in dataset.column_names else None,
5555
)
5656
parser = HfArgumentParser(TFTrainingArguments)
5757
self.args = parser.parse_args_into_dataclasses(args=["--output_dir", "./quantized_model",

tests/test_tf_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def preprocess_function(examples):
5151
drop_remainder=False,
5252
# `label_cols` is needed for user-defined losses, such as in this example
5353
# datasets v2.3.x need "labels", not "label"
54-
label_cols=["label", "labels"] if "label" in dataset.column_names else None,
54+
label_cols=["labels"] if "label" in dataset.column_names else None,
5555
)
5656

5757

0 commit comments

Comments
 (0)