diff --git a/test/test_transforms.py b/test/test_transforms.py index 952d411d609..5d0275f946f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -86,7 +86,7 @@ def test_scale(self): owidth = random.randint(5, 12) * 2 result = transforms.Compose([ transforms.ToPILImage(), - transforms.Scale((owidth, oheight)), + transforms.Scale((oheight, owidth)), transforms.ToTensor(), ])(img) assert result.size(1) == oheight @@ -94,7 +94,7 @@ def test_scale(self): result = transforms.Compose([ transforms.ToPILImage(), - transforms.Scale([owidth, oheight]), + transforms.Scale([oheight, owidth]), transforms.ToTensor(), ])(img) assert result.size(1) == oheight diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 85f0a82c8c8..da58aa12b9a 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -165,7 +165,7 @@ class Scale(object): Args: size (sequence or int): Desired output size. If size is a sequence like - (w, h), output size will be matched to this. If size is an int, + (h, w), output size will be matched to this. If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size) @@ -199,7 +199,7 @@ def __call__(self, img): ow = int(self.size * w / h) return img.resize((ow, oh), self.interpolation) else: - return img.resize(self.size, self.interpolation) + return img.resize(self.size[::-1], self.interpolation) class CenterCrop(object):