Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions cmd/sqlflowserver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func AssertGreaterEqualAny(a *assert.Assertions, actual *any.Any, expected inter
case "type.googleapis.com/google.protobuf.FloatValue":
b := wrappers.FloatValue{}
ptypes.UnmarshalAny(actual, &b)
a.GreaterOrEqual(float32(expected.(float64)), b.Value)
a.GreaterOrEqual(b.Value, float32(expected.(float64)))
}
}

Expand Down Expand Up @@ -189,7 +189,10 @@ func prepareTestData(dbStr string) error {
if err := testdata.Popularize(testDB.DB, testdata.IrisHiveSQL); err != nil {
return err
}
return testdata.Popularize(testDB.DB, testdata.ChurnHiveSQL)
if err = testdata.Popularize(testDB.DB, testdata.ChurnHiveSQL); err != nil {
return err
}
return testdata.Popularize(testDB.DB, testdata.HousingSQL)
case "maxcompute":
submitter := os.Getenv("SQLFLOW_submitter")
if submitter == "alps" {
Expand Down Expand Up @@ -381,6 +384,38 @@ func TestEnd2EndHive(t *testing.T) {
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
}

func TestEnd2EndHiveIR(t *testing.T) {
if os.Getenv("SQLFLOW_codegen") != "ir" {
t.Skip("Skipping ir test")
}

if os.Getenv("SQLFLOW_TEST_DB") != "hive" {
t.Skip("Skipping hive tests")
}

modelDir := ""
tmpDir, caCrt, caKey, err := generateTempCA()
defer os.RemoveAll(tmpDir)
if err != nil {
t.Fatalf("failed to generate CA pair %v", err)
}

dbConnStr = "hive://root:[email protected]:10000/iris?auth=NOSASL"
go start("", modelDir, caCrt, caKey, true, unitestPort)
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)
err = prepareTestData(dbConnStr)
if err != nil {
t.Fatalf("prepare test dataset failed: %v", err)
}
t.Run("TestShowDatabases", CaseShowDatabases)
t.Run("TestSelect", CaseSelect)
t.Run("TestTrainSQL", CaseTrainSQL)
t.Run("CaseTrainCustomModel", CaseTrainCustomModel)
t.Run("CaseTrainDeepWideModel", CaseTrainDeepWideModel)
t.Run("CaseTrainXGBoostRegression", CaseTrainXGBoostRegression)
t.Run("CasePredictXGBoostRegression", CasePredictXGBoostRegression)
}

func TestEnd2EndMaxCompute(t *testing.T) {
testDBDriver := os.Getenv("SQLFLOW_TEST_DB")
modelDir, _ := ioutil.TempDir("/tmp", "sqlflow_ssl_")
Expand Down
8 changes: 7 additions & 1 deletion pkg/sql/codegen/tensorflow/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"
"text/template"

pb "sqlflow.org/sqlflow/pkg/server/proto"
"sqlflow.org/sqlflow/pkg/sql/codegen"
)

Expand Down Expand Up @@ -195,6 +196,7 @@ func Train(ir *codegen.TrainIR) (string, error) {
ModelParams: modelParams,
TrainParams: trainParams,
Save: "model_save", // TODO(typhoonzero): executor.go will save the working directory, should test later.

}
var program bytes.Buffer
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
Expand All @@ -210,7 +212,7 @@ func Train(ir *codegen.TrainIR) (string, error) {
}

// Pred generates a Python program for predict using a TensorFlow model.
func Pred(ir *codegen.PredictIR) (string, error) {
func Pred(ir *codegen.PredictIR, session *pb.Session) (string, error) {
modelParams := make(map[string]interface{})
for attrKey, attr := range ir.TrainIR.Attributes {
if strings.HasPrefix(attrKey, "model.") {
Expand Down Expand Up @@ -249,6 +251,10 @@ func Pred(ir *codegen.PredictIR) (string, error) {
Y: ir.TrainIR.Label.GetFieldMeta()[0],
ModelParams: modelParams,
Save: "model_save",
HDFSNameNodeAddr: session.HdfsNamenodeAddr,
HiveLocation: session.HiveLocation,
HDFSUser: session.HdfsUser,
HDFSPass: session.HdfsPass,
}
var program bytes.Buffer
var predTemplate = template.Must(template.New("Pred").Funcs(template.FuncMap{
Expand Down
41 changes: 33 additions & 8 deletions pkg/sql/codegen/tensorflow/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,43 @@ package tensorflow

import (
"fmt"
"regexp"
"testing"

"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
pb "sqlflow.org/sqlflow/pkg/server/proto"
"sqlflow.org/sqlflow/pkg/sql/codegen"
)

func TestTrainCodegen(t *testing.T) {
a := assert.New(t)
tir := mockTrainIR()
_, err := Train(tir)
a.NoError(err)

pir := mockPredIR(tir)

sess := &pb.Session{
Token: "",
DbConnStr: "",
ExitOnSubmit: false,
UserId: "",
HiveLocation: "/sqlflowtmp",
HdfsNamenodeAddr: "192.168.1.1:8020",
HdfsUser: "sqlflow_admin",
HdfsPass: "sqlflow_pass",
}
code, err := Pred(pir, sess)
a.NoError(err)

r, _ := regexp.Compile(`hdfs_user="(.*)"`)
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_admin")
r, _ = regexp.Compile(`hdfs_pass="(.*)"`)
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_pass")
}

func mockTrainIR() *codegen.TrainIR {
cfg := &mysql.Config{
User: "root",
Passwd: "root",
Expand All @@ -42,7 +69,7 @@ func TestTrainCodegen(t *testing.T) {
COLUMN sepal_length, sepal_width, petal_length, petal_width
LABEL class
INTO sqlflow_models.my_xgboost_model;`
ir := codegen.TrainIR{
return &codegen.TrainIR{
DataSource: fmt.Sprintf("mysql://%s", cfg.FormatDSN()),
Select: "select * from iris.train;",
ValidationSelect: "select * from iris.test;",
Expand All @@ -59,16 +86,14 @@ func TestTrainCodegen(t *testing.T) {
&codegen.NumericColumn{&codegen.FieldMeta{"petal_length", codegen.Float, "", []int{1}, false, nil}},
&codegen.NumericColumn{&codegen.FieldMeta{"petal_width", codegen.Float, "", []int{1}, false, nil}}}},
Label: &codegen.NumericColumn{&codegen.FieldMeta{"class", codegen.Int, "", []int{1}, false, nil}}}
_, err := Train(&ir)
a.NoError(err)
}

predIR := codegen.PredictIR{
DataSource: fmt.Sprintf("mysql://%s", cfg.FormatDSN()),
func mockPredIR(trainIR *codegen.TrainIR) *codegen.PredictIR {
return &codegen.PredictIR{
DataSource: trainIR.DataSource,
Select: "select * from iris.test;",
ResultTable: "iris.predict",
Attributes: make(map[string]interface{}),
TrainIR: &ir,
TrainIR: trainIR,
}
_, err = Pred(&predIR)
a.NoError(err)
}
10 changes: 9 additions & 1 deletion pkg/sql/codegen/tensorflow/template_pred.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type predFiller struct {
Y *codegen.FieldMeta
ModelParams map[string]interface{}
Save string
HDFSNameNodeAddr string
HiveLocation string
HDFSUser string
HDFSPass string
}

const tfPredTemplateText = `
Expand Down Expand Up @@ -78,5 +82,9 @@ pred(is_keras_model="{{.IsKerasModel}}" == "true",
label_meta=label_meta,
model_params=model_params,
save="{{.Save}}",
batch_size=1)
batch_size=1,
hdfs_namenode_addr="{{.HDFSNameNodeAddr}}",
hive_location="{{.HiveLocation}}",
hdfs_user="{{.HDFSUser}}",
hdfs_pass="{{.HDFSPass}}")
`
20 changes: 12 additions & 8 deletions pkg/sql/codegen/xgboost/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ import (
"fmt"
"strings"

"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"

pb "sqlflow.org/sqlflow/pkg/server/proto"
"sqlflow.org/sqlflow/pkg/sql/codegen"
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
)

func newFloat32(f float32) *float32 {
Expand Down Expand Up @@ -151,7 +151,7 @@ func Train(ir *codegen.TrainIR) (string, error) {
}

// Pred generates a Python program for predict a xgboost model.
func Pred(ir *codegen.PredictIR) (string, error) {
func Pred(ir *codegen.PredictIR, session *pb.Session) (string, error) {
featureFieldMeta, labelFieldMeta, err := getFieldMeta(ir.TrainIR.Features["feature_columns"], ir.TrainIR.Label)
f, err := json.Marshal(featureFieldMeta)
if err != nil {
Expand All @@ -163,11 +163,15 @@ func Pred(ir *codegen.PredictIR) (string, error) {
}

r := predFiller{
DataSource: ir.DataSource,
PredSelect: ir.Select,
FeatureMetaJSON: string(f),
LabelMetaJSON: string(l),
ResultTable: ir.ResultTable,
DataSource: ir.DataSource,
PredSelect: ir.Select,
FeatureMetaJSON: string(f),
LabelMetaJSON: string(l),
ResultTable: ir.ResultTable,
HDFSNameNodeAddr: session.HdfsNamenodeAddr,
HiveLocation: session.HiveLocation,
HDFSUser: session.HdfsUser,
HDFSPass: session.HdfsPass,
}

var program bytes.Buffer
Expand Down
20 changes: 19 additions & 1 deletion pkg/sql/codegen/xgboost/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ package xgboost

import (
"fmt"
"regexp"
"testing"

"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
pb "sqlflow.org/sqlflow/pkg/server/proto"
"sqlflow.org/sqlflow/pkg/sql/codegen"
)

Expand All @@ -29,7 +31,23 @@ func TestTrainAndPredict(t *testing.T) {
a.NoError(err)

pir := mockPrdcIR(tir)
_, err = Pred(pir)
sess := &pb.Session{
Token: "",
DbConnStr: "",
ExitOnSubmit: false,
UserId: "",
HiveLocation: "/sqlflowtmp",
HdfsNamenodeAddr: "192.168.1.1:8020",
HdfsUser: "sqlflow_admin",
HdfsPass: "sqlflow_pass",
}
code, err := Pred(pir, sess)

r, _ := regexp.Compile(`hdfs_user="(.*)"`)
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_admin")
r, _ = regexp.Compile(`hdfs_pass="(.*)"`)
a.Equal(r.FindStringSubmatch(code)[1], "sqlflow_pass")

a.NoError(err)
}

Expand Down
25 changes: 18 additions & 7 deletions pkg/sql/codegen/xgboost/template_pred.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ import (
)

type predFiller struct {
DataSource string
PredSelect string
FeatureMetaJSON string
LabelMetaJSON string
ResultTable string
DataSource string
PredSelect string
FeatureMetaJSON string
LabelMetaJSON string
ResultTable string
HDFSNameNodeAddr string
HiveLocation string
HDFSUser string
HDFSPass string
}

const predTemplateText = `
Expand Down Expand Up @@ -61,13 +65,20 @@ preds = bst.predict(dpred)
if len(preds.shape) == 2:
# classifier result
preds = np.argmax(np.array(preds), axis=1)

feature_file_read = open("predict.txt", "r")

result_column_names = feature_column_names
result_column_names.append(label_name)
line_no = 0
with buffered_db_writer(conn.driver, conn, "{{.ResultTable}}", result_column_names, 100) as w:
with buffered_db_writer(conn.driver,
conn,
"{{.ResultTable}}",
result_column_names,
100,
hdfs_namenode_addr="{{.HDFSNameNodeAddr}}",
hive_location="{{.HiveLocation}}",
hdfs_user="{{.HDFSUser}}",
hdfs_pass="{{.HDFSPass}}") as w:
while True:
line = feature_file_read.readline()
if not line:
Expand Down
21 changes: 13 additions & 8 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ func train(wr *PipeWriter, tr *extendedSelect, db *DB, cwd string, modelDir stri
return e
}
var program bytes.Buffer
if strings.HasPrefix(strings.ToUpper(tr.estimator), `XGBOOST.`) {
if isXGBoostModel(tr.estimator) {
// FIXME(weiguoz): Remove the condition after the codegen refactor
if enableIR() {
ir, err := generateTrainIR(tr, db.String())
Expand Down Expand Up @@ -497,24 +497,20 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
return pr, fts, nil
}

func enableIR() bool {
return os.Getenv("SQLFLOW_codegen") == "ir"
}

func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir string, session *pb.Session) error {
pr, fts, e := loadModelMeta(pr, db, cwd, modelDir, pr.model)
if e != nil {
return fmt.Errorf("loadModelMeta %v", e)
}

var buf bytes.Buffer
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
if isXGBoostModel(pr.estimator) {
if enableIR() {
ir, err := generatePredictIR(pr, db.String(), cwd, modelDir)
if err != nil {
return err
}
code, err := xgboost.Pred(ir)
code, err := xgboost.Pred(ir, session)
if err != nil {
return err
}
Expand All @@ -534,7 +530,7 @@ func pred(wr *PipeWriter, pr *extendedSelect, db *DB, cwd string, modelDir strin
if err != nil {
return err
}
code, err := tensorflow.Pred(ir)
code, err := tensorflow.Pred(ir, session)
if err != nil {
return err
}
Expand Down Expand Up @@ -646,6 +642,15 @@ func createPredictionTable(predParsed *extendedSelect, db *DB, session *pb.Sessi
return nil
}

// -------------------------- utilities --------------------------------------
func isXGBoostModel(estimator string) bool {
return strings.HasPrefix(strings.ToUpper(estimator), `XGBOOST.`)
}

func enableIR() bool {
return os.Getenv("SQLFLOW_codegen") == "ir"
}

func parseTableColumn(s string) (string, string, error) {
pos := strings.LastIndex(s, ".")
if pos == -1 || pos == len(s)-1 {
Expand Down
4 changes: 2 additions & 2 deletions python/sqlflow_submitter/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def connect_with_data_source(driver_dsn):
port=int(port))
elif driver == "hive":
from impala.dbapi import connect
user, passwd, host, port, database, auth, session = parseHiveDSN(dsn)
user, passwd, host, port, database, auth, session_cfg = parseHiveDSN(dsn)
conn = connect(user=user,
password=passwd,
database=database,
host=host,
port=int(port),
auth_mechanism=auth)
conn.session = session
conn.session_cfg = session_cfg
elif driver == "maxcompute":
from sqlflow_submitter.maxcompute import MaxCompute
user, passwd, address, database = parseMaxComputeDSN(dsn)
Expand Down
Loading