Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d1ef242
add weighted kappa loss
wenmin-wu Dec 13, 2019
cbcec18
add unit tests
wenmin-wu Dec 13, 2019
4237f91
change some docs
wenmin-wu Dec 13, 2019
61cba30
change python files format
wenmin-wu Dec 19, 2019
e58fa6a
shorten some lines
wenmin-wu Dec 19, 2019
dcfc504
rename and update README and BUILD
wenmin-wu Dec 21, 2019
dfa27d9
Merge branch 'master' of github.com:wenmin-wu/addons
wenmin-wu Dec 21, 2019
a9f64b7
Merge branch 'master' into master
wenmin-wu Jan 7, 2020
ea9e106
Merge branch 'master' into wenmin-wu_master
gabrieldemarmiesse Feb 26, 2020
ed21d89
Merge branch 'master' into wenmin-wu_master
gabrieldemarmiesse Feb 26, 2020
fd03150
resolve conversations
wenmin-wu Mar 23, 2020
8d7bd15
Merge branch 'master' into master
wenmin-wu Mar 23, 2020
2dca416
resolve converstions
wenmin-wu Mar 23, 2020
ac06219
remove escape
wenmin-wu Mar 23, 2020
a950ce7
reformat tensorflow_addons/losses/kappa_loss* with black
wenmin-wu Mar 23, 2020
174e7f4
reformat code
wenmin-wu Mar 23, 2020
1836ae3
reformat code
wenmin-wu Mar 23, 2020
5055c9a
reformat code with black
wenmin-wu Mar 23, 2020
da72e82
Merge branch 'master' into wenmin-wu_master
gabrieldemarmiesse Apr 5, 2020
3715631
Update tensorflow_addons/losses/kappa_loss.py
wenmin-wu Apr 10, 2020
0b69dd6
[KappaLoss] change according to review
wenmin-wu Apr 10, 2020
58e6fd9
Merge branch 'master' of github.com:wenmin-wu/addons
wenmin-wu Apr 10, 2020
c8cb0c3
Update tensorflow_addons/losses/kappa_loss.py
wenmin-wu Apr 10, 2020
32a3a86
Update tensorflow_addons/losses/kappa_loss.py
wenmin-wu Apr 10, 2020
5020c1b
[KappaLoss] change accroding to code review
wenmin-wu Apr 10, 2020
238c671
Merge branch 'master' into master
wenmin-wu Apr 10, 2020
8cc4bb2
[KappaLoss] change code format
wenmin-wu Apr 10, 2020
1dea775
Merge branch 'master' of github.com:wenmin-wu/addons
wenmin-wu Apr 10, 2020
491396b
[SoftKappaLoss] mv kappa_loss_test.py to losses/tests
wenmin-wu Apr 12, 2020
a31f279
Update .github/CODEOWNERS
wenmin-wu Apr 12, 2020
44c0deb
[SoftKappaLoss] refine codes according to code review
wenmin-wu Apr 13, 2020
f557ef2
[SoftKappaLoss] reformat codes
wenmin-wu Apr 13, 2020
ae5d52e
[SoftKappaLoss] fix np_deep not defined
wenmin-wu Apr 13, 2020
42821bb
[SoftKappaLoss] fix tests problem
wenmin-wu Apr 13, 2020
ebfea4c
[SoftKappaLoss] unnecessary change to tigger CI
wenmin-wu Apr 13, 2020
6d6dc27
Merge branch 'master' into wenmin-wu_master
gabrieldemarmiesse Apr 13, 2020
8633927
Default value for the seed is not needed.
gabrieldemarmiesse Apr 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@
/tensorflow_addons/losses/tests/sparsemax_loss_test.py @andreasmadsen
/tensorflow_addons/losses/triplet.py @lc0
/tensorflow_addons/losses/tests/triplet_test.py @lc0
/tensorflow_addons/losses/kappa_loss.py @wenmin-wu
/tensorflow_addons/losses/tests/kappa_loss_test.py @wenmin-wu

