Skip to content

Commit 0ca582e

Browse files
WeiweiZhang1yintong-lu
authored andcommitted
add pruning examples and docs (#262)
* update qa example Signed-off-by: Zhang, Weiwei1 <[email protected]> * pruning doc modify Signed-off-by: Lu, Yintong <[email protected]> * pruning doc modify Signed-off-by: Lu, Yintong <[email protected]> * pruning doc modify Signed-off-by: Lu, Yintong <[email protected]> * pruning doc modify Signed-off-by: Lu, Yintong <[email protected]> * pruning doc modify Signed-off-by: Lu, Yintong <[email protected]> Signed-off-by: Zhang, Weiwei1 <[email protected]> Signed-off-by: Lu, Yintong <[email protected]> Co-authored-by: Lu, Yintong <[email protected]> Signed-off-by: zehao-intel <[email protected]>
1 parent 7afdcf9 commit 0ca582e

File tree

7 files changed

+128
-114
lines changed

7 files changed

+128
-114
lines changed
84.9 KB
Loading
-21 KB
Loading
77.6 KB
Loading
36.9 KB
Loading

docs/source/pruning_details.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ Pattern_lock pruning type uses masks of a fixed pattern during the pruning proce
251251

252252

253253

254-
Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure.
254+
Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure.
255+
Progressive pruning is used mainly for channel-wise pruning and currently only supports NxM pruning pattern.
255256

256257

257258

examples/pytorch/nlp/huggingface_models/question-answering/pruning/pytorch_pruner/eager/run_qa_no_trainer.py

Lines changed: 111 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
from transformers.utils import check_min_version
5858
from transformers.utils.versions import require_version
5959
from utils_qa import postprocess_qa_predictions
60+
from neural_compressor.pruning import Pruning
61+
from neural_compressor.pruner.utils import WeightPruningConfig
6062

6163
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
6264
check_min_version("4.21.0.dev0")
@@ -118,32 +120,32 @@ def parse_args():
118120
help="The configuration name of the dataset to use (via the datasets library).",
119121
)
120122
parser.add_argument(
121-
"--train_file",
122-
type=str,
123-
default=None,
123+
"--train_file",
124+
type=str,
125+
default=None,
124126
help="A csv or a json file containing the training data."
125127
)
126128
parser.add_argument(
127-
"--preprocessing_num_workers",
128-
type=int, default=4,
129+
"--preprocessing_num_workers",
130+
type=int, default=4,
129131
help="A csv or a json file containing the training data."
130132
)
131133

