From 1d44f2995168e637b9361dcaff0ca4760a7f2912 Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 20:01:29 +0530 Subject: [PATCH 1/8] modified code to accept list of pil images if input is a list of PIL images it returns a list of transformed imgs else it retains its old behaviour. if params need to be computed for every img then the params are computed based on the first img of the list. this change was made to ensure that a set of img have the same random transforms applied to them, for example in the image segmentation. --- torchvision/transforms/transforms.py | 162 ++++++++++++++++++++------- 1 file changed, 119 insertions(+), 43 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a640ea403f5..45660ae17ba 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -82,7 +82,10 @@ def __call__(self, pic): Returns: Tensor: Converted image. """ - return F.to_tensor(pic) + if not isinstance(pic,list) : + return F.to_tensor(pic) + else : + return [F.to_tensor(p) for p in pic] def __repr__(self): return self.__class__.__name__ + '()' @@ -116,7 +119,10 @@ def __call__(self, pic): PIL Image: Image converted to PIL Image. """ - return F.to_pil_image(pic, self.mode) + if not isinstance(pic,list) : + return F.to_pil_image(pic, self.mode) + else : + return [F.to_pil_image(p, self.mode) for p in pic] def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -184,7 +190,10 @@ def __call__(self, img): Returns: PIL Image: Rescaled image. """ - return F.resize(img, self.size, self.interpolation) + if not isinstance(img,list) : + return F.resize(img, self.size, self.interpolation) + else : + return [F.resize(im, self.size, self.interpolation) for im in img] def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] @@ -224,7 +233,10 @@ def __call__(self, img): Returns: PIL Image: Cropped image. """ - return F.center_crop(img, self.size) + if not isinstance(img,list) : + return F.center_crop(img, self.size) + else : + return [ F.center_crop(im, self.size) for im in img] def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) @@ -280,7 +292,10 @@ def __call__(self, img): Returns: PIL Image: Padded image. """ - return F.pad(img, self.padding, self.fill, self.padding_mode) + if not isinstance(img,list) : + return F.pad(img, self.padding, self.fill, self.padding_mode) + else : + return [F.pad(im, self.padding, self.fill, self.padding_mode) for im in img] def __repr__(self): return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ @@ -299,7 +314,10 @@ def __init__(self, lambd): self.lambd = lambd def __call__(self, img): - return self.lambd(img) + if not isinstance(img,list) : + return self.lambd(img) + else : + return [self.lambd(im) for im in img] def __repr__(self): return self.__class__.__name__ + '()' @@ -344,7 +362,10 @@ def __call__(self, img): if self.p < random.random(): return img for t in self.transforms: - img = t(img) + if not isinstance(img,list) : + img = t(img) + else : + img = [t(im) for im in img] return img def __repr__(self): @@ -364,7 +385,10 @@ def __call__(self, img): order = list(range(len(self.transforms))) random.shuffle(order) for i in order: - img = self.transforms[i](img) + if not isinstance(img,list) : + img = self.transforms[i](img) + else : + img = [self.transforms[i](im) for im in img] return img @@ -373,7 +397,10 @@ class RandomChoice(RandomTransforms): """ def __call__(self, img): t = random.choice(self.transforms) - return t(img) + if not isinstance(img,list) : + return t(img) + else : + return [t(im) for im in img] class RandomCrop(object): @@ -441,7 +468,7 @@ def get_params(img, output_size): j = random.randint(0, w - tw) return i, j, th, tw - def __call__(self, img): + def __call__(self, imgs): """ Args: img (PIL Image): Image to be cropped. @@ -449,20 +476,29 @@ def __call__(self, img): Returns: PIL Image: Cropped image. """ - if self.padding is not None: - img = F.pad(img, self.padding, self.fill, self.padding_mode) - - # pad the width if needed - if self.pad_if_needed and img.size[0] < self.size[1]: - img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) - # pad the height if needed - if self.pad_if_needed and img.size[1] < self.size[0]: - img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) - - i, j, h, w = self.get_params(img, self.size) - - return F.crop(img, i, j, h, w) - + not_list_flag = False + if not isinstance(imgs,list) : + imgs = [imgs] + not_list_flag = True + outputs = [] + for img in imgs : + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + # pad the width if needed + if self.pad_if_needed and img.size[0] < self.size[1]: + img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and img.size[1] < self.size[0]: + img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + + outputs.append(img) + i, j, h, w = self.get_params(imgs[0], self.size) + outputs = [F.crop(img, i, j, h, w)) for img in outputs] + if not_list_flag : + return outputs[0] + else : + return outputs def __repr__(self): return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) @@ -486,7 +522,10 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ if random.random() < self.p: - return F.hflip(img) + if not isinstance(img,list) : + return F.hflip(img) + else : + return [F.hflip(im) for im in img] return img def __repr__(self): @@ -512,7 +551,10 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ if random.random() < self.p: - return F.vflip(img) + if not isinstance(img,list) : + return F.vflip(img) + else : + return [F.vflip(im) for im in img] return img def __repr__(self): @@ -575,7 +617,7 @@ def get_params(img, scale, ratio): j = (img.size[0] - w) // 2 return i, j, w, w - def __call__(self, img): + def __call__(self, imgs): """ Args: img (PIL Image): Image to be cropped and resized. @@ -583,8 +625,18 @@ def __call__(self, img): Returns: PIL Image: Randomly cropped and resized image. """ - i, j, h, w = self.get_params(img, self.scale, self.ratio) - return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + outputs = [] + not_list_flag = False + if not isinstance(imgs,list) : + imgs = [imgs] + not_list_flag = True + i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) + for img in imgs : + outputs.append(F.resized_crop(img, i, j, h, w, self.size, self.interpolation)) + if not_list_flag : + return outputs[0] + else : + return outputs def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] @@ -638,7 +690,10 @@ def __init__(self, size): self.size = size def __call__(self, img): - return F.five_crop(img, self.size) + if not isinstance(img,list) : + return F.five_crop(img, self.size) + else : + return [F.five_crop(im, self.size) for im in img] def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) @@ -681,8 +736,10 @@ def __init__(self, size, vertical_flip=False): self.vertical_flip = vertical_flip def __call__(self, img): - return F.ten_crop(img, self.size, self.vertical_flip) - + if not isinstance(img,list) : + return F.ten_crop(img, self.size, self.vertical_flip) + else : + return [F.ten_crop(im, self.size, self.vertical_flip) for im in img] def __repr__(self): return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) @@ -819,7 +876,10 @@ def __call__(self, img): """ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) - return transform(img) + if not isinstance(img,list) : + return transform(img) + else : + return [transform(im) for im in img] def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -886,9 +946,11 @@ def __call__(self, img): """ angle = self.get_params(self.degrees) - - return F.rotate(img, angle, self.resample, self.expand, self.center) - + if not isinstance(img,list) : + return F.rotate(img, angle, self.resample, self.expand, self.center) + else : + return [F.rotate(im, angle, self.resample, self.expand, self.center) for im in img] + def __repr__(self): format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) format_string += ', resample={0}'.format(self.resample) @@ -993,16 +1055,25 @@ def get_params(degrees, translate, scale_ranges, shears, img_size): return angle, translations, scale, shear - def __call__(self, img): + def __call__(self, imgs): """ img (PIL Image): Image to be transformed. Returns: PIL Image: Affine transformed image. """ - ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) - return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) - + not_list_flag = False + outputs = [] + if not isinstance(imgs,list) : + imgs = [imgs] + not_list_flag = True + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, imgs[0].size) + outputs = [ F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) for img in imgs] + if not_list_flag : + return outputs[0] + else : + return outputs + def __repr__(self): s = '{name}(degrees={degrees}' if self.translate is not None: @@ -1045,8 +1116,10 @@ def __call__(self, img): Returns: PIL Image: Randomly grayscaled image. """ - return F.to_grayscale(img, num_output_channels=self.num_output_channels) - + if not isinstance(img,list) : + return F.to_grayscale(img, num_output_channels=self.num_output_channels) + else : + return [ F.to_grayscale(im, num_output_channels=self.num_output_channels) for im in img] def __repr__(self): return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) @@ -1078,7 +1151,10 @@ def __call__(self, img): """ num_output_channels = 1 if img.mode == 'L' else 3 if random.random() < self.p: - return F.to_grayscale(img, num_output_channels=num_output_channels) + if not isinstance(img,list) : + return F.to_grayscale(img, num_output_channels=num_output_channels) + else : + return [ F.to_grayscale(im, num_output_channels=num_output_channels) for im in img] return img def __repr__(self): From dd6411ec2d1037f8b6f404634bb644583b6e1afc Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 21:25:03 +0530 Subject: [PATCH 2/8] Revert "modified code to accept list of pil images" --- torchvision/transforms/transforms.py | 162 +++++++-------------------- 1 file changed, 43 insertions(+), 119 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 45660ae17ba..a640ea403f5 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -82,10 +82,7 @@ def __call__(self, pic): Returns: Tensor: Converted image. """ - if not isinstance(pic,list) : - return F.to_tensor(pic) - else : - return [F.to_tensor(p) for p in pic] + return F.to_tensor(pic) def __repr__(self): return self.__class__.__name__ + '()' @@ -119,10 +116,7 @@ def __call__(self, pic): PIL Image: Image converted to PIL Image. """ - if not isinstance(pic,list) : - return F.to_pil_image(pic, self.mode) - else : - return [F.to_pil_image(p, self.mode) for p in pic] + return F.to_pil_image(pic, self.mode) def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -190,10 +184,7 @@ def __call__(self, img): Returns: PIL Image: Rescaled image. """ - if not isinstance(img,list) : - return F.resize(img, self.size, self.interpolation) - else : - return [F.resize(im, self.size, self.interpolation) for im in img] + return F.resize(img, self.size, self.interpolation) def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] @@ -233,10 +224,7 @@ def __call__(self, img): Returns: PIL Image: Cropped image. """ - if not isinstance(img,list) : - return F.center_crop(img, self.size) - else : - return [ F.center_crop(im, self.size) for im in img] + return F.center_crop(img, self.size) def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) @@ -292,10 +280,7 @@ def __call__(self, img): Returns: PIL Image: Padded image. """ - if not isinstance(img,list) : - return F.pad(img, self.padding, self.fill, self.padding_mode) - else : - return [F.pad(im, self.padding, self.fill, self.padding_mode) for im in img] + return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self): return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ @@ -314,10 +299,7 @@ def __init__(self, lambd): self.lambd = lambd def __call__(self, img): - if not isinstance(img,list) : - return self.lambd(img) - else : - return [self.lambd(im) for im in img] + return self.lambd(img) def __repr__(self): return self.__class__.__name__ + '()' @@ -362,10 +344,7 @@ def __call__(self, img): if self.p < random.random(): return img for t in self.transforms: - if not isinstance(img,list) : - img = t(img) - else : - img = [t(im) for im in img] + img = t(img) return img def __repr__(self): @@ -385,10 +364,7 @@ def __call__(self, img): order = list(range(len(self.transforms))) random.shuffle(order) for i in order: - if not isinstance(img,list) : - img = self.transforms[i](img) - else : - img = [self.transforms[i](im) for im in img] + img = self.transforms[i](img) return img @@ -397,10 +373,7 @@ class RandomChoice(RandomTransforms): """ def __call__(self, img): t = random.choice(self.transforms) - if not isinstance(img,list) : - return t(img) - else : - return [t(im) for im in img] + return t(img) class RandomCrop(object): @@ -468,7 +441,7 @@ def get_params(img, output_size): j = random.randint(0, w - tw) return i, j, th, tw - def __call__(self, imgs): + def __call__(self, img): """ Args: img (PIL Image): Image to be cropped. @@ -476,29 +449,20 @@ def __call__(self, imgs): Returns: PIL Image: Cropped image. """ - not_list_flag = False - if not isinstance(imgs,list) : - imgs = [imgs] - not_list_flag = True - outputs = [] - for img in imgs : - if self.padding is not None: - img = F.pad(img, self.padding, self.fill, self.padding_mode) - - # pad the width if needed - if self.pad_if_needed and img.size[0] < self.size[1]: - img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) - # pad the height if needed - if self.pad_if_needed and img.size[1] < self.size[0]: - img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) - - outputs.append(img) - i, j, h, w = self.get_params(imgs[0], self.size) - outputs = [F.crop(img, i, j, h, w)) for img in outputs] - if not_list_flag : - return outputs[0] - else : - return outputs + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + # pad the width if needed + if self.pad_if_needed and img.size[0] < self.size[1]: + img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and img.size[1] < self.size[0]: + img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + + return F.crop(img, i, j, h, w) + def __repr__(self): return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) @@ -522,10 +486,7 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ if random.random() < self.p: - if not isinstance(img,list) : - return F.hflip(img) - else : - return [F.hflip(im) for im in img] + return F.hflip(img) return img def __repr__(self): @@ -551,10 +512,7 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ if random.random() < self.p: - if not isinstance(img,list) : - return F.vflip(img) - else : - return [F.vflip(im) for im in img] + return F.vflip(img) return img def __repr__(self): @@ -617,7 +575,7 @@ def get_params(img, scale, ratio): j = (img.size[0] - w) // 2 return i, j, w, w - def __call__(self, imgs): + def __call__(self, img): """ Args: img (PIL Image): Image to be cropped and resized. @@ -625,18 +583,8 @@ def __call__(self, imgs): Returns: PIL Image: Randomly cropped and resized image. """ - outputs = [] - not_list_flag = False - if not isinstance(imgs,list) : - imgs = [imgs] - not_list_flag = True - i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) - for img in imgs : - outputs.append(F.resized_crop(img, i, j, h, w, self.size, self.interpolation)) - if not_list_flag : - return outputs[0] - else : - return outputs + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) def __repr__(self): interpolate_str = _pil_interpolation_to_str[self.interpolation] @@ -690,10 +638,7 @@ def __init__(self, size): self.size = size def __call__(self, img): - if not isinstance(img,list) : - return F.five_crop(img, self.size) - else : - return [F.five_crop(im, self.size) for im in img] + return F.five_crop(img, self.size) def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) @@ -736,10 +681,8 @@ def __init__(self, size, vertical_flip=False): self.vertical_flip = vertical_flip def __call__(self, img): - if not isinstance(img,list) : - return F.ten_crop(img, self.size, self.vertical_flip) - else : - return [F.ten_crop(im, self.size, self.vertical_flip) for im in img] + return F.ten_crop(img, self.size, self.vertical_flip) + def __repr__(self): return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) @@ -876,10 +819,7 @@ def __call__(self, img): """ transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) - if not isinstance(img,list) : - return transform(img) - else : - return [transform(im) for im in img] + return transform(img) def __repr__(self): format_string = self.__class__.__name__ + '(' @@ -946,11 +886,9 @@ def __call__(self, img): """ angle = self.get_params(self.degrees) - if not isinstance(img,list) : - return F.rotate(img, angle, self.resample, self.expand, self.center) - else : - return [F.rotate(im, angle, self.resample, self.expand, self.center) for im in img] - + + return F.rotate(img, angle, self.resample, self.expand, self.center) + def __repr__(self): format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) format_string += ', resample={0}'.format(self.resample) @@ -1055,25 +993,16 @@ def get_params(degrees, translate, scale_ranges, shears, img_size): return angle, translations, scale, shear - def __call__(self, imgs): + def __call__(self, img): """ img (PIL Image): Image to be transformed. Returns: PIL Image: Affine transformed image. """ - not_list_flag = False - outputs = [] - if not isinstance(imgs,list) : - imgs = [imgs] - not_list_flag = True - ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, imgs[0].size) - outputs = [ F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) for img in imgs] - if not_list_flag : - return outputs[0] - else : - return outputs - + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) + return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) + def __repr__(self): s = '{name}(degrees={degrees}' if self.translate is not None: @@ -1116,10 +1045,8 @@ def __call__(self, img): Returns: PIL Image: Randomly grayscaled image. """ - if not isinstance(img,list) : - return F.to_grayscale(img, num_output_channels=self.num_output_channels) - else : - return [ F.to_grayscale(im, num_output_channels=self.num_output_channels) for im in img] + return F.to_grayscale(img, num_output_channels=self.num_output_channels) + def __repr__(self): return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) @@ -1151,10 +1078,7 @@ def __call__(self, img): """ num_output_channels = 1 if img.mode == 'L' else 3 if random.random() < self.p: - if not isinstance(img,list) : - return F.to_grayscale(img, num_output_channels=num_output_channels) - else : - return [ F.to_grayscale(im, num_output_channels=num_output_channels) for im in img] + return F.to_grayscale(img, num_output_channels=num_output_channels) return img def __repr__(self): From 3de832b2035c0e86ab25798ddfc45b4740d466ec Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 21:34:20 +0530 Subject: [PATCH 3/8] added a get_params method for horizontal flip, vertical flip and random grayscale --- torchvision/transforms/transforms.py | 41 ++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a640ea403f5..76e0a89eaea 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -477,6 +477,17 @@ class RandomHorizontalFlip(object): def __init__(self, p=0.5): self.p = p + @staticmethod + def get_params(p): + """Get parameters for ``crop`` for a random crop. + + Args: + p : probability of flipping + Returns: + tuple: bool that would determine the flip + """ + return random.random() < p + def __call__(self, img): """ Args: @@ -485,7 +496,8 @@ def __call__(self, img): Returns: PIL Image: Randomly flipped image. """ - if random.random() < self.p: + to_flip = self.get_params(self.p) + if to_flip : return F.hflip(img) return img @@ -503,6 +515,16 @@ class RandomVerticalFlip(object): def __init__(self, p=0.5): self.p = p + def get_params(p): + """Get parameters for ``crop`` for a random crop. + + Args: + p : probability of flipping + Returns: + tuple: bool that would determine the flip + """ + return random.random() < p + def __call__(self, img): """ Args: @@ -511,7 +533,8 @@ def __call__(self, img): Returns: PIL Image: Randomly flipped image. """ - if random.random() < self.p: + to_flip = self.get_params(self.p) + if to_flip : return F.vflip(img) return img @@ -1037,6 +1060,7 @@ class Grayscale(object): def __init__(self, num_output_channels=1): self.num_output_channels = num_output_channels + def __call__(self, img): """ Args: @@ -1068,6 +1092,16 @@ class RandomGrayscale(object): def __init__(self, p=0.1): self.p = p + def get_params(p): + """Get parameters for ``crop`` for a random crop. + + Args: + p : probability of converting to grayscale + Returns: + tuple: bool that would determine the flip + """ + return random.random() < p + def __call__(self, img): """ Args: @@ -1077,7 +1111,8 @@ def __call__(self, img): PIL Image: Randomly grayscaled image. """ num_output_channels = 1 if img.mode == 'L' else 3 - if random.random() < self.p: + to_convert = self.get_params(self.p) + if to_convert : return F.to_grayscale(img, num_output_channels=num_output_channels) return img From 503e636248be8bfbf09e4ea5e4a5632823299d1b Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 21:35:23 +0530 Subject: [PATCH 4/8] fixed comments --- torchvision/transforms/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 76e0a89eaea..18da3734e90 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -484,7 +484,7 @@ def get_params(p): Args: p : probability of flipping Returns: - tuple: bool that would determine the flip + tuple: bool """ return random.random() < p @@ -521,7 +521,7 @@ def get_params(p): Args: p : probability of flipping Returns: - tuple: bool that would determine the flip + tuple: bool """ return random.random() < p @@ -1098,7 +1098,7 @@ def get_params(p): Args: p : probability of converting to grayscale Returns: - tuple: bool that would determine the flip + tuple: bool """ return random.random() < p From ed485018d4468304ab9e229f02731b5e13578c48 Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 21:42:39 +0530 Subject: [PATCH 5/8] added the static decorator --- torchvision/transforms/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 18da3734e90..a5727420a76 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1092,6 +1092,7 @@ class RandomGrayscale(object): def __init__(self, p=0.1): self.p = p + @staticmethod def get_params(p): """Get parameters for ``crop`` for a random crop. From db61d3846f044b7874cecbc9838c69b1748f6d8e Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 21:44:43 +0530 Subject: [PATCH 6/8] added static decorator --- torchvision/transforms/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index a5727420a76..6b61000ad2f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -515,6 +515,7 @@ class RandomVerticalFlip(object): def __init__(self, p=0.5): self.p = p + @staticmethod def get_params(p): """Get parameters for ``crop`` for a random crop. From 509b9b20d4c286d9412ee2c8d393833c219e0b21 Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 21:46:13 +0530 Subject: [PATCH 7/8] removing space before colons --- torchvision/transforms/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 6b61000ad2f..3f9f41b5c6e 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -497,7 +497,7 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ to_flip = self.get_params(self.p) - if to_flip : + if to_flip: return F.hflip(img) return img @@ -535,7 +535,7 @@ def __call__(self, img): PIL Image: Randomly flipped image. """ to_flip = self.get_params(self.p) - if to_flip : + if to_flip: return F.vflip(img) return img @@ -1114,7 +1114,7 @@ def __call__(self, img): """ num_output_channels = 1 if img.mode == 'L' else 3 to_convert = self.get_params(self.p) - if to_convert : + if to_convert: return F.to_grayscale(img, num_output_channels=num_output_channels) return img From 74cf1ef0b0d8a92348b8026d88c13e9b28f088f4 Mon Sep 17 00:00:00 2001 From: quantummole <35193295+quantummole@users.noreply.github.com> Date: Mon, 24 Sep 2018 22:03:39 +0530 Subject: [PATCH 8/8] removed newline @ 1063 --- torchvision/transforms/transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 3f9f41b5c6e..108b5dfdcc0 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1061,7 +1061,6 @@ class Grayscale(object): def __init__(self, num_output_channels=1): self.num_output_channels = num_output_channels - def __call__(self, img): """ Args: