Skip to content

Bug in smooth_quant.py #1265

@MatteoPagliani

Description

@MatteoPagliani

Hi, I think there's a bug affecting both get_module and set_module methods in neural_compressor/adaptor/torch_utils/smooth_quant.py.

The original methods are:

def get_module(model, key):
    """Get module from model by key name.

    Args:
        model (torch.nn.Module): original model
        key (str): module name to be replaced
    """
    attrs = key.split(".")
    module = model
    for attr in attrs:
        try:
            attr = int(attr)
            module = module[attr]
        except:
            module = getattr(module, attr)
    return module


def set_module(model, key, new_module):
    """Set new module into model by key name.

    Args:
        model (torch.nn.Module): original model
        key (str): module name to be replaced
        new_module (torch.nn.Module): new module to be inserted
    """
    attrs = key.split(".")
    module = model
    for attr in attrs[:-1]:
        try:
            attr = int(attr)
            module = module[attr]
        except:
            module = getattr(module, attr)
    setattr(module, attrs[-1], new_module)

When the execution goes into the except blocks, getattr() raises this error: "TypeError: getattr(): attribute name must be string". I think this is due to the fact that attr got cast to int in the try blocks. So, I propose the following modifications:

def get_module(model, key):
    """Get module from model by key name.

    Args:
        model (torch.nn.Module): original model
        key (str): module name to be replaced
    """
    attrs = key.split(".")
    module = model
    for attr in attrs:
        try:
            module = module[int(attr)]
        except:
            module = getattr(module, attr)
    return module


def set_module(model, key, new_module):
    """Set new module into model by key name.

    Args:
        model (torch.nn.Module): original model
        key (str): module name to be replaced
        new_module (torch.nn.Module): new module to be inserted
    """
    attrs = key.split(".")
    module = model
    for attr in attrs[:-1]:
        try:
            module = module[int(attr)]
        except:
            module = getattr(module, attr)
    setattr(module, attrs[-1], new_module)

In this way I'm able to run my code without errors. Let me know what you think. Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions