Skip to content

Commit 0977ec4

Browse files
committed
update
1 parent 3457f32 commit 0977ec4

File tree

4 files changed

+9
-19
lines changed

4 files changed

+9
-19
lines changed

pkg/sql/codegen/tensorflow/codegen.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,20 +238,10 @@ func Pred(ir *codegen.PredictIR) (string, error) {
238238
}
239239
isKeras, estimatorStr := isKerasModel(ir.TrainIR.Estimator)
240240

241-
resultTableParts := strings.Split(ir.ResultTable, ".")
242-
resultTable := ""
243-
if len(resultTableParts) == 3 {
244-
resultTable = strings.Join(resultTableParts[0:2], ".")
245-
} else if len(resultTableParts) == 2 || len(resultTableParts) == 1 {
246-
resultTable = ir.ResultTable
247-
} else {
248-
return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table")
249-
}
250-
251241
filler := predFiller{
252242
DataSource: ir.DataSource,
253243
Select: ir.Select,
254-
ResultTable: resultTable,
244+
ResultTable: ir.ResultTable,
255245
Estimator: estimatorStr,
256246
IsKerasModel: isKeras,
257247
FieldMetas: fieldMetas,

pkg/sql/codegen/xgboost/codegen.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
2323

2424
"sqlflow.org/sqlflow/pkg/sql/codegen"
25-
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
2625
)
2726

2827
func newFloat32(f float32) *float32 {

pkg/sql/executor.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import (
2929
pb "sqlflow.org/sqlflow/pkg/server/proto"
3030
"sqlflow.org/sqlflow/pkg/sql/codegen/tensorflow"
3131
"sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
32-
xgb "sqlflow.org/sqlflow/pkg/sql/codegen/xgboost"
3332
)
3433

3534
// Run executes a SQL query and returns a stream of rows or messages
@@ -419,7 +418,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
419418
if err != nil {
420419
return err
421420
}
422-
code, err := xgb.Train(ir)
421+
code, err := xgboost.Train(ir)
423422
if err != nil {
424423
return err
425424
}

pkg/sql/ir_generator.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ func generatePredictIR(slct *extendedSelect, connStr string, cwd string, modelDi
9696
if err != nil {
9797
return nil, err
9898
}
99-
fmt.Printf("select %s, result table: %s\n", slct.standardSelect.String(), slct.into)
10099

101100
resultTable, err := parseResultTable(slct.into)
102101
if err != nil {
@@ -564,15 +563,18 @@ func parseShape(e *expr) ([]int, error) {
564563
return shape, nil
565564
}
566565

566+
// parseResultTable parse out the table name from the INTO statment
567+
// as the following 3 cases:
568+
// db.table.class_col -> db.table # cut the column name
569+
// db.table -> db.table # using the specified db
570+
// table -> table # using the default db
567571
func parseResultTable(intoStatement string) (string, error) {
568572
resultTableParts := strings.Split(intoStatement, ".")
569-
resultTable := ""
570573
if len(resultTableParts) == 3 {
571-
resultTable = strings.Join(resultTableParts[0:2], ".")
574+
return strings.Join(resultTableParts[0:2], "."), nil
572575
} else if len(resultTableParts) == 2 || len(resultTableParts) == 1 {
573-
resultTable = intoStatement
576+
return intoStatement, nil
574577
} else {
575578
return "", fmt.Errorf("error result table format, should be db.table.class_col or db.table or table")
576579
}
577-
return resultTable, nil
578580
}

0 commit comments

Comments
 (0)