Skip to content

Commit 7623acd

Browse files
committed
allow definition of primary keys directly on kallax.Model tags
Signed-off-by: Miguel Molina <[email protected]>
1 parent 7e544e6 commit 7623acd

File tree

7 files changed

+229
-73
lines changed

7 files changed

+229
-73
lines changed

generator/common_test.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
parseutil "gopkg.in/src-d/go-parse-utils.v1"
1111
)
1212

13-
func mkField(name, typ string, fields ...*Field) *Field {
14-
f := NewField(name, typ, reflect.StructTag(""))
13+
func mkField(name, typ, tag string, fields ...*Field) *Field {
14+
f := NewField(name, typ, reflect.StructTag(tag))
1515
f.SetFields(fields)
1616
return f
1717
}
@@ -46,13 +46,9 @@ func withNode(f *Field, name string, typ types.Type) *Field {
4646
return f
4747
}
4848

49-
func withTag(f *Field, tag string) *Field {
50-
f.Tag = reflect.StructTag(tag)
51-
return f
52-
}
53-
5449
func inline(f *Field) *Field {
55-
return withTag(f, `kallax:",inline"`)
50+
f.Tag = reflect.StructTag(`kallax:",inline"`)
51+
return f
5652
}
5753

5854
func processorFixture(source string) (*Processor, error) {

generator/processor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ func (p *Processor) processModel(name string, s *types.Struct, t *types.Named) (
201201
return nil, nil
202202
}
203203

204+
p.processBaseField(m, fields[base])
204205
if err := m.SetFields(fields); err != nil {
205206
return nil, err
206207
}
207208

208-
p.processBaseField(m, fields[base])
209209
return m, nil
210210
}
211211

generator/processor_test.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -384,19 +384,23 @@ func (s *ProcessorSuite) TestIsEmbedded() {
384384
type Bar struct {
385385
kallax.Model
386386
ID int64 ` + "`pk:\"autoincr\"`" + `
387-
Bar string
387+
Baz string
388388
}
389389
390390
type Struct struct {
391-
Bar Bar
391+
Qux Bar
392+
}
393+
394+
type Struct2 struct {
395+
Mux string
392396
}
393397
394398
type Foo struct {
395399
kallax.Model
396400
ID int64 ` + "`pk:\"autoincr\"`" + `
397401
A Bar
398402
B *Bar
399-
Bar
403+
Struct2
400404
*Struct
401405
C struct {
402406
D int
@@ -405,21 +409,16 @@ func (s *ProcessorSuite) TestIsEmbedded() {
405409
`
406410
pkg := s.processFixture(src)
407411
m := findModel(pkg, "Foo")
408-
cases := []struct {
409-
field string
410-
embedded bool
411-
}{
412-
{"Model", true},
413-
{"A", false},
414-
{"B", false},
415-
{"Bar", true},
416-
{"Struct", true},
417-
{"C", false},
412+
expected := []string{
413+
"ID", "Model", "A", "B", "Mux", "Qux", "C",
418414
}
419415

420-
for _, c := range cases {
421-
s.Equal(c.embedded, findField(m, c.field).IsEmbedded, c.field)
416+
var names []string
417+
for _, f := range m.Fields {
418+
names = append(names, f.Name)
422419
}
420+
421+
s.Equal(expected, names)
423422
}
424423

425424
func TestProcessor(t *testing.T) {

generator/template.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,16 @@ func addTemplate(base *template.Template, name string, filename string) *templat
512512
return template.Must(base.New(name).Parse(text))
513513
}
514514

515-
var base *template.Template = makeTemplate("base", "templates/base.tgo")
516-
var schema *template.Template = addTemplate(base, "schema", "templates/schema.tgo")
517-
var model *template.Template = addTemplate(base, "model", "templates/model.tgo")
518-
var query *template.Template = addTemplate(model, "query", "templates/query.tgo")
519-
var resultset *template.Template = addTemplate(model, "resultset", "templates/resultset.tgo")
515+
var (
516+
base = makeTemplate("base", "templates/base.tgo")
517+
schema = addTemplate(base, "schema", "templates/schema.tgo")
518+
model = addTemplate(base, "model", "templates/model.tgo")
519+
query = addTemplate(model, "query", "templates/query.tgo")
520+
resultset = addTemplate(model, "resultset", "templates/resultset.tgo")
521+
)
520522

521523
// Base is the default Template instance with all templates preloaded.
522-
var Base *Template = &Template{template: base}
524+
var Base = &Template{template: base}
523525

524526
const (
525527
// tplFindByCollection is the template of the FindBy autogenerated for
@@ -709,10 +711,9 @@ func shortName(pkg *types.Package, typ types.Type) string {
709711

710712
if specialName, ok := specialTypeShortName(typ); ok {
711713
return prefix + specialName
712-
} else {
713-
shortName := typeString(typ, pkg)
714-
return prefix + strings.Replace(shortName, "*", "", -1)
715714
}
715+
shortName := typeString(typ, pkg)
716+
return prefix + strings.Replace(shortName, "*", "", -1)
716717
}
717718

718719
// isEqualizable returns true if the autogenerated FindBy will use an equal query

generator/types.go

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,9 @@ func (m *Model) CtorRetVars() string {
476476
func (m *Model) SetFields(fields []*Field) error {
477477
var fs []*Field
478478
var id *Field
479-
for _, f := range fields {
479+
for _, f := range flattenFields(fields) {
480480
f.Model = m
481-
if f.IsPrimaryKey() {
481+
if f.IsPrimaryKey() && f.Type != BaseModel {
482482
if id != nil {
483483
return fmt.Errorf(
484484
"kallax: found more than one primary key in model %s: %s and %s",
@@ -489,14 +489,56 @@ func (m *Model) SetFields(fields []*Field) error {
489489
}
490490

491491
id = f
492-
m.ID = f
492+
} else if f.IsPrimaryKey() {
493+
if f.primaryKey == "" {
494+
return fmt.Errorf(
495+
"kallax: primary key defined in %s has no field name, but it must be specified",
496+
f.Name,
497+
)
498+
}
499+
500+
// the pk is defined in the model, we need to collect the model
501+
// and we'll look for the field afterwards, when we have collected
502+
// all fields. The model is appended to the field set, though,
503+
// because it will not act as a primary key.
504+
id = f
505+
fs = append(fs, f)
493506
} else {
494507
fs = append(fs, f)
495508
}
496509
}
497510

511+
// if the id is a Model we need to look for the specified field
512+
if id != nil && id.Type == BaseModel {
513+
for i, f := range fs {
514+
if f.columnName == id.primaryKey {
515+
f.isPrimaryKey = true
516+
f.isAutoincrement = id.isAutoincrement
517+
id = f
518+
519+
if len(fs)-1 == i {
520+
fs = append(fs[:i])
521+
} else {
522+
fs = append(fs[:i], fs[i+1:]...)
523+
}
524+
break
525+
}
526+
}
527+
528+
// If the ID is still a base model, means we did not find the pk
529+
// field.
530+
if id.Type == BaseModel {
531+
return fmt.Errorf(
532+
"kallax: the primary key was supposed to be %s according to the pk definition in %s, but the field could not be found",
533+
id.primaryKey,
534+
id.Name,
535+
)
536+
}
537+
}
538+
498539
if id != nil {
499540
m.Fields = []*Field{id}
541+
m.ID = id
500542
}
501543
m.Fields = append(m.Fields, fs...)
502544
return nil
@@ -556,6 +598,8 @@ func relationshipsOnFields(fields []*Field) []*Field {
556598
return result
557599
}
558600

601+
// ImplicitFK is a foreign key that is defined on just one side of the
602+
// relationship and needs to be added on the other side.
559603
type ImplicitFK struct {
560604
Name string
561605
Type string
@@ -590,6 +634,11 @@ type Field struct {
590634
// A struct is considered embedded if and only if the struct was embedded
591635
// as defined in Go.
592636
IsEmbedded bool
637+
638+
primaryKey string
639+
isPrimaryKey bool
640+
isAutoincrement bool
641+
columnName string
593642
}
594643

595644
// FieldKind is the kind of a field.
@@ -645,13 +694,49 @@ func (t FieldKind) String() string {
645694

646695
// NewField creates a new field with its name, type and struct tag.
647696
func NewField(n, t string, tag reflect.StructTag) *Field {
697+
pkName, autoincr, isPrimaryKey := pkProperties(tag)
698+
648699
return &Field{
649700
Name: n,
650701
Type: t,
651702
Tag: tag,
703+
704+
primaryKey: pkName,
705+
columnName: columnName(n, tag),
706+
isPrimaryKey: isPrimaryKey,
707+
isAutoincrement: autoincr,
652708
}
653709
}
654710

711+
// pkProperties returns the primary key properties from a struct tag.
712+
// Valid primary key definitions are the following:
713+
// - pk:"" -> non-autoincr primary key without a field name.
714+
// - pk:"autoincr" -> autoincr primary key without a field name.
715+
// - pk:"foobar" -> non-autoincr primary key with a field name.
716+
// - pk:"foobar,autoincr" -> autoincr primary key with a field name.
717+
func pkProperties(tag reflect.StructTag) (name string, autoincr, isPrimaryKey bool) {
718+
val, ok := tag.Lookup("pk")
719+
if !ok {
720+
return
721+
}
722+
723+
isPrimaryKey = true
724+
if val == "autoincr" || val == "" {
725+
if val == "autoincr" {
726+
autoincr = true
727+
}
728+
return
729+
}
730+
731+
parts := strings.Split(val, ",")
732+
name = parts[0]
733+
if len(parts) > 1 && parts[1] == "autoincr" {
734+
autoincr = true
735+
}
736+
737+
return
738+
}
739+
655740
// SetFields sets all the children fields and the current field as a parent of
656741
// the children.
657742
func (f *Field) SetFields(sf []*Field) {
@@ -667,16 +752,20 @@ func (f *Field) SetFields(sf []*Field) {
667752
// is the field name converted to lower snake case.
668753
// If the resultant name is a reserved keyword a _ will be prepended to the name.
669754
func (f *Field) ColumnName() string {
670-
name := strings.TrimSpace(strings.Split(f.Tag.Get("kallax"), ",")[0])
671-
if name == "" {
672-
name = toLowerSnakeCase(f.Name)
755+
return f.columnName
756+
}
757+
758+
func columnName(name string, tag reflect.StructTag) string {
759+
n := strings.TrimSpace(strings.Split(tag.Get("kallax"), ",")[0])
760+
if n == "" {
761+
n = toLowerSnakeCase(name)
673762
}
674763

675-
if _, ok := reservedKeywords[strings.ToLower(name)]; ok {
676-
name = "_" + name
764+
if _, ok := reservedKeywords[strings.ToLower(n)]; ok {
765+
n = "_" + n
677766
}
678767

679-
return name
768+
return n
680769
}
681770

682771
// ForeignKey returns the name of the foreign keys as specified in the struct
@@ -699,13 +788,12 @@ func (f *Field) ForeignKey() string {
699788

700789
// IsPrimaryKey reports whether the field is the primary key.
701790
func (f *Field) IsPrimaryKey() bool {
702-
_, ok := f.Tag.Lookup("pk")
703-
return ok
791+
return f.isPrimaryKey
704792
}
705793

706794
// IsAutoIncrement reports whether the field is an autoincrementable primary key.
707795
func (f *Field) IsAutoIncrement() bool {
708-
return f.Tag.Get("pk") == "autoincr"
796+
return f.isAutoincrement
709797
}
710798

711799
// IsInverse returns whether the field is an inverse relationship.
@@ -1003,6 +1091,22 @@ func toLowerSnakeCase(s string) string {
10031091
return buf.String()
10041092
}
10051093

1094+
// flattenFields will recursively flatten all fields removing the embedded ones
1095+
// from the field set.
1096+
func flattenFields(fields []*Field) []*Field {
1097+
var result = make([]*Field, 0, len(fields))
1098+
1099+
for _, f := range fields {
1100+
if f.IsEmbedded && f.Type != BaseModel {
1101+
result = append(result, flattenFields(f.Fields)...)
1102+
} else {
1103+
result = append(result, f)
1104+
}
1105+
}
1106+
1107+
return result
1108+
}
1109+
10061110
// Event is the name of an event.
10071111
type Event string
10081112

0 commit comments

Comments
 (0)