@@ -29,7 +29,10 @@ def __repr__(self) -> str:
2929 return self .__class__ .__name__ + "()"
3030
3131 def describe (self ) -> str :
32- return "The images are rescaled to ``[0.0, 1.0]``."
32+ return (
33+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
34+ "The images are rescaled to ``[0.0, 1.0]``."
35+ )
3336
3437
3538class ImageClassification (nn .Module ):
@@ -70,6 +73,7 @@ def __repr__(self) -> str:
7073
7174 def describe (self ) -> str :
7275 return (
76+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
7377 f"The images are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
7478 f"followed by a central crop of ``crop_size={ self .crop_size } ``. Finally the values are first rescaled to "
7579 f"``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
@@ -99,7 +103,6 @@ def forward(self, vid: Tensor) -> Tensor:
99103 vid = vid .unsqueeze (dim = 0 )
100104 need_squeeze = True
101105
102- vid = vid .permute (0 , 1 , 4 , 2 , 3 ) # (N, T, H, W, C) => (N, T, C, H, W)
103106 N , T , C , H , W = vid .shape
104107 vid = vid .view (- 1 , C , H , W )
105108 vid = F .resize (vid , self .resize_size , interpolation = self .interpolation )
@@ -126,9 +129,11 @@ def __repr__(self) -> str:
126129
127130 def describe (self ) -> str :
128131 return (
129- f"The video frames are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
132+ "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
133+ f"The frames are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
130134 f"followed by a central crop of ``crop_size={ self .crop_size } ``. Finally the values are first rescaled to "
131- f"``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
135+ f"``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and ``std={ self .std } ``. Finally the output "
136+ "dimensions are permuted to ``(..., C, T, H, W)`` tensors."
132137 )
133138
134139
@@ -167,6 +172,7 @@ def __repr__(self) -> str:
167172
168173 def describe (self ) -> str :
169174 return (
175+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
170176 f"The images are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``. "
171177 f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={ self .mean } `` and "
172178 f"``std={ self .std } ``."
@@ -196,4 +202,7 @@ def __repr__(self) -> str:
196202 return self .__class__ .__name__ + "()"
197203
198204 def describe (self ) -> str :
199- return "The images are rescaled to ``[-1.0, 1.0]``."
205+ return (
206+ "Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
207+ "The images are rescaled to ``[-1.0, 1.0]``."
208+ )
0 commit comments