Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 157 additions & 53 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,34 @@ def convert(
class Chunk(ConversionOps):
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""

reverse_op: type[ConversionOps]

def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
if chunks is None and sizes is None:
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
if chunks is not None and chunks <= 0:
raise ValueError("`chunks` must be a strictly positive integer.")
self.dim = dim
self.chunks = chunks
self.sizes = list(sizes) if sizes is not None else None
self.reverse_op = Concatenate

def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
# chunk requires a single tensor input
if len(value) != 1 or len(value[0]) != 1:
raise ValueError("Chunk operation requires a single tensor input.")
@property
def reverse_op(self) -> ConversionOps:
return Concatenate(self.dim)

def convert(self, value: torch.Tensor, concrete_target_keys=None, *args, **kwargs) -> list[torch.Tensor]:
# chunk requires a single tensor input (maybe not when saving actually!)
udpate_ = []
if concrete_target_keys is not None: # when saving we have multiple tensors
for layer in value:
for tensors in layer:
chunk_size = len(concrete_target_keys)
udpate_+= [dict(zip(concrete_target_keys, torch.chunk(tensors, chunks=chunk_size, dim=self.dim)))]
return udpate_
return list(torch.chunk(value[0][0], self.chunks, dim=self.dim))


class Concatenate(ConversionOps):
"""Concatenate tensors along `dim` using a reusable buffer."""

reverse_op: type[ConversionOps]

def __init__(self, dim: int = 0):
self.dim = dim
self.reverse_op = Chunk

@property
def reverse_op(self) -> ConversionOps:
return Chunk(self.dim)

@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
Expand All @@ -187,16 +188,19 @@ def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tenso
return torch.cat(tuple(tensors), dim=self.dim)


class MergeModulelist(Concatenate):
class MergeModulelist(ConversionOps):
"""
Merge a list of tensors into a single tensor along the first dimension.
We explicitly define this because for EP or TP you want to make sure you know what you are doing!

"""

def __init__(self, dim: int = 0):
super().__init__(dim=dim)
self.reverse_op = SplitModulelist
self.dim = dim

@property
def reverse_op(self) -> ConversionOps:
return SplitModulelist(self.dim)

@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
Expand All @@ -212,26 +216,24 @@ def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.
class SplitModulelist(ConversionOps):
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""

def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
self.sizes = [list(sub) for sub in sizes]
self.dim = dim
self.reverse_op = MergeModulelist
def __init__(self, dim: int = 0):
self.dim = dim

@property
def reverse_op(self) -> ConversionOps:
return MergeModulelist(self.dim)

@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
if not isinstance(value, Sequence):
raise TypeError("SplitModulelist expects a sequence of tensors.")
if len(value) != len(self.sizes):
raise ValueError("Number of tensors does not match the provided split specifications.")

result: list[list[torch.Tensor]] = []
for tensor, split_sizes in zip(value, self.sizes):
if not isinstance(tensor, torch.Tensor):
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
splits = torch.split(tensor, split_sizes, dim=self.dim)
result.append(list(splits))
def convert(self, value: Sequence[torch.Tensor], concrete_target_keys=None, config=None, *args, **kwargs) -> list[list[torch.Tensor]]:
result = []
for i, layers in enumerate(value):
tmp = {}
if not isinstance(layers, dict):
layers = {concrete_target_keys[i]: layers[i] for i in range(len(layers))}
for k, v in layers.items():
splits = torch.chunk(v, config.num_experts, dim=self.dim)
tmp.update({k.replace("*", str(i)): v for i, v in enumerate(splits)})
result.append(tmp)
return result


Expand Down Expand Up @@ -318,6 +320,11 @@ def _materialize_copy(tensor, dtype=None):
tensor = tensor.to(dtype)
return tensor

def spawn_dematerialize(thread_pool, tensor, dtype=None) -> Future:
def _job():
return tensor.detach() #.cpu()

return thread_pool.submit(_job)

def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
def _job():
Expand Down Expand Up @@ -369,7 +376,7 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) ->
values, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[layer_name] = (
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {values}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
Expand Down Expand Up @@ -598,19 +605,116 @@ def convert_and_load_state_dict_in_model(


# TODO this is not done yet!
def revert_weight_conversion(model, state_dict):
def revert_weight_conversion(model, state_dict, weight_mapping):
mapping = getattr(model, "_checkpoint_conversion_mapping", {}) # IDK why but setting this will fail all llava.
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
original_state_dict = {}
for key, value in state_dict.items():
for pattern, inverse_converter in reverse_key_mapping:
# TODO FIXME you name it
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict
reverse_key_mapping = [(v, k) for k, v in mapping.items()] # todo also take it into account

tp_plan = model.tp_plan or {} # {glob_pattern: plan_obj_or_key}
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}

misc = {}
final_state_dict = {}
# Global thread_pool
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)

