File tree Expand file tree Collapse file tree 2 files changed +25
-24
lines changed
py/torch_tensorrt/fx/converters Expand file tree Collapse file tree 2 files changed +25
-24
lines changed Original file line number Diff line number Diff line change @@ -543,3 +543,27 @@ def type_cast(
543543 layer_i .set_output_type (0 , cast_type )
544544 set_layer_name (layer_i , target , f"{ name } _dtype_change" )
545545 return layer_i .get_output (0 )
546+
547+
548+ def to_numpy (tensor : Optional [torch .Tensor ]) -> Optional [np .ndarray ]:
549+ """
550+ Convert a PyTorch Tensor to a Numpy Array. If the tensor is
551+ quantized it will be dequantized first.
552+
553+ Args:
554+ tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
555+
556+ Returns:
557+ A Numpy array.
558+ """
559+
560+ if tensor is None :
561+ return tensor
562+
563+ assert isinstance (
564+ tensor , torch .Tensor
565+ ), f"to_numpy can only be called on None or a torch.Tensor, got: { tensor } "
566+ if tensor .is_quantized :
567+ tensor = tensor .dequantize ()
568+
569+ return tensor .cpu ().detach ().contiguous ().numpy ()
Original file line number Diff line number Diff line change 2222from .converter_utils import prepend_ones
2323from .converter_utils import has_dynamic_shape
2424from .converter_utils import get_shape_with_dynamic_shape
25+ from .converter_utils import to_numpy
2526
2627from ..types import (
2728 Shape ,
@@ -278,30 +279,6 @@ def trunc_div(
278279 return output
279280
280281
281- def to_numpy (tensor : Optional [torch .Tensor ]) -> Optional [np .ndarray ]:
282- """
283- Convert a PyTorch Tensor to a Numpy Array. If the tensor is
284- quantized it will be dequantized first.
285-
286- Args:
287- tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
288-
289- Returns:
290- A Numpy array.
291- """
292-
293- if tensor is None :
294- return tensor
295-
296- assert isinstance (
297- tensor , torch .Tensor
298- ), f"to_numpy can only be called on None or a torch.Tensor, got: { tensor } "
299- if tensor .is_quantized :
300- tensor = tensor .dequantize ()
301-
302- return tensor .cpu ().detach ().contiguous ().numpy ()
303-
304-
305282def trt_dtype_to_torch_dtype (trt_dtype ):
306283 table = {
307284 trt .bool : torch .bool ,
You can’t perform that action at this time.
0 commit comments