Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions tensorflow_addons/losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
"""Implements contrastive loss."""

import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import TensorLike, Number
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand All @@ -36,10 +36,19 @@ def contrastive_loss(
`a` and `b` with shape `[batch_size, hidden_size]` can be computed
as follows:

```python
# y_pred = \sqrt (\sum_i (a[:, i] - b[:, i])^2)
y_pred = tf.linalg.norm(a - b, axis=1)
```
>>> a = tf.constant([[1, 2],
... [3, 4],
... [5, 6]], dtype=tf.float16)
>>> b = tf.constant([[5, 9],
... [3, 6],
... [1, 8]], dtype=tf.float16)
>>> y_pred = tf.linalg.norm(a - b, axis=1)
>>> y_pred
<tf.Tensor: shape=(3,), dtype=float16, numpy=array([8.06 , 2. , 4.473],
dtype=float16)>

<... Note: constants a & b have been used purely for
example purposes and have no significant value ...>

See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

Expand Down Expand Up @@ -79,10 +88,17 @@ class ContrastiveLoss(LossFunctionWrapper):
`a` and `b` with shape `[batch_size, hidden_size]` can be computed
as follows:

```python
# y_pred = \sqrt (\sum_i (a[:, i] - b[:, i])^2)
y_pred = tf.linalg.norm(a - b, axis=1)
```
>>> a = tf.constant([[1, 2],
... [3, 4],[5, 6]], dtype=tf.float16)
>>> b = tf.constant([[5, 9],
... [3, 6],[1, 8]], dtype=tf.float16)
>>> y_pred = tf.linalg.norm(a - b, axis=1)
>>> y_pred
<tf.Tensor: shape=(3,), dtype=float16, numpy=array([8.06 , 2. , 4.473],
dtype=float16)>

<... Note: constants a & b have been used purely for
example purposes and have no significant value ...>

Args:
margin: `Float`, margin term in the loss definition.
Expand Down
23 changes: 9 additions & 14 deletions tensorflow_addons/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import tensorflow as tf
import tensorflow.keras.backend as K
from typeguard import typechecked

from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand All @@ -37,22 +37,17 @@ class SigmoidFocalCrossEntropy(LossFunctionWrapper):

Usage:

```python
fl = tfa.losses.SigmoidFocalCrossEntropy()
loss = fl(
y_true = [[1.0], [1.0], [0.0]],
y_pred = [[0.97], [0.91], [0.03]])
print('Loss: ', loss.numpy()) # Loss: [6.8532745e-06,
1.9097870e-04,
2.0559824e-05]
```
>>> fl = tfa.losses.SigmoidFocalCrossEntropy()
>>> loss = fl(
... y_true = [[1.0], [1.0], [0.0]],y_pred = [[0.97], [0.91], [0.03]])
>>> loss
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([6.8532745e-06, 1.9097870e-04, 2.0559824e-05],
dtype=float32)>

Usage with `tf.keras` API:

```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy())
```
>>> model = tf.keras.Model()
>>> model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy())

Args:
alpha: balancing factor, default value is 0.25.
Expand Down
25 changes: 12 additions & 13 deletions tensorflow_addons/losses/giou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# ==============================================================================
"""Implements GIoU loss."""

from typing import Optional

import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import TensorLike
from typing import Optional
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand All @@ -33,20 +35,17 @@ class GIoULoss(LossFunctionWrapper):

Usage:

```python
gl = tfa.losses.GIoULoss()
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
loss = gl(boxes1, boxes2)
print('Loss: ', loss.numpy()) # Loss: [1.07500000298023224, 1.9333333373069763]
```
>>> gl = tfa.losses.GIoULoss()
>>> boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
>>> boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
>>> loss = gl(boxes1, boxes2)
>>> loss
<tf.Tensor: shape=(), dtype=float32, numpy=1.5041667>

Usage with `tf.keras` API:

```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.GIoULoss())
```
>>> model = tf.keras.Model()
>>> model.compile('sgd', loss=tfa.losses.GIoULoss())

Args:
mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss.
Expand Down
32 changes: 16 additions & 16 deletions tensorflow_addons/losses/kappa_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,24 @@ class WeightedKappaLoss(tf.keras.losses.Loss):

Usage:

