Skip to content

Commit fd806cb

Browse files
authored
simplify expression resolver (#722)
* wip simplify expression resolver * fix ci * wip * wip * updates * wip * update * fix ci
1 parent 2e29c2f commit fd806cb

20 files changed

+1352
-912
lines changed

sql/attribute.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2019 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package sql
15+
16+
import (
17+
"fmt"
18+
"strconv"
19+
"strings"
20+
)
21+
22+
type attribute struct {
23+
FullName string
24+
Prefix string
25+
Name string
26+
Value interface{}
27+
}
28+
29+
func (a *attribute) GenerateCode() (string, error) {
30+
if val, ok := a.Value.(string); ok {
31+
// auto convert to int first.
32+
if _, err := strconv.Atoi(val); err == nil {
33+
return fmt.Sprintf("%s=%s", a.Name, val), nil
34+
}
35+
return fmt.Sprintf("%s=\"%s\"", a.Name, val), nil
36+
}
37+
if val, ok := a.Value.([]interface{}); ok {
38+
intList, err := transformToIntList(val)
39+
if err != nil {
40+
return "", err
41+
}
42+
return fmt.Sprintf("%s=%s", a.Name,
43+
strings.Join(strings.Split(fmt.Sprint(intList), " "), ",")), nil
44+
}
45+
return "", fmt.Errorf("value of attribute must be string or list of int, given %s", a.Value)
46+
}
47+
48+
func filter(attrs map[string]*attribute, prefix string, remove bool) map[string]*attribute {
49+
ret := make(map[string]*attribute, 0)
50+
for _, a := range attrs {
51+
if strings.EqualFold(a.Prefix, prefix) {
52+
ret[a.Name] = a
53+
if remove {
54+
delete(attrs, a.FullName)
55+
}
56+
}
57+
}
58+
return ret
59+
}
60+
61+
func resolveAttribute(attrs *attrs) (map[string]*attribute, error) {
62+
ret := make(map[string]*attribute)
63+
for k, v := range *attrs {
64+
subs := strings.SplitN(k, ".", 2)
65+
name := subs[len(subs)-1]
66+
prefix := ""
67+
if len(subs) == 2 {
68+
prefix = subs[0]
69+
}
70+
r, _, err := resolveExpression(v)
71+
if err != nil {
72+
return nil, err
73+
}
74+
a := &attribute{
75+
FullName: k,
76+
Prefix: prefix,
77+
Name: name,
78+
Value: r}
79+
ret[a.FullName] = a
80+
}
81+
return ret, nil
82+
}

sql/attribute_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright 2019 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package sql
15+
16+
import (
17+
"testing"
18+
19+
"github.com/stretchr/testify/assert"
20+
)
21+
22+
func TestAttrs(t *testing.T) {
23+
a := assert.New(t)
24+
parser := newParser()
25+
26+
s := statementWithAttrs("estimator.hidden_units = [10, 20]")
27+
r, e := parser.Parse(s)
28+
a.NoError(e)
29+
attrs, err := resolveAttribute(&r.trainAttrs)
30+
a.NoError(err)
31+
attr := attrs["estimator.hidden_units"]
32+
a.Equal("estimator", attr.Prefix)
33+
a.Equal("hidden_units", attr.Name)
34+
a.Equal([]interface{}([]interface{}{10, 20}), attr.Value)
35+
36+
s = statementWithAttrs("dataset.name = hello")
37+
r, e = parser.Parse(s)
38+
a.NoError(e)
39+
attrs, err = resolveAttribute(&r.trainAttrs)
40+
a.NoError(err)
41+
attr = attrs["dataset.name"]
42+
a.Equal("dataset", attr.Prefix)
43+
a.Equal("name", attr.Name)
44+
a.Equal("hello", attr.Value)
45+
}

sql/bucket_column.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright 2019 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package sql
15+
16+
import (
17+
"fmt"
18+
"strings"
19+
)
20+
21+
type bucketColumn struct {
22+
SourceColumn *numericColumn
23+
Boundaries []int
24+
}
25+
26+
func (bc *bucketColumn) GenerateCode() (string, error) {
27+
sourceCode, _ := bc.SourceColumn.GenerateCode()
28+
return fmt.Sprintf(
29+
"tf.feature_column.bucketized_column(%s, boundaries=%s)",
30+
sourceCode,
31+
strings.Join(strings.Split(fmt.Sprint(bc.Boundaries), " "), ",")), nil
32+
}
33+
34+
func (bc *bucketColumn) GetDelimiter() string {
35+
return ""
36+
}
37+
38+
func (bc *bucketColumn) GetDtype() string {
39+
return ""
40+
}
41+
42+
func (bc *bucketColumn) GetKey() string {
43+
return bc.SourceColumn.Key
44+
}
45+
46+
func (bc *bucketColumn) GetInputShape() string {
47+
return bc.SourceColumn.GetInputShape()
48+
}
49+
50+
func (bc *bucketColumn) GetColumnType() int {
51+
return columnTypeBucket
52+
}
53+
54+
func resolveBucketColumn(el *exprlist) (*bucketColumn, error) {
55+
if len(*el) != 3 {
56+
return nil, fmt.Errorf("bad BUCKET expression format: %s", *el)
57+
}
58+
sourceExprList := (*el)[1]
59+
boundariesExprList := (*el)[2]
60+
if sourceExprList.typ != 0 {
61+
return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %v", sourceExprList)
62+
}
63+
source, _, err := resolveColumn(&sourceExprList.sexp)
64+
if err != nil {
65+
return nil, err
66+
}
67+
if source.GetColumnType() != columnTypeNumeric {
68+
return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source)
69+
}
70+
boundaries, _, err := resolveExpression(boundariesExprList)
71+
if err != nil {
72+
return nil, err
73+
}
74+
if _, ok := boundaries.([]interface{}); !ok {
75+
return nil, fmt.Errorf("bad BUCKET boundaries: %s", err)
76+
}
77+
b, err := transformToIntList(boundaries.([]interface{}))
78+
if err != nil {
79+
return nil, fmt.Errorf("bad BUCKET boundaries: %s", err)
80+
}
81+
return &bucketColumn{
82+
SourceColumn: source.(*numericColumn),
83+
Boundaries: b}, nil
84+
}

