@@ -580,6 +580,8 @@ def to_datetime(self, dayfirst=False):
580580 return DatetimeIndex (self .values )
581581
582582 def _assert_can_do_setop (self , other ):
583+ if not com .is_list_like (other ):
584+ raise TypeError ('Input must be Index or array-like' )
583585 return True
584586
585587 @property
@@ -1364,16 +1366,14 @@ def union(self, other):
13641366 -------
13651367 union : Index
13661368 """
1367- if not hasattr (other , '__iter__' ):
1368- raise TypeError ( 'Input must be iterable.' )
1369+ self . _assert_can_do_setop (other )
1370+ other = _ensure_index ( other )
13691371
13701372 if len (other ) == 0 or self .equals (other ):
13711373 return self
13721374
13731375 if len (self ) == 0 :
1374- return _ensure_index (other )
1375-
1376- self ._assert_can_do_setop (other )
1376+ return other
13771377
13781378 if not is_dtype_equal (self .dtype ,other .dtype ):
13791379 this = self .astype ('O' )
@@ -1439,11 +1439,7 @@ def intersection(self, other):
14391439 -------
14401440 intersection : Index
14411441 """
1442- if not hasattr (other , '__iter__' ):
1443- raise TypeError ('Input must be iterable!' )
1444-
14451442 self ._assert_can_do_setop (other )
1446-
14471443 other = _ensure_index (other )
14481444
14491445 if self .equals (other ):
@@ -1492,9 +1488,7 @@ def difference(self, other):
14921488
14931489 >>> index.difference(index2)
14941490 """
1495-
1496- if not hasattr (other , '__iter__' ):
1497- raise TypeError ('Input must be iterable!' )
1491+ self ._assert_can_do_setop (other )
14981492
14991493 if self .equals (other ):
15001494 return Index ([], name = self .name )
@@ -1517,7 +1511,7 @@ def sym_diff(self, other, result_name=None):
15171511 Parameters
15181512 ----------
15191513
1520- other : array-like
1514+ other : Index or array-like
15211515 result_name : str
15221516
15231517 Returns
@@ -1545,9 +1539,7 @@ def sym_diff(self, other, result_name=None):
15451539 >>> idx1 ^ idx2
15461540 Int64Index([1, 5], dtype='int64')
15471541 """
1548- if not hasattr (other , '__iter__' ):
1549- raise TypeError ('Input must be iterable!' )
1550-
1542+ self ._assert_can_do_setop (other )
15511543 if not isinstance (other , Index ):
15521544 other = Index (other )
15531545 result_name = result_name or self .name
@@ -5460,12 +5452,11 @@ def union(self, other):
54605452 >>> index.union(index2)
54615453 """
54625454 self ._assert_can_do_setop (other )
5455+ other , result_names = self ._convert_can_do_setop (other )
54635456
54645457 if len (other ) == 0 or self .equals (other ):
54655458 return self
54665459
5467- result_names = self .names if self .names == other .names else None
5468-
54695460 uniq_tuples = lib .fast_unique_multiple ([self .values , other .values ])
54705461 return MultiIndex .from_arrays (lzip (* uniq_tuples ), sortorder = 0 ,
54715462 names = result_names )
@@ -5483,12 +5474,11 @@ def intersection(self, other):
54835474 Index
54845475 """
54855476 self ._assert_can_do_setop (other )
5477+ other , result_names = self ._convert_can_do_setop (other )
54865478
54875479 if self .equals (other ):
54885480 return self
54895481
5490- result_names = self .names if self .names == other .names else None
5491-
54925482 self_tuples = self .values
54935483 other_tuples = other .values
54945484 uniq_tuples = sorted (set (self_tuples ) & set (other_tuples ))
@@ -5509,18 +5499,10 @@ def difference(self, other):
55095499 diff : MultiIndex
55105500 """
55115501 self ._assert_can_do_setop (other )
5502+ other , result_names = self ._convert_can_do_setop (other )
55125503
5513- if not isinstance (other , MultiIndex ):
5514- if len (other ) == 0 :
5504+ if len (other ) == 0 :
55155505 return self
5516- try :
5517- other = MultiIndex .from_tuples (other )
5518- except :
5519- raise TypeError ('other must be a MultiIndex or a list of'
5520- ' tuples' )
5521- result_names = self .names
5522- else :
5523- result_names = self .names if self .names == other .names else None
55245506
55255507 if self .equals (other ):
55265508 return MultiIndex (levels = [[]] * self .nlevels ,
@@ -5537,15 +5519,29 @@ def difference(self, other):
55375519 return MultiIndex .from_tuples (difference , sortorder = 0 ,
55385520 names = result_names )
55395521
5540- def _assert_can_do_setop (self , other ):
5541- pass
5542-
55435522 def astype (self , dtype ):
55445523 if not is_object_dtype (np .dtype (dtype )):
55455524 raise TypeError ('Setting %s dtype to anything other than object '
55465525 'is not supported' % self .__class__ )
55475526 return self ._shallow_copy ()
55485527
5528+ def _convert_can_do_setop (self , other ):
5529+ if not isinstance (other , MultiIndex ):
5530+ if len (other ) == 0 :
5531+ other = MultiIndex (levels = [[]] * self .nlevels ,
5532+ labels = [[]] * self .nlevels ,
5533+ verify_integrity = False )
5534+ else :
5535+ msg = 'other must be a MultiIndex or a list of tuples'
5536+ try :
5537+ other = MultiIndex .from_tuples (other )
5538+ except :
5539+ raise TypeError (msg )
5540+ result_names = self .names
5541+ else :
5542+ result_names = self .names if self .names == other .names else None
5543+ return other , result_names
5544+
55495545 def insert (self , loc , item ):
55505546 """
55515547 Make new MultiIndex inserting new item at location
0 commit comments