44from torch .nn .modules .instancenorm import InstanceNorm2d
55from torchvision .models .optical_flow import RAFT
66from torchvision .models .optical_flow .raft import _raft , BottleneckBlock , ResidualBlock
7-
8- # from torchvision.prototype. transforms import RaftEval
7+ from torchvision . prototype . transforms import RaftEval
8+ from torchvision .transforms . functional import InterpolationMode
99
1010from .._api import WeightsEnum
11-
12- # from .._api import Weights
11+ from .._api import Weights
1312from .._utils import handle_legacy_interface
1413
1514
2221)
2322
2423
24+ _COMMON_META = {"interpolation" : InterpolationMode .BILINEAR }
25+
26+
2527class Raft_Large_Weights (WeightsEnum ):
26- pass
27- # C_T_V1 = Weights(
28- # # Chairs + Things
29- # url="",
30- # transforms=RaftEval,
31- # meta={
32- # "recipe": "",
33- # "epe": -1234,
34- # },
35- # )
28+ C_T_V1 = Weights (
29+ # Chairs + Things, ported from original paper repo (raft-things.pth)
30+ url = "https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth" ,
31+ transforms = RaftEval ,
32+ meta = {
33+ ** _COMMON_META ,
34+ "recipe" : "https://github.com/princeton-vl/RAFT" ,
35+ "sintel_train_cleanpass_epe" : 1.4411 ,
36+ "sintel_train_finalpass_epe" : 2.7894 ,
37+ },
38+ )
39+
40+ C_T_V2 = Weights (
41+ # Chairs + Things
42+ url = "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth" ,
43+ transforms = RaftEval ,
44+ meta = {
45+ ** _COMMON_META ,
46+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/optical_flow" ,
47+ "sintel_train_cleanpass_epe" : 1.3822 ,
48+ "sintel_train_finalpass_epe" : 2.7161 ,
49+ },
50+ )
3651
3752 # C_T_SKHT_V1 = Weights(
3853 # # Chairs + Things + Sintel fine-tuning, i.e.:
@@ -59,7 +74,7 @@ class Raft_Large_Weights(WeightsEnum):
5974 # },
6075 # )
6176
62- # default = C_T_V1
77+ default = C_T_V2
6378
6479
6580class Raft_Small_Weights (WeightsEnum ):
@@ -75,13 +90,13 @@ class Raft_Small_Weights(WeightsEnum):
7590 # default = C_T_V1
7691
7792
78- @handle_legacy_interface (weights = ("pretrained" , None ))
93+ @handle_legacy_interface (weights = ("pretrained" , Raft_Large_Weights . C_T_V2 ))
7994def raft_large (* , weights : Optional [Raft_Large_Weights ] = None , progress = True , ** kwargs ):
8095 """RAFT model from
8196 `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.
8297
8398 Args:
84- weights(Raft_Large_weights, optinal ): TODO not implemented yet
99+ weights(Raft_Large_weights, optional ): pretrained weights to use.
85100 progress (bool): If True, displays a progress bar of the download to stderr
86101 kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
87102 to override any default.
@@ -92,7 +107,7 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
92107
93108 weights = Raft_Large_Weights .verify (weights )
94109
95- return _raft (
110+ model = _raft (
96111 # Feature encoder
97112 feature_encoder_layers = (64 , 64 , 96 , 128 , 256 ),
98113 feature_encoder_block = ResidualBlock ,
@@ -119,6 +134,11 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, *
119134 ** kwargs ,
120135 )
121136
137+ if weights is not None :
138+ model .load_state_dict (weights .get_state_dict (progress = progress ))
139+
140+ return model
141+
122142
123143@handle_legacy_interface (weights = ("pretrained" , None ))
124144def raft_small (* , weights : Optional [Raft_Small_Weights ] = None , progress = True , ** kwargs ):
@@ -138,7 +158,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
138158
139159 weights = Raft_Small_Weights .verify (weights )
140160
141- return _raft (
161+ model = _raft (
142162 # Feature encoder
143163 feature_encoder_layers = (32 , 32 , 64 , 96 , 128 ),
144164 feature_encoder_block = BottleneckBlock ,
@@ -164,3 +184,7 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, *
164184 use_mask_predictor = False ,
165185 ** kwargs ,
166186 )
187+
188+ if weights is not None :
189+ model .load_state_dict (weights .get_state_dict (progress = progress ))
190+ return model
0 commit comments