Skip to content

Commit 933fb08

Browse files
Return error on LABEL with string type (#1000)
* Return error on LABEL with string type * fix CI
1 parent d3ec581 commit 933fb08

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

pkg/sql/codegen.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
228228
} else if v == "BIGINT" {
229229
labelDtype = "int64"
230230
} else {
231-
log.Fatalf("Unsupported label data type: %s", v)
231+
return nil, fmt.Errorf("unsupported label data type: %s", v)
232232
}
233233
}
234234
r.Y = &FeatureMeta{

pkg/sql/codegen_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ package sql
1515

1616
import (
1717
"io/ioutil"
18+
"strings"
1819
"testing"
1920

2021
"github.com/stretchr/testify/assert"
@@ -85,3 +86,22 @@ func TestCodeGenPredict(t *testing.T) {
8586

8687
a.NoError(genTF(ioutil.Discard, r, nil, fts, testDB))
8788
}
89+
90+
func TestLabelAsStringType(t *testing.T) {
91+
a := assert.New(t)
92+
r, e := newParser().Parse(`SELECT customerID, gender FROM churn.train
93+
TRAIN DNNClassifier
94+
WITH
95+
model.n_classes = 3,
96+
model.hidden_units = [10, 20]
97+
COLUMN customerID
98+
LABEL gender
99+
INTO sqlflow_models.my_dnn_model;`)
100+
a.NoError(e)
101+
102+
fts, e := verify(r, testDB)
103+
a.NoError(e)
104+
e = genTF(ioutil.Discard, r, nil, fts, testDB)
105+
a.NotNil(e)
106+
a.True(strings.HasPrefix(e.Error(), "unsupported label data type:"))
107+
}

0 commit comments

Comments
 (0)