Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 80 additions & 46 deletions sql/codegen_xgboost.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package sql

import (
"encoding/json"
"fmt"
"strconv"
"strings"
Expand All @@ -24,24 +25,13 @@ import (
)

type xgboostFiller struct {
isTrain bool
standardSelect string
modelPath string
modelPath string
xgboostFields
xgColumnFields
xgDataSourceFields
xgRuntimeFields
}

type xgRuntimeFields struct {
runLocal bool
xgRuntimeResourceFields
}

type xgRuntimeResourceFields struct {
WorkerNum uint `json:"worker_num,omitempty"`
MemorySize uint `json:"memory_size,omitempty"`
CPUSize uint `json:"cpu_size,omitempty"`
xgboostJSON string
xgDataSourceJSON string
xgColumnJSON string
}

type xgboostFields struct {
Expand Down Expand Up @@ -90,12 +80,15 @@ type xgFeatureFields struct {
}

type xgDataSourceFields struct {
IsTrain bool `json:"is_train,omitempty"`
StandardSelect string `json:"standard_select,omitempty"`
IsTensorFlowIntegrated bool `json:"is_tf_integrated,omitempty"`
X []*xgFeatureMeta `json:"x,omitempty"`
LabelField *xgFeatureMeta `json:"label,omitempty"`
WeightField *xgFeatureMeta `json:"weight,omitempty"`
GroupField *xgFeatureMeta `json:"group,omitempty"`
xgDataBaseField `json:"db_config,omitempty"`
WriteBatchSize int `json:"write_batch_size,omitempty"`
}

type xgDataBaseField struct {
Expand Down Expand Up @@ -239,12 +232,7 @@ func sListPartial(key string, ptrFn func(*xgboostFiller) *[]string) func(*map[st
}
}

var xgbAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) error{
// runtime params
"run_local": boolPartial("run_local", func(r *xgboostFiller) *bool { return &(r.runLocal) }),
"workers": uIntPartial("workers", func(r *xgboostFiller) *uint { return &(r.WorkerNum) }),
"memory": uIntPartial("memory", func(r *xgboostFiller) *uint { return &(r.MemorySize) }),
"cpu": uIntPartial("cpu", func(r *xgboostFiller) *uint { return &(r.CPUSize) }),
var xgbTrainAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) error{
// booster params
"objective": strPartial("objective", func(r *xgboostFiller) *string { return &(r.Objective) }),
"booster": strPartial("booster", func(r *xgboostFiller) *string { return &(r.Booster) }),
Expand All @@ -262,6 +250,10 @@ var xgbAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) err
// xgboost train controllers
"num_round": uIntPartial("num_round", func(r *xgboostFiller) *uint { return &(r.NumRound) }),
"auto_train": boolPartial("auto_train", func(r *xgboostFiller) *bool { return &(r.AutoTrain) }),
// Label, Group, Weight and xgFeatureFields are parsed from columnClause
}

var xgbPredAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) error{
// xgboost output columns (for prediction)
"append_columns": sListPartial("append_columns", func(r *xgboostFiller) *[]string { return &(r.AppendColumns) }),
"result_column": strPartial("result_column", func(r *xgboostFiller) *string { return &(r.ResultColumn) }),
Expand All @@ -272,9 +264,16 @@ var xgbAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) err
}

func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
var rawAttrs map[string]*expr
if pr.train {
rawAttrs = pr.trainAttrs
} else {
rawAttrs = pr.predAttrs
}

