From 55dba9f93c94232e2bac0e8121c04af35e26a57a Mon Sep 17 00:00:00 2001 From: rubin <86082354+rubensantoniorosa2704@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:58:51 -0300 Subject: [PATCH 1/2] feat: improve MySQL mathematical expression type inference --- internal/compiler/output_columns.go | 101 +++++++++++++++++++++++++++- internal/sql/lang/operator.go | 1 + 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index b0a15e6ac4..6621bbf16e 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -155,7 +155,9 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er // TODO: Generate a name for these operations cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) case lang.IsMathematicalOperator(op): - cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + // Improve type inference for mathematical expressions + dataType, notNull := c.inferMathExpressionType(n, tables, op) + cols = append(cols, &Column{Name: name, DataType: dataType, NotNull: notNull}) default: cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false}) } @@ -770,3 +772,100 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) return nil } + +// inferMathExpressionType attempts to infer the data type of a mathematical expression +// by analyzing its operands and the operation being performed +func (c *Compiler) inferMathExpressionType(expr *ast.A_Expr, tables []*Table, op string) (string, bool) { + // Try to infer types from left and right operands + leftType := c.inferOperandType(expr.Lexpr, tables) + rightType := c.inferOperandType(expr.Rexpr, tables) + + // Debug logging to understand what's happening + // fmt.Printf("DEBUG: Math expression %s: left=%s, right=%s\n", op, leftType, rightType) + + // Determine the result type based on operands and operation + resultType := c.combineTypes(leftType, rightType, op) + + // For now, assume nullable since we're dealing with database columns + // In a more sophisticated implementation, we could track nullability through the expression + notNull := false + + return resultType, notNull +} + +// inferOperandType tries to determine the type of an operand in an expression +func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) string { + switch n := operand.(type) { + case *ast.ColumnRef: + // Look up the column in the available tables + parts := stringSlice(n.Fields) + var name string + if len(parts) >= 1 { + name = parts[len(parts)-1] // Get the column name (last part) + } + + for _, table := range tables { + for _, col := range table.Columns { + if col.Name == name { + return col.DataType + } + } + } + return "any" + case *ast.A_Const: + // Determine type based on constant value + switch n.Val.(type) { + case *ast.Integer: + return "int" + case *ast.Float: + return "float" + case *ast.String: + return "text" + default: + return "any" + } + case *ast.A_Expr: + // Recursive case for nested expressions + if n.Name != nil { + nestedOp := astutils.Join(n.Name, "") + if lang.IsMathematicalOperator(nestedOp) { + resultType, _ := c.inferMathExpressionType(n, tables, nestedOp) + return resultType + } + } + return "any" + default: + return "any" + } +} + +// combineTypes determines the result type when combining two operand types with an operation +func (c *Compiler) combineTypes(leftType, rightType, op string) string { + // Handle division specially - division operations typically result in float + if op == "/" { + // If either operand is float, result is float + if leftType == "float" || rightType == "float" { + return "float" + } + // Even integer division might want to be float in many cases + // For safety, return float for division unless both operands are clearly non-numeric + if leftType != "text" && rightType != "text" { + return "float" + } + } + + // For other mathematical operations + switch { + case leftType == "float" || rightType == "float": + return "float" + case leftType == "int" && rightType == "int": + return "int" + case leftType == "int" && rightType == "any": + return "int" + case leftType == "any" && rightType == "int": + return "int" + default: + // Default fallback + return "any" + } +} diff --git a/internal/sql/lang/operator.go b/internal/sql/lang/operator.go index cd5ef50e38..7d2556c855 100644 --- a/internal/sql/lang/operator.go +++ b/internal/sql/lang/operator.go @@ -23,6 +23,7 @@ func IsMathematicalOperator(s string) bool { case "-": case "*": case "/": + case "div": // MySQL division operator case "%": case "^": case "|/": From 9a4960f36370f1a215365649b2917eac419af4c1 Mon Sep 17 00:00:00 2001 From: rubin <86082354+rubensantoniorosa2704@users.noreply.github.com> Date: Thu, 30 Oct 2025 21:34:08 -0300 Subject: [PATCH 2/2] fix: resolve nullability issues and improve type combinations --- internal/compiler/output_columns.go | 154 +++++++++++++++++++--------- 1 file changed, 104 insertions(+), 50 deletions(-) diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index 6621bbf16e..bbee9e513f 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -126,7 +126,6 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er continue } switch n := res.Val.(type) { - case *ast.A_Const: name := "" if res.Name != nil { @@ -150,7 +149,8 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er if res.Name != nil { name = *res.Name } - switch op := astutils.Join(n.Name, ""); { + op := astutils.Join(n.Name, "") + switch { case lang.IsComparisonOperator(op): // TODO: Generate a name for these operations cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) @@ -774,98 +774,152 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) } // inferMathExpressionType attempts to infer the data type of a mathematical expression -// by analyzing its operands and the operation being performed +// by analyzing its operands and the operation being performed. func (c *Compiler) inferMathExpressionType(expr *ast.A_Expr, tables []*Table, op string) (string, bool) { - // Try to infer types from left and right operands - leftType := c.inferOperandType(expr.Lexpr, tables) - rightType := c.inferOperandType(expr.Rexpr, tables) + leftType, leftNotNull := c.inferOperandType(expr.Lexpr, tables) + rightType, rightNotNull := c.inferOperandType(expr.Rexpr, tables) - // Debug logging to understand what's happening - // fmt.Printf("DEBUG: Math expression %s: left=%s, right=%s\n", op, leftType, rightType) + // result is non-null only if both sides are non-null + notNull := leftNotNull && rightNotNull - // Determine the result type based on operands and operation resultType := c.combineTypes(leftType, rightType, op) - - // For now, assume nullable since we're dealing with database columns - // In a more sophisticated implementation, we could track nullability through the expression - notNull := false - return resultType, notNull } -// inferOperandType tries to determine the type of an operand in an expression -func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) string { +// inferOperandType tries to determine the type and nullability of an operand in an expression. +func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) (string, bool) { switch n := operand.(type) { case *ast.ColumnRef: - // Look up the column in the available tables parts := stringSlice(n.Fields) var name string if len(parts) >= 1 { - name = parts[len(parts)-1] // Get the column name (last part) + name = parts[len(parts)-1] } - for _, table := range tables { for _, col := range table.Columns { if col.Name == name { - return col.DataType + return col.DataType, col.NotNull } } } - return "any" + // Unknown column: assume non-null by default so generated code + // keeps the previous non-nullable behavior (avoids sql.Null*). + return "any", true case *ast.A_Const: - // Determine type based on constant value + // constants are non-nullable switch n.Val.(type) { case *ast.Integer: - return "int" + return "int", true case *ast.Float: - return "float" + return "float", true case *ast.String: - return "text" + return "text", true default: - return "any" + return "any", true } case *ast.A_Expr: - // Recursive case for nested expressions + // nested expression + nestedOp := "" if n.Name != nil { - nestedOp := astutils.Join(n.Name, "") - if lang.IsMathematicalOperator(nestedOp) { - resultType, _ := c.inferMathExpressionType(n, tables, nestedOp) - return resultType - } + nestedOp = astutils.Join(n.Name, "") } - return "any" + if lang.IsMathematicalOperator(nestedOp) { + t, notNull := c.inferMathExpressionType(n, tables, nestedOp) + return t, notNull + } + return "any", true default: - return "any" + return "any", true } } // combineTypes determines the result type when combining two operand types with an operation func (c *Compiler) combineTypes(leftType, rightType, op string) string { - // Handle division specially - division operations typically result in float - if op == "/" { - // If either operand is float, result is float - if leftType == "float" || rightType == "float" { + // Helper function to check if a type is a float variant + isFloatType := func(t string) bool { + return t == "float" || t == "float32" || t == "float64" || t == "double" || t == "double precision" || t == "real" + } + + // Helper function to check if a type is an integer variant + isIntType := func(t string) bool { + return t == "int" || t == "int32" || t == "int64" || t == "integer" || t == "int4" || t == "int8" || t == "bigint" || t == "smallint" + } + + // Normalize common DB types to standard types + normalizeType := func(t string) string { + switch t { + case "int4", "integer", "int32": + return "int" + case "int8", "bigint", "int64": + return "int" + case "smallint": + return "int" + case "float4", "real", "float32": + return "float" + case "float8", "double precision", "float64": + return "float" + case "any": + return "any" + default: + if isIntType(t) { + return "int" + } + if isFloatType(t) { + return "float" + } + return t + } + } + + leftNorm := normalizeType(leftType) + rightNorm := normalizeType(rightType) + + // treat MySQL "div" same as "/" for division semantics + if op == "/" || op == "div" { + if leftNorm == "float" || rightNorm == "float" { + return "float" + } + // If both are ints, return float for division (mathematical accuracy) + if leftNorm == "int" && rightNorm == "int" { + return "float" + } + // If at least one is numeric, prefer float + if (leftNorm == "int" || leftNorm == "float") && rightNorm == "any" { return "float" } - // Even integer division might want to be float in many cases - // For safety, return float for division unless both operands are clearly non-numeric - if leftType != "text" && rightType != "text" { + if leftNorm == "any" && (rightNorm == "int" || rightNorm == "float") { return "float" } + // For mixed types with at least one numeric, prefer float + if leftNorm != "text" && rightNorm != "text" { + return "float" + } + return "any" } - // For other mathematical operations - switch { - case leftType == "float" || rightType == "float": + // other math ops (* + -): + if leftNorm == "float" || rightNorm == "float" { return "float" - case leftType == "int" && rightType == "int": + } + if leftNorm == "int" && rightNorm == "int" { return "int" - case leftType == "int" && rightType == "any": + } + // If one side is numeric and other is any, prefer the numeric type + if leftNorm == "int" && rightNorm == "any" { return "int" - case leftType == "any" && rightType == "int": + } + if leftNorm == "any" && rightNorm == "int" { + return "int" + } + if leftNorm == "float" && rightNorm == "any" { + return "float" + } + if leftNorm == "any" && rightNorm == "float" { + return "float" + } + // If both are any, prefer int as a reasonable default for math + if leftNorm == "any" && rightNorm == "any" { return "int" - default: - // Default fallback - return "any" } + return "any" }