@@ -1024,7 +1024,12 @@ def _raise(left, right, err_msg):
10241024
10251025
10261026def assert_extension_array_equal (
1027- left , right , check_dtype = True , check_less_precise = False , check_exact = False
1027+ left ,
1028+ right ,
1029+ check_dtype = True ,
1030+ check_less_precise = False ,
1031+ check_exact = False ,
1032+ index_values = None ,
10281033):
10291034 """
10301035 Check that left and right ExtensionArrays are equal.
@@ -1041,6 +1046,8 @@ def assert_extension_array_equal(
10411046 If int, then specify the digits to compare.
10421047 check_exact : bool, default False
10431048 Whether to compare number exactly.
1049+ index_values : numpy.ndarray, default None
1050+ optional index (shared by both left and right), used in output.
10441051
10451052 Notes
10461053 -----
@@ -1056,24 +1063,31 @@ def assert_extension_array_equal(
10561063 if hasattr (left , "asi8" ) and type (right ) == type (left ):
10571064 # Avoid slow object-dtype comparisons
10581065 # np.asarray for case where we have a np.MaskedArray
1059- assert_numpy_array_equal (np .asarray (left .asi8 ), np .asarray (right .asi8 ))
1066+ assert_numpy_array_equal (
1067+ np .asarray (left .asi8 ), np .asarray (right .asi8 ), index_values = index_values
1068+ )
10601069 return
10611070
10621071 left_na = np .asarray (left .isna ())
10631072 right_na = np .asarray (right .isna ())
1064- assert_numpy_array_equal (left_na , right_na , obj = "ExtensionArray NA mask" )
1073+ assert_numpy_array_equal (
1074+ left_na , right_na , obj = "ExtensionArray NA mask" , index_values = index_values
1075+ )
10651076
10661077 left_valid = np .asarray (left [~ left_na ].astype (object ))
10671078 right_valid = np .asarray (right [~ right_na ].astype (object ))
10681079 if check_exact :
1069- assert_numpy_array_equal (left_valid , right_valid , obj = "ExtensionArray" )
1080+ assert_numpy_array_equal (
1081+ left_valid , right_valid , obj = "ExtensionArray" , index_values = index_values
1082+ )
10701083 else :
10711084 _testing .assert_almost_equal (
10721085 left_valid ,
10731086 right_valid ,
10741087 check_dtype = check_dtype ,
10751088 check_less_precise = check_less_precise ,
10761089 obj = "ExtensionArray" ,
1090+ index_values = index_values ,
10771091 )
10781092
10791093
@@ -1206,12 +1220,17 @@ def assert_series_equal(
12061220 check_less_precise = check_less_precise ,
12071221 check_dtype = check_dtype ,
12081222 obj = str (obj ),
1223+ index_values = np .asarray (left .index ),
12091224 )
12101225 elif is_extension_array_dtype (left .dtype ) and is_extension_array_dtype (right .dtype ):
1211- assert_extension_array_equal (left ._values , right ._values )
1226+ assert_extension_array_equal (
1227+ left ._values , right ._values , index_values = np .asarray (left .index )
1228+ )
12121229 elif needs_i8_conversion (left .dtype ) or needs_i8_conversion (right .dtype ):
12131230 # DatetimeArray or TimedeltaArray
1214- assert_extension_array_equal (left ._values , right ._values )
1231+ assert_extension_array_equal (
1232+ left ._values , right ._values , index_values = np .asarray (left .index )
1233+ )
12151234 else :
12161235 _testing .assert_almost_equal (
12171236 left ._values ,
0 commit comments