Skip to content

Commit b2a948a

Browse files
Joao Gomesfacebook-github-bot
authored andcommitted
[fbsync] Add raft_stereo weights (#6786)
Summary: * Add raft_stereo weights * Update the metrics layout Reviewed By: YosuaMichael Differential Revision: D40588172 fbshipit-source-id: 32c56534abc6f71effcf25ff11f906306cec34fe
1 parent 87c13ee commit b2a948a

File tree

1 file changed

+91
-3
lines changed

1 file changed

+91
-3
lines changed

torchvision/prototype/models/depth/stereo/raft_stereo.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
from functools import partial
12
from typing import Callable, List, Optional, Tuple
23

34
import torch
45
import torch.nn as nn
56
import torch.nn.functional as F
67
import torchvision.models.optical_flow.raft as raft
78
from torch import Tensor
8-
from torchvision.models._api import register_model, WeightsEnum
9+
from torchvision.models._api import register_model, Weights, WeightsEnum
910
from torchvision.models._utils import handle_legacy_interface
1011
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
1112
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
1213
from torchvision.ops import Conv2dNormActivation
14+
from torchvision.prototype.transforms._presets import StereoMatching
1315
from torchvision.utils import _log_api_usage_once
1416

1517

@@ -624,11 +626,97 @@ def _raft_stereo(
624626

625627

626628
class Raft_Stereo_Realtime_Weights(WeightsEnum):
627-
pass
629+
SCENEFLOW_V1 = Weights(
630+
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
631+
url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth",
632+
transforms=partial(StereoMatching, resize_size=(224, 224)),
633+
meta={
634+
"num_params": 8077152,
635+
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
636+
"_metrics": {
637+
# Following metrics from paper: https://arxiv.org/abs/2109.07547
638+
"Kitty2015": {
639+
"3px": 0.9409,
640+
}
641+
},
642+
},
643+
)
644+
645+
DEFAULT = SCENEFLOW_V1
628646

629647

630648
class Raft_Stereo_Base_Weights(WeightsEnum):
631-
pass
649+
SCENEFLOW_V1 = Weights(
650+
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
651+
url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth",
652+
transforms=partial(StereoMatching, resize_size=(224, 224)),
653+
meta={
654+
"num_params": 11116176,
655+
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
656+
"_metrics": {
657+
# Following metrics from paper: https://arxiv.org/abs/2109.07547
658+
# Using standard metrics for each datasets
659+
"Kitty2015": {
660+
# Ratio of pixels with difference less than 3px from ground truth
661+
"3px": 0.9426,
662+
},
663+
# For middlebury, ratio of pixels with difference less than 2px from ground truth
664+
# on full, half, and quarter image resolution
665+
"Middlebury2014-val-full": {
666+
"2px": 0.8167,
667+
},
668+
"Middlebury2014-val-half": {
669+
"2px": 0.8741,
670+
},
671+
"Middlebury2014-val-quarter": {
672+
"2px": 0.9064,
673+
},
674+
"ETH3D-val": {
675+
# Ratio of pixels with difference less than 1px from ground truth
676+
"1px": 0.9672,
677+
},
678+
},
679+
},
680+
)
681+
682+
MIDDLEBURY_V1 = Weights(
683+
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
684+
url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth",
685+
transforms=partial(StereoMatching, resize_size=(224, 224)),
686+
meta={
687+
"num_params": 11116176,
688+
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
689+
"_metrics": {
690+
# Following metrics from paper: https://arxiv.org/abs/2109.07547
691+
"Middlebury-test": {
692+
"mae": 1.27,
693+
"1px": 0.9063,
694+
"2px": 0.9526,
695+
"5px": 0.9725,
696+
}
697+
},
698+
},
699+
)
700+
701+
ETH3D_V1 = Weights(
702+
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
703+
url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth",
704+
transforms=partial(StereoMatching, resize_size=(224, 224)),
705+
meta={
706+
"num_params": 11116176,
707+
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
708+
"_metrics": {
709+
# Following metrics from paper: https://arxiv.org/abs/2109.07547
710+
"ETH3D-test": {
711+
"mae": 0.18,
712+
"1px": 0.9756,
713+
"2px": 0.9956,
714+
}
715+
},
716+
},
717+
)
718+
719+
DEFAULT = MIDDLEBURY_V1
632720

633721

634722
@register_model()

0 commit comments

Comments
 (0)