Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,14 +1776,7 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self._file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
target_schema = schema_to_pyarrow(
promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids
)
if self._use_large_types is False:
target_schema = _pyarrow_schema_ensure_small_types(target_schema)
return values.cast(target_schema)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
if (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
if field.field_type == TimestampType():
# Downcasting of nanoseconds to microseconds
if (
Expand All @@ -1802,13 +1795,22 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
pa.types.is_timestamp(target_type)
and target_type.tz == "UTC"
and pa.types.is_timestamp(values.type)
and values.type.tz in UTC_ALIASES
and (values.type.tz in UTC_ALIASES or values.type.tz is None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also allow when type.tz is None, mirrors the logic here

if primitive.tz in UTC_ALIASES:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()

):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")

if field.field_type != file_field.field_type:
target_schema = schema_to_pyarrow(
promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids
)
if self._use_large_types is False:
target_schema = _pyarrow_schema_ensure_small_types(target_schema)
return values.cast(target_schema)

return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down
39 changes: 39 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,45 @@ def test_initial_value() -> None:
assert val.as_py() == 22


def test__to_requested_schema_timestamp_to_timestamptz_projection() -> None:
from datetime import datetime, timezone

# file is written with timestamp without timezone
file_schema = Schema(NestedField(1, "ts_field", TimestampType(), required=False))
batch = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("us"),
)
],
names=["ts_field"],
)

# table is written with timestamp with timezone
table_schema = Schema(NestedField(1, "ts_field", TimestamptzType(), required=False))

actual_result = _to_requested_schema(table_schema, file_schema, batch, downcast_ns_timestamp_to_us=True)
expected = pa.record_batch(
[
pa.array(
[
datetime(2025, 8, 14, 12, 0, 0),
datetime(2025, 8, 14, 13, 0, 0),
],
type=pa.timestamp("us", tz=timezone.utc),
)
],
names=["ts_field"],
)

# expect actual_result to have timezone
assert expected.equals(actual_result)


def test__to_requested_schema_timestamps(
arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
arrow_table_with_all_timestamp_precisions: pa.Table,
Expand Down