-
Notifications
You must be signed in to change notification settings - Fork 370
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
The first dim of the output disappeared when perform inference with the cross compiled exported program on Windows.
To Reproduce
Run on Linux
from __future__ import annotations
import os
import torch
import torch_tensorrt
os.environ["CI_BUILD"] = "1"
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + x
with torch.inference_mode():
model = MyModule().eval().cuda()
inputs = (torch.zeros(2, 4, 6, 8, dtype=torch.float, device="cuda"),)
exported_program = torch.export.export(model, inputs)
trt_model = torch_tensorrt.dynamo.cross_compile_for_windows(
exported_program,
inputs,
enabled_precisions={torch.float},
debug=True,
min_block_size=1,
)
torch_tensorrt.dynamo.save_cross_compiled_exported_program(trt_model, "trt_windows.ep")
TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/16/2025-22:50:47] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
return (add,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
return (add,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
return (add,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
return (add,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
return (add,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.add.Tensor + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.add.Tensor + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(2, 4, 6, 8)]
graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {})
return add
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[2, 4, 6, 8], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (2, 4, 6, 8)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /add (kind: aten.add.Tensor, args: ('x <Node>', 'x <Node>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.add.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.add.Tensor
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /add [aten.add.Tensor] (Inputs: (x: (2, 4, 6, 8)@torch.float32, x: (2, 4, 6, 8)@torch.float32) | Outputs: (add: (2, 4, 6, 8)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('add <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(2, 4, 6, 8), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (add: (2, 4, 6, 8)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.006162
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Setting runtime_platform as trt.RuntimePlatform.WINDOWS_AMD64
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.195916
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 16356 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++
The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage
Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=True, use_aot_joint_export=True)
Graph Structure:
Inputs: List[Tensor: (2, 4, 6, 8)@float32]
...
TRT Engine #1 - Submodule name: _run_on_acc_0
Engine Inputs: List[Tensor: (2, 4, 6, 8)@float32]
Number of Operators in Engine: 1
Engine Outputs: List[Tensor: (2, 4, 6, 8)@float32]
...
Outputs: List[Tensor: (2, 4, 6, 8)@float32]
------------------------- Aggregate Stats -------------------------
Average Number of Operators per TRT Engine: 1.0
Most Operators in a TRT Engine: 1
********** Recommendations **********
- For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
- The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
DEBUG:torch_tensorrt.dynamo._compiler:successfully saved the module for windows at trt_windows.ep
Run on Windows
from __future__ import annotations
import os
import torch
import torch_tensorrt
with torch.inference_mode():
model = torch_tensorrt.dynamo.load_cross_compiled_exported_program("trt_windows.ep").module()
inputs = (torch.randn(2, 4, 6, 8, dtype=torch.float, device="cuda"),)
output = model(*inputs)
print(f"{output.shape=}")
TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/16/2025-22:51:54] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
output.shape=torch.Size([4, 6, 8])
Expected behavior
output.shape
should be [2, 4, 6, 8]
.
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.6.0+cu126
- PyTorch Version (e.g. 1.0): 2.6.0+cu126
- CPU Architecture: x64
- OS (e.g., Linux): Ubuntu 24.04 and Windows 11
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.12.3
- CUDA version: 12.6
- GPU models and configuration: RTX 4060 Ti
- Any other relevant information:
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working