From b9d798b4cb57e9ddc5bb3115b531828b4337627d Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 17 Apr 2024 16:21:48 +0200 Subject: [PATCH 01/18] Initial Import --- pyproject.toml | 2 +- src/substrait/sql/__init__.py | 1 + src/substrait/sql/extended_expression.py | 315 +++++++++++++++++++++++ 3 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 src/substrait/sql/__init__.py create mode 100644 src/substrait/sql/extended_expression.py diff --git a/pyproject.toml b/pyproject.toml index 7407070..d0faa2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20"] +dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/sql/__init__.py b/src/substrait/sql/__init__.py new file mode 100644 index 0000000..4dea9c5 --- /dev/null +++ b/src/substrait/sql/__init__.py @@ -0,0 +1 @@ +from .extended_expression import parse_sql_extended_expression diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py new file mode 100644 index 0000000..b6af469 --- /dev/null +++ b/src/substrait/sql/extended_expression.py @@ -0,0 +1,315 @@ +import pathlib + +import sqlglot +import yaml + +from substrait import proto + + +SQL_BINARY_FUNCTIONS = { + # Arithmetic + "add": "add", + "div": "div", + "mul": "mul", + "sub": "sub", + # Comparisons + "eq": "equal", +} + + +def parse_sql_extended_expression(catalog, schema, sql): + select = sqlglot.parse_one(sql) + if not isinstance(select, sqlglot.expressions.Select): + raise ValueError("a SELECT statement was expected") + + invoked_functions_projection, projections = _substrait_projection_from_sqlglot( + catalog, schema, select.expressions + ) + extension_uris, extensions = catalog.extensions_for_functions( + invoked_functions_projection + ) + projection_extended_expr = proto.ExtendedExpression( + extension_uris=extension_uris, + extensions=extensions, + base_schema=schema, + referred_expr=projections, + ) + + invoked_functions_filter, filter_expr = _substrait_expression_from_sqlglot( + catalog, schema, select.find(sqlglot.expressions.Where).this + ) + extension_uris, extensions = catalog.extensions_for_functions( + invoked_functions_filter + ) + filter_extended_expr = proto.ExtendedExpression( + extension_uris=extension_uris, + extensions=extensions, + base_schema=schema, + referred_expr=[proto.ExpressionReference(expression=filter_expr)], + ) + + return projection_extended_expr, filter_extended_expr + + +def _substrait_projection_from_sqlglot(catalog, schema, expressions): + if not expressions: + return set(), [] + + # My understanding of ExtendedExpressions is that they are meant to directly + # point to the Expression that ProjectRel would contain, so we don't actually + # need a ProjectRel at all. + """ + projection_sub = proto.ProjectRel( + input=proto.Rel( + read=proto.ReadRel( + named_table=proto.ReadRel.NamedTable(names=["__table__"]), + base_schema=schema, + ) + ), + expressions=[], + ) + """ + + substrait_expressions = [] + invoked_functions = set() + for sqlexpr in expressions: + output_names = [] + if isinstance(sqlexpr, sqlglot.expressions.Alias): + output_names = [sqlexpr.output_name] + sqlexpr = sqlexpr.this + _, substrait_expr = _parse_expression( + catalog, schema, sqlexpr, invoked_functions + ) + substrait_expr_reference = proto.ExpressionReference( + expression=substrait_expr, output_names=output_names + ) + substrait_expressions.append(substrait_expr_reference) + + return invoked_functions, substrait_expressions + + +def _substrait_expression_from_sqlglot(catalog, schema, sqlglot_node): + if not sqlglot_node: + return set(), None + + invoked_functions = set() + _, substrait_expr = _parse_expression( + catalog, schema, sqlglot_node, invoked_functions + ) + return invoked_functions, substrait_expr + + +def _parse_expression(catalog, schema, expr, invoked_functions): + # TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name + if isinstance(expr, sqlglot.expressions.Literal): + if expr.is_string: + return proto.Type(string=proto.Type.String()), proto.Expression( + literal=proto.Expression.Literal(string=expr.text) + ) + elif expr.is_int: + return proto.Type(i32=proto.Type.I32()), proto.Expression( + literal=proto.Expression.Literal(i32=int(expr.name)) + ) + elif sqlglot.helper.is_float(expr.name): + return proto.Type(fp32=proto.Type.FP32()), proto.Expression( + literal=proto.Expression.Literal(float=float(expr.name)) + ) + else: + raise ValueError(f"Unsupporter literal: {expr.text}") + elif isinstance(expr, sqlglot.expressions.Column): + column_name = expr.output_name + schema_field = list(schema.names).index(column_name) + schema_type = schema.struct.types[schema_field] + return schema_type, proto.Expression( + selection=proto.Expression.FieldReference( + direct_reference=proto.Expression.ReferenceSegment( + struct_field=proto.Expression.ReferenceSegment.StructField( + field=schema_field + ) + ) + ) + ) + elif expr.key in SQL_BINARY_FUNCTIONS: + left_type, left = _parse_expression( + catalog, schema, expr.left, invoked_functions + ) + right_type, right = _parse_expression( + catalog, schema, expr.right, invoked_functions + ) + function_name = SQL_BINARY_FUNCTIONS[expr.key] + signature, result_type, function_expression = _parse_function_invokation( + function_name, left_type, left, right_type, right + ) + invoked_functions.add(signature) + return result_type, function_expression + else: + raise ValueError( + f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" + ) + + +def _parse_function_invokation(function_name, left_type, left, right_type, right): + signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" + try: + function_anchor = catalog.function_anchor(signature) + except KeyError: + # not function found with the exact types, try any1_any1 version + signature = f"{function_name}:any1_any1" + function_anchor = catalog.function_anchor(signature) + return ( + signature, + left_type, + proto.Expression( + scalar_function=proto.Expression.ScalarFunction( + function_reference=function_anchor, + arguments=[ + proto.FunctionArgument(value=left), + proto.FunctionArgument(value=right), + ], + ) + ), + ) + + +class FunctionsCatalog: + STANDARD_EXTENSIONS = ( + "/functions_aggregate_approx.yaml", + "/functions_aggregate_generic.yaml", + "/functions_arithmetic.yaml", + "/functions_arithmetic_decimal.yaml", + "/functions_boolean.yaml", + "/functions_comparison.yaml", + "/functions_datetime.yaml", + "/functions_geometry.yaml", + "/functions_logarithmic.yaml", + "/functions_rounding.yaml", + "/functions_set.yaml", + "/functions_string.yaml", + ) + + def __init__(self): + self._declarations = {} + self._registered_extensions = {} + self._functions = {} + + def load_standard_extensions(self, dirpath): + for ext in self.STANDARD_EXTENSIONS: + self.load(dirpath, ext) + + def load(self, dirpath, filename): + with open(pathlib.Path(dirpath) / filename.strip("/")) as f: + sections = yaml.safe_load(f) + + loaded_functions = set() + for functions in sections.values(): + for function in functions: + function_name = function["name"] + for impl in function.get("impls", []): + argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] + if not argtypes: + signature = function_name + else: + signature = f"{function_name}:{'_'.join(argtypes)}" + self._declarations[signature] = filename + loaded_functions.add(signature) + + self._register_extensions(filename, loaded_functions) + + def _register_extensions(self, extension_uri, loaded_functions): + if extension_uri not in self._registered_extensions: + ext_anchor_id = len(self._registered_extensions) + 1 + self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( + extension_uri_anchor=ext_anchor_id, uri=extension_uri + ) + + for function in loaded_functions: + if function in self._functions: + extensions_by_anchor = self.extension_uris_by_anchor + function = self._functions[function] + function_extension = extensions_by_anchor[ + function.extension_uri_reference + ].uri + continue + raise ValueError( + f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" + ) + extension_anchor = self._registered_extensions[ + extension_uri + ].extension_uri_anchor + function_anchor = len(self._functions) + 1 + self._functions[function] = ( + proto.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=extension_anchor, + name=function, + function_anchor=function_anchor, + ) + ) + + @property + def extension_uris_by_anchor(self): + return { + ext.extension_uri_anchor: ext + for ext in self._registered_extensions.values() + } + + @property + def extension_uris(self): + return list(self._registered_extensions.values()) + + @property + def extensions(self): + return list(self._functions.values()) + + def function_anchor(self, function): + return self._functions[function].function_anchor + + def extensions_for_functions(self, functions): + uris_anchors = set() + extensions = [] + for f in functions: + ext = self._functions[f] + uris_anchors.add(ext.extension_uri_reference) + extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) + + uris_by_anchor = self.extension_uris_by_anchor + extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] + return extension_uris, extensions + + +catalog = FunctionsCatalog() +catalog.load_standard_extensions( + pathlib.Path(__file__).parent.parent / "third_party" / "substrait" / "extensions", +) +schema = proto.NamedStruct( + names=["first_name", "surname", "age"], + struct=proto.Type.Struct( + types=[ + proto.Type( + string=proto.Type.String( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + proto.Type( + string=proto.Type.String( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + proto.Type( + i32=proto.Type.I32( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + ] + ), +) + +if __name__ == '__main__': + sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" + projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) + print("---- SQL INPUT ----") + print(sql) + print("---- PROJECTION ----") + print(projection_expr) + print("---- FILTER ----") + print(filter_expr) + # parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") From 7358313f60f3f291b938942bd96eeb9af5e87bfc Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 17 Apr 2024 16:52:11 +0200 Subject: [PATCH 02/18] Split in submodules --- pyproject.toml | 2 +- src/substrait/sql/__main__.py | 42 +++++++ src/substrait/sql/extended_expression.py | 150 +---------------------- src/substrait/sql/functions_catalog.py | 120 ++++++++++++++++++ 4 files changed, 165 insertions(+), 149 deletions(-) create mode 100644 src/substrait/sql/__main__.py create mode 100644 src/substrait/sql/functions_catalog.py diff --git a/pyproject.toml b/pyproject.toml index d0faa2b..8aa3246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = [{name = "Substrait contributors", email = "substrait@googlegroups.com license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.8.1" -dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0"] +dependencies = ["protobuf >= 3.20", "sqlglot >= 23.10.0", "PyYAML"] dynamic = ["version"] [tool.setuptools_scm] diff --git a/src/substrait/sql/__main__.py b/src/substrait/sql/__main__.py new file mode 100644 index 0000000..119343b --- /dev/null +++ b/src/substrait/sql/__main__.py @@ -0,0 +1,42 @@ +import pathlib + +from substrait import proto +from .functions_catalog import FunctionsCatalog +from .extended_expression import parse_sql_extended_expression + +catalog = FunctionsCatalog() +catalog.load_standard_extensions( + pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions", +) +schema = proto.NamedStruct( + names=["first_name", "surname", "age"], + struct=proto.Type.Struct( + types=[ + proto.Type( + string=proto.Type.String( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + proto.Type( + string=proto.Type.String( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + proto.Type( + i32=proto.Type.I32( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + ), + ] + ), +) + +sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" +projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) +print("---- SQL INPUT ----") +print(sql) +print("---- PROJECTION ----") +print(projection_expr) +print("---- FILTER ----") +print(filter_expr) +# parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") \ No newline at end of file diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index b6af469..d684507 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -1,7 +1,4 @@ -import pathlib - import sqlglot -import yaml from substrait import proto @@ -138,7 +135,7 @@ def _parse_expression(catalog, schema, expr, invoked_functions): ) function_name = SQL_BINARY_FUNCTIONS[expr.key] signature, result_type, function_expression = _parse_function_invokation( - function_name, left_type, left, right_type, right + catalog, function_name, left_type, left, right_type, right ) invoked_functions.add(signature) return result_type, function_expression @@ -148,7 +145,7 @@ def _parse_expression(catalog, schema, expr, invoked_functions): ) -def _parse_function_invokation(function_name, left_type, left, right_type, right): +def _parse_function_invokation(catalog, function_name, left_type, left, right_type, right): signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" try: function_anchor = catalog.function_anchor(signature) @@ -170,146 +167,3 @@ def _parse_function_invokation(function_name, left_type, left, right_type, right ), ) - -class FunctionsCatalog: - STANDARD_EXTENSIONS = ( - "/functions_aggregate_approx.yaml", - "/functions_aggregate_generic.yaml", - "/functions_arithmetic.yaml", - "/functions_arithmetic_decimal.yaml", - "/functions_boolean.yaml", - "/functions_comparison.yaml", - "/functions_datetime.yaml", - "/functions_geometry.yaml", - "/functions_logarithmic.yaml", - "/functions_rounding.yaml", - "/functions_set.yaml", - "/functions_string.yaml", - ) - - def __init__(self): - self._declarations = {} - self._registered_extensions = {} - self._functions = {} - - def load_standard_extensions(self, dirpath): - for ext in self.STANDARD_EXTENSIONS: - self.load(dirpath, ext) - - def load(self, dirpath, filename): - with open(pathlib.Path(dirpath) / filename.strip("/")) as f: - sections = yaml.safe_load(f) - - loaded_functions = set() - for functions in sections.values(): - for function in functions: - function_name = function["name"] - for impl in function.get("impls", []): - argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] - if not argtypes: - signature = function_name - else: - signature = f"{function_name}:{'_'.join(argtypes)}" - self._declarations[signature] = filename - loaded_functions.add(signature) - - self._register_extensions(filename, loaded_functions) - - def _register_extensions(self, extension_uri, loaded_functions): - if extension_uri not in self._registered_extensions: - ext_anchor_id = len(self._registered_extensions) + 1 - self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( - extension_uri_anchor=ext_anchor_id, uri=extension_uri - ) - - for function in loaded_functions: - if function in self._functions: - extensions_by_anchor = self.extension_uris_by_anchor - function = self._functions[function] - function_extension = extensions_by_anchor[ - function.extension_uri_reference - ].uri - continue - raise ValueError( - f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" - ) - extension_anchor = self._registered_extensions[ - extension_uri - ].extension_uri_anchor - function_anchor = len(self._functions) + 1 - self._functions[function] = ( - proto.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=extension_anchor, - name=function, - function_anchor=function_anchor, - ) - ) - - @property - def extension_uris_by_anchor(self): - return { - ext.extension_uri_anchor: ext - for ext in self._registered_extensions.values() - } - - @property - def extension_uris(self): - return list(self._registered_extensions.values()) - - @property - def extensions(self): - return list(self._functions.values()) - - def function_anchor(self, function): - return self._functions[function].function_anchor - - def extensions_for_functions(self, functions): - uris_anchors = set() - extensions = [] - for f in functions: - ext = self._functions[f] - uris_anchors.add(ext.extension_uri_reference) - extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) - - uris_by_anchor = self.extension_uris_by_anchor - extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] - return extension_uris, extensions - - -catalog = FunctionsCatalog() -catalog.load_standard_extensions( - pathlib.Path(__file__).parent.parent / "third_party" / "substrait" / "extensions", -) -schema = proto.NamedStruct( - names=["first_name", "surname", "age"], - struct=proto.Type.Struct( - types=[ - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - proto.Type( - i32=proto.Type.I32( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - ] - ), -) - -if __name__ == '__main__': - sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" - projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) - print("---- SQL INPUT ----") - print(sql) - print("---- PROJECTION ----") - print(projection_expr) - print("---- FILTER ----") - print(filter_expr) - # parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py new file mode 100644 index 0000000..54a089a --- /dev/null +++ b/src/substrait/sql/functions_catalog.py @@ -0,0 +1,120 @@ +import pathlib + +import yaml + +from substrait import proto + + +class FunctionsCatalog: + """Catalog of Substrait functions and extensions. + + Loads extensions from YAML files and records the declared functions. + Given a set of functions it can generate the necessary extension URIs + and extensions to be included in an ExtendedExpression or Plan. + """ + + # TODO: Find a way to support standard extensions in released distribution. + # IE: Include the standard extension yaml files in the package data and + # update them when gen_proto is used.. + STANDARD_EXTENSIONS = ( + "/functions_aggregate_approx.yaml", + "/functions_aggregate_generic.yaml", + "/functions_arithmetic.yaml", + "/functions_arithmetic_decimal.yaml", + "/functions_boolean.yaml", + "/functions_comparison.yaml", + "/functions_datetime.yaml", + "/functions_geometry.yaml", + "/functions_logarithmic.yaml", + "/functions_rounding.yaml", + "/functions_set.yaml", + "/functions_string.yaml", + ) + + def __init__(self): + self._declarations = {} + self._registered_extensions = {} + self._functions = {} + + def load_standard_extensions(self, dirpath): + for ext in self.STANDARD_EXTENSIONS: + self.load(dirpath, ext) + + def load(self, dirpath, filename): + with open(pathlib.Path(dirpath) / filename.strip("/")) as f: + sections = yaml.safe_load(f) + + loaded_functions = set() + for functions in sections.values(): + for function in functions: + function_name = function["name"] + for impl in function.get("impls", []): + argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] + if not argtypes: + signature = function_name + else: + signature = f"{function_name}:{'_'.join(argtypes)}" + self._declarations[signature] = filename + loaded_functions.add(signature) + + self._register_extensions(filename, loaded_functions) + + def _register_extensions(self, extension_uri, loaded_functions): + if extension_uri not in self._registered_extensions: + ext_anchor_id = len(self._registered_extensions) + 1 + self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( + extension_uri_anchor=ext_anchor_id, uri=extension_uri + ) + + for function in loaded_functions: + if function in self._functions: + extensions_by_anchor = self.extension_uris_by_anchor + function = self._functions[function] + function_extension = extensions_by_anchor[ + function.extension_uri_reference + ].uri + continue + raise ValueError( + f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" + ) + extension_anchor = self._registered_extensions[ + extension_uri + ].extension_uri_anchor + function_anchor = len(self._functions) + 1 + self._functions[function] = ( + proto.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=extension_anchor, + name=function, + function_anchor=function_anchor, + ) + ) + + @property + def extension_uris_by_anchor(self): + return { + ext.extension_uri_anchor: ext + for ext in self._registered_extensions.values() + } + + @property + def extension_uris(self): + return list(self._registered_extensions.values()) + + @property + def extensions(self): + return list(self._functions.values()) + + def function_anchor(self, function): + return self._functions[function].function_anchor + + def extensions_for_functions(self, functions): + uris_anchors = set() + extensions = [] + for f in functions: + ext = self._functions[f] + uris_anchors.add(ext.extension_uri_reference) + extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) + + uris_by_anchor = self.extension_uris_by_anchor + extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] + return extension_uris, extensions \ No newline at end of file From 69e6dcf3e44b518a9c0371e31729face447f8d3d Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 17 Apr 2024 17:48:05 +0200 Subject: [PATCH 03/18] Refactor parsing of expressions and propagate column names --- src/substrait/sql/__main__.py | 7 +- src/substrait/sql/extended_expression.py | 208 ++++++++++------------- 2 files changed, 97 insertions(+), 118 deletions(-) diff --git a/src/substrait/sql/__main__.py b/src/substrait/sql/__main__.py index 119343b..76556a2 100644 --- a/src/substrait/sql/__main__.py +++ b/src/substrait/sql/__main__.py @@ -8,6 +8,10 @@ catalog.load_standard_extensions( pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions", ) + +# TODO: Turn this into a command line tool to test more queries. +# We can probably have a quick way to declare schema using command line args. +# like first_name=String,surname=String,age=I32 etc... schema = proto.NamedStruct( names=["first_name", "surname", "age"], struct=proto.Type.Struct( @@ -38,5 +42,4 @@ print("---- PROJECTION ----") print(projection_expr) print("---- FILTER ----") -print(filter_expr) -# parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)") \ No newline at end of file +print(filter_expr) \ No newline at end of file diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index d684507..1913376 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -1,3 +1,5 @@ +import itertools + import sqlglot from substrait import proto @@ -19,21 +21,28 @@ def parse_sql_extended_expression(catalog, schema, sql): if not isinstance(select, sqlglot.expressions.Select): raise ValueError("a SELECT statement was expected") - invoked_functions_projection, projections = _substrait_projection_from_sqlglot( - catalog, schema, select.expressions - ) + sqlglot_parser = SQLGlotParser(catalog, schema) + + # Handle the projections in the SELECT statemenent. + project_expressions = [] + projection_invoked_functions = set() + for sqlexpr in select.expressions: + invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(sqlexpr) + projection_invoked_functions.update(invoked_functions) + project_expressions.append(proto.ExpressionReference(expression=expr, output_names=[output_name])) extension_uris, extensions = catalog.extensions_for_functions( - invoked_functions_projection + projection_invoked_functions ) projection_extended_expr = proto.ExtendedExpression( extension_uris=extension_uris, extensions=extensions, base_schema=schema, - referred_expr=projections, + referred_expr=project_expressions, ) - invoked_functions_filter, filter_expr = _substrait_expression_from_sqlglot( - catalog, schema, select.find(sqlglot.expressions.Where).this + # Handle WHERE clause in the SELECT statement. + invoked_functions_filter, _, filter_expr = sqlglot_parser.expression_from_sqlglot( + select.find(sqlglot.expressions.Where).this ) extension_uris, extensions = catalog.extensions_for_functions( invoked_functions_filter @@ -48,122 +57,89 @@ def parse_sql_extended_expression(catalog, schema, sql): return projection_extended_expr, filter_extended_expr -def _substrait_projection_from_sqlglot(catalog, schema, expressions): - if not expressions: - return set(), [] +class SQLGlotParser: + def __init__(self, functions_catalog, schema): + self._functions_catalog = functions_catalog + self._schema = schema + self._counter = itertools.count() - # My understanding of ExtendedExpressions is that they are meant to directly - # point to the Expression that ProjectRel would contain, so we don't actually - # need a ProjectRel at all. - """ - projection_sub = proto.ProjectRel( - input=proto.Rel( - read=proto.ReadRel( - named_table=proto.ReadRel.NamedTable(names=["__table__"]), - base_schema=schema, - ) - ), - expressions=[], - ) - """ - - substrait_expressions = [] - invoked_functions = set() - for sqlexpr in expressions: - output_names = [] - if isinstance(sqlexpr, sqlglot.expressions.Alias): - output_names = [sqlexpr.output_name] - sqlexpr = sqlexpr.this - _, substrait_expr = _parse_expression( - catalog, schema, sqlexpr, invoked_functions - ) - substrait_expr_reference = proto.ExpressionReference( - expression=substrait_expr, output_names=output_names + def expression_from_sqlglot(self, sqlglot_node): + invoked_functions = set() + output_name, _, substrait_expr = self._parse_expression( + sqlglot_node, invoked_functions ) - substrait_expressions.append(substrait_expr_reference) - - return invoked_functions, substrait_expressions - + return invoked_functions, output_name, substrait_expr -def _substrait_expression_from_sqlglot(catalog, schema, sqlglot_node): - if not sqlglot_node: - return set(), None - - invoked_functions = set() - _, substrait_expr = _parse_expression( - catalog, schema, sqlglot_node, invoked_functions - ) - return invoked_functions, substrait_expr - - -def _parse_expression(catalog, schema, expr, invoked_functions): - # TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name - if isinstance(expr, sqlglot.expressions.Literal): - if expr.is_string: - return proto.Type(string=proto.Type.String()), proto.Expression( - literal=proto.Expression.Literal(string=expr.text) + def _parse_expression(self, expr, invoked_functions): + if isinstance(expr, sqlglot.expressions.Literal): + if expr.is_string: + return f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( + literal=proto.Expression.Literal(string=expr.text) + ) + elif expr.is_int: + return f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( + literal=proto.Expression.Literal(i32=int(expr.name)) + ) + elif sqlglot.helper.is_float(expr.name): + return f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( + literal=proto.Expression.Literal(float=float(expr.name)) + ) + else: + raise ValueError(f"Unsupporter literal: {expr.text}") + elif isinstance(expr, sqlglot.expressions.Column): + column_name = expr.output_name + schema_field = list(self._schema.names).index(column_name) + schema_type = self._schema.struct.types[schema_field] + return column_name, schema_type, proto.Expression( + selection=proto.Expression.FieldReference( + direct_reference=proto.Expression.ReferenceSegment( + struct_field=proto.Expression.ReferenceSegment.StructField( + field=schema_field + ) + ) + ) ) - elif expr.is_int: - return proto.Type(i32=proto.Type.I32()), proto.Expression( - literal=proto.Expression.Literal(i32=int(expr.name)) + elif isinstance(expr, sqlglot.expressions.Alias): + _, aliased_type, aliased_expr = self._parse_expression(expr.this, invoked_functions) + return expr.output_name, aliased_type, aliased_expr + elif expr.key in SQL_BINARY_FUNCTIONS: + left_name, left_type, left = self._parse_expression( + expr.left, invoked_functions ) - elif sqlglot.helper.is_float(expr.name): - return proto.Type(fp32=proto.Type.FP32()), proto.Expression( - literal=proto.Expression.Literal(float=float(expr.name)) + right_name, right_type, right = self._parse_expression( + expr.right, invoked_functions ) + function_name = SQL_BINARY_FUNCTIONS[expr.key] + signature, result_type, function_expression = self._parse_function_invokation( + function_name, left_type, left, right_type, right + ) + invoked_functions.add(signature) + result_name = f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + return result_name, result_type, function_expression else: - raise ValueError(f"Unsupporter literal: {expr.text}") - elif isinstance(expr, sqlglot.expressions.Column): - column_name = expr.output_name - schema_field = list(schema.names).index(column_name) - schema_type = schema.struct.types[schema_field] - return schema_type, proto.Expression( - selection=proto.Expression.FieldReference( - direct_reference=proto.Expression.ReferenceSegment( - struct_field=proto.Expression.ReferenceSegment.StructField( - field=schema_field - ) - ) + raise ValueError( + f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" ) - ) - elif expr.key in SQL_BINARY_FUNCTIONS: - left_type, left = _parse_expression( - catalog, schema, expr.left, invoked_functions - ) - right_type, right = _parse_expression( - catalog, schema, expr.right, invoked_functions - ) - function_name = SQL_BINARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = _parse_function_invokation( - catalog, function_name, left_type, left, right_type, right - ) - invoked_functions.add(signature) - return result_type, function_expression - else: - raise ValueError( - f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" - ) - -def _parse_function_invokation(catalog, function_name, left_type, left, right_type, right): - signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" - try: - function_anchor = catalog.function_anchor(signature) - except KeyError: - # not function found with the exact types, try any1_any1 version - signature = f"{function_name}:any1_any1" - function_anchor = catalog.function_anchor(signature) - return ( - signature, - left_type, - proto.Expression( - scalar_function=proto.Expression.ScalarFunction( - function_reference=function_anchor, - arguments=[ - proto.FunctionArgument(value=left), - proto.FunctionArgument(value=right), - ], - ) - ), - ) + def _parse_function_invokation(self, function_name, left_type, left, right_type, right): + signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" + try: + function_anchor = self._functions_catalog.function_anchor(signature) + except KeyError: + # not function found with the exact types, try any1_any1 version + signature = f"{function_name}:any1_any1" + function_anchor = self._functions_catalog.function_anchor(signature) + return ( + signature, + left_type, + proto.Expression( + scalar_function=proto.Expression.ScalarFunction( + function_reference=function_anchor, + arguments=[ + proto.FunctionArgument(value=left), + proto.FunctionArgument(value=right), + ], + ) + ), + ) From 600554221872ae423d7ffde5afd9439d43f0ff47 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 17 Apr 2024 18:07:38 +0200 Subject: [PATCH 04/18] track more TODOs --- src/substrait/sql/extended_expression.py | 5 +++-- src/substrait/sql/functions_catalog.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 1913376..1e43f55 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -126,12 +126,13 @@ def _parse_function_invokation(self, function_name, left_type, left, right_type, try: function_anchor = self._functions_catalog.function_anchor(signature) except KeyError: - # not function found with the exact types, try any1_any1 version + # No function found with the exact types, try any1_any1 version + # TODO: What about cases like i32_any1? What about any instead of any1? signature = f"{function_name}:any1_any1" function_anchor = self._functions_catalog.function_anchor(signature) return ( signature, - left_type, + left_type, # TODO: Get the actually returned type from the functions catalog. proto.Expression( scalar_function=proto.Expression.ScalarFunction( function_reference=function_anchor, diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 54a089a..8f7286d 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -73,9 +73,10 @@ def _register_extensions(self, extension_uri, loaded_functions): function_extension = extensions_by_anchor[ function.extension_uri_reference ].uri + # TODO: Support overloading of functions from different extensionUris. continue raise ValueError( - f"Duplicate function definition: {function} from {extension_uri}, already loaded from {function_extension}" + f"Duplicate function definition: {function.name} from {extension_uri}, already loaded from {function_extension}" ) extension_anchor = self._registered_extensions[ extension_uri From 96c583c0c8c355961b9027244e8a58a96f781e61 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 10:52:45 +0200 Subject: [PATCH 05/18] Command line tool to test queries --- src/substrait/sql/__main__.py | 98 +++++++++++++++--------- src/substrait/sql/extended_expression.py | 67 +++++++++++----- src/substrait/sql/functions_catalog.py | 5 +- 3 files changed, 110 insertions(+), 60 deletions(-) diff --git a/src/substrait/sql/__main__.py b/src/substrait/sql/__main__.py index 76556a2..f135e4a 100644 --- a/src/substrait/sql/__main__.py +++ b/src/substrait/sql/__main__.py @@ -1,45 +1,67 @@ import pathlib +import argparse from substrait import proto from .functions_catalog import FunctionsCatalog from .extended_expression import parse_sql_extended_expression -catalog = FunctionsCatalog() -catalog.load_standard_extensions( - pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions", -) - -# TODO: Turn this into a command line tool to test more queries. -# We can probably have a quick way to declare schema using command line args. -# like first_name=String,surname=String,age=I32 etc... -schema = proto.NamedStruct( - names=["first_name", "surname", "age"], - struct=proto.Type.Struct( - types=[ - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - proto.Type( - string=proto.Type.String( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), + +def main(): + """Commandline tool to test the SQL to ExtendedExpression parser. + + Run as python -m substrait.sql first_name=String,surname=String,age=I32 "SELECT surname, age + 1 as next_birthday, age + 2 WHERE age = 32" + """ + parser = argparse.ArgumentParser( + description="Convert a SQL SELECT statement to an ExtendedExpression" + ) + parser.add_argument("schema", type=str, help="Schema of the input data") + parser.add_argument("sql", type=str, help="SQL SELECT statement") + args = parser.parse_args() + + catalog = FunctionsCatalog() + catalog.load_standard_extensions( + pathlib.Path(__file__).parent.parent.parent.parent + / "third_party" + / "substrait" + / "extensions", + ) + schema = parse_schema(args.schema) + projection_expr, filter_expr = parse_sql_extended_expression( + catalog, schema, args.sql + ) + + print("---- SQL INPUT ----") + print(args.sql) + print("---- PROJECTION ----") + print(projection_expr) + print("---- FILTER ----") + print(filter_expr) + + +def parse_schema(schema_string): + """Parse Schema from a comma separated string of fieldname=fieldtype pairs. + + For example: "first_name=String,surname=String,age=I32" + """ + types = [] + names = [] + + fields = schema_string.split(",") + for field in fields: + fieldname, fieldtype = field.split("=") + proto_type = getattr(proto.Type, fieldtype) + names.append(fieldname) + types.append( proto.Type( - i32=proto.Type.I32( - nullability=proto.Type.Nullability.NULLABILITY_REQUIRED - ) - ), - ] - ), -) - -sql = "SELECT surname, age + 1 as next_birthday WHERE age = 32" -projection_expr, filter_expr = parse_sql_extended_expression(catalog, schema, sql) -print("---- SQL INPUT ----") -print(sql) -print("---- PROJECTION ----") -print(projection_expr) -print("---- FILTER ----") -print(filter_expr) \ No newline at end of file + **{ + fieldtype.lower(): proto_type( + nullability=proto.Type.Nullability.NULLABILITY_REQUIRED + ) + } + ) + ) + return proto.NamedStruct(names=names, struct=proto.Type.Struct(types=types)) + + +if __name__ == "__main__": + main() diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 1e43f55..d53f00f 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -27,9 +27,13 @@ def parse_sql_extended_expression(catalog, schema, sql): project_expressions = [] projection_invoked_functions = set() for sqlexpr in select.expressions: - invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(sqlexpr) + invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot( + sqlexpr + ) projection_invoked_functions.update(invoked_functions) - project_expressions.append(proto.ExpressionReference(expression=expr, output_names=[output_name])) + project_expressions.append( + proto.ExpressionReference(expression=expr, output_names=[output_name]) + ) extension_uris, extensions = catalog.extensions_for_functions( projection_invoked_functions ) @@ -73,16 +77,28 @@ def expression_from_sqlglot(self, sqlglot_node): def _parse_expression(self, expr, invoked_functions): if isinstance(expr, sqlglot.expressions.Literal): if expr.is_string: - return f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( - literal=proto.Expression.Literal(string=expr.text) + return ( + f"literal_{next(self._counter)}", + proto.Type(string=proto.Type.String()), + proto.Expression( + literal=proto.Expression.Literal(string=expr.text) + ), ) elif expr.is_int: - return f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( - literal=proto.Expression.Literal(i32=int(expr.name)) + return ( + f"literal_{next(self._counter)}", + proto.Type(i32=proto.Type.I32()), + proto.Expression( + literal=proto.Expression.Literal(i32=int(expr.name)) + ), ) elif sqlglot.helper.is_float(expr.name): - return f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( - literal=proto.Expression.Literal(float=float(expr.name)) + return ( + f"literal_{next(self._counter)}", + proto.Type(fp32=proto.Type.FP32()), + proto.Expression( + literal=proto.Expression.Literal(float=float(expr.name)) + ), ) else: raise ValueError(f"Unsupporter literal: {expr.text}") @@ -90,17 +106,23 @@ def _parse_expression(self, expr, invoked_functions): column_name = expr.output_name schema_field = list(self._schema.names).index(column_name) schema_type = self._schema.struct.types[schema_field] - return column_name, schema_type, proto.Expression( - selection=proto.Expression.FieldReference( - direct_reference=proto.Expression.ReferenceSegment( - struct_field=proto.Expression.ReferenceSegment.StructField( - field=schema_field + return ( + column_name, + schema_type, + proto.Expression( + selection=proto.Expression.FieldReference( + direct_reference=proto.Expression.ReferenceSegment( + struct_field=proto.Expression.ReferenceSegment.StructField( + field=schema_field + ) ) ) - ) + ), ) elif isinstance(expr, sqlglot.expressions.Alias): - _, aliased_type, aliased_expr = self._parse_expression(expr.this, invoked_functions) + _, aliased_type, aliased_expr = self._parse_expression( + expr.this, invoked_functions + ) return expr.output_name, aliased_type, aliased_expr elif expr.key in SQL_BINARY_FUNCTIONS: left_name, left_type, left = self._parse_expression( @@ -110,18 +132,24 @@ def _parse_expression(self, expr, invoked_functions): expr.right, invoked_functions ) function_name = SQL_BINARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = self._parse_function_invokation( - function_name, left_type, left, right_type, right + signature, result_type, function_expression = ( + self._parse_function_invokation( + function_name, left_type, left, right_type, right + ) ) invoked_functions.add(signature) - result_name = f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + result_name = ( + f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + ) return result_name, result_type, function_expression else: raise ValueError( f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" ) - def _parse_function_invokation(self, function_name, left_type, left, right_type, right): + def _parse_function_invokation( + self, function_name, left_type, left, right_type, right + ): signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" try: function_anchor = self._functions_catalog.function_anchor(signature) @@ -143,4 +171,3 @@ def _parse_function_invokation(self, function_name, left_type, left, right_type, ) ), ) - diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 8f7286d..e591523 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -7,7 +7,7 @@ class FunctionsCatalog: """Catalog of Substrait functions and extensions. - + Loads extensions from YAML files and records the declared functions. Given a set of functions it can generate the necessary extension URIs and extensions to be included in an ExtendedExpression or Plan. @@ -49,6 +49,7 @@ def load(self, dirpath, filename): for function in functions: function_name = function["name"] for impl in function.get("impls", []): + # TODO: There seem to be some functions that have arguments without type. What to do? argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] if not argtypes: signature = function_name @@ -118,4 +119,4 @@ def extensions_for_functions(self, functions): uris_by_anchor = self.extension_uris_by_anchor extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] - return extension_uris, extensions \ No newline at end of file + return extension_uris, extensions From 48c1d82fad78f957236059a97c48a213020834d6 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 15:15:53 +0200 Subject: [PATCH 06/18] Register builtin functions and handle return types --- src/substrait/sql/extended_expression.py | 62 +++++++++++++++--- src/substrait/sql/functions_catalog.py | 83 +++++++++++++++++++++--- 2 files changed, 126 insertions(+), 19 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index d53f00f..0e82abd 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -4,15 +4,29 @@ from substrait import proto - +SQL_UNARY_FUNCTIONS = {"not": "not"} SQL_BINARY_FUNCTIONS = { # Arithmetic "add": "add", "div": "div", "mul": "mul", "sub": "sub", + "mod": "modulus", + "bitwiseand": "bitwise_and", + "bitwiseor": "bitwise_or", + "bitwisexor": "bitwise_xor", + "bitwiseor": "bitwise_or", # Comparisons "eq": "equal", + "nullsafeeq": "is_not_distinct_from", + "new": "not_equal", + "gt": "gt", + "gte": "gte", + "lt": "lt", + "lte": "lte", + # logical + "and": "and", + "or": "or", } @@ -124,6 +138,17 @@ def _parse_expression(self, expr, invoked_functions): expr.this, invoked_functions ) return expr.output_name, aliased_type, aliased_expr + elif expr.key in SQL_UNARY_FUNCTIONS: + argument_name, argument_type, argument = self._parse_expression( + expr.this, invoked_functions + ) + function_name = SQL_UNARY_FUNCTIONS[expr.key] + signature, result_type, function_expression = ( + self._parse_function_invokation(function_name, argument_type, argument) + ) + invoked_functions.add(signature) + result_name = f"{function_name}_{argument_name}_{next(self._counter)}" + return result_name, result_type, function_expression elif expr.key in SQL_BINARY_FUNCTIONS: left_name, left_type, left = self._parse_expression( expr.left, invoked_functions @@ -148,26 +173,45 @@ def _parse_expression(self, expr, invoked_functions): ) def _parse_function_invokation( - self, function_name, left_type, left, right_type, right + self, function_name, left_type, left, right_type=None, right=None ): - signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}" + binary = False + argtypes = [left_type] + if right_type or right: + binary = True + argtypes.append(right_type) + signature = self._functions_catalog.signature(function_name, argtypes) + try: function_anchor = self._functions_catalog.function_anchor(signature) except KeyError: # No function found with the exact types, try any1_any1 version # TODO: What about cases like i32_any1? What about any instead of any1? - signature = f"{function_name}:any1_any1" + if binary: + signature = f"{function_name}:any1_any1" + else: + signature = f"{function_name}:any1" function_anchor = self._functions_catalog.function_anchor(signature) + + function_return_type = self._functions_catalog.function_return_type(signature) + if function_return_type is None: + print("No return type for", signature) + # TODO: Is this the right way to handle this? + function_return_type = left_type return ( signature, - left_type, # TODO: Get the actually returned type from the functions catalog. + function_return_type, proto.Expression( scalar_function=proto.Expression.ScalarFunction( function_reference=function_anchor, - arguments=[ - proto.FunctionArgument(value=left), - proto.FunctionArgument(value=right), - ], + arguments=( + [ + proto.FunctionArgument(value=left), + proto.FunctionArgument(value=right), + ] + if binary + else [proto.FunctionArgument(value=left)] + ), ) ), ) diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index e591523..4bd214d 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -23,7 +23,7 @@ class FunctionsCatalog: "/functions_arithmetic_decimal.yaml", "/functions_boolean.yaml", "/functions_comparison.yaml", - "/functions_datetime.yaml", + # "/functions_datetime.yaml", for now skip, it has duplicated functions "/functions_geometry.yaml", "/functions_logarithmic.yaml", "/functions_rounding.yaml", @@ -32,9 +32,10 @@ class FunctionsCatalog: ) def __init__(self): - self._declarations = {} self._registered_extensions = {} self._functions = {} + self._functions_return_type = {} + self._register_builtins() def load_standard_extensions(self, dirpath): for ext in self.STANDARD_EXTENSIONS: @@ -45,6 +46,7 @@ def load(self, dirpath, filename): sections = yaml.safe_load(f) loaded_functions = set() + functions_return_type = {} for functions in sections.values(): for function in functions: function_name = function["name"] @@ -55,12 +57,16 @@ def load(self, dirpath, filename): signature = function_name else: signature = f"{function_name}:{'_'.join(argtypes)}" - self._declarations[signature] = filename loaded_functions.add(signature) + functions_return_type[signature] = self._type_from_name( + impl["return"] + ) - self._register_extensions(filename, loaded_functions) + self._register_extensions(filename, loaded_functions, functions_return_type) - def _register_extensions(self, extension_uri, loaded_functions): + def _register_extensions( + self, extension_uri, loaded_functions, functions_return_type + ): if extension_uri not in self._registered_extensions: ext_anchor_id = len(self._registered_extensions) + 1 self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( @@ -70,14 +76,12 @@ def _register_extensions(self, extension_uri, loaded_functions): for function in loaded_functions: if function in self._functions: extensions_by_anchor = self.extension_uris_by_anchor - function = self._functions[function] + existing_function = self._functions[function] function_extension = extensions_by_anchor[ - function.extension_uri_reference + existing_function.extension_uri_reference ].uri - # TODO: Support overloading of functions from different extensionUris. - continue raise ValueError( - f"Duplicate function definition: {function.name} from {extension_uri}, already loaded from {function_extension}" + f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}" ) extension_anchor = self._registered_extensions[ extension_uri @@ -90,6 +94,48 @@ def _register_extensions(self, extension_uri, loaded_functions): function_anchor=function_anchor, ) ) + self._functions_return_type[function] = functions_return_type[function] + + def _register_builtins(self): + self._functions["not:boolean"] = ( + proto.SimpleExtensionDeclaration.ExtensionFunction( + name="not", + function_anchor=len(self._functions) + 1, + ) + ) + self._functions_return_type["not:boolean"] = proto.Type( + bool=proto.Type.Boolean() + ) + + def _type_from_name(self, typename): + nullable = False + if typename.endswith("?"): + nullable = True + + typename = typename.strip("?") + if typename in ("any", "any1"): + return None + + if typename == "boolean": + # For some reason boolean is an exception to the naming convention + typename = "bool" + + try: + type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[ + typename + ].message_type + except KeyError: + # TODO: improve resolution of complext type like LIST? + print("Unsupported type", typename) + return None + + type_class = getattr(proto.Type, type_descriptor.name) + nullability = ( + proto.Type.Nullability.NULLABILITY_REQUIRED + if not nullable + else proto.Type.Nullability.NULLABILITY_NULLABLE + ) + return proto.Type(**{typename: type_class(nullability=nullability)}) @property def extension_uris_by_anchor(self): @@ -106,14 +152,31 @@ def extension_uris(self): def extensions(self): return list(self._functions.values()) + def signature(self, function_name, proto_argtypes): + def _normalize_arg_types(argtypes): + for argtype in argtypes: + kind = argtype.WhichOneof("kind") + if kind == "bool": + yield "boolean" + else: + yield kind + + return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}" + def function_anchor(self, function): return self._functions[function].function_anchor + def function_return_type(self, function): + return self._functions_return_type[function] + def extensions_for_functions(self, functions): uris_anchors = set() extensions = [] for f in functions: ext = self._functions[f] + if not ext.extension_uri_reference: + # Built-in function + continue uris_anchors.add(ext.extension_uri_reference) extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) From 6c930c804a1115fe95b10b6e90ff0730989066ed Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 15:18:59 +0200 Subject: [PATCH 07/18] Fix typo --- src/substrait/sql/extended_expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 0e82abd..7d04eeb 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -19,7 +19,7 @@ # Comparisons "eq": "equal", "nullsafeeq": "is_not_distinct_from", - "new": "not_equal", + "neq": "not_equal", "gt": "gt", "gte": "gte", "lt": "lt", From 03d4a2e342e27f398a49da4c393ecc440c6337c6 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 17:00:15 +0200 Subject: [PATCH 08/18] Fix loading of boolean functions --- src/substrait/sql/functions_catalog.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 4bd214d..5f6b0bb 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -35,7 +35,6 @@ def __init__(self): self._registered_extensions = {} self._functions = {} self._functions_return_type = {} - self._register_builtins() def load_standard_extensions(self, dirpath): for ext in self.STANDARD_EXTENSIONS: @@ -52,7 +51,8 @@ def load(self, dirpath, filename): function_name = function["name"] for impl in function.get("impls", []): # TODO: There seem to be some functions that have arguments without type. What to do? - argtypes = [t.get("value", "unknown") for t in impl.get("args", [])] + # TODO: improve support complext type like LIST? + argtypes = [t.get("value", "unknown").strip("?") for t in impl.get("args", [])] if not argtypes: signature = function_name else: @@ -96,17 +96,6 @@ def _register_extensions( ) self._functions_return_type[function] = functions_return_type[function] - def _register_builtins(self): - self._functions["not:boolean"] = ( - proto.SimpleExtensionDeclaration.ExtensionFunction( - name="not", - function_anchor=len(self._functions) + 1, - ) - ) - self._functions_return_type["not:boolean"] = proto.Type( - bool=proto.Type.Boolean() - ) - def _type_from_name(self, typename): nullable = False if typename.endswith("?"): From 2b4b1bdc26b759c1789a402a374817780e7a7667 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 17:23:05 +0200 Subject: [PATCH 09/18] Refactor passing around info about parsed expressions --- src/substrait/sql/extended_expression.py | 113 +++++++++++++++-------- src/substrait/sql/functions_catalog.py | 5 +- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 7d04eeb..89965f0 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -31,6 +31,10 @@ def parse_sql_extended_expression(catalog, schema, sql): + """Parse a SQL SELECT statement into an ExtendedExpression. + + Only supports SELECT statements with projections and WHERE clauses. + """ select = sqlglot.parse_one(sql) if not isinstance(select, sqlglot.expressions.Select): raise ValueError("a SELECT statement was expected") @@ -41,12 +45,13 @@ def parse_sql_extended_expression(catalog, schema, sql): project_expressions = [] projection_invoked_functions = set() for sqlexpr in select.expressions: - invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot( - sqlexpr - ) - projection_invoked_functions.update(invoked_functions) + parsed_expr = sqlglot_parser.expression_from_sqlglot(sqlexpr) + projection_invoked_functions.update(parsed_expr.invoked_functions) project_expressions.append( - proto.ExpressionReference(expression=expr, output_names=[output_name]) + proto.ExpressionReference( + expression=parsed_expr.expression, + output_names=[parsed_expr.output_name], + ) ) extension_uris, extensions = catalog.extensions_for_functions( projection_invoked_functions @@ -59,17 +64,19 @@ def parse_sql_extended_expression(catalog, schema, sql): ) # Handle WHERE clause in the SELECT statement. - invoked_functions_filter, _, filter_expr = sqlglot_parser.expression_from_sqlglot( + filter_parsed_expr = sqlglot_parser.expression_from_sqlglot( select.find(sqlglot.expressions.Where).this ) extension_uris, extensions = catalog.extensions_for_functions( - invoked_functions_filter + filter_parsed_expr.invoked_functions ) filter_extended_expr = proto.ExtendedExpression( extension_uris=extension_uris, extensions=extensions, base_schema=schema, - referred_expr=[proto.ExpressionReference(expression=filter_expr)], + referred_expr=[ + proto.ExpressionReference(expression=filter_parsed_expr.expression) + ], ) return projection_extended_expr, filter_extended_expr @@ -82,16 +89,13 @@ def __init__(self, functions_catalog, schema): self._counter = itertools.count() def expression_from_sqlglot(self, sqlglot_node): - invoked_functions = set() - output_name, _, substrait_expr = self._parse_expression( - sqlglot_node, invoked_functions - ) - return invoked_functions, output_name, substrait_expr + """Parse a SQLGlot expression into a Substrait Expression.""" + return self._parse_expression(sqlglot_node) - def _parse_expression(self, expr, invoked_functions): + def _parse_expression(self, expr): if isinstance(expr, sqlglot.expressions.Literal): if expr.is_string: - return ( + return ParsedSubstraitExpression( f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( @@ -99,7 +103,7 @@ def _parse_expression(self, expr, invoked_functions): ), ) elif expr.is_int: - return ( + return ParsedSubstraitExpression( f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( @@ -107,7 +111,7 @@ def _parse_expression(self, expr, invoked_functions): ), ) elif sqlglot.helper.is_float(expr.name): - return ( + return ParsedSubstraitExpression( f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( @@ -120,7 +124,7 @@ def _parse_expression(self, expr, invoked_functions): column_name = expr.output_name schema_field = list(self._schema.names).index(column_name) schema_type = self._schema.struct.types[schema_field] - return ( + return ParsedSubstraitExpression( column_name, schema_type, proto.Expression( @@ -134,39 +138,47 @@ def _parse_expression(self, expr, invoked_functions): ), ) elif isinstance(expr, sqlglot.expressions.Alias): - _, aliased_type, aliased_expr = self._parse_expression( - expr.this, invoked_functions - ) - return expr.output_name, aliased_type, aliased_expr + parsed_expression = self._parse_expression(expr.this) + return parsed_expression.duplicate(output_name=expr.output_name) elif expr.key in SQL_UNARY_FUNCTIONS: - argument_name, argument_type, argument = self._parse_expression( - expr.this, invoked_functions - ) + argument_parsed_expr = self._parse_expression(expr.this) function_name = SQL_UNARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( - self._parse_function_invokation(function_name, argument_type, argument) - ) - invoked_functions.add(signature) - result_name = f"{function_name}_{argument_name}_{next(self._counter)}" - return result_name, result_type, function_expression - elif expr.key in SQL_BINARY_FUNCTIONS: - left_name, left_type, left = self._parse_expression( - expr.left, invoked_functions + self._parse_function_invokation( + function_name, + argument_parsed_expr.type, + argument_parsed_expr.expression, + ) ) - right_name, right_type, right = self._parse_expression( - expr.right, invoked_functions + result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + argument_parsed_expr.invoked_functions | {signature}, ) + elif expr.key in SQL_BINARY_FUNCTIONS: + left_parsed_expr = self._parse_expression(expr.left) + right_parsed_expr = self._parse_expression(expr.right) function_name = SQL_BINARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( self._parse_function_invokation( - function_name, left_type, left, right_type, right + function_name, + left_parsed_expr.type, + left_parsed_expr.expression, + right_parsed_expr.type, + right_parsed_expr.expression, ) ) - invoked_functions.add(signature) - result_name = ( - f"{left_name}_{function_name}_{right_name}_{next(self._counter)}" + result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + left_parsed_expr.invoked_functions + | right_parsed_expr.invoked_functions + | {signature}, ) - return result_name, result_type, function_expression else: raise ValueError( f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" @@ -215,3 +227,24 @@ def _parse_function_invokation( ) ), ) + + +class ParsedSubstraitExpression: + def __init__(self, output_name, type, expression, invoked_functions=None): + self.expression = expression + self.output_name = output_name + self.type = type + + if invoked_functions is None: + invoked_functions = set() + self.invoked_functions = invoked_functions + + def duplicate( + self, output_name=None, type=None, expression=None, invoked_functions=None + ): + return ParsedSubstraitExpression( + output_name or self.output_name, + type or self.type, + expression or self.expression, + invoked_functions or self.invoked_functions, + ) diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 5f6b0bb..8d72871 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -52,7 +52,10 @@ def load(self, dirpath, filename): for impl in function.get("impls", []): # TODO: There seem to be some functions that have arguments without type. What to do? # TODO: improve support complext type like LIST? - argtypes = [t.get("value", "unknown").strip("?") for t in impl.get("args", [])] + argtypes = [ + t.get("value", "unknown").strip("?") + for t in impl.get("args", []) + ] if not argtypes: signature = function_name else: From ddad5e0b0f60519957379edc1b00b8b628942be6 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 18:10:28 +0200 Subject: [PATCH 10/18] Refactor function invocation generation --- src/substrait/sql/extended_expression.py | 70 +++++++++++++----------- 1 file changed, 38 insertions(+), 32 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 89965f0..57b8eb3 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -93,10 +93,16 @@ def expression_from_sqlglot(self, sqlglot_node): return self._parse_expression(sqlglot_node) def _parse_expression(self, expr): + """Parse a SQLGlot node and return a Substrait expression. + + This is the internal implementation, expected to be + invoked in a recursive manner to parse the whole + expression tree. + """ if isinstance(expr, sqlglot.expressions.Literal): if expr.is_string: return ParsedSubstraitExpression( - f"literal_{next(self._counter)}", + f"literal${next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression( literal=proto.Expression.Literal(string=expr.text) @@ -104,7 +110,7 @@ def _parse_expression(self, expr): ) elif expr.is_int: return ParsedSubstraitExpression( - f"literal_{next(self._counter)}", + f"literal${next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression( literal=proto.Expression.Literal(i32=int(expr.name)) @@ -112,7 +118,7 @@ def _parse_expression(self, expr): ) elif sqlglot.helper.is_float(expr.name): return ParsedSubstraitExpression( - f"literal_{next(self._counter)}", + f"literal${next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression( literal=proto.Expression.Literal(float=float(expr.name)) @@ -144,11 +150,7 @@ def _parse_expression(self, expr): argument_parsed_expr = self._parse_expression(expr.this) function_name = SQL_UNARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( - self._parse_function_invokation( - function_name, - argument_parsed_expr.type, - argument_parsed_expr.expression, - ) + self._parse_function_invokation(function_name, argument_parsed_expr) ) result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" return ParsedSubstraitExpression( @@ -163,11 +165,7 @@ def _parse_expression(self, expr): function_name = SQL_BINARY_FUNCTIONS[expr.key] signature, result_type, function_expression = ( self._parse_function_invokation( - function_name, - left_parsed_expr.type, - left_parsed_expr.expression, - right_parsed_expr.type, - right_parsed_expr.expression, + function_name, left_parsed_expr, right_parsed_expr ) ) result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" @@ -185,24 +183,27 @@ def _parse_expression(self, expr): ) def _parse_function_invokation( - self, function_name, left_type, left, right_type=None, right=None + self, function_name, argument_parsed_expr, *additional_arguments ): - binary = False - argtypes = [left_type] - if right_type or right: - binary = True - argtypes.append(right_type) - signature = self._functions_catalog.signature(function_name, argtypes) + """Generates a Substrait function invokation expression. + + The function invocation will be generated from the function name + and the arguments as ParsedSubstraitExpression. + + Returns the function signature, the return type and the + invokation expression itself. + """ + arguments = [argument_parsed_expr] + list(additional_arguments) + signature = self._functions_catalog.signature( + function_name, proto_argtypes=[arg.type for arg in arguments] + ) try: function_anchor = self._functions_catalog.function_anchor(signature) except KeyError: # No function found with the exact types, try any1_any1 version # TODO: What about cases like i32_any1? What about any instead of any1? - if binary: - signature = f"{function_name}:any1_any1" - else: - signature = f"{function_name}:any1" + signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}" function_anchor = self._functions_catalog.function_anchor(signature) function_return_type = self._functions_catalog.function_return_type(signature) @@ -216,20 +217,25 @@ def _parse_function_invokation( proto.Expression( scalar_function=proto.Expression.ScalarFunction( function_reference=function_anchor, - arguments=( - [ - proto.FunctionArgument(value=left), - proto.FunctionArgument(value=right), - ] - if binary - else [proto.FunctionArgument(value=left)] - ), + arguments=[ + proto.FunctionArgument(value=arg.expression) + for arg in arguments + ], ) ), ) class ParsedSubstraitExpression: + """A Substrait expression that was parsed from a SQLGlot node. + + This stores the expression itself, with an associated output name + in case it is required to emit projections. + + It also stores the type of the expression (i64, string, boolean, etc...) + and the functions that the expression in going to invoke. + """ + def __init__(self, output_name, type, expression, invoked_functions=None): self.expression = expression self.output_name = output_name From 8a91012a681b237551a04fa61e3382191b411eda Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 18:11:41 +0200 Subject: [PATCH 11/18] record TODO --- src/substrait/sql/extended_expression.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 57b8eb3..f3b428e 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -203,6 +203,7 @@ def _parse_function_invokation( except KeyError: # No function found with the exact types, try any1_any1 version # TODO: What about cases like i32_any1? What about any instead of any1? + # TODO: What about optional arguments? IE: "i32_i32?" signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}" function_anchor = self._functions_catalog.function_anchor(signature) From 7688fc56c4837b8e2be046076f9c0b4bb04f23fe Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Thu, 18 Apr 2024 19:20:59 +0200 Subject: [PATCH 12/18] Dynamic dispatch of parsing --- src/substrait/sql/extended_expression.py | 156 ++++++++++++----------- src/substrait/sql/utils.py | 16 +++ 2 files changed, 100 insertions(+), 72 deletions(-) create mode 100644 src/substrait/sql/utils.py diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index f3b428e..74ec3c4 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -3,6 +3,8 @@ import sqlglot from substrait import proto +from .utils import DispatchRegistry + SQL_UNARY_FUNCTIONS = {"not": "not"} SQL_BINARY_FUNCTIONS = { @@ -83,6 +85,8 @@ def parse_sql_extended_expression(catalog, schema, sql): class SQLGlotParser: + DISPATCH_REGISTRY = DispatchRegistry() + def __init__(self, functions_catalog, schema): self._functions_catalog = functions_catalog self._schema = schema @@ -99,88 +103,96 @@ def _parse_expression(self, expr): invoked in a recursive manner to parse the whole expression tree. """ - if isinstance(expr, sqlglot.expressions.Literal): - if expr.is_string: - return ParsedSubstraitExpression( - f"literal${next(self._counter)}", - proto.Type(string=proto.Type.String()), - proto.Expression( - literal=proto.Expression.Literal(string=expr.text) - ), - ) - elif expr.is_int: - return ParsedSubstraitExpression( - f"literal${next(self._counter)}", - proto.Type(i32=proto.Type.I32()), - proto.Expression( - literal=proto.Expression.Literal(i32=int(expr.name)) - ), - ) - elif sqlglot.helper.is_float(expr.name): - return ParsedSubstraitExpression( - f"literal${next(self._counter)}", - proto.Type(fp32=proto.Type.FP32()), - proto.Expression( - literal=proto.Expression.Literal(float=float(expr.name)) - ), - ) - else: - raise ValueError(f"Unsupporter literal: {expr.text}") - elif isinstance(expr, sqlglot.expressions.Column): - column_name = expr.output_name - schema_field = list(self._schema.names).index(column_name) - schema_type = self._schema.struct.types[schema_field] + expr_class = expr.__class__ + return self.DISPATCH_REGISTRY[expr_class](self, expr) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Literal) + def _parse_Literal(self, expr): + if expr.is_string: return ParsedSubstraitExpression( - column_name, - schema_type, + f"literal${next(self._counter)}", + proto.Type(string=proto.Type.String()), proto.Expression( - selection=proto.Expression.FieldReference( - direct_reference=proto.Expression.ReferenceSegment( - struct_field=proto.Expression.ReferenceSegment.StructField( - field=schema_field - ) - ) - ) + literal=proto.Expression.Literal(string=expr.text) ), ) - elif isinstance(expr, sqlglot.expressions.Alias): - parsed_expression = self._parse_expression(expr.this) - return parsed_expression.duplicate(output_name=expr.output_name) - elif expr.key in SQL_UNARY_FUNCTIONS: - argument_parsed_expr = self._parse_expression(expr.this) - function_name = SQL_UNARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = ( - self._parse_function_invokation(function_name, argument_parsed_expr) - ) - result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + elif expr.is_int: return ParsedSubstraitExpression( - result_name, - result_type, - function_expression, - argument_parsed_expr.invoked_functions | {signature}, - ) - elif expr.key in SQL_BINARY_FUNCTIONS: - left_parsed_expr = self._parse_expression(expr.left) - right_parsed_expr = self._parse_expression(expr.right) - function_name = SQL_BINARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = ( - self._parse_function_invokation( - function_name, left_parsed_expr, right_parsed_expr - ) + f"literal${next(self._counter)}", + proto.Type(i32=proto.Type.I32()), + proto.Expression( + literal=proto.Expression.Literal(i32=int(expr.name)) + ), ) - result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" + elif sqlglot.helper.is_float(expr.name): return ParsedSubstraitExpression( - result_name, - result_type, - function_expression, - left_parsed_expr.invoked_functions - | right_parsed_expr.invoked_functions - | {signature}, + f"literal${next(self._counter)}", + proto.Type(fp32=proto.Type.FP32()), + proto.Expression( + literal=proto.Expression.Literal(float=float(expr.name)) + ), ) else: - raise ValueError( - f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}" + raise ValueError(f"Unsupporter literal: {expr.text}") + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Column) + def _parse_Column(self, expr): + column_name = expr.output_name + schema_field = list(self._schema.names).index(column_name) + schema_type = self._schema.struct.types[schema_field] + return ParsedSubstraitExpression( + column_name, + schema_type, + proto.Expression( + selection=proto.Expression.FieldReference( + direct_reference=proto.Expression.ReferenceSegment( + struct_field=proto.Expression.ReferenceSegment.StructField( + field=schema_field + ) + ) + ) + ), + ) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Alias) + def _parse_Alias(self, expr): + parsed_expression = self._parse_expression(expr.this) + return parsed_expression.duplicate(output_name=expr.output_name) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Binary) + def _parser_Binary(self, expr): + left_parsed_expr = self._parse_expression(expr.left) + right_parsed_expr = self._parse_expression(expr.right) + function_name = SQL_BINARY_FUNCTIONS[expr.key] + signature, result_type, function_expression = ( + self._parse_function_invokation( + function_name, left_parsed_expr, right_parsed_expr ) + ) + result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + left_parsed_expr.invoked_functions + | right_parsed_expr.invoked_functions + | {signature}, + ) + + @DISPATCH_REGISTRY.register(sqlglot.expressions.Unary) + def _parse_Unary(self, expr): + argument_parsed_expr = self._parse_expression(expr.this) + function_name = SQL_UNARY_FUNCTIONS[expr.key] + signature, result_type, function_expression = ( + self._parse_function_invokation(function_name, argument_parsed_expr) + ) + result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + argument_parsed_expr.invoked_functions | {signature}, + ) def _parse_function_invokation( self, function_name, argument_parsed_expr, *additional_arguments diff --git a/src/substrait/sql/utils.py b/src/substrait/sql/utils.py new file mode 100644 index 0000000..13eeeeb --- /dev/null +++ b/src/substrait/sql/utils.py @@ -0,0 +1,16 @@ +class DispatchRegistry: + def __init__(self): + self._registry = {} + + def register(self, cls): + def decorator(func): + self._registry[cls] = func + return func + return decorator + + def __getitem__(self, cls): + for dispatch_cls, func in self._registry.items(): + if issubclass(cls, dispatch_cls): + return func + else: + raise ValueError(f"Unsupported SQL Node type: {cls}") \ No newline at end of file From 1a5fcc761db60e165bf4091663a63f4735417052 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Fri, 19 Apr 2024 16:26:43 +0200 Subject: [PATCH 13/18] Tweak dynamic dispatch and handle variadic and, or etc... --- src/substrait/sql/extended_expression.py | 36 ++++++++---------------- src/substrait/sql/functions_catalog.py | 5 ++++ src/substrait/sql/utils.py | 30 +++++++++++++++++--- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 74ec3c4..fe78e1a 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -92,37 +92,25 @@ def __init__(self, functions_catalog, schema): self._schema = schema self._counter = itertools.count() + self._parse_expression = self.DISPATCH_REGISTRY.bind(self) + def expression_from_sqlglot(self, sqlglot_node): """Parse a SQLGlot expression into a Substrait Expression.""" return self._parse_expression(sqlglot_node) - def _parse_expression(self, expr): - """Parse a SQLGlot node and return a Substrait expression. - - This is the internal implementation, expected to be - invoked in a recursive manner to parse the whole - expression tree. - """ - expr_class = expr.__class__ - return self.DISPATCH_REGISTRY[expr_class](self, expr) - @DISPATCH_REGISTRY.register(sqlglot.expressions.Literal) def _parse_Literal(self, expr): if expr.is_string: return ParsedSubstraitExpression( f"literal${next(self._counter)}", proto.Type(string=proto.Type.String()), - proto.Expression( - literal=proto.Expression.Literal(string=expr.text) - ), + proto.Expression(literal=proto.Expression.Literal(string=expr.text)), ) elif expr.is_int: return ParsedSubstraitExpression( f"literal${next(self._counter)}", proto.Type(i32=proto.Type.I32()), - proto.Expression( - literal=proto.Expression.Literal(i32=int(expr.name)) - ), + proto.Expression(literal=proto.Expression.Literal(i32=int(expr.name))), ) elif sqlglot.helper.is_float(expr.name): return ParsedSubstraitExpression( @@ -134,7 +122,7 @@ def _parse_Literal(self, expr): ) else: raise ValueError(f"Unsupporter literal: {expr.text}") - + @DISPATCH_REGISTRY.register(sqlglot.expressions.Column) def _parse_Column(self, expr): column_name = expr.output_name @@ -164,10 +152,8 @@ def _parser_Binary(self, expr): left_parsed_expr = self._parse_expression(expr.left) right_parsed_expr = self._parse_expression(expr.right) function_name = SQL_BINARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = ( - self._parse_function_invokation( - function_name, left_parsed_expr, right_parsed_expr - ) + signature, result_type, function_expression = self._parse_function_invokation( + function_name, left_parsed_expr, right_parsed_expr ) result_name = f"{function_name}_{left_parsed_expr.output_name}_{right_parsed_expr.output_name}_{next(self._counter)}" return ParsedSubstraitExpression( @@ -183,10 +169,12 @@ def _parser_Binary(self, expr): def _parse_Unary(self, expr): argument_parsed_expr = self._parse_expression(expr.this) function_name = SQL_UNARY_FUNCTIONS[expr.key] - signature, result_type, function_expression = ( - self._parse_function_invokation(function_name, argument_parsed_expr) + signature, result_type, function_expression = self._parse_function_invokation( + function_name, argument_parsed_expr + ) + result_name = ( + f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" ) - result_name = f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" return ParsedSubstraitExpression( result_name, result_type, diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 8d72871..0430e8e 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -56,11 +56,16 @@ def load(self, dirpath, filename): t.get("value", "unknown").strip("?") for t in impl.get("args", []) ] + if impl.get("variadic", False): + # TODO: Variadic functions. + argtypes *= 2 + if not argtypes: signature = function_name else: signature = f"{function_name}:{'_'.join(argtypes)}" loaded_functions.add(signature) + print("Loaded function", signature) functions_return_type[signature] = self._type_from_name( impl["return"] ) diff --git a/src/substrait/sql/utils.py b/src/substrait/sql/utils.py index 13eeeeb..c73a531 100644 --- a/src/substrait/sql/utils.py +++ b/src/substrait/sql/utils.py @@ -1,4 +1,19 @@ +import types + + class DispatchRegistry: + """Dispatch a function based on the class of the argument. + + This class allows to register a function to execute for a specific class + and expose this as a method of an object which will be dispatched + based on the argument. + + It is similar to functools.singledispatch but it allows more + customization in case the dispatch rules grow in complexity + and works for class methods as well + (singledispatch supports methods only in more recent versions) + """ + def __init__(self): self._registry = {} @@ -6,11 +21,18 @@ def register(self, cls): def decorator(func): self._registry[cls] = func return func + return decorator - - def __getitem__(self, cls): + + def bind(self, obj): + return types.MethodType(self, obj) + + def __getitem__(self, argument): for dispatch_cls, func in self._registry.items(): - if issubclass(cls, dispatch_cls): + if isinstance(argument, dispatch_cls): return func else: - raise ValueError(f"Unsupported SQL Node type: {cls}") \ No newline at end of file + raise ValueError(f"Unsupported SQL Node type: {cls}") + + def __call__(self, obj, dispatch_argument, *args, **kwargs): + return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs) From 3c9f6ebf4d03ed77451ffb115b2923a7a7f56c9d Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 8 May 2024 18:39:29 +0200 Subject: [PATCH 14/18] Refactor FunctionsCatalog and improve functions lookup --- src/substrait/sql/extended_expression.py | 29 +-- src/substrait/sql/functions_catalog.py | 257 ++++++++++++++++------- src/substrait/sql/utils.py | 4 +- 3 files changed, 194 insertions(+), 96 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index fe78e1a..06d621a 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -194,30 +194,20 @@ def _parse_function_invokation( invokation expression itself. """ arguments = [argument_parsed_expr] + list(additional_arguments) - signature = self._functions_catalog.signature( + signature = self._functions_catalog.make_signature( function_name, proto_argtypes=[arg.type for arg in arguments] ) - try: - function_anchor = self._functions_catalog.function_anchor(signature) - except KeyError: - # No function found with the exact types, try any1_any1 version - # TODO: What about cases like i32_any1? What about any instead of any1? - # TODO: What about optional arguments? IE: "i32_i32?" - signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}" - function_anchor = self._functions_catalog.function_anchor(signature) - - function_return_type = self._functions_catalog.function_return_type(signature) - if function_return_type is None: - print("No return type for", signature) - # TODO: Is this the right way to handle this? - function_return_type = left_type + registered_function = self._functions_catalog.lookup_function(signature) + if registered_function is None: + raise KeyError(f"Function not found: {signature}") + return ( - signature, - function_return_type, + registered_function.signature, + registered_function.return_type, proto.Expression( scalar_function=proto.Expression.ScalarFunction( - function_reference=function_anchor, + function_reference=registered_function.function_anchor, arguments=[ proto.FunctionArgument(value=arg.expression) for arg in arguments @@ -255,3 +245,6 @@ def duplicate( expression or self.expression, invoked_functions or self.invoked_functions, ) + + def __repr__(self): + return f"" diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index 0430e8e..d6b175d 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -1,8 +1,80 @@ +import os import pathlib +from collections.abc import Iterable import yaml -from substrait import proto +from substrait.gen.proto.type_pb2 import Type as SubstraitType +from substrait.gen.proto.extensions.extensions_pb2 import ( + SimpleExtensionURI, + SimpleExtensionDeclaration, +) + + +class RegisteredSubstraitFunction: + """A Substrait function loaded from an extension file. + + The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction + and will use them to generate the necessary extension URIs and extensions. + """ + + def __init__(self, signature: str, function_anchor: int | None, impl: dict): + self.signature = signature + self.function_anchor = function_anchor + self.variadic = impl.get("variadic", False) + + if "return" in impl: + self.return_type = self._type_from_name(impl["return"]) + else: + # We do always need a return type + # to know which type to propagate up to the invoker + _, argtypes = FunctionsCatalog.parse_signature(signature) + # TODO: Is this the right way to handle this? + self.return_type = self._type_from_name(argtypes[0]) + + @property + def name(self) -> str: + name, _ = FunctionsCatalog.parse_signature(self.signature) + return name + + @property + def arguments(self) -> list[str]: + _, argtypes = FunctionsCatalog.parse_signature(self.signature) + return argtypes + + @property + def arguments_type(self) -> list[SubstraitType | None]: + return [self._type_from_name(arg) for arg in self.arguments] + + def _type_from_name(self, typename: str) -> SubstraitType | None: + nullable = False + if typename.endswith("?"): + nullable = True + + typename = typename.strip("?") + if typename in ("any", "any1"): + return None + + if typename == "boolean": + # For some reason boolean is an exception to the naming convention + typename = "bool" + + try: + type_descriptor = SubstraitType.DESCRIPTOR.fields_by_name[ + typename + ].message_type + except KeyError: + # TODO: improve resolution of complext type like LIST? + print("Unsupported type", typename) + return None + + type_class = getattr(SubstraitType, type_descriptor.name) + nullability = ( + SubstraitType.Nullability.NULLABILITY_REQUIRED + if not nullable + else SubstraitType.Nullability.NULLABILITY_NULLABLE + ) + return SubstraitType(**{typename: type_class(nullability=nullability)}) class FunctionsCatalog: @@ -32,20 +104,21 @@ class FunctionsCatalog: ) def __init__(self): - self._registered_extensions = {} + self._substrait_extension_uris = {} + self._substrait_extension_functions = {} self._functions = {} - self._functions_return_type = {} - def load_standard_extensions(self, dirpath): + def load_standard_extensions(self, dirpath: str | os.PathLike): + """Load all standard substrait extensions from the target directory.""" for ext in self.STANDARD_EXTENSIONS: self.load(dirpath, ext) - def load(self, dirpath, filename): + def load(self, dirpath: str | os.PathLike, filename: str): + """Load an extension from a YAML file in a target directory.""" with open(pathlib.Path(dirpath) / filename.strip("/")) as f: sections = yaml.safe_load(f) - loaded_functions = set() - functions_return_type = {} + loaded_functions = {} for functions in sections.values(): for function in functions: function_name = function["name"] @@ -56,100 +129,80 @@ def load(self, dirpath, filename): t.get("value", "unknown").strip("?") for t in impl.get("args", []) ] - if impl.get("variadic", False): - # TODO: Variadic functions. - argtypes *= 2 - if not argtypes: signature = function_name else: signature = f"{function_name}:{'_'.join(argtypes)}" - loaded_functions.add(signature) - print("Loaded function", signature) - functions_return_type[signature] = self._type_from_name( - impl["return"] + loaded_functions[signature] = RegisteredSubstraitFunction( + signature, None, impl ) - self._register_extensions(filename, loaded_functions, functions_return_type) + self._register_extensions(filename, loaded_functions) def _register_extensions( - self, extension_uri, loaded_functions, functions_return_type + self, + extension_uri: str, + loaded_functions: dict[str, RegisteredSubstraitFunction], ): - if extension_uri not in self._registered_extensions: - ext_anchor_id = len(self._registered_extensions) + 1 - self._registered_extensions[extension_uri] = proto.SimpleExtensionURI( + if extension_uri not in self._substrait_extension_uris: + ext_anchor_id = len(self._substrait_extension_uris) + 1 + self._substrait_extension_uris[extension_uri] = SimpleExtensionURI( extension_uri_anchor=ext_anchor_id, uri=extension_uri ) - for function in loaded_functions: - if function in self._functions: + for signature, registered_function in loaded_functions.items(): + if signature in self._substrait_extension_functions: extensions_by_anchor = self.extension_uris_by_anchor - existing_function = self._functions[function] + existing_function = self._substrait_extension_functions[signature] function_extension = extensions_by_anchor[ existing_function.extension_uri_reference ].uri raise ValueError( f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}" ) - extension_anchor = self._registered_extensions[ + extension_anchor = self._substrait_extension_uris[ extension_uri ].extension_uri_anchor - function_anchor = len(self._functions) + 1 - self._functions[function] = ( - proto.SimpleExtensionDeclaration.ExtensionFunction( + function_anchor = len(self._substrait_extension_functions) + 1 + self._substrait_extension_functions[signature] = ( + SimpleExtensionDeclaration.ExtensionFunction( extension_uri_reference=extension_anchor, - name=function, + name=signature, function_anchor=function_anchor, ) ) - self._functions_return_type[function] = functions_return_type[function] - - def _type_from_name(self, typename): - nullable = False - if typename.endswith("?"): - nullable = True - - typename = typename.strip("?") - if typename in ("any", "any1"): - return None - - if typename == "boolean": - # For some reason boolean is an exception to the naming convention - typename = "bool" - - try: - type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[ - typename - ].message_type - except KeyError: - # TODO: improve resolution of complext type like LIST? - print("Unsupported type", typename) - return None - - type_class = getattr(proto.Type, type_descriptor.name) - nullability = ( - proto.Type.Nullability.NULLABILITY_REQUIRED - if not nullable - else proto.Type.Nullability.NULLABILITY_NULLABLE - ) - return proto.Type(**{typename: type_class(nullability=nullability)}) + registered_function.function_anchor = function_anchor + self._functions.setdefault(registered_function.name, []).append( + registered_function + ) @property - def extension_uris_by_anchor(self): + def extension_uris_by_anchor(self) -> dict[int, SimpleExtensionURI]: return { ext.extension_uri_anchor: ext - for ext in self._registered_extensions.values() + for ext in self._substrait_extension_uris.values() } @property - def extension_uris(self): - return list(self._registered_extensions.values()) + def extension_uris(self) -> list[SimpleExtensionURI]: + return list(self._substrait_extension_uris.values()) @property - def extensions(self): - return list(self._functions.values()) + def extensions_functions( + self, + ) -> list[SimpleExtensionDeclaration.ExtensionFunction]: + return list(self._substrait_extension_functions.values()) + + @classmethod + def make_signature( + cls, function_name: str, proto_argtypes: Iterable[SubstraitType] + ): + """Create a function signature from a function name and substrait types. + + The signature is generated according to Function Signature Compound Names + as described in the Substrait documentation. + """ - def signature(self, function_name, proto_argtypes): def _normalize_arg_types(argtypes): for argtype in argtypes: kind = argtype.WhichOneof("kind") @@ -160,23 +213,73 @@ def _normalize_arg_types(argtypes): return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}" - def function_anchor(self, function): - return self._functions[function].function_anchor + @classmethod + def parse_signature(cls, signature: str) -> tuple[str, list[str]]: + """Parse a function signature and returns name and type names""" + try: + function_name, signature_args = signature.split(":") + except ValueError: + function_name = signature + argtypes = [] + else: + argtypes = signature_args.split("_") + return function_name, argtypes - def function_return_type(self, function): - return self._functions_return_type[function] + def extensions_for_functions( + self, function_signatures: Iterable[str] + ) -> tuple[list[SimpleExtensionURI], list[SimpleExtensionDeclaration]]: + """Given a set of function signatures, return the necessary extensions. - def extensions_for_functions(self, functions): + The function will return the URIs of the extensions and the extension + that have to be declared in the plan to use the functions. + """ uris_anchors = set() extensions = [] - for f in functions: - ext = self._functions[f] - if not ext.extension_uri_reference: - # Built-in function - continue + for f in function_signatures: + ext = self._substrait_extension_functions[f] uris_anchors.add(ext.extension_uri_reference) - extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext)) + extensions.append(SimpleExtensionDeclaration(extension_function=ext)) uris_by_anchor = self.extension_uris_by_anchor extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors] return extension_uris, extensions + + def lookup_function(self, signature: str) -> RegisteredSubstraitFunction | None: + """Given the signature of a function invocation, return the matching function.""" + function_name, invocation_argtypes = self.parse_signature(signature) + + functions = self._functions.get(function_name) + if not functions: + # No function with such a name at all. + return None + + is_variadic = functions[0].variadic + if is_variadic: + # If it's variadic we care about only the first parameter. + invocation_argtypes = invocation_argtypes[:1] + + found_function = None + for function in functions: + accepted_function_arguments = function.arguments + for argidx, argtype in enumerate(invocation_argtypes): + try: + accepted_argument = accepted_function_arguments[argidx] + except IndexError: + # More arguments than available were provided + break + if accepted_argument != argtype and accepted_argument not in ( + "any", + "any1", + ): + break + else: + if argidx < len(accepted_function_arguments) - 1: + # Not enough arguments were provided + remainder = accepted_function_arguments[argidx + 1 :] + if all(arg.endswith("?") for arg in remainder): + # All remaining arguments are optional + found_function = function + else: + found_function = function + + return found_function diff --git a/src/substrait/sql/utils.py b/src/substrait/sql/utils.py index c73a531..9ffad36 100644 --- a/src/substrait/sql/utils.py +++ b/src/substrait/sql/utils.py @@ -32,7 +32,9 @@ def __getitem__(self, argument): if isinstance(argument, dispatch_cls): return func else: - raise ValueError(f"Unsupported SQL Node type: {cls}") + raise ValueError( + f"Unsupported SQL Node type: {argument.__class__.__name__} -> {argument}" + ) def __call__(self, obj, dispatch_argument, *args, **kwargs): return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs) From e9508a41241949a6d58feb381fafd482e3f460fd Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 8 May 2024 19:12:14 +0200 Subject: [PATCH 15/18] Migrate function resolution from keys to classes --- src/substrait/sql/extended_expression.py | 44 ++++++++++++------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 06d621a..44d62ae 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -6,29 +6,29 @@ from .utils import DispatchRegistry -SQL_UNARY_FUNCTIONS = {"not": "not"} -SQL_BINARY_FUNCTIONS = { +SQL_FUNCTIONS = { # Arithmetic - "add": "add", - "div": "div", - "mul": "mul", - "sub": "sub", - "mod": "modulus", - "bitwiseand": "bitwise_and", - "bitwiseor": "bitwise_or", - "bitwisexor": "bitwise_xor", - "bitwiseor": "bitwise_or", + sqlglot.expressions.Add: "add", + sqlglot.expressions.Div: "div", + sqlglot.expressions.Mul: "mul", + sqlglot.expressions.Sub: "sub", + sqlglot.expressions.Mod: "modulus", + sqlglot.expressions.BitwiseAnd: "bitwise_and", + sqlglot.expressions.BitwiseOr: "bitwise_or", + sqlglot.expressions.BitwiseXor: "bitwise_xor", + sqlglot.expressions.BitwiseNot: "bitwise_not", # Comparisons - "eq": "equal", - "nullsafeeq": "is_not_distinct_from", - "neq": "not_equal", - "gt": "gt", - "gte": "gte", - "lt": "lt", - "lte": "lte", + sqlglot.expressions.EQ: "equal", + sqlglot.expressions.NullSafeEQ: "is_not_distinct_from", + sqlglot.expressions.NEQ: "not_equal", + sqlglot.expressions.GT: "gt", + sqlglot.expressions.GTE: "gte", + sqlglot.expressions.LT: "lt", + sqlglot.expressions.LTE: "lte", # logical - "and": "and", - "or": "or", + sqlglot.expressions.And: "and", + sqlglot.expressions.Or: "or", + sqlglot.expressions.Not: "not", } @@ -151,7 +151,7 @@ def _parse_Alias(self, expr): def _parser_Binary(self, expr): left_parsed_expr = self._parse_expression(expr.left) right_parsed_expr = self._parse_expression(expr.right) - function_name = SQL_BINARY_FUNCTIONS[expr.key] + function_name = SQL_FUNCTIONS[type(expr)] signature, result_type, function_expression = self._parse_function_invokation( function_name, left_parsed_expr, right_parsed_expr ) @@ -168,7 +168,7 @@ def _parser_Binary(self, expr): @DISPATCH_REGISTRY.register(sqlglot.expressions.Unary) def _parse_Unary(self, expr): argument_parsed_expr = self._parse_expression(expr.this) - function_name = SQL_UNARY_FUNCTIONS[expr.key] + function_name = SQL_FUNCTIONS[type(expr)] signature, result_type, function_expression = self._parse_function_invokation( function_name, argument_parsed_expr ) From b90d55db02b0b499156b732676c4a2148303d23e Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 8 May 2024 19:20:21 +0200 Subject: [PATCH 16/18] Improve conversion of typenames to types --- src/substrait/sql/functions_catalog.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/substrait/sql/functions_catalog.py b/src/substrait/sql/functions_catalog.py index d6b175d..4e026b7 100644 --- a/src/substrait/sql/functions_catalog.py +++ b/src/substrait/sql/functions_catalog.py @@ -47,6 +47,10 @@ def arguments_type(self) -> list[SubstraitType | None]: return [self._type_from_name(arg) for arg in self.arguments] def _type_from_name(self, typename: str) -> SubstraitType | None: + # TODO: improve support complext type like LIST? + typename, *_ = typename.split("<", 1) + typename = typename.lower() + nullable = False if typename.endswith("?"): nullable = True From a112ff23d99c37a8a4b09bde49c3171113580168 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 8 May 2024 19:48:03 +0200 Subject: [PATCH 17/18] Handle IS NULL, IS NOT NULL and NaN --- src/substrait/sql/extended_expression.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 44d62ae..55c9663 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -25,6 +25,7 @@ sqlglot.expressions.GTE: "gte", sqlglot.expressions.LT: "lt", sqlglot.expressions.LTE: "lte", + sqlglot.expressions.IsNan: "is_nan", # logical sqlglot.expressions.And: "and", sqlglot.expressions.Or: "or", @@ -147,6 +148,28 @@ def _parse_Alias(self, expr): parsed_expression = self._parse_expression(expr.this) return parsed_expression.duplicate(output_name=expr.output_name) + @DISPATCH_REGISTRY.register(sqlglot.expressions.Is) + def _parse_IS(self, expr): + # IS NULL is a special case because in SQLGlot is a binary expression with argument + # while in Substrait there are only the is_null and is_not_null unary functions + argument_parsed_expr = self._parse_expression(expr.left) + if isinstance(expr.right, sqlglot.expressions.Null): + function_name = "is_null" + else: + raise ValueError(f"Unsupported IS expression: {expr}") + signature, result_type, function_expression = self._parse_function_invokation( + function_name, argument_parsed_expr + ) + result_name = ( + f"{function_name}_{argument_parsed_expr.output_name}_{next(self._counter)}" + ) + return ParsedSubstraitExpression( + result_name, + result_type, + function_expression, + argument_parsed_expr.invoked_functions | {signature}, + ) + @DISPATCH_REGISTRY.register(sqlglot.expressions.Binary) def _parser_Binary(self, expr): left_parsed_expr = self._parse_expression(expr.left) From 5d999be12fddd69da715d2101bcaaa3400b5e8f0 Mon Sep 17 00:00:00 2001 From: Alessandro Molina Date: Wed, 15 May 2024 16:27:22 +0200 Subject: [PATCH 18/18] Fix string literals --- src/substrait/sql/__init__.py | 1 + src/substrait/sql/extended_expression.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/substrait/sql/__init__.py b/src/substrait/sql/__init__.py index 4dea9c5..9dbd56a 100644 --- a/src/substrait/sql/__init__.py +++ b/src/substrait/sql/__init__.py @@ -1 +1,2 @@ from .extended_expression import parse_sql_extended_expression +from .functions_catalog import FunctionsCatalog diff --git a/src/substrait/sql/extended_expression.py b/src/substrait/sql/extended_expression.py index 55c9663..7d41b3c 100644 --- a/src/substrait/sql/extended_expression.py +++ b/src/substrait/sql/extended_expression.py @@ -105,7 +105,7 @@ def _parse_Literal(self, expr): return ParsedSubstraitExpression( f"literal${next(self._counter)}", proto.Type(string=proto.Type.String()), - proto.Expression(literal=proto.Expression.Literal(string=expr.text)), + proto.Expression(literal=proto.Expression.Literal(string=expr.name)), ) elif expr.is_int: return ParsedSubstraitExpression(