Skip to content

Commit afc0ec7

Browse files
committed
Make metadata lazy for table row classes, replace container classes with dataclasses
1 parent 1510d2a commit afc0ec7

File tree

8 files changed

+508
-492
lines changed

8 files changed

+508
-492
lines changed

python/CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@
44

55
**Breaking changes**
66

7+
- `Mutation.position` and `Mutation.index` which were deprecated in 0.2.2 (Sep '19) have
8+
been removed.
9+
710
**Features**
811

912
- SVG visualization of a single tree allows all mutations on an edge to be plotted
1013
via the ``all_edge_mutations`` param (:user:`hyanwong`,:issue:`1253`, :pr:`1258`).
1114

15+
- Entity classes such as `Mutation`, `Node` are now python dataclasses
16+
(:user:`benjeffery`, :pr:`1261`).
17+
18+
- Metadata decoding for table row access is now lazy (:user:`benjeffery`, :pr:`1261`).
19+
20+
1221
**Fixes**
1322

1423
--------------------

python/tests/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,13 @@ def __init__(self, tree_sequence, breakpoints=None):
234234
def make_mutation(id_):
235235
site, node, derived_state, parent, metadata, time = ll_ts.get_mutation(id_)
236236
return tskit.Mutation(
237-
id_=id_,
237+
id=id_,
238238
site=site,
239239
node=node,
240240
time=time,
241241
derived_state=derived_state,
242242
parent=parent,
243-
encoded_metadata=metadata,
243+
metadata=metadata,
244244
metadata_decoder=tskit.metadata.parse_metadata_schema(
245245
ll_ts.get_table_metadata_schemas().mutation
246246
).decode_row,
@@ -250,11 +250,11 @@ def make_mutation(id_):
250250
pos, ancestral_state, ll_mutations, id_, metadata = ll_ts.get_site(j)
251251
self._sites.append(
252252
tskit.Site(
253-
id_=id_,
253+
id=id_,
254254
position=pos,
255255
ancestral_state=ancestral_state,
256256
mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations],
257-
encoded_metadata=metadata,
257+
metadata=metadata,
258258
metadata_decoder=tskit.metadata.parse_metadata_schema(
259259
ll_ts.get_table_metadata_schemas().site
260260
).decode_row,

python/tests/test_highlevel.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,17 +1241,7 @@ def verify_mutations(self, ts):
12411241
assert ts.num_mutations == len(other_mutations)
12421242
assert ts.num_mutations == len(mutations)
12431243
for mut, other_mut in zip(mutations, other_mutations):
1244-
# We cannot compare these directly as the mutations obtained
1245-
# from the mutations iterator will have extra deprecated
1246-
# attributes.
1247-
assert mut.id == other_mut.id
1248-
assert mut.site == other_mut.site
1249-
assert mut.parent == other_mut.parent
1250-
assert mut.node == other_mut.node
1251-
assert mut.metadata == other_mut.metadata
1252-
# Check the deprecated attrs.
1253-
assert mut.position == ts.site(mut.site).position
1254-
assert mut.index == mut.site
1244+
assert mut == other_mut
12551245

12561246
def test_sites_mutations(self):
12571247
# Check that the mutations iterator returns the correct values.
@@ -2103,17 +2093,7 @@ def verify_mutations(self, tree):
21032093
assert tree.num_mutations == len(other_mutations)
21042094
assert tree.num_mutations == len(mutations)
21052095
for mut, other_mut in zip(mutations, other_mutations):
2106-
# We cannot compare these directly as the mutations obtained
2107-
# from the mutations iterator will have extra deprecated
2108-
# attributes.
2109-
assert mut.id == other_mut.id
2110-
assert mut.site == other_mut.site
2111-
assert mut.parent == other_mut.parent
2112-
assert mut.node == other_mut.node
2113-
assert mut.metadata == other_mut.metadata
2114-
# Check the deprecated attrs.
2115-
assert mut.position == tree.tree_sequence.site(mut.site).position
2116-
assert mut.index == mut.site
2096+
assert mut == other_mut
21172097

21182098
def test_simple_mutations(self):
21192099
tree = self.get_tree()
@@ -2991,17 +2971,18 @@ def test_metadata(self):
29912971
(inst,) = self.get_instances(1)
29922972
(inst2,) = self.get_instances(1)
29932973
assert inst == inst2
2994-
inst._metadata_decoder = lambda m: "different decoder"
2974+
inst.metadata
29952975
assert inst == inst2
2996-
inst._encoded_metadata = b"different"
2997-
assert not (inst == inst2)
2976+
inst._metadata = "different"
2977+
assert inst != inst2
29982978

29992979
def test_decoder_run_once(self):
30002980
# For a given instance, the decoded metadata should be cached, with the decoder
30012981
# called once
30022982
(inst,) = self.get_instances(1)
30032983
times_run = 0
30042984

2985+
# Hack in a tracing decoder
30052986
def decoder(m):
30062987
nonlocal times_run
30072988
times_run += 1
@@ -3019,12 +3000,12 @@ class TestIndividualContainer(SimpleContainersMixin, SimpleContainersWithMetadat
30193000
def get_instances(self, n):
30203001
return [
30213002
tskit.Individual(
3022-
id_=j,
3003+
id=j,
30233004
flags=j,
30243005
location=[j],
30253006
parents=[j],
30263007
nodes=[j],
3027-
encoded_metadata=b"x" * j,
3008+
metadata=b"x" * j,
30283009
metadata_decoder=lambda m: m.decode() + "decoded",
30293010
)
30303011
for j in range(n)
@@ -3035,12 +3016,12 @@ class TestNodeContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin
30353016
def get_instances(self, n):
30363017
return [
30373018
tskit.Node(
3038-
id_=j,
3019+
id=j,
30393020
flags=j,
30403021
time=j,
30413022
population=j,
30423023
individual=j,
3043-
encoded_metadata=b"x" * j,
3024+
metadata=b"x" * j,
30443025
metadata_decoder=lambda m: m.decode() + "decoded",
30453026
)
30463027
for j in range(n)
@@ -3055,9 +3036,9 @@ def get_instances(self, n):
30553036
right=j,
30563037
parent=j,
30573038
child=j,
3058-
encoded_metadata=b"x" * j,
3039+
metadata=b"x" * j,
30593040
metadata_decoder=lambda m: m.decode() + "decoded",
3060-
id_=j,
3041+
id=j,
30613042
)
30623043
for j in range(n)
30633044
]
@@ -3067,11 +3048,11 @@ class TestSiteContainer(SimpleContainersMixin, SimpleContainersWithMetadataMixin
30673048
def get_instances(self, n):
30683049
return [
30693050
tskit.Site(
3070-
id_=j,
3051+
id=j,
30713052
position=j,
30723053
ancestral_state="A" * j,
30733054
mutations=TestMutationContainer().get_instances(j),
3074-
encoded_metadata=b"x" * j,
3055+
metadata=b"x" * j,
30753056
metadata_decoder=lambda m: m.decode() + "decoded",
30763057
)
30773058
for j in range(n)
@@ -3082,46 +3063,46 @@ class TestMutationContainer(SimpleContainersMixin, SimpleContainersWithMetadataM
30823063
def get_instances(self, n):
30833064
return [
30843065
tskit.Mutation(
3085-
id_=j,
3066+
id=j,
30863067
site=j,
30873068
node=j,
30883069
time=j,
30893070
derived_state="A" * j,
30903071
parent=j,
3091-
encoded_metadata=b"x" * j,
3072+
metadata=b"x" * j,
30923073
metadata_decoder=lambda m: m.decode() + "decoded",
30933074
)
30943075
for j in range(n)
30953076
]
30963077

30973078
def test_nan_equality(self):
30983079
a = tskit.Mutation(
3099-
id_=42,
3080+
id=42,
31003081
site=42,
31013082
node=42,
31023083
time=UNKNOWN_TIME,
31033084
derived_state="A" * 42,
31043085
parent=42,
3105-
encoded_metadata=b"x" * 42,
3086+
metadata=b"x" * 42,
31063087
metadata_decoder=lambda m: m.decode() + "decoded",
31073088
)
31083089
b = tskit.Mutation(
3109-
id_=42,
3090+
id=42,
31103091
site=42,
31113092
node=42,
31123093
derived_state="A" * 42,
31133094
parent=42,
3114-
encoded_metadata=b"x" * 42,
3095+
metadata=b"x" * 42,
31153096
metadata_decoder=lambda m: m.decode() + "decoded",
31163097
)
31173098
c = tskit.Mutation(
3118-
id_=42,
3099+
id=42,
31193100
site=42,
31203101
node=42,
31213102
time=math.nan,
31223103
derived_state="A" * 42,
31233104
parent=42,
3124-
encoded_metadata=b"x" * 42,
3105+
metadata=b"x" * 42,
31253106
metadata_decoder=lambda m: m.decode() + "decoded",
31263107
)
31273108
assert a == a
@@ -3139,13 +3120,14 @@ class TestMigrationContainer(SimpleContainersMixin, SimpleContainersWithMetadata
31393120
def get_instances(self, n):
31403121
return [
31413122
tskit.Migration(
3123+
id=j,
31423124
left=j,
31433125
right=j,
31443126
node=j,
31453127
source=j,
31463128
dest=j,
31473129
time=j,
3148-
encoded_metadata=b"x" * j,
3130+
metadata=b"x" * j,
31493131
metadata_decoder=lambda m: m.decode() + "decoded",
31503132
)
31513133
for j in range(n)
@@ -3156,8 +3138,8 @@ class TestPopulationContainer(SimpleContainersMixin, SimpleContainersWithMetadat
31563138
def get_instances(self, n):
31573139
return [
31583140
tskit.Population(
3159-
id_=j,
3160-
encoded_metadata=b"x" * j,
3141+
id=j,
3142+
metadata=b"x" * j,
31613143
metadata_decoder=lambda m: m.decode() + "decoded",
31623144
)
31633145
for j in range(n)
@@ -3167,7 +3149,7 @@ def get_instances(self, n):
31673149
class TestProvenanceContainer(SimpleContainersMixin):
31683150
def get_instances(self, n):
31693151
return [
3170-
tskit.Provenance(id_=j, timestamp="x" * j, record="y" * j) for j in range(n)
3152+
tskit.Provenance(id=j, timestamp="x" * j, record="y" * j) for j in range(n)
31713153
]
31723154

31733155

python/tests/test_stats.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,17 @@ def verify_max_distance(self, ts):
108108
A = ldc.get_r2_matrix()
109109
j = len(mutations) // 2
110110
for k in range(j):
111-
x = mutations[j + k].position - mutations[j].position
111+
x = (
112+
ts.site(mutations[j + k].site).position
113+
- ts.site(mutations[j].site).position
114+
)
112115
a = ldc.get_r2_array(j, max_distance=x)
113116
assert a.shape[0] == k
114117
assert np.allclose(A[j, j + 1 : j + 1 + k], a)
115-
x = mutations[j].position - mutations[j - k].position
118+
x = (
119+
ts.site(mutations[j].site).position
120+
- ts.site(mutations[j - k].site).position
121+
)
116122
a = ldc.get_r2_array(j, max_distance=x, direction=tskit.REVERSE)
117123
assert a.shape[0] == k
118124
assert np.allclose(A[j, j - k : j], a[::-1])

python/tskit/metadata.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
Classes for metadata decoding, encoding and validation
2424
"""
2525
import abc
26+
import builtins
2627
import collections
2728
import copy
2829
import json
2930
import pprint
3031
import struct
32+
import types
3133
from itertools import islice
3234
from typing import Any
3335
from typing import Mapping
@@ -39,6 +41,8 @@
3941
import tskit
4042
import tskit.exceptions as exceptions
4143

44+
__builtin__object__setattr__ = builtins.object.__setattr__
45+
4246

4347
def replace_root_refs(obj):
4448
if type(obj) == list:
@@ -656,3 +660,56 @@ def parse_metadata_schema(encoded_schema: str) -> MetadataSchema:
656660
except json.decoder.JSONDecodeError:
657661
raise ValueError(f"Metadata schema is not JSON, found {encoded_schema}")
658662
return MetadataSchema(decoded)
663+
664+
665+
class _CachedMetadata:
666+
"""
667+
Descriptor for lazy decoding of metadata on attribute access.
668+
"""
669+
670+
def __get__(self, row, owner):
671+
if row._metadata_decoder is not None:
672+
# Some classes that use this are frozen so we need to directly setattr.
673+
__builtin__object__setattr__(
674+
row, "_metadata", row._metadata_decoder(row._metadata)
675+
)
676+
# Decoder being None indicates that metadata is decoded
677+
__builtin__object__setattr__(row, "_metadata_decoder", None)
678+
return row._metadata
679+
680+
def __set__(self, row, value):
681+
__builtin__object__setattr__(row, "_metadata", value)
682+
683+
684+
def lazy_decode(cls):
685+
"""
686+
Modifies a dataclass such that it lazily decodes metadata, if it is encoded.
687+
If the metadata passed to the constructor is encoded a `decoder` parameter must be
688+
also be passed.
689+
"""
690+
wrapped_init = cls.__init__
691+
692+
# Intercept the init to catch the table reference needed for the decoder
693+
def new_init(self, *args, metadata_decoder=None, **kwargs):
694+
__builtin__object__setattr__(self, "_metadata_decoder", metadata_decoder)
695+
wrapped_init(self, *args, **kwargs)
696+
697+
cls.__init__ = new_init
698+
699+
# Add a descriptor to the class to decode and cache metadata
700+
cls.metadata = _CachedMetadata()
701+
702+
# Add slots needed to the class
703+
slots = cls.__slots__
704+
slots.extend(["_metadata", "_metadata_decoder"])
705+
dict_ = dict()
706+
sloted_members = dict()
707+
for k, v in cls.__dict__.items():
708+
if k not in slots:
709+
dict_[k] = v
710+
elif not isinstance(v, types.MemberDescriptorType):
711+
sloted_members[k] = v
712+
new_cls = type(cls.__name__, cls.__bases__, dict_)
713+
for k, v in sloted_members.items():
714+
setattr(new_cls, k, v)
715+
return new_cls

0 commit comments

Comments
 (0)