3
3
4
4
from . import initialization as init
5
5
from .hub_mixin import SMPHubMixin
6
+ from .utils import is_torch_compiling
6
7
7
8
T = TypeVar ("T" , bound = "SegmentationModel" )
8
9
9
10
10
11
class SegmentationModel (torch .nn .Module , SMPHubMixin ):
11
12
"""Base class for all segmentation models."""
12
13
13
- # if model supports shape not divisible by 2 ^ n
14
- # set to False
14
+ _is_torch_scriptable = True
15
+ _is_torch_exportable = True
16
+ _is_torch_compilable = True
17
+
18
+ # if model supports shape not divisible by 2 ^ n set to False
15
19
requires_divisible_input_shape = True
16
20
17
21
# Fix type-hint for models, to avoid HubMixin signature
@@ -29,6 +33,9 @@ def check_input_shape(self, x):
29
33
"""Check if the input shape is divisible by the output stride.
30
34
If not, raise a RuntimeError.
31
35
"""
36
+ if not self .requires_divisible_input_shape :
37
+ return
38
+
32
39
h , w = x .shape [- 2 :]
33
40
output_stride = self .encoder .output_stride
34
41
if h % output_stride != 0 or w % output_stride != 0 :
@@ -50,11 +57,13 @@ def check_input_shape(self, x):
50
57
def forward (self , x ):
51
58
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
52
59
53
- if not torch .jit .is_tracing () and self .requires_divisible_input_shape :
60
+ if not (
61
+ torch .jit .is_scripting () or torch .jit .is_tracing () or is_torch_compiling ()
62
+ ):
54
63
self .check_input_shape (x )
55
64
56
65
features = self .encoder (x )
57
- decoder_output = self .decoder (* features )
66
+ decoder_output = self .decoder (features )
58
67
59
68
masks = self .segmentation_head (decoder_output )
60
69
@@ -81,3 +90,29 @@ def predict(self, x):
81
90
x = self .forward (x )
82
91
83
92
return x
93
+
94
+ def load_state_dict (self , state_dict , ** kwargs ):
95
+ # for compatibility of weights for
96
+ # timm- ported encoders with TimmUniversalEncoder
97
+ from segmentation_models_pytorch .encoders import TimmUniversalEncoder
98
+
99
+ if not isinstance (self .encoder , TimmUniversalEncoder ):
100
+ return super ().load_state_dict (state_dict , ** kwargs )
101
+
102
+ patterns = ["regnet" , "res2" , "resnest" , "mobilenetv3" , "gernet" ]
103
+
104
+ is_deprecated_encoder = any (
105
+ self .encoder .name .startswith (pattern ) for pattern in patterns
106
+ )
107
+
108
+ if is_deprecated_encoder :
109
+ keys = list (state_dict .keys ())
110
+ for key in keys :
111
+ new_key = key
112
+ if key .startswith ("encoder." ) and not key .startswith ("encoder.model." ):
113
+ new_key = "encoder.model." + key .removeprefix ("encoder." )
114
+ if "gernet" in self .encoder .name :
115
+ new_key = new_key .replace (".stages." , ".stages_" )
116
+ state_dict [new_key ] = state_dict .pop (key )
117
+
118
+ return super ().load_state_dict (state_dict , ** kwargs )
0 commit comments