From 76abe13274ae3b6f6b7ddd6253eff7b8ac12b35c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Mon, 2 Sep 2019 17:53:51 +0800 Subject: [PATCH 1/2] refine code structure for expression resolver --- sql/attribute.go | 129 +++++++++-- sql/attribute_test.go | 45 ---- sql/bucket_column.go | 84 ------- sql/bucket_column_test.go | 55 ----- sql/category_id_column.go | 157 ------------- sql/category_id_column_test.go | 70 ------ sql/codegen.go | 22 +- sql/codegen_alps.go | 198 +++++++++-------- sql/codegen_xgboost.go | 25 ++- sql/column_spec.go | 78 ------- sql/column_spec_test.go | 59 ----- sql/cross_column.go | 93 -------- sql/cross_column_test.go | 59 ----- sql/embedding_column.go | 105 --------- sql/embedding_column_test.go | 57 ----- sql/engine_spec.go | 114 ---------- sql/expression_resolver.go | 379 ++++++++++++++++++++++++++++---- sql/expression_resolver_test.go | 260 ++++++++++++++++++++++ sql/feature_column.go | 79 ------- sql/gitlab_module.go | 23 -- sql/numeric_column.go | 91 -------- sql/numeric_column_test.go | 54 ----- 22 files changed, 834 insertions(+), 1402 deletions(-) delete mode 100644 sql/attribute_test.go delete mode 100644 sql/bucket_column.go delete mode 100644 sql/bucket_column_test.go delete mode 100644 sql/category_id_column.go delete mode 100644 sql/category_id_column_test.go delete mode 100644 sql/column_spec.go delete mode 100644 sql/column_spec_test.go delete mode 100644 sql/cross_column.go delete mode 100644 sql/cross_column_test.go delete mode 100644 sql/embedding_column.go delete mode 100644 sql/embedding_column_test.go delete mode 100644 sql/engine_spec.go delete mode 100644 sql/feature_column.go delete mode 100644 sql/gitlab_module.go delete mode 100644 sql/numeric_column.go delete mode 100644 sql/numeric_column_test.go diff --git a/sql/attribute.go b/sql/attribute.go index 2e1b7a1421..8fa1c95fe4 100644 --- a/sql/attribute.go +++ b/sql/attribute.go @@ -26,6 +26,110 @@ type attribute struct { Value interface{} } +type gitLabModule struct { + ModuleName string + ProjectName string + Sha string + PrivateToken string + SourceRoot string + GitLabServer string +} + +type engineSpec struct { + etype string + ps resourceSpec + worker resourceSpec + cluster string + queue string + masterResourceRequest string + masterResourceLimit string + workerResourceRequest string + workerResourceLimit string + volume string + imagePullPolicy string + restartPolicy string + extraPypiIndex string + namespace string + minibatchSize int + masterPodPriority string + clusterSpec string + recordsPerTask int +} + +func getEngineSpec(attrs map[string]*attribute) engineSpec { + getInt := func(key string, defaultValue int) int { + if p, ok := attrs[key]; ok { + strVal, _ := p.Value.(string) + intVal, err := strconv.Atoi(strVal) + + if err == nil { + return intVal + } + } + return defaultValue + } + getString := func(key string, defaultValue string) string { + if p, ok := attrs[key]; ok { + strVal, ok := p.Value.(string) + if ok { + // TODO(joyyoj): use the parser to do those validations. + if strings.HasPrefix(strVal, "\"") && strings.HasSuffix(strVal, "\"") { + return strVal[1 : len(strVal)-1] + } + return strVal + } + } + return defaultValue + } + + psNum := getInt("ps_num", 1) + psMemory := getInt("ps_memory", 2400) + workerMemory := getInt("worker_memory", 1600) + workerNum := getInt("worker_num", 2) + engineType := getString("type", "local") + if (psNum > 0 || workerNum > 0) && engineType == "local" { + engineType = "yarn" + } + cluster := getString("cluster", "") + queue := getString("queue", "") + + // ElasticDL engine specs + masterResourceRequest := getString("master_resource_request", "cpu=0.1,memory=1024Mi") + masterResourceLimit := getString("master_resource_limit", "") + workerResourceRequest := getString("worker_resource_request", "cpu=1,memory=4096Mi") + workerResourceLimit := getString("worker_resource_limit", "") + volume := getString("volume", "") + imagePullPolicy := getString("image_pull_policy", "Always") + restartPolicy := getString("restart_policy", "Never") + extraPypiIndex := getString("extra_pypi_index", "") + namespace := getString("namespace", "default") + minibatchSize := getInt("minibatch_size", 64) + masterPodPriority := getString("master_pod_priority", "") + clusterSpec := getString("cluster_spec", "") + recordsPerTask := getInt("records_per_task", 100) + + return engineSpec{ + etype: engineType, + ps: resourceSpec{Num: psNum, Memory: psMemory}, + worker: resourceSpec{Num: workerNum, Memory: workerMemory}, + cluster: cluster, + queue: queue, + masterResourceRequest: masterResourceRequest, + masterResourceLimit: masterResourceLimit, + workerResourceRequest: workerResourceRequest, + workerResourceLimit: workerResourceLimit, + volume: volume, + imagePullPolicy: imagePullPolicy, + restartPolicy: restartPolicy, + extraPypiIndex: extraPypiIndex, + namespace: namespace, + minibatchSize: minibatchSize, + masterPodPriority: masterPodPriority, + clusterSpec: clusterSpec, + recordsPerTask: recordsPerTask, + } +} + func (a *attribute) GenerateCode() (string, error) { if val, ok := a.Value.(string); ok { // auto convert to int first. @@ -45,7 +149,7 @@ func (a *attribute) GenerateCode() (string, error) { return "", fmt.Errorf("value of attribute must be string or list of int, given %s", a.Value) } -func filter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute { +func attrFilter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute { ret := make(map[string]*attribute, 0) for _, a := range attrs { if strings.EqualFold(a.Prefix, prefix) { @@ -57,26 +161,3 @@ func filter(attrs map[string]*attribute, prefix string, remove bool) map[string] } return ret } - -func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { - ret := make(map[string]*attribute) - for k, v := range *attrs { - subs := strings.SplitN(k, ".", 2) - name := subs[len(subs)-1] - prefix := "" - if len(subs) == 2 { - prefix = subs[0] - } - r, _, err := resolveExpression(v) - if err != nil { - return nil, err - } - a := &attribute{ - FullName: k, - Prefix: prefix, - Name: name, - Value: r} - ret[a.FullName] = a - } - return ret, nil -} diff --git a/sql/attribute_test.go b/sql/attribute_test.go deleted file mode 100644 index 97399c46c7..0000000000 --- a/sql/attribute_test.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestAttrs(t *testing.T) { - a := assert.New(t) - parser := newParser() - - s := statementWithAttrs("estimator.hidden_units = [10, 20]") - r, e := parser.Parse(s) - a.NoError(e) - attrs, err := resolveAttribute(&r.trainAttrs) - a.NoError(err) - attr := attrs["estimator.hidden_units"] - a.Equal("estimator", attr.Prefix) - a.Equal("hidden_units", attr.Name) - a.Equal([]interface{}([]interface{}{10, 20}), attr.Value) - - s = statementWithAttrs("dataset.name = hello") - r, e = parser.Parse(s) - a.NoError(e) - attrs, err = resolveAttribute(&r.trainAttrs) - a.NoError(err) - attr = attrs["dataset.name"] - a.Equal("dataset", attr.Prefix) - a.Equal("name", attr.Name) - a.Equal("hello", attr.Value) -} diff --git a/sql/bucket_column.go b/sql/bucket_column.go deleted file mode 100644 index 55f3893780..0000000000 --- a/sql/bucket_column.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strings" -) - -type bucketColumn struct { - SourceColumn *numericColumn - Boundaries []int -} - -func (bc *bucketColumn) GenerateCode() (string, error) { - sourceCode, _ := bc.SourceColumn.GenerateCode() - return fmt.Sprintf( - "tf.feature_column.bucketized_column(%s, boundaries=%s)", - sourceCode, - strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ",")), nil -} - -func (bc *bucketColumn) GetDelimiter() string { - return "" -} - -func (bc *bucketColumn) GetDtype() string { - return "" -} - -func (bc *bucketColumn) GetKey() string { - return bc.SourceColumn.Key -} - -func (bc *bucketColumn) GetInputShape() string { - return bc.SourceColumn.GetInputShape() -} - -func (bc *bucketColumn) GetColumnType() int { - return columnTypeBucket -} - -func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) - } - sourceExprList := (*el)[1] - boundariesExprList := (*el)[2] - if sourceExprList.typ != 0 { - return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %v", sourceExprList) - } - source, _, err := resolveColumn(&sourceExprList.sexp) - if err != nil { - return nil, err - } - if source.GetColumnType() != columnTypeNumeric { - return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) - } - boundaries, _, err := resolveExpression(boundariesExprList) - if err != nil { - return nil, err - } - if _, ok := boundaries.([]interface{}); !ok { - return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) - } - b, err := transformToIntList(boundaries.([]interface{})) - if err != nil { - return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) - } - return &bucketColumn{ - SourceColumn: source.(*numericColumn), - Boundaries: b}, nil -} diff --git a/sql/bucket_column_test.go b/sql/bucket_column_test.go deleted file mode 100644 index 309ee34917..0000000000 --- a/sql/bucket_column_test.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBucketColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])") - badInput := statementWithColumn("BUCKET(c1, [1, 10])") - badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - bc, ok := fcs[0].(*bucketColumn) - a.True(ok) - code, e := bc.GenerateCode() - a.NoError(e) - a.Equal("c1", bc.SourceColumn.Key) - a.Equal([]int{10}, bc.SourceColumn.Shape) - a.Equal([]int{1, 10}, bc.Boundaries) - a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBoundaries) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} diff --git a/sql/category_id_column.go b/sql/category_id_column.go deleted file mode 100644 index f083832dbd..0000000000 --- a/sql/category_id_column.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strconv" -) - -type categoryIDColumn struct { - Key string - BucketSize int - Delimiter string - Dtype string -} - -type sequenceCategoryIDColumn struct { - Key string - BucketSize int - Delimiter string - Dtype string - IsSequence bool -} - -func (cc *categoryIDColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", - cc.Key, cc.BucketSize), nil -} - -func (cc *categoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -func (cc *categoryIDColumn) GetDtype() string { - return cc.Dtype -} - -func (cc *categoryIDColumn) GetKey() string { - return cc.Key -} - -func (cc *categoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - -func (cc *categoryIDColumn) GetColumnType() int { - return columnTypeCategoryID -} - -func (cc *sequenceCategoryIDColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", - cc.Key, cc.BucketSize), nil -} - -func (cc *sequenceCategoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -func (cc *sequenceCategoryIDColumn) GetDtype() string { - return cc.Dtype -} - -func (cc *sequenceCategoryIDColumn) GetKey() string { - return cc.Key -} - -func (cc *sequenceCategoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - -func (cc *sequenceCategoryIDColumn) GetColumnType() int { - return columnTypeSeqCategoryID -} - -func parseCategoryColumnKey(el *exprlist) (*columnSpec, error) { - if (*el)[1].typ == 0 { - // explist, maybe DENSE/SPARSE expressions - subExprList := (*el)[1].sexp - isSparse := subExprList[0].val == sparse - return resolveColumnSpec(&subExprList, isSparse) - } - return nil, nil -} - -func resolveSeqCategoryIDColumn(el *exprlist) (*sequenceCategoryIDColumn, *columnSpec, error) { - key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) - if err != nil { - return nil, nil, err - } - return &sequenceCategoryIDColumn{ - Key: key, - BucketSize: bucketSize, - Delimiter: delimiter, - // TODO(typhoonzero): support config dtype - Dtype: "int64", - IsSequence: true}, cs, nil -} - -func resolveCategoryIDColumn(el *exprlist) (*categoryIDColumn, *columnSpec, error) { - key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) - if err != nil { - return nil, nil, err - } - return &categoryIDColumn{ - Key: key, - BucketSize: bucketSize, - Delimiter: delimiter, - // TODO(typhoonzero): support config dtype - Dtype: "int64"}, cs, nil -} - -func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columnSpec, error) { - if len(*el) != 3 && len(*el) != 4 { - return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) - } - var cs *columnSpec - key := "" - var err error - if (*el)[1].typ == 0 { - // explist, maybe DENSE/SPARSE expressions - subExprList := (*el)[1].sexp - isSparse := subExprList[0].val == sparse - cs, err = resolveColumnSpec(&subExprList, isSparse) - if err != nil { - return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %v", subExprList) - } - key = cs.ColumnName - } else { - key, err = expression2string((*el)[1]) - if err != nil { - return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) - } - } - bucketSize, err := strconv.Atoi((*el)[2].val) - if err != nil { - return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) - } - delimiter := "" - if len(*el) == 4 { - delimiter, err = resolveDelimiter((*el)[3].val) - if err != nil { - return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) - } - } - return key, bucketSize, delimiter, cs, nil -} diff --git a/sql/category_id_column_test.go b/sql/category_id_column_test.go deleted file mode 100644 index 404ff2482c..0000000000 --- a/sql/category_id_column_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCatIdColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("CATEGORY_ID(c1, 100)") - badKey := statementWithColumn("CATEGORY_ID([100], 100)") - badBucket := statementWithColumn("CATEGORY_ID(c1, bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - cc, ok := fcs[0].(*categoryIDColumn) - a.True(ok) - code, e := cc.GenerateCode() - a.NoError(e) - a.Equal("c1", cc.Key) - a.Equal(100, cc.BucketSize) - a.Equal("tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100)", code) - - r, e = parser.Parse(badKey) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucket) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestCatIdColumnWithColumnSpec(t *testing.T) { - a := assert.New(t) - parser := newParser() - - dense := statementWithColumn("CATEGORY_ID(DENSE(col1, 128, COMMA), 100)") - - r, e := parser.Parse(dense) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, css, e := resolveTrainColumns(&c) - a.NoError(e) - _, ok := fcs[0].(*categoryIDColumn) - a.True(ok) - a.Equal(css[0].ColumnName, "col1") -} diff --git a/sql/codegen.go b/sql/codegen.go index e92dfe97df..4c32761db8 100644 --- a/sql/codegen.go +++ b/sql/codegen.go @@ -20,6 +20,7 @@ import ( "text/template" "github.com/go-sql-driver/mysql" + "github.com/sql-machine-learning/sqlflow/sql/columns" "sqlflow.org/gohive" "sqlflow.org/gomaxcompute" ) @@ -111,8 +112,8 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D r.modelConfig.Epochs = trainResolved.Epoch featureColumnsCode := make(map[string][]string) - for target, columns := range pr.columns { - feaCols, colSpecs, err := resolveTrainColumns(&columns) + for target, columnsExpr := range pr.columns { + feaCols, colSpecs, err := resolveTrainColumns(&columnsExpr) if err != nil { return nil, err } @@ -120,25 +121,30 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D return nil, fmt.Errorf("newFiller doesn't support DENSE/SPARSE") } for _, col := range feaCols { - feaColCode, e := col.GenerateCode() + // TODO(typhoonzero): pass columnSpecs if needed. + feaColCode, e := col.GenerateCode(nil) if e != nil { return nil, e } + if len(feaColCode) > 1 { + return nil, fmt.Errorf("does not support grouped feature column yet, grouped column: %v", feaColCode) + } + // FIXME(typhoonzero): Use Heuristic rules to determine whether a column should be transformed to a // tf.SparseTensor. Currently the rules are: // if column have delimiter and it's not a sequence_catigorical_column, we'll treat it as a sparse column // else, use dense column. isSparse := false var isEmb bool - _, ok := col.(*sequenceCategoryIDColumn) + _, ok := col.(*columns.SequenceCategoryIDColumn) if !ok { - _, isEmb = col.(*embeddingColumn) + _, isEmb = col.(*columns.EmbeddingColumn) if isEmb { - _, ok = col.(*embeddingColumn).CategoryColumn.(*sequenceCategoryIDColumn) + _, ok = col.(*columns.EmbeddingColumn).CategoryColumn.(*columns.SequenceCategoryIDColumn) } } if !ok && col.GetDelimiter() != "" { - if _, ok := col.(*numericColumn); !ok { + if _, ok := col.(*columns.NumericColumn); !ok { isSparse = true } } @@ -152,7 +158,7 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D r.X = append(r.X, fm) featureColumnsCode[target] = append( featureColumnsCode[target], - feaColCode) + feaColCode[0]) } } diff --git a/sql/codegen_alps.go b/sql/codegen_alps.go index e4b0f387f2..df4809472f 100644 --- a/sql/codegen_alps.go +++ b/sql/codegen_alps.go @@ -26,6 +26,7 @@ import ( "text/template" pb "github.com/sql-machine-learning/sqlflow/server/proto" + "github.com/sql-machine-learning/sqlflow/sql/columns" "sqlflow.org/gomaxcompute" ) @@ -76,9 +77,28 @@ type alpsFiller struct { OSSEndpoint string } -type alpsFeatureColumn interface { - featureColumn - GenerateAlpsCode(metadata *metadata) ([]string, error) +// type alpsFeatureColumn interface { +// columns.FeatureColumn +// GenerateAlpsCode(metadata *metadata) ([]string, error) +// } + +type alpsBucketCol struct { + columns.BucketColumn +} +type alpsCategoryIDCol struct { + columns.CategoryIDColumn +} +type alpsSeqCategoryIDCol struct { + columns.SequenceCategoryIDColumn +} +type alpsCrossCol struct { + columns.CrossColumn +} +type alpsEmbeddingCol struct { + columns.EmbeddingColumn +} +type alpsNumericCol struct { + columns.NumericColumn } func engineCreatorCode(resolved *resolvedTrainClause) (string, error) { @@ -164,11 +184,11 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra } var odpsConfig = &gomaxcompute.Config{} - var columnInfo map[string]*columnSpec + var columnInfo map[string]*columns.ColumnSpec // TODO(joyyoj) read feature mapping table's name from table attributes. // TODO(joyyoj) pr may contains partition. - fmap := featureMap{pr.tables[0] + "_feature_map", ""} + fmap := columns.FeatureMap{pr.tables[0] + "_feature_map", ""} var meta metadata fields := make([]string, 0) if db != nil { @@ -185,7 +205,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra meta.columnInfo = &columnInfo } else { meta = metadata{odpsConfig, pr.tables[0], nil, nil} - columnInfo = map[string]*columnSpec{} + columnInfo = map[string]*columns.ColumnSpec{} for _, css := range resolved.ColumnSpecs { for _, cs := range css { columnInfo[cs.ColumnName] = cs @@ -203,7 +223,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra for _, cs := range columnInfo { csCode = append(csCode, cs.ToString()) } - y := &columnSpec{ + y := &columns.ColumnSpec{ ColumnName: pr.label, IsSparse: false, Shape: []int{1}, @@ -411,55 +431,16 @@ func alpsPred(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb return nil } -func (nc *numericColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { - output := make([]string, 0) - output = append(output, - fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, - strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ","))) - return output, nil -} - -func (bc *bucketColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { - sourceCode, _ := bc.SourceColumn.GenerateCode() - output := make([]string, 0) - output = append(output, fmt.Sprintf( - "tf.feature_column.bucketized_column(%s, boundaries=%s)", - sourceCode, - strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ","))) - return output, nil -} - -func (cc *crossColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { - var keysGenerated = make([]string, len(cc.Keys)) - var output []string - for idx, key := range cc.Keys { - if c, ok := key.(featureColumn); ok { - code, err := c.GenerateCode() - if err != nil { - return output, err - } - keysGenerated[idx] = code - continue - } - if str, ok := key.(string); ok { - keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) - } else { - return output, fmt.Errorf("cross generate code error, key: %s", key) - } - } - output = append(output, fmt.Sprintf( - "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", - strings.Join(keysGenerated, ","), cc.HashBucketSize)) - return output, nil -} - -func (cc *categoryIDColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { +// GenerateCode overrides the member function defined in `category_id_column.go` +func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) { output := make([]string, 0) - columnInfo, present := (*metadata.columnInfo)[cc.Key] + // columnInfo, present := (*metadata.columnInfo)[cc.Key] + columnInfo := cs var err error - if !present { - err = fmt.Errorf("Failed to get column info of %s", cc.Key) - } else if len(columnInfo.Shape) == 0 { + // if !present { + // err = fmt.Errorf("Failed to get column info of %s", cc.Key) + // } else + if len(columnInfo.Shape) == 0 { err = fmt.Errorf("Shape is empty %s", cc.Key) } else if len(columnInfo.Shape) == 1 { // FIXME(Yancey1989): the suffix "_0" is only used in alps-rc5, would be fixed in the next release. @@ -474,13 +455,15 @@ func (cc *categoryIDColumn) GenerateAlpsCode(metadata *metadata) ([]string, erro return output, err } -func (cc *sequenceCategoryIDColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { +func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) { output := make([]string, 0) - columnInfo, present := (*metadata.columnInfo)[cc.Key] + // columnInfo, present := (*metadata.columnInfo)[cc.Key] + columnInfo := cs var err error - if !present { - err = fmt.Errorf("Failed to get column info of %s", cc.Key) - } else if len(columnInfo.Shape) == 0 { + // if !present { + // err = fmt.Errorf("Failed to get column info of %s", cc.Key) + // } else + if len(columnInfo.Shape) == 0 { err = fmt.Errorf("Shape is empty %s", cc.Key) } else if len(columnInfo.Shape) == 1 { output = append(output, fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", @@ -494,15 +477,15 @@ func (cc *sequenceCategoryIDColumn) GenerateAlpsCode(metadata *metadata) ([]stri return output, err } -func (ec *embeddingColumn) GenerateAlpsCode(metadata *metadata) ([]string, error) { +func (ec *alpsEmbeddingCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) { var output []string - catColumn, ok := ec.CategoryColumn.(alpsFeatureColumn) - if !ok { - return output, fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) - } - sourceCode, err := catColumn.GenerateAlpsCode(metadata) + catColumn := &alpsCategoryIDCol{*(ec.CategoryColumn.(*columns.CategoryIDColumn))} + // if !ok { + // return output, fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) + // } + sourceCode, err := catColumn.GenerateCode(cs) if err != nil { - return output, err + return []string{}, err } output = make([]string, 0) for _, elem := range sourceCode { @@ -517,12 +500,35 @@ func (ec *embeddingColumn) GenerateAlpsCode(metadata *metadata) ([]string, error return output, nil } -func generateAlpsFeatureColumnCode(fcs []featureColumn, metadata *metadata) ([]string, error) { +func generateAlpsFeatureColumnCode(fcs []columns.FeatureColumn, metadata *metadata) ([]string, error) { var codes = make([]string, 0, 1000) for _, fc := range fcs { - code, err := fc.(alpsFeatureColumn).GenerateAlpsCode(metadata) - if err != nil { - return codes, nil + var castedFC columns.FeatureColumn + // FIXME(typhoonzero): Find a better way to override the `GenerateCode` function + switch fc.GetColumnType() { + case columns.ColumnTypeCategoryID: + castedFC = &alpsCategoryIDCol{(*fc.(*columns.CategoryIDColumn))} + case columns.ColumnTypeEmbedding: + castedFC = &alpsEmbeddingCol{(*fc.(*columns.EmbeddingColumn))} + case columns.ColumnTypeSeqCategoryID: + castedFC = &alpsSeqCategoryIDCol{(*fc.(*columns.SequenceCategoryIDColumn))} + default: + castedFC = fc + } + var code []string + var err error + if fc.GetKey() == "" { + // cross column have single key + code, err = castedFC.GenerateCode(nil) + } else { + cs, ok := (*metadata.columnInfo)[fc.GetKey()] + if !ok { + return nil, fmt.Errorf("No column spec found for column: %v", fc.GetKey()) + } + code, err = castedFC.GenerateCode(cs) + if err != nil { + return nil, err + } } codes = append(codes, code...) } @@ -532,13 +538,13 @@ func generateAlpsFeatureColumnCode(fcs []featureColumn, metadata *metadata) ([]s type metadata struct { odpsConfig *gomaxcompute.Config table string - featureMap *featureMap - columnInfo *map[string]*columnSpec + featureMap *columns.FeatureMap + columnInfo *map[string]*columns.ColumnSpec } -func flattenColumnSpec(columns map[string][]*columnSpec) map[string]*columnSpec { - output := map[string]*columnSpec{} - for _, cols := range columns { +func flattenColumnSpec(columnSpecs map[string][]*columns.ColumnSpec) map[string]*columns.ColumnSpec { + output := map[string]*columns.ColumnSpec{} + for _, cols := range columnSpecs { for _, col := range cols { output[col.ColumnName] = col } @@ -546,8 +552,8 @@ func flattenColumnSpec(columns map[string][]*columnSpec) map[string]*columnSpec return output } -func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []string) (map[string]*columnSpec, error) { - columns := map[string]*columnSpec{} +func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []string) (map[string]*columns.ColumnSpec, error) { + columns := map[string]*columns.ColumnSpec{} refColumns := flattenColumnSpec(resolved.ColumnSpecs) sparseColumns, _ := meta.getSparseColumnInfo() @@ -586,10 +592,10 @@ func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []stri } // get all referenced field names. -func getAllKeys(fcs []featureColumn) []string { +func getAllKeys(fcs []columns.FeatureColumn) []string { output := make([]string, 0) for _, fc := range fcs { - key := fc.(alpsFeatureColumn).GetKey() + key := fc.GetKey() output = append(output, key) } return output @@ -632,8 +638,8 @@ func getFields(meta *metadata, pr *extendedSelect) ([]string, error) { return fields, nil } -func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*columnSpec) (map[string]*columnSpec, error) { - output := map[string]*columnSpec{} +func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*columns.ColumnSpec) (map[string]*columns.ColumnSpec, error) { + output := map[string]*columns.ColumnSpec{} fields := strings.Join(keys, ",") query := fmt.Sprintf("SELECT %s FROM %s LIMIT 1", fields, meta.table) sqlDB, _ := sql.Open("maxcompute", meta.odpsConfig.FormatDSN()) @@ -643,8 +649,8 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c } defer sqlDB.Close() columnTypes, _ := rows.ColumnTypes() - columns, _ := rows.Columns() - count := len(columns) + columnNamess, _ := rows.Columns() + count := len(columnNamess) for rows.Next() { values := make([]interface{}, count) for i, ct := range columnTypes { @@ -663,17 +669,29 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c shape := make([]int, 1) shape[0] = len(fields) if userSpec, ok := refColumns[ct.Name()]; ok { - output[ct.Name()] = &columnSpec{ct.Name(), false, shape, userSpec.DType, userSpec.Delimiter, *meta.featureMap} + output[ct.Name()] = &columns.ColumnSpec{ + ct.Name(), + false, + shape, + userSpec.DType, + userSpec.Delimiter, + *meta.featureMap} } else { - output[ct.Name()] = &columnSpec{ct.Name(), false, shape, "float", ",", *meta.featureMap} + output[ct.Name()] = &columns.ColumnSpec{ + ct.Name(), + false, + shape, + "float", + ",", + *meta.featureMap} } } } return output, nil } -func (meta *metadata) getSparseColumnInfo() (map[string]*columnSpec, error) { - output := map[string]*columnSpec{} +func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, error) { + output := map[string]*columns.ColumnSpec{} sqlDB, _ := sql.Open("maxcompute", meta.odpsConfig.FormatDSN()) filter := "feature_type != '' " @@ -689,8 +707,8 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columnSpec, error) { } defer sqlDB.Close() columnTypes, _ := rows.ColumnTypes() - columns, _ := rows.Columns() - count := len(columns) + columnNames, _ := rows.Columns() + count := len(columnNames) for rows.Next() { values := make([]interface{}, count) for i, ct := range columnTypes { @@ -712,7 +730,7 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columnSpec, error) { column, present := output[*name] if !present { shape := make([]int, 0, 1000) - column := &columnSpec{*name, true, shape, "int64", "", *meta.featureMap} + column := &columns.ColumnSpec{*name, true, shape, "int64", "", *meta.featureMap} column.DType = "int64" output[*name] = column } diff --git a/sql/codegen_xgboost.go b/sql/codegen_xgboost.go index 0fe5c9fb60..6a1c6856cb 100644 --- a/sql/codegen_xgboost.go +++ b/sql/codegen_xgboost.go @@ -24,6 +24,7 @@ import ( "text/template" "github.com/go-sql-driver/mysql" + "github.com/sql-machine-learning/sqlflow/sql/columns" "sqlflow.org/gohive" "sqlflow.org/gomaxcompute" ) @@ -355,7 +356,7 @@ func parseFeatureColumns(columns *exprlist, r *xgboostFiller) error { // parseSparseKeyValueFeatures, parse features which is identified by `SPARSE`. // ex: SPARSE(col1, [100], comma) -func parseSparseKeyValueFeatures(colSpecs []*columnSpec, r *xgboostFiller) error { +func parseSparseKeyValueFeatures(colSpecs []*columns.ColumnSpec, r *xgboostFiller) error { var colNames []string for _, spec := range colSpecs { colNames = append(colNames, spec.ColumnName) @@ -394,14 +395,14 @@ func parseSparseKeyValueFeatures(colSpecs []*columnSpec, r *xgboostFiller) error } // check whether column is raw column (no tf transformation need) -func isSimpleColumn(col featureColumn) bool { - if _, ok := col.(*numericColumn); ok { +func isSimpleColumn(col columns.FeatureColumn) bool { + if _, ok := col.(*columns.NumericColumn); ok { return col.GetDelimiter() == "" && col.GetInputShape() == "[1]" && col.GetDtype() == "float32" } return false } -func parseDenseFeatures(feaCols []featureColumn, r *xgboostFiller) error { +func parseDenseFeatures(feaCols []columns.FeatureColumn, r *xgboostFiller) error { allSimpleCol := true for _, col := range feaCols { if allSimpleCol && !isSimpleColumn(col) { @@ -410,30 +411,34 @@ func parseDenseFeatures(feaCols []featureColumn, r *xgboostFiller) error { isSparse := false var isEmb bool - _, ok := col.(*sequenceCategoryIDColumn) + _, ok := col.(*columns.SequenceCategoryIDColumn) if !ok { - _, isEmb = col.(*embeddingColumn) + _, isEmb = col.(*columns.EmbeddingColumn) if isEmb { - _, ok = col.(*embeddingColumn).CategoryColumn.(*sequenceCategoryIDColumn) + _, ok = col.(*columns.EmbeddingColumn).CategoryColumn.(*columns.SequenceCategoryIDColumn) } } if !ok && col.GetDelimiter() != "" { - if _, ok := col.(*numericColumn); !ok { + if _, ok := col.(*columns.NumericColumn); !ok { isSparse = true } } - feaColCode, e := col.GenerateCode() + // TODO(typhoonzero): pass columnSpec if needed. + feaColCode, e := col.GenerateCode(nil) if e != nil { return e } + if len(feaColCode) > 1 { + return fmt.Errorf("does not support grouped column yet: %v", feaColCode) + } fm := &xgFeatureMeta{ FeatureName: col.GetKey(), Dtype: col.GetDtype(), Delimiter: col.GetDelimiter(), InputShape: col.GetInputShape(), - FeatureColumnCode: feaColCode, + FeatureColumnCode: feaColCode[0], IsSparse: isSparse, } r.X = append(r.X, fm) diff --git a/sql/column_spec.go b/sql/column_spec.go deleted file mode 100644 index 0a8662a925..0000000000 --- a/sql/column_spec.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strconv" -) - -// columnSpec defines how to generate codes to parse column data to tensor/sparsetensor -type columnSpec struct { - ColumnName string - IsSparse bool - Shape []int - DType string - Delimiter string - FeatureMap featureMap -} - -func resolveColumnSpec(el *exprlist, isSparse bool) (*columnSpec, error) { - if len(*el) < 4 { - return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) - } - name, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec name: %s, err: %s", (*el)[1], err) - } - var shape []int - intShape, err := strconv.Atoi((*el)[2].val) - if err != nil { - strShape, err := expression2string((*el)[2]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) - } - if strShape != "none" { - return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) - } - } else { - shape = append(shape, intShape) - } - unresolvedDelimiter, err := expression2string((*el)[3]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec delimiter: %s, err: %s", (*el)[1], err) - } - - delimiter, err := resolveDelimiter(unresolvedDelimiter) - if err != nil { - return nil, err - } - - // resolve feature map - fm := featureMap{} - dtype := "float" - if isSparse { - dtype = "int" - } - if len(*el) >= 5 { - dtype, err = expression2string((*el)[4]) - } - return &columnSpec{ - ColumnName: name, - IsSparse: isSparse, - Shape: shape, - DType: dtype, - Delimiter: delimiter, - FeatureMap: fm}, nil -} diff --git a/sql/column_spec_test.go b/sql/column_spec_test.go deleted file mode 100644 index 904bfa0cfe..0000000000 --- a/sql/column_spec_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestFeatureSpec(t *testing.T) { - a := assert.New(t) - parser := newParser() - - denseStatement := statementWithColumn("DENSE(c2, 5, comma)") - sparseStatement := statementWithColumn("SPARSE(c1, 100, comma)") - badStatement := statementWithColumn("DENSE(c3, bad, comma)") - - r, e := parser.Parse(denseStatement) - a.NoError(e) - c := r.columns["feature_columns"] - _, css, e := resolveTrainColumns(&c) - a.NoError(e) - cs := css[0] - a.Equal("c2", cs.ColumnName) - a.Equal(5, cs.Shape[0]) - a.Equal(",", cs.Delimiter) - a.Equal(false, cs.IsSparse) - a.Equal("DenseColumn(name=\"c2\", shape=[5], dtype=\"float\", separator=\",\")", cs.ToString()) - - r, e = parser.Parse(sparseStatement) - a.NoError(e) - c = r.columns["feature_columns"] - _, css, e = resolveTrainColumns(&c) - a.NoError(e) - cs = css[0] - a.Equal("c1", cs.ColumnName) - a.Equal(100, cs.Shape[0]) - a.Equal(true, cs.IsSparse) - a.Equal("SparseColumn(name=\"c1\", shape=[100], dtype=\"int\")", cs.ToString()) - - r, e = parser.Parse(badStatement) - a.NoError(e) - c = r.columns["feature_columns"] - _, _, e = resolveTrainColumns(&c) - a.Error(e) - -} diff --git a/sql/cross_column.go b/sql/cross_column.go deleted file mode 100644 index 6cb2f66dbd..0000000000 --- a/sql/cross_column.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strconv" - "strings" -) - -// TODO(uuleon) specify the hash_key if needed -type crossColumn struct { - Keys []interface{} - HashBucketSize int -} - -func (cc *crossColumn) GenerateCode() (string, error) { - var keysGenerated = make([]string, len(cc.Keys)) - for idx, key := range cc.Keys { - if c, ok := key.(featureColumn); ok { - code, err := c.GenerateCode() - if err != nil { - return "", err - } - keysGenerated[idx] = code - continue - } - if str, ok := key.(string); ok { - keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) - } else { - return "", fmt.Errorf("cross generate code error, key: %s", key) - } - } - return fmt.Sprintf( - "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", - strings.Join(keysGenerated, ","), cc.HashBucketSize), nil -} - -func (cc *crossColumn) GetDelimiter() string { - return "" -} - -func (cc *crossColumn) GetDtype() string { - return "" -} - -func (cc *crossColumn) GetKey() string { - // NOTE: cross column is a feature on multiple column keys. - return "" -} - -func (cc *crossColumn) GetInputShape() string { - // NOTE: return empty since crossed column input shape is already determined - // by the two crossed columns. - return "" -} - -func (cc *crossColumn) GetColumnType() int { - return columnTypeCross -} - -func resolveCrossColumn(el *exprlist) (*crossColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad CROSS expression format: %s", *el) - } - keysExpr := (*el)[1] - key, _, err := resolveExpression(keysExpr) - if err != nil { - return nil, err - } - if _, ok := key.([]interface{}); !ok { - return nil, fmt.Errorf("bad CROSS expression format: %s", *el) - } - - bucketSize, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) - } - return &crossColumn{ - Keys: key.([]interface{}), - HashBucketSize: bucketSize}, nil -} diff --git a/sql/cross_column_test.go b/sql/cross_column_test.go deleted file mode 100644 index 9fa5d0f83b..0000000000 --- a/sql/cross_column_test.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCrossColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], 20)") - badInput := statementWithColumn("cross(c1, 20)") - badBucketSize := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcList, _, e := resolveTrainColumns(&c) - a.NoError(e) - cc, ok := fcList[0].(*crossColumn) - a.True(ok) - code, e := cc.GenerateCode() - a.NoError(e) - - bc := cc.Keys[0].(*bucketColumn) - a.Equal("c1", bc.SourceColumn.Key) - a.Equal([]int{10}, bc.SourceColumn.Shape) - a.Equal([]int{1, 10}, bc.Boundaries) - a.Equal("c5", cc.Keys[1].(string)) - a.Equal(20, cc.HashBucketSize) - a.Equal("tf.feature_column.crossed_column([tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10]),\"c5\"], hash_bucket_size=20)", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcList, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucketSize) - a.NoError(e) - c = r.columns["feature_columns"] - fcList, _, e = resolveTrainColumns(&c) - a.Error(e) -} diff --git a/sql/embedding_column.go b/sql/embedding_column.go deleted file mode 100644 index ddd6b85433..0000000000 --- a/sql/embedding_column.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strconv" -) - -type embeddingColumn struct { - CategoryColumn interface{} - Dimension int - Combiner string - Initializer string -} - -func (ec *embeddingColumn) GetDelimiter() string { - return ec.CategoryColumn.(featureColumn).GetDelimiter() -} - -func (ec *embeddingColumn) GetDtype() string { - return ec.CategoryColumn.(featureColumn).GetDtype() -} - -func (ec *embeddingColumn) GetKey() string { - return ec.CategoryColumn.(featureColumn).GetKey() -} - -func (ec *embeddingColumn) GetInputShape() string { - return ec.CategoryColumn.(featureColumn).GetInputShape() -} - -func (ec *embeddingColumn) GetColumnType() int { - return columnTypeEmbedding -} - -func (ec *embeddingColumn) GenerateCode() (string, error) { - catColumn, ok := ec.CategoryColumn.(featureColumn) - if !ok { - return "", fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) - } - sourceCode, err := catColumn.GenerateCode() - if err != nil { - return "", err - } - return fmt.Sprintf("tf.feature_column.embedding_column(%s, dimension=%d, combiner=\"%s\")", - sourceCode, ec.Dimension, ec.Combiner), nil -} - -func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { - if len(*el) != 4 && len(*el) != 5 { - return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) - } - sourceExprList := (*el)[1] - var source featureColumn - var err error - if sourceExprList.typ == 0 { - source, _, err = resolveColumn(&sourceExprList.sexp) - if err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("key of EMBEDDING must be categorical column") - } - // TODO(uuleon) support other kinds of categorical column in the future - var catColumn interface{} - catColumn, ok := source.(*categoryIDColumn) - if !ok { - catColumn, ok = source.(*sequenceCategoryIDColumn) - if !ok { - return nil, fmt.Errorf("key of EMBEDDING must be categorical column") - } - } - dimension, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) - } - combiner, err := expression2string((*el)[3]) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) - } - initializer := "" - if len(*el) == 5 { - initializer, err = expression2string((*el)[4]) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) - } - } - return &embeddingColumn{ - CategoryColumn: catColumn, - Dimension: dimension, - Combiner: combiner, - Initializer: initializer}, nil -} diff --git a/sql/embedding_column_test.go b/sql/embedding_column_test.go deleted file mode 100644 index a4325080b1..0000000000 --- a/sql/embedding_column_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEmbeddingColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), 200, mean)") - badInput := statementWithColumn("EMBEDDING(c1, 100, mean)") - badBucket := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), bad, mean)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - ec, ok := fcs[0].(*embeddingColumn) - a.True(ok) - code, e := ec.GenerateCode() - a.NoError(e) - cc, ok := ec.CategoryColumn.(*categoryIDColumn) - a.True(ok) - a.Equal("c1", cc.Key) - a.Equal(100, cc.BucketSize) - a.Equal(200, ec.Dimension) - a.Equal("tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100), dimension=200, combiner=\"mean\")", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucket) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} diff --git a/sql/engine_spec.go b/sql/engine_spec.go deleted file mode 100644 index fbcca0e069..0000000000 --- a/sql/engine_spec.go +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "strconv" - "strings" -) - -type engineSpec struct { - etype string - ps resourceSpec - worker resourceSpec - cluster string - queue string - masterResourceRequest string - masterResourceLimit string - workerResourceRequest string - workerResourceLimit string - volume string - imagePullPolicy string - restartPolicy string - extraPypiIndex string - namespace string - minibatchSize int - masterPodPriority string - clusterSpec string - recordsPerTask int -} - -func getEngineSpec(attrs map[string]*attribute) engineSpec { - getInt := func(key string, defaultValue int) int { - if p, ok := attrs[key]; ok { - strVal, _ := p.Value.(string) - intVal, err := strconv.Atoi(strVal) - - if err == nil { - return intVal - } - } - return defaultValue - } - getString := func(key string, defaultValue string) string { - if p, ok := attrs[key]; ok { - strVal, ok := p.Value.(string) - if ok { - // TODO(joyyoj): use the parser to do those validations. - if strings.HasPrefix(strVal, "\"") && strings.HasSuffix(strVal, "\"") { - return strVal[1 : len(strVal)-1] - } - return strVal - } - } - return defaultValue - } - - psNum := getInt("ps_num", 1) - psMemory := getInt("ps_memory", 2400) - workerMemory := getInt("worker_memory", 1600) - workerNum := getInt("worker_num", 2) - engineType := getString("type", "local") - if (psNum > 0 || workerNum > 0) && engineType == "local" { - engineType = "yarn" - } - cluster := getString("cluster", "") - queue := getString("queue", "") - - // ElasticDL engine specs - masterResourceRequest := getString("master_resource_request", "cpu=0.1,memory=1024Mi") - masterResourceLimit := getString("master_resource_limit", "") - workerResourceRequest := getString("worker_resource_request", "cpu=1,memory=4096Mi") - workerResourceLimit := getString("worker_resource_limit", "") - volume := getString("volume", "") - imagePullPolicy := getString("image_pull_policy", "Always") - restartPolicy := getString("restart_policy", "Never") - extraPypiIndex := getString("extra_pypi_index", "") - namespace := getString("namespace", "default") - minibatchSize := getInt("minibatch_size", 64) - masterPodPriority := getString("master_pod_priority", "") - clusterSpec := getString("cluster_spec", "") - recordsPerTask := getInt("records_per_task", 100) - - return engineSpec{ - etype: engineType, - ps: resourceSpec{Num: psNum, Memory: psMemory}, - worker: resourceSpec{Num: workerNum, Memory: workerMemory}, - cluster: cluster, - queue: queue, - masterResourceRequest: masterResourceRequest, - masterResourceLimit: masterResourceLimit, - workerResourceRequest: workerResourceRequest, - workerResourceLimit: workerResourceLimit, - volume: volume, - imagePullPolicy: imagePullPolicy, - restartPolicy: restartPolicy, - extraPypiIndex: extraPypiIndex, - namespace: namespace, - minibatchSize: minibatchSize, - masterPodPriority: masterPodPriority, - clusterSpec: clusterSpec, - recordsPerTask: recordsPerTask, - } -} diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index d25dfd426b..675d63bb92 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -17,6 +17,8 @@ import ( "fmt" "strconv" "strings" + + "github.com/sql-machine-learning/sqlflow/sql/columns" ) const ( @@ -61,8 +63,8 @@ type resolvedTrainClause struct { EvalStartDelay int EvalThrottle int EvalCheckpointFilenameForInit string - FeatureColumns map[string][]featureColumn - ColumnSpecs map[string][]*columnSpec + FeatureColumns map[string][]columns.FeatureColumn + ColumnSpecs map[string][]*columns.ColumnSpec EngineParams engineSpec CustomModule *gitLabModule } @@ -75,11 +77,6 @@ type resolvedPredictClause struct { EngineParams engineSpec } -type featureMap struct { - Table string - Partition string -} - func trimQuotes(s string) string { if len(s) >= 2 { if s[0] == '"' && s[len(s)-1] == '"' { @@ -138,8 +135,8 @@ func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) { } return defaultValue } - modelParams := filter(attrs, "model", true) - engineParams := filter(attrs, "engine", true) + modelParams := attrFilter(attrs, "model", true) + engineParams := attrFilter(attrs, "engine", true) batchSize := getIntAttr("train.batch_size", 512) dropRemainder := getBoolAttr("train.drop_remainder", true, false) @@ -203,8 +200,8 @@ func resolveTrainClause(tc *trainClause) (*resolvedTrainClause, error) { return nil, fmt.Errorf("unsupported parameters: %v", attrs) } - fcMap := map[string][]featureColumn{} - csMap := map[string][]*columnSpec{} + fcMap := map[string][]columns.FeatureColumn{} + csMap := map[string][]*columns.ColumnSpec{} for target, columns := range tc.columns { fcs, css, err := resolveTrainColumns(&columns) if err != nil { @@ -262,8 +259,8 @@ func resolvePredictClause(pc *predictClause) (*resolvedPredictClause, error) { } return defaultValue } - modelParams := filter(attrs, "model", true) - engineParams := filter(attrs, "engine", true) + modelParams := attrFilter(attrs, "model", true) + engineParams := attrFilter(attrs, "engine", true) checkpointFilenameForInit := getStringAttr("predict.checkpoint_filename_for_init", "") @@ -281,14 +278,14 @@ func resolvePredictClause(pc *predictClause) (*resolvedPredictClause, error) { // resolveTrainColumns resolve columns from SQL statement, // returns featureColumn list and featureSpecs -func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, error) { - var fcs = make([]featureColumn, 0) - var css = make([]*columnSpec, 0) - for _, expr := range *columns { +func resolveTrainColumns(columnExprs *exprlist) ([]columns.FeatureColumn, []*columns.ColumnSpec, error) { + var fcs = make([]columns.FeatureColumn, 0) + var css = make([]*columns.ColumnSpec, 0) + for _, expr := range *columnExprs { if expr.typ != 0 { // Column identifier like "COLUMN a1,b1" // FIXME(typhoonzero): infer the column spec here. - c := &numericColumn{ + c := &columns.NumericColumn{ Key: expr.val, Shape: []int{1}, Dtype: "float32", @@ -383,35 +380,17 @@ func expression2string(e interface{}) (string, error) { return "", fmt.Errorf("expression expected to be string, actual: %s", resolved) } -func (cs *columnSpec) ToString() string { - if cs.IsSparse { - shape := strings.Join(strings.Split(fmt.Sprint(cs.Shape), " "), ",") - if len(cs.Shape) > 1 { - groupCnt := len(cs.Shape) - return fmt.Sprintf("GroupedSparseColumn(name=\"%s\", shape=%s, dtype=\"%s\", group=%d, group_separator='\\002')", - cs.ColumnName, shape, cs.DType, groupCnt) - } - return fmt.Sprintf("SparseColumn(name=\"%s\", shape=%s, dtype=\"%s\")", cs.ColumnName, shape, cs.DType) - - } - return fmt.Sprintf("DenseColumn(name=\"%s\", shape=%s, dtype=\"%s\", separator=\"%s\")", - cs.ColumnName, - strings.Join(strings.Split(fmt.Sprint(cs.Shape), " "), ","), - cs.DType, - cs.Delimiter) -} - -func generateFeatureColumnCode(fcs []featureColumn) (string, error) { - var codes = make([]string, 0, len(fcs)) - for _, fc := range fcs { - code, err := fc.GenerateCode() - if err != nil { - return "", nil - } - codes = append(codes, code) - } - return fmt.Sprintf("[%s]", strings.Join(codes, ",")), nil -} +// func generateFeatureColumnCode(fcs []columns.FeatureColumn) (string, error) { +// var codes = make([]string, 0, len(fcs)) +// for _, fc := range fcs { +// code, err := fc.GenerateCode() +// if err != nil { +// return "", nil +// } +// codes = append(codes, code) +// } +// return fmt.Sprintf("[%s]", strings.Join(codes, ",")), nil +// } func resolveDelimiter(delimiter string) (string, error) { if strings.EqualFold(delimiter, comma) { @@ -431,3 +410,309 @@ func transformToIntList(list []interface{}) ([]int, error) { } return b, nil } + +func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { + ret := make(map[string]*attribute) + for k, v := range *attrs { + subs := strings.SplitN(k, ".", 2) + name := subs[len(subs)-1] + prefix := "" + if len(subs) == 2 { + prefix = subs[0] + } + r, _, err := resolveExpression(v) + if err != nil { + return nil, err + } + a := &attribute{ + FullName: k, + Prefix: prefix, + Name: name, + Value: r} + ret[a.FullName] = a + } + return ret, nil +} + +func resolveBucketColumn(el *exprlist) (*columns.BucketColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) + } + sourceExprList := (*el)[1] + boundariesExprList := (*el)[2] + if sourceExprList.typ != 0 { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %v", sourceExprList) + } + source, _, err := resolveColumn(&sourceExprList.sexp) + if err != nil { + return nil, err + } + if source.GetColumnType() != columns.ColumnTypeNumeric { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) + } + boundaries, _, err := resolveExpression(boundariesExprList) + if err != nil { + return nil, err + } + if _, ok := boundaries.([]interface{}); !ok { + return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) + } + b, err := transformToIntList(boundaries.([]interface{})) + if err != nil { + return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) + } + return &columns.BucketColumn{ + SourceColumn: source.(*columns.NumericColumn), + Boundaries: b}, nil +} + +func resolveSeqCategoryIDColumn(el *exprlist) (*columns.SequenceCategoryIDColumn, *columns.ColumnSpec, error) { + key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) + if err != nil { + return nil, nil, err + } + return &columns.SequenceCategoryIDColumn{ + Key: key, + BucketSize: bucketSize, + Delimiter: delimiter, + // TODO(typhoonzero): support config dtype + Dtype: "int64", + IsSequence: true}, cs, nil +} + +func resolveCategoryIDColumn(el *exprlist) (*columns.CategoryIDColumn, *columns.ColumnSpec, error) { + key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) + if err != nil { + return nil, nil, err + } + return &columns.CategoryIDColumn{ + Key: key, + BucketSize: bucketSize, + Delimiter: delimiter, + // TODO(typhoonzero): support config dtype + Dtype: "int64"}, cs, nil +} + +func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columns.ColumnSpec, error) { + if len(*el) != 3 && len(*el) != 4 { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) + } + var cs *columns.ColumnSpec + key := "" + var err error + if (*el)[1].typ == 0 { + // explist, maybe DENSE/SPARSE expressions + subExprList := (*el)[1].sexp + isSparse := subExprList[0].val == sparse + cs, err = resolveColumnSpec(&subExprList, isSparse) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %v", subExprList) + } + key = cs.ColumnName + } else { + key, err = expression2string((*el)[1]) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) + } + } + bucketSize, err := strconv.Atoi((*el)[2].val) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) + } + delimiter := "" + if len(*el) == 4 { + delimiter, err = resolveDelimiter((*el)[3].val) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) + } + } + return key, bucketSize, delimiter, cs, nil +} + +func resolveCrossColumn(el *exprlist) (*columns.CrossColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad CROSS expression format: %s", *el) + } + keysExpr := (*el)[1] + key, _, err := resolveExpression(keysExpr) + if err != nil { + return nil, err + } + if _, ok := key.([]interface{}); !ok { + return nil, fmt.Errorf("bad CROSS expression format: %s", *el) + } + + bucketSize, err := strconv.Atoi((*el)[2].val) + if err != nil { + return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) + } + return &columns.CrossColumn{ + Keys: key.([]interface{}), + HashBucketSize: bucketSize}, nil +} + +func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, error) { + if len(*el) != 4 && len(*el) != 5 { + return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) + } + sourceExprList := (*el)[1] + var source columns.FeatureColumn + var err error + if sourceExprList.typ == 0 { + source, _, err = resolveColumn(&sourceExprList.sexp) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") + } + // TODO(uuleon) support other kinds of categorical column in the future + var catColumn interface{} + catColumn, ok := source.(*columns.CategoryIDColumn) + if !ok { + catColumn, ok = source.(*columns.SequenceCategoryIDColumn) + if !ok { + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") + } + } + dimension, err := strconv.Atoi((*el)[2].val) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) + } + combiner, err := expression2string((*el)[3]) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) + } + initializer := "" + if len(*el) == 5 { + initializer, err = expression2string((*el)[4]) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) + } + } + return &columns.EmbeddingColumn{ + CategoryColumn: catColumn, + Dimension: dimension, + Combiner: combiner, + Initializer: initializer}, nil +} + +func resolveNumericColumn(el *exprlist) (*columns.NumericColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) + } + key, err := expression2string((*el)[1]) + if err != nil { + return nil, fmt.Errorf("bad NUMERIC key: %s, err: %s", (*el)[1], err) + } + var shape []int + intVal, err := strconv.Atoi((*el)[2].val) + if err != nil { + list, _, err := resolveExpression((*el)[2]) + if err != nil { + return nil, err + } + if list, ok := list.([]interface{}); ok { + shape, err = transformToIntList(list) + if err != nil { + return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) + } + } else { + return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) + } + } else { + shape = append(shape, intVal) + } + return &columns.NumericColumn{ + Key: key, + Shape: shape, + // FIXME(typhoonzero, tony): support config Delimiter and Dtype + Delimiter: ",", + Dtype: "float32"}, nil +} + +func resolveColumnSpec(el *exprlist, isSparse bool) (*columns.ColumnSpec, error) { + if len(*el) < 4 { + return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) + } + name, err := expression2string((*el)[1]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec name: %s, err: %s", (*el)[1], err) + } + var shape []int + intShape, err := strconv.Atoi((*el)[2].val) + if err != nil { + strShape, err := expression2string((*el)[2]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) + } + if strShape != "none" { + return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) + } + } else { + shape = append(shape, intShape) + } + unresolvedDelimiter, err := expression2string((*el)[3]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec delimiter: %s, err: %s", (*el)[1], err) + } + + delimiter, err := resolveDelimiter(unresolvedDelimiter) + if err != nil { + return nil, err + } + + // resolve feature map + fm := columns.FeatureMap{} + dtype := "float" + if isSparse { + dtype = "int" + } + if len(*el) >= 5 { + dtype, err = expression2string((*el)[4]) + } + return &columns.ColumnSpec{ + ColumnName: name, + IsSparse: isSparse, + Shape: shape, + DType: dtype, + Delimiter: delimiter, + FeatureMap: fm}, nil +} + +// resolveFeatureColumn returns the acutal feature column typed struct +// as well as the columnSpec infomation. +func resolveColumn(el *exprlist) (columns.FeatureColumn, *columns.ColumnSpec, error) { + head := (*el)[0].val + if head == "" { + return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %v", el) + } + + switch strings.ToUpper(head) { + case dense: + cs, err := resolveColumnSpec(el, false) + return nil, cs, err + case sparse: + cs, err := resolveColumnSpec(el, true) + return nil, cs, err + case numeric: + // TODO(typhoonzero): support NUMERIC(DENSE(col)) and NUMERIC(SPARSE(col)) + fc, err := resolveNumericColumn(el) + return fc, nil, err + case bucket: + fc, err := resolveBucketColumn(el) + return fc, nil, err + case cross: + fc, err := resolveCrossColumn(el) + return fc, nil, err + case categoryID: + return resolveCategoryIDColumn(el) + case seqCategoryID: + return resolveSeqCategoryIDColumn(el) + case embedding: + fc, err := resolveEmbeddingColumn(el) + return fc, nil, err + default: + return nil, nil, fmt.Errorf("not supported expr: %s", head) + } +} diff --git a/sql/expression_resolver_test.go b/sql/expression_resolver_test.go index 263bb8f2a6..349bb06c5e 100644 --- a/sql/expression_resolver_test.go +++ b/sql/expression_resolver_test.go @@ -17,6 +17,7 @@ import ( "fmt" "testing" + "github.com/sql-machine-learning/sqlflow/sql/columns" "github.com/stretchr/testify/assert" ) @@ -49,3 +50,262 @@ func TestExecResource(t *testing.T) { attr := attrs["exec.worker_num"] a.Equal(attr.Value, "2") } + +func TestResolveAttrs(t *testing.T) { + a := assert.New(t) + parser := newParser() + + s := statementWithAttrs("estimator.hidden_units = [10, 20]") + r, e := parser.Parse(s) + a.NoError(e) + attrs, err := resolveAttribute(&r.trainAttrs) + a.NoError(err) + attr := attrs["estimator.hidden_units"] + a.Equal("estimator", attr.Prefix) + a.Equal("hidden_units", attr.Name) + a.Equal([]interface{}([]interface{}{10, 20}), attr.Value) + + s = statementWithAttrs("dataset.name = hello") + r, e = parser.Parse(s) + a.NoError(e) + attrs, err = resolveAttribute(&r.trainAttrs) + a.NoError(err) + attr = attrs["dataset.name"] + a.Equal("dataset", attr.Prefix) + a.Equal("name", attr.Name) + a.Equal("hello", attr.Value) +} + +func TestBucketColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])") + badInput := statementWithColumn("BUCKET(c1, [1, 10])") + badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + bc, ok := fcs[0].(*columns.BucketColumn) + a.True(ok) + code, e := bc.GenerateCode(nil) + a.NoError(e) + a.Equal("c1", bc.SourceColumn.Key) + a.Equal([]int{10}, bc.SourceColumn.Shape) + a.Equal([]int{1, 10}, bc.Boundaries) + a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code[0]) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBoundaries) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestCrossColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], 20)") + badInput := statementWithColumn("cross(c1, 20)") + badBucketSize := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcList, _, e := resolveTrainColumns(&c) + a.NoError(e) + cc, ok := fcList[0].(*columns.CrossColumn) + a.True(ok) + code, e := cc.GenerateCode(nil) + a.NoError(e) + + bc := cc.Keys[0].(*columns.BucketColumn) + a.Equal("c1", bc.SourceColumn.Key) + a.Equal([]int{10}, bc.SourceColumn.Shape) + a.Equal([]int{1, 10}, bc.Boundaries) + a.Equal("c5", cc.Keys[1].(string)) + a.Equal(20, cc.HashBucketSize) + a.Equal("tf.feature_column.crossed_column([tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10]),\"c5\"], hash_bucket_size=20)", code[0]) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcList, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucketSize) + a.NoError(e) + c = r.columns["feature_columns"] + fcList, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestCatIdColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("CATEGORY_ID(c1, 100)") + badKey := statementWithColumn("CATEGORY_ID([100], 100)") + badBucket := statementWithColumn("CATEGORY_ID(c1, bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + cc, ok := fcs[0].(*columns.CategoryIDColumn) + a.True(ok) + code, e := cc.GenerateCode(nil) + a.NoError(e) + a.Equal("c1", cc.Key) + a.Equal(100, cc.BucketSize) + a.Equal("tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100)", code[0]) + + r, e = parser.Parse(badKey) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucket) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestCatIdColumnWithColumnSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + dense := statementWithColumn("CATEGORY_ID(DENSE(col1, 128, COMMA), 100)") + + r, e := parser.Parse(dense) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, css, e := resolveTrainColumns(&c) + a.NoError(e) + _, ok := fcs[0].(*columns.CategoryIDColumn) + a.True(ok) + a.Equal(css[0].ColumnName, "col1") +} + +func TestEmbeddingColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), 200, mean)") + badInput := statementWithColumn("EMBEDDING(c1, 100, mean)") + badBucket := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), bad, mean)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + ec, ok := fcs[0].(*columns.EmbeddingColumn) + a.True(ok) + code, e := ec.GenerateCode(nil) + a.NoError(e) + cc, ok := ec.CategoryColumn.(*columns.CategoryIDColumn) + a.True(ok) + a.Equal("c1", cc.Key) + a.Equal(100, cc.BucketSize) + a.Equal(200, ec.Dimension) + a.Equal("tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100), dimension=200, combiner=\"mean\")", code[0]) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucket) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestNumericColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("NUMERIC(c2, [5, 10])") + moreArgs := statementWithColumn("NUMERIC(c1, 100, args)") + badShape := statementWithColumn("NUMERIC(c1, bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + nc, ok := fcs[0].(*columns.NumericColumn) + a.True(ok) + code, e := nc.GenerateCode(nil) + a.NoError(e) + a.Equal("c2", nc.Key) + a.Equal([]int{5, 10}, nc.Shape) + a.Equal("tf.feature_column.numeric_column(\"c2\", shape=[5,10])", code[0]) + + r, e = parser.Parse(moreArgs) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badShape) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestFeatureSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + denseStatement := statementWithColumn("DENSE(c2, 5, comma)") + sparseStatement := statementWithColumn("SPARSE(c1, 100, comma)") + badStatement := statementWithColumn("DENSE(c3, bad, comma)") + + r, e := parser.Parse(denseStatement) + a.NoError(e) + c := r.columns["feature_columns"] + _, css, e := resolveTrainColumns(&c) + a.NoError(e) + cs := css[0] + a.Equal("c2", cs.ColumnName) + a.Equal(5, cs.Shape[0]) + a.Equal(",", cs.Delimiter) + a.Equal(false, cs.IsSparse) + a.Equal("DenseColumn(name=\"c2\", shape=[5], dtype=\"float\", separator=\",\")", cs.ToString()) + + r, e = parser.Parse(sparseStatement) + a.NoError(e) + c = r.columns["feature_columns"] + _, css, e = resolveTrainColumns(&c) + a.NoError(e) + cs = css[0] + a.Equal("c1", cs.ColumnName) + a.Equal(100, cs.Shape[0]) + a.Equal(true, cs.IsSparse) + a.Equal("SparseColumn(name=\"c1\", shape=[100], dtype=\"int\")", cs.ToString()) + + r, e = parser.Parse(badStatement) + a.NoError(e) + c = r.columns["feature_columns"] + _, _, e = resolveTrainColumns(&c) + a.Error(e) + +} diff --git a/sql/feature_column.go b/sql/feature_column.go deleted file mode 100644 index 1be8becfb6..0000000000 --- a/sql/feature_column.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "fmt" - "strings" -) - -const ( - columnTypeBucket = 0 - columnTypeEmbedding = 1 - columnTypeNumeric = 2 - columnTypeCategoryID = 3 - columnTypeSeqCategoryID = 3 - columnTypeCross = 4 -) - -// featureColumn is an interface that all types of feature columns -// should follow. featureColumn is used to generate feature column code. -type featureColumn interface { - GenerateCode() (string, error) - // FIXME(typhoonzero): remove delimiter, dtype shape from feature column - // get these from column spec claused or by feature derivation. - GetDelimiter() string - GetDtype() string - GetKey() string - // return input shape json string, like "[2,3]" - GetInputShape() string - GetColumnType() int -} - -// resolveFeatureColumn returns the acutal feature column typed struct -// as well as the columnSpec infomation. -func resolveColumn(el *exprlist) (featureColumn, *columnSpec, error) { - head := (*el)[0].val - if head == "" { - return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %v", el) - } - - switch strings.ToUpper(head) { - case dense: - cs, err := resolveColumnSpec(el, false) - return nil, cs, err - case sparse: - cs, err := resolveColumnSpec(el, true) - return nil, cs, err - case numeric: - // TODO(typhoonzero): support NUMERIC(DENSE(col)) and NUMERIC(SPARSE(col)) - fc, err := resolveNumericColumn(el) - return fc, nil, err - case bucket: - fc, err := resolveBucketColumn(el) - return fc, nil, err - case cross: - fc, err := resolveCrossColumn(el) - return fc, nil, err - case categoryID: - return resolveCategoryIDColumn(el) - case seqCategoryID: - return resolveSeqCategoryIDColumn(el) - case embedding: - fc, err := resolveEmbeddingColumn(el) - return fc, nil, err - default: - return nil, nil, fmt.Errorf("not supported expr: %s", head) - } -} diff --git a/sql/gitlab_module.go b/sql/gitlab_module.go deleted file mode 100644 index aa26d707fd..0000000000 --- a/sql/gitlab_module.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -type gitLabModule struct { - ModuleName string - ProjectName string - Sha string - PrivateToken string - SourceRoot string - GitLabServer string -} diff --git a/sql/numeric_column.go b/sql/numeric_column.go deleted file mode 100644 index 022ebf4204..0000000000 --- a/sql/numeric_column.go +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "encoding/json" - "fmt" - "strconv" - "strings" -) - -type numericColumn struct { - Key string - Shape []int - Delimiter string - Dtype string -} - -func (nc *numericColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, - strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ",")), nil -} - -func (nc *numericColumn) GetDelimiter() string { - return nc.Delimiter -} - -func (nc *numericColumn) GetDtype() string { - return nc.Dtype -} - -func (nc *numericColumn) GetKey() string { - return nc.Key -} - -func (nc *numericColumn) GetInputShape() string { - jsonBytes, err := json.Marshal(nc.Shape) - if err != nil { - return "" - } - return string(jsonBytes) -} - -func (nc *numericColumn) GetColumnType() int { - return columnTypeNumeric -} - -func resolveNumericColumn(el *exprlist) (*numericColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) - } - key, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad NUMERIC key: %s, err: %s", (*el)[1], err) - } - var shape []int - intVal, err := strconv.Atoi((*el)[2].val) - if err != nil { - list, _, err := resolveExpression((*el)[2]) - if err != nil { - return nil, err - } - if list, ok := list.([]interface{}); ok { - shape, err = transformToIntList(list) - if err != nil { - return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) - } - } else { - return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) - } - } else { - shape = append(shape, intVal) - } - return &numericColumn{ - Key: key, - Shape: shape, - // FIXME(typhoonzero, tony): support config Delimiter and Dtype - Delimiter: ",", - Dtype: "float32"}, nil -} diff --git a/sql/numeric_column_test.go b/sql/numeric_column_test.go deleted file mode 100644 index 38bd65f05b..0000000000 --- a/sql/numeric_column_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2019 The SQLFlow Authors. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNumericColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("NUMERIC(c2, [5, 10])") - moreArgs := statementWithColumn("NUMERIC(c1, 100, args)") - badShape := statementWithColumn("NUMERIC(c1, bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - nc, ok := fcs[0].(*numericColumn) - a.True(ok) - code, e := nc.GenerateCode() - a.NoError(e) - a.Equal("c2", nc.Key) - a.Equal([]int{5, 10}, nc.Shape) - a.Equal("tf.feature_column.numeric_column(\"c2\", shape=[5,10])", code) - - r, e = parser.Parse(moreArgs) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badShape) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} From 6f8cea1169e5de2a5d854ebede4bd71a584ad26a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Mon, 2 Sep 2019 21:06:02 +0800 Subject: [PATCH 2/2] add columns package --- sql/columns/bucket_column.go | 62 +++++++++++++++++ sql/columns/category_id_column.go | 108 ++++++++++++++++++++++++++++++ sql/columns/column_spec.go | 55 +++++++++++++++ sql/columns/cross_column.go | 80 ++++++++++++++++++++++ sql/columns/embedding_column.go | 68 +++++++++++++++++++ sql/columns/feature_column.go | 46 +++++++++++++ sql/columns/numeric_column.go | 63 +++++++++++++++++ 7 files changed, 482 insertions(+) create mode 100644 sql/columns/bucket_column.go create mode 100644 sql/columns/category_id_column.go create mode 100644 sql/columns/column_spec.go create mode 100644 sql/columns/cross_column.go create mode 100644 sql/columns/embedding_column.go create mode 100644 sql/columns/feature_column.go create mode 100644 sql/columns/numeric_column.go diff --git a/sql/columns/bucket_column.go b/sql/columns/bucket_column.go new file mode 100644 index 0000000000..b0d31f801d --- /dev/null +++ b/sql/columns/bucket_column.go @@ -0,0 +1,62 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +import ( + "fmt" + "strings" +) + +// BucketColumn is the wrapper of `tf.feature_column.bucketized_column` +type BucketColumn struct { + SourceColumn *NumericColumn + Boundaries []int +} + +// GenerateCode implements the FeatureColumn interface. +func (bc *BucketColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { + sourceCode, _ := bc.SourceColumn.GenerateCode(cs) + if len(sourceCode) > 1 { + return []string{}, fmt.Errorf("does not support grouped column: %v", sourceCode) + } + return []string{fmt.Sprintf( + "tf.feature_column.bucketized_column(%s, boundaries=%s)", + sourceCode[0], + strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ","))}, nil +} + +// GetDelimiter implements the FeatureColumn interface. +func (bc *BucketColumn) GetDelimiter() string { + return "" +} + +// GetDtype implements the FeatureColumn interface. +func (bc *BucketColumn) GetDtype() string { + return "" +} + +// GetKey implements the FeatureColumn interface. +func (bc *BucketColumn) GetKey() string { + return bc.SourceColumn.Key +} + +// GetInputShape implements the FeatureColumn interface. +func (bc *BucketColumn) GetInputShape() string { + return bc.SourceColumn.GetInputShape() +} + +// GetColumnType implements the FeatureColumn interface. +func (bc *BucketColumn) GetColumnType() int { + return ColumnTypeBucket +} diff --git a/sql/columns/category_id_column.go b/sql/columns/category_id_column.go new file mode 100644 index 0000000000..a4ddf224f9 --- /dev/null +++ b/sql/columns/category_id_column.go @@ -0,0 +1,108 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +import ( + "fmt" +) + +// CategoryIDColumn is the wrapper of `tf.feature_column.categorical_column_with_identity` +type CategoryIDColumn struct { + Key string + BucketSize int + Delimiter string + Dtype string +} + +// SequenceCategoryIDColumn is the wrapper of `tf.feature_column.sequence_categorical_column_with_identity` +// NOTE: only used in tf >= 2.0 versions. +type SequenceCategoryIDColumn struct { + Key string + BucketSize int + Delimiter string + Dtype string + IsSequence bool +} + +// GenerateCode implements the FeatureColumn interface. +func (cc *CategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { + return []string{fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", + cc.Key, cc.BucketSize)}, nil +} + +// GetDelimiter implements the FeatureColumn interface. +func (cc *CategoryIDColumn) GetDelimiter() string { + return cc.Delimiter +} + +// GetDtype implements the FeatureColumn interface. +func (cc *CategoryIDColumn) GetDtype() string { + return cc.Dtype +} + +// GetKey implements the FeatureColumn interface. +func (cc *CategoryIDColumn) GetKey() string { + return cc.Key +} + +// GetInputShape implements the FeatureColumn interface. +func (cc *CategoryIDColumn) GetInputShape() string { + return fmt.Sprintf("[%d]", cc.BucketSize) +} + +// GetColumnType implements the FeatureColumn interface. +func (cc *CategoryIDColumn) GetColumnType() int { + return ColumnTypeCategoryID +} + +// GenerateCode implements the FeatureColumn interface. +func (cc *SequenceCategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { + return []string{fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", + cc.Key, cc.BucketSize)}, nil +} + +// GetDelimiter implements the FeatureColumn interface. +func (cc *SequenceCategoryIDColumn) GetDelimiter() string { + return cc.Delimiter +} + +// GetDtype implements the FeatureColumn interface. +func (cc *SequenceCategoryIDColumn) GetDtype() string { + return cc.Dtype +} + +// GetKey implements the FeatureColumn interface. +func (cc *SequenceCategoryIDColumn) GetKey() string { + return cc.Key +} + +// GetInputShape implements the FeatureColumn interface. +func (cc *SequenceCategoryIDColumn) GetInputShape() string { + return fmt.Sprintf("[%d]", cc.BucketSize) +} + +// GetColumnType implements the FeatureColumn interface. +func (cc *SequenceCategoryIDColumn) GetColumnType() int { + return ColumnTypeSeqCategoryID +} + +// func parseCategoryColumnKey(el *exprlist) (*columnSpec, error) { +// if (*el)[1].typ == 0 { +// // explist, maybe DENSE/SPARSE expressions +// subExprList := (*el)[1].sexp +// isSparse := subExprList[0].val == sparse +// return resolveColumnSpec(&subExprList, isSparse) +// } +// return nil, nil +// } diff --git a/sql/columns/column_spec.go b/sql/columns/column_spec.go new file mode 100644 index 0000000000..069eeaf28a --- /dev/null +++ b/sql/columns/column_spec.go @@ -0,0 +1,55 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +import ( + "fmt" + "strings" +) + +// FeatureMap only used by codegen_alps, a table containing column parse +// informations. +type FeatureMap struct { + Table string + Partition string +} + +// ColumnSpec defines how to generate codes to parse column data to tensor/sparsetensor +type ColumnSpec struct { + ColumnName string + IsSparse bool + Shape []int + DType string + Delimiter string + FeatureMap FeatureMap +} + +// ToString generates the debug string of ColumnSpec +func (cs *ColumnSpec) ToString() string { + if cs.IsSparse { + shape := strings.Join(strings.Split(fmt.Sprint(cs.Shape), " "), ",") + if len(cs.Shape) > 1 { + groupCnt := len(cs.Shape) + return fmt.Sprintf("GroupedSparseColumn(name=\"%s\", shape=%s, dtype=\"%s\", group=%d, group_separator='\\002')", + cs.ColumnName, shape, cs.DType, groupCnt) + } + return fmt.Sprintf("SparseColumn(name=\"%s\", shape=%s, dtype=\"%s\")", cs.ColumnName, shape, cs.DType) + + } + return fmt.Sprintf("DenseColumn(name=\"%s\", shape=%s, dtype=\"%s\", separator=\"%s\")", + cs.ColumnName, + strings.Join(strings.Split(fmt.Sprint(cs.Shape), " "), ","), + cs.DType, + cs.Delimiter) +} diff --git a/sql/columns/cross_column.go b/sql/columns/cross_column.go new file mode 100644 index 0000000000..a0c0d8a504 --- /dev/null +++ b/sql/columns/cross_column.go @@ -0,0 +1,80 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +import ( + "fmt" + "strings" +) + +// CrossColumn is the wapper of `tf.feature_column.crossed_column` +// TODO(uuleon) specify the hash_key if needed +type CrossColumn struct { + Keys []interface{} + HashBucketSize int +} + +// GenerateCode implements the FeatureColumn interface. +func (cc *CrossColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { + var keysGenerated = make([]string, len(cc.Keys)) + for idx, key := range cc.Keys { + if c, ok := key.(FeatureColumn); ok { + codeList, err := c.GenerateCode(cs) + if err != nil { + return []string{}, err + } + if len(codeList) > 1 { + return []string{}, fmt.Errorf("cross column does not support crossing multi feature column types") + } + keysGenerated[idx] = codeList[0] + continue + } + if str, ok := key.(string); ok { + keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) + } else { + return []string{}, fmt.Errorf("cross generate code error, key: %s", key) + } + } + return []string{fmt.Sprintf( + "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", + strings.Join(keysGenerated, ","), cc.HashBucketSize)}, nil +} + +// GetDelimiter implements the FeatureColumn interface. +func (cc *CrossColumn) GetDelimiter() string { + return "" +} + +// GetDtype implements the FeatureColumn interface. +func (cc *CrossColumn) GetDtype() string { + return "" +} + +// GetKey implements the FeatureColumn interface. +func (cc *CrossColumn) GetKey() string { + // NOTE: cross column is a feature on multiple column keys. + return "" +} + +// GetInputShape implements the FeatureColumn interface. +func (cc *CrossColumn) GetInputShape() string { + // NOTE: return empty since crossed column input shape is already determined + // by the two crossed columns. + return "" +} + +// GetColumnType implements the FeatureColumn interface. +func (cc *CrossColumn) GetColumnType() int { + return ColumnTypeCross +} diff --git a/sql/columns/embedding_column.go b/sql/columns/embedding_column.go new file mode 100644 index 0000000000..33bf459ff1 --- /dev/null +++ b/sql/columns/embedding_column.go @@ -0,0 +1,68 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +import ( + "fmt" +) + +// EmbeddingColumn is the wrapper of `tf.feature_column.embedding_column` +type EmbeddingColumn struct { + CategoryColumn interface{} + Dimension int + Combiner string + Initializer string +} + +// GetDelimiter implements the FeatureColumn interface. +func (ec *EmbeddingColumn) GetDelimiter() string { + return ec.CategoryColumn.(FeatureColumn).GetDelimiter() +} + +// GetDtype implements the FeatureColumn interface. +func (ec *EmbeddingColumn) GetDtype() string { + return ec.CategoryColumn.(FeatureColumn).GetDtype() +} + +// GetKey implements the FeatureColumn interface. +func (ec *EmbeddingColumn) GetKey() string { + return ec.CategoryColumn.(FeatureColumn).GetKey() +} + +// GetInputShape implements the FeatureColumn interface. +func (ec *EmbeddingColumn) GetInputShape() string { + return ec.CategoryColumn.(FeatureColumn).GetInputShape() +} + +// GetColumnType implements the FeatureColumn interface. +func (ec *EmbeddingColumn) GetColumnType() int { + return ColumnTypeEmbedding +} + +// GenerateCode implements the FeatureColumn interface. +func (ec *EmbeddingColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { + catColumn, ok := ec.CategoryColumn.(FeatureColumn) + if !ok { + return []string{}, fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) + } + sourceCode, err := catColumn.GenerateCode(cs) + if err != nil { + return []string{}, err + } + if len(sourceCode) > 1 { + return []string{}, fmt.Errorf("does not support grouped column: %v", sourceCode) + } + return []string{fmt.Sprintf("tf.feature_column.embedding_column(%s, dimension=%d, combiner=\"%s\")", + sourceCode[0], ec.Dimension, ec.Combiner)}, nil +} diff --git a/sql/columns/feature_column.go b/sql/columns/feature_column.go new file mode 100644 index 0000000000..245327a2c8 --- /dev/null +++ b/sql/columns/feature_column.go @@ -0,0 +1,46 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +const ( + // ColumnTypeBucket is the `FeatureColumn` of type bucket_column + ColumnTypeBucket = 0 + // ColumnTypeEmbedding is the `FeatureColumn` of type embedding_column + ColumnTypeEmbedding = 1 + // ColumnTypeNumeric is the `FeatureColumn` of type numeric_column + ColumnTypeNumeric = 2 + // ColumnTypeCategoryID is the `FeatureColumn` of type category_id_column + ColumnTypeCategoryID = 3 + // ColumnTypeSeqCategoryID is the `FeatureColumn` of type sequence_category_id_column + ColumnTypeSeqCategoryID = 4 + // ColumnTypeCross is the `FeatureColumn` of type cross_column + ColumnTypeCross = 5 +) + +// FeatureColumn is an interface that all types of feature columns +// should follow. featureColumn is used to generate feature column code. +type FeatureColumn interface { + // NOTE: submitters need to know the columnSpec when generating + // feature_column code. And we maybe use one compound column's data to generate + // multiple feature columns, so return a list of strings. + GenerateCode(cs *ColumnSpec) ([]string, error) + // FIXME(typhoonzero): remove delimiter, dtype shape from feature column + // get these from column spec claused or by feature derivation. + GetDelimiter() string + GetDtype() string + GetKey() string + // return input shape json string, like "[2,3]" + GetInputShape() string + GetColumnType() int +} diff --git a/sql/columns/numeric_column.go b/sql/columns/numeric_column.go new file mode 100644 index 0000000000..399e5712fd --- /dev/null +++ b/sql/columns/numeric_column.go @@ -0,0 +1,63 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package columns + +import ( + "encoding/json" + "fmt" + "strings" +) + +// NumericColumn is the wrapper of `tf.feature_column.numeric_column` +type NumericColumn struct { + Key string + Shape []int + Delimiter string + Dtype string +} + +// GenerateCode implements FeatureColumn interface. +func (nc *NumericColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { + return []string{fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, + strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ","))}, nil +} + +// GetDelimiter implements FeatureColumn interface. +func (nc *NumericColumn) GetDelimiter() string { + return nc.Delimiter +} + +// GetDtype implements FeatureColumn interface. +func (nc *NumericColumn) GetDtype() string { + return nc.Dtype +} + +// GetKey implements FeatureColumn interface. +func (nc *NumericColumn) GetKey() string { + return nc.Key +} + +// GetInputShape implements FeatureColumn interface. +func (nc *NumericColumn) GetInputShape() string { + jsonBytes, err := json.Marshal(nc.Shape) + if err != nil { + return "" + } + return string(jsonBytes) +} + +// GetColumnType implements FeatureColumn interface. +func (nc *NumericColumn) GetColumnType() int { + return ColumnTypeNumeric +}