Skip to content

Commit b9a7678

Browse files
committed
Document __call__ of TensorTransforms consistently
Report their input/output shapes.
1 parent fa9e946 commit b9a7678

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torchvision/transforms/transforms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def __init__(self, mean, std, inplace=False):
165165
self.inplace = inplace
166166

167167
def __call__(self, tensor):
168-
"""Apply the transform to the given tensor and return the transformed tensor."""
168+
"""Apply the transform to the given tensor (C x H x W) and return the
169+
transformed tensor (C x H x W).
170+
"""
169171
return F.normalize(tensor, self.mean, self.std, self.inplace)
170172

171173
def __repr__(self):
@@ -784,6 +786,7 @@ def __repr__(self):
784786
class LinearTransformation(object):
785787
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
786788
offline.
789+
787790
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
788791
subtract mean_vector from it which is then followed by computing the dot
789792
product with the transformation matrix and then reshaping the tensor to its
@@ -813,7 +816,9 @@ def __init__(self, transformation_matrix, mean_vector):
813816
self.mean_vector = mean_vector
814817

815818
def __call__(self, tensor):
816-
"""Apply the transform to the given tensor and return the transformed tensor."""
819+
"""Apply the transform to the given tensor (C x H x W) and return the
820+
transformed tensor (C x H x W).
821+
"""
817822
if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
818823
raise ValueError("tensor and transformation matrix have incompatible shape." +
819824
"[{} x {} x {}] != ".format(*tensor.size()) +
@@ -1276,7 +1281,9 @@ def get_params(img, scale, ratio, value=0):
12761281
return 0, 0, img_h, img_w, img
12771282

12781283
def __call__(self, img):
1279-
"""Apply the transform to the given tensor and return the transformed tensor."""
1284+
"""Apply the transform to the given tensor (C x H x W) and return the
1285+
transformed tensor (C x H x W).
1286+
"""
12801287
if random.uniform(0, 1) < self.p:
12811288
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
12821289
return F.erase(img, x, y, h, w, v, self.inplace)

0 commit comments

Comments
 (0)