Skip to content

Commit fb6f64e

Browse files
committed
Control skip_maybe_compile behavior directly
1 parent f059da2 commit fb6f64e

File tree

2 files changed

+108
-63
lines changed

2 files changed

+108
-63
lines changed

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..utils import (get_model_extra_attrs,
1515
get_per_request_piecewise_cuda_graph_flag,
1616
get_piecewise_cuda_graph_flag, make_weak_ref,
17-
set_piecewise_running)
17+
skip_maybe_compile)
1818
from .multi_stream.auto_multi_stream import multi_stream_schedule
1919
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
2020

@@ -171,68 +171,73 @@ def __call__(self, *args):
171171
or not get_per_request_piecewise_cuda_graph_flag()):
172172
return self.default_callable(*args)
173173

174-
if self.is_first_runner or self.is_last_runner:
175-
if self.is_first_runner == self.is_last_runner:
176-
set_piecewise_running(False)
177-
else:
178-
set_piecewise_running(self.is_first_runner)
179-
180-
entry = self.entries[runtime_num_of_token]
181-
182-
if entry.enable_inductor and not entry.compiled:
183-
entry.callable = compile_fx(entry.callable, args)
184-
entry.compiled = True
185-
186-
if entry.cuda_graph is None:
187-
188-
if not get_capture_piecewise_cuda_graph_flag():
189-
return entry.callable(*args)
190-
191-
if entry.warmup_count < 3:
192-
entry.warmup_count += 1
193-
return entry.callable(*args)
194-
195-
entry.input_addresses = [
196-
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
197-
]
198-
199-
graph = torch.cuda.CUDAGraph()
200-
201-
# Torch's cuda graph will call gc.collect() internally. This will slow down the performance.
202-
# We patch it to do nothing.
203-
with patch("gc.collect", lambda: None):
204-
# TODO: consider to use `make_graphed_callables()` when
205-
# it's ready rather than capture it ourselves
206-
# Graph Capture would override the stream. We need to setup the stream correctly.
207-
extra_attrs = get_model_extra_attrs()
208-
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
174+
# Determine if we should skip compilation in @maybe_compile decorated functions:
175+
# - First runner only: skip compilation (to avoid overhead)
176+
# - Last runner only: skip compilation (to avoid overhead)
177+
# - Both first and last (single runner): allow compilation (normal mode)
178+
# - Middle runner: allow compilation (normal mode)
179+
should_skip = (self.is_first_runner or self.is_last_runner) and \
180+
not (self.is_first_runner and self.is_last_runner)
181+
182+
# Use context manager to directly control @maybe_compile behavior
183+
# This makes the relationship explicit: PiecewiseRunner → skip_maybe_compile → @maybe_compile
184+
with skip_maybe_compile(should_skip):
185+
entry = self.entries[runtime_num_of_token]
186+
187+
if entry.enable_inductor and not entry.compiled:
188+
entry.callable = compile_fx(entry.callable, args)
189+
entry.compiled = True
190+
191+
if entry.cuda_graph is None:
192+
193+
if not get_capture_piecewise_cuda_graph_flag():
194+
return entry.callable(*args)
195+
196+
if entry.warmup_count < 3:
197+
entry.warmup_count += 1
198+
return entry.callable(*args)
199+
200+
entry.input_addresses = [
201+
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
202+
]
203+
204+
graph = torch.cuda.CUDAGraph()
205+
206+
# Torch's cuda graph will call gc.collect() internally. This will slow down the performance.
207+
# We patch it to do nothing.
208+
with patch("gc.collect", lambda: None):
209+
# TODO: consider to use `make_graphed_callables()` when
210+
# it's ready rather than capture it ourselves
211+
# Graph Capture would override the stream. We need to setup the stream correctly.
212+
extra_attrs = get_model_extra_attrs()
213+
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
214+
extra_attrs["global_stream"] = torch.cuda.current_stream()
215+
output = entry.callable(*args)
209216
extra_attrs["global_stream"] = torch.cuda.current_stream()
210-
output = entry.callable(*args)
211-
extra_attrs["global_stream"] = torch.cuda.current_stream()
212217

213-
entry.cuda_graph = graph
214-
# Mark weak ref here. The intermediate activation tensor should be freed properly.
215-
# Here we don't use python native weakref since we still need the object to be alive when the graph is replayed.
216-
entry.output = make_weak_ref(output)
217-
entry.output_addresses = [
218-
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
219-
]
218+
entry.cuda_graph = graph
219+
# Mark weak ref here. The intermediate activation tensor should be freed properly.
220+
# Here we don't use python native weakref since we still need the object to be alive when the graph is replayed.
221+
entry.output = make_weak_ref(output)
222+
entry.output_addresses = [
223+
i.data_ptr() for i in output if isinstance(i, torch.Tensor)
224+
]
220225

