Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 156 additions & 3 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
continue
}
switch n := res.Val.(type) {

case *ast.A_Const:
name := ""
if res.Name != nil {
Expand All @@ -150,12 +149,15 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
if res.Name != nil {
name = *res.Name
}
switch op := astutils.Join(n.Name, ""); {
op := astutils.Join(n.Name, "")
switch {
case lang.IsComparisonOperator(op):
// TODO: Generate a name for these operations
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
case lang.IsMathematicalOperator(op):
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
// Improve type inference for mathematical expressions
dataType, notNull := c.inferMathExpressionType(n, tables, op)
cols = append(cols, &Column{Name: name, DataType: dataType, NotNull: notNull})
default:
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
}
Expand Down Expand Up @@ -770,3 +772,154 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List)

return nil
}

// inferMathExpressionType attempts to infer the data type of a mathematical expression
// by analyzing its operands and the operation being performed.
func (c *Compiler) inferMathExpressionType(expr *ast.A_Expr, tables []*Table, op string) (string, bool) {
leftType, leftNotNull := c.inferOperandType(expr.Lexpr, tables)
rightType, rightNotNull := c.inferOperandType(expr.Rexpr, tables)

// result is non-null only if both sides are non-null
notNull := leftNotNull && rightNotNull

resultType := c.combineTypes(leftType, rightType, op)
return resultType, notNull
}

// inferOperandType tries to determine the type and nullability of an operand in an expression.
func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) (string, bool) {
switch n := operand.(type) {
case *ast.ColumnRef:
parts := stringSlice(n.Fields)
var name string
if len(parts) >= 1 {
name = parts[len(parts)-1]
}
for _, table := range tables {
for _, col := range table.Columns {
if col.Name == name {
return col.DataType, col.NotNull
}
}
}
// Unknown column: assume non-null by default so generated code
// keeps the previous non-nullable behavior (avoids sql.Null*).
return "any", true
case *ast.A_Const:
// constants are non-nullable
switch n.Val.(type) {
case *ast.Integer:
return "int", true
case *ast.Float:
return "float", true
case *ast.String:
return "text", true
default:
return "any", true
}
case *ast.A_Expr:
// nested expression
nestedOp := ""
if n.Name != nil {
nestedOp = astutils.Join(n.Name, "")
}
if lang.IsMathematicalOperator(nestedOp) {
t, notNull := c.inferMathExpressionType(n, tables, nestedOp)
return t, notNull
}
return "any", true
default:
return "any", true
}
}

// combineTypes determines the result type when combining two operand types with an operation
func (c *Compiler) combineTypes(leftType, rightType, op string) string {
// Helper function to check if a type is a float variant
isFloatType := func(t string) bool {
return t == "float" || t == "float32" || t == "float64" || t == "double" || t == "double precision" || t == "real"
}

// Helper function to check if a type is an integer variant
isIntType := func(t string) bool {
return t == "int" || t == "int32" || t == "int64" || t == "integer" || t == "int4" || t == "int8" || t == "bigint" || t == "smallint"
}

// Normalize common DB types to standard types
normalizeType := func(t string) string {
switch t {
case "int4", "integer", "int32":
return "int"
case "int8", "bigint", "int64":
return "int"
case "smallint":
return "int"
case "float4", "real", "float32":
return "float"
case "float8", "double precision", "float64":
return "float"
case "any":
return "any"
default:
if isIntType(t) {
return "int"
}
if isFloatType(t) {
return "float"
}
return t
}
}

leftNorm := normalizeType(leftType)
rightNorm := normalizeType(rightType)

// treat MySQL "div" same as "/" for division semantics
if op == "/" || op == "div" {
if leftNorm == "float" || rightNorm == "float" {
return "float"
}
// If both are ints, return float for division (mathematical accuracy)
if leftNorm == "int" && rightNorm == "int" {
return "float"
}
// If at least one is numeric, prefer float
if (leftNorm == "int" || leftNorm == "float") && rightNorm == "any" {
return "float"
}
if leftNorm == "any" && (rightNorm == "int" || rightNorm == "float") {
return "float"
}
// For mixed types with at least one numeric, prefer float
if leftNorm != "text" && rightNorm != "text" {
return "float"
}
return "any"
}

// other math ops (* + -):
if leftNorm == "float" || rightNorm == "float" {
return "float"
}
if leftNorm == "int" && rightNorm == "int" {
return "int"
}
// If one side is numeric and other is any, prefer the numeric type
if leftNorm == "int" && rightNorm == "any" {
return "int"
}
if leftNorm == "any" && rightNorm == "int" {
return "int"
}
if leftNorm == "float" && rightNorm == "any" {
return "float"
}
if leftNorm == "any" && rightNorm == "float" {
return "float"
}
// If both are any, prefer int as a reasonable default for math
if leftNorm == "any" && rightNorm == "any" {
return "int"
}
return "any"
}
1 change: 1 addition & 0 deletions internal/sql/lang/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func IsMathematicalOperator(s string) bool {
case "-":
case "*":
case "/":
case "div": // MySQL division operator
case "%":
case "^":
case "|/":
Expand Down
Loading