@@ -28,8 +28,9 @@ func (e *Enum) isType() {
2828}
2929
3030type CompositeType struct {
31- Name string
32- Comment string
31+ Name string
32+ ColTypeNames []* ast.TypeName
33+ Comment string
3334}
3435
3536func (ct * CompositeType ) isType () {
@@ -101,6 +102,16 @@ func stringSlice(list *ast.List) []string {
101102 return items
102103}
103104
105+ func columnTypeNamesSlice (list * ast.List ) []* ast.TypeName {
106+ items := []* ast.TypeName {}
107+ for _ , item := range list .Items {
108+ if n , ok := item .(* ast.ColumnDef ); ok {
109+ items = append (items , n .TypeName )
110+ }
111+ }
112+ return items
113+ }
114+
104115func (c * Catalog ) getType (rel * ast.TypeName ) (Type , int , error ) {
105116 ns := rel .Schema
106117 if ns == "" {
@@ -136,7 +147,8 @@ func (c *Catalog) createCompositeType(stmt *ast.CompositeTypeStmt) error {
136147 return sqlerr .TypeExists (tbl .Name )
137148 }
138149 schema .Types = append (schema .Types , & CompositeType {
139- Name : stmt .TypeName .Name ,
150+ Name : stmt .TypeName .Name ,
151+ ColTypeNames : columnTypeNamesSlice (stmt .ColDefList ),
140152 })
141153 return nil
142154}
@@ -277,16 +289,11 @@ func (c *Catalog) alterTypeSetSchema(stmt *ast.AlterTypeSetSchemaStmt) error {
277289 oldSchema .Types = append (oldSchema .Types [:idx ], oldSchema .Types [idx + 1 :]... )
278290 newSchema .Types = append (newSchema .Types , typ )
279291
280- // Update all the table columns with the new type
281- for _ , schema := range c .Schemas {
282- for _ , table := range schema .Tables {
283- for _ , column := range table .Columns {
284- if column .Type == oldType {
285- column .Type .Schema = * stmt .NewSchema
286- }
287- }
292+ c .updateTypeNames (func (t * ast.TypeName ) {
293+ if * t == oldType {
294+ t .Schema = * stmt .NewSchema
288295 }
289- }
296+ })
290297 return nil
291298}
292299
@@ -343,8 +350,9 @@ func (c *Catalog) renameType(stmt *ast.RenameTypeStmt) error {
343350
344351 case * CompositeType :
345352 schema .Types [idx ] = & CompositeType {
346- Name : newName ,
347- Comment : typ .Comment ,
353+ Name : newName ,
354+ ColTypeNames : typ .ColTypeNames ,
355+ Comment : typ .Comment ,
348356 }
349357
350358 case * Enum :
@@ -359,16 +367,33 @@ func (c *Catalog) renameType(stmt *ast.RenameTypeStmt) error {
359367
360368 }
361369
362- // Update all the table columns with the new type
370+ c .updateTypeNames (func (t * ast.TypeName ) {
371+ if * t == * stmt .Type {
372+ t .Name = newName
373+ }
374+ })
375+
376+ return nil
377+ }
378+
379+ func (c * Catalog ) updateTypeNames (typeUpdater func (t * ast.TypeName )) error {
363380 for _ , schema := range c .Schemas {
381+ // Update all the table columns with the new type
364382 for _ , table := range schema .Tables {
365383 for _ , column := range table .Columns {
366- if column .Type == * stmt .Type {
367- column .Type .Name = newName
368- }
384+ typeUpdater (& column .Type )
385+ }
386+ }
387+ // Update all the composite fields with the new type
388+ for _ , typ := range schema .Types {
389+ composite , ok := typ .(* CompositeType )
390+ if ! ok {
391+ continue
392+ }
393+ for _ , fieldType := range composite .ColTypeNames {
394+ typeUpdater (fieldType )
369395 }
370396 }
371397 }
372-
373398 return nil
374399}
0 commit comments