diff --git a/shiny/ui/_layout_columns.py b/shiny/ui/_layout_columns.py index 77e551a8d..7bd2394db 100644 --- a/shiny/ui/_layout_columns.py +++ b/shiny/ui/_layout_columns.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Dict, Iterable, Optional, TypeVar, Union, cast +from typing import Dict, Iterable, Literal, Optional, Tuple, TypeVar, Union, cast from warnings import warn from htmltools import Tag, TagAttrs, TagAttrValue, TagChild, css @@ -15,24 +15,20 @@ T = TypeVar("T") -class Breakpoints(Enum): - """ - References - ---------- - * [Available Bootstrap breakpoints](https://getbootstrap.com/docs/5.3/layout/breakpoints/#available-breakpoints) - """ +Breakpoint = Literal["xs", "sm", "md", "lg", "xl", "xxl"] +""" +References +---------- +* [Available Bootstrap breakpoints](https://getbootstrap.com/docs/5.3/layout/breakpoints/#available-breakpoints) +""" + - xs = "xs" - sm = "sm" - md = "md" - lg = "lg" - xl = "xl" - xxl = "xxl" +breakpoints: Tuple[Breakpoint, ...] = ("xs", "sm", "md", "lg", "xl", "xxl") -BreakpointsSoft = Dict[Breakpoints, Union[Iterable[T], T, None]] -BreakpointsOptional = Dict[Breakpoints, Union[Iterable[T], None]] -BreakpointsComplete = Dict[Breakpoints, Iterable[T]] +BreakpointsSoft = Dict[Breakpoint, Union[Iterable[T], T, None]] +BreakpointsOptional = Dict[Breakpoint, Union[Iterable[T], None]] +BreakpointsComplete = Dict[Breakpoint, Iterable[T]] BreakpointsUser = Union[BreakpointsSoft[T], Iterable[T], T, None] @@ -155,16 +151,15 @@ def as_col_spec( return None if not isinstance(col_widths, Dict): - return {Breakpoints.md: validate_col_width(col_widths, n_kids, Breakpoints.md)} + return {"md": validate_col_width(col_widths, n_kids, "md")} ret: BreakpointsOptional[int] = {} col_widths_items = cast(BreakpointsSoft[int], col_widths).items() for brk, value in col_widths_items: - bs_breakpoints = [str(bp.value) for bp in Breakpoints] - if str(brk) not in bs_breakpoints: + if brk not in breakpoints: raise ValueError( - f"Breakpoint '{brk}' is not valid. Valid breakpoints are: {', '.join(bs_breakpoints)}'." + f"Breakpoint '{brk}' is not valid. Valid breakpoints are: {', '.join(breakpoints)}'." ) if value is None: @@ -180,7 +175,7 @@ def as_col_spec( def validate_col_width( - x: Iterable[int] | int, n_kids: int, break_name: Breakpoints + x: Iterable[int] | int, n_kids: int, break_name: Breakpoint ) -> Iterable[int]: if isinstance(x, int): y = [x] @@ -266,9 +261,7 @@ def row_heights_attrs( # row height is derived from xs or defaults to auto in the CSS, so we don't need the # class to activate it classes = [ - f"bslib-grid--row-heights--{brk}" - for brk in x_complete.keys() - if brk != Breakpoints.xs + f"bslib-grid--row-heights--{brk}" for brk in x_complete.keys() if brk != "xs" ] # Create CSS variables, treating numeric values as fractional units, passing strings