diff --git a/datatree/formatting_html.py b/datatree/formatting_html.py index 4531f5ae..32927ef4 100644 --- a/datatree/formatting_html.py +++ b/datatree/formatting_html.py @@ -3,38 +3,76 @@ from typing import Any, Mapping from xarray.core.formatting_html import ( + _load_static_files, _mapping_section, - _obj_repr, attr_section, + collapsible_section, coord_section, datavar_section, dim_section, ) -from xarray.core.options import OPTIONS +from xarray.core.options import OPTIONS, _get_boolean_with_default OPTIONS["display_expand_groups"] = "default" +OPTIONS["display_expand_group_data"] = "default" + +additional_css_style = """ +.xr-tree { + display: inline-grid; + grid-template-columns: 100%; +} + +.xr-tree-item { + display: inline-grid; +} +.xr-tree-item-mid { + height: 100%; +} +.xr-tree-item-end { + height: 1.2em; +} + +.xr-tree-item-connection-vertical { + grid-column-start: 1; + border-right: 0.2em solid; + border-color: var(--xr-border-color); + width: 0px; +} +.xr-tree-item-connection-horizontal { + grid-column-start: 2; + grid-row-start: 1; + height: 1em; + width: 20px; + border-bottom: 0.2em solid; + border-color: var(--xr-border-color); +} + +.xr-tree-item-data { + grid-column-start: 3; +} +.xr-tree-item-data-sections { + margin-left: 0.6em; +} +""" def summarize_children(children: Mapping[str, Any]) -> str: - N_CHILDREN = len(children) - 1 + def is_last_item(index, n_total): + return index >= n_total - 1 - # Get result from node_repr and wrap it - lines_callback = lambda n, c, end: _wrap_repr(node_repr(n, c), end=end) + def format_child(name, child, end): + """format node and wrap it into a tree""" + formatted = node_repr(name, child) + return _wrap_repr(formatted, end=end) + + n_children = len(children) children_html = "".join( - lines_callback(n, c, end=False) # Long lines - if i < N_CHILDREN - else lines_callback(n, c, end=True) # Short lines - for i, (n, c) in enumerate(children.items()) + format_child(name, child, end=is_last_item(index, n_children)) + for index, (name, child) in enumerate(children.items()) ) - return "".join( - [ - "
{escape(repr(dt))}" + f"{node_repr(obj_type, dt)}" + "