Skip to content

Commit 6619eba

Browse files
committed
[Python][UHI] Start introducing UHI serialization
1 parent c8e955e commit 6619eba

File tree

6 files changed

+124
-3
lines changed

6 files changed

+124
-3
lines changed

bindings/pyroot/pythonizations/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ set(py_sources
128128
ROOT/_pythonization/_uhi/tags.py
129129
ROOT/_pythonization/_uhi/indexing.py
130130
ROOT/_pythonization/_uhi/plotting.py
131+
ROOT/_pythonization/_uhi/serialization.py
131132
${PYROOT_EXTRA_PYTHON_SOURCES}
132133
)
133134

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_th1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _FillWithNumpyArray(self, *args):
249249
def pythonize_th1(klass):
250250
# Parameters:
251251
# klass: class to be pythonized
252-
from ROOT._pythonization._uhi.main import _add_indexing_features
252+
from ROOT._pythonization._uhi.main import _add_indexing_features, _add_serialization_features
253253

254254
# Support hist *= scalar
255255
klass.__imul__ = _imul
@@ -261,7 +261,8 @@ def pythonize_th1(klass):
261261
klass._Original_SetDirectory = klass.SetDirectory
262262
klass.SetDirectory = _SetDirectory_SetOwnership
263263

264-
# Add UHI indexing features
264+
# Add UHI indexing and serialization features
265265
_add_indexing_features(klass)
266+
_add_serialization_features(klass)
266267

267268
inject_clone_releasing_ownership(klass)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,14 @@ def _add_plotting_features(klass: Any) -> None:
6464
klass.counts = _counts
6565
klass.axes = property(_axes)
6666
klass.values = values_func_dict.get(klass.__name__, _values_default)
67+
68+
69+
"""
70+
Implementation of the serialization component of the UHI
71+
"""
72+
73+
74+
def _add_serialization_features(klass: Any) -> None:
75+
from .serialization import _to_uhi_
76+
77+
klass._to_uhi_ = _to_uhi_

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_uhi/plotting.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def circular(self) -> bool:
3838
def discrete(self) -> bool:
3939
return self._discrete
4040

41+
@property
42+
def underflow(self) -> bool:
43+
return True
44+
45+
@property
46+
def overflow(self) -> bool:
47+
return True
48+
4149

4250
class PlottableAxisBase(ABC):
4351
def __init__(self, tAxis: Any) -> None:
@@ -104,7 +112,7 @@ def _hasWeights(hist: Any) -> bool:
104112
def _axes(self) -> Tuple[Union[PlottableAxisContinuous, PlottableAxisDiscrete], ...]:
105113
return tuple(PlottableAxisFactory.create(_get_axis(self, i)) for i in range(self.GetDimension()))
106114

107-
115+
# TODO this is not correct?
108116
def _kind(self) -> Kind:
109117
return Kind.COUNT if not _hasWeights(self) else Kind.MEAN
110118

@@ -174,6 +182,8 @@ def _counts(self) -> np.typing.NDArray[Any]: # noqa: F821
174182
where=sum_of_weights_squared != 0,
175183
)
176184

185+
def _get_sum_of_weights(self) -> np.typing.NDArray[Any]: # noqa: F821
186+
return self.values()
177187

178188
def _get_sum_of_weights_squared(self) -> np.typing.NDArray[Any]: # noqa: F821
179189
import numpy as np
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Author: Silia Taider CERN 10/2025
2+
3+
################################################################################
4+
# Copyright (C) 1995-2025, Rene Brun and Fons Rademakers. #
5+
# All rights reserved. #
6+
# #
7+
# For the licensing terms see $ROOTSYS/LICENSE. #
8+
# For the list of contributors see $ROOTSYS/README/CREDITS. #
9+
################################################################################
10+
from __future__ import annotations
11+
12+
from typing import Any
13+
14+
import ROOT
15+
16+
from .plotting import PlottableAxisBase, _get_sum_of_weights, _get_sum_of_weights_squared, _hasWeights
17+
from .tags import _get_axis
18+
19+
"""
20+
Implementation of the serialization component of the UHI
21+
"""
22+
23+
24+
def _axis_to_dict(root_axis: ROOT.TAxis, uhi_axis: PlottableAxisBase) -> dict[str, Any]:
25+
return {
26+
"type": "regular",
27+
"lower": root_axis.GetBinLowEdge(root_axis.GetFirst()),
28+
"upper": root_axis.GetBinUpEdge(root_axis.GetLast()),
29+
"bins": root_axis.GetNbins(),
30+
"underflow": uhi_axis.traits.underflow,
31+
"overflow": uhi_axis.traits.overflow,
32+
"circular": uhi_axis.traits.circular,
33+
}
34+
35+
36+
def _storage_to_dict(hist: Any) -> dict[str, Any]:
37+
"""
38+
Logic:
39+
- If histogram is a profile (TProfile*) --> Kind="MEAN":
40+
- if histogram has Sumw2: type is weighted_mean_storage (if _hasWeights(hist))
41+
- else: storage type is mean_storage
42+
- Else (TH1*/TH2*/TH3*) --> Kind="COUNT":
43+
- if histogram has Sumw2: type is weighted_storage
44+
- else if histogram is TH*I: type is int_storage
45+
- else: type is double_storage
46+
"""
47+
storage_dict = {
48+
"values": hist.values(),
49+
}
50+
51+
if hist.kind == "MEAN":
52+
storage_dict["variances"] = hist.variances()
53+
54+
if _hasWeights(hist):
55+
storage_dict["type"] = "weighted_mean"
56+
storage_dict["sum_of_weights"] = _get_sum_of_weights(hist)
57+
storage_dict["sum_of_weights_squared"] = _get_sum_of_weights_squared(hist)
58+
else:
59+
storage_dict["type"] = "mean"
60+
storage_dict["counts"] = hist.counts()
61+
62+
else: # COUNT
63+
if _hasWeights(hist):
64+
storage_dict["type"] = "weighted"
65+
storage_dict["variances"] = hist.variances()
66+
else:
67+
if hist.ClassName().endswith("I"):
68+
storage_dict["type"] = "int"
69+
else:
70+
storage_dict["type"] = "double"
71+
72+
return storage_dict
73+
74+
75+
def _to_uhi_(self) -> dict[str, Any]:
76+
return {
77+
"uhi_schema": 1,
78+
"writer_info": {"ROOT": {"version": ROOT.__version__, "class": self.ClassName()}},
79+
"axes": [_axis_to_dict(_get_axis(self, i), self.axes[i]) for i in range(self.GetDimension())],
80+
"storage": _storage_to_dict(self),
81+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import json
2+
3+
import ROOT
4+
import uhi.io.json
5+
6+
import hist
7+
8+
h = ROOT.TH1D("h", "h", 10, -5, 5)
9+
h[...] = range(10)
10+
print("\nh=", h)
11+
print("values=", h.values())
12+
13+
ob = json.dumps(h, default=uhi.io.json.default)
14+
ir = json.loads(ob, object_hook=uhi.io.json.object_hook)
15+
16+
h_loaded = hist.Hist(ir)
17+
print("\nh_loaded =\n", h_loaded)

0 commit comments

Comments
 (0)