Skip to content

Commit 70874be

Browse files
llm: register sdpa variant (#3802)
1 parent 9241476 commit 70874be

File tree

7 files changed

+428
-101
lines changed

7 files changed

+428
-101
lines changed

docsrc/tutorials/compile_hf_models.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ We have officially verified support for the following LLM families:
5959
| Qwen/Qwen2.5-7B-Instruct
6060
- FP16, FP32
6161
- Yes
62+
* - Gemma 3
63+
- | google/gemma-3-1b-it
64+
- FP16, FP32
65+
- Yes
6266

6367
Getting Started with run_llm.py
6468
-------------------------------
@@ -185,8 +189,8 @@ The number of key/value cache tensors is equal to the number of attention heads
185189

186190
Generating Outputs
187191
-------------------
188-
We use custom `generate <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L112>`_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching.
189-
There is also a `generate_with_static_cache <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L141>`_ function that performs autoregressive decoding with KV caching.
192+
We use custom `generate <https://github.com/pytorch/TensorRT/blob/9241476a868af46169348ab730d18907365a66ee/tools/llm/utils.py#L112>`_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching.
193+
There is also a `generate_with_static_cache <https://github.com/pytorch/TensorRT/blob/9241476a868af46169348ab730d18907365a66ee/tools/llm/utils.py#L141>`_ function that performs autoregressive decoding with KV caching.
190194

191195
The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache.
192196
The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``.

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Callable, Optional, Sequence, Union
2+
from typing import Any, Callable, Optional, Sequence, Union
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -53,20 +53,28 @@
5353
def _aten_lowering_pass(
5454
*args: LoweringPassSignature,
5555
index: Optional[int] = None,
56+
**kwargs: Any,
5657
) -> Union[
5758
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
5859
]:
5960
"""Adds a lowering pass to the registry, at a specified index if desired
6061
6162
If no index is specified, the lowering pass is inserted at the end of the list
63+
64+
Additional keyword arguments can be passed to configure the lowering pass behavior.
65+
These will be stored as metadata on the pass function.
6266
"""
6367

6468
def add_lowering_pass(
6569
lowering_pass: LoweringPassSignature,
6670
) -> LoweringPassSignature:
71+
# Store additional parameters as metadata on the function
72+
if kwargs:
73+
lowering_pass._lowering_pass_config = kwargs
74+
6775
ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
6876
logger.debug(
69-
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
77+
f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}"
7078
)
7179
return lowering_pass
7280

@@ -81,7 +89,7 @@ def add_lowering_pass(
8189
f"aten_lowering_pass decorator called with invalid arguments {args} "
8290
"To specify an index to insert the pass, use the keyword 'index='"
8391
)
84-
# If no arguments are specified, the decorator was called with an index keyword
92+
# If no arguments are specified, the decorator was called with keyword arguments
8593
else:
8694
return add_lowering_pass
8795

@@ -95,6 +103,18 @@ def _remove_lowering_pass(*, index: int) -> None:
95103
return
96104

97105

106+
def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]:
107+
"""Get the configuration parameters for a lowering pass function
108+
109+
Args:
110+
lowering_pass: The lowering pass function
111+
112+
Returns:
113+
Dictionary containing the configuration parameters, or empty dict if none
114+
"""
115+
return getattr(lowering_pass, "_lowering_pass_config", {})
116+
117+
98118
def post_lowering(
99119
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
100120
) -> torch.fx.GraphModule:
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
import sys
3+
4+
import pytest
5+
import torch
6+
import torch_tensorrt
7+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
8+
9+
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../tools/llm"))
10+
import argparse
11+
12+
from run_llm import compile_torchtrt
13+
from torchtrt_ext import register_sdpa
14+
15+
16+
@pytest.mark.unit
17+
@pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"])
18+
def test_gemma3_decoder_layer(precision):
19+
20+
with torch.inference_mode():
21+
args = argparse.Namespace()
22+
args.debug = False
23+
args.num_tokens = 128
24+
args.model = "google/gemma-3-1b-it"
25+
args.precision = precision
26+
args.min_block_size = 1
27+
args.prompt = "What is parallel programming ?"
28+
if args.precision == "FP16":
29+
dtype = torch.float16
30+
elif args.precision == "BF16":
31+
dtype = torch.bfloat16
32+
else:
33+
args.precision = "FP32"
34+
dtype = torch.float32
35+
36+
model = (
37+
AutoModelForCausalLM.from_pretrained(
38+
args.model,
39+
use_cache=False,
40+
attn_implementation="sdpa",
41+
num_hidden_layers=1,
42+
)
43+
.eval()
44+
.to("cuda")
45+
)
46+
47+
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
48+
model = model.to(dtype)
49+
# use randint will generate nan values in the logits, use a fixed input_ids for now
50+
# input_ids = torch.randint(0, model.config.vocab_size, (1, args.num_tokens)).to("cuda")
51+
input_ids = torch.tensor([[2, 3689, 563, 10616, 14929, 2360]]).to("cuda")
52+
53+
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to("cuda")
54+
pyt_outputs = model(input_ids.clone(), position_ids=position_ids.clone())
55+
trt_model = compile_torchtrt(model, input_ids, args)
56+
trt_outputs = trt_model(input_ids, position_ids=position_ids)
57+
58+
torch.testing.assert_close(
59+
pyt_outputs.logits, trt_outputs.logits, rtol=5e-1, atol=5e-1
60+
)

tools/llm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ We have officially verified support for the following models:
2323
| LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct<br>meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes |
2424
| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct<br>Qwen/Qwen2.5-1.5B-Instruct<br>Qwen/Qwen2.5-4B-Instruct<br>Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes |
2525
| Qwen 3 | Qwen/Qwen3-0.6B<br>Qwen/Qwen3-1.7B<br>Qwen/Qwen3-4B<br>Qwen/Qwen3-8B | FP16, FP32 | Yes |
26+
| Gemma 3 | google/gemma-3-1b-it | FP16, FP32 | Yes |
2627

2728

2829
### Usage

tools/llm/run_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def get_model(args):
5858
.eval()
5959
.cuda()
6060
)
61+
# register SDPA variant for the model
62+
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
63+
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
64+
else:
65+
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
6166

6267
if args.precision == "FP16":
6368
model = model.to(torch.float16)

0 commit comments

Comments
 (0)