diff --git a/datatree/formatting_html.py b/datatree/formatting_html.py index 4531f5ae..32927ef4 100644 --- a/datatree/formatting_html.py +++ b/datatree/formatting_html.py @@ -3,38 +3,76 @@ from typing import Any, Mapping from xarray.core.formatting_html import ( + _load_static_files, _mapping_section, - _obj_repr, attr_section, + collapsible_section, coord_section, datavar_section, dim_section, ) -from xarray.core.options import OPTIONS +from xarray.core.options import OPTIONS, _get_boolean_with_default OPTIONS["display_expand_groups"] = "default" +OPTIONS["display_expand_group_data"] = "default" + +additional_css_style = """ +.xr-tree { + display: inline-grid; + grid-template-columns: 100%; +} + +.xr-tree-item { + display: inline-grid; +} +.xr-tree-item-mid { + height: 100%; +} +.xr-tree-item-end { + height: 1.2em; +} + +.xr-tree-item-connection-vertical { + grid-column-start: 1; + border-right: 0.2em solid; + border-color: var(--xr-border-color); + width: 0px; +} +.xr-tree-item-connection-horizontal { + grid-column-start: 2; + grid-row-start: 1; + height: 1em; + width: 20px; + border-bottom: 0.2em solid; + border-color: var(--xr-border-color); +} + +.xr-tree-item-data { + grid-column-start: 3; +} +.xr-tree-item-data-sections { + margin-left: 0.6em; +} +""" def summarize_children(children: Mapping[str, Any]) -> str: - N_CHILDREN = len(children) - 1 + def is_last_item(index, n_total): + return index >= n_total - 1 - # Get result from node_repr and wrap it - lines_callback = lambda n, c, end: _wrap_repr(node_repr(n, c), end=end) + def format_child(name, child, end): + """format node and wrap it into a tree""" + formatted = node_repr(name, child) + return _wrap_repr(formatted, end=end) + + n_children = len(children) children_html = "".join( - lines_callback(n, c, end=False) # Long lines - if i < N_CHILDREN - else lines_callback(n, c, end=True) # Short lines - for i, (n, c) in enumerate(children.items()) + format_child(name, child, end=is_last_item(index, n_children)) + for index, (name, child) in enumerate(children.items()) ) - return "".join( - [ - "
", - children_html, - "
", - ] - ) + return f"
{children_html}
" children_section = partial( @@ -46,20 +84,65 @@ def summarize_children(children: Mapping[str, Any]) -> str: ) -def node_repr(group_title: str, dt: Any) -> str: - header_components = [f"
{escape(group_title)}
"] - - ds = dt.ds - +def summarize_data(node): + ds = node.ds sections = [ - children_section(dt.children), dim_section(ds), coord_section(ds.coords), datavar_section(ds.data_vars), attr_section(ds.attrs), ] - return _obj_repr(ds, header_components, sections) + sections_li = "".join( + f"
  • {section}
  • " for section in sections + ) + return ( + "
    " + f"" + "
    " + ) + + +def data_section(node): + name = "Data" + + details = summarize_data(node) + + n_items = 5 + expanded = _get_boolean_with_default( + "display_expand_group_data", + True, + ) + collapsed = not expanded + + return collapsible_section( + name=name, + details=details, + n_items=n_items, + enabled=True, + collapsed=collapsed, + ) + + +def join_sections(sections, header_components): + combined_sections = "".join( + f"
  • {s}
  • " for s in sections + ) + header = "".join(header_components) + return ( + "" + ) + + +def node_repr(group_title: str, dt: Any) -> str: + header_components = [f"
    {escape(group_title)}
    "] + + sections = [children_section(dt.children), data_section(dt)] + + return join_sections(sections, header_components) def _wrap_repr(r: str, end: bool = False) -> str: @@ -99,40 +182,27 @@ def _wrap_repr(r: str, end: bool = False) -> str: Tee color is set to the variable :code:`--xr-border-color`. """ - # height of line - end = bool(end) - height = "100%" if end is False else "1.2em" - return "".join( - [ - "
    ", - "
    ", - "
    ", - "
    ", - "
    ", - "
    ", - "" "
    ", - "
    ", - ] + item_class = "xr-tree-item-mid" if not end else "xr-tree-item-end" + return ( + "
    " + f"
    " + "
    " + f"
    " + "
    " ) def datatree_repr(dt: Any) -> str: obj_type = f"datatree.{type(dt).__name__}" - return node_repr(obj_type, dt) + + icons_svg, css_style = _load_static_files() + + return ( + "
    " + f"{icons_svg}" + f"" + f"" + f"
    {escape(repr(dt))}
    " + f"{node_repr(obj_type, dt)}" + "
    " + )