diff --git a/config.go b/config.go index f2161b8f..e1efb992 100644 --- a/config.go +++ b/config.go @@ -33,6 +33,7 @@ type Config struct { OutPath string // query code path OutFile string // query code file name, default: gen.go ModelPkgPath string // generated model code's package name + QueryPkgPath string // generated model code's package name WithUnitTest bool // generate unit test for query code // generate model global configuration @@ -120,6 +121,9 @@ func (cfg *Config) Revise() (err error) { if strings.TrimSpace(cfg.ModelPkgPath) == "" { cfg.ModelPkgPath = model.DefaultModelPkg } + if strings.TrimSpace(cfg.QueryPkgPath) == "" { + cfg.QueryPkgPath = model.DefaultQueryPkg + } cfg.OutPath, err = filepath.Abs(cfg.OutPath) if err != nil { @@ -129,11 +133,11 @@ func (cfg *Config) Revise() (err error) { cfg.OutPath = fmt.Sprintf(".%squery%s", string(os.PathSeparator), string(os.PathSeparator)) } if cfg.OutFile == "" { - cfg.OutFile = filepath.Join(cfg.OutPath, "gen.go") + cfg.OutFile = filepath.Join(cfg.OutPath, cfg.QueryPkgPath, "gen.go") } else if !strings.Contains(cfg.OutFile, string(os.PathSeparator)) { - cfg.OutFile = filepath.Join(cfg.OutPath, cfg.OutFile) + cfg.OutFile = filepath.Join(cfg.OutPath, cfg.QueryPkgPath, cfg.OutFile) } - cfg.queryPkgName = filepath.Base(cfg.OutPath) + cfg.queryPkgName = filepath.Base(cfg.QueryPkgPath) if cfg.db == nil { cfg.db, _ = gorm.Open(tests.DummyDialector{}) diff --git a/generator.go b/generator.go index 6f09bf59..b7ad794b 100644 --- a/generator.go +++ b/generator.go @@ -289,8 +289,9 @@ func (g *Generator) generateQueryFile() (err error) { return nil } - if err = os.MkdirAll(g.OutPath, os.ModePerm); err != nil { - return fmt.Errorf("make dir outpath(%s) fail: %s", g.OutPath, err) + queryOutPath := g.getQueryOutputPath() + if err = os.MkdirAll(queryOutPath, os.ModePerm); err != nil { + return fmt.Errorf("create query pkg path(%s) fail: %s", queryOutPath, err) } errChan := make(chan error) @@ -379,6 +380,10 @@ func (g *Generator) generateQueryFile() (err error) { return nil } +func (g *Generator) getQueryOutputPath() (outPath string) { + return filepath.Join(g.OutPath, g.QueryPkgPath) + string(os.PathSeparator) +} + // generateSingleQueryFile generate query code and save to file func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) { var buf bytes.Buffer @@ -425,8 +430,10 @@ func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) { return err } - defer g.info(fmt.Sprintf("generate query file: %s%s%s.gen.go", g.OutPath, string(os.PathSeparator), data.FileName)) - return g.output(fmt.Sprintf("%s%s%s.gen.go", g.OutPath, string(os.PathSeparator), data.FileName), buf.Bytes()) + outputPath := filepath.Join(g.OutPath, g.QueryPkgPath) + + defer g.info(fmt.Sprintf("generate query file: %s%s%s.gen.go", outputPath, string(os.PathSeparator), data.FileName)) + return g.output(fmt.Sprintf("%s%s%s.gen.go", outputPath, string(os.PathSeparator), data.FileName), buf.Bytes()) } // generateQueryUnitTestFile generate unit test file for query @@ -457,8 +464,10 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) { } } - defer g.info(fmt.Sprintf("generate unit test file: %s%s%s.gen_test.go", g.OutPath, string(os.PathSeparator), data.FileName)) - return g.output(fmt.Sprintf("%s%s%s.gen_test.go", g.OutPath, string(os.PathSeparator), data.FileName), buf.Bytes()) + outputPath := filepath.Join(g.OutPath, g.QueryPkgPath) + + defer g.info(fmt.Sprintf("generate unit test file: %s%s%s.gen_test.go", outputPath, string(os.PathSeparator), data.FileName)) + return g.output(fmt.Sprintf("%s%s%s.gen_test.go", outputPath, string(os.PathSeparator), data.FileName), buf.Bytes()) } // generateModelFile generate model structures and save to file @@ -467,12 +476,8 @@ func (g *Generator) generateModelFile() error { return nil } - modelOutPath, err := g.getModelOutputPath() - if err != nil { - return err - } - - if err = os.MkdirAll(modelOutPath, os.ModePerm); err != nil { + modelOutPath := g.getModelOutputPath() + if err := os.MkdirAll(modelOutPath, os.ModePerm); err != nil { return fmt.Errorf("create model pkg path(%s) fail: %s", modelOutPath, err) } @@ -512,7 +517,7 @@ func (g *Generator) generateModelFile() error { }(data) } select { - case err = <-errChan: + case err := <-errChan: return err case <-pool.AsyncWaitAll(): g.fillModelPkgPath(modelOutPath) @@ -520,16 +525,8 @@ func (g *Generator) generateModelFile() error { return nil } -func (g *Generator) getModelOutputPath() (outPath string, err error) { - if strings.Contains(g.ModelPkgPath, string(os.PathSeparator)) { - outPath, err = filepath.Abs(g.ModelPkgPath) - if err != nil { - return "", fmt.Errorf("cannot parse model pkg path: %w", err) - } - } else { - outPath = filepath.Join(filepath.Dir(g.OutPath), g.ModelPkgPath) - } - return outPath + string(os.PathSeparator), nil +func (g *Generator) getModelOutputPath() (outPath string) { + return filepath.Join(g.OutPath, g.ModelPkgPath) + string(os.PathSeparator) } func (g *Generator) fillModelPkgPath(filePath string) { diff --git a/generator_test.go b/generator_test.go index 1ab13f2c..4c0a8cb2 100644 --- a/generator_test.go +++ b/generator_test.go @@ -22,8 +22,7 @@ func TestConfig(t *testing.T) { OutFile: "", ModelPkgPath: "models", - - queryPkgName: "query", + QueryPkgPath: "query", } } diff --git a/internal/model/base.go b/internal/model/base.go index e5cdcfd3..bd809f6c 100644 --- a/internal/model/base.go +++ b/internal/model/base.go @@ -10,6 +10,8 @@ import ( const ( // DefaultModelPkg ... DefaultModelPkg = "model" + // DefaultQueryPkg ... + DefaultQueryPkg = "query" ) // Status sql status diff --git a/tests/generate_test.go b/tests/generate_test.go index 9b1bec22..8095f5ac 100644 --- a/tests/generate_test.go +++ b/tests/generate_test.go @@ -25,7 +25,7 @@ var _ = os.Setenv("GORM_DIALECT", "mysql") var generateCase = map[string]func(dir string) *gen.Generator{ generateDirPrefix + "dal_1": func(dir string) *gen.Generator { g := gen.NewGenerator(gen.Config{ - OutPath: dir + "/query", + OutPath: dir, Mode: gen.WithDefaultQuery, }) g.UseDB(DB) @@ -34,7 +34,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{ }, generateDirPrefix + "dal_2": func(dir string) *gen.Generator { g := gen.NewGenerator(gen.Config{ - OutPath: dir + "/query", + OutPath: dir, Mode: gen.WithDefaultQuery, WithUnitTest: true, @@ -50,7 +50,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{ }, generateDirPrefix + "dal_3": func(dir string) *gen.Generator { g := gen.NewGenerator(gen.Config{ - OutPath: dir + "/query", + OutPath: dir, Mode: gen.WithDefaultQuery | gen.WithQueryInterface, WithUnitTest: true, @@ -70,7 +70,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{ }, generateDirPrefix + "dal_4": func(dir string) *gen.Generator { g := gen.NewGenerator(gen.Config{ - OutPath: dir + "/query", + OutPath: dir, Mode: gen.WithDefaultQuery | gen.WithQueryInterface, WithUnitTest: true, @@ -88,7 +88,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{ }, generateDirPrefix + "dal_5": func(dir string) *gen.Generator { g := gen.NewGenerator(gen.Config{ - OutPath: dir + "/query", + OutPath: dir, Mode: gen.WithDefaultQuery | gen.WithQueryInterface, WithUnitTest: true, @@ -104,7 +104,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{ }, generateDirPrefix + "dal_6": func(dir string) *gen.Generator { g := gen.NewGenerator(gen.Config{ - OutPath: dir + "/query", + OutPath: dir, Mode: gen.WithDefaultQuery | gen.WithQueryInterface, WithUnitTest: true, diff --git a/tools/gentool/README.ZH_CN.md b/tools/gentool/README.ZH_CN.md index 9b3a15b0..6bf1c8c5 100644 --- a/tools/gentool/README.ZH_CN.md +++ b/tools/gentool/README.ZH_CN.md @@ -13,9 +13,9 @@ ## 使用方式 ```shell - - gentool -h - + + gentool -h + Usage of gentool: -db string input mysql or postgres or sqlite or sqlserver. consult[https://gorm.io/docs/connecting_to_the_database.html] (default "mysql") @@ -30,11 +30,13 @@ -fieldWithTypeTag generate field with gorm column type tag -modelPkgName string - generated model code's package name + generated model code's package name (default "model") + -queryPkgName string + generated query code's package name (default "query") -outFile string query code file name, default: gen.go -outPath string - specify a directory for output (default "./dao/query") + specify a directory for output (default "./dao") -tables string enter the required data table or leave it blank -onlyModel @@ -84,9 +86,17 @@ default "" #### modelPkgName -默认值是数据表名称。 +默认为:model + + 生成的model代码的包名称。 + 设置“outPath”后的路径。 + +#### queryPkgName + +默认为:query 生成的model代码的包名称。 + 设置“outPath”后的路径。 #### outFile diff --git a/tools/gentool/README.md b/tools/gentool/README.md index a8543b70..8568c130 100644 --- a/tools/gentool/README.md +++ b/tools/gentool/README.md @@ -28,11 +28,13 @@ Install GEN as a binary tool -fieldWithTypeTag generate field with gorm column type tag -modelPkgName string - generated model code's package name + generated model code's package name (default "model") + -queryPkgName string + generated query code's package name (default "query") -outFile string query code file name, default: gen.go -outPath string - specify a directory for output (default "./dao/query") + specify a directory for output (default "./dao") -tables string enter the required data table or leave it blank -onlyModel @@ -82,9 +84,17 @@ generate field with gorm column type tag #### modelPkgName -defalut table name. +default "model" generated model code's package name. + set the path after "outPath". + +#### queryPkgName + +default "query" + + generated query code's package name. + set the path after "outPath". #### outFile @@ -92,7 +102,7 @@ defalut table name. #### outPath -specify a directory for output (default "./dao/query") +specify a directory for output (default "./dao") #### tables diff --git a/tools/gentool/gen.yml b/tools/gentool/gen.yml index 5fa00f07..b778a60b 100644 --- a/tools/gentool/gen.yml +++ b/tools/gentool/gen.yml @@ -13,13 +13,15 @@ database: # only generate models (without query file) onlyModel : false # specify a directory for output - outPath : "./dao/query" + outPath : "./dao" # query code file name, default: gen.go outFile : "" # generate unit test for query code withUnitTest : false # generated model code's package name - modelPkgName : "" + modelPkgName : "model" + # generated query code's package name + queryPkgName : "query" # generate with pointer when field is nullable fieldNullable : false # generate with pointer when field has default value diff --git a/tools/gentool/gentool.go b/tools/gentool/gentool.go index 3d438867..2c4f7119 100644 --- a/tools/gentool/gentool.go +++ b/tools/gentool/gentool.go @@ -29,7 +29,7 @@ const ( dbClickHouse DBType = "clickhouse" ) const ( - defaultQueryPath = "./dao/query" + defaultQueryPath = "./dao" ) // CmdParams is command line parameters @@ -42,6 +42,7 @@ type CmdParams struct { OutFile string `yaml:"outFile"` // query code file name, default: gen.go WithUnitTest bool `yaml:"withUnitTest"` // generate unit test for query code ModelPkgName string `yaml:"modelPkgName"` // generated model code's package name + QueryPkgName string `yaml:"queryPkgName"` // generated query code's package name FieldNullable bool `yaml:"fieldNullable"` // generate with pointer when field is nullable FieldCoverable bool `yaml:"fieldCoverable"` // generate with pointer when field has default value FieldWithIndexTag bool `yaml:"fieldWithIndexTag"` // generate field with gorm index tag @@ -149,6 +150,7 @@ func argParse() *CmdParams { outFile := flag.String("outFile", "", "query code file name, default: gen.go") withUnitTest := flag.Bool("withUnitTest", false, "generate unit test for query code") modelPkgName := flag.String("modelPkgName", "", "generated model code's package name") + queryPkgName := flag.String("queryPkgName", "", "generated query code's package name") fieldNullable := flag.Bool("fieldNullable", false, "generate with pointer when field is nullable") fieldCoverable := flag.Bool("fieldCoverable", false, "generate with pointer when field has default value") fieldWithIndexTag := flag.Bool("fieldWithIndexTag", false, "generate field with gorm index tag") @@ -186,6 +188,9 @@ func argParse() *CmdParams { if *modelPkgName != "" { cmdParse.ModelPkgName = *modelPkgName } + if *queryPkgName != "" { + cmdParse.QueryPkgName = *queryPkgName + } if *fieldNullable { cmdParse.FieldNullable = *fieldNullable } @@ -220,6 +225,7 @@ func main() { OutPath: config.OutPath, OutFile: config.OutFile, ModelPkgPath: config.ModelPkgName, + QueryPkgPath: config.QueryPkgName, WithUnitTest: config.WithUnitTest, FieldNullable: config.FieldNullable, FieldCoverable: config.FieldCoverable,