@@ -93,7 +93,9 @@ def wiki2_eval(
9393
9494
9595# adapted from Hicham Badri (@mobicham)
96- def benchmark (model , tokenizer , max_length , tasks = None , device = "cuda" ):
96+ def benchmark (
97+ model , tokenizer , max_length , tasks = None , evaluation_limit = None , device = "cuda"
98+ ):
9799 import lm_eval
98100 import numpy as np
99101
@@ -126,21 +128,33 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
126128 for task in [("truthfulqa_mc2" , 0 )]:
127129 tag , fewshot = task
128130 results [tag ] = lm_eval .evaluator .simple_evaluate (
129- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
131+ model_eval ,
132+ tasks = [tag ],
133+ num_fewshot = fewshot ,
134+ batch_size = eval_batch_size ,
135+ limit = evaluation_limit ,
130136 )["results" ]
131137 print (tag , results [tag ])
132138 if "winogrande" in tasks :
133139 for task in [("winogrande" , 5 )]:
134140 tag , fewshot = task
135141 results [tag ] = lm_eval .evaluator .simple_evaluate (
136- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
142+ model_eval ,
143+ tasks = [tag ],
144+ num_fewshot = fewshot ,
145+ batch_size = eval_batch_size ,
146+ limit = evaluation_limit ,
137147 )["results" ]
138148 print (tag , results [tag ])
139149 if "arc_challenge" in tasks :
140150 for task in [("arc_challenge" , 25 )]:
141151 tag , fewshot = task
142152 results [tag ] = lm_eval .evaluator .simple_evaluate (
143- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
153+ model_eval ,
154+ tasks = [tag ],
155+ num_fewshot = fewshot ,
156+ batch_size = eval_batch_size ,
157+ limit = evaluation_limit ,
144158 )["results" ]
145159 print (tag , results [tag ])
146160
@@ -149,14 +163,22 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
149163 for task in [("hellaswag" , 10 )]:
150164 tag , fewshot = task
151165 results [tag ] = lm_eval .evaluator .simple_evaluate (
152- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
166+ model_eval ,
167+ tasks = [tag ],
168+ num_fewshot = fewshot ,
169+ batch_size = eval_batch_size ,
170+ limit = evaluation_limit ,
153171 )["results" ]
154172 print (tag , results [tag ])
155173 if "gsm8k" in tasks :
156174 for task in [("gsm8k" , 5 )]:
157175 tag , fewshot = task
158176 results [tag ] = lm_eval .evaluator .simple_evaluate (
159- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
177+ model_eval ,
178+ tasks = [tag ],
179+ num_fewshot = fewshot ,
180+ batch_size = eval_batch_size ,
181+ limit = evaluation_limit ,
160182 )["results" ]
161183 print (tag , results [tag ])
162184 # ############################################
@@ -167,7 +189,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
167189 for task in [("mmlu" , 5 )]:
168190 tag , fewshot = task
169191 results_mmlu [tag ] = lm_eval .evaluator .simple_evaluate (
170- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
192+ model_eval ,
193+ tasks = [tag ],
194+ num_fewshot = fewshot ,
195+ batch_size = eval_batch_size ,
196+ limit = evaluation_limit ,
171197 )["results" ]
172198 print (tag , results_mmlu [tag ])
173199
@@ -188,7 +214,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
188214 for task in [("leaderboard_bbh" , 3 )]:
189215 tag , fewshot = task
190216 results [tag ] = lm_eval .evaluator .simple_evaluate (
191- model_eval , tasks = [tag ], num_fewshot = fewshot , batch_size = eval_batch_size
217+ model_eval ,
218+ tasks = [tag ],
219+ num_fewshot = fewshot ,
220+ batch_size = eval_batch_size ,
221+ limit = evaluation_limit ,
192222 )["results" ]
193223 print (tag , results [tag ])
194224 results ["bbh" ] = results [tag ]
@@ -202,7 +232,7 @@ def quantize_and_eval(
202232 tasks : list [str ],
203233 max_seq_length : int ,
204234 calibration_limit : int ,
205- validation_size : int ,
235+ evaluation_limit : int ,
206236 device : str ,
207237 precision : torch .dtype ,
208238 compile : bool ,
@@ -223,18 +253,26 @@ def quantize_and_eval(
223253 if quant .startswith ("awq-int4wo" ):
224254 group_size = int (quant .split ("-" )[2 ])
225255 print (f"running { quant } quantization with group size { group_size } " )
226- # TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
227- from torchao .quantization import FbgemmConfig
256+ from torchao . dtypes import Int4CPULayout
257+ from torchao .quantization import FbgemmConfig , Int4WeightOnlyConfig
228258
229259 # use_hqq = True
230260 # base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
231- base_config = FbgemmConfig (
232- input_dtype = torch .bfloat16 ,
233- weight_dtype = torch .int4 ,
234- output_dtype = torch .bfloat16 ,
235- block_size = [1 , group_size ],
236- preshuffle = False ,
237- )
261+ if device == "cuda" :
262+ # TODO: this is temporary, we'll be using Int4WeightOnlyConfig for CUDA soon
263+ base_config = FbgemmConfig (
264+ input_dtype = torch .bfloat16 ,
265+ weight_dtype = torch .int4 ,
266+ output_dtype = torch .bfloat16 ,
267+ block_size = [1 , group_size ],
268+ preshuffle = False ,
269+ )
270+ elif device == "cpu" :
271+ base_config = Int4WeightOnlyConfig (
272+ group_size = group_size , layout = Int4CPULayout (), set_inductor_config = False
273+ )
274+ else :
275+ assert False , "Unsupported device: {}" .format (device )
238276 print (f"running { quant } prepare and calibrate" )
239277 t0 = time .time ()
240278 quant_config = AWQConfig (base_config , step = "prepare" )
@@ -291,7 +329,14 @@ def quantize_and_eval(
291329 if compile :
292330 model = torch .compile (model )
293331
294- return benchmark (model , tokenizer , max_seq_length , tasks = tasks , device = device )
332+ return benchmark (
333+ model ,
334+ tokenizer ,
335+ max_seq_length ,
336+ tasks = tasks ,
337+ evaluation_limit = evaluation_limit ,
338+ device = device ,
339+ )
295340
296341
297342if __name__ == "__main__" :
@@ -310,8 +355,8 @@ def quantize_and_eval(
310355 "--tasks" ,
311356 nargs = "+" ,
312357 type = str ,
313- help = "Task to benchmark model on. Either PPL or QA " ,
314- default = ["PPL " ],
358+ help = "Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md " ,
359+ default = ["hellaswag " ],
315360 )
316361 parser .add_argument (
317362 "--calibration_limit" ,
@@ -320,7 +365,10 @@ def quantize_and_eval(
320365 help = "Number of samples to use for calibration. Default is 10." ,
321366 )
322367 parser .add_argument (
323- "--validation_size" , type = int , default = 1 , help = "Validation size. Default is 1."
368+ "--evaluation_limit" ,
369+ type = int ,
370+ default = None ,
371+ help = "Number of samples to use for evaluation. Default is None (all)." ,
324372 )
325373 parser .add_argument (
326374 "--device" ,
@@ -368,7 +416,7 @@ def quantize_and_eval(
368416 args .tasks ,
369417 args .max_seq_length ,
370418 args .calibration_limit ,
371- args .validation_size ,
419+ args .evaluation_limit ,
372420 args .device ,
373421 args .precision ,
374422 args .compile ,
0 commit comments