From 5184c32f44eb42b8fa90d142b98388b62ff516ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Tue, 27 Aug 2019 21:04:21 +0800 Subject: [PATCH 1/8] wip simplify expression resolver --- sql/attribute.go | 83 +++++ sql/bucket_column.go | 77 +++++ sql/category_id_column.go | 125 +++++++ sql/codegen_alps.go | 6 +- sql/column_spec.go | 24 ++ sql/cross_column.go | 88 +++++ sql/embedding_column.go | 95 ++++++ sql/engine_spec.go | 114 +++++++ sql/expression_resolver.go | 571 +------------------------------- sql/expression_resolver_test.go | 16 + sql/feature_column.go | 28 ++ sql/gitlab_module.go | 23 ++ sql/numeric_column.go | 87 +++++ 13 files changed, 771 insertions(+), 566 deletions(-) create mode 100644 sql/attribute.go create mode 100644 sql/bucket_column.go create mode 100644 sql/category_id_column.go create mode 100644 sql/column_spec.go create mode 100644 sql/cross_column.go create mode 100644 sql/embedding_column.go create mode 100644 sql/engine_spec.go create mode 100644 sql/feature_column.go create mode 100644 sql/gitlab_module.go create mode 100644 sql/numeric_column.go diff --git a/sql/attribute.go b/sql/attribute.go new file mode 100644 index 0000000000..2e9e9610e9 --- /dev/null +++ b/sql/attribute.go @@ -0,0 +1,83 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" + "strings" +) + +type attribute struct { + FullName string + Prefix string + Name string + Value interface{} +} + +func (a *attribute) GenerateCode() (string, error) { + if val, ok := a.Value.(string); ok { + // auto convert to int first. + if _, err := strconv.Atoi(val); err == nil { + return fmt.Sprintf("%s=%s", a.Name, val), nil + } + return fmt.Sprintf("%s=\"%s\"", a.Name, val), nil + } + if val, ok := a.Value.([]interface{}); ok { + intList, err := transformToIntList(val) + if err != nil { + return "", err + } + return fmt.Sprintf("%s=%s", a.Name, + strings.Join(strings.Split(fmt.Sprint(intList), " "), ",")), nil + } + return "", fmt.Errorf("value of attribute must be string or list of int, given %s", a.Value) +} + +func filter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute { + ret := make(map[string]*attribute, 0) + for _, a := range attrs { + if strings.EqualFold(a.Prefix, prefix) { + ret[a.Name] = a + if remove { + delete(attrs, a.FullName) + } + } + } + return ret +} + +func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { + ret := make(map[string]*attribute) + for k, v := range *attrs { + subs := strings.SplitN(k, ".", 2) + name := subs[len(subs)-1] + prefix := "" + if len(subs) == 2 { + prefix = subs[0] + } + r, err := resolveExpression(v) + if err != nil { + return nil, err + } + + a := &attribute{ + FullName: k, + Prefix: prefix, + Name: name, + Value: r} + ret[a.FullName] = a + } + return ret, nil +} diff --git a/sql/bucket_column.go b/sql/bucket_column.go new file mode 100644 index 0000000000..cad9ed53b6 --- /dev/null +++ b/sql/bucket_column.go @@ -0,0 +1,77 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strings" +) + +type bucketColumn struct { + SourceColumn *numericColumn + Boundaries []int +} + +func (bc *bucketColumn) GenerateCode() (string, error) { + sourceCode, _ := bc.SourceColumn.GenerateCode() + return fmt.Sprintf( + "tf.feature_column.bucketized_column(%s, boundaries=%s)", + sourceCode, + strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ",")), nil +} + +func (bc *bucketColumn) GetDelimiter() string { + return "" +} + +func (bc *bucketColumn) GetDtype() string { + return "" +} + +func (bc *bucketColumn) GetKey() string { + return bc.SourceColumn.Key +} + +func (bc *bucketColumn) GetInputShape() string { + return bc.SourceColumn.GetInputShape() +} + +func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) + } + sourceExprList := (*el)[1] + boundariesExprList := (*el)[2] + source, err := resolveExpression(sourceExprList) + if err != nil { + return nil, err + } + if _, ok := source.(*numericColumn); !ok { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) + } + boundaries, err := resolveExpression(boundariesExprList) + if err != nil { + return nil, err + } + if _, ok := boundaries.([]interface{}); !ok { + return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) + } + b, err := transformToIntList(boundaries.([]interface{})) + if err != nil { + return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) + } + return &bucketColumn{ + SourceColumn: source.(*numericColumn), + Boundaries: b}, nil +} diff --git a/sql/category_id_column.go b/sql/category_id_column.go new file mode 100644 index 0000000000..91dc41bd01 --- /dev/null +++ b/sql/category_id_column.go @@ -0,0 +1,125 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" +) + +type categoryIDColumn struct { + Key string + BucketSize int + Delimiter string + Dtype string +} + +type sequenceCategoryIDColumn struct { + Key string + BucketSize int + Delimiter string + Dtype string + IsSequence bool +} + +func (cc *categoryIDColumn) GenerateCode() (string, error) { + return fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", + cc.Key, cc.BucketSize), nil +} + +func (cc *categoryIDColumn) GetDelimiter() string { + return cc.Delimiter +} + +func (cc *categoryIDColumn) GetDtype() string { + return cc.Dtype +} + +func (cc *categoryIDColumn) GetKey() string { + return cc.Key +} + +func (cc *categoryIDColumn) GetInputShape() string { + return fmt.Sprintf("[%d]", cc.BucketSize) +} + +func (cc *sequenceCategoryIDColumn) GenerateCode() (string, error) { + return fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", + cc.Key, cc.BucketSize), nil +} + +func (cc *sequenceCategoryIDColumn) GetDelimiter() string { + return cc.Delimiter +} + +func (cc *sequenceCategoryIDColumn) GetDtype() string { + return cc.Dtype +} + +func (cc *sequenceCategoryIDColumn) GetKey() string { + return cc.Key +} + +func (cc *sequenceCategoryIDColumn) GetInputShape() string { + return fmt.Sprintf("[%d]", cc.BucketSize) +} + +func resolveSeqCategoryIDColumn(el *exprlist) (*sequenceCategoryIDColumn, error) { + key, bucketSize, delimiter, err := parseCategoryIDColumnExpr(el) + if err != nil { + return nil, err + } + return &sequenceCategoryIDColumn{ + Key: key, + BucketSize: bucketSize, + Delimiter: delimiter, + // TODO(typhoonzero): support config dtype + Dtype: "int64", + IsSequence: true}, nil +} + +func resolveCategoryIDColumn(el *exprlist) (*categoryIDColumn, error) { + key, bucketSize, delimiter, err := parseCategoryIDColumnExpr(el) + if err != nil { + return nil, err + } + return &categoryIDColumn{ + Key: key, + BucketSize: bucketSize, + Delimiter: delimiter, + // TODO(typhoonzero): support config dtype + Dtype: "int64"}, nil +} + +func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, error) { + if len(*el) != 3 && len(*el) != 4 { + return "", 0, "", fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) + } + key, err := expression2string((*el)[1]) + if err != nil { + return "", 0, "", fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) + } + bucketSize, err := strconv.Atoi((*el)[2].val) + if err != nil { + return "", 0, "", fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) + } + delimiter := "" + if len(*el) == 4 { + delimiter, err = resolveDelimiter((*el)[3].val) + if err != nil { + return "", 0, "", fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) + } + } + return key, bucketSize, delimiter, nil +} diff --git a/sql/codegen_alps.go b/sql/codegen_alps.go index d4876b5e03..e4b0f387f2 100644 --- a/sql/codegen_alps.go +++ b/sql/codegen_alps.go @@ -663,9 +663,9 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c shape := make([]int, 1) shape[0] = len(fields) if userSpec, ok := refColumns[ct.Name()]; ok { - output[ct.Name()] = &columnSpec{ct.Name(), false, false, shape, userSpec.DType, userSpec.Delimiter, *meta.featureMap} + output[ct.Name()] = &columnSpec{ct.Name(), false, shape, userSpec.DType, userSpec.Delimiter, *meta.featureMap} } else { - output[ct.Name()] = &columnSpec{ct.Name(), false, false, shape, "float", ",", *meta.featureMap} + output[ct.Name()] = &columnSpec{ct.Name(), false, shape, "float", ",", *meta.featureMap} } } } @@ -712,7 +712,7 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columnSpec, error) { column, present := output[*name] if !present { shape := make([]int, 0, 1000) - column := &columnSpec{*name, false, true, shape, "int64", "", *meta.featureMap} + column := &columnSpec{*name, true, shape, "int64", "", *meta.featureMap} column.DType = "int64" output[*name] = column } diff --git a/sql/column_spec.go b/sql/column_spec.go new file mode 100644 index 0000000000..e7bd96d477 --- /dev/null +++ b/sql/column_spec.go @@ -0,0 +1,24 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +// columnSpec defines how to generate codes to parse column data to tensor/sparsetensor +type columnSpec struct { + ColumnName string + IsSparse bool + Shape []int + DType string + Delimiter string + FeatureMap featureMap +} diff --git a/sql/cross_column.go b/sql/cross_column.go new file mode 100644 index 0000000000..012a84a124 --- /dev/null +++ b/sql/cross_column.go @@ -0,0 +1,88 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" + "strings" +) + +// TODO(uuleon) specify the hash_key if needed +type crossColumn struct { + Keys []interface{} + HashBucketSize int +} + +func (cc *crossColumn) GenerateCode() (string, error) { + var keysGenerated = make([]string, len(cc.Keys)) + for idx, key := range cc.Keys { + if c, ok := key.(featureColumn); ok { + code, err := c.GenerateCode() + if err != nil { + return "", err + } + keysGenerated[idx] = code + continue + } + if str, ok := key.(string); ok { + keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) + } else { + return "", fmt.Errorf("cross generate code error, key: %s", key) + } + } + return fmt.Sprintf( + "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", + strings.Join(keysGenerated, ","), cc.HashBucketSize), nil +} + +func (cc *crossColumn) GetDelimiter() string { + return "" +} + +func (cc *crossColumn) GetDtype() string { + return "" +} + +func (cc *crossColumn) GetKey() string { + // NOTE: cross column is a feature on multiple column keys. + return "" +} + +func (cc *crossColumn) GetInputShape() string { + // NOTE: return empty since crossed column input shape is already determined + // by the two crossed columns. + return "" +} + +func resolveCrossColumn(el *exprlist) (*crossColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad CROSS expression format: %s", *el) + } + keysExpr := (*el)[1] + keys, err := resolveExpression(keysExpr) + if err != nil { + return nil, err + } + if _, ok := keys.([]interface{}); !ok { + return nil, fmt.Errorf("bad CROSS keys: %s", err) + } + bucketSize, err := strconv.Atoi((*el)[2].val) + if err != nil { + return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) + } + return &crossColumn{ + Keys: keys.([]interface{}), + HashBucketSize: bucketSize}, nil +} diff --git a/sql/embedding_column.go b/sql/embedding_column.go new file mode 100644 index 0000000000..09fa06e35a --- /dev/null +++ b/sql/embedding_column.go @@ -0,0 +1,95 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" +) + +type embeddingColumn struct { + CategoryColumn interface{} + Dimension int + Combiner string + Initializer string +} + +func (ec *embeddingColumn) GetDelimiter() string { + return ec.CategoryColumn.(featureColumn).GetDelimiter() +} + +func (ec *embeddingColumn) GetDtype() string { + return ec.CategoryColumn.(featureColumn).GetDtype() +} + +func (ec *embeddingColumn) GetKey() string { + return ec.CategoryColumn.(featureColumn).GetKey() +} + +func (ec *embeddingColumn) GetInputShape() string { + return ec.CategoryColumn.(featureColumn).GetInputShape() +} + +func (ec *embeddingColumn) GenerateCode() (string, error) { + catColumn, ok := ec.CategoryColumn.(featureColumn) + if !ok { + return "", fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) + } + sourceCode, err := catColumn.GenerateCode() + if err != nil { + return "", err + } + return fmt.Sprintf("tf.feature_column.embedding_column(%s, dimension=%d, combiner=\"%s\")", + sourceCode, ec.Dimension, ec.Combiner), nil +} + +func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { + if len(*el) != 4 && len(*el) != 5 { + return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) + } + sourceExprList := (*el)[1] + source, err := resolveExpression(sourceExprList) + if err != nil { + return nil, err + } + // TODO(uuleon) support other kinds of categorical column in the future + var catColumn interface{} + catColumn, ok := source.(*categoryIDColumn) + if !ok { + catColumn, ok = source.(*sequenceCategoryIDColumn) + if !ok { + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") + } + } + dimension, err := strconv.Atoi((*el)[2].val) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) + } + combiner, err := expression2string((*el)[3]) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) + } + initializer := "" + if len(*el) == 5 { + initializer, err = expression2string((*el)[4]) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) + } + } + return &embeddingColumn{ + CategoryColumn: catColumn, + Dimension: dimension, + Combiner: combiner, + Initializer: initializer}, nil +} diff --git a/sql/engine_spec.go b/sql/engine_spec.go new file mode 100644 index 0000000000..fbcca0e069 --- /dev/null +++ b/sql/engine_spec.go @@ -0,0 +1,114 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "strconv" + "strings" +) + +type engineSpec struct { + etype string + ps resourceSpec + worker resourceSpec + cluster string + queue string + masterResourceRequest string + masterResourceLimit string + workerResourceRequest string + workerResourceLimit string + volume string + imagePullPolicy string + restartPolicy string + extraPypiIndex string + namespace string + minibatchSize int + masterPodPriority string + clusterSpec string + recordsPerTask int +} + +func getEngineSpec(attrs map[string]*attribute) engineSpec { + getInt := func(key string, defaultValue int) int { + if p, ok := attrs[key]; ok { + strVal, _ := p.Value.(string) + intVal, err := strconv.Atoi(strVal) + + if err == nil { + return intVal + } + } + return defaultValue + } + getString := func(key string, defaultValue string) string { + if p, ok := attrs[key]; ok { + strVal, ok := p.Value.(string) + if ok { + // TODO(joyyoj): use the parser to do those validations. + if strings.HasPrefix(strVal, "\"") && strings.HasSuffix(strVal, "\"") { + return strVal[1 : len(strVal)-1] + } + return strVal + } + } + return defaultValue + } + + psNum := getInt("ps_num", 1) + psMemory := getInt("ps_memory", 2400) + workerMemory := getInt("worker_memory", 1600) + workerNum := getInt("worker_num", 2) + engineType := getString("type", "local") + if (psNum > 0 || workerNum > 0) && engineType == "local" { + engineType = "yarn" + } + cluster := getString("cluster", "") + queue := getString("queue", "") + + // ElasticDL engine specs + masterResourceRequest := getString("master_resource_request", "cpu=0.1,memory=1024Mi") + masterResourceLimit := getString("master_resource_limit", "") + workerResourceRequest := getString("worker_resource_request", "cpu=1,memory=4096Mi") + workerResourceLimit := getString("worker_resource_limit", "") + volume := getString("volume", "") + imagePullPolicy := getString("image_pull_policy", "Always") + restartPolicy := getString("restart_policy", "Never") + extraPypiIndex := getString("extra_pypi_index", "") + namespace := getString("namespace", "default") + minibatchSize := getInt("minibatch_size", 64) + masterPodPriority := getString("master_pod_priority", "") + clusterSpec := getString("cluster_spec", "") + recordsPerTask := getInt("records_per_task", 100) + + return engineSpec{ + etype: engineType, + ps: resourceSpec{Num: psNum, Memory: psMemory}, + worker: resourceSpec{Num: workerNum, Memory: workerMemory}, + cluster: cluster, + queue: queue, + masterResourceRequest: masterResourceRequest, + masterResourceLimit: masterResourceLimit, + workerResourceRequest: workerResourceRequest, + workerResourceLimit: workerResourceLimit, + volume: volume, + imagePullPolicy: imagePullPolicy, + restartPolicy: restartPolicy, + extraPypiIndex: extraPypiIndex, + namespace: namespace, + minibatchSize: minibatchSize, + masterPodPriority: masterPodPriority, + clusterSpec: clusterSpec, + recordsPerTask: recordsPerTask, + } +} diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index 8da0b3b7db..bf7ed9098c 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -14,7 +14,6 @@ package sql import ( - "encoding/json" "fmt" "strconv" "strings" @@ -39,36 +38,6 @@ type resourceSpec struct { Core int } -type engineSpec struct { - etype string - ps resourceSpec - worker resourceSpec - cluster string - queue string - masterResourceRequest string - masterResourceLimit string - workerResourceRequest string - workerResourceLimit string - volume string - imagePullPolicy string - restartPolicy string - extraPypiIndex string - namespace string - minibatchSize int - masterPodPriority string - clusterSpec string - recordsPerTask int -} - -type gitLabModule struct { - ModuleName string - ProjectName string - Sha string - PrivateToken string - SourceRoot string - GitLabServer string -} - type resolvedTrainClause struct { IsPreMadeModel bool ModelName string @@ -106,157 +75,11 @@ type resolvedPredictClause struct { EngineParams engineSpec } -// featureColumn is an interface that all types of feature columns and -// attributes (WITH clause) should follow. -// featureColumn is used to generate feature column code. -type featureColumn interface { - GenerateCode() (string, error) - // Some feature columns accept input tensors directly, and the data - // may be a tensor string like: 12,32,4,58,0,0 - GetDelimiter() string - GetDtype() string - GetKey() string - // return input shape json string, like "[2,3]" - GetInputShape() string -} - type featureMap struct { Table string Partition string } -// featureSpec contains information to generate DENSE/SPARSE code -type columnSpec struct { - ColumnName string - AutoDerivation bool - IsSparse bool - Shape []int - DType string - Delimiter string - FeatureMap featureMap -} - -type attribute struct { - FullName string - Prefix string - Name string - Value interface{} -} - -type numericColumn struct { - Key string - Shape []int - Delimiter string - Dtype string -} - -type bucketColumn struct { - SourceColumn *numericColumn - Boundaries []int -} - -// TODO(uuleon) specify the hash_key if needed -type crossColumn struct { - Keys []interface{} - HashBucketSize int -} - -type categoryIDColumn struct { - Key string - BucketSize int - Delimiter string - Dtype string -} - -type sequenceCategoryIDColumn struct { - Key string - BucketSize int - Delimiter string - Dtype string - IsSequence bool -} - -type embeddingColumn struct { - CategoryColumn interface{} - Dimension int - Combiner string - Initializer string -} - -func getEngineSpec(attrs map[string]*attribute) engineSpec { - getInt := func(key string, defaultValue int) int { - if p, ok := attrs[key]; ok { - strVal, _ := p.Value.(string) - intVal, err := strconv.Atoi(strVal) - - if err == nil { - return intVal - } - } - return defaultValue - } - getString := func(key string, defaultValue string) string { - if p, ok := attrs[key]; ok { - strVal, ok := p.Value.(string) - if ok { - // TODO(joyyoj): use the parser to do those validations. - if strings.HasPrefix(strVal, "\"") && strings.HasSuffix(strVal, "\"") { - return strVal[1 : len(strVal)-1] - } - return strVal - } - } - return defaultValue - } - - psNum := getInt("ps_num", 1) - psMemory := getInt("ps_memory", 2400) - workerMemory := getInt("worker_memory", 1600) - workerNum := getInt("worker_num", 2) - engineType := getString("type", "local") - if (psNum > 0 || workerNum > 0) && engineType == "local" { - engineType = "yarn" - } - cluster := getString("cluster", "") - queue := getString("queue", "") - - // ElasticDL engine specs - masterResourceRequest := getString("master_resource_request", "cpu=0.1,memory=1024Mi") - masterResourceLimit := getString("master_resource_limit", "") - workerResourceRequest := getString("worker_resource_request", "cpu=1,memory=4096Mi") - workerResourceLimit := getString("worker_resource_limit", "") - volume := getString("volume", "") - imagePullPolicy := getString("image_pull_policy", "Always") - restartPolicy := getString("restart_policy", "Never") - extraPypiIndex := getString("extra_pypi_index", "") - namespace := getString("namespace", "default") - minibatchSize := getInt("minibatch_size", 64) - masterPodPriority := getString("master_pod_priority", "") - clusterSpec := getString("cluster_spec", "") - recordsPerTask := getInt("records_per_task", 100) - - return engineSpec{ - etype: engineType, - ps: resourceSpec{Num: psNum, Memory: psMemory}, - worker: resourceSpec{Num: workerNum, Memory: workerMemory}, - cluster: cluster, - queue: queue, - masterResourceRequest: masterResourceRequest, - masterResourceLimit: masterResourceLimit, - workerResourceRequest: workerResourceRequest, - workerResourceLimit: workerResourceLimit, - volume: volume, - imagePullPolicy: imagePullPolicy, - restartPolicy: restartPolicy, - extraPypiIndex: extraPypiIndex, - namespace: namespace, - minibatchSize: minibatchSize, - masterPodPriority: masterPodPriority, - clusterSpec: clusterSpec, - recordsPerTask: recordsPerTask, - } -} - func trimQuotes(s string) string { if len(s) >= 2 { if s[0] == '"' && s[len(s)-1] == '"' { @@ -535,9 +358,9 @@ func resolveExpression(e interface{}) (interface{}, error) { case cross: return resolveCrossColumn(el) case categoryID: - return resolveCategoryIDColumn(el, false) + return resolveSeqCategoryIDColumn(el) case seqCategoryID: - return resolveCategoryIDColumn(el, true) + return resolveCategoryIDColumn(el) case embedding: return resolveEmbeddingColumn(el) case square: @@ -607,13 +430,12 @@ func resolveColumnSpec(el *exprlist, isSparse bool) (*columnSpec, error) { dtype, err = expression2string((*el)[4]) } return &columnSpec{ - ColumnName: name, - AutoDerivation: false, - IsSparse: isSparse, - Shape: shape, - DType: dtype, - Delimiter: delimiter, - FeatureMap: fm}, nil + ColumnName: name, + IsSparse: isSparse, + Shape: shape, + DType: dtype, + Delimiter: delimiter, + FeatureMap: fm}, nil } func expression2string(e interface{}) (string, error) { @@ -646,311 +468,6 @@ func (cs *columnSpec) ToString() string { cs.Delimiter) } -func resolveNumericColumn(el *exprlist) (*numericColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) - } - key, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad NUMERIC key: %s, err: %s", (*el)[1], err) - } - var shape []int - intVal, err := strconv.Atoi((*el)[2].val) - if err != nil { - list, err := resolveExpression((*el)[2]) - if err != nil { - return nil, err - } - if list, ok := list.([]interface{}); ok { - shape, err = transformToIntList(list) - if err != nil { - return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) - } - } else { - return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) - } - } else { - shape = append(shape, intVal) - } - return &numericColumn{ - Key: key, - Shape: shape, - // FIXME(typhoonzero, tony): support config Delimiter and Dtype - Delimiter: ",", - Dtype: "float32"}, nil -} - -func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) - } - sourceExprList := (*el)[1] - boundariesExprList := (*el)[2] - source, err := resolveExpression(sourceExprList) - if err != nil { - return nil, err - } - if _, ok := source.(*numericColumn); !ok { - return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) - } - boundaries, err := resolveExpression(boundariesExprList) - if err != nil { - return nil, err - } - if _, ok := boundaries.([]interface{}); !ok { - return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) - } - b, err := transformToIntList(boundaries.([]interface{})) - if err != nil { - return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) - } - return &bucketColumn{ - SourceColumn: source.(*numericColumn), - Boundaries: b}, nil -} - -func resolveCrossColumn(el *exprlist) (*crossColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad CROSS expression format: %s", *el) - } - keysExpr := (*el)[1] - keys, err := resolveExpression(keysExpr) - if err != nil { - return nil, err - } - if _, ok := keys.([]interface{}); !ok { - return nil, fmt.Errorf("bad CROSS keys: %s", err) - } - bucketSize, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) - } - return &crossColumn{ - Keys: keys.([]interface{}), - HashBucketSize: bucketSize}, nil -} - -func resolveCategoryIDColumn(el *exprlist, isSequence bool) (interface{}, error) { - if len(*el) != 3 && len(*el) != 4 { - return nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) - } - key, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) - } - bucketSize, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) - } - delimiter := "" - if len(*el) == 4 { - delimiter, err = resolveDelimiter((*el)[3].val) - if err != nil { - return nil, fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) - } - } - if isSequence { - return &sequenceCategoryIDColumn{ - Key: key, - BucketSize: bucketSize, - Delimiter: delimiter, - // TODO(typhoonzero): support config dtype - Dtype: "int64", - IsSequence: true}, nil - } - return &categoryIDColumn{ - Key: key, - BucketSize: bucketSize, - Delimiter: delimiter, - // TODO(typhoonzero): support config dtype - Dtype: "int64"}, nil -} - -func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { - if len(*el) != 4 && len(*el) != 5 { - return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) - } - sourceExprList := (*el)[1] - source, err := resolveExpression(sourceExprList) - if err != nil { - return nil, err - } - // TODO(uuleon) support other kinds of categorical column in the future - var catColumn interface{} - catColumn, ok := source.(*categoryIDColumn) - if !ok { - catColumn, ok = source.(*sequenceCategoryIDColumn) - if !ok { - return nil, fmt.Errorf("key of EMBEDDING must be categorical column") - } - } - dimension, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) - } - combiner, err := expression2string((*el)[3]) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) - } - initializer := "" - if len(*el) == 5 { - initializer, err = expression2string((*el)[4]) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) - } - } - return &embeddingColumn{ - CategoryColumn: catColumn, - Dimension: dimension, - Combiner: combiner, - Initializer: initializer}, nil -} - -func (nc *numericColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, - strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ",")), nil -} - -func (nc *numericColumn) GetDelimiter() string { - return nc.Delimiter -} - -func (nc *numericColumn) GetDtype() string { - return nc.Dtype -} - -func (nc *numericColumn) GetKey() string { - return nc.Key -} - -func (nc *numericColumn) GetInputShape() string { - jsonBytes, err := json.Marshal(nc.Shape) - if err != nil { - return "" - } - return string(jsonBytes) -} - -func (bc *bucketColumn) GenerateCode() (string, error) { - sourceCode, _ := bc.SourceColumn.GenerateCode() - return fmt.Sprintf( - "tf.feature_column.bucketized_column(%s, boundaries=%s)", - sourceCode, - strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ",")), nil -} - -func (bc *bucketColumn) GetDelimiter() string { - return "" -} - -func (bc *bucketColumn) GetDtype() string { - return "" -} - -func (bc *bucketColumn) GetKey() string { - return bc.SourceColumn.Key -} - -func (bc *bucketColumn) GetInputShape() string { - return bc.SourceColumn.GetInputShape() -} - -func (cc *crossColumn) GenerateCode() (string, error) { - var keysGenerated = make([]string, len(cc.Keys)) - for idx, key := range cc.Keys { - if c, ok := key.(featureColumn); ok { - code, err := c.GenerateCode() - if err != nil { - return "", err - } - keysGenerated[idx] = code - continue - } - if str, ok := key.(string); ok { - keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) - } else { - return "", fmt.Errorf("cross generate code error, key: %s", key) - } - } - return fmt.Sprintf( - "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", - strings.Join(keysGenerated, ","), cc.HashBucketSize), nil -} - -func (cc *crossColumn) GetDelimiter() string { - return "" -} - -func (cc *crossColumn) GetDtype() string { - return "" -} - -func (cc *crossColumn) GetKey() string { - // NOTE: cross column is a feature on multiple column keys. - return "" -} - -func (cc *crossColumn) GetInputShape() string { - // NOTE: return empty since crossed column input shape is already determined - // by the two crossed columns. - return "" -} - -func (cc *categoryIDColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", - cc.Key, cc.BucketSize), nil -} - -func (cc *categoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -func (cc *categoryIDColumn) GetDtype() string { - return cc.Dtype -} - -func (cc *categoryIDColumn) GetKey() string { - return cc.Key -} - -func (cc *categoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - -func (cc *sequenceCategoryIDColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", - cc.Key, cc.BucketSize), nil -} - -func (cc *sequenceCategoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -func (cc *sequenceCategoryIDColumn) GetDtype() string { - return cc.Dtype -} - -func (cc *sequenceCategoryIDColumn) GetKey() string { - return cc.Key -} - -func (cc *sequenceCategoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - -func (ec *embeddingColumn) GenerateCode() (string, error) { - catColumn, ok := ec.CategoryColumn.(featureColumn) - if !ok { - return "", fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) - } - sourceCode, err := catColumn.GenerateCode() - if err != nil { - return "", err - } - return fmt.Sprintf("tf.feature_column.embedding_column(%s, dimension=%d, combiner=\"%s\")", - sourceCode, ec.Dimension, ec.Combiner), nil -} - func generateFeatureColumnCode(fcs []featureColumn) (string, error) { var codes = make([]string, 0, len(fcs)) for _, fc := range fcs { @@ -963,46 +480,6 @@ func generateFeatureColumnCode(fcs []featureColumn) (string, error) { return fmt.Sprintf("[%s]", strings.Join(codes, ",")), nil } -func (ec *embeddingColumn) GetDelimiter() string { - return ec.CategoryColumn.(featureColumn).GetDelimiter() -} - -func (ec *embeddingColumn) GetDtype() string { - return ec.CategoryColumn.(featureColumn).GetDtype() -} - -func (ec *embeddingColumn) GetKey() string { - return ec.CategoryColumn.(featureColumn).GetKey() -} - -func (ec *embeddingColumn) GetInputShape() string { - return ec.CategoryColumn.(featureColumn).GetInputShape() -} - -func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { - ret := make(map[string]*attribute) - for k, v := range *attrs { - subs := strings.SplitN(k, ".", 2) - name := subs[len(subs)-1] - prefix := "" - if len(subs) == 2 { - prefix = subs[0] - } - r, err := resolveExpression(v) - if err != nil { - return nil, err - } - - a := &attribute{ - FullName: k, - Prefix: prefix, - Name: name, - Value: r} - ret[a.FullName] = a - } - return ret, nil -} - func resolveDelimiter(delimiter string) (string, error) { if strings.EqualFold(delimiter, comma) { return ",", nil @@ -1010,25 +487,6 @@ func resolveDelimiter(delimiter string) (string, error) { return "", fmt.Errorf("unsolved delimiter: %s", delimiter) } -func (a *attribute) GenerateCode() (string, error) { - if val, ok := a.Value.(string); ok { - // auto convert to int first. - if _, err := strconv.Atoi(val); err == nil { - return fmt.Sprintf("%s=%s", a.Name, val), nil - } - return fmt.Sprintf("%s=\"%s\"", a.Name, val), nil - } - if val, ok := a.Value.([]interface{}); ok { - intList, err := transformToIntList(val) - if err != nil { - return "", err - } - return fmt.Sprintf("%s=%s", a.Name, - strings.Join(strings.Split(fmt.Sprint(intList), " "), ",")), nil - } - return "", fmt.Errorf("value of attribute must be string or list of int, given %s", a.Value) -} - func transformToIntList(list []interface{}) ([]int, error) { var b = make([]int, len(list)) for idx, item := range list { @@ -1040,16 +498,3 @@ func transformToIntList(list []interface{}) ([]int, error) { } return b, nil } - -func filter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute { - ret := make(map[string]*attribute, 0) - for _, a := range attrs { - if strings.EqualFold(a.Prefix, prefix) { - ret[a.Name] = a - if remove { - delete(attrs, a.FullName) - } - } - } - return ret -} diff --git a/sql/expression_resolver_test.go b/sql/expression_resolver_test.go index 29b70506d1..0c7ca726c6 100644 --- a/sql/expression_resolver_test.go +++ b/sql/expression_resolver_test.go @@ -292,3 +292,19 @@ func TestExecResource(t *testing.T) { fmt.Println(attr) } + +func TestCatIdColumnWithColumnSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + dense := statementWithColumn("CATEGORY_ID(DENSE(col1, 128), 100)") + // sparse := statementWithColumn("CATEGORY_ID(SPARSE(col2, 1000, COMMA))") + + r, e := parser.Parse(dense) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + _, ok := fcs[0].(*categoryIDColumn) + a.True(ok) +} diff --git a/sql/feature_column.go b/sql/feature_column.go new file mode 100644 index 0000000000..f8f5e17f69 --- /dev/null +++ b/sql/feature_column.go @@ -0,0 +1,28 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +// featureColumn is an interface that all types of feature columns and +// attributes (WITH clause) should follow. +// featureColumn is used to generate feature column code. +type featureColumn interface { + GenerateCode() (string, error) + // Some feature columns accept input tensors directly, and the data + // may be a tensor string like: 12,32,4,58,0,0 + GetDelimiter() string + GetDtype() string + GetKey() string + // return input shape json string, like "[2,3]" + GetInputShape() string +} diff --git a/sql/gitlab_module.go b/sql/gitlab_module.go new file mode 100644 index 0000000000..aa26d707fd --- /dev/null +++ b/sql/gitlab_module.go @@ -0,0 +1,23 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +type gitLabModule struct { + ModuleName string + ProjectName string + Sha string + PrivateToken string + SourceRoot string + GitLabServer string +} diff --git a/sql/numeric_column.go b/sql/numeric_column.go new file mode 100644 index 0000000000..906a64041f --- /dev/null +++ b/sql/numeric_column.go @@ -0,0 +1,87 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +type numericColumn struct { + Key string + Shape []int + Delimiter string + Dtype string +} + +func (nc *numericColumn) GenerateCode() (string, error) { + return fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, + strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ",")), nil +} + +func (nc *numericColumn) GetDelimiter() string { + return nc.Delimiter +} + +func (nc *numericColumn) GetDtype() string { + return nc.Dtype +} + +func (nc *numericColumn) GetKey() string { + return nc.Key +} + +func (nc *numericColumn) GetInputShape() string { + jsonBytes, err := json.Marshal(nc.Shape) + if err != nil { + return "" + } + return string(jsonBytes) +} + +func resolveNumericColumn(el *exprlist) (*numericColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) + } + key, err := expression2string((*el)[1]) + if err != nil { + return nil, fmt.Errorf("bad NUMERIC key: %s, err: %s", (*el)[1], err) + } + var shape []int + intVal, err := strconv.Atoi((*el)[2].val) + if err != nil { + list, err := resolveExpression((*el)[2]) + if err != nil { + return nil, err + } + if list, ok := list.([]interface{}); ok { + shape, err = transformToIntList(list) + if err != nil { + return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) + } + } else { + return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) + } + } else { + shape = append(shape, intVal) + } + return &numericColumn{ + Key: key, + Shape: shape, + // FIXME(typhoonzero, tony): support config Delimiter and Dtype + Delimiter: ",", + Dtype: "float32"}, nil +} From 34304617a1e0844bae814fb9109757440cdbc73b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Wed, 28 Aug 2019 13:41:23 +0800 Subject: [PATCH 2/8] fix ci --- sql/expression_resolver.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index bf7ed9098c..a573c94526 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -358,9 +358,9 @@ func resolveExpression(e interface{}) (interface{}, error) { case cross: return resolveCrossColumn(el) case categoryID: - return resolveSeqCategoryIDColumn(el) - case seqCategoryID: return resolveCategoryIDColumn(el) + case seqCategoryID: + return resolveSeqCategoryIDColumn(el) case embedding: return resolveEmbeddingColumn(el) case square: From 8c19863896fcf46b423b731220d61804bb663d6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Wed, 28 Aug 2019 23:15:26 +0800 Subject: [PATCH 3/8] wip --- sql/attribute.go | 4 +- sql/bucket_column.go | 11 +++- sql/category_id_column.go | 63 +++++++++++++----- sql/column_spec.go | 54 +++++++++++++++ sql/cross_column.go | 6 +- sql/embedding_column.go | 6 +- sql/expression_resolver.go | 132 ++++++++++--------------------------- sql/feature_column.go | 61 +++++++++++++++-- sql/numeric_column.go | 6 +- 9 files changed, 217 insertions(+), 126 deletions(-) diff --git a/sql/attribute.go b/sql/attribute.go index 2e9e9610e9..20dc128074 100644 --- a/sql/attribute.go +++ b/sql/attribute.go @@ -67,11 +67,11 @@ func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { if len(subs) == 2 { prefix = subs[0] } - r, err := resolveExpression(v) + r, err := resolveLispExpression(v) if err != nil { + fmt.Printf("%v", err) return nil, err } - a := &attribute{ FullName: k, Prefix: prefix, diff --git a/sql/bucket_column.go b/sql/bucket_column.go index cad9ed53b6..39a634117b 100644 --- a/sql/bucket_column.go +++ b/sql/bucket_column.go @@ -47,20 +47,24 @@ func (bc *bucketColumn) GetInputShape() string { return bc.SourceColumn.GetInputShape() } +func (bc *bucketColumn) GetColumnType() int { + return columnTypeBucket +} + func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { if len(*el) != 3 { return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) } sourceExprList := (*el)[1] boundariesExprList := (*el)[2] - source, err := resolveExpression(sourceExprList) + source, _, err := resolveColumn(&sourceExprList.sexp) if err != nil { return nil, err } - if _, ok := source.(*numericColumn); !ok { + if source.GetColumnType() != columnTypeNumeric { return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) } - boundaries, err := resolveExpression(boundariesExprList) + boundaries, err := resolveLispExpression(boundariesExprList) if err != nil { return nil, err } @@ -72,6 +76,7 @@ func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) } return &bucketColumn{ + // SourceColumn: source.(*numericColumn), SourceColumn: source.(*numericColumn), Boundaries: b}, nil } diff --git a/sql/category_id_column.go b/sql/category_id_column.go index 91dc41bd01..a1bf193493 100644 --- a/sql/category_id_column.go +++ b/sql/category_id_column.go @@ -54,6 +54,10 @@ func (cc *categoryIDColumn) GetInputShape() string { return fmt.Sprintf("[%d]", cc.BucketSize) } +func (cc *categoryIDColumn) GetColumnType() int { + return columnTypeCategoryID +} + func (cc *sequenceCategoryIDColumn) GenerateCode() (string, error) { return fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", cc.Key, cc.BucketSize), nil @@ -75,10 +79,24 @@ func (cc *sequenceCategoryIDColumn) GetInputShape() string { return fmt.Sprintf("[%d]", cc.BucketSize) } -func resolveSeqCategoryIDColumn(el *exprlist) (*sequenceCategoryIDColumn, error) { - key, bucketSize, delimiter, err := parseCategoryIDColumnExpr(el) +func (cc *sequenceCategoryIDColumn) GetColumnType() int { + return columnTypeSeqCategoryID +} + +func parseCategoryColumnKey(el *exprlist) (*columnSpec, error) { + if (*el)[1].typ == 0 { + // explist, maybe DENSE/SPARSE expressions + subExprList := (*el)[1].sexp + isSparse := subExprList[0].val == sparse + return resolveColumnSpec(&subExprList, isSparse) + } + return nil, nil +} + +func resolveSeqCategoryIDColumn(el *exprlist) (*sequenceCategoryIDColumn, *columnSpec, error) { + key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) if err != nil { - return nil, err + return nil, nil, err } return &sequenceCategoryIDColumn{ Key: key, @@ -86,40 +104,53 @@ func resolveSeqCategoryIDColumn(el *exprlist) (*sequenceCategoryIDColumn, error) Delimiter: delimiter, // TODO(typhoonzero): support config dtype Dtype: "int64", - IsSequence: true}, nil + IsSequence: true}, cs, nil } -func resolveCategoryIDColumn(el *exprlist) (*categoryIDColumn, error) { - key, bucketSize, delimiter, err := parseCategoryIDColumnExpr(el) +func resolveCategoryIDColumn(el *exprlist) (*categoryIDColumn, *columnSpec, error) { + key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) if err != nil { - return nil, err + return nil, nil, err } return &categoryIDColumn{ Key: key, BucketSize: bucketSize, Delimiter: delimiter, // TODO(typhoonzero): support config dtype - Dtype: "int64"}, nil + Dtype: "int64"}, cs, nil } -func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, error) { +func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columnSpec, error) { if len(*el) != 3 && len(*el) != 4 { - return "", 0, "", fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) } - key, err := expression2string((*el)[1]) - if err != nil { - return "", 0, "", fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) + var cs *columnSpec + key := "" + if (*el)[1].typ == 0 { + // explist, maybe DENSE/SPARSE expressions + subExprList := (*el)[1].sexp + isSparse := subExprList[0].val == sparse + cs, err := resolveColumnSpec(&subExprList, isSparse) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) + } + key = cs.ColumnName + } else { + key, err := expression2string((*el)[1]) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) + } } bucketSize, err := strconv.Atoi((*el)[2].val) if err != nil { - return "", 0, "", fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) } delimiter := "" if len(*el) == 4 { delimiter, err = resolveDelimiter((*el)[3].val) if err != nil { - return "", 0, "", fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) } } - return key, bucketSize, delimiter, nil + return key, bucketSize, delimiter, cs, nil } diff --git a/sql/column_spec.go b/sql/column_spec.go index e7bd96d477..0a8662a925 100644 --- a/sql/column_spec.go +++ b/sql/column_spec.go @@ -13,6 +13,11 @@ package sql +import ( + "fmt" + "strconv" +) + // columnSpec defines how to generate codes to parse column data to tensor/sparsetensor type columnSpec struct { ColumnName string @@ -22,3 +27,52 @@ type columnSpec struct { Delimiter string FeatureMap featureMap } + +func resolveColumnSpec(el *exprlist, isSparse bool) (*columnSpec, error) { + if len(*el) < 4 { + return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) + } + name, err := expression2string((*el)[1]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec name: %s, err: %s", (*el)[1], err) + } + var shape []int + intShape, err := strconv.Atoi((*el)[2].val) + if err != nil { + strShape, err := expression2string((*el)[2]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) + } + if strShape != "none" { + return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) + } + } else { + shape = append(shape, intShape) + } + unresolvedDelimiter, err := expression2string((*el)[3]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec delimiter: %s, err: %s", (*el)[1], err) + } + + delimiter, err := resolveDelimiter(unresolvedDelimiter) + if err != nil { + return nil, err + } + + // resolve feature map + fm := featureMap{} + dtype := "float" + if isSparse { + dtype = "int" + } + if len(*el) >= 5 { + dtype, err = expression2string((*el)[4]) + } + return &columnSpec{ + ColumnName: name, + IsSparse: isSparse, + Shape: shape, + DType: dtype, + Delimiter: delimiter, + FeatureMap: fm}, nil +} diff --git a/sql/cross_column.go b/sql/cross_column.go index 012a84a124..5922bfe44f 100644 --- a/sql/cross_column.go +++ b/sql/cross_column.go @@ -66,12 +66,16 @@ func (cc *crossColumn) GetInputShape() string { return "" } +func (cc *crossColumn) GetColumnType() int { + return columnTypeCross +} + func resolveCrossColumn(el *exprlist) (*crossColumn, error) { if len(*el) != 3 { return nil, fmt.Errorf("bad CROSS expression format: %s", *el) } keysExpr := (*el)[1] - keys, err := resolveExpression(keysExpr) + keys, err := resolveLispExpression(keysExpr) if err != nil { return nil, err } diff --git a/sql/embedding_column.go b/sql/embedding_column.go index 09fa06e35a..fa94a92d2d 100644 --- a/sql/embedding_column.go +++ b/sql/embedding_column.go @@ -41,6 +41,10 @@ func (ec *embeddingColumn) GetInputShape() string { return ec.CategoryColumn.(featureColumn).GetInputShape() } +func (ec *embeddingColumn) GetColumnType() int { + return columnTypeEmbedding +} + func (ec *embeddingColumn) GenerateCode() (string, error) { catColumn, ok := ec.CategoryColumn.(featureColumn) if !ok { @@ -59,7 +63,7 @@ func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) } sourceExprList := (*el)[1] - source, err := resolveExpression(sourceExprList) + source, _, err := resolveColumn(&sourceExprList.sexp) if err != nil { return nil, err } diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index a573c94526..242ac31874 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -285,32 +285,36 @@ func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, err var fcs = make([]featureColumn, 0) var css = make([]*columnSpec, 0) for _, expr := range *columns { - result, err := resolveExpression(expr) - if err != nil { - return nil, nil, err - } - if cs, ok := result.(*columnSpec); ok { - css = append(css, cs) - continue - } else if c, ok := result.(featureColumn); ok { - fcs = append(fcs, c) - } else if s, ok := result.(string); ok { - // simple string column, generate default numeric column + if expr.typ != 0 { + // only column identifier like "COLUMN a1,b1" + // FIXME(typhoonzero): infer the column spec here. c := &numericColumn{ - Key: s, + Key: expr.val, Shape: []int{1}, Dtype: "float32", } fcs = append(fcs, c) } else { - return nil, nil, fmt.Errorf("not recognized type: %s", result) + result, cs, err := resolveColumn(&expr.sexp) + if err != nil { + return nil, nil, err + } + // if cs, ok := result.(*columnSpec); ok { + if cs != nil { + css = append(css, cs) + continue + } else if c, ok := result.(featureColumn); ok { + fcs = append(fcs, c) + } else { + return nil, nil, fmt.Errorf("not recognized type: %s", result) + } } } return fcs, css, nil } func getExpressionFieldName(expr *expr) (string, error) { - result, err := resolveExpression(expr) + result, err := resolveLispExpression(expr) if err != nil { return "", err } @@ -326,56 +330,40 @@ func getExpressionFieldName(expr *expr) (string, error) { } } -// resolveExpression resolve a SQLFlow expression to the actual value -// see: sql.y:241 for the definition of expression. -func resolveExpression(e interface{}) (interface{}, error) { +// resolveLispExpression returns the actual value of the expression: +// IDENT -> string +// [1,2,3] -> []interface{} +// ["a","b","c"] -> []interface{} +// [[1,2,3], "b", "c"] -> []interface{} +func resolveLispExpression(e interface{}) (interface{}, error) { if expr, ok := e.(*expr); ok { - if expr.val != "" { + if expr.typ != 0 { return expr.val, nil } - return resolveExpression(&expr.sexp) + return resolveLispExpression(&expr.sexp) } el, ok := e.(*exprlist) if !ok { - return nil, fmt.Errorf("input of resolveExpression must be `expr` or `exprlist` given %s", e) - } - - head := (*el)[0].val - if head == "" { - return resolveExpression(&(*el)[0].sexp) + return nil, fmt.Errorf("input of resolveLispExpression must be `expr` or `exprlist` given %s", e) } - - switch strings.ToUpper(head) { - case dense: - return resolveColumnSpec(el, false) - case sparse: - return resolveColumnSpec(el, true) - case numeric: - return resolveNumericColumn(el) - case bucket: - return resolveBucketColumn(el) - case cross: - return resolveCrossColumn(el) - case categoryID: - return resolveCategoryIDColumn(el) - case seqCategoryID: - return resolveSeqCategoryIDColumn(el) - case embedding: - return resolveEmbeddingColumn(el) - case square: + headTyp := (*el)[0].typ + if headTyp == 0 { + return resolveLispExpression(&(*el)[0].sexp) + } else if headTyp == '[' { var list []interface{} for idx, expr := range *el { if idx > 0 { if expr.sexp == nil { intVal, err := strconv.Atoi(expr.val) + // TODO: support list of float etc. if err != nil { list = append(list, expr.val) } else { list = append(list, intVal) } } else { - value, err := resolveExpression(&expr.sexp) + value, err := resolveLispExpression(&expr.sexp) if err != nil { return nil, err } @@ -384,62 +372,12 @@ func resolveExpression(e interface{}) (interface{}, error) { } } return list, nil - default: - return nil, fmt.Errorf("not supported expr: %s", head) - } -} - -func resolveColumnSpec(el *exprlist, isSparse bool) (*columnSpec, error) { - if len(*el) < 4 { - return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) - } - name, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec name: %s, err: %s", (*el)[1], err) - } - var shape []int - intShape, err := strconv.Atoi((*el)[2].val) - if err != nil { - strShape, err := expression2string((*el)[2]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) - } - if strShape != "none" { - return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) - } - } else { - shape = append(shape, intShape) - } - unresolvedDelimiter, err := expression2string((*el)[3]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec delimiter: %s, err: %s", (*el)[1], err) - } - - delimiter, err := resolveDelimiter(unresolvedDelimiter) - if err != nil { - return nil, err - } - - // resolve feature map - fm := featureMap{} - dtype := "float" - if isSparse { - dtype = "int" - } - if len(*el) >= 5 { - dtype, err = expression2string((*el)[4]) } - return &columnSpec{ - ColumnName: name, - IsSparse: isSparse, - Shape: shape, - DType: dtype, - Delimiter: delimiter, - FeatureMap: fm}, nil + return nil, fmt.Errorf("not supported expr: %v", el) } func expression2string(e interface{}) (string, error) { - resolved, err := resolveExpression(e) + resolved, err := resolveLispExpression(e) if err != nil { return "", err } diff --git a/sql/feature_column.go b/sql/feature_column.go index f8f5e17f69..4c1fbbee66 100644 --- a/sql/feature_column.go +++ b/sql/feature_column.go @@ -13,16 +13,67 @@ package sql -// featureColumn is an interface that all types of feature columns and -// attributes (WITH clause) should follow. -// featureColumn is used to generate feature column code. +import ( + "fmt" + "strings" +) + +const ( + columnTypeBucket = 0 + columnTypeEmbedding = 1 + columnTypeNumeric = 2 + columnTypeCategoryID = 3 + columnTypeSeqCategoryID = 3 + columnTypeCross = 4 +) + +// featureColumn is an interface that all types of feature columns +// should follow. featureColumn is used to generate feature column code. type featureColumn interface { GenerateCode() (string, error) - // Some feature columns accept input tensors directly, and the data - // may be a tensor string like: 12,32,4,58,0,0 + // FIXME(typhoonzero): remove delimiter, dtype shape from feature column + // get these from column spec claused or by feature derivation. GetDelimiter() string GetDtype() string GetKey() string // return input shape json string, like "[2,3]" GetInputShape() string + GetColumnType() int +} + +// resolveFeatureColumn returns the acutal feature column typed struct +// as well as the columnSpec infomation. +func resolveColumn(el *exprlist) (featureColumn, *columnSpec, error) { + head := (*el)[0].val + if head == "" { + return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %s", el) + } + + switch strings.ToUpper(head) { + case dense: + cs, err := resolveColumnSpec(el, false) + return nil, cs, err + case sparse: + cs, err := resolveColumnSpec(el, true) + return nil, cs, err + case numeric: + // TODO(typhoonzero): support NUMERIC(DENSE(col)) and NUMERIC(SPARSE(col)) + fc, err := resolveNumericColumn(el) + return fc, nil, err + case bucket: + fc, err := resolveBucketColumn(el) + return fc, nil, err + case cross: + fc, err := resolveCrossColumn(el) + return fc, nil, err + case categoryID: + return resolveCategoryIDColumn(el) + case seqCategoryID: + return resolveSeqCategoryIDColumn(el) + case embedding: + fc, err := resolveEmbeddingColumn(el) + return fc, nil, err + default: + return nil, nil, fmt.Errorf("not supported expr: %s", head) + } } diff --git a/sql/numeric_column.go b/sql/numeric_column.go index 906a64041f..55dd48ba31 100644 --- a/sql/numeric_column.go +++ b/sql/numeric_column.go @@ -52,6 +52,10 @@ func (nc *numericColumn) GetInputShape() string { return string(jsonBytes) } +func (nc *numericColumn) GetColumnType() int { + return columnTypeNumeric +} + func resolveNumericColumn(el *exprlist) (*numericColumn, error) { if len(*el) != 3 { return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) @@ -63,7 +67,7 @@ func resolveNumericColumn(el *exprlist) (*numericColumn, error) { var shape []int intVal, err := strconv.Atoi((*el)[2].val) if err != nil { - list, err := resolveExpression((*el)[2]) + list, err := resolveLispExpression((*el)[2]) if err != nil { return nil, err } From 3927d1b19655008cc425d148ef0f4cf08e53646d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Wed, 28 Aug 2019 23:30:11 +0800 Subject: [PATCH 4/8] wip --- sql/category_id_column.go | 5 +++-- sql/expression_resolver.go | 28 +++++++++++++++++----------- sql/feature_column.go | 2 +- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/category_id_column.go b/sql/category_id_column.go index a1bf193493..a469bd91d4 100644 --- a/sql/category_id_column.go +++ b/sql/category_id_column.go @@ -126,17 +126,18 @@ func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columnSpec, } var cs *columnSpec key := "" + var err error if (*el)[1].typ == 0 { // explist, maybe DENSE/SPARSE expressions subExprList := (*el)[1].sexp isSparse := subExprList[0].val == sparse - cs, err := resolveColumnSpec(&subExprList, isSparse) + cs, err = resolveColumnSpec(&subExprList, isSparse) if err != nil { return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) } key = cs.ColumnName } else { - key, err := expression2string((*el)[1]) + key, err = expression2string((*el)[1]) if err != nil { return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) } diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index 242ac31874..617c6a5d0f 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -285,6 +285,7 @@ func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, err var fcs = make([]featureColumn, 0) var css = make([]*columnSpec, 0) for _, expr := range *columns { + fmt.Printf("resolve columns: %v\n", expr) if expr.typ != 0 { // only column identifier like "COLUMN a1,b1" // FIXME(typhoonzero): infer the column spec here. @@ -314,20 +315,25 @@ func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, err } func getExpressionFieldName(expr *expr) (string, error) { - result, err := resolveLispExpression(expr) + if expr.typ != 0 { + return expr.val, nil + } + fc, _, err := resolveColumn(&expr.sexp) if err != nil { return "", err } - switch r := result.(type) { - case *columnSpec: - return r.ColumnName, nil - case featureColumn: - return r.GetKey(), nil - case string: - return r, nil - default: - return "", fmt.Errorf("getExpressionFieldName: unrecognized type %T", r) - } + return fc.GetKey(), nil + + // switch r := result.(type) { + // case *columnSpec: + // return r.ColumnName, nil + // case featureColumn: + // return r.GetKey(), nil + // case string: + // return r, nil + // default: + // return "", fmt.Errorf("getExpressionFieldName: unrecognized type %T", r) + // } } // resolveLispExpression returns the actual value of the expression: diff --git a/sql/feature_column.go b/sql/feature_column.go index 4c1fbbee66..1be8becfb6 100644 --- a/sql/feature_column.go +++ b/sql/feature_column.go @@ -46,7 +46,7 @@ type featureColumn interface { func resolveColumn(el *exprlist) (featureColumn, *columnSpec, error) { head := (*el)[0].val if head == "" { - return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %s", el) + return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %v", el) } switch strings.ToUpper(head) { From ed47dfae8df1e69f325a05049b7d86578462de67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Thu, 29 Aug 2019 10:17:23 +0800 Subject: [PATCH 5/8] updates --- sql/category_id_column.go | 2 +- sql/embedding_column.go | 12 +++-- sql/expression_resolver.go | 12 ++--- sql/expression_resolver_test.go | 87 --------------------------------- 4 files changed, 14 insertions(+), 99 deletions(-) diff --git a/sql/category_id_column.go b/sql/category_id_column.go index a469bd91d4..f083832dbd 100644 --- a/sql/category_id_column.go +++ b/sql/category_id_column.go @@ -133,7 +133,7 @@ func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columnSpec, isSparse := subExprList[0].val == sparse cs, err = resolveColumnSpec(&subExprList, isSparse) if err != nil { - return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %v", subExprList) } key = cs.ColumnName } else { diff --git a/sql/embedding_column.go b/sql/embedding_column.go index fa94a92d2d..ddd6b85433 100644 --- a/sql/embedding_column.go +++ b/sql/embedding_column.go @@ -63,9 +63,15 @@ func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) } sourceExprList := (*el)[1] - source, _, err := resolveColumn(&sourceExprList.sexp) - if err != nil { - return nil, err + var source featureColumn + var err error + if sourceExprList.typ == 0 { + source, _, err = resolveColumn(&sourceExprList.sexp) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") } // TODO(uuleon) support other kinds of categorical column in the future var catColumn interface{} diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index 617c6a5d0f..b67b66438b 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -285,9 +285,8 @@ func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, err var fcs = make([]featureColumn, 0) var css = make([]*columnSpec, 0) for _, expr := range *columns { - fmt.Printf("resolve columns: %v\n", expr) if expr.typ != 0 { - // only column identifier like "COLUMN a1,b1" + // Column identifier like "COLUMN a1,b1" // FIXME(typhoonzero): infer the column spec here. c := &numericColumn{ Key: expr.val, @@ -300,14 +299,11 @@ func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, err if err != nil { return nil, nil, err } - // if cs, ok := result.(*columnSpec); ok { if cs != nil { css = append(css, cs) - continue - } else if c, ok := result.(featureColumn); ok { - fcs = append(fcs, c) - } else { - return nil, nil, fmt.Errorf("not recognized type: %s", result) + } + if result != nil { + fcs = append(fcs, result) } } } diff --git a/sql/expression_resolver_test.go b/sql/expression_resolver_test.go index 0c7ca726c6..8a47ee85bd 100644 --- a/sql/expression_resolver_test.go +++ b/sql/expression_resolver_test.go @@ -184,77 +184,6 @@ func TestCrossColumn(t *testing.T) { a.Error(e) } -func TestCatIdColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("CATEGORY_ID(c1, 100)") - badKey := statementWithColumn("CATEGORY_ID([100], 100)") - badBucket := statementWithColumn("CATEGORY_ID(c1, bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - cc, ok := fcs[0].(*categoryIDColumn) - a.True(ok) - code, e := cc.GenerateCode() - a.NoError(e) - a.Equal("c1", cc.Key) - a.Equal(100, cc.BucketSize) - a.Equal("tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100)", code) - - r, e = parser.Parse(badKey) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucket) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestEmbeddingColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), 200, mean)") - badInput := statementWithColumn("EMBEDDING(c1, 100, mean)") - badBucket := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), bad, mean)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - ec, ok := fcs[0].(*embeddingColumn) - a.True(ok) - code, e := ec.GenerateCode() - a.NoError(e) - cc, ok := ec.CategoryColumn.(*categoryIDColumn) - a.True(ok) - a.Equal("c1", cc.Key) - a.Equal(100, cc.BucketSize) - a.Equal(200, ec.Dimension) - a.Equal("tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100), dimension=200, combiner=\"mean\")", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucket) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - func TestAttrs(t *testing.T) { a := assert.New(t) parser := newParser() @@ -292,19 +221,3 @@ func TestExecResource(t *testing.T) { fmt.Println(attr) } - -func TestCatIdColumnWithColumnSpec(t *testing.T) { - a := assert.New(t) - parser := newParser() - - dense := statementWithColumn("CATEGORY_ID(DENSE(col1, 128), 100)") - // sparse := statementWithColumn("CATEGORY_ID(SPARSE(col2, 1000, COMMA))") - - r, e := parser.Parse(dense) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - _, ok := fcs[0].(*categoryIDColumn) - a.True(ok) -} From 5c848f2a9b21bd1aeed593f408c2cb6bf154ed51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Thu, 29 Aug 2019 10:52:08 +0800 Subject: [PATCH 6/8] wip --- sql/attribute_test.go | 45 +++++++++ sql/bucket_column_test.go | 55 ++++++++++ sql/category_id_column_test.go | 70 +++++++++++++ sql/column_spec_test.go | 59 +++++++++++ sql/cross_column_test.go | 59 +++++++++++ sql/embedding_column_test.go | 57 +++++++++++ sql/expression_resolver_test.go | 174 +------------------------------- sql/numeric_column_test.go | 54 ++++++++++ 8 files changed, 400 insertions(+), 173 deletions(-) create mode 100644 sql/attribute_test.go create mode 100644 sql/bucket_column_test.go create mode 100644 sql/category_id_column_test.go create mode 100644 sql/column_spec_test.go create mode 100644 sql/cross_column_test.go create mode 100644 sql/embedding_column_test.go create mode 100644 sql/numeric_column_test.go diff --git a/sql/attribute_test.go b/sql/attribute_test.go new file mode 100644 index 0000000000..97399c46c7 --- /dev/null +++ b/sql/attribute_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAttrs(t *testing.T) { + a := assert.New(t) + parser := newParser() + + s := statementWithAttrs("estimator.hidden_units = [10, 20]") + r, e := parser.Parse(s) + a.NoError(e) + attrs, err := resolveAttribute(&r.trainAttrs) + a.NoError(err) + attr := attrs["estimator.hidden_units"] + a.Equal("estimator", attr.Prefix) + a.Equal("hidden_units", attr.Name) + a.Equal([]interface{}([]interface{}{10, 20}), attr.Value) + + s = statementWithAttrs("dataset.name = hello") + r, e = parser.Parse(s) + a.NoError(e) + attrs, err = resolveAttribute(&r.trainAttrs) + a.NoError(err) + attr = attrs["dataset.name"] + a.Equal("dataset", attr.Prefix) + a.Equal("name", attr.Name) + a.Equal("hello", attr.Value) +} diff --git a/sql/bucket_column_test.go b/sql/bucket_column_test.go new file mode 100644 index 0000000000..309ee34917 --- /dev/null +++ b/sql/bucket_column_test.go @@ -0,0 +1,55 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBucketColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])") + badInput := statementWithColumn("BUCKET(c1, [1, 10])") + badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + bc, ok := fcs[0].(*bucketColumn) + a.True(ok) + code, e := bc.GenerateCode() + a.NoError(e) + a.Equal("c1", bc.SourceColumn.Key) + a.Equal([]int{10}, bc.SourceColumn.Shape) + a.Equal([]int{1, 10}, bc.Boundaries) + a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBoundaries) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} diff --git a/sql/category_id_column_test.go b/sql/category_id_column_test.go new file mode 100644 index 0000000000..404ff2482c --- /dev/null +++ b/sql/category_id_column_test.go @@ -0,0 +1,70 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCatIdColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("CATEGORY_ID(c1, 100)") + badKey := statementWithColumn("CATEGORY_ID([100], 100)") + badBucket := statementWithColumn("CATEGORY_ID(c1, bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + cc, ok := fcs[0].(*categoryIDColumn) + a.True(ok) + code, e := cc.GenerateCode() + a.NoError(e) + a.Equal("c1", cc.Key) + a.Equal(100, cc.BucketSize) + a.Equal("tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100)", code) + + r, e = parser.Parse(badKey) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucket) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestCatIdColumnWithColumnSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + dense := statementWithColumn("CATEGORY_ID(DENSE(col1, 128, COMMA), 100)") + + r, e := parser.Parse(dense) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, css, e := resolveTrainColumns(&c) + a.NoError(e) + _, ok := fcs[0].(*categoryIDColumn) + a.True(ok) + a.Equal(css[0].ColumnName, "col1") +} diff --git a/sql/column_spec_test.go b/sql/column_spec_test.go new file mode 100644 index 0000000000..904bfa0cfe --- /dev/null +++ b/sql/column_spec_test.go @@ -0,0 +1,59 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFeatureSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + denseStatement := statementWithColumn("DENSE(c2, 5, comma)") + sparseStatement := statementWithColumn("SPARSE(c1, 100, comma)") + badStatement := statementWithColumn("DENSE(c3, bad, comma)") + + r, e := parser.Parse(denseStatement) + a.NoError(e) + c := r.columns["feature_columns"] + _, css, e := resolveTrainColumns(&c) + a.NoError(e) + cs := css[0] + a.Equal("c2", cs.ColumnName) + a.Equal(5, cs.Shape[0]) + a.Equal(",", cs.Delimiter) + a.Equal(false, cs.IsSparse) + a.Equal("DenseColumn(name=\"c2\", shape=[5], dtype=\"float\", separator=\",\")", cs.ToString()) + + r, e = parser.Parse(sparseStatement) + a.NoError(e) + c = r.columns["feature_columns"] + _, css, e = resolveTrainColumns(&c) + a.NoError(e) + cs = css[0] + a.Equal("c1", cs.ColumnName) + a.Equal(100, cs.Shape[0]) + a.Equal(true, cs.IsSparse) + a.Equal("SparseColumn(name=\"c1\", shape=[100], dtype=\"int\")", cs.ToString()) + + r, e = parser.Parse(badStatement) + a.NoError(e) + c = r.columns["feature_columns"] + _, _, e = resolveTrainColumns(&c) + a.Error(e) + +} diff --git a/sql/cross_column_test.go b/sql/cross_column_test.go new file mode 100644 index 0000000000..9fa5d0f83b --- /dev/null +++ b/sql/cross_column_test.go @@ -0,0 +1,59 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCrossColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], 20)") + badInput := statementWithColumn("cross(c1, 20)") + badBucketSize := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcList, _, e := resolveTrainColumns(&c) + a.NoError(e) + cc, ok := fcList[0].(*crossColumn) + a.True(ok) + code, e := cc.GenerateCode() + a.NoError(e) + + bc := cc.Keys[0].(*bucketColumn) + a.Equal("c1", bc.SourceColumn.Key) + a.Equal([]int{10}, bc.SourceColumn.Shape) + a.Equal([]int{1, 10}, bc.Boundaries) + a.Equal("c5", cc.Keys[1].(string)) + a.Equal(20, cc.HashBucketSize) + a.Equal("tf.feature_column.crossed_column([tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10]),\"c5\"], hash_bucket_size=20)", code) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcList, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucketSize) + a.NoError(e) + c = r.columns["feature_columns"] + fcList, _, e = resolveTrainColumns(&c) + a.Error(e) +} diff --git a/sql/embedding_column_test.go b/sql/embedding_column_test.go new file mode 100644 index 0000000000..a4325080b1 --- /dev/null +++ b/sql/embedding_column_test.go @@ -0,0 +1,57 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), 200, mean)") + badInput := statementWithColumn("EMBEDDING(c1, 100, mean)") + badBucket := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), bad, mean)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + ec, ok := fcs[0].(*embeddingColumn) + a.True(ok) + code, e := ec.GenerateCode() + a.NoError(e) + cc, ok := ec.CategoryColumn.(*categoryIDColumn) + a.True(ok) + a.Equal("c1", cc.Key) + a.Equal(100, cc.BucketSize) + a.Equal(200, ec.Dimension) + a.Equal("tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100), dimension=200, combiner=\"mean\")", code) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucket) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} diff --git a/sql/expression_resolver_test.go b/sql/expression_resolver_test.go index 8a47ee85bd..263bb8f2a6 100644 --- a/sql/expression_resolver_test.go +++ b/sql/expression_resolver_test.go @@ -38,177 +38,6 @@ func statementWithAttrs(attrs string) string { return fmt.Sprintf(trainStatement, attrs, "DENSE(c2, 5, comma)") } -func TestFeatureSpec(t *testing.T) { - a := assert.New(t) - parser := newParser() - - denseStatement := statementWithColumn("DENSE(c2, 5, comma)") - sparseStatement := statementWithColumn("SPARSE(c1, 100, comma)") - badStatement := statementWithColumn("DENSE(c3, bad, comma)") - - r, e := parser.Parse(denseStatement) - a.NoError(e) - c := r.columns["feature_columns"] - _, css, e := resolveTrainColumns(&c) - a.NoError(e) - cs := css[0] - a.Equal("c2", cs.ColumnName) - a.Equal(5, cs.Shape[0]) - a.Equal(",", cs.Delimiter) - a.Equal(false, cs.IsSparse) - a.Equal("DenseColumn(name=\"c2\", shape=[5], dtype=\"float\", separator=\",\")", cs.ToString()) - - r, e = parser.Parse(sparseStatement) - a.NoError(e) - c = r.columns["feature_columns"] - _, css, e = resolveTrainColumns(&c) - a.NoError(e) - cs = css[0] - a.Equal("c1", cs.ColumnName) - a.Equal(100, cs.Shape[0]) - a.Equal(true, cs.IsSparse) - a.Equal("SparseColumn(name=\"c1\", shape=[100], dtype=\"int\")", cs.ToString()) - - r, e = parser.Parse(badStatement) - a.NoError(e) - c = r.columns["feature_columns"] - _, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestNumericColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("NUMERIC(c2, [5, 10])") - moreArgs := statementWithColumn("NUMERIC(c1, 100, args)") - badShape := statementWithColumn("NUMERIC(c1, bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - nc, ok := fcs[0].(*numericColumn) - a.True(ok) - code, e := nc.GenerateCode() - a.NoError(e) - a.Equal("c2", nc.Key) - a.Equal([]int{5, 10}, nc.Shape) - a.Equal("tf.feature_column.numeric_column(\"c2\", shape=[5,10])", code) - - r, e = parser.Parse(moreArgs) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badShape) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestBucketColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])") - badInput := statementWithColumn("BUCKET(c1, [1, 10])") - badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - bc, ok := fcs[0].(*bucketColumn) - a.True(ok) - code, e := bc.GenerateCode() - a.NoError(e) - a.Equal("c1", bc.SourceColumn.Key) - a.Equal([]int{10}, bc.SourceColumn.Shape) - a.Equal([]int{1, 10}, bc.Boundaries) - a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBoundaries) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestCrossColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], 20)") - badInput := statementWithColumn("cross(c1, 20)") - badBucketSize := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcList, _, e := resolveTrainColumns(&c) - a.NoError(e) - cc, ok := fcList[0].(*crossColumn) - a.True(ok) - code, e := cc.GenerateCode() - a.NoError(e) - - bc := cc.Keys[0].(*bucketColumn) - a.Equal("c1", bc.SourceColumn.Key) - a.Equal([]int{10}, bc.SourceColumn.Shape) - a.Equal([]int{1, 10}, bc.Boundaries) - a.Equal("c5", cc.Keys[1].(string)) - a.Equal(20, cc.HashBucketSize) - a.Equal("tf.feature_column.crossed_column([tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10]),\"c5\"], hash_bucket_size=20)", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcList, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucketSize) - a.NoError(e) - c = r.columns["feature_columns"] - fcList, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestAttrs(t *testing.T) { - a := assert.New(t) - parser := newParser() - - s := statementWithAttrs("estimator.hidden_units = [10, 20]") - r, e := parser.Parse(s) - a.NoError(e) - attrs, err := resolveAttribute(&r.trainAttrs) - a.NoError(err) - attr := attrs["estimator.hidden_units"] - a.Equal("estimator", attr.Prefix) - a.Equal("hidden_units", attr.Name) - a.Equal([]interface{}([]interface{}{10, 20}), attr.Value) - - s = statementWithAttrs("dataset.name = hello") - r, e = parser.Parse(s) - a.NoError(e) - attrs, err = resolveAttribute(&r.trainAttrs) - a.NoError(err) - attr = attrs["dataset.name"] - a.Equal("dataset", attr.Prefix) - a.Equal("name", attr.Name) - a.Equal("hello", attr.Value) -} - func TestExecResource(t *testing.T) { a := assert.New(t) parser := newParser() @@ -218,6 +47,5 @@ func TestExecResource(t *testing.T) { attrs, err := resolveAttribute(&r.trainAttrs) a.NoError(err) attr := attrs["exec.worker_num"] - fmt.Println(attr) - + a.Equal(attr.Value, "2") } diff --git a/sql/numeric_column_test.go b/sql/numeric_column_test.go new file mode 100644 index 0000000000..38bd65f05b --- /dev/null +++ b/sql/numeric_column_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNumericColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("NUMERIC(c2, [5, 10])") + moreArgs := statementWithColumn("NUMERIC(c1, 100, args)") + badShape := statementWithColumn("NUMERIC(c1, bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + nc, ok := fcs[0].(*numericColumn) + a.True(ok) + code, e := nc.GenerateCode() + a.NoError(e) + a.Equal("c2", nc.Key) + a.Equal([]int{5, 10}, nc.Shape) + a.Equal("tf.feature_column.numeric_column(\"c2\", shape=[5,10])", code) + + r, e = parser.Parse(moreArgs) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badShape) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} From 5332caa144c99b656492433a001b3c68a8c35688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Thu, 29 Aug 2019 14:59:27 +0800 Subject: [PATCH 7/8] update --- cmd/sqlflowserver/main_test.go | 14 ++++++++++ sql/attribute.go | 3 +-- sql/bucket_column.go | 6 +++-- sql/cross_column.go | 9 ++++--- sql/expression_resolver.go | 49 +++++++++++++++------------------- sql/numeric_column.go | 2 +- 6 files changed, 46 insertions(+), 37 deletions(-) diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 884d5997b9..c1fa150e26 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -23,6 +23,8 @@ import ( "os" "os/exec" "path" + "strconv" + "strings" "testing" "time" @@ -589,6 +591,18 @@ INTO sqlflow_models.my_dnn_model;` // CaseTrainTextClassificationCustomLSTM is a simple End-to-End testing for case training // text classification models. func CaseTrainTextClassificationCustomLSTM(t *testing.T) { + tfVersionStr, err := exec.Command("python", "-c", "\"import tensorflow;print(tensorflow.__version__)\"").Output() + if err != nil { + t.Fatalf("error getting tensorflow version: %v", err) + } + versionParts := strings.Split(string(tfVersionStr), ".") + mainVer, err := strconv.Atoi(versionParts[0]) + if err != nil { + t.Fatalf("err tensorflow version format: %s", tfVersionStr) + } + if mainVer < 2 { + t.Skip("skip on tf version < 2") + } a := assert.New(t) trainSQL := `SELECT * FROM text_cn.train_processed diff --git a/sql/attribute.go b/sql/attribute.go index 20dc128074..2e1b7a1421 100644 --- a/sql/attribute.go +++ b/sql/attribute.go @@ -67,9 +67,8 @@ func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { if len(subs) == 2 { prefix = subs[0] } - r, err := resolveLispExpression(v) + r, _, err := resolveExpression(v) if err != nil { - fmt.Printf("%v", err) return nil, err } a := &attribute{ diff --git a/sql/bucket_column.go b/sql/bucket_column.go index 39a634117b..55f3893780 100644 --- a/sql/bucket_column.go +++ b/sql/bucket_column.go @@ -57,6 +57,9 @@ func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { } sourceExprList := (*el)[1] boundariesExprList := (*el)[2] + if sourceExprList.typ != 0 { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %v", sourceExprList) + } source, _, err := resolveColumn(&sourceExprList.sexp) if err != nil { return nil, err @@ -64,7 +67,7 @@ func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { if source.GetColumnType() != columnTypeNumeric { return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) } - boundaries, err := resolveLispExpression(boundariesExprList) + boundaries, _, err := resolveExpression(boundariesExprList) if err != nil { return nil, err } @@ -76,7 +79,6 @@ func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) } return &bucketColumn{ - // SourceColumn: source.(*numericColumn), SourceColumn: source.(*numericColumn), Boundaries: b}, nil } diff --git a/sql/cross_column.go b/sql/cross_column.go index 5922bfe44f..6cb2f66dbd 100644 --- a/sql/cross_column.go +++ b/sql/cross_column.go @@ -75,18 +75,19 @@ func resolveCrossColumn(el *exprlist) (*crossColumn, error) { return nil, fmt.Errorf("bad CROSS expression format: %s", *el) } keysExpr := (*el)[1] - keys, err := resolveLispExpression(keysExpr) + key, _, err := resolveExpression(keysExpr) if err != nil { return nil, err } - if _, ok := keys.([]interface{}); !ok { - return nil, fmt.Errorf("bad CROSS keys: %s", err) + if _, ok := key.([]interface{}); !ok { + return nil, fmt.Errorf("bad CROSS expression format: %s", *el) } + bucketSize, err := strconv.Atoi((*el)[2].val) if err != nil { return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) } return &crossColumn{ - Keys: keys.([]interface{}), + Keys: key.([]interface{}), HashBucketSize: bucketSize}, nil } diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index b67b66438b..d25dfd426b 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -319,41 +319,33 @@ func getExpressionFieldName(expr *expr) (string, error) { return "", err } return fc.GetKey(), nil - - // switch r := result.(type) { - // case *columnSpec: - // return r.ColumnName, nil - // case featureColumn: - // return r.GetKey(), nil - // case string: - // return r, nil - // default: - // return "", fmt.Errorf("getExpressionFieldName: unrecognized type %T", r) - // } } -// resolveLispExpression returns the actual value of the expression: -// IDENT -> string -// [1,2,3] -> []interface{} -// ["a","b","c"] -> []interface{} -// [[1,2,3], "b", "c"] -> []interface{} -func resolveLispExpression(e interface{}) (interface{}, error) { +// resolveExpression parse the expression recursively and +// returns the actual value of the expression: +// featureColumns, columnSpecs, error +// e.g. +// column_1 -> "column_1", nil, nil +// [1,2,3,4] -> [1,2,3,4], nil, nil +// [NUMERIC(col1), col2] -> [*numericColumn, "col2"], nil, nil +func resolveExpression(e interface{}) (interface{}, interface{}, error) { if expr, ok := e.(*expr); ok { if expr.typ != 0 { - return expr.val, nil + return expr.val, nil, nil } - return resolveLispExpression(&expr.sexp) + return resolveExpression(&expr.sexp) } - el, ok := e.(*exprlist) if !ok { - return nil, fmt.Errorf("input of resolveLispExpression must be `expr` or `exprlist` given %s", e) + return nil, nil, fmt.Errorf("input of resolveExpression must be `expr` or `exprlist` given %s", e) } headTyp := (*el)[0].typ - if headTyp == 0 { - return resolveLispExpression(&(*el)[0].sexp) + if headTyp == IDENT { + // Expression is a function call + return resolveColumn(el) } else if headTyp == '[' { var list []interface{} + var columnSpecList []interface{} for idx, expr := range *el { if idx > 0 { if expr.sexp == nil { @@ -365,21 +357,22 @@ func resolveLispExpression(e interface{}) (interface{}, error) { list = append(list, intVal) } } else { - value, err := resolveLispExpression(&expr.sexp) + value, cs, err := resolveExpression(&expr.sexp) if err != nil { - return nil, err + return nil, nil, err } list = append(list, value) + columnSpecList = append(columnSpecList, cs) } } } - return list, nil + return list, columnSpecList, nil } - return nil, fmt.Errorf("not supported expr: %v", el) + return nil, nil, fmt.Errorf("not supported expr: %v", el) } func expression2string(e interface{}) (string, error) { - resolved, err := resolveLispExpression(e) + resolved, _, err := resolveExpression(e) if err != nil { return "", err } diff --git a/sql/numeric_column.go b/sql/numeric_column.go index 55dd48ba31..022ebf4204 100644 --- a/sql/numeric_column.go +++ b/sql/numeric_column.go @@ -67,7 +67,7 @@ func resolveNumericColumn(el *exprlist) (*numericColumn, error) { var shape []int intVal, err := strconv.Atoi((*el)[2].val) if err != nil { - list, err := resolveLispExpression((*el)[2]) + list, _, err := resolveExpression((*el)[2]) if err != nil { return nil, err } From de6559efa866d8534490c8340c57620b2fda7bb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=84=E6=9C=A8?= Date: Thu, 29 Aug 2019 15:36:41 +0800 Subject: [PATCH 8/8] fix ci --- cmd/sqlflowserver/main_test.go | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index c1fa150e26..884d5997b9 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -23,8 +23,6 @@ import ( "os" "os/exec" "path" - "strconv" - "strings" "testing" "time" @@ -591,18 +589,6 @@ INTO sqlflow_models.my_dnn_model;` // CaseTrainTextClassificationCustomLSTM is a simple End-to-End testing for case training // text classification models. func CaseTrainTextClassificationCustomLSTM(t *testing.T) { - tfVersionStr, err := exec.Command("python", "-c", "\"import tensorflow;print(tensorflow.__version__)\"").Output() - if err != nil { - t.Fatalf("error getting tensorflow version: %v", err) - } - versionParts := strings.Split(string(tfVersionStr), ".") - mainVer, err := strconv.Atoi(versionParts[0]) - if err != nil { - t.Fatalf("err tensorflow version format: %s", tfVersionStr) - } - if mainVer < 2 { - t.Skip("skip on tf version < 2") - } a := assert.New(t) trainSQL := `SELECT * FROM text_cn.train_processed