diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 93d5f17fcbe..9beded4c957 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1359,7 +1359,7 @@ def test_ctor(self, transform_cls, trfms): class TestRandomChoice: def test_assertions(self): - with pytest.raises(ValueError, match="The number of probabilities doesn't match the number of transforms"): + with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"): transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1]) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 059a230ee5c..a8a87cd43dd 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -822,7 +822,7 @@ def test_random_choice(self, probabilities): v2_transforms.Resize(256), legacy_transforms.CenterCrop(224), ], - probabilities=probabilities, + p=probabilities, ) legacy_transform = legacy_transforms.RandomChoice( [ diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 27affc7100b..7f9df337352 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -139,7 +139,7 @@ def __init__( p = [1] * len(transforms) elif len(p) != len(transforms): raise ValueError( - f"The number of p doesn't match the number of transforms: " f"{len(p)} != {len(transforms)}" + f"Length of p doesn't match the number of transforms: " f"{len(p)} != {len(transforms)}" ) super().__init__()