11import io
22import torch
33from torchvision import ops
4+ from torchvision .models .detection .image_list import ImageList
45from torchvision .models .detection .transform import GeneralizedRCNNTransform
6+ from torchvision .models .detection .rpn import AnchorGenerator , RPNHead , RegionProposalNetwork
7+ from torchvision .models .detection .backbone_utils import resnet_fpn_backbone
58
69from collections import OrderedDict
710
@@ -20,7 +23,7 @@ class ONNXExporterTester(unittest.TestCase):
2023 def setUpClass (cls ):
2124 torch .manual_seed (123 )
2225
23- def run_model (self , model , inputs_list ):
26+ def run_model (self , model , inputs_list , tolerate_small_mismatch = False ):
2427 model .eval ()
2528
2629 onnx_io = io .BytesIO ()
@@ -36,9 +39,9 @@ def run_model(self, model, inputs_list):
3639 test_ouputs = model (* test_inputs )
3740 if isinstance (test_ouputs , torch .Tensor ):
3841 test_ouputs = (test_ouputs ,)
39- self .ort_validate (onnx_io , test_inputs , test_ouputs )
42+ self .ort_validate (onnx_io , test_inputs , test_ouputs , tolerate_small_mismatch )
4043
41- def ort_validate (self , onnx_io , inputs , outputs ):
44+ def ort_validate (self , onnx_io , inputs , outputs , tolerate_small_mismatch = False ):
4245
4346 inputs , _ = torch .jit ._flatten (inputs )
4447 outputs , _ = torch .jit ._flatten (outputs )
@@ -58,7 +61,13 @@ def to_numpy(tensor):
5861 ort_outs = ort_session .run (None , ort_inputs )
5962
6063 for i in range (0 , len (outputs )):
61- torch .testing .assert_allclose (outputs [i ], ort_outs [i ], rtol = 1e-03 , atol = 1e-05 )
64+ try :
65+ torch .testing .assert_allclose (outputs [i ], ort_outs [i ], rtol = 1e-03 , atol = 1e-05 )
66+ except AssertionError as error :
67+ if tolerate_small_mismatch :
68+ assert ("(0.00%)" in str (error )), str (error )
69+ else :
70+ assert False , str (error )
6271
6372 def test_nms (self ):
6473 boxes = torch .rand (5 , 4 )
@@ -91,11 +100,7 @@ def test_transform_images(self):
91100 class TransformModule (torch .nn .Module ):
92101 def __init__ (self_module ):
93102 super (TransformModule , self_module ).__init__ ()
94- min_size = 800
95- max_size = 1333
96- image_mean = [0.485 , 0.456 , 0.406 ]
97- image_std = [0.229 , 0.224 , 0.225 ]
98- self_module .transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
103+ self_module .transform = self ._init_test_generalized_rcnn_transform ()
99104
100105 def forward (self_module , images ):
101106 return self_module .transform (images )[0 ].tensors
@@ -104,6 +109,66 @@ def forward(self_module, images):
104109 input_test = [torch .rand (3 , 800 , 1280 ), torch .rand (3 , 800 , 800 )]
105110 self .run_model (TransformModule (), [input , input_test ])
106111
112+ def _init_test_generalized_rcnn_transform (self ):
113+ min_size = 800
114+ max_size = 1333
115+ image_mean = [0.485 , 0.456 , 0.406 ]
116+ image_std = [0.229 , 0.224 , 0.225 ]
117+ transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
118+ return transform
119+
120+ def _init_test_rpn (self ):
121+ anchor_sizes = ((32 ,), (64 ,), (128 ,), (256 ,), (512 ,))
122+ aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
123+ rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios )
124+ out_channels = 256
125+ rpn_head = RPNHead (out_channels , rpn_anchor_generator .num_anchors_per_location ()[0 ])
126+ rpn_fg_iou_thresh = 0.7
127+ rpn_bg_iou_thresh = 0.3
128+ rpn_batch_size_per_image = 256
129+ rpn_positive_fraction = 0.5
130+ rpn_pre_nms_top_n = dict (training = 2000 , testing = 1000 )
131+ rpn_post_nms_top_n = dict (training = 2000 , testing = 1000 )
132+ rpn_nms_thresh = 0.7
133+
134+ rpn = RegionProposalNetwork (
135+ rpn_anchor_generator , rpn_head ,
136+ rpn_fg_iou_thresh , rpn_bg_iou_thresh ,
137+ rpn_batch_size_per_image , rpn_positive_fraction ,
138+ rpn_pre_nms_top_n , rpn_post_nms_top_n , rpn_nms_thresh )
139+ return rpn
140+
141+ def test_rpn (self ):
142+ class RPNModule (torch .nn .Module ):
143+ def __init__ (self_module , images ):
144+ super (RPNModule , self_module ).__init__ ()
145+ self_module .rpn = self ._init_test_rpn ()
146+ self_module .images = ImageList (images , [i .shape [- 2 :] for i in images ])
147+
148+ def forward (self_module , features ):
149+ return self_module .rpn (self_module .images , features )
150+
151+ def get_features (images ):
152+ s0 , s1 = images .shape [- 2 :]
153+ features = [
154+ ('0' , torch .rand (2 , 256 , s0 // 4 , s1 // 4 )),
155+ ('1' , torch .rand (2 , 256 , s0 // 8 , s1 // 8 )),
156+ ('2' , torch .rand (2 , 256 , s0 // 16 , s1 // 16 )),
157+ ('3' , torch .rand (2 , 256 , s0 // 32 , s1 // 32 )),
158+ ('4' , torch .rand (2 , 256 , s0 // 64 , s1 // 64 )),
159+ ]
160+ features = OrderedDict (features )
161+ return features
162+
163+ images = torch .rand (2 , 3 , 600 , 600 )
164+ features = get_features (images )
165+ test_features = get_features (images )
166+
167+ model = RPNModule (images )
168+ model .eval ()
169+ model (features )
170+ self .run_model (model , [(features ,), (test_features ,)], tolerate_small_mismatch = True )
171+
107172 def test_multi_scale_roi_align (self ):
108173
109174 class TransformModule (torch .nn .Module ):
0 commit comments