diff --git a/sql/attribute.go b/sql/attribute.go new file mode 100644 index 0000000000..2e1b7a1421 --- /dev/null +++ b/sql/attribute.go @@ -0,0 +1,82 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" + "strings" +) + +type attribute struct { + FullName string + Prefix string + Name string + Value interface{} +} + +func (a *attribute) GenerateCode() (string, error) { + if val, ok := a.Value.(string); ok { + // auto convert to int first. + if _, err := strconv.Atoi(val); err == nil { + return fmt.Sprintf("%s=%s", a.Name, val), nil + } + return fmt.Sprintf("%s=\"%s\"", a.Name, val), nil + } + if val, ok := a.Value.([]interface{}); ok { + intList, err := transformToIntList(val) + if err != nil { + return "", err + } + return fmt.Sprintf("%s=%s", a.Name, + strings.Join(strings.Split(fmt.Sprint(intList), " "), ",")), nil + } + return "", fmt.Errorf("value of attribute must be string or list of int, given %s", a.Value) +} + +func filter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute { + ret := make(map[string]*attribute, 0) + for _, a := range attrs { + if strings.EqualFold(a.Prefix, prefix) { + ret[a.Name] = a + if remove { + delete(attrs, a.FullName) + } + } + } + return ret +} + +func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { + ret := make(map[string]*attribute) + for k, v := range *attrs { + subs := strings.SplitN(k, ".", 2) + name := subs[len(subs)-1] + prefix := "" + if len(subs) == 2 { + prefix = subs[0] + } + r, _, err := resolveExpression(v) + if err != nil { + return nil, err + } + a := &attribute{ + FullName: k, + Prefix: prefix, + Name: name, + Value: r} + ret[a.FullName] = a + } + return ret, nil +} diff --git a/sql/attribute_test.go b/sql/attribute_test.go new file mode 100644 index 0000000000..97399c46c7 --- /dev/null +++ b/sql/attribute_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAttrs(t *testing.T) { + a := assert.New(t) + parser := newParser() + + s := statementWithAttrs("estimator.hidden_units = [10, 20]") + r, e := parser.Parse(s) + a.NoError(e) + attrs, err := resolveAttribute(&r.trainAttrs) + a.NoError(err) + attr := attrs["estimator.hidden_units"] + a.Equal("estimator", attr.Prefix) + a.Equal("hidden_units", attr.Name) + a.Equal([]interface{}([]interface{}{10, 20}), attr.Value) + + s = statementWithAttrs("dataset.name = hello") + r, e = parser.Parse(s) + a.NoError(e) + attrs, err = resolveAttribute(&r.trainAttrs) + a.NoError(err) + attr = attrs["dataset.name"] + a.Equal("dataset", attr.Prefix) + a.Equal("name", attr.Name) + a.Equal("hello", attr.Value) +} diff --git a/sql/bucket_column.go b/sql/bucket_column.go new file mode 100644 index 0000000000..55f3893780 --- /dev/null +++ b/sql/bucket_column.go @@ -0,0 +1,84 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strings" +) + +type bucketColumn struct { + SourceColumn *numericColumn + Boundaries []int +} + +func (bc *bucketColumn) GenerateCode() (string, error) { + sourceCode, _ := bc.SourceColumn.GenerateCode() + return fmt.Sprintf( + "tf.feature_column.bucketized_column(%s, boundaries=%s)", + sourceCode, + strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ",")), nil +} + +func (bc *bucketColumn) GetDelimiter() string { + return "" +} + +func (bc *bucketColumn) GetDtype() string { + return "" +} + +func (bc *bucketColumn) GetKey() string { + return bc.SourceColumn.Key +} + +func (bc *bucketColumn) GetInputShape() string { + return bc.SourceColumn.GetInputShape() +} + +func (bc *bucketColumn) GetColumnType() int { + return columnTypeBucket +} + +func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) + } + sourceExprList := (*el)[1] + boundariesExprList := (*el)[2] + if sourceExprList.typ != 0 { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %v", sourceExprList) + } + source, _, err := resolveColumn(&sourceExprList.sexp) + if err != nil { + return nil, err + } + if source.GetColumnType() != columnTypeNumeric { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) + } + boundaries, _, err := resolveExpression(boundariesExprList) + if err != nil { + return nil, err + } + if _, ok := boundaries.([]interface{}); !ok { + return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) + } + b, err := transformToIntList(boundaries.([]interface{})) + if err != nil { + return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) + } + return &bucketColumn{ + SourceColumn: source.(*numericColumn), + Boundaries: b}, nil +} diff --git a/sql/bucket_column_test.go b/sql/bucket_column_test.go new file mode 100644 index 0000000000..309ee34917 --- /dev/null +++ b/sql/bucket_column_test.go @@ -0,0 +1,55 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBucketColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])") + badInput := statementWithColumn("BUCKET(c1, [1, 10])") + badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + bc, ok := fcs[0].(*bucketColumn) + a.True(ok) + code, e := bc.GenerateCode() + a.NoError(e) + a.Equal("c1", bc.SourceColumn.Key) + a.Equal([]int{10}, bc.SourceColumn.Shape) + a.Equal([]int{1, 10}, bc.Boundaries) + a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBoundaries) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} diff --git a/sql/category_id_column.go b/sql/category_id_column.go new file mode 100644 index 0000000000..f083832dbd --- /dev/null +++ b/sql/category_id_column.go @@ -0,0 +1,157 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" +) + +type categoryIDColumn struct { + Key string + BucketSize int + Delimiter string + Dtype string +} + +type sequenceCategoryIDColumn struct { + Key string + BucketSize int + Delimiter string + Dtype string + IsSequence bool +} + +func (cc *categoryIDColumn) GenerateCode() (string, error) { + return fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", + cc.Key, cc.BucketSize), nil +} + +func (cc *categoryIDColumn) GetDelimiter() string { + return cc.Delimiter +} + +func (cc *categoryIDColumn) GetDtype() string { + return cc.Dtype +} + +func (cc *categoryIDColumn) GetKey() string { + return cc.Key +} + +func (cc *categoryIDColumn) GetInputShape() string { + return fmt.Sprintf("[%d]", cc.BucketSize) +} + +func (cc *categoryIDColumn) GetColumnType() int { + return columnTypeCategoryID +} + +func (cc *sequenceCategoryIDColumn) GenerateCode() (string, error) { + return fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", + cc.Key, cc.BucketSize), nil +} + +func (cc *sequenceCategoryIDColumn) GetDelimiter() string { + return cc.Delimiter +} + +func (cc *sequenceCategoryIDColumn) GetDtype() string { + return cc.Dtype +} + +func (cc *sequenceCategoryIDColumn) GetKey() string { + return cc.Key +} + +func (cc *sequenceCategoryIDColumn) GetInputShape() string { + return fmt.Sprintf("[%d]", cc.BucketSize) +} + +func (cc *sequenceCategoryIDColumn) GetColumnType() int { + return columnTypeSeqCategoryID +} + +func parseCategoryColumnKey(el *exprlist) (*columnSpec, error) { + if (*el)[1].typ == 0 { + // explist, maybe DENSE/SPARSE expressions + subExprList := (*el)[1].sexp + isSparse := subExprList[0].val == sparse + return resolveColumnSpec(&subExprList, isSparse) + } + return nil, nil +} + +func resolveSeqCategoryIDColumn(el *exprlist) (*sequenceCategoryIDColumn, *columnSpec, error) { + key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) + if err != nil { + return nil, nil, err + } + return &sequenceCategoryIDColumn{ + Key: key, + BucketSize: bucketSize, + Delimiter: delimiter, + // TODO(typhoonzero): support config dtype + Dtype: "int64", + IsSequence: true}, cs, nil +} + +func resolveCategoryIDColumn(el *exprlist) (*categoryIDColumn, *columnSpec, error) { + key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) + if err != nil { + return nil, nil, err + } + return &categoryIDColumn{ + Key: key, + BucketSize: bucketSize, + Delimiter: delimiter, + // TODO(typhoonzero): support config dtype + Dtype: "int64"}, cs, nil +} + +func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columnSpec, error) { + if len(*el) != 3 && len(*el) != 4 { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) + } + var cs *columnSpec + key := "" + var err error + if (*el)[1].typ == 0 { + // explist, maybe DENSE/SPARSE expressions + subExprList := (*el)[1].sexp + isSparse := subExprList[0].val == sparse + cs, err = resolveColumnSpec(&subExprList, isSparse) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %v", subExprList) + } + key = cs.ColumnName + } else { + key, err = expression2string((*el)[1]) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) + } + } + bucketSize, err := strconv.Atoi((*el)[2].val) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) + } + delimiter := "" + if len(*el) == 4 { + delimiter, err = resolveDelimiter((*el)[3].val) + if err != nil { + return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) + } + } + return key, bucketSize, delimiter, cs, nil +} diff --git a/sql/category_id_column_test.go b/sql/category_id_column_test.go new file mode 100644 index 0000000000..404ff2482c --- /dev/null +++ b/sql/category_id_column_test.go @@ -0,0 +1,70 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCatIdColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("CATEGORY_ID(c1, 100)") + badKey := statementWithColumn("CATEGORY_ID([100], 100)") + badBucket := statementWithColumn("CATEGORY_ID(c1, bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + cc, ok := fcs[0].(*categoryIDColumn) + a.True(ok) + code, e := cc.GenerateCode() + a.NoError(e) + a.Equal("c1", cc.Key) + a.Equal(100, cc.BucketSize) + a.Equal("tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100)", code) + + r, e = parser.Parse(badKey) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucket) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} + +func TestCatIdColumnWithColumnSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + dense := statementWithColumn("CATEGORY_ID(DENSE(col1, 128, COMMA), 100)") + + r, e := parser.Parse(dense) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, css, e := resolveTrainColumns(&c) + a.NoError(e) + _, ok := fcs[0].(*categoryIDColumn) + a.True(ok) + a.Equal(css[0].ColumnName, "col1") +} diff --git a/sql/codegen_alps.go b/sql/codegen_alps.go index d4876b5e03..e4b0f387f2 100644 --- a/sql/codegen_alps.go +++ b/sql/codegen_alps.go @@ -663,9 +663,9 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c shape := make([]int, 1) shape[0] = len(fields) if userSpec, ok := refColumns[ct.Name()]; ok { - output[ct.Name()] = &columnSpec{ct.Name(), false, false, shape, userSpec.DType, userSpec.Delimiter, *meta.featureMap} + output[ct.Name()] = &columnSpec{ct.Name(), false, shape, userSpec.DType, userSpec.Delimiter, *meta.featureMap} } else { - output[ct.Name()] = &columnSpec{ct.Name(), false, false, shape, "float", ",", *meta.featureMap} + output[ct.Name()] = &columnSpec{ct.Name(), false, shape, "float", ",", *meta.featureMap} } } } @@ -712,7 +712,7 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columnSpec, error) { column, present := output[*name] if !present { shape := make([]int, 0, 1000) - column := &columnSpec{*name, false, true, shape, "int64", "", *meta.featureMap} + column := &columnSpec{*name, true, shape, "int64", "", *meta.featureMap} column.DType = "int64" output[*name] = column } diff --git a/sql/column_spec.go b/sql/column_spec.go new file mode 100644 index 0000000000..0a8662a925 --- /dev/null +++ b/sql/column_spec.go @@ -0,0 +1,78 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" +) + +// columnSpec defines how to generate codes to parse column data to tensor/sparsetensor +type columnSpec struct { + ColumnName string + IsSparse bool + Shape []int + DType string + Delimiter string + FeatureMap featureMap +} + +func resolveColumnSpec(el *exprlist, isSparse bool) (*columnSpec, error) { + if len(*el) < 4 { + return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) + } + name, err := expression2string((*el)[1]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec name: %s, err: %s", (*el)[1], err) + } + var shape []int + intShape, err := strconv.Atoi((*el)[2].val) + if err != nil { + strShape, err := expression2string((*el)[2]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) + } + if strShape != "none" { + return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) + } + } else { + shape = append(shape, intShape) + } + unresolvedDelimiter, err := expression2string((*el)[3]) + if err != nil { + return nil, fmt.Errorf("bad FeatureSpec delimiter: %s, err: %s", (*el)[1], err) + } + + delimiter, err := resolveDelimiter(unresolvedDelimiter) + if err != nil { + return nil, err + } + + // resolve feature map + fm := featureMap{} + dtype := "float" + if isSparse { + dtype = "int" + } + if len(*el) >= 5 { + dtype, err = expression2string((*el)[4]) + } + return &columnSpec{ + ColumnName: name, + IsSparse: isSparse, + Shape: shape, + DType: dtype, + Delimiter: delimiter, + FeatureMap: fm}, nil +} diff --git a/sql/column_spec_test.go b/sql/column_spec_test.go new file mode 100644 index 0000000000..904bfa0cfe --- /dev/null +++ b/sql/column_spec_test.go @@ -0,0 +1,59 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFeatureSpec(t *testing.T) { + a := assert.New(t) + parser := newParser() + + denseStatement := statementWithColumn("DENSE(c2, 5, comma)") + sparseStatement := statementWithColumn("SPARSE(c1, 100, comma)") + badStatement := statementWithColumn("DENSE(c3, bad, comma)") + + r, e := parser.Parse(denseStatement) + a.NoError(e) + c := r.columns["feature_columns"] + _, css, e := resolveTrainColumns(&c) + a.NoError(e) + cs := css[0] + a.Equal("c2", cs.ColumnName) + a.Equal(5, cs.Shape[0]) + a.Equal(",", cs.Delimiter) + a.Equal(false, cs.IsSparse) + a.Equal("DenseColumn(name=\"c2\", shape=[5], dtype=\"float\", separator=\",\")", cs.ToString()) + + r, e = parser.Parse(sparseStatement) + a.NoError(e) + c = r.columns["feature_columns"] + _, css, e = resolveTrainColumns(&c) + a.NoError(e) + cs = css[0] + a.Equal("c1", cs.ColumnName) + a.Equal(100, cs.Shape[0]) + a.Equal(true, cs.IsSparse) + a.Equal("SparseColumn(name=\"c1\", shape=[100], dtype=\"int\")", cs.ToString()) + + r, e = parser.Parse(badStatement) + a.NoError(e) + c = r.columns["feature_columns"] + _, _, e = resolveTrainColumns(&c) + a.Error(e) + +} diff --git a/sql/cross_column.go b/sql/cross_column.go new file mode 100644 index 0000000000..6cb2f66dbd --- /dev/null +++ b/sql/cross_column.go @@ -0,0 +1,93 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" + "strings" +) + +// TODO(uuleon) specify the hash_key if needed +type crossColumn struct { + Keys []interface{} + HashBucketSize int +} + +func (cc *crossColumn) GenerateCode() (string, error) { + var keysGenerated = make([]string, len(cc.Keys)) + for idx, key := range cc.Keys { + if c, ok := key.(featureColumn); ok { + code, err := c.GenerateCode() + if err != nil { + return "", err + } + keysGenerated[idx] = code + continue + } + if str, ok := key.(string); ok { + keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) + } else { + return "", fmt.Errorf("cross generate code error, key: %s", key) + } + } + return fmt.Sprintf( + "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", + strings.Join(keysGenerated, ","), cc.HashBucketSize), nil +} + +func (cc *crossColumn) GetDelimiter() string { + return "" +} + +func (cc *crossColumn) GetDtype() string { + return "" +} + +func (cc *crossColumn) GetKey() string { + // NOTE: cross column is a feature on multiple column keys. + return "" +} + +func (cc *crossColumn) GetInputShape() string { + // NOTE: return empty since crossed column input shape is already determined + // by the two crossed columns. + return "" +} + +func (cc *crossColumn) GetColumnType() int { + return columnTypeCross +} + +func resolveCrossColumn(el *exprlist) (*crossColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad CROSS expression format: %s", *el) + } + keysExpr := (*el)[1] + key, _, err := resolveExpression(keysExpr) + if err != nil { + return nil, err + } + if _, ok := key.([]interface{}); !ok { + return nil, fmt.Errorf("bad CROSS expression format: %s", *el) + } + + bucketSize, err := strconv.Atoi((*el)[2].val) + if err != nil { + return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) + } + return &crossColumn{ + Keys: key.([]interface{}), + HashBucketSize: bucketSize}, nil +} diff --git a/sql/cross_column_test.go b/sql/cross_column_test.go new file mode 100644 index 0000000000..9fa5d0f83b --- /dev/null +++ b/sql/cross_column_test.go @@ -0,0 +1,59 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCrossColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], 20)") + badInput := statementWithColumn("cross(c1, 20)") + badBucketSize := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcList, _, e := resolveTrainColumns(&c) + a.NoError(e) + cc, ok := fcList[0].(*crossColumn) + a.True(ok) + code, e := cc.GenerateCode() + a.NoError(e) + + bc := cc.Keys[0].(*bucketColumn) + a.Equal("c1", bc.SourceColumn.Key) + a.Equal([]int{10}, bc.SourceColumn.Shape) + a.Equal([]int{1, 10}, bc.Boundaries) + a.Equal("c5", cc.Keys[1].(string)) + a.Equal(20, cc.HashBucketSize) + a.Equal("tf.feature_column.crossed_column([tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10]),\"c5\"], hash_bucket_size=20)", code) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcList, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucketSize) + a.NoError(e) + c = r.columns["feature_columns"] + fcList, _, e = resolveTrainColumns(&c) + a.Error(e) +} diff --git a/sql/embedding_column.go b/sql/embedding_column.go new file mode 100644 index 0000000000..ddd6b85433 --- /dev/null +++ b/sql/embedding_column.go @@ -0,0 +1,105 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strconv" +) + +type embeddingColumn struct { + CategoryColumn interface{} + Dimension int + Combiner string + Initializer string +} + +func (ec *embeddingColumn) GetDelimiter() string { + return ec.CategoryColumn.(featureColumn).GetDelimiter() +} + +func (ec *embeddingColumn) GetDtype() string { + return ec.CategoryColumn.(featureColumn).GetDtype() +} + +func (ec *embeddingColumn) GetKey() string { + return ec.CategoryColumn.(featureColumn).GetKey() +} + +func (ec *embeddingColumn) GetInputShape() string { + return ec.CategoryColumn.(featureColumn).GetInputShape() +} + +func (ec *embeddingColumn) GetColumnType() int { + return columnTypeEmbedding +} + +func (ec *embeddingColumn) GenerateCode() (string, error) { + catColumn, ok := ec.CategoryColumn.(featureColumn) + if !ok { + return "", fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) + } + sourceCode, err := catColumn.GenerateCode() + if err != nil { + return "", err + } + return fmt.Sprintf("tf.feature_column.embedding_column(%s, dimension=%d, combiner=\"%s\")", + sourceCode, ec.Dimension, ec.Combiner), nil +} + +func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { + if len(*el) != 4 && len(*el) != 5 { + return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) + } + sourceExprList := (*el)[1] + var source featureColumn + var err error + if sourceExprList.typ == 0 { + source, _, err = resolveColumn(&sourceExprList.sexp) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") + } + // TODO(uuleon) support other kinds of categorical column in the future + var catColumn interface{} + catColumn, ok := source.(*categoryIDColumn) + if !ok { + catColumn, ok = source.(*sequenceCategoryIDColumn) + if !ok { + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") + } + } + dimension, err := strconv.Atoi((*el)[2].val) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) + } + combiner, err := expression2string((*el)[3]) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) + } + initializer := "" + if len(*el) == 5 { + initializer, err = expression2string((*el)[4]) + if err != nil { + return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) + } + } + return &embeddingColumn{ + CategoryColumn: catColumn, + Dimension: dimension, + Combiner: combiner, + Initializer: initializer}, nil +} diff --git a/sql/embedding_column_test.go b/sql/embedding_column_test.go new file mode 100644 index 0000000000..a4325080b1 --- /dev/null +++ b/sql/embedding_column_test.go @@ -0,0 +1,57 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbeddingColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), 200, mean)") + badInput := statementWithColumn("EMBEDDING(c1, 100, mean)") + badBucket := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), bad, mean)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + ec, ok := fcs[0].(*embeddingColumn) + a.True(ok) + code, e := ec.GenerateCode() + a.NoError(e) + cc, ok := ec.CategoryColumn.(*categoryIDColumn) + a.True(ok) + a.Equal("c1", cc.Key) + a.Equal(100, cc.BucketSize) + a.Equal(200, ec.Dimension) + a.Equal("tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100), dimension=200, combiner=\"mean\")", code) + + r, e = parser.Parse(badInput) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badBucket) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +} diff --git a/sql/engine_spec.go b/sql/engine_spec.go new file mode 100644 index 0000000000..fbcca0e069 --- /dev/null +++ b/sql/engine_spec.go @@ -0,0 +1,114 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "strconv" + "strings" +) + +type engineSpec struct { + etype string + ps resourceSpec + worker resourceSpec + cluster string + queue string + masterResourceRequest string + masterResourceLimit string + workerResourceRequest string + workerResourceLimit string + volume string + imagePullPolicy string + restartPolicy string + extraPypiIndex string + namespace string + minibatchSize int + masterPodPriority string + clusterSpec string + recordsPerTask int +} + +func getEngineSpec(attrs map[string]*attribute) engineSpec { + getInt := func(key string, defaultValue int) int { + if p, ok := attrs[key]; ok { + strVal, _ := p.Value.(string) + intVal, err := strconv.Atoi(strVal) + + if err == nil { + return intVal + } + } + return defaultValue + } + getString := func(key string, defaultValue string) string { + if p, ok := attrs[key]; ok { + strVal, ok := p.Value.(string) + if ok { + // TODO(joyyoj): use the parser to do those validations. + if strings.HasPrefix(strVal, "\"") && strings.HasSuffix(strVal, "\"") { + return strVal[1 : len(strVal)-1] + } + return strVal + } + } + return defaultValue + } + + psNum := getInt("ps_num", 1) + psMemory := getInt("ps_memory", 2400) + workerMemory := getInt("worker_memory", 1600) + workerNum := getInt("worker_num", 2) + engineType := getString("type", "local") + if (psNum > 0 || workerNum > 0) && engineType == "local" { + engineType = "yarn" + } + cluster := getString("cluster", "") + queue := getString("queue", "") + + // ElasticDL engine specs + masterResourceRequest := getString("master_resource_request", "cpu=0.1,memory=1024Mi") + masterResourceLimit := getString("master_resource_limit", "") + workerResourceRequest := getString("worker_resource_request", "cpu=1,memory=4096Mi") + workerResourceLimit := getString("worker_resource_limit", "") + volume := getString("volume", "") + imagePullPolicy := getString("image_pull_policy", "Always") + restartPolicy := getString("restart_policy", "Never") + extraPypiIndex := getString("extra_pypi_index", "") + namespace := getString("namespace", "default") + minibatchSize := getInt("minibatch_size", 64) + masterPodPriority := getString("master_pod_priority", "") + clusterSpec := getString("cluster_spec", "") + recordsPerTask := getInt("records_per_task", 100) + + return engineSpec{ + etype: engineType, + ps: resourceSpec{Num: psNum, Memory: psMemory}, + worker: resourceSpec{Num: workerNum, Memory: workerMemory}, + cluster: cluster, + queue: queue, + masterResourceRequest: masterResourceRequest, + masterResourceLimit: masterResourceLimit, + workerResourceRequest: workerResourceRequest, + workerResourceLimit: workerResourceLimit, + volume: volume, + imagePullPolicy: imagePullPolicy, + restartPolicy: restartPolicy, + extraPypiIndex: extraPypiIndex, + namespace: namespace, + minibatchSize: minibatchSize, + masterPodPriority: masterPodPriority, + clusterSpec: clusterSpec, + recordsPerTask: recordsPerTask, + } +} diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index 8da0b3b7db..d25dfd426b 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -14,7 +14,6 @@ package sql import ( - "encoding/json" "fmt" "strconv" "strings" @@ -39,36 +38,6 @@ type resourceSpec struct { Core int } -type engineSpec struct { - etype string - ps resourceSpec - worker resourceSpec - cluster string - queue string - masterResourceRequest string - masterResourceLimit string - workerResourceRequest string - workerResourceLimit string - volume string - imagePullPolicy string - restartPolicy string - extraPypiIndex string - namespace string - minibatchSize int - masterPodPriority string - clusterSpec string - recordsPerTask int -} - -type gitLabModule struct { - ModuleName string - ProjectName string - Sha string - PrivateToken string - SourceRoot string - GitLabServer string -} - type resolvedTrainClause struct { IsPreMadeModel bool ModelName string @@ -106,157 +75,11 @@ type resolvedPredictClause struct { EngineParams engineSpec } -// featureColumn is an interface that all types of feature columns and -// attributes (WITH clause) should follow. -// featureColumn is used to generate feature column code. -type featureColumn interface { - GenerateCode() (string, error) - // Some feature columns accept input tensors directly, and the data - // may be a tensor string like: 12,32,4,58,0,0 - GetDelimiter() string - GetDtype() string - GetKey() string - // return input shape json string, like "[2,3]" - GetInputShape() string -} - type featureMap struct { Table string Partition string } -// featureSpec contains information to generate DENSE/SPARSE code -type columnSpec struct { - ColumnName string - AutoDerivation bool - IsSparse bool - Shape []int - DType string - Delimiter string - FeatureMap featureMap -} - -type attribute struct { - FullName string - Prefix string - Name string - Value interface{} -} - -type numericColumn struct { - Key string - Shape []int - Delimiter string - Dtype string -} - -type bucketColumn struct { - SourceColumn *numericColumn - Boundaries []int -} - -// TODO(uuleon) specify the hash_key if needed -type crossColumn struct { - Keys []interface{} - HashBucketSize int -} - -type categoryIDColumn struct { - Key string - BucketSize int - Delimiter string - Dtype string -} - -type sequenceCategoryIDColumn struct { - Key string - BucketSize int - Delimiter string - Dtype string - IsSequence bool -} - -type embeddingColumn struct { - CategoryColumn interface{} - Dimension int - Combiner string - Initializer string -} - -func getEngineSpec(attrs map[string]*attribute) engineSpec { - getInt := func(key string, defaultValue int) int { - if p, ok := attrs[key]; ok { - strVal, _ := p.Value.(string) - intVal, err := strconv.Atoi(strVal) - - if err == nil { - return intVal - } - } - return defaultValue - } - getString := func(key string, defaultValue string) string { - if p, ok := attrs[key]; ok { - strVal, ok := p.Value.(string) - if ok { - // TODO(joyyoj): use the parser to do those validations. - if strings.HasPrefix(strVal, "\"") && strings.HasSuffix(strVal, "\"") { - return strVal[1 : len(strVal)-1] - } - return strVal - } - } - return defaultValue - } - - psNum := getInt("ps_num", 1) - psMemory := getInt("ps_memory", 2400) - workerMemory := getInt("worker_memory", 1600) - workerNum := getInt("worker_num", 2) - engineType := getString("type", "local") - if (psNum > 0 || workerNum > 0) && engineType == "local" { - engineType = "yarn" - } - cluster := getString("cluster", "") - queue := getString("queue", "") - - // ElasticDL engine specs - masterResourceRequest := getString("master_resource_request", "cpu=0.1,memory=1024Mi") - masterResourceLimit := getString("master_resource_limit", "") - workerResourceRequest := getString("worker_resource_request", "cpu=1,memory=4096Mi") - workerResourceLimit := getString("worker_resource_limit", "") - volume := getString("volume", "") - imagePullPolicy := getString("image_pull_policy", "Always") - restartPolicy := getString("restart_policy", "Never") - extraPypiIndex := getString("extra_pypi_index", "") - namespace := getString("namespace", "default") - minibatchSize := getInt("minibatch_size", 64) - masterPodPriority := getString("master_pod_priority", "") - clusterSpec := getString("cluster_spec", "") - recordsPerTask := getInt("records_per_task", 100) - - return engineSpec{ - etype: engineType, - ps: resourceSpec{Num: psNum, Memory: psMemory}, - worker: resourceSpec{Num: workerNum, Memory: workerMemory}, - cluster: cluster, - queue: queue, - masterResourceRequest: masterResourceRequest, - masterResourceLimit: masterResourceLimit, - workerResourceRequest: workerResourceRequest, - workerResourceLimit: workerResourceLimit, - volume: volume, - imagePullPolicy: imagePullPolicy, - restartPolicy: restartPolicy, - extraPypiIndex: extraPypiIndex, - namespace: namespace, - minibatchSize: minibatchSize, - masterPodPriority: masterPodPriority, - clusterSpec: clusterSpec, - recordsPerTask: recordsPerTask, - } -} - func trimQuotes(s string) string { if len(s) >= 2 { if s[0] == '"' && s[len(s)-1] == '"' { @@ -462,162 +285,94 @@ func resolveTrainColumns(columns *exprlist) ([]featureColumn, []*columnSpec, err var fcs = make([]featureColumn, 0) var css = make([]*columnSpec, 0) for _, expr := range *columns { - result, err := resolveExpression(expr) - if err != nil { - return nil, nil, err - } - if cs, ok := result.(*columnSpec); ok { - css = append(css, cs) - continue - } else if c, ok := result.(featureColumn); ok { - fcs = append(fcs, c) - } else if s, ok := result.(string); ok { - // simple string column, generate default numeric column + if expr.typ != 0 { + // Column identifier like "COLUMN a1,b1" + // FIXME(typhoonzero): infer the column spec here. c := &numericColumn{ - Key: s, + Key: expr.val, Shape: []int{1}, Dtype: "float32", } fcs = append(fcs, c) } else { - return nil, nil, fmt.Errorf("not recognized type: %s", result) + result, cs, err := resolveColumn(&expr.sexp) + if err != nil { + return nil, nil, err + } + if cs != nil { + css = append(css, cs) + } + if result != nil { + fcs = append(fcs, result) + } } } return fcs, css, nil } func getExpressionFieldName(expr *expr) (string, error) { - result, err := resolveExpression(expr) + if expr.typ != 0 { + return expr.val, nil + } + fc, _, err := resolveColumn(&expr.sexp) if err != nil { return "", err } - switch r := result.(type) { - case *columnSpec: - return r.ColumnName, nil - case featureColumn: - return r.GetKey(), nil - case string: - return r, nil - default: - return "", fmt.Errorf("getExpressionFieldName: unrecognized type %T", r) - } + return fc.GetKey(), nil } -// resolveExpression resolve a SQLFlow expression to the actual value -// see: sql.y:241 for the definition of expression. -func resolveExpression(e interface{}) (interface{}, error) { +// resolveExpression parse the expression recursively and +// returns the actual value of the expression: +// featureColumns, columnSpecs, error +// e.g. +// column_1 -> "column_1", nil, nil +// [1,2,3,4] -> [1,2,3,4], nil, nil +// [NUMERIC(col1), col2] -> [*numericColumn, "col2"], nil, nil +func resolveExpression(e interface{}) (interface{}, interface{}, error) { if expr, ok := e.(*expr); ok { - if expr.val != "" { - return expr.val, nil + if expr.typ != 0 { + return expr.val, nil, nil } return resolveExpression(&expr.sexp) } - el, ok := e.(*exprlist) if !ok { - return nil, fmt.Errorf("input of resolveExpression must be `expr` or `exprlist` given %s", e) - } - - head := (*el)[0].val - if head == "" { - return resolveExpression(&(*el)[0].sexp) + return nil, nil, fmt.Errorf("input of resolveExpression must be `expr` or `exprlist` given %s", e) } - - switch strings.ToUpper(head) { - case dense: - return resolveColumnSpec(el, false) - case sparse: - return resolveColumnSpec(el, true) - case numeric: - return resolveNumericColumn(el) - case bucket: - return resolveBucketColumn(el) - case cross: - return resolveCrossColumn(el) - case categoryID: - return resolveCategoryIDColumn(el, false) - case seqCategoryID: - return resolveCategoryIDColumn(el, true) - case embedding: - return resolveEmbeddingColumn(el) - case square: + headTyp := (*el)[0].typ + if headTyp == IDENT { + // Expression is a function call + return resolveColumn(el) + } else if headTyp == '[' { var list []interface{} + var columnSpecList []interface{} for idx, expr := range *el { if idx > 0 { if expr.sexp == nil { intVal, err := strconv.Atoi(expr.val) + // TODO: support list of float etc. if err != nil { list = append(list, expr.val) } else { list = append(list, intVal) } } else { - value, err := resolveExpression(&expr.sexp) + value, cs, err := resolveExpression(&expr.sexp) if err != nil { - return nil, err + return nil, nil, err } list = append(list, value) + columnSpecList = append(columnSpecList, cs) } } } - return list, nil - default: - return nil, fmt.Errorf("not supported expr: %s", head) + return list, columnSpecList, nil } -} - -func resolveColumnSpec(el *exprlist, isSparse bool) (*columnSpec, error) { - if len(*el) < 4 { - return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) - } - name, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec name: %s, err: %s", (*el)[1], err) - } - var shape []int - intShape, err := strconv.Atoi((*el)[2].val) - if err != nil { - strShape, err := expression2string((*el)[2]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) - } - if strShape != "none" { - return nil, fmt.Errorf("bad FeatureSpec shape: %s, err: %s", (*el)[2].val, err) - } - } else { - shape = append(shape, intShape) - } - unresolvedDelimiter, err := expression2string((*el)[3]) - if err != nil { - return nil, fmt.Errorf("bad FeatureSpec delimiter: %s, err: %s", (*el)[1], err) - } - - delimiter, err := resolveDelimiter(unresolvedDelimiter) - if err != nil { - return nil, err - } - - // resolve feature map - fm := featureMap{} - dtype := "float" - if isSparse { - dtype = "int" - } - if len(*el) >= 5 { - dtype, err = expression2string((*el)[4]) - } - return &columnSpec{ - ColumnName: name, - AutoDerivation: false, - IsSparse: isSparse, - Shape: shape, - DType: dtype, - Delimiter: delimiter, - FeatureMap: fm}, nil + return nil, nil, fmt.Errorf("not supported expr: %v", el) } func expression2string(e interface{}) (string, error) { - resolved, err := resolveExpression(e) + resolved, _, err := resolveExpression(e) if err != nil { return "", err } @@ -646,311 +401,6 @@ func (cs *columnSpec) ToString() string { cs.Delimiter) } -func resolveNumericColumn(el *exprlist) (*numericColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) - } - key, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad NUMERIC key: %s, err: %s", (*el)[1], err) - } - var shape []int - intVal, err := strconv.Atoi((*el)[2].val) - if err != nil { - list, err := resolveExpression((*el)[2]) - if err != nil { - return nil, err - } - if list, ok := list.([]interface{}); ok { - shape, err = transformToIntList(list) - if err != nil { - return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) - } - } else { - return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) - } - } else { - shape = append(shape, intVal) - } - return &numericColumn{ - Key: key, - Shape: shape, - // FIXME(typhoonzero, tony): support config Delimiter and Dtype - Delimiter: ",", - Dtype: "float32"}, nil -} - -func resolveBucketColumn(el *exprlist) (*bucketColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad BUCKET expression format: %s", *el) - } - sourceExprList := (*el)[1] - boundariesExprList := (*el)[2] - source, err := resolveExpression(sourceExprList) - if err != nil { - return nil, err - } - if _, ok := source.(*numericColumn); !ok { - return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) - } - boundaries, err := resolveExpression(boundariesExprList) - if err != nil { - return nil, err - } - if _, ok := boundaries.([]interface{}); !ok { - return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) - } - b, err := transformToIntList(boundaries.([]interface{})) - if err != nil { - return nil, fmt.Errorf("bad BUCKET boundaries: %s", err) - } - return &bucketColumn{ - SourceColumn: source.(*numericColumn), - Boundaries: b}, nil -} - -func resolveCrossColumn(el *exprlist) (*crossColumn, error) { - if len(*el) != 3 { - return nil, fmt.Errorf("bad CROSS expression format: %s", *el) - } - keysExpr := (*el)[1] - keys, err := resolveExpression(keysExpr) - if err != nil { - return nil, err - } - if _, ok := keys.([]interface{}); !ok { - return nil, fmt.Errorf("bad CROSS keys: %s", err) - } - bucketSize, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad CROSS bucketSize: %s, err: %s", (*el)[2].val, err) - } - return &crossColumn{ - Keys: keys.([]interface{}), - HashBucketSize: bucketSize}, nil -} - -func resolveCategoryIDColumn(el *exprlist, isSequence bool) (interface{}, error) { - if len(*el) != 3 && len(*el) != 4 { - return nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) - } - key, err := expression2string((*el)[1]) - if err != nil { - return nil, fmt.Errorf("bad CATEGORY_ID key: %s, err: %s", (*el)[1], err) - } - bucketSize, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad CATEGORY_ID bucketSize: %s, err: %s", (*el)[2].val, err) - } - delimiter := "" - if len(*el) == 4 { - delimiter, err = resolveDelimiter((*el)[3].val) - if err != nil { - return nil, fmt.Errorf("bad CATEGORY_ID delimiter: %s, %s", (*el)[3].val, err) - } - } - if isSequence { - return &sequenceCategoryIDColumn{ - Key: key, - BucketSize: bucketSize, - Delimiter: delimiter, - // TODO(typhoonzero): support config dtype - Dtype: "int64", - IsSequence: true}, nil - } - return &categoryIDColumn{ - Key: key, - BucketSize: bucketSize, - Delimiter: delimiter, - // TODO(typhoonzero): support config dtype - Dtype: "int64"}, nil -} - -func resolveEmbeddingColumn(el *exprlist) (*embeddingColumn, error) { - if len(*el) != 4 && len(*el) != 5 { - return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) - } - sourceExprList := (*el)[1] - source, err := resolveExpression(sourceExprList) - if err != nil { - return nil, err - } - // TODO(uuleon) support other kinds of categorical column in the future - var catColumn interface{} - catColumn, ok := source.(*categoryIDColumn) - if !ok { - catColumn, ok = source.(*sequenceCategoryIDColumn) - if !ok { - return nil, fmt.Errorf("key of EMBEDDING must be categorical column") - } - } - dimension, err := strconv.Atoi((*el)[2].val) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) - } - combiner, err := expression2string((*el)[3]) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) - } - initializer := "" - if len(*el) == 5 { - initializer, err = expression2string((*el)[4]) - if err != nil { - return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) - } - } - return &embeddingColumn{ - CategoryColumn: catColumn, - Dimension: dimension, - Combiner: combiner, - Initializer: initializer}, nil -} - -func (nc *numericColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, - strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ",")), nil -} - -func (nc *numericColumn) GetDelimiter() string { - return nc.Delimiter -} - -func (nc *numericColumn) GetDtype() string { - return nc.Dtype -} - -func (nc *numericColumn) GetKey() string { - return nc.Key -} - -func (nc *numericColumn) GetInputShape() string { - jsonBytes, err := json.Marshal(nc.Shape) - if err != nil { - return "" - } - return string(jsonBytes) -} - -func (bc *bucketColumn) GenerateCode() (string, error) { - sourceCode, _ := bc.SourceColumn.GenerateCode() - return fmt.Sprintf( - "tf.feature_column.bucketized_column(%s, boundaries=%s)", - sourceCode, - strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ",")), nil -} - -func (bc *bucketColumn) GetDelimiter() string { - return "" -} - -func (bc *bucketColumn) GetDtype() string { - return "" -} - -func (bc *bucketColumn) GetKey() string { - return bc.SourceColumn.Key -} - -func (bc *bucketColumn) GetInputShape() string { - return bc.SourceColumn.GetInputShape() -} - -func (cc *crossColumn) GenerateCode() (string, error) { - var keysGenerated = make([]string, len(cc.Keys)) - for idx, key := range cc.Keys { - if c, ok := key.(featureColumn); ok { - code, err := c.GenerateCode() - if err != nil { - return "", err - } - keysGenerated[idx] = code - continue - } - if str, ok := key.(string); ok { - keysGenerated[idx] = fmt.Sprintf("\"%s\"", str) - } else { - return "", fmt.Errorf("cross generate code error, key: %s", key) - } - } - return fmt.Sprintf( - "tf.feature_column.crossed_column([%s], hash_bucket_size=%d)", - strings.Join(keysGenerated, ","), cc.HashBucketSize), nil -} - -func (cc *crossColumn) GetDelimiter() string { - return "" -} - -func (cc *crossColumn) GetDtype() string { - return "" -} - -func (cc *crossColumn) GetKey() string { - // NOTE: cross column is a feature on multiple column keys. - return "" -} - -func (cc *crossColumn) GetInputShape() string { - // NOTE: return empty since crossed column input shape is already determined - // by the two crossed columns. - return "" -} - -func (cc *categoryIDColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", - cc.Key, cc.BucketSize), nil -} - -func (cc *categoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -func (cc *categoryIDColumn) GetDtype() string { - return cc.Dtype -} - -func (cc *categoryIDColumn) GetKey() string { - return cc.Key -} - -func (cc *categoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - -func (cc *sequenceCategoryIDColumn) GenerateCode() (string, error) { - return fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", - cc.Key, cc.BucketSize), nil -} - -func (cc *sequenceCategoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -func (cc *sequenceCategoryIDColumn) GetDtype() string { - return cc.Dtype -} - -func (cc *sequenceCategoryIDColumn) GetKey() string { - return cc.Key -} - -func (cc *sequenceCategoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - -func (ec *embeddingColumn) GenerateCode() (string, error) { - catColumn, ok := ec.CategoryColumn.(featureColumn) - if !ok { - return "", fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) - } - sourceCode, err := catColumn.GenerateCode() - if err != nil { - return "", err - } - return fmt.Sprintf("tf.feature_column.embedding_column(%s, dimension=%d, combiner=\"%s\")", - sourceCode, ec.Dimension, ec.Combiner), nil -} - func generateFeatureColumnCode(fcs []featureColumn) (string, error) { var codes = make([]string, 0, len(fcs)) for _, fc := range fcs { @@ -963,46 +413,6 @@ func generateFeatureColumnCode(fcs []featureColumn) (string, error) { return fmt.Sprintf("[%s]", strings.Join(codes, ",")), nil } -func (ec *embeddingColumn) GetDelimiter() string { - return ec.CategoryColumn.(featureColumn).GetDelimiter() -} - -func (ec *embeddingColumn) GetDtype() string { - return ec.CategoryColumn.(featureColumn).GetDtype() -} - -func (ec *embeddingColumn) GetKey() string { - return ec.CategoryColumn.(featureColumn).GetKey() -} - -func (ec *embeddingColumn) GetInputShape() string { - return ec.CategoryColumn.(featureColumn).GetInputShape() -} - -func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { - ret := make(map[string]*attribute) - for k, v := range *attrs { - subs := strings.SplitN(k, ".", 2) - name := subs[len(subs)-1] - prefix := "" - if len(subs) == 2 { - prefix = subs[0] - } - r, err := resolveExpression(v) - if err != nil { - return nil, err - } - - a := &attribute{ - FullName: k, - Prefix: prefix, - Name: name, - Value: r} - ret[a.FullName] = a - } - return ret, nil -} - func resolveDelimiter(delimiter string) (string, error) { if strings.EqualFold(delimiter, comma) { return ",", nil @@ -1010,25 +420,6 @@ func resolveDelimiter(delimiter string) (string, error) { return "", fmt.Errorf("unsolved delimiter: %s", delimiter) } -func (a *attribute) GenerateCode() (string, error) { - if val, ok := a.Value.(string); ok { - // auto convert to int first. - if _, err := strconv.Atoi(val); err == nil { - return fmt.Sprintf("%s=%s", a.Name, val), nil - } - return fmt.Sprintf("%s=\"%s\"", a.Name, val), nil - } - if val, ok := a.Value.([]interface{}); ok { - intList, err := transformToIntList(val) - if err != nil { - return "", err - } - return fmt.Sprintf("%s=%s", a.Name, - strings.Join(strings.Split(fmt.Sprint(intList), " "), ",")), nil - } - return "", fmt.Errorf("value of attribute must be string or list of int, given %s", a.Value) -} - func transformToIntList(list []interface{}) ([]int, error) { var b = make([]int, len(list)) for idx, item := range list { @@ -1040,16 +431,3 @@ func transformToIntList(list []interface{}) ([]int, error) { } return b, nil } - -func filter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute { - ret := make(map[string]*attribute, 0) - for _, a := range attrs { - if strings.EqualFold(a.Prefix, prefix) { - ret[a.Name] = a - if remove { - delete(attrs, a.FullName) - } - } - } - return ret -} diff --git a/sql/expression_resolver_test.go b/sql/expression_resolver_test.go index 29b70506d1..263bb8f2a6 100644 --- a/sql/expression_resolver_test.go +++ b/sql/expression_resolver_test.go @@ -38,248 +38,6 @@ func statementWithAttrs(attrs string) string { return fmt.Sprintf(trainStatement, attrs, "DENSE(c2, 5, comma)") } -func TestFeatureSpec(t *testing.T) { - a := assert.New(t) - parser := newParser() - - denseStatement := statementWithColumn("DENSE(c2, 5, comma)") - sparseStatement := statementWithColumn("SPARSE(c1, 100, comma)") - badStatement := statementWithColumn("DENSE(c3, bad, comma)") - - r, e := parser.Parse(denseStatement) - a.NoError(e) - c := r.columns["feature_columns"] - _, css, e := resolveTrainColumns(&c) - a.NoError(e) - cs := css[0] - a.Equal("c2", cs.ColumnName) - a.Equal(5, cs.Shape[0]) - a.Equal(",", cs.Delimiter) - a.Equal(false, cs.IsSparse) - a.Equal("DenseColumn(name=\"c2\", shape=[5], dtype=\"float\", separator=\",\")", cs.ToString()) - - r, e = parser.Parse(sparseStatement) - a.NoError(e) - c = r.columns["feature_columns"] - _, css, e = resolveTrainColumns(&c) - a.NoError(e) - cs = css[0] - a.Equal("c1", cs.ColumnName) - a.Equal(100, cs.Shape[0]) - a.Equal(true, cs.IsSparse) - a.Equal("SparseColumn(name=\"c1\", shape=[100], dtype=\"int\")", cs.ToString()) - - r, e = parser.Parse(badStatement) - a.NoError(e) - c = r.columns["feature_columns"] - _, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestNumericColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("NUMERIC(c2, [5, 10])") - moreArgs := statementWithColumn("NUMERIC(c1, 100, args)") - badShape := statementWithColumn("NUMERIC(c1, bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - nc, ok := fcs[0].(*numericColumn) - a.True(ok) - code, e := nc.GenerateCode() - a.NoError(e) - a.Equal("c2", nc.Key) - a.Equal([]int{5, 10}, nc.Shape) - a.Equal("tf.feature_column.numeric_column(\"c2\", shape=[5,10])", code) - - r, e = parser.Parse(moreArgs) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badShape) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestBucketColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])") - badInput := statementWithColumn("BUCKET(c1, [1, 10])") - badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - bc, ok := fcs[0].(*bucketColumn) - a.True(ok) - code, e := bc.GenerateCode() - a.NoError(e) - a.Equal("c1", bc.SourceColumn.Key) - a.Equal([]int{10}, bc.SourceColumn.Shape) - a.Equal([]int{1, 10}, bc.Boundaries) - a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBoundaries) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestCrossColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], 20)") - badInput := statementWithColumn("cross(c1, 20)") - badBucketSize := statementWithColumn("cross([BUCKET(NUMERIC(c1, 10), [1, 10]), c5], bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcList, _, e := resolveTrainColumns(&c) - a.NoError(e) - cc, ok := fcList[0].(*crossColumn) - a.True(ok) - code, e := cc.GenerateCode() - a.NoError(e) - - bc := cc.Keys[0].(*bucketColumn) - a.Equal("c1", bc.SourceColumn.Key) - a.Equal([]int{10}, bc.SourceColumn.Shape) - a.Equal([]int{1, 10}, bc.Boundaries) - a.Equal("c5", cc.Keys[1].(string)) - a.Equal(20, cc.HashBucketSize) - a.Equal("tf.feature_column.crossed_column([tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10]),\"c5\"], hash_bucket_size=20)", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcList, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucketSize) - a.NoError(e) - c = r.columns["feature_columns"] - fcList, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestCatIdColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("CATEGORY_ID(c1, 100)") - badKey := statementWithColumn("CATEGORY_ID([100], 100)") - badBucket := statementWithColumn("CATEGORY_ID(c1, bad)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - cc, ok := fcs[0].(*categoryIDColumn) - a.True(ok) - code, e := cc.GenerateCode() - a.NoError(e) - a.Equal("c1", cc.Key) - a.Equal(100, cc.BucketSize) - a.Equal("tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100)", code) - - r, e = parser.Parse(badKey) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucket) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestEmbeddingColumn(t *testing.T) { - a := assert.New(t) - parser := newParser() - - normal := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), 200, mean)") - badInput := statementWithColumn("EMBEDDING(c1, 100, mean)") - badBucket := statementWithColumn("EMBEDDING(CATEGORY_ID(c1, 100), bad, mean)") - - r, e := parser.Parse(normal) - a.NoError(e) - c := r.columns["feature_columns"] - fcs, _, e := resolveTrainColumns(&c) - a.NoError(e) - ec, ok := fcs[0].(*embeddingColumn) - a.True(ok) - code, e := ec.GenerateCode() - a.NoError(e) - cc, ok := ec.CategoryColumn.(*categoryIDColumn) - a.True(ok) - a.Equal("c1", cc.Key) - a.Equal(100, cc.BucketSize) - a.Equal(200, ec.Dimension) - a.Equal("tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(key=\"c1\", num_buckets=100), dimension=200, combiner=\"mean\")", code) - - r, e = parser.Parse(badInput) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) - - r, e = parser.Parse(badBucket) - a.NoError(e) - c = r.columns["feature_columns"] - fcs, _, e = resolveTrainColumns(&c) - a.Error(e) -} - -func TestAttrs(t *testing.T) { - a := assert.New(t) - parser := newParser() - - s := statementWithAttrs("estimator.hidden_units = [10, 20]") - r, e := parser.Parse(s) - a.NoError(e) - attrs, err := resolveAttribute(&r.trainAttrs) - a.NoError(err) - attr := attrs["estimator.hidden_units"] - a.Equal("estimator", attr.Prefix) - a.Equal("hidden_units", attr.Name) - a.Equal([]interface{}([]interface{}{10, 20}), attr.Value) - - s = statementWithAttrs("dataset.name = hello") - r, e = parser.Parse(s) - a.NoError(e) - attrs, err = resolveAttribute(&r.trainAttrs) - a.NoError(err) - attr = attrs["dataset.name"] - a.Equal("dataset", attr.Prefix) - a.Equal("name", attr.Name) - a.Equal("hello", attr.Value) -} - func TestExecResource(t *testing.T) { a := assert.New(t) parser := newParser() @@ -289,6 +47,5 @@ func TestExecResource(t *testing.T) { attrs, err := resolveAttribute(&r.trainAttrs) a.NoError(err) attr := attrs["exec.worker_num"] - fmt.Println(attr) - + a.Equal(attr.Value, "2") } diff --git a/sql/feature_column.go b/sql/feature_column.go new file mode 100644 index 0000000000..1be8becfb6 --- /dev/null +++ b/sql/feature_column.go @@ -0,0 +1,79 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "strings" +) + +const ( + columnTypeBucket = 0 + columnTypeEmbedding = 1 + columnTypeNumeric = 2 + columnTypeCategoryID = 3 + columnTypeSeqCategoryID = 3 + columnTypeCross = 4 +) + +// featureColumn is an interface that all types of feature columns +// should follow. featureColumn is used to generate feature column code. +type featureColumn interface { + GenerateCode() (string, error) + // FIXME(typhoonzero): remove delimiter, dtype shape from feature column + // get these from column spec claused or by feature derivation. + GetDelimiter() string + GetDtype() string + GetKey() string + // return input shape json string, like "[2,3]" + GetInputShape() string + GetColumnType() int +} + +// resolveFeatureColumn returns the acutal feature column typed struct +// as well as the columnSpec infomation. +func resolveColumn(el *exprlist) (featureColumn, *columnSpec, error) { + head := (*el)[0].val + if head == "" { + return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %v", el) + } + + switch strings.ToUpper(head) { + case dense: + cs, err := resolveColumnSpec(el, false) + return nil, cs, err + case sparse: + cs, err := resolveColumnSpec(el, true) + return nil, cs, err + case numeric: + // TODO(typhoonzero): support NUMERIC(DENSE(col)) and NUMERIC(SPARSE(col)) + fc, err := resolveNumericColumn(el) + return fc, nil, err + case bucket: + fc, err := resolveBucketColumn(el) + return fc, nil, err + case cross: + fc, err := resolveCrossColumn(el) + return fc, nil, err + case categoryID: + return resolveCategoryIDColumn(el) + case seqCategoryID: + return resolveSeqCategoryIDColumn(el) + case embedding: + fc, err := resolveEmbeddingColumn(el) + return fc, nil, err + default: + return nil, nil, fmt.Errorf("not supported expr: %s", head) + } +} diff --git a/sql/gitlab_module.go b/sql/gitlab_module.go new file mode 100644 index 0000000000..aa26d707fd --- /dev/null +++ b/sql/gitlab_module.go @@ -0,0 +1,23 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +type gitLabModule struct { + ModuleName string + ProjectName string + Sha string + PrivateToken string + SourceRoot string + GitLabServer string +} diff --git a/sql/numeric_column.go b/sql/numeric_column.go new file mode 100644 index 0000000000..022ebf4204 --- /dev/null +++ b/sql/numeric_column.go @@ -0,0 +1,91 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +type numericColumn struct { + Key string + Shape []int + Delimiter string + Dtype string +} + +func (nc *numericColumn) GenerateCode() (string, error) { + return fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, + strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ",")), nil +} + +func (nc *numericColumn) GetDelimiter() string { + return nc.Delimiter +} + +func (nc *numericColumn) GetDtype() string { + return nc.Dtype +} + +func (nc *numericColumn) GetKey() string { + return nc.Key +} + +func (nc *numericColumn) GetInputShape() string { + jsonBytes, err := json.Marshal(nc.Shape) + if err != nil { + return "" + } + return string(jsonBytes) +} + +func (nc *numericColumn) GetColumnType() int { + return columnTypeNumeric +} + +func resolveNumericColumn(el *exprlist) (*numericColumn, error) { + if len(*el) != 3 { + return nil, fmt.Errorf("bad NUMERIC expression format: %s", *el) + } + key, err := expression2string((*el)[1]) + if err != nil { + return nil, fmt.Errorf("bad NUMERIC key: %s, err: %s", (*el)[1], err) + } + var shape []int + intVal, err := strconv.Atoi((*el)[2].val) + if err != nil { + list, _, err := resolveExpression((*el)[2]) + if err != nil { + return nil, err + } + if list, ok := list.([]interface{}); ok { + shape, err = transformToIntList(list) + if err != nil { + return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) + } + } else { + return nil, fmt.Errorf("bad NUMERIC shape: %s, err: %s", (*el)[2].val, err) + } + } else { + shape = append(shape, intVal) + } + return &numericColumn{ + Key: key, + Shape: shape, + // FIXME(typhoonzero, tony): support config Delimiter and Dtype + Delimiter: ",", + Dtype: "float32"}, nil +} diff --git a/sql/numeric_column_test.go b/sql/numeric_column_test.go new file mode 100644 index 0000000000..38bd65f05b --- /dev/null +++ b/sql/numeric_column_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 The SQLFlow Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNumericColumn(t *testing.T) { + a := assert.New(t) + parser := newParser() + + normal := statementWithColumn("NUMERIC(c2, [5, 10])") + moreArgs := statementWithColumn("NUMERIC(c1, 100, args)") + badShape := statementWithColumn("NUMERIC(c1, bad)") + + r, e := parser.Parse(normal) + a.NoError(e) + c := r.columns["feature_columns"] + fcs, _, e := resolveTrainColumns(&c) + a.NoError(e) + nc, ok := fcs[0].(*numericColumn) + a.True(ok) + code, e := nc.GenerateCode() + a.NoError(e) + a.Equal("c2", nc.Key) + a.Equal([]int{5, 10}, nc.Shape) + a.Equal("tf.feature_column.numeric_column(\"c2\", shape=[5,10])", code) + + r, e = parser.Parse(moreArgs) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) + + r, e = parser.Parse(badShape) + a.NoError(e) + c = r.columns["feature_columns"] + fcs, _, e = resolveTrainColumns(&c) + a.Error(e) +}