Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions sql/python/sqlflow_submitter/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,92 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import os
import contextlib
import numpy as np
import tensorflow as tf
import sqlflow_submitter.db_writer as db_writer


def parseMySQLDSN(dsn):
# [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
user, passwd, host, port, database, config_str = re.findall("^(\w*):(\w*)@tcp\(([.a-zA-Z0-9]*):([0-9]*)\)/(\w*)(\?.*)?$", dsn)[0]
config = {}
if len(config_str) > 1:
for c in config_str[1:].split("&"):
k, v = c.split("=")
config[k] = v
return user, passwd, host, port, database, config


def parseHiveDSN(dsn):
# usr:pswd@hiveserver:10000/mydb?auth=PLAIN&session.mapreduce_job_quenename=mr
user_passwd, address_database, config_str = re.findall("^(.*)@([.a-zA-Z0-9/:]*)(\?.*)?", dsn)[0]
user, passwd = user_passwd.split(":")
if len(address_database.split("/")) > 1:
address, database = address_database.split("/")
else:
address, database = address_database, None
if len(address.split(":")) > 1:
host, port = address.split(":")
else:
host, port = address, None
config = {}
if len(config_str) > 1:
for c in config_str[1:].split("&"):
k, v = c.split("=")
config[k] = v
return user, passwd, host, port, database, config


def parseMaxComputeDSN(dsn):
# access_id:[email protected]/api?curr_project=test_ci&scheme=http
user_passwd, address, config_str = re.findall("^(.*)@([.a-zA-Z0-9/]*)(\?.*)?", dsn)[0]
user, passwd = user_passwd.split(":")
config = {}
if len(config_str) > 1:
for c in config_str[1:].split("&"):
k, v = c.split("=")
config[k] = v
if "scheme" in config:
address = config["scheme"] + "://" + address
return user, passwd, address, config["curr_project"]


def connect_with_data_source(dsn):
driver, source = data_source.split("://")
if driver == "mysql":
# NOTE: use MySQLdb to avoid bugs like infinite reading:
# https://bugs.mysql.com/bug.php?id=91971
from MySQLdb import connect
user, passwd, host, port, database, config = parseMySQLDSN(dsn)
conn = connect(user=user,
passwd=password,
db=database,
host=host,
port=int(port))
elif driver == "hive":
from impala.dbapi import connect
user, passwd, host, port, database, config = parseHiveDSN(dsn)
auth = config["auth"] if "auth" in config else ""
conn = connect(user=user,
password=password,
database=database,
host=host,
port=int(port),
auth_mechanism=auth)
elif driver == "maxcompute":
from sqlflow_submitter.maxcompute import MaxCompute
user, passwd, address, database = parseMaxComputeDSN(dsn)
conn = MaxCompute.connect(database, user, password, address)
else:
raise ValueError("connect_with_data_source doesn't support driver type {}".format(driver))

conn.driver = driver
return conn


def connect(driver, database, user, password, host, port, session_cfg={}, auth=""):
if driver == "mysql":
# NOTE: use MySQLdb to avoid bugs like infinite reading:
Expand Down Expand Up @@ -49,6 +129,7 @@ def connect(driver, database, user, password, host, port, session_cfg={}, auth="
def db_generator(driver, conn, statement,
feature_column_names, label_column_name,
feature_specs, fetch_size=128):

def read_feature(raw_val, feature_spec, feature_name):
# FIXME(typhoonzero): Should use correct dtype here.
if feature_spec["is_sparse"]:
Expand Down Expand Up @@ -108,6 +189,7 @@ def reader():
from sqlflow_submitter.maxcompute import MaxCompute
return MaxCompute.db_generator(conn, statement, feature_column_names,
label_column_name, feature_specs, fetch_size)

return reader


Expand Down
30 changes: 27 additions & 3 deletions sql/python/sqlflow_submitter/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,32 @@ def test_generate_fetch_size(self):
"is_sparse": False,
"shape": []
}}


gen = db_generator(driver, conn, 'SELECT * FROM iris.train limit 10',
["sepal_length"], "class", column_name_to_type, fetch_size=4)
["sepal_length"], "class", column_name_to_type, fetch_size=4)
self.assertEqual(len([g for g in gen()]), 10)

from sqlflow_submitter.db import parseHiveDSN, parseMaxComputeDSN,parseMySQLDSN

class TestConnectWithDataSource(TestCase):
def test_parse_mysql_dsn(self):
# [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
self.assertEqual(
("usr", "pswd", "localhost", "8000", "mydb", {"param1":"value1"}),
parseMySQLDSN("usr:pswd@tcp(localhost:8000)/mydb?param1=value1"))

def test_parse_hive_dsn(self):
self.assertEqual(
("usr", "pswd", "hiveserver", "1000", "mydb", {"auth":"PLAIN", "session.mapreduce_job_quenename": "mr"}),
parseHiveDSN("usr:pswd@hiveserver:1000/mydb?auth=PLAIN&session.mapreduce_job_quenename=mr"))
self.assertEqual(
("root", "root", "127.0.0.1", None, "mnist", {"auth":"PLAIN"}),
parseHiveDSN("root:[email protected]/mnist?auth=PLAIN"))
self.assertEqual(
("root", "root", "127.0.0.1", None, None, {}),
parseHiveDSN("root:[email protected]"))

def test_parse_maxcompute_dsn(self):
self.assertEqual(
("access_id", "access_key", "http://service.com/api", "test_ci"),
parseMaxComputeDSN("access_id:[email protected]/api?curr_project=test_ci&scheme=http"))