Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 79c4194

Browse files
stephenyan1231facebook-github-bot
authored andcommitted
new video data augmentation transform (#53)
Summary: Pull Request resolved: #53 - We add video transforms in TorchVision: pytorch/vision#1306 - In ClassyVision, we add default transforms for training / test stage. Alternatively, user can also explicitly provide transform config in json config input. See an example in the unit test. - Video data transforms supports audio modality in the video dataset. - Compared with image transforms which only returns a torch.Tensor, video transforms return a dict where key is the modality name (e.g. {"video", "audio"}) and value is a torch.Tensor for the modality data. Reviewed By: taylorgordon20 Differential Revision: D16999453 fbshipit-source-id: 112b66a3965cba4201bbb12c99f3fdd2f1fce86f
1 parent 1f8755f commit 79c4194

File tree

7 files changed

+345
-5
lines changed

7 files changed

+345
-5
lines changed

.circleci/config.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ jobs:
1616
# Download and cache dependencies
1717
- restore_cache:
1818
keys:
19-
- v2-dependencies-{{ checksum "requirements.txt" }}
19+
- v3-dependencies-{{ checksum "requirements.txt" }}
2020
# fallback to using the latest cache if no exact match is found
21-
- v2-dependencies-
21+
- v3-dependencies-
2222

2323
- run:
2424
name: install dependencies
@@ -31,7 +31,7 @@ jobs:
3131
- save_cache:
3232
paths:
3333
- ./venv
34-
key: v2-dependencies-{{ checksum "requirements.txt" }}
34+
key: v3-dependencies-{{ checksum "requirements.txt" }}
3535

3636
- run:
3737
name: run tests

classy_vision/dataset/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from .dataset import Dataset
99
from .list_dataset import ListDataset
1010
from .random_image_datasets import RandomImageBinaryClassDataset, RandomImageDataset
11+
from .random_video_datasets import RandomVideoDataset
1112
from .resample_dataset import ResampleDataset
1213
from .shuffle_dataset import ShuffleDataset
1314
from .transform_dataset import TransformDataset
1415
from .wrap_dataset import WrapDataset
16+
from .wrap_torchvision_video_dataset import WrapTorchVisionVideoDataset
1517

1618

1719
# TODO: Fix this:
@@ -23,8 +25,10 @@
2325
"ListDataset",
2426
"RandomImageBinaryClassDataset",
2527
"RandomImageDataset",
28+
"RandomVideoDataset",
2629
"ResampleDataset",
2730
"ShuffleDataset",
2831
"TransformDataset",
2932
"WrapDataset",
33+
"WrapTorchVisionVideoDataset",
3034
]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from ...generic.util import torch_seed
10+
from .dataset import Dataset
11+
12+
13+
class RandomVideoDataset(Dataset):
14+
def __init__(
15+
self,
16+
num_classes,
17+
split,
18+
num_samples,
19+
frames_per_clip,
20+
video_width,
21+
video_height,
22+
audio_samples,
23+
clips_per_video,
24+
seed=10,
25+
):
26+
self.num_classes = num_classes
27+
self.split = split
28+
# video config
29+
self.video_channels = 3
30+
self.num_samples = num_samples
31+
self.frames_per_clip = frames_per_clip
32+
self.video_width = video_width
33+
self.video_height = video_height
34+
# audio config
35+
self.audio_samples = audio_samples
36+
self.clips_per_video = clips_per_video
37+
# misc config
38+
self.seed = seed
39+
40+
def __getitem__(self, idx):
41+
if self.split == "train":
42+
# assume we only sample 1 clip from each training video
43+
target_seed_offset = idx
44+
else:
45+
# for video model testing, clips from the same video share the same
46+
# target label
47+
target_seed_offset = idx // self.clips_per_video
48+
with torch_seed(self.seed + target_seed_offset):
49+
target = torch.randint(0, self.num_classes, (1,)).item()
50+
51+
with torch_seed(self.seed + idx):
52+
return {
53+
"input": {
54+
"video": torch.randint(
55+
0,
56+
256,
57+
(
58+
self.frames_per_clip,
59+
self.video_height,
60+
self.video_width,
61+
self.video_channels,
62+
),
63+
dtype=torch.uint8,
64+
),
65+
"audio": torch.rand((self.audio_samples, 1), dtype=torch.float),
66+
},
67+
"target": target,
68+
}
69+
70+
def __len__(self):
71+
return self.num_samples
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .dataset import Dataset
8+
9+
10+
class WrapTorchVisionVideoDataset(Dataset):
11+
"""
12+
Wraps a TorchVision video dataset into our core dataset interface.
13+
A video dataset can contain both video and audio data
14+
"""
15+
16+
def __init__(self, dataset):
17+
import torch.utils.data
18+
19+
assert isinstance(dataset, torch.utils.data.Dataset)
20+
super(WrapTorchVisionVideoDataset, self).__init__()
21+
self.dataset = dataset
22+
23+
def __getitem__(self, idx):
24+
video, audio, target = self.dataset[idx]
25+
return {"input": {"video": video, "audio": audio}, "target": target}
26+
27+
def __len__(self):
28+
return len(self.dataset)
29+
30+
def get_classy_state(self):
31+
# Pytorch datasets don't have state
32+
return {
33+
# For debugging saved states
34+
"state": {"dataset_type": type(self)}
35+
}