221-
entry.cuda_graph.replay()
226+
entry.cuda_graph.replay()
222227

223-
return output
228+
return output
224229

225-
if enable_llm_debug():
226-
runtime_input_addresses = [
227-
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
228-
]
230+
if enable_llm_debug():
231+
runtime_input_addresses = [
232+
i.data_ptr() for i in args if isinstance(i, torch.Tensor)
233+
]
229234

230-
assert (entry.input_addresses == runtime_input_addresses
231-
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
235+
assert (entry.input_addresses == runtime_input_addresses
236+
), f"{entry.input_addresses} vs\n {runtime_input_addresses}"
232237

233-
entry.cuda_graph.replay()
238+
entry.cuda_graph.replay()
234239

235-
return entry.output
240+
return entry.output
236241

237242

238243
def piecewise_optimizer(

tensorrt_llm/_torch/utils.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from tensorrt_llm.quantization.utils import fp4_utils
1313

1414
_torch_compiling = threading.local()
15-
_piecewise_running = threading.local()
15+
# Controls whether @maybe_compile decorator should skip compilation
16+
# Set directly by PiecewiseRunner to avoid compilation overhead
17+
_skip_maybe_compile = threading.local()
1618

1719
aux_stream_name_list = [
1820
'Attention',
@@ -53,12 +55,35 @@ def is_torch_compiling() -> bool:
5355
return getattr(_torch_compiling, 'flag', False)
5456

5557

56-
def set_piecewise_running(enable: bool):
57-
_piecewise_running.flag = enable
58+
@contextlib.contextmanager
59+
def skip_maybe_compile(skip: bool = True):
60+
"""
61+
Context manager to directly control @maybe_compile decorator behavior.
62+
63+
When skip=True, functions decorated with @maybe_compile will skip torch.compile
64+
to avoid compilation overhead. Used by PiecewiseRunner to control compilation.
65+
66+
This makes the relationship between PiecewiseRunner and @maybe_compile explicit.
67+
68+
Args:
69+
skip: Whether to skip compilation in @maybe_compile decorated functions
70+
71+
Example:
72+
with skip_maybe_compile(True):
73+
# Functions with @maybe_compile will NOT be compiled
74+
result = some_function()
75+
"""
76+
old_state = getattr(_skip_maybe_compile, 'flag', False)
77+
_skip_maybe_compile.flag = skip
78+
try:
79+
yield
80+
finally:
81+
_skip_maybe_compile.flag = old_state
5882

5983

60-
def is_piecewise_running() -> bool:
61-
return getattr(_piecewise_running, 'flag', False)
84+
def _should_skip_maybe_compile() -> bool:
85+
"""Check if @maybe_compile should skip compilation."""
86+
return getattr(_skip_maybe_compile, 'flag', False)
6287

6388

6489
_global_attrs = threading.local()
@@ -340,19 +365,34 @@ def get_device_uuid(device_idx: int) -> str:
340365
def maybe_compile(func=None, **compile_kwargs):
341366
"""
342367
Conditionally compile a function with torch.compile.
343-
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
368+
369+
Compilation is skipped when running within a skip_maybe_compile(True) context,
370+
which is used by PiecewiseRunner to avoid compilation overhead.
371+
344372
Args:
345373
func: The function to decorate (optional, for direct decoration).
346374
**compile_kwargs: Keyword arguments for torch.compile.
347375
Returns:
348-
The conditionally compiled function..
376+
The conditionally compiled function.
377+
378+
Example:
379+
@maybe_compile
380+
def my_function(x):
381+
return x * 2
382+
383+
# Normal usage: function is compiled
384+
result = my_function(tensor)
385+
386+
# With skip_maybe_compile: function runs uncompiled
387+
with skip_maybe_compile(True):
388+
result = my_function(tensor) # Not compiled
349389
"""
350390

351391
def decorator(f):
352392
compiled_func = torch.compile(f, **compile_kwargs)
353393

354394
def wrapper(*args, **kwargs):
355-
if is_piecewise_running():
395+
if _should_skip_maybe_compile():
356396
return f(*args, **kwargs)
357397
return compiled_func(*args, **kwargs)
358398

0 commit comments

Comments
 (0)