Skip to content

Commit 0004113

Browse files
authored
Merge pull request #431 from OpenCOMPES/energy_calibration_performance_fix
faster version of per_file channels
2 parents ffe2013 + 6854d39 commit 0004113

File tree

2 files changed

+94
-101
lines changed

2 files changed

+94
-101
lines changed

sed/core/dfops.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -390,53 +390,20 @@ def offset_by_other_columns(
390390
"Please open a request on GitHub if this feature is required.",
391391
)
392392

393-
# calculate the mean of the columns to reduce
394-
means = {
395-
col: dask.delayed(df[col].mean())
396-
for col, red, pm in zip(offset_columns, reductions, preserve_mean)
397-
if red or pm
398-
}
399-
400-
# define the functions to apply the offsets
401-
def shift_by_mean(x, cols, signs, means, flip_signs=False):
402-
"""Shift the target column by the mean of the offset columns."""
403-
for col in cols:
404-
s = -signs[col] if flip_signs else signs[col]
405-
x[target_column] = x[target_column] + s * means[col]
406-
return x[target_column]
407-
408-
def shift_by_row(x, cols, signs):
409-
"""Apply the offsets to the target column."""
410-
for col in cols:
411-
x[target_column] = x[target_column] + signs[col] * x[col]
412-
return x[target_column]
413-
414393
# apply offset from the reduced columns
415-
df[target_column] = df.map_partitions(
416-
shift_by_mean,
417-
cols=[col for col, red in zip(offset_columns, reductions) if red],
418-
signs=signs_dict,
419-
means=means,
420-
meta=df[target_column].dtype,
421-
)
394+
for col, red in zip(offset_columns, reductions):
395+
if red == "mean":
396+
df[target_column] = df[target_column] + signs_dict[col] * df[col].mean()
422397

423398
# apply offset from the offset columns
424-
df[target_column] = df.map_partitions(
425-
shift_by_row,
426-
cols=[col for col, red in zip(offset_columns, reductions) if not red],
427-
signs=signs_dict,
428-
meta=df[target_column].dtype,
429-
)
399+
for col, red in zip(offset_columns, reductions):
400+
if not red:
401+
df[target_column] = df[target_column] + signs_dict[col] * df[col]
430402

431403
# compensate shift from the preserved mean columns
432404
if any(preserve_mean):
433-
df[target_column] = df.map_partitions(
434-
shift_by_mean,
435-
cols=[col for col, pmean in zip(offset_columns, preserve_mean) if pmean],
436-
signs=signs_dict,
437-
means=means,
438-
flip_signs=True,
439-
meta=df[target_column].dtype,
440-
)
405+
for col, pmean in zip(offset_columns, preserve_mean):
406+
if pmean:
407+
df[target_column] = df[target_column] - signs_dict[col] * df[col].mean()
441408

442409
return df

sed/loader/mpes/loader.py

Lines changed: 85 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -67,30 +67,26 @@ def hdf5_to_dataframe(
6767
seach_pattern="Stream",
6868
)
6969

70-
channel_list = []
70+
electron_channels = []
7171
column_names = []
7272

7373
for name, channel in channels.items():
74-
if (
75-
channel["format"] == "per_electron"
76-
and channel["dataset_key"] in test_proc
77-
or channel["format"] == "per_file"
78-
and channel["dataset_key"] in test_proc.attrs
79-
):
80-
channel_list.append(channel)
81-
column_names.append(name)
82-
else:
83-
print(
84-
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
85-
"Skipping the channel.",
86-
)
74+
if channel["format"] == "per_electron":
75+
if channel["dataset_key"] in test_proc:
76+
electron_channels.append(channel)
77+
column_names.append(name)
78+
else:
79+
print(
80+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
81+
"Skipping the channel.",
82+
)
8783

8884
if time_stamps:
8985
column_names.append(time_stamp_alias)
9086

