|
| 1 | +from functools import partial |
1 | 2 | from typing import Callable, List, Optional, Tuple |
2 | 3 |
|
3 | 4 | import torch |
4 | 5 | import torch.nn as nn |
5 | 6 | import torch.nn.functional as F |
6 | 7 | import torchvision.models.optical_flow.raft as raft |
7 | 8 | from torch import Tensor |
8 | | -from torchvision.models._api import register_model, WeightsEnum |
| 9 | +from torchvision.models._api import register_model, Weights, WeightsEnum |
9 | 10 | from torchvision.models._utils import handle_legacy_interface |
10 | 11 | from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow |
11 | 12 | from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock |
12 | 13 | from torchvision.ops import Conv2dNormActivation |
| 14 | +from torchvision.prototype.transforms._presets import StereoMatching |
13 | 15 | from torchvision.utils import _log_api_usage_once |
14 | 16 |
|
15 | 17 |
|
@@ -624,11 +626,97 @@ def _raft_stereo( |
624 | 626 |
|
625 | 627 |
|
626 | 628 | class Raft_Stereo_Realtime_Weights(WeightsEnum): |
627 | | - pass |
| 629 | + SCENEFLOW_V1 = Weights( |
| 630 | + # Weights ported from https://github.com/princeton-vl/RAFT-Stereo |
| 631 | + url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth", |
| 632 | + transforms=partial(StereoMatching, resize_size=(224, 224)), |
| 633 | + meta={ |
| 634 | + "num_params": 8077152, |
| 635 | + "recipe": "https://github.com/princeton-vl/RAFT-Stereo", |
| 636 | + "_metrics": { |
| 637 | + # Following metrics from paper: https://arxiv.org/abs/2109.07547 |
| 638 | + "Kitty2015": { |
| 639 | + "3px": 0.9409, |
| 640 | + } |
| 641 | + }, |
| 642 | + }, |
| 643 | + ) |
| 644 | + |
| 645 | + DEFAULT = SCENEFLOW_V1 |
628 | 646 |
|
629 | 647 |
|
630 | 648 | class Raft_Stereo_Base_Weights(WeightsEnum): |
631 | | - pass |
| 649 | + SCENEFLOW_V1 = Weights( |
| 650 | + # Weights ported from https://github.com/princeton-vl/RAFT-Stereo |
| 651 | + url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth", |
| 652 | + transforms=partial(StereoMatching, resize_size=(224, 224)), |
| 653 | + meta={ |
| 654 | + "num_params": 11116176, |
| 655 | + "recipe": "https://github.com/princeton-vl/RAFT-Stereo", |
| 656 | + "_metrics": { |
| 657 | + # Following metrics from paper: https://arxiv.org/abs/2109.07547 |
| 658 | + # Using standard metrics for each datasets |
| 659 | + "Kitty2015": { |
| 660 | + # Ratio of pixels with difference less than 3px from ground truth |
| 661 | + "3px": 0.9426, |
| 662 | + }, |
| 663 | + # For middlebury, ratio of pixels with difference less than 2px from ground truth |
| 664 | + # on full, half, and quarter image resolution |
| 665 | + "Middlebury2014-val-full": { |
| 666 | + "2px": 0.8167, |
| 667 | + }, |
| 668 | + "Middlebury2014-val-half": { |
| 669 | + "2px": 0.8741, |
| 670 | + }, |
| 671 | + "Middlebury2014-val-quarter": { |
| 672 | + "2px": 0.9064, |
| 673 | + }, |
| 674 | + "ETH3D-val": { |
| 675 | + # Ratio of pixels with difference less than 1px from ground truth |
| 676 | + "1px": 0.9672, |
| 677 | + }, |
| 678 | + }, |
| 679 | + }, |
| 680 | + ) |
| 681 | + |
| 682 | + MIDDLEBURY_V1 = Weights( |
| 683 | + # Weights ported from https://github.com/princeton-vl/RAFT-Stereo |
| 684 | + url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth", |
| 685 | + transforms=partial(StereoMatching, resize_size=(224, 224)), |
| 686 | + meta={ |
| 687 | + "num_params": 11116176, |
| 688 | + "recipe": "https://github.com/princeton-vl/RAFT-Stereo", |
| 689 | + "_metrics": { |
| 690 | + # Following metrics from paper: https://arxiv.org/abs/2109.07547 |
| 691 | + "Middlebury-test": { |
| 692 | + "mae": 1.27, |
| 693 | + "1px": 0.9063, |
| 694 | + "2px": 0.9526, |
| 695 | + "5px": 0.9725, |
| 696 | + } |
| 697 | + }, |
| 698 | + }, |
| 699 | + ) |
| 700 | + |
| 701 | + ETH3D_V1 = Weights( |
| 702 | + # Weights ported from https://github.com/princeton-vl/RAFT-Stereo |
| 703 | + url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth", |
| 704 | + transforms=partial(StereoMatching, resize_size=(224, 224)), |
| 705 | + meta={ |
| 706 | + "num_params": 11116176, |
| 707 | + "recipe": "https://github.com/princeton-vl/RAFT-Stereo", |
| 708 | + "_metrics": { |
| 709 | + # Following metrics from paper: https://arxiv.org/abs/2109.07547 |
| 710 | + "ETH3D-test": { |
| 711 | + "mae": 0.18, |
| 712 | + "1px": 0.9756, |
| 713 | + "2px": 0.9956, |
| 714 | + } |
| 715 | + }, |
| 716 | + }, |
| 717 | + ) |
| 718 | + |
| 719 | + DEFAULT = MIDDLEBURY_V1 |
632 | 720 |
|
633 | 721 |
|
634 | 722 | @register_model() |
|
0 commit comments