|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +from torch.nn.modules.batchnorm import BatchNorm2d |
| 4 | +from torch.nn.modules.instancenorm import InstanceNorm2d |
| 5 | +from torchvision.models.optical_flow import RAFT |
| 6 | +from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock |
| 7 | + |
| 8 | +# from torchvision.prototype.transforms import RaftEval |
| 9 | + |
| 10 | +from .._api import WeightsEnum |
| 11 | + |
| 12 | +# from .._api import Weights |
| 13 | +from .._utils import handle_legacy_interface |
| 14 | + |
| 15 | + |
| 16 | +__all__ = ( |
| 17 | + "RAFT", |
| 18 | + "raft_large", |
| 19 | + "raft_small", |
| 20 | + "Raft_Large_Weights", |
| 21 | + "Raft_Small_Weights", |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +class Raft_Large_Weights(WeightsEnum): |
| 26 | + pass |
| 27 | + # C_T_V1 = Weights( |
| 28 | + # # Chairs + Things |
| 29 | + # url="", |
| 30 | + # transforms=RaftEval, |
| 31 | + # meta={ |
| 32 | + # "recipe": "", |
| 33 | + # "epe": -1234, |
| 34 | + # }, |
| 35 | + # ) |
| 36 | + |
| 37 | + # C_T_SKHT_V1 = Weights( |
| 38 | + # # Chairs + Things + Sintel fine-tuning, i.e.: |
| 39 | + # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) |
| 40 | + # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel |
| 41 | + # url="", |
| 42 | + # transforms=RaftEval, |
| 43 | + # meta={ |
| 44 | + # "recipe": "", |
| 45 | + # "epe": -1234, |
| 46 | + # }, |
| 47 | + # ) |
| 48 | + |
| 49 | + # C_T_SKHT_K_V1 = Weights( |
| 50 | + # # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: |
| 51 | + # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti |
| 52 | + # # Same as CT_SKHT with extra fine-tuning on Kitti |
| 53 | + # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti |
| 54 | + # url="", |
| 55 | + # transforms=RaftEval, |
| 56 | + # meta={ |
| 57 | + # "recipe": "", |
| 58 | + # "epe": -1234, |
| 59 | + # }, |
| 60 | + # ) |
| 61 | + |
| 62 | + # default = C_T_V1 |
| 63 | + |
| 64 | + |
| 65 | +class Raft_Small_Weights(WeightsEnum): |
| 66 | + pass |
| 67 | + # C_T_V1 = Weights( |
| 68 | + # url="", # TODO |
| 69 | + # transforms=RaftEval, |
| 70 | + # meta={ |
| 71 | + # "recipe": "", |
| 72 | + # "epe": -1234, |
| 73 | + # }, |
| 74 | + # ) |
| 75 | + # default = C_T_V1 |
| 76 | + |
| 77 | + |
| 78 | +@handle_legacy_interface(weights=("pretrained", None)) |
| 79 | +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): |
| 80 | + """RAFT model from |
| 81 | + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. |
| 82 | +
|
| 83 | + Args: |
| 84 | + weights(Raft_Large_weights, optinal): TODO not implemented yet |
| 85 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 86 | + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class |
| 87 | + to override any default. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + nn.Module: The model. |
| 91 | + """ |
| 92 | + |
| 93 | + weights = Raft_Large_Weights.verify(weights) |
| 94 | + |
| 95 | + return _raft( |
| 96 | + # Feature encoder |
| 97 | + feature_encoder_layers=(64, 64, 96, 128, 256), |
| 98 | + feature_encoder_block=ResidualBlock, |
| 99 | + feature_encoder_norm_layer=InstanceNorm2d, |
| 100 | + # Context encoder |
| 101 | + context_encoder_layers=(64, 64, 96, 128, 256), |
| 102 | + context_encoder_block=ResidualBlock, |
| 103 | + context_encoder_norm_layer=BatchNorm2d, |
| 104 | + # Correlation block |
| 105 | + corr_block_num_levels=4, |
| 106 | + corr_block_radius=4, |
| 107 | + # Motion encoder |
| 108 | + motion_encoder_corr_layers=(256, 192), |
| 109 | + motion_encoder_flow_layers=(128, 64), |
| 110 | + motion_encoder_out_channels=128, |
| 111 | + # Recurrent block |
| 112 | + recurrent_block_hidden_state_size=128, |
| 113 | + recurrent_block_kernel_size=((1, 5), (5, 1)), |
| 114 | + recurrent_block_padding=((0, 2), (2, 0)), |
| 115 | + # Flow head |
| 116 | + flow_head_hidden_size=256, |
| 117 | + # Mask predictor |
| 118 | + use_mask_predictor=True, |
| 119 | + **kwargs, |
| 120 | + ) |
| 121 | + |
| 122 | + |
| 123 | +@handle_legacy_interface(weights=("pretrained", None)) |
| 124 | +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): |
| 125 | + """RAFT "small" model from |
| 126 | + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. |
| 127 | +
|
| 128 | + Args: |
| 129 | + weights(Raft_Small_weights, optinal): TODO not implemented yet |
| 130 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 131 | + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class |
| 132 | + to override any default. |
| 133 | +
|
| 134 | + Returns: |
| 135 | + nn.Module: The model. |
| 136 | +
|
| 137 | + """ |
| 138 | + |
| 139 | + weights = Raft_Small_Weights.verify(weights) |
| 140 | + |
| 141 | + return _raft( |
| 142 | + # Feature encoder |
| 143 | + feature_encoder_layers=(32, 32, 64, 96, 128), |
| 144 | + feature_encoder_block=BottleneckBlock, |
| 145 | + feature_encoder_norm_layer=InstanceNorm2d, |
| 146 | + # Context encoder |
| 147 | + context_encoder_layers=(32, 32, 64, 96, 160), |
| 148 | + context_encoder_block=BottleneckBlock, |
| 149 | + context_encoder_norm_layer=None, |
| 150 | + # Correlation block |
| 151 | + corr_block_num_levels=4, |
| 152 | + corr_block_radius=3, |
| 153 | + # Motion encoder |
| 154 | + motion_encoder_corr_layers=(96,), |
| 155 | + motion_encoder_flow_layers=(64, 32), |
| 156 | + motion_encoder_out_channels=82, |
| 157 | + # Recurrent block |
| 158 | + recurrent_block_hidden_state_size=96, |
| 159 | + recurrent_block_kernel_size=(3,), |
| 160 | + recurrent_block_padding=(1,), |
| 161 | + # Flow head |
| 162 | + flow_head_hidden_size=128, |
| 163 | + # Mask predictor |
| 164 | + use_mask_predictor=False, |
| 165 | + **kwargs, |
| 166 | + ) |
0 commit comments