Skip to content

Commit bfecbaf

Browse files
borisfomgs-olive
authored andcommitted
Fixing clamp to not use Torch
Signed-off-by: Boris Fomitchev <[email protected]>
1 parent 521fed1 commit bfecbaf

File tree

1 file changed

+7
-13
lines changed
  • py/torch_tensorrt/dynamo/converters/impl/elementwise

1 file changed

+7
-13
lines changed

py/torch_tensorrt/dynamo/converters/impl/elementwise/clamp.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import numpy as np
12
from typing import Optional
23
import tensorrt as trt
3-
import torch
44
from torch.fx.node import Target
55

66
from torch_tensorrt.dynamo.converters import SourceIR
@@ -27,24 +27,18 @@ def add_clamp(network, input, val, op, name):
2727
acc_ops_clamp_trt = get_trt_tensor(
2828
network,
2929
squeeze_left(
30-
torch.tensor(
31-
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
30+
np.array(
31+
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY)
3232
)
3333
),
3434
f"{name}_clamp_{val}",
3535
)
3636
else:
3737
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
38-
acc_ops_clamp_tensor = (
39-
(
40-
val
41-
* torch.ones(
42-
acc_ops_clamp_shape,
43-
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
44-
)
45-
)
46-
.cpu()
47-
.numpy()
38+
acc_ops_clamp_tensor = np.full(
39+
acc_ops_clamp_shape,
40+
val,
41+
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
4842
)
4943
acc_ops_clamp_trt = network.add_constant(
5044
acc_ops_clamp_shape, acc_ops_clamp_tensor

0 commit comments

Comments
 (0)