Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/imgs/pruning/Pruning_patterns.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/source/pruning_details.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ Pattern_lock pruning type uses masks of a fixed pattern during the pruning proce



Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure.
Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure.
Progressive pruning is used mainly for channel-wise pruning and currently only supports NxM pruning pattern.



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from utils_qa import postprocess_qa_predictions
from neural_compressor.pruning import Pruning
from neural_compressor.pruner.utils import WeightPruningConfig

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.21.0.dev0")
Expand Down Expand Up @@ -118,32 +120,32 @@ def parse_args():
help="The configuration name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--train_file",
type=str,
default=None,
"--train_file",
type=str,
default=None,
help="A csv or a json file containing the training data."
)
parser.add_argument(
"--preprocessing_num_workers",
type=int, default=4,
"--preprocessing_num_workers",
type=int, default=4,
help="A csv or a json file containing the training data."
)

parser.add_argument(
"--do_predict",
action="store_true",
"--do_predict",
action="store_true",
help="To do prediction on the question answering model"
)
parser.add_argument(
"--validation_file",
type=str,
default=None,
"--validation_file",
type=str,
default=None,
help="A csv or a json file containing the validation data."
)
parser.add_argument(
"--test_file",
type=str,
default=None,
"--test_file",
type=str,
default=None,
help="A csv or a json file containing the Prediction data."
)
parser.add_argument(
Expand All @@ -163,15 +165,14 @@ def parse_args():
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--teacher_model_name_or_path",
type=str,
default=None,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=False,
required=False
)
parser.add_argument(
"--config_name",
Expand Down Expand Up @@ -199,8 +200,8 @@ def parse_args():
parser.add_argument(
"--distill_loss_weight",
type=float,
default=1.0,
help="distiller loss weight",
default=0.0,
help="distiller loss weight"
)
parser.add_argument(
"--per_device_eval_batch_size",
Expand All @@ -215,15 +216,15 @@ def parse_args():
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=3,
"--num_train_epochs",
type=int,
default=3,
help="Total number of training epochs to perform."
)
parser.add_argument(
Expand All @@ -245,29 +246,28 @@ def parse_args():
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)

parser.add_argument(
"--warm_epochs",
type=int,
default=0,
"--warm_epochs",
type=int,
default=0,
help="Number of epochs the network not be purned"
)
parser.add_argument(
"--num_warmup_steps",
type=int,
default=0,
"--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
"--output_dir",
type=str,
default=None,
help="Where to store the final model."
)
parser.add_argument(
"--seed",
type=int,
default=None,
"--seed",
type=int,
default=None,
help="A seed for reproducible training."
)
parser.add_argument(
Expand Down Expand Up @@ -341,33 +341,18 @@ def parse_args():
choices=MODEL_TYPES,
)
parser.add_argument(
"--cooldown_epochs",
type=int, default=0,
help="Cooling epochs after pruning."
)
parser.add_argument(
"--do_prune", action="store_true",
help="Whether or not to prune the model"
)
parser.add_argument(
"--pruning_config",
type=str,
help="pruning_config",
)

parser.add_argument(
"--push_to_hub",
action="store_true",
"--push_to_hub",
action="store_true",
help="Whether or not to push the model to the Hub."
)
parser.add_argument(
"--hub_model_id",
type=str,
"--hub_model_id",
type=str,
help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument(
"--hub_token",
type=str,
"--hub_token",
type=str,
help="The token to use to push to the Model Hub."
)
parser.add_argument(
Expand All @@ -382,13 +367,6 @@ def parse_args():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--distiller",
type=str,
default=None,
help="teacher model path",
)

parser.add_argument(
"--with_tracking",
action="store_true",
Expand All @@ -405,6 +383,35 @@ def parse_args():
),
)

parser.add_argument(
"--cooldown_epochs",
type=int, default=0,
help="Cooling epochs after pruning."
)
parser.add_argument(
"--do_prune", action="store_true",
help="Whether or not to prune the model"
)
# parser.add_argument(
# "--keep_conf", action="store_true",
# help="Whether or not to keep the prune config infos"
# )
parser.add_argument(
"--pruning_pattern",
type=str, default="1x1",
help="pruning pattern type, we support NxM and N:M."
)
parser.add_argument(
"--target_sparsity",
type=float, default=0.8,
help="Target sparsity of the model."
)
parser.add_argument(
"--pruning_frequency",
type=int, default=-1,
help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps."
)

args = parser.parse_args()

# Sanity checks
Expand Down Expand Up @@ -435,7 +442,7 @@ def parse_args():
def main():
args = parse_args()

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
# send_example_telemetry("run_qa_no_trainer", args)

Expand Down Expand Up @@ -528,10 +535,13 @@ def main():
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

if args.teacher_model_name_or_path != None:

if args.distill_loss_weight > 0:
teacher_path = args.teacher_model_name_or_path
if teacher_path is None:
teacher_path = args.model_name_or_path
teacher_model = AutoModelForQuestionAnswering.from_pretrained(
args.teacher_model_name_or_path,
teacher_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
)
Expand Down Expand Up @@ -815,7 +825,6 @@ def post_processing_function(examples, features, predictions, stage="eval"):
def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
"""
Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor

Args:
start_or_end_logits(:obj:`tensor`):
This is the output predictions of the model. We can only enter either start or end logits.
Expand Down Expand Up @@ -847,6 +856,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
no_decay_outputs = ["bias", "LayerNorm.weight", "qa_outputs"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
Expand Down Expand Up @@ -876,10 +886,11 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
num_training_steps=args.max_train_steps,
)

if args.teacher_model_name_or_path != None:
if args.distill_loss_weight > 0:
teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
teacher_model.eval()
else:
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
Expand Down Expand Up @@ -949,36 +960,36 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)

params = [(n, p) for (n, p) in model.named_parameters() if
"bias" not in n and "LayerNorm" not in n and "embeddings" not in n \
and "layer.3.attention.output.dense.weight" not in n and "qa_outputs" not in n]

params_keys = [n for (n, p) in params]
for key in params_keys:
print(key)

# Pruning preparation
pruning_configs=[
{
"pruning_type": "snip_momentum",
"pruning_scope": "global",
"sparsity_decay_type": "exp"
}
]
config = WeightPruningConfig(
pruning_configs,
target_sparsity=args.target_sparsity,
excluded_op_names=["qa_outputs", "pooler", ".*embeddings*"],
pruning_op_types=["Linear"],
max_sparsity_ratio_per_op=0.98,
pruning_scope="global",
pattern=args.pruning_pattern,
pruning_frequency=1000
)
pruner = Pruning(config)
num_iterations = len(train_dataset) / total_batch_size
total_iterations = num_iterations * (args.num_train_epochs \
- args.warm_epochs - args.cooldown_epochs) - args.num_warmup_steps
completed_pruned_cnt = 0
total_cnt = 0
for n, param in params:
total_cnt += param.numel()
print(f"The total param quantity is {total_cnt}")

if args.teacher_model_name_or_path != None:
teacher_model.eval()

from pytorch_pruner.pruning import Pruning
pruner = Pruning(args.pruning_config)
total_iterations = num_iterations * (args.num_train_epochs - args.warm_epochs - args.cooldown_epochs) \
- args.num_warmup_steps
if args.do_prune:
pruner.update_items_for_all_pruners( \
start_step = int(args.warm_epochs*num_iterations+args.num_warmup_steps), \
end_step = int(total_iterations))##iterative
start = int(args.warm_epochs * num_iterations+args.num_warmup_steps)
end = int(total_iterations)
frequency = int((end - start + 1) / 4) if (args.pruning_frequency == -1) else args.pruning_frequency
pruner.update_config(start_step=start, end_step=end, pruning_frequency=frequency)##iterative
else:
total_step = num_iterations * args.num_train_epochs + 1
pruner.update_items_for_all_pruners(start_step=total_step, end_step=total_step)
pruner.update_config(start_step=total_step, end_step=total_step)
pruner.model = model
pruner.on_train_begin()

Expand All @@ -994,14 +1005,14 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
if args.teacher_model_name_or_path != None:
if args.distill_loss_weight > 0:
distill_loss_weight = args.distill_loss_weight
with torch.no_grad():
teacher_outputs = teacher_model(**batch)
loss = (distill_loss_weight) / 2 * get_loss_one_logit(outputs['start_logits'],
teacher_outputs['start_logits']) \
teacher_outputs['start_logits']) \
+ (distill_loss_weight) / 2 * get_loss_one_logit(outputs['end_logits'],
teacher_outputs['end_logits'])
teacher_outputs['end_logits'])

loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
Expand Down Expand Up @@ -1160,3 +1171,4 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if __name__ == "__main__":
main()


Loading