@@ -274,26 +274,35 @@ def random_color(column):
274
274
n = len (data )
275
275
classes = set (data [class_column ])
276
276
class_col = data [class_column ]
277
- if cols == None :
278
- columns = [data [col ] for col in data .columns if (col != class_column )]
277
+
278
+ if cols is None :
279
+ df = data .drop (class_column , axis = 1 )
279
280
else :
280
- columns = [data [col ] for col in cols ]
281
+ df = data [cols ]
282
+
281
283
used_legends = set ([])
282
- x = range (len (columns ))
284
+
285
+ ncols = len (df .columns )
286
+ x = range (ncols )
287
+
283
288
if ax == None :
284
289
ax = plt .gca ()
290
+
285
291
for i in range (n ):
286
- row = [ columns [ c ][ i ] for c in range ( len ( columns ))]
292
+ row = df . irow ( i ). values
287
293
y = row
288
294
label = None
289
- if str (class_col [i ]) not in used_legends :
290
- label = str (class_col [i ])
295
+ kls = class_col .iget_value (i )
296
+ if str (kls ) not in used_legends :
297
+ label = str (kls )
291
298
used_legends .add (label )
292
- ax .plot (x , y , color = random_color (class_col [i ]), label = label , ** kwds )
293
- for i , col in enumerate (columns ):
299
+ ax .plot (x , y , color = random_color (kls ), label = label , ** kwds )
300
+
301
+ for i in range (ncols ):
294
302
ax .axvline (i , linewidth = 1 , color = 'black' )
295
- ax .set_xticks (range (len (columns )))
296
- ax .set_xticklabels ([col for col in data .columns if col != class_column ])
303
+
304
+ ax .set_xticks (x )
305
+ ax .set_xticklabels (df .columns )
297
306
ax .legend (loc = 'upper right' )
298
307
ax .grid ()
299
308
return ax
0 commit comments