diff --git a/sql/python/sqlflow_submitter/db.py b/sql/python/sqlflow_submitter/db.py index 756587f16a..ac895dec70 100644 --- a/sql/python/sqlflow_submitter/db.py +++ b/sql/python/sqlflow_submitter/db.py @@ -11,12 +11,92 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import os import contextlib import numpy as np import tensorflow as tf import sqlflow_submitter.db_writer as db_writer + +def parseMySQLDSN(dsn): + # [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + user, passwd, host, port, database, config_str = re.findall("^(\w*):(\w*)@tcp\(([.a-zA-Z0-9]*):([0-9]*)\)/(\w*)(\?.*)?$", dsn)[0] + config = {} + if len(config_str) > 1: + for c in config_str[1:].split("&"): + k, v = c.split("=") + config[k] = v + return user, passwd, host, port, database, config + + +def parseHiveDSN(dsn): + # usr:pswd@hiveserver:10000/mydb?auth=PLAIN&session.mapreduce_job_quenename=mr + user_passwd, address_database, config_str = re.findall("^(.*)@([.a-zA-Z0-9/:]*)(\?.*)?", dsn)[0] + user, passwd = user_passwd.split(":") + if len(address_database.split("/")) > 1: + address, database = address_database.split("/") + else: + address, database = address_database, None + if len(address.split(":")) > 1: + host, port = address.split(":") + else: + host, port = address, None + config = {} + if len(config_str) > 1: + for c in config_str[1:].split("&"): + k, v = c.split("=") + config[k] = v + return user, passwd, host, port, database, config + + +def parseMaxComputeDSN(dsn): + # access_id:access_key@service.com/api?curr_project=test_ci&scheme=http + user_passwd, address, config_str = re.findall("^(.*)@([.a-zA-Z0-9/]*)(\?.*)?", dsn)[0] + user, passwd = user_passwd.split(":") + config = {} + if len(config_str) > 1: + for c in config_str[1:].split("&"): + k, v = c.split("=") + config[k] = v + if "scheme" in config: + address = config["scheme"] + "://" + address + return user, passwd, address, config["curr_project"] + + +def connect_with_data_source(dsn): + driver, source = data_source.split("://") + if driver == "mysql": + # NOTE: use MySQLdb to avoid bugs like infinite reading: + # https://bugs.mysql.com/bug.php?id=91971 + from MySQLdb import connect + user, passwd, host, port, database, config = parseMySQLDSN(dsn) + conn = connect(user=user, + passwd=password, + db=database, + host=host, + port=int(port)) + elif driver == "hive": + from impala.dbapi import connect + user, passwd, host, port, database, config = parseHiveDSN(dsn) + auth = config["auth"] if "auth" in config else "" + conn = connect(user=user, + password=password, + database=database, + host=host, + port=int(port), + auth_mechanism=auth) + elif driver == "maxcompute": + from sqlflow_submitter.maxcompute import MaxCompute + user, passwd, address, database = parseMaxComputeDSN(dsn) + conn = MaxCompute.connect(database, user, password, address) + else: + raise ValueError("connect_with_data_source doesn't support driver type {}".format(driver)) + + conn.driver = driver + return conn + + def connect(driver, database, user, password, host, port, session_cfg={}, auth=""): if driver == "mysql": # NOTE: use MySQLdb to avoid bugs like infinite reading: @@ -49,6 +129,7 @@ def connect(driver, database, user, password, host, port, session_cfg={}, auth=" def db_generator(driver, conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size=128): + def read_feature(raw_val, feature_spec, feature_name): # FIXME(typhoonzero): Should use correct dtype here. if feature_spec["is_sparse"]: @@ -108,6 +189,7 @@ def reader(): from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.db_generator(conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size) + return reader diff --git a/sql/python/sqlflow_submitter/db_test.py b/sql/python/sqlflow_submitter/db_test.py index 0f142320c0..84ccd15a33 100644 --- a/sql/python/sqlflow_submitter/db_test.py +++ b/sql/python/sqlflow_submitter/db_test.py @@ -157,8 +157,32 @@ def test_generate_fetch_size(self): "is_sparse": False, "shape": [] }} - - gen = db_generator(driver, conn, 'SELECT * FROM iris.train limit 10', - ["sepal_length"], "class", column_name_to_type, fetch_size=4) + ["sepal_length"], "class", column_name_to_type, fetch_size=4) self.assertEqual(len([g for g in gen()]), 10) + +from sqlflow_submitter.db import parseHiveDSN, parseMaxComputeDSN,parseMySQLDSN + +class TestConnectWithDataSource(TestCase): + def test_parse_mysql_dsn(self): + # [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + self.assertEqual( + ("usr", "pswd", "localhost", "8000", "mydb", {"param1":"value1"}), + parseMySQLDSN("usr:pswd@tcp(localhost:8000)/mydb?param1=value1")) + + def test_parse_hive_dsn(self): + self.assertEqual( + ("usr", "pswd", "hiveserver", "1000", "mydb", {"auth":"PLAIN", "session.mapreduce_job_quenename": "mr"}), + parseHiveDSN("usr:pswd@hiveserver:1000/mydb?auth=PLAIN&session.mapreduce_job_quenename=mr")) + self.assertEqual( + ("root", "root", "127.0.0.1", None, "mnist", {"auth":"PLAIN"}), + parseHiveDSN("root:root@127.0.0.1/mnist?auth=PLAIN")) + self.assertEqual( + ("root", "root", "127.0.0.1", None, None, {}), + parseHiveDSN("root:root@127.0.0.1")) + + def test_parse_maxcompute_dsn(self): + self.assertEqual( + ("access_id", "access_key", "http://service.com/api", "test_ci"), + parseMaxComputeDSN("access_id:access_key@service.com/api?curr_project=test_ci&scheme=http")) +