diff --git a/src/corner/core.py b/src/corner/core.py index a08b3b8..7a63023 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -904,16 +904,21 @@ def _parse_input(xs): def _get_fig_axes(fig, K): if not fig.axes: return fig.subplots(K, K), True - try: - axarr = np.array(fig.axes).reshape((K, K)) - return axarr.item() if axarr.size == 1 else axarr.squeeze(), False - except ValueError: + + axarr = np.array(fig.axes) + axarr_size = axarr.size + if np.sqrt(axarr_size) != int(np.sqrt(axarr_size)): raise ValueError( - ( - "Provided figure has {0} axes, but data has " - "dimensions K={1}" - ).format(len(fig.axes), K) + f"Provided figure has {axarr_size} axes. Must be a square number" ) + if axarr.size == K**2: + axarr = np.array(fig.axes).reshape((K, K)) + return axarr.item() if axarr.size == 1 else axarr.squeeze(), False + if axarr.size > K**2: + axarr_ndim = int(np.sqrt(axarr_size)) + axarr = axarr.reshape((axarr_ndim, axarr_ndim)) # Reshape to square + axarr = axarr[:K, :K] + return axarr.squeeze(), False def _set_xlim(force, new_fig, ax, new_xlim):