Skip to content
42 changes: 39 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,17 @@ class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p

@staticmethod
def get_params(p):
"""Get parameters for ``crop`` for a random crop.

Args:
p : probability of flipping
Returns:
tuple: bool
"""
return random.random() < p

def __call__(self, img):
"""
Args:
Expand All @@ -485,7 +496,8 @@ def __call__(self, img):
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
to_flip = self.get_params(self.p)
if to_flip:
return F.hflip(img)
return img

Expand All @@ -503,6 +515,17 @@ class RandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p

@staticmethod
def get_params(p):
"""Get parameters for ``crop`` for a random crop.

Args:
p : probability of flipping
Returns:
tuple: bool
"""
return random.random() < p

def __call__(self, img):
"""
Args:
Expand All @@ -511,7 +534,8 @@ def __call__(self, img):
Returns:
PIL Image: Randomly flipped image.
"""
if random.random() < self.p:
to_flip = self.get_params(self.p)
if to_flip:
return F.vflip(img)
return img

Expand Down Expand Up @@ -1068,6 +1092,17 @@ class RandomGrayscale(object):
def __init__(self, p=0.1):
self.p = p

@staticmethod
def get_params(p):
"""Get parameters for ``crop`` for a random crop.

Args:
p : probability of converting to grayscale
Returns:
tuple: bool
"""
return random.random() < p

def __call__(self, img):
"""
Args:
Expand All @@ -1077,7 +1112,8 @@ def __call__(self, img):
PIL Image: Randomly grayscaled image.
"""
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
to_convert = self.get_params(self.p)
if to_convert:
return F.to_grayscale(img, num_output_channels=num_output_channels)
return img

Expand Down