Skip to content

Commit d467afa

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] Add output_format do video datasets and readers (#6061)
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]> Reviewed By: NicolasHug Differential Revision: D36760918 fbshipit-source-id: bfaa11b43cb0ebffb41b0e24fef1b6b65b6deef4
1 parent d9b227b commit d467afa

File tree

5 files changed

+40
-8
lines changed

5 files changed

+40
-8
lines changed

torchvision/datasets/hmdb51.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ class HMDB51(VisionDataset):
3737
otherwise from the ``test`` split.
3838
transform (callable, optional): A function/transform that takes in a TxHxWxC video
3939
and returns a transformed version.
40+
output_format (str, optional): The format of the output video tensors (before transforms).
41+
Can be either "THWC" (default) or "TCHW".
4042
4143
Returns:
4244
tuple: A 3-tuple with the following entries:
4345
44-
- video (Tensor[T, H, W, C]): The `T` video frames
46+
- video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
4547
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
4648
and `L` is the number of points
4749
- label (int): class of the video clip
@@ -71,6 +73,7 @@ def __init__(
7173
_video_height: int = 0,
7274
_video_min_dimension: int = 0,
7375
_audio_samples: int = 0,
76+
output_format: str = "THWC",
7477
) -> None:
7578
super().__init__(root)
7679
if fold not in (1, 2, 3):
@@ -96,6 +99,7 @@ def __init__(
9699
_video_height=_video_height,
97100
_video_min_dimension=_video_min_dimension,
98101
_audio_samples=_audio_samples,
102+
output_format=output_format,
99103
)
100104
# we bookkeep the full version of video clips because we want to be able
101105
# to return the meta data of full version rather than the subset version of

torchvision/datasets/kinetics.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,14 @@ class Kinetics(VisionDataset):
6262
download (bool): Download the official version of the dataset to root folder.
6363
num_workers (int): Use multiple workers for VideoClips creation
6464
num_download_workers (int): Use multiprocessing in order to speed up download.
65+
output_format (str, optional): The format of the output video tensors (before transforms).
66+
Can be either "THWC" or "TCHW" (default).
67+
Note that in most other utils and datasets, the default is actually "THWC".
6568
6669
Returns:
6770
tuple: A 3-tuple with the following entries:
6871
69-
- video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor
72+
- video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
7073
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
7174
and `L` is the number of points in torch.float tensor
7275
- label (int): class of the video clip
@@ -106,6 +109,7 @@ def __init__(
106109
_audio_samples: int = 0,
107110
_audio_channels: int = 0,
108111
_legacy: bool = False,
112+
output_format: str = "TCHW",
109113
) -> None:
110114

