diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2147a45d7503..b621d7a8a7a0 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -148,33 +148,34 @@ def convert( class Chunk(ConversionOps): """Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``.""" - reverse_op: type[ConversionOps] - def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None): - if chunks is None and sizes is None: - raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.") - if chunks is not None and chunks <= 0: - raise ValueError("`chunks` must be a strictly positive integer.") self.dim = dim self.chunks = chunks - self.sizes = list(sizes) if sizes is not None else None - self.reverse_op = Concatenate - def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]: - # chunk requires a single tensor input - if len(value) != 1 or len(value[0]) != 1: - raise ValueError("Chunk operation requires a single tensor input.") + @property + def reverse_op(self) -> ConversionOps: + return Concatenate(self.dim) + + def convert(self, value: torch.Tensor, concrete_target_keys=None, *args, **kwargs) -> list[torch.Tensor]: + # chunk requires a single tensor input (maybe not when saving actually!) + udpate_ = [] + if concrete_target_keys is not None: # when saving we have multiple tensors + for layer in value: + for tensors in layer: + chunk_size = len(concrete_target_keys) + udpate_+= [dict(zip(concrete_target_keys, torch.chunk(tensors, chunks=chunk_size, dim=self.dim)))] + return udpate_ return list(torch.chunk(value[0][0], self.chunks, dim=self.dim)) class Concatenate(ConversionOps): """Concatenate tensors along `dim` using a reusable buffer.""" - - reverse_op: type[ConversionOps] - def __init__(self, dim: int = 0): self.dim = dim - self.reverse_op = Chunk + + @property + def reverse_op(self) -> ConversionOps: + return Chunk(self.dim) @torch.no_grad def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor: @@ -187,7 +188,7 @@ def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tenso return torch.cat(tuple(tensors), dim=self.dim) -class MergeModulelist(Concatenate): +class MergeModulelist(ConversionOps): """ Merge a list of tensors into a single tensor along the first dimension. We explicitly define this because for EP or TP you want to make sure you know what you are doing! @@ -195,8 +196,11 @@ class MergeModulelist(Concatenate): """ def __init__(self, dim: int = 0): - super().__init__(dim=dim) - self.reverse_op = SplitModulelist + self.dim = dim + + @property + def reverse_op(self) -> ConversionOps: + return SplitModulelist(self.dim) @torch.no_grad def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]: @@ -212,26 +216,24 @@ def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch. class SplitModulelist(ConversionOps): """Inverse of :class:`MergeModulelist` using explicit split sizes per group.""" - def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0): - if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes): - raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.") - self.sizes = [list(sub) for sub in sizes] - self.dim = dim - self.reverse_op = MergeModulelist + def __init__(self, dim: int = 0): + self.dim = dim + + @property + def reverse_op(self) -> ConversionOps: + return MergeModulelist(self.dim) @torch.no_grad - def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]: - if not isinstance(value, Sequence): - raise TypeError("SplitModulelist expects a sequence of tensors.") - if len(value) != len(self.sizes): - raise ValueError("Number of tensors does not match the provided split specifications.") - - result: list[list[torch.Tensor]] = [] - for tensor, split_sizes in zip(value, self.sizes): - if not isinstance(tensor, torch.Tensor): - raise TypeError("SplitModulelist can only split torch.Tensor instances.") - splits = torch.split(tensor, split_sizes, dim=self.dim) - result.append(list(splits)) + def convert(self, value: Sequence[torch.Tensor], concrete_target_keys=None, config=None, *args, **kwargs) -> list[list[torch.Tensor]]: + result = [] + for i, layers in enumerate(value): + tmp = {} + if not isinstance(layers, dict): + layers = {concrete_target_keys[i]: layers[i] for i in range(len(layers))} + for k, v in layers.items(): + splits = torch.chunk(v, config.num_experts, dim=self.dim) + tmp.update({k.replace("*", str(i)): v for i, v in enumerate(splits)}) + result.append(tmp) return result @@ -318,6 +320,11 @@ def _materialize_copy(tensor, dtype=None): tensor = tensor.to(dtype) return tensor +def spawn_dematerialize(thread_pool, tensor, dtype=None) -> Future: + def _job(): + return tensor.detach() #.cpu() + + return thread_pool.submit(_job) def spawn_materialize(thread_pool, tensor, dtype=None) -> Future: def _job(): @@ -369,7 +376,7 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> values, target_keys = extras descriptor = f"{op_name} " if op_name else "" misc[layer_name] = ( - f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}" + f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {values}" ) elif isinstance(extras, str): suffix = f" via {op_name}" if op_name else "" @@ -598,19 +605,116 @@ def convert_and_load_state_dict_in_model( # TODO this is not done yet! -def revert_weight_conversion(model, state_dict): +def revert_weight_conversion(model, state_dict, weight_mapping): mapping = getattr(model, "_checkpoint_conversion_mapping", {}) # IDK why but setting this will fail all llava. - reverse_key_mapping = [(v, k) for k, v in mapping.items()] - original_state_dict = {} - for key, value in state_dict.items(): - for pattern, inverse_converter in reverse_key_mapping: - # TODO FIXME you name it - replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns - replacement = re.sub(r"\(.*\)", "", replacement) - key, n_replace = re.subn(pattern, replacement, key) - # Early exit of the loop - if n_replace > 0: - break - original_state_dict[key] = value - state_dict = original_state_dict - return state_dict + reverse_key_mapping = [(v, k) for k, v in mapping.items()] # todo also take it into account + + tp_plan = model.tp_plan or {} # {glob_pattern: plan_obj_or_key} + weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter} + + misc = {} + final_state_dict = {} + # Global thread_pool + thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS) + + _patterns = list(itertools.chain.from_iterable([k.target_keys for k in weight_mapping])) + target_to_source = {sk: k for k in weight_mapping for sk in k.target_keys} + weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) + tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys())) + + state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) + # 1. Create the conversion entries + by_conversion_pattern: dict[str, ConversionEntry] = {} + for original_key, tensor in state_dict: + matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name) + if matched_pattern is not None: + converter = target_to_source[matched_pattern] # TODO make sure its the ref + sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key) + entry_key = "|".join(map(sub_with_extractor, converter.source_keys)) + target_key = entry_key # at this point we don't know how many we'll collect :) + entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter)) + converter_key = sub_with_extractor(matched_pattern) + else: + converter = WeightConverter(original_key) + converter_key = entry_key = target_key = original_key + entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter)) + + if False and quantizer is not None and quantizer.param_needs_quantization(model, t): + if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer": + from .integrations.finegrained_fp8 import Fp8Quantize + + converter.quantization_operation = Fp8Quantize() # TODO support other methods + else: + raise ValueError("This quantization method is gonna be supported SOOOON") + + future = None + # if device_mesh: + # if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name): + # if getattr(converter, "distributed_operation", {}) is None: + # tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ + # converter.distributed_operation = tp_layer( + # device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone() + # ) + # # VERY IMPORTANT: this tells us wether we collected stuffs or not. + # shard_index = len(entry.collected_tensors[target_key].get(converter_key, [])) + # future = spawn_tp_dematerialize( + # thread_pool, + # tensor, + # converter.distributed_operation, + # shard_index, + # ) + + if future is None: # If not TP, async materialize the tensors. TODO handle disk offload? + future = spawn_dematerialize(thread_pool, tensor) # -> should we put it to CPU always? + entry.collected_tensors[target_key].setdefault(converter_key, []).append(future) + + # 2. Actually convert the ckpt + keys = list(by_conversion_pattern.keys()) + + + with logging.tqdm(total=len(keys), desc="saving weights") as pbar: + for key in keys[::-1]: # revert to process simple keys first + group = by_conversion_pattern.pop(key) + converter = group.weight_converter + operations = converter.operations if isinstance(converter.operations, list) else [converter.operations] + for layer_name, tensors_for_this_layer in group.collected_tensors.items(): + pbar.update(1) + pbar.set_postfix({"Materializing param": layer_name}) + pbar.refresh() + concrete_target_keys = layer_name.split("|") + try: + with log_to_misc(layer_name, misc): + values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()] + + for op in operations[::-1]: + with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations): + reverse_op = op.reverse_op + values = reverse_op.convert(values, concrete_target_keys, config=model.config) + + values = [values] if not isinstance(values, list) else values + with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations): + if len(values) == 1 and isinstance(values[0], dict): + realized_value = values[0] + else: + realized_value = dict(zip(concrete_target_keys, values)) + + for k in list(realized_value.keys()).copy(): + if op := converter.quantization_operation: # dequantize + with log_to_misc(layer_name, misc, op=op): + realized_value.update( + op.convert( + {k: realized_value.pop(k)}, # quant_config=quantizer.quantization_config + ) + ) + + for k, output_value in realized_value.items(): + final_state_dict[k] = output_value[0] if isinstance(output_value, list) else output_value + + # TODO @Cyrilvallez handle scheduled saving, gather and etc + # schedule the saving of the weights using the threadpool. `save_file` + + except SkipLayer: + continue + del group + print(misc) + return final_state_dict diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 965084d0c24a..16c9d72fe3b4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3049,7 +3049,7 @@ def save_pretrained( variant: Optional[str] = None, token: Optional[Union[str, bool]] = None, save_peft_format: bool = True, - save_original_format: bool = False, # TODO next PR will make it go to True + save_original_format: bool = True, # TODO next PR will make it go to True **kwargs, ): """ @@ -3341,7 +3341,8 @@ def save_pretrained( # MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt # using what was loaded. Actually self._conversion_ops wont work because we need it # even if the files are not legacy -> thus no conversion happened - state_dict = revert_weight_conversion(self, state_dict) + weight_mapping = get_checkpoint_conversion_mapping(self.config.model_type) + state_dict = revert_weight_conversion(self, state_dict, weight_mapping) # Shard the model if it is too big. if not _hf_peft_config_loaded: