77from typing import (
88 TYPE_CHECKING ,
99 Any ,
10+ List ,
1011 Literal ,
1112 Optional ,
1213 Protocol ,
6869# styles: list[CellStyle] = field(default_factory=list)
6970
7071
72+ def styles_to_jsonifiable (
73+ styles : list [StyleInfo ],
74+ * ,
75+ column_names : list [str ],
76+ ) -> list [dict [str , Jsonifiable ]]:
77+ return [
78+ style_info_to_jsonifiable (
79+ style ,
80+ column_names = column_names ,
81+ )
82+ for style in styles
83+ ]
84+
85+
7186def style_info_to_jsonifiable (
7287 style_info : StyleInfo ,
7388 * ,
@@ -86,7 +101,20 @@ def style_info_to_jsonifiable(
86101 }
87102
88103
89- Styles = list [StyleInfo ]
104+ def as_styles (styles : Styles | None ) -> Styles :
105+
106+ if styles is None :
107+ styles = []
108+ # if not isinstance(styles, list):
109+ # styles = [styles]
110+ if not all (isinstance (style , StyleInfo ) for style in styles ):
111+ raise TypeError (
112+ "Expected 'styles' to be `None` or a list of `StyleInfo` objects"
113+ )
114+ return styles
115+
116+
117+ Styles = List [StyleInfo ]
90118
91119
92120class AbstractTabularData (abc .ABC ):
@@ -211,17 +239,7 @@ def __init__(
211239 editable = self .editable ,
212240 row_selection_mode = row_selection_mode ,
213241 )
214-
215- if styles is None :
216- styles = []
217- # if not isinstance(styles, list):
218- # styles = [styles]
219- if not all (isinstance (style , StyleInfo ) for style in styles ):
220- raise TypeError (
221- "The DataGrid() constructor expected 'styles' to be a "
222- "list of `StyleInfo` objects"
223- )
224- self .styles = styles
242+ self .styles = as_styles (styles )
225243
226244 def to_payload (self ) -> dict [str , Jsonifiable ]:
227245 res = serialize_pandas_df (self .data )
@@ -234,13 +252,10 @@ def to_payload(self) -> dict[str, Jsonifiable]:
234252 editable = self .editable ,
235253 style = "grid" ,
236254 fill = self .height is None ,
237- styles = [
238- style_info_to_jsonifiable (
239- style ,
240- column_names = self .data .columns .tolist (),
241- )
242- for style in self .styles
243- ],
255+ styles = styles_to_jsonifiable (
256+ self .styles ,
257+ column_names = self .data .columns .tolist (),
258+ ),
244259 )
245260 return res
246261
@@ -326,6 +341,7 @@ def __init__(
326341 editable : bool = False ,
327342 selection_mode : SelectionModeInput = "none" ,
328343 row_selection_mode : Literal ["deprecated" ] = "deprecated" ,
344+ styles : Optional [Styles ] = None ,
329345 ):
330346
331347 self .data = cast_to_pandas (
@@ -344,6 +360,7 @@ def __init__(
344360 editable = self .editable ,
345361 row_selection_mode = row_selection_mode ,
346362 )
363+ self .styles = as_styles (styles )
347364
348365 def to_payload (self ) -> dict [str , Jsonifiable ]:
349366 res = serialize_pandas_df (self .data )
@@ -354,6 +371,10 @@ def to_payload(self) -> dict[str, Jsonifiable]:
354371 filters = self .filters ,
355372 editable = self .editable ,
356373 style = "table" ,
374+ styles = styles_to_jsonifiable (
375+ self .styles ,
376+ column_names = self .data .columns .tolist (),
377+ ),
357378 )
358379 return res
359380
0 commit comments