111115
# TODO: support test
@@ -115,10 +119,12 @@ def __init__(
115119

116120
self.root = root
117121
self._legacy = _legacy
122+
118123
if _legacy:
119124
print("Using legacy structure")
120125
self.split_folder = root
121126
self.split = "unknown"
127+
output_format = "THWC"
122128
if download:
123129
raise ValueError("Cannot download the videos using legacy_structure.")
124130
else:
@@ -145,6 +151,7 @@ def __init__(
145151
_video_min_dimension=_video_min_dimension,
146152
_audio_samples=_audio_samples,
147153
_audio_channels=_audio_channels,
154+
output_format=output_format,
148155
)
149156
self.transform = transform
150157

@@ -233,9 +240,6 @@ def __len__(self) -> int:
233240

234241
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
235242
video, audio, info, video_idx = self.video_clips.get_clip(idx)
236-
if not self._legacy:
237-
# [T,H,W,C] --> [T,C,H,W]
238-
video = video.permute(0, 3, 1, 2)
239243
label = self.samples[video_idx][1]
240244

241245
if self.transform is not None:
@@ -308,7 +312,7 @@ def __init__(
308312
warnings.warn(
309313
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
310314
"Please use Kinetics(..., num_classes='400') instead."
311-
"Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format."
315+
"Note that Kinetics(..., num_classes='400') returns video in a Tensor[T, C, H, W] format."
312316
)
313317
if any(value is not None for value in (num_classes, split, download, num_download_workers)):
314318
raise RuntimeError(

torchvision/datasets/ucf101.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ class UCF101(VisionDataset):
3838
otherwise from the ``test`` split.
3939
transform (callable, optional): A function/transform that takes in a TxHxWxC video
4040
and returns a transformed version.
41+
output_format (str, optional): The format of the output video tensors (before transforms).
42+
Can be either "THWC" (default) or "TCHW".
4143
4244
Returns:
4345
tuple: A 3-tuple with the following entries:
4446
45-
- video (Tensor[T, H, W, C]): the `T` video frames
47+
- video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
4648
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
4749
and `L` is the number of points
4850
- label (int): class of the video clip
@@ -64,6 +66,7 @@ def __init__(
6466
_video_height: int = 0,
6567
_video_min_dimension: int = 0,
6668
_audio_samples: int = 0,
69+
output_format: str = "THWC",
6770
) -> None:
6871
super().__init__(root)
6972
if not 1 <= fold <= 3:
@@ -87,6 +90,7 @@ def __init__(
8790
_video_height=_video_height,
8891
_video_min_dimension=_video_min_dimension,
8992
_audio_samples=_audio_samples,
93+
output_format=output_format,
9094
)
9195
# we bookkeep the full version of video clips because we want to be able
9296
# to return the meta data of full version rather than the subset version of

torchvision/datasets/video_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class VideoClips:
9999
on the resampled video
100100
num_workers (int): how many subprocesses to use for data loading.
101101
0 means that the data will be loaded in the main process. (default: 0)
102+
output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
102103
"""
103104

104105
def __init__(
@@ -115,6 +116,7 @@ def __init__(
115116
_video_max_dimension: int = 0,
116117
_audio_samples: int = 0,
117118
_audio_channels: int = 0,
119+
output_format: str = "THWC",
118120
) -> None:
119121

120122
self.video_paths = video_paths
@@ -127,6 +129,9 @@ def __init__(
127129
self._video_max_dimension = _video_max_dimension
128130
self._audio_samples = _audio_samples
129131
self._audio_channels = _audio_channels
132+
self.output_format = output_format.upper()
133+
if self.output_format not in ("THWC", "TCHW"):
134+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
130135

131136
if _precomputed_metadata is None:
132137
self._compute_frame_pts()
@@ -366,6 +371,11 @@ def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]
366371
video = video[resampling_idx]
367372
info["video_fps"] = self.frame_rate
368373
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
374+
375+
if self.output_format == "TCHW":
376+
# [T,H,W,C] --> [T,C,H,W]
377+
video = video.permute(0, 3, 1, 2)
378+
369379
return video, audio, info, video_idx
370380

371381
def __getstate__(self) -> Dict[str, Any]:

torchvision/io/video.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def read_video(
239239
start_pts: Union[float, Fraction] = 0,
240240
end_pts: Optional[Union[float, Fraction]] = None,
241241
pts_unit: str = "pts",
242+
output_format: str = "THWC",
242243
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
243244
"""
244245
Reads a video from a file, returning both the video frames as well as
@@ -252,15 +253,20 @@ def read_video(
252253
The end presentation time
253254
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
254255
either 'pts' or 'sec'. Defaults to 'pts'.
256+
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
255257
256258
Returns:
257-
vframes (Tensor[T, H, W, C]): the `T` video frames
259+
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
258260
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
259261
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
260262
"""
261263
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
262264
_log_api_usage_once(read_video)
263265

266+
output_format = output_format.upper()
267+
if output_format not in ("THWC", "TCHW"):
268+
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
269+
264270
from torchvision import get_video_backend
265271

266272
if not os.path.exists(filename):
@@ -334,6 +340,10 @@ def read_video(
334340
else:
335341
aframes = torch.empty((1, 0), dtype=torch.float32)
336342

343+
if output_format == "TCHW":
344+
# [T,H,W,C] --> [T,C,H,W]
345+
vframes = vframes.permute(0, 3, 1, 2)
346+
337347
return vframes, aframes, info
338348

339349

0 commit comments

Comments
 (0)