|  | 
|  | 1 | +import PIL.Image | 
|  | 2 | +import pytest | 
|  | 3 | + | 
|  | 4 | +import torch | 
|  | 5 | + | 
|  | 6 | +from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask | 
|  | 7 | + | 
|  | 8 | +from torchvision.prototype import features | 
|  | 9 | +from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor | 
|  | 10 | +from torchvision.prototype.transforms.functional import to_image_pil | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +IMAGE = make_image(color_space=features.ColorSpace.RGB) | 
|  | 14 | +BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) | 
|  | 15 | +SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size) | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +@pytest.mark.parametrize( | 
|  | 19 | +    ("sample", "types", "expected"), | 
|  | 20 | +    [ | 
|  | 21 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), | 
|  | 22 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), | 
|  | 23 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), | 
|  | 24 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), | 
|  | 25 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), | 
|  | 26 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), | 
|  | 27 | +        ((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False), | 
|  | 28 | +        ((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False), | 
|  | 29 | +        ((IMAGE,), (features.BoundingBox, features.SegmentationMask), False), | 
|  | 30 | +        ( | 
|  | 31 | +            (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), | 
|  | 32 | +            (features.Image, features.BoundingBox, features.SegmentationMask), | 
|  | 33 | +            True, | 
|  | 34 | +        ), | 
|  | 35 | +        ((), (features.Image, features.BoundingBox, features.SegmentationMask), False), | 
|  | 36 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True), | 
|  | 37 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), | 
|  | 38 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), | 
|  | 39 | +        ((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True), | 
|  | 40 | +        ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), | 
|  | 41 | +        ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), | 
|  | 42 | +    ], | 
|  | 43 | +) | 
|  | 44 | +def test_has_any(sample, types, expected): | 
|  | 45 | +    assert has_any(sample, *types) is expected | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +@pytest.mark.parametrize( | 
|  | 49 | +    ("sample", "types", "expected"), | 
|  | 50 | +    [ | 
|  | 51 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), | 
|  | 52 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), | 
|  | 53 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), | 
|  | 54 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), | 
|  | 55 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), | 
|  | 56 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), | 
|  | 57 | +        ( | 
|  | 58 | +            (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), | 
|  | 59 | +            (features.Image, features.BoundingBox, features.SegmentationMask), | 
|  | 60 | +            True, | 
|  | 61 | +        ), | 
|  | 62 | +        ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False), | 
|  | 63 | +        ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False), | 
|  | 64 | +        ((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False), | 
|  | 65 | +        ( | 
|  | 66 | +            (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), | 
|  | 67 | +            (features.Image, features.BoundingBox, features.SegmentationMask), | 
|  | 68 | +            True, | 
|  | 69 | +        ), | 
|  | 70 | +        ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), | 
|  | 71 | +        ((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), | 
|  | 72 | +        ((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False), | 
|  | 73 | +        ( | 
|  | 74 | +            (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), | 
|  | 75 | +            (lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),), | 
|  | 76 | +            True, | 
|  | 77 | +        ), | 
|  | 78 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), | 
|  | 79 | +        ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), | 
|  | 80 | +    ], | 
|  | 81 | +) | 
|  | 82 | +def test_has_all(sample, types, expected): | 
|  | 83 | +    assert has_all(sample, *types) is expected | 
0 commit comments