Skip to content

Commit 97b4073

Browse files
authored
Fixed SDPA slow down and linear slow down (#3700)
1 parent 220cd64 commit 97b4073

File tree

8 files changed

+65
-10
lines changed

8 files changed

+65
-10
lines changed

examples/apps/flux_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def forward_loop(mod):
120120
"enabled_precisions": enabled_precisions,
121121
"truncate_double": True,
122122
"min_block_size": 1,
123-
"use_python_runtime": False,
123+
"use_python_runtime": True,
124124
"immutable_weights": False,
125125
"offload_module_to_cpu": args.low_vram_mode,
126126
"use_explicit_typing": use_explicit_typing,

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def aten_ops_gelu(
532532

533533

534534
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
535+
@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True)
535536
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
536537
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
537538
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import logging
12
import operator
23
import warnings
34
from typing import Any, Callable, Optional, Union
45

5-
import numpy as np
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
@@ -20,6 +20,8 @@
2020
)
2121
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2222

23+
logger = logging.getLogger(__name__)
24+
2325

2426
def get_python_op_from_trt_elementwise_op(
2527
trt_op: TRTElementWiseOp,
@@ -148,7 +150,11 @@ def convert_binary_elementwise(
148150
ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
149151
)
150152

151-
if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
153+
if len(lhs_val.shape) == len(rhs_val.shape) and all(
154+
a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape)
155+
):
156+
logger.info(f"skip broadcast for {name}")
157+
elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
152158
lhs_val, rhs_val = broadcast(
153159
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
154160
)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
aten.upsample_trilinear3d.vec,
171171
aten.upsample_bicubic2d.vec,
172172
aten.linear.default,
173+
aten.matmul.default,
173174
}
174175

175176

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,13 @@ def scaled_dot_product_attention_decomposition(
463463
) -> torch.Tensor:
464464
L, S = query.size(-2), key.size(-2)
465465
device = query.device
466-
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)
466+
467+
if is_causal or attn_mask is not None:
468+
attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device)
467469

468470
if is_causal:
469471
assert attn_mask is None, "attn_mask must be None when is_causal=True"
470-
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
472+
temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0)
471473
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))
472474

473475
if attn_mask is not None:
@@ -480,7 +482,7 @@ def scaled_dot_product_attention_decomposition(
480482
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
481483
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
482484

483-
attn_weight = query @ key.transpose(-2, -1)
485+
attn_weight = torch.matmul(query, key.transpose(-2, -1))
484486

485487
if scale is None:
486488
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)).to(
@@ -490,9 +492,12 @@ def scaled_dot_product_attention_decomposition(
490492
else:
491493
attn_weight = attn_weight * scale
492494

493-
attn_weight = attn_weight + attn_bias
495+
if is_causal or attn_mask is not None:
496+
# We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
497+
attn_weight = attn_weight + attn_bias
498+
494499
attn_weight = torch.softmax(attn_weight, dim=-1)
495-
return attn_weight @ value
500+
return torch.matmul(attn_weight, value)
496501

497502

498503
@register_torch_trt_decomposition(

py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@
1010

1111

1212
def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
13+
"""
14+
Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate
15+
`add` and `mm` nodes. This is useful for passes that want to insert additional
16+
logic (such as FP32 accumulation) specifically around the matrix multiplication
17+
operation, rather than the fused addmm.
18+
19+
Args:
20+
gm (torch.fx.GraphModule): The FX graph module to transform.
21+
22+
Returns:
23+
torch.fx.GraphModule: The modified FX graph module with addmm nodes split.
24+
"""
1325
target = torch.ops.aten.addmm.default
1426
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
1527
for addmm_node in addmm_nodes:
@@ -52,6 +64,7 @@ def accumulate_fp32_matmul(
5264
matmul_targets = [
5365
torch.ops.aten.mm.default,
5466
torch.ops.aten.bmm.default,
67+
torch.ops.aten.matmul.default,
5568
]
5669

5770
# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes

tools/perf/Flux/benchmark.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
#TODO: Enter the HF Token
22
huggingface-cli login --token HF_TOKEN
33

4+
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> pytorch_fp16_gpu_utilization.txt &
5+
NVIDIA_SMI_PID=$!
6+
python flux_perf.py --pytorch --max_batch_size 3 > pytorch_fp16_benchmark.txt
7+
kill $NVIDIA_SMI_PID
8+
49
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp8_gpu_utilization.txt &
510
NVIDIA_SMI_PID=$!
6-
python flux_perf.py --dtype fp8 --low_vram_mode> fp8_benchmark.txt
11+
python flux_perf.py --dtype fp8 --max_batch_size 3 > fp8_benchmark.txt
12+
kill $NVIDIA_SMI_PID
13+
14+
15+
nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp16_gpu_utilization.txt &
16+
NVIDIA_SMI_PID=$!
17+
python flux_perf.py --dtype fp16 --max_batch_size 3 > fp16_benchmark.txt
718
kill $NVIDIA_SMI_PID
819

920

tools/perf/Flux/flux_perf.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,22 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
4444
return
4545

4646

47+
from diffusers import FluxPipeline
48+
49+
4750
def main(args):
4851
print(f"Running flux_perfwith args: {args}")
49-
pipe, backbone, trt_gm = compile_model(args)
52+
if not args.pytorch:
53+
pipe, backbone, trt_gm = compile_model(args)
54+
else:
55+
pipe = (
56+
FluxPipeline.from_pretrained(
57+
"black-forest-labs/FLUX.1-dev",
58+
torch_dtype=torch.float16,
59+
)
60+
.to(torch.float16)
61+
.to("cuda:0")
62+
)
5063

5164
benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3)
5265

@@ -83,6 +96,11 @@ def main(args):
8396
action="store_true",
8497
help="Use dynamic shapes",
8598
)
99+
parser.add_argument(
100+
"--pytorch",
101+
action="store_true",
102+
help="Use pytorch runtime and no tensorrt",
103+
)
86104
parser.add_argument("--max_batch_size", type=int, default=1)
87105
args = parser.parse_args()
88106
main(args)

0 commit comments

Comments
 (0)