Skip to content

Commit af31bce

Browse files
committed
chore: move functions to organize code better
1 parent c4965cd commit af31bce

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,40 @@ def aten_ops_tile(
800800
)
801801

802802

803+
def zero_output_validator(node: Node) -> bool:
804+
if 0 in node.args[1]:
805+
_LOGGER.debug(
806+
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
807+
)
808+
return False
809+
else:
810+
return True
811+
812+
813+
@dynamo_tensorrt_converter(
814+
torch.ops.aten.as_strided.default,
815+
capability_validator=zero_output_validator,
816+
)
817+
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
818+
def aten_ops_as_strided(
819+
ctx: ConversionContext,
820+
target: Target,
821+
args: Tuple[Argument, ...],
822+
kwargs: Dict[str, Argument],
823+
name: str,
824+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
825+
return impl.slice.as_strided(
826+
ctx,
827+
target,
828+
source_ir=SourceIR.ATEN,
829+
name=name,
830+
input=args[0],
831+
size=args[1],
832+
stride=args[2],
833+
storage_offset=args_bounds_check(args, 3, None),
834+
)
835+
836+
803837
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
804838
@enforce_tensor_types(
805839
{
@@ -2186,40 +2220,6 @@ def aten_ops_linear(
21862220
)
21872221

21882222

2189-
def zero_output_validator(node: Node) -> bool:
2190-
if 0 in node.args[1]:
2191-
_LOGGER.debug(
2192-
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
2193-
)
2194-
return False
2195-
else:
2196-
return True
2197-
2198-
2199-
@dynamo_tensorrt_converter(
2200-
torch.ops.aten.as_strided.default,
2201-
capability_validator=zero_output_validator,
2202-
)
2203-
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
2204-
def aten_ops_as_strided(
2205-
ctx: ConversionContext,
2206-
target: Target,
2207-
args: Tuple[Argument, ...],
2208-
kwargs: Dict[str, Argument],
2209-
name: str,
2210-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2211-
return impl.slice.as_strided(
2212-
ctx,
2213-
target,
2214-
source_ir=SourceIR.ATEN,
2215-
name=name,
2216-
input=args[0],
2217-
size=args[1],
2218-
stride=args[2],
2219-
storage_offset=args_bounds_check(args, 3, None),
2220-
)
2221-
2222-
22232223
def avg_pool_param_validator(pool_node: Node) -> bool:
22242224
ceil_mode = args_bounds_check(pool_node.args, 4, False)
22252225
divisor_override = args_bounds_check(pool_node.args, 6)

0 commit comments

Comments
 (0)