@@ -476,9 +476,9 @@ func (m *Model) CtorRetVars() string {
476
476
func (m * Model ) SetFields (fields []* Field ) error {
477
477
var fs []* Field
478
478
var id * Field
479
- for _ , f := range fields {
479
+ for _ , f := range flattenFields ( fields ) {
480
480
f .Model = m
481
- if f .IsPrimaryKey () {
481
+ if f .IsPrimaryKey () && f . Type != BaseModel {
482
482
if id != nil {
483
483
return fmt .Errorf (
484
484
"kallax: found more than one primary key in model %s: %s and %s" ,
@@ -489,14 +489,56 @@ func (m *Model) SetFields(fields []*Field) error {
489
489
}
490
490
491
491
id = f
492
- m .ID = f
492
+ } else if f .IsPrimaryKey () {
493
+ if f .primaryKey == "" {
494
+ return fmt .Errorf (
495
+ "kallax: primary key defined in %s has no field name, but it must be specified" ,
496
+ f .Name ,
497
+ )
498
+ }
499
+
500
+ // the pk is defined in the model, we need to collect the model
501
+ // and we'll look for the field afterwards, when we have collected
502
+ // all fields. The model is appended to the field set, though,
503
+ // because it will not act as a primary key.
504
+ id = f
505
+ fs = append (fs , f )
493
506
} else {
494
507
fs = append (fs , f )
495
508
}
496
509
}
497
510
511
+ // if the id is a Model we need to look for the specified field
512
+ if id != nil && id .Type == BaseModel {
513
+ for i , f := range fs {
514
+ if f .columnName == id .primaryKey {
515
+ f .isPrimaryKey = true
516
+ f .isAutoincrement = id .isAutoincrement
517
+ id = f
518
+
519
+ if len (fs )- 1 == i {
520
+ fs = append (fs [:i ])
521
+ } else {
522
+ fs = append (fs [:i ], fs [i + 1 :]... )
523
+ }
524
+ break
525
+ }
526
+ }
527
+
528
+ // If the ID is still a base model, means we did not find the pk
529
+ // field.
530
+ if id .Type == BaseModel {
531
+ return fmt .Errorf (
532
+ "kallax: the primary key was supposed to be %s according to the pk definition in %s, but the field could not be found" ,
533
+ id .primaryKey ,
534
+ id .Name ,
535
+ )
536
+ }
537
+ }
538
+
498
539
if id != nil {
499
540
m .Fields = []* Field {id }
541
+ m .ID = id
500
542
}
501
543
m .Fields = append (m .Fields , fs ... )
502
544
return nil
@@ -556,6 +598,8 @@ func relationshipsOnFields(fields []*Field) []*Field {
556
598
return result
557
599
}
558
600
601
+ // ImplicitFK is a foreign key that is defined on just one side of the
602
+ // relationship and needs to be added on the other side.
559
603
type ImplicitFK struct {
560
604
Name string
561
605
Type string
@@ -590,6 +634,11 @@ type Field struct {
590
634
// A struct is considered embedded if and only if the struct was embedded
591
635
// as defined in Go.
592
636
IsEmbedded bool
637
+
638
+ primaryKey string
639
+ isPrimaryKey bool
640
+ isAutoincrement bool
641
+ columnName string
593
642
}
594
643
595
644
// FieldKind is the kind of a field.
@@ -645,13 +694,49 @@ func (t FieldKind) String() string {
645
694
646
695
// NewField creates a new field with its name, type and struct tag.
647
696
func NewField (n , t string , tag reflect.StructTag ) * Field {
697
+ pkName , autoincr , isPrimaryKey := pkProperties (tag )
698
+
648
699
return & Field {
649
700
Name : n ,
650
701
Type : t ,
651
702
Tag : tag ,
703
+
704
+ primaryKey : pkName ,
705
+ columnName : columnName (n , tag ),
706
+ isPrimaryKey : isPrimaryKey ,
707
+ isAutoincrement : autoincr ,
652
708
}
653
709
}
654
710
711
+ // pkProperties returns the primary key properties from a struct tag.
712
+ // Valid primary key definitions are the following:
713
+ // - pk:"" -> non-autoincr primary key without a field name.
714
+ // - pk:"autoincr" -> autoincr primary key without a field name.
715
+ // - pk:"foobar" -> non-autoincr primary key with a field name.
716
+ // - pk:"foobar,autoincr" -> autoincr primary key with a field name.
717
+ func pkProperties (tag reflect.StructTag ) (name string , autoincr , isPrimaryKey bool ) {
718
+ val , ok := tag .Lookup ("pk" )
719
+ if ! ok {
720
+ return
721
+ }
722
+
723
+ isPrimaryKey = true
724
+ if val == "autoincr" || val == "" {
725
+ if val == "autoincr" {
726
+ autoincr = true
727
+ }
728
+ return
729
+ }
730
+
731
+ parts := strings .Split (val , "," )
732
+ name = parts [0 ]
733
+ if len (parts ) > 1 && parts [1 ] == "autoincr" {
734
+ autoincr = true
735
+ }
736
+
737
+ return
738
+ }
739
+
655
740
// SetFields sets all the children fields and the current field as a parent of
656
741
// the children.
657
742
func (f * Field ) SetFields (sf []* Field ) {
@@ -667,16 +752,20 @@ func (f *Field) SetFields(sf []*Field) {
667
752
// is the field name converted to lower snake case.
668
753
// If the resultant name is a reserved keyword a _ will be prepended to the name.
669
754
func (f * Field ) ColumnName () string {
670
- name := strings .TrimSpace (strings .Split (f .Tag .Get ("kallax" ), "," )[0 ])
671
- if name == "" {
672
- name = toLowerSnakeCase (f .Name )
755
+ return f .columnName
756
+ }
757
+
758
+ func columnName (name string , tag reflect.StructTag ) string {
759
+ n := strings .TrimSpace (strings .Split (tag .Get ("kallax" ), "," )[0 ])
760
+ if n == "" {
761
+ n = toLowerSnakeCase (name )
673
762
}
674
763
675
- if _ , ok := reservedKeywords [strings .ToLower (name )]; ok {
676
- name = "_" + name
764
+ if _ , ok := reservedKeywords [strings .ToLower (n )]; ok {
765
+ n = "_" + n
677
766
}
678
767
679
- return name
768
+ return n
680
769
}
681
770
682
771
// ForeignKey returns the name of the foreign keys as specified in the struct
@@ -699,13 +788,12 @@ func (f *Field) ForeignKey() string {
699
788
700
789
// IsPrimaryKey reports whether the field is the primary key.
701
790
func (f * Field ) IsPrimaryKey () bool {
702
- _ , ok := f .Tag .Lookup ("pk" )
703
- return ok
791
+ return f .isPrimaryKey
704
792
}
705
793
706
794
// IsAutoIncrement reports whether the field is an autoincrementable primary key.
707
795
func (f * Field ) IsAutoIncrement () bool {
708
- return f .Tag . Get ( "pk" ) == "autoincr"
796
+ return f .isAutoincrement
709
797
}
710
798
711
799
// IsInverse returns whether the field is an inverse relationship.
@@ -1003,6 +1091,22 @@ func toLowerSnakeCase(s string) string {
1003
1091
return buf .String ()
1004
1092
}
1005
1093
1094
+ // flattenFields will recursively flatten all fields removing the embedded ones
1095
+ // from the field set.
1096
+ func flattenFields (fields []* Field ) []* Field {
1097
+ var result = make ([]* Field , 0 , len (fields ))
1098
+
1099
+ for _ , f := range fields {
1100
+ if f .IsEmbedded && f .Type != BaseModel {
1101
+ result = append (result , flattenFields (f .Fields )... )
1102
+ } else {
1103
+ result = append (result , f )
1104
+ }
1105
+ }
1106
+
1107
+ return result
1108
+ }
1109
+
1006
1110
// Event is the name of an event.
1007
1111
type Event string
1008
1112
0 commit comments