Skip to content

Commit d466cb8

Browse files
authored
#2066 doctest update losses (#2138)
* updated contrastive.py * updated focal_loss.py * updated giou_loss.py * updated kappa_loss.py * updated npairs.py * updated quantiles.py * reformatting giou_loss.py * updated testable docs contrastive.py * updated doctests focal_loss.py * updated testdocs giou_loss.py * updated testdocs kappa_loss.py * updated testdocs npairs.py * updated testdocs quantiles.py * minor changes to formatting * fixing multi-line error * adding empty line inbetween imports * adding empty line inbetween imports * reformatting imports * reverting code formatting changes * reformatting imports
1 parent 8d3789f commit d466cb8

File tree

6 files changed

+133
-88
lines changed

6 files changed

+133
-88
lines changed

tensorflow_addons/losses/contrastive.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
"""Implements contrastive loss."""
1616

1717
import tensorflow as tf
18+
from typeguard import typechecked
1819

1920
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
2021
from tensorflow_addons.utils.types import TensorLike, Number
21-
from typeguard import typechecked
2222

2323

2424
@tf.keras.utils.register_keras_serializable(package="Addons")
@@ -36,10 +36,19 @@ def contrastive_loss(
3636
`a` and `b` with shape `[batch_size, hidden_size]` can be computed
3737
as follows:
3838
39-
```python
40-
# y_pred = \sqrt (\sum_i (a[:, i] - b[:, i])^2)
41-
y_pred = tf.linalg.norm(a - b, axis=1)
42-
```
39+
>>> a = tf.constant([[1, 2],
40+
... [3, 4],
41+
... [5, 6]], dtype=tf.float16)
42+
>>> b = tf.constant([[5, 9],
43+
... [3, 6],
44+
... [1, 8]], dtype=tf.float16)
45+
>>> y_pred = tf.linalg.norm(a - b, axis=1)
46+
>>> y_pred
47+
<tf.Tensor: shape=(3,), dtype=float16, numpy=array([8.06 , 2. , 4.473],
48+
dtype=float16)>
49+
50+
<... Note: constants a & b have been used purely for
51+
example purposes and have no significant value ...>
4352
4453
See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
4554
@@ -79,10 +88,17 @@ class ContrastiveLoss(LossFunctionWrapper):
7988
`a` and `b` with shape `[batch_size, hidden_size]` can be computed
8089
as follows:
8190
82-
```python
83-
# y_pred = \sqrt (\sum_i (a[:, i] - b[:, i])^2)
84-
y_pred = tf.linalg.norm(a - b, axis=1)
85-
```
91+
>>> a = tf.constant([[1, 2],
92+
... [3, 4],[5, 6]], dtype=tf.float16)
93+
>>> b = tf.constant([[5, 9],
94+
... [3, 6],[1, 8]], dtype=tf.float16)
95+
>>> y_pred = tf.linalg.norm(a - b, axis=1)
96+
>>> y_pred
97+
<tf.Tensor: shape=(3,), dtype=float16, numpy=array([8.06 , 2. , 4.473],
98+
dtype=float16)>
99+
100+
<... Note: constants a & b have been used purely for
101+
example purposes and have no significant value ...>
86102
87103
Args:
88104
margin: `Float`, margin term in the loss definition.

tensorflow_addons/losses/focal_loss.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import tensorflow as tf
1818
import tensorflow.keras.backend as K
19+
from typeguard import typechecked
1920

2021
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
2122
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
22-
from typeguard import typechecked
2323

2424

2525
@tf.keras.utils.register_keras_serializable(package="Addons")
@@ -37,22 +37,17 @@ class SigmoidFocalCrossEntropy(LossFunctionWrapper):
3737
3838
Usage:
3939
40-
```python
41-
fl = tfa.losses.SigmoidFocalCrossEntropy()
42-
loss = fl(
43-
y_true = [[1.0], [1.0], [0.0]],
44-
y_pred = [[0.97], [0.91], [0.03]])
45-
print('Loss: ', loss.numpy()) # Loss: [6.8532745e-06,
46-
1.9097870e-04,
47-
2.0559824e-05]
48-
```
40+
>>> fl = tfa.losses.SigmoidFocalCrossEntropy()
41+
>>> loss = fl(
42+
... y_true = [[1.0], [1.0], [0.0]],y_pred = [[0.97], [0.91], [0.03]])
43+
>>> loss
44+
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([6.8532745e-06, 1.9097870e-04, 2.0559824e-05],
45+
dtype=float32)>
4946
5047
Usage with `tf.keras` API:
5148
52-
```python
53-
model = tf.keras.Model(inputs, outputs)
54-
model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy())
55-
```
49+
>>> model = tf.keras.Model()
50+
>>> model.compile('sgd', loss=tfa.losses.SigmoidFocalCrossEntropy())
5651
5752
Args:
5853
alpha: balancing factor, default value is 0.25.

tensorflow_addons/losses/giou_loss.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# ==============================================================================
1515
"""Implements GIoU loss."""
1616

17+
from typing import Optional
18+
1719
import tensorflow as tf
20+
from typeguard import typechecked
21+
1822
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
1923
from tensorflow_addons.utils.types import TensorLike
20-
from typing import Optional
21-
from typeguard import typechecked
2224

2325

2426
@tf.keras.utils.register_keras_serializable(package="Addons")
@@ -33,20 +35,17 @@ class GIoULoss(LossFunctionWrapper):
3335
3436
Usage:
3537
36-
```python
37-
gl = tfa.losses.GIoULoss()
38-
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
39-
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
40-
loss = gl(boxes1, boxes2)
41-
print('Loss: ', loss.numpy()) # Loss: [1.07500000298023224, 1.9333333373069763]
42-
```
38+
>>> gl = tfa.losses.GIoULoss()
39+
>>> boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
40+
>>> boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
41+
>>> loss = gl(boxes1, boxes2)
42+
>>> loss
43+
<tf.Tensor: shape=(), dtype=float32, numpy=1.5041667>
4344
4445
Usage with `tf.keras` API:
4546
46-
```python
47-
model = tf.keras.Model(inputs, outputs)
48-
model.compile('sgd', loss=tfa.losses.GIoULoss())
49-
```
47+
>>> model = tf.keras.Model()
48+
>>> model.compile('sgd', loss=tfa.losses.GIoULoss())
5049
5150
Args:
5251
mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss.

tensorflow_addons/losses/kappa_loss.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,24 @@ class WeightedKappaLoss(tf.keras.losses.Loss):
3737
3838
Usage:
3939
40-
```python
41-
kappa_loss = WeightedKappaLoss(num_classes=4)
42-
y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
43-
[1, 0, 0, 0], [0, 0, 0, 1]])
44-
y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
45-
[0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
46-
loss = kappa_loss(y_true, y_pred)
47-
print('Loss: ', loss.numpy()) # Loss: -1.1611923
48-
```
40+
>>> kappa_loss = tfa.losses.WeightedKappaLoss(num_classes=4)
41+
>>> y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
42+
... [1, 0, 0, 0], [0, 0, 0, 1]])
43+
>>> y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
44+
... [0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
45+
>>> loss = kappa_loss(y_true, y_pred)
46+
>>> loss
47+
<tf.Tensor: shape=(), dtype=float32, numpy=-1.1611925>
4948
5049
Usage with `tf.keras` API:
51-
```python
52-
# outputs should be softmax results
53-
# if you want to weight the samples, just multiply the outputs
54-
# by the sample weight.
55-
model = tf.keras.Model(inputs, outputs)
56-
model.compile('sgd', loss=tfa.losses.WeightedKappa(num_classes=4))
57-
```
50+
51+
>>> model = tf.keras.Model()
52+
>>> model.compile('sgd', loss=tfa.losses.WeightedKappaLoss(num_classes=4))
53+
54+
<... outputs should be softmax results
55+
if you want to weight the samples, just multiply the outputs
56+
by the sample weight ...>
57+
5858
"""
5959

6060
@typechecked

tensorflow_addons/losses/npairs.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
"""Implements npairs loss."""
1616

1717
import tensorflow as tf
18+
from typeguard import typechecked
1819

1920
from tensorflow_addons.utils.types import TensorLike
20-
from typeguard import typechecked
2121

2222

2323
@tf.keras.utils.register_keras_serializable(package="Addons")
@@ -33,10 +33,21 @@ def npairs_loss(y_true: TensorLike, y_pred: TensorLike) -> tf.Tensor:
3333
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
3434
with shape `[batch_size, hidden_size]` can be computed as follows:
3535
36-
```python
37-
# y_pred = a * b^T
38-
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
39-
```
36+
>>> a = tf.constant([[1, 2],
37+
... [3, 4],
38+
... [5, 6]], dtype=tf.float16)
39+
>>> b = tf.constant([[5, 9],
40+
... [3, 6],
41+
... [1, 8]], dtype=tf.float16)
42+
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
43+
>>> y_pred
44+
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
45+
array([[23., 15., 17.],
46+
[51., 33., 35.],
47+
[79., 51., 53.]], dtype=float16)>
48+
49+
<... Note: constants a & b have been used purely for
50+
example purposes and have no significant value ...>
4051
4152
See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
4253
@@ -89,10 +100,21 @@ def npairs_multilabel_loss(y_true: TensorLike, y_pred: TensorLike) -> tf.Tensor:
89100
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
90101
with shape `[batch_size, hidden_size]` can be computed as follows:
91102
92-
```python
93-
# y_pred = a * b^T
94-
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
95-
```
103+
>>> a = tf.constant([[1, 2],
104+
... [3, 4],
105+
... [5, 6]], dtype=tf.float16)
106+
>>> b = tf.constant([[5, 9],
107+
... [3, 6],
108+
... [1, 8]], dtype=tf.float16)
109+
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
110+
>>> y_pred
111+
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
112+
array([[23., 15., 17.],
113+
[51., 33., 35.],
114+
[79., 51., 53.]], dtype=float16)>
115+
116+
<... Note: constants a & b have been used purely for
117+
example purposes and have no significant value ...>
96118
97119
See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
98120
@@ -139,10 +161,21 @@ class NpairsLoss(tf.keras.losses.Loss):
139161
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
140162
with shape `[batch_size, hidden_size]` can be computed as follows:
141163
142-
```python
143-
# y_pred = a * b^T
144-
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
145-
```
164+
>>> a = tf.constant([[1, 2],
165+
... [3, 4],
166+
... [5, 6]], dtype=tf.float16)
167+
>>> b = tf.constant([[5, 9],
168+
... [3, 6],
169+
... [1, 8]], dtype=tf.float16)
170+
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
171+
>>> y_pred
172+
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
173+
array([[23., 15., 17.],
174+
[51., 33., 35.],
175+
[79., 51., 53.]], dtype=float16)>
176+
177+
<... Note: constants a & b have been used purely for
178+
example purposes and have no significant value ...>
146179
147180
See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
148181
@@ -184,10 +217,21 @@ class NpairsMultilabelLoss(tf.keras.losses.Loss):
184217
The similarity matrix `y_pred` between two embedding matrices `a` and `b`
185218
with shape `[batch_size, hidden_size]` can be computed as follows:
186219
187-
```python
188-
# y_pred = a * b^T
189-
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
190-
```
220+
>>> a = tf.constant([[1, 2],
221+
... [3, 4],
222+
... [5, 6]], dtype=tf.float16)
223+
>>> b = tf.constant([[5, 9],
224+
... [3, 6],
225+
... [1, 8]], dtype=tf.float16)
226+
>>> y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
227+
>>> y_pred
228+
<tf.Tensor: shape=(3, 3), dtype=float16, numpy=
229+
array([[23., 15., 17.],
230+
[51., 33., 35.],
231+
[79., 51., 53.]], dtype=float16)>
232+
233+
<... Note: constants a & b have been used purely for
234+
example purposes and have no significant value ...>
191235
192236
See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
193237

tensorflow_addons/losses/quantiles.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,11 @@ def pinball_loss(
3535
See: https://en.wikipedia.org/wiki/Quantile_regression
3636
3737
Usage:
38-
```python
39-
loss = pinball_loss([0., 0., 1., 1.], [1., 1., 1., 0.], tau=.1)
4038
41-
# loss = max(0.1 * (y_true - y_pred), (0.1 - 1) * (y_true - y_pred))
42-
# = (0.9 + 0.9 + 0 + 0.1) / 4
43-
44-
print('Loss: ', loss.numpy()) # Loss: 0.475
45-
```
39+
>>> loss = tfa.losses.pinball_loss([0., 0., 1., 1.],
40+
... [1., 1., 1., 0.], tau=.1)
41+
>>> loss
42+
<tf.Tensor: shape=(), dtype=float32, numpy=0.475>
4643
4744
Args:
4845
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`
@@ -84,22 +81,16 @@ class PinballLoss(LossFunctionWrapper):
8481
See: https://en.wikipedia.org/wiki/Quantile_regression
8582
8683
Usage:
87-
```python
88-
pinball = tfa.losses.PinballLoss(tau=.1)
89-
loss = pinball([0., 0., 1., 1.], [1., 1., 1., 0.])
90-
91-
# loss = max(0.1 * (y_true - y_pred), (0.1 - 1) * (y_true - y_pred))
92-
# = (0.9 + 0.9 + 0 + 0.1) / 4
9384
94-
print('Loss: ', loss.numpy()) # Loss: 0.475
95-
```
85+
>>> pinball = tfa.losses.PinballLoss(tau=.1)
86+
>>> loss = pinball([0., 0., 1., 1.], [1., 1., 1., 0.])
87+
>>> loss
88+
<tf.Tensor: shape=(), dtype=float32, numpy=0.475>
9689
9790
Usage with the `tf.keras` API:
9891
99-
```python
100-
model = tf.keras.Model(inputs, outputs)
101-
model.compile('sgd', loss=tfa.losses.PinballLoss(tau=.1))
102-
```
92+
>>> model = tf.keras.Model()
93+
>>> model.compile('sgd', loss=tfa.losses.PinballLoss(tau=.1))
10394
10495
Args:
10596
tau: (Optional) Float in [0, 1] or a tensor taking values in [0, 1] and

0 commit comments

Comments
 (0)