|  | 
| 21 | 21 | 
 | 
| 22 | 22 | 
 | 
| 23 | 23 | _MODELS_URLS = { | 
| 24 |  | -    "raft_large": "https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", | 
|  | 24 | +    "raft_large": "https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", | 
| 25 | 25 |     "raft_small": "https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", | 
| 26 | 26 | } | 
| 27 | 27 | 
 | 
| @@ -587,10 +587,16 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): | 
| 587 | 587 |     `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. | 
| 588 | 588 | 
 | 
| 589 | 589 |     Args: | 
| 590 |  | -        pretrained (bool): Whether to use pretrained weights. | 
| 591 |  | -        progress (bool): If True, displays a progress bar of the download to stderr | 
| 592 |  | -        kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | 
| 593 |  | -            to override any default. | 
|  | 590 | +        pretrained (bool): Whether to use weights that have been pre-trained on | 
|  | 591 | +            :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D` | 
|  | 592 | +            with two fine-tuning steps: | 
|  | 593 | +
 | 
|  | 594 | +            - one on :class:`~torchvsion.datasets.Sintel` + :class:`~torchvsion.datasets.FlyingThings3D` | 
|  | 595 | +            - one on :class:`~torchvsion.datasets.KittiFlow`. | 
|  | 596 | +
 | 
|  | 597 | +            This corresponds to the ``C+T+S/K`` strategy in the paper. | 
|  | 598 | +
 | 
|  | 599 | +        progress (bool): If True, displays a progress bar of the download to stderr. | 
| 594 | 600 | 
 | 
| 595 | 601 |     Returns: | 
| 596 | 602 |         nn.Module: The model. | 
| @@ -632,10 +638,9 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): | 
| 632 | 638 |     `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. | 
| 633 | 639 | 
 | 
| 634 | 640 |     Args: | 
| 635 |  | -        pretrained (bool): Whether to use pretrained weights. | 
|  | 641 | +        pretrained (bool): Whether to use weights that have been pre-trained on | 
|  | 642 | +            :class:`~torchvsion.datasets.FlyingChairs` + :class:`~torchvsion.datasets.FlyingThings3D`. | 
| 636 | 643 |         progress (bool): If True, displays a progress bar of the download to stderr | 
| 637 |  | -        kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | 
| 638 |  | -            to override any default. | 
| 639 | 644 | 
 | 
| 640 | 645 |     Returns: | 
| 641 | 646 |         nn.Module: The model. | 
|  | 
0 commit comments