From 07107f8494869e555e7d522757c4c61836024e7a Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 2 Sep 2019 16:09:29 +0800 Subject: [PATCH] fix write hive table with unordered row data --- sql/python/sqlflow_submitter/db_test.py | 4 +-- .../sqlflow_submitter/db_writer/hive.py | 31 +++++++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/sql/python/sqlflow_submitter/db_test.py b/sql/python/sqlflow_submitter/db_test.py index 6ebfcaf888..f3b9861cd4 100644 --- a/sql/python/sqlflow_submitter/db_test.py +++ b/sql/python/sqlflow_submitter/db_test.py @@ -85,8 +85,8 @@ def test_hive(self): def _do_test(self, driver, conn): table_name = "test_db" - table_schema = ["features", "label"] - values = [('5,6,1,2', 1)] * 10 + table_schema = ["label", "features"] + values = [(1, '5,6,1,2')] * 10 execute(driver, conn, self.drop_statement) diff --git a/sql/python/sqlflow_submitter/db_writer/hive.py b/sql/python/sqlflow_submitter/db_writer/hive.py index 3df735b1ce..bedf2c6045 100644 --- a/sql/python/sqlflow_submitter/db_writer/hive.py +++ b/sql/python/sqlflow_submitter/db_writer/hive.py @@ -17,16 +17,43 @@ import tempfile import subprocess +CSV_DELIMITER = '\001' + class HiveDBWriter(BufferedDBWriter): def __init__(self, conn, table_name, table_schema, buff_size=10000): super().__init__(conn, table_name, table_schema, buff_size) self.tmp_f = tempfile.NamedTemporaryFile(dir="./") self.f = open(self.tmp_f.name, "w") + self.schema_idx = self._indexing_table_schema(table_schema) + + def _indexing_table_schema(self, table_schema): + cursor = self.conn.cursor() + cursor.execute("describe %s" % self.table_name) + column_list = cursor.fetchall() + schema_idx = [] + idx_map = {} + # column list: [(col1, type, desc), (col2, type, desc)...] + for i, e in enumerate(column_list): + idx_map[e[0]] = i + + for s in table_schema: + if s not in idx_map: + raise ValueError("column: %s should be in table columns:%s" % (s, idx_map)) + schema_idx.append(idx_map[s]) + + return schema_idx + + def _ordered_row_data(self, row): + # Use NULL as the default value for hive columns + row_data = ["NULL" for i in range(len(self.table_schema))] + for idx, element in enumerate(row): + row_data[self.schema_idx[idx]] = str(element) + return CSV_DELIMITER.join(row_data) def flush(self): for row in self.rows: - line = "%s\n" % '\001'.join([str(v) for v in row]) - self.f.write(line) + data = self._ordered_row_data(row) + self.f.write(data+'\n') self.rows = [] def write_hive_table(self):