Skip to content

Commit 7f88cb5

Browse files
committed
added support for eye(AxisCollection)
1 parent d0de183 commit 7f88cb5

File tree

3 files changed

+54
-13
lines changed

3 files changed

+54
-13
lines changed

doc/source/changes/version_0_33.rst.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ Miscellaneous improvements
5555
- support passing a dict as legend to customize the legend.
5656
- many tweaks to make several plots look better out of the box.
5757

58+
* :py:obj:`eye()` now supports an AxisCollection as argument, so you can use axes from another array by using
59+
``eye(other_array.axes)``.
5860

5961
Fixes
6062
^^^^^

larray/core/array.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9192,10 +9192,10 @@ def eye(rows, columns=None, k=0, title=None, dtype=None, meta=None):
91929192
91939193
Parameters
91949194
----------
9195-
rows : int or Axis
9196-
Rows of the output.
9195+
rows : int or Axis or tuple or length 2 AxisCollection
9196+
Rows of the output (if int or Axis) or rows and columns (if tuple or AxisCollection).
91979197
columns : int or Axis, optional
9198-
Columns of the output. If None, defaults to rows.
9198+
Columns of the output. Defaults to the value of `rows` if it is an int or Axis.
91999199
k : int, optional
92009200
Index of the diagonal: 0 (the default) refers to the main diagonal, a positive value refers to an upper
92019201
diagonal, and a negative value to a lower diagonal.
@@ -9214,16 +9214,16 @@ def eye(rows, columns=None, k=0, title=None, dtype=None, meta=None):
92149214
92159215
Examples
92169216
--------
9217+
>>> eye('sex=M,F')
9218+
sex\sex M F
9219+
M 1.0 0.0
9220+
F 0.0 1.0
92179221
>>> eye(2, dtype=int)
92189222
{0}*\{1}* 0 1
92199223
0 1 0
92209224
1 0 1
9221-
>>> sex = Axis('sex=M,F')
9222-
>>> eye(sex)
9223-
sex\sex M F
9224-
M 1.0 0.0
9225-
F 0.0 1.0
92269225
>>> age = Axis('age=0..2')
9226+
>>> sex = Axis('sex=M,F')
92279227
>>> eye(age, sex)
92289228
age\sex M F
92299229
0 1.0 0.0
@@ -9236,9 +9236,16 @@ def eye(rows, columns=None, k=0, title=None, dtype=None, meta=None):
92369236
2 0.0 0.0 0.0
92379237
"""
92389238
meta = _handle_meta(meta, title)
9239-
if columns is None:
9240-
columns = rows.copy() if isinstance(rows, Axis) else rows
9241-
axes = AxisCollection([rows, columns])
9239+
if isinstance(rows, AxisCollection):
9240+
assert columns is None
9241+
axes = rows
9242+
elif isinstance(rows, (tuple, list)):
9243+
assert columns is None
9244+
axes = AxisCollection(rows)
9245+
else:
9246+
if columns is None:
9247+
columns = rows.copy() if isinstance(rows, Axis) else rows
9248+
axes = AxisCollection([rows, columns])
92429249
shape = axes.shape
92439250
data = np.eye(shape[0], shape[1], k, dtype)
92449251
return Array(data, axes, meta=meta)

larray/tests/test_array.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
assert_array_equal, assert_array_nan_equal, assert_larray_equiv, assert_larray_equal,
1515
needs_xlwings, needs_pytables, needs_xlsxwriter, needs_openpyxl, needs_python37,
1616
must_warn)
17-
from larray import (Array, LArray, Axis, LGroup, union, zeros, zeros_like, ndtest, empty, ones, eye, diag, stack,
17+
from larray import (Array, LArray, Axis, AxisCollection, LGroup, IGroup,
18+
union, zeros, zeros_like, ndtest, empty, ones, eye, diag, stack,
1819
clip, exp, where, X, mean, isnan, round, read_hdf, read_csv, read_eurostat, read_excel,
19-
from_lists, from_string, open_excel, from_frame, sequence, nan, IGroup)
20+
from_lists, from_string, open_excel, from_frame, sequence, nan)
2021
from larray.inout.pandas import from_series
2122
from larray.core.axis import _to_ticks, _to_key
2223
from larray.util.misc import LHDFStore
@@ -4508,6 +4509,37 @@ def test_ufuncs(small_array):
45084509
assert_array_equal(rounded, np.round(raw + 0.6))
45094510

45104511

4512+
def test_eye():
4513+
age = Axis('age=0..2')
4514+
sex = Axis('sex=M,F')
4515+
4516+
# using one Axis object
4517+
res = eye(sex)
4518+
expected = from_string(r'''
4519+
sex\sex M F
4520+
M 1.0 0.0
4521+
F 0.0 1.0''')
4522+
assert_array_equal(res, expected)
4523+
4524+
# using an AxisCollection
4525+
res = eye(AxisCollection([age, sex]))
4526+
expected = from_string(r'''
4527+
age\sex M F
4528+
0 1.0 0.0
4529+
1 0.0 1.0
4530+
2 0.0 0.0''')
4531+
assert_array_equal(res, expected)
4532+
4533+
# using a tuple of axes
4534+
res = eye((age, sex))
4535+
expected = from_string(r"""
4536+
age\sex M F
4537+
0 1.0 0.0
4538+
1 0.0 1.0
4539+
2 0.0 0.0""")
4540+
assert_array_equal(res, expected)
4541+
4542+
45114543
def test_diag():
45124544
# 2D -> 1D
45134545
a = ndtest((3, 3))

0 commit comments

Comments
 (0)