@@ -184,7 +184,7 @@ def _init_test_generalized_rcnn_transform(self):
184184 transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
185185 return transform
186186
187- def _init_test_rpn (self ):
187+ def _init_test_rpn (self , score_threshold = 0.0 ):
188188 anchor_sizes = ((32 ,), (64 ,), (128 ,), (256 ,), (512 ,))
189189 aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
190190 rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios )
@@ -197,7 +197,7 @@ def _init_test_rpn(self):
197197 rpn_pre_nms_top_n = dict (training = 2000 , testing = 1000 )
198198 rpn_post_nms_top_n = dict (training = 2000 , testing = 1000 )
199199 rpn_nms_thresh = 0.7
200- rpn_score_thresh = 0.0
200+ rpn_score_thresh = score_threshold
201201
202202 rpn = RegionProposalNetwork (
203203 rpn_anchor_generator , rpn_head ,
@@ -260,7 +260,7 @@ def test_rpn(self):
260260 class RPNModule (torch .nn .Module ):
261261 def __init__ (self_module ):
262262 super (RPNModule , self_module ).__init__ ()
263- self_module .rpn = self ._init_test_rpn ()
263+ self_module .rpn = self ._init_test_rpn (0.5 )
264264
265265 def forward (self_module , images , features ):
266266 images = ImageList (images , [i .shape [- 2 :] for i in images ])
0 commit comments