-
Couldn't load subscription status.
- Fork 282
Closed
Description
Hi, I think there's a bug affecting both get_module and set_module methods in neural_compressor/adaptor/torch_utils/smooth_quant.py.
The original methods are:
def get_module(model, key):
"""Get module from model by key name.
Args:
model (torch.nn.Module): original model
key (str): module name to be replaced
"""
attrs = key.split(".")
module = model
for attr in attrs:
try:
attr = int(attr)
module = module[attr]
except:
module = getattr(module, attr)
return module
def set_module(model, key, new_module):
"""Set new module into model by key name.
Args:
model (torch.nn.Module): original model
key (str): module name to be replaced
new_module (torch.nn.Module): new module to be inserted
"""
attrs = key.split(".")
module = model
for attr in attrs[:-1]:
try:
attr = int(attr)
module = module[attr]
except:
module = getattr(module, attr)
setattr(module, attrs[-1], new_module)When the execution goes into the except blocks, getattr() raises this error: "TypeError: getattr(): attribute name must be string". I think this is due to the fact that attr got cast to int in the try blocks. So, I propose the following modifications:
def get_module(model, key):
"""Get module from model by key name.
Args:
model (torch.nn.Module): original model
key (str): module name to be replaced
"""
attrs = key.split(".")
module = model
for attr in attrs:
try:
module = module[int(attr)]
except:
module = getattr(module, attr)
return module
def set_module(model, key, new_module):
"""Set new module into model by key name.
Args:
model (torch.nn.Module): original model
key (str): module name to be replaced
new_module (torch.nn.Module): new module to be inserted
"""
attrs = key.split(".")
module = model
for attr in attrs[:-1]:
try:
module = module[int(attr)]
except:
module = getattr(module, attr)
setattr(module, attrs[-1], new_module)In this way I'm able to run my code without errors. Let me know what you think. Thanks.
Metadata
Metadata
Assignees
Labels
No labels