// parse pr.attrs to map[string][]string
attrs := make(map[string][]string)
for k, exp := range pr.trainAttrs {
for k, exp := range rawAttrs {
strExp := exp.String()
if strings.HasPrefix(strExp, "[") && strings.HasSuffix(strExp, "]") {
attrs[k] = exp.cdr()
Expand All @@ -290,8 +289,14 @@ func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
}

// fill xgboostFiller with attrs
var setterMap map[string]func(*map[string][]string, *xgboostFiller) error
if pr.train {
setterMap = xgbTrainAttrSetterMap
} else {
setterMap = xgbPredAttrSetterMap
}
for k := range attrs {
if setter, ok := xgbAttrSetterMap[k]; ok {
if setter, ok := setterMap[k]; ok {
if e := setter(&attrs, r); e != nil {
return xgParseAttrError(e)
}
Expand All @@ -309,15 +314,14 @@ func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
return nil
}

/* parse feature column, which owned by default column target("feature_columns:), from AST(pr.columns)

For now, two schemas are supported:
1. sparse-kv
schema: COLUMN SPARSE([feature_column], [1-dim shape], [single char delimiter])
data example: COLUMN SPARSE("0:1.5 1:100.1f 11:-1.2", [20], " ")
2. tf feature columns
roughly same as TFEstimator, except output shape of feaColumns are required to be 1-dim.
*/
// parseFeatureColumns, parse feature columns from AST(pr.columns).
// Features columns are columns owned by default column target whose key is "feature_columns".
// For now, two schemas are supported:
// 1. sparse-kv
// schema: COLUMN SPARSE([feature_column], [1-dim shape], [single char delimiter])
// data example: COLUMN SPARSE("0:1.5 1:100.1f 11:-1.2", [20], " ")
// 2. tf feature columns
// Roughly same as TFEstimator, except output shape of feaColumns are required to be 1-dim.
func parseFeatureColumns(columns *exprlist, r *xgboostFiller) error {
feaCols, colSpecs, err := resolveTrainColumns(columns)
if err != nil {
Expand All @@ -337,7 +341,7 @@ func parseFeatureColumns(columns *exprlist, r *xgboostFiller) error {
return nil
}

// parse sparse kv feature, which identified by `SPARSE`.
// parseSparseKeyValueFeatures, parse features which is identified by `SPARSE`.
// ex: SPARSE(col1, [100], comma)
func parseSparseKeyValueFeatures(colSpecs []*columnSpec, r *xgboostFiller) error {
var colNames []string
Expand Down Expand Up @@ -391,10 +395,7 @@ func parseDenseFeatures(feaCols []featureColumn, r *xgboostFiller) error {
if allSimpleCol && !isSimpleColumn(col) {
allSimpleCol = false
}
// FIXME(typhoonzero): Use Heuristic rules to determine whether a column should be transformed to a
// tf.SparseTensor. Currently the rules are:
// if column have delimiter and it's not a sequence_catigorical_column, we'll treat it as a sparse column
// else, use dense column.

isSparse := false
var isEmb bool
_, ok := col.(*sequenceCategoryIDColumn)
Expand Down Expand Up @@ -482,13 +483,19 @@ func xgParseColumns(pr *extendedSelect, filler *xgboostFiller) error {
return xgParseColumnError(target, e)
}
case "group":
if !pr.train {
continue
}
colMeta, e := parseSimpleColumn("group", &columns)
if e != nil {
return xgParseColumnError(target, e)
}
filler.GroupField = colMeta
filler.Group = colMeta.FeatureName
case "weight":
if !pr.train {
continue
}
colMeta, e := parseSimpleColumn("weight", &columns)
if e != nil {
return xgParseColumnError(target, e)
Expand All @@ -499,10 +506,13 @@ func xgParseColumns(pr *extendedSelect, filler *xgboostFiller) error {
return xgParseColumnError(target, xgUnsupportedColTagError())
}
}
filler.LabelField = &xgFeatureMeta{
FeatureName: pr.label,
// in predict mode, ignore label info
if pr.train {
filler.LabelField = &xgFeatureMeta{
FeatureName: pr.label,
}
filler.Label = pr.label
}
filler.Label = pr.label

return nil
}
Expand Down Expand Up @@ -546,14 +556,14 @@ func xgParseEstimator(pr *extendedSelect, filler *xgboostFiller) error {

func newXGBoostFiller(pr *extendedSelect, fts fieldTypes, db *DB) (*xgboostFiller, error) {
filler := &xgboostFiller{
isTrain: pr.train,
modelPath: pr.save,
standardSelect: pr.standardSelect.String(),
modelPath: pr.save,
}
filler.IsTrain = pr.train
filler.StandardSelect = pr.standardSelect.String()

// solve keyword: WITH (attributes)
if e := xgParseAttr(pr, filler); e != nil {
return nil, fmt.Errorf("failed to set xgboost attributes: %exp", e)
return nil, fmt.Errorf("failed to set xgboost attributes: %v", e)
}

// solve keyword: TRAIN (estimator)
Expand All @@ -566,7 +576,31 @@ func newXGBoostFiller(pr *extendedSelect, fts fieldTypes, db *DB) (*xgboostFille
return nil, e
}

return xgFillDatabaseInfo(filler, db)
// fill data base info
if _, e := xgFillDatabaseInfo(filler, db); e != nil {
return nil, e
}

// serialize fields
jsonBuffer, e := json.Marshal(filler.xgboostFields)
if e != nil {
return nil, e
}
filler.xgboostJSON = string(jsonBuffer)

jsonBuffer, e = json.Marshal(filler.xgDataSourceFields)
if e != nil {
return nil, e
}
filler.xgDataSourceJSON = string(jsonBuffer)

jsonBuffer, e = json.Marshal(filler.xgColumnFields)
if e != nil {
return nil, e
}
filler.xgColumnJSON = string(jsonBuffer)

return filler, nil
}

func xgFillDatabaseInfo(r *xgboostFiller, db *DB) (*xgboostFiller, error) {
Expand All @@ -591,7 +625,7 @@ func xgFillDatabaseInfo(r *xgboostFiller, db *DB) (*xgboostFiller, error) {
r.Host, r.Port, r.Database = sa[0], sa[1], cfg.DBName
r.User, r.Password = cfg.User, cfg.Passwd
// remove the last ';' which leads to a ParseException
r.standardSelect = removeLastSemicolon(r.standardSelect)
r.StandardSelect = removeLastSemicolon(r.StandardSelect)
case "maxcompute":
cfg, err := gomaxcompute.ParseDSN(db.dataSourceName)
if err != nil {
Expand Down
Loading