Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.
Closed
180 changes: 125 additions & 55 deletions datatree/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,76 @@
from typing import Any, Mapping

from xarray.core.formatting_html import (
_load_static_files,
_mapping_section,
_obj_repr,
attr_section,
collapsible_section,
coord_section,
datavar_section,
dim_section,
)
from xarray.core.options import OPTIONS
from xarray.core.options import OPTIONS, _get_boolean_with_default

OPTIONS["display_expand_groups"] = "default"
OPTIONS["display_expand_group_data"] = "default"

additional_css_style = """
.xr-tree {
display: inline-grid;
grid-template-columns: 100%;
}

.xr-tree-item {
display: inline-grid;
}
.xr-tree-item-mid {
height: 100%;
}
.xr-tree-item-end {
height: 1.2em;
}

.xr-tree-item-connection-vertical {
grid-column-start: 1;
border-right: 0.2em solid;
border-color: var(--xr-border-color);
width: 0px;
}
.xr-tree-item-connection-horizontal {
grid-column-start: 2;
grid-row-start: 1;
height: 1em;
width: 20px;
border-bottom: 0.2em solid;
border-color: var(--xr-border-color);
}

.xr-tree-item-data {
grid-column-start: 3;
}
.xr-tree-item-data-sections {
margin-left: 0.6em;
}
"""


def summarize_children(children: Mapping[str, Any]) -> str:
N_CHILDREN = len(children) - 1
def is_last_item(index, n_total):
return index >= n_total - 1

# Get result from node_repr and wrap it
lines_callback = lambda n, c, end: _wrap_repr(node_repr(n, c), end=end)
def format_child(name, child, end):
"""format node and wrap it into a tree"""
formatted = node_repr(name, child)
return _wrap_repr(formatted, end=end)

n_children = len(children)

children_html = "".join(
lines_callback(n, c, end=False) # Long lines
if i < N_CHILDREN
else lines_callback(n, c, end=True) # Short lines
for i, (n, c) in enumerate(children.items())
format_child(name, child, end=is_last_item(index, n_children))
for index, (name, child) in enumerate(children.items())
)

return "".join(
[
"<div style='display: inline-grid; grid-template-columns: 100%'>",
children_html,
"</div>",
]
)
return f"<div class='xr-tree'>{children_html}</div>"


children_section = partial(
Expand All @@ -46,20 +84,65 @@ def summarize_children(children: Mapping[str, Any]) -> str:
)


def node_repr(group_title: str, dt: Any) -> str:
header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]

ds = dt.ds

def summarize_data(node):
ds = node.ds
sections = [
children_section(dt.children),
dim_section(ds),
coord_section(ds.coords),
datavar_section(ds.data_vars),
attr_section(ds.attrs),
]

return _obj_repr(ds, header_components, sections)
sections_li = "".join(
f"<li class='xr-section-item'>{section}</li>" for section in sections
)
return (
"<div class='xr-tree-item-data-sections'>"
f"<ul class='xr-sections' style='width: 1200px;'>{sections_li}</ul>"
"</div>"
)


def data_section(node):
name = "Data"

details = summarize_data(node)

n_items = 5
expanded = _get_boolean_with_default(
"display_expand_group_data",
True,
)
collapsed = not expanded

return collapsible_section(
name=name,
details=details,
n_items=n_items,
enabled=True,
collapsed=collapsed,
)


def join_sections(sections, header_components):
combined_sections = "".join(
f"<li class='xr-section-item'>{s}</li>" for s in sections
)
header = "".join(header_components)
return (
"<div class='xr-wrap' style='display: none'>"
f"<div class='xr-header'>{header}</div>"
f"<ul class='xr-sections'>{combined_sections}</ul>"
"</div>"
)


def node_repr(group_title: str, dt: Any) -> str:
header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]

sections = [children_section(dt.children), data_section(dt)]

return join_sections(sections, header_components)


def _wrap_repr(r: str, end: bool = False) -> str:
Expand Down Expand Up @@ -99,40 +182,27 @@ def _wrap_repr(r: str, end: bool = False) -> str:

Tee color is set to the variable :code:`--xr-border-color`.
"""
# height of line
end = bool(end)
height = "100%" if end is False else "1.2em"
return "".join(
[
"<div style='display: inline-grid;'>",
"<div style='",
"grid-column-start: 1;",
"border-right: 0.2em solid;",
"border-color: var(--xr-border-color);",
f"height: {height};",
"width: 0px;",
"'>",
"</div>",
"<div style='",
"grid-column-start: 2;",
"grid-row-start: 1;",
"height: 1em;",
"width: 20px;",
"border-bottom: 0.2em solid;",
"border-color: var(--xr-border-color);",
"'>",
"</div>",
"<div style='",
"grid-column-start: 3;",
"'>",
"<ul class='xr-sections'>",
r,
"</ul>" "</div>",
"</div>",
]
item_class = "xr-tree-item-mid" if not end else "xr-tree-item-end"
return (
"<div class='xr-tree-item'>"
f"<div class='xr-tree-item-connection-vertical {item_class}'></div>"
"<div class='xr-tree-item-connection-horizontal'></div>"
f"<div class='xr-tree-item-data'><ul class='xr-sections'>{r}</ul></div>"
"</div>"
)


def datatree_repr(dt: Any) -> str:
obj_type = f"datatree.{type(dt).__name__}"
return node_repr(obj_type, dt)

icons_svg, css_style = _load_static_files()

return (
"<div>"
f"{icons_svg}"
f"<style>{css_style}</style>"
f"<style>{additional_css_style}</style>"
f"<pre class='xr-text-repr-fallback'>{escape(repr(dt))}</pre>"
f"{node_repr(obj_type, dt)}"
"</div>"
)