Skip to content
2 changes: 1 addition & 1 deletion oracle/clause_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func ReturningClauseBuilder(c clause.Clause, builder clause.Builder) {
var dest interface{}
if stmt.Schema != nil {
if field := findFieldByDBName(stmt.Schema, column.Name); field != nil {
dest = createTypedDestination(field.FieldType)
dest = createTypedDestination(field)
} else {
dest = new(string) // Default to string for unknown fields
}
Expand Down
153 changes: 92 additions & 61 deletions oracle/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,30 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field {
}

// Create typed destination for OUT parameters
func createTypedDestination(fieldType reflect.Type) interface{} {
// Handle pointer types
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
func createTypedDestination(f *schema.Field) interface{} {
if f == nil {
var s string
return &s
}

// Type-safe handling for known GORM types and SQL null types
switch fieldType {
case reflect.TypeOf(gorm.DeletedAt{}):
ft := f.FieldType
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}

if ft == reflect.TypeOf(gorm.DeletedAt{}) {
return new(sql.NullTime)
case reflect.TypeOf(time.Time{}):
}
if ft == reflect.TypeOf(time.Time{}) {
if !f.NotNull { // nullable column => keep NULLs
return new(sql.NullTime)
}
return new(time.Time)
}

switch ft {
case reflect.TypeOf(sql.NullTime{}):
return new(sql.NullTime)
case reflect.TypeOf(sql.NullInt64{}):
return new(sql.NullInt64)
case reflect.TypeOf(sql.NullInt32{}):
Expand All @@ -120,33 +132,28 @@ func createTypedDestination(fieldType reflect.Type) interface{} {
return new(sql.NullFloat64)
case reflect.TypeOf(sql.NullBool{}):
return new(sql.NullBool)
case reflect.TypeOf(sql.NullTime{}):
return new(sql.NullTime)
}

// Handle primitive types by Kind
switch fieldType.Kind() {
switch ft.Kind() {
case reflect.String:
return new(string)

case reflect.Bool:
return new(int64)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return new(int64) // Oracle returns NUMBER as int64
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return new(int64)

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return new(uint64)

case reflect.Float32, reflect.Float64:
return new(float64) // Oracle returns FLOAT as float64
case reflect.Bool:
return new(int64) // Oracle NUMBER(1) for boolean
case reflect.String:
return new(string)
case reflect.Struct:
// For time.Time specifically
if fieldType == reflect.TypeOf(time.Time{}) {
return new(time.Time)
}
// For other structs, use string as safe fallback
return new(string)
default:
// For unknown types, use string as safe fallback
return new(string)
return new(float64)
}

// Fallback
var s string
return &s
}

// Convert values for Oracle-specific types
Expand Down Expand Up @@ -182,7 +189,7 @@ func convertValue(val interface{}) interface{} {

// Convert Oracle values back to Go types
func convertFromOracleToField(value interface{}, field *schema.Field) interface{} {
if value == nil {
if value == nil || field == nil {
return nil
}

Expand All @@ -194,7 +201,6 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{

var converted interface{}

// Handle special types first using type-safe comparisons
switch targetType {
case reflect.TypeOf(gorm.DeletedAt{}):
if nullTime, ok := value.(sql.NullTime); ok {
Expand All @@ -203,7 +209,31 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
converted = gorm.DeletedAt{}
}
case reflect.TypeOf(time.Time{}):
converted = value
switch vv := value.(type) {
case time.Time:
converted = vv
case sql.NullTime:
if vv.Valid {
converted = vv.Time
} else {
// DB returned NULL
if isPtr {
return nil // -> *time.Time(nil)
}
// non-pointer time.Time: represent NULL as zero time
return time.Time{}
}
default:
converted = value
}

case reflect.TypeOf(sql.NullTime{}):
if nullTime, ok := value.(sql.NullTime); ok {
converted = nullTime
} else {
converted = sql.NullTime{}
}

case reflect.TypeOf(sql.NullInt64{}):
if nullInt, ok := value.(sql.NullInt64); ok {
converted = nullInt
Expand All @@ -228,48 +258,24 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
} else {
converted = sql.NullBool{}
}
case reflect.TypeOf(sql.NullTime{}):
if nullTime, ok := value.(sql.NullTime); ok {
converted = nullTime
} else {
converted = sql.NullTime{}
}
default:
// Handle primitive types
// primitives and everything else
converted = convertPrimitiveType(value, targetType)
}

// Handle pointer types
if isPtr && converted != nil {
if isZeroValueForPointer(converted, targetType) {
// Pointer targets: nil for "zero-ish", else allocate and set.
if isPtr {
if isZeroFor(targetType, converted) {
return nil
}
ptr := reflect.New(targetType)
ptr.Elem().Set(reflect.ValueOf(converted))
converted = ptr.Interface()
return ptr.Interface()
}

return converted
}

// Helper function to check if a value should be treated as nil for pointer fields
func isZeroValueForPointer(value interface{}, targetType reflect.Type) bool {
v := reflect.ValueOf(value)
if !v.IsValid() || v.Kind() != targetType.Kind() {
return false
}

switch targetType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0.0
}
return false
}

// Helper function to handle primitive type conversions
func convertPrimitiveType(value interface{}, targetType reflect.Type) interface{} {
switch targetType.Kind() {
Expand Down Expand Up @@ -442,3 +448,28 @@ func isNullValue(value interface{}) bool {
return false
}
}

func isZeroFor(t reflect.Type, v interface{}) bool {
if v == nil {
return true
}
rv := reflect.ValueOf(v)
if !rv.IsValid() {
return true
}
// exact type match?
if rv.Type() == t {
// special-case time.Time
if t == reflect.TypeOf(time.Time{}) {
return rv.Interface().(time.Time).IsZero()
}
// generic zero check
z := reflect.Zero(t)
return reflect.DeepEqual(rv.Interface(), z.Interface())
}
// If types differ (e.g., sql.NullTime), treat invalid as zero
if nt, ok := v.(sql.NullTime); ok {
return !nt.Valid
}
return false
}
4 changes: 2 additions & 2 deletions oracle/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
for _, column := range allColumns {
if field := findFieldByDBName(schema, column); field != nil {
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
plsqlBuilder.WriteString("; END IF;\n")
Expand Down Expand Up @@ -602,7 +602,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
quotedColumn := columnBuilder.String()

if field := findFieldByDBName(schema, column); field != nil {
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n",
rowIdx, outParamIndex+1, rowIdx+1, quotedColumn))
outParamIndex++
Expand Down
2 changes: 1 addition & 1 deletion oracle/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
for _, column := range allColumns {
field := findFieldByDBName(schema, column)
if field != nil {
dest := createTypedDestination(field.FieldType)
dest := createTypedDestination(field)
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})

plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))
Expand Down
2 changes: 1 addition & 1 deletion oracle/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
for _, column := range allColumns {
field := findFieldByDBName(schema, column)
if field != nil {
dest := createTypedDestination(field.FieldType)
dest := createTypedDestination(field)
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
}
}
Expand Down
2 changes: 0 additions & 2 deletions tests/associations_many2many_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
}

func TestDuplicateMany2ManyAssociation(t *testing.T) {
t.Skip()
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
Expand Down Expand Up @@ -436,7 +435,6 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) {
}

func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
t.Skip()
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
ID: 1,
Expand Down
1 change: 0 additions & 1 deletion tests/generics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ func TestGenericsJoins(t *testing.T) {
}

func TestGenericsNestedJoins(t *testing.T) {
t.Skip()
users := []User{
{
Name: "generics-nested-joins-1",
Expand Down
4 changes: 1 addition & 3 deletions tests/joins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,6 @@ func TestJoinArgsWithDB(t *testing.T) {
}

func TestNestedJoins(t *testing.T) {
t.Skip()

users := []User{
{
Name: "nested-joins-1",
Expand Down Expand Up @@ -424,7 +422,7 @@ func TestNestedJoins(t *testing.T) {
Joins("Manager.NamedPet.Toy").
Joins("NamedPet").
Joins("NamedPet.Toy").
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
Find(&users2, "\"users\".\"id\" IN ?", userIDs).Error; err != nil {
t.Fatalf("Failed to load with joins, got error: %v", err)
} else if len(users2) != len(users) {
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))
Expand Down
10 changes: 5 additions & 5 deletions tests/passed-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ TestMany2ManyOmitAssociations
TestMany2ManyAssociationForSlice
#TestSingleTableMany2ManyAssociation
#TestSingleTableMany2ManyAssociationForSlice
#TestDuplicateMany2ManyAssociation
TestDuplicateMany2ManyAssociation
TestConcurrentMany2ManyAssociation
#TestMany2ManyDuplicateBelongsToAssociation
TestMany2ManyDuplicateBelongsToAssociation
TestInvalidAssociation
TestAssociationNotNullClear
#TestForeignKeyConstraints
Expand Down Expand Up @@ -112,7 +112,7 @@ TestGenericsDelete
TestGenericsFindInBatches
TestGenericsScopes
#TestGenericsJoins
#TestGenericsNestedJoins
TestGenericsNestedJoins
#TestGenericsPreloads
#TestGenericsNestedPreloads
TestGenericsDistinct
Expand Down Expand Up @@ -146,7 +146,7 @@ TestJoinCount
TestInnerJoins
TestJoinWithSameColumnName
TestJoinArgsWithDB
#TestNestedJoins
TestNestedJoins
TestJoinsPreload_Issue7013
TestJoinsPreload_Issue7013_RelationEmpty
TestJoinsPreload_Issue7013_NoEntries
Expand Down Expand Up @@ -267,7 +267,7 @@ TestSubQueryWithHaving
TestQueryWithTableAndConditions
TestQueryWithTableAndConditionsAndAllFields
#TestQueryScannerWithSingleColumn
#TestQueryResetNullValue
TestQueryResetNullValue
TestQueryError
TestQueryScanToArray
TestRownum
Expand Down
1 change: 0 additions & 1 deletion tests/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,6 @@ func TestQueryScannerWithSingleColumn(t *testing.T) {
}

func TestQueryResetNullValue(t *testing.T) {
t.Skip()
type QueryResetItem struct {
ID string `gorm:"type:varchar(5)"`
Name string
Expand Down
Loading