Skip to content

Commit f4b6184

Browse files
committed
Fixed SDPA perf gap
1 parent e9b099a commit f4b6184

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def aten_ops_gelu(
531531
)
532532

533533

534-
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
534+
@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True)
535535
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
536536
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
537537
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
aten.upsample_bilinear2d.vec,
172172
aten.upsample_trilinear3d.vec,
173173
aten.upsample_bicubic2d.vec,
174+
aten.matmul.default,
174175
}
175176

176177

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
_get_decomp_for_cia,
1010
)
1111
from torch._ops import OpOverload
12-
1312
from torch_tensorrt.dynamo._defaults import default_device
1413
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
1514
from torch_tensorrt.dynamo.utils import to_torch_device
@@ -423,8 +422,8 @@ def instance_norm_decomposition(
423422

424423
@register_torch_trt_decomposition(
425424
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
426-
) # type: ignore
427-
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
425+
)
426+
def full_like_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
428427
input = args[0]
429428
shape = args[0].shape
430429
fill_value = args[1]
@@ -454,11 +453,13 @@ def scaled_dot_product_attention_decomposition(
454453
) -> torch.Tensor:
455454
L, S = query.size(-2), key.size(-2)
456455
device = query.device
457-
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)
456+
457+
if is_causal or attn_mask is not None:
458+
attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device)
458459

459460
if is_causal:
460461
assert attn_mask is None, "attn_mask must be None when is_causal=True"
461-
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
462+
temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0)
462463
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
463464

464465
if attn_mask is not None:
@@ -471,17 +472,19 @@ def scaled_dot_product_attention_decomposition(
471472
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
472473
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
473474

474-
attn_weight = query @ key.transpose(-2, -1)
475+
attn_weight = torch.matmul(query, key.transpose(-2, -1))
475476

476477
if scale is None:
477478
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
478479
attn_weight = attn_weight / scale
479480
else:
480481
attn_weight = attn_weight * scale
481482

482-
attn_weight = attn_weight + attn_bias
483+
if is_causal or attn_mask is not None:
484+
attn_weight = attn_weight + attn_bias
485+
483486
attn_weight = torch.softmax(attn_weight, dim=-1)
484-
return attn_weight @ value
487+
return torch.matmul(attn_weight, value)
485488

486489

487490
@register_torch_trt_decomposition(

0 commit comments

Comments
 (0)