|
1 | 1 | import torch |
2 | | -import random |
| 2 | +import torch.nn as nn |
3 | 3 |
|
4 | 4 |
|
5 | | -def crop(vid, i, j, h, w): |
6 | | - return vid[..., i:(i + h), j:(j + w)] |
| 5 | +class ConvertBHWCtoBCHW(nn.Module): |
| 6 | + """Convert tensor from (B, H, W, C) to (B, C, H, W) |
| 7 | + """ |
7 | 8 |
|
| 9 | + def forward(self, vid: torch.Tensor) -> torch.Tensor: |
| 10 | + return vid.permute(0, 3, 1, 2) |
8 | 11 |
|
9 | | -def center_crop(vid, output_size): |
10 | | - h, w = vid.shape[-2:] |
11 | | - th, tw = output_size |
12 | 12 |
|
13 | | - i = int(round((h - th) / 2.)) |
14 | | - j = int(round((w - tw) / 2.)) |
15 | | - return crop(vid, i, j, th, tw) |
| 13 | +class ConvertBCHWtoCBHW(nn.Module): |
| 14 | + """Convert tensor from (B, C, H, W) to (C, B, H, W) |
| 15 | + """ |
16 | 16 |
|
17 | | - |
18 | | -def hflip(vid): |
19 | | - return vid.flip(dims=(-1,)) |
20 | | - |
21 | | - |
22 | | -# NOTE: for those functions, which generally expect mini-batches, we keep them |
23 | | -# as non-minibatch so that they are applied as if they were 4d (thus image). |
24 | | -# this way, we only apply the transformation in the spatial domain |
25 | | -def resize(vid, size, interpolation='bilinear'): |
26 | | - # NOTE: using bilinear interpolation because we don't work on minibatches |
27 | | - # at this level |
28 | | - scale = None |
29 | | - if isinstance(size, int): |
30 | | - scale = float(size) / min(vid.shape[-2:]) |
31 | | - size = None |
32 | | - return torch.nn.functional.interpolate( |
33 | | - vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False) |
34 | | - |
35 | | - |
36 | | -def pad(vid, padding, fill=0, padding_mode="constant"): |
37 | | - # NOTE: don't want to pad on temporal dimension, so let as non-batch |
38 | | - # (4d) before padding. This works as expected |
39 | | - return torch.nn.functional.pad(vid, padding, value=fill, mode=padding_mode) |
40 | | - |
41 | | - |
42 | | -def to_normalized_float_tensor(vid): |
43 | | - return vid.permute(3, 0, 1, 2).to(torch.float32) / 255 |
44 | | - |
45 | | - |
46 | | -def normalize(vid, mean, std): |
47 | | - shape = (-1,) + (1,) * (vid.dim() - 1) |
48 | | - mean = torch.as_tensor(mean).reshape(shape) |
49 | | - std = torch.as_tensor(std).reshape(shape) |
50 | | - return (vid - mean) / std |
51 | | - |
52 | | - |
53 | | -# Class interface |
54 | | - |
55 | | -class RandomCrop(object): |
56 | | - def __init__(self, size): |
57 | | - self.size = size |
58 | | - |
59 | | - @staticmethod |
60 | | - def get_params(vid, output_size): |
61 | | - """Get parameters for ``crop`` for a random crop. |
62 | | - """ |
63 | | - h, w = vid.shape[-2:] |
64 | | - th, tw = output_size |
65 | | - if w == tw and h == th: |
66 | | - return 0, 0, h, w |
67 | | - i = random.randint(0, h - th) |
68 | | - j = random.randint(0, w - tw) |
69 | | - return i, j, th, tw |
70 | | - |
71 | | - def __call__(self, vid): |
72 | | - i, j, h, w = self.get_params(vid, self.size) |
73 | | - return crop(vid, i, j, h, w) |
74 | | - |
75 | | - |
76 | | -class CenterCrop(object): |
77 | | - def __init__(self, size): |
78 | | - self.size = size |
79 | | - |
80 | | - def __call__(self, vid): |
81 | | - return center_crop(vid, self.size) |
82 | | - |
83 | | - |
84 | | -class Resize(object): |
85 | | - def __init__(self, size): |
86 | | - self.size = size |
87 | | - |
88 | | - def __call__(self, vid): |
89 | | - return resize(vid, self.size) |
90 | | - |
91 | | - |
92 | | -class ToFloatTensorInZeroOne(object): |
93 | | - def __call__(self, vid): |
94 | | - return to_normalized_float_tensor(vid) |
95 | | - |
96 | | - |
97 | | -class Normalize(object): |
98 | | - def __init__(self, mean, std): |
99 | | - self.mean = mean |
100 | | - self.std = std |
101 | | - |
102 | | - def __call__(self, vid): |
103 | | - return normalize(vid, self.mean, self.std) |
104 | | - |
105 | | - |
106 | | -class RandomHorizontalFlip(object): |
107 | | - def __init__(self, p=0.5): |
108 | | - self.p = p |
109 | | - |
110 | | - def __call__(self, vid): |
111 | | - if random.random() < self.p: |
112 | | - return hflip(vid) |
113 | | - return vid |
114 | | - |
115 | | - |
116 | | -class Pad(object): |
117 | | - def __init__(self, padding, fill=0): |
118 | | - self.padding = padding |
119 | | - self.fill = fill |
120 | | - |
121 | | - def __call__(self, vid): |
122 | | - return pad(vid, self.padding, self.fill) |
| 17 | + def forward(self, vid: torch.Tensor) -> torch.Tensor: |
| 18 | + return vid.permute(1, 0, 2, 3) |
0 commit comments