sql/bucket_column_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2019 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package sql
15+
16+
import (
17+
"testing"
18+
19+
"github.com/stretchr/testify/assert"
20+
)
21+
22+
func TestBucketColumn(t *testing.T) {
23+
a := assert.New(t)
24+
parser := newParser()
25+
26+
normal := statementWithColumn("BUCKET(NUMERIC(c1, 10), [1, 10])")
27+
badInput := statementWithColumn("BUCKET(c1, [1, 10])")
28+
badBoundaries := statementWithColumn("BUCKET(NUMERIC(c1, 10), 100)")
29+
30+
r, e := parser.Parse(normal)
31+
a.NoError(e)
32+
c := r.columns["feature_columns"]
33+
fcs, _, e := resolveTrainColumns(&c)
34+
a.NoError(e)
35+
bc, ok := fcs[0].(*bucketColumn)
36+
a.True(ok)
37+
code, e := bc.GenerateCode()
38+
a.NoError(e)
39+
a.Equal("c1", bc.SourceColumn.Key)
40+
a.Equal([]int{10}, bc.SourceColumn.Shape)
41+
a.Equal([]int{1, 10}, bc.Boundaries)
42+
a.Equal("tf.feature_column.bucketized_column(tf.feature_column.numeric_column(\"c1\", shape=[10]), boundaries=[1,10])", code)
43+
44+
r, e = parser.Parse(badInput)
45+
a.NoError(e)
46+
c = r.columns["feature_columns"]
47+
fcs, _, e = resolveTrainColumns(&c)
48+
a.Error(e)
49+
50+
r, e = parser.Parse(badBoundaries)
51+
a.NoError(e)
52+
c = r.columns["feature_columns"]
53+
fcs, _, e = resolveTrainColumns(&c)
54+
a.Error(e)
55+
}

0 commit comments

Comments
 (0)