Skip to content

Commit 2cffd1a

Browse files
committed
Rough in code to write out NN graph as Weighted Adjacency List
1 parent fec9f22 commit 2cffd1a

File tree

1 file changed

+60
-2
lines changed

1 file changed

+60
-2
lines changed

src/aspire/classification/rir_class2d.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
bispectrum_freq_cutoff=None,
3030
large_pca_implementation="legacy",
3131
nn_implementation="legacy",
32+
output_nn_filename=None,
3233
bispectrum_implementation="legacy",
3334
aligner=None,
3435
dtype=None,
@@ -47,7 +48,8 @@ def __init__(
4748
Z. Zhao, Y. Shkolnisky, A. Singer, Rotationally Invariant Image Representation
4849
for Viewing Direction Classification in Cryo-EM. (2014)
4950
50-
:param src: Source instance
51+
:param src: Source instance. Note it is possible to use one `source` for classification (ie CWF),
52+
and a different `source` for stacking in the `aligner`.
5153
:param pca_basis: Optional FSPCA Basis instance
5254
:param fspca_components: Components (top eigvals) to keep from full FSCPA, default truncates to 400.
5355
:param alpha: Amplitude Power Scale, default 1/3 (eq 20 from RIIR paper).
@@ -119,6 +121,7 @@ def __init__(
119121
f"Provided nn_implementation={nn_implementation} not in {nn_implementations.keys()}"
120122
)
121123
self._nn_classification = nn_implementations[nn_implementation]
124+
self.output_nn_filename = output_nn_filename
122125

123126
# # Do we have a sane Large Dataset PCA
124127
large_pca_implementations = {
@@ -185,6 +188,8 @@ def classify(self, diagnostics=False):
185188
# # Stage 2: Compute Nearest Neighbors
186189
logger.info("Calculate Nearest Neighbors")
187190
classes, reflections, distances = self.nn_classification(coef_b, coef_b_r)
191+
if self.output_nn_filename is not None:
192+
self._save_nn(classes, reflections, distances)
188193

189194
if diagnostics:
190195
# Lets peek at the distribution of distances
@@ -351,7 +356,7 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000):
351356
# Check with Joakim about preference.
352357
# I (GBW) think class[i] should have class[i][0] be the original image index.
353358
classes[start:finish] = np.argsort(-corr, axis=1)[:, :n_nbor]
354-
# Store the corr values for the n_nhors in this batch
359+
# Store the corr values for the n_nbors in this batch
355360
distances[start:finish] = np.take_along_axis(
356361
corr, classes[start:finish], axis=1
357362
)
@@ -366,6 +371,59 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r, batch_size=2000):
366371

367372
return classes, refl, distances
368373

374+
def _save_nn(self, classes, reflections, distances):
375+
"""
376+
Output the Nearest Neighbors graph as a weighted adjacency list.
377+
378+
Vertices are indexed by their natural index in `source`.
379+
Note reflected images are represented by `index + src.n`.
380+
381+
Only the output of the Nearest Neighbor call is saved.
382+
If you want a complete graph, specify 2*src.n neighbors,
383+
that is all images and their reflections.
384+
385+
Because this is mixed datatypes (int and floating),
386+
this will be output as a space delimited text file.
387+
388+
Vi1 Vj1 W_i1_j1 Vj2 Wi1_j2 ...
389+
Vi2 Vj1 W_i2_j1 Vj2 Wi2_j2 ...
390+
...
391+
392+
"""
393+
394+
# Construct the weighted adjacency list
395+
AdjList = []
396+
for k in range(len(classes)):
397+
398+
row = []
399+
vik = classes[k][0]
400+
row.append(vik)
401+
402+
for j in range(1, len(classes[k])):
403+
404+
# Neighbor index
405+
vj = classes[k][j]
406+
if reflections[k][j]:
407+
vj += self.src.n
408+
row.append(vj)
409+
410+
# Neighbor Weight (distance)
411+
wt = distances[k][j]
412+
row.append(wt)
413+
414+
# Store this row of the AdjList
415+
AdjList.append(row)
416+
417+
logger.info(
418+
"Writing Nearest Neighbors as Weighted Adjacency List"
419+
f" to {self.output_nn_filename}"
420+
)
421+
422+
# Output
423+
with open(self.output_nn_filename, "w") as fh:
424+
for row in AdjList:
425+
fh.write(" ".join(str(x) for x in row) + "\n")
426+
369427
def _legacy_pca(self, M):
370428
"""
371429
This is more or less the historic implementation ported

0 commit comments

Comments
 (0)