@@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
373373
374374_USE_SWISH_OPT = True
375375if _USE_SWISH_OPT :
376- class SwishAutoFn (torch .autograd .Function ):
377- """ Memory Efficient Swish
378- From: https://blog.ceshine.net/post/pytorch-memory-swish/
376+ @torch .jit .script
377+ def swish_jit_fwd (x ):
378+ return x .mul (torch .sigmoid (x ))
379+
380+
381+ @torch .jit .script
382+ def swish_jit_bwd (x , grad_output ):
383+ x_sigmoid = torch .sigmoid (x )
384+ return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid )))
385+
386+
387+ class SwishJitAutoFn (torch .autograd .Function ):
388+ """ torch.jit.script optimised Swish
389+ Inspired by conversation btw Jeremy Howard & Adam Pazske
390+ https://twitter.com/jeremyphoward/status/1188251041835315200
379391 """
392+
380393 @staticmethod
381394 def forward (ctx , x ):
382- result = x .mul (torch .sigmoid (x ))
383395 ctx .save_for_backward (x )
384- return result
396+ return swish_jit_fwd ( x )
385397
386398 @staticmethod
387399 def backward (ctx , grad_output ):
388- x = ctx .saved_variables [0 ]
389- sigmoid_x = torch .sigmoid (x )
390- return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x )))
400+ x = ctx .saved_tensors [0 ]
401+ return swish_jit_bwd (x , grad_output )
391402
392403
393404 def swish (x , inplace = False ):
394- return SwishAutoFn .apply (x )
405+ # inplace ignored
406+ return SwishJitAutoFn .apply (x )
395407else :
396408 def swish (x , inplace = False ):
397409 return x .mul_ (x .sigmoid ()) if inplace else x .mul (x .sigmoid ())
0 commit comments