Skip to content

Commit 47612a7

Browse files
authored
[NLP] Update evaluate.py for results format changes (#2533)
Adapt the test script for the result format changes in #2376
1 parent 5d6ce32 commit 47612a7

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

bin/pytorch_inference/evaluate.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,15 @@ def compare_results(expected, actual, tolerance):
170170

171171
request_id = actual['request_id']
172172

173-
if len(expected['inference']) != len(actual['inference']):
173+
actual_result = actual['result']
174+
175+
if len(expected['inference']) != len(actual_result['inference']):
174176
print("[{}] len(inference) does not match [{}], [{}]".format(request_id, len(expected['inference']), len(actual['inference'])), flush=True)
175177
return False
176178

177179
for i in range(len(expected['inference'])):
178180
expected_array = expected['inference'][i]
179-
actual_array = actual['inference'][i]
181+
actual_array = actual_result['inference'][i]
180182

181183
if len(expected_array) != len(actual_array):
182184
print("[{}] array [{}] lengths are not equal [{}], [{}]".format(request_id, i, len(expected_array), len(actual_array)), flush=True)
@@ -280,16 +282,17 @@ def test_evaluation(args):
280282
result_docs = json.load(output_file)
281283
except:
282284
print("Error parsing json: ", sys.exc_info()[0])
283-
return
285+
return
286+
284287

285288
for result in result_docs:
289+
286290
if 'error' in result:
287291
print(f"Inference failed. Request: {result['error']['request_id']}, Msg: {result['error']['error']}")
288292
results_match = False
289293
continue
290294

291-
if 'thread_settings' in result:
292-
print(f"Thread settings read: {result}")
295+
if 'thread_settings' in result:
293296
continue
294297

295298
expected = test_evaluation[doc_count]['expected_output']
@@ -301,7 +304,7 @@ def test_evaluation(args):
301304
total_time_ms += result['time_ms']
302305

303306
# compare to expected
304-
if compare_results(expected, result['result'], tolerance) == False:
307+
if compare_results(expected, result, tolerance) == False:
305308
print()
306309
print(f'ERROR: inference result [{doc_count}] does not match expected results')
307310
print()

0 commit comments

Comments
 (0)