Skip to content

Commit 3b264a6

Browse files
committed
Add encoder_kwargs to DeepLab and Linknet
1 parent 8bf52c7 commit 3b264a6

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
activation: Optional[str] = None,
5757
upsampling: int = 8,
5858
aux_params: Optional[dict] = None,
59+
encoder_kwargs: dict = {}
5960
):
6061
super().__init__()
6162

@@ -65,6 +66,7 @@ def __init__(
6566
depth=encoder_depth,
6667
weights=encoder_weights,
6768
output_stride=8,
69+
**encoder_kwargs
6870
)
6971

7072
self.decoder = DeepLabV3Decoder(
@@ -137,6 +139,7 @@ def __init__(
137139
activation: Optional[str] = None,
138140
upsampling: int = 4,
139141
aux_params: Optional[dict] = None,
142+
encoder_kwargs: dict = {}
140143
):
141144
super().__init__()
142145

@@ -149,6 +152,7 @@ def __init__(
149152
depth=encoder_depth,
150153
weights=encoder_weights,
151154
output_stride=encoder_output_stride,
155+
**encoder_kwargs
152156
)
153157

154158
self.decoder = DeepLabV3PlusDecoder(

segmentation_models_pytorch/decoders/linknet/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
classes: int = 1,
6262
activation: Optional[Union[str, callable]] = None,
6363
aux_params: Optional[dict] = None,
64+
encoder_kwargs: dict = {}
6465
):
6566
super().__init__()
6667

@@ -69,6 +70,7 @@ def __init__(
6970
in_channels=in_channels,
7071
depth=encoder_depth,
7172
weights=encoder_weights,
73+
**encoder_kwargs
7274
)
7375

7476
self.decoder = LinknetDecoder(

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44

55
class TimmUniversalEncoder(nn.Module):
6-
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32):
6+
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32, **kwargs):
77
super().__init__()
88
kwargs = dict(
9+
**kwargs,
910
in_chans=in_channels,
1011
features_only=True,
1112
output_stride=output_stride,

0 commit comments

Comments
 (0)