Skip to content

Commit 083f7c4

Browse files
committed
Working through issues between custom catalog and build in schema
1 parent 294a8a9 commit 083f7c4

File tree

4 files changed

+75
-6
lines changed

4 files changed

+75
-6
lines changed

python/datafusion/catalog.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import annotations
2121

2222
from abc import ABC, abstractmethod
23-
from typing import TYPE_CHECKING
23+
from typing import TYPE_CHECKING, Protocol
2424

2525
import datafusion._internal as df_internal
2626

@@ -174,7 +174,9 @@ def schema(self, name: str) -> Schema | None:
174174
"""Retrieve a specific schema from this catalog."""
175175
...
176176

177-
def register_schema(self, name: str, schema: Schema) -> None: # noqa: B027
177+
def register_schema( # noqa: B027
178+
self, name: str, schema: SchemaProviderExportable | SchemaProvider | Schema
179+
) -> None:
178180
"""Add a schema to this catalog.
179181
180182
This method is optional. If your catalog provides a fixed list of schemas, you
@@ -229,3 +231,12 @@ def deregister_table(self, name, cascade: bool) -> None: # noqa: B027
229231
def table_exist(self, name: str) -> bool:
230232
"""Returns true if the table exists in this schema."""
231233
...
234+
235+
236+
class SchemaProviderExportable(Protocol):
237+
"""Type hint for object that has __datafusion_schema_provider__ PyCapsule.
238+
239+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.SchemaProvider.html
240+
"""
241+
242+
def __datafusion_schema_provider__(self) -> object: ...

python/datafusion/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
except ImportError:
3030
from typing_extensions import deprecated # Python 3.12
3131

32-
from datafusion.catalog import Catalog, Table
32+
from datafusion.catalog import Catalog, CatalogProvider, Table
3333
from datafusion.dataframe import DataFrame
3434
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
@@ -763,7 +763,7 @@ def catalog_names(self) -> set[str]:
763763
return self.ctx.catalog_names()
764764

765765
def register_catalog_provider(
766-
self, name: str, provider: CatalogProviderExportable | Catalog
766+
self, name: str, provider: CatalogProviderExportable | CatalogProvider | Catalog
767767
) -> None:
768768
"""Register a catalog provider."""
769769
if isinstance(provider, Catalog):

python/tests/test_catalog.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,51 @@ def test_python_table_provider(ctx: SessionContext):
162162
schema.deregister_table("table3")
163163
schema.register_table("table4", create_dataset())
164164
assert schema.table_names() == {"table4"}
165+
166+
167+
def test_in_end_to_end_python_providers(ctx: SessionContext):
168+
"""Test registering all python providers and running a query against them."""
169+
170+
all_catalog_names = [
171+
"datafusion",
172+
"custom_catalog",
173+
"in_mem_catalog",
174+
]
175+
176+
all_schema_names = [
177+
"custom_schema",
178+
"in_mem_schema",
179+
]
180+
181+
ctx.register_catalog_provider(all_catalog_names[1], CustomCatalogProvider())
182+
ctx.register_catalog_provider(
183+
all_catalog_names[2], dfn.catalog.Catalog.memory_catalog()
184+
)
185+
186+
for catalog_name in all_catalog_names:
187+
catalog = ctx.catalog(catalog_name)
188+
189+
# Clean out previous schemas if they exist so we can start clean
190+
for schema_name in catalog.schema_names():
191+
catalog.deregister_schema(schema_name, cascade=False)
192+
193+
catalog.register_schema(all_schema_names[0], CustomSchemaProvider())
194+
catalog.register_schema(all_schema_names[1], dfn.catalog.Schema.memory_schema())
195+
196+
for schema_name in all_schema_names:
197+
schema = catalog.schema(schema_name)
198+
199+
for table_name in schema.table_names():
200+
schema.deregister_table(table_name)
201+
202+
schema.register_table("test_table", create_dataset())
203+
204+
for catalog_name in all_catalog_names:
205+
for schema_name in all_schema_names:
206+
table_full_name = f"{catalog_name}.{schema_name}.test_table"
207+
208+
batches = ctx.sql(f"select * from {table_full_name}").collect()
209+
210+
assert len(batches) == 1
211+
assert batches[0].column(0) == pa.array([1, 2, 3])
212+
assert batches[0].column(1) == pa.array([4, 5, 6])

src/catalog.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,19 @@ impl RustWrappedPySchemaProvider {
314314

315315
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
316316
} else {
317-
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
317+
if let Ok(inner_table) = py_table.getattr("table") {
318+
if let Ok(inner_table) = inner_table.extract::<PyTable>() {
319+
return Ok(Some(inner_table.table));
320+
}
321+
}
318322

319-
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
323+
match py_table.extract::<PyTable>() {
324+
Ok(py_table) => Ok(Some(py_table.table)),
325+
Err(_) => {
326+
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
327+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
328+
}
329+
}
320330
}
321331
})
322332
}

0 commit comments

Comments
 (0)