Skip to content

Commit ee26e9c

Browse files
authored
Add preprocessing information on the Weights documentation (#6009)
* Adding `__repr__` in presets * Adds `describe()` methods to all presets. * Adding transform descriptions in the documentation. * Change "preprocessing" to "inference"
1 parent c67a583 commit ee26e9c

File tree

2 files changed

+90
-24
lines changed

2 files changed

+90
-24
lines changed

docs/source/conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
366366
lines += [".. table::", ""]
367367
lines += textwrap.indent(table, " " * 4).split("\n")
368368
lines.append("")
369+
lines.append(
370+
f"The inference transforms are available at ``{str(field)}.transforms`` and "
371+
f"perform the following operations: {field.transforms().describe()}"
372+
)
373+
lines.append("")
369374

370375

371376
def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None):

torchvision/transforms/_presets.py

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def forward(self, img: Tensor) -> Tensor:
2525
img = F.pil_to_tensor(img)
2626
return F.convert_image_dtype(img, torch.float)
2727

28+
def __repr__(self) -> str:
29+
return self.__class__.__name__ + "()"
30+
31+
def describe(self) -> str:
32+
return "The images are rescaled to ``[0.0, 1.0]``."
33+
2834

2935
class ImageClassification(nn.Module):
3036
def __init__(
@@ -37,21 +43,38 @@ def __init__(
3743
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
3844
) -> None:
3945
super().__init__()
40-
self._crop_size = [crop_size]
41-
self._size = [resize_size]
42-
self._mean = list(mean)
43-
self._std = list(std)
44-
self._interpolation = interpolation
46+
self.crop_size = [crop_size]
47+
self.resize_size = [resize_size]
48+
self.mean = list(mean)
49+
self.std = list(std)
50+
self.interpolation = interpolation
4551

4652
def forward(self, img: Tensor) -> Tensor:
47-
img = F.resize(img, self._size, interpolation=self._interpolation)
48-
img = F.center_crop(img, self._crop_size)
53+
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
54+
img = F.center_crop(img, self.crop_size)
4955
if not isinstance(img, Tensor):
5056
img = F.pil_to_tensor(img)
5157
img = F.convert_image_dtype(img, torch.float)
52-
img = F.normalize(img, mean=self._mean, std=self._std)
58+
img = F.normalize(img, mean=self.mean, std=self.std)
5359
return img
5460

61+
def __repr__(self) -> str:
62+
format_string = self.__class__.__name__ + "("
63+
format_string += f"\n crop_size={self.crop_size}"
64+
format_string += f"\n resize_size={self.resize_size}"
65+
format_string += f"\n mean={self.mean}"
66+
format_string += f"\n std={self.std}"
67+
format_string += f"\n interpolation={self.interpolation}"
68+
format_string += "\n)"
69+
return format_string
70+
71+
def describe(self) -> str:
72+
return (
73+
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
74+
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
75+
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
76+
)
77+
5578

5679
class VideoClassification(nn.Module):
5780
def __init__(
@@ -64,11 +87,11 @@ def __init__(
6487
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
6588
) -> None:
6689
super().__init__()
67-
self._crop_size = list(crop_size)
68-
self._size = list(resize_size)
69-
self._mean = list(mean)
70-
self._std = list(std)
71-
self._interpolation = interpolation
90+
self.crop_size = list(crop_size)
91+
self.resize_size = list(resize_size)
92+
self.mean = list(mean)
93+
self.std = list(std)
94+
self.interpolation = interpolation
7295

7396
def forward(self, vid: Tensor) -> Tensor:
7497
need_squeeze = False
@@ -79,18 +102,35 @@ def forward(self, vid: Tensor) -> Tensor:
79102
vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W)
80103
N, T, C, H, W = vid.shape
81104
vid = vid.view(-1, C, H, W)
82-
vid = F.resize(vid, self._size, interpolation=self._interpolation)
83-
vid = F.center_crop(vid, self._crop_size)
105+
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
106+
vid = F.center_crop(vid, self.crop_size)
84107
vid = F.convert_image_dtype(vid, torch.float)
85-
vid = F.normalize(vid, mean=self._mean, std=self._std)
86-
H, W = self._crop_size
108+
vid = F.normalize(vid, mean=self.mean, std=self.std)
109+
H, W = self.crop_size
87110
vid = vid.view(N, T, C, H, W)
88111
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
89112

90113
if need_squeeze:
91114
vid = vid.squeeze(dim=0)
92115
return vid
93116

117+
def __repr__(self) -> str:
118+
format_string = self.__class__.__name__ + "("
119+
format_string += f"\n crop_size={self.crop_size}"
120+
format_string += f"\n resize_size={self.resize_size}"
121+
format_string += f"\n mean={self.mean}"
122+
format_string += f"\n std={self.std}"
123+
format_string += f"\n interpolation={self.interpolation}"
124+
format_string += "\n)"
125+
return format_string
126+
127+
def describe(self) -> str:
128+
return (
129+
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
130+
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
131+
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
132+
)
133+
94134

95135
class SemanticSegmentation(nn.Module):
96136
def __init__(
@@ -102,20 +142,35 @@ def __init__(
102142
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
103143
) -> None:
104144
super().__init__()
105-
self._size = [resize_size] if resize_size is not None else None
106-
self._mean = list(mean)
107-
self._std = list(std)
108-
self._interpolation = interpolation
145+
self.resize_size = [resize_size] if resize_size is not None else None
146+
self.mean = list(mean)
147+
self.std = list(std)
148+
self.interpolation = interpolation
109149

110150
def forward(self, img: Tensor) -> Tensor:
111-
if isinstance(self._size, list):
112-
img = F.resize(img, self._size, interpolation=self._interpolation)
151+
if isinstance(self.resize_size, list):
152+
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
113153
if not isinstance(img, Tensor):
114154
img = F.pil_to_tensor(img)
115155
img = F.convert_image_dtype(img, torch.float)
116-
img = F.normalize(img, mean=self._mean, std=self._std)
156+
img = F.normalize(img, mean=self.mean, std=self.std)
117157
return img
118158

159+
def __repr__(self) -> str:
160+
format_string = self.__class__.__name__ + "("
161+
format_string += f"\n resize_size={self.resize_size}"
162+
format_string += f"\n mean={self.mean}"
163+
format_string += f"\n std={self.std}"
164+
format_string += f"\n interpolation={self.interpolation}"
165+
format_string += "\n)"
166+
return format_string
167+
168+
def describe(self) -> str:
169+
return (
170+
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
171+
f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
172+
)
173+
119174

120175
class OpticalFlow(nn.Module):
121176
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
@@ -135,3 +190,9 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
135190
img2 = img2.contiguous()
136191

137192
return img1, img2
193+
194+
def __repr__(self) -> str:
195+
return self.__class__.__name__ + "()"
196+
197+
def describe(self) -> str:
198+
return "The images are rescaled to ``[-1.0, 1.0]``."

0 commit comments

Comments
 (0)