Skip to content

Commit 1eefc9d

Browse files
committed
fixed an obscure bug in Session.binops
session1 binop session2 returned NotImplemented instead of nan for arrays present in session2 but not session1 This should become more relevant when we implement #514 and #515 Note that this is too obscure IMO to mention in the changelog.
1 parent c0b09dc commit 1eefc9d

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

larray/core/session.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from larray.core.axis import Axis
1212
from larray.core.array import LArray, larray_nan_equal, get_axes, ndtest, zeros, zeros_like, sequence
13-
from larray.util.misc import float_error_handler_factory, is_interactive_interpreter, renamed_to
13+
from larray.util.misc import float_error_handler_factory, is_interactive_interpreter, renamed_to, inverseop
1414
from larray.inout.session import check_pattern, handler_classes, ext_default_engine
1515

1616

@@ -662,6 +662,13 @@ def opmethod(self, other):
662662
# TypeError for str arrays, ValueError for incompatible axes, ...
663663
except Exception:
664664
res_array = np.nan
665+
# this should only ever happen when self_array is a non Array (eg. nan)
666+
if res_array is NotImplemented:
667+
try:
668+
res_array = getattr(other_array, '__%s__' % inverseop(opname))(self_array)
669+
# TypeError for str arrays, ValueError for incompatible axes, ...
670+
except Exception:
671+
res_array = np.nan
665672
res.append((name, res_array))
666673
return Session(res)
667674
opmethod.__name__ = opfullname

larray/util/misc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,18 @@ def wrapper(*args, **kwargs):
625625
warnings.warn(msg, FutureWarning, stacklevel=stacklevel)
626626
return newfunc(*args, **kwargs)
627627
return wrapper
628+
629+
630+
def inverseop(opname):
631+
comparison_ops = {
632+
'lt': 'gt',
633+
'gt': 'lt',
634+
'le': 'ge',
635+
'ge': 'le',
636+
'eq': 'eq',
637+
'ne': 'ne'
638+
}
639+
if opname in comparison_ops:
640+
return comparison_ops[opname]
641+
else:
642+
return 'r' + opname

0 commit comments

Comments
 (0)