Skip to content

Commit 67302b0

Browse files
SunMarcArthurZuckerCyrilvallezmatthewdouglasgithub-actions[bot]
authored
Fix bnb for the weights refactor (#42043)
* small fix * nits * ish * up * rev * fix more tie weights keys * small fixes * nit * update * fix and fix * fix a test * glubs * current shitty changes * ship validated ones * more * more update * more * more * more * mllama * more up * fix ernie * fix xopies * up more * more fixes * up * up * fix-copies * fix more * more updates * AI UPDATE * up * hoey * make it fast * fix * lol * fix asjusting * more fixes * _dtype nit * up * nit * update * update * remove semaphores * fix import to avoid jit execution * try to remove custom tiing logic when its stupid * fix more individual models * fix whisper as well * fix? * fox umt5 * improve tqdm bar * cleanup a bit * oupsi * some updates * improve * remove all buffering -> much faster without it * remove some tie_weights custome funcs when not needed * more fixes related to strict matching regex * remove ALL custom tie weights * small update * revert change to init scheme (no need for params) * fix * mixtral init * try less strict source check * tied weight first shot to the fiiiixxxxxx * does this help? * :) * fix some ppolry defined tied_weights_keys for now * fixes for more models torch_bc * nits and fixes * last update * Revert "tied weight first shot to the fiiiixxxxxx" This reverts commit 3fea865. * here we go again * an attempt * up? * nits * Fix bnb loading ! * rm print * subclass nn.Parameters * up * lol * Ouiiii * fix led * fix long cat flash * fix qwen and long cat flash * properly fix qwen init * just push this for now * propnet is dumb * update * rm import * update * push * Update src/transformers/core_model_loading.py Co-authored-by: Matthew Douglas <[email protected]> * remove explict sharing of some tied keys. * update decoder.bias * moe case * Fix loadedparam * rm report * more changes to untangle old hardcoded ting * fixup * fix big faileurs * Fix tests single gpu * should fix it * fix prophnet * fix resize token embeddings * nits * fix xcodex * asyncio? * fix smart apply * fix data-2-vec * [build-ci-image] * checkout * uupdate * fix hunyuan * update error message * fix deformable detr * fixes * fix init weights for non param gate up projs * shared todo? * guard needed for compressed-tensors * deal with buffers * update some models * big revert, don't break this behaviour * ty @SunMarc this fixes the buffers Co-authored-by: SunMarc <[email protected]> * mt5 fuck * fix lxmbert * nuke slow test fetcher * fix * fix zamba and deepcopy for now * fix zamba tied weight keys! ~ * fix-copies * update fetch terst * fix gradient for test modeling common! * break "shared" for now I will fix tomorrow changes are properly isoalted now :) * does this fix marian? probably not * fix some vlms * D fine seems to handle this well * glob is fine actually * fix dab detr * small steps * opusy * fix some more models? * yups * better erro * fix? * fix double escape * escape wehere it makes sense * ?? * fix ibert * fix tvp as well * more fxes * try always download ref PR * ONONONO * big fixup * more fixup * small step * small nits * nits * brut force some stuff * fix vilt * make sure special models that always need tie always tie * cleaning up * small nits * fix zamba and bridge tower! * just fixup * potential culprits * revert bark and fix bridgetower * remove now non existant tie_weights * ? * lol reformer actually had nothing tied! * wow these two fucking models were really not well made * fix sam family! * fix bark revision * fix speech2test ? * push this for now.... * upsy * the fuck * fix rtdetr * update * proper * wow that one 's annoying * update * try to find the culprit * get some help on common * nit about general init and cls.padding_idx * revert num workers update * remove old loading func * fix glob * add annotations * fix re * small improvements * clean some stuff * improvements * someone did not understannnnnnd what I tried to dooo or does BNB not support that either? * gluos * fix case when `.` is just not there * for now let's do this * fix * fix small test * style * fix merge conflits * style * 8bit fixed ? * fix * fix 8bit dtype * fix * rm copy * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * style * test * fix * finally ? * Apply style fixes * fix * fix * Apply style fixes * tie weights * warning * Apply style fixes * init * default --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Cyril Vallez <[email protected]> Co-authored-by: Matthew Douglas <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: SunMarc <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 9f31104 commit 67302b0

File tree

13 files changed

+627
-185
lines changed

13 files changed

+627
-185
lines changed

src/transformers/core_model_loading.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,6 @@
4646

4747
logger = logging.get_logger(__name__)
4848

