77import os
88import tempfile
99import zipfile
10- from typing import Any , Optional , Tuple
10+ from collections import defaultdict
11+ from typing import Optional , Tuple
1112
1213import torch
1314
1415
16+ def flatten_args (args ) -> tuple | list :
17+ flattened_args : list = []
18+ if isinstance (args , torch .Tensor ):
19+ return [args ]
20+
21+ for arg in args :
22+ if isinstance (arg , (tuple , list )):
23+ flattened_args .extend (arg )
24+ else :
25+ flattened_args .append (arg )
26+
27+ return tuple (flattened_args )
28+
29+
1530class GenericModelEvaluator :
1631 def __init__ (
1732 self ,
@@ -32,31 +47,34 @@ def __init__(
3247 else :
3348 self .tosa_output_path = None
3449
35- def get_model_error (self ) -> tuple [ float , float , float , float ] :
50+ def get_model_error (self ) -> defaultdict :
3651 """
37- Returns the following metrics between the outputs of the FP32 and INT8 model:
52+ Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
3853 - Maximum error
3954 - Maximum absolute error
4055 - Maximum percentage error
4156 - Mean absolute error
4257 """
43- fp32_output = self .fp32_model (* self .example_input )
44- int8_output = self .int8_model (* self .example_input )
45-
46- difference = fp32_output - int8_output
47- percentage_error = torch .div (difference , fp32_output ) * 100
48-
49- max_error = torch .max (difference ).item ()
50- max_absolute_error = torch .max (torch .abs (difference )).item ()
51- max_percentage_error = torch .max (percentage_error ).item ()
52- mean_absolute_error = torch .mean (torch .abs (difference ).float ()).item ()
53-
54- return (
55- float (max_error ),
56- float (max_absolute_error ),
57- float (max_percentage_error ),
58- float (mean_absolute_error ),
59- )
58+ fp32_outputs = flatten_args (self .fp32_model (* self .example_input ))
59+ int8_outputs = flatten_args (self .int8_model (* self .example_input ))
60+
61+ model_error_dict = defaultdict (list )
62+
63+ for fp32_output , int8_output in zip (fp32_outputs , int8_outputs ):
64+ difference = fp32_output - int8_output
65+ percentage_error = torch .div (difference , fp32_output ) * 100
66+ model_error_dict ["max_error" ].append (torch .max (difference ).item ())
67+ model_error_dict ["max_absolute_error" ].append (
68+ torch .max (torch .abs (difference )).item ()
69+ )
70+ model_error_dict ["max_percentage_error" ].append (
71+ torch .max (percentage_error ).item ()
72+ )
73+ model_error_dict ["mean_absolute_error" ].append (
74+ torch .mean (torch .abs (difference ).float ()).item ()
75+ )
76+
77+ return model_error_dict
6078
6179 def get_compression_ratio (self ) -> float :
6280 """Compute the compression ratio of the outputted TOSA flatbuffer."""
@@ -72,19 +90,10 @@ def get_compression_ratio(self) -> float:
7290
7391 return compression_ratio
7492
75- def evaluate (self ) -> dict [str , Any ]:
76- max_error , max_absolute_error , max_percent_error , mean_absolute_error = (
77- self .get_model_error ()
78- )
79- output_metrics = {
80- "name" : self .model_name ,
81- "metrics" : {
82- "max_error" : max_error ,
83- "max_absolute_error" : max_absolute_error ,
84- "max_percentage_error" : max_percent_error ,
85- "mean_absolute_error" : mean_absolute_error ,
86- },
87- }
93+ def evaluate (self ) -> dict [any ]:
94+ model_error_dict = self .get_model_error ()
95+
96+ output_metrics = {"name" : self .model_name , "metrics" : dict (model_error_dict )}
8897
8998 if self .tosa_output_path :
9099 # We know output_metrics["metrics"] is list since we just defined it, safe to ignore.
0 commit comments