Skip to content

Commit 9e0644a

Browse files
authored
Refine pruning API reaction when obtain wrong pruner type. (#314)
* Refine pruning API reaction when obtain wrong pruner type. Signed-off-by: YIYANGCAI <[email protected]> * Add pattern_lock valid type. Signed-off-by: YIYANGCAI <[email protected]> --------- Signed-off-by: YIYANGCAI <[email protected]>
1 parent 3273d4f commit 9e0644a

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

neural_compressor/pruner/criteria.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
torch = LazyImport('torch')
2020

2121

22-
CRITERIAS = {}
22+
CRITERIA = {}
2323

2424

2525
def register_criterion(name):
2626
"""Register a criterion to the registry."""
2727

2828
def register(criterion):
29-
CRITERIAS[name] = criterion
29+
CRITERIA[name] = criterion
3030
return criterion
3131

3232
return register
@@ -35,9 +35,9 @@ def register(criterion):
3535
def get_criterion(config, modules):
3636
"""Get registered criterion class."""
3737
name = config["criterion_type"]
38-
if name not in CRITERIAS.keys():
39-
assert False, f"criteria does not support {name}, currently only support {CRITERIAS.keys()}"
40-
return CRITERIAS[name](modules, config)
38+
if name not in CRITERIA.keys():
39+
assert False, f"criteria does not support {name}, currently only support {CRITERIA.keys()}"
40+
return CRITERIA[name](modules, config)
4141

4242

4343
class PruningCriterion:

neural_compressor/pruner/pruners.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121
torch = LazyImport('torch')
2222
from .patterns import get_pattern
2323
from .schedulers import get_scheduler
24-
from .criteria import get_criterion, CRITERIAS
24+
from .criteria import get_criterion, CRITERIA
2525
from .regs import get_reg
2626
from .logger import logger
2727

2828
PRUNERS = {}
2929

30-
3130
def register_pruner(name):
3231
"""Class decorator to register a Pruner subclass to the registry.
3332
@@ -48,6 +47,14 @@ def register(pruner):
4847

4948
return register
5049

50+
def parse_valid_pruner_types():
51+
"""Get all valid pruner names"""
52+
valid_pruner_types = []
53+
for x in CRITERIA.keys():
54+
for p in ["", "_progressive"]:
55+
valid_pruner_types.append(x + p)
56+
valid_pruner_types.append("pattern_lock")
57+
return valid_pruner_types
5158

5259
def get_pruner(config, modules):
5360
"""Get registered pruner class.
@@ -71,7 +78,7 @@ def get_pruner(config, modules):
7178
# if progressive, delete "progressive" words and reset config["progressive"]
7279
name = config["pruning_type"][0:-12]
7380
config["progressive"] = True
74-
if name in CRITERIAS:
81+
if name in CRITERIA:
7582
if config["progressive"] == False:
7683
config['criterion_type'] = name
7784
name = "basic" ##return the basic pruner
@@ -80,7 +87,7 @@ def get_pruner(config, modules):
8087
name = "progressive" ## return the progressive pruner
8188

8289
if name not in PRUNERS.keys():
83-
assert False, f"does not support {name}, currently only support {PRUNERS.keys()}"
90+
assert False, f"does not support {name}, currently only support {parse_valid_pruner_types()}"
8491
return PRUNERS[name](config, modules)
8592

8693

0 commit comments

Comments
 (0)