From 31b8c4d97bee8349ffd4bffc3dc0abf5931eb9d4 Mon Sep 17 00:00:00 2001 From: Sergii Khomenko Date: Fri, 8 Oct 2021 16:44:47 +0100 Subject: [PATCH] Switch from np.frombuffer to torch.frombuffer --- test/test_image.py | 4 +--- test/test_video_reader.py | 2 +- torchvision/io/_video_opt.py | 7 +++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 35ec677ba5c..9c6a73b8362 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -476,9 +476,7 @@ def test_encode_jpeg(img_path): buf = io.BytesIO() pil_img.save(buf, format="JPEG", quality=75) - # pytorch can't read from raw bytes so we go through numpy - pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) - encoded_jpeg_pil = torch.as_tensor(pil_bytes) + encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8) for src_img in [img, img.contiguous()]: encoded_jpeg_torch = encode_jpeg(src_img, quality=75) diff --git a/test/test_video_reader.py b/test/test_video_reader.py index 282ce653322..73c4d8a1b85 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -258,7 +258,7 @@ def _get_video_tensor(video_dir, video_file): assert os.path.exists(full_path), "File not found: %s" % full_path with open(full_path, "rb") as fp: - video_tensor = torch.from_numpy(np.frombuffer(fp.read(), dtype=np.uint8)) + video_tensor = torch.frombuffer(fp.read(), dtype=torch.uint8) return full_path, video_tensor diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index a887e35e08e..d9dbc4a4f32 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -3,7 +3,6 @@ from fractions import Fraction from typing import List, Tuple -import numpy as np import torch from .._internally_replaced_utils import _get_extension_path @@ -338,7 +337,7 @@ def _read_video_from_memory( _validate_pts(audio_pts_range) if not isinstance(video_data, torch.Tensor): - video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) + video_data = torch.frombuffer(video_data, dtype=torch.uint8) result = torch.ops.video_reader.read_video_from_memory( video_data, @@ -378,7 +377,7 @@ def _read_video_timestamps_from_memory(video_data): is much faster than read_video(...) """ if not isinstance(video_data, torch.Tensor): - video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) + video_data = torch.frombuffer(video_data, dtype=torch.uint8) result = torch.ops.video_reader.read_video_from_memory( video_data, 0, # seek_frame_margin @@ -415,7 +414,7 @@ def _probe_video_from_memory(video_data): This function is torchscriptable """ if not isinstance(video_data, torch.Tensor): - video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) + video_data = torch.frombuffer(video_data, dtype=torch.uint8) result = torch.ops.video_reader.probe_video_from_memory(video_data) vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)