Skip to content
23 changes: 22 additions & 1 deletion src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
MetadataCommands,
)
from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings
from databricks.sql.backend.sea.utils.metadata_transforms import (
create_table_catalog_transform,
)
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
from databricks.sql.backend.sea.utils.result_column import ResultColumn
from databricks.sql.backend.sea.utils.conversion import SqlType
from databricks.sql.thrift_api.TCLIService import ttypes

if TYPE_CHECKING:
Expand Down Expand Up @@ -740,7 +745,23 @@ def get_schemas(
assert isinstance(
result, SeaResultSet
), "Expected SeaResultSet from SEA backend"
result.prepare_metadata_columns(MetadataColumnMappings.SCHEMA_COLUMNS)

# Create dynamic schema columns with catalog name bound to TABLE_CATALOG
schema_columns = []
for col in MetadataColumnMappings.SCHEMA_COLUMNS:
if col.thrift_col_name == "TABLE_CATALOG":
# Create a new column with the catalog transform bound
dynamic_col = ResultColumn(
col.thrift_col_name,
col.sea_col_name,
col.thrift_col_type,
create_table_catalog_transform(catalog_name),
)
schema_columns.append(dynamic_col)
else:
schema_columns.append(col)

result.prepare_metadata_columns(schema_columns)
return result

def get_tables(
Expand Down
17 changes: 13 additions & 4 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _prepare_column_mapping(self) -> None:
None,
None,
None,
True,
None,
)

# Set the mapping
Expand Down Expand Up @@ -356,14 +356,20 @@ def _normalise_arrow_metadata_cols(self, table: "pyarrow.Table") -> "pyarrow.Tab
if self._column_index_mapping
else None
)

column = (
pyarrow.nulls(table.num_rows)
if old_idx is None
else table.column(old_idx)
)
new_columns.append(column)

# Apply transform if available
if result_column.transform_value:
# Convert to list, apply transform, and convert back
values = column.to_pylist()
transformed_values = [result_column.transform_value(v) for v in values]
column = pyarrow.array(transformed_values)

new_columns.append(column)
column_names.append(result_column.thrift_col_name)

return pyarrow.Table.from_arrays(new_columns, names=column_names)
Expand All @@ -382,8 +388,11 @@ def _normalise_json_metadata_cols(self, rows: List[List[str]]) -> List[List[Any]
if self._column_index_mapping
else None
)

value = None if old_idx is None else row[old_idx]

# Apply transform if available
if result_column.transform_value:
value = result_column.transform_value(value)
new_row.append(value)
transformed_rows.append(new_row)
return transformed_rows
30 changes: 25 additions & 5 deletions src/databricks/sql/backend/sea/utils/metadata_mappings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from databricks.sql.backend.sea.utils.result_column import ResultColumn
from databricks.sql.backend.sea.utils.conversion import SqlType
from databricks.sql.backend.sea.utils.metadata_transforms import (
transform_remarks,
transform_is_autoincrement,
transform_is_nullable,
transform_nullable,
transform_data_type,
transform_ordinal_position,
)


class MetadataColumnMappings:
Expand All @@ -18,7 +26,9 @@ class MetadataColumnMappings:
SCHEMA_COLUMN = ResultColumn("TABLE_SCHEM", "namespace", SqlType.STRING)
TABLE_NAME_COLUMN = ResultColumn("TABLE_NAME", "tableName", SqlType.STRING)
TABLE_TYPE_COLUMN = ResultColumn("TABLE_TYPE", "tableType", SqlType.STRING)
REMARKS_COLUMN = ResultColumn("REMARKS", "remarks", SqlType.STRING)
REMARKS_COLUMN = ResultColumn(
"REMARKS", "remarks", SqlType.STRING, transform_remarks
)
TYPE_CATALOG_COLUMN = ResultColumn("TYPE_CAT", None, SqlType.STRING)
TYPE_SCHEM_COLUMN = ResultColumn("TYPE_SCHEM", None, SqlType.STRING)
TYPE_NAME_COLUMN = ResultColumn("TYPE_NAME", None, SqlType.STRING)
Expand All @@ -28,7 +38,9 @@ class MetadataColumnMappings:
REF_GENERATION_COLUMN = ResultColumn("REF_GENERATION", None, SqlType.STRING)

