|
14 | 14 | from ..utils import (get_model_extra_attrs, |
15 | 15 | get_per_request_piecewise_cuda_graph_flag, |
16 | 16 | get_piecewise_cuda_graph_flag, make_weak_ref, |
17 | | - set_piecewise_running) |
| 17 | + skip_maybe_compile) |
18 | 18 | from .multi_stream.auto_multi_stream import multi_stream_schedule |
19 | 19 | from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function |
20 | 20 |
|
@@ -171,68 +171,73 @@ def __call__(self, *args): |
171 | 171 | or not get_per_request_piecewise_cuda_graph_flag()): |
172 | 172 | return self.default_callable(*args) |
173 | 173 |
|
174 | | - if self.is_first_runner or self.is_last_runner: |
175 | | - if self.is_first_runner == self.is_last_runner: |
176 | | - set_piecewise_running(False) |
177 | | - else: |
178 | | - set_piecewise_running(self.is_first_runner) |
179 | | - |
180 | | - entry = self.entries[runtime_num_of_token] |
181 | | - |
182 | | - if entry.enable_inductor and not entry.compiled: |
183 | | - entry.callable = compile_fx(entry.callable, args) |
184 | | - entry.compiled = True |
185 | | - |
186 | | - if entry.cuda_graph is None: |
187 | | - |
188 | | - if not get_capture_piecewise_cuda_graph_flag(): |
189 | | - return entry.callable(*args) |
190 | | - |
191 | | - if entry.warmup_count < 3: |
192 | | - entry.warmup_count += 1 |
193 | | - return entry.callable(*args) |
194 | | - |
195 | | - entry.input_addresses = [ |
196 | | - i.data_ptr() for i in args if isinstance(i, torch.Tensor) |
197 | | - ] |
198 | | - |
199 | | - graph = torch.cuda.CUDAGraph() |
200 | | - |
201 | | - # Torch's cuda graph will call gc.collect() internally. This will slow down the performance. |
202 | | - # We patch it to do nothing. |
203 | | - with patch("gc.collect", lambda: None): |
204 | | - # TODO: consider to use `make_graphed_callables()` when |
205 | | - # it's ready rather than capture it ourselves |
206 | | - # Graph Capture would override the stream. We need to setup the stream correctly. |
207 | | - extra_attrs = get_model_extra_attrs() |
208 | | - with torch.cuda.graph(graph, pool=self.graph_pool_handle): |
| 174 | + # Determine if we should skip compilation in @maybe_compile decorated functions: |
| 175 | + # - First runner only: skip compilation (to avoid overhead) |
| 176 | + # - Last runner only: skip compilation (to avoid overhead) |
| 177 | + # - Both first and last (single runner): allow compilation (normal mode) |
| 178 | + # - Middle runner: allow compilation (normal mode) |
| 179 | + should_skip = (self.is_first_runner or self.is_last_runner) and \ |
| 180 | + not (self.is_first_runner and self.is_last_runner) |
| 181 | + |
| 182 | + # Use context manager to directly control @maybe_compile behavior |
| 183 | + # This makes the relationship explicit: PiecewiseRunner → skip_maybe_compile → @maybe_compile |
| 184 | + with skip_maybe_compile(should_skip): |
| 185 | + entry = self.entries[runtime_num_of_token] |
| 186 | + |
| 187 | + if entry.enable_inductor and not entry.compiled: |
| 188 | + entry.callable = compile_fx(entry.callable, args) |
| 189 | + entry.compiled = True |
| 190 | + |
| 191 | + if entry.cuda_graph is None: |
| 192 | + |
| 193 | + if not get_capture_piecewise_cuda_graph_flag(): |
| 194 | + return entry.callable(*args) |
| 195 | + |
| 196 | + if entry.warmup_count < 3: |
| 197 | + entry.warmup_count += 1 |
| 198 | + return entry.callable(*args) |
| 199 | + |
| 200 | + entry.input_addresses = [ |
| 201 | + i.data_ptr() for i in args if isinstance(i, torch.Tensor) |
| 202 | + ] |
| 203 | + |
| 204 | + graph = torch.cuda.CUDAGraph() |
| 205 | + |
| 206 | + # Torch's cuda graph will call gc.collect() internally. This will slow down the performance. |
| 207 | + # We patch it to do nothing. |
| 208 | + with patch("gc.collect", lambda: None): |
| 209 | + # TODO: consider to use `make_graphed_callables()` when |
| 210 | + # it's ready rather than capture it ourselves |
| 211 | + # Graph Capture would override the stream. We need to setup the stream correctly. |
| 212 | + extra_attrs = get_model_extra_attrs() |
| 213 | + with torch.cuda.graph(graph, pool=self.graph_pool_handle): |
| 214 | + extra_attrs["global_stream"] = torch.cuda.current_stream() |
| 215 | + output = entry.callable(*args) |
209 | 216 | extra_attrs["global_stream"] = torch.cuda.current_stream() |
210 | | - output = entry.callable(*args) |
211 | | - extra_attrs["global_stream"] = torch.cuda.current_stream() |
212 | 217 |
|
213 | | - entry.cuda_graph = graph |
214 | | - # Mark weak ref here. The intermediate activation tensor should be freed properly. |
215 | | - # Here we don't use python native weakref since we still need the object to be alive when the graph is replayed. |
216 | | - entry.output = make_weak_ref(output) |
217 | | - entry.output_addresses = [ |
218 | | - i.data_ptr() for i in output if isinstance(i, torch.Tensor) |
219 | | - ] |
| 218 | + entry.cuda_graph = graph |
| 219 | + # Mark weak ref here. The intermediate activation tensor should be freed properly. |
| 220 | + # Here we don't use python native weakref since we still need the object to be alive when the graph is replayed. |
| 221 | + entry.output = make_weak_ref(output) |
| 222 | + entry.output_addresses = [ |
| 223 | + i.data_ptr() for i in output if isinstance(i, torch.Tensor) |
| 224 | + ] |
220 | 225 |
|
221 | | - entry.cuda_graph.replay() |
| 226 | + entry.cuda_graph.replay() |
222 | 227 |
|
223 | | - return output |
| 228 | + return output |
224 | 229 |
|
225 | | - if enable_llm_debug(): |
226 | | - runtime_input_addresses = [ |
227 | | - i.data_ptr() for i in args if isinstance(i, torch.Tensor) |
228 | | - ] |
| 230 | + if enable_llm_debug(): |
| 231 | + runtime_input_addresses = [ |
| 232 | + i.data_ptr() for i in args if isinstance(i, torch.Tensor) |
| 233 | + ] |
229 | 234 |
|
230 | | - assert (entry.input_addresses == runtime_input_addresses |
231 | | - ), f"{entry.input_addresses} vs\n {runtime_input_addresses}" |
| 235 | + assert (entry.input_addresses == runtime_input_addresses |
| 236 | + ), f"{entry.input_addresses} vs\n {runtime_input_addresses}" |
232 | 237 |
|
233 | | - entry.cuda_graph.replay() |
| 238 | + entry.cuda_graph.replay() |
234 | 239 |
|
235 | | - return entry.output |
| 240 | + return entry.output |
236 | 241 |
|
237 | 242 |
|
238 | 243 | def piecewise_optimizer( |
|
0 commit comments