/tensorflow_addons/metrics/cohens_kappa.py @aakashkumarnain
/tensorflow_addons/metrics/tests/cohens_kappa_test.py @aakashkumarnain
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
npairs_multilabel_loss,
NpairsMultilabelLoss,
)
from tensorflow_addons.losses.kappa_loss import WeightedKappaLoss
132 changes: 132 additions & 0 deletions tensorflow_addons/losses/kappa_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements Weighted kappa loss."""

import tensorflow as tf
from tensorflow_addons.utils.types import Number
from typeguard import typechecked
from typing import Optional


@tf.keras.utils.register_keras_serializable(package="Addons")
class WeightedKappaLoss(tf.keras.losses.Loss):
"""Implements the Weighted Kappa loss function.

Weighted Kappa loss was introduced in the
[Weighted kappa loss function for multi-class classification
of ordinal data in deep learning]
(https://www.sciencedirect.com/science/article/abs/pii/S0167865517301666).
Weighted Kappa is widely used in Ordinal Classification Problems.
The loss value lies in [-inf, log 2], where log 2
means the random prediction.

Usage:

```python
kappa_loss = WeightedKappaLoss(num_classes=4)
y_true = tf.constant([[0, 0, 1, 0], [0, 1, 0, 0],
[1, 0, 0, 0], [0, 0, 0, 1]])
y_pred = tf.constant([[0.1, 0.2, 0.6, 0.1], [0.1, 0.5, 0.3, 0.1],
[0.8, 0.05, 0.05, 0.1], [0.01, 0.09, 0.1, 0.8]])
loss = kappa_loss(y_true, y_pred)
print('Loss: ', loss.numpy()) # Loss: -1.1611923
```

Usage with `tf.keras` API:
```python
# outputs should be softmax results
# if you want to weight the samples, just multiply the outputs
# by the sample weight.
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.WeightedKappa(num_classes=4))
```
"""

@typechecked
def __init__(
self,
num_classes: int,
weightage: Optional[str] = "quadratic",
name: Optional[str] = "cohen_kappa_loss",
epsilon: Optional[Number] = 1e-6,
dtype: Optional[tf.DType] = tf.float32,
reduction: str = tf.keras.losses.Reduction.NONE,
):
"""Creates a `WeightedKappa` instance.

Args:
num_classes: Number of unique classes in your dataset.
weightage: (Optional) Weighting to be considered for calculating
kappa statistics. A valid value is one of
['linear', 'quadratic']. Defaults to `quadratic` since it's
mostly used.
name: (Optional) String name of the metric instance.
epsilon: (Optional) increment to avoid log zero,
so the loss will be log(1 - k + epsilon), where k belongs to
[-1, 1], usually you can use the default value which is 1e-6.
dtype: (Optional) Data type of the metric result.
Defaults to `tf.float32`.
Raises:
ValueError: If the value passed for `weightage` is invalid
i.e. not any one of ['linear', 'quadratic']
"""

super().__init__(name=name, reduction=reduction)

if weightage not in ("linear", "quadratic"):
raise ValueError("Unknown kappa weighting type.")

self.weightage = weightage
self.num_classes = num_classes
self.epsilon = epsilon
self.dtype = dtype
label_vec = tf.range(num_classes, dtype=dtype)
self.row_label_vec = tf.reshape(label_vec, [1, num_classes])
self.col_label_vec = tf.reshape(label_vec, [num_classes, 1])
col_mat = tf.tile(self.col_label_vec, [1, num_classes])
row_mat = tf.tile(self.row_label_vec, [num_classes, 1])
if weightage == "linear":
self.weight_mat = tf.abs(col_mat - row_mat)
else:
self.weight_mat = (col_mat - row_mat) ** 2

def call(self, y_true, y_pred):
y_true = tf.cast(y_true, dtype=self.dtype)
batch_size = tf.shape(y_true)[0]
cat_labels = tf.matmul(y_true, self.col_label_vec)
cat_label_mat = tf.tile(cat_labels, [1, self.num_classes])
row_label_mat = tf.tile(self.row_label_vec, [batch_size, 1])
if self.weightage == "linear":
weight = tf.abs(cat_label_mat - row_label_mat)
else:
weight = (cat_label_mat - row_label_mat) ** 2
numerator = tf.reduce_sum(weight * y_pred)
label_dist = tf.reduce_sum(y_true, axis=0, keepdims=True)
pred_dist = tf.reduce_sum(y_pred, axis=0, keepdims=True)
w_pred_dist = tf.matmul(self.weight_mat, pred_dist, transpose_b=True)
denominator = tf.reduce_sum(tf.matmul(label_dist, w_pred_dist))
denominator /= tf.cast(batch_size, dtype=self.dtype)
loss = tf.math.divide_no_nan(numerator, denominator)
return tf.math.log(loss + self.epsilon)

def get_config(self):
config = {
"num_classes": self.num_classes,
"weightage": self.weightage,
"epsilon": self.epsilon,
"dtype": self.dtype,
}
base_config = super().get_config()
return {**base_config, **config}
92 changes: 92 additions & 0 deletions tensorflow_addons/losses/tests/kappa_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Weighted Kappa Loss."""

import pytest
import numpy as np
import tensorflow as tf
from tensorflow_addons.losses.kappa_loss import WeightedKappaLoss


def weighted_kappa_loss_np(y_true, y_pred, weightage="quadratic", eps=1e-6):
num_samples, num_classes = y_true.shape
cat_labels = y_true.argmax(axis=1).reshape((-1, 1))
label_mat = np.tile(cat_labels, (1, num_classes))
row_label_vec = np.arange(num_classes).reshape((1, num_classes))
label_mat_ = np.tile(row_label_vec, (num_samples, 1))
if weightage == "linear":
weight = np.abs(label_mat - label_mat_)
else:
weight = (label_mat - label_mat_) ** 2
numerator = (y_pred * weight).sum()
label_dist = y_true.sum(axis=0, keepdims=True)
pred_dist = y_pred.sum(axis=0, keepdims=True)

col_label_vec = row_label_vec.T
row_mat = np.tile(row_label_vec, (num_classes, 1))
col_mat = np.tile(col_label_vec, (1, num_classes))
if weightage == "quadratic":
weight_ = (col_mat - row_mat) ** 2
else:
weight_ = np.abs(col_mat - row_mat)
weighted_pred_dist = np.matmul(weight_, pred_dist.T)
denominator = np.matmul(label_dist, weighted_pred_dist).sum()
denominator /= num_samples
return np.log(np.nan_to_num(numerator / denominator) + eps)


def gen_labels_and_preds(num_samples, num_classes, seed):
np.random.seed(seed)
rands = np.random.uniform(size=(num_samples, num_classes))
cat_labels = rands.argmax(axis=1)
y_true = np.eye(num_classes, dtype="int")[cat_labels]
y_pred = np.random.uniform(size=(num_samples, num_classes))
y_pred /= y_pred.sum(axis=1, keepdims=True)
return y_true, y_pred


@pytest.mark.parametrize("np_seed", [0, 1, 2, 3])
def test_linear_weighted_kappa_loss(np_seed):
y_true, y_pred = gen_labels_and_preds(50, 4, np_seed)
kappa_loss = WeightedKappaLoss(num_classes=4, weightage="linear")
y_pred = y_pred.astype(kappa_loss.dtype.as_numpy_dtype)
loss = kappa_loss(y_true, y_pred)
loss_np = weighted_kappa_loss_np(y_true, y_pred, weightage="linear")
np.testing.assert_allclose(loss, loss_np, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize("np_seed", [0, 1, 2, 3])
def test_quadratic_weighted_kappa_loss(np_seed):
y_true, y_pred = gen_labels_and_preds(100, 3, np_seed)
kappa_loss = WeightedKappaLoss(num_classes=3)
y_pred = y_pred.astype(kappa_loss.dtype.as_numpy_dtype)
loss = kappa_loss(y_true, y_pred)
loss_np = weighted_kappa_loss_np(y_true, y_pred)
np.testing.assert_allclose(loss, loss_np, rtol=1e-5, atol=1e-5)


def test_config():
kappa_loss = WeightedKappaLoss(
num_classes=4, weightage="linear", name="kappa_loss", epsilon=0.001,
)
assert kappa_loss.num_classes == 4
assert kappa_loss.weightage == "linear"
assert kappa_loss.name == "kappa_loss"
np.testing.assert_allclose(kappa_loss.epsilon, 0.001, 1e-6)


def test_serialization():
loss = WeightedKappaLoss(num_classes=3)
tf.keras.losses.deserialize(tf.keras.losses.serialize(loss))