Skip to content

Commit 88d86e9

Browse files
Merge pull request #1 from oracle-samples/null-returned-value
Null returned value
2 parents c86c817 + 488bbae commit 88d86e9

File tree

10 files changed

+103
-78
lines changed

10 files changed

+103
-78
lines changed

oracle/clause_builder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func ReturningClauseBuilder(c clause.Clause, builder clause.Builder) {
165165
var dest interface{}
166166
if stmt.Schema != nil {
167167
if field := findFieldByDBName(stmt.Schema, column.Name); field != nil {
168-
dest = createTypedDestination(field.FieldType)
168+
dest = createTypedDestination(field)
169169
} else {
170170
dest = new(string) // Default to string for unknown fields
171171
}

oracle/common.go

Lines changed: 92 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,30 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field {
100100
}
101101

102102
// Create typed destination for OUT parameters
103-
func createTypedDestination(fieldType reflect.Type) interface{} {
104-
// Handle pointer types
105-
if fieldType.Kind() == reflect.Ptr {
106-
fieldType = fieldType.Elem()
103+
func createTypedDestination(f *schema.Field) interface{} {
104+
if f == nil {
105+
var s string
106+
return &s
107107
}
108108

109-
// Type-safe handling for known GORM types and SQL null types
110-
switch fieldType {
111-
case reflect.TypeOf(gorm.DeletedAt{}):
109+
ft := f.FieldType
110+
for ft.Kind() == reflect.Ptr {
111+
ft = ft.Elem()
112+
}
113+
114+
if ft == reflect.TypeOf(gorm.DeletedAt{}) {
112115
return new(sql.NullTime)
113-
case reflect.TypeOf(time.Time{}):
116+
}
117+
if ft == reflect.TypeOf(time.Time{}) {
118+
if !f.NotNull { // nullable column => keep NULLs
119+
return new(sql.NullTime)
120+
}
114121
return new(time.Time)
122+
}
123+
124+
switch ft {
125+
case reflect.TypeOf(sql.NullTime{}):
126+
return new(sql.NullTime)
115127
case reflect.TypeOf(sql.NullInt64{}):
116128
return new(sql.NullInt64)
117129
case reflect.TypeOf(sql.NullInt32{}):
@@ -120,33 +132,28 @@ func createTypedDestination(fieldType reflect.Type) interface{} {
120132
return new(sql.NullFloat64)
121133
case reflect.TypeOf(sql.NullBool{}):
122134
return new(sql.NullBool)
123-
case reflect.TypeOf(sql.NullTime{}):
124-
return new(sql.NullTime)
125135
}
126136

127-
// Handle primitive types by Kind
128-
switch fieldType.Kind() {
137+
switch ft.Kind() {
138+
case reflect.String:
139+
return new(string)
140+
141+
case reflect.Bool:
142+
return new(int64)
143+
129144
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
130-
return new(int64) // Oracle returns NUMBER as int64
131-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
145+
return new(int64)
146+
147+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
132148
return new(uint64)
149+
133150
case reflect.Float32, reflect.Float64:
134-
return new(float64) // Oracle returns FLOAT as float64
135-
case reflect.Bool:
136-
return new(int64) // Oracle NUMBER(1) for boolean
137-
case reflect.String:
138-
return new(string)
139-
case reflect.Struct:
140-
// For time.Time specifically
141-
if fieldType == reflect.TypeOf(time.Time{}) {
142-
return new(time.Time)
143-
}
144-
// For other structs, use string as safe fallback
145-
return new(string)
146-
default:
147-
// For unknown types, use string as safe fallback
148-
return new(string)
151+
return new(float64)
149152
}
153+
154+
// Fallback
155+
var s string
156+
return &s
150157
}
151158

152159
// Convert values for Oracle-specific types
@@ -182,7 +189,7 @@ func convertValue(val interface{}) interface{} {
182189

183190
// Convert Oracle values back to Go types
184191
func convertFromOracleToField(value interface{}, field *schema.Field) interface{} {
185-
if value == nil {
192+
if value == nil || field == nil {
186193
return nil
187194
}
188195

@@ -194,7 +201,6 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
194201

195202
var converted interface{}
196203

197-
// Handle special types first using type-safe comparisons
198204
switch targetType {
199205
case reflect.TypeOf(gorm.DeletedAt{}):
200206
if nullTime, ok := value.(sql.NullTime); ok {
@@ -203,7 +209,31 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
203209
converted = gorm.DeletedAt{}
204210
}
205211
case reflect.TypeOf(time.Time{}):
206-
converted = value
212+
switch vv := value.(type) {
213+
case time.Time:
214+
converted = vv
215+
case sql.NullTime:
216+
if vv.Valid {
217+
converted = vv.Time
218+
} else {
219+
// DB returned NULL
220+
if isPtr {
221+
return nil // -> *time.Time(nil)
222+
}
223+
// non-pointer time.Time: represent NULL as zero time
224+
return time.Time{}
225+
}
226+
default:
227+
converted = value
228+
}
229+
230+
case reflect.TypeOf(sql.NullTime{}):
231+
if nullTime, ok := value.(sql.NullTime); ok {
232+
converted = nullTime
233+
} else {
234+
converted = sql.NullTime{}
235+
}
236+
207237
case reflect.TypeOf(sql.NullInt64{}):
208238
if nullInt, ok := value.(sql.NullInt64); ok {
209239
converted = nullInt
@@ -228,48 +258,24 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
228258
} else {
229259
converted = sql.NullBool{}
230260
}
231-
case reflect.TypeOf(sql.NullTime{}):
232-
if nullTime, ok := value.(sql.NullTime); ok {
233-
converted = nullTime
234-
} else {
235-
converted = sql.NullTime{}
236-
}
237261
default:
238-
// Handle primitive types
262+
// primitives and everything else
239263
converted = convertPrimitiveType(value, targetType)
240264
}
241265

242-
// Handle pointer types
243-
if isPtr && converted != nil {
244-
if isZeroValueForPointer(converted, targetType) {
266+
// Pointer targets: nil for "zero-ish", else allocate and set.
267+
if isPtr {
268+
if isZeroFor(targetType, converted) {
245269
return nil
246270
}
247271
ptr := reflect.New(targetType)
248272
ptr.Elem().Set(reflect.ValueOf(converted))
249-
converted = ptr.Interface()
273+
return ptr.Interface()
250274
}
251275

252276
return converted
253277
}
254278

255-
// Helper function to check if a value should be treated as nil for pointer fields
256-
func isZeroValueForPointer(value interface{}, targetType reflect.Type) bool {
257-
v := reflect.ValueOf(value)
258-
if !v.IsValid() || v.Kind() != targetType.Kind() {
259-
return false
260-
}
261-
262-
switch targetType.Kind() {
263-
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
264-
return v.Int() == 0
265-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
266-
return v.Uint() == 0
267-
case reflect.Float32, reflect.Float64:
268-
return v.Float() == 0.0
269-
}
270-
return false
271-
}
272-
273279
// Helper function to handle primitive type conversions
274280
func convertPrimitiveType(value interface{}, targetType reflect.Type) interface{} {
275281
switch targetType.Kind() {
@@ -442,3 +448,28 @@ func isNullValue(value interface{}) bool {
442448
return false
443449
}
444450
}
451+
452+
func isZeroFor(t reflect.Type, v interface{}) bool {
453+
if v == nil {
454+
return true
455+
}
456+
rv := reflect.ValueOf(v)
457+
if !rv.IsValid() {
458+
return true
459+
}
460+
// exact type match?
461+
if rv.Type() == t {
462+
// special-case time.Time
463+
if t == reflect.TypeOf(time.Time{}) {
464+
return rv.Interface().(time.Time).IsZero()
465+
}
466+
// generic zero check
467+
z := reflect.Zero(t)
468+
return reflect.DeepEqual(rv.Interface(), z.Interface())
469+
}
470+
// If types differ (e.g., sql.NullTime), treat invalid as zero
471+
if nt, ok := v.(sql.NullTime); ok {
472+
return !nt.Valid
473+
}
474+
return false
475+
}

oracle/create.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
490490
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
491491
for _, column := range allColumns {
492492
if field := findFieldByDBName(schema, column); field != nil {
493-
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
493+
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
494494
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1))
495495
writeQuotedIdentifier(&plsqlBuilder, column)
496496
plsqlBuilder.WriteString("; END IF;\n")
@@ -602,7 +602,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
602602
quotedColumn := columnBuilder.String()
603603

604604
if field := findFieldByDBName(schema, column); field != nil {
605-
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
605+
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
606606
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n",
607607
rowIdx, outParamIndex+1, rowIdx+1, quotedColumn))
608608
outParamIndex++

oracle/delete.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
278278
for _, column := range allColumns {
279279
field := findFieldByDBName(schema, column)
280280
if field != nil {
281-
dest := createTypedDestination(field.FieldType)
281+
dest := createTypedDestination(field)
282282
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
283283

284284
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))

oracle/update.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
544544
for _, column := range allColumns {
545545
field := findFieldByDBName(schema, column)
546546
if field != nil {
547-
dest := createTypedDestination(field.FieldType)
547+
dest := createTypedDestination(field)
548548
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
549549
}
550550
}

tests/associations_many2many_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) {
372372
}
373373

374374
func TestDuplicateMany2ManyAssociation(t *testing.T) {
375-
t.Skip()
376375
user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{
377376
{Code: "TestDuplicateMany2ManyAssociation-language-1"},
378377
{Code: "TestDuplicateMany2ManyAssociation-language-2"},
@@ -436,7 +435,6 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) {
436435
}
437436

438437
func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) {
439-
t.Skip()
440438
user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{
441439
{Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{
442440
ID: 1,

tests/generics_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ func TestGenericsJoins(t *testing.T) {
422422
}
423423

424424
func TestGenericsNestedJoins(t *testing.T) {
425-
t.Skip()
426425
users := []User{
427426
{
428427
Name: "generics-nested-joins-1",

tests/joins_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,6 @@ func TestJoinArgsWithDB(t *testing.T) {
383383
}
384384

385385
func TestNestedJoins(t *testing.T) {
386-
t.Skip()
387-
388386
users := []User{
389387
{
390388
Name: "nested-joins-1",
@@ -424,7 +422,7 @@ func TestNestedJoins(t *testing.T) {
424422
Joins("Manager.NamedPet.Toy").
425423
Joins("NamedPet").
426424
Joins("NamedPet.Toy").
427-
Find(&users2, "users.id IN ?", userIDs).Error; err != nil {
425+
Find(&users2, "\"users\".\"id\" IN ?", userIDs).Error; err != nil {
428426
t.Fatalf("Failed to load with joins, got error: %v", err)
429427
} else if len(users2) != len(users) {
430428
t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users))

tests/passed-tests.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ TestMany2ManyOmitAssociations
2121
TestMany2ManyAssociationForSlice
2222
#TestSingleTableMany2ManyAssociation
2323
#TestSingleTableMany2ManyAssociationForSlice
24-
#TestDuplicateMany2ManyAssociation
24+
TestDuplicateMany2ManyAssociation
2525
TestConcurrentMany2ManyAssociation
26-
#TestMany2ManyDuplicateBelongsToAssociation
26+
TestMany2ManyDuplicateBelongsToAssociation
2727
TestInvalidAssociation
2828
TestAssociationNotNullClear
2929
#TestForeignKeyConstraints
@@ -112,7 +112,7 @@ TestGenericsDelete
112112
TestGenericsFindInBatches
113113
TestGenericsScopes
114114
#TestGenericsJoins
115-
#TestGenericsNestedJoins
115+
TestGenericsNestedJoins
116116
#TestGenericsPreloads
117117
#TestGenericsNestedPreloads
118118
TestGenericsDistinct
@@ -146,7 +146,7 @@ TestJoinWithSoftDeleted
146146
TestInnerJoins
147147
TestJoinWithSameColumnName
148148
TestJoinArgsWithDB
149-
#TestNestedJoins
149+
TestNestedJoins
150150
TestJoinsPreload_Issue7013
151151
TestJoinsPreload_Issue7013_RelationEmpty
152152
TestJoinsPreload_Issue7013_NoEntries
@@ -274,7 +274,7 @@ TestScanNullValue
274274
TestQueryWithTableAndConditions
275275
TestQueryWithTableAndConditionsAndAllFields
276276
#TestQueryScannerWithSingleColumn
277-
#TestQueryResetNullValue
277+
TestQueryResetNullValue
278278
TestQueryError
279279
TestQueryScanToArray
280280
TestRownum

tests/query_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,6 @@ func TestQueryScannerWithSingleColumn(t *testing.T) {
13861386
}
13871387

13881388
func TestQueryResetNullValue(t *testing.T) {
1389-
t.Skip()
13901389
type QueryResetItem struct {
13911390
ID string `gorm:"type:varchar(5)"`
13921391
Name string

0 commit comments

Comments
 (0)