Skip to content

Commit d29aa0f

Browse files
add block_mask & retrain_free features (#775)
Signed-off-by: Zhang, Weiwei1 <[email protected]> Co-authored-by: wenhuach21 <[email protected]>
1 parent e45c022 commit d29aa0f

File tree

10 files changed

+2878
-122
lines changed

10 files changed

+2878
-122
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py

Lines changed: 924 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/bin/bash
2+
set -x
3+
python \
4+
examples/pytorch/nlp/huggingface_models/language-modeling/pruning/eager/run_clm_no_trainer.py \
5+
--model_name_or_path /path/to/your/model \
6+
--dataset_name lambada \
7+
--per_device_train_batch_size 2 \
8+
--per_device_eval_batch_size 16 \
9+
--max_train_steps 3002 \
10+
--weight_decay 0 \
11+
--block_size 512 \
12+
--do_prune \
13+
--auto_slim \
14+
--output_dir sparse_clm_models/ \
15+
--target_sparsity 0.2 \
16+
--pruning_pattern channelx1 \
17+
--pruning_frequency 500 \
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import time
2+
import torch
3+
class CPUTimer:
4+
def __init__(self, timelogs):
5+
self.timelogs = timelogs
6+
7+
def __enter__(self):
8+
self.start = time.time()
9+
10+
def __exit__(self):
11+
end = time.time()
12+
self.timelogs.append((end - self.start) * 1000) # ms
13+
14+
def get_avg_time(self):
15+
return sum(self.timelogs) / len(self.timelogs)
16+
17+
class GPUTimer:
18+
def __init__(self, timelogs):
19+
self.timelogs = timelogs
20+
21+
def __enter__(self):
22+
self.start_event = torch.cuda.Event(enable_timing=True)
23+
self.end_event = torch.cuda.Event(enable_timing=True)
24+
self.start_event.record()
25+
26+
def __exit__(self):
27+
self.end_event.record()
28+
self.end_event.synchronize()
29+
elapsed_time = self.start_event.elapsed_time(self.end_event)
30+
self.timelogs.append(elapsed_time)
31+
32+
def get_avg_time(self):
33+
return sum(self.timelogs) / len(self.timelogs)

0 commit comments

Comments
 (0)