5757from transformers .utils import check_min_version
5858from transformers .utils .versions import require_version
5959from utils_qa import postprocess_qa_predictions
60+ from neural_compressor .pruning import Pruning
61+ from neural_compressor .pruner .utils import WeightPruningConfig
6062
6163# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
6264check_min_version ("4.21.0.dev0" )
@@ -118,32 +120,32 @@ def parse_args():
118120 help = "The configuration name of the dataset to use (via the datasets library)." ,
119121 )
120122 parser .add_argument (
121- "--train_file" ,
122- type = str ,
123- default = None ,
123+ "--train_file" ,
124+ type = str ,
125+ default = None ,
124126 help = "A csv or a json file containing the training data."
125127 )
126128 parser .add_argument (
127- "--preprocessing_num_workers" ,
128- type = int , default = 4 ,
129+ "--preprocessing_num_workers" ,
130+ type = int , default = 4 ,
129131 help = "A csv or a json file containing the training data."
130132 )
131133
132134 parser .add_argument (
133- "--do_predict" ,
134- action = "store_true" ,
135+ "--do_predict" ,
136+ action = "store_true" ,
135137 help = "To do prediction on the question answering model"
136138 )
137139 parser .add_argument (
138- "--validation_file" ,
139- type = str ,
140- default = None ,
140+ "--validation_file" ,
141+ type = str ,
142+ default = None ,
141143 help = "A csv or a json file containing the validation data."
142144 )
143145 parser .add_argument (
144- "--test_file" ,
145- type = str ,
146- default = None ,
146+ "--test_file" ,
147+ type = str ,
148+ default = None ,
147149 help = "A csv or a json file containing the Prediction data."
148150 )
149151 parser .add_argument (
@@ -163,15 +165,14 @@ def parse_args():
163165 parser .add_argument (
164166 "--model_name_or_path" ,
165167 type = str ,
166- help = "Path to pretrained model or model identifier from huggingface.co/models." ,
167- required = False ,
168+ help = "Path to pretrained model or model identifier from huggingface.co/models."
168169 )
169170 parser .add_argument (
170171 "--teacher_model_name_or_path" ,
171172 type = str ,
172173 default = None ,
173174 help = "Path to pretrained model or model identifier from huggingface.co/models." ,
174- required = False ,
175+ required = False
175176 )
176177 parser .add_argument (
177178 "--config_name" ,
@@ -199,8 +200,8 @@ def parse_args():
199200 parser .add_argument (
200201 "--distill_loss_weight" ,
201202 type = float ,
202- default = 1 .0 ,
203- help = "distiller loss weight" ,
203+ default = 0 .0 ,
204+ help = "distiller loss weight"
204205 )
205206 parser .add_argument (
206207 "--per_device_eval_batch_size" ,
@@ -215,15 +216,15 @@ def parse_args():
215216 help = "Initial learning rate (after the potential warmup period) to use." ,
216217 )
217218 parser .add_argument (
218- "--weight_decay" ,
219- type = float ,
220- default = 0.0 ,
219+ "--weight_decay" ,
220+ type = float ,
221+ default = 0.0 ,
221222 help = "Weight decay to use."
222223 )
223224 parser .add_argument (
224- "--num_train_epochs" ,
225- type = int ,
226- default = 3 ,
225+ "--num_train_epochs" ,
226+ type = int ,
227+ default = 3 ,
227228 help = "Total number of training epochs to perform."
228229 )
229230 parser .add_argument (
@@ -245,29 +246,28 @@ def parse_args():
245246 help = "The scheduler type to use." ,
246247 choices = ["linear" , "cosine" , "cosine_with_restarts" , "polynomial" , "constant" , "constant_with_warmup" ],
247248 )
248-
249249 parser .add_argument (
250- "--warm_epochs" ,
251- type = int ,
252- default = 0 ,
250+ "--warm_epochs" ,
251+ type = int ,
252+ default = 0 ,
253253 help = "Number of epochs the network not be purned"
254254 )
255255 parser .add_argument (
256- "--num_warmup_steps" ,
257- type = int ,
258- default = 0 ,
256+ "--num_warmup_steps" ,
257+ type = int ,
258+ default = 0 ,
259259 help = "Number of steps for the warmup in the lr scheduler."
260260 )
261261 parser .add_argument (
262- "--output_dir" ,
263- type = str ,
264- default = None ,
262+ "--output_dir" ,
263+ type = str ,
264+ default = None ,
265265 help = "Where to store the final model."
266266 )
267267 parser .add_argument (
268- "--seed" ,
269- type = int ,
270- default = None ,
268+ "--seed" ,
269+ type = int ,
270+ default = None ,
271271 help = "A seed for reproducible training."
272272 )
273273 parser .add_argument (
@@ -341,33 +341,18 @@ def parse_args():
341341 choices = MODEL_TYPES ,
342342 )
343343 parser .add_argument (
344- "--cooldown_epochs" ,
345- type = int , default = 0 ,
346- help = "Cooling epochs after pruning."
347- )
348- parser .add_argument (
349- "--do_prune" , action = "store_true" ,
350- help = "Whether or not to prune the model"
351- )
352- parser .add_argument (
353- "--pruning_config" ,
354- type = str ,
355- help = "pruning_config" ,
356- )
357-
358- parser .add_argument (
359- "--push_to_hub" ,
360- action = "store_true" ,
344+ "--push_to_hub" ,
345+ action = "store_true" ,
361346 help = "Whether or not to push the model to the Hub."
362347 )
363348 parser .add_argument (
364- "--hub_model_id" ,
365- type = str ,
349+ "--hub_model_id" ,
350+ type = str ,
366351 help = "The name of the repository to keep in sync with the local `output_dir`."
367352 )
368353 parser .add_argument (
369- "--hub_token" ,
370- type = str ,
354+ "--hub_token" ,
355+ type = str ,
371356 help = "The token to use to push to the Model Hub."
372357 )
373358 parser .add_argument (
@@ -382,13 +367,6 @@ def parse_args():
382367 default = None ,
383368 help = "If the training should continue from a checkpoint folder." ,
384369 )
385- parser .add_argument (
386- "--distiller" ,
387- type = str ,
388- default = None ,
389- help = "teacher model path" ,
390- )
391-
392370 parser .add_argument (
393371 "--with_tracking" ,
394372 action = "store_true" ,
@@ -405,6 +383,35 @@ def parse_args():
405383 ),
406384 )
407385
386+ parser .add_argument (
387+ "--cooldown_epochs" ,
388+ type = int , default = 0 ,
389+ help = "Cooling epochs after pruning."
390+ )
391+ parser .add_argument (
392+ "--do_prune" , action = "store_true" ,
393+ help = "Whether or not to prune the model"
394+ )
395+ # parser.add_argument(
396+ # "--keep_conf", action="store_true",
397+ # help="Whether or not to keep the prune config infos"
398+ # )
399+ parser .add_argument (
400+ "--pruning_pattern" ,
401+ type = str , default = "1x1" ,
402+ help = "pruning pattern type, we support NxM and N:M."
403+ )
404+ parser .add_argument (
405+ "--target_sparsity" ,
406+ type = float , default = 0.8 ,
407+ help = "Target sparsity of the model."
408+ )
409+ parser .add_argument (
410+ "--pruning_frequency" ,
411+ type = int , default = - 1 ,
412+ help = "Sparse step frequency for iterative pruning, default to a quarter of pruning steps."
413+ )
414+
408415 args = parser .parse_args ()
409416
410417 # Sanity checks
@@ -435,7 +442,7 @@ def parse_args():
435442def main ():
436443 args = parse_args ()
437444
438- # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
445+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
439446 # information sent is the one passed as arguments along with your Python/PyTorch versions.
440447 # send_example_telemetry("run_qa_no_trainer", args)
441448
@@ -528,10 +535,13 @@ def main():
528535 "You are instantiating a new tokenizer from scratch. This is not supported by this script."
529536 "You can do it from another script, save it, and load it from here, using --tokenizer_name."
530537 )
531-
532- if args .teacher_model_name_or_path != None :
538+
539+ if args .distill_loss_weight > 0 :
540+ teacher_path = args .teacher_model_name_or_path
541+ if teacher_path is None :
542+ teacher_path = args .model_name_or_path
533543 teacher_model = AutoModelForQuestionAnswering .from_pretrained (
534- args . teacher_model_name_or_path ,
544+ teacher_path ,
535545 from_tf = bool (".ckpt" in args .model_name_or_path ),
536546 config = config ,
537547 )
@@ -815,7 +825,6 @@ def post_processing_function(examples, features, predictions, stage="eval"):
815825 def create_and_fill_np_array (start_or_end_logits , dataset , max_len ):
816826 """
817827 Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
818-
819828 Args:
820829 start_or_end_logits(:obj:`tensor`):
821830 This is the output predictions of the model. We can only enter either start or end logits.
@@ -847,6 +856,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
847856 # Optimizer
848857 # Split weights in two groups, one with weight decay and the other not.
849858 no_decay = ["bias" , "LayerNorm.weight" ]
859+ no_decay_outputs = ["bias" , "LayerNorm.weight" , "qa_outputs" ]
850860 optimizer_grouped_parameters = [
851861 {
852862 "params" : [p for n , p in model .named_parameters () if not any (nd in n for nd in no_decay )],
@@ -876,10 +886,11 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
876886 num_training_steps = args .max_train_steps ,
877887 )
878888
879- if args .teacher_model_name_or_path != None :
889+ if args .distill_loss_weight > 0 :
880890 teacher_model , model , optimizer , train_dataloader , eval_dataloader , lr_scheduler = accelerator .prepare (
881891 teacher_model , model , optimizer , train_dataloader , eval_dataloader , lr_scheduler
882892 )
893+ teacher_model .eval ()
883894 else :
884895 # Prepare everything with our `accelerator`.
885896 model , optimizer , train_dataloader , eval_dataloader , lr_scheduler = accelerator .prepare (
@@ -949,36 +960,36 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
949960 starting_epoch = resume_step // len (train_dataloader )
950961 resume_step -= starting_epoch * len (train_dataloader )
951962
952- params = [(n , p ) for (n , p ) in model .named_parameters () if
953- "bias" not in n and "LayerNorm" not in n and "embeddings" not in n \
954- and "layer.3.attention.output.dense.weight" not in n and "qa_outputs" not in n ]
955-
956- params_keys = [n for (n , p ) in params ]
957- for key in params_keys :
958- print (key )
959-
960963 # Pruning preparation
964+ pruning_configs = [
965+ {
966+ "pruning_type" : "snip_momentum" ,
967+ "pruning_scope" : "global" ,
968+ "sparsity_decay_type" : "exp"
969+ }
970+ ]
971+ config = WeightPruningConfig (
972+ pruning_configs ,
973+ target_sparsity = args .target_sparsity ,
974+ excluded_op_names = ["qa_outputs" , "pooler" , ".*embeddings*" ],
975+ pruning_op_types = ["Linear" ],
976+ max_sparsity_ratio_per_op = 0.98 ,
977+ pruning_scope = "global" ,
978+ pattern = args .pruning_pattern ,
979+ pruning_frequency = 1000
980+ )
981+ pruner = Pruning (config )
961982 num_iterations = len (train_dataset ) / total_batch_size
962- total_iterations = num_iterations * (args .num_train_epochs \
963- - args .warm_epochs - args .cooldown_epochs ) - args .num_warmup_steps
964- completed_pruned_cnt = 0
965- total_cnt = 0
966- for n , param in params :
967- total_cnt += param .numel ()
968- print (f"The total param quantity is { total_cnt } " )
969-
970- if args .teacher_model_name_or_path != None :
971- teacher_model .eval ()
972-
973- from pytorch_pruner .pruning import Pruning
974- pruner = Pruning (args .pruning_config )
983+ total_iterations = num_iterations * (args .num_train_epochs - args .warm_epochs - args .cooldown_epochs ) \
984+ - args .num_warmup_steps
975985 if args .do_prune :
976- pruner .update_items_for_all_pruners ( \
977- start_step = int (args .warm_epochs * num_iterations + args .num_warmup_steps ), \
978- end_step = int (total_iterations ))##iterative
986+ start = int (args .warm_epochs * num_iterations + args .num_warmup_steps )
987+ end = int (total_iterations )
988+ frequency = int ((end - start + 1 ) / 4 ) if (args .pruning_frequency == - 1 ) else args .pruning_frequency
989+ pruner .update_config (start_step = start , end_step = end , pruning_frequency = frequency )##iterative
979990 else :
980991 total_step = num_iterations * args .num_train_epochs + 1
981- pruner .update_items_for_all_pruners (start_step = total_step , end_step = total_step )
992+ pruner .update_config (start_step = total_step , end_step = total_step )
982993 pruner .model = model
983994 pruner .on_train_begin ()
984995
@@ -994,14 +1005,14 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
9941005 # We keep track of the loss at each epoch
9951006 if args .with_tracking :
9961007 total_loss += loss .detach ().float ()
997- if args .teacher_model_name_or_path != None :
1008+ if args .distill_loss_weight > 0 :
9981009 distill_loss_weight = args .distill_loss_weight
9991010 with torch .no_grad ():
10001011 teacher_outputs = teacher_model (** batch )
10011012 loss = (distill_loss_weight ) / 2 * get_loss_one_logit (outputs ['start_logits' ],
1002- teacher_outputs ['start_logits' ]) \
1013+ teacher_outputs ['start_logits' ]) \
10031014 + (distill_loss_weight ) / 2 * get_loss_one_logit (outputs ['end_logits' ],
1004- teacher_outputs ['end_logits' ])
1015+ teacher_outputs ['end_logits' ])
10051016
10061017 loss = loss / args .gradient_accumulation_steps
10071018 accelerator .backward (loss )
@@ -1160,3 +1171,4 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
11601171if __name__ == "__main__" :
11611172 main ()
11621173
1174+
0 commit comments