COL_NAME_COLUMN = ResultColumn("COLUMN_NAME", "col_name", SqlType.STRING)
DATA_TYPE_COLUMN = ResultColumn("DATA_TYPE", None, SqlType.INT)
DATA_TYPE_COLUMN = ResultColumn(
"DATA_TYPE", "columnType", SqlType.INT, transform_data_type
)
COLUMN_TYPE_COLUMN = ResultColumn("TYPE_NAME", "columnType", SqlType.STRING)
COLUMN_SIZE_COLUMN = ResultColumn("COLUMN_SIZE", "columnSize", SqlType.INT)
BUFFER_LENGTH_COLUMN = ResultColumn("BUFFER_LENGTH", None, SqlType.TINYINT)
Expand All @@ -43,22 +55,30 @@ class MetadataColumnMappings:
"ORDINAL_POSITION",
"ordinalPosition",
SqlType.INT,
transform_ordinal_position,
)

NULLABLE_COLUMN = ResultColumn("NULLABLE", None, SqlType.INT)
NULLABLE_COLUMN = ResultColumn(
"NULLABLE", "isNullable", SqlType.INT, transform_nullable
)
COLUMN_DEF_COLUMN = ResultColumn("COLUMN_DEF", None, SqlType.STRING)
SQL_DATA_TYPE_COLUMN = ResultColumn("SQL_DATA_TYPE", None, SqlType.INT)
SQL_DATETIME_SUB_COLUMN = ResultColumn("SQL_DATETIME_SUB", None, SqlType.INT)
CHAR_OCTET_LENGTH_COLUMN = ResultColumn("CHAR_OCTET_LENGTH", None, SqlType.INT)
IS_NULLABLE_COLUMN = ResultColumn("IS_NULLABLE", "isNullable", SqlType.STRING)
IS_NULLABLE_COLUMN = ResultColumn(
"IS_NULLABLE", "isNullable", SqlType.STRING, transform_is_nullable
)

SCOPE_CATALOG_COLUMN = ResultColumn("SCOPE_CATALOG", None, SqlType.STRING)
SCOPE_SCHEMA_COLUMN = ResultColumn("SCOPE_SCHEMA", None, SqlType.STRING)
SCOPE_TABLE_COLUMN = ResultColumn("SCOPE_TABLE", None, SqlType.STRING)
SOURCE_DATA_TYPE_COLUMN = ResultColumn("SOURCE_DATA_TYPE", None, SqlType.SMALLINT)

