Skip to content

Commit 20f0603

Browse files
committed
faster version of per_file channels
1 parent 9fea323 commit 20f0603

File tree

1 file changed

+85
-59
lines changed

1 file changed

+85
-59
lines changed

sed/loader/mpes/loader.py

Lines changed: 85 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,26 @@ def hdf5_to_dataframe(
6565
seach_pattern="Stream",
6666
)
6767

68-
channel_list = []
68+
electron_channels = []
6969
column_names = []
7070

7171
for name, channel in channels.items():
72-
if (
73-
channel["format"] == "per_electron"
74-
and channel["dataset_key"] in test_proc
75-
or channel["format"] == "per_file"
76-
and channel["dataset_key"] in test_proc.attrs
77-
):
78-
channel_list.append(channel)
79-
column_names.append(name)
80-
else:
81-
print(
82-
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
83-
"Skipping the channel.",
84-
)
72+
if channel["format"] == "per_electron":
73+
if channel["dataset_key"] in test_proc:
74+
electron_channels.append(channel)
75+
column_names.append(name)
76+
else:
77+
print(
78+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
79+
"Skipping the channel.",
80+
)
8581

8682
if time_stamps:
8783
column_names.append(time_stamp_alias)
8884

8985
test_array = hdf5_to_array(
9086
h5file=test_proc,
91-
channels=channel_list,
87+
channels=electron_channels,
9288
time_stamps=time_stamps,
9389
ms_markers_key=ms_markers_key,
9490
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -99,7 +95,7 @@ def hdf5_to_dataframe(
9995
da.from_delayed(
10096
dask.delayed(hdf5_to_array)(
10197
h5file=h5py.File(f),
102-
channels=channel_list,
98+
channels=electron_channels,
10399
time_stamps=time_stamps,
104100
ms_markers_key=ms_markers_key,
105101
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -111,7 +107,25 @@ def hdf5_to_dataframe(
111107
]
112108
array_stack = da.concatenate(arrays, axis=1).T
113109

114-
return ddf.from_dask_array(array_stack, columns=column_names)
110+
dataframe = ddf.from_dask_array(array_stack, columns=column_names)
111+
112+
for name, channel in channels.items():
113+
if channel["format"] == "per_file":
114+
if channel["dataset_key"] in test_proc.attrs:
115+
values = [float(get_attribute(h5py.File(f), channel["dataset_key"])) for f in files]
116+
delayeds = [
117+
add_value(partition, name, value)
118+
for partition, value in zip(dataframe.partitions, values)
119+
]
120+
dataframe = ddf.from_delayed(delayeds)
121+
122+
else:
123+
print(
124+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
125+
"Skipping the channel.",
126+
)
127+
128+
return dataframe
115129

116130

117131
def hdf5_to_timed_dataframe(
@@ -154,30 +168,26 @@ def hdf5_to_timed_dataframe(
154168
seach_pattern="Stream",
155169
)
156170

157-
channel_list = []
171+
electron_channels = []
158172
column_names = []
159173

160174
for name, channel in channels.items():
161-
if (
162-
channel["format"] == "per_electron"
163-
and channel["dataset_key"] in test_proc
164-
or channel["format"] == "per_file"
165-
and channel["dataset_key"] in test_proc.attrs
166-
):
167-
channel_list.append(channel)
168-
column_names.append(name)
169-
else:
170-
print(
171-
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
172-
"Skipping the channel.",
173-
)
175+
if channel["format"] == "per_electron":
176+
if channel["dataset_key"] in test_proc:
177+
electron_channels.append(channel)
178+
column_names.append(name)
179+
else:
180+
print(
181+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
182+
"Skipping the channel.",
183+
)
174184

175185
if time_stamps:
176186
column_names.append(time_stamp_alias)
177187

178188
test_array = hdf5_to_timed_array(
179189
h5file=test_proc,
180-
channels=channel_list,
190+
channels=electron_channels,
181191
time_stamps=time_stamps,
182192
ms_markers_key=ms_markers_key,
183193
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -188,7 +198,7 @@ def hdf5_to_timed_dataframe(
188198
da.from_delayed(
189199
dask.delayed(hdf5_to_timed_array)(
190200
h5file=h5py.File(f),
191-
channels=channel_list,
201+
channels=electron_channels,
192202
time_stamps=time_stamps,
193203
ms_markers_key=ms_markers_key,
194204
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -200,7 +210,41 @@ def hdf5_to_timed_dataframe(
200210
]
201211
array_stack = da.concatenate(arrays, axis=1).T
202212

203-
return ddf.from_dask_array(array_stack, columns=column_names)
213+
dataframe = ddf.from_dask_array(array_stack, columns=column_names)
214+
215+
for name, channel in channels.items():
216+
if channel["format"] == "per_file":
217+
if channel["dataset_key"] in test_proc.attrs:
218+
values = [float(get_attribute(h5py.File(f), channel["dataset_key"])) for f in files]
219+
delayeds = [
220+
add_value(partition, name, value)
221+
for partition, value in zip(dataframe.partitions, values)
222+
]
223+
dataframe = ddf.from_delayed(delayeds)
224+
225+
else:
226+
print(
227+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
228+
"Skipping the channel.",
229+
)
230+
231+
return dataframe
232+
233+
234+
@dask.delayed
235+
def add_value(partition: ddf.DataFrame, name: str, value: float) -> ddf.DataFrame:
236+
"""Dask delayed helper function to add a value to each dataframe partition
237+
238+
Args:
239+
partition (ddf.DataFrame): Dask dataframe partition
240+
name (str): Name of the column to add
241+
value (float): value to add to this partition
242+
243+
Returns:
244+
ddf.DataFrame: Dataframe partition with added column
245+
"""
246+
partition[name] = value
247+
return partition
204248

205249

206250
def get_datasets_and_aliases(
@@ -254,7 +298,7 @@ def hdf5_to_array(
254298
Args:
255299
h5file (h5py.File):
256300
hdf5 file handle to read from
257-
electron_channels (Sequence[Dict[str, any]]):
301+
channels (Sequence[Dict[str, any]]):
258302
channel dicts containing group names and types to read.
259303
time_stamps (bool, optional):
260304
Option to calculate time stamps. Defaults to False.
@@ -268,40 +312,25 @@ def hdf5_to_array(
268312
"""
269313

270314
# Delayed array for loading an HDF5 file of reasonable size (e.g. < 1GB)
271-
272-
# determine group length from per_electron column:
273-
nelectrons = 0
274-
for channel in channels:
275-
if channel["format"] == "per_electron":
276-
nelectrons = len(h5file[channel["dataset_key"]])
277-
break
278-
if nelectrons == 0:
279-
raise ValueError("No 'per_electron' columns defined, or no hits found in file.")
280-
281315
# Read out groups:
282316
data_list = []
283317
for channel in channels:
284318
if channel["format"] == "per_electron":
285319
g_dataset = np.asarray(h5file[channel["dataset_key"]])
286-
elif channel["format"] == "per_file":
287-
value = float(get_attribute(h5file, channel["dataset_key"]))
288-
g_dataset = np.asarray([value] * nelectrons)
289320
else:
290321
raise ValueError(
291322
f"Invalid 'format':{channel['format']} for channel {channel['dataset_key']}.",
292323
)
293-
if "data_type" in channel.keys():
294-
g_dataset = g_dataset.astype(channel["data_type"])
324+
if "dtype" in channel.keys():
325+
g_dataset = g_dataset.astype(channel["dtype"])
295326
else:
296327
g_dataset = g_dataset.astype("float32")
297-
if len(g_dataset) != nelectrons:
298-
raise ValueError(f"Inconsistent entries found for channel {channel['dataset_key']}.")
299328
data_list.append(g_dataset)
300329

301330
# calculate time stamps
302331
if time_stamps:
303332
# create target array for time stamps
304-
time_stamp_data = np.zeros(nelectrons)
333+
time_stamp_data = np.zeros(len(data_list[0]))
305334
# the ms marker contains a list of events that occurred at full ms intervals.
306335
# It's monotonically increasing, and can contain duplicates
307336
ms_marker = np.asarray(h5file[ms_markers_key])
@@ -355,7 +384,7 @@ def hdf5_to_timed_array(
355384
Args:
356385
h5file (h5py.File):
357386
hdf5 file handle to read from
358-
electron_channels (Sequence[Dict[str, any]]):
387+
channels (Sequence[Dict[str, any]]):
359388
channel dicts containing group names and types to read.
360389
time_stamps (bool, optional):
361390
Option to calculate time stamps. Defaults to False.
@@ -380,15 +409,12 @@ def hdf5_to_timed_array(
380409
g_dataset = np.asarray(h5file[channel["dataset_key"]])
381410
for i, point in enumerate(ms_marker):
382411
timed_dataset[i] = g_dataset[int(point) - 1]
383-
elif channel["format"] == "per_file":
384-
value = float(get_attribute(h5file, channel["dataset_key"]))
385-
timed_dataset[:] = value
386412
else:
387413
raise ValueError(
388414
f"Invalid 'format':{channel['format']} for channel {channel['dataset_key']}.",
389415
)
390-
if "data_type" in channel.keys():
391-
timed_dataset = timed_dataset.astype(channel["data_type"])
416+
if "dtype" in channel.keys():
417+
timed_dataset = timed_dataset.astype(channel["dtype"])
392418
else:
393419
timed_dataset = timed_dataset.astype("float32")
394420

0 commit comments

Comments
 (0)