classy_vision/dataset/transforms/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Callable, Dict, List
1010

1111
import torchvision.transforms as transforms
12+
import torchvision.transforms._transforms_video as transforms_video
1213
from classy_vision.generic.registry_utils import import_all_modules
1314

1415
from .classy_transform import ClassyTransform
@@ -34,11 +35,14 @@ def build_transform(transform_config: Dict[str, Any]) -> Callable:
3435
if name in TRANSFORM_REGISTRY:
3536
return TRANSFORM_REGISTRY[name].from_config(transform_args)
3637
# the name should be available in torchvision.transforms
37-
assert hasattr(transforms, name), (
38+
assert hasattr(transforms, name) or hasattr(transforms_video, name), (
3839
f"{name} isn't a registered tranform"
3940
", nor is it available in torchvision.transforms"
4041
)
41-
return getattr(transforms, name)(**transform_args)
42+
if hasattr(transforms, name):
43+
return getattr(transforms, name)(**transform_args)
44+
else:
45+
return getattr(transforms_video, name)(**transform_args)
4246

4347

4448
def build_transforms(transforms_config: List[Dict[str, Any]]) -> Callable:
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Callable, Dict, List, Optional
8+
9+
import torch
10+
import torchvision.transforms as transforms
11+
import torchvision.transforms._transforms_video as transforms_video
12+
13+
from . import ClassyTransform, build_transforms, register_transform
14+
from .util import FieldTransform, ImagenetConstants
15+
16+
17+
class VideoConstants:
18+
"""use the same mean/std from image classification to enable the parameter
19+
inflation where parameters of 2D conv in image model can be inflated into
20+
3D conv in video model"""
21+
22+
MEAN = ImagenetConstants.MEAN
23+
STD = ImagenetConstants.STD
24+
CROP_SIZE = 112
25+
26+
27+
@register_transform("video_default_augment")
28+
class VideoDefaultAugmentTransform(ClassyTransform):
29+
def __init__(
30+
self,
31+
crop_size: int = VideoConstants.CROP_SIZE,
32+
mean: List[float] = VideoConstants.MEAN,
33+
std: List[float] = VideoConstants.STD,
34+
):
35+
self._transform = transforms.Compose(
36+
[
37+
transforms_video.ToTensorVideo(),
38+
transforms_video.RandomResizedCropVideo(crop_size),
39+
transforms_video.RandomHorizontalFlipVideo(),
40+
transforms_video.NormalizeVideo(mean=mean, std=std),
41+
]
42+
)
43+
44+
def __call__(self, video):
45+
return self._transform(video)
46+
47+
48+
@register_transform("video_default_no_augment")
49+
class VideoDefaultNoAugmentTransform(ClassyTransform):
50+
def __init__(
51+
self,
52+
mean: List[float] = VideoConstants.MEAN,
53+
std: List[float] = VideoConstants.STD,
54+
):
55+
self._transform = transforms.Compose(
56+
# At testing stage, central cropping is not used because we
57+
# conduct fully convolutional-style testing
58+
[
59+
transforms_video.ToTensorVideo(),
60+
transforms_video.NormalizeVideo(mean=mean, std=std),
61+
]
62+
)
63+
64+
def __call__(self, video):
65+
return self._transform(video)
66+
67+
68+
@register_transform("dummy_audio_transform")
69+
class DummyAudioTransform(ClassyTransform):
70+
"""
71+
A dummy audio transform. It ignores actual audio data, and returns an empty tensor.
72+
It is useful when actual audio data is raw waveform and has a varying number of
73+
waveform samples which makes minibatch assembling impossible
74+
"""
75+
76+
def __init__(self):
77+
pass
78+
79+
def __call__(self, _audio):
80+
return torch.zeros(0, 1, dtype=torch.float)
81+
82+
83+
class ClassyVideoGenericTransform(object):
84+
def __init__(
85+
self,
86+
config: Optional[Dict[str, List[Dict[str, Any]]]] = None,
87+
split: str = "train",
88+
):
89+
self.transforms = {
90+
"video": VideoDefaultAugmentTransform()
91+
if split == "train"
92+
else VideoDefaultNoAugmentTransform(),
93+
"audio": DummyAudioTransform(),
94+
}
95+
if config is not None:
96+
for mode, modal_config in config.items():
97+
assert mode in ["video", "audio"], (
98+
"unknown video data modality %s" % mode
99+
)
100+
self.transforms[mode] = build_transforms(modal_config)
101+
102+
def __call__(self, video):
103+
assert isinstance(video, dict), "video data is expected be a dict"
104+
for mode, modal_data in video.items():
105+
if mode in self.transforms:
106+
video[mode] = self.transforms[mode](modal_data)
107+
return video
108+
109+
110+
def build_video_field_transform_default(
111+
config: Optional[Dict[str, List[Dict[str, Any]]]],
112+
split: str = "train",
113+
key: str = "input",
114+
) -> Callable:
115+
"""
116+
Returns a FieldTransform which applies a transform on the specified key.
117+
118+
"""
119+
transform = ClassyVideoGenericTransform(config, split)
120+
return FieldTransform(transform, key=key)

0 commit comments

Comments
 (0)