Skip to content

Commit d0cd469

Browse files
change llm model test from gemma3 to qwen to skip auth (#3807)
1 parent d76574f commit d0cd469

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

tests/py/dynamo/models/test_llm_models.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515

1616
@pytest.mark.unit
1717
@pytest.mark.parametrize("precision", ["FP16", "BF16", "FP32"])
18-
def test_gemma3_decoder_layer(precision):
18+
def test_llm_decoder_layer(precision):
1919

2020
with torch.inference_mode():
2121
args = argparse.Namespace()
2222
args.debug = False
2323
args.num_tokens = 128
24-
args.model = "google/gemma-3-1b-it"
24+
args.model = "Qwen/Qwen2.5-0.5B-Instruct"
2525
args.precision = precision
2626
args.min_block_size = 1
2727
args.prompt = "What is parallel programming ?"
@@ -44,7 +44,10 @@ def test_gemma3_decoder_layer(precision):
4444
.to("cuda")
4545
)
4646

47-
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
47+
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
48+
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
49+
else:
50+
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
4851
model = model.to(dtype)
4952
# use randint will generate nan values in the logits, use a fixed input_ids for now
5053
# input_ids = torch.randint(0, model.config.vocab_size, (1, args.num_tokens)).to("cuda")

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,10 @@ def scaled_dot_product_attention(
257257
attn_bias = impl.unary.log(
258258
ctx, target, source_ir, name + "_log", one_minus_temp_mask
259259
)
260-
scaled_add_attn_bias = impl.elementwise.add(
261-
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
262-
)
260+
261+
scaled_add_attn_bias = impl.elementwise.add(
262+
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
263+
)
263264
softmax = impl.normalization.softmax(
264265
ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
265266
)

0 commit comments

Comments
 (0)