Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 20 additions & 61 deletions src/transformers/integrations/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,61 +159,6 @@ def wrapper(*args, **kwargs):
setattr(torch, torch_function_name, old_torch_function)


def find_tied_parameters(model: "nn.Module", **kwargs):
"""
Find the tied parameters in a given model.

<Tip warning={true}>

The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
them.

</Tip>

Args:
model (`torch.nn.Module`): The model to inspect.

Returns:
list[list[str]]: A list of lists of parameter names being all tied together.

Example:

```py
>>> from collections import OrderedDict
>>> import torch.nn as nn

>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
>>> model.linear2.weight = model.linear1.weight
>>> find_tied_parameters(model)
[['linear1.weight', 'linear2.weight']]
```
"""

# get ALL model parameters and their names
all_named_parameters = dict(model.named_parameters(remove_duplicate=False))

# get ONLY unique named parameters,
# if parameter is tied and have multiple names, it will be included only once
no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))

# the difference of the two sets will give us the tied parameters
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())

# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
# which names refer to the same parameter. To identify this, we need to group them together.
tied_param_groups = {}
for tied_param_name in tied_param_names:
tied_param = all_named_parameters[tied_param_name]
for param_name, param in no_duplicate_named_parameters.items():
# compare if parameters are the same, if so, group their names together
if param is tied_param:
if param_name not in tied_param_groups:
tied_param_groups[param_name] = []
tied_param_groups[param_name].append(tied_param_name)

return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]


def check_and_set_device_map(device_map: "torch.device | int | str | dict | None") -> dict | str | None:
from ..modeling_utils import get_torch_context_manager_or_global_device

Expand Down Expand Up @@ -271,11 +216,21 @@ def compute_module_sizes(
leaves_module_sizes = defaultdict(int)

if buffers_only:
named_tensors = model.named_buffers(recurse=True)
iterator = model.named_buffers()
else:
named_tensors = model.state_dict().items()

for name, param in named_tensors:
# We need parameters + buffers here, as state_dict does not count non-persistent buffers which are taking space
def all_tensors():
yield from model.named_parameters()
yield from model.named_buffers()

iterator = all_tensors()

tied_keys = getattr(model, "all_tied_weights_keys", {}).keys()
for name, param in iterator:
# Do not count tied keys (the model is usually not tied yet here, so they will appear in the iterator)
# If the model is already tied, then they simply do not appear in the iterator anyway (remove_duplicates=True by default)
if name in tied_keys:
continue
if hf_quantizer is not None:
dtype_size = hf_quantizer.param_element_size(model, name)
else:
Expand Down Expand Up @@ -591,8 +546,12 @@ def _init_infer_auto_device_map(

if tied_parameters is None:
if len(model.all_tied_weights_keys) > 0:
# create a list of list of tied params
tied_parameters = [list(t) for t in model.all_tied_weights_keys.items()]
# create a list of list of tied params based on unique tied groups
groups = set(model.all_tied_weights_keys.values())
tied_parameters = [
sorted([k for k, v in model.all_tied_weights_keys.items() if v == target] + [target])
for target in groups
]
else:
tied_parameters = [[]]

Expand Down