Skip to content

Commit 3a93746

Browse files
committed
Make metadata lazy for table row classes, replace container classes with dataclasses
1 parent 3655c70 commit 3a93746

File tree

8 files changed

+507
-492
lines changed

8 files changed

+507
-492
lines changed

python/CHANGELOG.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44

55
**Breaking changes**
66

7+
- `Mutation.position` and `Mutation.index` which were deprecated in 0.2.2 (Sep '19) have
8+
been removed.
9+
710
**Features**
811

12+
- Entity classes such as `Mutation`, `Node` are now python dataclasses
13+
(:user:`benjeffery`, :pr:`1261`).
14+
15+
- Metadata decoding for table row access is now lazy (:user:`benjeffery`, :pr:`1261`).
16+
917
**Fixes**
1018

1119
--------------------

python/tests/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,13 @@ def __init__(self, tree_sequence, breakpoints=None):
234234
def make_mutation(id_):
235235
site, node, derived_state, parent, metadata, time = ll_ts.get_mutation(id_)
236236
return tskit.Mutation(
237-
id_=id_,
237+
id=id_,
238238
site=site,
239239
node=node,
240240
time=time,
241241
derived_state=derived_state,
242242
parent=parent,
243-
encoded_metadata=metadata,
243+
metadata=metadata,
244244
metadata_decoder=tskit.metadata.parse_metadata_schema(
245245
ll_ts.get_table_metadata_schemas().mutation
246246
).decode_row,
@@ -250,11 +250,11 @@ def make_mutation(id_):
250250
pos, ancestral_state, ll_mutations, id_, metadata = ll_ts.get_site(j)
251251
self._sites.append(
252252
tskit.Site(
253-
id_=id_,
253+
id=id_,
254254
position=pos,
255255
ancestral_state=ancestral_state,
256256
mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations],
257-
encoded_metadata=metadata,
257+
metadata=metadata,
258258
metadata_decoder=tskit.metadata.parse_metadata_schema(
259259
ll_ts.get_table_metadata_schemas().site
260260
).decode_row,

python/tests/test_highlevel.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,17 +1239,7 @@ def verify_mutations(self, ts):
12391239
assert ts.num_mutations == len(other_mutations)
12401240
assert ts.num_mutations == len(mutations)
12411241
for mut, other_mut in zip(mutations, other_mutations):
1242-
# We cannot compare these directly as the mutations obtained
1243-
# from the mutations iterator will have extra deprecated
1244-
# attributes.
1245-
assert mut.id == other_mut.id
1246-
assert mut.site == other_mut.site
1247-
assert mut.parent == other_mut.parent
1248-
assert mut.node == other_mut.node
1249-
assert mut.metadata == other_mut.metadata
1250-
# Check the deprecated attrs.
1251-
assert mut.position == ts.site(mut.site).position
1252-
assert mut.index == mut.site
1242+
assert mut == other_mut
12531243

12541244
def test_sites_mutations(self):
12551245
# Check that the mutations iterator returns the correct values.
@@ -2084,17 +2074,7 @@ def verify_mutations(self, tree):
20842074
assert tree.num_mutations == len(other_mutations)
20852075
assert tree.num_mutations == len(mutations)
20862076
for mut, other_mut in zip(mutations, other_mutations):
2087-
# We cannot compare these directly as the mutations obtained
2088-
# from the mutations iterator will have extra deprecated
2089-
# attributes.
2090-
assert mut.id == other_mut.id
2091-
assert mut.site == other_mut.site
2092-
assert mut.parent == other_mut.parent
2093-
assert mut.node == other_mut.node
2094-
assert mut.metadata == other_mut.metadata
2095-
# Check the deprecated attrs.
2096-
assert mut.position == tree.tree_sequence.site(mut.site).position
2097-
assert mut.index == mut.site
2077+
assert mut == other_mut
20982078

20992079
def test_simple_mutations(self):
21002080
tree = self.get_tree()
@@ -2948,17 +2928,18 @@ def test_metadata(self):
29482928
(inst,) = self.get_instances(1)
29492929
(inst2,) = self.get_instances(1)
29502930
assert inst == inst2
2951-
inst._metadata_decoder = lambda m: "different decoder"
2931+
inst.metadata
29522932
assert inst == inst2
2953-
inst._encoded_metadata = b"different"
2954-
assert not (inst == inst2)
2933+
inst._metadata = "different"
2934+
assert inst != inst2
29552935

29562936
def test_decoder_run_once(self):
29572937
# For a given instance, the decoded metadata should be cached, with the decoder
29582938
# called once
29592939
(inst,) = self.get_instances(1)
29602940
times_run = 0
29612941

2942+
# Hack in a tracing decoder
29622943
def decoder(m):
29632944
nonlocal times_run
29642945
times_run += 1
@@ -2976,12 +2957,12 @@ class TestIndividualContainer(SimpleContainersMixin, SimpleContainersWithMetadat
29762957
def get_instances(self, n):
29772958
return [
29782959
tskit.Individual(
2979-
id_=j,
2960+
id=j,
29802961
flags=j,
29812962
location=[j],
29822963
parents=[j],
29832964
nodes=[j],
2984-
encoded_metadata=b"x" * j,
2965+
metadata=b"x" * j,
29852966
metadata_decoder=lambda m: m.decode() + "decoded",
29862967
)
29872968
for j in range(n)
@@ -2992,12 +2973,12 @@ class TestNodeContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin
29922973
def get_instances(self, n):
29932974
return [
29942975
tskit.Node(
2995-
id_=j,
2976+
id=j,
29962977
flags=j,
29972978
time=j,
29982979
population=j,
29992980
individual=j,
3000-
encoded_metadata=b"x" * j,
2981+
metadata=b"x" * j,
30012982
metadata_decoder=lambda m: m.decode() + "decoded",
30022983
)
30032984
for j in range(n)
@@ -3012,9 +2993,9 @@ def get_instances(self, n):
30122993
right=j,
30132994
parent=j,
30142995
child=j,
3015-
encoded_metadata=b"x" * j,
2996+
metadata=b"x" * j,
30162997
metadata_decoder=lambda m: m.decode() + "decoded",
3017-
id_=j,
2998+
id=j,
30182999
)
30193000
for j in range(n)
30203001
]
@@ -3024,11 +3005,11 @@ class TestSiteContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin
30243005
def get_instances(self, n):
30253006
return [
30263007
tskit.Site(
3027-
id_=j,
3008+
id=j,
30283009
position=j,
30293010
ancestral_state="A" * j,
30303011
mutations=TestMutationContainer().get_instances(j),
3031-
encoded_metadata=b"x" * j,
3012+
metadata=b"x" * j,
30323013
metadata_decoder=lambda m: m.decode() + "decoded",
30333014
)
30343015
for j in range(n)
@@ -3039,46 +3020,46 @@ class TestMutationContainer(SimpleContainersMixin, SimpleContainersWithMetadataM
30393020
def get_instances(self, n):
30403021
return [
30413022
tskit.Mutation(
3042-
id_=j,
3023+
id=j,
30433024
site=j,
30443025
node=j,
30453026
time=j,
30463027
derived_state="A" * j,
30473028
parent=j,
3048-
encoded_metadata=b"x" * j,
3029+
metadata=b"x" * j,
30493030
metadata_decoder=lambda m: m.decode() + "decoded",
30503031
)
30513032
for j in range(n)
30523033
]
30533034

30543035
def test_nan_equality(self):
30553036
a = tskit.Mutation(
3056-
id_=42,
3037+
id=42,
30573038
site=42,
30583039
node=42,
30593040
time=UNKNOWN_TIME,
30603041
derived_state="A" * 42,
30613042
parent=42,
3062-
encoded_metadata=b"x" * 42,
3043+
metadata=b"x" * 42,
30633044
metadata_decoder=lambda m: m.decode() + "decoded",
30643045
)
30653046
b = tskit.Mutation(
3066-
id_=42,
3047+
id=42,
30673048
site=42,
30683049
node=42,
30693050
derived_state="A" * 42,
30703051
parent=42,
3071-
encoded_metadata=b"x" * 42,
3052+
metadata=b"x" * 42,
30723053
metadata_decoder=lambda m: m.decode() + "decoded",
30733054
)
30743055
c = tskit.Mutation(
3075-
id_=42,
3056+
id=42,
30763057
site=42,
30773058
node=42,
30783059
time=math.nan,
30793060
derived_state="A" * 42,
30803061
parent=42,
3081-
encoded_metadata=b"x" * 42,
3062+
metadata=b"x" * 42,
30823063
metadata_decoder=lambda m: m.decode() + "decoded",
30833064
)
30843065
assert a == a
@@ -3096,13 +3077,14 @@ class TestMigrationContainer(SimpleContainersMixin, SimpleContainersWithMetadata
30963077
def get_instances(self, n):
30973078
return [
30983079
tskit.Migration(
3080+
id=j,
30993081
left=j,
31003082
right=j,
31013083
node=j,
31023084
source=j,
31033085
dest=j,
31043086
time=j,
3105-
encoded_metadata=b"x" * j,
3087+
metadata=b"x" * j,
31063088
metadata_decoder=lambda m: m.decode() + "decoded",
31073089
)
31083090
for j in range(n)
@@ -3113,8 +3095,8 @@ class TestPopulationContainer(SimpleContainersMixin, SimpleContainersWithMetadat
31133095
def get_instances(self, n):
31143096
return [
31153097
tskit.Population(
3116-
id_=j,
3117-
encoded_metadata=b"x" * j,
3098+
id=j,
3099+
metadata=b"x" * j,
31183100
metadata_decoder=lambda m: m.decode() + "decoded",
31193101
)
31203102
for j in range(n)
@@ -3124,7 +3106,7 @@ def get_instances(self, n):
31243106
class TestProvenanceContainer(SimpleContainersMixin):
31253107
def get_instances(self, n):
31263108
return [
3127-
tskit.Provenance(id_=j, timestamp="x" * j, record="y" * j) for j in range(n)
3109+
tskit.Provenance(id=j, timestamp="x" * j, record="y" * j) for j in range(n)
31283110
]
31293111

31303112

python/tests/test_stats.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,17 @@ def verify_max_distance(self, ts):
108108
A = ldc.get_r2_matrix()
109109
j = len(mutations) // 2
110110
for k in range(j):
111-
x = mutations[j + k].position - mutations[j].position
111+
x = (
112+
ts.site(mutations[j + k].site).position
113+
- ts.site(mutations[j].site).position
114+
)
112115
a = ldc.get_r2_array(j, max_distance=x)
113116
assert a.shape[0] == k
114117
assert np.allclose(A[j, j + 1 : j + 1 + k], a)
115-
x = mutations[j].position - mutations[j - k].position
118+
x = (
119+
ts.site(mutations[j].site).position
120+
- ts.site(mutations[j - k].site).position
121+
)
116122
a = ldc.get_r2_array(j, max_distance=x, direction=tskit.REVERSE)
117123
assert a.shape[0] == k
118124
assert np.allclose(A[j, j - k : j], a[::-1])

python/tskit/metadata.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
Classes for metadata decoding, encoding and validation
2424
"""
2525
import abc
26+
import builtins
2627
import collections
2728
import copy
2829
import json
2930
import pprint
3031
import struct
32+
import types
3133
from itertools import islice
3234
from typing import Any
3335
from typing import Mapping
@@ -39,6 +41,8 @@
3941
import tskit
4042
import tskit.exceptions as exceptions
4143

44+
__builtin__object__setattr__ = builtins.object.__setattr__
45+
4246

4347
def replace_root_refs(obj):
4448
if type(obj) == list:
@@ -656,3 +660,56 @@ def parse_metadata_schema(encoded_schema: str) -> MetadataSchema:
656660
except json.decoder.JSONDecodeError:
657661
raise ValueError(f"Metadata schema is not JSON, found {encoded_schema}")
658662
return MetadataSchema(decoded)
663+
664+
665+
class _CachedMetadata:
666+
"""
667+
Descriptor for lazy decoding of metadata on attribute access.
668+
"""
669+
670+
def __get__(self, row, owner):
671+
if row._metadata_decoder is not None:
672+
# Some classes that use this are frozen so we need to directly setattr.
673+
__builtin__object__setattr__(
674+
row, "_metadata", row._metadata_decoder(row._metadata)
675+
)
676+
# Decoder being None indicates that metadata is decoded
677+
__builtin__object__setattr__(row, "_metadata_decoder", None)
678+
return row._metadata
679+
680+
def __set__(self, row, value):
681+
__builtin__object__setattr__(row, "_metadata", value)
682+
683+
684+
def lazy_decode(cls):
685+
"""
686+
Modifies a dataclass such that it lazily decodes metadata, if it is encoded.
687+
If the metadata passed to the constructor is encoded a `decoder` parameter must be
688+
also be passed.
689+
"""
690+
wrapped_init = cls.__init__
691+
692+
# Intercept the init to catch the table reference needed for the decoder
693+
def new_init(self, *args, metadata_decoder=None, **kwargs):
694+
__builtin__object__setattr__(self, "_metadata_decoder", metadata_decoder)
695+
wrapped_init(self, *args, **kwargs)
696+
697+
cls.__init__ = new_init
698+
699+
# Add a descriptor to the class to decode and cache metadata
700+
cls.metadata = _CachedMetadata()
701+
702+
# Add slots needed to the class
703+
slots = cls.__slots__
704+
slots.extend(["_metadata", "_metadata_decoder"])
705+
dict_ = dict()
706+
sloted_members = dict()
707+
for k, v in cls.__dict__.items():
708+
if k not in slots:
709+
dict_[k] = v
710+
elif not isinstance(v, types.MemberDescriptorType):
711+
sloted_members[k] = v
712+
new_cls = type(cls.__name__, cls.__bases__, dict_)
713+
for k, v in sloted_members.items():
714+
setattr(new_cls, k, v)
715+
return new_cls

0 commit comments

Comments
 (0)