diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 46b9669f9..22b91a079 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -155,6 +155,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)): return orig == new + if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new): + return True # This should be at the end of all numpy checking try: diff --git a/tests/test_comparator.py b/tests/test_comparator.py index d10a48d58..de0d753a0 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -468,6 +468,36 @@ def test_pandas(): assert comparator(an, ao) assert not comparator(an, ap) + assert comparator(pd.NA, pd.NA) + assert not comparator(pd.NA, None) + assert not comparator(None, pd.NA) + + s1 = pd.Series([1, 2, pd.NA, 4]) + s2 = pd.Series([1, 2, pd.NA, 4]) + s3 = pd.Series([1, 2, None, 4]) + + assert comparator(s1, s2) + assert not comparator(s1, s3) + + df1 = pd.DataFrame({'a': [1, 2, pd.NA], 'b': [4, pd.NA, 6]}) + df2 = pd.DataFrame({'a': [1, 2, pd.NA], 'b': [4, pd.NA, 6]}) + df3 = pd.DataFrame({'a': [1, 2, None], 'b': [4, None, 6]}) + assert comparator(df1, df2) + assert not comparator(df1, df3) + + d1 = {'a': pd.NA, 'b': [1, pd.NA, 3]} + d2 = {'a': pd.NA, 'b': [1, pd.NA, 3]} + d3 = {'a': None, 'b': [1, None, 3]} + assert comparator(d1, d2) + assert not comparator(d1, d3) + + s1 = pd.Series([1, 2, pd.NA, 4]) + s2 = pd.Series([1, 2, pd.NA, 4]) + + filtered1 = s1[s1 > 1] + filtered2 = s2[s2 > 1] + assert comparator(filtered1, filtered2) + def test_pyrsistent(): try: