Skip to content

Commit 02e6430

Browse files
authored
create_table with a PyArrow Schema (#305)
1 parent a3e3683 commit 02e6430

File tree

17 files changed

+417
-139
lines changed

17 files changed

+417
-139
lines changed

mkdocs/docs/api.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,25 @@ catalog.create_table(
146146
)
147147
```
148148

149+
To create a table using a pyarrow schema:
150+
151+
```python
152+
import pyarrow as pa
153+
154+
schema = pa.schema(
155+
[
156+
pa.field("foo", pa.string(), nullable=True),
157+
pa.field("bar", pa.int32(), nullable=False),
158+
pa.field("baz", pa.bool_(), nullable=True),
159+
]
160+
)
161+
162+
catalog.create_table(
163+
identifier="docs_example.bids",
164+
schema=schema,
165+
)
166+
```
167+
149168
## Load a table
150169

151170
### Catalog table

pyiceberg/catalog/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from dataclasses import dataclass
2525
from enum import Enum
2626
from typing import (
27+
TYPE_CHECKING,
2728
Callable,
2829
Dict,
2930
List,
@@ -56,6 +57,9 @@
5657
)
5758
from pyiceberg.utils.config import Config, merge_config
5859

60+
if TYPE_CHECKING:
61+
import pyarrow as pa
62+
5963
logger = logging.getLogger(__name__)
6064

6165
_ENV_CONFIG = Config()
@@ -288,7 +292,7 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[
288292
def create_table(
289293
self,
290294
identifier: Union[str, Identifier],
291-
schema: Schema,
295+
schema: Union[Schema, "pa.Schema"],
292296
location: Optional[str] = None,
293297
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
294298
sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -512,6 +516,22 @@ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> Non
512516
if overlap:
513517
raise ValueError(f"Updates and deletes have an overlap: {overlap}")
514518

519+
@staticmethod
520+
def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema:
521+
if isinstance(schema, Schema):
522+
return schema
523+
try:
524+
import pyarrow as pa
525+
526+
from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow
527+
528+
if isinstance(schema, pa.Schema):
529+
schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore
530+
return schema
531+
except ModuleNotFoundError:
532+
pass
533+
raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema")
534+
515535
def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str:
516536
if not location:
517537
return self._get_default_warehouse_location(database_name, table_name)

pyiceberg/catalog/dynamodb.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import uuid
1818
from time import time
1919
from typing import (
20+
TYPE_CHECKING,
2021
Any,
2122
Dict,
2223
List,
@@ -57,6 +58,9 @@
5758
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
5859
from pyiceberg.typedef import EMPTY_DICT
5960

61+
if TYPE_CHECKING:
62+
import pyarrow as pa
63+
6064
DYNAMODB_CLIENT = "dynamodb"
6165

6266
DYNAMODB_COL_IDENTIFIER = "identifier"
@@ -127,7 +131,7 @@ def _dynamodb_table_exists(self) -> bool:
127131
def create_table(
128132
self,
129133
identifier: Union[str, Identifier],
130-
schema: Schema,
134+
schema: Union[Schema, "pa.Schema"],
131135
location: Optional[str] = None,
132136
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
133137
sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -152,6 +156,8 @@ def create_table(
152156
ValueError: If the identifier is invalid, or no path is given to store metadata.
153157
154158
"""
159+
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
160+
155161
database_name, table_name = self.identifier_to_database_and_table(identifier)
156162

157163
location = self._resolve_table_location(location, database_name, table_name)

pyiceberg/catalog/glue.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
from typing import (
20+
TYPE_CHECKING,
2021
Any,
2122
Dict,
2223
List,
@@ -88,6 +89,9 @@
8889
UUIDType,
8990
)
9091

92+
if TYPE_CHECKING:
93+
import pyarrow as pa
94+
9195
# If Glue should skip archiving an old table version when creating a new version in a commit. By
9296
# default, Glue archives all old table versions after an UpdateTable call, but Glue has a default
9397
# max number of archived table versions (can be increased). So for streaming use case with lots
@@ -329,7 +333,7 @@ def _get_glue_table(self, database_name: str, table_name: str) -> TableTypeDef:
329333
def create_table(
330334
self,
331335
identifier: Union[str, Identifier],
332-
schema: Schema,
336+
schema: Union[Schema, "pa.Schema"],
333337
location: Optional[str] = None,
334338
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
335339
sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -354,6 +358,8 @@ def create_table(
354358
ValueError: If the identifier is invalid, or no path is given to store metadata.
355359
356360
"""
361+
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
362+
357363
database_name, table_name = self.identifier_to_database_and_table(identifier)
358364

359365
location = self._resolve_table_location(location, database_name, table_name)

pyiceberg/catalog/hive.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import time
1919
from types import TracebackType
2020
from typing import (
21+
TYPE_CHECKING,
2122
Any,
2223
Dict,
2324
List,
@@ -91,6 +92,10 @@
9192
UUIDType,
9293
)
9394

95+
if TYPE_CHECKING:
96+
import pyarrow as pa
97+
98+
9499
# Replace by visitor
95100
hive_types = {
96101
BooleanType: "boolean",
@@ -250,7 +255,7 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table:
250255
def create_table(
251256
self,
252257
identifier: Union[str, Identifier],
253-
schema: Schema,
258+
schema: Union[Schema, "pa.Schema"],
254259
location: Optional[str] = None,
255260
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
256261
sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -273,6 +278,8 @@ def create_table(
273278
AlreadyExistsError: If a table with the name already exists.
274279
ValueError: If the identifier is invalid.
275280
"""
281+
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
282+
276283
properties = {**DEFAULT_PROPERTIES, **properties}
277284
database_name, table_name = self.identifier_to_database_and_table(identifier)
278285
current_time_millis = int(time.time() * 1000)

pyiceberg/catalog/noop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
from typing import (
18+
TYPE_CHECKING,
1819
List,
1920
Optional,
2021
Set,
@@ -33,12 +34,15 @@
3334
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER
3435
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
3536

37+
if TYPE_CHECKING:
38+
import pyarrow as pa
39+
3640

3741
class NoopCatalog(Catalog):
3842
def create_table(
3943
self,
4044
identifier: Union[str, Identifier],
41-
schema: Schema,
45+
schema: Union[Schema, "pa.Schema"],
4246
location: Optional[str] = None,
4347
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
4448
sort_order: SortOrder = UNSORTED_SORT_ORDER,

pyiceberg/catalog/rest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
from json import JSONDecodeError
1818
from typing import (
19+
TYPE_CHECKING,
1920
Any,
2021
Dict,
2122
List,
@@ -68,6 +69,9 @@
6869
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
6970
from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel
7071

72+
if TYPE_CHECKING:
73+
import pyarrow as pa
74+
7175
ICEBERG_REST_SPEC_VERSION = "0.14.1"
7276

7377

@@ -437,12 +441,14 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
437441
def create_table(
438442
self,
439443
identifier: Union[str, Identifier],
440-
schema: Schema,
444+
schema: Union[Schema, "pa.Schema"],
441445
location: Optional[str] = None,
442446
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
443447
sort_order: SortOrder = UNSORTED_SORT_ORDER,
444448
properties: Properties = EMPTY_DICT,
445449
) -> Table:
450+
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
451+
446452
namespace_and_table = self._split_identifier_for_path(identifier)
447453
request = CreateTableRequest(
448454
name=namespace_and_table["table"],

pyiceberg/catalog/sql.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
from typing import (
19+
TYPE_CHECKING,
1920
List,
2021
Optional,
2122
Set,
@@ -65,6 +66,9 @@
6566
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
6667
from pyiceberg.typedef import EMPTY_DICT
6768

69+
if TYPE_CHECKING:
70+
import pyarrow as pa
71+
6872

6973
class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase):
7074
pass
@@ -140,7 +144,7 @@ def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table:
140144
def create_table(
141145
self,
142146
identifier: Union[str, Identifier],
143-
schema: Schema,
147+
schema: Union[Schema, "pa.Schema"],
144148
location: Optional[str] = None,
145149
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
146150
sort_order: SortOrder = UNSORTED_SORT_ORDER,
@@ -165,6 +169,8 @@ def create_table(
165169
ValueError: If the identifier is invalid, or no path is given to store metadata.
166170
167171
"""
172+
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
173+
168174
database_name, table_name = self.identifier_to_database_and_table(identifier)
169175
if not self._namespace_exists(database_name):
170176
raise NoSuchNamespaceError(f"Namespace does not exist: {database_name}")

pyiceberg/io/pyarrow.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import annotations
2727

2828
import concurrent.futures
29+
import itertools
2930
import logging
3031
import os
3132
import re
@@ -34,7 +35,6 @@
3435
from dataclasses import dataclass
3536
from enum import Enum
3637
from functools import lru_cache, singledispatch
37-
from itertools import chain
3838
from typing import (
3939
TYPE_CHECKING,
4040
Any,
@@ -637,7 +637,7 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
637637
if len(positional_deletes) == 1:
638638
all_chunks = positional_deletes[0]
639639
else:
640-
all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in positional_deletes]))
640+
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
641641
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
642642

643643

@@ -912,6 +912,21 @@ def after_map_value(self, element: pa.Field) -> None:
912912
self._field_names.pop()
913913

914914

915+
class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
916+
"""
917+
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
918+
919+
The schema generated through this visitor should always be
920+
used in conjunction with `new_table_metadata` function to
921+
assign new field ids in order. This is currently used only
922+
when creating an Iceberg Schema from a PyArrow schema when
923+
creating a new Iceberg table.
924+
"""
925+
926+
def _field_id(self, field: pa.Field) -> int:
927+
return -1
928+
929+
915930
def _task_to_table(
916931
fs: FileSystem,
917932
task: FileScanTask,
@@ -999,7 +1014,7 @@ def _task_to_table(
9991014

10001015
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
10011016
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
1002-
unique_deletes = set(chain.from_iterable([task.delete_files for task in tasks]))
1017+
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
10031018
if len(unique_deletes) > 0:
10041019
executor = ExecutorFactory.get_or_create()
10051020
deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
@@ -1421,7 +1436,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsColl
14211436
def struct(
14221437
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
14231438
) -> List[StatisticsCollector]:
1424-
return list(chain(*[result() for result in field_results]))
1439+
return list(itertools.chain(*[result() for result in field_results]))
14251440

14261441
def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
14271442
self._field_id = field.field_id
@@ -1513,7 +1528,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath
15131528
return struct_result()
15141529

15151530
def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
1516-
return list(chain(*[result() for result in field_results]))
1531+
return list(itertools.chain(*[result() for result in field_results]))
15171532

15181533
def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
15191534
self._field_id = field.field_id

0 commit comments

Comments
 (0)