Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 62 additions & 55 deletions tensorrt_llm/_torch/compilation/piecewise_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..utils import (get_model_extra_attrs,
get_per_request_piecewise_cuda_graph_flag,
get_piecewise_cuda_graph_flag, make_weak_ref,
set_piecewise_running)
skip_maybe_compile)
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function

Expand Down Expand Up @@ -171,68 +171,75 @@ def __call__(self, *args):
or not get_per_request_piecewise_cuda_graph_flag()):
return self.default_callable(*args)

if self.is_first_runner or self.is_last_runner:
if self.is_first_runner == self.is_last_runner:
set_piecewise_running(False)
else:
set_piecewise_running(self.is_first_runner)

entry = self.entries[runtime_num_of_token]

if entry.enable_inductor and not entry.compiled:
entry.callable = compile_fx(entry.callable, args)
entry.compiled = True

if entry.cuda_graph is None:

if not get_capture_piecewise_cuda_graph_flag():
return entry.callable(*args)

if entry.warmup_count < 3:
entry.warmup_count += 1
return entry.callable(*args)

entry.input_addresses = [
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
]

graph = torch.cuda.CUDAGraph()

# Torch's cuda graph will call gc.collect() internally. This will slow down the performance.
# We patch it to do nothing.
with patch("gc.collect", lambda: None):
# TODO: consider to use `make_graphed_callables()` when
# it's ready rather than capture it ourselves
# Graph Capture would override the stream. We need to setup the stream correctly.
extra_attrs = get_model_extra_attrs()
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
# Determine if we should skip compilation in @maybe_compile decorated functions:
# - First runner only: skip compilation (to avoid overhead)
# - Last runner only: skip compilation (to avoid overhead)
# - Both first and last (single runner): allow compilation (normal mode)
# - Middle runner: allow compilation (normal mode)
should_skip = (self.is_first_runner or self.is_last_runner) and \
not (self.is_first_runner and self.is_last_runner)

# Use context manager to directly control @maybe_compile behavior
# This makes the relationship explicit: PiecewiseRunner → skip_maybe_compile → @maybe_compile
with skip_maybe_compile(should_skip):
entry = self.entries[runtime_num_of_token]

if entry.enable_inductor and not entry.compiled:
entry.callable = compile_fx(entry.callable, args)
entry.compiled = True

if entry.cuda_graph is None:

if not get_capture_piecewise_cuda_graph_flag():
return entry.callable(*args)

if entry.warmup_count < 3:
entry.warmup_count += 1
return entry.callable(*args)

entry.input_addresses = [
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
]

graph = torch.cuda.CUDAGraph()

# Torch's cuda graph will call gc.collect() internally. This will slow down the performance.
# We patch it to do nothing.
with patch("gc.collect", lambda: None):
# TODO: consider to use `make_graphed_callables()` when
# it's ready rather than capture it ourselves
# Graph Capture would override the stream. We need to setup the stream correctly.
extra_attrs = get_model_extra_attrs()
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
extra_attrs[
"global_stream"] = torch.cuda.current_stream()
output = entry.callable(*args)
extra_attrs["global_stream"] = torch.cuda.current_stream()
output = entry.callable(*args)
extra_attrs["global_stream"] = torch.cuda.current_stream()

entry.cuda_graph = graph
# Mark weak ref here. The intermediate activation tensor should be freed properly.
# Here we don't use python native weakref since we still need the object to be alive when the graph is replayed.
entry.output = make_weak_ref(output)
entry.output_addresses = [
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
]
entry.cuda_graph = graph
# Mark weak ref here. The intermediate activation tensor should be freed properly.
# Here we don't use python native weakref since we still need the object to be alive when the graph is replayed.
entry.output = make_weak_ref(output)
entry.output_addresses = [
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
]

entry.cuda_graph.replay()
entry.cuda_graph.replay()

return output
return output

if enable_llm_debug():
runtime_input_addresses = [
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
]
if enable_llm_debug():
runtime_input_addresses = [
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
]

assert (entry.input_addresses == runtime_input_addresses
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
assert (
entry.input_addresses == runtime_input_addresses
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"

entry.cuda_graph.replay()
entry.cuda_graph.replay()

return entry.output
return entry.output


def piecewise_optimizer(
Expand Down
66 changes: 51 additions & 15 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils

is_torch_compiling_flag = False
is_piecewise_running_flag = False
_torch_compiling = threading.local()
# Controls whether @maybe_compile decorator should skip compilation
# Set directly by PiecewiseRunner to avoid compilation overhead
_skip_maybe_compile = threading.local()

aux_stream_name_list = [
'Attention',
Expand Down Expand Up @@ -46,23 +48,42 @@ class ActivationType(IntEnum):


def set_torch_compiling(enable: bool):
global is_torch_compiling_flag
is_torch_compiling_flag = enable
_torch_compiling.flag = enable


def is_torch_compiling() -> bool:
global is_torch_compiling_flag
return is_torch_compiling_flag
return getattr(_torch_compiling, 'flag', False)


def set_piecewise_running(enable: bool):
global is_piecewise_running_flag
is_piecewise_running_flag = enable
@contextlib.contextmanager
def skip_maybe_compile(skip: bool = True):
"""
Context manager to directly control @maybe_compile decorator behavior.

When skip=True, functions decorated with @maybe_compile will skip torch.compile
to avoid compilation overhead. Used by PiecewiseRunner to control compilation.

This makes the relationship between PiecewiseRunner and @maybe_compile explicit.

Args:
skip: Whether to skip compilation in @maybe_compile decorated functions

Example:
with skip_maybe_compile(True):
# Functions with @maybe_compile will NOT be compiled
result = some_function()
"""
old_state = getattr(_skip_maybe_compile, 'flag', False)
_skip_maybe_compile.flag = skip
try:
yield
finally:
_skip_maybe_compile.flag = old_state


def is_piecewise_running() -> bool:
global is_piecewise_running_flag
return is_piecewise_running_flag
def _should_skip_maybe_compile() -> bool:
"""Check if @maybe_compile should skip compilation."""
return getattr(_skip_maybe_compile, 'flag', False)


_global_attrs = threading.local()
Expand Down Expand Up @@ -344,19 +365,34 @@ def get_device_uuid(device_idx: int) -> str:
def maybe_compile(func=None, **compile_kwargs):
"""
Conditionally compile a function with torch.compile.
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.

Compilation is skipped when running within a skip_maybe_compile(True) context,
which is used by PiecewiseRunner to avoid compilation overhead.

Args:
func: The function to decorate (optional, for direct decoration).
**compile_kwargs: Keyword arguments for torch.compile.
Returns:
The conditionally compiled function..
The conditionally compiled function.

Example:
@maybe_compile
def my_function(x):
return x * 2

# Normal usage: function is compiled
result = my_function(tensor)

# With skip_maybe_compile: function runs uncompiled
with skip_maybe_compile(True):
result = my_function(tensor) # Not compiled
"""

def decorator(f):
compiled_func = torch.compile(f, **compile_kwargs)

def wrapper(*args, **kwargs):
if is_piecewise_running():
if _should_skip_maybe_compile():
return f(*args, **kwargs)
return compiled_func(*args, **kwargs)

Expand Down
Loading