@@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
412412 spark : SparkSession , session_catalog : Catalog , arrow_table_with_null : pa .Table , part_col : str , format_version : int
413413) -> None :
414414 identifier = f"default.unpartitioned_table_v{ format_version } _evolve_into_identity_transformed_partition_field_{ part_col } "
415+
416+ try :
417+ session_catalog .drop_table (identifier = identifier )
418+ except NoSuchTableError :
419+ pass
420+
415421 tbl = session_catalog .create_table (
416422 identifier = identifier ,
417423 schema = TABLE_SCHEMA ,
@@ -756,6 +762,55 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
756762 tbl .append ("not a df" )
757763
758764
765+ @pytest .mark .integration
766+ @pytest .mark .parametrize (
767+ "spec" ,
768+ [
769+ (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = TruncateTransform (2 ), name = "int_trunc" ))),
770+ (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = TruncateTransform (2 ), name = "long_trunc" ))),
771+ (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = TruncateTransform (2 ), name = "string_trunc" ))),
772+ ],
773+ )
774+ @pytest .mark .parametrize ("format_version" , [1 , 2 ])
775+ def test_truncate_transform (
776+ spec : PartitionSpec ,
777+ spark : SparkSession ,
778+ session_catalog : Catalog ,
779+ arrow_table_with_null : pa .Table ,
780+ format_version : int ,
781+ ) -> None :
782+ identifier = "default.truncate_transform"
783+
784+ try :
785+ session_catalog .drop_table (identifier = identifier )
786+ except NoSuchTableError :
787+ pass
788+
789+ tbl = _create_table (
790+ session_catalog = session_catalog ,
791+ identifier = identifier ,
792+ properties = {"format-version" : str (format_version )},
793+ data = [arrow_table_with_null ],
794+ partition_spec = spec ,
795+ )
796+
797+ assert tbl .format_version == format_version , f"Expected v{ format_version } , got: v{ tbl .format_version } "
798+ df = spark .table (identifier )
799+ assert df .count () == 3 , f"Expected 3 total rows for { identifier } "
800+ for col in arrow_table_with_null .column_names :
801+ assert df .where (f"{ col } is not null" ).count () == 2 , f"Expected 2 non-null rows for { col } "
802+ assert df .where (f"{ col } is null" ).count () == 1 , f"Expected 1 null row for { col } is null"
803+
804+ assert tbl .inspect .partitions ().num_rows == 3
805+ files_df = spark .sql (
806+ f"""
807+ SELECT *
808+ FROM { identifier } .files
809+ """
810+ )
811+ assert files_df .count () == 3
812+
813+
759814@pytest .mark .integration
760815@pytest .mark .parametrize (
761816 "spec" ,
@@ -767,18 +822,52 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
767822 PartitionField (source_id = 1 , field_id = 1002 , transform = IdentityTransform (), name = "bool" ),
768823 )
769824 ),
770- # none of non-identity is supported
771- (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = BucketTransform (2 ), name = "int_bucket" ))),
772- (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = BucketTransform (2 ), name = "long_bucket" ))),
773- (PartitionSpec (PartitionField (source_id = 10 , field_id = 1001 , transform = BucketTransform (2 ), name = "date_bucket" ))),
774- (PartitionSpec (PartitionField (source_id = 8 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamp_bucket" ))),
775- (PartitionSpec (PartitionField (source_id = 9 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamptz_bucket" ))),
776- (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = BucketTransform (2 ), name = "string_bucket" ))),
777- (PartitionSpec (PartitionField (source_id = 12 , field_id = 1001 , transform = BucketTransform (2 ), name = "fixed_bucket" ))),
778- (PartitionSpec (PartitionField (source_id = 11 , field_id = 1001 , transform = BucketTransform (2 ), name = "binary_bucket" ))),
779- (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = TruncateTransform (2 ), name = "int_trunc" ))),
780- (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = TruncateTransform (2 ), name = "long_trunc" ))),
781- (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = TruncateTransform (2 ), name = "string_trunc" ))),
825+ ],
826+ )
827+ @pytest .mark .parametrize ("format_version" , [1 , 2 ])
828+ def test_identity_and_bucket_transform_spec (
829+ spec : PartitionSpec ,
830+ spark : SparkSession ,
831+ session_catalog : Catalog ,
832+ arrow_table_with_null : pa .Table ,
833+ format_version : int ,
834+ ) -> None :
835+ identifier = "default.identity_and_bucket_transform"
836+
837+ try :
838+ session_catalog .drop_table (identifier = identifier )
839+ except NoSuchTableError :
840+ pass
841+
842+ tbl = _create_table (
843+ session_catalog = session_catalog ,
844+ identifier = identifier ,
845+ properties = {"format-version" : str (format_version )},
846+ data = [arrow_table_with_null ],
847+ partition_spec = spec ,
848+ )
849+
850+ assert tbl .format_version == format_version , f"Expected v{ format_version } , got: v{ tbl .format_version } "
851+ df = spark .table (identifier )
852+ assert df .count () == 3 , f"Expected 3 total rows for { identifier } "
853+ for col in arrow_table_with_null .column_names :
854+ assert df .where (f"{ col } is not null" ).count () == 2 , f"Expected 2 non-null rows for { col } "
855+ assert df .where (f"{ col } is null" ).count () == 1 , f"Expected 1 null row for { col } is null"
856+
857+ assert tbl .inspect .partitions ().num_rows == 3
858+ files_df = spark .sql (
859+ f"""
860+ SELECT *
861+ FROM { identifier } .files
862+ """
863+ )
864+ assert files_df .count () == 3
865+
866+
867+ @pytest .mark .integration
868+ @pytest .mark .parametrize (
869+ "spec" ,
870+ [
782871 (PartitionSpec (PartitionField (source_id = 11 , field_id = 1001 , transform = TruncateTransform (2 ), name = "binary_trunc" ))),
783872 ],
784873)
@@ -801,11 +890,66 @@ def test_unsupported_transform(
801890
802891 with pytest .raises (
803892 ValueError ,
804- match = "Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: * " ,
893+ match = "FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary " ,
805894 ):
806895 tbl .append (arrow_table_with_null )
807896
808897
898+ @pytest .mark .integration
899+ @pytest .mark .parametrize (
900+ "spec, expected_rows" ,
901+ [
902+ (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = BucketTransform (2 ), name = "int_bucket" )), 3 ),
903+ (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = BucketTransform (2 ), name = "long_bucket" )), 3 ),
904+ (PartitionSpec (PartitionField (source_id = 10 , field_id = 1001 , transform = BucketTransform (2 ), name = "date_bucket" )), 3 ),
905+ (PartitionSpec (PartitionField (source_id = 8 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamp_bucket" )), 3 ),
906+ (PartitionSpec (PartitionField (source_id = 9 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamptz_bucket" )), 3 ),
907+ (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = BucketTransform (2 ), name = "string_bucket" )), 3 ),
908+ (PartitionSpec (PartitionField (source_id = 12 , field_id = 1001 , transform = BucketTransform (2 ), name = "fixed_bucket" )), 2 ),
909+ (PartitionSpec (PartitionField (source_id = 11 , field_id = 1001 , transform = BucketTransform (2 ), name = "binary_bucket" )), 2 ),
910+ ],
911+ )
912+ @pytest .mark .parametrize ("format_version" , [1 , 2 ])
913+ def test_bucket_transform (
914+ spark : SparkSession ,
915+ session_catalog : Catalog ,
916+ arrow_table_with_null : pa .Table ,
917+ spec : PartitionSpec ,
918+ expected_rows : int ,
919+ format_version : int ,
920+ ) -> None :
921+ identifier = "default.bucket_transform"
922+
923+ try :
924+ session_catalog .drop_table (identifier = identifier )
925+ except NoSuchTableError :
926+ pass
927+
928+ tbl = _create_table (
929+ session_catalog = session_catalog ,
930+ identifier = identifier ,
931+ properties = {"format-version" : str (format_version )},
932+ data = [arrow_table_with_null ],
933+ partition_spec = spec ,
934+ )
935+
936+ assert tbl .format_version == format_version , f"Expected v{ format_version } , got: v{ tbl .format_version } "
937+ df = spark .table (identifier )
938+ assert df .count () == 3 , f"Expected 3 total rows for { identifier } "
939+ for col in arrow_table_with_null .column_names :
940+ assert df .where (f"{ col } is not null" ).count () == 2 , f"Expected 2 non-null rows for { col } "
941+ assert df .where (f"{ col } is null" ).count () == 1 , f"Expected 1 null row for { col } is null"
942+
943+ assert tbl .inspect .partitions ().num_rows == expected_rows
944+ files_df = spark .sql (
945+ f"""
946+ SELECT *
947+ FROM { identifier } .files
948+ """
949+ )
950+ assert files_df .count () == expected_rows
951+
952+
809953@pytest .mark .integration
810954@pytest .mark .parametrize (
811955 "transform,expected_rows" ,
0 commit comments