Skip to content

Commit ebbf0ed

Browse files
committed
Remove circular dependence by moving mat_to_blk_diag into from_mat in BlkDiagMatrix
1 parent e690cbd commit ebbf0ed

File tree

2 files changed

+13
-21
lines changed

2 files changed

+13
-21
lines changed

src/aspire/utils/blk_diag_matrix.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def empty(nblocks, dtype=np.float64):
9393
"""
9494

9595
# Empty partition has block dims of zero until they are assigned
96-
partition = [(0,0)] * nblocks
96+
partition = [(0, 0)] * nblocks
9797

9898
return BlkDiagMatrix(partition, dtype=dtype)
9999

@@ -784,10 +784,10 @@ def from_list(blk_diag, dtype=np.float64):
784784
@staticmethod
785785
def from_mat(mat, blk_partition, dtype=np.float64):
786786
"""
787-
Convert full block diagonal matrix into list representation.
787+
Convert full block diagonal matrix into BlkDiagMatrix.
788788
789789
:param mat; The full block diagonal matrix including the zero elements
790-
ofnon-diagonal blocks.
790+
of non-diagonal blocks.
791791
:param blk_partition: The matrix block partition in the form of a
792792
K-element list storing all shapes of K diagonal matrix blocks,
793793
where `blk_partition[i]` corresponds to the shape (number rows
@@ -796,15 +796,21 @@ def from_mat(mat, blk_partition, dtype=np.float64):
796796
:return: The BlkDiagMatrix instance.
797797
"""
798798

799-
# TODO: maybe can improve this implementation
800-
801799
A = BlkDiagMatrix(blk_partition, dtype=dtype)
802800

803801
rows = blk_partition[:, 0]
804802
cols = blk_partition[:, 1]
805803
cellarray = Cell2D(rows, cols, dtype=mat.dtype)
806-
blk_diag = cellarray.mat_to_blk_diag(mat, rows, cols)
807-
A.data = BlkDiagMatrix.from_list(blk_diag)
804+
805+
offset = 0
806+
blk_ind = 0
807+
for i in range(0, cellarray.nrow):
808+
for j in range(0, cellarray.ncol):
809+
offset += 1
810+
if i == j:
811+
blk_ind += 1
812+
A[blk_ind] = cellarray.cell_list[offset]
813+
808814
return A
809815

810816
def solve(self, Y):

src/aspire/utils/cell.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import numpy as np
22
import logging
33

4-
# ugh circular depends
5-
#from aspire.utils.blk_diag_matrix import BlkDiagMatrix
6-
74
logger = logging.getLogger(__name__)
85

96

@@ -45,14 +42,3 @@ def mat2cell(self, mat, rows, cols):
4542
offsetc += cols[j]
4643
offsetr += rows[i]
4744
return self.cell_list
48-
49-
def mat_to_blk_diag(self, mat, rows, cols):
50-
self.mat2cell(mat, rows, cols)
51-
blk_diag = BlkDiagMatrix.empty(0, self.dtype)
52-
offset = 0
53-
for i in range(0, self.nrow):
54-
for j in range(0, self.ncol):
55-
offset += 1
56-
if i == j:
57-
blk_diag.append(self.cell_list[offset])
58-
return blk_diag

0 commit comments

Comments
 (0)