From 5ce91407cb018632ca4479d6be8e22118bf815e5 Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 23 Oct 2025 07:47:07 +0000 Subject: [PATCH] feat: narwhals-compliant dataframe --- examples/dataframe_example.py | 17 +++++++++ examples/narwhals_example.py | 36 ++++++++++++++++++ src/substrait/dataframe/__init__.py | 16 ++++++++ src/substrait/dataframe/dataframe.py | 36 ++++++++++++++++++ src/substrait/dataframe/expression.py | 36 ++++++++++++++++++ tests/dataframe/test_df_project.py | 54 +++++++++++++++++++++++++++ 6 files changed, 195 insertions(+) create mode 100644 examples/dataframe_example.py create mode 100644 examples/narwhals_example.py create mode 100644 src/substrait/dataframe/__init__.py create mode 100644 src/substrait/dataframe/dataframe.py create mode 100644 src/substrait/dataframe/expression.py create mode 100644 tests/dataframe/test_df_project.py diff --git a/examples/dataframe_example.py b/examples/dataframe_example.py new file mode 100644 index 0000000..f2d14d6 --- /dev/null +++ b/examples/dataframe_example.py @@ -0,0 +1,17 @@ +from substrait.builders.plan import read_named_table +from substrait.builders.type import i64, boolean, struct, named_struct +from substrait.extension_registry import ExtensionRegistry +import substrait.dataframe as sdf + +registry = ExtensionRegistry(load_default_extensions=True) + +ns = named_struct( + names=["id", "is_applicable"], + struct=struct(types=[i64(nullable=False), boolean()], nullable=False), +) + +table = read_named_table("example_table", ns) + +frame = sdf.DataFrame(read_named_table("example_table", ns)) +frame = frame.select(sdf.col("id")) +print(frame.to_substrait(registry)) diff --git a/examples/narwhals_example.py b/examples/narwhals_example.py new file mode 100644 index 0000000..736af04 --- /dev/null +++ b/examples/narwhals_example.py @@ -0,0 +1,36 @@ +# Install duckdb and pyarrow before running this example +# /// script +# dependencies = [ +# "narwhals==2.9.0", +# "substrait[extensions] @ file:///${PROJECT_ROOT}/" +# ] +# /// + +from substrait.builders.plan import read_named_table +from substrait.builders.type import i64, boolean, struct, named_struct +from substrait.extension_registry import ExtensionRegistry + +from narwhals.typing import FrameT +import narwhals as nw +import substrait.dataframe as sdf + + +registry = ExtensionRegistry(load_default_extensions=True) + +ns = named_struct( + names=["id", "is_applicable"], + struct=struct(types=[i64(nullable=False), boolean()], nullable=False), +) + +table = read_named_table("example_table", ns) + + +lazy_frame: FrameT = nw.from_native( + sdf.DataFrame(read_named_table("example_table", ns)) +) + +lazy_frame = lazy_frame.select(nw.col("id").abs(), new_id=nw.col("id")) + +df: sdf.DataFrame = lazy_frame.to_native() + +print(df.to_substrait(registry)) diff --git a/src/substrait/dataframe/__init__.py b/src/substrait/dataframe/__init__.py new file mode 100644 index 0000000..5236ac4 --- /dev/null +++ b/src/substrait/dataframe/__init__.py @@ -0,0 +1,16 @@ +import substrait.dataframe +from substrait.builders.extended_expression import column + +from substrait.dataframe.dataframe import DataFrame +from substrait.dataframe.expression import Expression + +__all__ = [DataFrame, Expression] + + +def col(name: str) -> Expression: + """Column selection.""" + return Expression(column(name)) + +# TODO +def parse_into_expr(expr, str_as_lit: bool): + return expr._to_compliant_expr(substrait.dataframe) diff --git a/src/substrait/dataframe/dataframe.py b/src/substrait/dataframe/dataframe.py new file mode 100644 index 0000000..98fb5be --- /dev/null +++ b/src/substrait/dataframe/dataframe.py @@ -0,0 +1,36 @@ +from typing import Union, Iterable +import substrait.dataframe +from substrait.builders.plan import project +from substrait.dataframe.expression import Expression + + +class DataFrame: + def __init__(self, plan): + self.plan = plan + self._native_frame = self + + def to_substrait(self, registry): + return self.plan(registry) + + def __narwhals_lazyframe__(self) -> "DataFrame": + """Return object implementing CompliantDataFrame protocol.""" + return self + + def __narwhals_namespace__(self): + """ + Return the namespace object that contains functions like col, lit, etc. + This is how Narwhals knows which backend's functions to use. + """ + return substrait.dataframe + + def select( + self, *exprs: Union[Expression, Iterable[Expression]], **named_exprs: Expression + ) -> "DataFrame": + expressions = [e.expr for e in exprs] + [ + expr.alias(alias).expr for alias, expr in named_exprs.items() + ] + return DataFrame(project(self.plan, expressions=expressions)) + + # TODO handle version + def _with_version(self, version): + return self diff --git a/src/substrait/dataframe/expression.py b/src/substrait/dataframe/expression.py new file mode 100644 index 0000000..011b625 --- /dev/null +++ b/src/substrait/dataframe/expression.py @@ -0,0 +1,36 @@ +from substrait.builders.extended_expression import ( + UnboundExtendedExpression, + ExtendedExpressionOrUnbound, + resolve_expression, + scalar_function +) +import substrait.gen.proto.type_pb2 as stp +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.extension_registry import ExtensionRegistry + + +def _alias( + expr: ExtendedExpressionOrUnbound, + alias: str = None, +): + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_expression = resolve_expression(expr, base_schema, registry) + bound_expression.referred_expr[0].output_names[0] = alias + return bound_expression + + return resolve + + +class Expression: + def __init__(self, expr: UnboundExtendedExpression): + self.expr = expr + + def alias(self, alias: str): + self.expr = _alias(self.expr, alias) + return self + + def abs(self): + self.expr = scalar_function("functions_arithmetic.yaml", "abs", expressions=[self.expr]) + return self diff --git a/tests/dataframe/test_df_project.py b/tests/dataframe/test_df_project.py new file mode 100644 index 0000000..6caaa1d --- /dev/null +++ b/tests/dataframe/test_df_project.py @@ -0,0 +1,54 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table +from substrait.extension_registry import ExtensionRegistry +import substrait.dataframe as sdf + + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) + +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + + +def test_project(): + df = sdf.DataFrame(read_named_table("table", named_struct)) + + actual = df.select(id=sdf.col("id")).to_substrait(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon( + emit=stalg.RelCommon.Emit(output_mapping=[2]) + ), + input=df.to_substrait(None).relations[-1].root.input, + expressions=[ + stalg.Expression( + selection=stalg.Expression.FieldReference( + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=0 + ) + ), + root_reference=stalg.Expression.FieldReference.RootReference(), + ) + ) + ], + ) + ), + names=["id"], + ) + ) + ] + ) + + assert actual == expected