@@ -580,8 +580,18 @@ def to_datetime(self, dayfirst=False):
580580 return DatetimeIndex (self .values )
581581
582582 def _assert_can_do_setop (self , other ):
583+ if not com .is_list_like (other ):
584+ raise TypeError ('Input must be Index or array-like' )
583585 return True
584586
587+ def _convert_can_do_setop (self , other ):
588+ if not isinstance (other , Index ):
589+ other = Index (other , name = self .name )
590+ result_name = self .name
591+ else :
592+ result_name = self .name if self .name == other .name else None
593+ return other , result_name
594+
585595 @property
586596 def nlevels (self ):
587597 return 1
@@ -1364,16 +1374,14 @@ def union(self, other):
13641374 -------
13651375 union : Index
13661376 """
1367- if not hasattr (other , '__iter__' ):
1368- raise TypeError ( 'Input must be iterable.' )
1377+ self . _assert_can_do_setop (other )
1378+ other = _ensure_index ( other )
13691379
13701380 if len (other ) == 0 or self .equals (other ):
13711381 return self
13721382
13731383 if len (self ) == 0 :
1374- return _ensure_index (other )
1375-
1376- self ._assert_can_do_setop (other )
1384+ return other
13771385
13781386 if not is_dtype_equal (self .dtype ,other .dtype ):
13791387 this = self .astype ('O' )
@@ -1439,11 +1447,7 @@ def intersection(self, other):
14391447 -------
14401448 intersection : Index
14411449 """
1442- if not hasattr (other , '__iter__' ):
1443- raise TypeError ('Input must be iterable!' )
1444-
14451450 self ._assert_can_do_setop (other )
1446-
14471451 other = _ensure_index (other )
14481452
14491453 if self .equals (other ):
@@ -1492,18 +1496,12 @@ def difference(self, other):
14921496
14931497 >>> index.difference(index2)
14941498 """
1495-
1496- if not hasattr (other , '__iter__' ):
1497- raise TypeError ('Input must be iterable!' )
1499+ self ._assert_can_do_setop (other )
14981500
14991501 if self .equals (other ):
15001502 return Index ([], name = self .name )
15011503
1502- if not isinstance (other , Index ):
1503- other = np .asarray (other )
1504- result_name = self .name
1505- else :
1506- result_name = self .name if self .name == other .name else None
1504+ other , result_name = self ._convert_can_do_setop (other )
15071505
15081506 theDiff = sorted (set (self ) - set (other ))
15091507 return Index (theDiff , name = result_name )
@@ -1517,7 +1515,7 @@ def sym_diff(self, other, result_name=None):
15171515 Parameters
15181516 ----------
15191517
1520- other : array-like
1518+ other : Index or array-like
15211519 result_name : str
15221520
15231521 Returns
@@ -1545,13 +1543,10 @@ def sym_diff(self, other, result_name=None):
15451543 >>> idx1 ^ idx2
15461544 Int64Index([1, 5], dtype='int64')
15471545 """
1548- if not hasattr (other , '__iter__' ):
1549- raise TypeError ('Input must be iterable!' )
1550-
1551- if not isinstance (other , Index ):
1552- other = Index (other )
1553- result_name = result_name or self .name
1554-
1546+ self ._assert_can_do_setop (other )
1547+ other , result_name_update = self ._convert_can_do_setop (other )
1548+ if result_name is None :
1549+ result_name = result_name_update
15551550 the_diff = sorted (set ((self .difference (other )).union (other .difference (self ))))
15561551 return Index (the_diff , name = result_name )
15571552
@@ -5460,12 +5455,11 @@ def union(self, other):
54605455 >>> index.union(index2)
54615456 """
54625457 self ._assert_can_do_setop (other )
5458+ other , result_names = self ._convert_can_do_setop (other )
54635459
54645460 if len (other ) == 0 or self .equals (other ):
54655461 return self
54665462
5467- result_names = self .names if self .names == other .names else None
5468-
54695463 uniq_tuples = lib .fast_unique_multiple ([self .values , other .values ])
54705464 return MultiIndex .from_arrays (lzip (* uniq_tuples ), sortorder = 0 ,
54715465 names = result_names )
@@ -5483,12 +5477,11 @@ def intersection(self, other):
54835477 Index
54845478 """
54855479 self ._assert_can_do_setop (other )
5480+ other , result_names = self ._convert_can_do_setop (other )
54865481
54875482 if self .equals (other ):
54885483 return self
54895484
5490- result_names = self .names if self .names == other .names else None
5491-
54925485 self_tuples = self .values
54935486 other_tuples = other .values
54945487 uniq_tuples = sorted (set (self_tuples ) & set (other_tuples ))
@@ -5509,18 +5502,10 @@ def difference(self, other):
55095502 diff : MultiIndex
55105503 """
55115504 self ._assert_can_do_setop (other )
5505+ other , result_names = self ._convert_can_do_setop (other )
55125506
5513- if not isinstance (other , MultiIndex ):
5514- if len (other ) == 0 :
5507+ if len (other ) == 0 :
55155508 return self
5516- try :
5517- other = MultiIndex .from_tuples (other )
5518- except :
5519- raise TypeError ('other must be a MultiIndex or a list of'
5520- ' tuples' )
5521- result_names = self .names
5522- else :
5523- result_names = self .names if self .names == other .names else None
55245509
55255510 if self .equals (other ):
55265511 return MultiIndex (levels = [[]] * self .nlevels ,
@@ -5537,15 +5522,30 @@ def difference(self, other):
55375522 return MultiIndex .from_tuples (difference , sortorder = 0 ,
55385523 names = result_names )
55395524
5540- def _assert_can_do_setop (self , other ):
5541- pass
5542-
55435525 def astype (self , dtype ):
55445526 if not is_object_dtype (np .dtype (dtype )):
55455527 raise TypeError ('Setting %s dtype to anything other than object '
55465528 'is not supported' % self .__class__ )
55475529 return self ._shallow_copy ()
55485530
5531+ def _convert_can_do_setop (self , other ):
5532+ result_names = self .names
5533+
5534+ if not hasattr (other , 'names' ):
5535+ if len (other ) == 0 :
5536+ other = MultiIndex (levels = [[]] * self .nlevels ,
5537+ labels = [[]] * self .nlevels ,
5538+ verify_integrity = False )
5539+ else :
5540+ msg = 'other must be a MultiIndex or a list of tuples'
5541+ try :
5542+ other = MultiIndex .from_tuples (other )
5543+ except :
5544+ raise TypeError (msg )
5545+ else :
5546+ result_names = self .names if self .names == other .names else None
5547+ return other , result_names
5548+
55495549 def insert (self , loc , item ):
55505550 """
55515551 Make new MultiIndex inserting new item at location
0 commit comments