Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import math
from collections import OrderedDict

Expand All @@ -26,7 +27,8 @@
logger = logging.get_logger(__name__)


class PytorchGELUTanh(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't it be a breaking change? 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so since the class is only used in the mapping in the end of the file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not part of the public api so should be alright!

@use_kernel_forward_from_hub("GeluTanh")
class GELUTanh(nn.Module):
"""
A fast C implementation of the tanh approximation of the GeLU activation function. See
https://huggingface.co/papers/1606.08415.
Expand All @@ -35,8 +37,18 @@ class PytorchGELUTanh(nn.Module):
match due to rounding errors.
"""

def __init__(self, use_gelu_tanh_python: bool = False):
super().__init__()
if use_gelu_tanh_python:
self.act = self._gelu_tanh_python
else:
self.act = functools.partial(nn.functional.gelu, approximate="tanh")

def _gelu_tanh_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

def forward(self, input: Tensor) -> Tensor:
return nn.functional.gelu(input, approximate="tanh")
return self.act(input)


@use_kernel_forward_from_hub("NewGELU")
Expand All @@ -50,6 +62,7 @@ def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))


@use_kernel_forward_from_hub("GeLU")
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
Expand All @@ -72,6 +85,20 @@ def forward(self, input: Tensor) -> Tensor:
return self.act(input)


@use_kernel_forward_from_hub("SiLU")
class SiLUActivation(nn.Module):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""

def forward(self, input: Tensor) -> Tensor:
return nn.functional.silu(input)


@use_kernel_forward_from_hub("FastGELU")
class FastGELUActivation(nn.Module):
"""
Expand Down Expand Up @@ -290,7 +317,8 @@ def forward(self, input: Tensor) -> Tensor:
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"gelu_pytorch_tanh": PytorchGELUTanh,
"gelu_pytorch_tanh": GELUTanh,
"gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
"gelu_accurate": AccurateGELUActivation,
"laplace": LaplaceActivation,
"leaky_relu": nn.LeakyReLU,
Expand All @@ -301,7 +329,7 @@ def forward(self, input: Tensor) -> Tensor:
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": nn.SiLU,
"silu": SiLUActivation,
"swish": nn.SiLU,
"tanh": nn.Tanh,
"prelu": nn.PReLU,
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,27 @@
)
}
},
"SiLU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation", layer_name="Silu", version=">=0.1.0"
)
}
},
"GeLU": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation", layer_name="Gelu", version=">=0.1.0"
)
}
},
"GeluTanh": {
"cuda": {
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation", layer_name="GeluTanh", version=">=0.1.0"
)
}
},
}

register_kernel_mapping(_KERNEL_MAPPING)
Expand Down