33A collection of activations fn and modules with a common interface so that they can
44easily be swapped. All have an `inplace` arg even if not used.
55
6- These activations are not compatible with jit scripting or ONNX export of the model, please use either
7- the JIT or basic versions of the activations.
6+ These activations are not compatible with jit scripting or ONNX export of the model, please use
7+ basic versions of the activations.
88
99Hacked together by / Copyright 2020 Ross Wightman
1010"""
1414from torch .nn import functional as F
1515
1616
17- @torch .jit .script
18- def swish_jit_fwd (x ):
17+ def swish_fwd (x ):
1918 return x .mul (torch .sigmoid (x ))
2019
2120
22- @torch .jit .script
23- def swish_jit_bwd (x , grad_output ):
21+ def swish_bwd (x , grad_output ):
2422 x_sigmoid = torch .sigmoid (x )
2523 return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid )))
2624
2725
28- class SwishJitAutoFn (torch .autograd .Function ):
29- """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
26+ class SwishAutoFn (torch .autograd .Function ):
27+ """ optimised Swish w/ memory-efficient checkpoint
3028 Inspired by conversation btw Jeremy Howard & Adam Pazske
3129 https://twitter.com/jeremyphoward/status/1188251041835315200
3230 """
@@ -37,123 +35,117 @@ def symbolic(g, x):
3735 @staticmethod
3836 def forward (ctx , x ):
3937 ctx .save_for_backward (x )
40- return swish_jit_fwd (x )
38+ return swish_fwd (x )
4139
4240 @staticmethod
4341 def backward (ctx , grad_output ):
4442 x = ctx .saved_tensors [0 ]
45- return swish_jit_bwd (x , grad_output )
43+ return swish_bwd (x , grad_output )
4644
4745
4846def swish_me (x , inplace = False ):
49- return SwishJitAutoFn .apply (x )
47+ return SwishAutoFn .apply (x )
5048
5149
5250class SwishMe (nn .Module ):
5351 def __init__ (self , inplace : bool = False ):
5452 super (SwishMe , self ).__init__ ()
5553
5654 def forward (self , x ):
57- return SwishJitAutoFn .apply (x )
55+ return SwishAutoFn .apply (x )
5856
5957
60- @torch .jit .script
61- def mish_jit_fwd (x ):
58+ def mish_fwd (x ):
6259 return x .mul (torch .tanh (F .softplus (x )))
6360
6461
65- @torch .jit .script
66- def mish_jit_bwd (x , grad_output ):
62+ def mish_bwd (x , grad_output ):
6763 x_sigmoid = torch .sigmoid (x )
6864 x_tanh_sp = F .softplus (x ).tanh ()
6965 return grad_output .mul (x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp ))
7066
7167
72- class MishJitAutoFn (torch .autograd .Function ):
68+ class MishAutoFn (torch .autograd .Function ):
7369 """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
74- A memory efficient, jit scripted variant of Mish
70+ A memory efficient variant of Mish
7571 """
7672 @staticmethod
7773 def forward (ctx , x ):
7874 ctx .save_for_backward (x )
79- return mish_jit_fwd (x )
75+ return mish_fwd (x )
8076
8177 @staticmethod
8278 def backward (ctx , grad_output ):
8379 x = ctx .saved_tensors [0 ]
84- return mish_jit_bwd (x , grad_output )
80+ return mish_bwd (x , grad_output )
8581
8682
8783def mish_me (x , inplace = False ):
88- return MishJitAutoFn .apply (x )
84+ return MishAutoFn .apply (x )
8985
9086
9187class MishMe (nn .Module ):
9288 def __init__ (self , inplace : bool = False ):
9389 super (MishMe , self ).__init__ ()
9490
9591 def forward (self , x ):
96- return MishJitAutoFn .apply (x )
92+ return MishAutoFn .apply (x )
9793
9894
99- @torch .jit .script
100- def hard_sigmoid_jit_fwd (x , inplace : bool = False ):
95+ def hard_sigmoid_fwd (x , inplace : bool = False ):
10196 return (x + 3 ).clamp (min = 0 , max = 6 ).div (6. )
10297
10398
104- @torch .jit .script
105- def hard_sigmoid_jit_bwd (x , grad_output ):
99+ def hard_sigmoid_bwd (x , grad_output ):
106100 m = torch .ones_like (x ) * ((x >= - 3. ) & (x <= 3. )) / 6.
107101 return grad_output * m
108102
109103
110- class HardSigmoidJitAutoFn (torch .autograd .Function ):
104+ class HardSigmoidAutoFn (torch .autograd .Function ):
111105 @staticmethod
112106 def forward (ctx , x ):
113107 ctx .save_for_backward (x )
114- return hard_sigmoid_jit_fwd (x )
108+ return hard_sigmoid_fwd (x )
115109
116110 @staticmethod
117111 def backward (ctx , grad_output ):
118112 x = ctx .saved_tensors [0 ]
119- return hard_sigmoid_jit_bwd (x , grad_output )
113+ return hard_sigmoid_bwd (x , grad_output )
120114
121115
122116def hard_sigmoid_me (x , inplace : bool = False ):
123- return HardSigmoidJitAutoFn .apply (x )
117+ return HardSigmoidAutoFn .apply (x )
124118
125119
126120class HardSigmoidMe (nn .Module ):
127121 def __init__ (self , inplace : bool = False ):
128122 super (HardSigmoidMe , self ).__init__ ()
129123
130124 def forward (self , x ):
131- return HardSigmoidJitAutoFn .apply (x )
125+ return HardSigmoidAutoFn .apply (x )
132126
133127
134- @torch .jit .script
135- def hard_swish_jit_fwd (x ):
128+ def hard_swish_fwd (x ):
136129 return x * (x + 3 ).clamp (min = 0 , max = 6 ).div (6. )
137130
138131
139- @torch .jit .script
140- def hard_swish_jit_bwd (x , grad_output ):
132+ def hard_swish_bwd (x , grad_output ):
141133 m = torch .ones_like (x ) * (x >= 3. )
142134 m = torch .where ((x >= - 3. ) & (x <= 3. ), x / 3. + .5 , m )
143135 return grad_output * m
144136
145137
146- class HardSwishJitAutoFn (torch .autograd .Function ):
147- """A memory efficient, jit-scripted HardSwish activation"""
138+ class HardSwishAutoFn (torch .autograd .Function ):
139+ """A memory efficient HardSwish activation"""
148140 @staticmethod
149141 def forward (ctx , x ):
150142 ctx .save_for_backward (x )
151- return hard_swish_jit_fwd (x )
143+ return hard_swish_fwd (x )
152144
153145 @staticmethod
154146 def backward (ctx , grad_output ):
155147 x = ctx .saved_tensors [0 ]
156- return hard_swish_jit_bwd (x , grad_output )
148+ return hard_swish_bwd (x , grad_output )
157149
158150 @staticmethod
159151 def symbolic (g , self ):
@@ -164,55 +156,53 @@ def symbolic(g, self):
164156
165157
166158def hard_swish_me (x , inplace = False ):
167- return HardSwishJitAutoFn .apply (x )
159+ return HardSwishAutoFn .apply (x )
168160
169161
170162class HardSwishMe (nn .Module ):
171163 def __init__ (self , inplace : bool = False ):
172164 super (HardSwishMe , self ).__init__ ()
173165
174166 def forward (self , x ):
175- return HardSwishJitAutoFn .apply (x )
167+ return HardSwishAutoFn .apply (x )
176168
177169
178- @torch .jit .script
179- def hard_mish_jit_fwd (x ):
170+ def hard_mish_fwd (x ):
180171 return 0.5 * x * (x + 2 ).clamp (min = 0 , max = 2 )
181172
182173
183- @torch .jit .script
184- def hard_mish_jit_bwd (x , grad_output ):
174+ def hard_mish_bwd (x , grad_output ):
185175 m = torch .ones_like (x ) * (x >= - 2. )
186176 m = torch .where ((x >= - 2. ) & (x <= 0. ), x + 1. , m )
187177 return grad_output * m
188178
189179
190- class HardMishJitAutoFn (torch .autograd .Function ):
191- """ A memory efficient, jit scripted variant of Hard Mish
180+ class HardMishAutoFn (torch .autograd .Function ):
181+ """ A memory efficient variant of Hard Mish
192182 Experimental, based on notes by Mish author Diganta Misra at
193183 https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
194184 """
195185 @staticmethod
196186 def forward (ctx , x ):
197187 ctx .save_for_backward (x )
198- return hard_mish_jit_fwd (x )
188+ return hard_mish_fwd (x )
199189
200190 @staticmethod
201191 def backward (ctx , grad_output ):
202192 x = ctx .saved_tensors [0 ]
203- return hard_mish_jit_bwd (x , grad_output )
193+ return hard_mish_bwd (x , grad_output )
204194
205195
206196def hard_mish_me (x , inplace : bool = False ):
207- return HardMishJitAutoFn .apply (x )
197+ return HardMishAutoFn .apply (x )
208198
209199
210200class HardMishMe (nn .Module ):
211201 def __init__ (self , inplace : bool = False ):
212202 super (HardMishMe , self ).__init__ ()
213203
214204 def forward (self , x ):
215- return HardMishJitAutoFn .apply (x )
205+ return HardMishAutoFn .apply (x )
216206
217207
218208
0 commit comments