Skip to content

Commit 576d360

Browse files
committed
Bring in JIT version of optimized swish activation from gen_efficientnet as default (while working on feature extraction functionality here).
1 parent 1f39d15 commit 576d360

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

timm/models/gen_efficientnet.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
373373

374374
_USE_SWISH_OPT = True
375375
if _USE_SWISH_OPT:
376-
class SwishAutoFn(torch.autograd.Function):
377-
""" Memory Efficient Swish
378-
From: https://blog.ceshine.net/post/pytorch-memory-swish/
376+
@torch.jit.script
377+
def swish_jit_fwd(x):
378+
return x.mul(torch.sigmoid(x))
379+
380+
381+
@torch.jit.script
382+
def swish_jit_bwd(x, grad_output):
383+
x_sigmoid = torch.sigmoid(x)
384+
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
385+
386+
387+
class SwishJitAutoFn(torch.autograd.Function):
388+
""" torch.jit.script optimised Swish
389+
Inspired by conversation btw Jeremy Howard & Adam Pazske
390+
https://twitter.com/jeremyphoward/status/1188251041835315200
379391
"""
392+
380393
@staticmethod
381394
def forward(ctx, x):
382-
result = x.mul(torch.sigmoid(x))
383395
ctx.save_for_backward(x)
384-
return result
396+
return swish_jit_fwd(x)
385397

386398
@staticmethod
387399
def backward(ctx, grad_output):
388-
x = ctx.saved_variables[0]
389-
sigmoid_x = torch.sigmoid(x)
390-
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
400+
x = ctx.saved_tensors[0]
401+
return swish_jit_bwd(x, grad_output)
391402

392403

393404
def swish(x, inplace=False):
394-
return SwishAutoFn.apply(x)
405+
# inplace ignored
406+
return SwishJitAutoFn.apply(x)
395407
else:
396408
def swish(x, inplace=False):
397409
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())

0 commit comments

Comments
 (0)