Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 110 additions & 22 deletions pyiceberg/catalog/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
union,
update,
)
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.exc import IntegrityError, NoResultFound, OperationalError
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
Expand All @@ -48,6 +48,7 @@
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
CommitFailedException,
NamespaceAlreadyExistsError,
NamespaceNotEmptyError,
NoSuchNamespaceError,
Expand All @@ -59,7 +60,7 @@
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT
Expand Down Expand Up @@ -268,16 +269,32 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
with Session(self.engine) as session:
res = session.execute(
delete(IcebergTables).where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
if self.engine.dialect.supports_sane_rowcount:
res = session.execute(
delete(IcebergTables).where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
)
)
)
if res.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
)
.one()
)
session.delete(tbl)
except NoResultFound as e:
raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") from e
session.commit()
if res.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}")

def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table:
"""Rename a fully classified table name.
Expand All @@ -301,18 +318,35 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
raise NoSuchNamespaceError(f"Namespace does not exist: {to_database_name}")
with Session(self.engine) as session:
try:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == from_database_name,
IcebergTables.table_name == from_table_name,
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == from_database_name,
IcebergTables.table_name == from_table_name,
)
.values(table_namespace=to_database_name, table_name=to_table_name)
)
.values(table_namespace=to_database_name, table_name=to_table_name)
)
result = session.execute(stmt)
if result.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {from_table_name}")
result = session.execute(stmt)
if result.rowcount < 1:
raise NoSuchTableError(f"Table does not exist: {from_table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == from_database_name,
IcebergTables.table_name == from_table_name,
)
.one()
)
tbl.table_namespace = to_database_name
tbl.table_name = to_table_name
except NoResultFound as e:
raise NoSuchTableError(f"Table does not exist: {from_table_name}") from e
session.commit()
except IntegrityError as e:
raise TableAlreadyExistsError(f"Table {to_database_name}.{to_table_name} already exists") from e
Expand All @@ -329,8 +363,62 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons

Raises:
NoSuchTableError: If a table with the given identifier does not exist.
CommitFailedException: If the commit failed.
"""
raise NotImplementedError
identifier_tuple = self.identifier_to_tuple_without_catalog(
tuple(table_request.identifier.namespace.root + [table_request.identifier.name])
)
current_table = self.load_table(identifier_tuple)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
base_metadata = current_table.metadata
for requirement in table_request.requirements:
requirement.validate(base_metadata)

updated_metadata = update_table_metadata(base_metadata, table_request.updates)
if updated_metadata == base_metadata:
# no changes, do nothing
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)

# write new metadata
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)

with Session(self.engine) as session:
if self.engine.dialect.supports_sane_rowcount:
stmt = (
update(IcebergTables)
.where(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location)
)
result = session.execute(stmt)
if result.rowcount < 1:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}")
else:
try:
tbl = (
session.query(IcebergTables)
.with_for_update(of=IcebergTables)
.filter(
IcebergTables.catalog_name == self.name,
IcebergTables.table_namespace == database_name,
IcebergTables.table_name == table_name,
IcebergTables.metadata_location == current_table.metadata_location,
)
.one()
)
tbl.metadata_location = new_metadata_location
tbl.previous_metadata_location = current_table.metadata_location
except NoResultFound as e:
raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e
session.commit()

return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)

def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool:
namespace = self.identifier_to_database(identifier)
Expand Down
55 changes: 55 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SortOrder,
)
from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import IntegerType


@pytest.fixture(name="warehouse", scope="session")
Expand Down Expand Up @@ -87,6 +88,19 @@ def catalog_sqlite(warehouse: Path) -> Generator[SqlCatalog, None, None]:
catalog.destroy_tables()


@pytest.fixture(scope="module")
def catalog_sqlite_without_rowcount(warehouse: Path) -> Generator[SqlCatalog, None, None]:
props = {
"uri": "sqlite:////tmp/sql-catalog.db",
"warehouse": f"file://{warehouse}",
}
catalog = SqlCatalog("test_sql_catalog", **props)
catalog.engine.dialect.supports_sane_rowcount = False
catalog.create_tables()
yield catalog
catalog.destroy_tables()


def test_creation_with_no_uri() -> None:
with pytest.raises(NoSuchPropertyException):
SqlCatalog("test_ddb_catalog", not_uri="unused")
Expand Down Expand Up @@ -305,6 +319,7 @@ def test_load_table_from_self_identifier(catalog: SqlCatalog, table_schema_neste
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_drop_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
Expand All @@ -322,6 +337,7 @@ def test_drop_table(catalog: SqlCatalog, table_schema_nested: Schema, random_ide
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_drop_table_from_self_identifier(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
Expand All @@ -341,6 +357,7 @@ def test_drop_table_from_self_identifier(catalog: SqlCatalog, table_schema_neste
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_drop_table_that_does_not_exist(catalog: SqlCatalog, random_identifier: Identifier) -> None:
Expand All @@ -353,6 +370,7 @@ def test_drop_table_that_does_not_exist(catalog: SqlCatalog, random_identifier:
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table(
Expand All @@ -377,6 +395,7 @@ def test_rename_table(
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table_from_self_identifier(
Expand All @@ -403,6 +422,7 @@ def test_rename_table_from_self_identifier(
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table_to_existing_one(
Expand All @@ -425,6 +445,7 @@ def test_rename_table_to_existing_one(
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_missing_table(catalog: SqlCatalog, random_identifier: Identifier, another_random_identifier: Identifier) -> None:
Expand All @@ -439,6 +460,7 @@ def test_rename_missing_table(catalog: SqlCatalog, random_identifier: Identifier
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_rename_table_to_missing_namespace(
Expand Down Expand Up @@ -664,3 +686,36 @@ def test_update_namespace_properties(catalog: SqlCatalog, database_name: str) ->
else:
assert k in update_report.removed
assert "updated test description" == catalog.load_namespace_properties(database_name)["comment"]


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
lazy_fixture('catalog_sqlite_without_rowcount'),
],
)
def test_commit_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table = catalog.create_table(random_identifier, table_schema_nested)

assert catalog._parse_metadata_version(table.metadata_location) == 0
assert table.metadata.current_schema_id == 0

transaction = table.transaction()
update = transaction.update_schema()
update.add_column(path="b", field_type=IntegerType())
update.commit()
transaction.commit_transaction()

updated_table_metadata = table.metadata

assert catalog._parse_metadata_version(table.metadata_location) == 1
assert updated_table_metadata.current_schema_id == 1
assert len(updated_table_metadata.schemas) == 2
new_schema = next(schema for schema in updated_table_metadata.schemas if schema.schema_id == 1)
assert new_schema
assert new_schema == update._apply()
assert new_schema.find_field("b").field_type == IntegerType()