```python
kappa_loss = WeightedKappaLoss(num_classes=4)
y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
[1, 0, 0, 0], [0, 0, 0, 1]])
y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
[0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
loss = kappa_loss(y_true, y_pred)
print('Loss: ', loss.numpy()) # Loss: -1.1611923
```
>>> kappa_loss = tfa.losses.WeightedKappaLoss(num_classes=4)
>>> y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
... [1, 0, 0, 0], [0, 0, 0, 1]])
>>> y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
... [0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
>>> loss = kappa_loss(y_true, y_pred)
>>> loss
<tf.Tensor: shape=(), dtype=float32, numpy=-1.1611925>

Usage with `tf.keras` API:
```python
# outputs should be softmax results
# if you want to weight the samples, just multiply the outputs
# by the sample weight.
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.WeightedKappa(num_classes=4))
```

>>> model = tf.keras.Model()
>>> model.compile('sgd', loss=tfa.losses.WeightedKappaLoss(num_classes=4))

<... outputs should be softmax results
if you want to weight the samples, just multiply the outputs
by the sample weight ...>

"""

@typechecked
Expand Down
78 changes: 61 additions & 17 deletions tensorflow_addons/losses/npairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"""Implements npairs loss."""

import tensorflow as tf
from typeguard import typechecked

from tensorflow_addons.utils.types import TensorLike
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
Expand All @@ -33,10 +33,21 @@ def npairs_loss(y_true: TensorLike, y_pred: TensorLike) -> tf.Tensor:
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```
>>> a = tf.constant([[1, 2],
... [3, 4],
... [5, 6]], dtype=tf.float16)
>>> b = tf.constant([[5, 9],
... [3, 6],
... [1, 8]], dtype=tf.float16)
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
>>> y_pred
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
array([[23., 15., 17.],
[51., 33., 35.],
[79., 51., 53.]], dtype=float16)>

<... Note: constants a & b have been used purely for
example purposes and have no significant value ...>

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Expand Down Expand Up @@ -89,10 +100,21 @@ def npairs_multilabel_loss(y_true: TensorLike, y_pred: TensorLike) -> tf.Tensor:
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```
>>> a = tf.constant([[1, 2],
... [3, 4],
... [5, 6]], dtype=tf.float16)
>>> b = tf.constant([[5, 9],
... [3, 6],
... [1, 8]], dtype=tf.float16)
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
>>> y_pred
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
array([[23., 15., 17.],
[51., 33., 35.],
[79., 51., 53.]], dtype=float16)>

<... Note: constants a & b have been used purely for
example purposes and have no significant value ...>

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Expand Down Expand Up @@ -139,10 +161,21 @@ class NpairsLoss(tf.keras.losses.Loss):
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```
>>> a = tf.constant([[1, 2],
... [3, 4],
... [5, 6]], dtype=tf.float16)
>>> b = tf.constant([[5, 9],
... [3, 6],
... [1, 8]], dtype=tf.float16)
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
>>> y_pred
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
array([[23., 15., 17.],
[51., 33., 35.],
[79., 51., 53.]], dtype=float16)>

<... Note: constants a & b have been used purely for
example purposes and have no significant value ...>

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Expand Down Expand Up @@ -184,10 +217,21 @@ class NpairsMultilabelLoss(tf.keras.losses.Loss):
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```
>>> a = tf.constant([[1, 2],
... [3, 4],
... [5, 6]], dtype=tf.float16)
>>> b = tf.constant([[5, 9],
... [3, 6],
... [1, 8]], dtype=tf.float16)
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
>>> y_pred
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
array([[23., 15., 17.],
[51., 33., 35.],
[79., 51., 53.]], dtype=float16)>

<... Note: constants a & b have been used purely for
example purposes and have no significant value ...>

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Expand Down
29 changes: 10 additions & 19 deletions tensorflow_addons/losses/quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@ def pinball_loss(
See: https://en.wikipedia.org/wiki/Quantile_regression

Usage:
```python
loss = pinball_loss([0., 0., 1., 1.], [1., 1., 1., 0.], tau=.1)

# loss = max(0.1 * (y_true - y_pred), (0.1 - 1) * (y_true - y_pred))
# = (0.9 + 0.9 + 0 + 0.1) / 4

print('Loss: ', loss.numpy()) # Loss: 0.475
```
>>> loss = tfa.losses.pinball_loss([0., 0., 1., 1.],
... [1., 1., 1., 0.], tau=.1)
>>> loss
<tf.Tensor: shape=(), dtype=float32, numpy=0.475>

Args:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`
Expand Down Expand Up @@ -84,22 +81,16 @@ class PinballLoss(LossFunctionWrapper):
See: https://en.wikipedia.org/wiki/Quantile_regression

Usage:
```python
pinball = tfa.losses.PinballLoss(tau=.1)
loss = pinball([0., 0., 1., 1.], [1., 1., 1., 0.])

# loss = max(0.1 * (y_true - y_pred), (0.1 - 1) * (y_true - y_pred))
# = (0.9 + 0.9 + 0 + 0.1) / 4

print('Loss: ', loss.numpy()) # Loss: 0.475
```
>>> pinball = tfa.losses.PinballLoss(tau=.1)
>>> loss = pinball([0., 0., 1., 1.], [1., 1., 1., 0.])
>>> loss
<tf.Tensor: shape=(), dtype=float32, numpy=0.475>

Usage with the `tf.keras` API:

```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.PinballLoss(tau=.1))
```
>>> model = tf.keras.Model()
>>> model.compile('sgd', loss=tfa.losses.PinballLoss(tau=.1))

Args:
tau: (Optional) Float in [0, 1] or a tensor taking values in [0, 1] and
Expand Down