Skip to content

Commit e690cbd

Browse files
committed
add square array check for BlkDiagMatrix constructor, setter, and doc
1 parent 227000e commit e690cbd

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/aspire/utils/blk_diag_matrix.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ class BlkDiagMatrix:
1919
"""
2020
Define a BlkDiagMatrix class which implements operations for
2121
block diagonal matrices as used by ASPIRE.
22+
23+
Currently BlkDiagMatrix is implemented only for square blocks.
24+
While in the future this can be extended, at this time assigning
25+
a non square array will raise NotImplementedError.
2226
"""
2327

2428
# Developers' Note:
@@ -39,7 +43,7 @@ def __init__(self, partition, dtype=np.float64):
3943
diagonal matrix blocks, where `partition[i]` corresponds to
4044
the shape (number of rows and columns) of the `i` matrix block.
4145
:param dtype: Datatype for blocks, defaults to np.float64.
42-
:return BlkDiagMatrix instance.
46+
:return: BlkDiagMatrix instance.
4347
"""
4448

4549
self.nblocks = len(partition)
@@ -48,6 +52,22 @@ def __init__(self, partition, dtype=np.float64):
4852
self._cached_blk_sizes = np.array(partition)
4953
if len(partition):
5054
assert self._cached_blk_sizes.shape[1] == 2
55+
assert all([BlkDiagMatrix.check_square(shp) for shp in partition])
56+
57+
@staticmethod
58+
def check_square(shp):
59+
"""
60+
Check if supplied shape tuple is square.
61+
62+
:param shp: Shape to test, expressed as a 2-tuple.
63+
:return: True or raises NotImplementedError.
64+
65+
"""
66+
if shp[0] != shp[1]:
67+
raise NotImplementedError("Currently BlkDiagMatrix only supports"
68+
" square blocks. Received {}".format(
69+
shp))
70+
return True
5171

5272
def reset_cache(self):
5373
"""
@@ -126,6 +146,7 @@ def __setitem__(self, key, value):
126146
Convenience wrapper, setter on self.data.
127147
"""
128148

149+
BlkDiagMatrix.check_square(value.shape)
129150
self.data[key] = value
130151
self.reset_cache()
131152

0 commit comments

Comments
 (0)