9187
test_array = hdf5_to_array(
9288
h5file=test_proc,
93-
channels=channel_list,
89+
channels=electron_channels,
9490
time_stamps=time_stamps,
9591
ms_markers_key=ms_markers_key,
9692
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -101,7 +97,7 @@ def hdf5_to_dataframe(
10197
da.from_delayed(
10298
dask.delayed(hdf5_to_array)(
10399
h5file=h5py.File(f),
104-
channels=channel_list,
100+
channels=electron_channels,
105101
time_stamps=time_stamps,
106102
ms_markers_key=ms_markers_key,
107103
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -113,7 +109,25 @@ def hdf5_to_dataframe(
113109
]
114110
array_stack = da.concatenate(arrays, axis=1).T
115111

116-
return ddf.from_dask_array(array_stack, columns=column_names)
112+
dataframe = ddf.from_dask_array(array_stack, columns=column_names)
113+
114+
for name, channel in channels.items():
115+
if channel["format"] == "per_file":
116+
if channel["dataset_key"] in test_proc.attrs:
117+
values = [float(get_attribute(h5py.File(f), channel["dataset_key"])) for f in files]
118+
delayeds = [
119+
add_value(partition, name, value)
120+
for partition, value in zip(dataframe.partitions, values)
121+
]
122+
dataframe = ddf.from_delayed(delayeds)
123+
124+
else:
125+
print(
126+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
127+
"Skipping the channel.",
128+
)
129+
130+
return dataframe
117131

118132

119133
def hdf5_to_timed_dataframe(
@@ -156,30 +170,26 @@ def hdf5_to_timed_dataframe(
156170
seach_pattern="Stream",
157171
)
158172

159-
channel_list = []
173+
electron_channels = []
160174
column_names = []
161175

162176
for name, channel in channels.items():
163-
if (
164-
channel["format"] == "per_electron"
165-
and channel["dataset_key"] in test_proc
166-
or channel["format"] == "per_file"
167-
and channel["dataset_key"] in test_proc.attrs
168-
):
169-
channel_list.append(channel)
170-
column_names.append(name)
171-
else:
172-
print(
173-
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
174-
"Skipping the channel.",
175-
)
177+
if channel["format"] == "per_electron":
178+
if channel["dataset_key"] in test_proc:
179+
electron_channels.append(channel)
180+
column_names.append(name)
181+
else:
182+
print(
183+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
184+
"Skipping the channel.",
185+
)
176186

177187
if time_stamps:
178188
column_names.append(time_stamp_alias)
179189

180190
test_array = hdf5_to_timed_array(
181191
h5file=test_proc,
182-
channels=channel_list,
192+
channels=electron_channels,
183193
time_stamps=time_stamps,
184194
ms_markers_key=ms_markers_key,
185195
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -190,7 +200,7 @@ def hdf5_to_timed_dataframe(
190200
da.from_delayed(
191201
dask.delayed(hdf5_to_timed_array)(
192202
h5file=h5py.File(f),
193-
channels=channel_list,
203+
channels=electron_channels,
194204
time_stamps=time_stamps,
195205
ms_markers_key=ms_markers_key,
196206
first_event_time_stamp_key=first_event_time_stamp_key,
@@ -202,7 +212,41 @@ def hdf5_to_timed_dataframe(
202212
]
203213
array_stack = da.concatenate(arrays, axis=1).T
204214

205-
return ddf.from_dask_array(array_stack, columns=column_names)
215+
dataframe = ddf.from_dask_array(array_stack, columns=column_names)
216+
217+
for name, channel in channels.items():
218+
if channel["format"] == "per_file":
219+
if channel["dataset_key"] in test_proc.attrs:
220+
values = [float(get_attribute(h5py.File(f), channel["dataset_key"])) for f in files]
221+
delayeds = [
222+
add_value(partition, name, value)
223+
for partition, value in zip(dataframe.partitions, values)
224+
]
225+
dataframe = ddf.from_delayed(delayeds)
226+
227+
else:
228+
print(
229+
f"Entry \"{channel['dataset_key']}\" for channel \"{name}\" not found.",
230+
"Skipping the channel.",
231+
)
232+
233+
return dataframe
234+
235+
236+
@dask.delayed
237+
def add_value(partition: ddf.DataFrame, name: str, value: float) -> ddf.DataFrame:
238+
"""Dask delayed helper function to add a value to each dataframe partition
239+
240+
Args:
241+
partition (ddf.DataFrame): Dask dataframe partition
242+
name (str): Name of the column to add
243+
value (float): value to add to this partition
244+
245+
Returns:
246+
ddf.DataFrame: Dataframe partition with added column
247+
"""
248+
partition[name] = value
249+
return partition
206250

207251

208252
def get_datasets_and_aliases(
@@ -256,7 +300,7 @@ def hdf5_to_array(
256300
Args:
257301
h5file (h5py.File):
258302
hdf5 file handle to read from
259-
electron_channels (Sequence[Dict[str, any]]):
303+
channels (Sequence[Dict[str, any]]):
260304
channel dicts containing group names and types to read.
261305
time_stamps (bool, optional):
262306
Option to calculate time stamps. Defaults to False.
@@ -270,40 +314,25 @@ def hdf5_to_array(
270314
"""
271315

272316
# Delayed array for loading an HDF5 file of reasonable size (e.g. < 1GB)
273-
274-
# determine group length from per_electron column:
275-
nelectrons = 0
276-
for channel in channels:
277-
if channel["format"] == "per_electron":
278-
nelectrons = len(h5file[channel["dataset_key"]])
279-
break
280-
if nelectrons == 0:
281-
raise ValueError("No 'per_electron' columns defined, or no hits found in file.")
282-
283317
# Read out groups:
284318
data_list = []
285319
for channel in channels:
286320
if channel["format"] == "per_electron":
287321
g_dataset = np.asarray(h5file[channel["dataset_key"]])
288-
elif channel["format"] == "per_file":
289-
value = float(get_attribute(h5file, channel["dataset_key"]))
290-
g_dataset = np.asarray([value] * nelectrons)
291322
else:
292323
raise ValueError(
293324
f"Invalid 'format':{channel['format']} for channel {channel['dataset_key']}.",
294325
)
295-
if "data_type" in channel.keys():
296-
g_dataset = g_dataset.astype(channel["data_type"])
326+
if "dtype" in channel.keys():
327+
g_dataset = g_dataset.astype(channel["dtype"])
297328
else:
298329
g_dataset = g_dataset.astype("float32")
299-
if len(g_dataset) != nelectrons:
300-
raise ValueError(f"Inconsistent entries found for channel {channel['dataset_key']}.")
301330
data_list.append(g_dataset)
302331

303332
# calculate time stamps
304333
if time_stamps:
305334
# create target array for time stamps
306-
time_stamp_data = np.zeros(nelectrons)
335+
time_stamp_data = np.zeros(len(data_list[0]))
307336
# the ms marker contains a list of events that occurred at full ms intervals.
308337
# It's monotonically increasing, and can contain duplicates
309338
ms_marker = np.asarray(h5file[ms_markers_key])
@@ -357,7 +386,7 @@ def hdf5_to_timed_array(
357386
Args:
358387
h5file (h5py.File):
359388
hdf5 file handle to read from
360-
electron_channels (Sequence[Dict[str, any]]):
389+
channels (Sequence[Dict[str, any]]):
361390
channel dicts containing group names and types to read.
362391
time_stamps (bool, optional):
363392
Option to calculate time stamps. Defaults to False.
@@ -382,15 +411,12 @@ def hdf5_to_timed_array(
382411
g_dataset = np.asarray(h5file[channel["dataset_key"]])
383412
for i, point in enumerate(ms_marker):
384413
timed_dataset[i] = g_dataset[int(point) - 1]
385-
elif channel["format"] == "per_file":
386-
value = float(get_attribute(h5file, channel["dataset_key"]))
387-
timed_dataset[:] = value
388414
else:
389415
raise ValueError(
390416
f"Invalid 'format':{channel['format']} for channel {channel['dataset_key']}.",
391417
)
392-
if "data_type" in channel.keys():
393-
timed_dataset = timed_dataset.astype(channel["data_type"])
418+
if "dtype" in channel.keys():
419+
timed_dataset = timed_dataset.astype(channel["dtype"])
394420
else:
395421
timed_dataset = timed_dataset.astype("float32")
396422

0 commit comments

Comments
 (0)