_patterns = list(itertools.chain.from_iterable([k.target_keys for k in weight_mapping]))
target_to_source = {sk: k for k in weight_mapping for sk in k.target_keys}
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))

state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
# 1. Create the conversion entries
by_conversion_pattern: dict[str, ConversionEntry] = {}
for original_key, tensor in state_dict:
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
if matched_pattern is not None:
converter = target_to_source[matched_pattern] # TODO make sure its the ref
sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key)
entry_key = "|".join(map(sub_with_extractor, converter.source_keys))
target_key = entry_key # at this point we don't know how many we'll collect :)
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
converter_key = sub_with_extractor(matched_pattern)
else:
converter = WeightConverter(original_key)
converter_key = entry_key = target_key = original_key
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))

if False and quantizer is not None and quantizer.param_needs_quantization(model, t):
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
from .integrations.finegrained_fp8 import Fp8Quantize

converter.quantization_operation = Fp8Quantize() # TODO support other methods
else:
raise ValueError("This quantization method is gonna be supported SOOOON")

future = None
# if device_mesh:
# if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
# if getattr(converter, "distributed_operation", {}) is None:
# tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
# converter.distributed_operation = tp_layer(
# device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
# )
# # VERY IMPORTANT: this tells us wether we collected stuffs or not.
# shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
# future = spawn_tp_dematerialize(
# thread_pool,
# tensor,
# converter.distributed_operation,
# shard_index,
# )

if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
future = spawn_dematerialize(thread_pool, tensor) # -> should we put it to CPU always?
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)

# 2. Actually convert the ckpt
keys = list(by_conversion_pattern.keys())


with logging.tqdm(total=len(keys), desc="saving weights") as pbar:
for key in keys[::-1]: # revert to process simple keys first
group = by_conversion_pattern.pop(key)
converter = group.weight_converter
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
pbar.update(1)
pbar.set_postfix({"Materializing param": layer_name})
pbar.refresh()
concrete_target_keys = layer_name.split("|")
try:
with log_to_misc(layer_name, misc):
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]

for op in operations[::-1]:
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
reverse_op = op.reverse_op
values = reverse_op.convert(values, concrete_target_keys, config=model.config)

values = [values] if not isinstance(values, list) else values
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
if len(values) == 1 and isinstance(values[0], dict):
realized_value = values[0]
else:
realized_value = dict(zip(concrete_target_keys, values))

for k in list(realized_value.keys()).copy():
if op := converter.quantization_operation: # dequantize
with log_to_misc(layer_name, misc, op=op):
realized_value.update(
op.convert(
{k: realized_value.pop(k)}, # quant_config=quantizer.quantization_config
)
)

for k, output_value in realized_value.items():
final_state_dict[k] = output_value[0] if isinstance(output_value, list) else output_value

# TODO @Cyrilvallez handle scheduled saving, gather and etc
# schedule the saving of the weights using the threadpool. `save_file`

except SkipLayer:
continue
del group
print(misc)
return final_state_dict
5 changes: 3 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3049,7 +3049,7 @@ def save_pretrained(
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
save_peft_format: bool = True,
save_original_format: bool = False, # TODO next PR will make it go to True
save_original_format: bool = True, # TODO next PR will make it go to True
**kwargs,
):
"""
Expand Down Expand Up @@ -3341,7 +3341,8 @@ def save_pretrained(
# MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt
# using what was loaded. Actually self._conversion_ops wont work because we need it
# even if the files are not legacy -> thus no conversion happened
state_dict = revert_weight_conversion(self, state_dict)
weight_mapping = get_checkpoint_conversion_mapping(self.config.model_type)
state_dict = revert_weight_conversion(self, state_dict, weight_mapping)

# Shard the model if it is too big.
if not _hf_peft_config_loaded:
Expand Down