Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sqeleton/abcs/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

class AbstractCompiler(ABC):
@abstractmethod
def compile(self, elem: Any, params: Dict[str, Any] = None) -> str:
...
def compile(self, elem: Any, params: Dict[str, Any] = None) -> str: ...


class Compilable(ABC):
Expand Down
4 changes: 1 addition & 3 deletions sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ class String_FixedAlphanum(String_Alphanum):

def make_value(self, value):
if len(value) != self.length:
raise ValueError(
f"Expected alphanumeric value of length {self.length}, but got '{value}'."
)
raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.")
return self.python_type(value, max_len=self.length)


Expand Down
3 changes: 2 additions & 1 deletion sqeleton/bound_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def with_schema(self, schema):
table_path = self.node.replace(schema=schema)
return self.replace(node=table_path)

def query_schema(self, *, refine: bool = True, refine_where = None, case_sensitive=True):
def query_schema(self, *, refine: bool = True, refine_where=None, case_sensitive=True):
table_path = self.node

if table_path.schema:
Expand All @@ -77,5 +77,6 @@ def bound_table(database: AbstractDatabase, table_path: Union[TablePath, str, tu


if TYPE_CHECKING:

class BoundTable(BoundTable, TablePath):
pass
9 changes: 3 additions & 6 deletions sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,13 @@ def compile(self, sql_ast):
# logger.setLevel(level)

@overload
def query(self, query_input: QueryInput) -> Any:
...
def query(self, query_input: QueryInput) -> Any: ...

@overload
def query(self, query_input: QueryInput, res_type: None) -> Any:
...
def query(self, query_input: QueryInput, res_type: None) -> Any: ...

@overload
def query(self, query_input: QueryInput, res_type: Type[TRes]) -> TRes:
...
def query(self, query_input: QueryInput, res_type: Type[TRes]) -> TRes: ...

def query(self, query_input, res_type=None):
"""Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
Expand Down
2 changes: 1 addition & 1 deletion sqeleton/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
class Dialect(BaseDialect):
name = "Clickhouse"
ROUNDS_ON_PREC_LOSS = False
ARG_SYMBOL = None # TODO Clickhouse only supports named parameters, not positional
ARG_SYMBOL = None # TODO Clickhouse only supports named parameters, not positional
TYPE_CLASSES = {
"Int8": Integer,
"Int16": Integer,
Expand Down
4 changes: 3 additions & 1 deletion sqeleton/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def _process_table_schema(
self._refine_coltypes(path, col_dict, where)
return col_dict

def process_query_table_schema(self, path: DbPath, raw_schema: Dict[str, Tuple], refine: bool = True, refine_where: Optional[str] = None) -> Tuple[Dict[str, ColType], Optional[list]]:
def process_query_table_schema(
self, path: DbPath, raw_schema: Dict[str, Tuple], refine: bool = True, refine_where: Optional[str] = None
) -> Tuple[Dict[str, ColType], Optional[list]]:
if not refine:
raise NotImplementedError()
return self._process_table_schema(path, raw_schema, list(raw_schema), refine_where), None
Expand Down
16 changes: 13 additions & 3 deletions sqeleton/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@
Native_UUID,
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, Database, QueryResult, import_helper, ThreadLocalInterpreter, Mixin_Schema, Mixin_RandomSample, SqlCode, logger
from .base import (
BaseDialect,
Database,
QueryResult,
import_helper,
ThreadLocalInterpreter,
Mixin_Schema,
Mixin_RandomSample,
SqlCode,
logger,
)
from .base import (
MD5_HEXDIGITS,
CHECKSUM_HEXDIGITS,
Expand Down Expand Up @@ -76,7 +86,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
class Dialect(BaseDialect, Mixin_Schema):
name = "Presto"
ROUNDS_ON_PREC_LOSS = True
ARG_SYMBOL = None # Not implemented by Presto
ARG_SYMBOL = None # Not implemented by Presto
TYPE_CLASSES = {
# Timestamps
"timestamp with time zone": TimestampTZ,
Expand Down Expand Up @@ -186,7 +196,7 @@ def _query(self, sql_code: SqlCode) -> Optional[QueryResult]:
if isinstance(sql_code, ThreadLocalInterpreter):
return sql_code.apply_queries(partial(query_cursor, c))
elif isinstance(sql_code, str):
sql_code = CompiledCode(sql_code, [], None) # Unknown type. #TODO: Should we guess?
sql_code = CompiledCode(sql_code, [], None) # Unknown type. #TODO: Should we guess?

return query_cursor(c, sql_code)

Expand Down
11 changes: 5 additions & 6 deletions sqeleton/databases/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
else:
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"

return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
return (
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
)

def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
if isinstance(coltype, String_UUID):
Expand Down Expand Up @@ -55,17 +57,14 @@ def __init__(self, **kw):
self.default_schema = kw.get("schema")

if kw.get("password"):
kw["auth"] = trino.auth.BasicAuthentication(
kw.pop("user"), kw.pop("password")
)
kw["auth"] = trino.auth.BasicAuthentication(kw.pop("user"), kw.pop("password"))
kw["http_scheme"] = "https"

cert = kw.pop("cert", None)
self._conn = trino.dbapi.connect(**kw)
if cert is not None:
self._conn._http_session.verify = cert


@property
def is_autocommit(self) -> bool:
return True
return True
4 changes: 0 additions & 4 deletions sqeleton/queries/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from .ast_classes import Expr, ExprNode, Concat, Code




@dataclass
class NormalizeAsString(ExprNode):
expr: ExprNode
Expand All @@ -35,7 +33,6 @@ def compile_node(c: Compiler, n: NormalizeAsString) -> str:
expr = c.compile(n.expr)
return c.dialect.normalize_value_by_type(expr, n.expr_type or n.expr.type)


@md
def compile_node(c: Compiler, n: ApplyFuncAndNormalizeAsString) -> str:
expr = n.expr
Expand All @@ -56,7 +53,6 @@ def compile_node(c: Compiler, n: ApplyFuncAndNormalizeAsString) -> str:

return c.compile(expr)


@md
def compile_node(c: Compiler, n: Checksum) -> str:
if len(n.exprs) > 1:
Expand Down
10 changes: 8 additions & 2 deletions sqeleton/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

Schema = CaseAwareMapping


class TableType:
pass
# TODO: This should replace the current Schema type
Expand All @@ -21,6 +22,7 @@ def is_superclass(cls, t):

SchemaInput = Union[Type[TableType], Schema, dict]


@dataclass
class Options:
default: Any = None
Expand All @@ -29,18 +31,21 @@ class Options:
# TODO: foreign_key, unique
# TODO: index?


@dataclass
class _Field:
type: type
options: Options


class _Schema(CaseAwareMapping[Union[type, _Field]]):
pass

@classmethod
def make(cls, schema: SchemaInput):
assert schema
if TableType.is_superclass(schema):

def _make_field(k: str, v: type):
field = getattr(schema, k)
if field:
Expand All @@ -49,7 +54,7 @@ def _make_field(k: str, v: type):
return _Field(v, field)
return v

schema = CaseSensitiveDict({k:_make_field(k, v) for k,v in schema.__annotations__.items()})
schema = CaseSensitiveDict({k: _make_field(k, v) for k, v in schema.__annotations__.items()})

elif isinstance(schema, CaseAwareMapping):
pass
Expand All @@ -59,7 +64,8 @@ def _make_field(k: str, v: type):

return schema

def options(**kw) -> Any: # Any, so that type-checking doesn't complain

def options(**kw) -> Any: # Any, so that type-checking doesn't complain
return Options(**kw)


Expand Down
6 changes: 2 additions & 4 deletions sqeleton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,10 @@ def match_regexps(regexps: Dict[str, Any], s: str) -> Generator[tuple, None, Non

class CaseAwareMapping(MutableMapping[str, V]):
@abstractmethod
def get_key(self, key: str) -> str:
...
def get_key(self, key: str) -> str: ...

@abstractmethod
def __init__(self, initial):
...
def __init__(self, initial): ...

def new(self, initial=()):
return type(self)(initial)
Expand Down
Loading