9
9
_get_decomp_for_cia ,
10
10
)
11
11
from torch ._ops import OpOverload
12
-
13
12
from torch_tensorrt .dynamo ._defaults import default_device
14
13
from torch_tensorrt .dynamo .conversion .converter_utils import get_positive_dim
15
14
from torch_tensorrt .dynamo .utils import to_torch_device
@@ -423,8 +422,8 @@ def instance_norm_decomposition(
423
422
424
423
@register_torch_trt_decomposition (
425
424
torch .ops .aten .full_like , registry = TORCH_TRT_DECOMPOSITIONS
426
- ) # type: ignore
427
- def full_like_decomposition (* args , ** kwargs ) -> torch .Tensor :
425
+ )
426
+ def full_like_decomposition (* args : Any , ** kwargs : Any ) -> torch .Tensor :
428
427
input = args [0 ]
429
428
shape = args [0 ].shape
430
429
fill_value = args [1 ]
@@ -454,11 +453,13 @@ def scaled_dot_product_attention_decomposition(
454
453
) -> torch .Tensor :
455
454
L , S = query .size (- 2 ), key .size (- 2 )
456
455
device = query .device
457
- attn_bias = torch .zeros (L , S , dtype = query .dtype , device = device )
456
+
457
+ if is_causal or attn_mask is not None :
458
+ attn_bias = torch .zeros ((L , S ), dtype = query .dtype , device = device )
458
459
459
460
if is_causal :
460
461
assert attn_mask is None , "attn_mask must be None when is_causal=True"
461
- temp_mask = torch .ones (L , S , dtype = torch .bool , device = device ).tril (diagonal = 0 )
462
+ temp_mask = torch .ones (( L , S ) , dtype = torch .bool , device = device ).tril (diagonal = 0 )
462
463
attn_bias = attn_bias .masked_fill (temp_mask .logical_not (), float ("-inf" ))
463
464
464
465
if attn_mask is not None :
@@ -471,17 +472,19 @@ def scaled_dot_product_attention_decomposition(
471
472
key = key .repeat_interleave (query .size (- 3 ) // key .size (- 3 ), - 3 )
472
473
value = value .repeat_interleave (query .size (- 3 ) // value .size (- 3 ), - 3 )
473
474
474
- attn_weight = query @ key .transpose (- 2 , - 1 )
475
+ attn_weight = torch . matmul ( query , key .transpose (- 2 , - 1 ) )
475
476
476
477
if scale is None :
477
478
scale = torch .sqrt (torch .scalar_tensor (query .size (- 1 ), dtype = torch .int ))
478
479
attn_weight = attn_weight / scale
479
480
else :
480
481
attn_weight = attn_weight * scale
481
482
482
- attn_weight = attn_weight + attn_bias
483
+ if is_causal or attn_mask is not None :
484
+ attn_weight = attn_weight + attn_bias
485
+
483
486
attn_weight = torch .softmax (attn_weight , dim = - 1 )
484
- return attn_weight @ value
487
+ return torch . matmul ( attn_weight , value )
485
488
486
489
487
490
@register_torch_trt_decomposition (
0 commit comments