@@ -19,6 +19,10 @@ class BlkDiagMatrix:
1919 """
2020 Define a BlkDiagMatrix class which implements operations for
2121 block diagonal matrices as used by ASPIRE.
22+
23+ Currently BlkDiagMatrix is implemented only for square blocks.
24+ While in the future this can be extended, at this time assigning
25+ a non square array will raise NotImplementedError.
2226 """
2327
2428 # Developers' Note:
@@ -39,7 +43,7 @@ def __init__(self, partition, dtype=np.float64):
3943 diagonal matrix blocks, where `partition[i]` corresponds to
4044 the shape (number of rows and columns) of the `i` matrix block.
4145 :param dtype: Datatype for blocks, defaults to np.float64.
42- :return BlkDiagMatrix instance.
46+ :return: BlkDiagMatrix instance.
4347 """
4448
4549 self .nblocks = len (partition )
@@ -48,6 +52,22 @@ def __init__(self, partition, dtype=np.float64):
4852 self ._cached_blk_sizes = np .array (partition )
4953 if len (partition ):
5054 assert self ._cached_blk_sizes .shape [1 ] == 2
55+ assert all ([BlkDiagMatrix .check_square (shp ) for shp in partition ])
56+
57+ @staticmethod
58+ def check_square (shp ):
59+ """
60+ Check if supplied shape tuple is square.
61+
62+ :param shp: Shape to test, expressed as a 2-tuple.
63+ :return: True or raises NotImplementedError.
64+
65+ """
66+ if shp [0 ] != shp [1 ]:
67+ raise NotImplementedError ("Currently BlkDiagMatrix only supports"
68+ " square blocks. Received {}" .format (
69+ shp ))
70+ return True
5171
5272 def reset_cache (self ):
5373 """
@@ -126,6 +146,7 @@ def __setitem__(self, key, value):
126146 Convenience wrapper, setter on self.data.
127147 """
128148
149+ BlkDiagMatrix .check_square (value .shape )
129150 self .data [key ] = value
130151 self .reset_cache ()
131152
0 commit comments