@@ -25,6 +25,12 @@ def forward(self, img: Tensor) -> Tensor:
2525 img = F .pil_to_tensor (img )
2626 return F .convert_image_dtype (img , torch .float )
2727
28+ def __repr__ (self ) -> str :
29+ return self .__class__ .__name__ + "()"
30+
31+ def describe (self ) -> str :
32+ return "The images are rescaled to ``[0.0, 1.0]``."
33+
2834
2935class ImageClassification (nn .Module ):
3036 def __init__ (
@@ -37,21 +43,38 @@ def __init__(
3743 interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
3844 ) -> None :
3945 super ().__init__ ()
40- self ._crop_size = [crop_size ]
41- self ._size = [resize_size ]
42- self ._mean = list (mean )
43- self ._std = list (std )
44- self ._interpolation = interpolation
46+ self .crop_size = [crop_size ]
47+ self .resize_size = [resize_size ]
48+ self .mean = list (mean )
49+ self .std = list (std )
50+ self .interpolation = interpolation
4551
4652 def forward (self , img : Tensor ) -> Tensor :
47- img = F .resize (img , self ._size , interpolation = self ._interpolation )
48- img = F .center_crop (img , self ._crop_size )
53+ img = F .resize (img , self .resize_size , interpolation = self .interpolation )
54+ img = F .center_crop (img , self .crop_size )
4955 if not isinstance (img , Tensor ):
5056 img = F .pil_to_tensor (img )
5157 img = F .convert_image_dtype (img , torch .float )
52- img = F .normalize (img , mean = self ._mean , std = self ._std )
58+ img = F .normalize (img , mean = self .mean , std = self .std )
5359 return img
5460
61+ def __repr__ (self ) -> str :
62+ format_string = self .__class__ .__name__ + "("
63+ format_string += f"\n crop_size={ self .crop_size } "
64+ format_string += f"\n resize_size={ self .resize_size } "
65+ format_string += f"\n mean={ self .mean } "
66+ format_string += f"\n std={ self .std } "
67+ format_string += f"\n interpolation={ self .interpolation } "
68+ format_string += "\n )"
69+ return format_string
70+
71+ def describe (self ) -> str :
72+ return (
73+ f"The images are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
74+ f"followed by a central crop of ``crop_size={ self .crop_size } ``. Then the values are rescaled to "
75+ f"``[0.0, 1.0]`` and normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
76+ )
77+
5578
5679class VideoClassification (nn .Module ):
5780 def __init__ (
@@ -64,11 +87,11 @@ def __init__(
6487 interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
6588 ) -> None :
6689 super ().__init__ ()
67- self ._crop_size = list (crop_size )
68- self ._size = list (resize_size )
69- self ._mean = list (mean )
70- self ._std = list (std )
71- self ._interpolation = interpolation
90+ self .crop_size = list (crop_size )
91+ self .resize_size = list (resize_size )
92+ self .mean = list (mean )
93+ self .std = list (std )
94+ self .interpolation = interpolation
7295
7396 def forward (self , vid : Tensor ) -> Tensor :
7497 need_squeeze = False
@@ -79,18 +102,35 @@ def forward(self, vid: Tensor) -> Tensor:
79102 vid = vid .permute (0 , 1 , 4 , 2 , 3 ) # (N, T, H, W, C) => (N, T, C, H, W)
80103 N , T , C , H , W = vid .shape
81104 vid = vid .view (- 1 , C , H , W )
82- vid = F .resize (vid , self ._size , interpolation = self ._interpolation )
83- vid = F .center_crop (vid , self ._crop_size )
105+ vid = F .resize (vid , self .resize_size , interpolation = self .interpolation )
106+ vid = F .center_crop (vid , self .crop_size )
84107 vid = F .convert_image_dtype (vid , torch .float )
85- vid = F .normalize (vid , mean = self ._mean , std = self ._std )
86- H , W = self ._crop_size
108+ vid = F .normalize (vid , mean = self .mean , std = self .std )
109+ H , W = self .crop_size
87110 vid = vid .view (N , T , C , H , W )
88111 vid = vid .permute (0 , 2 , 1 , 3 , 4 ) # (N, T, C, H, W) => (N, C, T, H, W)
89112
90113 if need_squeeze :
91114 vid = vid .squeeze (dim = 0 )
92115 return vid
93116
117+ def __repr__ (self ) -> str :
118+ format_string = self .__class__ .__name__ + "("
119+ format_string += f"\n crop_size={ self .crop_size } "
120+ format_string += f"\n resize_size={ self .resize_size } "
121+ format_string += f"\n mean={ self .mean } "
122+ format_string += f"\n std={ self .std } "
123+ format_string += f"\n interpolation={ self .interpolation } "
124+ format_string += "\n )"
125+ return format_string
126+
127+ def describe (self ) -> str :
128+ return (
129+ f"The video frames are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``, "
130+ f"followed by a central crop of ``crop_size={ self .crop_size } ``. Then the values are rescaled to "
131+ f"``[0.0, 1.0]`` and normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
132+ )
133+
94134
95135class SemanticSegmentation (nn .Module ):
96136 def __init__ (
@@ -102,20 +142,35 @@ def __init__(
102142 interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
103143 ) -> None :
104144 super ().__init__ ()
105- self ._size = [resize_size ] if resize_size is not None else None
106- self ._mean = list (mean )
107- self ._std = list (std )
108- self ._interpolation = interpolation
145+ self .resize_size = [resize_size ] if resize_size is not None else None
146+ self .mean = list (mean )
147+ self .std = list (std )
148+ self .interpolation = interpolation
109149
110150 def forward (self , img : Tensor ) -> Tensor :
111- if isinstance (self ._size , list ):
112- img = F .resize (img , self ._size , interpolation = self ._interpolation )
151+ if isinstance (self .resize_size , list ):
152+ img = F .resize (img , self .resize_size , interpolation = self .interpolation )
113153 if not isinstance (img , Tensor ):
114154 img = F .pil_to_tensor (img )
115155 img = F .convert_image_dtype (img , torch .float )
116- img = F .normalize (img , mean = self ._mean , std = self ._std )
156+ img = F .normalize (img , mean = self .mean , std = self .std )
117157 return img
118158
159+ def __repr__ (self ) -> str :
160+ format_string = self .__class__ .__name__ + "("
161+ format_string += f"\n resize_size={ self .resize_size } "
162+ format_string += f"\n mean={ self .mean } "
163+ format_string += f"\n std={ self .std } "
164+ format_string += f"\n interpolation={ self .interpolation } "
165+ format_string += "\n )"
166+ return format_string
167+
168+ def describe (self ) -> str :
169+ return (
170+ f"The images are resized to ``resize_size={ self .resize_size } `` using ``interpolation={ self .interpolation } ``. "
171+ f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={ self .mean } `` and ``std={ self .std } ``."
172+ )
173+
119174
120175class OpticalFlow (nn .Module ):
121176 def forward (self , img1 : Tensor , img2 : Tensor ) -> Tuple [Tensor , Tensor ]:
@@ -135,3 +190,9 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
135190 img2 = img2 .contiguous ()
136191
137192 return img1 , img2
193+
194+ def __repr__ (self ) -> str :
195+ return self .__class__ .__name__ + "()"
196+
197+ def describe (self ) -> str :
198+ return "The images are rescaled to ``[-1.0, 1.0]``."
0 commit comments