49-
str_to_torch_dtype = {
50-
"BOOL": torch.bool,
51-
"U8": torch.uint8,
52-
"I8": torch.int8,
53-
"I16": torch.int16,
54-
"F16": torch.float16,
55-
"BF16": torch.bfloat16,
56-
"I32": torch.int32,
57-
"F32": torch.float32,
58-
"F64": torch.float64,
59-
"I64": torch.int64,
60-
"F8_E4M3": torch.float8_e4m3fn,
61-
"F8_E5M2": torch.float8_e5m2,
62-
}
63-
64-
6549
logger = logging.get_logger(__name__)
6650

6751

@@ -389,11 +373,15 @@ def set_param_for_module(
389373
missing_keys: MutableSet[str],
390374
misc: MutableMapping[str, Any],
391375
distributed_operation: Optional[TensorParallelLayer],
376+
hf_quantizer: HfQuantizer,
392377
):
393378
with log_to_misc(layer_name, misc, layer_name):
394379
module_path, _, param_name = layer_name.rpartition(".")
395380
module_obj = model.get_submodule(module_path) if module_path else model
396-
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
381+
if isinstance(param_value, list):
382+
param_value = param_value[0]
383+
elif not isinstance(param_value, torch.nn.Parameter):
384+
param_value = param_value[...]
397385
ref = getattr(module_obj, param_name)
398386

399387
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
@@ -415,7 +403,7 @@ def set_param_for_module(
415403

416404
# Remove from missing keys (it's either mismatched, or all good)
417405
missing_keys.discard(layer_name)
418-
if ref is not None and ref.shape != param_value.shape:
406+
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
419407
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
420408
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
421409
else:
@@ -434,7 +422,7 @@ def convert_and_load_state_dict_in_model(
434422
state_dict: dict[str, Any],
435423
weight_mapping: dict[str, WeightConverter] | None,
436424
tp_plan: dict[str, str] | None,
437-
quantizer: HfQuantizer | None,
425+
hf_quantizer: HfQuantizer | None,
438426
dtype: torch.dtype | None = None,
439427
device_map: dict | None = None,
440428
dtype_plan: dict | None = None,
@@ -499,20 +487,14 @@ def convert_and_load_state_dict_in_model(
499487
unexpected_keys.add(t)
500488
continue
501489

502-
if quantizer is not None and quantizer.param_needs_quantization(model, t):
503-
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
504-
from .integrations.finegrained_fp8 import Fp8Quantize
505-
506-
converter.quantization_operation = Fp8Quantize() # TODO support other methods
507-
else:
508-
raise ValueError("This quantization method is gonna be supported SOOOON")
509-
else:
510-
_dtype = dtype
511-
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
512-
if matched_dtype_pattern is not None:
513-
_dtype = dtype_plan[matched_dtype_pattern]
514-
elif empty_param.dtype != _dtype:
515-
_dtype = empty_param.dtype
490+
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
491+
converter.quantization_operation = hf_quantizer.get_quantize_ops()
492+
_dtype = dtype
493+
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
494+
if matched_dtype_pattern is not None:
495+
_dtype = dtype_plan[matched_dtype_pattern]
496+
elif empty_param.dtype != _dtype:
497+
_dtype = empty_param.dtype
516498

517499
first_target_key = new_target_key[0]
518500
target_key = "|".join(new_target_key)
@@ -575,9 +557,7 @@ def convert_and_load_state_dict_in_model(
575557
if op := converter.quantization_operation:
576558
with log_to_misc(layer_name, misc, op=op):
577559
realized_value.update(
578-
op.convert(
579-
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
580-
)
560+
op.convert({k: realized_value.pop(k)}, model=model, missing_keys=missing_keys)
581561
)
582562

583563
for k, output_value in realized_value.items():
@@ -591,6 +571,7 @@ def convert_and_load_state_dict_in_model(
591571
missing_keys,
592572
misc,
593573
converter.distributed_operation,
574+
hf_quantizer,
594575
)
595576

596577
except SkipLayer:

src/transformers/integrations/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
"unpack_weights",
3333
],
3434
"bitsandbytes": [
35+
"Bnb4bitQuantize",
3536
"dequantize_and_replace",
36-
"get_keys_to_not_convert",
3737
"replace_with_bnb_linear",
3838
"validate_bnb_backend_availability",
3939
],
@@ -177,8 +177,8 @@
177177
unpack_weights,
178178
)
179179
from .bitsandbytes import (
180+
Bnb4bitQuantize,
180181
dequantize_and_replace,
181-
get_keys_to_not_convert,
182182
replace_with_bnb_linear,
183183
validate_bnb_backend_availability,
184184
)

0 commit comments

Comments
 (0)