diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 77ece56163..593c6b54a1 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -31,7 +31,7 @@ union, update, ) -from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy.exc import IntegrityError, NoResultFound, OperationalError from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -48,6 +48,7 @@ PropertiesUpdateSummary, ) from pyiceberg.exceptions import ( + CommitFailedException, NamespaceAlreadyExistsError, NamespaceNotEmptyError, NoSuchNamespaceError, @@ -59,7 +60,7 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table +from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata from pyiceberg.table.metadata import new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT @@ -268,16 +269,32 @@ def drop_table(self, identifier: Union[str, Identifier]) -> None: identifier_tuple = self.identifier_to_tuple_without_catalog(identifier) database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError) with Session(self.engine) as session: - res = session.execute( - delete(IcebergTables).where( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == database_name, - IcebergTables.table_name == table_name, + if self.engine.dialect.supports_sane_rowcount: + res = session.execute( + delete(IcebergTables).where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == database_name, + IcebergTables.table_name == table_name, + ) ) - ) + if res.rowcount < 1: + raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == database_name, + IcebergTables.table_name == table_name, + ) + .one() + ) + session.delete(tbl) + except NoResultFound as e: + raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") from e session.commit() - if res.rowcount < 1: - raise NoSuchTableError(f"Table does not exist: {database_name}.{table_name}") def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: Union[str, Identifier]) -> Table: """Rename a fully classified table name. @@ -301,18 +318,35 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U raise NoSuchNamespaceError(f"Namespace does not exist: {to_database_name}") with Session(self.engine) as session: try: - stmt = ( - update(IcebergTables) - .where( - IcebergTables.catalog_name == self.name, - IcebergTables.table_namespace == from_database_name, - IcebergTables.table_name == from_table_name, + if self.engine.dialect.supports_sane_rowcount: + stmt = ( + update(IcebergTables) + .where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == from_database_name, + IcebergTables.table_name == from_table_name, + ) + .values(table_namespace=to_database_name, table_name=to_table_name) ) - .values(table_namespace=to_database_name, table_name=to_table_name) - ) - result = session.execute(stmt) - if result.rowcount < 1: - raise NoSuchTableError(f"Table does not exist: {from_table_name}") + result = session.execute(stmt) + if result.rowcount < 1: + raise NoSuchTableError(f"Table does not exist: {from_table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == from_database_name, + IcebergTables.table_name == from_table_name, + ) + .one() + ) + tbl.table_namespace = to_database_name + tbl.table_name = to_table_name + except NoResultFound as e: + raise NoSuchTableError(f"Table does not exist: {from_table_name}") from e session.commit() except IntegrityError as e: raise TableAlreadyExistsError(f"Table {to_database_name}.{to_table_name} already exists") from e @@ -329,8 +363,62 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons Raises: NoSuchTableError: If a table with the given identifier does not exist. + CommitFailedException: If the commit failed. """ - raise NotImplementedError + identifier_tuple = self.identifier_to_tuple_without_catalog( + tuple(table_request.identifier.namespace.root + [table_request.identifier.name]) + ) + current_table = self.load_table(identifier_tuple) + database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError) + base_metadata = current_table.metadata + for requirement in table_request.requirements: + requirement.validate(base_metadata) + + updated_metadata = update_table_metadata(base_metadata, table_request.updates) + if updated_metadata == base_metadata: + # no changes, do nothing + return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location) + + # write new metadata + new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1 + new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version) + self._write_metadata(updated_metadata, current_table.io, new_metadata_location) + + with Session(self.engine) as session: + if self.engine.dialect.supports_sane_rowcount: + stmt = ( + update(IcebergTables) + .where( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == database_name, + IcebergTables.table_name == table_name, + IcebergTables.metadata_location == current_table.metadata_location, + ) + .values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location) + ) + result = session.execute(stmt) + if result.rowcount < 1: + raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") + else: + try: + tbl = ( + session.query(IcebergTables) + .with_for_update(of=IcebergTables) + .filter( + IcebergTables.catalog_name == self.name, + IcebergTables.table_namespace == database_name, + IcebergTables.table_name == table_name, + IcebergTables.metadata_location == current_table.metadata_location, + ) + .one() + ) + tbl.metadata_location = new_metadata_location + tbl.previous_metadata_location = current_table.metadata_location + except NoResultFound as e: + raise CommitFailedException(f"Table has been updated by another process: {database_name}.{table_name}") from e + session.commit() + + return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location) def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool: namespace = self.identifier_to_database(identifier) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 95dc24ad15..8bf921aa4d 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -42,6 +42,7 @@ SortOrder, ) from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import IntegerType @pytest.fixture(name="warehouse", scope="session") @@ -87,6 +88,19 @@ def catalog_sqlite(warehouse: Path) -> Generator[SqlCatalog, None, None]: catalog.destroy_tables() +@pytest.fixture(scope="module") +def catalog_sqlite_without_rowcount(warehouse: Path) -> Generator[SqlCatalog, None, None]: + props = { + "uri": "sqlite:////tmp/sql-catalog.db", + "warehouse": f"file://{warehouse}", + } + catalog = SqlCatalog("test_sql_catalog", **props) + catalog.engine.dialect.supports_sane_rowcount = False + catalog.create_tables() + yield catalog + catalog.destroy_tables() + + def test_creation_with_no_uri() -> None: with pytest.raises(NoSuchPropertyException): SqlCatalog("test_ddb_catalog", not_uri="unused") @@ -305,6 +319,7 @@ def test_load_table_from_self_identifier(catalog: SqlCatalog, table_schema_neste [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_drop_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None: @@ -322,6 +337,7 @@ def test_drop_table(catalog: SqlCatalog, table_schema_nested: Schema, random_ide [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_drop_table_from_self_identifier(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None: @@ -341,6 +357,7 @@ def test_drop_table_from_self_identifier(catalog: SqlCatalog, table_schema_neste [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_drop_table_that_does_not_exist(catalog: SqlCatalog, random_identifier: Identifier) -> None: @@ -353,6 +370,7 @@ def test_drop_table_that_does_not_exist(catalog: SqlCatalog, random_identifier: [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_rename_table( @@ -377,6 +395,7 @@ def test_rename_table( [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_rename_table_from_self_identifier( @@ -403,6 +422,7 @@ def test_rename_table_from_self_identifier( [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_rename_table_to_existing_one( @@ -425,6 +445,7 @@ def test_rename_table_to_existing_one( [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_rename_missing_table(catalog: SqlCatalog, random_identifier: Identifier, another_random_identifier: Identifier) -> None: @@ -439,6 +460,7 @@ def test_rename_missing_table(catalog: SqlCatalog, random_identifier: Identifier [ lazy_fixture('catalog_memory'), lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), ], ) def test_rename_table_to_missing_namespace( @@ -664,3 +686,36 @@ def test_update_namespace_properties(catalog: SqlCatalog, database_name: str) -> else: assert k in update_report.removed assert "updated test description" == catalog.load_namespace_properties(database_name)["comment"] + + +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), + ], +) +def test_commit_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None: + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, table_schema_nested) + + assert catalog._parse_metadata_version(table.metadata_location) == 0 + assert table.metadata.current_schema_id == 0 + + transaction = table.transaction() + update = transaction.update_schema() + update.add_column(path="b", field_type=IntegerType()) + update.commit() + transaction.commit_transaction() + + updated_table_metadata = table.metadata + + assert catalog._parse_metadata_version(table.metadata_location) == 1 + assert updated_table_metadata.current_schema_id == 1 + assert len(updated_table_metadata.schemas) == 2 + new_schema = next(schema for schema in updated_table_metadata.schemas if schema.schema_id == 1) + assert new_schema + assert new_schema == update._apply() + assert new_schema.find_field("b").field_type == IntegerType()