diff --git a/oracle/common.go b/oracle/common.go index 347c016..42571e6 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -424,6 +424,50 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) { builder.WriteByte('"') } +// writeTableRecordCollectionDecl writes the PL/SQL declarations needed to +// define a custom record type and a collection of that record type, +// based on the schema of the given table. +// +// Specifically, it generates: +// - A RECORD type (`t_record`) with fields corresponding to the table's columns. +// - A nested TABLE type (`t_records`) of `t_record`. +// +// The declarations are written into the provided strings.Builder in the +// correct PL/SQL syntax, so they can be used as part of a larger PL/SQL block. +// +// Example output: +// +// TYPE t_record IS RECORD ( +// "id" "users"."id"%TYPE, +// "created_at" "users"."created_at"%TYPE, +// ... +// ); +// TYPE t_records IS TABLE OF t_record; +// +// Parameters: +// - plsqlBuilder: The builder to write the PL/SQL code into. +// - dbNames: The slice containing the column names. +// - table: The table name +func writeTableRecordCollectionDecl(plsqlBuilder *strings.Builder, dbNames []string, table string) { + // Declare a record where each element has the same structure as a row from the given table + plsqlBuilder.WriteString(" TYPE t_record IS RECORD (\n") + for i, field := range dbNames { + if i > 0 { + plsqlBuilder.WriteString(",\n") + } + plsqlBuilder.WriteString(" ") + writeQuotedIdentifier(plsqlBuilder, field) + plsqlBuilder.WriteString(" ") + writeQuotedIdentifier(plsqlBuilder, table) + plsqlBuilder.WriteString(".") + writeQuotedIdentifier(plsqlBuilder, field) + plsqlBuilder.WriteString("%TYPE") + } + plsqlBuilder.WriteString("\n") + plsqlBuilder.WriteString(" );\n") + plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF t_record;\n") +} + // Helper function to check if a value represents NULL func isNullValue(value interface{}) bool { if value == nil { diff --git a/oracle/create.go b/oracle/create.go index df56828..2a23339 100644 --- a/oracle/create.go +++ b/oracle/create.go @@ -285,9 +285,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau // Start PL/SQL block plsqlBuilder.WriteString("DECLARE\n") - plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ") - writeQuotedIdentifier(&plsqlBuilder, stmt.Table) - plsqlBuilder.WriteString("%ROWTYPE;\n") + writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table) plsqlBuilder.WriteString(" l_affected_records t_records;\n") // Create array types and variables for each column @@ -526,9 +524,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { // Start PL/SQL block plsqlBuilder.WriteString("DECLARE\n") - plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ") - writeQuotedIdentifier(&plsqlBuilder, stmt.Table) - plsqlBuilder.WriteString("%ROWTYPE;\n") + writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table) plsqlBuilder.WriteString(" l_inserted_records t_records;\n") // Create array types and variables for each column diff --git a/oracle/delete.go b/oracle/delete.go index ba48eb9..871ac41 100644 --- a/oracle/delete.go +++ b/oracle/delete.go @@ -239,9 +239,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) { // Start PL/SQL block plsqlBuilder.WriteString("DECLARE\n") - plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ") - writeQuotedIdentifier(&plsqlBuilder, stmt.Table) - plsqlBuilder.WriteString("%ROWTYPE;\n") + writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table) plsqlBuilder.WriteString(" l_deleted_records t_records;\n") plsqlBuilder.WriteString("BEGIN\n") diff --git a/oracle/update.go b/oracle/update.go index a504b1d..0a5b653 100644 --- a/oracle/update.go +++ b/oracle/update.go @@ -476,9 +476,7 @@ func buildUpdatePLSQL(db *gorm.DB) { // Start PL/SQL block plsqlBuilder.WriteString("DECLARE\n") - plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ") - writeQuotedIdentifier(&plsqlBuilder, stmt.Table) - plsqlBuilder.WriteString("%ROWTYPE;\n") + writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table) plsqlBuilder.WriteString(" l_updated_records t_records;\n") plsqlBuilder.WriteString("BEGIN\n") diff --git a/tests/gorm_test.go b/tests/gorm_test.go index 2f24159..b1935b8 100644 --- a/tests/gorm_test.go +++ b/tests/gorm_test.go @@ -54,7 +54,6 @@ func TestOpen(t *testing.T) { } func TestReturningWithNullToZeroValues(t *testing.T) { - t.Skip() // This user struct will leverage the existing users table, but override // the Name field to default to null. type user struct { @@ -72,7 +71,7 @@ func TestReturningWithNullToZeroValues(t *testing.T) { } got := user{} - results := DB.First(&got, "id = ?", u1.ID) + results := DB.First(&got, "\"id\" = ?", u1.ID) if results.Error != nil { t.Fatalf("errors happened on first: %v", results.Error) } else if results.RowsAffected != 1 { @@ -81,7 +80,7 @@ func TestReturningWithNullToZeroValues(t *testing.T) { t.Fatalf("first expects: %v, got %v", u1, got) } - results = DB.Select("id, name").Find(&got) + results = DB.Select("\"id\", \"name\"").Find(&got) if results.Error != nil { t.Fatalf("errors happened on first: %v", results.Error) } else if results.RowsAffected != 1 { @@ -112,7 +111,7 @@ func TestReturningWithNullToZeroValues(t *testing.T) { } var gotUsers []user - results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) + results = DB.Where("\"id\" in (?, ?)", u1.ID, u2.ID).Order("\"id\" asc").Select("\"id\", \"name\"").Find(&gotUsers) if results.Error != nil { t.Fatalf("errors happened on first: %v", results.Error) } else if results.RowsAffected != 2 { diff --git a/tests/passed-tests.txt b/tests/passed-tests.txt index 783e39b..bd2f203 100644 --- a/tests/passed-tests.txt +++ b/tests/passed-tests.txt @@ -124,7 +124,7 @@ TestGenericsReuse TestGenericsWithTransaction TestGenericsToSQL TestOpen -#TestReturningWithNullToZeroValues +TestReturningWithNullToZeroValues TestGroupBy TestRunCallbacks TestCallbacksWithErrors