diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index b0a15e6ac4..bbee9e513f 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -126,7 +126,6 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er continue } switch n := res.Val.(type) { - case *ast.A_Const: name := "" if res.Name != nil { @@ -150,12 +149,15 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er if res.Name != nil { name = *res.Name } - switch op := astutils.Join(n.Name, ""); { + op := astutils.Join(n.Name, "") + switch { case lang.IsComparisonOperator(op): // TODO: Generate a name for these operations cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true}) case lang.IsMathematicalOperator(op): - cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true}) + // Improve type inference for mathematical expressions + dataType, notNull := c.inferMathExpressionType(n, tables, op) + cols = append(cols, &Column{Name: name, DataType: dataType, NotNull: notNull}) default: cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false}) } @@ -770,3 +772,154 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) return nil } + +// inferMathExpressionType attempts to infer the data type of a mathematical expression +// by analyzing its operands and the operation being performed. +func (c *Compiler) inferMathExpressionType(expr *ast.A_Expr, tables []*Table, op string) (string, bool) { + leftType, leftNotNull := c.inferOperandType(expr.Lexpr, tables) + rightType, rightNotNull := c.inferOperandType(expr.Rexpr, tables) + + // result is non-null only if both sides are non-null + notNull := leftNotNull && rightNotNull + + resultType := c.combineTypes(leftType, rightType, op) + return resultType, notNull +} + +// inferOperandType tries to determine the type and nullability of an operand in an expression. +func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) (string, bool) { + switch n := operand.(type) { + case *ast.ColumnRef: + parts := stringSlice(n.Fields) + var name string + if len(parts) >= 1 { + name = parts[len(parts)-1] + } + for _, table := range tables { + for _, col := range table.Columns { + if col.Name == name { + return col.DataType, col.NotNull + } + } + } + // Unknown column: assume non-null by default so generated code + // keeps the previous non-nullable behavior (avoids sql.Null*). + return "any", true + case *ast.A_Const: + // constants are non-nullable + switch n.Val.(type) { + case *ast.Integer: + return "int", true + case *ast.Float: + return "float", true + case *ast.String: + return "text", true + default: + return "any", true + } + case *ast.A_Expr: + // nested expression + nestedOp := "" + if n.Name != nil { + nestedOp = astutils.Join(n.Name, "") + } + if lang.IsMathematicalOperator(nestedOp) { + t, notNull := c.inferMathExpressionType(n, tables, nestedOp) + return t, notNull + } + return "any", true + default: + return "any", true + } +} + +// combineTypes determines the result type when combining two operand types with an operation +func (c *Compiler) combineTypes(leftType, rightType, op string) string { + // Helper function to check if a type is a float variant + isFloatType := func(t string) bool { + return t == "float" || t == "float32" || t == "float64" || t == "double" || t == "double precision" || t == "real" + } + + // Helper function to check if a type is an integer variant + isIntType := func(t string) bool { + return t == "int" || t == "int32" || t == "int64" || t == "integer" || t == "int4" || t == "int8" || t == "bigint" || t == "smallint" + } + + // Normalize common DB types to standard types + normalizeType := func(t string) string { + switch t { + case "int4", "integer", "int32": + return "int" + case "int8", "bigint", "int64": + return "int" + case "smallint": + return "int" + case "float4", "real", "float32": + return "float" + case "float8", "double precision", "float64": + return "float" + case "any": + return "any" + default: + if isIntType(t) { + return "int" + } + if isFloatType(t) { + return "float" + } + return t + } + } + + leftNorm := normalizeType(leftType) + rightNorm := normalizeType(rightType) + + // treat MySQL "div" same as "/" for division semantics + if op == "/" || op == "div" { + if leftNorm == "float" || rightNorm == "float" { + return "float" + } + // If both are ints, return float for division (mathematical accuracy) + if leftNorm == "int" && rightNorm == "int" { + return "float" + } + // If at least one is numeric, prefer float + if (leftNorm == "int" || leftNorm == "float") && rightNorm == "any" { + return "float" + } + if leftNorm == "any" && (rightNorm == "int" || rightNorm == "float") { + return "float" + } + // For mixed types with at least one numeric, prefer float + if leftNorm != "text" && rightNorm != "text" { + return "float" + } + return "any" + } + + // other math ops (* + -): + if leftNorm == "float" || rightNorm == "float" { + return "float" + } + if leftNorm == "int" && rightNorm == "int" { + return "int" + } + // If one side is numeric and other is any, prefer the numeric type + if leftNorm == "int" && rightNorm == "any" { + return "int" + } + if leftNorm == "any" && rightNorm == "int" { + return "int" + } + if leftNorm == "float" && rightNorm == "any" { + return "float" + } + if leftNorm == "any" && rightNorm == "float" { + return "float" + } + // If both are any, prefer int as a reasonable default for math + if leftNorm == "any" && rightNorm == "any" { + return "int" + } + return "any" +} diff --git a/internal/sql/lang/operator.go b/internal/sql/lang/operator.go index cd5ef50e38..7d2556c855 100644 --- a/internal/sql/lang/operator.go +++ b/internal/sql/lang/operator.go @@ -23,6 +23,7 @@ func IsMathematicalOperator(s string) bool { case "-": case "*": case "/": + case "div": // MySQL division operator case "%": case "^": case "|/":