Skip to content

feat: revert linear converter #3703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3579,3 +3579,22 @@ def aten_ops_nonzero(
name,
args[0],
)


@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True)
def aten_ops_linear(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.linear.linear(
ctx,
target,
SourceIR.ATEN,
name,
input=args[0],
weight=args[1],
bias=args_bounds_check(args, 2, None),
)
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
embedding,
full,
grid,
linear,
matmul,
nccl_ops,
normalization,
Expand Down
56 changes: 56 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR, get_trt_tensor
from torch_tensorrt.dynamo.types import TRTTensor


def linear(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
weight: Union[TRTTensor, torch.Tensor, np.ndarray],
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
) -> TRTTensor:
# Process weight terms
if not isinstance(weight, (TRTTensor, torch.Tensor, np.ndarray)):
raise RuntimeError(
f"Linear layer {name} has weight of type {type(weight)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
)
elif isinstance(weight, (torch.Tensor, np.ndarray)):
weight = get_trt_tensor(ctx, weight, f"{name}_weight")

# Process bias terms
if bias is not None and not isinstance(bias, (TRTTensor, torch.Tensor, np.ndarray)):
raise RuntimeError(
f"Linear layer {name} has bias of type {type(bias)}, Expect Union[TRTTensor, torch.Tensor, np.ndarray],"
)
elif isinstance(bias, (torch.Tensor, np.ndarray)):
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

# add IMatrixMultiplyLayer
out = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
f"{name}_matrix_multiply",
input,
weight,
input_matrix_op=trt.MatrixOperation.NONE,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)

if bias is not None:
# add bias
out = impl.elementwise.add(
ctx, target, source_ir, f"{name}_add_bias", out, bias
)

return out
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
aten.upsample_bicubic2d.vec,
aten.linear.default,
}


Expand Down
54 changes: 54 additions & 0 deletions tests/py/dynamo/conversion/test_linear_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase


class TestLinearConverter(DispatchTestCase):
@parameterized.expand(
[
(10, 10),
(10, 100),
(100, 10),
(100, 100),
]
)
def test_linear_converter(self, in_features, out_features):
class LinearModel(nn.Module):
def __init__(self, in_features, out_features):
super(LinearModel, self).__init__()
self.linear = nn.Linear(in_features, out_features)

def forward(self, x):
return self.linear(x)

model = LinearModel(in_features, out_features).eval().cuda()
inputs = [torch.randn(int(torch.randint(1, 20, (1,))), in_features).cuda()]
self.run_test(model, inputs, use_dynamo_tracer=True, enable_passes=True)

def test_linear_with_dynamic_shape(self):
class LinearModel(torch.nn.Module):
def forward(self, x, weight, bias):
return torch.ops.aten.linear.default(x, weight, bias)

input_specs = [
Input(
dtype=torch.float32,
min_shape=(1, 10),
opt_shape=(10, 10),
max_shape=(100, 10),
),
Input(dtype=torch.float32, shape=(20, 10)),
Input(dtype=torch.float32, shape=(20,)),
]

self.run_test_with_dynamic_shape(
LinearModel(), input_specs, use_dynamo_tracer=True, enable_passes=True
)


if __name__ == "__main__":
run_tests()
Loading