diff --git a/generator.go b/generator.go index 6871193e..c9441f4b 100644 --- a/generator.go +++ b/generator.go @@ -340,6 +340,17 @@ func (g *Generator) generateQueryFile() (err error) { return err } + if g.judgeMode(WithQueryInterface) { + queryInterfaceTmpl := tmpl.QueryInterfaceWithContext + if g.judgeMode(WithoutContext) { + queryInterfaceTmpl = tmpl.QueryInterfaceWithoutContext + } + err = render(queryInterfaceTmpl, &buf, g) + if err != nil { + return err + } + } + err = g.output(g.OutFile, buf.Bytes()) if err != nil { return err diff --git a/internal/template/query.go b/internal/template/query.go index 390eb635..9d62d3e7 100644 --- a/internal/template/query.go +++ b/internal/template/query.go @@ -111,6 +111,42 @@ func (q *QueryTx) RollbackTo(name string) error { ` +// QueryInterfaceWithContext interface method template with context +const QueryInterfaceWithContext = ` +var _ IQuery = (*Query)(nil) + +type IQuery interface{ + {{range $name,$d :=.Data -}} + {{$d.ModelStructName}}Do(ctx context.Context) I{{$d.ModelStructName}}Do + {{end -}} +} + +{{range $name,$d :=.Data -}} +func (q *Query) {{$d.ModelStructName}}Do(ctx context.Context) I{{$d.ModelStructName}}Do { + return q.{{$d.ModelStructName}}.WithContext(ctx) +} + +{{end -}} +` + +// QueryInterfaceWithoutContext interface method template without context +const QueryInterfaceWithoutContext = ` +var _ IQuery = (*Query)(nil) + +type IQuery interface{ + {{range $name,$d :=.Data -}} + {{$d.ModelStructName}}Do() I{{$d.ModelStructName}}Do + {{end -}} +} + +{{range $name,$d :=.Data -}} +func (q *Query) {{$d.ModelStructName}}Do() I{{$d.ModelStructName}}Do { + return &q.{{$d.ModelStructName}}.{{$d.QueryStructName}}Do +} + +{{end -}} +` + // QueryMethodTest query method test template const QueryMethodTest = ` diff --git a/tests/.expect/dal_3/query/gen.go b/tests/.expect/dal_3/query/gen.go index 68fa07c1..e158be69 100644 --- a/tests/.expect/dal_3/query/gen.go +++ b/tests/.expect/dal_3/query/gen.go @@ -133,3 +133,33 @@ func (q *QueryTx) SavePoint(name string) error { func (q *QueryTx) RollbackTo(name string) error { return q.db.RollbackTo(name).Error } + +var _ IQuery = (*Query)(nil) + +type IQuery interface { + BankDo(ctx context.Context) IBankDo + CreditCardDo(ctx context.Context) ICreditCardDo + CustomerDo(ctx context.Context) ICustomerDo + PersonDo(ctx context.Context) IPersonDo + UserDo(ctx context.Context) IUserDo +} + +func (q *Query) BankDo(ctx context.Context) IBankDo { + return q.Bank.WithContext(ctx) +} + +func (q *Query) CreditCardDo(ctx context.Context) ICreditCardDo { + return q.CreditCard.WithContext(ctx) +} + +func (q *Query) CustomerDo(ctx context.Context) ICustomerDo { + return q.Customer.WithContext(ctx) +} + +func (q *Query) PersonDo(ctx context.Context) IPersonDo { + return q.Person.WithContext(ctx) +} + +func (q *Query) UserDo(ctx context.Context) IUserDo { + return q.User.WithContext(ctx) +} diff --git a/tests/.expect/dal_4/query/gen.go b/tests/.expect/dal_4/query/gen.go index 68fa07c1..e158be69 100644 --- a/tests/.expect/dal_4/query/gen.go +++ b/tests/.expect/dal_4/query/gen.go @@ -133,3 +133,33 @@ func (q *QueryTx) SavePoint(name string) error { func (q *QueryTx) RollbackTo(name string) error { return q.db.RollbackTo(name).Error } + +var _ IQuery = (*Query)(nil) + +type IQuery interface { + BankDo(ctx context.Context) IBankDo + CreditCardDo(ctx context.Context) ICreditCardDo + CustomerDo(ctx context.Context) ICustomerDo + PersonDo(ctx context.Context) IPersonDo + UserDo(ctx context.Context) IUserDo +} + +func (q *Query) BankDo(ctx context.Context) IBankDo { + return q.Bank.WithContext(ctx) +} + +func (q *Query) CreditCardDo(ctx context.Context) ICreditCardDo { + return q.CreditCard.WithContext(ctx) +} + +func (q *Query) CustomerDo(ctx context.Context) ICustomerDo { + return q.Customer.WithContext(ctx) +} + +func (q *Query) PersonDo(ctx context.Context) IPersonDo { + return q.Person.WithContext(ctx) +} + +func (q *Query) UserDo(ctx context.Context) IUserDo { + return q.User.WithContext(ctx) +} diff --git a/tests/.expect/dal_5/model/users.gen.go b/tests/.expect/dal_5/model/users.gen.go index d65d4f54..103aa298 100644 --- a/tests/.expect/dal_5/model/users.gen.go +++ b/tests/.expect/dal_5/model/users.gen.go @@ -14,12 +14,17 @@ const TableNameUser = "users" type User struct { ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` CreatedAt *time.Time `gorm:"column:created_at" json:"-"` - Name *string `gorm:"column:name;index:idx_name,priority:1;comment:oneline" json:"-"` + Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1;comment:oneline" json:"-"` // oneline Address *string `gorm:"column:address" json:"-"` RegisterTime *time.Time `gorm:"column:register_time" json:"-"` - Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` - CompanyID *int64 `gorm:"column:company_id;default:666" json:"-"` - PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` + /* + multiline + line1 + line2 + */ + Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` + CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` + PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` } func (m *User) IsEmpty() bool { diff --git a/tests/.expect/dal_5/query/gen.go b/tests/.expect/dal_5/query/gen.go index a19d931b..6b948dc7 100644 --- a/tests/.expect/dal_5/query/gen.go +++ b/tests/.expect/dal_5/query/gen.go @@ -101,3 +101,13 @@ func (q *QueryTx) SavePoint(name string) error { func (q *QueryTx) RollbackTo(name string) error { return q.db.RollbackTo(name).Error } + +var _ IQuery = (*Query)(nil) + +type IQuery interface { + UserDo(ctx context.Context) IUserDo +} + +func (q *Query) UserDo(ctx context.Context) IUserDo { + return q.User.WithContext(ctx) +} diff --git a/tests/.expect/dal_5/query/users.gen.go b/tests/.expect/dal_5/query/users.gen.go index 3d2ffb34..c2c456f1 100644 --- a/tests/.expect/dal_5/query/users.gen.go +++ b/tests/.expect/dal_5/query/users.gen.go @@ -94,6 +94,8 @@ func (u user) TableName() string { return u.userDo.TableName() } func (u user) Alias() string { return u.userDo.Alias() } +func (u user) Columns(cols ...field.Expr) gen.Columns { return u.userDo.Columns(cols...) } + func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) { _f, ok := u.fieldMap[fieldName] if !ok || _f == nil { diff --git a/tests/.expect/dal_6/model/users.gen.go b/tests/.expect/dal_6/model/users.gen.go index d65d4f54..103aa298 100644 --- a/tests/.expect/dal_6/model/users.gen.go +++ b/tests/.expect/dal_6/model/users.gen.go @@ -14,12 +14,17 @@ const TableNameUser = "users" type User struct { ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` CreatedAt *time.Time `gorm:"column:created_at" json:"-"` - Name *string `gorm:"column:name;index:idx_name,priority:1;comment:oneline" json:"-"` + Name *string `gorm:"column:name;index:idx_name,priority:1;index:idx_name_company_id,priority:1;comment:oneline" json:"-"` // oneline Address *string `gorm:"column:address" json:"-"` RegisterTime *time.Time `gorm:"column:register_time" json:"-"` - Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` - CompanyID *int64 `gorm:"column:company_id;default:666" json:"-"` - PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` + /* + multiline + line1 + line2 + */ + Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` + CompanyID *int64 `gorm:"column:company_id;index:idx_name_company_id,priority:2;default:666" json:"-"` + PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` } func (m *User) IsEmpty() bool { diff --git a/tests/.expect/dal_6/query/gen.go b/tests/.expect/dal_6/query/gen.go index a19d931b..6b948dc7 100644 --- a/tests/.expect/dal_6/query/gen.go +++ b/tests/.expect/dal_6/query/gen.go @@ -101,3 +101,13 @@ func (q *QueryTx) SavePoint(name string) error { func (q *QueryTx) RollbackTo(name string) error { return q.db.RollbackTo(name).Error } + +var _ IQuery = (*Query)(nil) + +type IQuery interface { + UserDo(ctx context.Context) IUserDo +} + +func (q *Query) UserDo(ctx context.Context) IUserDo { + return q.User.WithContext(ctx) +} diff --git a/tests/.expect/dal_6/query/users.gen.go b/tests/.expect/dal_6/query/users.gen.go index f79fff04..a44672e0 100644 --- a/tests/.expect/dal_6/query/users.gen.go +++ b/tests/.expect/dal_6/query/users.gen.go @@ -94,6 +94,8 @@ func (u user) TableName() string { return u.userDo.TableName() } func (u user) Alias() string { return u.userDo.Alias() } +func (u user) Columns(cols ...field.Expr) gen.Columns { return u.userDo.Columns(cols...) } + func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) { _f, ok := u.fieldMap[fieldName] if !ok || _f == nil { diff --git a/tests/go.mod b/tests/go.mod index e9e98943..c6d49b37 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -4,11 +4,10 @@ go 1.16 require ( github.com/mattn/go-sqlite3 v1.14.16 // indirect - golang.org/x/tools v0.5.0 // indirect - gorm.io/driver/mysql v1.4.5 + gorm.io/driver/mysql v1.5.1-0.20230509030346-3715c134c25b gorm.io/driver/sqlite v1.4.4 gorm.io/gen v0.3.19 - gorm.io/gorm v1.24.3 + gorm.io/gorm v1.25.1-0.20230505075827-e61b98d69677 gorm.io/hints v1.1.1 // indirect gorm.io/plugin/dbresolver v1.4.0 )