|
1 | | -import unittest |
2 | | - |
3 | | - |
4 | 1 | import torch |
5 | 2 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone |
6 | 3 |
|
| 4 | +import pytest |
7 | 5 |
|
8 | | -class ResnetFPNBackboneTester(unittest.TestCase): |
9 | | - @classmethod |
10 | | - def setUpClass(cls): |
11 | | - cls.dtype = torch.float32 |
12 | | - |
13 | | - def test_resnet18_fpn_backbone(self): |
14 | | - device = torch.device('cpu') |
15 | | - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) |
16 | | - resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False) |
17 | | - y = resnet18_fpn(x) |
18 | | - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) |
19 | 6 |
|
20 | | - def test_resnet50_fpn_backbone(self): |
21 | | - device = torch.device('cpu') |
22 | | - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) |
23 | | - resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) |
24 | | - y = resnet50_fpn(x) |
25 | | - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) |
| 7 | +@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) |
| 8 | +def test_resnet_fpn_backbone(backbone_name): |
| 9 | + x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') |
| 10 | + y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) |
| 11 | + assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] |
0 commit comments