Skip to content

Commit f3d6ee5

Browse files
Add testing module, with an function check_metric_serialization (#1295)
* Added testing module * Added an example in cohens_kappa_test.py * Should work with bazel now. * Changed import order. * Forgot the if __name__ == "__main__" * Use the new way of testing.
1 parent 1e9a398 commit f3d6ee5

File tree

9 files changed

+204
-1
lines changed

9 files changed

+204
-1
lines changed

tensorflow_addons/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ py_library(
2626
"//tensorflow_addons/optimizers",
2727
"//tensorflow_addons/rnn",
2828
"//tensorflow_addons/seq2seq",
29+
"//tensorflow_addons/testing",
2930
"//tensorflow_addons/text",
3031
"//tensorflow_addons/utils",
3132
],

tensorflow_addons/metrics/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ py_library(
1515
"utils.py",
1616
],
1717
deps = [
18+
"//tensorflow_addons/testing",
1819
"//tensorflow_addons/utils",
1920
],
2021
)

tensorflow_addons/metrics/cohens_kappa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(
7171
sparse_labels: bool = False,
7272
regression: bool = False,
7373
dtype: AcceptableDTypes = None,
74-
**kwargs
7574
):
7675
"""Creates a `CohenKappa` instance.
7776

tensorflow_addons/metrics/cohens_kappa_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tensorflow as tf
2020
from tensorflow_addons.metrics import CohenKappa
2121
from tensorflow_addons.utils import test_utils
22+
from tensorflow_addons.testing.serialization import check_metric_serialization
2223

2324

2425
@test_utils.run_all_in_graph_and_eager_modes
@@ -232,3 +233,12 @@ def test_with_ohe_labels():
232233

233234
obj.update_state(y_true, y_pred)
234235
np.testing.assert_allclose(0.19999999, obj.result().numpy())
236+
237+
238+
def test_cohen_kappa_serialization():
239+
actuals = np.array([4, 4, 3, 3, 2, 2, 1, 1], dtype=np.int32)
240+
preds = np.array([1, 2, 4, 1, 3, 3, 4, 4], dtype=np.int32)
241+
weights = np.array([1, 1, 2, 5, 10, 2, 3, 3], dtype=np.int32)
242+
243+
ck = CohenKappa(num_classes=5, sparse_labels=True, weightage="quadratic")
244+
check_metric_serialization(ck, actuals, preds, weights)

tensorflow_addons/testing/BUILD

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
py_library(
6+
name = "testing",
7+
srcs = [
8+
"__init__.py",
9+
"serialization.py",
10+
],
11+
)
12+
13+
py_test(
14+
name = "serialization_test",
15+
size = "small",
16+
srcs = glob(["*_test.py"]),
17+
main = "run_all_test.py",
18+
deps = [
19+
":testing",
20+
],
21+
)

tensorflow_addons/testing/__init__.py

Whitespace-only changes.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pathlib import Path
2+
import sys
3+
4+
import pytest
5+
6+
if __name__ == "__main__":
7+
dirname = Path(__file__).absolute().parent
8+
sys.exit(pytest.main([str(dirname)]))
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Union
2+
import inspect
3+
4+
import numpy as np
5+
from tensorflow.keras.metrics import Metric
6+
import typeguard
7+
8+
9+
@typeguard.typechecked
10+
def check_metric_serialization(
11+
metric: Metric,
12+
y_true: Union[tuple, np.ndarray],
13+
y_pred: Union[tuple, np.ndarray],
14+
sample_weight: Union[tuple, np.ndarray, None] = None,
15+
strict: bool = True,
16+
):
17+
config = metric.get_config()
18+
class_ = metric.__class__
19+
20+
check_config(config, class_, strict)
21+
22+
metric_copy = class_(**config)
23+
metric_copy.set_weights(metric.get_weights())
24+
25+
if isinstance(y_true, tuple):
26+
y_true = get_random_array(y_true)
27+
if isinstance(y_pred, tuple):
28+
y_pred = get_random_array(y_pred)
29+
if isinstance(sample_weight, tuple) and sample_weight is not None:
30+
sample_weight = get_random_array(sample_weight)
31+
32+
# the behavior should be the same for the original and the copy
33+
if sample_weight is None:
34+
metric.update_state(y_true, y_pred)
35+
metric_copy.update_state(y_true, y_pred)
36+
else:
37+
metric.update_state(y_true, y_pred, sample_weight)
38+
metric_copy.update_state(y_true, y_pred, sample_weight)
39+
40+
assert_all_arrays_close(metric.get_weights(), metric_copy.get_weights())
41+
metric_result = metric.result().numpy()
42+
metric_copy_result = metric_copy.result().numpy()
43+
if metric_result != metric_copy_result:
44+
raise ValueError(
45+
"The original gave a result of {} after an "
46+
"`.update_states()` call, but the copy gave "
47+
"a result of {} after the same "
48+
"call.".format(metric_result, metric_copy_result)
49+
)
50+
51+
52+
def check_config(config, class_, strict):
53+
init_signature = inspect.signature(class_.__init__)
54+
55+
for parameter_name in init_signature.parameters:
56+
if parameter_name == "self":
57+
continue
58+
elif parameter_name == "args" and strict:
59+
raise KeyError(
60+
"Please do not use args in the class constructor of {}, "
61+
"as it hides the real signature "
62+
"and degrades the user experience. "
63+
"If you have no alternative to *args, "
64+
"use `strict=False` in check_metric_serialization.".format(
65+
class_.__name__
66+
)
67+
)
68+
elif parameter_name == "kwargs" and strict:
69+
raise KeyError(
70+
"Please do not use kwargs in the class constructor of {}, "
71+
"as it hides the real signature "
72+
"and degrades the user experience. "
73+
"If you have no alternative to **kwargs, "
74+
"use `strict=False` in check_metric_serialization.".format(
75+
class_.__name__
76+
)
77+
)
78+
if parameter_name not in config:
79+
raise KeyError(
80+
"The constructor parameter {} is not present in the config dict "
81+
"obtained with `.get_config()` of {}. All parameters should be set to "
82+
"ensure a perfect copy of the keras object can be obtained when "
83+
"serialized.".format(parameter_name, class_.__name__)
84+
)
85+
86+
87+
def assert_all_arrays_close(list1, list2):
88+
for array1, array2 in zip(list1, list2):
89+
np.testing.assert_allclose(array1, array2)
90+
91+
92+
def get_random_array(shape):
93+
return np.random.uniform(size=shape).astype(np.float32)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
import pytest
3+
import tensorflow as tf
4+
5+
from tensorflow.keras.metrics import MeanAbsoluteError, TrueNegatives, Metric
6+
from tensorflow_addons.testing.serialization import check_metric_serialization
7+
8+
9+
def test_check_metric_serialization_mae():
10+
check_metric_serialization(MeanAbsoluteError(), (2, 2), (2, 2))
11+
check_metric_serialization(MeanAbsoluteError(name="hello"), (2, 2), (2, 2))
12+
check_metric_serialization(MeanAbsoluteError(), (2, 2, 2), (2, 2, 2))
13+
check_metric_serialization(MeanAbsoluteError(), (2, 2, 2), (2, 2, 2), (2, 2, 1))
14+
15+
16+
def get_random_booleans():
17+
return np.random.uniform(0, 2, size=(2, 2))
18+
19+
20+
def test_check_metric_serialization_true_negative():
21+
check_metric_serialization(
22+
TrueNegatives(0.8),
23+
np.random.uniform(0, 2, size=(2, 2)).astype(np.bool),
24+
np.random.uniform(0, 1, size=(2, 2)).astype(np.float32),
25+
)
26+
27+
28+
class MyDummyMetric(Metric):
29+
def __init__(self, stuff, name):
30+
super().__init__(name)
31+
self.stuff = stuff
32+
33+
def update_state(self, y_true, y_pred, sample_weights):
34+
pass
35+
36+
def get_config(self):
37+
return super().get_config()
38+
39+
def result(self):
40+
return 3
41+
42+
43+
def test_missing_arg():
44+
with pytest.raises(KeyError) as exception_info:
45+
check_metric_serialization(MyDummyMetric("dodo", "dada"), (2,), (2,))
46+
47+
assert "stuff" in str(exception_info.value)
48+
49+
50+
class MyOtherDummyMetric(Metric):
51+
def __init__(self, to_add, name=None, dtype=None):
52+
super().__init__(name, dtype)
53+
self.to_add = to_add
54+
self.sum_of_y_pred = self.add_weight(name="my_sum", initializer="zeros")
55+
56+
def update_state(self, y_true, y_pred, sample_weights=None):
57+
self.sum_of_y_pred.assign_add(tf.math.reduce_sum(y_pred) + self.to_add)
58+
59+
def get_config(self):
60+
config = {"to_add": self.to_add + 1}
61+
config.update(super().get_config())
62+
return config
63+
64+
def result(self):
65+
return self.sum_of_y_pred
66+
67+
68+
def test_wrong_serialization():
69+
with pytest.raises(AssertionError):
70+
check_metric_serialization(MyOtherDummyMetric(5), (2,), (2,))

0 commit comments

Comments
 (0)