IS_AUTO_INCREMENT_COLUMN = ResultColumn(
"IS_AUTOINCREMENT", "isAutoIncrement", SqlType.STRING
"IS_AUTO_INCREMENT",
"isAutoIncrement",
SqlType.STRING,
transform_is_autoincrement,
)
IS_GENERATED_COLUMN = ResultColumn(
"IS_GENERATEDCOLUMN", "isGenerated", SqlType.STRING
Expand Down
83 changes: 83 additions & 0 deletions src/databricks/sql/backend/sea/utils/metadata_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Simple transformation functions for metadata value normalization."""


def transform_is_autoincrement(value):
"""Transform IS_AUTOINCREMENT: boolean to YES/NO string."""
if isinstance(value, bool) or value is None:
return "YES" if value else "NO"
return value


def transform_is_nullable(value):
"""Transform IS_NULLABLE: true/false to YES/NO string."""
if value is True or value == "true":
return "YES"
elif value is False or value == "false":
return "NO"
return value


def transform_remarks(value):
if value is None:
return ""
return value


def transform_nullable(value):
"""Transform NULLABLE column: boolean/string to integer."""
if value is True or value == "true" or value == "YES":
return 1
elif value is False or value == "false" or value == "NO":
return 0
return value


# Type code mapping based on JDBC specification
TYPE_CODE_MAP = {
"STRING": 12, # VARCHAR
"VARCHAR": 12, # VARCHAR
"CHAR": 1, # CHAR
"INT": 4, # INTEGER
"INTEGER": 4, # INTEGER
"BIGINT": -5, # BIGINT
"SMALLINT": 5, # SMALLINT
"TINYINT": -6, # TINYINT
"DOUBLE": 8, # DOUBLE
"FLOAT": 6, # FLOAT
"REAL": 7, # REAL
"DECIMAL": 3, # DECIMAL
"NUMERIC": 2, # NUMERIC
"BOOLEAN": 16, # BOOLEAN
"DATE": 91, # DATE
"TIMESTAMP": 93, # TIMESTAMP
"BINARY": -2, # BINARY
"ARRAY": 2003, # ARRAY
"MAP": 2002, # JAVA_OBJECT
"STRUCT": 2002, # JAVA_OBJECT
}


def transform_data_type(value):
"""Transform DATA_TYPE: type name to JDBC type code."""
if isinstance(value, str):
# Handle parameterized types like DECIMAL(10,2)
base_type = value.split("(")[0].upper()
return TYPE_CODE_MAP.get(base_type, value)
return value


def transform_ordinal_position(value):
"""Transform ORDINAL_POSITION: decrement by 1 (1-based to 0-based)."""
if isinstance(value, int):
return value - 1
return value


def create_table_catalog_transform(catalog_name):
"""Factory function to create TABLE_CATALOG transform with bound catalog name."""

def transform_table_catalog(value):
"""Transform TABLE_CATALOG: return the catalog name for all rows."""
return catalog_name

return transform_table_catalog
4 changes: 3 additions & 1 deletion src/databricks/sql/backend/sea/utils/result_column.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Callable, Any


@dataclass(frozen=True)
Expand All @@ -11,8 +11,10 @@ class ResultColumn:
thrift_col_name: Column name as returned by Thrift (e.g., "TABLE_CAT")
sea_col_name: Server result column name from SEA (e.g., "catalog")
thrift_col_type: SQL type name
transform_value: Optional callback to transform values for this column
"""

thrift_col_name: str
sea_col_name: Optional[str] # None if SEA doesn't return this column
thrift_col_type: str
transform_value: Optional[Callable[[Any], Any]] = None
13 changes: 11 additions & 2 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,17 @@ def test_get_schemas(self):
finally:
cursor.execute("DROP DATABASE IF EXISTS {}".format(database_name))

def test_get_catalogs(self):
with self.cursor({}) as cursor:
@pytest.mark.parametrize(
"backend_params",
[
{},
{
"use_sea": True,
},
],
)
def test_get_catalogs(self, backend_params):
with self.cursor(backend_params) as cursor:
cursor.catalogs()
cursor.fetchall()
catalogs_desc = cursor.description
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_metadata_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_column_columns_mapping(self):
"TABLE_SCHEM": ("namespace", SqlType.STRING),
"TABLE_NAME": ("tableName", SqlType.STRING),
"COLUMN_NAME": ("col_name", SqlType.STRING),
"DATA_TYPE": (None, SqlType.INT),
"DATA_TYPE": ("columnType", SqlType.INT),
"TYPE_NAME": ("columnType", SqlType.STRING),
"COLUMN_SIZE": ("columnSize", SqlType.INT),
"DECIMAL_DIGITS": ("decimalDigits", SqlType.INT),
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_sea_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,6 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor):

# Verify prepare_metadata_columns was called for successful cases
assert mock_result_set.prepare_metadata_columns.call_count == 2
mock_result_set.prepare_metadata_columns.assert_called_with(
MetadataColumnMappings.SCHEMA_COLUMNS
)

def test_get_tables(self, sea_client, sea_session_id, mock_cursor):
"""Test the get_tables method with various parameter combinations."""
Expand Down
Loading