@@ -34,6 +34,8 @@ class Raft_Large_Weights(WeightsEnum):
3434 "recipe" : "https://github.com/princeton-vl/RAFT" ,
3535 "sintel_train_cleanpass_epe" : 1.4411 ,
3636 "sintel_train_finalpass_epe" : 2.7894 ,
37+ "kitti_train_per_image_epe" : 5.0172 ,
38+ "kitti_train_f1-all" : 17.4506 ,
3739 },
3840 )
3941
@@ -46,6 +48,8 @@ class Raft_Large_Weights(WeightsEnum):
4648 "recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
4749 "sintel_train_cleanpass_epe" : 1.3822 ,
4850 "sintel_train_finalpass_epe" : 2.7161 ,
51+ "kitti_train_per_image_epe" : 4.5118 ,
52+ "kitti_train_f1-all" : 16.0679 ,
4953 },
5054 )
5155
@@ -87,10 +91,25 @@ class Raft_Small_Weights(WeightsEnum):
8791 "recipe" : "https://github.com/princeton-vl/RAFT" ,
8892 "sintel_train_cleanpass_epe" : 2.1231 ,
8993 "sintel_train_finalpass_epe" : 3.2790 ,
94+ "kitti_train_per_image_epe" : 7.6557 ,
95+ "kitti_train_f1-all" : 25.2801 ,
96+ },
97+ )
98+ C_T_V2 = Weights (
99+ # Chairs + Things
100+ url = "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth" ,
101+ transforms = RaftEval ,
102+ meta = {
103+ ** _COMMON_META ,
104+ "recipe" : "https://github.com/princeton-vl/RAFT" ,
105+ "sintel_train_cleanpass_epe" : 1.9901 ,
106+ "sintel_train_finalpass_epe" : 3.2831 ,
107+ "kitti_train_per_image_epe" : 7.5978 ,
108+ "kitti_train_f1-all" : 25.2369 ,
90109 },
91110 )
92111
93- default = C_T_V1 # TODO: Change to V2 once we upload our own weights
112+ default = C_T_V2
94113
95114
96115@handle_legacy_interface (weights = ("pretrained" , Raft_Large_Weights .C_T_V2 ))
@@ -143,14 +162,13 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
143162 return model
144163
145164
146- # TODO: change to V2 once we upload our own weights
147- @handle_legacy_interface (weights = ("pretrained" , Raft_Small_Weights .C_T_V1 ))
165+ @handle_legacy_interface (weights = ("pretrained" , Raft_Small_Weights .C_T_V2 ))
148166def raft_small (* , weights : Optional [Raft_Small_Weights ] = None , progress = True , ** kwargs ):
149167 """RAFT "small" model from
150168 `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
151169
152170 Args:
153- weights(Raft_Small_weights, optinal ): TODO not implemented yet
171+ weights(Raft_Small_weights, optional ): pretrained weights to use.
154172 progress (bool): If True, displays a progress bar of the download to stderr
155173 kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
156174 to override any default.
0 commit comments