Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions torchvision/prototype/models/depth/stereo/raft_stereo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from functools import partial
from typing import Callable, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import register_model, WeightsEnum
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
from torchvision.ops import Conv2dNormActivation
from torchvision.prototype.transforms._presets import StereoMatching
from torchvision.utils import _log_api_usage_once


Expand Down Expand Up @@ -624,11 +626,97 @@ def _raft_stereo(


class Raft_Stereo_Realtime_Weights(WeightsEnum):
pass
SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_realtime-cf345ccb.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 8077152,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Kitty2015": {
"3px": 0.9409,
}
},
},
)

DEFAULT = SCENEFLOW_V1


class Raft_Stereo_Base_Weights(WeightsEnum):
pass
SCENEFLOW_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_sceneflow-eff3f2e6.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
# Using standard metrics for each datasets
"Kitty2015": {
# Ratio of pixels with difference less than 3px from ground truth
"3px": 0.9426,
},
# For middlebury, ratio of pixels with difference less than 2px from ground truth
# on full, half, and quarter image resolution
"Middlebury2014-val-full": {
"2px": 0.8167,
},
"Middlebury2014-val-half": {
"2px": 0.8741,
},
"Middlebury2014-val-quarter": {
"2px": 0.9064,
},
"ETH3D-val": {
# Ratio of pixels with difference less than 1px from ground truth
"1px": 0.9672,
},
},
},
)

MIDDLEBURY_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_middlebury-afa9d252.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"Middlebury-test": {
"mae": 1.27,
"1px": 0.9063,
"2px": 0.9526,
"5px": 0.9725,
}
},
},
)

ETH3D_V1 = Weights(
# Weights ported from https://github.com/princeton-vl/RAFT-Stereo
url="https://download.pytorch.org/models/raft_stereo_base_eth3d-d4830f22.pth",
transforms=partial(StereoMatching, resize_size=(224, 224)),
meta={
"num_params": 11116176,
"recipe": "https://github.com/princeton-vl/RAFT-Stereo",
"_metrics": {
# Following metrics from paper: https://arxiv.org/abs/2109.07547
"ETH3D-test": {
"mae": 0.18,
"1px": 0.9756,
"2px": 0.9956,
}
},
},
)

DEFAULT = MIDDLEBURY_V1


@register_model()
Expand Down