132134
parser.add_argument(
133-
"--do_predict",
134-
action="store_true",
135+
"--do_predict",
136+
action="store_true",
135137
help="To do prediction on the question answering model"
136138
)
137139
parser.add_argument(
138-
"--validation_file",
139-
type=str,
140-
default=None,
140+
"--validation_file",
141+
type=str,
142+
default=None,
141143
help="A csv or a json file containing the validation data."
142144
)
143145
parser.add_argument(
144-
"--test_file",
145-
type=str,
146-
default=None,
146+
"--test_file",
147+
type=str,
148+
default=None,
147149
help="A csv or a json file containing the Prediction data."
148150
)
149151
parser.add_argument(
@@ -163,15 +165,14 @@ def parse_args():
163165
parser.add_argument(
164166
"--model_name_or_path",
165167
type=str,
166-
help="Path to pretrained model or model identifier from huggingface.co/models.",
167-
required=False,
168+
help="Path to pretrained model or model identifier from huggingface.co/models."
168169
)
169170
parser.add_argument(
170171
"--teacher_model_name_or_path",
171172
type=str,
172173
default=None,
173174
help="Path to pretrained model or model identifier from huggingface.co/models.",
174-
required=False,
175+
required=False
175176
)
176177
parser.add_argument(
177178
"--config_name",
@@ -199,8 +200,8 @@ def parse_args():
199200
parser.add_argument(
200201
"--distill_loss_weight",
201202
type=float,
202-
default=1.0,
203-
help="distiller loss weight",
203+
default=0.0,
204+
help="distiller loss weight"
204205
)
205206
parser.add_argument(
206207
"--per_device_eval_batch_size",
@@ -215,15 +216,15 @@ def parse_args():
215216
help="Initial learning rate (after the potential warmup period) to use.",
216217
)
217218
parser.add_argument(
218-
"--weight_decay",
219-
type=float,
220-
default=0.0,
219+
"--weight_decay",
220+
type=float,
221+
default=0.0,
221222
help="Weight decay to use."
222223
)
223224
parser.add_argument(
224-
"--num_train_epochs",
225-
type=int,
226-
default=3,
225+
"--num_train_epochs",
226+
type=int,
227+
default=3,
227228
help="Total number of training epochs to perform."
228229
)
229230
parser.add_argument(
@@ -245,29 +246,28 @@ def parse_args():
245246
help="The scheduler type to use.",
246247
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
247248
)
248-
249249
parser.add_argument(
250-
"--warm_epochs",
251-
type=int,
252-
default=0,
250+
"--warm_epochs",
251+
type=int,
252+
default=0,
253253
help="Number of epochs the network not be purned"
254254
)
255255
parser.add_argument(
256-
"--num_warmup_steps",
257-
type=int,
258-
default=0,
256+
"--num_warmup_steps",
257+
type=int,
258+
default=0,
259259
help="Number of steps for the warmup in the lr scheduler."
260260
)
261261
parser.add_argument(
262-
"--output_dir",
263-
type=str,
264-
default=None,
262+
"--output_dir",
263+
type=str,
264+
default=None,
265265
help="Where to store the final model."
266266
)
267267
parser.add_argument(
268-
"--seed",
269-
type=int,
270-
default=None,
268+
"--seed",
269+
type=int,
270+
default=None,
271271
help="A seed for reproducible training."
272272
)
273273
parser.add_argument(
@@ -341,33 +341,18 @@ def parse_args():
341341
choices=MODEL_TYPES,
342342
)
343343
parser.add_argument(
344-
"--cooldown_epochs",
345-
type=int, default=0,
346-
help="Cooling epochs after pruning."
347-
)
348-
parser.add_argument(
349-
"--do_prune", action="store_true",
350-
help="Whether or not to prune the model"
351-
)
352-
parser.add_argument(
353-
"--pruning_config",
354-
type=str,
355-
help="pruning_config",
356-
)
357-
358-
parser.add_argument(
359-
"--push_to_hub",
360-
action="store_true",
344+
"--push_to_hub",
345+
action="store_true",
361346
help="Whether or not to push the model to the Hub."
362347
)
363348
parser.add_argument(
364-
"--hub_model_id",
365-
type=str,
349+
"--hub_model_id",
350+
type=str,
366351
help="The name of the repository to keep in sync with the local `output_dir`."
367352
)
368353
parser.add_argument(
369-
"--hub_token",
370-
type=str,
354+
"--hub_token",
355+
type=str,
371356
help="The token to use to push to the Model Hub."
372357
)
373358
parser.add_argument(
@@ -382,13 +367,6 @@ def parse_args():
382367
default=None,
383368
help="If the training should continue from a checkpoint folder.",
384369
)
385-
parser.add_argument(
386-
"--distiller",
387-
type=str,
388-
default=None,
389-
help="teacher model path",
390-
)
391-
392370
parser.add_argument(
393371
"--with_tracking",
394372
action="store_true",
@@ -405,6 +383,35 @@ def parse_args():
405383
),
406384
)
407385

386+
parser.add_argument(
387+
"--cooldown_epochs",
388+
type=int, default=0,
389+
help="Cooling epochs after pruning."
390+
)
391+
parser.add_argument(
392+
"--do_prune", action="store_true",
393+
help="Whether or not to prune the model"
394+
)
395+
# parser.add_argument(
396+
# "--keep_conf", action="store_true",
397+
# help="Whether or not to keep the prune config infos"
398+
# )
399+
parser.add_argument(
400+
"--pruning_pattern",
401+
type=str, default="1x1",
402+
help="pruning pattern type, we support NxM and N:M."
403+
)
404+
parser.add_argument(
405+
"--target_sparsity",
406+
type=float, default=0.8,
407+
help="Target sparsity of the model."
408+
)
409+
parser.add_argument(
410+
"--pruning_frequency",
411+
type=int, default=-1,
412+
help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps."
413+
)
414+
408415
args = parser.parse_args()
409416

410417
# Sanity checks
@@ -435,7 +442,7 @@ def parse_args():
435442
def main():
436443
args = parse_args()
437444

438-
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
445+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
439446
# information sent is the one passed as arguments along with your Python/PyTorch versions.
440447
# send_example_telemetry("run_qa_no_trainer", args)
441448

@@ -528,10 +535,13 @@ def main():
528535
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
529536
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
530537
)
531-
532-
if args.teacher_model_name_or_path != None:
538+
539+
if args.distill_loss_weight > 0:
540+
teacher_path = args.teacher_model_name_or_path
541+
if teacher_path is None:
542+
teacher_path = args.model_name_or_path
533543
teacher_model = AutoModelForQuestionAnswering.from_pretrained(
534-
args.teacher_model_name_or_path,
544+
teacher_path,
535545
from_tf=bool(".ckpt" in args.model_name_or_path),
536546
config=config,
537547
)
@@ -815,7 +825,6 @@ def post_processing_function(examples, features, predictions, stage="eval"):
815825
def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
816826
"""
817827
Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor
818-
819828
Args:
820829
start_or_end_logits(:obj:`tensor`):
821830
This is the output predictions of the model. We can only enter either start or end logits.
@@ -847,6 +856,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
847856
# Optimizer
848857
# Split weights in two groups, one with weight decay and the other not.
849858
no_decay = ["bias", "LayerNorm.weight"]
859+
no_decay_outputs = ["bias", "LayerNorm.weight", "qa_outputs"]
850860
optimizer_grouped_parameters = [
851861
{
852862
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
@@ -876,10 +886,11 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
876886
num_training_steps=args.max_train_steps,
877887
)
878888

879-
if args.teacher_model_name_or_path != None:
889+
if args.distill_loss_weight > 0:
880890
teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
881891
teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
882892
)
893+
teacher_model.eval()
883894
else:
884895
# Prepare everything with our `accelerator`.
885896
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
@@ -949,36 +960,36 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
949960
starting_epoch = resume_step // len(train_dataloader)
950961
resume_step -= starting_epoch * len(train_dataloader)
951962

952-
params = [(n, p) for (n, p) in model.named_parameters() if
953-
"bias" not in n and "LayerNorm" not in n and "embeddings" not in n \
954-
and "layer.3.attention.output.dense.weight" not in n and "qa_outputs" not in n]
955-
956-
params_keys = [n for (n, p) in params]
957-
for key in params_keys:
958-
print(key)
959-
960963
# Pruning preparation
964+
pruning_configs=[
965+
{
966+
"pruning_type": "snip_momentum",
967+
"pruning_scope": "global",
968+
"sparsity_decay_type": "exp"
969+
}
970+
]
971+
config = WeightPruningConfig(
972+
pruning_configs,
973+
target_sparsity=args.target_sparsity,
974+
excluded_op_names=["qa_outputs", "pooler", ".*embeddings*"],
975+
pruning_op_types=["Linear"],
976+
max_sparsity_ratio_per_op=0.98,
977+
pruning_scope="global",
978+
pattern=args.pruning_pattern,
979+
pruning_frequency=1000
980+
)
981+
pruner = Pruning(config)
961982
num_iterations = len(train_dataset) / total_batch_size
962-
total_iterations = num_iterations * (args.num_train_epochs \
963-
- args.warm_epochs - args.cooldown_epochs) - args.num_warmup_steps
964-
completed_pruned_cnt = 0
965-
total_cnt = 0
966-
for n, param in params:
967-
total_cnt += param.numel()
968-
print(f"The total param quantity is {total_cnt}")
969-
970-
if args.teacher_model_name_or_path != None:
971-
teacher_model.eval()
972-
973-
from pytorch_pruner.pruning import Pruning
974-
pruner = Pruning(args.pruning_config)
983+
total_iterations = num_iterations * (args.num_train_epochs - args.warm_epochs - args.cooldown_epochs) \
984+
- args.num_warmup_steps
975985
if args.do_prune:
976-
pruner.update_items_for_all_pruners( \
977-
start_step = int(args.warm_epochs*num_iterations+args.num_warmup_steps), \
978-
end_step = int(total_iterations))##iterative
986+
start = int(args.warm_epochs * num_iterations+args.num_warmup_steps)
987+
end = int(total_iterations)
988+
frequency = int((end - start + 1) / 4) if (args.pruning_frequency == -1) else args.pruning_frequency
989+
pruner.update_config(start_step=start, end_step=end, pruning_frequency=frequency)##iterative
979990
else:
980991
total_step = num_iterations * args.num_train_epochs + 1
981-
pruner.update_items_for_all_pruners(start_step=total_step, end_step=total_step)
992+
pruner.update_config(start_step=total_step, end_step=total_step)
982993
pruner.model = model
983994
pruner.on_train_begin()
984995

@@ -994,14 +1005,14 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
9941005
# We keep track of the loss at each epoch
9951006
if args.with_tracking:
9961007
total_loss += loss.detach().float()
997-
if args.teacher_model_name_or_path != None:
1008+
if args.distill_loss_weight > 0:
9981009
distill_loss_weight = args.distill_loss_weight
9991010
with torch.no_grad():
10001011
teacher_outputs = teacher_model(**batch)
10011012
loss = (distill_loss_weight) / 2 * get_loss_one_logit(outputs['start_logits'],
1002-
teacher_outputs['start_logits']) \
1013+
teacher_outputs['start_logits']) \
10031014
+ (distill_loss_weight) / 2 * get_loss_one_logit(outputs['end_logits'],
1004-
teacher_outputs['end_logits'])
1015+
teacher_outputs['end_logits'])
10051016

10061017
loss = loss / args.gradient_accumulation_steps
10071018
accelerator.backward(loss)
@@ -1160,3 +1171,4 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
11601171
if __name__ == "__main__":
11611172
main()
11621173

1174+

0 commit comments

Comments
 (0)