Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Nov 5, 2025

What does this PR do?

This PR fixes bnb support (8bit + 4bit) in the new weight loading logic.

Testing

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_name = "meta-llama/Llama-3.2-3B-Instruct"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")

#model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
# don't pass quantization_config

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map=0
)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, do_sample=False, max_new_tokens=1024)
print(tokenizer.decode(outputs[0]))
  • check why the memory is way too high when quantizing on the fly
  • bnb tests

@SunMarc
Copy link
Member Author

SunMarc commented Nov 17, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Nov 17, 2025

Style bot fixed some files and pushed the changes.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.Just tie twice is my nightmare but good otherwise

Comment on lines +487 to +489
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
converter.quantization_operation = hf_quantizer.get_quantize_ops()
_dtype = dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@SunMarc
Copy link
Member Author

SunMarc commented Nov 18, 2025

run-slow: bnb

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: []
quantizations: ["quantization/bnb"]

@SunMarc
Copy link
Member Author

SunMarc commented Nov 18, 2025

LGTM.Just tie twice is my nightmare but good otherwise

I've upstreamed some code from accelerate to fix tied weights. This should be easier this way and we can better tweak device map related code in the future.

@SunMarc
Copy link
Member Author

SunMarc commented Nov 18, 2025

@bot /style

@github-actions
Copy link
Contributor

Style fix is beginning .... View the workflow run here.

@SunMarc
Copy link
Member Author

SunMarc commented Nov 18, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Nov 18, 2025

Style bot fixed some files and pushed the changes.

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@SunMarc
Copy link
Member Author

SunMarc commented Nov 18, 2025

run-slow: bnb

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: []
quantizations: ["quantization/bnb"]

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

Comment on lines 592 to 594
if tied_parameters is None and len(model.all_tied_weights_keys) > 0:
# create a list of list of tied params
tied_parameters = [list(t) for t in model.all_tied_weights_keys.items()]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed this

Comment on lines +615 to +624
def infer_auto_device_map(
model: nn.Module,
max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[list[str]] = None,
verbose: bool = False,
clean_result: bool = True,
offload_buffers: bool = False,
tied_parameters: Optional[list[list[str]]] = None,
hf_quantizer: "HfQuantizer | None" = None,
):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed dtype and special_dtype to rely on hf_quantizer instead when computing compute_module_sizes

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bnb, finegrained_fp8

@ArthurZucker ArthurZucker merged commit 67302b0 into main Nov 18, 2025
22 of 24 checks passed
@ArthurZucker ArthurZucker deleted the fix-bnb branch November 18, 2025 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants