@@ -62,11 +62,14 @@ class Kinetics(VisionDataset):
6262 download (bool): Download the official version of the dataset to root folder.
6363 num_workers (int): Use multiple workers for VideoClips creation
6464 num_download_workers (int): Use multiprocessing in order to speed up download.
65+ output_format (str, optional): The format of the output video tensors (before transforms).
66+ Can be either "THWC" or "TCHW" (default).
67+ Note that in most other utils and datasets, the default is actually "THWC".
6568
6669 Returns:
6770 tuple: A 3-tuple with the following entries:
6871
69- - video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor
72+ - video (Tensor[T, C, H, W] or Tensor[T, H, W, C] ): the `T` video frames in torch.uint8 tensor
7073 - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
7174 and `L` is the number of points in torch.float tensor
7275 - label (int): class of the video clip
@@ -106,6 +109,7 @@ def __init__(
106109 _audio_samples : int = 0 ,
107110 _audio_channels : int = 0 ,
108111 _legacy : bool = False ,
112+ output_format : str = "TCHW" ,
109113 ) -> None :
110114
111115 # TODO: support test
@@ -115,10 +119,12 @@ def __init__(
115119
116120 self .root = root
117121 self ._legacy = _legacy
122+
118123 if _legacy :
119124 print ("Using legacy structure" )
120125 self .split_folder = root
121126 self .split = "unknown"
127+ output_format = "THWC"
122128 if download :
123129 raise ValueError ("Cannot download the videos using legacy_structure." )
124130 else :
@@ -145,6 +151,7 @@ def __init__(
145151 _video_min_dimension = _video_min_dimension ,
146152 _audio_samples = _audio_samples ,
147153 _audio_channels = _audio_channels ,
154+ output_format = output_format ,
148155 )
149156 self .transform = transform
150157
@@ -233,9 +240,6 @@ def __len__(self) -> int:
233240
234241 def __getitem__ (self , idx : int ) -> Tuple [Tensor , Tensor , int ]:
235242 video , audio , info , video_idx = self .video_clips .get_clip (idx )
236- if not self ._legacy :
237- # [T,H,W,C] --> [T,C,H,W]
238- video = video .permute (0 , 3 , 1 , 2 )
239243 label = self .samples [video_idx ][1 ]
240244
241245 if self .transform is not None :
@@ -308,7 +312,7 @@ def __init__(
308312 warnings .warn (
309313 "The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
310314 "Please use Kinetics(..., num_classes='400') instead."
311- "Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format."
315+ "Note that Kinetics(..., num_classes='400') returns video in a Tensor[T, C, H, W] format."
312316 )
313317 if any (value is not None for value in (num_classes , split , download , num_download_workers )):
314318 raise RuntimeError (
0 commit comments